diff --git a/.gitee/PULL_REQUEST_TEMPLATE.zh-CN.md b/.gitee/PULL_REQUEST_TEMPLATE.zh-CN.md index fc9e09f35030f71a8b23b5bc9fe86b120820b8bc..e9cc1deb82ff0498f1a8267cd288ecde798f308c 100644 --- a/.gitee/PULL_REQUEST_TEMPLATE.zh-CN.md +++ b/.gitee/PULL_REQUEST_TEMPLATE.zh-CN.md @@ -17,6 +17,11 @@ --- +## 3. 分支合并要求 +- [ ] **代码合并**(请确保将 master 分支的最新代码同步合并至 poc 分支及 pre-research 分支,同时保证 poc 分支的代码也已正确合并到 pre-research 分支。) + +--- + ## 3. 代码检视 - **要求:** - 合入代码超过 200 行,需三人以上会议检视。 diff --git a/.gitignore b/.gitignore index 2f15a00811101c8743f981fecb6976c7066fb941..2417a7f3477ee3d635fb09975cbe0473f2637031 100644 --- a/.gitignore +++ b/.gitignore @@ -2,7 +2,6 @@ __pycache__/ *.py[cod] *$py.class -.idea # C extensions *.so @@ -143,4 +142,7 @@ cython_debug/ att_advisor*.html *.xlsx operator_tuning_file*.cfg -.ipynb_checkpoints/ \ No newline at end of file +.ipynb_checkpoints/ + +# pycharm settings +.idea \ No newline at end of file diff --git a/.gitmodules b/.gitmodules deleted file mode 100644 index b08433f072bf89f62edf88b3aff40d24c1040ea8..0000000000000000000000000000000000000000 --- a/.gitmodules +++ /dev/null @@ -1,3 +0,0 @@ -[submodule "dynolog_npu/third_party/dynolog"] - path = dynolog_npu/third_party/dynolog - url = https://github.com/facebookincubator/dynolog.git diff --git a/OWNERS b/OWNERS index 415d737ed907c577bc61e71c2839a485395b899c..2e949debf181a6e75fdb5b1e1e091ce7a39c7e69 100644 --- a/OWNERS +++ b/OWNERS @@ -1,6 +1,7 @@ approvers: - leo920320 - wo-wenjie +- ma-dongfang - xhahn - aerfaliang - wangchao285 @@ -10,14 +11,16 @@ approvers: - ly-qianxiao - blian - kun_8 +- binghamhuang reviewers: - lv-kaimeng +- litian_drinksnow +- binghamhuang - wo-wenjie - ly-qianxiao - leo920320 - sunboquan +- stby - Seanesmhxocism - TAJh -- czr9775 -- kali20gakki -- wjchuee \ No newline at end of file +- czr9775 \ No newline at end of file diff --git a/README.md b/README.md index 5ae0bf742fced7ed86452d03d013670cc3528316..400cb2673be8f40077b79b7af0d388f91535aa85 100644 --- a/README.md +++ b/README.md @@ -1,75 +1,78 @@ -# 🚨 重要通知 +# 变更通知 -**1. Ascend Training Tools 更名为 MindStudio Training Tools (mstt)。** +原Ascend Training Tools工具更名为MindStudio Training Tools,MindStudio训练工具链。变更计划如下: -**2. 本代码仓 URL 变更为 [https://gitee.com/ascend/mstt](https://gitee.com/ascend/mstt),原 URL 仍然可用(2024.07.04 )。** +1. 2024.06.25本代码仓名称变更为mstt。 +2. 2024.07.04 URL变更为[https://gitee.com/ascend/mstt](https://gitee.com/ascend/mstt),原始URL仍然可用,但建议使用新URL。 ---- - -# 🧰 MindStudio Training Tools +# MindStudio Training Tools ![Build Status](https://img.shields.io/badge/build-passing-brightgreen) ![Commit Activity](https://img.shields.io/badge/commit%20activity-high-red) ![License: Apache 2.0](https://img.shields.io/badge/license-Apache%202.0-blue) -## [分析迁移工具](https://gitee.com/ascend/mstt/wikis/工具介绍/分析迁移工具/分析迁移工具介绍) +## [模型训练开发全流程](https://www.hiascend.com/software/mindstudio/training) -1. [脚本分析工具](https://gitee.com/ascend/mstt/wikis/%E5%B7%A5%E5%85%B7%E4%BB%8B%E7%BB%8D/%E5%88%86%E6%9E%90%E8%BF%81%E7%A7%BB%E5%B7%A5%E5%85%B7/%E5%88%86%E6%9E%90%E5%B7%A5%E5%85%B7%E4%BD%BF%E7%94%A8%E6%8C%87%E5%AF%BC) +mstt包括精度工具(msprobe)和性能工具(msprof-analyze),分析迁移工具请参见[昇腾社区](https://www.hiascend.com/software/mindstudio/training)。 - 脚本分析工具可以帮助用户在执行迁移操作前,分析基于 GPU 平台的 PyTorch 训练脚本中算子、三方库套件、API 亲和性以及动态 shape 的支持情况。 +![training_process](debug/resources/training_process.png) -2. [(推荐)自动迁移工具](https://gitee.com/ascend/mstt/wikis/%E5%B7%A5%E5%85%B7%E4%BB%8B%E7%BB%8D/%E5%88%86%E6%9E%90%E8%BF%81%E7%A7%BB%E5%B7%A5%E5%85%B7/%E8%87%AA%E5%8A%A8%E8%BF%81%E7%A7%BB%E5%B7%A5%E5%85%B7%E4%BD%BF%E7%94%A8%E6%8C%87%E5%AF%BC) +# 使用说明 - 自动迁移工具只需在训练脚本中导入库代码即可完成模型脚本的迁移,使用方式简单,且修改内容少。 +## [精度工具](./debug/accuracy_tools/) -3. [脚本迁移工具](https://gitee.com/ascend/mstt/wikis/%E5%B7%A5%E5%85%B7%E4%BB%8B%E7%BB%8D/%E5%88%86%E6%9E%90%E8%BF%81%E7%A7%BB%E5%B7%A5%E5%85%B7/%E8%84%9A%E6%9C%AC%E8%BF%81%E7%A7%BB%E5%B7%A5%E5%85%B7%E4%BD%BF%E7%94%A8%E6%8C%87%E5%AF%BC) +[MindStudio Probe(msprobe,MindStudio 精度调试工具)](./debug/accuracy_tools/msprobe)。 - 脚本迁移工具通过后端命令行,将 GPU 上训练的 PyTorch 脚本迁移至 NPU 上,得到新的训练脚本用于训练。 -## [精度工具](./debug/accuracy_tools/) +### [性能工具](https://gitee.com/ascend/mstt/tree/master/profiler) -[MindStudio Probe(msprobe,MindStudio 精度调试工具)](./debug/accuracy_tools/msprobe)。 +1. [compare_tools(性能比对工具)](https://gitee.com/ascend/mstt/tree/master/profiler/compare_tools) + + 提供NPU与GPU性能拆解功能以及算子、通信、内存性能的比对功能。 -## [性能工具](./profiler/msprof_analyze) +2. [cluster_analyse(集群分析工具)](https://gitee.com/ascend/mstt/tree/master/profiler/cluster_analyse) -1. [compare_tools(性能比对工具)](./profiler/msprof_analyze/compare_tools) + 提供多机多卡的集群分析能力(基于通信域的通信分析和迭代耗时分析), 当前需要配合MindStudio Insight的集群分析功能使用。 - 提供 NPU 与 GPU 性能拆解功能以及算子、通信、内存性能的比对功能。 +3. [affinity_cpu_bind (亲和性cpu绑核工具) ](https://gitee.com/ascend/mstt/tree/master/profiler/affinity_cpu_bind) -2. [cluster_analyse(集群分析工具)](./profiler/msprof_analyze/cluster_analyse) + 提供亲和性CPU绑核能力,改善host_bound调度问题。 - 提供多机多卡的集群分析能力(基于通信域的通信分析和迭代耗时分析), 当前需要配合 MindStudio Insight 的集群分析功能使用。 +### [Tensorboard](https://gitee.com/ascend/mstt/tree/master/plugins/tensorboard-plugins/tb_plugin) -3. [advisor](./profiler/msprof_analyze/advisor) +Tensorboard支持NPU性能数据可视化插件PyTorch Profiler TensorBoard NPU Plugin。 - 将 Ascend PyTorch Profiler 或者 msprof 采集的 PyTorch 场景性能数据进行分析,并输出性能调优建议。 +支持将Ascend平台采集、解析的Pytorch Profiling数据可视化呈现,也兼容GPU数据采集、解析可视化。 -4. [bind_core](./profiler/affinity_cpu_bind) +## 分支维护策略 - 绑核脚本,支持非侵入修改工程代码,实现一键式绑核功能。 +MindStudio Training Tools工具版本分支的维护阶段如下: -## [Tensorboard](./plugins/tensorboard-plugins/tb_plugin) +| **状态** | **时间** | **说明** | +| ------------------- | -------- | ------------------------------------------------ | +| 计划 | 1—3 个月 | 计划特性 | +| 开发 | 3个月 | 开发特性 | +| 维护 | 6—12个月 | 合入所有已解决的问题并发布版本 | +| 无维护 | 0—3 个月 | 合入所有已解决的问题,无专职维护人员,无版本发布 | +| 生命周期终止(EOL) | N/A | 分支不再接受任何修改 | -Tensorboard 支持 NPU 性能数据可视化插件 PyTorch Profiler TensorBoard NPU Plugin。 +## 现有分支的维护状态 -支持将 Ascend 平台采集、解析的 PyTorch Profiling 数据可视化呈现,也兼容 GPU 数据采集、解析可视化。 +MindStudio Training Tools分支版本号命名规则如下: -## 分支维护策略 +mstt仓每年发布4个版本,每个版本都将对应一个分支;以v6.0为例,其将对应v6.0.RC1、v6.0.RC2、v6.0.RC3以及v6.0.0四个版本,在仓库中将存在与之对应的分支。 -1. MindStudio Training Tools 工具版本分支的维护阶段如下: +| **分支** | **状态** | **发布日期** | **后续状态** | **EOL日期** | +| ------------- | -------- | ------------ | ------------------------ | ----------- | +| **v6.0.0** | 维护 | 2023/12/12 | 预计2024/12/12起无维护 | | - | **状态** | **时间** | **说明** | - | ------------------- | -------- | ------------------------------------------------ | - | 计划 | 1—3 个月 | 计划特性 | - | 开发 | 3个月 | 开发特性 | - | 维护 | 6—12个月 | 合入所有已解决的问题并发布版本 | - | 无维护 | 0—3 个月 | 合入所有已解决的问题,无专职维护人员,无版本发布 | - | 生命周期终止(EOL) | N/A | 分支不再接受任何修改 | +## 参与贡献 -2. MindStudio Training Tools 分支版本号命名规则如下: +1. Fork 本仓库 +2. 新建 xxx 分支 +3. 提交代码 +4. 新建 Pull Request - mstt 仓每年发布 4 个版本,每个版本都将对应一个分支;以 v6.0 为例,其将对应 v6.0.RC1、v6.0.RC2、v6.0.RC3 以及 v6.0.0 四个版本,在仓库中将存在与之对应的分支。 +## 版本过渡提示 - | **分支** | **状态** | **发布日期** | **后续状态** | **EOL日期** | - | ------------- | -------- | ------------ | ------------------------ | ----------- | - | **v6.0.0** | 维护 | 2023.12.12 | 预计 2024.12.12 起无维护 | | +当前版本预检和ptdbg维护到2024/09/30,准备于2024/09/30下线,相关目录mstt/debug/accuracy_tools/api_accuracy_checker和mstt/debug/accuracy_tools/ptdbg_ascend将于2024/09/30删除。新版本的预检和ptdbg已经合到mstt/debug/accuracy_tools/atat目录下。 diff --git a/debug/OWNERS b/debug/OWNERS index 0bda9243569f0b6bcd0ce761d7817d512b487ddd..8a038d942355503529f874050378bf2204ac88e1 100644 --- a/debug/OWNERS +++ b/debug/OWNERS @@ -4,13 +4,15 @@ approvers: - wangchao285 - kun_8 - brightlyking +- wqc01202410 +- shawnzhu1 +- pengxiaopeng1 reviewers: - lv-kaimeng - TAJh - jiandaobao -- pengxiaopeng1 - zhengxinqian - louyujing - yang_chen_2001_02_14 -- shawnzhu1 -- wqc01202410 +- li-changwei4 +- qiangge123a diff --git a/debug/accuracy_tools/cmake/Findgtest.cmake b/debug/accuracy_tools/cmake/Findgtest.cmake index dbfe76abcc9b5d3c2f61642cc8c6e270fc441a0f..d4dd8d8895466d3367dff2032a7de03c829e3dc6 100644 --- a/debug/accuracy_tools/cmake/Findgtest.cmake +++ b/debug/accuracy_tools/cmake/Findgtest.cmake @@ -1,7 +1,6 @@ set(PACKAGE_VERSION 1.12.1) set(PKG_NAME gtest) -set(URL "https://gitee.com/mirrors/googletest/repository/archive/release-1.12.1.tar.gz") set(SHA256_VALUE "81964fe578e9bd7c94dfdb09c8e4d6e6759e19967e397dbea48d1c10e45d0df2") set(DOWNLOAD_PATH "$ENV{PROJECT_ROOT_PATH}/third_party") set(DIR_NAME "${DOWNLOAD_PATH}/googletest-release-1.12.1") @@ -9,7 +8,6 @@ set(DIR_NAME "${DOWNLOAD_PATH}/googletest-release-1.12.1") if (NOT ${PKG_NAME}_FOUND) download_opensource_pkg(${PKG_NAME} - URL ${URL} SHA256 ${SHA256_VALUE} DOWNLOAD_PATH ${DOWNLOAD_PATH} ) diff --git a/debug/accuracy_tools/cmake/Findmockcpp.cmake b/debug/accuracy_tools/cmake/Findmockcpp.cmake index c360702c187bfdef553a6b67344ea132a18373f6..73b1729aa5bec968c3e127560db981885c80ba83 100644 --- a/debug/accuracy_tools/cmake/Findmockcpp.cmake +++ b/debug/accuracy_tools/cmake/Findmockcpp.cmake @@ -1,7 +1,6 @@ set(PACKAGE_VERSION 2.7) set(PKG_NAME mockcpp) -set(URL "https://gitee.com/sinojelly/mockcpp/repository/archive/v2.7.zip") set(SHA256_VALUE "0dc7111c5be9785d0550ed3b68db7e12fd5d7802b7bc6548c52ac7b9e727fcc1") set(DOWNLOAD_PATH "$ENV{PROJECT_ROOT_PATH}/third_party") set(DIR_NAME "${DOWNLOAD_PATH}/mockcpp-v2.7") @@ -9,7 +8,6 @@ set(DIR_NAME "${DOWNLOAD_PATH}/mockcpp-v2.7") if (NOT ${PKG_NAME}_FOUND) download_opensource_pkg(${PKG_NAME} - URL ${URL} SHA256 ${SHA256_VALUE} DOWNLOAD_PATH ${DOWNLOAD_PATH} ) diff --git a/debug/accuracy_tools/cmake/Findnlohmannjson.cmake b/debug/accuracy_tools/cmake/Findnlohmannjson.cmake index 0f85cc00a0d30a3896a8f47cac95911929070e33..7acac96ca3ff8025745a6eeddbdf568e453a58f1 100644 --- a/debug/accuracy_tools/cmake/Findnlohmannjson.cmake +++ b/debug/accuracy_tools/cmake/Findnlohmannjson.cmake @@ -1,7 +1,6 @@ set(PACKAGE_VERSION 3.10.1) set(PKG_NAME nlohmannjson) -set(URL "https://gitee.com/mirrors/JSON-for-Modern-CPP/repository/archive/v3.10.1.zip") set(SHA256_VALUE "5c7d0a0542431fef628f8dc4c34fd022fe8747ccb577012d58f38672d8747e0d") set(DOWNLOAD_PATH "$ENV{PROJECT_ROOT_PATH}/third_party") set(DIR_NAME "${DOWNLOAD_PATH}/JSON-for-Modern-CPP-v3.10.1") @@ -9,7 +8,6 @@ set(DIR_NAME "${DOWNLOAD_PATH}/JSON-for-Modern-CPP-v3.10.1") if (NOT ${PKG_NAME}_FOUND) download_opensource_pkg(${PKG_NAME} - URL ${URL} SHA256 ${SHA256_VALUE} DOWNLOAD_PATH ${DOWNLOAD_PATH} ) diff --git a/debug/accuracy_tools/cmake/Findopenssl.cmake b/debug/accuracy_tools/cmake/Findopenssl.cmake index d361095242917df8accbb81a51de65c5ca5ac980..cc33bfc5902aa4c1651029789f04c8a4d2dc10bf 100644 --- a/debug/accuracy_tools/cmake/Findopenssl.cmake +++ b/debug/accuracy_tools/cmake/Findopenssl.cmake @@ -1,7 +1,6 @@ set(PACKAGE_VERSION 1.1.1) set(PKG_NAME openssl) -set(URL "https://gitee.com/mirrors/openssl/repository/archive/OpenSSL_1_1_1k.tar.gz") set(SHA256_VALUE "b92f9d3d12043c02860e5e602e50a73ed21a69947bcc74d391f41148e9f6aa95") set(DOWNLOAD_PATH "$ENV{PROJECT_ROOT_PATH}/third_party") set(DIR_NAME "${DOWNLOAD_PATH}/openssl-OpenSSL_1_1_1k") @@ -23,7 +22,6 @@ endif() endif() download_opensource_pkg(${PKG_NAME} - URL ${URL} SHA256 ${SHA256_VALUE} DOWNLOAD_PATH ${DOWNLOAD_PATH} ) diff --git a/debug/accuracy_tools/cmake/Findprotobuf.cmake b/debug/accuracy_tools/cmake/Findprotobuf.cmake index 4d70515e980f7a921447250fe58400f600419e4c..62c1fe7fbbebc6e0d76fec309a0154d5b102d3aa 100644 --- a/debug/accuracy_tools/cmake/Findprotobuf.cmake +++ b/debug/accuracy_tools/cmake/Findprotobuf.cmake @@ -1,10 +1,9 @@ -set(PACKAGE_VERSION 3.13.0) +set(PACKAGE_VERSION 3.15.0) set(PKG_NAME protobuf) -set(URL "https://gitee.com/mirrors/protobuf_source/repository/archive/v3.13.0.tar.gz") -set(SHA256_VALUE "ab9b39e7053a6fb06b01bf75fb6ec6a71a1ada5a5f8e2446f927336e97b9e7bb") +set(SHA256_VALUE "a1ce078c369f46a3277fdc7ce462ac73cb7cb0edec8bc9d90d23fdb34491c575") set(DOWNLOAD_PATH "$ENV{PROJECT_ROOT_PATH}/third_party") -set(DIR_NAME "${DOWNLOAD_PATH}/protobuf_source-v3.13.0") +set(DIR_NAME "${DOWNLOAD_PATH}/protobuf_source-v3.15.0") if (NOT ${PKG_NAME}_FOUND) @@ -32,7 +31,6 @@ endif() endif() download_opensource_pkg(${PKG_NAME} - URL ${URL} SHA256 ${SHA256_VALUE} DOWNLOAD_PATH ${DOWNLOAD_PATH} ) diff --git a/debug/accuracy_tools/cmake/config.ini b/debug/accuracy_tools/cmake/config.ini new file mode 100644 index 0000000000000000000000000000000000000000..57e544d540aafa1ddf67245d95a78cdc9a151fae --- /dev/null +++ b/debug/accuracy_tools/cmake/config.ini @@ -0,0 +1,14 @@ +[gtest] +url = https://gitee.com/mirrors/googletest/repository/archive/release-1.12.1.tar.gz + +[mockcpp] +url = https://gitee.com/sinojelly/mockcpp/repository/archive/v2.7.zip + +[nlohmannjson] +url = https://gitee.com/mirrors/JSON-for-Modern-CPP/repository/archive/v3.10.1.zip + +[openssl] +url = https://gitee.com/mirrors/openssl/repository/archive/OpenSSL_1_1_1k.tar.gz + +[protobuf] +url = https://gitee.com/mirrors/protobuf_source/repository/archive/v3.15.0.tar.gz \ No newline at end of file diff --git a/debug/accuracy_tools/cmake/download_opensource.sh b/debug/accuracy_tools/cmake/download_opensource.sh index 725e971621434c32d9954c80b9efe234502eefcc..671dc218bb135a39ffc8937777815d84df654187 100644 --- a/debug/accuracy_tools/cmake/download_opensource.sh +++ b/debug/accuracy_tools/cmake/download_opensource.sh @@ -1,11 +1,11 @@ #!/bin/bash if [ "$#" -lt 2 ]; then - echo "Usage: $0 [ ] [ ]" + echo "Usage: $0 [ ] [ ]" exit 1 fi -url=$1 +pkg_name=$1 path=$2 if [ "$#" -ge 3 ]; then @@ -15,6 +15,16 @@ if [ "$#" -ge 4 ]; then tag=$4 fi +url=$(awk -F " = " '/\['${pkg_name}'\]/{a=1}a==1&&$1~/url/{print $2;exit}' config.ini) +lib_path=$MSTT_LIB_PATH +if [ -n "$lib_path" ]; then + url=${lib_path}$(echo $url | awk -F '/' -v OFS='/' '{print $5,$8}') +fi +if [[ ! $url = https* ]]; then + echo "The URL of $pkg_name is illegal." + exit 1 +fi + echo "Start to download ${url}..." if [ ! -d "$path" ]; then diff --git a/debug/accuracy_tools/cmake/utils.cmake b/debug/accuracy_tools/cmake/utils.cmake index e3e963d63e99da4e0bb1fd2973051278feb04435..738afff874f37bea442c33f6cf607a21bdd6cbe7 100644 --- a/debug/accuracy_tools/cmake/utils.cmake +++ b/debug/accuracy_tools/cmake/utils.cmake @@ -2,13 +2,10 @@ function(download_opensource_pkg pkg_name) message("start to download ${pkg_name}...") set(options) - set(oneValueArgs URL SHA256 GIT_TAG DOWNLOAD_PATH DIR_NAME BUILD_CMD) + set(oneValueArgs SHA256 GIT_TAG DOWNLOAD_PATH DIR_NAME BUILD_CMD) set(multiValueArgs PATCHES) cmake_parse_arguments(PKG "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) - if (NOT PKG_URL) - message(FATAL_ERROR "${pkg_name} need URL.") - endif() if (NOT PKG_DOWNLOAD_PATH) set(PKG_DOWNLOAD_PATH "${CMAKE_SOURCE_DIR}/../third_party") endif() @@ -16,7 +13,7 @@ function(download_opensource_pkg pkg_name) execute_process( WORKING_DIRECTORY $ENV{PROJECT_ROOT_PATH}/cmake - COMMAND bash download_opensource.sh ${PKG_URL} ${PKG_DOWNLOAD_PATH} ${PKG_SHA256} ${PKG_GIT_TAG} + COMMAND bash download_opensource.sh ${pkg_name} ${PKG_DOWNLOAD_PATH} ${PKG_SHA256} ${PKG_GIT_TAG} RESULT_VARIABLE RESULT ) if (NOT RESULT EQUAL 0) diff --git a/debug/accuracy_tools/msprobe/README.md b/debug/accuracy_tools/msprobe/README.md index 0e68d1f8d9bdaba93a2f65220f85d08eb45f8586..f6725410b43b8460872bf65d1670be8b576fd642 100644 --- a/debug/accuracy_tools/msprobe/README.md +++ b/debug/accuracy_tools/msprobe/README.md @@ -44,6 +44,7 @@ export MSPROBE_LOG_LEVEL={x} - msprobe支持AscendPyTorch 1.11.0或更高版本,支持的PyTorch和CANN以及PyTorch和python软件版本配套关系请参见《[Ascend Extension for PyTorch插件](https://gitee.com/ascend/pytorch)》。 - msprobe支持MindSpore 2.4.0或更高版本,支持的MindSpore和CANN以及MindSpore和python软件版本配套关系请参见《[MindSpore版本发布列表](https://www.mindspore.cn/versions)》。 +- msprobe支持MSAdapter 2.1.0。 - msprobe支持的固件驱动版本与配套CANN软件支持的固件驱动版本相同,开发者可通过“[昇腾社区-固件与驱动](https://gitee.com/link?target=https%3A%2F%2Fwww.hiascend.com%2Fhardware%2Ffirmware-drivers%2Fcommunity%3Fproduct%3D2%26model%3D28%26cann%3D8.0.RC3.alpha003%26driver%3D1.0.25.alpha)”页面根据产品型号与CANN软件版本获取配套的固件与驱动。 @@ -53,7 +54,9 @@ export MSPROBE_LOG_LEVEL={x} **2. 工具读写的所有路径,如config_path、dump_path等,只允许包含大小写字母、数字、下划线、斜杠、点和短横线。** -## ⚙️ [安装](./docs/01.installation.md) +## ⚙️ 安装 + +请参见[安装指导说明](./docs/01.installation.md)。 ## 🌟 新版本特性 @@ -69,35 +72,37 @@ export MSPROBE_LOG_LEVEL={x} ### 1 数据采集 -msprobe 通过在训练脚本中添加 PrecisionDebugger 接口的方式对 API 执行精度数据 dump 操作,对应 config.json 中的 task 为 statistics 或 tensor。 +msprobe 通过在训练脚本中添加 PrecisionDebugger 接口的方式对 API 执行精度数据 dump 操作。对应 config.json 中的 "statistics" 或 "tensor" task。 [PyTorch 场景的数据采集](./docs/05.data_dump_PyTorch.md) [MindSpore 场景的数据采集](./docs/06.data_dump_MindSpore.md) +[MSAdapter 场景的数据采集](./docs/29.data_dump_MSAdapter.md) + ### 2 精度预检 -精度预检旨在昇腾 NPU 上扫描训练模型中的所有 API 进行 API 复现,给出精度情况的诊断和分析。对应 config.json 中的 task 为 run_ut。 +精度预检旨在昇腾 NPU 上扫描训练模型中的所有 API 进行 API 复现,给出精度情况的诊断和分析。对应 config.json 中的 "run_ut" task。 PyTorch 场景的[离线预检](./docs/07.accuracy_checker_PyTorch.md)和[在线预检](./docs/08.accuracy_checker_online_PyTorch.md) MindSpore 动态图场景的[离线预检](./docs/09.accuracy_checker_MindSpore.md) -### 3 精度比对 +### 3 分级可视化构图比对 -该功能进行 PyTorch 整网 API 粒度的数据 dump、精度比对,进而定位训练场景下的精度问题。 +该功能将msprobe工具dump的精度数据进行解析,还原模型图结构,实现模型各个层级的精度数据比对,方便用户理解模型结构、分析精度问题。 -[PyTorch 场景的精度比对](./docs/10.accuracy_compare_PyTorch.md) +[PyTorch 场景的分级可视化构图比对](./docs/21.visualization_PyTorch.md) -[MindSpore 场景的精度比对](./docs/11.accuracy_compare_MindSpore.md) +[MindSpore 场景的分级可视化构图比对](./docs/22.visualization_MindSpore.md) -### 4 溢出检测与解析 +### 4 精度比对 -溢出检测与解析是在执行精度数据 dump 时,判断是否存在输入正常但输出存在溢出的 API,从而判断是否为正常溢出。对应 config.json 中的 overflow_check。 +该功能进行 PyTorch 整网 API 粒度的数据 dump、精度比对,进而定位训练场景下的精度问题。 -[PyTorch 场景的溢出检测与解析](./docs/12.overflow_check_PyTorch.md) +[PyTorch 场景的精度比对](./docs/10.accuracy_compare_PyTorch.md) -[MindSpore 场景的溢出检测与解析](./docs/13.overflow_check_MindSpore.md) +[MindSpore 场景的精度比对](./docs/11.accuracy_compare_MindSpore.md) ### 5 数据解析 @@ -129,26 +134,46 @@ MindSpore 动态图场景的[离线预检](./docs/09.accuracy_checker_MindSpore. [兼容 PyTorch 和 MindSpore 框架的训练状态监控](./docs/19.monitor.md) -### 10 分级可视化构图比对 +### 10 单算子API自动生成脚本 -该功能将msprobe工具dump的精度数据进行解析,还原模型图结构,实现模型各个层级的精度数据比对,方便用户理解模型结构、分析精度问题。 +该功能将msprobe工具dump的精度数据进行解析,自动生成单API脚本,用于复现整网中出现的算子问题,降低用户复现问题的成本,供开发分析算子问题。 -[PyTorch 场景的分级可视化构图比对](./docs/21.visualization_PyTorch.md) +[PyTorch 单算子API自动生成脚本](./docs/23.generate_operator_PyTorch.md) -[MindSpore 场景的分级可视化构图比对](./docs/22.visualization_MindSpore.md) +### 11 数码关联 +该功能只支持 MindSpore 静态图场景,用于将IR图与dump数据进行关联,获取dump数据和代码调用栈的关联关系。 -### 11 单算子API自动生成脚本 +[MindSpore 场景的数码关联](./docs/24.code_mapping_Mindspore.md) -该功能将msprobe工具dump的精度数据进行解析,自动生成单API脚本,用于复现整网中出现的算子问题,降低用户复现问题的成本,供开发分析算子问题。 +### 12 溢出检测与解析 -[PyTorch 单算子API自动生成脚本](./docs/23.generate_operator_PyTorch.md) +溢出检测用于采集溢出 API 或 模块的精度数据,而溢出解析则是通过对溢出数据的分析,进一步判断是否为正常溢出。对应 config.json 中的 "overflow_check" task。 +推荐直接使用[数据采集](#1-数据采集)功能采集统计量信息,检测溢出问题。 -### 12 数码关联 +[PyTorch 场景的溢出检测与解析](./docs/12.overflow_check_PyTorch.md) -该功能只支持 MindSpore 静态图场景,用于将IR图与dump数据进行关联,获取dump数据和代码调用栈的关联关系。 +[MindSpore 场景的溢出检测](./docs/13.overflow_check_MindSpore.md) -[MindSpore 场景的数码关联](./docs/24.code_mapping_Mindspore.md) +[MSAdapter 场景的溢出检测](./docs/30.overflow_check_MSAdapter.md) + +### 13 训练前配置检查 + +该工具主要适用于对比两个环境下可能影响训练精度的配置差异, 推荐在精度对比前使用。 + +[PyTorch 训练前配置检查](./docs/31.config_checking.md) + +### 14 权重比对 + +权重比对功能用于训练过程中保存的checkpoint,计算对应参数间的余弦相似度、欧式距离等指标。当前支持pytorch下megatron/mindspeed不同模型并行策略下的权重互相比对。 + +[Megatron权重比对](./docs/32.checkpoint_compare.md) + +### 15 整网首个溢出节点分析 + +多rank场景下通过dump数据找到首个出现Nan或Inf的节点。 + +[PyTorch 场景整网首个溢出节点分析](./docs/33.nan_analyze) ## 📑 补充材料 diff --git a/debug/accuracy_tools/msprobe/ccsrc/CMakeLists.txt b/debug/accuracy_tools/msprobe/ccsrc/CMakeLists.txt index 2579a3a0e785c0e0ca384b4d52118a5d828249f8..8472c1ad714f37f045e2c41b7e17ec6f3d709bb6 100644 --- a/debug/accuracy_tools/msprobe/ccsrc/CMakeLists.txt +++ b/debug/accuracy_tools/msprobe/ccsrc/CMakeLists.txt @@ -26,6 +26,8 @@ compile_protobuf_file( ${PROTO_SRC} ) +set(CMAKE_SKIP_RPATH TRUE) + add_library(_msprobe_c SHARED) target_compile_options(_msprobe_c PRIVATE "-Wall") @@ -33,8 +35,9 @@ target_compile_options(_msprobe_c PRIVATE "-fPIC") target_compile_options(_msprobe_c PRIVATE "-fstack-protector-all") target_compile_options(_msprobe_c PRIVATE "-ftrapv") target_compile_options(_msprobe_c PRIVATE "-fstack-check") +target_compile_options(_msprobe_c PRIVATE "-D_FORTIFY_SOURCE=2") -target_link_options(_msprobe_c PRIVATE "-Wl,-z,relor") +target_link_options(_msprobe_c PRIVATE "-Wl,-z,relro") target_link_options(_msprobe_c PRIVATE "-Wl,-z,now") target_link_options(_msprobe_c PRIVATE "-Wl,-z,noexecstack") @@ -50,6 +53,7 @@ if(DEFINED BUILD_TYPE AND "${BUILD_TYPE}" STREQUAL "debug") target_compile_definitions(_msprobe_c PRIVATE __DEBUG__) else() target_compile_options(_msprobe_c PRIVATE "-O2") + target_link_options(_msprobe_c PRIVATE "-s") endif() target_include_directories(_msprobe_c PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) diff --git a/debug/accuracy_tools/msprobe/ccsrc/base/DebuggerConfig.cpp b/debug/accuracy_tools/msprobe/ccsrc/base/DebuggerConfig.cpp index 9f61e03a31f6d4dfa2ca0b258d589bbcd29356fa..9cd5c8cfbfb7e51ee54e70076f50d15b5a1371d4 100644 --- a/debug/accuracy_tools/msprobe/ccsrc/base/DebuggerConfig.cpp +++ b/debug/accuracy_tools/msprobe/ccsrc/base/DebuggerConfig.cpp @@ -30,7 +30,7 @@ namespace MindStudioDebugger { template DebuggerErrno ParseJsonBaseObj2Var(const nlohmann::json& content, const std::string& field, T& output, - bool mandatory=false) + bool mandatory = false) { nlohmann::json::const_iterator iter = content.find(field); if (iter == content.end()) { @@ -52,7 +52,7 @@ DebuggerErrno ParseJsonBaseObj2Var(const nlohmann::json& content, const std::str template DebuggerErrno ParseJsonStringAndTrans(const nlohmann::json& content, const std::string& field, - const std::map& enum2name, T& output, bool mandatory=false) { + const std::map& enum2name, T& output, bool mandatory = false) { DebuggerErrno ret; std::string value; @@ -93,14 +93,16 @@ DebuggerErrno ParseJsonStringAndTrans(const nlohmann::json& content, const std:: static bool DebuggerCfgParseUIntRangeGetBorder(const std::string& exp, uint32_t& left, uint32_t& right) { if (std::count(exp.begin(), exp.end(), '-') != 1) { - LOG_ERROR(DebuggerErrno::ERROR_INVALID_FORMAT, "When using a range expression, it should be formatted as \"a-b\"."); + LOG_ERROR(DebuggerErrno::ERROR_INVALID_FORMAT, + "When using a range expression, it should be formatted as \"a-b\"."); return false; } std::istringstream iss(exp); char dash; iss >> left >> dash >> right; if (iss.fail() || dash != '-') { - LOG_ERROR(DebuggerErrno::ERROR_INVALID_FORMAT, "When using a range expression, it should be formatted as \"a-b\"."); + LOG_ERROR(DebuggerErrno::ERROR_INVALID_FORMAT, + "When using a range expression, it should be formatted as \"a-b\"."); return false; } if (left >= right) { @@ -140,7 +142,12 @@ void DebuggerCfgParseUIntRange(const nlohmann::json& content, const std::string& LOG_ERROR(DebuggerErrno::ERROR_INVALID_FORMAT, "Failed to parse " + name + "."); return; } - realLen += (end - begin + 1); + uint32_t rangeSize = end - begin; + if (realLen > UINT32_MAX - (rangeSize + 1)) { + LOG_ERROR(DebuggerErrno::ERROR_VALUE_OVERFLOW, name + " size exceeds limit"); + return; + } + realLen += (rangeSize + 1); buf.emplace_back(std::make_pair(begin, end)); } } diff --git a/debug/accuracy_tools/msprobe/ccsrc/base/DebuggerConfig.hpp b/debug/accuracy_tools/msprobe/ccsrc/base/DebuggerConfig.hpp index 15ea9e6fda47c0380d9718f135a1baf0658788eb..1d40c1cefa216fcc93b9f2a1b9ed99e5510426f2 100644 --- a/debug/accuracy_tools/msprobe/ccsrc/base/DebuggerConfig.hpp +++ b/debug/accuracy_tools/msprobe/ccsrc/base/DebuggerConfig.hpp @@ -14,7 +14,8 @@ * limitations under the License. */ -#pragma once +#ifndef DEBUGGERCONFIG_H +#define DEBUGGERCONFIG_H #include #include @@ -199,7 +200,7 @@ public: OverflowCheckCfg() = default; ~OverflowCheckCfg() = default; - uint32_t overflowNums{1}; + int32_t overflowNums{1}; DebuggerOpCheckLevel checkMode{DebuggerOpCheckLevel::CHECK_LEVEL_ALL}; private: @@ -262,4 +263,6 @@ private: std::shared_ptr overflowCheckCfg{nullptr}; }; -} \ No newline at end of file +} + +#endif \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/ccsrc/core/AclDumpDataProcessor.cpp b/debug/accuracy_tools/msprobe/ccsrc/core/AclDumpDataProcessor.cpp index 0fe3443fa1f9286fe77c710c955d543d94c4b3a4..94fd19b33f03f86bba5d818a57d59300d600dc42 100644 --- a/debug/accuracy_tools/msprobe/ccsrc/core/AclDumpDataProcessor.cpp +++ b/debug/accuracy_tools/msprobe/ccsrc/core/AclDumpDataProcessor.cpp @@ -56,23 +56,30 @@ constexpr const char* kStatsHeaderShape = "Shape"; constexpr const char* kStatsHeaderMax = "Max Value"; constexpr const char* kStatsHeaderMin = "Min Value"; constexpr const char* kStatsHeaderAvg = "Avg Value"; -constexpr const char* kStatsHeaderL2Norm = "L2 Norm Value"; +constexpr const char* kStatsHeaderL2Norm = "l2norm"; +constexpr const char* kStatsHeaderL2NormInCsv = "L2Norm Value"; constexpr const char* kStatsHeaderMD5 = "MD5 Value"; constexpr const char* kStatsHeaderNan = "Nan Count"; +constexpr const char* kStatsHeaderNanInCsv = "NaN Count"; constexpr const char* kStatsHeaderNegInf = "Negative Inf Count"; constexpr const char* kStatsHeaderPosInf = "Positive Inf Count"; constexpr const char* kRankId = "RANK_ID"; constexpr const char* kDigitalNumbers = "0123456789"; -static const std::map summaryOptionHeaderStrMap = { - {DebuggerSummaryOption::MAX, kStatsHeaderMax}, - {DebuggerSummaryOption::MIN, kStatsHeaderMin}, - {DebuggerSummaryOption::MEAN, kStatsHeaderAvg}, - {DebuggerSummaryOption::L2NORM, kStatsHeaderL2Norm}, - {DebuggerSummaryOption::NAN_CNT, kStatsHeaderNan}, - {DebuggerSummaryOption::NEG_INF_CNT, kStatsHeaderNegInf}, - {DebuggerSummaryOption::POS_INF_CNT, kStatsHeaderPosInf}, - {DebuggerSummaryOption::MD5, kStatsHeaderMD5}, +static const std::map> summaryOptionHeaderStrMap = { + {DebuggerSummaryOption::MAX, {kStatsHeaderMax, kStatsHeaderMax}}, + {DebuggerSummaryOption::MIN, {kStatsHeaderMin, kStatsHeaderMin}}, + {DebuggerSummaryOption::MEAN, {kStatsHeaderAvg, kStatsHeaderAvg}}, + {DebuggerSummaryOption::L2NORM, {kStatsHeaderL2Norm, kStatsHeaderL2NormInCsv}}, + {DebuggerSummaryOption::NAN_CNT, {kStatsHeaderNan, kStatsHeaderNanInCsv}}, + {DebuggerSummaryOption::NEG_INF_CNT, {kStatsHeaderNegInf, kStatsHeaderNegInf}}, + {DebuggerSummaryOption::POS_INF_CNT, {kStatsHeaderPosInf, kStatsHeaderPosInf}}, + {DebuggerSummaryOption::MD5, {kStatsHeaderMD5, kStatsHeaderMD5}}, +}; + +const static std::map kDtypeTransMap = { + {AclDtype::DT_BF16, AclDtype::DT_FLOAT}, + {AclDtype::DT_INT4, AclDtype::DT_INT8}, }; class AclTensorStats { @@ -152,7 +159,8 @@ AclTensorStats::AclTensorStats(const AclTensorInfo& tensor, const std::map& opt) +AclTensorStats AclTensorStats::CalTensorSummary(const AclTensorInfo& tensor, + const std::vector& opt) { DEBUG_FUNC_TRACE(); std::map summary; @@ -170,7 +178,7 @@ static std::map ParseTensorSummaryHeaderOrder(c for (uint32_t pos = 0; pos < segs.size(); ++pos) { const std::string& opt = segs[pos]; for (auto it = summaryOptionHeaderStrMap.begin(); it != summaryOptionHeaderStrMap.end(); ++it) { - if (opt == it->second) { + if (opt == it->second.first) { ret[pos] = it->first; break; } @@ -233,7 +241,7 @@ std::string AclTensorStats::GetCsvHeader() const ret.append("Op Type,Op Name,Task ID,Stream ID,Timestamp,Input/Output,Slot,Data Size,Data Type,Format,Shape"); for (auto it = stats.begin(); it != stats.end(); it++) { ret.append(","); - ret.append(summaryOptionHeaderStrMap.at(it->first)); + ret.append(summaryOptionHeaderStrMap.at(it->first).second); } ret.append("\n"); @@ -290,8 +298,15 @@ DebuggerErrno AclDumpDataProcessor::PushData(const acldumpChunk *chunk) } size_t len = chunk->bufLen; + if (len == 0) { + LOG_ERROR(DebuggerErrno::ERROR_INVALID_VALUE, ToString() + ": invalid value(cached size " + + std::to_string(totalLen) + ", receiving size " + std::to_string(len) + ")."); + errorOccurred = true; + return DebuggerErrno::ERROR_INVALID_VALUE; + } + /* 防止正负翻转 */ - if (SIZE_MAX - len < totalLen || totalLen + len > kMaxDataLen || len == 0) { + if (SIZE_MAX - len < totalLen || totalLen + len > kMaxDataLen) { LOG_ERROR(DebuggerErrno::ERROR_BUFFER_OVERFLOW, ToString() + ": buffer overflow(cached size " + std::to_string(totalLen) + ", receiving size " + std::to_string(len) + ")."); errorOccurred = true; @@ -306,7 +321,10 @@ DebuggerErrno AclDumpDataProcessor::PushData(const acldumpChunk *chunk) return DebuggerErrno::ERROR_NO_MEMORY; } - if (memcpy(p->data(), chunk->dataBuf, len) == nullptr) { + /* vector p根据chunk->dataBuf的长度,即len,申请创建,所以无需校验空间大小 */ + try { + std::copy(chunk->dataBuf, chunk->dataBuf + len, p->begin()); + } catch (const std::exception& e) { LOG_ERROR(DebuggerErrno::ERROR_SYSCALL_FAILED, ToString() + ": Failed to copy data;"); delete p; errorOccurred = true; @@ -354,9 +372,11 @@ DebuggerErrno AclDumpDataProcessor::ConcatenateData() } size_t offset = 0; - uint8_t* msg = p->data(); while (!buffer.empty()) { - if (memcpy(msg + offset, buffer.front()->data(), buffer.front()->size()) == nullptr) { + /* vector p根据buffer里所有vector的总长度,即totalLen,申请创建,所以无需校验空间大小 */ + try { + std::copy(buffer.front()->begin(), buffer.front()->end(), p->begin() + offset); + } catch (const std::exception& e) { delete p; LOG_ERROR(DebuggerErrno::ERROR_SYSCALL_FAILED, "Data processor(" + dumpPath + "): Failed to copy."); return DebuggerErrno::ERROR_SYSCALL_FAILED; @@ -524,7 +544,11 @@ static std::string MappingFilePath(const std::string& originPath) } DebuggerErrno ret; - FileUtils::CreateDir(dir); + ret = FileUtils::CreateDir(dir); + if (ret != DebuggerErrno::OK) { + LOG_ERROR(DebuggerErrno::ERROR, "Failed to create directory " + dir + "."); + return std::string(); + } std::ofstream ofs; constexpr const char* mapFileName = "mapping.csv"; @@ -585,7 +609,8 @@ static std::string GenDataPath(const std::string& path) { } /* * ACL 接口返回数据的路径格式如下 - * {dump_path}/rank_{rank_id}/{time stamp}/step_{step_id}/{time}/{device_id}/{model_name}/{model_id}/{iteration_id}/{data name} + * {dump_path}/rank_{rank_id}/{time stamp}/step_{step_id}/{time} + /{device_id}/{model_name}/{model_id}/{iteration_id}/{data name} * items[0] 表示 rank_{rank_id} * items[1] 表示 {time stamp} * items[2] 表示 step_{step_id} @@ -603,7 +628,7 @@ static std::string GenDataPath(const std::string& path) { inline std::string GetTensorInfoSuffix(AclTensorInfo& tensor) { return "." + tensor.inout + "." + std::to_string(tensor.slot) + - "." + DataUtils::GetFormatString(tensor.hostFmt) + "." + DataUtils::GetDTypeString(tensor.dtype); + "." + DataUtils::GetFormatString(tensor.hostFmt) + "." + DataUtils::GetDTypeString(tensor.oriDtype); } static DebuggerErrno DumpOneAclTensorFmtBin(AclTensorInfo& tensor) @@ -640,10 +665,14 @@ static DebuggerErrno DumpOneAclTensorFmtNpy(AclTensorInfo& tensor) return DebuggerErrno::OK; } - if (tensor.dtype == AclDtype::DT_BF16) { - ret = AclTensor::TransDtype(tensor, AclDtype::DT_FLOAT); + auto it = kDtypeTransMap.find(tensor.dtype); + if (it != kDtypeTransMap.end()) { + AclDtype dstDtype = it->second; + ret = AclTensor::TransDtype(tensor, dstDtype); if (ret != DebuggerErrno::OK) { - LOG_ERROR(ret, tensor + ": Failed to transform dtype from bf16 to fp32."); + LOG_ERROR(ret, tensor + ": Failed to transform dtype from " + + DataUtils::GetDTypeString(it->first) + " to " + + DataUtils::GetDTypeString(it->second)+ "."); return ret; } } @@ -684,7 +713,7 @@ static DebuggerErrno WriteOneTensorStatToDisk(const AclTensorStats& stat) /* 此处防止多进程间竞争,使用文件锁,故使用C风格接口 */ uint32_t retry = 100; uint32_t interval = 10; - if (FileUtils::IsPathExist(dumpfile) && !FileUtils::IsRegularFile(dumpfile)) { + if (FileUtils::CheckFileBeforeCreateOrWrite(dumpfile, true) != DebuggerErrno::OK) { LOG_ERROR(DebuggerErrno::ERROR_FILE_ALREADY_EXISTS, "File " + dumpfile + " exists and has invalid format."); return DebuggerErrno::ERROR_FILE_ALREADY_EXISTS; } @@ -705,6 +734,7 @@ static DebuggerErrno WriteOneTensorStatToDisk(const AclTensorStats& stat) if (i >= retry) { LOG_ERROR(DebuggerErrno::ERROR_SYSCALL_FAILED, "Failed to occupy file " + dumpfile); + close(fd); return DebuggerErrno::ERROR_SYSCALL_FAILED; } @@ -736,7 +766,9 @@ static DebuggerErrno DumpOneAclTensor(AclTensorInfo& tensor, std::vector #include @@ -30,8 +31,8 @@ constexpr size_t kMaxDataLen = 4ULL * 1024 * 1024 * 1024; class AclDumpDataProcessor { public: - AclDumpDataProcessor(const std::string& path, const std::vector& opts) : - dumpPath{path}, hostAnalysisOpts{opts} {}; + AclDumpDataProcessor(const std::string& path, const std::vector& opts) + : dumpPath{path}, hostAnalysisOpts{opts} {}; ~AclDumpDataProcessor(); bool IsCompleted() const {return completed;} @@ -57,3 +58,5 @@ private: } +#endif + diff --git a/debug/accuracy_tools/msprobe/ccsrc/core/AclDumper.cpp b/debug/accuracy_tools/msprobe/ccsrc/core/AclDumper.cpp index 80769d7fc5fbc9d36115a544e05dd00f2a7541c3..bce101949db4e8f7906b5d94bd34e690b6400dd2 100644 --- a/debug/accuracy_tools/msprobe/ccsrc/core/AclDumper.cpp +++ b/debug/accuracy_tools/msprobe/ccsrc/core/AclDumper.cpp @@ -31,7 +31,7 @@ namespace MindStudioDebugger { constexpr const char* kAclDumpScene = "dump_scene"; constexpr const char* kSceneNormal = "normal"; -constexpr const char* kSceneException ="lite_exception"; +constexpr const char* kSceneException = "lite_exception"; constexpr const char* kAclDumpPath = "dump_path"; constexpr const char* kAclDumpStep = "dump_step"; @@ -151,6 +151,26 @@ bool AclDumper::IsCfgEnableAclDumper() ELE_IN_VECTOR(tasks, DebuggerTaskType::TASK_OVERFLOW_CHECK)); } +bool AclDumper::IsOverflowCompleted() +{ + return overflowNums != -1 && realOverflowNums > overflowNums; +} + +void AclDumper::CountOverflowNumbers(const acldumpChunk* chunk) +{ + if (IsOverflowCompleted() || !isOverflowDump || !chunk->isLastChunk) { + return; + } + const std::string fileName = chunk->fileName; + auto separator = fileName.rfind("/"); + auto fileBaseName = fileName.substr(separator + 1); + if (fileBaseName.rfind("Opdebug.Node_OpDebug.") == 0) { + // count according to the first file: Node_OpDebug + realOverflowNums++; + } + return; +} + std::string AclDumper::GetDumpPath(uint32_t curStep) const { if (!initialized || foreDumpPath.empty()) { @@ -357,6 +377,11 @@ DebuggerErrno AclDumper::Initialize() void AclDumper::OnAclDumpCallBack(const acldumpChunk* chunk, int32_t len) { DEBUG_FUNC_TRACE(); + CountOverflowNumbers(chunk); + if (IsOverflowCompleted()) { + return; + } + std::string dumpPath = FileUtils::GetAbsPath(chunk->fileName); auto it = dataProcessors.find(dumpPath); if (it == dataProcessors.end()) { @@ -404,7 +429,7 @@ void AclDumper::SetDump(uint32_t rank, uint32_t curStep, ExtArgs& args) if (!initialized) { ret = Initialize(); - if(ret != DebuggerErrno::OK) { + if (ret != DebuggerErrno::OK) { LOG_ERROR(ret, "AclDumper initialization failed."); return; } @@ -424,6 +449,8 @@ void AclDumper::SetDump(uint32_t rank, uint32_t curStep, ExtArgs& args) ret = AclDumpGenStatJson(statisticsCfg, rank, curStep, kernels); } else if (overflowCheckCfg != nullptr) { ret = AclDumpGenOverflowJson(overflowCheckCfg, rank, curStep); + overflowNums = overflowCheckCfg->overflowNums; + isOverflowDump = true; } if (ret != DebuggerErrno::OK) { diff --git a/debug/accuracy_tools/msprobe/ccsrc/core/AclDumper.hpp b/debug/accuracy_tools/msprobe/ccsrc/core/AclDumper.hpp index dcfad5fafcabdf944e1d4b0b0a3cd77251ce047d..6985df65e166101c08501e5e206e003bda494b9a 100644 --- a/debug/accuracy_tools/msprobe/ccsrc/core/AclDumper.hpp +++ b/debug/accuracy_tools/msprobe/ccsrc/core/AclDumper.hpp @@ -58,11 +58,17 @@ private: uint32_t curStep, const char** kernels); DebuggerErrno AclDumpGenOverflowJson(std::shared_ptr overflowCfg, uint32_t rank, uint32_t curStep); + void CountOverflowNumbers(const acldumpChunk* chunk); + bool IsOverflowCompleted(); + bool initialized{false}; bool aclDumpHasSet{false}; std::string foreDumpPath; std::vector hostAnalysisOpt; std::map> dataProcessors; + bool isOverflowDump{false}; + int32_t overflowNums{1}; + int32_t realOverflowNums{0}; }; void KernelInitDump(); diff --git a/debug/accuracy_tools/msprobe/ccsrc/core/AclTensor.cpp b/debug/accuracy_tools/msprobe/ccsrc/core/AclTensor.cpp index 45adff4962156f87f52c17166bc3b381f07f2978..4bbbaec5a0cea436fb7231954b26409d23595130 100644 --- a/debug/accuracy_tools/msprobe/ccsrc/core/AclTensor.cpp +++ b/debug/accuracy_tools/msprobe/ccsrc/core/AclTensor.cpp @@ -164,7 +164,8 @@ const static std::unordered_map formatTrans {AclDumpMsg::OutputFormat::FORMAT_NC1HWC0_C04, AclFormat::FORMAT_NC1HWC0_C04}, {AclDumpMsg::OutputFormat::FORMAT_FRACTAL_Z_C04, AclFormat::FORMAT_FRACTAL_Z_C04}, {AclDumpMsg::OutputFormat::FORMAT_CHWN, AclFormat::FORMAT_CHWN}, - {AclDumpMsg::OutputFormat::FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS, AclFormat::FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS}, + {AclDumpMsg::OutputFormat::FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS, + AclFormat::FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS}, {AclDumpMsg::OutputFormat::FORMAT_HWCN, AclFormat::FORMAT_HWCN}, {AclDumpMsg::OutputFormat::FORMAT_NC1KHKWHWC0, AclFormat::FORMAT_NC1KHKWHWC0}, {AclDumpMsg::OutputFormat::FORMAT_BN_WEIGHT, AclFormat::FORMAT_BN_WEIGHT}, @@ -291,7 +292,11 @@ static inline void AssertDim(const AclShape& shape, size_t dim) static inline void AssertConsis(const AclTensorInfo& tensor) { - if (EleNumOfTensor(tensor, false) * SizeOfAclDType(tensor) != tensor.dataSize) { + size_t tensor_size = EleNumOfTensor(tensor, false) * SizeOfAclDType(tensor); + // Processing dtype whose size < 1 + // The ele num of quantization type(qint4*2) in MindSpore must be even. + if (tensor.dtype == AclDtype::DT_INT4) tensor_size = EleNumOfTensor(tensor, false) / 2; + if (tensor_size != tensor.dataSize) { throw std::runtime_error(tensor + ": The internal data of Tensor is inconsistent."); } } @@ -343,7 +348,8 @@ AclTensorInfo ParseAttrsFromDumpData(const std::string& dumpPath, const uint8_t* } int32_t subFormat = tensor.sub_format(); - return AclTensorInfo{dumpPath, data, dtype, dFmt, hFmt, dShape, hShape, dataSize, subFormat, io, slot, dumpOriginData}; + return AclTensorInfo{dumpPath, data, dtype, dtype, dFmt, hFmt, + dShape, hShape, dataSize, subFormat, io, slot, dumpOriginData}; } template AclTensorInfo ParseAttrsFromDumpData( @@ -390,8 +396,9 @@ static DebuggerErrno FRAC_Z_TO_NCHW_WITH_GROUPS(AclTensorInfo& tensor) auto coutOpt = AlignCeil(eMult * coutOri, kCubeSize); auto c1Dim = cinOpt / cubeK; const uint8_t* src = tensor.aclData; - uint8_t* dst = tensor.transBuf.data(); + auto dst = tensor.transBuf.begin(); auto dtypeSize = SizeOfAclDType(tensor); + auto dstSize = tensor.transBuf.size(); for (int64_t g = 0; g < groups; ++g) { for (int64_t c = 0; c < cDim; ++c) { @@ -407,8 +414,13 @@ static DebuggerErrno FRAC_Z_TO_NCHW_WITH_GROUPS(AclTensorInfo& tensor) (dstCi / cubeK) * hDim * wDim * coutOpt * cubeK + h * wDim * coutOpt * cubeK + w * coutOpt * cubeK + dstCo * cubeK + temporary; int64_t hstIdx = srcCo * cDim * hDim * wDim + c * hDim * wDim + h * wDim + w; - /* 此处由偏移计算逻辑保障不会越界读写 */ - std::memcpy(dst + hstIdx * dtypeSize, src + devIdx * dtypeSize, dtypeSize); + int64_t devOffset = devIdx * dtypeSize; + int64_t hstOffset = hstIdx * dtypeSize; + if (hstOffset + dtypeSize > dstSize) { + return DebuggerErrno::ERROR_INVALID_VALUE; + } + std::copy(src + devOffset, src + devOffset + dtypeSize, + dst + hstOffset); } } } @@ -446,8 +458,9 @@ static DebuggerErrno FRAC_Z_TO_NCHW(AclTensorInfo& tensor) } const uint8_t* src = tensor.aclData; - uint8_t* dst = tensor.transBuf.data(); + auto dst = tensor.transBuf.begin(); auto dtypeSize = SizeOfAclDType(tensor); + auto dstSize = tensor.transBuf.size(); for (int64_t nIdx = 0; nIdx < n; nIdx++) { int64_t nHeadAddr = nIdx * chw; for (int64_t cIdx = 0; cIdx < c; cIdx++) { @@ -460,8 +473,13 @@ static DebuggerErrno FRAC_Z_TO_NCHW(AclTensorInfo& tensor) auto c0Idx = cIdx % c0; auto ncIdx = nIdx; auto srcIdx = c1Idx * hwncc0 + hIdx * wncc0 + wIdx * ncc0 + ncIdx * c0 + c0Idx; - /* 此处由偏移计算逻辑保障不会越界读写 */ - std::memcpy(dst + dstIdx * dtypeSize, src + srcIdx * dtypeSize, dtypeSize); + auto dstOffset = dstIdx * dtypeSize; + auto srcOffset = srcIdx * dtypeSize; + if (dstOffset + dtypeSize > dstSize) { + return DebuggerErrno::ERROR_INVALID_VALUE; + } + std::copy(src + srcOffset, src + srcOffset + dtypeSize, + dst + dstOffset); } } } @@ -511,11 +529,16 @@ static DebuggerErrno FRAC_NZ_TO_NCHW(AclTensorInfo& tensor) auto w0 = tensor.deviceShape[shapeSize - fnzW0]; auto h1h0w0 = h1 * h0 * w0; auto w1h1h0w0 = w1 * h1h0w0; + if (w0 == 0) { + LOG_WARNING(DebuggerErrno::ERROR_INVALID_VALUE, tensor + ": Invalid shape size."); + return DebuggerErrno::ERROR_INVALID_VALUE; + } auto numW1 = w / w0; const uint8_t* src = tensor.aclData; - uint8_t* dst = tensor.transBuf.data(); + auto dst = tensor.transBuf.begin(); auto dtypeSize = SizeOfAclDType(tensor); + auto dstSize = tensor.transBuf.size(); for (int64_t timesIdx = 0; timesIdx < times; timesIdx++) { auto timesHead = timesIdx * w1h1h0w0; @@ -527,8 +550,13 @@ static DebuggerErrno FRAC_NZ_TO_NCHW(AclTensorInfo& tensor) for (int64_t i = 0; i < w0; ++i) { int64_t srcIdx = h1h0Head + w1Idx * h1h0w0 + i; int64_t dstIdx = srcHHead + w1Idx * w0 + i; - /* 此处由偏移计算逻辑保障不会越界读写 */ - std::memcpy(dst + dstIdx * dtypeSize, src + srcIdx * dtypeSize, dtypeSize); + int64_t dstOffset = dstIdx * dtypeSize; + int64_t srcOffset = srcIdx * dtypeSize; + if (dstOffset + dtypeSize > dstSize) { + return DebuggerErrno::ERROR_INVALID_VALUE; + } + std::copy(src + srcOffset, src + srcOffset + dtypeSize, + dst + dstOffset); } } auto w1Head = numW1 * w0; @@ -536,8 +564,12 @@ static DebuggerErrno FRAC_NZ_TO_NCHW(AclTensorInfo& tensor) auto srcWIdx = w1Head + w0Idx; int64_t srcIdx = h1h0Head + numW1 * h1h0w0 + w0Idx; int64_t dstIdx = srcHHead + srcWIdx; - /* 此处由偏移计算逻辑保障不会越界读写 */ - std::memcpy(dst + dstIdx * dtypeSize, src + srcIdx * dtypeSize, dtypeSize); + int64_t dstOffset = dstIdx * dtypeSize; + int64_t srcOffset = srcIdx * dtypeSize; + if (dstOffset + dtypeSize > dstSize) { + return DebuggerErrno::ERROR_INVALID_VALUE; + } + std::copy(src + srcOffset, src + srcOffset + dtypeSize, dst + dstOffset); } } } @@ -556,6 +588,10 @@ static DebuggerErrno NC1HWC0_TO_NCHW(AclTensorInfo& tensor) auto w = tensor.hostShape[kW]; auto c1 = tensor.deviceShape[kDim1]; auto c0 = tensor.deviceShape[kDim4]; + if (c0 == 0) { + LOG_WARNING(DebuggerErrno::ERROR_INVALID_VALUE, tensor + ": Invalid shape size."); + return DebuggerErrno::ERROR_INVALID_VALUE; + } auto hw = h * w; auto chw = c * hw; @@ -564,8 +600,9 @@ static DebuggerErrno NC1HWC0_TO_NCHW(AclTensorInfo& tensor) auto c1hwc0 = c1 * hwc0; const uint8_t* src = tensor.aclData; - uint8_t* dst = tensor.transBuf.data(); + auto dst = tensor.transBuf.begin(); auto dtypeSize = SizeOfAclDType(tensor); + auto dstSize = tensor.transBuf.size(); for (int64_t nIndex = 0; nIndex < n; nIndex++) { int64_t nHeadAddr = nIndex * chw; for (int64_t cIndex = 0; cIndex < c; cIndex++) { @@ -577,8 +614,13 @@ static DebuggerErrno NC1HWC0_TO_NCHW(AclTensorInfo& tensor) int64_t c1Index = cIndex / c0; int64_t c0Index = cIndex % c0; int64_t srcIdx = nIndex * c1hwc0 + c1Index * hwc0 + hIndex * wc0 + wIndex * c0 + c0Index; - /* 此处由偏移计算逻辑保障不会越界读写 */ - std::memcpy(dst + dstIdx * dtypeSize, src + srcIdx * dtypeSize, dtypeSize); + int64_t dstOffset = dstIdx * dtypeSize; + int64_t srcOffset = srcIdx * dtypeSize; + if (dstOffset + dtypeSize > dstSize) { + return DebuggerErrno::ERROR_INVALID_VALUE; + } + std::copy(src + srcOffset, src + srcOffset + dtypeSize, + dst + dstOffset); } } } @@ -599,6 +641,10 @@ static DebuggerErrno NDC1HWC0_TO_NCDHW(AclTensorInfo& tensor) auto w = tensor.hostShape[W_ncdhw]; auto c1 = tensor.deviceShape[C1_ndc1hwc0]; auto c0 = tensor.deviceShape[C0_ndc1hwc0]; + if (c0 == 0) { + LOG_WARNING(DebuggerErrno::ERROR_INVALID_VALUE, tensor + ": Invalid shape size."); + return DebuggerErrno::ERROR_INVALID_VALUE; + } const int64_t cdhw = c * d * h * w; const int64_t dhw = d * h * w; @@ -609,8 +655,9 @@ static DebuggerErrno NDC1HWC0_TO_NCDHW(AclTensorInfo& tensor) const int64_t wc0 = w * c0; const uint8_t* src = tensor.aclData; - uint8_t* dst = tensor.transBuf.data(); + auto dst = tensor.transBuf.begin(); auto dtypeSize = SizeOfAclDType(tensor); + auto dstSize = tensor.transBuf.size(); for (int64_t nIndex = 0; nIndex < n; nIndex++) { int64_t nHead = nIndex * cdhw; for (int64_t cIndex = 0; cIndex < c; cIndex++) { @@ -625,8 +672,13 @@ static DebuggerErrno NDC1HWC0_TO_NCDHW(AclTensorInfo& tensor) int64_t c0Index = cIndex % c0; auto srcIdx = nIndex * dc1hwc0 + dIndex * c1hwc0 + c1Index * hwc0 + hIndex * wc0 + wIndex * c0 + c0Index; - /* 此处由偏移计算逻辑保障不会越界读写 */ - std::memcpy(dst + dstIdx * dtypeSize, src + srcIdx * dtypeSize, dtypeSize); + int64_t dstOffset = dstIdx * dtypeSize; + int64_t srcOffset = srcIdx * dtypeSize; + if (dstOffset + dtypeSize > dstSize) { + return DebuggerErrno::ERROR_INVALID_VALUE; + } + std::copy(src + srcOffset, src + srcOffset + dtypeSize, + dst + dstOffset); } } } @@ -652,8 +704,9 @@ static DebuggerErrno C1HWNCoC0_TO_NCHW(AclTensorInfo& tensor) auto cubeK = GetCubeSizeByType(tensor.dtype); const uint8_t* src = tensor.aclData; - uint8_t* dst = tensor.transBuf.data(); + auto dst = tensor.transBuf.begin(); auto dtypeSize = SizeOfAclDType(tensor); + auto dstSize = tensor.transBuf.size(); for (int64_t nIndex = 0; nIndex < n; nIndex++) { for (int64_t cIndex = 0; cIndex < c; cIndex++) { for (int64_t hIndex = 0; hIndex < h; hIndex++) { @@ -664,8 +717,13 @@ static DebuggerErrno C1HWNCoC0_TO_NCHW(AclTensorInfo& tensor) int64_t coIndex = c0Index; int64_t srcIdx = c1Index * h * w * n * co * c0 + hIndex * w * n * co * c0 + wIndex * n * co * c0 + nIndex * co * c0 + coIndex * c0 + c0Index; - /* 此处由偏移计算逻辑保障不会越界读写 */ - std::memcpy(dst + dstIdx * dtypeSize, src + srcIdx * dtypeSize, dtypeSize); + int64_t dstOffset = dstIdx * dtypeSize; + int64_t srcOffset = srcIdx * dtypeSize; + if (dstOffset + dtypeSize > dstSize) { + return DebuggerErrno::ERROR_INVALID_VALUE; + } + std::copy(src + srcOffset, src + srcOffset + dtypeSize, + dst + dstOffset); } } } @@ -691,6 +749,10 @@ static DebuggerErrno FRAC_Z3D_TO_NCDHW(AclTensorInfo& tensor) auto w = tensor.hostShape[W_ncdhw]; constexpr int kFZ3D_C0 = 3; auto c0 = tensor.deviceShape[kFZ3D_C0]; + if (c0 == 0) { + LOG_WARNING(DebuggerErrno::ERROR_INVALID_VALUE, tensor + ": Invalid shape size."); + return DebuggerErrno::ERROR_INVALID_VALUE; + } auto cube_k = GetCubeSizeByType(tensor.dtype); auto c1 = DivCeil(c, cube_k); constexpr int64_t kNiSize = 16; @@ -704,8 +766,9 @@ static DebuggerErrno FRAC_Z3D_TO_NCDHW(AclTensorInfo& tensor) auto cdhw = c * dhw; const uint8_t* src = tensor.aclData; - uint8_t* dst = tensor.transBuf.data(); + auto dst = tensor.transBuf.begin(); auto dtypeSize = SizeOfAclDType(tensor); + auto dstSize = tensor.transBuf.size(); for (int64_t nIdx = 0; nIdx < n; nIdx++) { int64_t nHead = nIdx * cdhw; for (int64_t cIdx = 0; cIdx < c; cIdx++) { @@ -721,8 +784,13 @@ static DebuggerErrno FRAC_Z3D_TO_NCDHW(AclTensorInfo& tensor) int64_t ncIdx = nIdx; int64_t srcIdx = dIdx * c1hwn1n0c0 + c1I * c1hwn1n0c0 + hIdx * wn1n0c0 + wI * n1n0c0 + ncIdx * c0 + c0I; - /* 此处由偏移计算逻辑保障不会越界读写 */ - std::memcpy(dst + dstIdx * dtypeSize, src + srcIdx * dtypeSize, dtypeSize); + int64_t dstOffset = dstIdx * dtypeSize; + int64_t srcOffset = srcIdx * dtypeSize; + if (dstOffset + dtypeSize > dstSize) { + return DebuggerErrno::ERROR_INVALID_VALUE; + } + std::copy(src + srcOffset, src + srcOffset + dtypeSize, + dst + dstOffset); } } } @@ -749,11 +817,11 @@ DebuggerErrno TransFormatD2H(AclTensorInfo& tensor) } } -static void TransBf16ToFp32(const uint8_t* input, size_t num, uint8_t* output, size_t bufferSize) +static DebuggerErrno TransBf16ToFp32(const uint8_t* input, size_t num, uint8_t* output, size_t bufferSize) { if (bufferSize < num * sizeof(float)) { LOG_ERROR(DebuggerErrno::ERROR_BUFFER_OVERFLOW, "Insufficient space for converting data from bf16 to fp32."); - return; + return DebuggerErrno::ERROR_BUFFER_OVERFLOW; } const DataUtils::BFloat16* in = reinterpret_cast(input); float* out = reinterpret_cast(output); @@ -761,36 +829,93 @@ static void TransBf16ToFp32(const uint8_t* input, size_t num, uint8_t* output, s for (size_t i = 0; i < num; i++) { out[i] = static_cast(in[i]); } + return DebuggerErrno::OK; } -DebuggerErrno TransDtype(AclTensorInfo& tensor, AclDtype to) +static DebuggerErrno TransInt4ToInt8(const uint8_t* input, size_t elemNums, uint8_t* output, size_t bufferSize) { + if (bufferSize < elemNums * sizeof(int8_t)) { + LOG_ERROR(DebuggerErrno::ERROR_BUFFER_OVERFLOW, "Insufficient space for converting data from int4 to int8."); + return DebuggerErrno::ERROR_BUFFER_OVERFLOW; + } + const int8_t *srcData = reinterpret_cast(input); + int8_t *dstData = reinterpret_cast(output); + size_t inputLength = elemNums / 2; + int maxValue = 7; + int minValue = -8; + int signBitShift = 3; + int signBitMask = 0x08; + for (size_t i = 0; i < inputLength; ++i) { + int8_t s = *srcData; + int8_t t = s & 0xf; + // keep the sign bit not change + int8_t signBit = (t & signBitMask) >> signBitShift; + if (signBit == 1) { + t = t | 0xf0; + } else { + t = t & 0x0f; + } + if (t < minValue || t > maxValue) { + LOG_ERROR(DebuggerErrno::ERROR_INVALID_VALUE, "Invalid int4 value."); + } + *dstData = t; + ++dstData; + + int highByteShift = 4; + t = s >> highByteShift; + signBit = (t & signBitMask) >> signBitShift; + if (signBit == 1) { + t = t | 0xf0; + } else { + t = t & 0x0f; + } + if (t < minValue || t > maxValue) { + LOG_ERROR(DebuggerErrno::ERROR_INVALID_VALUE, "Invalid int4 value."); + } + *dstData = t; + ++dstData; + ++srcData; + } + return DebuggerErrno::OK; +} - const static std::set> kSupportedDtypeTrans = { - {AclDtype::DT_BF16, AclDtype::DT_FLOAT}, - }; +DebuggerErrno TransDtype(AclTensorInfo& tensor, AclDtype to) +{ if (tensor.dtype == to) { return DebuggerErrno::OK; } - if (kSupportedDtypeTrans.find({tensor.dtype, to}) == kSupportedDtypeTrans.end()) { - return DebuggerErrno::ERROR_UNKNOWN_TRANS; - } - + tensor.oriDtype = tensor.dtype; std::vector buffer; - AssertConsis(tensor); + try { + AssertConsis(tensor); + } catch (const std::runtime_error& e) { + LOG_ERROR(DebuggerErrno::ERROR_INVALID_OPERATION, e.what()); + return DebuggerErrno::ERROR_INVALID_OPERATION; + } size_t bufferSize = EleNumOfTensor(tensor) * SizeOfAclDType(to); - buffer.reserve(bufferSize); + buffer.resize(bufferSize); const uint8_t* input = tensor.transBuf.empty() ? tensor.aclData : tensor.transBuf.data(); uint8_t* output = buffer.data(); + DebuggerErrno ret; - /* 目前仅支持bf16->fp32,若有通用转换需求再用更泛化的方式重写 */ if (tensor.dtype == AclDtype::DT_BF16 && to == AclDtype::DT_FLOAT) { - TransBf16ToFp32(input, EleNumOfTensor(tensor), output, bufferSize); + ret = TransBf16ToFp32(input, EleNumOfTensor(tensor), output, bufferSize); + } else if (tensor.dtype == AclDtype::DT_INT4 && to == AclDtype::DT_INT8) { + ret = TransInt4ToInt8(input, EleNumOfTensor(tensor), output, bufferSize); + } else { + LOG_ERROR(DebuggerErrno::ERROR_UNKNOWN_TRANS, tensor + ": Trans " + DataUtils::GetDTypeString(tensor.dtype) + + " to " + DataUtils::GetDTypeString(to) + " is not supported."); + return DebuggerErrno::ERROR_UNKNOWN_TRANS; + } + + if (ret != DebuggerErrno::OK) { + return ret; } tensor.transBuf = std::move(buffer); + tensor.dtype = to; return DebuggerErrno::OK; } diff --git a/debug/accuracy_tools/msprobe/ccsrc/core/AclTensor.hpp b/debug/accuracy_tools/msprobe/ccsrc/core/AclTensor.hpp index 8b5ba5b06d935d5aaa2dff35e921b9072db6aa1a..f9d289ac17ffb7ab5cb2a9b655b6f3dcf45be98d 100644 --- a/debug/accuracy_tools/msprobe/ccsrc/core/AclTensor.hpp +++ b/debug/accuracy_tools/msprobe/ccsrc/core/AclTensor.hpp @@ -40,6 +40,7 @@ struct AclTensorInfo { std::string dumpPath; const uint8_t* aclData; AclDtype dtype; + AclDtype oriDtype; AclFormat deviceFmt; AclFormat hostFmt; AclShape deviceShape; @@ -52,7 +53,7 @@ struct AclTensorInfo { std::vector transBuf; std::string ToString() const { - return "AclTensor(path=" + dumpPath + ",dtype=" + std::to_string(dtype) + ",inout=" + inout + ")"; + return "AclTensor(path=" + dumpPath + ",dtype=" + DataUtils::GetDTypeString(dtype) + ",inout=" + inout + ")"; } }; @@ -65,12 +66,13 @@ inline std::string operator+(const AclTensorInfo& tensor, const std::string& s) } namespace AclTensor { -size_t SizeOfTensor(const AclTensorInfo& tensor, bool host=true); +size_t SizeOfTensor(const AclTensorInfo& tensor, bool host = true); template AclTensorInfo ParseAttrsFromDumpData(const std::string &dumpPath, const uint8_t* data, const T& tensor, const std::string& io, uint32_t slot); DebuggerErrno TransFormatD2H(AclTensorInfo& tensor); DebuggerErrno TransDtype(AclTensorInfo& tensor, AclDtype to); +bool IsDtypeSupportTrans(AclDtype dtype); } } diff --git a/debug/accuracy_tools/msprobe/ccsrc/core/PrecisionDebugger.cpp b/debug/accuracy_tools/msprobe/ccsrc/core/PrecisionDebugger.cpp index d4d74f1962222558c88c576b8ffbd8c474e152f2..16a640253b726841bd35fbdf51b66e1c0cae27dd 100644 --- a/debug/accuracy_tools/msprobe/ccsrc/core/PrecisionDebugger.cpp +++ b/debug/accuracy_tools/msprobe/ccsrc/core/PrecisionDebugger.cpp @@ -19,6 +19,7 @@ #include "base/ErrorInfos.hpp" #include "base/DebuggerConfig.hpp" #include "third_party/ACL/AclApi.hpp" +#include "core/mindspore/MSAclDumper.hpp" #include "PrecisionDebugger.hpp" namespace MindStudioDebugger { @@ -83,12 +84,12 @@ int32_t PrecisionDebugger::Initialize(const std::string& framework, const std::s return ret; } - if(AscendCLApi::LoadAclApi() != DebuggerErrno::OK) { + if (AscendCLApi::LoadAclApi() != DebuggerErrno::OK) { return -1; } const DebuggerConfig& cfg = DebuggerConfig::GetInstance(); - for (auto iter = subDebuggers.begin(); iter != subDebuggers.end(); ) { + for (auto iter = subDebuggers.begin(); iter != subDebuggers.end();) { if (!(*iter)->Condition(cfg)) { iter = subDebuggers.erase(iter); } else { @@ -133,25 +134,7 @@ void PrecisionDebugger::Stop() void PrecisionDebugger::Step() { - return Step(1); -} - -void PrecisionDebugger::Step(uint32_t step) -{ - DEBUG_FUNC_TRACE(); - if (!initialized) { - return; - } - - if (step > UINT32_MAX - curStep) { - throw std::runtime_error("Step over upper limit(4294967295)."); - } - curStep += step; - CALL_ACL_API(aclrtSynchronizeDevice); - - for (auto task : subDebuggers) { - task->OnStep(curStep); - } + MSAclDumper::GetInstance().Step(); } } \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/ccsrc/core/mindspore/MSAclDumper.cpp b/debug/accuracy_tools/msprobe/ccsrc/core/mindspore/MSAclDumper.cpp index 2d80ed3ce1ab11ee5ddf9bad18583a6813f32529..e4acc246c5b70d1a035aafeadcc6275e1aeedf37 100644 --- a/debug/accuracy_tools/msprobe/ccsrc/core/mindspore/MSAclDumper.cpp +++ b/debug/accuracy_tools/msprobe/ccsrc/core/mindspore/MSAclDumper.cpp @@ -24,7 +24,7 @@ namespace MindStudioDebugger { -void MSAclDumper::OnStepBegin(uint32_t device, uint32_t curStep, ExtArgs& args) +void MSAclDumper::OnStepBegin(uint32_t device, ExtArgs& args) { DEBUG_FUNC_TRACE(); if (!PrecisionDebugger::GetInstance().IsEnable()) { @@ -41,7 +41,7 @@ void MSAclDumper::OnStepBegin(uint32_t device, uint32_t curStep, ExtArgs& args) rank = static_cast(device); } - AclDumper::GetInstance().SetDump(rank, curStep, args); + AclDumper::GetInstance().SetDump(rank, msprobeStep, args); return; } @@ -51,6 +51,11 @@ void MSAclDumper::OnStepEnd(ExtArgs& args) AclDumper::GetInstance().FinalizeDump(args); } +void MSAclDumper::Step() +{ + msprobeStep++; +} + __attribute__((constructor)) void RegisterMSAclDumper() { MSAclDumper::GetInstance().Register(); diff --git a/debug/accuracy_tools/msprobe/ccsrc/core/mindspore/MSAclDumper.hpp b/debug/accuracy_tools/msprobe/ccsrc/core/mindspore/MSAclDumper.hpp index cd09bf51af0dac67065d51b8ce60c20f011cd585..579bdb8009629b12017ec251fc5a69be72053f7a 100644 --- a/debug/accuracy_tools/msprobe/ccsrc/core/mindspore/MSAclDumper.hpp +++ b/debug/accuracy_tools/msprobe/ccsrc/core/mindspore/MSAclDumper.hpp @@ -36,8 +36,9 @@ public: cfg.GetDebugLevel() == DebuggerLevel::L2; } - void OnStepBegin(uint32_t device, uint32_t curStep, ExtArgs& args); + void OnStepBegin(uint32_t device, ExtArgs& args); void OnStepEnd(ExtArgs& args); + void Step(); private: MSAclDumper() = default; @@ -46,6 +47,7 @@ private: MSAclDumper& operator=(const MSAclDumper &obj) = delete; explicit MSAclDumper(MSAclDumper &&obj) = delete; MSAclDumper& operator=(MSAclDumper &&obj) = delete; + uint32_t msprobeStep{0}; }; } \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/ccsrc/core/mindspore/MindSporeTrigger.cpp b/debug/accuracy_tools/msprobe/ccsrc/core/mindspore/MindSporeTrigger.cpp index 631ea7c4acf4666b911a3bb5f28a3c6cc4fe0d54..2223808fc540c78f01c3574e769da482600c3c9a 100644 --- a/debug/accuracy_tools/msprobe/ccsrc/core/mindspore/MindSporeTrigger.cpp +++ b/debug/accuracy_tools/msprobe/ccsrc/core/mindspore/MindSporeTrigger.cpp @@ -23,12 +23,13 @@ namespace MindStudioDebugger { bool MindSporeTrigger::stepBeginFlag = false; -void MindSporeTrigger::TriggerOnStepBegin(uint32_t device, uint32_t curStep, ExtArgs& args) +void MindSporeTrigger::TriggerOnStepBegin(uint32_t device, uint32_t /* curStep */, ExtArgs& args) { DEBUG_FUNC_TRACE(); CleanErrorInfoCache(); - MSAclDumper::GetInstance().OnStepBegin(device, curStep, args); + MSAclDumper::GetInstance().OnStepBegin(device, args); + stepBeginFlag = true; CleanErrorInfoCache(); diff --git a/debug/accuracy_tools/msprobe/ccsrc/if/mindspore/MindSporeDbgHook.cpp b/debug/accuracy_tools/msprobe/ccsrc/if/mindspore/MindSporeDbgHook.cpp index 42f3a2e5b61d5da021b2ef7da4a7b88c6dc2abbb..db279a33f17311a7c3681e7d899c2fa85a6fdcc8 100644 --- a/debug/accuracy_tools/msprobe/ccsrc/if/mindspore/MindSporeDbgHook.cpp +++ b/debug/accuracy_tools/msprobe/ccsrc/if/mindspore/MindSporeDbgHook.cpp @@ -34,6 +34,9 @@ EXPORT_SYMBOL void MS_DbgOnStepBegin(uint32_t device, int32_t curStep, } /* mindspore使用了_GLIBCXX_USE_CXX11_ABI=0,为了解决CXX版本兼容问题,此处将string转char*使用 */ if (ext.first == static_cast(MindStudioDebugger::MindStudioExtensionArgs::ALL_KERNEL_NAMES)) { + if (ext.second == nullptr) { + continue; + } std::vector* ss = reinterpret_cast*>(ext.second); strBuf = new const char*[(*ss).size() + 1]; strBuf[(*ss).size()] = nullptr; diff --git a/debug/accuracy_tools/msprobe/ccsrc/if/python/CPythonAgent.cpp b/debug/accuracy_tools/msprobe/ccsrc/if/python/CPythonAgent.cpp index 4b8fc03491e2c0792c3c707c272e7b587d60c7ad..faee46d8e42a5fc5daa809d0e8b9590d0dc5ce21 100644 --- a/debug/accuracy_tools/msprobe/ccsrc/if/python/CPythonAgent.cpp +++ b/debug/accuracy_tools/msprobe/ccsrc/if/python/CPythonAgent.cpp @@ -56,7 +56,7 @@ static PyObject* CPythonAgentRegister(PyObject *module, PyObject *args) static PyObject* CPythonAgentUnRegister(PyObject *module, PyObject *obj) { CPythonUtils::PythonStringObject name(obj); - if(name.IsNone()) { + if (name.IsNone()) { PyErr_SetString(PyExc_TypeError, "\"name\" should be a string."); Py_RETURN_NONE; } @@ -68,7 +68,7 @@ static PyObject* CPythonAgentUnRegister(PyObject *module, PyObject *obj) static PyObject* CPythonAgentGetContext(PyObject *module, PyObject *obj) { CPythonUtils::PythonStringObject name(obj); - if(name.IsNone()) { + if (name.IsNone()) { PyErr_SetString(PyExc_TypeError, "\"name\" should be a string."); Py_RETURN_NONE; } diff --git a/debug/accuracy_tools/msprobe/ccsrc/if/python/PrecisionDebuggerIfPython.cpp b/debug/accuracy_tools/msprobe/ccsrc/if/python/PrecisionDebuggerIfPython.cpp index da1cf3cf1c5d4c8894d0b12b5518657b5928a8d6..26997e7eefd225a18a3805d5edd8a1eddeabd137 100644 --- a/debug/accuracy_tools/msprobe/ccsrc/if/python/PrecisionDebuggerIfPython.cpp +++ b/debug/accuracy_tools/msprobe/ccsrc/if/python/PrecisionDebuggerIfPython.cpp @@ -99,20 +99,9 @@ static PyObject* PrecisionDebuggerStop(PyObject *self) Py_RETURN_NONE; } -static PyObject* PrecisionDebuggerStep(PyObject *self, PyObject *args) +static PyObject* PrecisionDebuggerStep(PyObject *self) { - if (args == nullptr || PyTuple_GET_SIZE(args) == 0) { - PrecisionDebugger::GetInstance().Step(); - Py_RETURN_NONE; - } - - PyObject* increment = PyTuple_GetItem(args, 0); - if (!PyLong_Check(increment)) { - PyErr_SetString(PyExc_TypeError, "\'step\' should be a int."); - Py_RETURN_NONE; - } - - PrecisionDebugger::GetInstance().Step(PyLong_AsUnsignedLong(increment)); + PrecisionDebugger::GetInstance().Step(); Py_RETURN_NONE; } @@ -126,7 +115,7 @@ PyDoc_STRVAR(StepDoc, static PyMethodDef PrecisionDebuggerMethods[] = { {"start", reinterpret_cast(PrecisionDebuggerStart), METH_NOARGS, StartDoc}, {"stop", reinterpret_cast(PrecisionDebuggerStop), METH_NOARGS, StopDoc}, - {"step", reinterpret_cast(PrecisionDebuggerStep), METH_VARARGS, StepDoc}, + {"step", reinterpret_cast(PrecisionDebuggerStep), METH_NOARGS, StepDoc}, {nullptr, nullptr, 0, nullptr} }; diff --git a/debug/accuracy_tools/msprobe/ccsrc/third_party/ACL/AclApi.cpp b/debug/accuracy_tools/msprobe/ccsrc/third_party/ACL/AclApi.cpp index 1636c6998d9096b62e9a7f281c7e5ac1b4de4818..edbf8292a1494957dff93d807508b32066fb1d43 100644 --- a/debug/accuracy_tools/msprobe/ccsrc/third_party/ACL/AclApi.cpp +++ b/debug/accuracy_tools/msprobe/ccsrc/third_party/ACL/AclApi.cpp @@ -52,7 +52,7 @@ DebuggerErrno LoadAclApi() return DebuggerErrno::OK; } - hLibAscendcl = dlopen(kLibAscendclName, RTLD_LAZY); + hLibAscendcl = dlopen(kLibAscendclName, RTLD_LAZY | RTLD_NOLOAD); if (hLibAscendcl == nullptr) { LOG_ERROR(DebuggerErrno::ERROR_DEPENDENCY_NOT_FIND, "Failed to search libascendcl.so." + std::string(dlerror())); @@ -83,7 +83,7 @@ DebuggerErrno LoadAclApi() } /* 规避adump的bug,mindspore场景优先使用libmindspore_ascend.so中的符号 */ - void* handler = dlopen(kLibMSAscendName, RTLD_LAZY); + void* handler = dlopen(kLibMSAscendName, RTLD_LAZY | RTLD_NOLOAD); std::string libName = kLibMSAscendName; if (handler == nullptr) { handler = hLibAscendcl; @@ -152,5 +152,5 @@ aclError ACLAPI_aclrtSynchronizeDevice() return aclrtSynchronizeDeviceFunc(); } -} +} } diff --git a/debug/accuracy_tools/msprobe/ccsrc/utils/CPythonUtils.cpp b/debug/accuracy_tools/msprobe/ccsrc/utils/CPythonUtils.cpp index fd944f62db4ff728d1aa2c5d1d5ff818bd5dcf62..c255aab193b163c94a42f85789de9dc35d71ab25 100644 --- a/debug/accuracy_tools/msprobe/ccsrc/utils/CPythonUtils.cpp +++ b/debug/accuracy_tools/msprobe/ccsrc/utils/CPythonUtils.cpp @@ -108,7 +108,7 @@ int32_t PythonObject::To(uint32_t& output) const if (!PyLong_Check(ptr)) { return -1; } - output = static_cast(PyLong_AsUnsignedLong(ptr)); + output = static_cast(PyLong_AsUnsignedLong(ptr)); return 0; } @@ -155,7 +155,7 @@ PythonObject PythonObject::Get(const std::string& name, bool ignore) const return ret; } -PythonObject PythonObject::Call(bool ignore) +PythonObject PythonObject::Call(bool ignore) noexcept { if (!PyCallable_Check(ptr)) { if (!ignore) { @@ -173,7 +173,7 @@ PythonObject PythonObject::Call(bool ignore) return ret; } -PythonObject PythonObject::Call(PythonTupleObject& args, bool ignore) +PythonObject PythonObject::Call(PythonTupleObject& args, bool ignore) noexcept { if (!PyCallable_Check(ptr)) { if (!ignore) { @@ -191,7 +191,7 @@ PythonObject PythonObject::Call(PythonTupleObject& args, bool ignore) return ret; } -PythonObject PythonObject::Call(PythonTupleObject& args, PythonDictObject& kwargs, bool ignore) +PythonObject PythonObject::Call(PythonTupleObject& args, PythonDictObject& kwargs, bool ignore) noexcept { if (!PyCallable_Check(ptr)) { if (!ignore) { @@ -203,7 +203,7 @@ PythonObject PythonObject::Call(PythonTupleObject& args, PythonDictObject& kwarg if (args.IsNone() || kwargs.IsNone()) { if (!ignore) { PyErr_SetString(PyExc_TypeError, "Call python object with invalid parameters."); - } + } return PythonObject(); } @@ -230,7 +230,7 @@ PythonObject PythonObject::GetGlobal(const std::string& name, bool ignore) } -PythonObject PythonObject::Import(const std::string& name, bool ignore) +PythonObject PythonObject::Import(const std::string& name, bool ignore) noexcept { PyObject* m = PyImport_ImportModule(name.c_str()); if (m == nullptr) { diff --git a/debug/accuracy_tools/msprobe/ccsrc/utils/CPythonUtils.hpp b/debug/accuracy_tools/msprobe/ccsrc/utils/CPythonUtils.hpp index 40ebcb1dafd505fd7dfa3bda1c2c1609cb60297a..8aa2a4a02c64bae174a9d1e2b783b2e41b74bb95 100644 --- a/debug/accuracy_tools/msprobe/ccsrc/utils/CPythonUtils.hpp +++ b/debug/accuracy_tools/msprobe/ccsrc/utils/CPythonUtils.hpp @@ -40,14 +40,14 @@ namespace CPythonUtils { * | tuple | PythonTupleObject | * | dict | PythonDictObject | * ------------------------------------------- - * + * * 创建对象的方式: * 1、通过原生PyObject*类型创建,PythonObject生命周期内会持有原生对象的一个引用 * 2、通过From方法从c++对象创建 * 3、通过GetGlobal、Import等方法从解释器上下文获取 * 4、通过GetRegisteredPyObj获取到上下文的python对象 * 5、通过已有PythonObject对象的Get、GetItem等方法获取子对象 - * + * * 对象转换: * 1、对于转换成PyObject*、bool、string的场景,支持隐式转换 * 2、对于非通用类型转换,调用To方法,返回0表示成功 @@ -56,7 +56,7 @@ namespace CPythonUtils { * python维度支持bool()的都可以转bool(即并非只有bool类型支持转换,下同) * 支持str()的都可以转string * 可迭代对象(且元素支持转换)都可以转vector - * + * * 对象传递: * 1、子类可以安全传递或拷贝给PythonObject对象 * 2、PythonObject传给子类时,若类型匹配,可以安全转递,否则会转为None @@ -101,9 +101,9 @@ public: } /* 获取全局对象 */ - static PythonObject GetGlobal(const std::string& name, bool ignore=true); + static PythonObject GetGlobal(const std::string& name, bool ignore = true); /* 获取模块对象;若其还未加载至缓存,则加载一遍 */ - static PythonObject Import(const std::string& name, bool ignore=true); + static PythonObject Import (const std::string& name, bool ignore = true) noexcept; /* From/To转换,统一放一份在基类,用于遍历迭代器等场景 */ static PythonObject From(const PythonObject& input); @@ -136,12 +136,12 @@ public: bool IsCallable() const {return PyCallable_Check(ptr);} /* 用于调用可调用对象,相当于python代码中的obj(),为了简单只实现了args+kwargs参数形式 */ - PythonObject Call(bool ignore=true); - PythonObject Call(PythonTupleObject& args, bool ignore=true); - PythonObject Call(PythonTupleObject& args, PythonDictObject& kwargs, bool ignore=true); + PythonObject Call(bool ignore = true) noexcept; + PythonObject Call(PythonTupleObject& args, bool ignore = true) noexcept; + PythonObject Call(PythonTupleObject& args, PythonDictObject& kwargs, bool ignore = true) noexcept; /* 用于获取对象属性,相当于python代码中的obj.xx */ - PythonObject Get(const std::string& name, bool ignore=true) const; + PythonObject Get(const std::string& name, bool ignore = true) const; PythonObject& NewRef() { Py_XINCREF(ptr); return *this; @@ -159,9 +159,9 @@ public: operator std::string() const { return ToString(); } - PythonObject operator()(bool ignore=true) {return Call(ignore);} - PythonObject operator()(PythonTupleObject& args, bool ignore=true) {return Call(args, ignore);} - PythonObject operator()(PythonTupleObject& args, PythonDictObject& kwargs, bool ignore=true) { + PythonObject operator()(bool ignore = true) {return Call(ignore);} + PythonObject operator()(PythonTupleObject& args, bool ignore = true) {return Call(args, ignore);} + PythonObject operator()(PythonTupleObject& args, PythonDictObject& kwargs, bool ignore = true) { return Call(args, kwargs, ignore); } @@ -170,7 +170,7 @@ protected: Py_XDECREF(ptr); if (o == nullptr) { o = Py_None; - } + } Py_INCREF(o); ptr = o; } @@ -220,11 +220,11 @@ public: size_t Size() const; template - PythonListObject& Append(T value, bool ignore=true); - PythonObject GetItem(size_t pos, bool ignore=true); - PythonListObject& SetItem(size_t pos, PythonObject& item, bool ignore=true); - PythonListObject& Insert(int64_t pos, PythonObject& item, bool ignore=true); - PythonTupleObject ToTuple(bool ignore=true); + PythonListObject& Append(T value, bool ignore = true); + PythonObject GetItem(size_t pos, bool ignore = true); + PythonListObject& SetItem(size_t pos, PythonObject& item, bool ignore = true); + PythonListObject& Insert(int64_t pos, PythonObject& item, bool ignore = true); + PythonTupleObject ToTuple(bool ignore = true); }; class PythonTupleObject : public PythonObject { @@ -236,7 +236,7 @@ public: static PythonTupleObject From(const std::vector& input); size_t Size() const; - PythonObject GetItem(size_t pos, bool ignore=true); + PythonObject GetItem(size_t pos, bool ignore = true); }; class PythonDictObject : public PythonObject { @@ -248,11 +248,11 @@ public: static PythonDictObject From(const std::map& input); template - PythonDictObject& Add(T1 key, T2 value, bool ignore=true); + PythonDictObject& Add(T1 key, T2 value, bool ignore = true); template - PythonDictObject& Delete(T key, bool ignore=true); + PythonDictObject& Delete(T key, bool ignore = true); template - PythonObject GetItem(T key, bool ignore=true); + PythonObject GetItem(T key, bool ignore = true); }; /**************************************************************************************************/ diff --git a/debug/accuracy_tools/msprobe/ccsrc/utils/DataUtils.cpp b/debug/accuracy_tools/msprobe/ccsrc/utils/DataUtils.cpp index c2d7df85294f7c96f0fe1a1b9458dfd2ad2e502c..23088d48e31a15af69bdc19939e490b27c4a50a6 100644 --- a/debug/accuracy_tools/msprobe/ccsrc/utils/DataUtils.cpp +++ b/debug/accuracy_tools/msprobe/ccsrc/utils/DataUtils.cpp @@ -56,10 +56,15 @@ BFloat16::BFloat16(float f32) BFloat16::operator float() const { - float f32 = 0; - uint32_t tmp = value_; - tmp <<= 16; - std::memcpy(&f32, &tmp, sizeof(f32)); + /* 为了兼容性,不要用c++20的bit_cast */ + union + { + float f32; + uint32_t ui32; + }; + + ui32 = static_cast(value_); + ui32 <<= 16; // 将ui32左移16位 return f32; } diff --git a/debug/accuracy_tools/msprobe/ccsrc/utils/DataUtils.hpp b/debug/accuracy_tools/msprobe/ccsrc/utils/DataUtils.hpp index f58e15a8c77719f62ddeef8ebbcd25a5b5ebf624..5dbb83bb35e35a5e324707a9f2bacccdecd6f4b3 100644 --- a/debug/accuracy_tools/msprobe/ccsrc/utils/DataUtils.hpp +++ b/debug/accuracy_tools/msprobe/ccsrc/utils/DataUtils.hpp @@ -14,7 +14,8 @@ * limitations under the License. */ -#pragma once +#ifndef DATAUTILS_H +#define DATAUTILS_H #include #include @@ -166,4 +167,6 @@ std::string GetFormatString(TensorFormat fmt); std::string GetShapeString(const TensorShape& shape); } -} \ No newline at end of file +} + +#endif \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/ccsrc/utils/FileOperation.cpp b/debug/accuracy_tools/msprobe/ccsrc/utils/FileOperation.cpp index 7f025e568abdfe95830902d1e72bdb77300f7de5..35342914369fa6f34e4a2caa646b4bf8b26ab5bf 100644 --- a/debug/accuracy_tools/msprobe/ccsrc/utils/FileOperation.cpp +++ b/debug/accuracy_tools/msprobe/ccsrc/utils/FileOperation.cpp @@ -100,7 +100,8 @@ inline static std::vector NpyLen2Bytes(size_t length, size_t lengthLen) { return buff; } -static std::string GenerateNpyHeader(const DataUtils::TensorShape &shape, DataUtils::DataType dt, bool fortranOrder=false) +static std::string GenerateNpyHeader(const DataUtils::TensorShape &shape, + DataUtils::DataType dt, bool fortranOrder = false) { auto typeDesc = npyTypeDescMap.find(dt); if (typeDesc == npyTypeDescMap.end()) { diff --git a/debug/accuracy_tools/msprobe/ccsrc/utils/FileUtils.cpp b/debug/accuracy_tools/msprobe/ccsrc/utils/FileUtils.cpp index 246f899690ccd0e306f5b6b550870406086430cc..8c3cd20883d26d68a5e3504bec47a9c3d76d3023 100644 --- a/debug/accuracy_tools/msprobe/ccsrc/utils/FileUtils.cpp +++ b/debug/accuracy_tools/msprobe/ccsrc/utils/FileUtils.cpp @@ -600,63 +600,5 @@ DebuggerErrno CheckFileBeforeCreateOrWrite(const std::string &path, bool overwri } return DebuggerErrno::OK; } - -/* 其他文件操作工具 */ -static DebuggerErrno ListAllAux(const std::string &path, std::vector& output, uint32_t depth) -{ - if (depth > PATH_DEPTH_MAX) { - return DebuggerErrno::ERROR_PATH_TOO_DEEP; - } - - DIR* dir = opendir(path.c_str()); - if (dir == nullptr) { - return DebuggerErrno::ERROR_FAILED_TO_OPEN_FILE; - } - - DebuggerErrno ret = DebuggerErrno::OK; - size_t max = output.capacity(); - size_t num = output.size(); - if (num >= max) { - return DebuggerErrno::OK; - } - - struct dirent* entry = nullptr; - while ((entry = readdir(dir)) != nullptr) { - if (strcmp(entry->d_name, ".") == 0 || (strcmp(entry->d_name, "..") == 0)) { - continue; - } - std::string entryPath = path + "/" + entry->d_name; - if (entry->d_type == DT_DIR) { - ret = ListAllAux(entryPath, output, depth + 1); - if (ret != DebuggerErrno::OK) { - closedir(dir); - return ret; - } - } else if (entry->d_type == DT_REG) { - output.emplace_back(entryPath); - if (++num >= max) { - break; - } - } - } - closedir(dir); - return DebuggerErrno::OK; -} - -std::vector ListAll(const std::string &path, size_t max) -{ - std::vector ret; - std::string realPath = GetAbsPath(path); - if (CheckDirCommon(realPath) != DebuggerErrno::OK) { - return ret; - } - ret.reserve(max); - - uint32_t depth = std::count(realPath.begin(), realPath.end(), pathSeparator); - ListAllAux(realPath, ret, depth); - ret.resize(ret.size()); - return ret; -} - } } \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/ccsrc/utils/FileUtils.hpp b/debug/accuracy_tools/msprobe/ccsrc/utils/FileUtils.hpp index 70b47137fc40fd7fb73be11ddb8d3551550e2b8d..f944b606747c45b1e6c9dc86d74aa5401f6014ab 100644 --- a/debug/accuracy_tools/msprobe/ccsrc/utils/FileUtils.hpp +++ b/debug/accuracy_tools/msprobe/ccsrc/utils/FileUtils.hpp @@ -64,7 +64,7 @@ constexpr const uint32_t FILE_NAME_MAX = 255; /* 基础检查函数库,不做过多校验,路径有效性由调用者保证 */ bool IsPathExist(const std::string& path); -std::vector SplitPath(const std::string &path, char separator=pathSeparator); +std::vector SplitPath(const std::string &path, char separator = pathSeparator); std::string GetAbsPath(const std::string &path); bool IsDir(const std::string& path); bool IsRegularFile(const std::string& path); @@ -85,23 +85,19 @@ bool IsFileOwner(const std::string& path); /* 文件操作函数库,会对入参做基本检查 */ DebuggerErrno DeleteFile(const std::string &path); -DebuggerErrno DeleteDir(const std::string &path, bool recursion=false); -DebuggerErrno CreateDir(const std::string &path, bool recursion=false, mode_t mode=NORMAL_DIR_MODE_DEFAULT); +DebuggerErrno DeleteDir(const std::string &path, bool recursion = false); +DebuggerErrno CreateDir(const std::string &path, bool recursion = false, mode_t mode = NORMAL_DIR_MODE_DEFAULT); DebuggerErrno Chmod(const std::string& path, const mode_t& mode); DebuggerErrno GetFileSize(const std::string &path, size_t& size); -DebuggerErrno OpenFile(const std::string& path, std::ifstream& ifs, std::ios::openmode mode=std::ios::in); -DebuggerErrno OpenFile(const std::string& path, std::ofstream& ofs, std::ios::openmode mode=std::ios::out, - mode_t permission=NORMAL_FILE_MODE_DEFAULT); +DebuggerErrno OpenFile(const std::string& path, std::ifstream& ifs, std::ios::openmode mode = std::ios::in); +DebuggerErrno OpenFile(const std::string& path, std::ofstream& ofs, std::ios::openmode mode = std::ios::out, + mode_t permission = NORMAL_FILE_MODE_DEFAULT); /* 通用检查函数 */ DebuggerErrno CheckFileSuffixAndSize(const std::string &path, FileType type); DebuggerErrno CheckDirCommon(const std::string &path); DebuggerErrno CheckFileBeforeRead(const std::string &path, const std::string& authority="r", - FileType type=FileType::COMMON); -DebuggerErrno CheckFileBeforeCreateOrWrite(const std::string &path, bool overwrite=false); - -/* 其他文件操作工具 */ -std::vector ListAll(const std::string &path, size_t max = 1024); - + FileType type = FileType::COMMON); +DebuggerErrno CheckFileBeforeCreateOrWrite(const std::string &path, bool overwrite = false); } } \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/ccsrc/utils/MathUtils.hpp b/debug/accuracy_tools/msprobe/ccsrc/utils/MathUtils.hpp index 141471ac8ce284ac1a7ab4b6db59f5d0da9a9fe2..d11fdf338706513d59da54906b1cb8def00a3013 100644 --- a/debug/accuracy_tools/msprobe/ccsrc/utils/MathUtils.hpp +++ b/debug/accuracy_tools/msprobe/ccsrc/utils/MathUtils.hpp @@ -62,7 +62,7 @@ T AlignCeil(T v, T block) float Random(); float Random(float floor, float ceil); int32_t RandomInt(int32_t floor, int32_t ceil); -std::string RandomString(uint32_t len, char min=' ', char max='~'); +std::string RandomString(uint32_t len, char min = ' ', char max = '~'); std::string CalculateMD5(const uint8_t* data, size_t length); diff --git a/debug/accuracy_tools/msprobe/config.json b/debug/accuracy_tools/msprobe/config.json index 553b7f9ee3b89215647b00fb14b70af44ea5f00c..9bf9579b80770210bdda668b782a41540e7cb763 100644 --- a/debug/accuracy_tools/msprobe/config.json +++ b/debug/accuracy_tools/msprobe/config.json @@ -25,7 +25,9 @@ "run_ut": { "white_list": [], "black_list": [], - "error_data_path": "./" + "error_data_path": "./", + "master_ip": "127.0.0.1", + "master_port": "8888" }, "grad_probe": { "grad_level": "L1", diff --git a/debug/accuracy_tools/msprobe/core/common/const.py b/debug/accuracy_tools/msprobe/core/common/const.py index d9623b807121ea129484a535fe8a9e2293e662f3..b46144e5c94482210751fc170d3eafe57c3e9f5e 100644 --- a/debug/accuracy_tools/msprobe/core/common/const.py +++ b/debug/accuracy_tools/msprobe/core/common/const.py @@ -27,6 +27,8 @@ class Const: ipv4_pattern = "([1-9]?\d|1\d{2}|2[0-4]\d|25[0-5])(\.([1-9]?\d|1\d{2}|2[0-4]\d|25[0-5])){3}$" SEP = "." + COLON = ":" + DOUBLE_SLASH = "//" REGEX_PREFIX_MAX_LENGTH = 20 REGEX_PREFIX_PATTERN = r"^[a-zA-Z0-9_-]+$" REGEX_FORWARD_BACKWARD = r'\.(forward|backward)\.' @@ -51,7 +53,10 @@ class Const: FOUR_SEGMENT = 4 SIX_SEGMENT = 6 SEVEN_SEGMENT = 7 + MAX_DEPTH = 10 + CPU_QUARTER = 4 + DUMP_MAX_DEPTH = 50 # dump mode ALL = "all" @@ -67,7 +72,7 @@ class Const: SUMMARY = "summary" MD5 = "md5" VALUE = "value" - SUMMARY_MODE = [ALL, SUMMARY, MD5] + SUMMARY_MODE = ["statistics", "md5"] WRITE_FLAGS = os.O_WRONLY | os.O_CREAT WRITE_MODES = stat.S_IWUSR | stat.S_IRUSR @@ -77,6 +82,8 @@ class Const: NUMPY_SUFFIX = ".npy" NUMPY_PATTERN = "*.npy" PT_SUFFIX = ".pt" + PY_SUFFIX = ".py" + INIT_PY = "init.py" ONE_GB = 1073741824 # 1 * 1024 * 1024 * 1024 TEN_GB = 10737418240 # 10 * 1024 * 1024 * 1024 ONE_MB = 1048576 # 1 * 1024 * 1024 @@ -92,6 +99,7 @@ class Const: GRAD_OUTPUT = 'grad_output' PARAMS = 'parameters' PARAMS_GRAD = 'parameters_grad' + DEBUG = 'debug' START = "start" STOP = "stop" ENV_ENABLE = "1" @@ -163,10 +171,15 @@ class Const: LEFT_MOVE_INDEX = -1 RIGHT_MOVE_INDEX = 1 LAST_INDEX = -1 + MAX_TRAVERSAL_DEPTH = 5 TOP_LAYER = "TopLayer" CELL = "Cell" MODULE = "Module" + API = "api" + PYNATIVE_MODE = "pynative" + PYNATIVE_GRAPH_MODE = "pynative_graph" + FRAME_FILE_LIST = ["site-packages/torch", "package/torch", "site-packages/mindspore", "package/mindspore"] INPLACE_LIST = [ "broadcast", "all_reduce", "reduce", "all_gather", "gather", "scatter", "reduce_scatter", @@ -188,7 +201,11 @@ class Const: FILL_CHAR_NUMS = 50 TOOL_ENDS_SUCCESSFULLY = f"{TOOL_NAME} ends successfully." + WITHOUT_CALL_STACK = "The call stack retrieval failed." + STACK_FILTER_KEYWORDS = ["msprobe/core", "msprobe/pytorch", "msprobe/mindspore"] + CALL_STACK_FLAG = "data_dump/api_registry" + NEW_STACK_FLAG = "0" STEP = "step" RANK = "rank" @@ -206,12 +223,16 @@ class Const: TORCH_FLOAT32 = "torch.float32" TORCH_BFLOAT16 = "torch.bfloat16" + TYPE = 'type' DTYPE = 'dtype' SHAPE = 'shape' + STACK_INFO = 'stack_info' MAX = 'Max' MIN = 'Min' MEAN = 'Mean' NORM = 'Norm' + DATA_NAME = 'data_name' + TENSOR_STAT_INDEX = 'tensor_stat_index' CODE_STACK = 'Code Stack' OP_NAME = 'Op Name' @@ -224,12 +245,133 @@ class Const: SCOPE_SEPARATOR = "/" REPLACEMENT_CHARACTER = "_" + FORWARD_PATTERN = SEP + FORWARD + SEP + BACKWARD_PATTERN = SEP + BACKWARD + SEP + OPTIMIZER = "optimizer" CLIP_GRAD = "clip_grad" END_PREFIX = "end_" TENSOR_STAT_LEN = 2 + SUPPORT_API_FILE_NAME = "support_wrap_ops.yaml" + + PT_API_TYPE_FUNCTIONAL = "functional" + PT_API_TYPE_TENSOR = "tensor" + PT_API_TYPE_TORCH = "torch" + PT_API_TYPE_VF = "_VF" + PT_API_TYPE_NPU = "torch_npu" + PT_API_TYPE_ATEN = "aten" + PT_API_TYPE_DIST = "distributed" + PT_API_TYPE_NPU_DIST = "npu_distributed" + PT_API_TYPE_MINDSPEED = "mindspeed" + + MS_API_TYPE_OPS = "ops" + MS_API_TYPE_TENSOR = "tensor" + MS_API_TYPE_STUB_TENSOR = "stubtensor" + MS_API_TYPE_MINT = "mint.ops" + MS_API_TYPE_MINT_FUNC = "mint.nn.functional" + MS_API_TYPE_COM = "communication.comm_func" + MS_API_TYPE_MINT_DIST = "mint.distributed" + + FUNCTIONAL_API_TYPE_PREFIX = "Functional" + TENSOR_API_TYPE_PREFIX = "Tensor" + DIST_API_TYPE_PREFIX = "Distributed" + + TORCH_API_TYPE_PREFIX = "Torch" + NPU_API_TYPE_PREFIX = "NPU" + ATEN_API_TYPE_PREFIX = "Aten" + VF_API_TYPE_PREFIX = "VF" + MINDSPEED_API_TYPE_PREFIX = "MindSpeed" + + MINT_API_TYPE_PREFIX = "Mint" + MINT_FUNC_API_TYPE_PREFIX = "MintFunctional" + MINT_DIST_API_TYPE_PREFIX = "MintDistributed" + + SUPPORT_API_DICT_KEY_MAP = { + PT_FRAMEWORK: { + PT_API_TYPE_FUNCTIONAL: PT_API_TYPE_FUNCTIONAL, + PT_API_TYPE_TENSOR: PT_API_TYPE_TENSOR, + PT_API_TYPE_TORCH: PT_API_TYPE_TORCH, + PT_API_TYPE_VF: PT_API_TYPE_VF, + PT_API_TYPE_NPU: PT_API_TYPE_NPU, + PT_API_TYPE_ATEN: PT_API_TYPE_ATEN, + PT_API_TYPE_DIST: PT_API_TYPE_DIST, + PT_API_TYPE_NPU_DIST: PT_API_TYPE_NPU_DIST, + PT_API_TYPE_MINDSPEED: PT_API_TYPE_MINDSPEED + }, + MS_FRAMEWORK: { + MS_API_TYPE_OPS: MS_API_TYPE_OPS, + MS_API_TYPE_TENSOR: MS_API_TYPE_TENSOR, + MS_API_TYPE_STUB_TENSOR: MS_API_TYPE_TENSOR, + MS_API_TYPE_MINT: MS_API_TYPE_MINT, + MS_API_TYPE_MINT_FUNC: MS_API_TYPE_MINT_FUNC, + MS_API_TYPE_COM: MS_API_TYPE_COM, + MS_API_TYPE_MINT_DIST: MS_API_TYPE_MINT_DIST + }, + MT_FRAMEWORK: { + PT_API_TYPE_FUNCTIONAL: PT_API_TYPE_FUNCTIONAL, + PT_API_TYPE_TENSOR: PT_API_TYPE_TENSOR, + PT_API_TYPE_TORCH: PT_API_TYPE_TORCH, + PT_API_TYPE_NPU: PT_API_TYPE_NPU, + PT_API_TYPE_DIST: PT_API_TYPE_DIST + } + } + + API_DATA_PREFIX = { + PT_FRAMEWORK: { + PT_API_TYPE_FUNCTIONAL: FUNCTIONAL_API_TYPE_PREFIX, + PT_API_TYPE_TENSOR: TENSOR_API_TYPE_PREFIX, + PT_API_TYPE_TORCH: TORCH_API_TYPE_PREFIX, + PT_API_TYPE_VF: VF_API_TYPE_PREFIX, + PT_API_TYPE_NPU: NPU_API_TYPE_PREFIX, + PT_API_TYPE_ATEN: ATEN_API_TYPE_PREFIX, + PT_API_TYPE_DIST: DIST_API_TYPE_PREFIX, + PT_API_TYPE_NPU_DIST: DIST_API_TYPE_PREFIX, + PT_API_TYPE_MINDSPEED: MINDSPEED_API_TYPE_PREFIX + }, + MS_FRAMEWORK: { + MS_API_TYPE_OPS: FUNCTIONAL_API_TYPE_PREFIX, + MS_API_TYPE_TENSOR: TENSOR_API_TYPE_PREFIX, + MS_API_TYPE_STUB_TENSOR: TENSOR_API_TYPE_PREFIX, + MS_API_TYPE_MINT: MINT_API_TYPE_PREFIX, + MS_API_TYPE_MINT_FUNC: MINT_FUNC_API_TYPE_PREFIX, + MS_API_TYPE_COM: DIST_API_TYPE_PREFIX, + MS_API_TYPE_MINT_DIST: MINT_DIST_API_TYPE_PREFIX + }, + MT_FRAMEWORK: { + PT_API_TYPE_FUNCTIONAL: FUNCTIONAL_API_TYPE_PREFIX, + PT_API_TYPE_TENSOR: TENSOR_API_TYPE_PREFIX, + PT_API_TYPE_TORCH: TORCH_API_TYPE_PREFIX, + PT_API_TYPE_NPU: NPU_API_TYPE_PREFIX, + PT_API_TYPE_DIST: DIST_API_TYPE_PREFIX + } + } + + def _fused_adamw_( + self, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + *, + lr, + beta1, + beta2, + weight_decay, + eps, + amsgrad, + maximize, + grad_scale=None, + found_inf=None + ): + pass + + API_WITH_SELF_ARG = { + 'Torch._fused_adamw_': _fused_adamw_ + } + class CompareConst: """ @@ -256,6 +398,7 @@ class CompareConst: MEAN_DIFF = "Mean diff" NORM_DIFF = "L2norm diff" COSINE = "Cosine" + EUC_DIST = "EucDist" MAX_ABS_ERR = "MaxAbsErr" MAX_RELATIVE_ERR = "MaxRelativeErr" MIN_RELATIVE_ERR = "MinRelativeErr" @@ -330,8 +473,8 @@ class CompareConst: ULP_ERR_STATUS = "ulp_err_status" COMPARE_RESULT_HEADER = [ - NPU_NAME, BENCH_NAME, NPU_DTYPE, BENCH_DTYPE, NPU_SHAPE, BENCH_SHAPE, COSINE, MAX_ABS_ERR, MAX_RELATIVE_ERR, - ONE_THOUSANDTH_ERR_RATIO, FIVE_THOUSANDTHS_ERR_RATIO, + NPU_NAME, BENCH_NAME, NPU_DTYPE, BENCH_DTYPE, NPU_SHAPE, BENCH_SHAPE, COSINE, EUC_DIST, + MAX_ABS_ERR, MAX_RELATIVE_ERR, ONE_THOUSANDTH_ERR_RATIO, FIVE_THOUSANDTHS_ERR_RATIO, NPU_MAX, NPU_MIN, NPU_MEAN, NPU_NORM, BENCH_MAX, BENCH_MIN, BENCH_MEAN, BENCH_NORM, ACCURACY, ERROR_MESSAGE ] @@ -357,18 +500,16 @@ class CompareConst: Const.MD5: MD5_COMPARE_RESULT_HEADER } - ALL_COMPARE_INDEX = [COSINE, MAX_ABS_ERR, MAX_RELATIVE_ERR, ONE_THOUSANDTH_ERR_RATIO, FIVE_THOUSANDTHS_ERR_RATIO] + ALL_COMPARE_INDEX = [COSINE, EUC_DIST, MAX_ABS_ERR, MAX_RELATIVE_ERR, ONE_THOUSANDTH_ERR_RATIO, + FIVE_THOUSANDTHS_ERR_RATIO] SUMMARY_COMPARE_INDEX = [MAX_DIFF, MIN_DIFF, MEAN_DIFF, NORM_DIFF, MAX_RELATIVE_ERR, MIN_RELATIVE_ERR, MEAN_RELATIVE_ERR, NORM_RELATIVE_ERR] # dtype match - MS_TYPE = [ - [Const.FLOAT16, Const.FLOAT32], [Const.FLOAT32, Const.FLOAT16], - [Const.FLOAT16, Const.BFLOAT16], [Const.BFLOAT16, Const.FLOAT16] - ] - TORCH_TYPE = [ - [Const.TORCH_FLOAT16, Const.TORCH_FLOAT32], [Const.TORCH_FLOAT32, Const.TORCH_FLOAT16], - [Const.TORCH_FLOAT16, Const.TORCH_BFLOAT16], [Const.TORCH_BFLOAT16, Const.TORCH_FLOAT16] + + DTYPE_MATCH_GROUPS = [ + {Const.FLOAT16, Const.FLOAT32, Const.BFLOAT16}, + {Const.TORCH_FLOAT16, Const.TORCH_FLOAT32, Const.TORCH_BFLOAT16} ] # read_op @@ -389,13 +530,6 @@ class CompareConst: Const.PARAMS_GRAD: PARAMS_GRAD_STRUCT } - STRUCT_COMPARE_KEY = [ - INPUT_STRUCT, - OUTPUT_STRUCT, - PARAMS_STRUCT, - PARAMS_GRAD_STRUCT - ] - # compare standard HUNDRED_RATIO_THRESHOLD = 0.01 THOUSAND_RATIO_THRESHOLD = 0.001 @@ -467,22 +601,42 @@ class CompareConst: BENCH_MEAN: None, BENCH_NORM: None, ACCURACY: '', ERROR_MESSAGE: '' } MS_GRAPH_NPY = { - COSINE: None, MAX_ABS_ERR: None, MAX_RELATIVE_ERR: None, ONE_THOUSANDTH_ERR_RATIO: None, + COSINE: None, EUC_DIST: None, MAX_ABS_ERR: None, MAX_RELATIVE_ERR: None, ONE_THOUSANDTH_ERR_RATIO: None, FIVE_THOUSANDTHS_ERR_RATIO: None } MS_GRAPH_STATISTIC = { MAX_DIFF: None, MIN_DIFF: None, MEAN_DIFF: None, NORM_DIFF: None, MAX_RELATIVE_ERR: None, MIN_RELATIVE_ERR: None, MEAN_RELATIVE_ERR: None, NORM_RELATIVE_ERR: None } + + API_MAPPING_KEYS_TO_COMPARE = [ + ('ms_args', 'pt_args'), + ('ms_outputs', 'pt_outputs'), + ('ms_parameters', 'pt_parameters'), + ('ms_parameters_grad', 'pt_parameters_grad') + ] + INPUT_PATTERN = Const.SEP + Const.INPUT + Const.SEP KWARGS_PATTERN = Const.SEP + Const.KWARGS + Const.SEP OUTPUT_PATTERN = Const.SEP + Const.OUTPUT + Const.SEP PARAMS_PATTERN = Const.SEP + Const.PARAMS + Const.SEP PARAMS_GRAD_PATTERN = Const.SEP + Const.PARAMS_GRAD + Const.SEP - COMPARE_KEY = 'compare_key' - COMPARE_SHAPE = 'compare_shape' + + CMP_KEY = 'compare_key' + CMP_SHAPE = 'compare_shape' + + OP_NAME_X = 'op_name_x' + MATCH_RESULT_COLUMNS = [ + OP_NAME_X, 'dtype_x', 'shape_x', 'summary_x', 'stack_info_x', 'data_name_x', + CMP_KEY, CMP_SHAPE, + 'op_name_y', 'dtype_y', 'shape_y', 'summary_y', 'stack_info_y', 'data_name_y', + ] + INTERNAL_API_MAPPING_FILE = 'ms_to_pt_api.yaml' UNREADABLE = 'unreadable data' + NPU_DUMP_DATA_DIR = 'npu_dump_data_dir' + BENCH_DUMP_DATA_DIR = 'bench_dump_data_dir' + NO_REAL_DATA_FLAG = '-1' class FileCheckConst: @@ -504,6 +658,7 @@ class FileCheckConst: XLSX_SUFFIX = ".xlsx" YAML_SUFFIX = ".yaml" IR_SUFFIX = ".ir" + ZIP_SUFFIX = ".zip" MAX_PKL_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024 MAX_NUMPY_SIZE = 10737418240 # 10 * 1024 * 1024 * 1024 MAX_JSON_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024 @@ -512,6 +667,9 @@ class FileCheckConst: MAX_XLSX_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024 MAX_YAML_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024 MAX_IR_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024 + MAX_ZIP_SIZE = 10737418240 # 10 * 1024 * 1024 * 1024 + MAX_FILE_IN_ZIP_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024 + MAX_FILE_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024 COMMOM_FILE_SIZE = 1048576 # 1 * 1024 * 1024 DIR = "dir" FILE = "file" @@ -525,7 +683,8 @@ class FileCheckConst: CSV_SUFFIX: MAX_CSV_SIZE, XLSX_SUFFIX: MAX_XLSX_SIZE, YAML_SUFFIX: MAX_YAML_SIZE, - IR_SUFFIX: MAX_IR_SIZE + IR_SUFFIX: MAX_IR_SIZE, + ZIP_SUFFIX: MAX_ZIP_SIZE } CSV_BLACK_LIST = r'^[+-=%@\+\-=%@]|;[+-=%@\+\-=%@]' @@ -538,61 +697,6 @@ class OverflowConst: OVERFLOW_DEBUG_MODE = 1 -class MsCompareConst: - # api_info field - MINT = "Mint" - MINT_FUNCTIONAL = "MintFunctional" - TENSOR_API = "Tensor" - - API_NAME_STR_LENGTH = 4 - MAX_RECURSION_DEPTH = 20 - - # Mindtorch api_info field - MINDTORCH_TENSOR = "Tensor" - MINDTORCH = "Torch" - MINDTORCH_FUNC = "Functional" - MINDTORCH_NPU = "NPU" - MINDTORCH_DIST = "Distributed" - - - - MT_VALID_API_TYPES = [ - MINDTORCH, MINDTORCH_FUNC, MINDTORCH_TENSOR - ] - - TASK_FIELD = "task" - STATISTICS_TASK = "statistics" - FRAMEWORK = "framework" - TENSOR_TASK = "tensor" - DUMP_DATA_DIR_FIELD = "dump_data_dir" - DATA_FIELD = "data" - - # supported api yaml - SUPPORTED_API_LIST_FILE = "checker_support_api.yaml" - SUPPORTED_TENSOR_LIST_KEY = "tensor" - - # detail_csv - DETAIL_CSV_API_NAME = "API Name" - DETAIL_CSV_BENCH_DTYPE = "Bench Dtype" - DETAIL_CSV_TESTED_DTYPE = "Tested Dtype" - DETAIL_CSV_SHAPE = "Shape" - DETAIL_CSV_PASS_STATUS = "Status" - DETAIL_CSV_MESSAGE = "Message" - DETAIL_CSV_FILE_NAME = "accuracy_checking_details" - - # result_csv - RESULT_CSV_FORWARD_TEST_SUCCESS = "Forward Test Success" - RESULT_CSV_BACKWARD_TEST_SUCCESS = "Backward Test Success" - RESULT_CSV_FILE_NAME = "accuracy_checking_result" - - EPSILON = 1e-8 - - class ProcessStatus: - SUCCESS = "success" - API_NOT_FOUND = "api_not_found" - EXCEPTION_SKIP = "exception_skip" - - class MsgConst: """ Class for log messages const @@ -629,7 +733,16 @@ class MonitorConst: """ Class for monitor const """ - OP_LIST = ["norm", "min", "max", "zeros", "nans", "id", "mean"] + + # monitor config set default values + DEFAULT_GRAD_ACC_STEPS = 1 + DEFAULT_START_ITERATION = 0 + DEFAULT_START_STEP = 0 + DEFAULT_MAX_COLLECT_TIMES = 1e8 + DEFAULT_MIN_COLLECT_TIMES = 0 + DEFAULT_STEP_INTERVAL = 1 + + OP_LIST = ["norm", "min", "max", "zeros", "nans", "id", "mean", "shape", "dtype"] MONITOR_OUTPUT_DIR = "MONITOR_OUTPUT_DIR" DEFAULT_MONITOR_OUTPUT_DIR = "./monitor_output" DATABASE = "database" @@ -641,7 +754,7 @@ class MonitorConst: "DeepSpeedZeroOptimizer_Stage3" ) DEEPSPEED_ZERO_OPT_FILTER = "DeepSpeedZeroOptimizer" - RULE_NAME = ['AnomalyTurbulence'] + RULE_NAME = ['AnomalyTurbulence', 'AnomalyNan'] SLICE_SIZE = 20480 # used for name @@ -664,9 +777,11 @@ class MonitorConst: EXP_AVG = "exp_avg" EXP_AVG_SQ = "exp_avg_sq" PARAM = "param" + PRE_PARAM = "param_origin" + POST_PARAM = "param_updated" CSV_HEADER = ["vpp_stage", "name", "step"] - CSV_HEADER_XY = ["vpp_stage", "name", "step", "micro_step"] + CSV_HEADER_MICRO_STEP = ["vpp_stage", "name", "step", "micro_step"] OUTPUT_DIR_PATTERN = r"([\w-]{0,20})-rank(\d{1,5})-" ANOMALY_JSON = "anomaly.json" ANALYSE_JSON = "anomaly_analyse.json" @@ -674,3 +789,86 @@ class MonitorConst: CSV = "csv" API = "api" HEADER_NAME = 'name' + + MAX_NDIGITS = 20 + + DEFAULT_STAGE = -1 + FORWARD_STAGE = 0 + BACKWARD_STAGE = 1 + OPTIMIZER_STAGE = 2 + FORWARD_KEY = [ACTV] + BACKWARD_KEY = [ACTVGRAD, PRE_GRAD, POST_GRAD, ACC_GRAD] + OPTIMIZER_KEY = [EXP_AVG, EXP_AVG_SQ] + + CAL_SIM_SEQ_LENGTH = 1024 + CAL_SIM_H_CHUNK_RATIO = 4 + + TRAIN_STAGE = {} + for key in FORWARD_KEY: + TRAIN_STAGE[key] = FORWARD_STAGE + for key in BACKWARD_KEY: + TRAIN_STAGE[key] = BACKWARD_STAGE + for key in OPTIMIZER_KEY: + TRAIN_STAGE[key] = OPTIMIZER_STAGE + + +class DistributedCheckConst: + API_FULL_NAME = "api_full_name" + API_NAME = "api_name" + GROUP = "group" + GROUP_RANKS = "group_ranks" + GROUP_INDEX = "group_index" + SRC = "src" + SRC_INDEX = "src_index" + OP = "op" + SCATTER_LIST = "scatter_list" + TORCH_PROCESS_GROUP = "torch.ProcessGroup" + ALL_ARGS = "all_args" + ALL_KWARGS = "all_kwargs" + RESULT_FILE_PATH = "result_file_path" + BENCHMARK_RESULT = "benchmark_result" + MASTER_IP = "master_ip" + MASTER_PORT = "master_port" + WORLD_SIZE = "world_size" + HCCL = "hccl" + TCP = "tcp" + BROADCAST = "broadcast" + REDUCE = "reduce" + ALL_REDUCE = "all_reduce" + SCATTER = "scatter" + GATHER = "gather" + ALL_GATHER = "all_gather" + ALL_TO_ALL = "all_to_all" + ALL_TO_ALL_SINGLE = "all_to_all_single" + BROADCAST_SRC_INDEX = 1 + FIRST_TENSOR_INDEX = 0 + MAX_CUMSUM_CHECK_NUM = 1000 + + REDOPTYPE_SUM = "RedOpType.SUM" + REDOPTYPE_PRODUCT = "RedOpType.PRODUCT" + REDOPTYPE_MIN = "RedOpType.MIN" + REDOPTYPE_MAX = "RedOpType.MAX" + REDOPTYPE_BAND = "RedOpType.BAND" + REDOPTYPE_BOR = "RedOpType.BOR" + REDOPTYPE_BXOR = "RedOpType.BXOR" + + API_ARGS_INDEX = { + "broadcast": { + "group": 2, + "src": 1 + }, + "reduce": { + "op": 2, + "dst": 1 + }, + "all_reduce": { + "reduce_op": 2 + }, + "scatter": { + "src": 2, + "scatter_list": 1 + }, + "gather": { + "dst": 2 + } + } diff --git a/debug/accuracy_tools/msprobe/core/common/decorator.py b/debug/accuracy_tools/msprobe/core/common/decorator.py new file mode 100644 index 0000000000000000000000000000000000000000..d3710002bcc281be2fd0f19fc7abda1af35ec936 --- /dev/null +++ b/debug/accuracy_tools/msprobe/core/common/decorator.py @@ -0,0 +1,50 @@ +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import defaultdict +from functools import wraps + +from msprobe.core.common.const import Const +from msprobe.core.common.exceptions import MsprobeException +from msprobe.core.common.log import logger + +# 记录工具函数递归的深度 +recursion_depth = defaultdict(int) + + +def recursion_depth_decorator(func_info, max_depth=Const.MAX_DEPTH): + """装饰一个函数,当函数递归调用超过限制时,抛出异常并打印函数信息。""" + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + func_id = id(func) + recursion_depth[func_id] += 1 + if recursion_depth[func_id] > max_depth: + msg = f"call {func_info} exceeds the recursion limit." + logger.error_log_with_exp( + msg, + MsprobeException( + MsprobeException.RECURSION_LIMIT_ERROR, msg + ), + ) + try: + result = func(*args, **kwargs) + finally: + recursion_depth[func_id] -= 1 + return result + + return wrapper + + return decorator diff --git a/debug/accuracy_tools/msprobe/core/common/exceptions.py b/debug/accuracy_tools/msprobe/core/common/exceptions.py index d71d30224b677fb19361f62de0ee25b2d32d389f..f4bb39db4c9534a6fe553d75e396dc5b91d71f33 100644 --- a/debug/accuracy_tools/msprobe/core/common/exceptions.py +++ b/debug/accuracy_tools/msprobe/core/common/exceptions.py @@ -21,19 +21,21 @@ class CodedException(Exception): def __str__(self): return self.error_info - - + + class MsprobeException(CodedException): INVALID_PARAM_ERROR = 0 OVERFLOW_NUMS_ERROR = 1 RECURSION_LIMIT_ERROR = 2 INTERFACE_USAGE_ERROR = 3 + UNSUPPORTED_TYPE_ERROR = 4 err_strs = { INVALID_PARAM_ERROR: "[msprobe] 无效参数:", OVERFLOW_NUMS_ERROR: "[msprobe] 超过预设溢出次数 当前溢出次数:", RECURSION_LIMIT_ERROR: "[msprobe] 递归调用超过限制:", - INTERFACE_USAGE_ERROR: "[msprobe] Invalid interface usage: " + INTERFACE_USAGE_ERROR: "[msprobe] Invalid interface usage: ", + UNSUPPORTED_TYPE_ERROR: "[msprobe] Unsupported type: " } diff --git a/debug/accuracy_tools/msprobe/core/common/file_utils.py b/debug/accuracy_tools/msprobe/core/common/file_utils.py index fdc626ca6a1a90e9060cefa237f9d5d8d7e42844..8cb3b6f614439328cb1b5b41ea7c7866be38c222 100644 --- a/debug/accuracy_tools/msprobe/core/common/file_utils.py +++ b/debug/accuracy_tools/msprobe/core/common/file_utils.py @@ -13,22 +13,31 @@ # See the License for the specific language governing permissions and # limitations under the License. +import atexit import csv import fcntl +import io import os +import pickle +from multiprocessing import shared_memory import stat import json import re import shutil -from datetime import datetime, timezone -from dateutil import parser +import sys +import zipfile +import multiprocessing import yaml import numpy as np import pandas as pd +from msprobe.core.common.decorator import recursion_depth_decorator from msprobe.core.common.log import logger from msprobe.core.common.exceptions import FileCheckException -from msprobe.core.common.const import FileCheckConst +from msprobe.core.common.const import FileCheckConst, CompareConst +from msprobe.core.common.global_lock import global_lock, is_main_process + +proc_lock = multiprocessing.Lock() class FileChecker: @@ -245,8 +254,8 @@ def check_path_type(file_path, file_type): def check_others_writable(directory): dir_stat = os.stat(directory) is_writable = ( - bool(dir_stat.st_mode & stat.S_IWGRP) or # 组可写 - bool(dir_stat.st_mode & stat.S_IWOTH) # 其他用户可写 + bool(dir_stat.st_mode & stat.S_IWGRP) or # 组可写 + bool(dir_stat.st_mode & stat.S_IWOTH) # 其他用户可写 ) return is_writable @@ -266,6 +275,7 @@ def make_dir(dir_path): file_check.common_check() +@recursion_depth_decorator('msprobe.core.common.file_utils.create_directory', max_depth=16) def create_directory(dir_path): """ Function Description: @@ -297,13 +307,14 @@ def check_path_before_create(path): def check_dirpath_before_read(path): path = os.path.realpath(path) dirpath = os.path.dirname(path) - if check_others_writable(dirpath): - logger.warning(f"The directory is writable by others: {dirpath}.") - try: - check_path_owner_consistent(dirpath) - except FileCheckException: - logger.warning(f"The directory {dirpath} is not yours.") - + if dedup_log('check_dirpath_before_read', dirpath): + if check_others_writable(dirpath): + logger.warning(f"The directory is writable by others: {dirpath}.") + try: + check_path_owner_consistent(dirpath) + except FileCheckException: + logger.warning(f"The directory {dirpath} is not yours.") + def check_file_or_directory_path(path, isdir=False): """ @@ -332,6 +343,23 @@ def change_mode(path, mode): 'Failed to change {} authority. {}'.format(path, str(ex))) from ex +@recursion_depth_decorator('msprobe.core.common.file_utils.recursive_chmod') +def recursive_chmod(path): + """ + 递归地修改目录及其子目录和文件的权限,文件修改为640,路径修改为750 + + :param path: 要修改权限的目录路径 + """ + for _, dirs, files in os.walk(path): + for file_name in files: + file_path = os.path.join(path, file_name) + change_mode(file_path, FileCheckConst.DATA_FILE_AUTHORITY) + for dir_name in dirs: + dir_path = os.path.join(path, dir_name) + change_mode(dir_path, FileCheckConst.DATA_DIR_AUTHORITY) + recursive_chmod(dir_path) + + def path_len_exceeds_limit(file_path): return len(os.path.realpath(file_path)) > FileCheckConst.DIRECTORY_LENGTH or \ len(os.path.basename(file_path)) > FileCheckConst.FILE_NAME_LENGTH @@ -427,6 +455,43 @@ def save_excel(path, data): return "list" raise ValueError("Data must be a DataFrame or a list of (DataFrame, sheet_name) pairs.") + def check_value_is_valid(value: str) -> bool: + if not isinstance(value, str): + return True + try: + # -1.00 or +1.00 should be considered as digit numbers + float(value) + except ValueError: + # otherwise, they will be considered as formular injections + return not bool(re.compile(FileCheckConst.CSV_BLACK_LIST).search(value)) + return True + + def malicious_check(df): + for row_name in df.index: + if not check_value_is_valid(row_name): + raise RuntimeError(f"Malicious value [{row_name}] not allowed to be written into the excel: {path}.") + + for col_name in df.columns: + if not check_value_is_valid(col_name): + raise RuntimeError(f"Malicious value [{col_name}] not allowed to be written into the excel: {path}.") + + for _, row in df.iterrows(): + for _, value in row.items(): + if not check_value_is_valid(value): + raise RuntimeError(f"Malicious value [{value}] not allowed to be written into the excel: {path}.") + + def save_in_slice(df, base_name): + malicious_check(df) + df_length = len(df) + if df_length < CompareConst.MAX_EXCEL_LENGTH: + df.to_excel(writer, sheet_name=base_name if base_name else 'Sheet1', index=False) + else: + slice_num = (df_length + CompareConst.MAX_EXCEL_LENGTH - 1) // CompareConst.MAX_EXCEL_LENGTH + slice_size = (df_length + slice_num - 1) // slice_num + for i in range(slice_num): + df.iloc[i * slice_size: min((i + 1) * slice_size, df_length)] \ + .to_excel(writer, sheet_name=f'{base_name}_part_{i}' if base_name else f'part_{i}', index=False) + check_path_before_create(path) path = os.path.realpath(path) @@ -434,20 +499,18 @@ def save_excel(path, data): data_type = validate_data(data) try: - if data_type == "single": - data.to_excel(path, index=False) - elif data_type == "list": - with pd.ExcelWriter(path) as writer: + with pd.ExcelWriter(path) as writer: + if data_type == "single": + save_in_slice(data, None) + elif data_type == "list": for data_df, sheet_name in data: - data_df.to_excel(writer, sheet_name=sheet_name, index=False) + save_in_slice(data_df, sheet_name) except Exception as e: logger.error(f'Save excel file "{os.path.basename(path)}" failed.') raise RuntimeError(f"Save excel file {path} failed.") from e change_mode(path, FileCheckConst.DATA_FILE_AUTHORITY) - - def move_file(src_path, dst_path): check_file_or_directory_path(src_path) check_path_before_create(dst_path) @@ -511,7 +574,7 @@ def write_csv(data, filepath, mode="a+", malicious_check=False): if not isinstance(value, str): return True try: - # -1.00 or +1.00 should be consdiered as digit numbers + # -1.00 or +1.00 should be considered as digit numbers float(value) except ValueError: # otherwise, they will be considered as formular injections @@ -557,7 +620,7 @@ def write_df_to_csv(data, filepath, mode="w", header=True, malicious_check=False if not isinstance(value, str): return True try: - # -1.00 or +1.00 should be consdiered as digit numbers + # -1.00 or +1.00 should be considered as digit numbers float(value) except ValueError: # otherwise, they will be considered as formular injections @@ -588,8 +651,11 @@ def write_df_to_csv(data, filepath, mode="w", header=True, malicious_check=False def remove_path(path): if not os.path.exists(path): return + if os.path.islink(path): + logger.error(f"Failed to delete {path}, it is a symbolic link.") + raise RuntimeError("Delete file or directory failed.") try: - if os.path.islink(path) or os.path.isfile(path): + if os.path.isfile(path): os.remove(path) else: shutil.rmtree(path) @@ -598,7 +664,7 @@ def remove_path(path): raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR) from err except Exception as e: logger.error("Failed to delete {}. Please check.".format(path)) - raise RuntimeError(f"Delete {path} failed.") from e + raise RuntimeError("Delete file or directory failed.") from e def get_json_contents(file_path): @@ -632,42 +698,236 @@ def os_walk_for_files(path, depth): return res -def check_crt_valid(pem_path): +def check_zip_file(zip_file_path): + with zipfile.ZipFile(zip_file_path, 'r') as zip_file: + total_size = 0 + if len(zip_file.infolist()) > FileCheckConst.MAX_FILE_IN_ZIP_SIZE: + raise ValueError(f"Too many files in {os.path.basename(zip_file_path)}") + for file_info in zip_file.infolist(): + if file_info.file_size > FileCheckConst.MAX_FILE_SIZE: + raise ValueError(f"File {file_info.filename} is too large to extract") + + total_size += file_info.file_size + if total_size > FileCheckConst.MAX_ZIP_SIZE: + raise ValueError(f"Total extracted size exceeds the limit of {FileCheckConst.MAX_ZIP_SIZE} bytes") + + +def read_xlsx(file_path, sheet_name=None): + check_file_or_directory_path(file_path) + check_zip_file(file_path) + try: + if sheet_name: + result_df = pd.read_excel(file_path, keep_default_na=False, sheet_name=sheet_name) + else: + result_df = pd.read_excel(file_path, keep_default_na=False) + except Exception as e: + logger.error(f"The xlsx file failed to load. Please check the path: {file_path}.") + raise RuntimeError(f"Read xlsx file {file_path} failed.") from e + return result_df + + +def create_file_with_list(result_list, filepath): + check_path_before_create(filepath) + filepath = os.path.realpath(filepath) + try: + with FileOpen(filepath, 'w', encoding='utf-8') as file: + fcntl.flock(file, fcntl.LOCK_EX) + for item in result_list: + file.write(item + '\n') + fcntl.flock(file, fcntl.LOCK_UN) + except Exception as e: + logger.error(f'Save list to file "{os.path.basename(filepath)}" failed.') + raise RuntimeError(f"Save list to file {os.path.basename(filepath)} failed.") from e + change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY) + + +def create_file_with_content(data, filepath): + check_path_before_create(filepath) + filepath = os.path.realpath(filepath) + try: + with FileOpen(filepath, 'w', encoding='utf-8') as file: + fcntl.flock(file, fcntl.LOCK_EX) + file.write(data) + fcntl.flock(file, fcntl.LOCK_UN) + except Exception as e: + logger.error(f'Save content to file "{os.path.basename(filepath)}" failed.') + raise RuntimeError(f"Save content to file {os.path.basename(filepath)} failed.") from e + change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY) + + +def add_file_to_zip(zip_file_path, file_path, arc_path=None): + """ + Add a file to a ZIP archive, if zip does not exist, create one. + + :param zip_file_path: Path to the ZIP archive + :param file_path: Path to the file to add + :param arc_path: Optional path inside the ZIP archive where the file should be added """ - Check the validity of the SSL certificate. + check_file_suffix(zip_file_path, FileCheckConst.ZIP_SUFFIX) + check_file_size(file_path, FileCheckConst.MAX_FILE_IN_ZIP_SIZE) + zip_size = os.path.getsize(zip_file_path) if os.path.exists(zip_file_path) else 0 + if zip_size + os.path.getsize(file_path) > FileCheckConst.MAX_ZIP_SIZE: + raise RuntimeError(f"ZIP file size exceeds the limit of {FileCheckConst.MAX_ZIP_SIZE} bytes") + check_path_before_create(zip_file_path) + try: + proc_lock.acquire() + with zipfile.ZipFile(zip_file_path, 'a') as zip_file: + zip_file.write(file_path, arc_path) + except Exception as e: + logger.error(f'add file to zip "{os.path.basename(zip_file_path)}" failed.') + raise RuntimeError(f"add file to zip {os.path.basename(zip_file_path)} failed.") from e + finally: + proc_lock.release() + change_mode(zip_file_path, FileCheckConst.DATA_FILE_AUTHORITY) - Load the SSL certificate from the specified path, parse and check its validity period. - If the certificate is expired or invalid, raise a RuntimeError. - Parameters: - pem_path (str): The file path of the SSL certificate. +def create_file_in_zip(zip_file_path, file_name, content): + """ + Create a file with content inside a ZIP archive. - Raises: - RuntimeError: If the SSL certificate is invalid or expired. + :param zip_file_path: Path to the ZIP archive + :param file_name: Name of the file to create + :param content: Content to write to the file """ - import OpenSSL + check_file_suffix(zip_file_path, FileCheckConst.ZIP_SUFFIX) + check_path_before_create(zip_file_path) + zip_size = os.path.getsize(zip_file_path) if os.path.exists(zip_file_path) else 0 + if zip_size + sys.getsizeof(content) > FileCheckConst.MAX_ZIP_SIZE: + raise RuntimeError(f"ZIP file size exceeds the limit of {FileCheckConst.MAX_ZIP_SIZE} bytes") try: - with FileOpen(pem_path, "r") as f: - pem_data = f.read() - cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, pem_data) - pem_start = parser.parse(cert.get_notBefore().decode("UTF-8")) - pem_end = parser.parse(cert.get_notAfter().decode("UTF-8")) - logger.info(f"The SSL certificate passes the verification and the validity period " - f"starts from {pem_start} ends at {pem_end}.") + proc_lock.acquire() + with zipfile.ZipFile(zip_file_path, 'a') as zip_file: + zip_info = zipfile.ZipInfo(file_name) + zip_info.compress_type = zipfile.ZIP_DEFLATED + zip_file.writestr(zip_info, content) except Exception as e: - logger.error("Failed to parse the SSL certificate. Check the certificate.") - raise RuntimeError(f"The SSL certificate is invalid, {pem_path}") from e + logger.error(f'Save content to file "{os.path.basename(zip_file_path)}" failed.') + raise RuntimeError(f"Save content to file {os.path.basename(zip_file_path)} failed.") from e + finally: + proc_lock.release() + change_mode(zip_file_path, FileCheckConst.DATA_FILE_AUTHORITY) - now_utc = datetime.now(tz=timezone.utc) - if cert.has_expired() or not (pem_start <= now_utc <= pem_end): - raise RuntimeError(f"The SSL certificate has expired and needs to be replaced, {pem_path}") +def extract_zip(zip_file_path, extract_dir): + """ + Extract the contents of a ZIP archive to a specified directory. -def read_xlsx(file_path): - check_file_or_directory_path(file_path) + :param zip_file_path: Path to the ZIP archive + :param extract_dir: Directory to extract the contents to + """ + check_file_suffix(zip_file_path, FileCheckConst.ZIP_SUFFIX) try: - result_df = pd.read_excel(file_path, keep_default_na=False) + proc_lock.acquire() + check_zip_file(zip_file_path) except Exception as e: - logger.error(f"The xlsx file failed to load. Please check the path: {file_path}.") - raise RuntimeError(f"Read xlsx file {file_path} failed.") from e - return result_df + logger.error(f'Save content to file "{os.path.basename(zip_file_path)}" failed.') + raise RuntimeError(f"Save content to file {os.path.basename(zip_file_path)} failed.") from e + finally: + proc_lock.release() + with zipfile.ZipFile(zip_file_path, 'r') as zip_file: + zip_file.extractall(extract_dir) + + +def split_zip_file_path(zip_file_path): + check_file_suffix(zip_file_path, FileCheckConst.ZIP_SUFFIX) + zip_file_path = os.path.realpath(zip_file_path) + return os.path.dirname(zip_file_path), os.path.basename(zip_file_path) + + +def dedup_log(func_name, filter_name): + with SharedDict() as shared_dict: + exist_names = shared_dict.get(func_name, set()) + if filter_name in exist_names: + return False + exist_names.add(filter_name) + shared_dict[func_name] = exist_names + return True + + +class SharedDict: + def __init__(self): + self._changed = False + self._dict = None + self._shm = None + + def __enter__(self): + self._load_shared_memory() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + try: + if self._changed: + data = pickle.dumps(self._dict) + global_lock.acquire() + try: + self._shm.buf[0:len(data)] = bytearray(data) + finally: + global_lock.release() + self._shm.close() + except FileNotFoundError: + name = self.get_shared_memory_name() + logger.debug(f'close shared memory {name} failed, shared memory has already been destroyed.') + + def __setitem__(self, key, value): + self._dict[key] = value + self._changed = True + + def __contains__(self, item): + return item in self._dict + + @classmethod + def destroy_shared_memory(cls): + if is_main_process(): + name = cls.get_shared_memory_name() + try: + shm = shared_memory.SharedMemory(create=False, name=name) + shm.close() + shm.unlink() + logger.debug(f'destroy shared memory, name: {name}') + except FileNotFoundError: + logger.debug(f'destroy shared memory {name} failed, shared memory has already been destroyed.') + + @classmethod + def get_shared_memory_name(cls): + if is_main_process(): + return f'shared_memory_{os.getpid()}' + return f'shared_memory_{os.getppid()}' + + def get(self, key, default=None): + return self._dict.get(key, default) + + def _load_shared_memory(self): + name = self.get_shared_memory_name() + try: + self._shm = shared_memory.SharedMemory(create=False, name=name) + except FileNotFoundError: + try: + self._shm = shared_memory.SharedMemory(create=True, name=name, size=1024 * 1024 * 5) + data = pickle.dumps({}) + self._shm.buf[0:len(data)] = bytearray(data) + logger.debug(f'create shared memory, name: {name}') + except FileExistsError: + self._shm = shared_memory.SharedMemory(create=False, name=name) + self._safe_load() + + def _safe_load(self): + with io.BytesIO(self._shm.buf[:]) as buff: + try: + self._dict = SafeUnpickler(buff).load() + except Exception as e: + logger.debug(f'shared dict is unreadable, reason: {e}, create new dict.') + self._dict = {} + self._shm.buf[:] = bytearray(b'\x00' * len(self._shm.buf)) # 清空内存 + self._changed = True + + +class SafeUnpickler(pickle.Unpickler): + WHITELIST = {'builtins': {'str', 'bool', 'int', 'float', 'list', 'set', 'dict'}} + + def find_class(self, module, name): + if module in self.WHITELIST and name in self.WHITELIST[module]: + return super().find_class(module, name) + raise pickle.PicklingError(f'Unpickling {module}.{name} is illegal!') + + +atexit.register(SharedDict.destroy_shared_memory) diff --git a/debug/accuracy_tools/msprobe/core/common/global_lock.py b/debug/accuracy_tools/msprobe/core/common/global_lock.py new file mode 100644 index 0000000000000000000000000000000000000000..2090f009ea5a78a7c5fbda61c12b6c0a842b7d25 --- /dev/null +++ b/debug/accuracy_tools/msprobe/core/common/global_lock.py @@ -0,0 +1,86 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import multiprocessing +from multiprocessing.shared_memory import SharedMemory +import random +import time +import atexit +import os + +from msprobe.core.common.log import logger + + +def is_main_process(): + return multiprocessing.current_process().name == 'MainProcess' + + +class GlobalLock: + def __init__(self): + self.name = self.get_lock_name() + try: + self._shm = SharedMemory(create=False, name=self.name) + time.sleep(random.randint(0, 500) / 10000) # 等待随机时长以避免同时获得锁 + except FileNotFoundError: + try: + self._shm = SharedMemory(create=True, name=self.name, size=1) + self._shm.buf[0] = 0 + logger.debug(f'{self.name} is created.') + except FileExistsError: + self.__init__() + + @classmethod + def get_lock_name(cls): + if is_main_process(): + return f'global_lock_{os.getpid()}' + return f'global_lock_{os.getppid()}' + + @classmethod + def is_lock_exist(cls): + try: + SharedMemory(create=False, name=cls.get_lock_name()).close() + return True + except FileNotFoundError: + return False + + def cleanup(self): + self._shm.close() + if is_main_process(): + try: + self._shm.unlink() + logger.debug(f'{self.name} is unlinked.') + except FileNotFoundError: + logger.warning(f'{self.name} has already been unlinked.') + + def acquire(self, timeout=180): + """ + acquire global lock, default timeout is 3 minutes. + + :param float timeout: timeout(seconds), default value is 180. + """ + start = time.time() + while time.time() - start < timeout: + if self._shm.buf[0] == 0: + self._shm.buf[0] = 1 + return + time.sleep(random.randint(10, 500) / 10000) # 自旋,等待1-50ms + self._shm.buf[0] = 1 + + def release(self): + self._shm.buf[0] = 0 + + +global_lock = GlobalLock() +atexit.register(global_lock.cleanup) diff --git a/debug/accuracy_tools/msprobe/core/common/runtime.py b/debug/accuracy_tools/msprobe/core/common/runtime.py new file mode 100644 index 0000000000000000000000000000000000000000..69719cd9236a9860364dfd7f0b6b3d2692ea0be7 --- /dev/null +++ b/debug/accuracy_tools/msprobe/core/common/runtime.py @@ -0,0 +1,25 @@ +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from msprobe.core.common.const import Const + + +class Runtime: + step_count: int = 0 + rank_id: int = -1 + is_running: bool = False + run_mode: str = Const.PYNATIVE_MODE + current_iter: int = 0 + current_rank: None diff --git a/debug/accuracy_tools/msprobe/core/common/utils.py b/debug/accuracy_tools/msprobe/core/common/utils.py index c06b5b64927bf47da1573df3b1d4db34dfa24cb1..880a462563b7f39e7915a1285671c0ec93e3915f 100644 --- a/debug/accuracy_tools/msprobe/core/common/utils.py +++ b/debug/accuracy_tools/msprobe/core/common/utils.py @@ -18,9 +18,8 @@ import os import re import subprocess import time -from collections import defaultdict +import inspect from datetime import datetime, timezone -from functools import wraps import numpy as np @@ -75,6 +74,10 @@ class MsprobeBaseException(Exception): MERGE_COMPARE_RESULT_ERROR = 33 NAMES_STRUCTS_MATCH_ERROR = 34 INVALID_STATE_ERROR = 35 + INVALID_API_NAME_ERROR = 36 + CROSS_FRAME_ERROR = 37 + MISSING_THRESHOLD_ERROR = 38 + WRONG_THRESHOLD_ERROR = 38 def __init__(self, code, error_info: str = ""): super(MsprobeBaseException, self).__init__() @@ -191,27 +194,6 @@ def check_regex_prefix_format_valid(prefix): raise ValueError(f"prefix contains invalid characters, prefix pattern {Const.REGEX_PREFIX_PATTERN}") -def execute_command(cmd): - """ - Function Description: - run the following command - Parameter: - cmd: command - Exception Description: - when invalid command throw exception - """ - logger.info('Execute command:%s' % cmd) - process = subprocess.Popen(cmd, shell=False, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) - while process.poll() is None: - line = process.stdout.readline() - line = line.strip() - if line: - logger.info(line) - if process.returncode != 0: - logger.error('Failed to execute command:%s' % " ".join(cmd)) - raise CompareException(CompareException.INVALID_DATA_ERROR) - - def add_time_as_suffix(name): return '{}_{}.csv'.format(name, time.strftime("%Y%m%d%H%M%S", time.localtime(time.time()))) @@ -220,6 +202,10 @@ def add_time_with_xlsx(name): return '{}_{}.xlsx'.format(name, time.strftime("%Y%m%d%H%M%S", time.localtime(time.time()))) +def add_time_with_json(name): + return '{}_{}.json'.format(name, time.strftime("%Y%m%d%H%M%S", time.localtime(time.time()))) + + def add_time_with_yaml(name): return '{}_{}.yaml'.format(name, time.strftime("%Y%m%d%H%M%S", time.localtime(time.time()))) @@ -247,6 +233,10 @@ def md5_find(data): def detect_framework_by_dump_json(file_path): + json_data = load_json(file_path) + framework = json_data.get("framework", None) + if framework in [Const.PT_FRAMEWORK, Const.MS_FRAMEWORK]: + return framework pattern_ms = r'"type":\s*"mindspore' pattern_pt = r'"type":\s*"torch' with FileOpen(file_path, 'r') as file: @@ -279,10 +269,25 @@ def set_dump_path(input_param): npu_path_valid = npu_path is not None and npu_path.endswith("dump.json") bench_path_valid = bench_path is not None and bench_path.endswith("dump.json") if not npu_path_valid or not bench_path_valid: - logger.error(f"Please check the json path is valid. npu_path: {npu_path}, bench_path: {bench_path}") + logger.error(f"Please check the json path is valid and ensure that neither npu_path nor bench_path is None.") raise CompareException(CompareException.INVALID_PATH_ERROR) - input_param['npu_dump_data_dir'] = os.path.join(os.path.dirname(npu_path), Const.DUMP_TENSOR_DATA) - input_param['bench_dump_data_dir'] = os.path.join(os.path.dirname(bench_path), Const.DUMP_TENSOR_DATA) + input_param[CompareConst.NPU_DUMP_DATA_DIR] = os.path.join(os.path.dirname(npu_path), Const.DUMP_TENSOR_DATA) + input_param[CompareConst.BENCH_DUMP_DATA_DIR] = os.path.join(os.path.dirname(bench_path), Const.DUMP_TENSOR_DATA) + + +def check_dump_json_key(json_data, device_type): + task = json_data.get('task', None) + if not task: + logger.error(f"Task for {device_type} is empty, please check.") + raise CompareException(CompareException.INVALID_TASK_ERROR) + if 'data' not in json_data: + logger.error(f"Missing 'data' in dump.json, please check dump.json of {device_type}.") + raise CompareException(CompareException.INVALID_DATA_ERROR) + api_data = json_data.get('data') + if not isinstance(api_data, dict): + logger.error(f"Invalid type for 'data': expected a dict. Please check dump.json of {device_type}.") + raise CompareException(CompareException.INVALID_DATA_ERROR) + return task, api_data def get_dump_mode(input_param): @@ -291,12 +296,8 @@ def get_dump_mode(input_param): npu_json_data = load_json(npu_path) bench_json_data = load_json(bench_path) - npu_task = npu_json_data.get('task', None) - bench_task = bench_json_data.get('task', None) - - if not npu_task or not bench_task: - logger.error(f"Please check the dump task is correct, npu's task is {npu_task}, bench's task is {bench_task}.") - raise CompareException(CompareException.INVALID_TASK_ERROR) + npu_task, npu_api_data = check_dump_json_key(npu_json_data, 'npu') + bench_task, bench_api_data = check_dump_json_key(bench_json_data, 'bench') if npu_task != bench_task: logger.error(f"Please check the dump task is consistent.") @@ -309,8 +310,8 @@ def get_dump_mode(input_param): return Const.STRUCTURE if npu_task == Const.STATISTICS: - npu_md5_compare = md5_find(npu_json_data['data']) - bench_md5_compare = md5_find(bench_json_data['data']) + npu_md5_compare = md5_find(npu_api_data) + bench_md5_compare = md5_find(bench_api_data) if npu_md5_compare == bench_md5_compare: return Const.MD5 if npu_md5_compare else Const.SUMMARY else: @@ -424,6 +425,37 @@ def get_real_step_or_rank(step_or_rank_input, obj): return real_step_or_rank +def check_init_step(step): + if not is_int(step): + raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR, + f"{step} must be an integer") + if not step >= 0: + raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR, + f"{step} must be greater than or equal to 0") + + +def check_token_range(token_range): + if token_range is None: + return + if not isinstance(token_range, (list, tuple)): + logger.error("Token_range must be a list or tuple.") + raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR) + if len(token_range) != 2: + logger.error("Token_range must contains exactly 2 elements.") + raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR) + + start, end = token_range + if not isinstance(start, int) or not isinstance(end, int): + logger.error("Start and end in token_range must be integer.") + raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR) + if start > end: + logger.error("Start in token_range must less than the end.") + raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR) + if start < 0: + logger.error("Start in token_range must >= 0.") + raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR) + + def check_seed_all(seed, mode, rm_dropout): if is_int(seed): if seed < 0 or seed > Const.MAX_SEED_VALUE: @@ -467,36 +499,6 @@ def safe_get_value(container, index, container_name, key=None): raise MsprobeBaseException(MsprobeBaseException.INVALID_OBJECT_TYPE_ERROR) from e -# 记录工具函数递归的深度 -recursion_depth = defaultdict(int) - - -# 装饰一个函数,当函数递归调用超过限制时,抛出异常并打印函数信息。 -def recursion_depth_decorator(func_info): - def decorator(func): - @wraps(func) - def wrapper(*args, **kwargs): - func_id = id(func) - recursion_depth[func_id] += 1 - if recursion_depth[func_id] > Const.MAX_DEPTH: - msg = f"call {func_info} exceeds the recursion limit." - logger.error_log_with_exp( - msg, - MsprobeException( - MsprobeException.RECURSION_LIMIT_ERROR, msg - ), - ) - try: - result = func(*args, **kwargs) - finally: - recursion_depth[func_id] -= 1 - return result - - return wrapper - - return decorator - - def check_str_param(param): if not re.match(Const.REGEX_PREFIX_PATTERN, param): logger.error('The parameter {} contains special characters.'.format(param)) @@ -509,4 +511,71 @@ class DumpPathAggregation: construct_file_path = None dump_tensor_data_dir = None free_benchmark_file_path = None - debug_file_path = None \ No newline at end of file + debug_file_path = None + + +def is_save_variable_valid(variable, valid_special_types, depth=0): + if depth > Const.DUMP_MAX_DEPTH: + return False + if isinstance(variable, valid_special_types): + return True + elif isinstance(variable, (list, tuple)): + return all(is_save_variable_valid(item, valid_special_types, depth + 1) for item in variable) + elif isinstance(variable, dict): + return all(isinstance(key, str) and is_save_variable_valid(value, valid_special_types, depth + 1) + for key, value in variable.items()) + else: + return False + + +def replace_last_occurrence(text, old, new): + if text is None: + return text + index = text.rfind(old) + if index != -1: + return text[:index] + text[index:].replace(old, new, 1) + return text + + +def load_stack_json(stack_path): + stack_dict = load_json(stack_path) + + if not isinstance(stack_dict, dict): + raise MsprobeException( + MsprobeException.INVALID_PARAM_ERROR, + "The format of the stack.json is incorrect, the outermost layer of stack.json should be a dict type." + ) + + if not stack_dict.get(Const.NEW_STACK_FLAG): + return stack_dict + + new_stack_dict = {} + for stack_info in stack_dict.values(): + if not isinstance(stack_info, list) or len(stack_info) != 2: + continue + + api_list, stack_str = stack_info + if not isinstance(api_list, list): + continue + + for api_name in api_list: + new_stack_dict.update({api_name: stack_str}) + return new_stack_dict + + +def analyze_api_call_stack(name): + try: + api_stack = inspect.stack()[2:] + except Exception as e: + logger.warning(f"The call stack of {name} failed to retrieve, {e}.") + api_stack = None + stack_str = [] + if api_stack: + for (_, path, line, func, code, _) in api_stack: + if not code: + continue + stack_line = f"File {path}, line {str(line)}, in {func}, \n {code[0].strip()} \n" + stack_str.append(stack_line) + else: + stack_str.append(Const.WITHOUT_CALL_STACK) + return "".join(stack_str) diff --git a/debug/accuracy_tools/msprobe/core/common_config.py b/debug/accuracy_tools/msprobe/core/common_config.py index b9a717c0c52f11e52ac055e3cfe6a0e77fe7e44c..836a7b89d3008c8e2fc34053eddd186e875279d6 100644 --- a/debug/accuracy_tools/msprobe/core/common_config.py +++ b/debug/accuracy_tools/msprobe/core/common_config.py @@ -111,3 +111,10 @@ class BaseConfig: f"The element '{mode}' of data_mode {self.data_mode} is not in {Const.DUMP_DATA_MODE_LIST}.", MsprobeException(MsprobeException.INVALID_PARAM_ERROR) ) + + def _check_summary_mode(self): + if self.summary_mode and self.summary_mode not in Const.SUMMARY_MODE: + logger.error_log_with_exp( + f"summary_mode is invalid, summary_mode is not in {Const.SUMMARY_MODE}.", + MsprobeException(MsprobeException.INVALID_PARAM_ERROR) + ) diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/unittest/__init__.py b/debug/accuracy_tools/msprobe/core/compare/__init__.py similarity index 100% rename from debug/accuracy_tools/msprobe/pytorch/monitor/unittest/__init__.py rename to debug/accuracy_tools/msprobe/core/compare/__init__.py diff --git a/debug/accuracy_tools/msprobe/core/compare/acc_compare.py b/debug/accuracy_tools/msprobe/core/compare/acc_compare.py index 55229d72657c67428186bcb233371e3b9eee73e0..1a611d8b7650ff00e99779655bf0d1b6ce83f4b0 100644 --- a/debug/accuracy_tools/msprobe/core/compare/acc_compare.py +++ b/debug/accuracy_tools/msprobe/core/compare/acc_compare.py @@ -13,111 +13,235 @@ # See the License for the specific language governing permissions and # limitations under the License. -import multiprocessing import os import re -from copy import deepcopy +from dataclasses import dataclass +from collections import defaultdict +import numpy as np import pandas as pd from tqdm import tqdm from msprobe.core.advisor.advisor import Advisor from msprobe.core.common.const import CompareConst, Const from msprobe.core.common.exceptions import FileCheckException -from msprobe.core.common.file_utils import load_json, remove_path +from msprobe.core.common.file_utils import load_json, remove_path, create_directory, save_json from msprobe.core.common.log import logger -from msprobe.core.common.utils import CompareException, add_time_with_xlsx, check_op_str_pattern_valid, safe_get_value -from msprobe.core.compare.check import check_dump_json_str, check_graph_mode, check_stack_json_str, \ - check_struct_match, fuzzy_check_op -from msprobe.core.compare.highlight import find_compare_result_error_rows, highlight_rows_xlsx -from msprobe.core.compare.multiprocessing_compute import ComparisonResult, _handle_multi_process, _save_cmp_result -from msprobe.core.compare.npy_compare import compare_ops_apply, get_error_flag_and_msg -from msprobe.core.compare.utils import get_accuracy, get_rela_diff_summary_mode, get_un_match_accuracy, merge_tensor, \ - print_compare_ends_info, read_op, get_name_and_state, reorder_op_x_list +from msprobe.core.common.utils import CompareException, add_time_with_xlsx, check_op_str_pattern_valid, \ + set_dump_path, get_dump_mode, check_compare_param, check_configuration_param, load_stack_json, add_time_with_json +from msprobe.core.compare.check import check_dump_json_str, check_stack_json_str, cross_dtype_mapping +from msprobe.core.compare.utils import merge_tensor, print_compare_ends_info, read_op, \ + reorder_op_x_list, set_stack_json_path +from msprobe.core.compare.config import ModeConfig, MappingConfig, MappingDict +from msprobe.core.compare.multiprocessing_compute import CompareRealData +from msprobe.core.compare.highlight import HighLight +from msprobe.core.compare.diff_analyze.first_diff_analyze import FirstDiffAnalyze + + +@dataclass +class ComparisonConfig: + dump_mode: str + stack_mode: bool + auto_analyze: bool + fuzzy_match: bool + data_mapping: dict + suffix: str + cell_mapping: dict + api_mapping: dict + layer_mapping: dict + first_diff_analyze: bool -class ModeConfig: - def __init__(self, stack_mode=False, auto_analyze=True, fuzzy_match=False, dump_mode=None): - self.stack_mode = stack_mode - self.auto_analyze = auto_analyze - self.fuzzy_match = fuzzy_match - self.dump_mode = dump_mode +class Comparator: + def __init__(self, file_reader, mode_config: ModeConfig, mapping_config: MappingConfig, is_cross_framework=False): + self.file_reader = file_reader + self.mode_config = mode_config + self.mapping_config = mapping_config + self.cross_frame = is_cross_framework + self.mapping_dict = MappingDict(mapping_config) -class Comparator: - def __init__(self, mode_config: ModeConfig): - self.stack_mode = mode_config.stack_mode - self.auto_analyze = mode_config.auto_analyze - self.fuzzy_match = mode_config.fuzzy_match - self.dump_mode = mode_config.dump_mode + def process_output_file(self, output_path, suffix): + if self.mode_config.first_diff_analyze: + file_name = add_time_with_json("compare_result" + suffix) + else: + file_name = add_time_with_xlsx("compare_result" + suffix) + file_path = os.path.join(os.path.realpath(output_path), file_name) + if os.path.exists(file_path): + logger.warning(f"{file_path} will be deleted.") + remove_path(file_path) + return file_path - @staticmethod - def get_result_md5_compare(ms_op_name, bench_op_name, npu_ops_all, bench_ops_all, *args): - npu_struct = npu_ops_all.get(ms_op_name).get('struct', []) - bench_struct = bench_ops_all.get(bench_op_name).get('struct', []) + def compare_core(self, input_param, output_path, **kwargs): + """ + Compares data from multiple JSON files and generates a comparison report. - if len(npu_struct) < 3 or len(bench_struct) < 3: - logger.error(f"The length of npu_struct and bench_struct must be >= 3, " - f"but got npu_struct={len(npu_struct)} and bench_struct={len(bench_struct)}. Please check!") - raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) + Args: + input_param (dict): A dictionary containing paths to JSON files ("npu_path", "bench_path", + "stack_path"). + output_path (str): The path where the output Excel report will be saved. + **kwargs: Additional keyword arguments including: + - stack_mode (bool, optional): Enables stack mode comparison. Defaults to False. + - auto_analyze (bool, optional): If True, triggers automatic analysis after comparison. Defaults to True. + - suffix (str, optional): Suffix to append to the output file name. Defaults to ''. + - fuzzy_match (bool, optional): Enables fuzzy matching during comparison. Defaults to False. + - dump_mode (str): ALL, SUMMARY, MD5. - result_item = [ms_op_name, bench_op_name, npu_struct[0], bench_struct[0], - npu_struct[1], bench_struct[1], npu_struct[2], bench_struct[2], - CompareConst.PASS if npu_struct[2] == bench_struct[2] else CompareConst.DIFF] + Returns: + """ + logger.info("Please check whether the input data belongs to you. If not, there may be security risks.") - if len(args) >= 2 and args[0]: - result_item.extend(args[1]) - else: - result_item.append(CompareConst.NONE) - return result_item + # get kwargs or set default value + suffix = kwargs.get('suffix', '') - @staticmethod - def calculate_summary_data(npu_summary_data, bench_summary_data, result_item): - err_msg = "" - result_item, accuracy_check, err_msg = get_rela_diff_summary_mode(result_item, npu_summary_data, - bench_summary_data, err_msg) - result_item.append(accuracy_check) - result_item.append(err_msg) + # process output file + file_path = self.process_output_file(output_path, suffix) - @staticmethod - def _generate_na_data(ops_all): - if not ops_all: - return {} - key = next(iter(ops_all)) - value = deepcopy(ops_all[key]) - for k, v in value.items(): - if isinstance(v, tuple): - value[k] = tuple(CompareConst.N_A for _ in range(len(v))) - elif isinstance(v, list): - value[k] = [CompareConst.N_A] * len(v) - else: - value[k] = CompareConst.N_A - return value + # initialize the compare result table and compare general data(name, dtype, shape, statistics/md5, etc.) + npu_json = input_param.get("npu_json_path") + bench_json = input_param.get("bench_json_path") + stack_json = input_param.get("stack_json_path") + result_df = self.compare_statistics([npu_json, bench_json, stack_json]) + if not result_df.values.tolist(): + logger.warning("Can`t match any op. No compare result file generated.") + return - def make_result_table(self, result): - header = CompareConst.HEAD_OF_COMPARE_MODE[self.dump_mode][:] + if self.mode_config.first_diff_analyze: + first_diff_analyze = FirstDiffAnalyze() + check_result = first_diff_analyze.check(result_df) + save_json(file_path, check_result, indent=4) + logger.info(f"Saving json file to disk: {file_path}") + return - if self.stack_mode: - header.append(CompareConst.STACK) - if self.dump_mode == Const.ALL: - header.append(CompareConst.DATA_NAME) - else: - if self.dump_mode == Const.ALL: - for row in result: - del row[-2] # 输出结果不要堆栈信息时,删除中间结果result中的stack info,真实数据时为倒数第2列 - header.append(CompareConst.DATA_NAME) - else: - for row in result: - del row[-1] # 输出结果不要堆栈信息时,删除中间结果result中的stack info,非真实数据时为倒数第1列 - result_df = pd.DataFrame(result, columns=header, dtype='object') - return result_df + # compare real data + if self.mode_config.dump_mode == Const.ALL: + compare_real_data = CompareRealData(self.file_reader, self.mode_config, self.cross_frame) + result_df = compare_real_data.do_multi_process(input_param, result_df) + + # highlight suspicious API + highlight_dict = {"red_rows": set(), "yellow_rows": set(), "red_lines": [], "yellow_lines": []} + highlight = HighLight(self.mode_config) + highlight.find_compare_result_error_rows(result_df, highlight_dict) + highlight.highlight_rows_xlsx(result_df, highlight_dict, file_path) + + # output compare analysis suggestions + if self.mode_config.auto_analyze: + advisor = Advisor(result_df, output_path, suffix) + advisor.analysis() + + print_compare_ends_info() + + def compare_statistics(self, file_list): + # load and parse json data + parse_data = ParseData(self.mode_config) + npu_df, bench_df = parse_data.parse(file_list) + + npu_df[[Const.DTYPE, Const.SHAPE]] = npu_df[[Const.DTYPE, Const.SHAPE]].astype(str) + bench_df[[Const.DTYPE, Const.SHAPE]] = bench_df[[Const.DTYPE, Const.SHAPE]].astype(str) + + # create new columns for compare op_name and shape + # process npu_df's COMPARE_KEY whether same or different framework + process_df = ProcessDf(self.mode_config, self.mapping_config, self.mapping_dict) + npu_df, bench_df = process_df.process_compare_key_and_shape(npu_df, bench_df) + + # match npu and bench, match_result contains both npu_info and bench_info + match = Match(self.mode_config, self.mapping_config, self.cross_frame) + match_result = match.match_api_infos(npu_df, bench_df) + # 筛选出npu_name存在的行并填充筛选出行中的缺失值为N/A + match_result = match_result[match_result['op_name_x'].notna()].fillna(CompareConst.N_A) + bench_columns = [i + '_y' for i in bench_df.columns] + match_result.loc[~match.gen_dtype_condition(match_result), bench_columns] = CompareConst.N_A + + # organize compare result table by renaming columns + create_table = CreateTable(self.mode_config) + result_df, header = create_table.make_result_df(match_result) + + # calculate statistics diff + calc_stats_diff = CalcStatsDiff(self.mode_config) + return calc_stats_diff.calc_accuracy(result_df, header) + + +class ParseData: + def __init__(self, mode_config: ModeConfig): + self.mode_config = mode_config + + def parse(self, file_list): + npu_json_path, bench_json_path, stack_json_path = file_list + npu_json_data = load_json(npu_json_path) + bench_json_data = load_json(bench_json_path) + stack_json_data = load_stack_json(stack_json_path) if self.mode_config.stack_mode else None + + # parse json data and generate df + npu_df = self.gen_data_df(npu_json_data, stack_json_data) + bench_df = self.gen_data_df(bench_json_data, stack_json_data) + + return npu_df, bench_df + + def gen_data_df(self, data_json, stack_json_data): + result = { + CompareConst.OP_NAME: [], + Const.DTYPE: [], + Const.SHAPE: [], + Const.SUMMARY: [], + Const.STACK_INFO: [] + } + if self.mode_config.dump_mode == Const.ALL: + result['data_name'] = [] + elif self.mode_config.dump_mode == Const.MD5: + result[Const.MD5] = [] + + apis_data = data_json.get('data', None) + if not apis_data: + logger.warning('No APIs found in dump.json.') + return pd.DataFrame(result) + + api_nums = len(apis_data) + progress_bar = tqdm(total=api_nums, desc="API/Module Read Progress", unit="api/module", ncols=100) + + # 从json中循环解析API数据,遍历所有API + for data_name in apis_data: + check_op_str_pattern_valid(data_name) + merge_list = self.gen_merge_list(data_json, data_name, stack_json_data) + if not merge_list: + continue + + op_name_list = merge_list.get(CompareConst.OP_NAME) + summary_list = merge_list.get(Const.SUMMARY) + data_name_list = merge_list.get('data_name') + op_name_reorder, summary_reorder, data_name_reorder = reorder_op_x_list(op_name_list, + summary_list, + data_name_list) + # 遍历单个API的所有item + for index, op_name in enumerate(op_name_reorder): + result[CompareConst.OP_NAME].append(op_name) + if (CompareConst.INPUT_PATTERN in op_name) or (CompareConst.KWARGS_PATTERN in op_name): + struct = merge_list[CompareConst.INPUT_STRUCT].pop(0) + elif CompareConst.OUTPUT_PATTERN in op_name: + struct = merge_list[CompareConst.OUTPUT_STRUCT].pop(0) + elif CompareConst.PARAMS_PATTERN in op_name: + struct = merge_list[CompareConst.PARAMS_STRUCT].pop(0) + else: + struct = merge_list[CompareConst.PARAMS_GRAD_STRUCT].pop(0) + result[Const.DTYPE].append(struct[0]) + result[Const.SHAPE].append(struct[1]) + if self.mode_config.dump_mode == Const.MD5: + result[Const.MD5].append(struct[2]) + result[Const.SUMMARY].append(summary_reorder.pop(0)) + result[Const.STACK_INFO].append( + merge_list[Const.STACK_INFO][0] if index == 0 and self.mode_config.stack_mode else None) + if self.mode_config.dump_mode == Const.ALL: + result['data_name'].append(data_name_reorder.pop(0)) + + progress_bar.update(1) + progress_bar.close() + return pd.DataFrame(result) def gen_merge_list(self, json_data, op_name, stack_json_data): op_data = json_data['data'][op_name] check_dump_json_str(op_data, op_name) op_parsed_list = read_op(op_data, op_name) - if self.stack_mode: + if self.mode_config.stack_mode: stack_info = stack_json_data.get(op_name) if stack_info is not None: check_stack_json_str(stack_info, op_name) @@ -127,423 +251,487 @@ class Comparator: 'full_info': stack_info }) - merge_list = merge_tensor(op_parsed_list, self.dump_mode) + merge_list = merge_tensor(op_parsed_list, self.mode_config.dump_mode) return merge_list - def check_op(self, npu_dict, bench_dict): - npu_op_name = npu_dict[CompareConst.OP_NAME] - bench_op_name = bench_dict[CompareConst.OP_NAME] - graph_mode = check_graph_mode(safe_get_value(npu_op_name, 0, "npu_op_name"), - safe_get_value(bench_op_name, 0, "bench_op_name")) - - frame_name = getattr(self, "frame_name") - if frame_name == "PTComparator": - from msprobe.pytorch.compare.match import graph_mapping - if graph_mode: - return graph_mapping.match(npu_op_name[0], bench_op_name[0]) - struct_match = check_struct_match(npu_dict, bench_dict) - if not self.fuzzy_match: - name_match = npu_op_name == bench_op_name - return name_match and struct_match - try: - name_match = fuzzy_check_op(npu_op_name, bench_op_name) - except Exception as err: - logger.warning("%s and %s can not fuzzy match." % (npu_op_name, bench_op_name)) - name_match = False - return name_match and struct_match - - def match_op(self, npu_queue, bench_queue): - for b_index, b_op in enumerate(bench_queue[0: -1]): - if self.check_op(npu_queue[-1], b_op): - return len(npu_queue) - 1, b_index - if self.check_op(npu_queue[-1], bench_queue[-1]): - return len(npu_queue) - 1, len(bench_queue) - 1 - for n_index, n_op in enumerate(npu_queue[0: -1]): - if self.check_op(n_op, bench_queue[-1]): - return n_index, len(bench_queue) - 1 - return -1, -1 - def compare_process(self, file_lists): - npu_json_path, bench_json_path, stack_json_path = file_lists - npu_json_data = load_json(npu_json_path) - bench_json_data = load_json(bench_json_path) - stack_json_data = load_json(stack_json_path) if self.stack_mode else None +class ProcessDf: + def __init__(self, mode_config: ModeConfig, mapping_config: MappingConfig, mapping_dict: MappingDict): + self.mode_config = mode_config + self.mapping_config = mapping_config + self.mapping_dict = mapping_dict - if self.fuzzy_match: - logger.warning("This task uses fuzzy matching, which may affect the accuracy of the comparison.") + @staticmethod + def get_api_name(api_list): + try: + api_name = api_list[0] + Const.SEP + api_list[1] + except IndexError as error: + logger.error('Failed to retrieve API name, please check if the dump data is reasonable') + raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from error + return api_name + + def process_compare_key_and_shape(self, npu_df, bench_df): + npu_df = self.assign_npu_df_compare_key(npu_df, bench_df) + npu_df[CompareConst.CMP_SHAPE] = npu_df[Const.SHAPE] + bench_df[CompareConst.CMP_KEY] = bench_df[CompareConst.OP_NAME] + bench_df[CompareConst.CMP_SHAPE] = bench_df[Const.SHAPE] + return npu_df, bench_df + + def assign_npu_df_compare_key(self, npu_df, bench_df): + """ + 处理 npu_df 的 COMPARE_KEY 赋值逻辑 - npu_ops_queue = [] - bench_ops_queue = [] - result = [] + :param npu_df: DataFrame,NPU 对比数据 + :param bench_df: DataFrame,Bench 对比数据 + :return: compare_key(name)处理后的 npu_df + """ + # 处理api_mapping映射 + if self.mapping_config.api_mapping: + # 如果用户不传api_mapping.yaml,先使用内置api_mapping.yaml替换npu_op_name + npu_df[CompareConst.CMP_KEY] = npu_df[CompareConst.OP_NAME].apply(self.process_internal_api_mapping) + # 如果用户传入api_mapping.yaml,再使用传入api_mapping.yaml进一步替换npu_op_name + if isinstance(self.mapping_config.api_mapping, str): + self.modify_compare_data_with_user_mapping(npu_df, bench_df) + # 处理cell_mapping映射 + elif self.mapping_config.cell_mapping: + npu_df[CompareConst.CMP_KEY] = npu_df[CompareConst.OP_NAME].apply(self.process_cell_mapping) + # 处理data_mapping映射 + elif self.mapping_config.data_mapping: + npu_df[CompareConst.CMP_KEY] = npu_df[CompareConst.OP_NAME].apply(self.process_data_mapping) + else: + npu_df[CompareConst.CMP_KEY] = npu_df[CompareConst.OP_NAME] + return npu_df + + def process_internal_api_mapping(self, npu_op_name): + # get api name & class name from op_name + ms_api_name = self.get_api_name(npu_op_name.split(Const.SEP)) + class_name = ms_api_name.split(Const.SEP)[0] + if class_name == "Mint": + return npu_op_name.replace("Mint", "Torch") + elif class_name == "MintFunctional": + return npu_op_name.replace("MintFunctional", "Functional") + elif self.mapping_dict.ms_to_pt_mapping.get(ms_api_name): + return npu_op_name.replace(ms_api_name, self.mapping_dict.ms_to_pt_mapping.get(ms_api_name)) + else: + return npu_op_name + + def modify_compare_data_with_user_mapping(self, npu_df, bench_df): + def gen_input_compare_key(pattern, term): + is_unmatched = True + for i, prefix in enumerate(mapping_dict.get(f'ms_{term}')): + if op_name.split(pattern)[1].startswith(str(prefix)): + npu_df.loc[index, CompareConst.CMP_KEY] = ( + op_name.replace(pattern + str(prefix), + pattern + str(mapping_dict.get(f'pt_{term}')[i]))) + is_unmatched = False + return is_unmatched + + ms_api_indices_dict = self.get_api_indices_dict(npu_df) + pt_api_indices_dict = self.get_api_indices_dict(bench_df) + + for mapping_dict in self.mapping_dict.api_mapping_dict: + all_length_equal = True + for k1, k2 in CompareConst.API_MAPPING_KEYS_TO_COMPARE: + if len(mapping_dict.get(k1, [])) != len(mapping_dict.get(k2, [])): + all_length_equal = False + if not all_length_equal: + logger.warning('The user-defined mapping table is incorrect,\ + make sure that the number of parameters is equal') + continue - ops_npu_iter = iter(npu_json_data['data']) - ops_bench_iter = iter(bench_json_data['data']) - read_err_npu = True - read_err_bench = True - last_npu_ops_len = 0 - last_bench_ops_len = 0 + ms_api, pt_api = mapping_dict.get('ms_api'), mapping_dict.get('pt_api') + if ms_api not in ms_api_indices_dict or pt_api not in pt_api_indices_dict: + continue + for index in ms_api_indices_dict.get(ms_api): + op_name = npu_df.loc[index, CompareConst.OP_NAME].replace(ms_api, pt_api, 1) + if CompareConst.INPUT_PATTERN in op_name: + is_abandoned = gen_input_compare_key(CompareConst.INPUT_PATTERN, 'args') + elif CompareConst.KWARGS_PATTERN in op_name: + is_abandoned = gen_input_compare_key(CompareConst.KWARGS_PATTERN, 'args') + elif CompareConst.OUTPUT_PATTERN in op_name: + is_abandoned = gen_input_compare_key(CompareConst.OUTPUT_PATTERN, 'output') + elif CompareConst.PARAMS_PATTERN in op_name: + is_abandoned = gen_input_compare_key(CompareConst.PARAMS_PATTERN, 'parameters') + elif CompareConst.PARAMS_GRAD_PATTERN in op_name: + is_abandoned = gen_input_compare_key(CompareConst.PARAMS_GRAD_PATTERN, 'parameters_grad') + else: + logger.error(f'Excepted op_name: {op_name}') + raise CompareException(CompareException.INVALID_DATA_ERROR) + if is_abandoned: + npu_df.loc[index, CompareConst.CMP_KEY] = op_name + 'abandoned' - npu_api_nums = len(npu_json_data['data']) - progress_bar = tqdm(total=npu_api_nums, desc="API/Module Read Progress", unit="item", ncols=100) + def get_api_indices_dict(self, op_name_df): + """ + 生成多个api对应的各自的所有的input、output等的index的键值对字典 + 示例: + {'Functional.conv2d': [0, 1, 2, 3], + 'Functional.batch_norm': [4, 5, 6, 7, 8] + } + """ + api_indices_dict = defaultdict(list) + for op_index, name in enumerate(op_name_df[CompareConst.OP_NAME]): + api_name = self.get_api_name(name.split(Const.SEP)) + api_indices_dict[api_name].append(op_index) + return api_indices_dict + + def process_cell_mapping(self, npu_op_name): + if not npu_op_name: + return CompareConst.N_A + param_grad_flag = Const.PARAMS_GRAD in npu_op_name.split(Const.SEP) + if not param_grad_flag and not re.search(Const.REGEX_FORWARD_BACKWARD, npu_op_name): + return CompareConst.N_A + npu_op_name = npu_op_name.replace("Cell", "Module", 1) + if self.mapping_dict.cell_mapping_dict: + # get cell name & class name from op_name + # Cell.fc1.Dense.forward.0.input.0 + cell_name = re.split(r'\.(?:forward|backward|parameters_grad)\.', npu_op_name.split(Const.SEP, 1)[-1])[0] + if cell_name in self.mapping_dict.cell_mapping_dict: + npu_op_name = npu_op_name.replace(cell_name, self.mapping_dict.cell_mapping_dict[cell_name], 1) + return npu_op_name + + def process_data_mapping(self, npu_op_name): + return self.mapping_dict.data_mapping_dict.get(npu_op_name, npu_op_name) + + +class Match: + def __init__(self, mode_config: ModeConfig, mapping_config: MappingConfig, cross_frame): + self.mode_config = mode_config + self.mapping_config = mapping_config + self.cross_frame = cross_frame - while True: - if not read_err_npu and not read_err_bench: - break - try: - last_npu_ops_len = len(npu_ops_queue) - op_name_npu = next(ops_npu_iter) - check_op_str_pattern_valid(op_name_npu) - npu_merge_list = self.gen_merge_list(npu_json_data, op_name_npu, stack_json_data) - if npu_merge_list: - npu_ops_queue.append(npu_merge_list) - except StopIteration: - read_err_npu = False - try: - last_bench_ops_len = len(bench_ops_queue) - op_name_bench = next(ops_bench_iter) - check_op_str_pattern_valid(op_name_bench) - bench_merge_list = self.gen_merge_list(bench_json_data, op_name_bench, stack_json_data) - if bench_merge_list: - bench_ops_queue.append(bench_merge_list) - except StopIteration: - read_err_bench = False + @staticmethod + def put_unmatched_in_table(match_result, npu_op_item): + npu_columns = npu_op_item.index.tolist()[:-2] + new_columns = [name[:-1] + 'y' for name in npu_columns] + na_series = pd.Series([CompareConst.N_A] * len(new_columns), index=new_columns) + new_result_item = pd.concat([npu_op_item, na_series]).to_frame().T + new_result_item.columns = CompareConst.MATCH_RESULT_COLUMNS + match_result = pd.concat([match_result, new_result_item]) + return match_result - progress_bar.update(1) + @staticmethod + def put_matched_in_table(match_result, npu_op_item, bench_op_item): + head_len = len(CompareConst.MATCH_RESULT_COLUMNS) + new_result_item = pd.concat([npu_op_item, bench_op_item]).head(head_len).to_frame().T + new_result_item.columns = CompareConst.MATCH_RESULT_COLUMNS + match_result = pd.concat([match_result, new_result_item]) + return match_result - # merge all boolean expressions - both_empty = not npu_ops_queue and not bench_ops_queue - no_change = (len(npu_ops_queue) == last_npu_ops_len) and (len(bench_ops_queue) == last_bench_ops_len) - if both_empty or no_change: - continue + @staticmethod + def rename_api(op_name): + """ + 原api: {api_type}.{api_name}.{API调用次数}.{前向反向}.{input/output}.{参数序号} + rename后: {api_type}.{api_name}.{前向反向}.{input/output}.{参数序号} + """ + if Const.FORWARD not in op_name and Const.BACKWARD not in op_name: + return op_name + process = Const.FORWARD if Const.FORWARD in op_name else Const.BACKWARD + name_split = op_name.split(process) + try: + torch_func_index, in_out = name_split[0], name_split[1] + except IndexError as error: + logger.error(f'{op_name} can not be split with {process}, please check!') + raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from error + torch_func_split = torch_func_index.rsplit(Const.SEP, 2) + torch_func = str(torch_func_split[0]) + Const.SEP + process + str(in_out) + return torch_func + + def check_op_item(self, npu_op_item, bench_op_item): + name_match = self.rename_api(npu_op_item[CompareConst.CMP_KEY]) == self.rename_api( + bench_op_item[CompareConst.CMP_KEY]) + shape_match = npu_op_item[CompareConst.CMP_SHAPE] == bench_op_item[CompareConst.CMP_SHAPE] + if name_match and shape_match: + return True + else: + npu_op_name = npu_op_item[CompareConst.OP_NAME] + bench_op_name = bench_op_item[CompareConst.OP_NAME] + check_op_str_pattern_valid(npu_op_name) + check_op_str_pattern_valid(bench_op_name) + logger.warning(f"{npu_op_name} and {bench_op_name} can not fuzzy match") + return False + + def match_api_infos(self, npu_df, bench_df): + """ + 正常匹配和模糊匹配 + """ + if self.mapping_config.data_mapping: + match_result = pd.merge(npu_df, bench_df, on=[CompareConst.CMP_KEY], how='left') + + # reorder match_result by op_name of npu + op_name_order = npu_df[CompareConst.OP_NAME].tolist() + match_result[CompareConst.OP_NAME_X] = pd.Categorical(match_result[CompareConst.OP_NAME_X], + categories=op_name_order, ordered=True) + match_result = match_result.sort_values(CompareConst.OP_NAME_X).reset_index(drop=True) + match_result[CompareConst.OP_NAME_X] = match_result[CompareConst.OP_NAME_X].astype('object') + + elif not self.mode_config.fuzzy_match: + match_result = pd.merge(npu_df, bench_df, on=[CompareConst.CMP_KEY, CompareConst.CMP_SHAPE], + how='outer') + else: + match_result = self.process_fuzzy_match(npu_df, bench_df) + return match_result - # APIs in NPU and Bench models unconsistent judgment + def process_fuzzy_match(self, npu_df, bench_df): + """ + 模糊匹配通过循环方式匹配api + """ + npu_ops_queue = [] + bench_ops_queue = [] + match_result = pd.DataFrame(columns=CompareConst.MATCH_RESULT_COLUMNS) + + max_len = max(len(npu_df), len(bench_df)) + min_len = min(len(npu_df), len(bench_df)) + for i in range(max_len): + if i < min_len: + npu_ops_queue.append(npu_df.iloc[i]) + bench_ops_queue.append(bench_df.iloc[i]) + else: + try: + npu_ops_queue.append(npu_df.iloc[i]) + except IndexError: + pass + try: + bench_ops_queue.append(bench_df.iloc[i]) + except IndexError: + pass + + # 如果append之后queue状态不一致,则判断结束 if bool(npu_ops_queue) ^ bool(bench_ops_queue): - logger.info("Please check whether the number and calls of APIs in NPU and Bench models are consistent.") break - n_match_point, b_match_point = self.match_op(npu_ops_queue, bench_ops_queue) + npu_match_point, bench_match_point = self.match_op(npu_ops_queue, bench_ops_queue) - # 如果没有匹配到,数据放到队列中,跳过,直到后面匹配到,把匹配之前的api放到不匹配中 - if n_match_point == -1 and b_match_point == -1: + # 如果没有匹配到,数据放到队列中,跳过。直到后面匹配到,把匹配之前的api放到不匹配中 + if npu_match_point == -1 and bench_match_point == -1: continue - n_match_data = npu_ops_queue[n_match_point] - b_match_data = bench_ops_queue[b_match_point] - un_match_data = npu_ops_queue[0: n_match_point] - for npu_data in un_match_data: - get_un_match_accuracy(result, npu_data, self.dump_mode) - get_accuracy(result, n_match_data, b_match_data, self.dump_mode) - del npu_ops_queue[0: n_match_point + 1] - del bench_ops_queue[0: b_match_point + 1] - progress_bar.close() + npu_op_item = npu_ops_queue[npu_match_point] + bench_op_item = bench_ops_queue[bench_match_point] + unmatched_data = npu_ops_queue[0: npu_match_point] + for op_item in unmatched_data: + match_result = self.put_unmatched_in_table(match_result, op_item) + match_result = self.put_matched_in_table(match_result, npu_op_item, bench_op_item) + del npu_ops_queue[0: npu_match_point + 1] + del bench_ops_queue[0: bench_match_point + 1] + if npu_ops_queue: - for npu_data in npu_ops_queue: - get_un_match_accuracy(result, npu_data, self.dump_mode) - - result_df = self.make_result_table(result) - return result_df - - def merge_data(self, json_data, stack_json_data): - ops_all = {} - for op_name in json_data.get('data', {}): - merge_list = self.gen_merge_list(json_data, op_name, stack_json_data) - if merge_list: - struct_to_index_mapping = { - CompareConst.INPUT_STRUCT: 0, - CompareConst.OUTPUT_STRUCT: 0, - CompareConst.PARAMS_STRUCT: 0, - CompareConst.PARAMS_GRAD_STRUCT: 0 - } - - op_name_list = merge_list.get(CompareConst.OP_NAME) - summary_list = merge_list.get(Const.SUMMARY) - data_name_list = merge_list.get('data_name') - op_name_reorder, summary_reorder, data_name_reorder = reorder_op_x_list(op_name_list, - summary_list, - data_name_list) - for index, op_full_name in enumerate(op_name_reorder): - data_name = data_name_reorder[index] if data_name_reorder else None - - _, state = get_name_and_state(op_full_name) - struct_key = CompareConst.STATE_TO_STRUCT_MAPPING.get(state) - if not struct_key: - continue - ops_all[op_full_name] = { - CompareConst.STRUCT: safe_get_value(merge_list, struct_to_index_mapping.get(struct_key), - "merge_list", key=struct_key), - CompareConst.SUMMARY: safe_get_value(summary_reorder, index, "summary_reorder"), - 'data_name': data_name, - 'stack_info': merge_list.get('stack_info') - } - struct_to_index_mapping[struct_key] += 1 - return ops_all - - def get_accuracy(self, npu_ops_all, bench_ops_all): - result = [] - bench_ops_all[CompareConst.N_A] = self._generate_na_data(bench_ops_all) - for ms_op_name, bench_op_name in self.data_mapping_dict.items(): - if ms_op_name in npu_ops_all and bench_op_name in bench_ops_all: - npu_stack_info = npu_ops_all.get(ms_op_name).get("stack_info", None) - bench_stack_info = bench_ops_all.get(bench_op_name).get("stack_info", None) - has_stack = npu_stack_info and bench_stack_info - if self.dump_mode == Const.MD5: - result.append(self.get_result_md5_compare(ms_op_name, bench_op_name, npu_ops_all, - bench_ops_all, has_stack, npu_stack_info)) - continue - - npu_struct = npu_ops_all.get(ms_op_name).get('struct', []) - bench_struct = bench_ops_all.get(bench_op_name).get('struct', []) - - if len(npu_struct) < 2 or len(bench_struct) < 2: - logger.error( - f"The length of npu_struct and bench_struct must be >= 2, " - f"but got npu_struct={len(npu_struct)} and bench_struct={len(bench_struct)}. " - f"Please check!" - ) - raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) - - base_result_item = [ - ms_op_name, bench_op_name, - npu_struct[0], - bench_struct[0], - npu_struct[1], - bench_struct[1] - ] - - if self.dump_mode == Const.SUMMARY: - result_item = base_result_item + [" "] * 8 - else: - result_item = base_result_item + [" "] * 5 - - npu_summary_data = npu_ops_all.get(ms_op_name).get("summary") - result_item.extend(npu_summary_data) - bench_summary_data = bench_ops_all.get(bench_op_name).get("summary") - result_item.extend(bench_summary_data) - if self.dump_mode == Const.SUMMARY: - self.calculate_summary_data(npu_summary_data, bench_summary_data, result_item) - else: - result_item.append(CompareConst.ACCURACY_CHECK_YES) - result_item.append("") - if has_stack: - result_item.extend(npu_stack_info) - else: - result_item.append(CompareConst.NONE) - if self.dump_mode == Const.ALL: - result_item.append(npu_ops_all.get(ms_op_name).get("data_name", None)) - result.append(result_item) - elif ms_op_name not in npu_ops_all: - logger.warning(f'Can not find npu op name : `{ms_op_name}` in npu dump json file.') - elif bench_op_name not in npu_ops_all: - logger.warning(f'Can not find bench op name : `{bench_op_name}` in bench dump json file.') - return result + for op_item in npu_ops_queue: + match_result = self.put_unmatched_in_table(match_result, op_item) - def compare_process_custom(self, file_lists): - npu_json_path, bench_json_path, stack_json_path = file_lists - npu_json_data = load_json(npu_json_path) - bench_json_data = load_json(bench_json_path) - stack_json_data = load_json(stack_json_path) if self.stack_mode else None - npu_ops_all = self.merge_data(npu_json_data, stack_json_data) - bench_ops_all = self.merge_data(bench_json_data, stack_json_data) + match_result.reset_index(drop=True, inplace=True) + return match_result - result = self.get_accuracy(npu_ops_all, bench_ops_all) - result_df = self.make_result_table(result) - return result_df + def match_op(self, npu_queue, bench_queue): + for b_index, b_op in enumerate(bench_queue[0: -1]): + if self.check_op_item(npu_queue[-1], b_op): + return len(npu_queue) - 1, b_index + if self.check_op_item(npu_queue[-1], bench_queue[-1]): + return len(npu_queue) - 1, len(bench_queue) - 1 + for n_index, n_op in enumerate(npu_queue[0: -1]): + if self.check_op_item(n_op, bench_queue[-1]): + return n_index, len(bench_queue) - 1 + return -1, -1 - def compare_by_op(self, npu_op_name, bench_op_name, op_name_mapping_dict, input_param, bench_data): + def gen_dtype_condition(self, match_result): """ - :param npu_op_name: excel中的NPU_Name,例如:MintFunctional.conv2d.0.forward.input.3.0 - :param bench_op_name: excel中的Bench_Name,例如:Functional.conv2d.0.forward.input.3.0 - :param op_name_mapping_dict: op_name和npy或pt文件的映射关系 - :param input_param: npu_json_path/bench_json_path/stack_json_path等参数 - :param bench_data: bench的dump数据中"data"字段 - :return: result_list,包含余弦相似度、最大绝对误差、最大相对误差、千分之一误差率、千分之五误差率和错误信息 - 用于读取excel中的NPU_Name和Bench_Name,根据映射关系找到npy或pt文件,然后读取文件中的数据进行比较,计算余弦相似度、 - 最大绝对误差、最大相对误差、千分之一误差率、千分之五误差率并生成错误信息 + dtype匹配条件为npu、bench的dtype一致或属于规定的映射关系 """ - npu_bench_name_list = op_name_mapping_dict[npu_op_name] - data_name = safe_get_value(npu_bench_name_list, 1, "npu_bench_name_list") - error_file, relative_err, error_flag = None, None, False - bench_data_name = get_bench_data_name(bench_op_name, bench_data) - if data_name == '-1' or data_name == -1: # 没有真实数据路径 - n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE - error_flag = True - elif not bench_data_name: - n_value, b_value, error_flag = CompareConst.READ_NONE, CompareConst.READ_NONE, True - error_file = 'no_bench_data' - else: - try: - read_npy_data = getattr(self, "read_npy_data") - frame_name = getattr(self, "frame_name") - if frame_name == "MSComparator": - n_value = read_npy_data(input_param.get("npu_dump_data_dir"), npu_op_name + Const.NUMPY_SUFFIX) - if self.cross_frame: - b_value = read_npy_data(input_param.get("bench_dump_data_dir"), bench_data_name, - load_pt_file=True) - else: - b_value = read_npy_data(input_param.get("bench_dump_data_dir"), bench_data_name) - else: - n_value = read_npy_data(input_param.get("npu_dump_data_dir"), npu_op_name + Const.PT_SUFFIX) - b_value = read_npy_data(input_param.get("bench_dump_data_dir"), bench_data_name) - except IOError as error: - error_file = error.filename - n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE - error_flag = True - except (FileCheckException, CompareException): - error_file = data_name - n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE - error_flag = True - - # 通过n_value, b_value同时得到错误标志和错误信息 - n_value, b_value, error_flag, err_msg = get_error_flag_and_msg(n_value, b_value, - error_flag=error_flag, error_file=error_file) - - result_list, err_msg = compare_ops_apply(n_value, b_value, error_flag, err_msg) - - if self.fuzzy_match and npu_op_name != bench_op_name and bench_op_name != CompareConst.N_A: - err_msg += " Fuzzy matching data, the comparison accuracy may be affected." - result_list.append(err_msg) - return result_list + # 如果使用了data_mapping,不校验dtype,返回全True的DataFrame + if self.mapping_config.data_mapping: + return pd.Series(True, index=match_result.index) + + npu_dtype = match_result['dtype_x'] + bench_dtype = match_result['dtype_y'] + npu_dtype = self.process_cross_frame_dtype(npu_dtype) + bench_dtype = self.process_cross_frame_dtype(bench_dtype) + + equal_condition = npu_dtype == bench_dtype + match_condition = ( + (npu_dtype.isin(CompareConst.DTYPE_MATCH_GROUPS[0]) & bench_dtype.isin( + CompareConst.DTYPE_MATCH_GROUPS[0])) | + (npu_dtype.isin(CompareConst.DTYPE_MATCH_GROUPS[1]) & bench_dtype.isin( + CompareConst.DTYPE_MATCH_GROUPS[1])) + ) + return equal_condition | match_condition - def compare_core(self, input_param, output_path, **kwargs): - """ - Compares data from multiple JSON files and generates a comparison report. + def process_cross_frame_dtype(self, dtype): + if self.cross_frame: + dtype = dtype.map(cross_dtype_mapping).fillna(dtype) + return dtype - Args: - input_param (dict): A dictionary containing paths to JSON files ("npu_path", "bench_path", - "stack_path"). - output_path (str): The path where the output Excel report will be saved. - **kwargs: Additional keyword arguments including: - - stack_mode (bool, optional): Enables stack mode comparison. Defaults to False. - - auto_analyze (bool, optional): If True, triggers automatic analysis after comparison. Defaults to True. - - suffix (str, optional): Suffix to append to the output file name. Defaults to ''. - - fuzzy_match (bool, optional): Enables fuzzy matching during comparison. Defaults to False. - - dump_mode (str): ALL, SUMMARY, MD5. - Returns: - """ - # get kwargs or set default value - suffix = kwargs.get('suffix', '') +class CreateTable: + def __init__(self, mode_config: ModeConfig): + self.mode_config = mode_config - logger.info("Please check whether the input data belongs to you. If not, there may be security risks.") - file_name = add_time_with_xlsx("compare_result" + suffix) - file_path = os.path.join(os.path.realpath(output_path), file_name) - remove_path(file_path) - highlight_dict = {"red_rows": set(), "yellow_rows": set(), "red_lines": [], "yellow_lines": []} + @staticmethod + def process_data_name(result): + result['data_name_x'] = result.apply(lambda row: [row['data_name_x'], row['data_name_y']], axis=1) + return result - npu_json = input_param.get("npu_json_path") - bench_json = input_param.get("bench_json_path") - stack_json = input_param.get("stack_json_path") - if self.data_mapping: - result_df = self.compare_process_custom([npu_json, bench_json, stack_json]) - else: - result_df = self.compare_process([npu_json, bench_json, stack_json]) + @staticmethod + def set_summary(summary): + if summary == CompareConst.N_A: + return [CompareConst.N_A] * 4 # 4为统计值个数 + summary_list = [] + for i in summary: + if str(i).lower() == 'nan': + summary_list.append(CompareConst.NAN) + else: + summary_list.append(i) + return summary_list - if not result_df.values.tolist(): - logger.warning("Can`t match any op.") - return + def make_result_df(self, result): + # get header + header = CompareConst.HEAD_OF_COMPARE_MODE[self.mode_config.dump_mode][:] + if self.mode_config.stack_mode: + header.append(CompareConst.STACK) + if self.mode_config.dump_mode == Const.ALL: + header.append(CompareConst.DATA_NAME) + result = self.process_data_name(result) + + # rename match_result columns + result.rename(columns={'op_name_x': CompareConst.NPU_NAME, + 'op_name_y': CompareConst.BENCH_NAME, + 'dtype_x': CompareConst.NPU_DTYPE, + 'dtype_y': CompareConst.BENCH_DTYPE, + 'shape_x': CompareConst.NPU_SHAPE, + 'shape_y': CompareConst.BENCH_SHAPE, + 'md5_x': CompareConst.NPU_MD5, + 'md5_y': CompareConst.BENCH_MD5, + 'data_name_x': CompareConst.DATA_NAME, + 'stack_info_x': CompareConst.STACK}, inplace=True) + + # process summary data + npu_summary = [CompareConst.NPU_MAX, CompareConst.NPU_MIN, CompareConst.NPU_MEAN, CompareConst.NPU_NORM] + bench_summary = [CompareConst.BENCH_MAX, CompareConst.BENCH_MIN, CompareConst.BENCH_MEAN, + CompareConst.BENCH_NORM] + if result.empty: + result[npu_summary] = pd.DataFrame(columns=npu_summary) + result[bench_summary] = pd.DataFrame(columns=bench_summary) + else: + result[npu_summary] = result['summary_x'].apply(self.set_summary).tolist() + result[bench_summary] = result['summary_y'].apply(self.set_summary).tolist() - if self.dump_mode == Const.ALL: - result_df = self.do_multi_process(input_param, result_df) + result_df = pd.DataFrame(columns=header) + for h in header: + if h in result.columns: + result_df[h] = result[h] + return result_df, header - find_compare_result_error_rows(result_df, highlight_dict, self.dump_mode) - highlight_rows_xlsx(result_df, highlight_dict, file_path) - if self.auto_analyze: - advisor = Advisor(result_df, output_path, suffix) - advisor.analysis() +class CalcStatsDiff: + def __init__(self, mode_config: ModeConfig): + self.mode_config = mode_config - print_compare_ends_info() + @staticmethod + def type_check(val): + """ + 检查是否为数值或字符串形式的nan, 如果是返回True + """ + check_series = pd.Series(False, index=val.index) + val_str = val.astype(str) + check_series[pd.to_numeric(val_str, errors='coerce').notna() | val_str.str.lower().eq('nan')] = True + return check_series - def compare_ops(self, idx, dump_path_dict, result_df, lock, input_param): - cos_result = [] - max_err_result = [] - max_relative_err_result = [] - err_mess = [] - one_thousand_err_ratio_result = [] - five_thousand_err_ratio_result = [] - is_print_compare_log = input_param.get("is_print_compare_log") - bench_data = load_json(input_param.get("bench_json_path")).get('data') - for i in range(len(result_df)): - npu_op_name = result_df.iloc[i, 0] - bench_op_name = result_df.iloc[i, 1] - if is_print_compare_log: - logger.info("start compare: {}".format(npu_op_name)) - - cos_sim, max_abs_err, max_relative_err, one_thousand_err_ratio, five_thousand_err_ratio, err_msg = \ - self.compare_by_op(npu_op_name, bench_op_name, dump_path_dict, input_param, bench_data) - - if is_print_compare_log: - logger.info( - "[{}] Compare result: cosine {}, max_abs_err {}, max_relative_err {}, {}, \ - one_thousand_err_ratio {}, " - "five_thousand_err_ratio {}".format(npu_op_name, cos_sim, max_abs_err, max_relative_err, - err_msg, one_thousand_err_ratio, five_thousand_err_ratio)) - cos_result.append(cos_sim) - max_err_result.append(max_abs_err) - max_relative_err_result.append(max_relative_err) - err_mess.append(err_msg) - one_thousand_err_ratio_result.append(one_thousand_err_ratio) - five_thousand_err_ratio_result.append(five_thousand_err_ratio) - - cr = ComparisonResult( - cos_result=cos_result, - max_err_result=max_err_result, - max_relative_err_result=max_relative_err_result, - err_msgs=err_mess, - one_thousand_err_ratio_result=one_thousand_err_ratio_result, - five_thousand_err_ratio_result=five_thousand_err_ratio_result + @staticmethod + def get_number(val): + return pd.to_numeric(val.astype(str), errors='coerce') + + def calc_summary_diff(self, result_df, cond_no_bench, stats_index: str): + npu_val = result_df['NPU ' + stats_index] + bench_val = result_df['Bench ' + stats_index] + diff_name = stats_index.capitalize() + ' diff' + rel_err_name = ('norm' if stats_index == 'l2norm' else stats_index).capitalize() + 'RelativeErr' + + # npu、bench中统计量均为数字或nan + cond_num_nan = self.type_check(npu_val) & self.type_check(bench_val) + + # 如果统计量不是数字或nan,就赋值统计量差异为N/A + result_df.loc[~cond_num_nan, [diff_name, rel_err_name]] = CompareConst.N_A + cond_valid_stat = ~cond_no_bench & cond_num_nan # 有效统计条件:bench_name不是N/A,并且NPU和bench的统计量都是数字或nan + result_df.loc[cond_valid_stat, diff_name] = self.get_number(npu_val) - self.get_number(bench_val) + + cond_diff_nan = result_df[diff_name].isna() # 统计量差异是nan + cond_nan_diff = cond_valid_stat & cond_diff_nan + result_df.loc[cond_nan_diff, [diff_name, rel_err_name]] = CompareConst.NAN + + cond_not_nan_diff = cond_valid_stat & ~cond_diff_nan + condition_pt_zero = bench_val == 0 + result_df.loc[cond_not_nan_diff & condition_pt_zero, rel_err_name] = CompareConst.N_A + + # 相对误差转成百分比字符串 + cond_ref_err = cond_not_nan_diff & ~condition_pt_zero + result_df.loc[cond_ref_err, rel_err_name] = ( + result_df.loc[cond_ref_err, diff_name] / bench_val[cond_ref_err] * 100) + result_df.loc[cond_ref_err, rel_err_name] = (result_df.loc[cond_ref_err, rel_err_name].abs().astype(str) + '%') + + magnitude = self.get_number(result_df[diff_name]).abs() / (pd.Series( + np.maximum(self.get_number(npu_val), self.get_number(bench_val))).abs() + CompareConst.EPSILON) + return magnitude > CompareConst.MAGNITUDE + + def calc_accuracy(self, result_df, header): + # bench name N/A represents no bench data, err_msg adds "No bench data matched." + condition_no_bench = result_df[CompareConst.BENCH_NAME] == CompareConst.N_A + result_df[condition_no_bench] = result_df[condition_no_bench].fillna(CompareConst.N_A) + result_df.loc[condition_no_bench, CompareConst.ERROR_MESSAGE] = CompareConst.NO_BENCH + + if self.mode_config.first_diff_analyze or self.mode_config.dump_mode == Const.SUMMARY: + warning_list = [ + self.calc_summary_diff(result_df, condition_no_bench, stats_index) + for stats_index in ['max', 'min', 'mean', 'l2norm'] + ] + warning_flag = pd.DataFrame(warning_list).any() + result_df.loc[~condition_no_bench, [CompareConst.RESULT, CompareConst.ERROR_MESSAGE]] = '' + result_df.loc[warning_flag, CompareConst.RESULT] = CompareConst.WARNING + result_df.loc[warning_flag, CompareConst.ERROR_MESSAGE] = 'Need double check api accuracy.' + elif self.mode_config.dump_mode == Const.MD5: + condition_md5_equal = result_df[CompareConst.NPU_MD5] == result_df[CompareConst.BENCH_MD5] + result_df.loc[condition_md5_equal, CompareConst.RESULT] = CompareConst.PASS + result_df.loc[~condition_md5_equal & ~condition_no_bench, CompareConst.RESULT] = CompareConst.DIFF + else: + fill_cols = [CompareConst.COSINE, CompareConst.EUC_DIST, + CompareConst.MAX_ABS_ERR, CompareConst.MAX_RELATIVE_ERR, + CompareConst.ONE_THOUSANDTH_ERR_RATIO, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO, + CompareConst.ERROR_MESSAGE] + result_df.loc[~condition_no_bench, fill_cols] = '' + result_df.loc[~condition_no_bench, CompareConst.ACCURACY] = CompareConst.ACCURACY_CHECK_YES + + return result_df[header] + + +def setup_comparison(input_param, output_path, **kwargs) -> ComparisonConfig: + """公共的前置处理逻辑,返回封装后的 ComparisonConfig 对象""" + try: + config = ComparisonConfig( + dump_mode='', + stack_mode=False, + auto_analyze=kwargs.get('auto_analyze', True), + fuzzy_match=kwargs.get('fuzzy_match', False), + data_mapping=kwargs.get('data_mapping', {}), + suffix=kwargs.get('suffix', ''), + cell_mapping=kwargs.get('cell_mapping', {}), + api_mapping=kwargs.get('api_mapping', {}), + layer_mapping=kwargs.get('layer_mapping', {}), + first_diff_analyze=kwargs.get('first_diff_analyze', False) ) - return _save_cmp_result(idx, cr, result_df, lock) + set_dump_path(input_param) + config.dump_mode = get_dump_mode(input_param) - def do_multi_process(self, input_parma, result_df): - try: - result_df = _handle_multi_process(self.compare_ops, input_parma, result_df, - multiprocessing.Manager().RLock()) - return result_df - except ValueError as e: - logger.error('result dataframe is not found.') - raise CompareException(CompareException.INVALID_DATA_ERROR) from e - - -def get_bench_data_name(bench_op_name, bench_data): - bench_name_list = re.split(r'\.(input|output|kwargs|parameters|parameters_grad)\.', bench_op_name) - if len(bench_name_list) > 1 and bench_name_list[1] == Const.PARAMS_GRAD: - bench_data_bundle = bench_data.get(bench_name_list[0] + Const.SEP + bench_name_list[1], {}) - else: - bench_data_bundle = bench_data.get(bench_name_list[0], {}) - if not bench_data_bundle or len(bench_name_list) < 3: - return None - layers = bench_name_list[2].split(Const.SEP) - - def _get(key, container): - if isinstance(container, dict): - return container.get(key) - if isinstance(container, list): - try: - return container[int(key)] - except (ValueError, IndexError): - return None - return None - - def get_by_layer(container, params_grad=False): - data = container - # dump.json中parameters_grad的结构为key:[{}], 如果存在key,有且只有一个列表元素,而op_name中只命名到了key,因此加'0' - if params_grad: - layers.append('0') - for layer in layers: - data = _get(layer, data) - return _get(CompareConst.DATA_NAME.lower(), data) - - if Const.INPUT == bench_name_list[1]: - return get_by_layer(bench_data_bundle.get(Const.INPUT, bench_data_bundle.get(Const.INPUT_ARGS))) - elif Const.KWARGS == bench_name_list[1]: - return get_by_layer(bench_data_bundle.get(Const.INPUT_KWARGS)) - elif Const.OUTPUT == bench_name_list[1]: - return get_by_layer(bench_data_bundle.get(Const.OUTPUT)) - elif Const.PARAMS == bench_name_list[1]: - return get_by_layer(bench_data_bundle.get(Const.PARAMS)) - elif Const.PARAMS_GRAD == bench_name_list[1]: - return get_by_layer(bench_data_bundle, params_grad=True) - else: - return None + # set stack_mode and set "stack_json_path" in input_param + if 'stack_json_path' in input_param: + config.stack_mode = kwargs.get('stack_mode', False) + else: + config.stack_mode = set_stack_json_path(input_param) + + check_configuration_param(config.stack_mode, config.auto_analyze, config.fuzzy_match, + input_param.get('is_print_compare_log', True)) + create_directory(output_path) + check_compare_param(input_param, output_path, config.dump_mode, config.stack_mode) + + return config + + except (CompareException, FileCheckException) as error: + logger.error('Compare failed. Please check the arguments and do it again!') + raise CompareException(error.code) from error diff --git a/debug/accuracy_tools/msprobe/core/compare/check.py b/debug/accuracy_tools/msprobe/core/compare/check.py index 653823e20b29b14b6e7ede929f3bd2865bffaa18..a88ddb8f5e088a9f72ef2d2b721b03dbc539c385 100644 --- a/debug/accuracy_tools/msprobe/core/compare/check.py +++ b/debug/accuracy_tools/msprobe/core/compare/check.py @@ -14,117 +14,46 @@ # limitations under the License. from msprobe.core.common.log import logger -from msprobe.core.compare.utils import rename_api from msprobe.core.common.utils import check_op_str_pattern_valid, CompareException -from msprobe.core.common.const import CompareConst, Const - -dtype_mapping = { - "Int8": "torch.int8", - "UInt8": "torch.uint8", - "Int16": "torch.int16", - "UInt16": "torch.uint16", - "Int32": "torch.int32", - "UInt32": "torch.uint32", - "Int64": "torch.int64", - "UInt64": "torch.uint64", - "Float16": "torch.float16", - "Float32": "torch.float32", - "Float64": "torch.float64", - "Bool": "torch.bool", - "BFloat16": "torch.bfloat16", - "Complex64": "torch.complex64", - "Complex128": "torch.complex128" +from msprobe.core.common.const import Const + +cross_dtype_mapping = { + "Int8": "int", + "torch.int8": "int", + "UInt8": "int", + "torch.uint8": "int", + "Int16": "int", + "torch.int16": "int", + "UInt16": "int", + "torch.uint16": "int", + "Int32": "int", + "torch.int32": "int", + "UInt32": "int", + "torch.uint32": "int", + "Int64": "int", + "torch.int64": "int", + "UInt64": "int", + "torch.uint64": "int", + + "Float16": "float", + "torch.float16": "float", + "Float32": "float", + "torch.float32": "float", + "Float64": "float", + "torch.float64": "float", + "BFloat16": "float", + "torch.bfloat16": "float", + + "Bool": "bool", + "torch.bool": "bool", + + "Complex64": "complex", + "torch.complex64": "complex", + "Complex128": "complex", + "torch.complex128": "complex", } -def compare_op_dict_struct(npu_dict, bench_dict): - return all(npu_dict.get(key) == bench_dict.get(key) for key in CompareConst.STRUCT_COMPARE_KEY) - - -def check_struct_match(npu_dict, bench_dict): - is_match = compare_op_dict_struct(npu_dict, bench_dict) - if not is_match: - struct_match_list = [] - try: - for i, key in enumerate(CompareConst.STRUCT_COMPARE_KEY): - # 首先额外检查input_struct是否空,input_struct不可能为空 - if i == 0 and (not npu_dict.get(key, []) or not bench_dict.get(key, [])): - return False - struct_match_list.append(check_type_shape_match(npu_dict.get(key, []), bench_dict.get(key, []))) - except CompareException as error: - err_msg = f'index out of bounds error occurs in npu or bench api, please check!\n' \ - f'npu_dict: {npu_dict}' \ - f'bench_dict: {bench_dict}' - logger.error(err_msg) - raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from error - is_match = all(struct_match_list) - return is_match - - -def check_type_shape_match(npu_struct, bench_struct): - """ - further check dtypes with a dtype mapping list when dtypes are not entirely consistent. - """ - if len(npu_struct) != len(bench_struct): - return False - if not npu_struct and not bench_struct: - return True - - struct_match = False - for npu_type_shape, bench_type_shape in zip(npu_struct, bench_struct): - try: - npu_type = npu_type_shape[0] - npu_shape = npu_type_shape[1] - bench_type = bench_type_shape[0] - bench_shape = bench_type_shape[1] - except IndexError as error: - logger.error(f'length of npu_type_shape: {npu_type_shape} and bench_type_shape: {bench_type_shape} ' - f'should both be 2, please check!') - raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from error - shape_match = npu_shape == bench_shape - type_match = npu_type == bench_type - if not type_match: - if ([npu_type, bench_type] in CompareConst.MS_TYPE) or ([npu_type, bench_type] in CompareConst.TORCH_TYPE): - type_match = True - else: - type_match = False - struct_match = shape_match and type_match - if not struct_match: - return False - return struct_match - - -def check_graph_mode(a_op_name, b_op_name): - if Const.ATEN in a_op_name and Const.ATEN not in b_op_name: - return True - if Const.ATEN not in a_op_name and Const.ATEN in b_op_name: - return True - return False - - -def fuzzy_check_op(npu_name_list, bench_name_list): - # 先检查api里的item长度是否相等,如果不是parameters_grad, 必然有input或者output,长度不可能为0 - # 如果是parameters_grad, "parameters_grad"字段的字典不会是空字典,因此len>=1 - if len(npu_name_list) == 0 or len(bench_name_list) == 0 or len(npu_name_list) != len(bench_name_list): - return False - is_match = True - for npu_name, bench_name in zip(npu_name_list, bench_name_list): - is_match = fuzzy_check_name(npu_name, bench_name) - if not is_match: - break - return is_match - - -def fuzzy_check_name(npu_name, bench_name): - if Const.FORWARD in npu_name and Const.FORWARD in bench_name: - is_match = rename_api(npu_name, Const.FORWARD) == rename_api(bench_name, Const.FORWARD) - elif Const.BACKWARD in npu_name and Const.BACKWARD in bench_name: - is_match = rename_api(npu_name, Const.BACKWARD) == rename_api(bench_name, Const.BACKWARD) - else: - is_match = npu_name == bench_name - return is_match - - def check_dump_json_str(op_data, op_name): input_list = op_data.get(Const.INPUT_ARGS, None) if op_data.get(Const.INPUT_ARGS, None) else op_data.get( Const.INPUT, None) diff --git a/debug/accuracy_tools/msprobe/core/compare/config.py b/debug/accuracy_tools/msprobe/core/compare/config.py new file mode 100644 index 0000000000000000000000000000000000000000..53fe857453d31c79776c7b1c5f55ee85b83ca426 --- /dev/null +++ b/debug/accuracy_tools/msprobe/core/compare/config.py @@ -0,0 +1,72 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from msprobe.core.common.const import Const, CompareConst +from msprobe.core.common.file_utils import load_yaml + + +class ModeConfig: + def __init__(self, stack_mode=False, auto_analyze=True, fuzzy_match=False, dump_mode=Const.SUMMARY, + first_diff_analyze=False): + self.stack_mode = stack_mode + self.auto_analyze = auto_analyze + self.fuzzy_match = fuzzy_match + self.dump_mode = dump_mode + self.first_diff_analyze = first_diff_analyze + + +class MappingConfig: + def __init__(self, cell_mapping=None, api_mapping=None, data_mapping=None): + self.cell_mapping = cell_mapping + self.api_mapping = api_mapping + self.data_mapping = data_mapping + + +class MappingDict: + def __init__(self, mapping_config: MappingConfig): + self.cell_mapping_dict = self.load_mapping_file(mapping_config.cell_mapping) + self.api_mapping_dict = self.load_mapping_file(mapping_config.api_mapping) + if mapping_config.api_mapping is not None: + self.ms_to_pt_mapping = self.load_internal_api() + self.data_mapping_dict = self.init_data_mapping(mapping_config.data_mapping) + + @staticmethod + def load_internal_api(): + cur_path = os.path.dirname(os.path.realpath(__file__)) + yaml_path = os.path.abspath(os.path.join(cur_path, CompareConst.INTERNAL_API_MAPPING_FILE)) + return load_yaml(yaml_path) + + @staticmethod + def load_mapping_file(mapping_file): + if isinstance(mapping_file, str): + mapping_dict = load_yaml(mapping_file) + else: + mapping_dict = {} + return mapping_dict + + def init_data_mapping(self, data_mapping): + """ + 初始化data_mapping_dict + """ + if isinstance(data_mapping, str) or data_mapping is None: + data_mapping_dict = self.load_mapping_file(data_mapping) + elif isinstance(data_mapping, dict): + data_mapping_dict = data_mapping + else: + raise TypeError(f"The type of parameter `data_mapping` must be dict, str or None, but got " + f"{type(data_mapping)}") + return data_mapping_dict diff --git a/profiler/msprof_analyze/advisor/analyzer/computation/ai_core_performance/__init__.py b/debug/accuracy_tools/msprobe/core/compare/diff_analyze/__init__.py similarity index 100% rename from profiler/msprof_analyze/advisor/analyzer/computation/ai_core_performance/__init__.py rename to debug/accuracy_tools/msprobe/core/compare/diff_analyze/__init__.py diff --git a/debug/accuracy_tools/msprobe/core/compare/diff_analyze/diff_analyze_threshold.yaml b/debug/accuracy_tools/msprobe/core/compare/diff_analyze/diff_analyze_threshold.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4b35d10bfb99f00a93e2fd6ad69112c6a40efce1 --- /dev/null +++ b/debug/accuracy_tools/msprobe/core/compare/diff_analyze/diff_analyze_threshold.yaml @@ -0,0 +1,14 @@ +compare_metrics: + - MaxRelativeErr + - MinRelativeErr + - MeanRelativeErr + - NormRelativeErr + +MaxRelativeErr: + - 0.5 +MinRelativeErr: + - 0.5 +MeanRelativeErr: + - 0.5 +NormRelativeErr: + - 0.5 \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/core/compare/diff_analyze/first_diff_analyze.py b/debug/accuracy_tools/msprobe/core/compare/diff_analyze/first_diff_analyze.py new file mode 100644 index 0000000000000000000000000000000000000000..ef2f7c5487e0e6fb9d19878b1e333e8ef077cbeb --- /dev/null +++ b/debug/accuracy_tools/msprobe/core/compare/diff_analyze/first_diff_analyze.py @@ -0,0 +1,108 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from msprobe.core.common.utils import safe_get_value, logger, CompareException +from msprobe.core.common.file_utils import load_yaml +from msprobe.core.compare.utils import api_batches_update, get_name_and_state + +cur_dir = os.path.dirname(os.path.realpath(__file__)) +diff_threshold_yaml_path = os.path.join(cur_dir, 'diff_analyze_threshold.yaml') +thresholds = load_yaml(diff_threshold_yaml_path) +cmp_metrics = thresholds.get('compare_metrics') + + +class FirstDiffAnalyze: + @staticmethod + def single_metric_diff_check(cmp_metric, metric_value): + threshold = thresholds.get(cmp_metric, None) + if threshold is None: + logger.error(f"Check diff or {cmp_metric} need to configure the threshold. " + f"Please configure it in 'diff_analyze_threshold.yaml'.") + raise CompareException(CompareException.MISSING_THRESHOLD_ERROR) + if not isinstance(threshold, list) or len(threshold) != 1: + logger.error(f"{cmp_metric} threshold configure wrong. Please check.") + raise CompareException(CompareException.WRONG_THRESHOLD_ERROR) + if isinstance(metric_value, str) and metric_value.endswith('%'): + metric_value_float = float(metric_value[:-1]) / 100 + if metric_value_float > threshold[0]: + return True + return False + + def single_api_check(self, result_slice, header): + """ + 单个api差异检查 + + :param result_slice: 数据切片 + :param header: 列名列表 + :return: {'is_same': bool, 'op_items': list[dict]} + """ + single_check_result = { + 'is_same': True, + 'op_items': [] + } + + column_indices = {name: idx for idx, name in enumerate(header)} + + for line in result_slice: + op_item = { + column_name: line[column_indices[column_name]] + for column_name in header + } + single_check_result['op_items'].append(op_item) + + for cmp_metric in cmp_metrics: + metric_value = line[column_indices[cmp_metric]] + if self.single_metric_diff_check(cmp_metric, metric_value): + single_check_result['is_same'] = False + break + return single_check_result + + def check(self, result_df): + """ + 比对后循环遍历api检查norm差异 + example: + { + 'Functional.conv2d.0.forward': { + 'is_same': true, + 'op_items': [ + { + 'NPU name': 'Functional.conv2d.0.forward.input.0', + 'Bench name': 'Functional.conv2d.0.forward.input.0', + 'xxx': 1, + 'NormRelativeErr': 2, + 'yyy': 3, + ... + } + ] + } + } + """ + result = result_df.values + header = result_df.columns.tolist() + + api_batches = [] + for i, res_i in enumerate(result): + api_full_name = safe_get_value(res_i, 0, "res_i") + api_name, state = get_name_and_state(api_full_name) + api_batches_update(api_batches, api_name, state, i) + + check_result = {} + for api_batch in api_batches: + result_slice = result[api_batch.start: api_batch.params_grad_end_index] + check_result[api_batch.api_name[: -1]] = self.single_api_check(result_slice, header) + + return check_result diff --git a/debug/accuracy_tools/msprobe/core/compare/highlight.py b/debug/accuracy_tools/msprobe/core/compare/highlight.py index cf3e1c4c03e9553f5566870b7c5ebe2d890e9774..560ebcdc7a5e5265ff4e869f54487b5c0b2ed83f 100644 --- a/debug/accuracy_tools/msprobe/core/compare/highlight.py +++ b/debug/accuracy_tools/msprobe/core/compare/highlight.py @@ -29,13 +29,8 @@ from msprobe.core.common.const import CompareConst, Const from msprobe.core.common.file_utils import save_workbook from msprobe.core.common.log import logger from msprobe.core.common.utils import get_header_index, safe_get_value -from msprobe.core.compare.utils import table_value_is_valid, get_name_and_state, CompareException - - -class HighlightCheck(abc.ABC): - @abc.abstractmethod - def apply(self, info, color_columns, dump_mode): - raise NotImplementedError +from msprobe.core.compare.utils import table_value_is_valid, get_name_and_state, CompareException, api_batches_update +from msprobe.core.compare.config import ModeConfig def add_highlight_row_info(color_list, num, highlight_err_msg): @@ -46,6 +41,12 @@ def add_highlight_row_info(color_list, num, highlight_err_msg): color_list.append((num, [highlight_err_msg])) +class HighlightCheck(abc.ABC): + @abc.abstractmethod + def apply(self, info, color_columns, dump_mode): + raise NotImplementedError + + class CheckOrderMagnitude(HighlightCheck): """检查Max diff的数量级差异""" @@ -75,12 +76,12 @@ class CheckOneThousandErrorRatio(HighlightCheck): if (api_in[one_thousand_index] > CompareConst.ONE_THOUSAND_ERROR_IN_RED and api_out[one_thousand_index] < CompareConst.ONE_THOUSAND_ERROR_OUT_RED): add_highlight_row_info(color_columns.red, num, - "The input/parameters's one thousandth err ratio exceeds 0.9, " + "The input/parameter's one thousandth err ratio exceeds 0.9, " "while the output's is below 0.6") elif api_in[one_thousand_index] - api_out[one_thousand_index] > CompareConst.ONE_THOUSAND_ERROR_DIFF_YELLOW: add_highlight_row_info(color_columns.yellow, num, "The output's one thousandth err ratio decreases by more than 0.1 " - "compared to the input/parameters's") + "compared to the input/parameter's") class CheckCosineSimilarity(HighlightCheck): @@ -94,7 +95,7 @@ class CheckCosineSimilarity(HighlightCheck): if api_in[cosine_index] - api_out[cosine_index] > CompareConst.COSINE_DIFF_YELLOW: add_highlight_row_info(color_columns.yellow, num, "The output's cosine decreases by more than 0.1 " - "compared to the input/parameters's") + "compared to the input/parameter's") class CheckMaxRelativeDiff(HighlightCheck): @@ -117,7 +118,7 @@ class CheckMaxRelativeDiff(HighlightCheck): input_max_relative_diff < CompareConst.MAX_RELATIVE_IN_YELLOW): add_highlight_row_info(color_columns.yellow, num, "The output's maximum relative error exceeds 0.1, " - "while the input/parameters's is below 0.01") + "while the input/parameter's is below 0.01") class CheckOverflow(HighlightCheck): @@ -146,270 +147,216 @@ class HighlightRules: } # 用于比较输入和输出的规则 + # 真实数据检查规则 compare_rules = { "check_order_magnitude": CheckOrderMagnitude(), "check_one_thousand_error": CheckOneThousandErrorRatio(), "check_cosine_similarity": CheckCosineSimilarity() } + # 统计量数据检查规则 summary_compare_rules = { "check_order_magnitude": CheckOrderMagnitude(), "check_max_relative_diff": CheckMaxRelativeDiff(), } -def check_indices_numeric(api_items, indices: list): - """检查指定索引处的值是否都为数字类型(int 或 float)""" - return all(isinstance(api_items[i], (float, int)) for i in indices) - - -def apply_comparison_rules(api_info, dump_mode, color_columns): - """output与input/params的比较""" - if dump_mode == Const.SUMMARY: - for rule in HighlightRules.summary_compare_rules.values(): - rule.apply(api_info, color_columns, dump_mode) - else: - for rule in HighlightRules.compare_rules.values(): - rule.apply(api_info, color_columns, dump_mode) - - -def find_error_rows(result, api_batch, highlight_dict, dump_mode): - """找到单个API中需要高亮的行""" - if dump_mode == Const.MD5: - return - npu_max_index = get_header_index(CompareConst.NPU_MAX, dump_mode) - bench_max_index = get_header_index(CompareConst.BENCH_MAX, dump_mode) - max_diff_index = get_header_index(CompareConst.MAX_DIFF if dump_mode == Const.SUMMARY - else CompareConst.MAX_ABS_ERR, dump_mode) - - red_lines, yellow_lines = [], [] - LineInfo = namedtuple('LineInfo', ['line_data', 'num_pointer']) - ApiInfo = namedtuple('ApiInfo', ['api_input', 'api_output', 'num_pointer']) - ColorColumns = namedtuple('ColorColumns', ['red', 'yellow']) - color_columns = ColorColumns(red=red_lines, yellow=yellow_lines) - - api_batch_start = api_batch.start # result_df的input起始全局索引 - api_batch_params_end_index = api_batch.params_end_index # result_df的params结束全局索引 + 1 - api_batch_output_end_index = api_batch.output_end_index # result_df的output结束全局索引 + 1 - api_batch_params_slice_index_local = api_batch_params_end_index - api_batch_start # result的params结束局部切片索引 - api_batch_output_slice_index_local = api_batch_output_end_index - api_batch_start # result的output结束局部切片索引 - - # 对单行API的输入或输出进行误差判断 - for i, line in enumerate(result): - index = api_batch_start + i - line_info = LineInfo(line_data=line, num_pointer=index) - for rule in HighlightRules.basic_rules.values(): - rule.apply(line_info, color_columns, dump_mode) - - # 对API的输出与输入比较,进行误差判断 - for n, api_out in enumerate(result[api_batch_params_slice_index_local: api_batch_output_slice_index_local]): - index = api_batch_start + api_batch_params_slice_index_local + n - # 单行检查只有溢出检查(红色),如果已经溢出,不进一步检查 - if index in red_lines: - continue - if not check_indices_numeric(api_out, [npu_max_index, bench_max_index, max_diff_index]): - continue - - # input/parameters的比较检查, 这里api_in包括input、parameters - for _, api_in in enumerate(result[0: api_batch_params_slice_index_local]): - if not check_indices_numeric(api_in, [npu_max_index, bench_max_index, max_diff_index]): - continue - api_info = ApiInfo(api_input=api_in, api_output=api_out, num_pointer=index) - apply_comparison_rules(api_info, dump_mode, color_columns) - - red_lines_num_set = {x[0] for x in red_lines} - yellow_lines_num_set = {x[0] for x in yellow_lines} - highlight_dict.get('red_rows', set()).update(red_lines_num_set) - highlight_dict.get('yellow_rows', set()).update(yellow_lines_num_set - red_lines_num_set) - highlight_dict.get('red_lines', []).extend(red_lines) - highlight_dict.get('yellow_lines', []).extend(yellow_lines) - - -class ApiBatch: - def __init__(self, api_name: str, start: int): - self.api_name = api_name - self.start = start - self.input_len = 1 # input的数量 - self.params_end_index = start + 1 # params的结束index - self.output_end_index = start + 1 # output的结束index - self.params_grad_end_index = start + 1 # params_grad的结束index - # 内部state的标志("input", "output", "parameters", "parameters_grad"), - # 用于控制计算input_len, output_end_index, params_end_index, self.params_grad_end_index - self._state = Const.INPUT # api_batch初始化为input - - def set_state(self, state: str): - """设置当前状态""" - if state in {Const.INPUT, Const.OUTPUT, Const.KWARGS, Const.PARAMS, Const.PARAMS_GRAD}: - self._state = state - else: - raise ValueError(f"Invalid state: {state}") - - def increment(self, state: str): - self.set_state(state) - if self._state == Const.INPUT or self._state == Const.KWARGS: - self.input_len += 1 - self.params_end_index += 1 - self.output_end_index += 1 - if self._state == Const.PARAMS: - self.params_end_index += 1 - self.output_end_index += 1 - if self._state == Const.OUTPUT: - self.output_end_index += 1 - self.params_grad_end_index += 1 - - -def api_batches_update(api_batches, api_name, state, index): - """ - 当一个api的所有item更新完后,input, output的索引范围: - input: [start: start+input_len] - output: [start+input_len: output_end_index] - params: [output_end_index: params_end_index] - """ - if not api_batches: - api_batches.append(ApiBatch(api_name, index)) - else: - api_batch = api_batches[-1] - if api_batch.api_name == api_name or ( - not re.search(Const.REGEX_FORWARD_BACKWARD, api_name) and api_name in api_batch.api_name): - try: - api_batch.increment(state) - except ValueError as e: - logger.error(f"api_batch: {api_batch} with invalid state, please check! {e}") - raise CompareException(CompareException.INVALID_STATE_ERROR) from e - else: - api_batches.append(ApiBatch(api_name, index)) - - -def find_compare_result_error_rows(result_df, highlight_dict, dump_mode): - """将dataframe根据API分组,并找到有误差的算子用于高亮""" - result = result_df.values - api_batches = [] - for i, res_i in enumerate(result): - api_full_name = safe_get_value(res_i, 0, "res_i") - api_name, state = get_name_and_state(api_full_name) - api_batches_update(api_batches, api_name, state, i) - with tqdm(total=len(api_batches), desc="API/Module Analyse Progress", unit="item", ncols=100) as progress_bar: - for api_batch in api_batches: - find_error_rows(result[api_batch.start: api_batch.params_grad_end_index], api_batch, highlight_dict, - dump_mode) - progress_bar.update(1) - - -def value_check(value, api_name=None, i=None, result_df_columns=None): - if not table_value_is_valid(value): - if result_df_columns: - logger.error(f"Malicious value [{value}] at api_name [{api_name}], column [{result_df_columns[i]}], " - f"is not allowed to be written into the compare result xlsx.") - else: - logger.error(f"Malicious value [{value}] is not allowed to be written into the compare result xlsx.") - - -def df_malicious_value_check(df_chunk, result_df_columns): - for row in df_chunk.itertuples(index=False): - api_name = row[0] - for i, value in enumerate(row): - value_check(value, api_name, i, result_df_columns) - - -def handle_multi_process_malicious_value_check(func, result_df): - result_total_nums = len(result_df) - process_num = int((multiprocessing.cpu_count() + 1) / 2) - - if result_total_nums <= process_num: - process_num = 1 - chunks = [result_df] - else: - chunk_size = result_total_nums // process_num - chunks = [result_df.iloc[i: i + chunk_size] for i in range(0, result_total_nums, chunk_size)] - - pool = multiprocessing.Pool(process_num) - - def err_call(args): - logger.error("Multiprocessing malicious value check failed! Reason: {}".format(args)) - try: - pool.terminate() - except OSError: - logger.error("Pool terminate failed") - - result_df_columns = result_df.columns.tolist() - for column in result_df_columns: - value_check(column) - for df_chunk in chunks: - pool.apply_async(func, args=(df_chunk, result_df_columns,), error_callback=err_call) - - pool.close() - pool.join() - +class HighLight: + def __init__(self, mode_config: ModeConfig): + self.mode_config = mode_config -def compare_result_df_convert(value): - if not isinstance(value, (float, int)) or isinstance(value, bool): # bool类型或者非数字类型转str - value = f"{str(value)}\t" if str(value) in ("inf", "-inf", "nan") else str(value) - if isinstance(value, float): - value = f"{str(value)}\t" if str(value) in ("inf", "-inf", "nan") else value - return value + @staticmethod + def check_indices_numeric(api_items, indices: list): + """检查指定索引处的值是否都为数字类型(int 或 float)""" + return all(isinstance(api_items[i], (float, int)) for i in indices) + @staticmethod + def update_highlight_err_msg(result_df, highlight_dict): + if result_df.shape[1] <= 1: + return -def highlight_rows_xlsx(result_df, highlight_dict, file_path): - """Write and highlight results in Excel""" - - update_highlight_err_msg(result_df, highlight_dict) # add highlight err_msg + if CompareConst.NPU_MD5 in result_df.columns: + return - wb = openpyxl.Workbook() - ws = wb.active + err_msg = result_df.get(CompareConst.ERROR_MESSAGE) + red_lines_num_set = highlight_dict.get('red_rows') + + for color in ['red', 'yellow']: + line_key = f'{color}_lines' + lines = highlight_dict.get(line_key, []) + for line_index, messages in lines: + if color == 'yellow' and line_index in red_lines_num_set: + continue # 如果是 yellow 行,且已被 red 行覆盖,跳过 + + for msg in messages: + if err_msg[line_index] == '': + err_msg[line_index] = msg + else: + err_msg[line_index] += '\n' + msg + + if color == 'red': + red_lines_num_set.add(line_index) + + result_df[CompareConst.ERROR_MESSAGE] = err_msg + + @staticmethod + def compare_result_df_convert(value): + if not isinstance(value, (float, int)) or isinstance(value, bool): # bool类型或者非数字类型转str + value = f"{str(value)}\t" if str(value) in ("inf", "-inf", "nan") else str(value) + if isinstance(value, float): + value = f"{str(value)}\t" if str(value) in ("inf", "-inf", "nan") else value + return value + + @staticmethod + def value_check(value, api_name=None, i=None, result_df_columns=None): + if not table_value_is_valid(value): + if result_df_columns: + logger.error(f"Malicious value [{value}] at api_name [{api_name}], column [{result_df_columns[i]}], " + f"is not allowed to be written into the compare result xlsx.") + else: + logger.error(f"Malicious value [{value}] is not allowed to be written into the compare result xlsx.") + + def find_compare_result_error_rows(self, result_df, highlight_dict): + """将dataframe根据API分组,并找到有误差的算子用于高亮""" + result = result_df.values + api_batches = [] + for i, res_i in enumerate(result): + api_full_name = safe_get_value(res_i, 0, "res_i") + api_name, state = get_name_and_state(api_full_name) + api_batches_update(api_batches, api_name, state, i) + with tqdm(total=len(api_batches), desc="API/Module Analyse Progress", unit="item", ncols=100) as progress_bar: + for api_batch in api_batches: + self.find_error_rows(result[api_batch.start: api_batch.params_grad_end_index], api_batch, + highlight_dict) + progress_bar.update(1) + + def find_error_rows(self, result, api_batch, highlight_dict): + """找到单个API中需要高亮的行""" + if self.mode_config.dump_mode == Const.MD5: + return + npu_max_index = get_header_index(CompareConst.NPU_MAX, self.mode_config.dump_mode) + bench_max_index = get_header_index(CompareConst.BENCH_MAX, self.mode_config.dump_mode) + max_diff_index = get_header_index(CompareConst.MAX_DIFF if self.mode_config.dump_mode == Const.SUMMARY + else CompareConst.MAX_ABS_ERR, self.mode_config.dump_mode) + + red_lines, yellow_lines = [], [] + LineInfo = namedtuple('LineInfo', ['line_data', 'num_pointer']) + ApiInfo = namedtuple('ApiInfo', ['api_input', 'api_output', 'num_pointer']) + ColorColumns = namedtuple('ColorColumns', ['red', 'yellow']) + color_columns = ColorColumns(red=red_lines, yellow=yellow_lines) + + api_batch_start = api_batch.start # result_df的input起始全局索引 + api_batch_params_end_index = api_batch.params_end_index # result_df的params结束全局索引 + 1 + api_batch_output_end_index = api_batch.output_end_index # result_df的output结束全局索引 + 1 + api_batch_params_slice_index_local = api_batch_params_end_index - api_batch_start # result的params结束局部切片索引 + api_batch_output_slice_index_local = api_batch_output_end_index - api_batch_start # result的output结束局部切片索引 + + # 对单行API的输入或输出进行误差判断 + for i, line in enumerate(result): + index = api_batch_start + i + line_info = LineInfo(line_data=line, num_pointer=index) + for rule in HighlightRules.basic_rules.values(): + rule.apply(line_info, color_columns, self.mode_config.dump_mode) + + # 对API的输出与输入比较,进行误差判断 + for n, api_out in enumerate(result[api_batch_params_slice_index_local: api_batch_output_slice_index_local]): + index = api_batch_start + api_batch_params_slice_index_local + n + # 单行检查只有溢出检查(红色),如果已经溢出,不进一步检查 + if index in red_lines: + continue + if not self.check_indices_numeric(api_out, [npu_max_index, bench_max_index, max_diff_index]): + continue - # write header - logger.info('Initializing Excel file.') + # input/parameters的比较检查, 这里api_in包括input、parameters + for api_in in result[0: api_batch_params_slice_index_local]: + if not self.check_indices_numeric(api_in, [npu_max_index, bench_max_index, max_diff_index]): + continue + api_info = ApiInfo(api_input=api_in, api_output=api_out, num_pointer=index) + self.apply_comparison_rules(api_info, color_columns) + + red_lines_num_set = {x[0] for x in red_lines} + yellow_lines_num_set = {x[0] for x in yellow_lines} + highlight_dict.get('red_rows', set()).update(red_lines_num_set) + highlight_dict.get('yellow_rows', set()).update(yellow_lines_num_set - red_lines_num_set) + highlight_dict.get('red_lines', []).extend(red_lines) + highlight_dict.get('yellow_lines', []).extend(yellow_lines) + + def apply_comparison_rules(self, api_info, color_columns): + """output与input/params的比较""" + if self.mode_config.dump_mode == Const.SUMMARY: + for rule in HighlightRules.summary_compare_rules.values(): + rule.apply(api_info, color_columns, self.mode_config.dump_mode) + else: + for rule in HighlightRules.compare_rules.values(): + rule.apply(api_info, color_columns, self.mode_config.dump_mode) - handle_multi_process_malicious_value_check(df_malicious_value_check, result_df) + def highlight_rows_xlsx(self, result_df, highlight_dict, file_path): + """Write and highlight results in Excel""" - result_df_convert = result_df.applymap(compare_result_df_convert) + self.update_highlight_err_msg(result_df, highlight_dict) # add highlight err_msg - for row in dataframe_to_rows(result_df_convert, index=False, header=True): - ws.append(row) + wb = openpyxl.Workbook() + ws = wb.active - # 对可疑数据标色 - logger.info('Coloring Excel in progress.') - col_len = len(result_df.columns) - red_fill = PatternFill( - start_color=CompareConst.RED, end_color=CompareConst.RED, fill_type="solid" - ) - yellow_fill = PatternFill( - start_color=CompareConst.YELLOW, end_color=CompareConst.YELLOW, fill_type="solid", - ) - for i in highlight_dict.get("red_rows", []): - for j in range(1, col_len + 1): - ws.cell(row=i + 2, column=j).fill = red_fill # 2因为ws.cell中的row或column需要>=1,数据从第2行开始 - for i in highlight_dict.get("yellow_rows", []): - for j in range(1, col_len + 1): - ws.cell(row=i + 2, column=j).fill = yellow_fill + # write header + logger.info('Initializing Excel file.') - logger.info('Saving Excel file to disk: %s' % file_path) - save_workbook(wb, file_path) + self.handle_multi_process_malicious_value_check(self.df_malicious_value_check, result_df) + result_df_convert = result_df.applymap(self.compare_result_df_convert) -def update_highlight_err_msg(result_df, highlight_dict): - if result_df.shape[1] <= 1: - return + for row in dataframe_to_rows(result_df_convert, index=False, header=True): + ws.append(row) - if CompareConst.NPU_MD5 in result_df.columns: - return + # 对可疑数据标色 + logger.info('Coloring Excel in progress.') + col_len = len(result_df.columns) + red_fill = PatternFill( + start_color=CompareConst.RED, end_color=CompareConst.RED, fill_type="solid" + ) + yellow_fill = PatternFill( + start_color=CompareConst.YELLOW, end_color=CompareConst.YELLOW, fill_type="solid", + ) + for i in highlight_dict.get("red_rows", []): + for j in range(1, col_len + 1): + ws.cell(row=i + 2, column=j).fill = red_fill # 2因为ws.cell中的row或column需要>=1,数据从第2行开始 + for i in highlight_dict.get("yellow_rows", []): + for j in range(1, col_len + 1): + ws.cell(row=i + 2, column=j).fill = yellow_fill - err_msg = result_df.get(CompareConst.ERROR_MESSAGE) - red_lines_num_set = highlight_dict.get('red_rows') + logger.info('Saving Excel file to disk: %s' % file_path) + save_workbook(wb, file_path) - for color in ['red', 'yellow']: - line_key = f'{color}_lines' - lines = highlight_dict.get(line_key, []) - for line_index, messages in lines: - if color == 'yellow' and line_index in red_lines_num_set: - continue # 如果是 yellow 行,且已被 red 行覆盖,跳过 + def handle_multi_process_malicious_value_check(self, func, result_df): + result_total_nums = len(result_df) + process_num = int((multiprocessing.cpu_count() + 1) / 2) - for msg in messages: - if err_msg[line_index] == '': - err_msg[line_index] = msg - else: - err_msg[line_index] += '\n' + msg + if result_total_nums <= process_num: + process_num = 1 + chunks = [result_df] + else: + chunk_size = result_total_nums // process_num + chunks = [result_df.iloc[i: i + chunk_size] for i in range(0, result_total_nums, chunk_size)] - if color == 'red': - red_lines_num_set.add(line_index) + pool = multiprocessing.Pool(process_num) - result_df[CompareConst.ERROR_MESSAGE] = err_msg + def err_call(args): + logger.error("Multiprocessing malicious value check failed! Reason: {}".format(args)) + try: + pool.close() + except OSError: + logger.error("Pool terminate failed") + + result_df_columns = result_df.columns.tolist() + for column in result_df_columns: + self.value_check(column) + for df_chunk in chunks: + pool.apply_async(func, args=(df_chunk, result_df_columns,), error_callback=err_call) + + pool.close() + pool.join() + + def df_malicious_value_check(self, df_chunk, result_df_columns): + for row in df_chunk.itertuples(index=False): + api_name = row[0] + for i, value in enumerate(row): + self.value_check(value, api_name, i, result_df_columns) diff --git a/debug/accuracy_tools/msprobe/core/compare/layer_mapping/layer_mapping.py b/debug/accuracy_tools/msprobe/core/compare/layer_mapping/layer_mapping.py index d0f19462ee1ccf4d72c69885c18174cec32df056..4845adb0482b1a6cca988e876a1315e56589e87a 100644 --- a/debug/accuracy_tools/msprobe/core/compare/layer_mapping/layer_mapping.py +++ b/debug/accuracy_tools/msprobe/core/compare/layer_mapping/layer_mapping.py @@ -23,7 +23,7 @@ from msprobe.core.common.utils import (add_time_with_yaml, get_stack_construct_by_dump_json_path) from msprobe.core.compare.layer_mapping.data_scope_parser import get_dump_data_items from msprobe.core.compare.utils import read_op, reorder_op_name_list - +from msprobe.core.common.decorator import recursion_depth_decorator class LayerTrie: @@ -71,6 +71,7 @@ class LayerTrie: file_path = os.path.join(os.path.realpath(output_path), file_name) save_yaml(file_path, result) + @recursion_depth_decorator("LayerMapping: LayerTrie.convert_to_dict", max_depth=100) def convert_to_dict(self, node): result = {} result["data_item"] = {st: [dt.data_name for dt in dts] for st, dts in node.data_items.items()} @@ -163,6 +164,8 @@ def preprocess_layer_mapping(mapping): for key, value in name_map.items(): key_list = key.split('.') prefix = key_list[0] # 取前缀 + value_list = value.split('(') + value = value_list[0] # 取前缀 key_len = len(key_list) if prefix not in final_mapping[type_name]: final_mapping[type_name][prefix] = [] diff --git a/debug/accuracy_tools/msprobe/core/compare/merge_result/merge_result.py b/debug/accuracy_tools/msprobe/core/compare/merge_result/merge_result.py index b605bd59fca0b2b3a510a7a686caa94383488bd2..9edc6d9a9dc36d05325c5af98f18a296f3627e2f 100644 --- a/debug/accuracy_tools/msprobe/core/compare/merge_result/merge_result.py +++ b/debug/accuracy_tools/msprobe/core/compare/merge_result/merge_result.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -21,7 +21,8 @@ from functools import partial import pandas as pd from tqdm import tqdm -from msprobe.core.common.file_utils import load_yaml, logger, FileChecker, save_excel, read_xlsx, create_directory +from msprobe.core.common.file_utils import load_yaml, logger, FileChecker, save_excel, read_xlsx, create_directory, \ + remove_path from msprobe.core.common.const import FileCheckConst, Const, CompareConst from msprobe.core.common.utils import CompareException, add_time_with_xlsx from msprobe.core.compare.utils import table_value_is_valid @@ -32,8 +33,8 @@ def check_compare_result_name(file_name): """ check whether the compare result name is as expected """ - single_rank_pattern = r"^compare_result_rank-rank_\d{14}.xlsx$" - multi_ranks_pattern = r"^compare_result_rank(\d+)-rank\1_\d{14}.xlsx$" + single_rank_pattern = r"^compare_result_(rank|rank-rank)_\d{14}\.xlsx$" + multi_ranks_pattern = r"^compare_result_rank(\d+)(?:-rank\1)?_\d{14}\.xlsx$" if re.match(multi_ranks_pattern, file_name): return True if re.match(single_rank_pattern, file_name): @@ -47,7 +48,7 @@ def reorder_path(compare_result_path_list): """ reorder compare results by rank num """ - rank_pattern = r"compare_result_rank(\d+)-rank" + rank_pattern = r"compare_result_rank(\d+)" reorder_path_list = sorted( compare_result_path_list, key=lambda path: int(re.search(rank_pattern, os.path.basename(path)).group(1)) @@ -63,6 +64,7 @@ def get_result_path(input_dir): for f in os.listdir(input_dir) if f.endswith(FileCheckConst.XLSX_SUFFIX)] filt_compare_result_path_list = [] for file_path in compare_result_path_list: + FileChecker(file_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE).common_check() file_name = os.path.basename(file_path) if check_compare_result_name(file_name): compare_result_path_checker = FileChecker(file_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE) @@ -236,7 +238,7 @@ def handle_multi_process(func, func_args, lock): def err_call(args): logger.error('Multiprocess merge result failed! Reason: {}'.format(args)) try: - pool.terminate() + pool.close() except OSError: logger.error("Pool terminate failed") @@ -329,6 +331,10 @@ def generate_merge_result(all_compare_index_dict_list, all_rank_num_list, all_co for i, df in enumerate(merge_df_list): # merge_df_list中df与compare_index_list中compare_index一一对应 final_result_df_list.append((df, compare_index_list[i])) + + if os.path.exists(output_path): + logger.warning(f"{output_path} will be deleted.") + remove_path(output_path) save_excel(output_path, final_result_df_list) logger.info(f"The compare results of the multi-ranks are merged and saved in: {output_path}.") diff --git a/debug/accuracy_tools/msprobe/mindspore/compare/ms_to_pt_api.yaml b/debug/accuracy_tools/msprobe/core/compare/ms_to_pt_api.yaml similarity index 100% rename from debug/accuracy_tools/msprobe/mindspore/compare/ms_to_pt_api.yaml rename to debug/accuracy_tools/msprobe/core/compare/ms_to_pt_api.yaml diff --git a/debug/accuracy_tools/msprobe/core/compare/multiprocessing_compute.py b/debug/accuracy_tools/msprobe/core/compare/multiprocessing_compute.py index c2c1461e452f9d2c7f4e0e2803dfe51be2a132c0..8e23f9f8b9f0f1cb33e58013aaa325da5ffdf11b 100644 --- a/debug/accuracy_tools/msprobe/core/compare/multiprocessing_compute.py +++ b/debug/accuracy_tools/msprobe/core/compare/multiprocessing_compute.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -15,51 +15,28 @@ import multiprocessing from dataclasses import dataclass +from functools import partial + import pandas as pd from tqdm import tqdm + from msprobe.core.common.log import logger from msprobe.core.common.utils import CompareException from msprobe.core.common.const import CompareConst +from msprobe.core.common.exceptions import FileCheckException +from msprobe.core.compare.npy_compare import compare_ops_apply, get_error_flag_and_msg +from msprobe.core.compare.config import ModeConfig -def _handle_multi_process(func, input_parma, result_df, lock): - process_num = max(int((multiprocessing.cpu_count() + 1) // 4), 1) - op_name_mapping_dict = read_dump_data(result_df) - - df_chunk_size = len(result_df) // process_num - if df_chunk_size > 0: - df_chunks = [result_df.iloc[i:i + df_chunk_size] for i in range(0, len(result_df), df_chunk_size)] - else: - df_chunks = [result_df] - - results = [] - pool = multiprocessing.Pool(process_num) - - def err_call(args): - logger.error('multiprocess compare failed! Reason: {}'.format(args)) - try: - pool.terminate() - except OSError as e: - logger.error("pool terminate failed") - - progress_bar = tqdm(total=len(result_df), desc="API/Module Item Compare Process", unit="row", ncols=100) - - def update_progress(size, progress_lock): - with progress_lock: - progress_bar.update(size) - - for process_idx, df_chunk in enumerate(df_chunks): - idx = df_chunk_size * process_idx - chunk_size = len(df_chunk) - result = pool.apply_async(func, - args=(idx, op_name_mapping_dict, df_chunk, lock, input_parma), - error_callback=err_call, - callback=update_progress(chunk_size, lock)) - results.append(result) - final_results = [r.get() for r in results] - pool.close() - pool.join() - return pd.concat(final_results, ignore_index=True) +@dataclass +class ComparisonResult: + cos_result: list + euc_dist_result: list + max_err_result: list + max_relative_err_result: list + one_thousand_err_ratio_result: list + five_thousand_err_ratio_result: list + err_msgs: list def _ms_graph_handle_multi_process(func, result_df, mode): @@ -76,9 +53,9 @@ def _ms_graph_handle_multi_process(func, result_df, mode): def err_call(args): logger.error('multiprocess compare failed! Reason: {}'.format(args)) try: - pool.terminate() + pool.close() except OSError as e: - logger.error("pool terminate failed") + logger.error(f'pool terminate failed: {str(e)}') for df_chunk in df_chunks: result = pool.apply_async(func, args=(df_chunk, mode), error_callback=err_call) @@ -89,72 +66,6 @@ def _ms_graph_handle_multi_process(func, result_df, mode): return pd.concat(final_results, ignore_index=True) -def read_dump_data(result_df): - try: - npu_dump_name_list = result_df.iloc[0:, 0].tolist() - npu_dump_tensor_list = result_df.iloc[0:, -1].tolist() - op_name_mapping_dict = {} - for index, _ in enumerate(npu_dump_name_list): - npu_dump_name = npu_dump_name_list[index] - npu_dump_tensor = npu_dump_tensor_list[index] - op_name_mapping_dict[npu_dump_name] = [npu_dump_tensor, npu_dump_tensor] - return op_name_mapping_dict - except ValueError as e: - logger.error('result dataframe is not found.') - raise CompareException(CompareException.INVALID_DATA_ERROR) from e - except IndexError as e: - logger.error('result dataframe elements can not be access.') - raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from e - - -@dataclass -class ComparisonResult: - cos_result: list - max_err_result: list - max_relative_err_result: list - err_msgs: list - one_thousand_err_ratio_result: list - five_thousand_err_ratio_result: list - - -def _save_cmp_result(offset, result: ComparisonResult, result_df, lock): - """ - Save comparison results into the result DataFrame with thread safety. - Args: - offset: offset for index - result: data struct of ComparisonResult - result_df: result of DataFrame - lock: thread lock - - Returns: - comparison results in DataFrame - """ - - lock.acquire() - try: - for i, _ in enumerate(result.cos_result): - process_index = i + offset - result_df.loc[process_index, CompareConst.COSINE] = result.cos_result[i] - result_df.loc[process_index, CompareConst.MAX_ABS_ERR] = result.max_err_result[i] - result_df.loc[process_index, CompareConst.MAX_RELATIVE_ERR] = result.max_relative_err_result[i] - result_df.loc[process_index, CompareConst.ERROR_MESSAGE] = result.err_msgs[i] - result_df.loc[process_index, CompareConst.ACCURACY] = ( - check_accuracy(result.cos_result[i], result.max_err_result[i])) - result_df.loc[process_index, CompareConst.ONE_THOUSANDTH_ERR_RATIO] = ( - result.one_thousand_err_ratio_result)[i] - result_df.loc[process_index, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO] = ( - result.five_thousand_err_ratio_result)[i] - return result_df - except ValueError as e: - logger.error('result dataframe is not found.') - raise CompareException(CompareException.INVALID_DATA_ERROR) from e - except IndexError as e: - logger.error('result dataframe elements can not be access.') - raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from e - finally: - lock.release() - - def check_accuracy(cos, max_abs_err): if cos == CompareConst.SHAPE_UNMATCH: return CompareConst.ACCURACY_CHECK_UNMATCH @@ -172,3 +83,212 @@ def check_accuracy(cos, max_abs_err): if cos < CompareConst.COS_MAX_THRESHOLD or max_abs_err > CompareConst.MAX_ABS_ERR_MAX_THRESHOLD: return CompareConst.ACCURACY_CHECK_NO return CompareConst.ACCURACY_CHECK_YES + + +class CompareRealData: + def __init__(self, file_reader, mode_config: ModeConfig, cross_frame): + self.file_reader = file_reader + self.mode_config = mode_config + self.cross_frame = cross_frame + + @staticmethod + def read_dump_data(result_df): + try: + npu_dump_name_list = result_df.iloc[0:, 0].tolist() + dump_tensor_pair_list = result_df.iloc[0:, -1].tolist() + op_name_mapping_dict = {} + for index, npu_dump_name in enumerate(npu_dump_name_list): + dump_tensor_pair = dump_tensor_pair_list[index] + op_name_mapping_dict[npu_dump_name] = dump_tensor_pair + return op_name_mapping_dict + except ValueError as e: + logger.error('result dataframe is not found.') + raise CompareException(CompareException.INVALID_DATA_ERROR) from e + except IndexError as e: + logger.error('result dataframe elements can not be access.') + raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from e + + @staticmethod + def _save_cmp_result(offset, result: ComparisonResult, result_df, lock): + """ + Save comparison results into the result DataFrame with thread safety. + Args: + offset: offset for index + result: data struct of ComparisonResult + result_df: result of DataFrame + lock: thread lock + + Returns: + comparison results in DataFrame + """ + + lock.acquire() + try: + for i, cos_item in enumerate(result.cos_result): + process_index = i + offset + result_df.loc[process_index, CompareConst.COSINE] = cos_item + result_df.loc[process_index, CompareConst.EUC_DIST] = result.euc_dist_result[i] + result_df.loc[process_index, CompareConst.MAX_ABS_ERR] = result.max_err_result[i] + result_df.loc[process_index, CompareConst.MAX_RELATIVE_ERR] = result.max_relative_err_result[i] + result_df.loc[process_index, CompareConst.ONE_THOUSANDTH_ERR_RATIO] = ( + result.one_thousand_err_ratio_result)[i] + result_df.loc[process_index, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO] = ( + result.five_thousand_err_ratio_result)[i] + result_df.loc[process_index, CompareConst.ACCURACY] = ( + check_accuracy(result.cos_result[i], result.max_err_result[i])) + result_df.loc[process_index, CompareConst.ERROR_MESSAGE] = result.err_msgs[i] + return result_df + except ValueError as e: + logger.error('result dataframe is not found.') + raise CompareException(CompareException.INVALID_DATA_ERROR) from e + except IndexError as e: + logger.error('result dataframe elements can not be access.') + raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from e + finally: + lock.release() + + def compare_by_op(self, npu_op_name, bench_op_name, op_name_mapping_dict, input_param): + """ + :param npu_op_name: excel中的NPU_Name,例如:MintFunctional.conv2d.0.forward.input.3.0 + :param bench_op_name: excel中的Bench_Name,例如:Functional.conv2d.0.forward.input.3.0 + :param op_name_mapping_dict: op_name和npy或pt文件的映射关系 + :param input_param: npu_json_path/bench_json_path/stack_json_path等参数 + :return: result_list,包含余弦相似度、最大绝对误差、最大相对误差、千分之一误差率、千分之五误差率和错误信息 + 用于读取excel中的NPU_Name和Bench_Name,根据映射关系找到npy或pt文件,然后读取文件中的数据进行比较,计算余弦相似度、欧式距离 + 最大绝对误差、最大相对误差、千分之一误差率、千分之五误差率并生成错误信息 + """ + error_file, relative_err, error_flag = None, None, False + + data_name_pair = op_name_mapping_dict.get(npu_op_name) + npu_data_name = data_name_pair[0] + bench_data_name = data_name_pair[1] + + if str(npu_data_name) == CompareConst.NO_REAL_DATA_FLAG: # 没有npu真实数据 + n_value, b_value, error_flag = CompareConst.READ_NONE, CompareConst.READ_NONE, True + elif str(bench_data_name) == CompareConst.NO_REAL_DATA_FLAG: # 没有bench真实数据 + n_value, b_value, error_flag = CompareConst.READ_NONE, CompareConst.READ_NONE, True + error_file = 'no_bench_data' + elif str(bench_data_name) == CompareConst.N_A: # bench没匹配 + n_value, b_value, error_flag = CompareConst.READ_NONE, CompareConst.READ_NONE, True + error_file = None + else: + npu_dir = input_param.get(CompareConst.NPU_DUMP_DATA_DIR) + bench_dir = input_param.get(CompareConst.BENCH_DUMP_DATA_DIR) + try: + n_value, b_value = self.file_reader(npu_dir, npu_data_name, bench_dir, bench_data_name, + self.cross_frame) + except IOError as error: + error_file = error.filename + n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE + error_flag = True + except (FileCheckException, CompareException): + error_file = data_name_pair + n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE + error_flag = True + + # 通过n_value, b_value同时得到错误标志和错误信息 + n_value, b_value, error_flag, err_msg = get_error_flag_and_msg(n_value, b_value, + error_flag=error_flag, error_file=error_file) + + result_list, err_msg = compare_ops_apply(n_value, b_value, error_flag, err_msg) + + if self.mode_config.fuzzy_match and npu_op_name != bench_op_name and bench_op_name != CompareConst.N_A: + err_msg += " Fuzzy matching data, the comparison accuracy may be affected." + result_list.append(err_msg) + return result_list + + def compare_ops(self, idx, dump_path_dict, result_df, lock, input_param): + cos_result = [] + euc_dist_result = [] + max_err_result = [] + max_relative_err_result = [] + one_thousand_err_ratio_result = [] + five_thousand_err_ratio_result = [] + err_mess = [] + + is_print_compare_log = input_param.get("is_print_compare_log") + + for i in range(len(result_df)): + npu_op_name = result_df.iloc[i, 0] + bench_op_name = result_df.iloc[i, 1] + if is_print_compare_log: + logger.info("start compare: {}".format(npu_op_name)) + + cos_sim, euc_dist, max_abs_err, max_relative_err, one_thousand_err_ratio, five_thousand_err_ratio, err_msg \ + = self.compare_by_op(npu_op_name, bench_op_name, dump_path_dict, input_param) + + if is_print_compare_log: + logger.info( + "[{}] Compare result: cosine {}, max_abs_err {}, max_relative_err {}, {}, \ + one_thousand_err_ratio {}, " + "five_thousand_err_ratio {}".format(npu_op_name, cos_sim, max_abs_err, max_relative_err, + err_msg, one_thousand_err_ratio, five_thousand_err_ratio)) + cos_result.append(cos_sim) + euc_dist_result.append(euc_dist) + max_err_result.append(max_abs_err) + max_relative_err_result.append(max_relative_err) + one_thousand_err_ratio_result.append(one_thousand_err_ratio) + five_thousand_err_ratio_result.append(five_thousand_err_ratio) + err_mess.append(err_msg) + + cr = ComparisonResult( + cos_result=cos_result, + euc_dist_result=euc_dist_result, + max_err_result=max_err_result, + max_relative_err_result=max_relative_err_result, + one_thousand_err_ratio_result=one_thousand_err_ratio_result, + five_thousand_err_ratio_result=five_thousand_err_ratio_result, + err_msgs=err_mess + ) + + return self._save_cmp_result(idx, cr, result_df, lock) + + def do_multi_process(self, input_param, result_df): + try: + result_df = self._handle_multi_process(self.compare_ops, input_param, result_df, + multiprocessing.Manager().RLock()) + return result_df + except ValueError as e: + logger.error('result dataframe is not found.') + raise CompareException(CompareException.INVALID_DATA_ERROR) from e + + def _handle_multi_process(self, func, input_param, result_df, lock): + process_num = max(int((multiprocessing.cpu_count() + 1) // 4), 1) + op_name_mapping_dict = self.read_dump_data(result_df) + + df_chunk_size = len(result_df) // process_num + if df_chunk_size > 0: + df_chunks = [result_df.iloc[i:i + df_chunk_size] for i in range(0, len(result_df), df_chunk_size)] + else: + df_chunks = [result_df] + + results = [] + pool = multiprocessing.Pool(process_num) + + def err_call(args): + logger.error('multiprocess compare failed! Reason: {}'.format(args)) + try: + pool.close() + except OSError: + logger.error("pool terminate failed") + + progress_bar = tqdm(total=len(result_df), desc="API/Module Item Compare Process", unit="row", ncols=100) + + def update_progress(size, progress_lock, extra_param=None): + with progress_lock: + progress_bar.update(size) + + for process_idx, df_chunk in enumerate(df_chunks): + idx = df_chunk_size * process_idx + chunk_size = len(df_chunk) + result = pool.apply_async(func, + args=(idx, op_name_mapping_dict, df_chunk, lock, input_param), + error_callback=err_call, + callback=partial(update_progress, chunk_size, lock) + ) + results.append(result) + + final_results = [r.get() for r in results] + pool.close() + pool.join() + return pd.concat(final_results, ignore_index=True) diff --git a/debug/accuracy_tools/msprobe/core/compare/npy_compare.py b/debug/accuracy_tools/msprobe/core/compare/npy_compare.py index c551985780cb9b56e32573727f9bf88f274da24e..b6b27b1772fc9ec70fdb47688c1642856dfb391d 100644 --- a/debug/accuracy_tools/msprobe/core/compare/npy_compare.py +++ b/debug/accuracy_tools/msprobe/core/compare/npy_compare.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -59,7 +59,7 @@ def get_error_flag_and_msg(n_value, b_value, error_flag=False, error_file=None): if error_file == "no_bench_data": err_msg = "Bench does not have data file." elif error_file: - err_msg = f"Dump file: {error_file} not found." + err_msg = f"Dump file: {error_file} not found or read failed." else: err_msg = CompareConst.NO_BENCH error_flag = True @@ -70,7 +70,7 @@ def get_error_flag_and_msg(n_value, b_value, error_flag=False, error_file=None): error_flag = True return CompareConst.NONE, CompareConst.NONE, error_flag, err_msg if not n_value.shape: # 判断数据是否为0维张量 - err_msg = (f"This is type of 0-d tensor, can not calculate '{CompareConst.COSINE}', " + err_msg = (f"This is type of 0-d tensor, can not calculate '{CompareConst.COSINE}', '{CompareConst.EUC_DIST}', " f"'{CompareConst.ONE_THOUSANDTH_ERR_RATIO}' and '{CompareConst.FIVE_THOUSANDTHS_ERR_RATIO}'. ") error_flag = False # 0-d tensor 最大绝对误差、最大相对误差仍然支持计算,因此error_flag设置为False,不做统一处理 return n_value, b_value, error_flag, err_msg @@ -168,8 +168,9 @@ def statistics_data_check(result_dict): class TensorComparisonBasic(abc.ABC): """NPU和bench中npy数据的比较模板""" + @abc.abstractmethod - def apply(self, n_value, b_value, relative_err): + def apply(self, n_value, b_value, relative_err, err_msg): raise NotImplementedError @@ -190,6 +191,7 @@ def get_relative_err(n_value, b_value): class GetCosineSimilarity(TensorComparisonBasic): """计算cosine相似度""" + @staticmethod def correct_data(result): if result == CompareConst.NAN: @@ -198,9 +200,9 @@ class GetCosineSimilarity(TensorComparisonBasic): return round(float(result), 6) return result - def apply(self, n_value, b_value, relative_err): - if not n_value.shape: - return CompareConst.UNSUPPORTED, "" + def apply(self, n_value, b_value, relative_err, err_msg): + if "This is type of 0-d tensor" in err_msg: + return CompareConst.UNSUPPORTED, err_msg with np.errstate(divide="ignore", invalid="ignore"): if len(n_value) == 1: @@ -224,9 +226,22 @@ class GetCosineSimilarity(TensorComparisonBasic): return result, "" +class GetEuclideanDistance(TensorComparisonBasic): + """计算欧式距离""" + + def apply(self, n_value, b_value, relative_err, err_msg): + if "This is type of 0-d tensor" in err_msg: + return CompareConst.UNSUPPORTED, err_msg + + distance = np.linalg.norm(n_value - b_value, ord=2) + + return distance, "" + + class GetMaxAbsErr(TensorComparisonBasic): """计算最大绝对误差""" - def apply(self, n_value, b_value, relative_err): + + def apply(self, n_value, b_value, relative_err, err_msg): temp_res = n_value - b_value max_value = np.max(np.abs(temp_res)) if np.isnan(max_value): @@ -237,7 +252,8 @@ class GetMaxAbsErr(TensorComparisonBasic): class GetMaxRelativeErr(TensorComparisonBasic): """计算最大相对误差""" - def apply(self, n_value, b_value, relative_err): + + def apply(self, n_value, b_value, relative_err, err_msg): max_relative_err = np.max(np.abs(relative_err)) if np.isnan(max_relative_err): msg = "Cannot compare by MaxRelativeError, the data contains nan/inf/-inf in dump data." @@ -247,12 +263,13 @@ class GetMaxRelativeErr(TensorComparisonBasic): class GetErrRatio(TensorComparisonBasic): """计算相对误差小于指定阈值(千分之一、千分之五)的比例""" + def __init__(self, threshold): self.threshold = threshold - def apply(self, n_value, b_value, relative_err): - if not n_value.shape: - return CompareConst.UNSUPPORTED, "" + def apply(self, n_value, b_value, relative_err, err_msg): + if "This is type of 0-d tensor" in err_msg: + return CompareConst.UNSUPPORTED, err_msg if not np.size(relative_err): return CompareConst.NAN, "" @@ -264,6 +281,7 @@ class GetErrRatio(TensorComparisonBasic): class CompareOps: compare_ops = { "cosine_similarity": GetCosineSimilarity(), + "euclidean_distance": GetEuclideanDistance(), "max_abs_error": GetMaxAbsErr(), "max_relative_error": GetMaxRelativeErr(), "one_thousand_err_ratio": GetErrRatio(CompareConst.THOUSAND_RATIO_THRESHOLD), @@ -272,10 +290,8 @@ class CompareOps: def error_value_process(n_value): - if n_value == CompareConst.READ_NONE or n_value == CompareConst.UNREADABLE: + if n_value in [CompareConst.READ_NONE, CompareConst.UNREADABLE, CompareConst.NONE]: return CompareConst.UNSUPPORTED, "" - if n_value == CompareConst.NONE: - return 0, "" if n_value == CompareConst.SHAPE_UNMATCH: return CompareConst.SHAPE_UNMATCH, "" if n_value == CompareConst.NAN: @@ -295,7 +311,7 @@ def compare_ops_apply(n_value, b_value, error_flag, err_msg): n_value, b_value = reshape_value(n_value, b_value) for op in CompareOps.compare_ops.values(): - result, msg = op.apply(n_value, b_value, relative_err) + result, msg = op.apply(n_value, b_value, relative_err, err_msg) result_list.append(result) err_msg += msg return result_list, err_msg diff --git a/debug/accuracy_tools/msprobe/core/compare/utils.py b/debug/accuracy_tools/msprobe/core/compare/utils.py index a2edf57e5bb91400675fe01734ea7fbf0e1df893..6da9f3e4bd1bc5aac30a848f9503ba33c4948b1c 100644 --- a/debug/accuracy_tools/msprobe/core/compare/utils.py +++ b/debug/accuracy_tools/msprobe/core/compare/utils.py @@ -20,6 +20,7 @@ import zlib from dataclasses import dataclass import numpy as np +import pandas as pd from msprobe.core.common.const import Const, CompareConst, FileCheckConst from msprobe.core.common.utils import CompareException, check_regex_prefix_format_valid, logger, safe_get_value @@ -81,22 +82,6 @@ def check_and_return_dir_contents(dump_dir, prefix): return contents -def rename_api(npu_name, process): - """ - 原api: {api_type}.{api_name}.{API调用次数}.{前向反向}.{input/output}.{参数序号} - rename后: {api_type}.{api_name}.{input/output}.{参数序号} - """ - npu_split = npu_name.split(process) - try: - torch_func_index, in_out = npu_split[0], npu_split[1] - except IndexError as error: - logger.error(f'{npu_name} can not be split with {process}, please check!') - raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from error - torch_func_split = torch_func_index.rsplit(Const.SEP, 2) - torch_func = str(torch_func_split[0]) + str(in_out) - return torch_func - - def read_op(op_data, op_name): if Const.PARAMS_GRAD in op_name.split(Const.SEP): op_parsed_list = op_item_parse(op_data, op_name) @@ -191,35 +176,201 @@ def gen_op_item(op_data, op_name): return op_item -def resolve_api_special_parameters(data_dict, full_op_name, item_list): +@dataclass +class ApiItemInfo: + name: str + struct: tuple + stack_info: list + + +def merge_tensor(tensor_list, dump_mode): + keys = [ + CompareConst.OP_NAME, + CompareConst.INPUT_STRUCT, + CompareConst.KWARGS_STRUCT, + CompareConst.OUTPUT_STRUCT, + CompareConst.PARAMS_STRUCT, + CompareConst.PARAMS_GRAD_STRUCT, + Const.SUMMARY, + Const.STACK_INFO + ] + op_dict = {key: [] for key in keys} + + if dump_mode == Const.ALL: + op_dict["data_name"] = [] + + for tensor in tensor_list: + # A dict(len=2) with 'full_op_name' and 'full_info' is added to the tensor only if self.stack_mode is True + if len(tensor) == 2: + op_dict[Const.STACK_INFO].append(tensor['full_info']) + break + + op_dict[CompareConst.OP_NAME].append(tensor['full_op_name']) + + _, state = get_name_and_state(tensor['full_op_name']) + struct_key = CompareConst.STATE_TO_STRUCT_MAPPING.get(state) + if not struct_key: + continue + if dump_mode == Const.MD5: + op_dict.get(struct_key).append((tensor[Const.DTYPE], tensor[Const.SHAPE], tensor[Const.MD5])) + else: + op_dict.get(struct_key).append((tensor[Const.DTYPE], tensor[Const.SHAPE])) + op_dict[Const.SUMMARY].append([tensor[Const.MAX], tensor[Const.MIN], tensor[Const.MEAN], tensor[Const.NORM]]) + + if dump_mode == Const.ALL: + op_dict["data_name"].append(tensor['data_name']) + + if not op_dict[CompareConst.KWARGS_STRUCT]: + del op_dict[CompareConst.KWARGS_STRUCT] + return op_dict if op_dict[CompareConst.OP_NAME] else {} + + +def print_compare_ends_info(): + total_len = len(CompareConst.COMPARE_ENDS_SUCCESSFULLY) + Const.FILL_CHAR_NUMS + logger.info('*' * total_len) + logger.info(f"*{CompareConst.COMPARE_ENDS_SUCCESSFULLY.center(total_len - 2)}*") + logger.info('*' * total_len) + + +def table_value_is_valid(value: str) -> bool: + if not isinstance(value, str): + return True + try: + # -1.00 or +1.00 should be considered as digit numbers + float(value) + except ValueError: + # otherwise, they will be considered as formular injections + return not bool(re.compile(FileCheckConst.CSV_BLACK_LIST).search(value)) + return True + + +class ApiBatch: + def __init__(self, api_name: str, start: int): + self.api_name = api_name + self.start = start + self.input_len = 1 # input的数量 + self.params_end_index = start + 1 # params的结束index + self.output_end_index = start + 1 # output的结束index + self.params_grad_end_index = start + 1 # params_grad的结束index + # 内部state的标志("input", "output", "parameters", "parameters_grad"), + # 用于控制计算input_len, output_end_index, params_end_index, self.params_grad_end_index + self._state = Const.INPUT # api_batch初始化为input + + def set_state(self, state: str): + """设置当前状态""" + if state in {Const.INPUT, Const.OUTPUT, Const.KWARGS, Const.PARAMS, Const.PARAMS_GRAD}: + self._state = state + else: + raise ValueError(f"Invalid state: {state}") + + def increment(self, state: str): + self.set_state(state) + if self._state == Const.INPUT or self._state == Const.KWARGS: + self.input_len += 1 + self.params_end_index += 1 + self.output_end_index += 1 + if self._state == Const.PARAMS: + self.params_end_index += 1 + self.output_end_index += 1 + if self._state == Const.OUTPUT: + self.output_end_index += 1 + self.params_grad_end_index += 1 + + +def api_batches_update(api_batches, api_name, state, index): + """ + 当一个api的所有item更新完后,input, output的索引范围: + input: [start: start+input_len] + output: [start+input_len: output_end_index] + params: [output_end_index: params_end_index] + """ + if not api_batches: + api_batches.append(ApiBatch(api_name, index)) + else: + api_batch = api_batches[-1] + if api_batch.api_name == api_name or ( + not re.search(Const.REGEX_FORWARD_BACKWARD, api_name) and api_name in api_batch.api_name): + try: + api_batch.increment(state) + except ValueError as e: + logger.error(f"api_batch: {api_batch} with invalid state, please check! {e}") + raise CompareException(CompareException.INVALID_STATE_ERROR) from e + else: + api_batches.append(ApiBatch(api_name, index)) + + +def get_name_and_state(name): """ - Function Description: - 解析下面格式的数据, 是api参数的一种特殊格式 - { - "last_hidden_state": { - "type": "torch.Tensor", - "dtype": "torch.bfloat16", - ... - }, - "loss": { - "type": "torch.Tensor", - "dtype": "torch.float32", - ... - } - } - Parameter: - data_dict: 字典格式的数据 - full_op_name: 参数的全名字符串 - item_list: 参数信息集合 + Get api/module name and state + example: + name = 'conv2d.forward.1.input.0' + return: ('conv2d.forward.1.', 'input') + + name = 'Functional.pad.0.backward.output.0' + return: ('Functional.pad.0.backward.', 'output') + + state type: input, output, kwargs, parameters, parameters_grad """ - for key, value in data_dict.items(): - if isinstance(value, dict): - parsed_item = value - parts = full_op_name.split(Const.SEP) - parts.insert(-1, key) - full_op_name_new = ".".join(parts) - parsed_item['full_op_name'] = full_op_name_new - item_list.append(parsed_item) + if not isinstance(name, str): + logger.error(f'Invalid name: {name}, type should be string, please check.') + raise CompareException(CompareException.INVALID_API_NAME_ERROR) + + if Const.PARAMS_GRAD in name.split(Const.SEP): + return name.split(Const.PARAMS_GRAD)[0], Const.PARAMS_GRAD + + split = re.split(Const.REGEX_FORWARD_BACKWARD, name) + if len(split) < 3: + logger.error(f'Invalid name string: {name}, can not be split by forward/backward, please check.') + raise CompareException(CompareException.INVALID_API_NAME_ERROR) + api = f'{split[0]}.{split[1]}.' + state_str = split[2] + match = re.match(r'^(\d+\.)?(input|output|kwargs|parameters)\..+$', state_str) + if not match: + raise CompareException(f'Invalid name string: {name}') + if match.group(1): + api = f'{api}{match.group(1)}' + state = match.group(2) + return api, state + + +def reorder_op_name_list(op_name_list): + if not op_name_list: + return op_name_list + + parameters = [] + output = [] + parameters_grad = [] + others = [] + for x in op_name_list: + state = get_name_and_state(x)[1] + if state == Const.PARAMS: + parameters.append(x) + elif state == Const.OUTPUT: + output.append(x) + elif state == Const.PARAMS_GRAD: + parameters_grad.append(x) + else: + others.append(x) + # 合并others, parameters, 和output,确保parameters排在output前面 + op_name_reorder = others + parameters + output + parameters_grad + return op_name_reorder + + +def reorder_op_x_list(op_name_list, summary_list, data_name_list): + """对op_name, summary, data_name重新排序,把parameters放到input后output前,data_name由于统计量比对时,为None,单独处理""" + if not op_name_list or not summary_list: + return op_name_list, summary_list, data_name_list + + index_map = {name: index for index, name in enumerate(op_name_list)} + + op_name_reorder = reorder_op_name_list(op_name_list) + summary_reorder = [summary_list[index_map.get(name)] for name in op_name_reorder] + if data_name_list: + data_name_reorder = [data_name_list[index_map.get(name)] for name in op_name_reorder] + else: + data_name_reorder = data_name_list + + return op_name_reorder, summary_reorder, data_name_reorder def process_summary_data(summary_data): @@ -285,9 +436,9 @@ def result_item_init(n_info, b_info, dump_mode): md5_compare_result = CompareConst.PASS if n_info.struct[2] == b_info.struct[2] else CompareConst.DIFF result_item.extend([n_info.struct[2], b_info.struct[2], md5_compare_result]) elif dump_mode == Const.SUMMARY: - result_item.extend([" "] * 8) + result_item.extend([" "] * 8) # 8个统计量数据情况的比对指标 else: - result_item.extend([" "] * 5) + result_item.extend([" "] * 6) # 6个真实数据情况的比对指标 else: err_msg = "index out of bounds error will occur in result_item_init, please check!\n" \ f"npu_info_struct is {n_info.struct}\n" \ @@ -321,8 +472,8 @@ def get_accuracy(result, n_dict, b_dict, dump_mode): has_stack = npu_stack_info and bench_stack_info if dump_mode == Const.ALL: - npu_data_name = n_dict.get("data_name", None) - bench_data_name = b_dict.get("data_name", None) + npu_data_name_list = n_dict.get("data_name", None) + bench_data_name_list = b_dict.get("data_name", None) for index in range(min_len): n_name = safe_get_value(n_dict, n_start + index, "n_dict", key="op_name") @@ -353,7 +504,9 @@ def get_accuracy(result, n_dict, b_dict, dump_mode): result_item.append(err_msg) result_item = stack_column_process(result_item, has_stack, index, key, npu_stack_info) if dump_mode == Const.ALL: - result_item.append(safe_get_value(npu_data_name, n_start + index, "npu_data_name")) + npu_data_name = safe_get_value(npu_data_name_list, n_start + index, "npu_data_name_list") + bench_data_name = safe_get_value(bench_data_name_list, b_start + index, "bench_data_name_list") + result_item.append([npu_data_name, bench_data_name]) result.append(result_item) @@ -371,7 +524,7 @@ def get_accuracy(result, n_dict, b_dict, dump_mode): continue result_item = [ n_name, CompareConst.NAN, n_struct[0], CompareConst.NAN, n_struct[1], CompareConst.NAN, - " ", " ", " ", " ", " " + " ", " ", " ", " ", " ", " " ] summary_data = n_dict.get(CompareConst.SUMMARY)[n_start + index] result_item.extend(summary_data) @@ -388,12 +541,13 @@ def get_accuracy(result, n_dict, b_dict, dump_mode): result_item.append(err_msg) result_item = stack_column_process(result_item, has_stack, index, key, npu_stack_info) if dump_mode == Const.ALL: - result_item.append(safe_get_value(npu_data_name, n_start + index, "npu_data_name")) + npu_data_name = safe_get_value(npu_data_name_list, n_start + index, "npu_data_name_list") + result_item.append([npu_data_name, "-1"]) result.append(result_item) - n_num, n_num_input, n_num_output, n_num_params, n_num_params_grad = count_struct(n_dict) - b_num, b_num_input, b_num_output, b_num_params, b_num_params_grad = count_struct(b_dict) + _, n_num_input, n_num_output, n_num_params, n_num_params_grad = count_struct(n_dict) + _, b_num_input, b_num_output, b_num_params, b_num_params_grad = count_struct(b_dict) get_accuracy_core(0, n_num_input, 0, b_num_input, CompareConst.INPUT_STRUCT) get_accuracy_core(n_num_input + n_num_output, n_num_params, b_num_input + b_num_output, b_num_params, @@ -404,197 +558,23 @@ def get_accuracy(result, n_dict, b_dict, dump_mode): CompareConst.PARAMS_GRAD_STRUCT) -def append_stack_info(result_item, npu_stack_info, index): - """添加堆栈信息到 result_item""" - if npu_stack_info and index == 0: - result_item.extend(npu_stack_info) - else: - result_item.append(CompareConst.NONE) - - -def get_un_match_accuracy(result, n_dict, dump_mode): - npu_stack_info = n_dict.get("stack_info", None) - bench_name, bench_type, bench_shape = CompareConst.N_A, CompareConst.N_A, CompareConst.N_A +def make_result_table(result, dump_mode, stack_mode): + header = CompareConst.HEAD_OF_COMPARE_MODE[dump_mode][:] - struct_to_index_mapping = { - CompareConst.INPUT_STRUCT: 0, - CompareConst.OUTPUT_STRUCT: 0, - CompareConst.PARAMS_STRUCT: 0, - CompareConst.PARAMS_GRAD_STRUCT: 0 - } - - op_name_list = n_dict.get(CompareConst.OP_NAME) - summary_list = n_dict.get(Const.SUMMARY) - data_name_list = n_dict.get('data_name') - op_name_reorder, summary_reorder, _ = reorder_op_x_list(op_name_list, - summary_list, - data_name_list) - for index, n_name in enumerate(op_name_reorder): - _, state = get_name_and_state(n_name) - struct_key = CompareConst.STATE_TO_STRUCT_MAPPING.get(state) - if not struct_key: - continue - n_struct = safe_get_value(n_dict, struct_to_index_mapping.get(struct_key), "n_dict", key=struct_key) - struct_to_index_mapping[struct_key] += 1 - - try: - result_item = [n_name, bench_name, n_struct[0], bench_type, n_struct[1], bench_shape] - except IndexError as e: - err_msg = "index out of bounds error occurs, please check!\n" \ - f"op_name of n_dict is {n_dict['op_name']}\n" \ - f"input_struct of n_dict is {n_dict[CompareConst.INPUT_STRUCT]}\n" \ - f"output_struct of n_dict is {n_dict[CompareConst.OUTPUT_STRUCT]}" - logger.error(err_msg) - raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from e - - if dump_mode == Const.MD5: - result_item.extend([CompareConst.N_A] * 3) - append_stack_info(result_item, npu_stack_info, index) - result.append(result_item) - continue - if dump_mode == Const.SUMMARY: - result_item.extend([CompareConst.N_A] * 8) + if stack_mode: + header.append(CompareConst.STACK) if dump_mode == Const.ALL: - result_item.extend([CompareConst.N_A] * 5) - - npu_summary_data = safe_get_value(summary_reorder, index, "summary_reorder") - bench_summary_data = [CompareConst.N_A] * 4 - result_item.extend(npu_summary_data) - result_item.extend(bench_summary_data) - err_msg = CompareConst.NO_BENCH - accuracy_check_res = CompareConst.N_A - result_item.append(accuracy_check_res) - result_item.append(err_msg) - append_stack_info(result_item, npu_stack_info, index) - if dump_mode == Const.ALL and result_item[1] == CompareConst.N_A: - result_item.extend(["-1"]) - result.append(result_item) - - -def merge_tensor(tensor_list, dump_mode): - op_dict = {} - op_dict["op_name"] = [] - op_dict[CompareConst.INPUT_STRUCT] = [] - op_dict[CompareConst.KWARGS_STRUCT] = [] - op_dict[CompareConst.OUTPUT_STRUCT] = [] - op_dict[CompareConst.PARAMS_STRUCT] = [] - op_dict[CompareConst.PARAMS_GRAD_STRUCT] = [] - op_dict[Const.SUMMARY] = [] - op_dict["stack_info"] = [] - - if dump_mode == Const.ALL: - op_dict["data_name"] = [] - - for tensor in tensor_list: - # A dict(len=2) with 'full_op_name' and 'full_info' is added to the tensor only if self.stack_mode is True - if len(tensor) == 2: - op_dict['stack_info'].append(tensor['full_info']) - break - - op_dict["op_name"].append(tensor['full_op_name']) - - _, state = get_name_and_state(tensor['full_op_name']) - struct_key = CompareConst.STATE_TO_STRUCT_MAPPING.get(state) - if not struct_key: - continue - if dump_mode == Const.MD5: - op_dict.get(struct_key).append((tensor[Const.DTYPE], tensor[Const.SHAPE], tensor[Const.MD5])) - else: - op_dict.get(struct_key).append((tensor[Const.DTYPE], tensor[Const.SHAPE])) - op_dict[Const.SUMMARY].append([tensor[Const.MAX], tensor[Const.MIN], tensor[Const.MEAN], tensor[Const.NORM]]) - + header.append(CompareConst.DATA_NAME) + else: if dump_mode == Const.ALL: - op_dict["data_name"].append(tensor['data_name']) - - if not op_dict[CompareConst.KWARGS_STRUCT]: - del op_dict[CompareConst.KWARGS_STRUCT] - return op_dict if op_dict["op_name"] else {} - - -def print_compare_ends_info(): - total_len = len(CompareConst.COMPARE_ENDS_SUCCESSFULLY) + Const.FILL_CHAR_NUMS - logger.info('*' * total_len) - logger.info(f"*{CompareConst.COMPARE_ENDS_SUCCESSFULLY.center(total_len - 2)}*") - logger.info('*' * total_len) - - -def table_value_is_valid(value: str) -> bool: - if not isinstance(value, str): - return True - try: - # -1.00 or +1.00 should be consdiered as digit numbers - float(value) - except ValueError: - # otherwise, they will be considered as formular injections - return not bool(re.compile(FileCheckConst.CSV_BLACK_LIST).search(value)) - return True - - -def get_name_and_state(name): - """ - Get api/module name and state - example: - name = 'conv2d.forward.1.input.0' - return: ('conv2d.forward.1.', 'input') - - name = 'Functional.pad.0.backward.output.0' - return: ('Functional.pad.0.backward.', 'output') - - state type: input, output, kwargs, parameters, parameters_grad - """ - if Const.PARAMS_GRAD in name.split(Const.SEP): - return name.split(Const.PARAMS_GRAD)[0], Const.PARAMS_GRAD - - split = re.split(Const.REGEX_FORWARD_BACKWARD, name) - api = f'{split[0]}.{split[1]}.' - state_str = split[2] - match = re.match(r'^(\d+\.)?(input|output|kwargs|parameters)\..+$', state_str) - if not match: - raise CompareException(f'Invalid name string: {name}') - if match.group(1): - api = f'{api}{match.group(1)}' - state = match.group(2) - return api, state - - -def reorder_op_name_list(op_name_list): - if not op_name_list: - return op_name_list - - parameters = [] - output = [] - parameters_grad = [] - others = [] - for x in op_name_list: - state = get_name_and_state(x)[1] - if state == Const.PARAMS: - parameters.append(x) - elif state == Const.OUTPUT: - output.append(x) - elif state == Const.PARAMS_GRAD: - parameters_grad.append(x) + for row in result: + del row[-2] # 输出结果不要堆栈信息时,删除中间结果result中的stack info,真实数据时为倒数第2列 + header.append(CompareConst.DATA_NAME) else: - others.append(x) - # 合并others, parameters, 和output,确保parameters排在output前面 - op_name_reorder = others + parameters + output + parameters_grad - return op_name_reorder - - -def reorder_op_x_list(op_name_list, summary_list, data_name_list): - """对op_name, summary, data_name重新排序,把parameters放到input后output前,data_name由于统计量比对时,为None,单独处理""" - if not op_name_list or not summary_list: - return op_name_list, summary_list, data_name_list - - index_map = {name: index for index, name in enumerate(op_name_list)} - - op_name_reorder = reorder_op_name_list(op_name_list) - summary_reorder = [summary_list[index_map.get(name)] for name in op_name_reorder] - if data_name_list: - data_name_reorder = [data_name_list[index_map.get(name)] for name in op_name_reorder] - else: - data_name_reorder = data_name_list - - return op_name_reorder, summary_reorder, data_name_reorder + for row in result: + del row[-1] # 输出结果不要堆栈信息时,删除中间结果result中的stack info,非真实数据时为倒数第1列 + result_df = pd.DataFrame(result, columns=header, dtype='object') + return result_df def _compare_parser(parser): diff --git a/debug/accuracy_tools/msprobe/core/data_dump/api_registry.py b/debug/accuracy_tools/msprobe/core/data_dump/api_registry.py new file mode 100644 index 0000000000000000000000000000000000000000..9090c1fa206f7149d3094ac2e2066c580b6ec1f7 --- /dev/null +++ b/debug/accuracy_tools/msprobe/core/data_dump/api_registry.py @@ -0,0 +1,239 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Dict, Any, Optional, Callable, Union, List, Tuple + +from msprobe.core.common.const import Const +from msprobe.core.common.file_utils import load_yaml +from msprobe.core.common.log import logger + + +def _get_attr(module, attr_name): + if Const.SEP in attr_name: + sub_module_name, sub_attr = attr_name.rsplit(Const.SEP, 1) + sub_module = getattr(module, sub_module_name, None) + attr = getattr(sub_module, sub_attr, None) + else: + attr = getattr(module, attr_name, None) + return attr + + +class ApiWrapper: + def __init__( + self, api_types: Dict[str, Dict[str, Any]], + api_list_paths: Union[str, List[str], Tuple[str]], + backlist: Union[List[str], Tuple[str]] = None + ): + self.api_types = api_types + if not isinstance(api_list_paths, (list, tuple)): + api_list_paths = [api_list_paths] * len(self.api_types) + elif len(api_list_paths) != len(self.api_types): + raise RuntimeError("The number of api_list_paths must be equal to the number of frameworks in 'api_types', " + "when api_list_paths is a list or tuple.") + self.api_list_paths = api_list_paths + self.backlist = backlist if backlist else [] + self.api_names = self._get_api_names() + self.wrapped_api_functions = dict() + + @staticmethod + def deal_with_self_kwargs(api_name, api_func, args, kwargs): + if kwargs and 'self' in kwargs: + func_params = None + try: + func_params = inspect.signature(api_func).parameters + except Exception: + if api_name in Const.API_WITH_SELF_ARG: + func_params = inspect.signature(Const.API_WITH_SELF_ARG.get(api_name)).parameters + if func_params is None: + return False, args, kwargs + + for name, param in func_params.items(): + if name == 'self' and param.kind == inspect.Parameter.KEYWORD_ONLY: + return False, args, kwargs + args_ = list(args) + names_and_values = [] + self_index = 0 + for i, item in enumerate(func_params.items()): + names_and_values.append((item[0], item[1].default)) + if item[0] == 'self': + self_index = i + break + for i in range(len(args), self_index + 1): + if names_and_values[i][0] in kwargs: + args_.append(kwargs.pop(names_and_values[i][0])) + else: + args_.append(names_and_values[i][1]) + args = tuple(args_) + + return True, args, kwargs + + def wrap_api( + self, api_templates, hook_build_func: Optional[Callable] + ): + api_types_num = sum([len(v) for v in self.api_types.values()]) + if not isinstance(api_templates, (list, tuple)): + api_templates = [api_templates] * api_types_num + elif len(api_templates) != api_types_num: + raise RuntimeError("The number of api_templates must be equal to the number of api_types, " + "when api_templates is a list or tuple.") + + self.wrapped_api_functions.clear() + index = 0 + for framework, api_types in self.api_types.items(): + wrapped_functions_in_framework = dict() + for api_type, api_modules in api_types.items(): + wrapped_functions = dict() + name_prefix = Const.API_DATA_PREFIX.get(framework, {}).get(api_type, "API") + api_template = api_templates[index] + index += 1 + for api_name in self.api_names.get(framework, {}).get(api_type, []): + ori_api = _get_attr(api_modules[0], api_name) + if callable(ori_api): + def wrap_api_func(api_name, api_func, prefix, hook_build_func, api_template): + def api_function(*args, **kwargs): + api_name_with_prefix = prefix + Const.SEP + str(api_name.split(Const.SEP)[-1]) + enable_wrap, args, kwargs = self.deal_with_self_kwargs(api_name_with_prefix, + api_func, args, kwargs) + if not enable_wrap: + logger.warning(f'Cannot collect precision data of {api_name_with_prefix}. ' + 'It may be fixed by passing the value of "self" ' + 'as a positional argument instead of a keyword argument. ') + return api_func(*args, **kwargs) + return api_template(api_name, api_func, prefix, hook_build_func)(*args, **kwargs) + api_function.__name__ = api_name + return api_function + wrapped_functions[api_name] = wrap_api_func(api_name, ori_api, name_prefix, + hook_build_func, api_template) + wrapped_functions_in_framework[api_type] = wrapped_functions + self.wrapped_api_functions[framework] = wrapped_functions_in_framework + return self.wrapped_api_functions + + def _get_api_names(self): + api_names = dict() + + for index, framework in enumerate(self.api_types.keys()): + api_list = load_yaml(self.api_list_paths[index]) + valid_names = dict() + for api_type, api_modules in self.api_types.get(framework, {}).items(): + key_in_file = Const.SUPPORT_API_DICT_KEY_MAP.get(framework, {}).get(api_type) + api_from_file = api_list.get(key_in_file, []) + names = set() + for api_name in api_from_file: + if f'{key_in_file}.{api_name}' in self.backlist: + continue + target_attr = api_name + target_module = api_modules[0] + if Const.SEP in api_name: + sub_module_name, target_attr = api_name.rsplit(Const.SEP, 1) + target_module = getattr(api_modules[0], sub_module_name, None) + if target_module and target_attr in dir(target_module): + names.add(api_name) + valid_names[api_type] = names + api_names[framework] = valid_names + + return api_names + + +class ApiRegistry: + """ + Base class for api registry. + """ + + def __init__(self, api_types, inner_used_api, supported_api_list_path, api_templates, backlist=None): + self.ori_api_attr = dict() + self.wrapped_api_attr = dict() + self.inner_used_ori_attr = dict() + self.inner_used_wrapped_attr = dict() + self.api_types = api_types + self.inner_used_api = inner_used_api + self.supported_api_list_path = supported_api_list_path + self.api_templates = api_templates + self.backlist = backlist if backlist else [] + self.all_api_registered = False + + @staticmethod + def store_ori_attr(ori_api_group, api_list, api_ori_attr): + for api in api_list: + api_ori_attr[api] = _get_attr(ori_api_group, api) + + @staticmethod + def set_api_attr(api_group, attr_dict): + for api, api_attr in attr_dict.items(): + if Const.SEP in api: + sub_module_name, sub_op = api.rsplit(Const.SEP, 1) + sub_module = getattr(api_group, sub_module_name, None) + if sub_module is not None: + setattr(sub_module, sub_op, api_attr) + else: + setattr(api_group, api, api_attr) + + @staticmethod + def register_custom_api(module, api_name, api_prefix, hook_build_func, api_template): + def wrap_api_func(api_name, api_func, prefix, hook_build_func, api_template): + def api_function(*args, **kwargs): + return api_template(api_name, api_func, prefix, hook_build_func)(*args, **kwargs) + + api_function.__name__ = api_name + return api_function + + setattr(module, api_name, + wrap_api_func(api_name, getattr(module, api_name), api_prefix, hook_build_func, api_template)) + + def register_all_api(self): + self.all_api_registered = True + for framework, api_types in self.api_types.items(): + for api_type, api_modules in api_types.items(): + api_type_with_framework = framework + Const.SEP + api_type + for module in api_modules[1]: + self.set_api_attr(module, self.wrapped_api_attr.get(api_type_with_framework, {})) + + def register_inner_used_api(self): + for api_type in self.inner_used_api.keys(): + self.set_api_attr(self.inner_used_api.get(api_type)[0], self.inner_used_wrapped_attr.get(api_type, {})) + + def restore_all_api(self): + self.all_api_registered = False + for framework, api_types in self.api_types.items(): + for api_type, api_modules in api_types.items(): + api_type_with_framework = framework + Const.SEP + api_type + for module in api_modules[1]: + self.set_api_attr(module, self.ori_api_attr.get(api_type_with_framework, {})) + + def restore_inner_used_api(self): + for api_type in self.inner_used_api.keys(): + self.set_api_attr(self.inner_used_api.get(api_type)[0], self.inner_used_ori_attr.get(api_type, {})) + + def initialize_hook(self, hook_build_func): + api_wrapper = ApiWrapper(self.api_types, self.supported_api_list_path, self.backlist) + wrapped_api_functions = api_wrapper.wrap_api(self.api_templates, hook_build_func) + + for framework, api_types in self.api_types.items(): + for api_type, api_modules in api_types.items(): + ori_attr = dict() + self.store_ori_attr(api_modules[0], api_wrapper.api_names.get(framework).get(api_type), ori_attr) + api_type_with_framework = framework + Const.SEP + api_type + self.ori_api_attr[api_type_with_framework] = ori_attr + self.wrapped_api_attr[api_type_with_framework] = wrapped_api_functions.get(framework).get(api_type) + + for inner_used_api_type, inner_used_api_list in self.inner_used_api.items(): + ori_attr = dict() + wrapped_attr = dict() + for api_name in inner_used_api_list[1:]: + if self.ori_api_attr.get(inner_used_api_type, {}).get(api_name): + ori_attr[api_name] = self.ori_api_attr.get(inner_used_api_type).get(api_name) + wrapped_attr[api_name] = self.wrapped_api_attr.get(inner_used_api_type).get(api_name) + self.inner_used_ori_attr[inner_used_api_type] = ori_attr + self.inner_used_wrapped_attr[inner_used_api_type] = wrapped_attr diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_collector.py b/debug/accuracy_tools/msprobe/core/data_dump/data_collector.py index 20e4489f89e4bd345595e6a1db1e39ab427d4908..c29a3ff56cde97d14dd1ca28b8d0524963f69ee1 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_collector.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_collector.py @@ -41,7 +41,7 @@ class DataCollector: self.backward_module_names = {} self.optimizer_status = "" self.optimizer_status_first_start = {Const.OPTIMIZER: True, Const.CLIP_GRAD: True} - atexit.register(self.write_json) + atexit.register(self.write_json_at_exit) @property def dump_data_dir(self): @@ -78,6 +78,11 @@ class DataCollector: def write_json(self): self.data_writer.write_json() + def write_json_at_exit(self): + if self.config.async_dump and self.config.task == Const.TENSOR: + self.data_processor.dump_async_data() + self.data_writer.write_json() + def update_data(self, name, data_info): msg = f"msprobe is collecting data on {name}." if self.config.task == Const.OVERFLOW_CHECK: @@ -89,6 +94,10 @@ class DataCollector: logger.debug(msg) self.data_writer.update_data(data_info) + def call_stack_collect(self, name): + stack_info = self.data_processor.analyze_api_call_stack(name) + self.data_writer.update_stack(name, stack_info) + def forward_input_data_collect(self, name, module, pid, module_input_output, is_recompute=None): if self.config.task == Const.FREE_BENCHMARK: backward_name = name.replace(Const.FORWARD, Const.BACKWARD) @@ -105,6 +114,7 @@ class DataCollector: self.set_is_recomputable(data_info, is_recompute) if self.config.level == Const.LEVEL_L2: return + self.call_stack_collect(name) self.handle_data(name, data_info, flush=self.data_processor.is_terminated) def forward_output_data_collect(self, name, module, pid, module_input_output, is_recompute=None): @@ -118,7 +128,7 @@ class DataCollector: self.set_is_recomputable(data_info, is_recompute) if self.config.level == Const.LEVEL_L2: return - self.data_writer.update_stack(self.data_processor.analyze_api_call_stack(name)) + self.handle_data(name, data_info, flush=self.data_processor.is_terminated) def forward_data_collect(self, name, module, pid, module_input_output, is_recompute=None): @@ -130,7 +140,7 @@ class DataCollector: if self.config.task != Const.STRUCTURE: data_info = self.data_processor.analyze_forward(name, module, module_input_output) self.set_is_recomputable(data_info, is_recompute) - self.data_writer.update_stack(self.data_processor.analyze_api_call_stack(name)) + self.call_stack_collect(name) self.handle_data(name, data_info, flush=self.data_processor.is_terminated) def backward_data_collect(self, name, module, pid, module_input_output, is_recompute=None): @@ -180,7 +190,10 @@ class DataCollector: self.optimizer_status_first_start[self.optimizer_status] = False self.data_writer.update_construct({name: self.optimizer_status}) else: - self.data_writer.update_construct({name: self.module_processor.api_parent_node}) + if self.config.level == Const.LEVEL_MIX and \ + not (name.startswith(Const.MODULE) or name.startswith(Const.CELL)): + self.data_writer.update_construct({name: self.module_processor.api_parent_node}) + self.data_writer.update_construct(self.module_processor.module_node) def handle_data(self, name, data_info, flush=False): @@ -204,6 +217,7 @@ class DataCollector: def params_data_collect(self, name, param_name, pid, data): grad_name = name + Const.SEP + Const.PARAMS_GRAD + self.update_api_or_module_name(grad_name) # 校验scope和pid,以及当前name是否有过反向计算 if not self.check_scope_and_pid(self.scope, name, pid) and not self.backward_module_names.get(name): # 如果没有反向计算,则需要清除之前占位写入的grad数据 @@ -213,18 +227,18 @@ class DataCollector: data_info = self.data_processor.analyze_params(grad_name, param_name, data) self.handle_data(grad_name, data_info, flush=self.data_processor.is_terminated) - def fill_stack_tensor_data(self): - self.data_writer.fill_stack_tensor_data() - def debug_data_collect_forward(self, variable, name_with_count): data_info = self.data_processor.analyze_debug_forward(variable, name_with_count) - self.data_writer.update_debug({name_with_count: data_info}) + name_with_count_category = name_with_count + Const.SEP + Const.DEBUG + self.data_writer.update_debug({name_with_count_category: data_info}) def debug_data_collect_backward(self, variable, grad_name_with_count): # prepare all None nested data structure all_none_data_info = self.data_processor.analyze_element_to_all_none(variable) - self.data_writer.update_debug({grad_name_with_count: all_none_data_info}) + grad_name_with_count_category = grad_name_with_count + Const.SEP + Const.DEBUG + self.data_writer.update_debug({grad_name_with_count_category: all_none_data_info}) # register tensor backward hook - self.data_processor.analyze_debug_backward(variable, grad_name_with_count, self.data_writer.cache_debug['data']) + self.data_processor.analyze_debug_backward(variable, grad_name_with_count_category, + self.data_writer.cache_debug['data']) diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py index 775a80b2418ef356867228b4ca09fad8c86cce25..6ff6b771976f731604df5aac7dee0eee2b2c13b4 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,17 +13,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import inspect import os from dataclasses import dataclass, is_dataclass -from typing import Tuple, Dict, Optional, Any from functools import partial -import copy -from typing import Union +from typing import Tuple, Dict, Optional, Any, Union import numpy as np from msprobe.core.common.const import Const +from msprobe.core.common.file_utils import save_npy from msprobe.core.common.log import logger from msprobe.core.common.utils import convert_tuple, CompareException @@ -79,21 +79,17 @@ class ModuleBackwardOutputs: class TensorStatInfo: - def __init__(self, max_val=None, min_val=None, mean_val=None, norm_val=None, stack_tensor_stat=None): + def __init__(self, max_val=None, min_val=None, mean_val=None, norm_val=None): self.max = max_val self.min = min_val self.mean = mean_val self.norm = norm_val - self.stack_tensor_stat = stack_tensor_stat class BaseDataProcessor: _recursive_key_stack = [] - special_type = ( - np.integer, np.floating, np.bool_, np.complexfloating, np.str_, np.byte, np.unicode_, np.ndarray, - bool, int, float, str, slice, - type(Ellipsis) - ) + builtin_type = (bool, int, float, str, slice, type(Ellipsis)) + np_type = (np.integer, np.floating, np.bool_, np.complexfloating, np.str_, np.byte, np.unicode_, np.ndarray) def __init__(self, config, data_writer): self.data_writer = data_writer @@ -120,7 +116,10 @@ class BaseDataProcessor: @staticmethod def analyze_api_call_stack(name): try: - api_stack = inspect.stack()[5:] + if name.startswith("Primitive"): + api_stack = inspect.stack()[4:] + else: + api_stack = inspect.stack()[5:] except Exception as e: logger.warning(f"The call stack of <{name}> failed to retrieve, {e}.") api_stack = None @@ -129,12 +128,14 @@ class BaseDataProcessor: for (_, path, line, func, code, _) in api_stack: if not code: continue + if any(filter_path in path for filter_path in Const.STACK_FILTER_KEYWORDS) and \ + Const.CALL_STACK_FLAG not in path: + continue stack_line = f"File {path}, line {str(line)}, in {func}, \n {code[0].strip()}" stack_str.append(stack_line) else: stack_str.append(Const.WITHOUT_CALL_STACK) - stack_info_struct = {name: stack_str} - return stack_info_struct + return tuple(stack_str) @staticmethod def transfer_type(data): @@ -178,20 +179,8 @@ class BaseDataProcessor: "invalid data_structure type or invalid index") @staticmethod - def _convert_numpy_to_builtin(arg): - type_mapping = { - np.integer: int, - np.floating: float, - np.bool_: bool, - np.complexfloating: complex, - np.str_: str, - np.byte: bytes, - np.unicode_: str - } - for numpy_type, builtin_type in type_mapping.items(): - if isinstance(arg, numpy_type): - return builtin_type(arg), type(arg).__name__ - return arg, '' + def is_distributed_op(module): + return getattr(module, "op_is_distributed", False) @staticmethod def _analyze_builtin(arg): @@ -217,21 +206,39 @@ class BaseDataProcessor: return single_arg @staticmethod - def _analyze_numpy(ndarray, numpy_type): + def _analyze_numpy(arg): + return {"type": type(arg).__name__, "value": arg.item()} + + @staticmethod + def _analyze_ndarray(ndarray, _): ndarray_json = {} ndarray_json.update({'type': 'numpy.ndarray'}) ndarray_json.update({'dtype': str(ndarray.dtype)}) ndarray_json.update({'shape': ndarray.shape}) - if ndarray.size > 0: - ndarray_json.update({"Max": np.max(ndarray).item()}) - ndarray_json.update({"Min": np.min(ndarray).item()}) - ndarray_json.update({"Mean": np.mean(ndarray).item()}) - ndarray_json.update({"Norm": np.linalg.norm(ndarray).item()}) - else: - ndarray_json.update({"Max": None}) - ndarray_json.update({"Min": None}) - ndarray_json.update({"Mean": None}) - ndarray_json.update({"Norm": None}) + + # 先初始化默认值 + stats = { + "Max": None, + "Min": None, + "Mean": None, + "Norm": None + } + + try: + # 只有非空时才尝试计算 + if ndarray.size > 0: + stats = { + "Max": np.max(ndarray).item(), + "Min": np.min(ndarray).item(), + "Mean": np.mean(ndarray).item(), + "Norm": np.linalg.norm(ndarray).item() + } + except Exception as e: + logger.warning(f"Error analyzing ndarray stats: {e}") + + # 最后一次性更新 + ndarray_json.update(stats) + return ndarray_json @staticmethod @@ -248,12 +255,12 @@ class BaseDataProcessor: @classmethod def get_special_types(cls): - return cls.special_type + return cls.builtin_type + cls.np_type @classmethod def recursive_apply_transform(cls, args, transform, depth=0) -> Union[dict, list, None]: - if depth > Const.MAX_DEPTH: - logger.error(f"The maximum depth of recursive transform, {Const.MAX_DEPTH} is reached.") + if depth > Const.DUMP_MAX_DEPTH: + logger.error(f"The maximum depth of recursive transform, {Const.DUMP_MAX_DEPTH} is reached.") raise CompareException(CompareException.RECURSION_LIMIT_ERROR) if isinstance(args, cls.get_special_types()): arg_transform = transform(args, cls._recursive_key_stack) @@ -303,6 +310,7 @@ class BaseDataProcessor: def real_hook_fn(grad): return wrap_hook_fn(grad) + element.register_hook(real_hook_fn) def if_return_forward_new_output(self): @@ -350,6 +358,8 @@ class BaseDataProcessor: return api_info_struct def analyze_forward_output(self, name, module, module_input_output: ModuleForwardInputsOutputs): + if self.is_distributed_op(module): + module_input_output.update_output_with_args_and_kwargs() api_info_struct = {} # check whether data_mode contains forward or input if self.is_dump_for_data_mode(Const.FORWARD, Const.OUTPUT): @@ -427,6 +437,7 @@ class BaseDataProcessor: api_info_struct = {} self.save_name = name + Const.SEP + param_name data_info = self.analyze_element(grad) + self.save_name = None grad_info_dict = {param_name: [data_info]} api_info_struct[name] = grad_info_dict return api_info_struct @@ -435,10 +446,10 @@ class BaseDataProcessor: file_format = Const.PT_SUFFIX if self.config.framework == Const.PT_FRAMEWORK else Const.NUMPY_SUFFIX if self.save_name is not None: dump_data_name = (self.save_name + file_format) - self.save_name = None else: - dump_data_name = (self.current_api_or_module_name + Const.SEP + self.api_data_category + Const.SEP + - suffix + file_format) + suffix_with_seq = (Const.SEP + suffix) if suffix else "" + dump_data_name = (self.current_api_or_module_name + Const.SEP + self.api_data_category + suffix_with_seq + + file_format) file_path = os.path.join(self.data_writer.dump_tensor_data_dir, dump_data_name) return dump_data_name, file_path @@ -447,23 +458,32 @@ class BaseDataProcessor: def analyze_debug_forward(self, variable, name_with_count): self.current_api_or_module_name = name_with_count - self.api_data_category = Const.TENSOR - # these two attributes are used to construct tensor file name {name_with_count}.tensor.{indexes}.npy/pt + self.api_data_category = Const.DEBUG + # these two attributes are used to construct tensor file name {name_with_count}.debug.{indexes}.npy/pt data_info = self.analyze_element(variable) return data_info - def analyze_debug_backward(self, variable, grad_name_with_count, nested_data_structure): + def analyze_debug_backward(self, variable, grad_name_with_count_category, nested_data_structure): def hook_fn(grad, indexes): suffix = Const.SEP.join([str(index) for index in indexes]) - self.save_name = grad_name_with_count + Const.SEP + Const.TENSOR + Const.SEP + suffix + suffix_with_sep = (Const.SEP + suffix) if suffix else "" + self.save_name = grad_name_with_count_category + suffix_with_sep grad_data_info = self.analyze_element(grad) self.save_name = None - full_index = [grad_name_with_count] + indexes + full_index = [grad_name_with_count_category] + indexes try: self.set_value_into_nested_structure(nested_data_structure, full_index, grad_data_info) except (ValueError, IndexError) as e: - logger.warning(f"error occured while recording statistics of {grad_name_with_count} variable, " - f"skip current recording, detailed infomation: {e}") + logger.warning(f"error occurred while recording statistics of {grad_name_with_count_category} variable," + f"skip current recording, detailed information: {e}") return grad + wrap_register_hook_single_element = partial(self.register_hook_single_element, hook_fn=hook_fn) - self.recursive_apply_transform(variable, wrap_register_hook_single_element) \ No newline at end of file + self.recursive_apply_transform(variable, wrap_register_hook_single_element) + + def _analyze_and_save_ndarray(self, ndarray, suffix): + dump_data_name, file_path = self.get_save_file_path(suffix) + save_npy(ndarray, file_path) + ndarray_json = BaseDataProcessor._analyze_ndarray(ndarray, suffix) + ndarray_json.update({"data_name": dump_data_name}) + return ndarray_json diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py index 8c4542a1917b76809aad21971e148ec17bd6045e..587857d080f5bf44d234a140af28b5d1c5c51e24 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py @@ -17,16 +17,17 @@ import zlib import mindspore as ms from mindspore import mint, ops, hal +from mindspore.mint import distributed from mindspore._c_expression.typing import Number import numpy as np from msprobe.core.common.const import Const from msprobe.core.data_dump.data_processor.base import (BaseDataProcessor, TensorStatInfo, ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs) -from msprobe.core.common.file_utils import path_len_exceeds_limit, save_npy +from msprobe.core.common.file_utils import path_len_exceeds_limit from msprobe.mindspore.common.utils import convert_bf16_to_fp32, save_tensor_as_npy from msprobe.mindspore.common.log import logger -from msprobe.mindspore.dump.hook_cell.api_registry import api_register +from msprobe.mindspore.dump.hook_cell.api_register import get_api_register has_adump = True try: @@ -36,7 +37,7 @@ except ImportError: class MindsporeDataProcessor(BaseDataProcessor): - mindspore_special_type = tuple([ms.Tensor, Number]) + mindspore_special_type = tuple([ms.Tensor, Number, distributed.P2POp]) def __init__(self, config, data_writer): super().__init__(config, data_writer) @@ -44,6 +45,7 @@ class MindsporeDataProcessor(BaseDataProcessor): "dtype": self.analyze_dtype_in_kwargs } self._async_dump_cache = {} + self.api_register = get_api_register() @staticmethod def get_md5_for_tensor(x): @@ -60,11 +62,10 @@ class MindsporeDataProcessor(BaseDataProcessor): def get_stat_info_sync(data): tensor_stat = TensorStatInfo() if data.dtype == ms.bool_: - data_np = data.asnumpy() - tensor_stat.max = np.max(data_np).item() - tensor_stat.min = np.min(data_np).item() + tensor_stat.max = mint.any(data) + tensor_stat.min = mint.all(data) elif not data.shape: - tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data.item() + tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data elif data.dtype == ms.complex64 or data.dtype == ms.complex128: data_abs = np.abs(data.asnumpy()) tensor_stat.max = np.max(data_abs).item() @@ -74,83 +75,98 @@ class MindsporeDataProcessor(BaseDataProcessor): else: if not ops.is_floating_point(data) or data.dtype == ms.float64: data = data.to(ms.float32) - api_register.norm_inner_op_set_ori_func() - get_max_value = api_register.mint_ops_ori_attr.get("max", mint.max) - get_min_value = api_register.mint_ops_ori_attr.get("min", mint.min) - get_mean_value = api_register.mint_ops_ori_attr.get("mean", mint.mean) - if hasattr(mint, "norm"): - get_norm_value = api_register.mint_ops_ori_attr.get("norm", mint.norm) - else: - get_norm_value = api_register.functional_ori_attr.get("norm", ops.norm) - tensor_stat.max = get_max_value(data).item() - tensor_stat.min = get_min_value(data).item() - tensor_stat.mean = get_mean_value(data).item() - tensor_stat.norm = get_norm_value(data).item() - api_register.norm_inner_op_set_hook_func() + get_norm_value = mint.norm if hasattr(mint, "norm") else ops.norm + tensor_stat.max = mint.max(data) + tensor_stat.min = mint.min(data) + tensor_stat.mean = mint.mean(data) + tensor_stat.norm = get_norm_value(data) return tensor_stat @staticmethod def get_stat_info_async(data): tensor_stat = TensorStatInfo() - stack_method = api_register.functional_ori_attr.get("stack", ms.ops.stack) - if data.dtype == ms.complex64 or data.dtype == ms.complex128: + if data.dtype == ms.bool_: + tensor_stat.max = mint.any(data) + tensor_stat.min = mint.all(data) + elif not data.shape: + tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data + elif data.dtype == ms.complex64 or data.dtype == ms.complex128: logger.warning("Async dump do not support complex data!") return tensor_stat - elif data.dtype == ms.bool_: - tensor_stat.stack_tensor_stat = (["Max", "Min"], stack_method([data.any(), data.all()])) - elif not data.shape: - tensor_stat.stack_tensor_stat = (["Max", "Min", "Mean", "Norm"], stack_method([data, data, data, data])) else: if not ops.is_floating_point(data) or data.dtype == ms.float64: data = data.to(ms.float32) - api_register.norm_inner_op_set_ori_func() - get_max_value = api_register.mint_ops_ori_attr.get("max", mint.max) - get_min_value = api_register.mint_ops_ori_attr.get("min", mint.min) - get_mean_value = api_register.mint_ops_ori_attr.get("mean", mint.mean) - if hasattr(mint, "norm"): - get_norm_value = api_register.mint_ops_ori_attr.get("norm", mint.norm) - else: - get_norm_value = api_register.functional_ori_attr.get("norm", ops.norm) - tensor_stat.stack_tensor_stat = (["Max", "Min", "Mean", "Norm"], stack_method( - [get_max_value(data), get_min_value(data), get_mean_value(data), get_norm_value(data)])) - api_register.norm_inner_op_set_hook_func() + get_norm_value = mint.norm if hasattr(mint, "norm") else ops.norm + tensor_stat.max = mint.max(data) + tensor_stat.min = mint.min(data) + tensor_stat.mean = mint.mean(data) + tensor_stat.norm = get_norm_value(data) return tensor_stat @staticmethod def is_hookable_element(element): return hasattr(element, "register_hook") and callable(element.register_hook) + @staticmethod + def process_group_hash(arg): + group_ranks = distributed.get_process_group_ranks(arg) + group_ranks_hash = zlib.crc32(str(group_ranks).encode('utf-8')) + return f"{group_ranks_hash:08x}" + @classmethod def get_special_types(cls): return super().get_special_types() + cls.mindspore_special_type + def dump_async_data(self): + for file_path, tensor in self._async_dump_cache.items(): + save_tensor_as_npy(tensor, file_path) + self._async_dump_cache.clear() + def get_stat_info(self, data): + self.api_register.restore_inner_used_api() tensor_stat = TensorStatInfo() if data.numel() == 0: - return tensor_stat + stat_info = tensor_stat else: if self.config.async_dump: - return MindsporeDataProcessor.get_stat_info_async(data) + stat_info = MindsporeDataProcessor.get_stat_info_async(data) else: - return MindsporeDataProcessor.get_stat_info_sync(data) + stat_info = MindsporeDataProcessor.get_stat_info_sync(data) + self.api_register.register_inner_used_api() + return stat_info def analyze_single_element(self, element, suffix_stack): if suffix_stack and suffix_stack[-1] in self.mindspore_object_key: return self.mindspore_object_key[suffix_stack[-1]](element) - converted_numpy, numpy_type = self._convert_numpy_to_builtin(element) - if converted_numpy is not element: - return {"type": numpy_type, "value": converted_numpy} - if isinstance(element, Number): - return self.analyze_dtype_in_kwargs(element) - if isinstance(element, ms.Tensor): - return self._analyze_tensor(element, Const.SEP.join([str(suffix) for suffix in suffix_stack])) - if isinstance(element, np.ndarray): - return self._analyze_numpy(element, Const.SEP.join([str(suffix) for suffix in suffix_stack])) - if isinstance(element, (bool, int, float, str, slice, type(Ellipsis))): - return self._analyze_builtin(element) + suffix_str = Const.SEP.join(str(s) for s in suffix_stack) + type_analyzer = [ + (MindsporeDataProcessor.builtin_type, self._analyze_builtin), + (ms.Tensor, lambda e: self._analyze_tensor(e, suffix_str)), + (Number, self.analyze_dtype_in_kwargs), + (MindsporeDataProcessor.np_type[:-1], self._analyze_numpy), + (np.ndarray, lambda e: self._analyze_ndarray(e, suffix_str)), + (distributed.P2POp, lambda e: self._analyze_p2pop(e, suffix_str)) + ] + for type_key, analyze_fn in type_analyzer: + if isinstance(element, type_key): + return analyze_fn(element) return {} + def _analyze_p2pop(self, arg, suffix): + p2pop_info = {"class_type": "mindspore.mint.distributed.P2POp"} + try: + tensor_info = self._analyze_tensor(arg.tensor, suffix) + p2pop_info.update({"tensor": tensor_info}) + p2pop_info.update({"op": arg.op}) + p2pop_info.update({"peer": arg.peer}) + p2pop_info.update({"tag": arg.tag}) + group_id = self.process_group_hash(arg.group) if arg.group else None + p2pop_info.update({"group_id": group_id}) + except Exception as e: + logger.warning(f"Failed to parse the P2POp content with error info: {e}.") + return p2pop_info + def _analyze_tensor(self, tensor, suffix): tensor_stat = self.get_stat_info(tensor) tensor_json = { @@ -159,45 +175,54 @@ class MindsporeDataProcessor(BaseDataProcessor): 'shape': tensor.shape } - if tensor_stat.stack_tensor_stat is None: - tensor_json.update({'Max': self.transfer_type(tensor_stat.max)}) - tensor_json.update({'Min': self.transfer_type(tensor_stat.min)}) - tensor_json.update({'Mean': self.transfer_type(tensor_stat.mean)}) - tensor_json.update({'Norm': self.transfer_type(tensor_stat.norm)}) - else: - tensor_json.update({'tensor_stat': tensor_stat.stack_tensor_stat}) + # 将统计值存入全局 buffer,并返回占位索引 + stat_values = [ + tensor_stat.max, + tensor_stat.min, + tensor_stat.mean, + tensor_stat.norm + ] + + placeholder_index = self.data_writer.append_stat_to_buffer(stat_values) + + tensor_json.update({Const.TENSOR_STAT_INDEX: placeholder_index}) + if self.config.summary_mode == Const.MD5 and not self.config.async_dump: tensor_md5 = self.get_md5_for_tensor(tensor) tensor_json.update({Const.MD5: tensor_md5}) return tensor_json - -class StatisticsDataProcessor(MindsporeDataProcessor): - pass - - -class TensorDataProcessor(MindsporeDataProcessor): - def dump_async_data(self): - for file_path, tensor in self._async_dump_cache.items(): - save_tensor_as_npy(tensor, file_path) - self._async_dump_cache.clear() - - def _analyze_tensor(self, tensor, suffix): + def _analyze_and_save_tensor(self, tensor, suffix): dump_data_name, file_path = self.get_save_file_path(suffix) - single_arg = super()._analyze_tensor(tensor, suffix) + single_arg = MindsporeDataProcessor._analyze_tensor(self, tensor, suffix) single_arg.update({"data_name": dump_data_name}) if self.config.async_dump: self._async_dump_cache[file_path] = tensor.copy() else: save_tensor_as_npy(tensor, file_path) return single_arg - - def _analyze_numpy(self, ndarray, suffix): - dump_data_name, file_path = self.get_save_file_path(suffix) - save_npy(ndarray, file_path) - ndarray_json = super()._analyze_numpy(ndarray, suffix) - ndarray_json.update({"data_name": dump_data_name}) - return ndarray_json + + +class StatisticsDataProcessor(MindsporeDataProcessor): + def _analyze_tensor(self, tensor, suffix): + if any(item in self.current_api_or_module_name for item in self.config.tensor_list): + return self._analyze_and_save_tensor(tensor, suffix) + else: + return super()._analyze_tensor(tensor, suffix) + + def _analyze_ndarray(self, ndarray, suffix): + if any(item in self.current_api_or_module_name for item in self.config.tensor_list): + return self._analyze_and_save_ndarray(ndarray, suffix) + else: + return super()._analyze_ndarray(ndarray, suffix) + + +class TensorDataProcessor(MindsporeDataProcessor): + def _analyze_tensor(self, tensor, suffix): + return self._analyze_and_save_tensor(tensor, suffix) + + def _analyze_ndarray(self, ndarray, suffix): + return self._analyze_and_save_ndarray(ndarray, suffix) class OverflowCheckDataProcessor(MindsporeDataProcessor): @@ -262,11 +287,20 @@ class OverflowCheckDataProcessor(MindsporeDataProcessor): self.cached_tensors_and_file_paths = {} def _analyze_maybe_overflow_tensor(self, tensor_json): - if tensor_json['Max'] is None: + tensor_stat_index = tensor_json.get(Const.TENSOR_STAT_INDEX) + if tensor_stat_index is None: + logger.warning("tensor_stat_index does not exist in tensor_json.") + return + max_tensor = self.data_writer.get_buffer_values_max(tensor_stat_index) + min_tensor = self.data_writer.get_buffer_values_min(tensor_stat_index) + if max_tensor is None or min_tensor is None: return - if np.isinf(tensor_json['Max']) or np.isnan(tensor_json['Max']): + + if mint.isinf(max_tensor) or mint.isnan(max_tensor): self.has_overflow = True - if np.isinf(tensor_json['Min']) or np.isnan(tensor_json['Min']): + return + + if mint.isinf(min_tensor) or mint.isnan(min_tensor): self.has_overflow = True def _analyze_tensor(self, tensor, suffix): diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py index 64253aa4260cab608e5ca84a5d006b28b94a33ab..2d6198d8369a8f073156c2dc3859913130174009 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import hashlib import zlib from dataclasses import asdict from typing import List @@ -24,14 +23,15 @@ from torch import distributed as dist from torch.distributed.distributed_c10d import _get_default_group from msprobe.core.common.const import Const +from msprobe.core.common.exceptions import MsprobeException from msprobe.core.common.file_utils import path_len_exceeds_limit from msprobe.core.common.log import logger from msprobe.core.common.utils import convert_tuple +from msprobe.core.common.decorator import recursion_depth_decorator from msprobe.core.data_dump.data_processor.base import BaseDataProcessor, ModuleBackwardInputsOutputs, \ ModuleForwardInputsOutputs, TensorStatInfo -from msprobe.pytorch.common.utils import save_pt, load_pt +from msprobe.pytorch.common.utils import Const as PtConst, save_pt, is_hifloat8_tensor, is_float8_tensor from msprobe.pytorch.free_benchmark import FreeBenchmarkCheck, UnequalRow -from msprobe.core.common.utils import recursion_depth_decorator is_gpu = False try: @@ -78,20 +78,23 @@ class PytorchDataProcessor(BaseDataProcessor): def analyze_device_in_kwargs(element): single_arg = {} single_arg.update({'type': "torch.device"}) - if not isinstance(element, str): + if isinstance(element, (int, str)): + single_arg.update({"value": element}) + elif isinstance(element, torch.device): if hasattr(element, "index"): device_value = element.type + ":" + str(element.index) else: device_value = element.type single_arg.update({"value": device_value}) else: - single_arg.update({"value": element}) + logger.debug(f"Device type {type(element)} is not supported.") return single_arg @staticmethod def analyze_dtype_in_kwargs(element): return {"type": "torch.dtype", "value": str(element)} + @staticmethod def get_stat_info_async(data): tensor_stat = TensorStatInfo() @@ -99,19 +102,17 @@ class PytorchDataProcessor(BaseDataProcessor): logger.warning("Async dump do not support complex data!") return tensor_stat elif data.dtype == torch.bool: - tensor_stat.stack_tensor_stat = (["Max", "Min"], torch.stack( - [torch.any(data), torch.all(data)])) + tensor_stat.max = torch.any(data) + tensor_stat.min = torch.all(data) elif not data.shape: - tensor_stat.stack_tensor_stat = (["Max", "Min", "Mean", "Norm"], torch.stack([data, data, data, data])) + tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data else: - if not data.is_floating_point() or data.dtype == torch.float64: + if data.dtype == torch.float64 or not data.is_floating_point(): data = data.float() - tensor_stat.stack_tensor_stat = (["Max", "Min", "Mean", "Norm"], torch.stack([ - torch.max(data), - torch.min(data), - torch.mean(data), - torch.norm(data) - ])) + tensor_stat.max = torch.max(data) + tensor_stat.min = torch.min(data) + tensor_stat.mean = torch.mean(data) + tensor_stat.norm = torch.norm(data) return tensor_stat @staticmethod @@ -124,17 +125,17 @@ class PytorchDataProcessor(BaseDataProcessor): tensor_stat.min = np.min(data_abs).item() tensor_stat.mean = np.mean(data_abs).item() elif data.dtype == torch.bool: - tensor_stat.max = torch.any(data).item() - tensor_stat.min = torch.all(data).item() + tensor_stat.max = torch.any(data) + tensor_stat.min = torch.all(data) elif not data.shape: - tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data.item() + tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data else: - if not data.is_floating_point() or data.dtype == torch.float64: + if data.dtype == torch.float64 or not data.is_floating_point(): data = data.float() - tensor_stat.max = torch.max(data).item() - tensor_stat.min = torch.min(data).item() - tensor_stat.mean = torch.mean(data).item() - tensor_stat.norm = torch.norm(data).item() + tensor_stat.max = torch.max(data) + tensor_stat.min = torch.min(data) + tensor_stat.mean = torch.mean(data) + tensor_stat.norm = torch.norm(data) return tensor_stat @staticmethod @@ -143,7 +144,7 @@ class PytorchDataProcessor(BaseDataProcessor): if data.is_meta: return tensor_stat data_clone = data.detach() - if data_clone.numel() == 0: + if not data_clone.numel() or not data_clone.data_ptr(): return tensor_stat else: if data_clone.device.type == Const.CPU_LOWERCASE or not async_dump: @@ -171,12 +172,8 @@ class PytorchDataProcessor(BaseDataProcessor): @staticmethod def process_group_hash(arg): group_ranks = dist.get_process_group_ranks(arg) - group_ranks_hash = hashlib.md5(str(group_ranks).encode('utf-8')).hexdigest() - return group_ranks_hash - - @staticmethod - def is_distributed_op(module): - return getattr(module, "op_is_distributed", False) + group_ranks_hash = zlib.crc32(str(group_ranks).encode('utf-8')) + return f"{group_ranks_hash:08x}" @staticmethod def is_hookable_element(element): @@ -214,43 +211,52 @@ class PytorchDataProcessor(BaseDataProcessor): logger.warning(f"Failed to get value of torch.distributed.ReduceOp with error info: {e}.") return {"type": "torch.distributed.ReduceOp", "value": op_type} + @staticmethod + def _cast_to_float_if_fp8(tensor): + dtype = str(tensor.dtype) + if is_float8_tensor(tensor): + dtype = PtConst.HIFLOAT8_TYPE if is_hifloat8_tensor(tensor) else dtype + logger.debug( + f"The {dtype} tensor analyzing/saving is unsupported in dump function." + f"Casting to float for processing." + ) + tensor = tensor.float() + return tensor, dtype + @classmethod def get_special_types(cls): return super().get_special_types() + cls.pytorch_special_type + def dump_async_data(self): + for file_path, tensor in self._async_dump_cache.items(): + save_pt(tensor.contiguous(), file_path) + self._async_dump_cache.clear() + def analyze_single_element(self, element, suffix_stack): if suffix_stack and suffix_stack[-1] in self.torch_object_key: return self.torch_object_key[suffix_stack[-1]](element) - if isinstance(element, torch.Size): - return self._analyze_torch_size(element) - if isinstance(element, torch.memory_format): - return self._analyze_memory_format(element) - if isinstance(element, dist.ProcessGroup): - return self._analyze_process_group(element) - if isinstance(element, dist.P2POp): - return self._analyze_p2pop(element) - if isinstance(element, dist.ReduceOp): - return self._analyze_reduce_op(element) - converted_numpy, numpy_type = self._convert_numpy_to_builtin(element) - if converted_numpy is not element: - return {"type": numpy_type, "value": converted_numpy} - if isinstance(element, torch.Tensor): - return self._analyze_tensor(element, Const.SEP.join([str(suffix) for suffix in suffix_stack])) - if isinstance(element, np.ndarray): - return self._analyze_numpy(element, Const.SEP.join([str(suffix) for suffix in suffix_stack])) - if isinstance(element, (bool, int, float, str, slice, type(Ellipsis))): - return self._analyze_builtin(element) - return {} - def analyze_forward_output(self, name, module, module_input_output: ModuleForwardInputsOutputs): - if self.is_distributed_op(module): - module_input_output.update_output_with_args_and_kwargs() - return super().analyze_forward_output(name, module, module_input_output) + suffix_str = Const.SEP.join(str(s) for s in suffix_stack) + type_analyzer = [ + (PytorchDataProcessor.builtin_type, self._analyze_builtin), + (torch.Size, self._analyze_torch_size), + (torch.Tensor, lambda e: self._analyze_tensor(e, suffix_str)), + (torch.memory_format, self._analyze_memory_format), + (dist.ProcessGroup, self._analyze_process_group), + (dist.P2POp, lambda e: self._analyze_p2pop(e, suffix_str)), + (dist.ReduceOp, self._analyze_reduce_op), + (PytorchDataProcessor.np_type[:-1], self._analyze_numpy), + (np.ndarray, lambda e: self._analyze_ndarray(e, suffix_str)), + ] + for type_key, analyze_fn in type_analyzer: + if isinstance(element, type_key): + return analyze_fn(element) + return {} - def _analyze_p2pop(self, arg): + def _analyze_p2pop(self, arg, suffix): p2pop_info = {"class_type": "torch.distributed.P2POp"} try: - tensor_info = self._analyze_tensor(arg.tensor, []) + tensor_info = self._analyze_tensor(arg.tensor, suffix) p2pop_info.update({"tensor": tensor_info}) p2pop_info.update({"op": arg.op.__name__}) p2pop_info.update({"peer": arg.peer}) @@ -263,63 +269,71 @@ class PytorchDataProcessor(BaseDataProcessor): return p2pop_info def _analyze_tensor(self, tensor, suffix): + tensor, dtype = self._cast_to_float_if_fp8(tensor) tensor_stat = self.get_stat_info(tensor, self.config.async_dump) tensor_json = {} tensor_json.update({'type': 'torch.Tensor'}) - tensor_json.update({'dtype': str(tensor.dtype)}) + tensor_json.update({'dtype': dtype}) tensor_json.update({"shape": tensor.shape}) - if tensor_stat.stack_tensor_stat is None: - tensor_json.update({"Max": tensor_stat.max}) - tensor_json.update({"Min": tensor_stat.min}) - tensor_json.update({"Mean": tensor_stat.mean}) - tensor_json.update({"Norm": tensor_stat.norm}) - tensor_json.update({"requires_grad": tensor.requires_grad}) - if tensor_stat.max is not None: - if np.isinf(tensor_stat.max) or np.isnan(tensor_stat.max): - tensor_json['Max_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(tensor, "max") - if tensor_stat.min is not None: - if np.isinf(tensor_stat.min) or np.isnan(tensor_stat.min): - tensor_json['Min_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(tensor, "min") - else: - tensor_json.update({"requires_grad": tensor.requires_grad}) - tensor_json.update({"tensor_stat": tensor_stat.stack_tensor_stat}) + stat_values = [ + tensor_stat.max, + tensor_stat.min, + tensor_stat.mean, + tensor_stat.norm + ] + placeholder_index = self.data_writer.append_stat_to_buffer(stat_values) + + tensor_json.update({Const.TENSOR_STAT_INDEX: placeholder_index}) + tensor_json.update({"requires_grad": tensor.requires_grad}) if self.config.summary_mode == Const.MD5 and not self.config.async_dump: tensor_md5 = self.get_md5_for_tensor(tensor) tensor_json.update({Const.MD5: tensor_md5}) return tensor_json - -class StatisticsDataProcessor(PytorchDataProcessor): - pass - - -class TensorDataProcessor(PytorchDataProcessor): - def dump_async_data(self): - for file_path, tensor in self._async_dump_cache.items(): - save_pt(tensor.contiguous(), file_path) - self._async_dump_cache.clear() - - def _analyze_tensor(self, tensor, suffix): + def _analyze_and_save_tensor(self, tensor, suffix): dump_data_name, file_path = self.get_save_file_path(suffix) - single_arg = super()._analyze_tensor(tensor, suffix) + single_arg = PytorchDataProcessor._analyze_tensor(self, tensor, suffix) single_arg.update({"data_name": dump_data_name}) + tensor, _ = self._cast_to_float_if_fp8(tensor) if self.config.async_dump: self._async_dump_cache[file_path] = tensor.clone().detach() else: saved_tensor = tensor.clone().contiguous().detach() save_pt(saved_tensor, file_path) return single_arg - - def _analyze_numpy(self, ndarray, suffix): + + def _analyze_and_save_ndarray(self, ndarray, suffix): dump_data_name, file_path = self.get_save_file_path(suffix) save_pt(torch.tensor(ndarray), file_path) - ndarray_json = super()._analyze_numpy(ndarray, suffix) + ndarray_json = PytorchDataProcessor._analyze_ndarray(ndarray, suffix) ndarray_json.update({"data_name": dump_data_name}) return ndarray_json +class StatisticsDataProcessor(PytorchDataProcessor): + def _analyze_tensor(self, tensor, suffix): + if any(item in self.current_api_or_module_name for item in self.config.tensor_list): + return self._analyze_and_save_tensor(tensor, suffix) + else: + return super()._analyze_tensor(tensor, suffix) + + def _analyze_ndarray(self, ndarray, suffix): + if any(item in self.current_api_or_module_name for item in self.config.tensor_list): + return self._analyze_and_save_ndarray(ndarray, suffix) + else: + return super()._analyze_ndarray(ndarray, suffix) + + +class TensorDataProcessor(PytorchDataProcessor): + def _analyze_tensor(self, tensor, suffix): + return self._analyze_and_save_tensor(tensor, suffix) + + def _analyze_ndarray(self, ndarray, suffix): + return self._analyze_and_save_ndarray(ndarray, suffix) + + class OverflowCheckDataProcessor(PytorchDataProcessor): __slots__ = ["cached_tensors_and_file_paths"] @@ -383,7 +397,8 @@ class OverflowCheckDataProcessor(PytorchDataProcessor): self._analyze_maybe_overflow_flag() if self.has_overflow: for file_path, tensor in self.cached_tensors_and_file_paths.items(): - save_pt(tensor, file_path) + tensor, _ = self._cast_to_float_if_fp8(tensor) + save_pt(tensor.clone().contiguous().detach(), file_path) self.real_overflow_nums += 1 if self.overflow_nums != -1 and self.real_overflow_nums >= self.overflow_nums: logger.info(f"[{Const.TOOL_NAME}] Reached the preset overflow times, " @@ -409,10 +424,22 @@ class OverflowCheckDataProcessor(PytorchDataProcessor): raise RuntimeError(f"overflow check failed") from e def _analyze_maybe_overflow_tensor(self, tensor_json): - if tensor_json['Max'] is None or tensor_json['Min'] is None: + tensor_stat_index = tensor_json.get(Const.TENSOR_STAT_INDEX) + if tensor_stat_index is None: + logger.warning("tensor_stat_index does not exist in tensor_json.") + return + max_tensor = self.data_writer.get_buffer_values_max(tensor_stat_index) + min_tensor = self.data_writer.get_buffer_values_min(tensor_stat_index) + + if max_tensor is None or min_tensor is None: + return + + if torch.isinf(max_tensor) or torch.isnan(max_tensor): + self.has_overflow = True return - self.has_overflow = np.isinf(tensor_json['Max']) or np.isnan(tensor_json['Max']) or \ - np.isinf(tensor_json['Min']) or np.isnan(tensor_json['Min']) + + if torch.isinf(min_tensor) or torch.isnan(min_tensor): + self.has_overflow = True def _analyze_tensor(self, tensor, suffix): dump_data_name, file_path = self.get_save_file_path(suffix) @@ -508,11 +535,13 @@ class KernelDumpDataProcessor(PytorchDataProcessor): return if self.config.is_backward_kernel_dump: - self.forward_args = self.clone_and_detach_tensor(module_input_output.args) - self.forward_kwargs = self.clone_and_detach_tensor(module_input_output.kwargs) try: + self.forward_args = self.clone_and_detach_tensor(module_input_output.args) + self.forward_kwargs = self.clone_and_detach_tensor(module_input_output.kwargs) output = module.forward(*self.forward_args, **self.forward_kwargs) - except Exception: + except Exception as e: + if isinstance(e, MsprobeException): + logger.warning(str(e)) self._print_unsupported_log(name) self.enable_kernel_dump = False return @@ -554,9 +583,17 @@ class KernelDumpDataProcessor(PytorchDataProcessor): self.stop_kernel_dump() logger.info(f"The kernel data of {name} is dumped successfully.") - @recursion_depth_decorator("KernelDump: KernelDumpDataProcessor.clone_and_detach_tensor") + @recursion_depth_decorator( + "KernelDump: KernelDumpDataProcessor.clone_and_detach_tensor", + max_depth=Const.DUMP_MAX_DEPTH + ) def clone_and_detach_tensor(self, input_params): if isinstance(input_params, torch.Tensor): + if is_float8_tensor(input_params): + raise MsprobeException( + MsprobeException.UNSUPPORTED_TYPE_ERROR, + f"L2 backward dump does not support float8 type." + ) if input_params.requires_grad: return input_params.clone().detach().requires_grad_() return input_params.clone() @@ -571,6 +608,8 @@ class KernelDumpDataProcessor(PytorchDataProcessor): def analyze_single_element(self, element, suffix_stack): if isinstance(element, torch.Tensor): + if is_float8_tensor(element): + return {} if not self.is_found_output_tensor: if element.requires_grad: self.forward_output_tensor = element diff --git a/debug/accuracy_tools/msprobe/core/data_dump/json_writer.py b/debug/accuracy_tools/msprobe/core/data_dump/json_writer.py index b1e26d16f9741765c1c9600a64efb112aa0f42d7..2f8ef29e40bb84fcdd4dfd3370269b0d54f86694 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/json_writer.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/json_writer.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -16,12 +16,14 @@ import csv import os import copy -import numpy as np +import threading from msprobe.core.common.const import Const, FileCheckConst from msprobe.core.common.file_utils import change_mode, FileOpen, save_json, load_json from msprobe.core.common.log import logger -from msprobe.core.common.exceptions import MsprobeException +from msprobe.core.common.decorator import recursion_depth_decorator + +lock = threading.Lock() class DataWriter: @@ -34,10 +36,12 @@ class DataWriter: self.dump_tensor_data_dir = None self.debug_file_path = None self.flush_size = 1000 + self.larger_flush_size = 20000 self.cache_data = {} self.cache_stack = {} self.cache_construct = {} self.cache_debug = {} + self.stat_stack_list = [] @staticmethod def write_data_to_csv(result: list, result_header: tuple, file_path: str): @@ -54,13 +58,54 @@ class DataWriter: if is_new_file: change_mode(file_path, FileCheckConst.DATA_FILE_AUTHORITY) + @recursion_depth_decorator("JsonWriter: DataWriter._replace_stat_placeholders", max_depth=Const.DUMP_MAX_DEPTH) + def _replace_stat_placeholders(self, data, stat_result): + if isinstance(data, dict): + keys = list(data.keys()) # 获取当前所有键 + for key in keys: # 递归所有变量 + value = data[key] + if key == Const.TENSOR_STAT_INDEX and isinstance(value, int): + if value >= 0: + idx = value + else: + return + stat_values = stat_result[idx] if idx < len(stat_result) else [None] * 4 + # 构建新字段并删除旧键 + new_entries = { + "type": data["type"], + "dtype": data["dtype"], + "shape": data["shape"], + "Max": stat_values[0], + "Min": stat_values[1], + "Mean": stat_values[2], + "Norm": stat_values[3] + } + del data[key] + + # 重构字典顺序 + updated_dict = {} + # 通过插入排序后字段保证字段写入json的有序 + updated_dict.update(new_entries) + # 遍历原字典其他字段(排除已删除的tensor_stat_index) + for k in data: + if k not in new_entries: + updated_dict[k] = data[k] + data.clear() + data.update(updated_dict) + else: + self._replace_stat_placeholders(value, stat_result) + elif isinstance(data, (list, tuple)): + for item in data: + self._replace_stat_placeholders(item, stat_result) + def reset_cache(self): self.cache_data = {} self.cache_stack = {} self.cache_construct = {} + self.cache_debug = {} def initialize_json_file(self, **kwargs): - if self.debug_file_path and not self.cache_debug: + if kwargs["level"] == Const.LEVEL_DEBUG and not self.cache_debug: # debug level case only create debug.json debug_dict = copy.deepcopy(kwargs) debug_dict.update({"dump_data_dir": self.dump_tensor_data_dir, Const.DATA: {}}) @@ -86,39 +131,59 @@ class DataWriter: def flush_data_periodically(self): dump_data = self.cache_data.get(Const.DATA) - if dump_data and isinstance(dump_data, dict) and len(dump_data) % self.flush_size == 0: + if not dump_data or not isinstance(dump_data, dict): + return + + length = len(dump_data) + + threshold = self.flush_size if length < self.larger_flush_size else self.larger_flush_size + + if length % threshold == 0: self.write_json() - def update_data(self, new_data): - if not isinstance(new_data, dict) or len(new_data.keys()) != 1: - logger.warning(f"The data info({new_data}) should be a dict with only one outer key.") - return - dump_data = self.cache_data.get(Const.DATA) - if not isinstance(dump_data, dict): - logger.warning(f"The dump data({dump_data}) should be a dict.") - return - key = next(iter(new_data.keys())) - if key in dump_data: - dump_data.get(key).update(new_data.get(key)) - else: - dump_data.update(new_data) + def update_data(self, new_data): + with lock: + if not isinstance(new_data, dict) or len(new_data.keys()) != 1: + logger.warning(f"The data info({new_data}) should be a dict with only one outer key.") + return + dump_data = self.cache_data.get(Const.DATA) + if not isinstance(dump_data, dict): + logger.warning(f"The dump data({dump_data}) should be a dict.") + return + + key = next(iter(new_data.keys())) + if key in dump_data: + dump_data.get(key).update(new_data.get(key)) + else: + dump_data.update(new_data) - def update_stack(self, new_data): - self.cache_stack.update(new_data) + def update_stack(self, name, stack_data): + with lock: + api_list = self.cache_stack.get(stack_data) + if api_list is None: + self.cache_stack.update({stack_data: [name]}) + else: + api_list.append(name) def update_construct(self, new_data): - self.cache_construct.update(new_data) + with lock: + self.cache_construct.update(new_data) def update_debug(self, new_data): - self.cache_debug['data'].update(new_data) + with lock: + self.cache_debug['data'].update(new_data) def write_data_json(self, file_path): logger.info(f"dump.json is at {os.path.dirname(os.path.dirname(file_path))}. ") save_json(file_path, self.cache_data, indent=1) def write_stack_info_json(self, file_path): - save_json(file_path, self.cache_stack, indent=1) + num, new_cache_stack = 0, {} + for key, value in self.cache_stack.items(): + new_cache_stack[num] = [value, key] + num += 1 + save_json(file_path, new_cache_stack, indent=1) def write_construct_info_json(self, file_path): save_json(file_path, self.cache_construct, indent=1) @@ -126,38 +191,61 @@ class DataWriter: def write_debug_info_json(self, file_path): save_json(file_path, self.cache_debug, indent=1) + def append_stat_to_buffer(self, stat_vector): + """ + 直接使用 Python list 存储 stat_vector, + 将 stat_vector 存入 self.stat_stack_list 的方式 + """ + self.stat_stack_list.append(stat_vector) + return len(self.stat_stack_list) - 1 + + def get_buffer_values_max(self, index): + if 0 <= index < len(self.stat_stack_list) and len(self.stat_stack_list[index]) >= 1: + return self.stat_stack_list[index][0] + else: + logger.warning(f"stat_stack_list[{index}] The internal data is incomplete," + f" and the maximum value cannot be obtained.") + return None + + def get_buffer_values_min(self, index): + if 0 <= index < len(self.stat_stack_list) and len(self.stat_stack_list[index]) >= 1: + return self.stat_stack_list[index][1] + else: + logger.warning(f"stat_stack_list[{index}] Internal data is incomplete" + f" and minimum values cannot be obtained.") + return None + + def flush_stat_stack(self): + """ + 在 flush 阶段,将所有存储的统计值从设备搬到 CPU, + 这里返回一个列表,每个元素是 [Max, Min, Mean, Norm] 的数值列表 + """ + if not self.stat_stack_list: + return [] + result = [ + [ + x.item() if hasattr(x, "item") else x + for x in stat_values + ] + for stat_values in self.stat_stack_list + ] + self.stat_stack_list = [] + return result + def write_json(self): - if self.cache_data: - self.write_data_json(self.dump_file_path) - if self.cache_stack: - self.write_stack_info_json(self.stack_file_path) - if self.cache_construct: - self.write_construct_info_json(self.construct_file_path) - if self.cache_debug: - self.write_debug_info_json(self.debug_file_path) - - def fill_stack_tensor_data(self): - self.process_stat_data_recursive(self.cache_data) - - def process_stat_data_recursive(self, data, depth=0): - if depth > Const.MAX_DEPTH: - logger.error(f"The maximum depth of recursive process stat data, {Const.MAX_DEPTH} is reached.") - raise MsprobeException(MsprobeException.RECURSION_LIMIT_ERROR) - if isinstance(data, dict): - if "tensor_stat" in data.keys(): - tensor_stat = data["tensor_stat"] - if len(tensor_stat) != Const.TENSOR_STAT_LEN or len(tensor_stat[0]) != len(tensor_stat[1]): - logger.warning("Some bad data in async dump") - else: - tensor_stat_index, tensor_stat_data = tensor_stat[0], tensor_stat[1] - if hasattr(tensor_stat_data, "device") and tensor_stat_data.device != Const.CPU_LOWERCASE: - tensor_stat_data = tensor_stat_data.cpu() - for index, stat in zip(tensor_stat_index, tensor_stat_data): - data.update({index: stat.item()}) - del data["tensor_stat"] - else: - for key in data.keys(): - self.process_stat_data_recursive(data[key], depth + 1) - elif isinstance(data, (list, tuple)): - for i in data: - self.process_stat_data_recursive(i, depth + 1) \ No newline at end of file + with lock: + # 在写 JSON 前,统一获取统计值 + stat_result = self.flush_stat_stack() + # 遍历 cache_data,将占位符替换为最终统计值 + if stat_result: + self._replace_stat_placeholders(self.cache_data, stat_result) + if self.cache_debug: + self._replace_stat_placeholders(self.cache_debug, stat_result) + if self.cache_data: + self.write_data_json(self.dump_file_path) + if self.cache_stack: + self.write_stack_info_json(self.stack_file_path) + if self.cache_construct: + self.write_construct_info_json(self.construct_file_path) + if self.cache_debug: + self.write_debug_info_json(self.debug_file_path) diff --git a/debug/accuracy_tools/msprobe/core/debugger/precision_debugger.py b/debug/accuracy_tools/msprobe/core/debugger/precision_debugger.py new file mode 100644 index 0000000000000000000000000000000000000000..03698530b08b976b580bc4e94f9962baf2d2f21f --- /dev/null +++ b/debug/accuracy_tools/msprobe/core/debugger/precision_debugger.py @@ -0,0 +1,146 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from msprobe.core.common.const import Const, FileCheckConst, MsgConst +from msprobe.core.common.exceptions import MsprobeException +from msprobe.core.common.file_utils import FileChecker, load_json +from msprobe.core.common.utils import get_real_step_or_rank, check_init_step +from msprobe.core.common_config import CommonConfig + + +class BasePrecisionDebugger: + _instance = None + tasks_not_need_debugger = [Const.GRAD_PROBE] + + def __new__(cls, *args, **kwargs): + if cls._instance is None: + cls._instance = super(BasePrecisionDebugger, cls).__new__(cls) + cls._instance.config = None + cls._instance.enable_dataloader = False + cls._instance.initialized = False + cls.service = None + cls.first_start = False + return cls._instance + + def __init__( + self, + config_path=None, + task=None, + dump_path=None, + level=None, + step=None + ): + if self.initialized: + return + self.initialized = True + self._check_input_params(config_path, task, dump_path, level) + self.common_config, self.task_config = self._parse_config_path(config_path, task) + self.task = self.common_config.task + if step is not None: + self.common_config.step = get_real_step_or_rank(step, Const.STEP) + + @staticmethod + def _check_input_params(config_path, task, dump_path, level): + if not config_path: + config_path = os.path.join(os.path.dirname(__file__), + "../../config.json") + if config_path is not None: + if not isinstance(config_path, str): + raise MsprobeException( + MsprobeException.INVALID_PARAM_ERROR, f"config_path must be a string") + file_checker = FileChecker( + file_path=config_path, path_type=FileCheckConst.FILE, file_type=FileCheckConst.JSON_SUFFIX) + file_checker.common_check() + + if task is not None and task not in Const.TASK_LIST: + raise MsprobeException( + MsprobeException.INVALID_PARAM_ERROR, f"task must be one of {Const.TASK_LIST}") + + if dump_path is not None: + if not isinstance(dump_path, str): + raise MsprobeException( + MsprobeException.INVALID_PARAM_ERROR, f"dump_path must be a string") + + if level is not None and level not in Const.LEVEL_LIST: + raise MsprobeException( + MsprobeException.INVALID_PARAM_ERROR, f"level must be one of {Const.LEVEL_LIST}") + + @staticmethod + def _get_task_config(task, json_config): + raise NotImplementedError("Subclass must implement _get_task_config") + + @classmethod + def forward_backward_dump_end(cls): + instance = cls._instance + instance.stop() + + @classmethod + def set_init_step(cls, step): + instance = cls._instance + if not instance: + raise Exception(MsgConst.NOT_CREATED_INSTANCE) + check_init_step(step) + instance.service.init_step = step + instance.service.loop = 0 + + @classmethod + def register_custom_api(cls, module, api, api_prefix=None): + if not api_prefix: + api_prefix = getattr(module, "__name__", "Custom") + if not isinstance(api_prefix, str): + raise MsprobeException( + MsprobeException.INVALID_PARAM_ERROR, "api_prefix must be string") + if not hasattr(module, api): + raise MsprobeException( + MsprobeException.INVALID_PARAM_ERROR, f"module {str(module)} does not have {api}") + instance = cls._instance + if not instance: + raise Exception(MsgConst.NOT_CREATED_INSTANCE) + instance.service.register_custom_api(module, api, api_prefix) + + @classmethod + def restore_custom_api(cls, module, api): + if not hasattr(module, api): + raise MsprobeException( + MsprobeException.INVALID_PARAM_ERROR, f"module {str(module)} does not have {api}") + instance = cls._instance + if not instance: + raise Exception(MsgConst.NOT_CREATED_INSTANCE) + instance.service.restore_custom_api(module, api) + + @classmethod + def _get_instance(cls): + instance = cls._instance + if not instance: + raise Exception(MsgConst.NOT_CREATED_INSTANCE) + if instance.task in BasePrecisionDebugger.tasks_not_need_debugger: + instance = None + return instance + + def _parse_config_path(self, json_file_path, task): + if not json_file_path: + json_file_path = os.path.join(os.path.dirname(__file__), + "../../config.json") + json_config = load_json(json_file_path) + common_config = CommonConfig(json_config) + if task: + task_config = self._get_task_config(task, json_config) + else: + if not common_config.task: + common_config.task = Const.STATISTICS + task_config = self._get_task_config(common_config.task, json_config) + return common_config, task_config diff --git a/debug/accuracy_tools/msprobe/core/grad_probe/constant.py b/debug/accuracy_tools/msprobe/core/grad_probe/constant.py index 22a8b6c13411b68a6566d0686062f8c74cb27196..5d9c72a6f2d60203b0d9ba716e867e39ee22d807 100644 --- a/debug/accuracy_tools/msprobe/core/grad_probe/constant.py +++ b/debug/accuracy_tools/msprobe/core/grad_probe/constant.py @@ -31,6 +31,7 @@ class GradConst: STEP = "step" BOUNDS = "bounds" OUTPUT_PATH = "output_path" + TIME_STAMP = "time_stamp" # level const LEVEL = "level" @@ -51,7 +52,7 @@ class GradConst: BOUNDS_MINIMUM = -2**63 BOUNDS_MAXIMUM = 2**63 - 1 - # file safty + # file safety DATA_DIR_AUTHORITY = 0o750 DATA_FILE_AUTHORITY = 0o640 DIRECTORY_LENGTH = 4096 diff --git a/debug/accuracy_tools/msprobe/core/grad_probe/grad_compare.py b/debug/accuracy_tools/msprobe/core/grad_probe/grad_compare.py index 4f2b25bd28dfe330a8716695278ab8c64222c4b6..f50fc0f4e381db0e4069ef99b5c70b593f1580d0 100644 --- a/debug/accuracy_tools/msprobe/core/grad_probe/grad_compare.py +++ b/debug/accuracy_tools/msprobe/core/grad_probe/grad_compare.py @@ -112,7 +112,7 @@ class GradComparator: result.append([key] + value) result_csv_path = os.path.join(output_dir, "similarities.csv") if os.path.exists(result_csv_path): - logger.warning(f"{result_csv_path} will be recoverd") + logger.warning(f"{result_csv_path} will be deleted") remove_path(result_csv_path) write_csv(result, result_csv_path) @@ -121,7 +121,7 @@ class GradComparator: similarities = {} logger.info(f"{len(steps)} steps will be compared") grad_weight_order = cls._get_grad_weight_order(path1, path2) - for step in tqdm(steps, desc="culculate similarities (by step)"): + for step in tqdm(steps, desc="calculate similarities (by step)"): grad_files = cls._get_matched_grad_files(path1, path2, step) same_count_summary = 0 total_count_summary = 0 diff --git a/debug/accuracy_tools/msprobe/core/grad_probe/utils.py b/debug/accuracy_tools/msprobe/core/grad_probe/utils.py index de3e4156acc74f135120e06116b5894a0e9ed09e..468367a54a8bf4926edd5a8f25cefaa5890ec40c 100644 --- a/debug/accuracy_tools/msprobe/core/grad_probe/utils.py +++ b/debug/accuracy_tools/msprobe/core/grad_probe/utils.py @@ -82,7 +82,7 @@ class ListCache(list): if len(self) == 0: return if not self._output_file: - logger.warning("dumpfile path is not setted") + logger.warning("dumpfile path is not set.") write_csv(self, self._output_file) logger.info(f"write {len(self)} items to {self._output_file}.") self.clear() diff --git a/debug/accuracy_tools/msprobe/core/hook_manager.py b/debug/accuracy_tools/msprobe/core/hook_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..c662309b91ec3d3089d6b45e84194de63eee2429 --- /dev/null +++ b/debug/accuracy_tools/msprobe/core/hook_manager.py @@ -0,0 +1,244 @@ +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +import os + +from msprobe.core.common.runtime import Runtime +from msprobe.core.common.utils import Const +from msprobe.core.data_dump.data_processor.base import (ModuleBackwardInputsOutputs, ModuleForwardInputsOutputs) + + +class HookSet: + def __init__(self, forward_hook=None, forward_pre_hook=None, backward_hook=None, backward_pre_hook=None): + self.forward_hook = forward_hook + self.forward_pre_hook = forward_pre_hook + self.backward_hook = backward_hook + self.backward_pre_hook = backward_pre_hook + + +class BaseHookManager(ABC): + inner_switch = False + hook_handle_dict = {} + params_grad_info = {} + + def __init__(self, data_collector, config, attl_manager=None): + self.data_collector = data_collector + self.config = config + self.attl_manager = attl_manager + + @property + def _pid(self): + return os.getpid() + + @property + @abstractmethod + def _is_recompute(self): + pass + + @staticmethod + @abstractmethod + def _no_grad_context(): + pass + + @staticmethod + @abstractmethod + def _add_count(name): + pass + + @staticmethod + @abstractmethod + def _process_kwargs_and_output(module, hook_type, kwargs_or_output, output_or_kwargs): + pass + + @staticmethod + def _clear_input_kwargs(module): + if hasattr(module, 'msprobe_input_kwargs'): + del module.msprobe_input_kwargs + + @abstractmethod + def build_hook(self): + pass + + @abstractmethod + def _get_params_dict(self, module): + pass + + @abstractmethod + def _need_exchange(self, module): + pass + + def _register_param_hook(self, name, module, params_dict): + ori_name = name.rsplit(Const.SEP, 2)[0] + grad_name = ori_name + Const.SEP + Const.PARAMS_GRAD + # 首次执行前向hook时,添加params_grad_name属性,并注册参数hook + setattr(module, 'params_grad_name', grad_name) + # data_mode为forward时,不注册参数hook + if not (Const.FORWARD in self.config.data_mode and Const.BACKWARD not in self.config.data_mode): + for param_name, param in params_dict.items(): + if param.requires_grad: + name = ori_name + Const.SEP + param_name + old_handle = BaseHookManager.hook_handle_dict.get(name) + if old_handle and hasattr(old_handle, "remove"): + old_handle.remove() + handle = param.register_hook(self._build_grad_hook(module, ori_name, param_name)) + BaseHookManager.hook_handle_dict[name] = handle + + def _init_params_grad_info(self, module, params_dict): + ''' + 初始化参数梯度信息, 在前向hook结束后, 将参数梯度信息写入cache_data中用于占位 + ''' + if not params_dict: + return + if not (Const.FORWARD in self.config.data_mode and Const.BACKWARD not in self.config.data_mode): + grad_name = module.params_grad_name if hasattr(module, 'params_grad_name') else None + # 判断是否已经在cache_data中进行了占位, 若没有则先写入cache_data中 + if not BaseHookManager.params_grad_info.get(grad_name): + data_info = {grad_name: {key: [None] for key, value in params_dict.items() if value.requires_grad}} + # 当模块中的参数有requires_grad属性为True时,才会进行梯度计算,此时才需要占位 + if data_info.get(grad_name): + # 将grad_name的data_info先写入cache_data中, 梯度计算后再更新 + self.data_collector.handle_data(grad_name, data_info, + flush=self.data_collector.data_processor.is_terminated) + # 记录当前模块的参数梯度信息已占位 + BaseHookManager.params_grad_info[grad_name] = True + + def _should_execute_hook(self, hook_type, module, is_forward): + is_module_hook = hook_type == Const.MODULE + if hasattr(module, 'async_op_dump_flag') and getattr(module, 'async_op_dump_flag'): + return False + if is_module_hook and not Runtime.is_running: + return False + elif not is_module_hook and is_forward and not Runtime.is_running: + return False + elif not is_module_hook and not is_forward and not module.forward_data_collected: + return False + if BaseHookManager.inner_switch: + return False + if not self.data_collector or self.data_collector.data_processor.is_terminated: + return False + return True + + def _build_grad_hook(self, module, ori_name, param_name): + def hook_fn(grad): + if not self._should_execute_hook(Const.MODULE, module, False): + return + BaseHookManager.inner_switch = True + self.data_collector.params_data_collect(ori_name, param_name, self._pid, grad) + BaseHookManager.inner_switch = False + return + return hook_fn + + def _build_forward_pre_hook(self, hook_type, full_name, api_name): + def forward_pre_hook(module, args, kwargs=None): + if hook_type == Const.MODULE: + return + if not self._should_execute_hook(hook_type, module, True): + return + if kwargs is None: + kwargs = module.msprobe_input_kwargs if hasattr(module, 'msprobe_input_kwargs') else {} + with self._no_grad_context(): + BaseHookManager.inner_switch = False + module.forward_data_collected = True + self._add_count(api_name) + module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=None) + self.data_collector.update_api_or_module_name(full_name) + if getattr(self.config, "online_run_ut", False): + BaseHookManager.inner_switch = False + return + self.data_collector.forward_input_data_collect( + full_name, + module, + self._pid, + module_input_output, + self._is_recompute + ) + BaseHookManager.inner_switch = False + return forward_pre_hook + + def _build_forward_hook(self, hook_type, full_name): + def forward_hook(module, args, kwargs_or_output, output_or_kwargs=None): + if not self._should_execute_hook(hook_type, module, True): + self._clear_input_kwargs(module) + return None + kwargs, output = self._process_kwargs_and_output(module, hook_type, kwargs_or_output, output_or_kwargs) + BaseHookManager.inner_switch = True + self.data_collector.update_api_or_module_name(full_name) + module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=output) + with self._no_grad_context(): + if getattr(self.config, "online_run_ut", False): + if self.data_collector.scope and not self.data_collector.scope.check(full_name): + return None + if self.attl_manager: + self.attl_manager.attl_send(full_name, args, kwargs, output) + BaseHookManager.inner_switch = False + return None + if hook_type == Const.MODULE: + params_dict = self._get_params_dict(module) + setattr(module_input_output, Const.PARAMS, params_dict) + if params_dict: + self._register_param_hook(full_name, module, params_dict) + self.data_collector.update_api_or_module_name(full_name) + self.data_collector.forward_data_collect( + full_name, + module, + self._pid, + module_input_output, + self._is_recompute + ) + self._init_params_grad_info(module, params_dict) + else: + self.data_collector.forward_output_data_collect( + full_name, + module, + self._pid, + module_input_output, + self._is_recompute + ) + self._clear_input_kwargs(module) + + if self.data_collector.if_return_forward_new_output(): + forward_new_output = self.data_collector.get_forward_new_output() + BaseHookManager.inner_switch = False + return forward_new_output + + BaseHookManager.inner_switch = False + return output + return forward_hook + + def _build_backward_hook(self, hook_type, full_name): + def backward_hook(module, grad_input, grad_output): + if not self._should_execute_hook(hook_type, module, False): + return + BaseHookManager.inner_switch = True + self.data_collector.update_api_or_module_name(full_name) + if getattr(self.config, "online_run_ut", False): + BaseHookManager.inner_switch = False + return + need_exchange = self._need_exchange(module) if hook_type == Const.MODULE else True + if need_exchange: + module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_output, grad_output=grad_input) + else: + module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_input, grad_output=grad_output) + self.data_collector.backward_data_collect( + full_name, + module, + self._pid, + module_input_output, + self._is_recompute + ) + BaseHookManager.inner_switch = False + return backward_hook + diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/kernel_dump/kernel_config.py b/debug/accuracy_tools/msprobe/core/kernel_dump/kernel_config.py similarity index 100% rename from debug/accuracy_tools/msprobe/mindspore/dump/kernel_dump/kernel_config.py rename to debug/accuracy_tools/msprobe/core/kernel_dump/kernel_config.py diff --git a/debug/accuracy_tools/msprobe/core/monitor/__init__.py b/debug/accuracy_tools/msprobe/core/monitor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/anomaly_analyse.py b/debug/accuracy_tools/msprobe/core/monitor/anomaly_processor.py similarity index 49% rename from debug/accuracy_tools/msprobe/pytorch/monitor/anomaly_analyse.py rename to debug/accuracy_tools/msprobe/core/monitor/anomaly_processor.py index 9a0b71e8a5791bc216c82737d1d4f4a482abceb9..8c50ad761682c05533d525acaa39e6f830cc4e48 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/anomaly_analyse.py +++ b/debug/accuracy_tools/msprobe/core/monitor/anomaly_processor.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -12,18 +12,205 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import os import sys +import math import argparse import ast import heapq +from abc import ABC +from dataclasses import dataclass, field +from typing import List -from msprobe.pytorch.common.log import logger from msprobe.core.common.const import MonitorConst -from msprobe.core.common.file_utils import check_path_before_create, save_json, create_directory, remove_path, \ +from msprobe.core.common.log import logger +from msprobe.core.common.file_utils import save_json, create_directory, remove_path, \ check_file_or_directory_path, load_json -from msprobe.pytorch.monitor.anomaly_detect import GradAnomalyData + + +class ScanRule(ABC): + name = "ScanRule" + + def apply(self, cur, history=None): + raise NotImplementedError("abstract method apply is not implemented") + + +class AnomalyTurbulence(ScanRule): + name = "AnomalyTurbulence" + + def __init__(self, threshold) -> None: + self.threshold = threshold + + def apply(self, cur, history=None): + """ + :param cur: float, current metric value + :param history: float, history weighted average + :return: bool, whether the current value deviates from the historical average value of current metric + """ + up_bound = history * (1 + self.threshold) + return abs(cur) > up_bound + + +class AnomalyNan(ScanRule): + name = "AnomalyNan" + + def __init__(self, threshold=None) -> None: + self.threshold = threshold + + def apply(self, cur, history=None): + return math.isnan(cur) or (self.threshold is not None and abs(cur) > self.threshold) + + +class AnomalyScanner: + + @staticmethod + def load_rules(specs: List[dict]): + """ + specs: [{"rule_name": "AnomalyTurbulence", "args": {"threshold": 0.5}}] + """ + if specs is None: + return [] + alert_rules = [] + for spec in specs: + # 使用get方法获取键值,如果键不存在则返回None + rule_cls_name = spec.get("rule_name") + rule_args = spec.get("args") + + # 检查必要的键是否存在 + if rule_cls_name is None or (rule_cls_name == "AnomalyTurbulence" and rule_args is None): + logger.warning(f"Spec is missing required keys: {spec}") + continue + + cur_module = sys.modules.get(__name__) + try: + rule_cls = getattr(cur_module, rule_cls_name) + except AttributeError: + logger.error(f"Rule class '{rule_cls_name}' not found in the current module.") + continue + + try: + rule_instance = rule_cls(**rule_args) if rule_args is not None else rule_cls() + alert_rules.append(rule_instance) + except Exception as e: + logger.error(f"Error creating instance of rule '{rule_cls_name}': {e}") + continue + + return alert_rules + + @staticmethod + def scan(scan_rules: List[ScanRule], history, cur): + anomaly = False + for rule in scan_rules: + anomaly = rule.apply(cur, history=history) + if anomaly: + return anomaly, rule.name + return anomaly, None + + +class AnomalyDataFactory(ABC): + def __init__(self, rank, pp_stage, group_mates): + super().__init__() + self.rank = rank + self.pp_stage = pp_stage + self.group_mates = group_mates + self.micro_step = 0 + self.name2callid = {} + + def set_call_id(self, name2callid): + """根据当前GradContext信息更新call_id vpp_stage等信息 + """ + self.name2callid = name2callid + + def create(self, tag, message, step): + """如果检查出异常, 调用当前接口生成GradAnomalyData实例 + tag (tuple): metric tag ('0:1.post_attention_norm.weight/rank0/pre_grad', 'min') + message (str): anomaly detect message + step (int): training step + """ + if not isinstance(tag, tuple) or len(tag) != 2: + raise ValueError("tag must be a tuple with length 2") + tag_name = tag[0] + param_name = tag_name.split('/')[0] + call_id = self.name2callid.get(tag_name, -1) + if MonitorConst.NAME_SEP in param_name: + vpp_stage = int(param_name.split(MonitorConst.NAME_SEP)[0]) + else: + vpp_stage = 0 + + return GradAnomalyData( + self.rank, + step, + self.micro_step, + self.pp_stage, + vpp_stage, + call_id, + tag_name, + message, + self.group_mates + ) + + +@dataclass(eq=True) +class GradAnomalyData: + rank: int = 0 + step: int = 0 + micro_step: int = 0 + pp_stage: int = 0 + vpp_stage: int = 0 + call_id: int = 0 + tag_name: str = field(default=None, compare=False) + message: str = field(default="", compare=False) + group_mates: list = field(default=None, compare=False) + + def __lt__(self, other): + """ + 自定义比较函数,用于确定 GradAnomalyData 实例之间的顺序。 + 比较规则为: + step 和 micro_step 值越小优先级越高; + vpp 和 pp 在前向阶段值越小优先级越高,在非前向阶段值越大优先级越高; + call_id 值越小优先级越高。 + """ + if not isinstance(other, GradAnomalyData): + return NotImplemented + + self_train_stage = self.get_train_stage(self.tag_name) + other_train_stage = self.get_train_stage(other.tag_name) + + def vpp_pp_comparator(anomaly): + """ + Determine the priority rule for vpp and pp based on train stage + Forward stage prefers smaller vpp and pp + Other stages prefer larger vpp and pp + """ + if self_train_stage == MonitorConst.FORWARD_STAGE: + return anomaly.vpp_stage, anomaly.pp_stage + else: + return -anomaly.vpp_stage, -anomaly.pp_stage + + self_cmp = [self.step, self.micro_step, self_train_stage, *vpp_pp_comparator(self), self.call_id] + other_cmp = [other.step, other.micro_step, other_train_stage, *vpp_pp_comparator(other), other.call_id] + return self_cmp < other_cmp + + def __le__(self, other): + if not isinstance(other, GradAnomalyData): + return NotImplemented + return self == other or self < other + + @staticmethod + def get_train_stage(tag_name): + """ + :param tag_name: "0:fc2.input:0/rank0/actv", "0:fc1.weight/rank0/post_grad", "0:fc2.weight/rank0/exp_avg_sq" + :return: int, if forward return 0; if backward return 1; if optimizer return 2 + """ + key_ = tag_name.split("/")[-1] + return MonitorConst.TRAIN_STAGE.get(key_, MonitorConst.DEFAULT_STAGE) + + def to_dict(self): + return self.__dict__ + + def get_key(self): + # 0:1.self_attention.core_attention_flash_0/rank0/input_grad + return ''.join([str(self.tag_name), "_step_", str(self.step), "_call_", str(self.call_id)]) class AnomalyDataWriter: @@ -46,12 +233,7 @@ class AnomalyDataWriter: def init_detected_json(self): """初始化落盘文件""" - check_path_before_create(self.dump_path) - if not os.path.exists(self.dump_path): - create_directory(self.dump_path) - - if not os.path.exists(self.dump_rank_dir): - create_directory(self.dump_rank_dir) + create_directory(self.dump_rank_dir) if os.path.exists(self.json_path): check_file_or_directory_path(self.json_path, isdir=False) @@ -66,11 +248,12 @@ class AnomalyDataWriter: anomalies: GradAnomalyData对象列表 """ anomalies_json = self.get_anomaly_dict(anomalies) - logger.info(f"{MonitorConst.ANOMALY_JSON} is at {self.dump_rank_dir}.") + if anomalies_json: + logger.info(f"{MonitorConst.ANOMALY_JSON} is at {self.dump_rank_dir}.") - data_to_write = load_json(self.json_path) if os.path.exists(self.json_path) else {} - data_to_write.update(anomalies_json) - save_json(self.json_path, data_to_write, indent=1) + data_to_write = load_json(self.json_path) if os.path.exists(self.json_path) else {} + data_to_write.update(anomalies_json) + save_json(self.json_path, data_to_write, indent=1) class AnomalyDataLoader: @@ -145,27 +328,6 @@ class AnomalyAnalyse: save_json(json_path, sorted_data, indent=1) -def _get_parse_args(): - parser = argparse.ArgumentParser() - parser.add_argument("-d", "--data_path", dest="data_path_dir", default="./", type=str, - help=" The anomaly detect result dictionary: generate from monitor tool.", - required=True, - ) - parser.add_argument("-o", "--out_path", dest="out_path", default="", type=str, - help=" The analyse task result out path.", - required=False, - ) - parser.add_argument("-k", "--topk", dest="top_k_number", default=8, type=int, - help=" Top K number of earliest anomalies.", - required=False, - ) - parser.add_argument("-s", "--step", dest="step_list", default="[]", type=str, - help=" Analyse which steps.", - required=False, - ) - return parser.parse_args(sys.argv[1:]) - - def _get_step_and_stop(args): try: step_list = ast.literal_eval(args.step_list) @@ -196,6 +358,27 @@ def _anomaly_analyse(): logger.info(f"{index}: {anomaly.message}") +def _get_parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("-d", "--data_path", dest="data_path_dir", default="./", type=str, + help=" The anomaly detect result dictionary: generate from monitor tool.", + required=True, + ) + parser.add_argument("-o", "--out_path", dest="out_path", default="", type=str, + help=" The analyse task result out path.", + required=False, + ) + parser.add_argument("-k", "--topk", dest="top_k_number", default=8, type=int, + help=" Top K number of earliest anomalies.", + required=False, + ) + parser.add_argument("-s", "--step", dest="step_list", default="[]", type=str, + help=" Analyse which steps.", + required=False, + ) + return parser.parse_args(sys.argv[1:]) + + if __name__ == "__main__": _anomaly_analyse() logger.info("Analyse task completed.") diff --git a/debug/accuracy_tools/msprobe/core/overflow_check/abnormal_scene.py b/debug/accuracy_tools/msprobe/core/overflow_check/abnormal_scene.py index 54dae2576e48b7ad75df97fa046e6e90bbd144c2..0e0c50cc6aa0cf93f963a699ee36c13d888ec320 100644 --- a/debug/accuracy_tools/msprobe/core/overflow_check/abnormal_scene.py +++ b/debug/accuracy_tools/msprobe/core/overflow_check/abnormal_scene.py @@ -20,6 +20,7 @@ import numpy as np from msprobe.core.overflow_check.api_info import APIInfo from msprobe.core.overflow_check.level import OverflowLevel from msprobe.core.overflow_check.utils import has_nan_inf +from msprobe.core.common.decorator import recursion_depth_decorator class AnomalyScene: @@ -35,6 +36,7 @@ class AnomalyScene: raise NotImplementedError @staticmethod + @recursion_depth_decorator("AbnormalScene: AnomalyScene._has_anomaly") def _has_anomaly(data: Union[Dict, Any]) -> bool: """检查张量是否包含异常值""" if isinstance(data, dict): diff --git a/debug/accuracy_tools/msprobe/core/overflow_check/level.py b/debug/accuracy_tools/msprobe/core/overflow_check/level.py index 2f40468f6551a3787bdae7f9d94a5f66599151a0..0848110178d0effa9b3bc40ae6d4437a800d3f04 100644 --- a/debug/accuracy_tools/msprobe/core/overflow_check/level.py +++ b/debug/accuracy_tools/msprobe/core/overflow_check/level.py @@ -1,22 +1,22 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from enum import Enum - - -class OverflowLevel(Enum): - MEDIUM = "medium" - HIGH = "high" - CRITICAL = "critical" +# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum + + +class OverflowLevel(Enum): + MEDIUM = "medium" + HIGH = "high" + CRITICAL = "critical" diff --git a/debug/accuracy_tools/msprobe/core/service.py b/debug/accuracy_tools/msprobe/core/service.py new file mode 100644 index 0000000000000000000000000000000000000000..00fe27727b86ac8e2e7afa2199e11168767519e5 --- /dev/null +++ b/debug/accuracy_tools/msprobe/core/service.py @@ -0,0 +1,356 @@ +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from abc import ABC, abstractmethod +import copy +from collections import defaultdict +import functools +import os + +from msprobe.core.common.exceptions import DistributedNotInitializedError +from msprobe.core.common.file_utils import create_directory +from msprobe.core.common.runtime import Runtime +from msprobe.core.common.utils import Const, print_tools_ends_info, DumpPathAggregation +from msprobe.core.data_dump.api_registry import ApiRegistry +from msprobe.core.data_dump.data_collector import build_data_collector +from msprobe.core.hook_manager import BaseHookManager +from msprobe.core.kernel_dump.kernel_config import create_kernel_config_json + + +class BaseService(ABC): + def __init__(self, config): + self.config = copy.deepcopy(config) + self.config.level = getattr(config, 'level_ori', config.level) # 兼容MindSpore配置 + self.model = None + self.data_collector = build_data_collector(self.config) + self.attl_manager = None + self.current_iter = 0 + self.loop = 0 + self.init_step = 0 + self.cur_token_id = 0 + self.first_start = True + self.primitive_switch = False + self.current_rank = None + self.dump_iter_dir = None + self.should_stop_service = False + self.ori_customer_func = {} + self.debug_variable_counter = None + self.currrent_step_first_debug_save = True + self.logger = None # 子类中注入 + self.api_register = None # 子类中注入 + self.api_template = None # 子类中注入 + self.hook_manager = None # 子类中注入 + self._init_specific_components() + self._register_api_hook() + + @property + def _is_debug_level(self): + return self.config.level == Const.LEVEL_DEBUG + + @property + def _is_l2_level(self): + return self.config.level == Const.LEVEL_L2 + + @property + def _is_mix_level(self): + return self.config.level == Const.LEVEL_MIX + + @property + def _is_need_module_hook(self): + return self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L0] + + @property + def _is_need_api_hook(self): + return self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L1, Const.LEVEL_L2] + + @property + def _is_no_dump_step(self): + return (self.config.step and self.current_iter not in self.config.step) + + @property + def _is_no_dump_rank(self): + return (self.config.rank and self.current_rank not in self.config.rank) + + @property + def _need_tensor_data(self): + """判断是否需要采集tensor数据""" + return bool( + self.config.task in self.data_collector.tasks_need_tensor_data or + (self.config.task == Const.STATISTICS and self.config.tensor_list) + ) + + @property + def _is_online_run_ut(self): + return getattr(self.config, "online_run_ut", False) + + @property + @abstractmethod + def _get_framework_type(self): + """获取框架类型""" + pass + + @staticmethod + @abstractmethod + def _get_current_rank(): + """获取当前rank_id""" + pass + + @staticmethod + def _change_jit_switch(status): + """修改JitDump开关,mindspore子类重写""" + pass + + def start(self, model=None, token_range=None): + """通用start模板""" + self._process_iteration() + if self._is_debug_level: + return + if self._need_stop_service(): + return + self.model = model + self.cur_token_id = 0 + if self.first_start: + try: + self.current_rank = self._get_current_rank() + except DistributedNotInitializedError: + self.current_rank = None + Runtime.current_rank = self.current_rank + if self._is_no_dump_rank: + return + self._register_hook() + if self._is_need_module_hook: + self._register_module_hook() + self.first_start = False + + if token_range: + self._register_infer_count_hook(self.model, token_range) + self.logger.info(f"{Const.TOOL_NAME}: debugger.start() is set successfully") + if token_range is None: + Runtime.is_running = True + self.primitive_switch = True + self._change_jit_switch(True) + self.logger.info(f"Dump switch is turned on at step {self.current_iter}. ") + if self._is_online_run_ut: + self._run_ut_dispatch(True) + else: + self.create_dirs() + self.logger.info(f"Dump data will be saved in {self.dump_iter_dir}.") + + def stop(self): + """通用stop模板""" + if self._is_debug_level or self.should_stop_service: + return + if self._is_no_dump_step or self._is_no_dump_rank: + return + self.logger.info(f"{Const.TOOL_NAME}: debugger.stop() is set successfully. " + "Please set debugger.start() to turn on the dump switch again. ") + Runtime.is_running = False + self.primitive_switch = False + self._change_jit_switch(False) + if self._is_l2_level: + return + if self._is_online_run_ut: + self._run_ut_dispatch(False) + self._process_async_dump() + self.data_collector.write_json() + + def step(self): + """通用step处理""" + if self.should_stop_service: + return + self._process_async_dump() + self.data_collector.write_json() + self.currrent_step_first_debug_save = True + self.loop += 1 + self._reset_status() + + def save(self, variable, name, save_backward): + ''' + Args: + variable: Union[List[variable], dict{str: variable}, mindspore.tensor, str, float, int] + name: str + save_backward: boolean + Return: + void + ''' + if not self._is_debug_level: + return + self.current_iter = self.loop + self.init_step + if self._is_no_dump_step: + return + + if self.currrent_step_first_debug_save: + try: + self.current_rank = self._get_current_rank() + except DistributedNotInitializedError: + self.current_rank = None + + self.create_dirs() + self.debug_variable_counter = defaultdict(int) + self.currrent_step_first_debug_save = False + + count = self.debug_variable_counter[name] + self.debug_variable_counter[name] += 1 + + name_with_count = f"{name}.{count}" + grad_name_with_count = f"{name}_grad.{count}" + + # forward save + self.data_collector.debug_data_collect_forward(variable, name_with_count) + + # backward save + if save_backward: + self.data_collector.debug_data_collect_backward(variable, grad_name_with_count) + + def register_custom_api(self, module, api_name, api_prefix): + self.ori_customer_func[str(module) + Const.SEP + api_name] = getattr(module, api_name) + ApiRegistry.register_custom_api(module, api_name, api_prefix, + functools.partial(self.build_hook, Const.API), self.api_template) + + def restore_custom_api(self, module, api): + ori_func = self.ori_customer_func.get(str(module) + Const.SEP + api) + if ori_func: + setattr(module, api, ori_func) + + + def build_hook(self, hook_type, name): + return self.hook_manager.build_hook(hook_type, name) + + def create_dirs(self): + """统一目录创建逻辑""" + create_directory(self.config.dump_path) + if Runtime.run_mode == Const.PYNATIVE_GRAPH_MODE: + self.dump_iter_dir = os.path.join(self.config.dump_path, Const.PYNATIVE_MODE, f"step{self.current_iter}") + else: + self.dump_iter_dir = os.path.join(self.config.dump_path, f"step{self.current_iter}") + + cur_rank = self.current_rank if self.current_rank is not None else '' + if self._is_l2_level: + self._create_l2_dirs(cur_rank) + else: + self._create_default_dirs(cur_rank) + + @abstractmethod + def _init_specific_components(self): + """初始化框架特定组件""" + pass + + @abstractmethod + def _register_hook(self): + """注册hook函数""" + pass + + @abstractmethod + def _register_module_hook(self): + """注册模块级别的hook函数""" + + def _need_stop_service(self): + if self.should_stop_service: + return True + end_service = self.config.step and self.current_iter > max(self.config.step) or \ + self.data_collector and self.data_collector.data_processor.is_terminated + if end_service: + if self._is_online_run_ut and self.attl_manager: + self.attl_manager.attl_stop() + self.primitive_switch = False + self._change_jit_switch(False) + Runtime.is_running = False + self.should_stop_service = True + print_tools_ends_info() + return True + if self._is_no_dump_step: + return True + return False + + def _register_api_hook(self): + if self._is_need_api_hook: + self.api_register.initialize_hook(functools.partial(self.build_hook, Const.API)) + self.api_register.register_all_api() + self.logger.info(f"The api {self.config.task} hook function is successfully mounted to the model.") + + def _register_infer_count_hook(self, root_model, token_range): + """ + 通过root_model执行的轮次来判断当前在第几个token + param root_model: 需要采集的推理模型 + param token_range: [start, end], 采集infer的token循环范围,左右皆包含在内 + return: None + """ + def infer_hook(model, args): + if self.cur_token_id == token_range[0]: + Runtime.is_running = True + self.primitive_switch = True + self._change_jit_switch(True) + self.logger.info(f"Current token id: {self.cur_token_id}, start dump infer token.") + elif token_range[0] < self.cur_token_id <= token_range[1]: + self.logger.debug(f"Current token id: {self.cur_token_id}.") + elif self.cur_token_id == token_range[1] + 1: + Runtime.is_running = False + self.primitive_switch = False + self._change_jit_switch(False) + self.logger.info( + f"Current token id: {self.cur_token_id}, exceed token_range, early stop dump infer token.") + self.cur_token_id += 1 + if isinstance(root_model, list): + root_model = root_model[0] + self.logger.warning("Infer model can only input one to support token_range, choose the first one.") + if self._is_online_run_ut: + return + root_model.register_forward_pre_hook(infer_hook) + + def _create_l2_dirs(self, cur_rank): + create_directory(self.dump_iter_dir) + kernel_config_path = create_kernel_config_json(self.dump_iter_dir, cur_rank) + self.config.kernel_config_path = kernel_config_path + + def _create_default_dirs(self, cur_rank): + dump_dir = os.path.join(self.dump_iter_dir, f"rank{cur_rank}") + create_directory(dump_dir) + + dump_data_dir = None + if self._need_tensor_data: + dump_data_dir = os.path.join(dump_dir, "dump_tensor_data") + create_directory(dump_data_dir) + + self._configure_dump_paths(dump_dir, dump_data_dir) + + def _configure_dump_paths(self, dump_dir, dump_data_dir): + dump_path_aggregation = DumpPathAggregation() + dump_path_aggregation.dump_file_path = os.path.join(dump_dir, "dump.json") + dump_path_aggregation.stack_file_path = os.path.join(dump_dir, "stack.json") + dump_path_aggregation.construct_file_path = os.path.join(dump_dir, "construct.json") + dump_path_aggregation.dump_tensor_data_dir = dump_data_dir + dump_path_aggregation.debug_file_path = os.path.join(dump_dir, "debug.json") + dump_path_aggregation.free_benchmark_file_path = os.path.join(dump_dir, "free_benchmark.csv") + self.data_collector.update_dump_paths(dump_path_aggregation) + self.data_collector.initialize_json_file(self._get_framework_type) + + def _process_iteration(self): + """处理迭代计数""" + self.current_iter = self.loop + self.init_step + self.data_collector.update_iter(self.current_iter) + Runtime.current_iter = self.current_iter + + def _process_async_dump(self): + """处理异步dump逻辑""" + if self.config.async_dump and self.config.task in [Const.STATISTICS, Const.TENSOR]: + self.data_collector.data_processor.dump_async_data() + + def _reset_status(self): + """通用状态重置""" + self.data_collector.reset_status() + BaseHookManager.params_grad_info.clear() + if self._is_l2_level: + self.data_collector.data_processor.reset_status() diff --git a/debug/accuracy_tools/msprobe/docs/01.installation.md b/debug/accuracy_tools/msprobe/docs/01.installation.md index 1ab5f6419ba07ec749bad139f874fbc7301fd8b3..a8c45873129a6a768d0f388fdd60143e7c162fa2 100644 --- a/debug/accuracy_tools/msprobe/docs/01.installation.md +++ b/debug/accuracy_tools/msprobe/docs/01.installation.md @@ -13,6 +13,7 @@ pip install mindstudio-probe ``` ## 2 下载 whl 包安装 +下方whl包链接为master分支最新编译的whl包。如果需要使用pre-research分支的代码,请拉取pre-research分支代码,通过源码编译安装。 |版本|发布日期|支持 PyTorch 版本|支持 MindSpore 版本|下载链接|校验码| |:--:|:--:|:--:|:--:|:--:|:--:| @@ -52,7 +53,7 @@ pip install ./mindstudio_probe*.whl |参数|说明|是否必选| |--|--|:--:| -|--include-mod|指定可选模块,可取值`adump`,表示在编whl包时加入adump模块。默认未配置该参数,表示编基础包。
• adump模块用于MindSpore静态图场景L2级别的dump。
• 仅MindSpore 2.5.0及以上版本支持adump模块。
• 若使用源码安装,编译环境需支持GCC 7或以上版本,和CMAKE 3.14或以上版本。
• 生成的whl包仅限编译时使用的python版本和处理器架构可用。|否| +|--include-mod|指定可选模块,可取值`adump`,表示在编whl包时加入adump模块。默认未配置该参数,表示编基础包。
• adump模块用于MindSpore静态图场景L2级别的dump。
• 仅MindSpore 2.5.0及以上版本支持adump模块。
• 若使用源码安装,编译环境需支持GCC 7.5或以上版本,和CMAKE 3.14或以上版本。
• 生成的whl包仅限编译时使用的python版本和处理器架构可用。|否| # 特性变更说明 @@ -80,8 +81,6 @@ pip install ./mindstudio_probe*.whl ## 1.1.1 -## 1.1.1 - 【数据采集】 - dump 支持 processgroup、namedtuple、slice 等数据类型 diff --git a/debug/accuracy_tools/msprobe/docs/02.config_introduction.md b/debug/accuracy_tools/msprobe/docs/02.config_introduction.md index f134bd4536294d209e7b3e6e73fd80b9be61041d..c170ff0152778e82dce1c1d237523ede9ffc0913 100644 --- a/debug/accuracy_tools/msprobe/docs/02.config_introduction.md +++ b/debug/accuracy_tools/msprobe/docs/02.config_introduction.md @@ -10,47 +10,61 @@ ### 1.1 通用配置 -| 参数 | 解释 | 是否必选 | -| ----------------- |------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -------- | -| task | dump 的任务类型,str 类型。可选参数:
"statistics":仅采集统计信息,默认值;
"tensor":采集统计信息和完全复刻整网的真实数据;
"run_ut":精度预检,仅 PyTorch 场景支持,采集数据时勿选;
"overflow_check":溢出检测;
"free_benchmark":无标杆比对;
"grad_probe":梯度监控;
"structure":仅采集模型结构以及调用栈信息,不采集具体数据。
根据 task 参数取值的不同,可以配置不同场景参数,详见:
[1.2 task 配置为 statistics](#12-task-配置为-statistics),
[1.3 task 配置为 tensor](#13-task-配置为-tensor),
[1.4 task 配置为 run_ut](#14-task-配置为-run_ut),
[1.5 task 配置为 overflow_check](#15-task-配置为-overflow_check),
[1.6 task 配置为 free_benchmark](#16-task-配置为-free_benchmark),
[1.7 task 配置为 grad_probe](#17-task-配置为-grad_probe)。
**配置示例**:"task": "tensor"。 | 否 | -| dump_path | 设置 dump 数据目录路径,str 类型。
**配置示例**:"dump_path": "./dump_path"。 | 是 | -| rank | 指定对某张卡上的数据进行采集,list[Union[int, str]] 类型,默认未配置(表示采集所有卡的数据),应配置元素为 ≥0 的整数或类似"4-6"的字符串,且须配置实际可用的 Rank ID。
PyTorch 场景: Rank ID 从 0 开始计数,最大取值为所有节点可用卡总数-1,若所配置的值大于实际训练所运行的卡的 Rank ID,则 dump 数据为空,比如当前环境 Rank ID 为 0 到 7,实际训练运行 0 到 3 卡,此时若配置 Rank ID 为 4 或不存在的 10 等其他值,dump 数据为空。
MindSpore 场景:所有节点的 Rank ID 均从 0 开始计数,最大取值为每个节点可用卡总数-1,config.json 配置一次 rank 参数对所有节点同时生效。
注意,单卡训练时,rank必须为[],即空列表,不能指定rank。
**配置示例**:"rank": [1, "4-6"]。 | 否 | -| step | 指定采集某个 step 的数据,list[Union[int, str]] 类型。默认未配置,表示采集所有 step 数据。采集特定 step 时,须指定为训练脚本中存在的 step,可逐个配置,也可以指定范围。
**配置示例**:"step": [0, 1 , 2, "4-6"]。 | 否 | -| level | dump 级别,str 类型,根据不同级别采集不同数据。可选参数:
"L0":dump 模块级精度数据,仅 PyTorch 与 MindSpore 动态图场景支持,使用背景详见 [1.1.1 模块级精度数据 dump 说明](#111-模块级精度数据-dump-说明);
"L1":dump API 级精度数据,默认值,仅 PyTorch 与 MindSpore 动态图场景支持;
"L2":dump kernel 级精度数据,PyTorch场景详细介绍见 [PyTorch 场景的 kernel dump 说明](./04.kernel_dump_PyTorch.md);MindSpore场景详细介绍见 [MindSpore 场景的 kernel dump 说明](./28.kernel_dump_MindSpore.md);
"mix":dump module 模块级和 API 级精度数据,即"L0"+"L1",仅 PyTorch 与 MindSpore 动态图场景支持。
"debug":单点保存功能,细节详见[单点保存工具 README](./28.debugger_save_instruction.md)
**配置示例**:"level": "L1"。 | 否 | -| enable_dataloader | 自动控制开关,bool 类型,仅 PyTorch 场景支持。可选参数 true(开启)或 false(关闭),默认为 false。配置为 true 后自动识别 step 参数指定的迭代,并在该迭代执行完成后退出训练,此时 start、stop 和 step 函数可不配置,开启该开关要求训练脚本是通过 torch.utils.data.dataloader 方式加载数据。仅支持 PyTorch 单卡训练使用,分布式训练场景下存在数据 dump 不全问题。 **这个特性下个版本将被废弃** | 否 | -| async_dump | 异步 dump 开关,bool 类型。可选参数 true(开启)或 false(关闭),默认为 false。配置为 true 后开启异步 dump,即采集的精度数据会在当前 step 训练结束后统一落盘,训练过程中工具不触发同步操作。由于使用该模式有**显存溢出**的风险,当 task 配置为 tensor 时,即真实数据的异步dump模式,必须配置 [list](#13-task-配置为-tensor) 参数,指定需要 dump 的 tensor 。该模式暂不支持复数类型 tensor
的统计量计算。 | 否 | +| 参数 | 解释 | 是否必选 | +| ----------------- |------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -------- | +| task | dump 的任务类型,str 类型。可选参数:
"statistics":仅采集统计信息,默认值;
"tensor":采集统计信息和完全复刻整网的真实数据;
"run_ut":精度预检,仅 PyTorch 场景支持,采集数据时勿选;
"overflow_check":溢出检测;
"free_benchmark":无标杆比对,不支持 MSAdapter 场景;
"grad_probe":梯度监控, 不支持 MSAdapter 场景;
"structure":仅采集模型结构以及调用栈信息,不采集具体数据。
根据 task 参数取值的不同,可以配置不同场景参数,详见:
[1.2 task 配置为 statistics](#12-task-配置为-statistics),
[1.3 task 配置为 tensor](#13-task-配置为-tensor),
[1.4 task 配置为 run_ut](#14-task-配置为-run_ut),
[1.5 task 配置为 overflow_check](#15-task-配置为-overflow_check),
[1.6 task 配置为 free_benchmark](#16-task-配置为-free_benchmark),
[1.7 task 配置为 grad_probe](#17-task-配置为-grad_probe)。
[1.8 task 配置为 structure](#18-task-配置为-structure)。
**配置示例**:"task": "tensor"。 | 否 | +| dump_path | 设置 dump 数据目录路径,str 类型。
**配置示例**:"dump_path": "./dump_path"。 | 是 | +| rank | 指定对某张卡上的数据进行采集,list[Union[int, str]] 类型,默认未配置(表示采集所有卡的数据),应配置元素为 ≥0 的整数或类似"4-6"的字符串,且须配置实际可用的 Rank ID。
PyTorch 场景: Rank ID 从 0 开始计数,最大取值为所有节点可用卡总数-1,若所配置的值大于实际训练所运行的卡的 Rank ID,则 dump 数据为空,比如当前环境 Rank ID 为 0 到 7,实际训练运行 0 到 3 卡,此时若配置 Rank ID 为 4 或不存在的 10 等其他值,dump 数据为空。
MindSpore 场景:所有节点的 Rank ID 均从 0 开始计数,最大取值为每个节点可用卡总数-1,config.json 配置一次 rank 参数对所有节点同时生效。静态图 L0 级别 dump 暂不支持指定rank。
注意,单卡训练时,rank必须为[],即空列表,不能指定rank。
**配置示例**:"rank": [1, "4-6"]。 | 否 | +| step | 指定采集某个 step 的数据,list[Union[int, str]] 类型。默认未配置,表示采集所有 step 数据。采集特定 step 时,须指定为训练脚本中存在的 step,可逐个配置,也可以指定范围。
**配置示例**:"step": [0, 1 , 2, "4-6"]。 | 否 | +| level | dump 级别,str 类型,根据不同级别采集不同数据。可选参数:
"L0":dump 模块级精度数据,使用背景详见 [1.1.1 模块级精度数据 dump 说明](#111-模块级精度数据-dump-说明);
"L1":dump API 级精度数据,默认值,仅 PyTorch、MSAdapter 以及 MindSpore 均支持;
"L2":dump kernel 级精度数据,PyTorch 场景详细介绍见 [PyTorch 场景的 kernel dump 说明](./04.kernel_dump_PyTorch.md);MindSpore 动态图场景详细介绍见 [MindSpore 动态图场景的 kernel dump 说明](./28.kernel_dump_MindSpore.md);MindSpore 静态图场景详细介绍见《MindSpore 场景的数据采集》中的 ["**8.1 静态图场景**"](./06.data_dump_MindSpore.md#81-静态图场景)小节;
"mix":dump module 模块级和 API 级精度数据,即"L0"+"L1",仅 PyTorch、MSAdapter 以及 MindSpore 动态图场景支持。
"debug":单点保存功能,细节详见[单点保存工具 README](./28.debugger_save_instruction.md)
**配置示例**:"level": "L1"。 | 否 | +| enable_dataloader | 自动控制开关,bool 类型,仅 PyTorch 场景支持。可选参数 true(开启)或 false(关闭),默认为 false。配置为 true 后自动识别 step 参数指定的迭代,并在该迭代执行完成后退出训练,此时 start、stop 和 step 函数可不配置,开启该开关要求训练脚本是通过 torch.utils.data.dataloader 方式加载数据。仅支持 PyTorch 单卡训练使用,分布式训练场景下存在数据 dump 不全问题。 **这个特性下个版本将被废弃** | 否 | +| async_dump | 异步 dump 开关,bool 类型, 支持 task 为 tensor 或 statistic 模式, level 支持 L0、 L1、 mix、 debug 模式。可选参数 true(开启)或 false(关闭),默认为 false。配置为 true 后开启异步 dump,即采集的精度数据会在当前 step 训练结束后统一落盘,训练过程中工具不触发同步操作。由于使用该模式有**显存溢出**的风险,当 task 配置为 tensor 时,即真实数据的异步dump模式,必须配置 [list](#13-task-配置为-tensor) 参数,指定需要 dump 的 tensor 。该模式暂不支持复数类型 tensor 的统计量计算。
| 否 | #### 1.1.1 模块级精度数据 dump 说明 -仅 PyTorch 与 MindSpore 动态图场景支持。 +PyTorch 与 MindSpore 均支持。 大模型场景下,通常不是简单的利用自动迁移能力实现从 GPU 到 NPU 的训练脚本迁移,而是会对 NPU 网络进行一系列针对性的适配,因此,常常会造成迁移后的 NPU 模型存在部分子结构不能与 GPU 原始模型完全对应。模型结构不一致导致 API 调用类型及数量不一致,若直接按照 API 粒度进行精度数据 dump 和比对,则无法完全比对所有的 API。 本小节介绍的功能是对模型中的大粒度模块进行数据 dump,使其比对时,对于无法以 API 粒度比对的模块可以直接以模块粒度进行比对。 -模块指的是继承 nn.Module 类(PyTorch场景)或 nn.Cell 类(MindSpore场景)的子类,通常情况下这类模块就是一个小模型,可以被视为一个整体,dump 数据时以模块为粒度进行 dump。 +模块指的是继承 nn.Module 类(PyTorch 与 MSAdapter 场景)或 nn.Cell 类(MindSpore 场景)的子类,通常情况下这类模块就是一个小模型,可以被视为一个整体,dump 数据时以模块为粒度进行 dump。 +特别地,在PyTorch场景中,为了规避BackwardHook函数的输出不能进行原地操作的框架限制,工具使用了`torch._C._autograd._set_creation_meta`接口对BackwardHook函数的输出张量进行属性重置,这可能会造成dump数据中缺少原地操作模块(nn.ReLU(inplace=True)及其上一个模块的反向数据。 ### 1.2 task 配置为 statistics - - - - - - - - + + + + + + + + + + +
参数解释是否必选
scopePyTorch 和 MindSpore 动态图场景 dump 范围,list[str] 类型,默认未配置(list 也未配置时表示 dump 所有 API 的数据)。该参数可以在 [ ] 内配置两个模块名或 API 名,要求列表长度必须为2,需要配置按照工具命名格式的完整模块名或API名称,用于锁定区间,dump 该范围内的数据。
配置示例: +
scopePyTorch、MSAdapter 以及 MindSpore 动态图场景 dump 范围,list[str] 类型,默认未配置(list 也未配置时表示 dump 所有 API 的数据)。该参数可以在 [ ] 内配置两个模块名或 API 名,要求列表长度必须为2,需要配置按照工具命名格式的完整模块名或API名称,用于锁定区间,dump 该范围内的数据。
配置示例: "scope": ["Module.conv1.Conv2d.forward.0", "Module.fc2.Linear.forward.0"], 或 "scope": ["Cell.conv1.Conv2d.forward.0", "Cell.fc2.Dense.backward.0"], 或"scope": ["Tensor.add.0.forward", "Functional.square.2.forward"]。与 level 参数取值相关,level 为 L0 级别时,可配置模块名;level 为 L1 级别时,可配置 API 名, level为 mix 级别时,可配置为模块名或API名。
list自定义采集的算子列表,list[str] 类型,默认未配置(scope 也未配置时表示 dump 所有 API 的数据),包含以下配置方法:
PyTorch 和 MindSpore 动态图场景配置具体的 API 全称,dump 该 API 数据。在 PyTorch 场景,如果 level 配置成 L2,该配置为必填项。
配置示例:"list": ["Tensor.permute.1.forward", "Tensor.transpose.2.forward", "Torch.relu.3.backward"]。
PyTorch 和 MindSpore 动态图场景在level为 mix 级别时可以配置模块名称,dump该模块展开数据 (dump该模块从执行开始到执行结束期间的所有数据)。 +
PyTorch、MSAdapter 以及 MindSpore 动态图场景配置具体的 API 全称,dump 该 API 数据。在 PyTorch 场景,如果 level 配置成 L2,该配置为必填项。
配置示例:"list": ["Tensor.permute.1.forward", "Tensor.transpose.2.forward", "Torch.relu.3.backward"]。
PyTorch 和 MindSpore 动态图场景在level为 mix 级别时可以配置模块名称,dump该模块展开数据 (dump该模块从执行开始到执行结束期间的所有数据)。
配置示例:"list": ["Module.module.language_model.encoder.layers.0.mlp.ParallelMlp.forward.0"], 或 "list": ["Cell.network_with_loss.language_model.encoder.layers.0.mlp.ParallelMlp.forward.0"]
PyTorch 和 MindSpore 动态图场景指定某一类 API,dump 某一类的 API 级别输入输出数据。
配置示例:"list": ["relu"]。
PyTorch 和 MindSpore 动态图场景在level为 mix 级别时, 会dump名称中包含list中配置的字符串的API数据,还会将名称中包含list中配置的字符串的模块进行展开dump (dump该模块从执行开始到执行结束期间的所有数据)。
MindSpore 静态图场景配置 kernel_name,可以是算子的名称列表,也可以指定算子类型("level": "L2"时不支持),还可以配置算子名称的正则表达式(当字符串符合“name-regex(xxx)”格式时,后台则会将其作为正则表达式。
配置示例:list: ["name-regex(Default/.+)"]
可匹配算子名称以“Default/”开头的所有算子。
data_modedump 数据过滤,str 类型。
PyTorch 与 MindSpore 动态图场景:支持"all"、"forward"、"backward"、"input"和"output",除"all"外,其余参数可以自由组合。默认为["all"],即保存所有 dump 的数据。
配置示例:"data_mode": ["backward"] (仅保存反向数据)或 "data_mode": ["forward", "input"](仅保存前向的输入数据)。
MindSpore 静态图场景:仅支持"all"、"input"和"output"参数,且各参数只能单独配置,不支持自由组合。
配置示例:"data_mode": ["all"]。
summary_mode控制 dump 文件输出的模式,str 类型,仅 PyTorch 与 MindSpore 动态图场景支持,可选参数:
md5:dump 输出包含 CRC-32 值以及 API 统计信息的 dump.json 文件,用于验证数据的完整性;
statistics:dump 仅输出包含 API 统计信息的 dump.json 文件,默认值。
配置示例:"summary_mode": "md5"。
MindSpore静态图jit_level=O2场景L2级dump,支持上述配置的同时额外支持配置统计项列表,可选统计项为max、min、mean、l2norm,可从中任意选取组合搭配。其中mean、l2norm的结果为float数据格式。
配置示例:"summary_mode": ["max", "min"]。
PyTorch、MSAdapter 以及 MindSpore 动态图场景指定某一类 API,dump 某一类的 API 级别输入输出数据。
配置示例:"list": ["relu"]。
PyTorch、MSAdapter 以及 MindSpore 动态图场景在level为 mix 级别时, 会dump名称中包含list中配置的字符串的API数据,还会将名称中包含list中配置的字符串的模块进行展开dump (dump该模块从执行开始到执行结束期间的所有数据)。
MindSpore 静态图场景配置 kernel_name,可以是算子的名称列表,也可以指定算子类型(jit_level=O2 时不支持),还可以配置算子名称的正则表达式(当字符串符合“name-regex(xxx)”格式时,后台则会将其作为正则表达式。
配置示例:list: ["name-regex(Default/.+)"]
可匹配算子名称以“Default/”开头的所有算子。
tensor_list自定义采集真实数据的算子列表,list[str] 类型,默认未配置。包含以下配置方法:
PyTorch、MSAdapter 以及 MindSpore 动态图场景指定某一类 API 或模块,即会 dump 这一类 API 或模块输入输出的统计量信息和完整的 tensor 数据。
配置示例:"tensor_list": ["relu"]。
PyTorch、MSAdapter 以及 MindSpore 动态图场景目前只支持level配置为 L0, L1 和 mix 级别。
MindSpore 静态图场景不支持。
device控制统计值计算所用的设备,可选值["device", "host"],默认"host"。使用device计算会比host有性能加速,只支持min/max/avg/l2norm统计量。支持 MindSpore静态图 O0/O1 场景。
precision控制统计值计算所用精度,可选值["high", "low"],默认值为"high"。选择"high"时,avg/l2norm统计量使用float32进行计算,会增加device内存占用,精度更高;为"low"时使用与原始数据相同的类型进行计算,device内存占用较少,但在处理较大数值时可能会导致统计量溢出。支持 MindSpore静态图 O0/O1 场景。
data_modedump 数据过滤,str 类型。
PyTorch、MSAdapter 以及 MindSpore 动态图场景:支持"all"、"forward"、"backward"、"input"和"output",除"all"外,其余参数可以自由组合。默认为["all"],即保存所有 dump 的数据。
配置示例:"data_mode": ["backward"] (仅保存反向数据)或 "data_mode": ["forward", "input"](仅保存前向的输入数据)。
MindSpore 静态图场景:L0 级别 dump 仅支持"all"、"forward"和"backward"参数;L2 级别 dump 仅支持"all"、"input"和"output"参数。且各参数只能单独配置,不支持自由组合。
配置示例:"data_mode": ["all"]。
summary_mode控制 dump 文件输出的模式,str 类型,支持 PyTorch、MSAdapter、MindSpore 动态图以及 MindSpore 静态图 L2 级别 jit_level=O2 场景和 L0 级别 jit_level=O0/O1 场景。
PyTorch、MSAdapter 以及 MindSpore 动态图场景:可选参数为
md5:dump 输出包含 CRC-32 值以及 API 统计信息的 dump.json 文件,用于验证数据的完整性;
statistics:dump 仅输出包含 API 统计信息的 dump.json 文件,默认值。
配置示例:"summary_mode": "md5"。
MindSpore 静态图 jit_level=O2 场景:支持上述配置的同时额外支持配置统计项列表,可选统计项为max、min、mean、l2norm,可从中任意选取组合搭配。其中mean、l2norm的结果为float数据格式。
配置示例:"summary_mode": ["max", "min"]。
-**说明**:"summary_mode"配置为"md5"时,所使用的校验算法为CRC-32算法。 +**说明**: + + +1. "summary_mode" 配置为 "md5" 时,所使用的校验算法为 CRC-32 算法。 + +**示例**: + - [PyTorch场景](03.config_examples.md#11-task-配置为-statistics) + - [MindSpore静态图场景](03.config_examples.md#21-task-配置为-statistics) + - [MindSpore动态图场景](03.config_examples.md#31-task-配置为-statistics) ### 1.3 task 配置为 tensor @@ -60,12 +74,21 @@ | list | 与[ 1.2 task 配置为 statistics ](#12-task-配置为-statistics)中的解释相同。 | 否 | | data_mode | 与[ 1.2 task 配置为 statistics ](#12-task-配置为-statistics)中的解释相同 | 否 | | file_format | tensor 数据的保存格式,str 类型,仅支持 MindSpore 静态图场景的 L2 级别配置该字段,其他场景不生效。可选参数:
"bin":dump 的 tensor 文件为二进制格式;
"npy":dump 的 tensor 文件后缀为 .npy,默认值。 | 否 | +| summary_mode | 控制 dump 文件输出的模式,str 类型,支持 PyTorch、MSAdapter、MindSpore 动态图。可选参数:
md5:dump 输出包含 CRC-32 值以及 API 统计信息的 dump.json 文件,用于验证数据的完整性;
statistics:dump 仅输出包含 API 统计信息的 dump.json 文件,默认值。| 否 | | online_run_uta | 在线预检模式开关,bool 类型,可选参数 true(开启)、false(关闭),默认未配置,表示关闭。配置为 true 表示开启在线预检。| 否 | | nfs_patha | 在线预检模式共享存储目录路径,str 类型,用于 GPU 设备和 NPU 设备间进行通信。仅在 online_run_ut 字段配置为 true 时生效,配置该参数后 host 和 port 不生效。 | 否 | | hosta | 在线预检模式局域网场景信息接收端 IP,str 类型,用于 GPU 设备和 NPU 设备间进行通信,NPU 侧须配置为 GPU 侧的局域网 IP 地址。仅在 online_run_ut 字段配置为 true 时生效,局域网场景时,不能配置 nfs_path 参数,否则局域网场景不生效。 | 否 | | porta | 在线预检模式局域网场景信息接收端端口号,int 类型,用于 GPU 设备和 NPU 设备间进行通信,NPU 侧须配置为 GPU 侧的端口号。仅在 online_run_ut 字段配置为 true 时生效,局域网场景时,不能配置 nfs_path 参数,否则局域网场景不生效。| 否 | -**a**:online_run_ut、nfs_path、host、port 等字段仅在线预检场景 NPU 机器生效。 +**说明**: + +1. online_run_ut、nfs_path、host、port 等字段仅在线预检场景 NPU 机器生效。 + +**示例**: + - [PyTorch场景](03.config_examples.md#12-task-配置为-tensor) + - [MindSpore静态图场景](03.config_examples.md#22-task-配置为-tensor) + - [MindSpore动态图场景](03.config_examples.md#32-task-配置为-tensor) + ### 1.4 task 配置为 run_ut @@ -80,22 +103,46 @@ | portb | 在线预检模式局域网场景信息接收端端口号,int 类型,用于 GPU 设备和 NPU 设备间进行通信,GPU 侧配置为本机可用端口。局域网场景时,不能配置 nfs_path 参数,否则局域网场景不生效。仅在 is_online 字段配置为 true 时生效。| 否 | | rank_listb | 指定在线预检的 Rank ID,默认值为 [0],list[int] 类型,应配置为大于等于 0 的整数,且须根据实际卡的 Rank ID 配置,若所配置的值大于实际训练所运行的卡的 Rank ID,则在线预检输出数据为空。GPU 和 NPU 须配置一致。仅在 is_online 字段配置为 true 时生效。 | 否 | -**a**:white_list 和 black_list 同时配置时,二者配置的 API 名单若无交集,则白名单生效,若 API 名单存在交集,则白名单排除的部分以及交集的 API 不进行 dump。 +**说明**: + +1. white_list 和 black_list 同时配置时,二者配置的 API 名单若无交集,则白名单生效,若 API 名单存在交集,则白名单排除的部分以及交集的 API 不进行 dump。 + +2. is_online、nfs_path、host、port、rank_list 等字段仅在线预检场景 GPU 机器生效。 -**b**:is_online、nfs_path、host、port、rank_list 等字段仅在线预检场景 GPU 机器生效。 +**示例**: +```json +{ + "task": "run_ut", + "dump_path": "/home/data_dump", + "rank": [], + "step": [], + "level": "L1", + + "run_ut": { + "white_list": [], + "black_list": [], + "error_data_path": "./" + } +} +``` ### 1.5 task 配置为 overflow_check -PyTorch 与 MindSpore 动态图场景下,"level"须为"L0"或"L1";MindSpore 静态图场景下,"level"须为"L2",且模型编译优化等级(jit_level)须为"O2"。 +PyTorch、MSAdapter 以及 MindSpore 动态图场景下,"level"须为"L0"或"L1";MindSpore 静态图场景下,"level"须为"L2",且模型编译优化等级(jit_level)须为"O2"。 | 参数 | 解释 | 是否必选 | | ------------- | ---------------------- | -------- | -| overflow_nums | 最大溢出次数,int 类型,默认为 1,仅 PyTorch 与 MindSpore 动态图场景支持。表示第 N 次溢出后,不再进行溢出检测。过程中检测到溢出 API 对应的 输入输出 数据均 dump。
**配置示例**:"overflow_nums": 3。配置为 -1 时,表示持续检测溢出直到训练结束。 | 否 | -| check_mode | 溢出类型,str 类型,仅 MindSpore 场景支持,可选参数:
"aicore":开启 AI Core 的溢出检测,不支持 MindSpore v2.3.0 以上版本;
"atomic":开启 Atomic 的溢出检测,不支持 MindSpore v2.3.0 以上版本;
"all":开启算子的溢出检测,默认值。
**配置示例**:"check_mode": "all"。 | 否 | +| overflow_nums | 最大溢出次数,int 类型,默认为 1,仅 PyTorch、MSAdapter 以及 MindSpore 动态图场景支持。表示第 N 次溢出后,不再进行溢出检测。过程中检测到溢出 API 对应的 输入输出 数据均 dump。
**配置示例**:"overflow_nums": 3。配置为 -1 时,表示持续检测溢出直到训练结束。 | 否 | +| check_mode | 溢出类型,str 类型,仅 MindSpore v2.3.0 以下版本的静态图场景支持,可选参数:
"aicore":开启 AI Core 的溢出检测;
"atomic":开启 Atomic 的溢出检测;
"all":开启算子的溢出检测,默认值。
**配置示例**:"check_mode": "all"。 | 否 | + +**示例**: + - [PyTorch场景](03.config_examples.md#14-task-配置为-overflow_check) + - [MindSpore静态图场景](03.config_examples.md#23-task-配置为-overflow_check) + - [MindSpore动态图场景](03.config_examples.md#33-task-配置为-overflow_check) ### 1.6 task 配置为 free_benchmark -仅 PyTorch 场景与 MindSpore 动态图场景支持,且"level"为"L1"。 +仅 PyTorch 与 MindSpore 动态图场景支持,且"level"为"L1"。 - task 配置为 free_benchmark 时,开启**无标杆比对**,在 NPU 环境下通过对当前模型 API 的输入添加扰动因子,二次执行,将得到的输出与未添加扰动因子前的输出进行比对,从而**得出该模型中可能存在因迁移等变化导致精度降低的 API**。 @@ -119,6 +166,10 @@ PyTorch 与 MindSpore 动态图场景下,"level"须为"L0"或"L1";MindSpore max_sample每个算子预热的采样次数的最大阈值(仅 PyTorch 场景支持),int 类型,默认值为 20。须配置 "if_preheat": "true"。否 +**示例**: + - [PyTorch场景](03.config_examples.md#15-task-配置为-free_benchmark) + - [MindSpore动态图场景](03.config_examples.md#34-task-配置为-free_benchmark) + #### 1.6.1 无标杆比对数据存盘格式 无标杆比对在 dump_path 目录下输出结果文件 `free_benchmark.csv`,如下示例: @@ -162,5 +213,15 @@ PyTorch 与 MindSpore 动态图场景下,"level"须为"L0"或"L1";MindSpore | L1 | ("param_name", "max", "min", "norm", "shape") | 是 | | L2 | ("param_name", *intervals, "=0", "max", "min", "norm", "shape") | 是 | - intervals就是根据值分布bounds划分出的区间。 - MindSpore静态图模式下,L0级别中暂不支持"MD5" +**说明**: + +1. intervals就是根据值分布bounds划分出的区间。 +2. MindSpore静态图模式下,L0级别中暂不支持"MD5" + +### 1.8 task 配置为 structure +structure 模式仅采集模型结构,无其他特殊配置。 + +**示例**: + - [PyTorch场景](03.config_examples.md#16-task-配置为-structure) + - [MindSpore动态图场景](03.config_examples.md#35-task-配置为-structure) + diff --git a/debug/accuracy_tools/msprobe/docs/03.config_examples.md b/debug/accuracy_tools/msprobe/docs/03.config_examples.md index 542250fac243f3ab2f1d0aff87bc509ac7c1a675..0d29a4eb1a824bba2c1bda1a214c9add2e87bdba 100644 --- a/debug/accuracy_tools/msprobe/docs/03.config_examples.md +++ b/debug/accuracy_tools/msprobe/docs/03.config_examples.md @@ -17,6 +17,7 @@ "statistics": { "scope": [], "list": [], + "tensor_list": [], "data_mode": ["all"], "summary_mode": "statistics" } diff --git a/debug/accuracy_tools/msprobe/docs/04.kernel_dump_PyTorch.md b/debug/accuracy_tools/msprobe/docs/04.kernel_dump_PyTorch.md index ce3fd54f5a6741b262f6248f70a9f1166ca0b4a6..346481aad12c42994669b7b3ea794843e49c1618 100644 --- a/debug/accuracy_tools/msprobe/docs/04.kernel_dump_PyTorch.md +++ b/debug/accuracy_tools/msprobe/docs/04.kernel_dump_PyTorch.md @@ -6,7 +6,7 @@ ## 1 kernel dump 配置示例 -使用 kernel dump 时,list 必须要填一个 API 名称,kernel dump 目前每个 step 只支持采集一个 API 的数据。 +使用 kernel dump 时,task 需要配置为 tensor , list 必须要填一个 API 名称,kernel dump 目前每个 step 只支持采集一个 API 的数据。 API 名称填写参考 L1 dump 结果文件 dump.json 中的API名称,命名格式为:`{api_type}.{api_name}.{API调用次数}.{forward/backward}`。 ```json diff --git a/debug/accuracy_tools/msprobe/docs/05.data_dump_PyTorch.md b/debug/accuracy_tools/msprobe/docs/05.data_dump_PyTorch.md index db9a989c9d1c731fd9099d311f3ab3b95e5c7d5d..b18a248aae8e8bc243d86ff414396a53125ee640 100644 --- a/debug/accuracy_tools/msprobe/docs/05.data_dump_PyTorch.md +++ b/debug/accuracy_tools/msprobe/docs/05.data_dump_PyTorch.md @@ -2,7 +2,7 @@ msprobe 工具主要通过在训练脚本内添加 dump 接口、启动训练的方式采集精度数据。 -dump的'tensor'模式采集数据量大小,可以参考[数据量基线](./26.data_dump_PyTorch_baseline.md)。 +dump "statistics"模式的性能膨胀大小"与"tensor"模式采集的数据量大小,可以参考[dump基线](./26.data_dump_PyTorch_baseline.md)。 本工具提供固定的 API 支持列表,若需要删除或增加 dump 的 API,可以在 msprobe/pytorch/hook_module/support_wrap_ops.yaml 文件内手动修改,如下示例: @@ -15,6 +15,52 @@ functional: # functional为算子类别,找到对应的类别,在该类别 删除API的场景:部分模型代码逻辑会存在API原生类型校验,工具执行dump操作时,对模型的API封装可能与模型的原生API类型不一致,此时可能引发校验失败,详见《[FAQ](FAQ.md)》中“异常情况”的第10和11条。 +## 快速上手 + +这个示例定义了一个 nn.Module 类型的简单网络,使用原型函数 PrecisionDebugger 进行数据采集。 + +```python +# 根据需要import包 +import torch +import torch.nn as nn +import torch.nn.functional as F + +# 导入工具的数据采集接口 +from msprobe.pytorch import PrecisionDebugger, seed_all + +# 在模型训练开始前固定随机性 +seed_all() + +# 在模型训练开始前实例化PrecisionDebugger +debugger = PrecisionDebugger() + +# 定义网络 +class ModuleOP(nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear_1 = nn.Linear(in_features=8, out_features=4) + self.linear_2 = nn.Linear(in_features=4, out_features=2) + + def forward(self, x): + x1 = self.linear_1(x) + x2 = self.linear_2(x1) + r1 = F.relu(x2) + return r1 + +if __name__ == "__main__": + module = ModuleOP() + + # 开启数据 dump + debugger.start(model=module) + x = torch.randn(10, 8) + out = module(x) + loss = out.sum() + loss.backward() + + # 关闭数据 dump + debugger.stop() +``` + ## 1 接口介绍 ### 1.1 PrecisionDebugger @@ -30,9 +76,11 @@ PrecisionDebugger(config_path=None, task=None, dump_path=None, level=None, model 1. config_path:指定 dump 配置文件路径; 2. model:指定需要采集 Module 级数据的模型,支持传入 torch.nn.Module 或 list[torch.nn.Module] 类型,默认未配置。 level 配置为"L0"或"mix"时,必须在该接口或 **start** 接口中配置该参数。该参数在将来会从该接口移除,建议在 **start** 接口中配置该参数。 -3. 其他参数均在 [config.json](../config.json) 文件中可配,详细配置可见 [config.json 介绍](./02.config_introduction.md)。 +3. 其他参数均在 config.json 文件中可配,详细配置可见 [config.json 介绍](./02.config_introduction.md)。 -此接口的参数均不是必要,且优先级高于 [config.json](../config.json) 文件中的配置,但可配置的参数相比 config.json 较少。 +此接口的参数均不是必要(均不配置的情况下默认采集所有 rank 和 step 的 L1 级别的统计数据),且优先级高于 config.json 文件中的配置,但可配置的参数相比 config.json 较少。 + +注:此接口的初始化需与采集目标在同一个进程中,否则将无法采集目标数据。 ### 1.2 start @@ -41,12 +89,15 @@ level 配置为"L0"或"mix"时,必须在该接口或 **start** 接口中配置 **原型**: ```Python -debugger.start(model=None) +debugger.start(model=None, token_range=None) ``` 1. model:指定需要采集 Module 级数据的模型,支持传入 torch.nn.Module、list[torch.nn.Module]或Tuple[torch.nn.Module] 类型,默认未配置。 -level 配置为"L0"或"mix"时,必须在该接口或 **PrecisionDebugger** 接口中配置该参数。 +level 配置为"L0"|"mix"或token_range不为None时,必须在该接口或 **PrecisionDebugger** 接口中配置该参数。 本接口中的 model 比 PrecisionDebugger 中 model 参数优先级更高,会覆盖 PrecisionDebugger 中的 model 参数。 +
对于复杂模型,如果仅需要监控一部分(如model.A,model.A extends torch.nn.Module),传入需要监控的部分(如model.A)即可。 +注意:传入的当前层不会被dump,工具只会dump传入层的子层级。如传入了model.A,A本身不会被dump,而是会dump A.x, A.x.xx等。 +2. token_range:指定推理模型采集时的token循环始末范围,支持传入[int, int]类型,代表[start, end],范围包含边界,默认未配置。 ### 1.3 stop @@ -183,58 +234,65 @@ save(variable, name, save_backward=True) **参数说明**: | 参数名称 | 参数含义 | 支持数据类型 | 是否必选| | ---------- | ------------------| ------------------- | ------------------- | -| variable | 需要保存的变量 |dict, list, torch.tensor, int, float, str | 是 | +| variable | 需要保存的变量 |dict, list, tuple, torch.tensor, int, float, str | 是 | | name | 指定的名称 | str | 是 | | save_backward | 是否保存反向数据 | boolean | 否 | -## 2 示例代码 +### 1.10 set_init_step -### 2.1 快速上手 +**功能说明**:设置起始step数,step数默认从0开始计数,使用该接口后step从指定值开始计数。该函数需要写在训练迭代的循环开始前,不能写在循环内。 -这个示例定义了一个 nn.Module 类型的简单网络,在进行数据采集时使用原型函数 PrecisionDebugger 传入 config_path 参数和 model 参数。 +**原型**: -```python -# 根据需要import包 -import torch -import torch.nn as nn -import torch.nn.functional as F +```Python +debugger.set_init_step(step) +``` -# 导入工具的数据采集接口 -from msprobe.pytorch import PrecisionDebugger, seed_all +**参数说明**: -# 在模型训练开始前固定随机性 -seed_all() -# 在模型训练开始前实例化PrecisionDebugger -debugger = PrecisionDebugger(config_path='./config.json') +1.step: 指定的起始step数。 -# 定义网络 -class ModuleOP(nn.Module): - def __init__(self) -> None: - super().__init__() - self.linear_1 = nn.Linear(in_features=8, out_features=4) - self.linear_2 = nn.Linear(in_features=4, out_features=2) +### 1.11 register_custom_api - def forward(self, x): - x1 = self.linear_1(x) - x2 = self.linear_2(x1) - r1 = F.relu(x2) - return r1 +**功能说明**:注册用户自定义的api到工具用于 L1 dump 。 -if __name__ == "__main__": - module = ModuleOP() - # 开启数据 dump - debugger.start(model=module) +**原型**: - x = torch.randn(10, 8) - out = module(x) - loss = out.sum() - loss.backward() +```Python +debugger.register_custom_api(module, api_name, api_prefix) +``` +**参数说明**: - # 关闭数据 dump - debugger.stop() +以 torch.matmul api 为例 + +1.module: api 所属的包,即传入 torch。 + +2.api_name: api 名,string类型,即传入 "matmul"。 + +3.api_prefix: [dump.json](./27.dump_json_instruction.md) 中 api 名的前缀,可选,默认为包名的字符串格式, 即 "torch"。 + +### 1.12 restore_custom_api + +**功能说明**:恢复用户原有的自定义的api,取消 dump 。 + +**原型**: + +```Python +debugger.restore_custom_api(module, api_name) ``` +**参数说明**: + +以 torch.matmul api 为例 + +1.module: api 所属的包,即传入 torch。 + +2.api_name: api 名,string类型,即传入 "matmul"。 -### 2.2 采集完整的前反向数据 + +## 2 示例代码 + + +### 2.1 采集完整的前反向数据 ```Python from msprobe.pytorch import PrecisionDebugger, seed_all @@ -255,7 +313,7 @@ for data, label in data_loader: debugger.step() # 结束一个step的dump ``` -### 2.3 采集指定代码块的前反向数据 +### 2.2 采集指定代码块的前反向数据 ```Python from msprobe.pytorch import PrecisionDebugger, seed_all @@ -279,7 +337,7 @@ for data, label in data_loader: debugger.step() # 结束一个step的dump ``` -### 2.4 采集函数模块化数据 +### 2.3 采集函数模块化数据 ```Python # 根据需要import包 @@ -321,6 +379,80 @@ if __name__ == "__main__": debugger.stop() ``` +### 2.4 跨文件采集数据 +为了确保所有API都被工具封装,PrecisionDebugger的实例化通常放在训练工程的入口位置,但有的时候,模型定义会在另一个文件中。 假设有两个文件,train.py(为训练工程入口)module.py(为模型定义文件),为了采集module.py中定义的ModuleOP模块中某些子模块或API的前反向数据,需要在train.py和module.py文件中分别导入PrecisionDebugger并进行如下配置。 + +train.py文件: + +```Python +# 根据需要import包 +import torch +from module import ModuleOP + +# 导入工具的数据采集接口 +from msprobe.pytorch import PrecisionDebugger + +# 将PrecisionDebugger的实例化放在文件的开始位置,即导包后的位置,确保所有API都被封装 +debugger = PrecisionDebugger(config_path='./config.json') + +if __name__ == "__main__": + module = ModuleOP() + + x = torch.randn(10, 8) + out = module(x) + loss = out.sum() + loss.backward() +``` + +module.py文件: + +```Python +import torch +import torch.nn as nn +import torch.nn.functional as F + +from msprobe.pytorch import PrecisionDebugger + +# 定义网络 +class ModuleOP(nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear_1 = nn.Linear(in_features=8, out_features=4) + self.linear_2 = nn.Linear(in_features=4, out_features=2) + + def forward(self, x): + PrecisionDebugger.start() + x1 = self.linear_1(x) + PrecisionDebugger.stop() + x2 = self.linear_2(x1) + r1 = F.relu(x2) + return r1 + +``` + +### 2.5 推理模型采集指定token_range + +```Python +from vllm import LLM, SamplingParams +from msprobe.pytorch import PrecisionDebugger, seed_all +# 在模型训练开始前固定随机性 +seed_all() +# 请勿将PrecisionDebugger的初始化流程插入到循环代码中 +debugger = PrecisionDebugger(config_path="./config.json", dump_path="./dump_path") +# 模型定义及初始化等操作 +prompts = ["Hello, my name is"] +sampling_params = SamplingParams(temprature=0.8, top_p=0.95) +llm = LLM(model='...') +model = llm.llm_engine.model_executor.driver_worker.worker.model_runner.get_model() +# 开启数据dump, 指定采集推理模型逐字符循环推理中的第1~3次 +debugger.start(model=model, token_range=[1,3]) +# 推理模型生成的逻辑 +output = llm.generate(prompts, sampling_params=sampling_params) +# 关闭数据dump并落盘 +debugger.stop() +debugger.step() +``` + ## 3 dump 结果文件介绍 训练结束后,工具将 dump 的数据保存在 dump_path 参数指定的目录下。目录结构示例如下: @@ -334,8 +466,8 @@ if __name__ == "__main__": | | | | ├── Functional.linear.5.backward.output.pt # 命名格式为{api_type}.{api_name}.{API调用次数}.{forward/backward}.{input/output}.{参数序号}, 其中,“参数序号”表示该API的第n个输入或输出,例如1,则为第一个参数,若该参数为list格式,则根据list继续排序,例如1.1,表示该API的第1个参数的第1个元素。 | | | | ... | | | | ├── Module.conv1.Conv2d.forward.0.input.0.pt # 命名格式为{Module}.{module_name}.{class_name}.{forward/backward}.{调用次数}.{input/output}.{参数序号}, 其中,“参数序号”表示该Module的第n个参数,例如1,则为第一个参数,若该参数为list格式,则根据list继续排序,例如1.1,表示该Module的第1个参数的第1个元素。 -| | | | ├── Module.conv1.Conv2D.forward.0.parameters.bias.pt # 模块参数数据:命名格式为{Module}.{module_name}.{class_name}.forward.{调用次数}.parameters.{parameter_name}。 -| | | | └── Module.conv1.Conv2D.parameters_grad.weight.pt # 模块参数梯度数据:命名格式为{Module}.{module_name}.{class_name}.parameters_grad.{parameter_name}。因为同一模块的参数使用同一梯度进行更新,所以参数梯度文件名不包含调用次数。 +| | | | ├── Module.conv1.Conv2d.forward.0.parameters.bias.pt # 模块参数数据:命名格式为{Module}.{module_name}.{class_name}.forward.{调用次数}.parameters.{parameter_name}。 +| | | | └── Module.conv1.Conv2d.parameters_grad.weight.pt # 模块参数梯度数据:命名格式为{Module}.{module_name}.{class_name}.parameters_grad.{parameter_name}。因为同一模块的参数使用同一梯度进行更新,所以参数梯度文件名不包含调用次数。 | | | | # 当dump时传入的model参数为List[torch.nn.Module]或Tuple[torch.nn.Module]时,模块级数据的命名中包含该模块在列表中的索引index,命名格式为{Module}.{index}.*,*表示以上三种模块级数据的命名格式,例如:Module.0.conv1.Conv2d.forward.0.input.0.pt。 │ | | ├── dump.json │ | | ├── stack.json @@ -355,7 +487,7 @@ if __name__ == "__main__": ``` * `rank`:设备 ID,每张卡的数据保存在对应的 `rank{ID}` 目录下。非分布式场景下没有 rank ID,目录名称为 rank。 * `dump_tensor_data`:保存采集到的张量数据。 -* `dump.json`: 保存API或Module前反向数据的统计量信息。包含dump数据的API名称或Module名称,各数据的dtype、 shape、max、min、mean、L2norm(L2范数,平方根)统计信息以及当配置summary_mode="md5"时的CRC-32数据。具体介绍可参考[dump.json文件说明](./27.dump_json_instruction.md#1-dumpjson文件介绍pytorch)。 +* `dump.json`: 保存API或Module前反向数据的统计量信息。包含dump数据的API名称或Module名称,各数据的dtype、 shape、max、min、mean、L2norm(L2范数,平方根)统计信息以及当配置summary_mode="md5"时的CRC-32数据。具体介绍可参考[dump.json文件说明](./27.dump_json_instruction.md#1-PyTorch场景下的dump.json文件)。 * `stack.json`:API/Module的调用栈信息。 * `construct.json`:分层分级结构,level为L1时,construct.json内容为空。 @@ -366,12 +498,14 @@ dump 过程中,pt 文件在对应算子或者模块被执行后就会落盘, pt 文件保存的前缀和 PyTorch 对应关系如下: -| 前缀 | Torch模块 | -| ----------- | ------------------- | +| 前缀 | Torch模块 | +|-------------|---------------------| | Tensor | torch.Tensor | | Torch | torch | | Functional | torch.nn.functional | -| NPU | NPU 亲和算子 | +| NPU | NPU 亲和算子 | | VF | torch._VF | | Aten | torch.ops.aten | | Distributed | torch.distributed | +| MindSpeed | mindspeed.ops | + diff --git a/debug/accuracy_tools/msprobe/docs/06.data_dump_MindSpore.md b/debug/accuracy_tools/msprobe/docs/06.data_dump_MindSpore.md index f7507facd2a92f3acbefdc92fa6cd808a155d6e3..fdf19b2b7905e15cdf52d964a5b235b3eb9cb5f8 100644 --- a/debug/accuracy_tools/msprobe/docs/06.data_dump_MindSpore.md +++ b/debug/accuracy_tools/msprobe/docs/06.data_dump_MindSpore.md @@ -26,12 +26,14 @@ msprobe 工具通过在训练脚本中添加 `PrecisionDebugger` 接口并启动训练的方式,采集模型在运行过程中的精度数据。该工具支持对MindSpore的静态图和动态图场景进行不同Level等级的精度数据采集。 -dump 的"tensor"模式采集数据量大小,可以参考[数据量基线](data_dump_MindSpore/data_dump_MindSpore_baseline.md)。 +dump "statistics"模式的性能膨胀大小"与"tensor"模式采集的数据量大小,可以参考[dump基线](data_dump_MindSpore/data_dump_MindSpore_baseline.md)。 ## 5. 场景介绍 -### 5.1 静态图场景 -在静态图场景下,msprobe 仅支持 **L2 Level** 的数据采集。 +### 5.1 静态图场景 +在静态图场景下,msprobe 支持 **L0 Level** 和 **L2 Level** 的数据采集。且当 MindSpore 版本高于 2.5.0 时,若需采集 **L2 Level** 数据,必须使用编包时添加了`--include-mod=adump`选项的 mindstudio-probe whl 包进行 msprobe 工具安装。 +- **L0 Level(Cell 级)** :采集 `Cell` 对象的数据,适用于需要分析特定网络模块的情况。仅支持 2.7.0 及以上版本的 MindSpore 框架。 + - **L2 Level(Kernel 级)** :采集底层算子的输入输出数据,适用于深入分析算子级别的精度问题。 采集方式请参见[示例代码 > 静态图场景](#71-静态图场景)。详细介绍请参见[《config.json 配置文件介绍》](./02.config_introduction.md#11-通用配置)中的“level 参数”和[《config.json 配置示例》](./03.config_examples.md#2-mindspore-静态图场景) 中的“MindSpore 静态图场景”。 @@ -46,7 +48,7 @@ dump 的"tensor"模式采集数据量大小,可以参考[数据量基线](data 采集方式请参见[示例代码 > 动态图场景](#72-动态图场景)。 -> **注意** :动态图模式下,使用 `PSJit` 或 `PIJit` 装饰的部分实际以静态图模式执行,此时的 **Kernel 级(L2 Level)** 数据采集方式与静态图场景相同。 +> **注意** :动态图模式下,使用 `mindspore.jit` 装饰的部分实际以静态图模式执行,此时的 **Kernel 级(L2 Level)** 数据采集方式与静态图场景相同。 - **L0 Level(Cell 级)** :采集 `Cell` 对象的数据,适用于需要分析特定网络模块的情况。 - **L1 Level(API 级)** :采集 MindSpore API 的输入输出数据,适用于定位 API 层面的精度问题。 @@ -56,7 +58,7 @@ dump 的"tensor"模式采集数据量大小,可以参考[数据量基线](data - **debug level (单点保存)**:单点保存网络中变量的正反向数据,适用于用户熟悉网络结构的场景。 -详细介绍请参见[《config.json 配置文件介绍》](./02.config_introduction.md#11-通用配置)中的“level 参数”和[《config.json 配置示例》](./03.config_examples.md#3-mindspore-动态图场景) 中的“MindSpore 动态图场景”。 +详细介绍请参见[《config.json 配置文件介绍》](./02.config_introduction.md#11-通用配置)中的“level 参数”。 ## 6 接口介绍 @@ -85,12 +87,15 @@ PrecisionDebugger(config_path=None, task=None, dump_path=None, level=None, step= **原型**: ```Python -start(model=None) +start(model=None, token_range=None) ``` **参数说明**: -1. model:指定需要采集数据的实例化模型,支持传入mindspore.nn.Cell、List[mindspore.nn.Cell]或Tuple[mindspore.nn.Cell] 类型, 默认未配置。Cell级别("L0" level)dump 与 "mix" level dump 时,必须传入 model 才可以采集 model 内的所有Cell 对象数据。API级别("L1" level)dump 时,传入 model 可以采集 model 内包含 primitive op 对象在内的所有 API 数据,若不传入 model 参数,则只采集非 primitive op 的 API 数据。 +1. model:指定需要采集数据的实例化模型,支持传入mindspore.nn.Cell、List[mindspore.nn.Cell]或Tuple[mindspore.nn.Cell] 类型,默认未配置。Cell级别("L0" level)dump 与 "mix" level dump 时,必须传入 model 才可以采集 model 内的所有Cell 对象数据。API级别("L1" level)dump 时,传入 model 可以采集 model 内包含 primitive op 对象在内的所有 API 数据,若不传入 model 参数,则只采集非 primitive op 的 API 数据。token_range不为None时,必须传入model参数。 +
对于复杂模型,如果仅需要监控一部分(如model.A,model.A extends mindspore.nn.Cell),传入需要监控的部分(如model.A)即可。 +注意:传入的当前层不会被dump,工具只会dump传入层的子层级。如传入了model.A,A本身不会被dump,而是会dump A.x, A.x.xx等。 +2. token_range:指定推理模型采集时的token循环始末范围,支持传入[int, int]类型,代表[start, end],范围包含边界,默认未配置。 #### 6.1.2 stop @@ -110,7 +115,7 @@ stop() **功能说明**:结束一个 step 的数据采集,完成所有数据落盘并更新 dump 参数。在一个 step 结束的位置添加,且必须在 **stop** 函数之后的位置调用。 该函数需要配合 **start** 和 **stop** 函数使用,尽量添加在反向计算代码之后,否则可能会导致反向数据丢失。 -**仅未使用 Model 高阶 API 的动态图场景支持。** +**仅未使用 Model 高阶 API 的动态图和静态图场景支持。** **原型**: @@ -144,15 +149,65 @@ save(variable, name, save_backward=True) **参数说明**: | 参数名称 | 参数含义 | 支持数据类型 | 是否必选| | ---------- | ------------------| ------------------- | ------------------- | -| variable | 需要保存的变量 |dict, list, torch.tensor, int, float, str | 是 | +| variable | 需要保存的变量 |dict, list, tuple, torch.tensor, int, float, str | 是 | | name | 指定的名称 | str | 是 | | save_backward | 是否保存反向数据 | boolean | 否 | +#### 6.1.6 set_init_step + +**功能说明**:设置起始step数,step数默认从0开始计数,使用该接口后step从指定值开始计数。该函数需要写在训练迭代的循环开始前,不能写在循环内。 + +**原型**: + +```Python +set_init_step(step) +``` + +**参数说明**: + +1.step: 指定的起始step数。 + + +#### 6.1.7 register_custom_api + +**功能说明**:注册用户自定义的api到工具,用于 L1 dump 。 + +**原型**: + +```Python +debugger.register_custom_api(module, api_name, api_prefix) +``` +**参数说明**: + +以 torch.matmul api 为例 + +1.module: api 所属的包,即传入 torch。 + +2.api_name: api 名,string类型,即传入 "matmul"。 + +3.api_prefix: [dump.json](./27.dump_json_instruction.md) 中 api 名的前缀,可选,默认为包名的字符串格式, 即 "torch"。 + +#### 6.1.8 restore_custom_api +**功能说明**:恢复用户原有的自定义的api,取消 dump 。 -### 6.2 msprobe.mindspore.common.utils.MsprobeStep +**原型**: + +```Python +debugger.restore_custom_api(module, api_name) +``` +**参数说明**: + +以 torch.matmul api 为例 + +1.module: api 所属的包,即传入 torch。 -**功能说明**:MindSpore Callback类,自动在每个step开始时调用start()接口,在每个step结束时调用stop()、step()接口。实现使用 Model 高阶 API 的动态图场景下 L0、L1、mix 级别的精度数据采集控制,控制粒度为单个 **Step** ,而 PrecisionDebugger.start, PrecisionDebugger.stop 接口的控制粒度任意训练代码段。 +2.api_name: api 名,string类型,即传入 "matmul"。 + + +### 6.2 msprobe.mindspore.MsprobeStep + +**功能说明**:MindSpore Callback类,自动在每个step开始时调用start()接口,在每个step结束时调用stop()、step()接口。实现使用 Model 高阶 API 的动态图场景下 L0、L1、mix 级别,和静态图场景下 L0级别的精度数据采集控制,控制粒度为单个 **Step** ,而 PrecisionDebugger.start, PrecisionDebugger.stop 接口的控制粒度为任意训练代码段。 **原型**: @@ -164,7 +219,17 @@ MsprobeStep(debugger) 1. debugger:PrecisionDebugger对象。 -### 6.3 msprobe.mindspore.seed_all +### 6.3 msprobe.mindspore.MsprobeInitStep + +**功能说明**:MindSpore Callback 类,自动获取并设置初始 step 值。仅适用于静态图 O0/O1 模式的断点续训场景。 + +**原型**: + +```Python +MsprobeInitStep() +``` + +### 6.4 msprobe.mindspore.seed_all **功能说明**:用于固定网络中的随机性和开启确定性计算。 @@ -181,12 +246,59 @@ seed_all(seed=1234, mode=False, rm_dropout=True) 3. rm_dropout:控制dropout失效的开关。可配置 True 或 False,默认值:True,非必选。参数示例:rm_dropout=True。该参数设置为 True 后,将会使mindspore.ops.Dropout,mindspore.ops.Dropout2D,mindspore.ops.Dropout3D,mindspore.mint.nn.Dropout和mindspore.mint.nn.functional.dropout失效,以避免因随机dropout造成的网络随机性。建议在采集mindspore数据前开启。注意:通过rm_dropout控制dropout失效或生效需要在初始化Dropout实例前调用才能生效。 +## 7. 示例代码 + +### 7.1 静态图场景 +#### 7.1.1 L0 级别 +**说明**: 静态图 L0 级别的Dump功能是基于mindspore.ops.TensorDump算子实现。在Ascend平台上的Graph模式下,可以通过设置环境变量 [MS_DUMP_SLICE_SIZE 和 MS_DUMP_WAIT_TIME](https://www.mindspore.cn/docs/zh-CN/r2.5.0/api_python/env_var_list.html) 解决在输出大Tesnor或输出Tensor比较密集场景下算子执行失败的问题。 -## 7. 示例代码 +##### 7.1.1.1 未使用 Model 高阶 API -### 7.1 静态图场景 + +```python +import mindspore as ms +ms.set_context(mode=ms.GRAPH_MODE, device_target="Ascend") + +from msprobe.mindspore import PrecisionDebugger +debugger = PrecisionDebugger(config_path="./config.json") + +# 模型、损失函数的定义以及初始化等操作 +# ... +model = Network() +# 数据集迭代的地方往往是模型开始训练的地方 +for data, label in data_loader: + debugger.start(model) # 进行 L0 级别下Cell 对象的数据采集时调用 + # 如下是模型每个 step 执行的逻辑 + grad_net = ms.grad(model)(data) + # ... + debugger.step() # 更新迭代数 +``` + +##### 7.1.1.2 使用 Model 高阶 API + + +```python +import mindspore as ms +from mindspore.train import Model +ms.set_context(mode=ms.GRAPH_MODE, device_target="Ascend") + +from msprobe.mindspore import PrecisionDebugger +from msprobe.mindspore.common.utils import MsprobeStep +debugger = PrecisionDebugger(config_path="./config.json") + +# 模型、损失函数的定义以及初始化等操作 +# ... + +model = Network() +# 进行 L0 级别下 Cell 对象的数据采集时调用 +debugger.start(model) +trainer = Model(model, loss_fn=loss_fn, optimizer=optimizer, metrics={'accuracy'}) +trainer.train(1, train_dataset, callbacks=[MsprobeStep(debugger)]) +``` + +#### 7.1.2 L2 级别 ```python import mindspore as ms @@ -198,7 +310,8 @@ debugger.start() # 请勿将以上初始化流程置于模型实例化或 mindspore.communication.init 调用后 # 模型定义和训练代码 # ... - +debugger.stop() +debugger.step() ``` ### 7.2 动态图场景 @@ -297,15 +410,43 @@ trainer = Model(model, loss_fn=loss_fn, optimizer=optimizer, metrics={'accuracy' trainer.train(1, train_dataset) ``` + +#### 7.2.3 推理模型采集指定token_range +需要配合mindtorch套件改造原推理代码,套件包装后使用方式与torch一致,唯一区别为import的是msprobe.mindspore下的PrecisionDebugger。 + +```Python +from vllm import LLM, SamplingParams +from msprobe.mindspore import PrecisionDebugger, seed_all +# 在模型训练开始前固定随机性 +seed_all() +# 请勿将PrecisionDebugger的初始化流程插入到循环代码中 +debugger = PrecisionDebugger(config_path="./config.json", dump_path="./dump_path") +# 模型定义及初始化等操作 +prompts = ["Hello, my name is"] +sampling_params = SamplingParams(temprature=0.8, top_p=0.95) +llm = LLM(model='...') +model = llm.llm_engine.model_executor.driver_worker.worker.model_runner.get_model() +# 开启数据dump, 指定采集推理模型逐字符循环推理中的第1~3次 +debugger.start(model=model, token_range=[1,3]) +# 推理模型生成的逻辑 +output = llm.generate(prompts, sampling_params=sampling_params) +# 关闭数据dump并落盘 +debugger.stop() +debugger.step() +``` + ## 8. dump 结果文件介绍 ### 8.1 静态图场景 -训练结束后,数据将保存在 `dump_path` 指定的目录下。 +训练结束后,数据将保存在 `dump_path` 指定的目录下。
+L0 级别 dump 的目录结构与动态图场景下目录结构一致。
+L2 级别 dump 的目录结构如下所示: -若jit_level=O2,且使用mindstudio-probe发布包或源码编包时添加了`--include-mod=adump`选项,目录结构示例如下: +若jit_level=O2,MindSpore 版本不低于 2.5.0,且使用mindstudio-probe发布包或源码编包时添加了`--include-mod=adump`选项,目录结构示例如下: ``` ├── dump_path +│ ├── acl_dump_{device_id}.json │ ├── rank_0 │ | ├── {timestamp} │ | │ ├── step_0 @@ -329,9 +470,9 @@ trainer.train(1, train_dataset) **说明** 1. 若配置文件中指定落盘npy格式,但是实际数据格式不在npy支持范围内(如bf16、int4等),则该tensor会以原始码流落盘,并不会转换为npy格式。 2. 若原始文件全名长度超过255个字符,则文件基础名会被转换为长度为32位的随机数字字符串,原始文件名与转换后文件名的对应关系会保存在同目录下的`mapping.csv`文件中。 +3. acl_dump_{device_id}.json 为在 Dump 接口调用过程中生成的中间文件,一般情况下无需关注。 - -其他场景请参见 MindSpore 官方文档中的[数据对象目录](https://www.mindspore.cn/docs/zh-CN/r2.4.0/model_train/debug/dump.html)。 +其他场景下,除 kernel_kbyk_dump.json(jit_level=O0/O1)、kernel_graph_dump.json(jit_level=O2)等无需关注的中间文件外的其他 dump 结果文件请参见 MindSpore 官方文档中的[ Ascend 下 O0/O1 模式 Dump 数据对象目录和数据文件介绍](https://www.mindspore.cn/docs/zh-CN/r2.5.0/model_train/debug/dump.html#%E6%95%B0%E6%8D%AE%E5%AF%B9%E8%B1%A1%E7%9B%AE%E5%BD%95%E5%92%8C%E6%95%B0%E6%8D%AE%E6%96%87%E4%BB%B6%E4%BB%8B%E7%BB%8D)与[ Ascend 下 O2 模式 Dump 数据对象目录和数据文件介绍](https://www.mindspore.cn/docs/zh-CN/r2.5.0/model_train/debug/dump.html#%E6%95%B0%E6%8D%AE%E5%AF%B9%E8%B1%A1%E7%9B%AE%E5%BD%95%E5%92%8C%E6%95%B0%E6%8D%AE%E6%96%87%E4%BB%B6%E4%BB%8B%E7%BB%8D-1)。 ### 8.2 动态图场景 @@ -348,9 +489,9 @@ dump 结果目录结构示例如下: | | | | ├── Tensor.__add__.0.forward.output.0.npy | | | | ... | | | | ├── Jit.AlexNet.0.forward.input.0.npy -| | | | ├── Primitive.conv2d.Conv2D.0.forward.input.0.npy -| | | | ├── Cell.conv1.Conv2D.forward.0.parameters.weight.npy # 模块参数数据:命名格式为{Cell}.{cell_name}.{class_name}.forward.{调用次数}.parameters.{parameter_name}。 -| | | | ├── Cell.conv1.Conv2D.parameters_grad.weight.npy # 模块参数梯度数据:命名格式为{Cell}.{cell_name}.{class_name}.parameters_grad.{parameter_name}。因为同一模块的参数使用同一梯度进行更新,所以参数梯度文件名不包含调用次数。 +| | | | ├── Primitive.conv2d.Conv2d.0.forward.input.0.npy +| | | | ├── Cell.conv1.Conv2d.forward.0.parameters.weight.npy # 模块参数数据:命名格式为{Cell}.{cell_name}.{class_name}.forward.{调用次数}.parameters.{parameter_name}。 +| | | | ├── Cell.conv1.Conv2d.parameters_grad.weight.npy # 模块参数梯度数据:命名格式为{Cell}.{cell_name}.{class_name}.parameters_grad.{parameter_name}。因为同一模块的参数使用同一梯度进行更新,所以参数梯度文件名不包含调用次数。 | | | | └── Cell.relu.ReLU.forward.0.input.0.npy # 命名格式为{Cell}.{cell_name}.{class_name}.{forward/backward}.{调用次数}.{input/output}.{参数序号}, 其中,“参数序号”表示该Cell的第n个参数,例如1,则为第一个参数,若该参数为list格式,则根据list继续排序,例如1.1,表示该Cell的第1个参数的第1个元素。 | | | | # 当dump时传入的model参数为List[mindspore.nn.Cell]或Tuple[mindspore.nn.Cell]时,模块级数据的命名中包含该模块在列表中的索引index,命名格式为{Cell}.{index}.*,*表示以上三种模块级数据的命名格式,例如:Cell.0.relu.ReLU.forward.0.input.0.npy。 │ | | ├── dump.json @@ -372,17 +513,41 @@ dump 结果目录结构示例如下: * `rank`:设备 ID,每张卡的数据保存在对应的 `rank{ID}` 目录下。非分布式场景下没有 rank ID,目录名称为 rank。 * `dump_tensor_data`:保存采集到的张量数据。 -* `dump.json`: 保存API或Cell前反向数据的统计量信息。包含dump数据的API名称或Cell名称,各数据的dtype、 shape、max、min、mean、L2norm(L2范数,平方根)统计信息以及当配置summary_mode="md5"时的CRC-32数据。具体介绍可参考[dump.json文件说明](./27.dump_json_instruction.md#2-dumpjson文件示例mindspore)。 +* `dump.json`: 保存API或Cell前反向数据的统计量信息。包含dump数据的API名称或Cell名称,各数据的dtype、 shape、max、min、mean、L2norm(L2范数,平方根)统计信息以及当配置summary_mode="md5"时的CRC-32数据。具体介绍可参考[dump.json文件说明](./27.dump_json_instruction.md#2-mindspore-场景下的-dumpjson-文件)。 * `stack.json`:API/Cell的调用栈信息。 * `construct.json`:分层分级结构,level为L1时,construct.json内容为空。 dump 过程中,npy 文件在对应API或者模块被执行后就会落盘,而 json 文件则需要在正常执行 PrecisionDebugger.stop() 后才会写入完整数据,因此,程序异常终止时,被执行API对应的 npy 文件已被保存,但 json 文件中的数据可能丢失。 -动态图场景下使能 PSJit 或 PIJit,装饰特定 Cell 或 function,被装饰的部分会全部/部分使能**静态图**流程。 +动态图场景下使用 `mindspore.jit` 装饰特定 Cell 或 function 时,被装饰的部分会被编译成**静态图**执行。 + +- config.json 文件配置 level 为 L0 或 mix,且 MindSpore 版本不低于 2.7.0 时, 若存在 construct 方法被 `mindspore.jit` 装饰的 Cell 对象,则 dump_path 下将生成 `graph` 与 `pynative` 目录,分别存放 construct 方法被 `mindspore.jit` 装饰的 Cell 对象的精度数据、其它Cell 或 API 对象的精度数据。示例如下: + +```lua +├── dump_path +│ ├── graph +│ | ├── step0 +│ | | ├── rank0 +│ | │ | ├── dump_tensor_data +| | | | | ├── ... +│ | | | ├── dump.json +│ | | | ├── stack.json +│ | | | └── construct.json +│ | | ├── ... +│ ├── pynative +│ | ├── step0 +│ | | ├── rank0 +│ | │ | ├── dump_tensor_data +| | | | | ├── ... +│ | | | ├── dump.json +│ | | | ├── stack.json +│ | | | └── construct.json +│ | | ├── ... +``` -- PSJit 场景下 config.json 文件配置 level 为 L1 时,被 PSJit 装饰的部分也作为 API 被 dump 到对应目录;配置 level 为 L2 时,则只会 dump 用户网络中静态图流程下的相关 kernel,其结果目录同jit_level 为 O0/O1 时的静态图 dump 相同。 -- PIJit 场景下 config.json 文件配置 level 为 L1 时,会被还原为动态图,按 API 粒度进行 dump;配置 level 为 L2 时,则只会 dump 用户网络中静态图流程下的相关 kernel。 +- config.json 文件配置 level 为 L1 时, 若 `mindspore.jit` 的 `capture_mode` 参数设置为 ast(原 PSJit 场景), 则被装饰的部分也作为 API 被 dump 到对应目录;若 `mindspore.jit` 的 `capture_mode` 参数设置为 bytecode(原 PIJit 场景), 则被装饰的部分会被还原为动态图,按 API 粒度进行 dump。 +- config.json 文件配置 level 为 L2 时, 仅会 dump 被 `mindspore.jit` 装饰部分的 kernel 精度数据,其结果目录同 jit_level 为 O0/O1 时的静态图 dump 结果相同。 npy文件名的前缀含义如下: @@ -393,12 +558,11 @@ npy文件名的前缀含义如下: | Primitive | mindspore.ops.Primitive API数据 | | Mint | mindspore.mint API数据 | | MintFunctional | mindspore.mint.nn.functional API数据 | +| MintDistributed | mindspore.mint.distributed API数据 | | Distributed | mindspore.communication.comm_func API数据 | | Jit | 被"jit"装饰的模块或函数数据 | | Cell | mindspore.nn.Cell 类(模块)数据 | - - ## 9.补充说明 ### 9.1 修改 API 支持列表 @@ -411,3 +575,6 @@ ops: - adaptive_avg_pool2d - adaptive_avg_pool3d ``` +### 9.2 不支持模型 + +静态图场景L0级暂不支持Yi模型。 \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/docs/07.accuracy_checker_PyTorch.md b/debug/accuracy_tools/msprobe/docs/07.accuracy_checker_PyTorch.md index b07568e25a2915a4e8e5c2157e7de4252410f38d..cb4a5e1c81c0b174b0dd00ba3cc6f57b5ee22648 100644 --- a/debug/accuracy_tools/msprobe/docs/07.accuracy_checker_PyTorch.md +++ b/debug/accuracy_tools/msprobe/docs/07.accuracy_checker_PyTorch.md @@ -34,16 +34,17 @@ run_ut 预检操作包括以下两种方式: msprobe -f pytorch run_ut -api_info ./dump_path/step{step_number}/rank{rank_number}/dump.json ``` - | 参数名称 | 解释 | 是否必选 | - | ---------------------------- | ------------------------------------------------------------ | ---------------------------------- | - | -api_info 或 --api_info_file | 指定 API 信息文件 dump.json。 | 是 | - | -save_error_data | 保存精度未达标的 API 输入输出数据。 | 否 | - | -o 或 --out_path | 指定 run_ut 执行结果存盘路径,默认“./”。 | 否 | - | -j 或 --jit_compile | 开启 jit 编译。 | 否 | - | -d 或 --device | 指定 Device ID,选择 UT 代码运行所在的卡,默认值为 0。 | 否 | + | 参数名称 | 解释 | 是否必选 | + |-------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| ---------------------------------- | + | -f 或 --framework | 指定训练框架。pytorch。 | 是 | + | -api_info 或 --api_info_file | 指定 API 信息文件 dump.json。 | 是 | + | -save_error_data | 保存精度未达标的 API 输入输出数据。 | 否 | + | -o 或 --out_path | 指定 run_ut 执行结果存盘路径,默认“./”。 | 否 | + | -j 或 --jit_compile | 开启 jit 编译。 | 否 | + | -d 或 --device | 指定 Device ID,选择 UT 代码运行所在的卡,默认值为 0。 | 否 | | -csv_path 或 --result_csv_path | 指定本次运行中断时生成的 `accuracy_checking_result_{timestamp}.csv` 文件路径,执行 run_ut 中断时,若想从中断处继续执行,配置此参数即可。需要指定为上次中断的 `accuracy_checking_result_{timestamp}.csv` 文件。详见 [3.3 断点续检](#33-断点续检)。 | run_ut 操作中断后继续执行场景下必须配置 | - | -f 或 --filter_api | 过滤模型中除最大值和最小值以外其他参数和结构相同的 API。适用于模型较大且重复 API 较多的场景。 | 否 | - | -config 或 --config_path | 指定离线预检操作过程中额外配置(包括黑名单、白名单等)的 [config.json](../config.json) 文件,默认未配置。config.json 文件的配置可参考[配置文件介绍](./02.config_introduction.md)。 | 否 | + | -f 或 --filter_api | 过滤模型中除最大值和最小值以外其他参数和结构相同的 API。适用于模型较大且重复 API 较多的场景。 | 否 | + | -config 或 --config_path | 指定离线预检操作过程中额外配置(包括黑名单、白名单等)的 [config.json](../config.json) 文件,默认未配置。config.json 文件的配置可参考[配置文件介绍](./02.config_introduction.md)。 | 否 | run_ut 执行结果包括 `accuracy_checking_result_{timestamp}.csv` 和 `accuracy_checking_details_{timestamp}.csv` 两个文件。`accuracy_checking_result_{timestamp}.csv` 属于 API 级,标明每个 API 是否通过测试。建议用户先查看 `accuracy_checking_result_{timestamp}.csv` 文件,对于其中没有通过测试的或者特定感兴趣的 API,根据其 API name 字段在 `accuracy_checking_details_{timestamp}.csv` 中查询其各个输出的达标情况以及比较指标。详细介绍请参见[ 4 预检结果](#4-预检结果)。 @@ -103,11 +104,12 @@ msprobe -f pytorch multi_run_ut -api_info ./dump_path/step{step_number}/rank{ran | 参数名称 | 解释 | 是否必选 | | ---------------------------- | ------------------------------------------------------------ | ---------------------------------- | +| -f 或 --framework | 指定训练框架。pytorch。 | 是 | | -api_info 或 --api_info_file | 指定 API 信息文件 dump.json。 | 是 | | -save_error_data | 保存精度未达标的 API 输入输出数据。 | 否 | | -o 或 --out_path | 指定 run_ut 执行结果存盘路径,默认“./”。 | 否 | | -j 或 --jit_compile | 开启 jit 编译。 | 否 | -| -n | 同时执行 run_ut 线程的数量,默认为 8,最大支持 64,但每个 Device 最大支持 8 个线程。当指定多个线程和多个 Device 时,线程数在每张卡上均分。 | 否 | +| -n 或 --num_splits | 同时执行 run_ut 线程的数量,默认为 8,最大支持 64,但每个 Device 最大支持 8 个线程。当指定多个线程和多个 Device 时,线程数在每张卡上均分。 | 否 | | -d 或 --device | 指定 Device ID,选择 UT 代码运行所在的卡,默认值为 0,支持同时指定 0~7,共 8 个 Device。 | 否 | | -csv_path 或 --result_csv_path | 指定本次运行中断时生成的 `accuracy_checking_result_{timestamp}.csv` 文件路径,执行 run_ut 中断时,若想从中断处继续执行,配置此参数即可。需要指定为上次中断的 `accuracy_checking_result_{timestamp}.csv` 文件。详见 [3.3 断点续检](#33-断点续检)。 | run_ut 操作中断后继续执行场景下必须配置 | | -f 或 --filter_api | 过滤模型中除最大值和最小值以外其他参数和结构相同的 API。适用于模型较大且重复 API 较多的场景。 | 否 | @@ -212,8 +214,9 @@ Forward Test Success 和 Backward Test Success 是否通过测试是由 `accurac msprobe -f pytorch api_precision_compare -npu /home/xxx/npu/accuracy_checking_details_{timestamp}.csv -gpu /home/xxx/gpu/accuracy_checking_details_{timestamp}.csv -o /home/xxx/ ``` -| 参数名称 | 说明 | 是否必选 | -| -------------------- | ------------- | -------- | +| 参数名称 | 说明 | 是否必选 | +|-----------------------| ------------- | -------- | +| -f 或 --framework | 指定训练框架。pytorch。 | 是 | | -npu 或 --npu_csv_path | NPU 预检结果 `accuracy_checking_details_{timestamp}.csv` 文件路径。默认从当前目录下识别该文件。 | 是 | | -gpu 或 --gpu_csv_path | GPU 预检结果 `accuracy_checking_details_{timestamp}.csv` 文件路径。默认从当前目录下识别该文件。 | 是 | | -o 或 --out_path | 指定 api_precision_compare.py 执行结果存盘路径,默认为当前目录。 | 否 | diff --git a/debug/accuracy_tools/msprobe/docs/08.accuracy_checker_online_PyTorch.md b/debug/accuracy_tools/msprobe/docs/08.accuracy_checker_online_PyTorch.md index a93ad3b62405d549a16e7196e2f2145de68e8674..06b0eaef8c45199ddd7d4466450ad43f721be6dd 100644 --- a/debug/accuracy_tools/msprobe/docs/08.accuracy_checker_online_PyTorch.md +++ b/debug/accuracy_tools/msprobe/docs/08.accuracy_checker_online_PyTorch.md @@ -37,7 +37,7 @@ Host 与 GPU Host 设备间建立连接,将 NPU 上对应 API 的输入数据 | host | 在线预检模式局域网场景信息接收端 IP,str 类型,用于 GPU 设备和 NPU 设备间进行通信,GPU 侧配置为本机地址 127.0.0.1 或本机局域网 IP。局域网场景时,不能配置 nfs_path 参数,否则局域网场景不生效。 | 否 | | port | 在线预检模式局域网场景信息接收端端口号,int 类型,用于 GPU 设备和 NPU 设备间进行通信,GPU 侧配置为本机可用端口。局域网场景时,不能配置 nfs_path 参数,否则局域网场景不生效。 | 否 | | rank_list | 指定在线预检的 Rank ID,默认值为 [0],list[int] 类型,应配置为大于等于 0 的整数,且须根据实际卡的 Rank ID 配置,若所配置的值大于实际训练所运行的卡的 Rank ID,则在线预检输出数据为空。GPU 和 NPU 须配置一致。 | 是 | -| tls_path | 在线预检模式局域网场景 SSL 证书路径,该路径下包含私钥文件 server.key 和公钥文件 server.crt,str 类型,未配置该参数时默认取值当前路径。tls_path配置为空字符串时,采用TCP协议明文传输api数据;当配置为路径时,采用TLS1.2协议加密传输数据,加密传输时安全性较高,传输速率较低。 | 否 | +| tls_path | 在线预检模式局域网场景 SSL 证书路径,该路径下包含私钥 server.key、证书 server.crt、自建CA证书 ca.crt、CRL吊销证书 crl.pem,str 类型,未配置该参数时默认取值当前路径。tls_path配置为空字符串时,采用TCP协议明文传输api数据;当配置为路径时,采用TLS1.2协议加密传输数据,加密传输时安全性较高,传输速率较低。其中 crl.pem 为非必需文件,仅当用户存在吊销记录时使用。 | 否 | #### 3.1.2 NPU 侧在线预检配置说明 @@ -55,21 +55,73 @@ Host 与 GPU Host 设备间建立连接,将 NPU 上对应 API 的输入数据 | nfs_path | 在线预检模式共享存储目录路径,str 类型,用于 GPU 设备和 NPU 设备间进行通信。配置该参数后 host 和 port 不生效。 | 否 | | host | 在线预检模式局域网场景信息接收端 IP,str 类型,用于 GPU 设备和 NPU 设备间进行通信,NPU 侧须配置为 GPU 侧的局域网 IP 地址。局域网场景时,不能配置 nfs_path 参数,否则局域网场景不生效。 | 否 | | port | 在线预检模式局域网场景信息接收端端口号,int 类型,用于 GPU 设备和 NPU 设备间进行通信,NPU 侧须配置为 GPU 侧的端口号。局域网场景时,不能配置 nfs_path 参数,否则局域网场景不生效。 | 否 | -| tls_path | 在线预检模式局域网场景 SSL 证书路径,该路径下包含私钥文件 client.key 和公钥文件 client.crt,str 类型,未配置该参数时默认取值当前路径。tls_path配置为空字符串时,采用TCP协议明文传输api数据;当配置为路径时,采用TLS1.2协议加密传输数据,加密传输时安全性较高,传输速率较低。 | 否 | +| tls_path | 在线预检模式局域网场景 SSL 证书路径,该路径下包含私钥 client.key、证书 client.crt、自建CA证书 ca.crt、CRL吊销证书 crl.pem,str 类型,未配置该参数时默认取值当前路径。tls_path配置为空字符串时,采用TCP协议明文传输api数据;当配置为路径时,采用TLS1.2协议加密传输数据,加密传输时安全性较高,传输速率较低。其中 crl.pem 为非必需文件,仅当用户存在吊销记录时使用。 | 否 | | online_run_ut_recompute | 模型训练是否使用重计算机制,bool类型,默认为False,表示模型没有使用重计算。在线预检暂不支持重计算机制下反向算子的预检,当模型训练使用重计算时,跳过反向算子预检,默认模型关闭重计算。 | 否 | #### 3.1.3 局域网场景配置示例 -若采用 TLS1.2 协议加密传输 api 数据,需配置 SSL 证书,可参考如下生成自签名证书方法,仅供调试使用,生产环境请申请正式证书。 +若采用 TLS1.2 协议加密传输 api 数据,需配置 SSL 证书,可参考如下生成自签名证书方法。 + +以下秘钥生成方法仅为简单示例,客户应使用与自己需求相符的秘钥生成和存储机制并保证秘钥安全性与机密性,必要时可采用分层秘钥机制。 +以下示例中加密口令仅供参考,使用时请更换为复杂口令,并保护口令安全。 ```shell -# 创建私钥文件server.key -openssl genrsa -out server.key 2048 +# 生成CA证书的根私钥和证书签名请求,其中ca_password为CA私钥加密口令,仅作演示,请更换使用 +openssl req -new -newkey rsa:3072 -passout pass:ca_password -subj "/CN=*ca.com/O=ca.Inc./C=CN/ST=Zhejiang/L=Hangzhou" -keyout ca.key -out ca.csr +# 自签发根证书 +openssl x509 -req -days 365 -in ca.csr -signkey ca.key -passin pass:ca_password -out ca.crt -extensions v3_ca -extfile <(cat <<-EOF +[v3_ca] +basicConstraints = critical,CA:true +keyUsage = critical, keyCertSign, cRLSign +EOF +) + +# 生成client公私钥,其中client_password为私钥加密口令,仅作演示,请更换使用 +openssl genrsa -aes256 -passout pass:client_password -out client.key 3072 +# 基于client公私钥生成签名请求 +openssl req -new -key client.key -passin pass:client_password -subj "/CN=*example.com/O=Test, Inc./C=CN/ST=Zhejiang/L=Hangzhou" -out client.csr +# 利用自签发的根证书,签发client证书 +openssl x509 -req -days 180 -CA ca.crt -CAkey ca.key -passin pass:ca_password -in client.csr -out client.crt -CAcreateserial -extfile <(cat <<-EOF +[v3_server] +basicConstraints = CA:FALSE +keyUsage = critical, digitalSignature, keyEncipherment +extendedKeyUsage = serverAuth +EOF +) + +# 生成server公私钥,其中server_password为私钥加密口令,仅作演示,请更换使用 +openssl genrsa -aes256 -passout pass:server_password -out server.key 3072 +# 基于server公私钥生成签名请求 +openssl req -new -key server.key -passin pass:server_password -subj "/CN=*example.com/O=Test, Inc./C=CN/ST=Zhejiang/L=Hangzhou" -out server.csr +# 利用自签发的根证书,签发server证书 +openssl x509 -req -days 180 -CA ca.crt -CAkey ca.key -passin pass:ca_password -in server.csr -out server.crt -CAcreateserial -extfile <(cat <<-EOF +[v3_server] +basicConstraints = CA:FALSE +keyUsage = critical, digitalSignature, keyEncipherment +extendedKeyUsage = serverAuth +EOF +) + +``` -# 创建签名请求文件server.csr -openssl req -new -key server.key -out server.csr +当需要吊销已创建的SSL证书时,通过openssl命令生成CRL证书 crl.pem,示例如下: +```shell +# 创建证书信息的文本数据库,空文件即可 +touch index.txt + +# 创建ca配置文件ca.cnf,内容如下,用于吊销证书使用 +[ca] +default_ca = CA_default +[CA_default] +database = ./index.txt +default_md = sha256 + +# 吊销证书 client.crt,其中ca_password为CA私钥加密口令,与CA创建时保持一致 +openssl ca -revoke client.crt -config ca.cnf -cert ca.crt -keyfile ca.key -passin pass:ca_password +# 生成CRL文件 +openssl ca -gencrl -config ca.cnf -cert ca.crt -keyfile ca.key -passin pass:ca_password -out crl.pem -crldays 30 +# 查看生成的CRL文件内容: +openssl工具的命令: openssl crl -inform PEM -in crl.pem -text -# 自签名, 生成1年期公钥文件server.crt -openssl x509 -req -days 365 -in server.csr -signkey server.key -out server.crt ``` 注意:配置TLS协议时,传输性能受机器环境和网络质量的影响,可能触发NPU超时中断模型训练,为避免训练和预检中断,丢弃长时间未传输的api数据,同时NPU侧配置HCCL环境变量,配置方式如下: diff --git a/debug/accuracy_tools/msprobe/docs/09.accuracy_checker_MindSpore.md b/debug/accuracy_tools/msprobe/docs/09.accuracy_checker_MindSpore.md index 8e5ab781ce0652ea572e0a0e5fb053655c5f48ec..dd27fcb0b0ea2c16aac5396aadc040b241ac8853 100644 --- a/debug/accuracy_tools/msprobe/docs/09.accuracy_checker_MindSpore.md +++ b/debug/accuracy_tools/msprobe/docs/09.accuracy_checker_MindSpore.md @@ -2,7 +2,7 @@ ## 1 简介 -**MindSpore 动态图精度预检**a通过扫描昇腾 NPU 上用户训练 MindSpore 模型中的所有 Mint API,输出精度情况的诊断和分析。工具以模型中所有 Mint API 前反向的 dump 结果为输入,构造相应的 API 单元测试,将 NPU 输出与标杆(CPU 高精度)比对,计算对应的精度指标,从而找出 NPU 中存在精度问题的 Mint API。本工具支持**随机生成模式和真实数据模式**b。 +**MindSpore 动态图精度预检**a通过扫描昇腾 NPU 上用户训练 MindSpore 模型中的所有 Mint API 以及 Msadapter场景下迁移的 Mindspore API,输出精度情况的诊断和分析。工具以模型中所有 API 前反向的 dump 结果为输入,构造相应的 API 单元测试,将 NPU 输出与标杆(CPU 高精度)比对,计算对应的精度指标,从而找出 NPU 中存在精度问题的 API。本工具支持**随机生成模式和真实数据模式**b。 a. 支持 Mindspore 版本:2.4/2.5; @@ -31,6 +31,7 @@ msprobe -f mindspore run_ut -api_info ./dump.json -o ./checker_result | 参数名称 | 说明 |参数类型 | 是否必选 | | ---------------------------- |---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------------- | ---------------------------------- | +| -f 或 --framework | 指定训练框架。mindspore。 | str | 是 | | -api_info 或 --api_info_file | 指定 API 信息文件 dump.json。对其中的mint api以及部分Tensor api进行预检,预检支持的Tensor api列表详见 [ 预检支持列表](../mindspore/api_accuracy_checker/checker_support_api.yaml)。 | str | 是 | | -o 或 --out_path | 指定预检结果存盘路径,默认“./”。 | str | 否 | | -csv_path 或 --result_csv_path | 指定本次运行中断时生成的 `accuracy_checking_result_{timestamp}.csv` 文件路径,执行 run_ut 中断时,若想从中断处继续执行,配置此参数即可。需要指定为上次中断的 `accuracy_checking_result_{timestamp}.csv` 文件。详见 [3.3 断点续检](#33-断点续检)。 | str | 否 | @@ -45,12 +46,13 @@ multi_run_ut 脚本,可以并行在多个Device执行 run_ut 操作,从而 msprobe -f mindspore multi_run_ut -api_info ./dump.json -d 0 1 2 3 ``` -| 参数名称 | 说明 |参数类型 | 是否必选 | -| ---------------------------- |---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------------- | ---------------------------------- | -| -api_info 或 --api_info_file | 指定 API 信息文件 dump.json。对其中的mint api以及部分Tensor api进行预检,预检支持的Tensor api列表详见 [ 预检支持列表](../mindspore/api_accuracy_checker/checker_support_api.yaml)。 | str | 是 | -| -o 或 --out_path | 指定预检结果存盘路径,默认“./”。 | str | 否 | -| -csv_path 或 --result_csv_path | 指定本次运行中断时生成的 `accuracy_checking_result_{timestamp}.csv` 文件路径,执行 run_ut 中断时,若想从中断处继续执行,配置此参数即可。需要指定为上次中断的 `accuracy_checking_result_{timestamp}.csv` 文件。详见 [3.3 断点续检](#33-断点续检)。 | str | 否 | -| -d 或 --device | 指定 Device ID,选择 UT 代码运行所在的卡,默认值为 0,支持同时指定 0 ~ Device数量 - 1 ,例如 0 1 2 3 4。 | List[int] | 否 | +| 参数名称 | 说明 | 参数类型 | 是否必选 | +| ---------------------------- |----------------------------------------------------------------------------------------------------------------------------------------------------------|-----------| ---------------------------------- | +| -f 或 --framework | 指定训练框架。mindspore。 | str | 是 | +| -api_info 或 --api_info_file | 指定 API 信息文件 dump.json。对其中的mint api以及部分Tensor api进行预检,预检支持的Tensor api列表详见 [ 预检支持列表](../mindspore/api_accuracy_checker/checker_support_api.yaml)。 | str | 是 | +| -o 或 --out_path | 指定预检结果存盘路径,默认“./”。 | str | 否 | +| -csv_path 或 --result_csv_path | 指定本次运行中断时生成的 `accuracy_checking_result_{timestamp}.csv` 文件路径,执行 run_ut 中断时,若想从中断处继续执行,配置此参数即可。需要指定为上次中断的 `accuracy_checking_result_{timestamp}.csv` 文件。详见 [3.3 断点续检](#33-断点续检)。 | str | 否 | +| -d 或 --device | 指定 Device ID,选择 UT 代码运行所在的卡,默认值为 0,支持同时指定 0 ~ Device数量 - 1 ,例如 0 1 2 3 4。 | List[int] | 否 | 在不同卡数下,使用38B语言大模型的预检耗时基线参考 [multi_run_ut耗时基线](accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md) diff --git a/debug/accuracy_tools/msprobe/docs/10.accuracy_compare_PyTorch.md b/debug/accuracy_tools/msprobe/docs/10.accuracy_compare_PyTorch.md index b4525d738d849a17ca5049bd2214784c6f788d21..a7bf2b7b9c211c639d5ac21b4421634c7ee7ca44 100644 --- a/debug/accuracy_tools/msprobe/docs/10.accuracy_compare_PyTorch.md +++ b/debug/accuracy_tools/msprobe/docs/10.accuracy_compare_PyTorch.md @@ -51,14 +51,15 @@ msprobe -f pytorch compare -i ./compare.json -o ./output -s 完整参数说明: -| 参数名 | 说明 | 是否必选 | -|-------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -------- | -| -i 或 --input_path | 指定[比对文件](#214-比对文件),str 类型。 | 是 | -| -o 或 --output_path | 配置比对结果文件存盘目录,str 类型,默认在当前目录创建output目录。文件名称基于时间戳自动生成,格式为:`compare_result_{timestamp}.xlsx`。 | 否 | -| -s 或 --stack_mode | 比对结果展示调用栈信息(NPU_Stack_Info)的开关,bool 类型。单卡场景开启时,根据[比对文件](#214-比对文件)的参数说明配置stack_path;多卡场景开启时,自动识别npu_dump目录下stack.json文件,如存在生成详细调用栈信息,否则不生成,此参数不生效。通过直接配置该参数开启,默认未配置,表示关闭。 | 否 | +| 参数名 | 说明 | 是否必选 | +|-------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -------- | +| -f 或 --framework | 指定训练框架。pytorch。 | 是 | +| -i 或 --input_path | 指定[比对文件](#214-比对文件),str 类型。 | 是 | +| -o 或 --output_path | 配置比对结果文件存盘目录,str 类型,默认在当前目录创建output目录。文件名称基于时间戳自动生成,格式为:`compare_result_{timestamp}.xlsx`。
提示:output目录下与结果件同名文件将被删除覆盖。 | 否 | +| -s 或 --stack_mode | 比对结果展示调用栈信息(NPU_Stack_Info)的开关,bool 类型。单卡场景开启时,根据[比对文件](#214-比对文件)的参数说明配置stack_path;多卡场景开启时,自动识别npu_dump目录下stack.json文件,如存在生成详细调用栈信息,否则不生成,此参数不生效。通过直接配置该参数开启,默认未配置,表示关闭。 | 否 | | -c 或 --compare_only | 仅比对开关,bool 类型。该参数默认未配置,会启用自动精度分析,工具自动针对比对结果进行分析,识别到第一个精度可能不达标节点(在比对结果文件中的 Accuracy Reached or Not 列显示为 No),并给出问题可能产生的原因(打屏展示并生成 `advisor_{timestamp}.txt` 文件)。通过配置该参数取消自动精度分析,仅输出比对结果表格。 | 否 | -| -f 或 --fuzzy_match | 模糊匹配,bool 类型。开启后,对于网络中同一层级且命名仅调用次数不同的 API,可匹配并进行比对。通过直接配置该参数开启,默认未配置,表示关闭。 | 否 | -| -dm或--data_mapping | 自定义映射关系比对。需要指定自定义映射文件*.yaml。自定义映射文件的格式请参见[自定义映射文件](#215-自定义映射文件)。仅[API和模块无法自动匹配场景](#213-api和模块无法自动匹配场景)需要配置。仅支持逐卡比对,即使用[比对文件](#214-比对文件)的单卡场景示例。 | 否 | +| -f 或 --fuzzy_match | 模糊匹配,bool 类型。开启后,对于网络中同一层级且命名仅调用次数不同的 API,可匹配并进行比对。通过直接配置该参数开启,默认未配置,表示关闭。 | 否 | +| -dm或--data_mapping | 自定义映射关系比对。需要指定自定义映射文件*.yaml。自定义映射文件的格式请参见[自定义映射文件](#215-自定义映射文件)。仅[API和模块无法自动匹配场景](#213-api和模块无法自动匹配场景)需要配置。仅支持逐卡比对,即使用[比对文件](#214-比对文件)的单卡场景示例。 | 否 | #### 2.1.2 整网比对场景 @@ -66,19 +67,17 @@ msprobe -f pytorch compare -i ./compare.json -o ./output -s 支持单卡和多卡,可同时比对多卡的 dump 数据。多机场景需要每个设备单独执行比对操作。 -1. 配置[config.json](../config.json)文件。 - -2. 参见 [PyTorch 场景下的数据采集](./05.data_dump_PyTorch.md)章节完成 CPU 或 GPU 与 NPU 的精度数据 dump。 +1. 参见 [PyTorch 场景下的数据采集](./05.data_dump_PyTorch.md)章节完成 CPU 或 GPU 与 NPU 的精度数据 dump。 -3. 创建[比对文件](#214-比对文件)。 +2. 创建[比对文件](#214-比对文件)。 -4. 运行命令: +3. 运行命令: ```shell msprobe -f pytorch compare -i ./compare.json -o ./output -s ``` -5. 查看比对结果,请参见 [3 精度比对结果分析](#3-精度比对结果分析)。 +4. 查看比对结果,请参见 [3 精度比对结果分析](#3-精度比对结果分析)。 #### 2.1.3 API和模块无法自动匹配场景 @@ -121,8 +120,8 @@ msprobe -f pytorch compare -i ./compare.json -o ./output -s ```json { - "npu_path": "./npu_dump/step0", - "bench_path": "./bench_dump/step0", + "npu_path": "./npu_dump/step0", # 需填写到step层级(rank的上一层级) + "bench_path": "./bench_dump/step0", # 需填写到step层级(rank的上一层级) "is_print_compare_log": true } ``` @@ -131,8 +130,8 @@ msprobe -f pytorch compare -i ./compare.json -o ./output -s | 参数名 | 说明 | 是否必选 | | -------------------- |-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------| -| npu_path | 配置 NPU 环境下的 dump.json 文件(单卡场景)或真实数据目录(多卡场景),str 类型。 | 是 | -| bench_path | 配置 CPU、GPU 或 NPU 环境下的 dump.json 文件(单卡场景)或真实数据目录(多卡场景),str 类型。 | 是 | +| npu_path | 配置 NPU 环境下的 dump.json 文件(单卡场景)或 dump 目录(多卡场景),str 类型。 | 是 | +| bench_path | 配置 CPU、GPU 或 NPU 环境下的 dump.json 文件(单卡场景)或 dump 目录(多卡场景),str 类型。 | 是 | | stack_path | 配置 NPU dump 目录下的 stack.json 文件,str 类型。如果没有配置stack_path,命令行-s参数不生效,程序自动识别是否存在stack.json文件,如存在,则比对结果中呈现NPU_Stack_Info,如不存在,则不呈现。如果配置了stack_path,比对结果中是否呈现NPU_Stack_Info则通过命令行参数-s来控制。 | 否 | | is_print_compare_log | 配置是否开启单个算子的日志打屏。可取值 true 或 false,默认为 true。关闭后则只输出常规日志,bool 类型。 | 否 | @@ -162,10 +161,109 @@ NPU.npu_fusion_attention.4.forward.input.0: NPU.npu_fusion_attention.4.forward.i Module.module.language_model.embedding.word_embedding.VocabParallelEmbedding.forward.0.input.0: Module.module.language_model.embedding.word_embedding.VocabParallelEmbedding.forward.0.input.0 ``` -API和模块名称在dump.json文件中的“data_name”字段展示,如下图红框处所示: +当dump.json文件中存在“data_name”字段时,API和模块名称为data_name字段去掉文件后缀,如下图红框处所示: ![pt_dump](./img/pt_dump.png) +当dump.json文件中不存在“data_name”字段时,名称的拼写规则如下: + +input_args、input_kwargs和output使用统一的命名规则,当值是list类型时,名称后面添加'.{index}',当值类型是dict类型时,名称后面加'.{key}',当值类型是具体Tensor或null或int或float或bool或空list/dict等时,命名结束。 + +以下面api的dump文件为例: +```yaml + "Functional.max_pool2d.0.forward": { + "input_args": [ + { + "type": "torch.Tensor", + "dytpe": "torch_float32", + "shape": [ + 1, + 64, + 14, + 14 + ], + "Max": xxx, + "Min": xxx, + "Mean": xxx, + "Norm": xxx, + "requires_grad": true + }, + { + "type": "int", + "value": 3 + }, + { + "type": "int", + "value": 2 + }, + { + "type": "int", + "value": 1 + }, + { + "type": "int", + "value": 1 + } + ], + "input_kwargs": { + "ceil_mode": { + "type": "bool", + "value": false + }, + "return_indices": { + "type": "bool", + "value": false + }, + }, + "output": [ + { + "type": "torch.Tensor", + "dtype": "torch.float32", + "shape": [ + 1, + 64, + 7, + 7 + ], + "Max": xxx, + "Min": xxx, + "Mean": xxx, + "Norm": xxx, + "requires_grad": true + } + ] + } +``` + +初始名称为Functional.max_pool2d.0.forward,input_args是list,长度为5,第0项后面是Tensor,命名结束;第1-4项后面均是int,命名结束;按照顺序命名为 +``` +Functional.max_pool2d.0.forward.input.0 +Functional.max_pool2d.0.forward.input.1 +Functional.max_pool2d.0.forward.input.2 +Functional.max_pool2d.0.forward.input.3 +Functional.max_pool2d.0.forward.input.4 +``` +input_kwargs是dict,key是ceil_mode、return_indices,值均是bool,命名结束;命名为 +``` +Functional.max_pool2d.0.forward.input.ceil_mode +Functional.max_pool2d.0.forward.input.return_indices +``` +output是list,长度为1,第0项后面是Tensor,命名结束;按照顺序命名为 +``` +Functional.max_pool2d.0.forward.output.0 +``` +综上,生成的的op_name为 +``` +Functional.max_pool2d.0.forward.input.0 +Functional.max_pool2d.0.forward.input.1 +Functional.max_pool2d.0.forward.input.2 +Functional.max_pool2d.0.forward.input.3 +Functional.max_pool2d.0.forward.input.4 +Functional.max_pool2d.0.forward.input.ceil_mode +Functional.max_pool2d.0.forward.input.return_indices +Functional.max_pool2d.0.forward.output.0 +``` + ### 2.2 比对函数方式 #### 2.2.1 compare 函数 @@ -180,13 +278,13 @@ compare(input_param, output_path, stack_mode=False, auto_analyze=True, fuzzy_mat **参数说明**: -| 参数名 | 说明 | 是否必选 | -| ------------ |----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -------- | +| 参数名 | 说明 | 是否必选 | +| ------------ |-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -------- | | input_param | 配置 dump 数据文件及目录,dict 类型。配置参数包括:
"npu_json_path":指定 NPU dump 目录下的 dump.json 文件。
**配置示例**:"npu_json_path": "./npu_dump/dump.json"。
"bench_json_path":指定 CPU、GPU 或 NPU dump 目录下的 dump.json 文件。
**配置示例**:"bench_json_path": "./bench_dump/dump.json"。
"stack_json_path":指定 NPU dump 目录下的 stack.json 文件。
**配置示例**:"stack_json_path": "./npu_dump/stack.json"。
"is_print_compare_log":配置是否开启单个算子的日志打屏。
**配置示例**:True 或 False。 | 是 | -| output_path | 配置比对结果文件存盘目录,str 类型。
**配置示例**:'./output'。文件名称基于时间戳自动生成,格式为:`compare_result_{timestamp}.xlsx`。 | 是 | -| stack_mode | 配置 stack_mode 的开关,bool 类型。仅当配置 stack_json_path 时需要,开启时比对结果呈现NPU_Stack_Info,关闭时不呈现。当不配置stack_json_path 时,自动识别是否存在stack.json,存在时呈现NPU_Stack_Info,否则不呈现。
**配置示例**:stack_mode=True,默认为 False。 | 否 | -| auto_analyze | 自动精度分析,bool 类型。开启后工具自动针对比对结果进行分析,识别到第一个精度可能不达标节点(在比对结果文件中的 Accuracy Reached or Not 列显示为 No),并给出问题可能产生的原因(打屏展示并生成 advisor_{timestamp}.txt 文件)。
**配置示例**:auto_analyze=False,默认为 True。 | 否 | -| fuzzy_match | 模糊匹配,bool 类型。开启后,对于网络中同一层级且命名仅调用次数不同的 API,可匹配并进行比对。
**配置示例**:fuzzy_match=True,默认为 False。 | 否 | +| output_path | 配置比对结果文件存盘目录,str 类型。
**配置示例**:'./output'。文件名称基于时间戳自动生成,格式为:`compare_result_{timestamp}.xlsx`。
提示:output目录下与结果件同名文件将被删除覆盖。 | 是 | +| stack_mode | 配置 stack_mode 的开关,bool 类型。仅当配置 stack_json_path 时需要,开启时比对结果呈现NPU_Stack_Info,关闭时不呈现。当不配置stack_json_path 时,自动识别是否存在stack.json,存在时呈现NPU_Stack_Info,否则不呈现。
**配置示例**:stack_mode=True,默认为 False。 | 否 | +| auto_analyze | 自动精度分析,bool 类型。开启后工具自动针对比对结果进行分析,识别到第一个精度可能不达标节点(在比对结果文件中的 Accuracy Reached or Not 列显示为 No),并给出问题可能产生的原因(打屏展示并生成 advisor_{timestamp}.txt 文件)。
**配置示例**:auto_analyze=False,默认为 True。 | 否 | +| fuzzy_match | 模糊匹配,bool 类型。开启后,对于网络中同一层级且命名仅调用次数不同的 API,可匹配并进行比对。
**配置示例**:fuzzy_match=True,默认为 False。 | 否 | **函数示例**: @@ -215,12 +313,12 @@ compare_distributed(npu_dump_dir, bench_dump_dir, output_path, **kwargs) **参数说明**: -| 参数名 | 说明 | 是否必选 | -| -------------- |-----------------------------------------------------------------------------------------------------------------------------------------------------------| -------- | -| npu_dump_dir | 配置 NPU 环境下的 dump 目录。str 类型。dump 数据目录须指定到 step 级。
**配置示例**:'./npu_dump/step0'。 | 是 | -| bench_dump_dir | 配置 CPU、GPU 或 NPU 环境下的 dump 目录。str 类型。
**配置示例**:'./gpu_dump/step0'。 | 是 | -| output_path | 配置比对结果文件存盘目录。需要预先创建 output_path 目录。str 类型。
**配置示例**:'./output'。文件名称基于时间戳自动生成,格式为:`compare_result_rank{npu_ID}-rank{cpu/gpu/npu_ID}_{timestamp}.xlsx`。 | 是 | -| **kwargs | 支持 compare 的所有可选参数。 其中,stack_mode不生效,自动识别是否存在stack.json,如存在,呈现NPU_Stack_Info,否则不呈现。 | 否 | +| 参数名 | 说明 | 是否必选 | +| -------------- |------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -------- | +| npu_dump_dir | 配置 NPU 环境下的 dump 目录。str 类型。dump 数据目录须指定到 step 级。
**配置示例**:'./npu_dump/step0'。 | 是 | +| bench_dump_dir | 配置 CPU、GPU 或 NPU 环境下的 dump 目录。str 类型。
**配置示例**:'./gpu_dump/step0'。 | 是 | +| output_path | 配置比对结果文件存盘目录。需要预先创建 output_path 目录。str 类型。
**配置示例**:'./output'。文件名称基于时间戳自动生成,格式为:`compare_result_rank{npu_ID}_{timestamp}.xlsx`。
提示:output目录下与结果件同名文件将被删除覆盖。 | 是 | +| **kwargs | 支持 compare 的所有可选参数。 其中,stack_mode不生效,自动识别是否存在stack.json,如存在,呈现NPU_Stack_Info,否则不呈现。 | 否 | **函数示例**: @@ -257,11 +355,11 @@ PyTorch 精度比对是以 CPU 或 GPU 的计算结果为标杆,通过计算 统计量有 4 种:最大值(max)、最小值(min)、平均值(mean)和 L2-范数(L2 norm)。 -|dump 数据模式|Cosine (tensor 余弦相似度)|MaxAbsErr (tensor 最大绝对误差)|MaxRelativeErr (tensor 最大相对误差)|One Thousandth Err Ratio (tensor 相对误差小于千分之一的比例)|Five Thousandth Err Ratio (tensor 相对误差小于千分之五的比例)|NPU 和 bench 的统计量绝对误差 (max, min, mean, L2 norm) diff| NPU 和 bench 的统计量相对误差 (max, min, mean, L2 norm) RelativeErr |NPU 和 bench 的统计量 (max, min, mean, L2 norm)|NPU MD5 (NPU 数据 CRC-32 值)|BENCH MD5 (bench 数据 CRC-32 值)|Result (比对结果)|Accuracy Reached or Not (计算精度是否达标)|Err_message (错误信息提示)|NPU_Stack_Info (堆栈信息)|Data_Name (NPU 真实数据名)| -|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:| -|真实数据模式|√|√|√|√|√|||√||||√|√|√|√| -|统计数据模式||||||√|√|√|||√||√|√|| -|MD5 模式|||||||||√|√|√|||√|| +|dump 数据模式|Cosine (tensor 余弦相似度)|EucDist (tensor 欧式距离)|MaxAbsErr (tensor 最大绝对误差)|MaxRelativeErr (tensor 最大相对误差)|One Thousandth Err Ratio (tensor 相对误差小于千分之一的比例)|Five Thousandth Err Ratio (tensor 相对误差小于千分之五的比例)|NPU 和 bench 的统计量绝对误差 (max, min, mean, L2 norm) diff| NPU 和 bench 的统计量相对误差 (max, min, mean, L2 norm) RelativeErr |NPU 和 bench 的统计量 (max, min, mean, L2 norm)|NPU MD5 (NPU 数据 CRC-32 值)|BENCH MD5 (bench 数据 CRC-32 值)|Result (比对结果)|Accuracy Reached or Not (计算精度是否达标)|Err_message (错误信息提示)|NPU_Stack_Info (堆栈信息)| Data_Name ([NPU真实数据名,Bench真实数据名]) | +|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---------------------------------:| +|真实数据模式|√|√|√|√|√|√|||√||||√|√|√| √ | +|统计数据模式|||||||√|√|√|||√||√|√| | +|MD5 模式||||||||||√|√|√|||√| | 上表中NPU_Stack_Info字段需要配置-s参数生成。 @@ -315,12 +413,12 @@ MD5 模式: 1. "Need double check api accuracy.":四个统计值中至少 1 个相对误差 > 0.5(统计数据模式); 2. "Fuzzy matching data, the comparison arruracy may be affected.":NPU 或 Bench 的真实数据名没有匹配上(真实数据模式); -3. "Dump file: {} not found.":NPU 真实数据不存在或者读取出错(真实数据模式); +3. "Dump file: {} not found or read failed.":NPU 或 Bench 的真实数据不存在或者读取出错(真实数据模式); 4. "No bench data matched.":Bench 的 API 没有匹配上、Bench 真实数据不存在或读取出错(真实数据模式); 5. "This is empty data, can not compare.":读取到的数据为空(真实数据模式); 6. "Shape of NPU and bench Tensor do not match. Skipped.":NPU 和 Bench 的数据结构不一致(真实数据模式); 7. "The Position of inf or nan in NPU and bench Tensor do not match.":NPU 和 Bench 的数据有 nan/inf(真实数据模式); -8. "This is type of 0-d tensor, can not calculate 'Cosine', 'One Thousandth Err Ratio' and 'Five Thousandths Err Ratio'.":NPU 为0维张量(真实数据模式); +8. "This is type of 0-d tensor, can not calculate 'Cosine', 'EucDist', 'One Thousandth Err Ratio' and 'Five Thousandths Err Ratio'.":NPU 为0维张量(真实数据模式); 9. "Dtype of NPU and bench Tensor do not match.":NPU 和 Bench 数据的数据类型不同(真实数据模式); 10. "":除以上情况的其余情况(真实数据模式、统计数据模式)。 @@ -330,13 +428,15 @@ MD5 模式: 1. Cosine:通过计算两个向量的余弦值来判断其相似度,数值越接近于 1 说明计算出的两个张量越相似,实际可接受阈值为大于 0.99。在计算中可能会存在 nan,主要由于可能会出现其中一个向量为 0。 -2. MaxAbsErr:当最大绝对误差越接近 0 表示其计算的误差越小,实际可接受阈值为小于 0.001。 +2. EucDist:通过计算两个向量的欧式距离来判断其相似度,定义为多维空间中两个点之间的绝对距离。数值越接近0,张量越相似,数值越大,差异越大。 + +3. MaxAbsErr:当最大绝对误差越接近 0 表示其计算的误差越小,实际可接受阈值为小于 0.001。 -3. MaxRelativeErr:当最大相对误差越接近 0 表示其计算的误差越小。 +4. MaxRelativeErr:当最大相对误差越接近 0 表示其计算的误差越小。 当 dump 数据中存在 0 或 Nan 时,比对结果中最大相对误差则出现 inf 或 Nan 的情况,属于正常现象。 -4. One Thousandth Err Ratio(相对误差小于千分之一的元素比例)、Five Thousandths Err Ratio(相对误差小于千分之五的元素比例)精度指标:是指 NPU 的 Tensor 中的元素逐个与对应的标杆数据对比,相对误差小于千分之一、千分之五的比例占总元素个数的比例。该数据仅作为精度下降趋势的参考,并不参与计算精度是否通过的判定。 +5. One Thousandth Err Ratio(相对误差小于千分之一的元素比例)、Five Thousandths Err Ratio(相对误差小于千分之五的元素比例)精度指标:是指 NPU 的 Tensor 中的元素逐个与对应的标杆数据对比,相对误差小于千分之一、千分之五的比例占总元素个数的比例。该数据仅作为精度下降趋势的参考,并不参与计算精度是否通过的判定。 ## 4 多卡比对结果提取汇总通信算子数据 @@ -358,11 +458,12 @@ msprobe -f pytorch merge_result -i ./input_dir -o ./output_dir -config ./config. **完整参数说明** -| 参数名 | 说明 | 是否必选 | -| ---------------------- |------------------------------------------------------------------------------------| -------- | -| -i 或 --input_dir | 多卡比对结果存盘目录,即使用compare比对的结果输出目录,str类型。所有比对结果应全部为真实数据比对结果或统计数据比对结果,否则可能导致汇总数据不完整。 | 是 | -| -o 或 --output_dir | 数据提取汇总结果存盘目录,str类型。文件名称基于时间戳自动生成,格式为:`multi_ranks_compare_merge_{timestamp}.xlsx`。 | 是 | -| -config或--config-path | 指定需要汇总数据的API和比对指标的yaml文件路径,str类型。
yaml文件详细介绍见下文“**yaml文件说明**”。 | 是 | +| 参数名 | 说明 | 是否必选 | +| --------------------- |-------------------------------------------------------------------------------------------------------------------| -------- | +| -f 或 --framework | 指定训练框架。pytorch。 | 是 | +| -i 或 --input_dir | 多卡比对结果存盘目录,即使用compare比对的结果输出目录,str类型。所有比对结果应全部为真实数据比对结果或统计数据比对结果,否则可能导致汇总数据不完整。 | 是 | +| -o 或 --output_dir | 数据提取汇总结果存盘目录,str类型。文件名称基于时间戳自动生成,格式为:`multi_ranks_compare_merge_{timestamp}.xlsx`。
提示:output目录下与结果件同名文件将被删除覆盖。 | 是 | +| -config或--config-path | 指定需要汇总数据的API和比对指标的yaml文件路径,str类型。
yaml文件详细介绍见下文“**yaml文件说明**”。 | 是 | **yaml文件说明** @@ -378,10 +479,10 @@ compare_index: - MeanRelativeErr ``` -| 参数名 | 说明 | -| ------------- | ------------------------------------------------------------ | -| api | 表示需要汇总的API或module名称。如果没有配置,工具会提示报错。
api名称配置格式为:`{api_type}.{api_name}.{API调用次数}.{前向反向}`
须按顺序配置以上四个字段,可按如下组合配置:
{api_type}
{api_type}.{api_name}
{api_type}.{api_name}.{API调用次数}
{api_type}.{api_name}.{API调用次数}.{前向反向}
这里的api指代API或module。 | -| compare_index | 表示需要汇总的比对指标。compare_index需为dump_mode对应比对指标的子集。如果没有配置,工具将根据比对结果自动提取dump_mode对应的全部比对指标进行汇总。
统计数据模式比对指标:Max diff、Min diff、Mean diff、Norm diff、MaxRelativeErr、MinRelativeErr、MeanRelativeErr、NormRelativeErr
真实数据模式比对指标:Cosine、MaxAbsErr、MaxRelativeErr、One Thousandth Err Ratio、Five Thousandths Err Ratio | +| 参数名 | 说明 | +| ------------- |-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| api | 表示需要汇总的API或module名称。如果没有配置,工具会提示报错。
api名称配置格式为:`{api_type}.{api_name}.{API调用次数}.{前向反向}`
须按顺序配置以上四个字段,可按如下组合配置:
{api_type}
{api_type}.{api_name}
{api_type}.{api_name}.{API调用次数}
{api_type}.{api_name}.{API调用次数}.{前向反向}
这里的api指代API或module。 | +| compare_index | 表示需要汇总的比对指标。compare_index需为dump_mode对应比对指标的子集。如果没有配置,工具将根据比对结果自动提取dump_mode对应的全部比对指标进行汇总。
统计数据模式比对指标:Max diff、Min diff、Mean diff、L2norm diff、MaxRelativeErr、MinRelativeErr、MeanRelativeErr、NormRelativeErr
真实数据模式比对指标:Cosine、EucDist、MaxAbsErr、MaxRelativeErr、One Thousandth Err Ratio、Five Thousandths Err Ratio | **汇总结果件说明** diff --git a/debug/accuracy_tools/msprobe/docs/11.accuracy_compare_MindSpore.md b/debug/accuracy_tools/msprobe/docs/11.accuracy_compare_MindSpore.md index 1b1824a774f15a86106585669d5f3412b3faca2e..55a148058a761f03b50b20ba635789e37241629f 100644 --- a/debug/accuracy_tools/msprobe/docs/11.accuracy_compare_MindSpore.md +++ b/debug/accuracy_tools/msprobe/docs/11.accuracy_compare_MindSpore.md @@ -19,7 +19,7 @@ msprobe精度比对工具主要用于如下场景: - 通过对同一个网络模型,在整网环境下分别在MindSpore动态图和PyTorch环境下获得API或模块dump数据,由用户指定可以比对的API或模块,以PyTorch数据作为标杆,进行自动比对,从而实现跨框架的精度对比。 - 通过对同一个网络模型,在整网环境下分别在MindSpore动态图和PyTorch环境下获得API或模块dump数据,由用户指定可以比对的模型代码中的Layer层,以PyTorch数据作为标杆,进行自动比对,从而实现跨框架的精度对比。 -执行精度比对操作需要安装msprobe工具。详见《[MindStudio精度调试工具](../README.md)》的“工具安装”章节。 +执行精度比对操作需要安装msprobe工具。详见[《msprobe 工具安装指南》](./01.installation.md)。 ## 2 命令行比对 @@ -35,17 +35,18 @@ msprobe -f mindspore compare -i ./compare.json -o ./output -s **完整参数说明** -| 参数名 | 说明 | 是否必选 | -| -------------------- | ------------------------------------------------------------ | -------- | -| -i或--input_path | 指定比对文件。比对文件内容及示例请参见[比对文件](#31-比对文件)或[比对文件(kernel)](#32-比对文件kernel)(比对文件(kernel)仅[不同版本下的全量kernel比对](#23-不同版本下的全量kernel比对)场景支持)。 | 是 | -| -o或--output_path | 配置比对结果文件存盘目录,默认会在当前目录创建output目录。文件名称基于时间戳自动生成,格式为:
`compare_result_{timestamp}.xlsx`
`compare_result_{rank_id}_{step_id}_{timestamp}.xlsx`(仅[不同版本下的全量kernel比对](#23-不同版本下的全量kernel比对)场景支持)。 | 否 | -| -s或--stack_mode | 比对结果展示调用栈信息(NPU_Stack_Info)的开关,bool 类型。单卡场景开启时,需要使用[比对文件](#31-比对文件)的单卡场景配置stack_path指定stack.json文件,才能生成详细调用栈信息,否则在比对时会报错;暂不支持多卡场景。通过直接配置该参数开启,默认未配置,表示关闭。 | 否 | -| -c或--compare_only | 仅比对开关,bool 类型。该参数默认未配置,会启用自动精度分析,工具自动针对比对结果进行分析,识别到第一个精度可能不达标节点(在比对结果文件中的 Accuracy Reached or Not 列显示为 No),并给出问题可能产生的原因(打屏展示并生成 `advisor_{timestamp}.txt` 文件)。通过配置该参数取消自动精度分析,仅输出比对结果表格。 | 否 | -| -f或--fuzzy_match | 模糊匹配。开启后,对于网络中同一层级且命名仅调用次数不同的API,可匹配并进行比对。通过直接配置该参数开启,默认未配置,表示关闭。 | 否 | -| -am或--api_mapping | 跨框架比对。配置该参数时表示开启跨框架API比对功能,可以指定自定义映射文件*.yaml,不指定映射文件时按照msprobe定义的默认映射关系进行比对。自定义映射文件的格式请参见[自定义映射文件(api_mapping)](#33-自定义映射文件api_mapping)。仅[跨框架的API比对](#25-跨框架的api比对)场景需要配置。 | 否 | -| -cm或--cell_mapping | 跨框架比对。配置该参数时表示开启跨框架cell模块比对功能,可以指定自定义映射文件*.yaml,不指定映射文件时按照msprobe定义的默认映射关系进行比对。自定义映射文件的格式请参见[自定义映射文件(cell_mapping)](#34-自定义映射文件cell_mapping)。仅[跨框架的cell模块比对](#26-跨框架的cell模块比对)场景需要配置。 | 否 | -| -dm或--data_mapping | 同框架或跨框架比对。通过映射文件指定两个具体参数的对应关系,可以在L0、L1或mix采集场景下使用。配置该参数的同时需要指定自定义映射文件*.yaml。自定义映射文件的格式请参见[自定义映射文件(data_mapping)](#35-自定义映射文件data_mapping)。 | 否 | -| -lm或--layer_mapping | 跨框架比对。配置该参数时表示开启跨框架Layer层的比对功能,指定模型代码中的Layer层后,可以识别对应dump数据中的模块或API。需要指定自定义映射文件*.yaml。自定义映射文件的格式请参见[自定义映射文件(Layer_mapping)](#36-自定义映射文件layer_mapping)。仅[跨框架的Layer层比对](#27-跨框架的layer层比对)场景需要配置。 | 否 | +| 参数名 | 说明 | 是否必选 | +| -------------------- |--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -------- | +| -f 或 --framework | 指定训练框架。mindspore。 | 是 | +| -i或--input_path | 指定比对文件。比对文件内容及示例请参见[比对文件](#41-比对文件)或[比对文件(kernel)](#42-比对文件kernel)(比对文件(kernel)仅[不同版本下的全量kernel比对](#23-不同版本下的全量kernel比对)场景支持)。 | 是 | +| -o或--output_path | 配置比对结果文件存盘目录,默认会在当前目录创建output目录。文件名称基于时间戳自动生成,格式为:
`compare_result_{timestamp}.xlsx`
`compare_result_{rank_id}_{step_id}_{timestamp}.xlsx`(仅[不同版本下的全量kernel比对](#23-不同版本下的全量kernel比对)场景支持)。
提示:output目录下与结果件同名文件将被删除覆盖。 | 否 | +| -s或--stack_mode | 比对结果展示调用栈信息(NPU_Stack_Info)的开关,bool 类型。单卡场景开启时,需要使用[比对文件](#41-比对文件)的单卡场景配置stack_path指定stack.json文件,才能生成详细调用栈信息,否则在比对时会报错;暂不支持多卡场景。通过直接配置该参数开启,默认未配置,表示关闭。 | 否 | +| -c或--compare_only | 仅比对开关,bool 类型。该参数默认未配置,会启用自动精度分析,工具自动针对比对结果进行分析,识别到第一个精度可能不达标节点(在比对结果文件中的 Accuracy Reached or Not 列显示为 No),并给出问题可能产生的原因(打屏展示并生成 `advisor_{timestamp}.txt` 文件)。通过配置该参数取消自动精度分析,仅输出比对结果表格。 | 否 | +| -f或--fuzzy_match | 模糊匹配。开启后,对于网络中同一层级且命名仅调用次数不同的API,可匹配并进行比对。通过直接配置该参数开启,默认未配置,表示关闭。 | 否 | +| -am或--api_mapping | 跨框架比对。配置该参数时表示开启跨框架API比对功能,可以指定自定义映射文件*.yaml,不指定映射文件时按照msprobe定义的默认映射关系进行比对。自定义映射文件的格式请参见[自定义映射文件(api_mapping)](#43-自定义映射文件api_mapping)。仅[跨框架的API比对](#25-跨框架的api比对)场景需要配置。 | 否 | +| -cm或--cell_mapping | 跨框架比对。配置该参数时表示开启跨框架cell模块比对功能,可以指定自定义映射文件*.yaml,不指定映射文件时按照msprobe定义的默认映射关系进行比对。自定义映射文件的格式请参见[自定义映射文件(cell_mapping)](#44-自定义映射文件cell_mapping)。仅[跨框架的cell模块比对](#26-跨框架的cell模块比对)场景需要配置。 | 否 | +| -dm或--data_mapping | 同框架或跨框架比对。通过映射文件指定两个具体参数的对应关系,可以在L0、L1或mix采集场景下使用。配置该参数的同时需要指定自定义映射文件*.yaml。自定义映射文件的格式请参见[自定义映射文件(data_mapping)](#45-自定义映射文件data_mapping)。 | 否 | +| -lm或--layer_mapping | 跨框架比对。配置该参数时表示开启跨框架Layer层的比对功能,指定模型代码中的Layer层后,可以识别对应dump数据中的模块或API。需要指定自定义映射文件*.yaml。自定义映射文件的格式请参见[自定义映射文件(Layer_mapping)](#46-自定义映射文件layer_mapping)。仅[跨框架的Layer层比对](#27-跨框架的layer层比对)场景需要配置。 | 否 | 动态图模式没有填写任何mapping时,按照同框架比对的方式进行比对,比对数据和标杆数据的Cell或Api名称需要完全相同才能匹配得上。 @@ -53,7 +54,7 @@ msprobe -f mindspore compare -i ./compare.json -o ./output -s 1. 参见《[MindSpore 场景的精度数据采集](./06.data_dump_MindSpore.md)》完成不同环境下MindSpore静态图精度数据的采集,得到不同框架版本的API dump数据。 -2. 创建比对文件,文件内容及示例请参见[比对文件](#31-比对文件)。 +2. 创建比对文件,文件内容及示例请参见[比对文件](#41-比对文件)。 3. 执行如下示例命令进行比对: @@ -67,7 +68,7 @@ msprobe -f mindspore compare -i ./compare.json -o ./output -s 1. 参见《[MindSpore 场景的精度数据采集](./06.data_dump_MindSpore.md)》完成不同环境下MindSpore静态图精度数据的采集,得到不同框架版本的kernel dump数据。 -2. 创建比对文件,文件内容及示例请参见[比对文件(kernel)](#32-比对文件kernel)。 +2. 创建比对文件,文件内容及示例请参见[比对文件(kernel)](#42-比对文件kernel)。 3. 执行如下示例命令进行比对: @@ -85,7 +86,7 @@ msprobe -f mindspore compare -i ./compare.json -o ./output -s 2. 参见《[MindSpore 场景的精度数据采集](./06.data_dump_MindSpore.md)》完成不同环境下MindSpore动态图精度数据的采集,得到不同框架版本的cell模块dump数据。 -3. 创建比对文件,文件内容及示例请参见[比对文件](#31-比对文件)。 +3. 创建比对文件,文件内容及示例请参见[比对文件](#41-比对文件)。 4. 执行如下示例命令进行比对: @@ -101,7 +102,7 @@ msprobe -f mindspore compare -i ./compare.json -o ./output -s 2. 参见《[MindSpore 场景的精度数据采集](./06.data_dump_MindSpore.md)》和《[PyTorch 场景的精度数据采集](./05.data_dump_PyTorch.md)》完成不同环境下API精度数据的采集,得到两个框架的API dump数据。 -3. 创建比对文件,文件内容及示例请参见[比对文件](#31-比对文件)。 +3. 创建比对文件,文件内容及示例请参见[比对文件](#41-比对文件)。 4. 执行如下示例命令进行比对: @@ -115,14 +116,14 @@ msprobe -f mindspore compare -i ./compare.json -o ./output -s msprobe -f mindspore compare -i ./compare.json -o ./output -s -am api_mapping.yaml ``` - api_mapping.yaml文件配置请参见[自定义映射文件(api_mapping)](#33-自定义映射文件api_mapping)。 + api_mapping.yaml文件配置请参见[自定义映射文件(api_mapping)](#43-自定义映射文件api_mapping)。 不传入api_mapping.yaml的情况下将按照内置的api映射进行匹配;传入api_mapping.yaml的情况下优先按照api_mapping.yaml的内容进行匹配,api_mapping.yaml中没有涉及的按照内置的api映射进行匹配。 此外,也可以通过data_mapping.yaml文件实现具体参数的匹配,例: ```shell msprobe -f mindspore compare -i ./compare.json -o ./output -s -dm data_mapping.yaml ``` - data_mapping.yaml的写法请参见[自定义映射文件(data_mapping)](#35-自定义映射文件data_mapping)。 + data_mapping.yaml的写法请参见[自定义映射文件(data_mapping)](#45-自定义映射文件data_mapping)。 5. 查看比对结果,请详见PyTorch目录下的《[PyTorch 场景的精度比对-精度比对结果分析](./10.accuracy_compare_PyTorch.md#3-精度比对结果分析)》章节。 @@ -132,7 +133,7 @@ msprobe -f mindspore compare -i ./compare.json -o ./output -s 2. 参见《[MindSpore 场景的精度数据采集](./06.data_dump_MindSpore.md)》和《[PyTorch 场景的精度数据采集](./05.data_dump_PyTorch.md)》完成不同环境下cell模块精度数据的采集,得到两个框架的cell模块dump数据。 -3. 创建比对文件,文件内容及示例请参见[比对文件](#31-比对文件)。 +3. 创建比对文件,文件内容及示例请参见[比对文件](#41-比对文件)。 4. 执行如下示例命令进行比对: @@ -146,14 +147,14 @@ msprobe -f mindspore compare -i ./compare.json -o ./output -s msprobe -f mindspore compare -i ./compare.json -o ./output -s -cm cell_mapping.yaml ``` - cell_mapping.yaml文件配置请参见[自定义映射文件(cell_mapping)](#34-自定义映射文件cell_mapping)。 + cell_mapping.yaml文件配置请参见[自定义映射文件(cell_mapping)](#44-自定义映射文件cell_mapping)。 不传入cell_mapping.yaml的情况下仅将Cell改成Module后进行匹配;传入cell_mapping.yaml的情况下将按照cell_mapping.yaml的内容进行匹配。 此外,也可以通过data_mapping.yaml文件实现具体参数的匹配,例: ```shell msprobe -f mindspore compare -i ./compare.json -o ./output -s -dm data_mapping.yaml ``` - data_mapping.yaml的写法请参见[自定义映射文件(data_mapping)](#35-自定义映射文件data_mapping)。 + data_mapping.yaml的写法请参见[自定义映射文件(data_mapping)](#45-自定义映射文件data_mapping)。 5. 查看比对结果,请详见PyTorch目录下的《[PyTorch 场景的精度比对-精度比对结果分析](./10.accuracy_compare_PyTorch.md#3-精度比对结果分析)》章节。 @@ -165,7 +166,7 @@ layer_mapping可以从Layer层识别整网的API和Cell,简化配置。 2. 参见《[MindSpore 场景的精度数据采集](./06.data_dump_MindSpore.md)》和《[PyTorch 场景的精度数据采集](./05.data_dump_PyTorch.md)》完成不同环境下API或模块精度数据的采集,得到两个框架的API或模块dump数据。 -3. 创建比对文件,文件内容及示例请参见[比对文件](#31-比对文件)。 +3. 创建比对文件,文件内容及示例请参见[比对文件](#41-比对文件)。 4. 执行如下示例命令进行比对: @@ -173,16 +174,34 @@ layer_mapping可以从Layer层识别整网的API和Cell,简化配置。 msprobe -f mindspore compare -i ./compare.json -o ./output -s -lm layer_mapping.yaml ``` - layer_mapping.yaml文件配置请参见[自定义映射文件(layer_mapping)](#36-自定义映射文件layer_mapping)。 + layer_mapping.yaml文件配置请参见[自定义映射文件(layer_mapping)](#46-自定义映射文件layer_mapping)。 此外,也可以通过data_mapping.yaml文件实现具体参数的匹配,例: ```shell msprobe -f mindspore compare -i ./compare.json -o ./output -s -dm data_mapping.yaml ``` - data_mapping.yaml的写法请参见[自定义映射文件(data_mapping)](#35-自定义映射文件data_mapping)。 + data_mapping.yaml的写法请参见[自定义映射文件(data_mapping)](#45-自定义映射文件data_mapping)。 5. 查看比对结果,请详见PyTorch目录下的《[PyTorch 场景的精度比对-精度比对结果分析](./10.accuracy_compare_PyTorch.md#3-精度比对结果分析)》章节。 +### 2.8 单点数据比对 +1. 参见 [单点保存工具](./28.debugger_save_instruction.md)章节完成 CPU 或 GPU 与 NPU 的单点数据采集。 + +2. 创建比对文件,文件内容及示例请参见[比对文件(单点数据)](#47-比对文件单点数据)。 + +3. 执行如下示例命令进行比对: + + ```shell + msprobe -f mindspore compare -i ./compare.json -o ./output + ``` + +4. Pytorch & MindSpore 动态图场景查看比对结果,请详见PyTorch目录下的《[PyTorch 场景的精度比对-精度比对结果分析](./10.accuracy_compare_PyTorch.md#3-精度比对结果分析)》章节。 +MindSpore静态图场景比对结果: +- `result.csv` 文件列出了所有执行精度比对的 单点保存数据 详细信息和比对结果,示例如下: + + ![compare_result](./img/save_compare_result_sample.png) +具体字段含义同PyTorch目录下的《[PyTorch 场景的精度比对-精度比对结果分析](./10.accuracy_compare_PyTorch.md#3-精度比对结果分析)》章节。 + ## 3 多卡比对结果提取汇总通信算子数据 本功能是将多卡比对场景的比对结果,进行通信算子数据提取和汇总,输出整理好的通信算子多卡比对精度表。 @@ -204,11 +223,12 @@ msprobe -f mindspore merge_result -i ./input_dir -o ./output_dir -config ./confi **完整参数说明** -| 参数名 | 说明 | 是否必选 | -| ---------------------- | ------------------------------------------------------------ | -------- | -| -i 或 --input_dir | 多卡比对结果存盘目录,即使用compare比对的结果输出目录,str类型。所有比对结果应全部为真实数据比对结果或统计数据比对结果,否则可能导致汇总数据不完整。 | 是 | -| -o 或 --output_dir | 数据提取汇总结果存盘目录,str类型。文件名称基于时间戳自动生成,格式为:`multi_ranks_compare_merge_{timestamp}.xlsx`。 | 是 | -| -config或--config-path | 指定需要汇总数据的API和比对指标的yaml文件路径,str类型。
yaml文件详细介绍见下文“**yaml文件说明**”。 | 是 | +| 参数名 | 说明 | 是否必选 | +|-----------------------|-------------------------------------------------------------------------------------------------------------------| -------- | +| -f 或 --framework | 指定训练框架。mindspore。 | 是 | +| -i 或 --input_dir | 多卡比对结果存盘目录,即使用compare比对的结果输出目录,str类型。所有比对结果应全部为真实数据比对结果或统计数据比对结果,否则可能导致汇总数据不完整。 | 是 | +| -o 或 --output_dir | 数据提取汇总结果存盘目录,str类型。文件名称基于时间戳自动生成,格式为:`multi_ranks_compare_merge_{timestamp}.xlsx`。
提示:output目录下与结果件同名文件将被删除覆盖。 | 是 | +| -config或--config-path | 指定需要汇总数据的API和比对指标的yaml文件路径,str类型。
yaml文件详细介绍见下文“**yaml文件说明**”。 | 是 | **yaml文件说明** @@ -224,10 +244,10 @@ compare_index: - MeanRelativeErr ``` -| 参数名 | 说明 | -| ------------- | ------------------------------------------------------------ | -| api | 表示需要汇总的API或module名称。如果没有配置,工具会提示报错。
api名称配置格式为:`{api_type}.{api_name}.{API调用次数}.{前向反向}`
须按顺序配置以上四个字段,可按如下组合配置:
{api_type}
{api_type}.{api_name}
{api_type}.{api_name}.{API调用次数}
{api_type}.{api_name}.{API调用次数}.{前向反向}
这里的api指代API或module。 | -| compare_index | 表示需要汇总的比对指标。compare_index需为dump_mode对应比对指标的子集。如果没有配置,工具将根据比对结果自动提取dump_mode对应的全部比对指标进行汇总。
统计数据模式比对指标:Max diff、Min diff、Mean diff、Norm diff、MaxRelativeErr、MinRelativeErr、MeanRelativeErr、NormRelativeErr
真实数据模式比对指标:Cosine、MaxAbsErr、MaxRelativeErr、One Thousandth Err Ratio、Five Thousandths Err Ratio | +| 参数名 | 说明 | +| ------------- |-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| api | 表示需要汇总的API或module名称。如果没有配置,工具会提示报错。
api名称配置格式为:`{api_type}.{api_name}.{API调用次数}.{前向反向}`
须按顺序配置以上四个字段,可按如下组合配置:
{api_type}
{api_type}.{api_name}
{api_type}.{api_name}.{API调用次数}
{api_type}.{api_name}.{API调用次数}.{前向反向}
这里的api指代API或module。 | +| compare_index | 表示需要汇总的比对指标。compare_index需为dump_mode对应比对指标的子集。如果没有配置,工具将根据比对结果自动提取dump_mode对应的全部比对指标进行汇总。
统计数据模式比对指标:Max diff、Min diff、Mean diff、L2norm diff、MaxRelativeErr、MinRelativeErr、MeanRelativeErr、NormRelativeErr
真实数据模式比对指标:Cosine、EucDist、MaxAbsErr、MaxRelativeErr、One Thousandth Err Ratio、Five Thousandths Err Ratio | **汇总结果件说明** @@ -279,20 +299,20 @@ compare_index: 多卡场景示例如下: ```json { -"npu_path": "./npu_dump/step0", # 需填写到step层级(rank的上一层级) -"bench_path": "./bench_dump/step0", # 需填写到step层级(rank的上一层级) +"npu_path": "./npu_dump/step0", # 需填写到step层级(rank的上一层级) +"bench_path": "./bench_dump/step0", # 需填写到step层级(rank的上一层级) "is_print_compare_log": true } ``` **参数说明** -| 参数名 | 说明 | 是否必选 | -| -------------------- | ------------------------------------------------------------ |------| -| npu_path | 配置NPU环境下的dump.json文件(单卡场景)。跨框架场景指定为MindSpore的json文件。数据类型:str。 | 是 | -| bench_path | 配置CPU、GPU或NPU环境下的dump.json文件(单卡场景)。 跨框架场景指定为PyTorch的json文件。数据类型:str。 | 是 | -| stack_path | 配置NPU dump目录下的stack.json文件。数据类型:str。 如果没有配置stack_path,命令行-s参数不生效,程序自动识别是否存在stack.json文件,如存在,则比对结果中呈现NPU_Stack_Info,如不存在,则不呈现。如果配置了stack_path,比对结果中是否呈现NPU_Stack_Info则通过命令行参数-s来控制。 | 否 | -| is_print_compare_log | 配置是否开启单个算子的日志打屏。可取值true或false,默认为true。关闭后则只输出常规日志。数据类型:bool | 否 | +| 参数名 | 说明 | 是否必选 | +| -------------------- |-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------| +| npu_path | 配置NPU环境下的dump.json文件(单卡场景)或dump目录(多卡场景)。跨框架场景指定为MindSpore的dump.json文件或dump目录。数据类型:str。 | 是 | +| bench_path | 配置CPU、GPU或NPU环境下的dump.json文件(单卡场景)或dump目录(多卡场景)。跨框架场景指定为PyTorch的dump.json文件或dump目录。数据类型:str。 | 是 | +| stack_path | 配置NPU dump目录下的stack.json文件。数据类型:str。如果没有配置stack_path,命令行-s参数不生效,程序自动识别是否存在stack.json文件,如存在,则比对结果中呈现NPU_Stack_Info,如不存在,则不呈现。如果配置了stack_path,比对结果中是否呈现NPU_Stack_Info则通过命令行参数-s来控制。 | 否 | +| is_print_compare_log | 配置是否开启单个算子的日志打屏。可取值true或false,默认为true。关闭后则只输出常规日志。数据类型:bool。 | 否 | ### 4.2 比对文件(kernel) @@ -573,7 +593,7 @@ input_args、input_kwargs和output使用统一的命名规则,当值是list类 "md5": "28f8f74f" } ] -} +} ``` , 初始名称为`Cell.network.module.NetworkWithLoss.forward.0`,`input_args`是`list`,长度为2,按照顺序命名为 @@ -646,4 +666,36 @@ yaml文件中只需配置MindSpore与PyTorch模型代码中功能一致但名称 模型代码示例: -![ms_dump](./img/ms_layer.png) \ No newline at end of file +![ms_dump](./img/ms_layer.png) + +### 4.7 比对文件(单点数据) + +MindSpore动态图单卡场景示例如下: + ```json +{ +"npu_path": "./npu_dump/debug.json", +"bench_path": "./bench_dump/debug.json" +} + ``` + +MindSpore动态图多卡场景(step0目录下包含debug.json文件)示例如下: +```json +{ +"npu_path": "./npu_dump/step0", +"bench_path": "./bench_dump/step0" +} +``` + +MindSpore静态图场景(不区分单/多卡)示例如下: +```json +{ +"npu_path": "./npu_dump/", +"bench_path": "./bench_dump/", +"map_dict": {"input": "x"}, +"common": true +} +``` +- `npu_path`表示NPU dump文件目录,可指定到./npu_dump/ 或者./npu_dump/step0 或者./npu_dump/step0/rank0 保证对应即可,比对结果保持相同目录结构。 +- `bench_path`表示bench dump文件目录,指定同上。 +- `common`表示开启MindSpore静态图单点保存比对,默认关闭。 +- `map_dict`可用于当单点保存比对的`npy`文件名称不完全对应时,通过手动指定保证比对正确执行,比对指定名称对应,如{"input": "x"},则`input_float32_1.npy`会对应`x_float32_1.npy`。 diff --git a/debug/accuracy_tools/msprobe/docs/12.overflow_check_PyTorch.md b/debug/accuracy_tools/msprobe/docs/12.overflow_check_PyTorch.md index 97b049000c6aca9a69aeca66e1a27a4260b3d142..ea41bcdc10a07a1ddcb033c82765b09fd90ebbf3 100644 --- a/debug/accuracy_tools/msprobe/docs/12.overflow_check_PyTorch.md +++ b/debug/accuracy_tools/msprobe/docs/12.overflow_check_PyTorch.md @@ -12,13 +12,13 @@ msprobe 工具在 PyTorch 场景下提供溢出数据采集功能和溢出数据 ### 1.2 接口介绍 -溢出检测功能提供的接口与数据采集任务一致,详见[ PyTorch 场景的精度数据采集](./05.data_dump_PyTorch.md)中的"**1 接口介绍**"章节。 +溢出检测功能提供的接口与数据采集任务一致,详见[ PyTorch 场景的精度数据采集](./05.data_dump_PyTorch.md)中的"**接口介绍**"章节。 其中 PrecisionDebugger 中的 task 或是 config.json 中的 task 需要指定为 **overflow_check**,详见[配置文件介绍](./02.config_introduction.md)中的 "**1.1 通用配置介绍**"和"**1.5 task 配置为 overflow_check**"章节。 ### 1.3 示例代码 -溢出检测功能使用方式与数据采集任务一致,详见[ PyTorch 场景的精度数据采集](./05.data_dump_PyTorch.md)中的"**2 示例代码**"章节。 +溢出检测功能使用方式与数据采集任务一致,详见[ PyTorch 场景的精度数据采集](./05.data_dump_PyTorch.md)中的"**示例代码**"章节。 ### 1.4 结果文件介绍 @@ -28,7 +28,7 @@ msprobe 工具在 PyTorch 场景下提供溢出数据采集功能和溢出数据 溢出数据采集功能在昇腾 NPU 上支持饱和模式(仅支持 Atlas 训练系列产品)和 INF/NAN 模式。 -INF/NAN 模式遵循 IEEE 754 标准,根据定义输出 INF/NAN 的计算结果。与之对应的饱和模式在计算出现溢出时,饱和为浮点数极值(+-MAX)。对于 CANN 侧配置,Atlas 训练系列产品,默认为饱和模式,且不建议使用 INF/NAN 模式;Atlas A2 训练系列产品,默认为 INF/NAN 模式,且不建议使用饱和模式。 +INF/NAN 模式遵循 IEEE 754 标准,根据定义输出 INF/NAN 的计算结果。与之对应的饱和模式在计算出现溢出时,饱和为浮点数极值(+-MAX)。对于 CANN 侧配置,Atlas 训练系列产品,默认为饱和模式,且不支持使用 INF/NAN 模式;Atlas A2 训练系列产品,默认为 INF/NAN 模式,且不建议使用饱和模式。 INF/NAN 模式的使能方式如下: @@ -58,8 +58,9 @@ export INF_NAN_MODE_ENABLE=1 msprobe -f pytorch run_overflow_check -api_info ./dump_path/step{step_number}/rank{rank_number}/dump.json ``` -| 参数名称 | 说明 | 是否必选 | -| -------------------------- |------------------------------------| -------- | +| 参数名称 | 说明 | 是否必选 | +|---------------------------|------------------------------------| -------- | +| -f 或 --framework | 指定训练框架。pytorch。 | 是 | | -api_info或--api_info_file | 指定采集下来的 API 信息文件 dump.json。 | 是 | | -j或--jit_compile | 开启 jit 编译。 | 否 | | -d或--device | 指定 Device ID,选择 UT 代码运行所在的卡,默认值为0。 | 否 | diff --git a/debug/accuracy_tools/msprobe/docs/13.overflow_check_MindSpore.md b/debug/accuracy_tools/msprobe/docs/13.overflow_check_MindSpore.md index 33ff4a0259aef02d122022402966c65358e8efff..ab280f1119cd17634a9a45aa48ad7e4ec78facb6 100644 --- a/debug/accuracy_tools/msprobe/docs/13.overflow_check_MindSpore.md +++ b/debug/accuracy_tools/msprobe/docs/13.overflow_check_MindSpore.md @@ -11,21 +11,23 @@ export INF_NAN_MODE_ENABLE=1 export MS_ASCEND_CHECK_OVERFLOW_MODE="INFNAN_MODE" ``` -**a**:在处理浮点数计算溢出问题时,NPU 当前支持两种溢出模式:INF/NAN 模式与饱和模式。INF/NAN 模式遵循 IEEE 754 标准,根据定义输出 INF/NAN 的计算结果。与之对应的饱和模式在计算出现溢出时,饱和为浮点数极值(+-MAX)。对于 CANN 侧配置,Atlas 训练系列产品,默认为饱和模式,且不建议使用 INF/NAN 模式;Atlas A2训练系列产品,默认为 INF/NAN 模式,且不建议使用饱和模式。对于 MindSpore 框架侧配置,仅支持对 Atlas A2 训练系列产品进行设置,默认为 INF/NAN 模式。CANN 侧 与 MindSpore 框架侧配置须一致。 +**a**:在处理浮点数计算溢出问题时,NPU 当前支持两种溢出模式:INF/NAN 模式与饱和模式。INF/NAN 模式遵循 IEEE 754 标准,根据定义输出 INF/NAN 的计算结果。与之对应的饱和模式在计算出现溢出时,饱和为浮点数极值(+-MAX)。对于 CANN 侧配置,Atlas 训练系列产品,默认为饱和模式,且不支持使用 INF/NAN 模式;Atlas A2训练系列产品,默认为 INF/NAN 模式,且不建议使用饱和模式。对于 MindSpore 框架侧配置,仅支持对 Atlas A2 训练系列产品进行设置,默认为 INF/NAN 模式。CANN 侧 与 MindSpore 框架侧配置须一致。 溢出检测任务的配置示例见[MindSpore 静态图场景下 task 配置为 overflow_check](https://gitee.com/ascend/mstt/blob/master/debug/accuracy_tools/msprobe/docs/03.config_examples.md#23-task-%E9%85%8D%E7%BD%AE%E4%B8%BA-overflow_check)、[MindSpore 动态图场景下 task 配置为 overflow_check](https://gitee.com/ascend/mstt/blob/master/debug/accuracy_tools/msprobe/docs/03.config_examples.md#33-task-%E9%85%8D%E7%BD%AE%E4%B8%BA-overflow_check)。 ## 1 接口介绍 -溢出检测功能提供的接口与数据采集任务一致,详见MindSpore 场景的精度数据采集中的["**1 接口介绍**"](./06.data_dump_MindSpore.md#1-接口介绍)章节。 +溢出检测功能提供的接口与数据采集任务一致,详见MindSpore 场景的精度数据采集中的["**接口介绍**"](./06.data_dump_MindSpore.md#6-接口介绍)章节。 需要注意,目前暂不支持动态图 "L1" level 下 primitive op 的溢出检测。 ## 2 示例代码 -溢出检测功能使用方式与数据采集任务一致,详见MindSpore 场景的精度数据采集中的["**2 示例代码**"](./06.data_dump_MindSpore.md#2-示例代码)节。 +溢出检测功能使用方式与数据采集任务一致,详见MindSpore 场景的精度数据采集中的["**示例代码**"](./06.data_dump_MindSpore.md#7-示例代码)节。 ## 3 溢出检测结果文件介绍 -溢出检测结果文件目录结构与含义与数据采集任务一致,但仅保存溢出 API 或 kernel 的真实数据或统计信息。详见MindSpore 场景的精度数据采集中的["**3 dump 结果文件介绍**"](./06.data_dump_MindSpore.md#3-dump-结果文件介绍)章节。 +溢出检测结果文件目录结构与含义与数据采集任务一致,但仅保存溢出 API 或 kernel 的真实数据或统计信息。详见MindSpore 场景的精度数据采集中的["**8. dump 结果文件介绍**"](./06.data_dump_MindSpore.md#8-dump-结果文件介绍)章节。 + +**说明**:在静态图 O2 编译等级下,若 MindSpore 版本为 2.4,或者 MindSpore 版本为 2.5,且未使用编包时添加了`--include-mod=adump`选项的 mindstudio-probe whl 包,则会产生 kernel_graph_overflow_check.json 中间文件,一般情况下无需关注。 \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/docs/14.data_parse_PyTorch.md b/debug/accuracy_tools/msprobe/docs/14.data_parse_PyTorch.md index 68a3d1a57dc1b649ffdb6d02d7be378900458e65..931e74a6dc2b266af888504e7d4e10e819ffa5d3 100644 --- a/debug/accuracy_tools/msprobe/docs/14.data_parse_PyTorch.md +++ b/debug/accuracy_tools/msprobe/docs/14.data_parse_PyTorch.md @@ -18,6 +18,9 @@ msprobe -f pytorch parse Parse >>> ``` +| 参数名称 | 说明 | 是否必选 | +|---------------------------|------------------------------------| -------- | +| -f 或 --framework | 指定训练框架。pytorch。 | 是 | 可在 parse 的界面中执行 Shell 命令,以及如下场景的相关解析命令(详细介绍请参见以下章节。): @@ -26,13 +29,7 @@ Parse >>> - 支持交互式指定 pkl 文件中 API 对应 dump 数据查看。 - 支持 API 进行可选层级比对和打印(统计级和像素级)。 -Ctrl+C 可以退出 parse 交互式界面。不退出 parse 交互式界面若需要执行非该界面下的内置 Shell 命令,且命令与 parse 交互式界面命令冲突时,非该界面命令需要使用 run 命令,在相关命令前加上 run 前缀,如下示例: - -```bash -msprobe -f pytorch parse -Parse >>> run vim cli.py -Parse >>> vim cli.py -``` +Ctrl+C 可以退出 parse 交互式界面。 ### 2.2 kernel 层级算子数据批量转换 @@ -44,11 +41,11 @@ Parse >>> vim cli.py cad -m my_dump_path [-out output_path] [-asc msaccucmp_path] ``` -| 参数名称 | 说明 | 是否必选 | -| -------- | ------------------------------------------------------------ | -------- | -| -m | 待转换 kernel dump 数据目录。需要指定到 kernel dump 数据的 deviceid 级目录。 | 是 | -| -out | 结果输出目录,须指定已存在的目录,默认为 ./parse_data/acl_batch_convert。未指定时保存在默认路径下,比对结束后会打印 log 提示输出结果存放路径。 | 否 | -| -asc | 指定 msaccucmp 路径,默认路径为:/usr/local/Ascend/ascend-toolkit/latest/tools/operator_cmp/compare/msaccucmp.py。 | 否 | +| 参数名称 | 说明 | 是否必选 | +|-------------------------| ------------------------------------------------------------ | -------- | +| -m 或 --my_dump_path | 待转换 kernel dump 数据目录。需要指定到 kernel dump 数据的 deviceid 级目录。 | 是 | +| -out 或 --output_path | 结果输出目录,须指定已存在的目录,默认为 ./parse_data/acl_batch_convert。未指定时保存在默认路径下,比对结束后会打印 log 提示输出结果存放路径。 | 否 | +| -asc 或 --msaccucmp_path | 指定 msaccucmp 路径,默认路径为:/usr/local/Ascend/ascend-toolkit/latest/tools/operator_cmp/compare/msaccucmp.py。 | 否 | **示例代码**: @@ -105,12 +102,12 @@ Parse >>> cad -m /home/xxx/my_dump_path/20000124003856/0 vc -m my_dump_path -g golden_dump_path [-out output_path] [-cmp_path msaccucmp_path] ``` -| 参数名称 | 说明 | 是否必选 | -| --------- | ------------------------------------------------------------ | -------- | -| -m | 待比对 kernel dump 数据目录。如果比对单个算子,需要指定到 kernel dump 数据的 model_id 级目录;如果批量比对,则指定到 cad 转换后的 timestamp 级目录。 | 是 | -| -g | 标杆 kernel dump 数据目录。如果比对单个算子,需要指定到 kernel dump 数据的 model_id 级目录;如果批量比对,则指定到 cad 转换后的 timestamp 级目录。 | 是 | -| -out | 结果输出目录,须指定已存在的目录,默认为 ./parse_data/acl_batch_comapre。未指定时保存在默认路径下,比对结束后会打印 log 提示输出结果存放路径。 | 否 | -| -cmp_path | 指定 msaccucmp 路径,默认路径为:/usr/local/Ascend/ascend-toolkit/latest/tools/operator_cmp/compare/msaccucmp.py | 否 | +| 参数名称 | 说明 | 是否必选 | +|------------------------------| ------------------------------------------------------------ | -------- | +| -m 或 --my_dump_path | 待比对 kernel dump 数据目录。如果比对单个算子,需要指定到 kernel dump 数据的 model_id 级目录;如果批量比对,则指定到 cad 转换后的 timestamp 级目录。 | 是 | +| -g 或 --golden_dump_path | 标杆 kernel dump 数据目录。如果比对单个算子,需要指定到 kernel dump 数据的 model_id 级目录;如果批量比对,则指定到 cad 转换后的 timestamp 级目录。 | 是 | +| -out 或 --output_path | 结果输出目录,须指定已存在的目录,默认为 ./parse_data/acl_batch_compare。未指定时保存在默认路径下,比对结束后会打印 log 提示输出结果存放路径。 | 否 | +| -cmp_path 或 --msaccucmp_path | 指定 msaccucmp 路径,默认路径为:/usr/local/Ascend/ascend-toolkit/latest/tools/operator_cmp/compare/msaccucmp.py | 否 | 输出结果:`batch_compare_{timestamp}.csv` 文件。 @@ -119,7 +116,7 @@ vc -m my_dump_path -g golden_dump_path [-out output_path] [-cmp_path msaccucmp_p ```bash # 传入待比对数据目录以及标杆数据目录 Parse >>> vc -m ./my_dump_path -g ./golden_data_path -[INFO]Compare result is saved in : parse_data/acl_batch_comapre/batch_compare_1707271118.csv +[INFO]Compare result is saved in : parse_data/acl_batch_compare/batch_compare_1707271118.csv ``` ### 2.3 kernel 算子数据的 npy 转换 @@ -130,12 +127,12 @@ Parse >>> vc -m ./my_dump_path -g ./golden_data_path dc -n file_name/file_path [-f format] [-out output_path] ``` -| 参数名称 | 说明 | 是否必选 | -| --------- | ------------------------------------------------------------ | -------- | -| -n | 需转换的 dump 数据文件或 dump 数据文件目录。 | 是 | -| -f | 开启 format 转换,指定该参数时需要配置 format 格式。当前内置的 Format 转换支持如下类型:
FRACTAL_NZ 转换 NCHW;
FRACTAL_NZ 转换成 NHWC;
FRACTAL_NZ 转换 ND;
HWCN 转换 FRACTAL_Z;
HWCN 转换成 NCHW;
HWCN 转换成 NHWC;
NC1HWC0 转换成 HWCN;
NC1HWC0 转换成 NCHW;
NC1HWC0 转换成 NHWC;
NCHW 转换成 FRACTAL_Z;
NCHW转换成NHWC;
NHWC转换成FRACTAL_Z;
NHWC转换成HWCN;
NHWC转换成NCHW;
NDC1HWC0转换成NCDHW。 | 否 | -| -out | 结果输出目录。 | 否 | -| -cmp_path | 指定 msaccucmp 路径,默认路径为:/usr/local/Ascend/ascend-toolkit/latest/tools/operator_cmp/compare/msaccucmp.py | 否 | +| 参数名称 | 说明 | 是否必选 | +|-------------------------| ------------------------------------------------------------ | -------- | +| -n 或 --name | 需转换的 dump 数据文件或 dump 数据文件目录。 | 是 | +| -f 或 --format | 开启 format 转换,指定该参数时需要配置 format 格式。当前内置的 Format 转换支持如下类型:
FRACTAL_NZ 转换 NCHW;
FRACTAL_NZ 转换成 NHWC;
FRACTAL_NZ 转换 ND;
HWCN 转换 FRACTAL_Z;
HWCN 转换成 NCHW;
HWCN 转换成 NHWC;
NC1HWC0 转换成 HWCN;
NC1HWC0 转换成 NCHW;
NC1HWC0 转换成 NHWC;
NCHW 转换成 FRACTAL_Z;
NCHW转换成NHWC;
NHWC转换成FRACTAL_Z;
NHWC转换成HWCN;
NHWC转换成NCHW;
NDC1HWC0转换成NCDHW。 | 否 | +| -out 或 --output_path | 结果输出目录。 | 否 | +| -cmp_path 或 --msaccucmp | 指定 msaccucmp 路径,默认路径为:/usr/local/Ascend/ascend-toolkit/latest/tools/operator_cmp/compare/msaccucmp.py | 否 | - 输出结果:npy 文件。 @@ -149,9 +146,9 @@ dc -n file_name/file_path [-f format] [-out output_path] pt -n file_path ``` - | 参数名称 | 说明 | 是否必选 | - | -------- | ------------- | -------- | - | -n | npy 文件路径。 | 是 | + | 参数名称 | 说明 | 是否必选 | + |-------------| ------------- | -------- | + | -n 或 --name | npy 文件路径。 | 是 | 打印统计信息:shape, dtype, max, min 和 mean。默认在 npy 文件路径下将该数据保存为 txt 文件。 @@ -197,10 +194,10 @@ TextFile:./parse_data/dump_convert/Add.fp32_vars_add_1fp32_vars_Relu_6.24.5.1636 pk -f pkl_path -n api_name ``` -| 参数名称 | 说明 | 是否必选 | -| -------- | ----------------------- | -------- | -| -f | 指定 dump.json 文件路径。 | 是 | -| -n | 指定 API 名称。 | 是 | +| 参数名称 | 说明 | 是否必选 | +|-------------| ----------------------- | -------- | +| -f 或 --file | 指定 dump.json 文件路径。 | 是 | +| -n 或 --name | 指定 API 名称。 | 是 | - 输出结果:打印统计信息(shape, dtype, max和min mean)。 - 若 pkl 文件中存在相应的堆栈信息,则会打印堆栈信息。 @@ -225,20 +222,20 @@ Statistic Info: 输入以下命令, 进行统计级和像素级比对。 ```bash -cn -m my_data*.npy -g gloden*.npy [-p num] [-al atol] [-rl rtol] +cn -m my_data*.npy -g golden*.npy [-p num] [-al atol] [-rl rtol] ``` - 统计级比对:对 tensor 整体进行余弦值及相对误差的计算。 - 像素级比对:对输入的两个 npy 文件进行逐元素比对。若两个 tensor 对应元素的相对误差或绝对误差大于**误差阈值**(-al 和 -rl 配置)则被标记为错误数据。 -| 参数名称 | 说明 | 是否必选 | -| -------- | ----------------------------------------------- | -------- | -| -m | 待比对数据。 | 是 | -| -g | 标杆数据。 | 是 | -| -p | 设置比对结束后打印错误元素的个数,默认值 20。 | 否 | -| -al | 判定数据存在精度问题的绝对误差阈值,默认 0.001。 | 否 | -| -rl | 判定数据存在精度问题的相对误差阈值,默认 0.001。 | 否 | -| -s | 将 npy 文件保存成 txt 文件,用于查看,默认开启。 | 否 | +| 参数名称 | 说明 | 是否必选 | +|-------------------------| ----------------------------------------------- | -------- | +| -m 或 --my_dump_path | 待比对数据。 | 是 | +| -g 或 --golden_dump_path | 标杆数据。 | 是 | +| -p 或 --print | 设置比对结束后打印错误元素的个数,默认值 20。 | 否 | +| -al 或 --atol | 判定数据存在精度问题的绝对误差阈值,默认 0.001。 | 否 | +| -rl 或 --rtol | 判定数据存在精度问题的相对误差阈值,默认 0.001。 | 否 | +| -s 或 --save | 将 npy 文件保存成 txt 文件,用于查看,默认开启。 | 否 | 输出结果: diff --git a/debug/accuracy_tools/msprobe/docs/17.grad_probe.md b/debug/accuracy_tools/msprobe/docs/17.grad_probe.md index f210088013415e40167f3eea3aab6163b0c947dc..da1183617610c61a41d6e0b27cf070fb9644112a 100644 --- a/debug/accuracy_tools/msprobe/docs/17.grad_probe.md +++ b/debug/accuracy_tools/msprobe/docs/17.grad_probe.md @@ -65,6 +65,7 @@ + 值分布:梯度数据落在各个区间的元素个数占总元素个数的比例。 + bounds:一个列表,用来划分出区间以统计值分布。例如传入bounds = [-10, 0, 10],此时有一个 grad_value: Tensor = [9.3 , 5.4, -1.0, -12.3],依据 bounds 划分出 (-inf, -10]、(-10, 0]、(0, 10]、(10, inf) 四个区间,然后统计grad_value里的数据落在每个区间内的个数,得到 1、1、2、0。如下图所示: + ![Alt text](./img/grad_probe_image-1.png) 2. 插入代码。示例代码如下: diff --git a/debug/accuracy_tools/msprobe/docs/18.online_dispatch.md b/debug/accuracy_tools/msprobe/docs/18.online_dispatch.md index e686c61b68add9c9a1ade9ae3e89b897c9b8d6bf..b8de3c6d68dadb206a27a1a78f80fe9bc321d2ea 100644 --- a/debug/accuracy_tools/msprobe/docs/18.online_dispatch.md +++ b/debug/accuracy_tools/msprobe/docs/18.online_dispatch.md @@ -70,7 +70,7 @@ PyTorch NPU在线精度比对是msprobe工具实现在PyTorch训练过程中直 | api_list | dump范围,dump_mode="list"时设置,需要Dump Aten Ir API名称,默认为None,Aten Ir API名称可以通过dir(torch.ops.aten)查看。 | 否 | | dump_path| dump文件生成的路径。 | 是 | | tag | 传入tag字符串,成为dump文件夹名一部分,默认为None。 | 否 | -| process_num | 多进程并发数,默认为0。 | 否 | +| process_num | 多进程并发数,默认为0,最大不超过CPU核数的四分之一。 | 否 | | debug | debug信息打印,默认为False。 | 否 | ### dump数据存盘说明 dump数据存盘目录名格式:`atat_tag_rankid_{timestamp}`。 diff --git a/debug/accuracy_tools/msprobe/docs/19.monitor.md b/debug/accuracy_tools/msprobe/docs/19.monitor.md index fa1b7d06d6c52b55c49f26352f823de41b28cb2d..f57fadf9afda3b4718da21cbce8714e0632e006e 100644 --- a/debug/accuracy_tools/msprobe/docs/19.monitor.md +++ b/debug/accuracy_tools/msprobe/docs/19.monitor.md @@ -10,7 +10,7 @@ 要求: - PyTorch场景:torch不低于**2.0** -- MindSpore场景:mindspore不低于**2.4.10**,仅支持**MindSpore动态图**,暂不支持**msadapter**套件 +- MindSpore场景:mindspore不低于**2.4.10**,仅支持**MindSpore动态图**,已支持**msadapter**套件 ## 功能介绍 下表中字段为训练状态轻量化监控工具的完整功能点: @@ -21,12 +21,13 @@ | [权重梯度监控](#权重梯度监控) | 开启权重梯度监控 | PyTorch、MindSpore | | [激活值监控](#激活值监控) | 开启激活值监控 | PyTorch、MindSpore | | [优化器状态监控](#优化器状态监控) | 开启优化器状态监控 | PyTorch、MindSpore | +| [采集module堆栈信息](#采集module堆栈信息) | 采集监控的第一个 step 的 module 对应的堆栈信息辅助问题定位 | PyTorch、MindSpore | | [指定监控对象](#指定监控对象) | 指定监控的nn.Module(nn.Cell)及对应的输入输出 | PyTorch、MindSpore | | [打印模型结构](#打印模型结构) | 打印模型结构 | PyTorch | -| [Module全量监控](#Module全量监控) | 对全量module的输入输出做监控 | PyTorch、MindSpore | -| [Parameter全量监控](#Parameter全量监控) | 对全量Parameter的输入输出做监控 | PyTorch、MindSpore | -| [输出格式和统计量](#输出格式和统计量) | format PyTorch支持`csv`、`tensorboard`和`api`,MindSpore仅支持`csv`,`ops`均支持,`ndigits`仅PyTorch支持 | PyTorch、MindSpore | -| [梯度异常时序判断](#梯度异常时序判断) | 梯度异常时自动梯度落盘 | PyTorch | +| [l2可解释特征监控](#l2可解释特征监控) | 开启模型状态的高阶监控 | PyTorch | +| [输出格式和统计量](#输出格式和统计量) | format PyTorch支持`csv`、`tensorboard`和`api`,MindSpore仅支持`csv`,`ops`、`ndigits`均支持 | PyTorch、MindSpore | +| [mbs粒度梯度监控](#mbs粒度梯度监控) | 开启梯度监控时,采集聚合前梯度时支持`micro_batch_size`粒度 | PyTorch、MindSpore | +| [异常告警](#异常告警) | 监控对象指标异常时自动告警,支持异常数据落盘 | PyTorch、MindSpore | | [csv格式数据转tensorboard可视化显示](#csv格式数据转tensorboard可视化显示) | 将csv转为tensorboard文件显示 | PyTorch | | [动态启停](#动态启停) | 训练过程中动态修改配置开启监控 | PyTorch、MindSpore | | [功能重载](#功能重载) | 训练中开启激活值监控。待废弃,请使用动态启停功能代替。 | PyTorch | @@ -205,12 +206,26 @@ monitor.monitor_gnorm_with_ad( 本工具针对分布式计算框架megatron和deepspeed框架做了适配,暂不支持其他框架。 +### 采集module堆栈信息 +- 工具配置示例: +```json +{ + "targets": { + }, + "format": "csv", + "stack_info": true +} +``` +开启 `stack_info` 后会采集监控的第一个 step 的所有 module 的堆栈信息,输出格式仅支持 csv 。 ## 高阶功能 + ### 指定监控对象 -工具支持对nn.Module(**激活值监控**)和nn.Parameter(**权重监控**、**权重梯度监控、优化器监控**)对象实现相应的监控行为,在配置文件的"targets"(dict)字段指定,targets格式为{module_name/param_name: {filed: format}}。 +工具支持对指定nn.Module进行状态监控,在配置文件的`targets`字段中指定,`targets`格式为{module_name: {}}。 + +module_name可以通过nn.Module的接口named_modules()获取。 #### 打印模型结构 工具提供可选项`print_struct`打印模型结构,帮助配置targets。工具会在在第一个step后打印结构并停止训练进程,模型结构默认打印在`$MONITOR_OUTPUT_DIR/module_struct.json`。 @@ -221,7 +236,6 @@ monitor.monitor_gnorm_with_ad( ``` 输出样例: -字段`config`用于配置文件中指定module target。其余为各个元素的shape和dtype。 ```json "0:63.mlp.linear_fc2": { @@ -245,40 +259,30 @@ monitor.monitor_gnorm_with_ad( } }, ``` +对于module对象,通常关心前向/反向传播的输入和输出: -- Module - 对于module对象,通常关心其前向的输入(input)输出(output)和反向的输入--前向输出的梯度(output_grad)和输出--前向输入的梯度(input_grad)。同时需要声明这些对象的类型,通常为"tensor"或"tuple\[length]"。 +- 前向的输入(input) +- 前向的输出(output) +- 反向的输入,表示前向输出的梯度(output_grad) +- 反向的输出,表示前向输入的梯度(input_grad) - "tensor"可以直接用来计算统计量,"tuple"需要进一步指定监控的索引。如"tuple[2]:0",表示该对象为长度2的tuple,对第0元素进行监控;不指定索引时,默认对第0元素进行监控。 - module_name可以通过nn.Module的接口`named_modules()`获取。 -```json -// 示例:对一个名为"module.encoder.layers.0.mlp"的module,监控其前向输入第0元素和输出。 -{ - "targets": { - "module.encoder.layers.0.mlp": { - "input": "tuple[2]:0", - "output": "tensor" - } - } -} -``` -#### Module全量监控 -工具提供简便的全量module监控方式。或不配置targets、all_xy字段,同样表示全量监控。 +#### 指定监控对象 + +targets字段指定监控对象示例如下: ```json -{ - "targets": {}, - "all_xy": true +// 示例:对一个名为"module.encoder.layers.0.mlp"的module。 +"targets": { + "module.encoder.layers.0.mlp": {} } ``` +对于parameter对象,通常会关注其在一个训练迭代中的梯度(weight grad)、adam类优化器中的动量(1st moment, 2nd moment)。 +parameter归属于某一module,可以通过指定module_name来监控包含在这一module中的**所有**parameter。 -- Parameter - 对于parameter对象,通常会关注其在一个训练迭代中的梯度(weight grad)、adam类优化器中的动量(1st moment, 2nd moment)。 - parameter归属于某一module,也可以通过指定module_name来监控包含在这一module中的**所有**parameter。 +param_name可以通过nn.Module的接口`named_parameters()`获取。 - param_name可以通过nn.Module的接口`named_parameters()`获取。 ```json // 示例:监控"module.encoder.layers.0.mlp"的所有参数和"module.embedding.word_embedding.weight"这一参数 { @@ -289,8 +293,10 @@ monitor.monitor_gnorm_with_ad( } ``` -#### Parameter全量监控 -工具提供简便的全量parameter监控方式。或不配置targets,同样表示全量监控。 + +#### 全量监控 + +工具提供简便的全量module对象监控方式。 ```json { @@ -298,7 +304,48 @@ monitor.monitor_gnorm_with_ad( } ``` +### l2可解释特征监控 +- 工具配置 +```json +{ + "l2_targets": { + "attention_hook": ["0:0.self_attention.core_attention.flash_attention"], + "linear_hook": ["0:0.self_attention.linear_qkv", "0:1.self_attention.linear_qkv"], + "token_hook": ["0:0"], + "norm_hook": ["0:0.input_layernorm"] + }, + + "proxy_model": false, + "recording_l2_features": true, + "module_ranks": [0] +} +``` +| 配置项 | 类型 | 说明 | 示例值 | +|--------|------|------|--------| +| **l2_targets.attention_hook** | List[str] | 指定需要监控的注意力层, 采集"entropy"和"sorftmax_max"指标,可以通过[打印模型结构功能](#打印模型结构)获取 | `["0:0.self_attention.core_attention.flash_attention"]` | +| **l2_targets.linear_hook** | List[str] | 指定需要监控的线性层, 采集"sr"和 "kernel_norm"指标,可以通过[打印模型结构功能](#打印模型结构)获取 | `["0:0.self_attention.linear_qkv", "0:1.self_attention.linear_qkv"]` | +| **l2_targets.token_hook** | List[str] | 指定需要监控输出token间相似度的模型层, 采集"token_similarity指标,可以通过[打印模型结构功能](#打印模型结构)获取 | `["0:0"]` | +| **l2_targets.norm_hook** | List[str] | 指定需要监控的归一化层, 采集"std_x"和"jacobian"指标,可以通过[打印模型结构功能](#打印模型结构)获取 | `["0:0.input_layernorm"]` | +| **proxy_model** | bool | 是否监控模型dp间reduce浮点误差加法影响, 默认为false, 选择true会在开启[权重梯度监控](#权重梯度监控)的target中注册的模块的梯度聚合影响指标(proxy)监控,**需要在[Train_mon](#公开接口)定义时指定dp_group,并设置module_ranks指定监控module的dp组** | `false` | +| **module_ranks** | List[int] | 指定需要监控的module所在的dp组, | `[0]` | +| **recording_l2_features** | bool | 是否开启L2层特征数据采集 | `true` | + + +#### L2可解释特征监控指标说明 + +| **指标名称** | **适用Hook类型** | **数学定义/计算方式** | **监控意义** | +|--------------------|-------------------|-------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------| +| **entropy** | attention_hook | $H(p)=-\sum p_i \log p_i$,其中$p_i$为注意力权重 | 衡量注意力分布的不确定性,低熵值表示注意力集中,高熵值表示注意力分散 | +| **softmax_max** | attention_hook | $\max(\text{softmax}(QK^T/\sqrt{d}))$ | 反映注意力机制的聚焦程度,高值表示存在显著主导的注意力头 | +| **sr(stable_rank)** | linear_hook | $\frac{\|W\|_F}{\|W\|_2}$(稳定秩,Frobenius范数除以谱范数) | 评估权重矩阵的有效秩,低值表示矩阵接近低秩状态 | +| **kernel_norm** | linear_hook | $\|W\|_F$(Frobenius范数) | 评估权重矩阵的缩放程度,与模型泛化能力相关 | +| **token_similarity** | token_hook | $\text{cosine\_sim}(t_i, t_j)=\frac{t_i \cdot t_j}{\|t_i\|\|t_j\|}$ | 跟踪token嵌入的相似性变化,检测表征坍缩现象 | +| **std_x** | norm_hook | $\sqrt{\frac{1}{N}\sum(x_i-\mu)^2}$ | 特征值的标准差,反映层输入的分布稳定性 | +| **jacobian** | norm_hook | $\|\frac{\partial f(x)}{\partial x}\|_2$ | 评估模型局部线性度,高值可能预示训练不稳定 | +|**proxy**|proxy_hook| $\frac{1}{m} \sum_{i=1}^m\left\|z_i\right \| / \left\|\frac{1}{m} \sum_{i=1}^m z_i\right\|$ | 聚合前梯度离散度与中心趋势的相对比值,评估模型dp间梯度分散程度,以估计reduce浮点误差加法可能带来的影响| + ### 输出格式和统计量 + 工具配置示例: ```json { @@ -333,7 +380,7 @@ export MONITOR_OUTPUT_DIR=/xxx/output_dir 监控结果写入csv文件中,可以通过`ndigits`字段设置小数位数。 表头为 vpp_stage | name | step | micro_step(optional) | *ops |。 仅在激活值监控的输出文件中包含micor_step。 - 激活值监控的name为.\, 其他任务的name为> + 激活值监控的name为.\, 其他任务的name为 - **api** 监控结果不落盘,在训练过程中可以通过`generate_wgrad_metrics`、`generate_xy_metrics`等接口获取,使用方式参考[公开接口](#公开接口) 。 @@ -349,16 +396,54 @@ export MONITOR_OUTPUT_DIR=/xxx/output_dir ![step_count_per_record](img/monitor/step_count_per_record.png) -### 梯度异常时序判断 +### mbs粒度梯度监控 + +当配置梯度监控任务时,工具默认`global_batch_size`粒度进行梯度监控。当需要监控`micro_batch_size`粒度梯度信息时,在配置文件中配置`monitor_mbs_grad`为`true`,配置示例如下: + +```json +{ + "wg_distribution": true, + "monitor_mbs_grad": true +} +``` + +应用范围 + +- **仅支持采集聚合前梯度**,在梯度累积场景下,聚合后梯度已无法区分`micro_batch`数据。 +- PyTorch场景下,Megatron和DeepSpeed训练框架下均支持,FSDP训练框架下暂不支持。 +- MindSpore场景下均支持。 + +### 异常告警 + +工具的异常告警功能旨在自动判断训练过程中的异常现象,用户可通过在配置文件中配置alert字段来指定告警规则,并在训练过程中根据该规则及时打屏对用户发出告警。 + + 1. 训练前配置相关参数 -工具支持自动判断训练过程中的梯度异常,需要在配置文件中设置alert相关字段。"AnomalyTurbulence"会将当前数值与历史均值比较,如果相对偏差超过阈值,会在打屏信息中提示用户。如果打开"`dump`"选项,则会将异常梯度相关信息落盘到目录`monitor_output/anomaly_detected`,用于后续时序判断。 +当前支持的异常告警规则如下: + +| 异常告警 |解释| rule_name | args是否可选 | +|--------------|----|-----------|---------------------------------------------------------------------| +| 历史均值偏离告警 |将当前数值与历史均值比较。如果相对偏差超过阈值,会在打屏信息中提示用户指标偏离。当前仅对`norm`和`mean`指标生效。| AnomalyTurbulence | 否,必须传入threshold。当指标超过`(1+threshold)*avg`时,识别为偏离历史均值。 | +| nan值/极大值告警 |根据是否提供threshold来判断nan值或极大值| AnomalyNan | 是, 若未配置args或未配置threshold,则默认检测nan,若提供threshold,则检测nan值以及绝对值超过阈值的极大值 | + +除此之外,我们在alert中支持dump配置项,如果打开"`dump`"选项,则会将异常信息落盘到目录`monitor_output/anomaly_detected`。 + +- 历史均值偏离告警案例如下: ```json "alert": { - "rules": [{"rule_name": "AnomalyTurbulence", "args": {"threshold": 0.5}}], + "rules": [{"rule_name": "AnomalyTurbulence", "args": {"threshold": 0.5}}], // 0.5表示偏离50%则提示偏离 + "dump": true + }, +``` +- nan值/极大值告警案例如下: +```json + "alert": { + "rules": [{"rule_name": "AnomalyNan", "args": {"threshold": 1e10}}], "dump": true }, ``` + 2. 实例化工具时传入流水线并行group ```python monitor = TrainerMon( @@ -395,9 +480,9 @@ python3 -m msprobe.pytorch.monitor.anomaly_analyse -d $MONITOR_OUTPUT_DIR/anomal ``` 异常事件分析结束,将topk事件写入文件`anomaly_detected/anomaly_analyse.json`。异常分析支持以下参数配置: -| 字段名 | 解释 | 是否必选 | -| ----------------- | ------------------------------------------------------------ | -------- | -| -d 或 --data_path | 指定梯度异常落盘文件夹,梯度监控功能输出,一般为$MONITOR_OUTPUT_DIR/anomaly_detected。 | 是 | +| 字段名 | 解释 | 是否必选 | +| ----------------- | --------------------------------------------------------- | -------- | +| -d 或 --data_path | 指定异常落盘文件夹,监控功能输出,一般为$MONITOR_OUTPUT_DIR/anomaly_detected。 | 是 | | -o 或 --out_path | 排序后的异常落盘文件地址,默认在--data_path路径下落盘一个anomaly_analyse.json文件。 | 否 | | -k 或 --topk | 指定保留前topk个异常,默认为8。 | 否 | | -s 或 --step_list | 指定分析的step范围,默认为[]。 | 否 | @@ -411,39 +496,65 @@ python3 -m msprobe.pytorch.monitor.anomaly_analyse -d $MONITOR_OUTPUT_DIR/anomal from msprobe.pytorch.monitor.csv2tb import csv2tensorboard_by_step # 前三个参数用来指定需要转换的一批文件,指定monitor输出目录及一个时间范围,会对这个范围内的文件进行转换 # process_num指定拉起的进程个数,默认为1,更多的进程个数可以加速转换 -# data_type_list是一个列表,指定需要转换的数据类型, 数据类型应来自输出件文件前缀,所有类型数据: -# ["actv", "actv_grad", "exp_avg", "exp_avg_sq", "grad_unreduced", "grad_reduced", "param"] -# 不指定就转换全部数据 -# output_dirpath可指定输出目录, 不传值时保存到"{curtime}_csv2tensorboard_by_step"文件夹,其中curtime为自动获取的当前时间戳 +# data_type_list是一个列表,指定需要转换的数据类型,默认转换全部数据,数据类型应来自输出件文件前缀,所有类型数据: +# ["actv", "actv_grad", "exp_avg", "exp_avg_sq", "grad_unreduced", "grad_reduced", "param_origin", "param_updated"] +# output_dirpath可指定输出目录,默认保存到"{curtime}_csv2tensorboard_by_step"文件夹,其中curtime为自动获取的当前时间戳 csv2tensorboard_by_step( - monitor_path="~/monitor_output", - time_start="Dec03_21-34-40", - time_end="Dec03_21-34-42", + monitor_path="~/monitor_output", # 必填 + time_start="Dec03_21-34-40", # 必填 + time_end="Dec03_21-34-42", # 必填 process_num=8, - data_type_list=["param"] + data_type_list=["grad_unreduced"] ) ``` +将csv数据转换为sqlite数据。 + +```python +from msprobe.pytorch.monitor.csv2db import CSV2DBConfig, csv2db +# output_dirpath可指定输出目录,默认保存到"{curtime}_csv2db"文件夹,其中curtime为自动获取的当前时间戳 +# step_partition可以控制数据库中按step分区的间隔,默认每500步一个表 +config = CSV2DBConfig( + monitor_path="~/monitor_output",# 与转换为tensorboard用法一致 + time_start="Dec03_21-34-40",# 与转换为tensorboard用法一致 + time_end="Dec03_21-34-42",# 与转换为tensorboard用法一致 + process_num=8,# 与转换为tensorboard用法一致 + data_type_list=["grad_unreduced"],# 与转换为tensorboard用法一致 + step_partition=500, + output_dirpath="~/monitor_output" +) +csv2db(config) +``` + ### 动态启停 动态启停模式:支持用户在训练过程中随时启动/更新监控。 -用户可在训练开始前通过配置环境变量DYNAMIC_MONITOR=True来确认开启动态启停模式,该模式下需要配合config.json文件中的dynamic_on字段来使用。 +用户可在训练开始前通过配置环境变量`DYNAMIC_MONITOR=True`来确认进入动态启停模式,该模式下需要配合config.json文件中的`dynamic_on`字段来使用。 在动态启停模式下,启动和停止分别由如下控制: -- 启动: - 首次监控:config.json文件中dynamic_on字段为true,代表是否需要开启监控。 - 非首次监控:config文件时间戳更新且config.json文件中dynamic_on字段为true。 -- 停止: - 到达collect_times之后自动停止并改config.json文件中dynamic_on字段为false,可再通过上述操作重启。 +- **启动**: + - 首次监控:查看config.json文件中`dynamic_on`字段,若为`true`则在下一步开启监控。 + - 非首次监控:查看config.json文件时间戳,若时间戳更新且config.json文件中`dynamic_on`字段为`true`则在下一步开启监控。 +- **停止**: + 到达`collect_times`之后自动停止并改config.json文件中`dynamic_on`字段为`false`,可再通过上述操作重启。 -大部分情况下,用户可在看到异常趋势后再手动更新config.json文件并打开dynamic_on开关;此外,使用时若想要在一开始就启动监控,可直接打开dynamic_on开关做基础配置的监测(首次不要求时间戳更新) +**注意事项:**: -注意事项: +- 默认监控启动皆统一在配置初始化或查询到更新后的下一步,即第n步挂上hook将在第n+1步启动采集,如需采集第0步数据请使用静态模式。 +- config.json中途修改出错时,若此时不在监控则不生效,若在监控则用原配置继续。 +- 达到`collect_times`之后程序会自动将该值置为`false`待下次改`true`重启。 -- 默认监控启动皆统一在配置初始化或查询到更新后的下一步,也就是若第n步挂上hook则第n+1步才启动采集,如需采集第0步数据请用静态模式。 -- config中途修改出错时,若此时不在监控就不生效,若在监控则用原配置继续。 -- 达到collect_times之后会自动将该值置为false待下次改true重启。 +**支持的使用场景说明如下:** + +| 场景 | 监控模式 | 操作步骤 | 结果描述 | +|-----------------------------------------------|----|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------| +| 场景1: 使用默认静态模式 | 静态 | 1. 配置环境变量:`export DYNAMIC_MONITOR=False `
或不设置该环境变量 | 走默认分支进行数据采集和保存,不受config.json中`dynamic_on`影响 | +| 场景2: 进入动态启停模式,初始不启动监控 | 动态 | 1.配置环境变量:`export DYNAMIC_MONITOR=True`
2.配置config.json中`dynamic_on: false`或不设置该字段 | 初始状态下无监控,不进行数据采集和保存 | +| 场景3: 进入动态启停模式,初始即启动监控 | 动态 | 1.配置环境变量:`export DYNAMIC_MONITOR=True`
2.配置config.json中`dynamic_on: true` | 根据初始配置在第1步(初始计数为0)开启监控并保存,采集`collect_times`次数后结束监控 | +| 场景4: 进入动态启停模式,初始暂不启动监控,训练中途启动 | 动态 | 1.配置环境变量:`export DYNAMIC_MONITOR=True`
2.开始时配置config.json中`dynamic_on: false`或不设置该字段
3.训练中途修改config.json中`dynamic_on: true` | 训练中途根据最新配置在下一步开启监控并保存,采集`collect_times`次数后结束监控 | +| 场景5: 进入动态启停模式,监控还未结束时中途修改config.json采集配置 | 动态 | 1.配置环境变量:`export DYNAMIC_MONITOR=True`
2.期间配置`dynamic_on: true`启动采集
3.在采集还未达到`collect_times`次数前,中途修改config.json配置 | 更新前按旧配置采集并保存,更新后下一步以最新config.json采集且`collect_times`重新从0开始计数。此功能可配合中途`collect_times`改0来实现提前停止监控。 +| 场景6: 进入动态启停模式,在根据`collect_times`结束监控后,需重新启动监控 | 动态 | 1.配置环境变量:`export DYNAMIC_MONITOR=True`
2.期间`dynamic_on: true`启动采集
3.采集达到`collect_times`次数后结束监控,程序自动改`dynamic_on:false`
4.配置config.json中`dynamic_on:true`重启监控 | 更新前按旧配置采集并保存,中途停止监控后无采集,重启后下一步以最新config.json重启采集且`collect_times`重新从0开始计数。 ### 功能重载 此功能将在2026年废弃。请使用[动态启停](#动态启停)功能代替。 @@ -499,8 +610,8 @@ csv2tensorboard_by_step(monitor_path, time_start, time_end, process_num=1, data_ | time_start | 起始时间戳。搭配time_end一起使用。指定一个时间范围,会对这个范围内的文件进行转换。左闭右闭的区间。 | 是 | | time_end | 结束时间戳。搭配time_start一起使用。指定一个时间范围,会对这个范围内的文件进行转换。左闭右闭的区间。 | 是 | | process_num | 指定拉起的进程个数,默认为1,更多的进程个数可以加速转换。 | 否 | -| data_type_list | 指定需要转换的数据类型, 数据类型应来自输出件文件前缀,所有类型数据:
["actv", "actv_grad", "exp_avg", "exp_avg_sq", "grad_unreduced", "grad_reduced", "param"]。
不指定就转换全部数据。 | 否 | - +| data_type_list | 指定需要转换的数据类型, 数据类型应来自输出件文件前缀,所有类型数据:
["actv", "actv_grad", "exp_avg", "exp_avg_sq", "grad_unreduced", "grad_reduced", "param_origin", "param_updated"]。
不指定就转换全部数据。 | 否 | +| output_dirpath | 指定转换后的输出路径,默认输出到"{curtime}_csv2tensorboard_by_step"文件夹,其中curtime为自动获取的当前时间戳。 | 否 | - 在模型任意位置获取当前参数**梯度**统计量 ```python TrainerMon.generate_wgrad_metrics() -> tuple[dict, dict] @@ -561,6 +672,7 @@ TrainerMon.monitor_gnorm_with_ad(model, grad_acc_steps, optimizer, dp_group, tp_ "mv_distribution": true, "param_distribution": true, "wg_distribution": true, + "monitor_mbs_grad": true, "cc_distribution": {"enable":true, "cc_codeline":[]}, "alert": { "rules": [{"rule_name": "AnomalyTurbulence", "args": {"threshold": 0.5}}], @@ -578,33 +690,36 @@ TrainerMon.monitor_gnorm_with_ad(model, grad_acc_steps, optimizer, dp_group, tp_ 下面详细解释各个字段: -| 字段名字 | 是否必选 | 解释 | -| ----------------------- | -------- |-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| "targets" | 可选 | 指定需要监控的模型层和监控对象, 例如transformer的第0层language_model.encoder.layers.0,可选择监控input、output、input_grad、output_grad。如果不清楚模型结构, 可以将 "print_struct" 字段设置为 true, 监控工具会打印模型中torch module的名字和详细结构,并在第1个step后退出。未配置时默认为全量监控。 | -| "input" | 可选 | "tuple[2]:0"的意思是目标module的前向input参数为长度为2的tuple, 我们关心的是tuple第0个元素。 | -| "output" | 必选 | "tensor"的意思是目标module的前向output参数类型为tensor | -| "input_grad" | 可选 | "tuple[2]:0"的意思是目标module的后向input_grad参数是长度为2的tuple, 我们关心的是tuple的第0个元素。 | -| "output_grad" | 必选 | "tuple[1]:0"的意思是目标module的后向input_grad参数是长度为1的tuple, 我们关心的是tuple的第0个元素。 | -| "dynamic_on" | 可选 | 在动态启停时使用,true代表打开监控,false代表关闭监控,默认值为false,且达到collect_times之后会自动将该值置为false待下次改true重启。**仅PyTorch场景支持此参数**。 | -| "collect_times" | 可选 | 设置采集次数,达到该次数后停止监控,默认值为100000000,目的是一直采集。 | -| "start_step" | 可选 | 设置开始采集step,模型训练达到start_step后开始监控采集,默认值为0,表示从step0开始监控采集。 | -| "step_interval" | 可选 | 设置采集step间隔,默认值为1,表示每个step均采集监控数据。 | -| "print_struct" | 可选 | 设置为true后监控工具会打印模型中torch module的名字和详细结构,并在第1个step后退出。不填默认为false。**仅PyTorch场景支持此参数**。 | -| "module_ranks" | 可选 | 用于在分布式训练场景中希望控制在哪些rank开启module监控。如果不填,则默认在所有rank开启。 列表内rank要求为int类型。 | -| "ur_distribution" | 可选 | 若为true则会统计adam优化器指定模块(targets中指定)参数的update和ratio向量的数值分布,并展示在heatmap里,默认为false,同时format字段必须设置为tensorboard。
依赖histc算子, 需要CANN8.0.rc2以上版本, 否则会有严重的性能问题。**仅PyTorch场景支持此参数**。 | -| "xy_distribution" | 可选 | 若为true则会监控指定module(targets中指定)的输入输出张量。 默认为false。 | -| "all_xy" | 可选 | 开启xy_distribution后生效,若为true,监控所有module。默认为false。
与targets同时生效,all_xy配置为true时,若targets配置module_xx和指定对象,则module_xx按targets配置生效,其他module则监控全部对象,包含input、output、input_grad、output_grad。 | -| "forward_only" | 可选 | 开启xy_distribution后生效,若为true,仅监控指定module的前向,targets中的input_grad、output_grad不生效。默认为false。 | -| "backward_only" | 可选 | 开启xy_distribution后生效,若为true,仅监控指定module的反向,targets中的input、output不生效。默认为false。 | -| "mv_distribution" | 可选 | 若为true则会监控指定模块中的参数的优化器状态, 默认为false。版本依赖histc算子, 需要CANN8.0.rc2以上版本, 否则会有严重的性能问题。**仅PyTorch场景支持此参数**。 | +| "xy_distribution" | 可选 | 若为true则会监控指定module(targets中指定)的输入输出张量。 默认为false。 | +| "all_xy" | 可选 | 开启xy_distribution后生效,若为true,监控所有module。默认为false。
与targets同时生效,all_xy配置为true时,若targets配置module_xx和指定对象,则module_xx按targets配置生效,其他module则监控全部对象,包含input、output、input_grad、output_grad。 | +| "forward_only" | 可选 | 开启xy_distribution后生效,若为true,仅监控指定module的前向,targets中的input_grad、output_grad不生效。默认为false。 | +| "backward_only" | 可选 | 开启xy_distribution后生效,若为true,仅监控指定module的反向,targets中的input、output不生效。默认为false。 | +| "mv_distribution" | 可选 | 若为true则会监控指定模块中的参数的优化器状态, 默认为false。版本=2.4.0 -## 展示示例 +## 更新通知 -支持重建模型的层级结构; +请注意,tb_graph_ascend插件已于2025/3/12更新到1.0.0版本,如果当前环境已安装旧版本插件,推荐升级。 -支持两个模型的结构差异比对; +更新内容如下: -支持两个模型的精度数据比对,支持疑似有精度问题节点的快速搜索,自动跳转展开节点所在的层级。 +- 优化了信息栏,使用了更人性化、更美观的展示界面; +- 提升了节点渲染和搜索性能; +- 双图比对场景画布分离,操作左图时不会影响到右图; +- 新增浏览器匹配节点功能,双图比对场景有未匹配节点时,可通过在浏览器页面手动选中调试侧和标杆侧的未匹配节点进行精度比对; +- 新增颜色图例可配置功能。 + +## 工具特性 + +- 支持重建模型的层级结构; +- 支持两个模型的结构差异比对; +- 支持两个模型的精度数据比对; +- 支持模型数据的溢出检测; +- 支持多卡场景的批量构图,能够关联各卡的通信节点,分析各卡之间的数据传递; +- 支持节点名称搜索,按精度比对结果筛选节点,按溢出检测结果筛选节点,支持自动跳转展开节点所在的层级; +- 支持跨套件、跨框架的模型比对。 ![vis_show](./img/visualization/vis_showcase.png) ## 1.依赖安装 -分级可视化工具依赖**msprobe工具**和**tensorboard。** - ### 1.1 安装msprobe工具 [msprobe工具安装](https://gitee.com/ascend/mstt/blob/master/debug/accuracy_tools/msprobe/docs/01.installation.md) @@ -28,6 +40,8 @@ ``pip3 install tb-graph-ascend``即可。 +如需升级工具,请先``pip3 uninstall tb-graph-ascend``再``pip3 install tb-graph-ascend``即可。 + ## 2.模型结构数据采集 [MindSpore场景的精度数据采集](https://gitee.com/ascend/mstt/blob/master/debug/accuracy_tools/msprobe/docs/06.data_dump_MindSpore.md) @@ -45,12 +59,14 @@ msprobe -f mindspore graph -i ./compare.json -o ./output | 参数名 | 说明 | 是否必选 | |-------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -------- | +| -f 或 --framework | 指定训练框架。mindspore。 | 是 | | -i 或 --input_path | 指定比对文件,参考[比对文件说明](#313-比对文件说明) | 是 | | -o 或 --output_path | 配置比对结果文件存盘目录,str 类型。文件名称基于时间戳自动生成,格式为:`compare_{timestamp}.vis或build_{timestamp}.vis`。 | 是 | -| -lm 或 --layer_mapping| 跨框架比对,MindSpore和PyTorch的比对场景。配置该参数时表示开启跨框架Layer层的比对功能,指定模型代码中的Layer层后,可以识别对应dump数据中的模块或API。需要指定自定义映射文件*.yaml。自定义映射文件的格式请参见[自定义映射文件(Layer)](#71-自定义映射文件layer), 如何配置自定义映射文件请参考[模型分级可视化如何配置layer mapping映射文件](./visualization/layer_mapping_example.md)。 | 否 | +| -lm 或 --layer_mapping| 跨框架比对,MindSpore和PyTorch的比对场景。配置该参数时表示开启跨框架Layer层的比对功能,指定模型代码中的Layer层后,可以识别对应dump数据中的模块或API。需要指定自定义映射文件*.yaml。自定义映射文件的格式请参见[自定义映射文件(Layer)](#71-自定义映射文件layer), 如何配置自定义映射文件请参考[模型分级可视化如何配置layer mapping映射文件](./visualization/layer_mapping_example.md)。配置该参数后,将仅按节点名称进行比对,忽略节点的 type 和 shape。如果调试侧和标杆侧有名称不同的节点,则需要配置自定义映射文件,-lm参数传入自定义映射文件路径;如果调试侧和标杆侧节点名称相同,则仅指定-lm即可。 | 否 | | -oc 或 --overflow_check | 是否开启溢出检测模式,开启后会在输出vis文件中(`compare_{timestamp}.vis或build_{timestamp}.vis`)对每个溢出节点进行标记溢出等级,溢出等级说明参考[溢出等级说明](#312-溢出等级说明) | 否 | | -f 或 --fuzzy_match | 是否开启模糊匹配,bool类型。模糊匹配说明参考[匹配说明](#311-匹配说明) | 否 | | -cs 或 --complete_stack | 是否使用完整的堆栈信息,bool类型。默认使用精简的堆栈信息,数据量小有助于增加流畅度。完整堆栈和精简堆栈信息参考[堆栈信息说明](#72-堆栈信息说明) | 否 | +| -mm 或 --multi_mapping | 一对一、一对多、多对一、多对多节点映射,例如待调试侧若干小算子与标杆侧融合算子比对等场景,需要指定自定义映射文件*.yaml。自定义映射文件的格式请参见[自定义映射文件(multi)](#73-自定义映射文件multi) | 否 | #### 3.1.1 匹配说明 @@ -62,7 +78,7 @@ msprobe -f mindspore graph -i ./compare.json -o ./output - 节点的层级一致(父节点们一致) 2.模糊匹配 -- Cell节点dump名称一致,两个匹配上的Cell节点, 忽略各自节点下所有api的dump调用次数,按照名称一致+Cell节点内的调用顺序进行匹配 +- Cell节点dump名称一致,两个匹配上的Cell节点,忽略各自节点下所有api的dump调用次数,按照名称一致+Cell节点内的调用顺序进行匹配 - ![fuzzy_match_ms.png](./img/visualization/fuzzy_match_ms.png) - 参数shape一致 @@ -83,11 +99,12 @@ msprobe -f mindspore graph -i ./compare.json -o ./output ``` **比对文件参数说明**: -| 参数名 | 说明 | 是否必选 | -|-------------------|-------------------------------------------------------------------------------------------------------|------| -| npu_path | 指定待调试侧比对路径,str类型。工具根据路径格式自动进行单rank比对、多rank批量比对或多step批量比对,具体格式参考3.2 图构建和比对。 | 是 | -| bench_path | 指定标杆侧比对路径,str类型。单图构建场景可以不配置 | 否 | -| is_print_compare_log | 配置是否开启单个算子的日志打屏。可取值 true 或 false,默认为 true。关闭后则只输出常规日志,bool 类型。 | 否 | +| 参数名 | 说明 | 是否必选 | +|-------------------|--------------------------------------------------------------------------------------------------------------------------|------| +| npu_path | 指定待调试侧比对路径,str类型。工具根据路径格式自动进行单rank比对、多rank批量比对或多step批量比对,具体格式参考3.2 图构建和比对。 | 是 | +| bench_path | 指定标杆侧比对路径,str类型。单图构建场景可以不配置。 | 否 | +| is_print_compare_log | 配置是否开启单个算子的日志打屏。可取值 true 或 false,默认为 true。关闭后则只输出常规日志,bool 类型。 | 否 | +| parallel_merge | 配置是否开启不同切分策略下的图合并,dict类型。rank_size、tp、pp参数按实际情况进行配置。比对时配置npu、bench,只构图配置npu。 配置示例见[3.2.5 不同切分策略下的图合并](#325-不同切分策略下的图合并)。 | 否 | ### 3.2 图构建和比对 @@ -313,6 +330,33 @@ dump配置请参考[dump配置示例](./03.config_examples.md#35-task-配置为- 得到dump数据后,若需比较特定两个rank之间的数据,请参考[3.2.2 双图比对](#322-双图比对);若需进行多个rank或多个step的数据批量比对,请参考[3.2.3 批量构建或比对](#323-批量构建或比对)。 +#### 3.2.5 不同切分策略下的图合并 + +适用场景:不同Tensor Parallelism(TP)、Pipeline Parallelism(PP)切分策略下,两个模型产生了精度差异,需要进行整网数据比对,但被切分的数据分布于多rank中,需要将分布在各个rank的数据合并后再进行比对。 + +使用限制: + +- 当前支持基于Megatron、MindSpeed-LLM套件的模型进行图合并,其他套件的模型图合并效果有待验证; +- 当前仅支持msprobe工具dump的statistics数据; +- 图合并比对时要确保DP切分一致,例如rank=8 tp=1 pp=8的配置,dp=1,图合并将得到一张图,rank=8 tp=1 pp=4的配置,dp=2,图合并将得到两张图,暂不支持数量不一致的图进行比对。 + +使能方式: + +在compare.json里增加parallel_merge配置项, rank_size、tp、pp参数按实际情况进行配置; + +npu_path、bench_path的配置以及执行命令请参考[3.2.3 批量构建或比对](#323-批量构建或比对) + +``` +{ + "npu_path": "./npu_dump", + "bench_path": "./bench_dump", # 只进行图构建可不配置 + "is_print_compare_log": true, + "parallel_merge": { + "npu": {"rank_size": 8, "tp": 8, "pp": 1}, + "bench": {"rank_size": 8, "tp": 1, "pp": 8} # 只进行图构建可不配置 + } +} +``` ## 4.启动tensorboard @@ -329,11 +373,25 @@ tensorboard --logdir out_path --bind_all --port [可选,端口号] ubuntu是机器地址,6008是端口号。 -**注意,ubuntu需要替换为真实的服务器地址,例如真实的服务器地址为10.123.456.78,则需要在浏览器窗口输入http://10.123.456.78:6008** +**注意,ubuntu需要替换为真实的服务器地址,例如真实的服务器地址为10.123.456.78,则需要在浏览器窗口输入 http://10.123.456.78:6008** ### 4.2 不可直连的服务器 -**如果链接打不开(服务器无法直连需要挂vpn才能连接等场景),可以尝试使用vscode连接服务器,在vscode终端输入:** +**如果链接打不开(服务器无法直连需要挂vpn才能连接等场景),可以尝试以下方法,选择其一即可:** + +1.本地电脑网络手动设置代理,例如Windows10系统,在【手动设置代理】中添加服务器地址(例如10.123.456.78) + +![proxy](./img/visualization/proxy.png) + +然后,在服务器中输入: +``` +tensorboard --logdir out_path --bind_all --port 6008[可选,端口号] +``` + +最后,在浏览器窗口输入 http://10.123.456.78:6008 + +**注意,如果当前服务器开启了防火墙,则此方法无效,需要关闭防火墙,或者尝试后续方法** +2.或者使用vscode连接服务器,在vscode终端输入: ``` tensorboard --logdir out_path ``` @@ -341,6 +399,14 @@ tensorboard --logdir out_path 按住CTRL点击链接即可 +3.或者将构图结果件vis文件从服务器传输至本地电脑,在本地电脑中安装tb_graph_ascend插件查看构图结果 + +电脑终端输入: +``` +tensorboard --logdir out_path +``` +按住CTRL点击链接即可 + ## 5.浏览器查看 ### 5.1 浏览器打开图 @@ -361,35 +427,68 @@ tensorboard --logdir out_path ### 5.5 未匹配节点筛选 节点匹配规则: -1.名称一致 +参考[匹配说明](#311-匹配说明) ,不符合匹配规则的节点为无匹配节点,颜色标灰。适用于排查两个模型结构差异的场景。 -2.节点输入输出参数数量一致,参数type、shape一致 +![vis_unmatch_info.png](./img/visualization/vis_unmatch_info.png) -3.节点的层级一致(父节点们一致) +### 5.6 手动选择节点匹配 -![vis_unmatch_info.png](./img/visualization/vis_unmatch_info.png) +可通过浏览器界面,通过鼠标选择两个待匹配的灰色节点进行匹配。当前暂不支持真实数据模式。 + +![vis_match_info.png](./img/visualization/vis_match_info.png) ## 6.图比对说明 -### 颜色 +### 6.1 颜色 颜色越深,精度比对差异越大,越可疑,具体信息可见浏览器页面左下角颜色图例。 -### 疑似有精度问题判定 - -#### 真实数据模式 +#### 6.1.1 真实数据模式 节点中所有输入的最小双千指标和所有输出的最小双千分之一指标的差值,反映了双千指标的下降情况,**值越大精度差距越大,颜色标记越深**。 ``One Thousandth Err Ratio(双千分之一)精度指标:Tensor中的元素逐个与对应的标杆数据对比,相对误差小于千分之一的比例占总元素个数的比例,比例越接近1越好`` -#### 统计信息模式 +如果调试侧(NPU)节点的output指标中的最大值(MAX)或最小值(MIN)中存在 nan/inf/-inf,直接标记为最深颜色。 + +#### 6.1.2 统计信息模式 节点中输出的统计量相对误差,**值越大精度差距越大,颜色标记越深**。 -``相对误差:abs((npu统计值 - bench统计值) / bench统计值)`` +``相对误差:abs((npu统计值 - bench统计值) / bench统计值)`` + +如果调试侧(NPU)节点的output指标中的最大值(MAX)或最小值(MIN)中存在 nan/inf/-inf,直接标记为最深颜色。 -#### md5模式 +#### 6.1.3 md5模式 节点中任意输入输出的md5值不同。 +### 6.2 指标说明 + +精度比对从三个层面评估 API 的精度,依次是:真实数据模式、统计数据模式和 MD5 模式。比对结果分别有不同的指标。 + +**公共指标**: +- name: 参数名称,例如input.0 +- type: 类型,例如mindspore.Tensor +- dtype: 数据类型,例如BFloat32 +- shape: 张量形状,例如[32, 1, 32] +- Max: 最大值 +- Min: 最小值 +- Mean: 平均值 +- Norm: L2-范数 + +**真实数据模式指标**: +- Cosine: tensor 余弦相似度 +- EucDist: tensor 欧式距离 +- MaxAbsErr: tensor 最大绝对误差 +- MaxRelativeErr: tensor 最大相对误差 +- One Thousandth Err Ratio: tensor 相对误差小于千分之一的比例(双千分之一) +- Five Thousandth Err Ratio: tensor 相对误差小于千分之五的比例(双千分之五) + +**统计数据模式指标** +- (Max, Min, Mean, Norm) diff: 统计量绝对误差 +- (Max, Min, Mean, Norm) RelativeErr: 统计量相对误差 + +**MD5模式指标** +- md5: CRC-32 值 + ## 7.附录 ### 7.1 自定义映射文件(Layer) @@ -482,11 +581,36 @@ yaml文件中只需配置MindSpore与PyTorch模型代码中功能一致但名称 ] } ``` +### 7.3 自定义映射文件(multi) +支持一对一、一对多、多对一、多对多节点映射配置,**多个节点使用英文逗号,分隔开**。 + +配置多个节点时,如果待配置节点为Cell.layer3.Linear.forward.0、Cell.layer4.Linear.forward.0和Cell.layer5.Linear.forward.0,则Cell.layer4.Linear.forward.0无需配置,仅取首尾节点配置即可(Cell.layer3.Linear.forward.0,Cell.layer5.Linear.forward.0)。注意,**配置节点的先后顺序不能乱(construct.json中的节点名称顺序代表先后顺序,请参考[dump结果文件介绍](./06.data_dump_MindSpore.md#82-动态图场景))**,Cell.layer3.Linear.forward.0在前,就不能配置成Cell.layer5.Linear.forward.0,Cell.layer3.Linear.forward.0,会导致配置无效。 + +```yaml +# 一对一 +Cell.layer.Linear.forward.0: Cell.layer1.Linear.forward.0 +``` +```yaml +# 一对多 +Cell.layer.Linear.forward.0: Cell.layer1.Linear.forward.0,Cell.layer2.Linear.forward.0 +``` +```yaml +# 多对一 +Cell.layer1.Linear.forward.0,Cell.layer2.Linear.forward.0: Cell.layer.Linear.forward.0 +``` +```yaml +# 多对多 +Cell.layer3.Linear.forward.0,Cell.layer5.Linear.forward.0: Cell.layer1.Linear.forward.0,Cell.layer2.Linear.forward.0 +``` # FAQ 1. 图比对场景,节点呈现灰色,且没有精度比对数据,怎么处理? 节点呈现灰色,代表左边待调试侧节点与右边标杆侧节点没有匹配上,可能有以下几点原因: - **标杆侧确实没有能与待调试侧匹配上的节点**,属于代码实现上的差异,请确认此差异是否正常,是否会影响到整网精度。 -- **节点的输入或输出type、shape不一致,参数个数不一致,节点所在层级的父层级不一致**,导致节点无法匹配,具体匹配规则见[匹配说明](#311-匹配说明),可尝试使用模糊匹配功能,如何使用此功能请参考[构图命令行说明](#31-构图命令行说明)。如果是参数shape不一致,即使是模糊匹配功能也无法让节点匹配上,请检查参数shape不一致是否合理。 -- **节点名称不一致**,导致节点无法匹配,可使用layer mapping功能,如何使用此功能请参考[构图命令行说明](#31-构图命令行说明),如何自定义映射文件请参考[模型分级可视化如何配置layer mapping映射文件](./visualization/layer_mapping_example.md)。 +- **节点名称一致,但节点的输入或输出type、shape不一致,参数个数不一致,节点所在层级的父层级不一致,导致节点无法匹配** + - 具体匹配规则见[匹配说明](#311-匹配说明),可尝试使用模糊匹配功能,如何使用此功能请参考[构图命令行说明](#31-构图命令行说明); + - 如果是参数shape不一致,即使是模糊匹配功能也无法让节点匹配上,请检查参数shape不一致是否合理。 +- **节点名称不一致**,导致节点无法匹配,目前提供两种方法,选其一即可 + - 可使用layer mapping功能,如何使用此功能请参考[构图命令行说明](#31-构图命令行说明),如何自定义映射文件请参考[模型分级可视化如何配置layer mapping映射文件](./visualization/layer_mapping_example.md); + - 可通过浏览器页面手动选择未匹配节点进行匹配,请参考[手动选择节点匹配](#56-手动选择节点匹配)。 diff --git a/debug/accuracy_tools/msprobe/docs/23.generate_operator_PyTorch.md b/debug/accuracy_tools/msprobe/docs/23.generate_operator_PyTorch.md index 59e2755ec3e5a3939af3a20d19fda12031a9bf51..e7c8dc7de74930d6ef9c5ef2c172a9dda4d4a040 100644 --- a/debug/accuracy_tools/msprobe/docs/23.generate_operator_PyTorch.md +++ b/debug/accuracy_tools/msprobe/docs/23.generate_operator_PyTorch.md @@ -33,15 +33,15 @@ b. 在生成单API脚本时可以选择由工具构造随机数获得 dump 数 ``` **配置文件参数说明** - | 参数名称 | 解释 | 是否必选 | - | ---------------------------- | ------------------------------------------------------------ | ---------------------------------- | - | dump_json_path | dump.json的文件路径,包含所有dump算子的信息;如果已经提取了可疑算子并保存可以不指定。 | 否 | - | api_name | 算子名,如Functional.softmax.3、Tensor.add.0、Torch.matmul.5等。如果已经提取了可疑算子并保存可以不指定 | 否 | - | extract_api_path | 提取可疑算子的json文件路径 | 是 | - | propagation | 选择复现算子的forward还是backward,默认为forward | 否 | - | data_mode | 选择复现算子的随机数据(random_data)还是真实数据(real_data)模式,默认为random_data | 否 | - | random_seed | 仅random_data模式有效,表示手动设定的随机种子,默认为1234 | 否 | - | iter_times | 仅random_data模式有效,表示单API运行的次数 | 否 | + | 参数名称 | 解释 | 是否必选 | + | ---------------------------- |----------------------------------------------------------------------------| ---------------------------------- | + | dump_json_path | dump.json的文件路径,包含所有dump算子的信息;如果已经提取了可疑算子并保存可以不指定。 | 否 | + | api_name | 算子名,如Functional.softmax.3、Tensor.add.0、Torch.matmul.5等。如果已经提取了可疑算子并保存可以不指定 | 否 | + | extract_api_path | 提取可疑算子的json文件路径 | 是 | + | propagation | 选择复现算子的forward还是backward,默认为forward | 否 | + | data_mode | 选择复现算子的随机数据(random_data)还是真实数据(real_data)模式,默认为random_data | 否 | + | random_seed | 仅random_data模式有效,表示手动设定的随机种子,默认为1234 | 否 | + | iter_times | 仅random_data模式有效,表示单API运行的次数,由于安全相关原因,最大支持设置为1000 | 否 | ### 2.2 运行命令生成单API脚本 config_op.json配置好后,运行如下命令: diff --git a/debug/accuracy_tools/msprobe/docs/24.code_mapping_Mindspore.md b/debug/accuracy_tools/msprobe/docs/24.code_mapping_Mindspore.md index 05e3900d2647b07ed5334082e3ac519cfc7fb2b2..9e741305363e7ee8419a2df8661c6afad463201c 100644 --- a/debug/accuracy_tools/msprobe/docs/24.code_mapping_Mindspore.md +++ b/debug/accuracy_tools/msprobe/docs/24.code_mapping_Mindspore.md @@ -20,9 +20,10 @@ msprobe -f mindspore code_mapping --ir --dump_data [--outp ``` -| 参数名称 | 说明 |参数类型 | 是否必选 | -| ---------------------------- |-------------------------------------------------------------------------------------------------------------------------------------------|---------------------- | ---------------------------------- | -| --ir | 指定 MindSpore 静态图运行时生成的IR图文件。 | str | 是 | -| --dump_data | 指定dump数据文件(支持tensor或statistic模式的dump数据)。可指定单个dump数据 文件或dump数据文件的父目录,指定父目录表示关联目录下的所有dump数据文件。 | str | 是 | -| --output | 关联结果输出目录,默认为"./",只在tensor模式时生效,会把数据文件路径和代码调用栈的关联关系存到output路径下的code_mapping_{时间戳}.csv中。如果关联的是statistic模式,则会把statistic.csv中每个条目加上该条目对应的代码栈。 | str | 否 | +| 参数名称 | 说明 | 参数类型 | 是否必选 | +| ---------------------------- |-------------------------------------------------------------------------------------------------------------------------------------------|------| ---------------------------------- | +| -f 或 --framework | 指定训练框架。mindspore。 | str | 是 | +| --ir | 指定 MindSpore 静态图运行时生成的IR图文件。 | str | 是 | +| --dump_data | 指定dump数据文件(支持tensor或statistic模式的dump数据)。可指定单个dump数据 文件或dump数据文件的父目录,指定父目录表示关联目录下的所有dump数据文件。 | str | 是 | +| --output | 关联结果输出目录,默认为"./",只在tensor模式时生效,会把数据文件路径和代码调用栈的关联关系存到output路径下的code_mapping_{时间戳}.csv中。如果关联的是statistic模式,则会把statistic.csv中每个条目加上该条目对应的代码栈。 | str | 否 | diff --git a/debug/accuracy_tools/msprobe/docs/25.tool_function_introduction.md b/debug/accuracy_tools/msprobe/docs/25.tool_function_introduction.md index f6f5db9781223fc299df978dfd55a9d2af2e07e6..61e3c3c591bd6c778c1a4f6f47a393f3a5fd08c5 100644 --- a/debug/accuracy_tools/msprobe/docs/25.tool_function_introduction.md +++ b/debug/accuracy_tools/msprobe/docs/25.tool_function_introduction.md @@ -2,28 +2,29 @@ ## 1 PyTorch框架 -| 功能名(英文) | 简介 | 适用场景/优势 | 当前版本局限性 | -|------------------------------------------------------------------------------------|---------------------------------------------------------------|--------------------------------------------------------------------------|-----------------------------------------------------------------------------------------------------------------| -| [数据采集
(dump)](./05.data_dump_PyTorch.md) | 采集模型训练过程中的API或Module层级的前反向输入输出数据,包括层次关系、统计值信息、真实数据和调用栈等。 | 1、将模型中训练的API或Module的前反向输入输出数据保存下来分析
2、模型出现溢出时,可用于查看哪些API或Module出现了溢出 | 1、API级数据采集仅支持白名单列表上的API
2、工具会做一些同步操作,引入工具可能会导致一些同步问题消失
3、当前对inplace操作API或Module的支持度有限
4、暂不支持参数及参数梯度的采集 | -| [离线预检
(api_accuracy_checker)](./07.accuracy_checker_PyTorch.md) | 为网络中每个API创建用例,检验其精度,并根据不同比对算法综合判定API在NPU上的精度是否达标,快速找出精度差异API。 | 1、对模型中所有的API做精度初步排查
2、精度排查不受模型累计误差影响 | 1、依赖GPU环境
2、不支持通信算子
3、仅支持部分融合算子 | -| [整网比对
(compare)](./10.accuracy_compare_PyTorch.md) | 计算模型整网NPU和标杆设备的精度误差指标,标记精度异常API或Module,助力快速定位精度问题根因。 | 1、整网比对定位精度可疑算子 | 1、由于使用整网dump数据,定位的可疑算子受累计误差影响
2、当模型规模较大时,比对所需时间较长 | -| [在线预检
(online_api_accuracy_checker)](./08.accuracy_checker_online_PyTorch.md) | 通过TCP通信或共享存储空间的方式,进行在线精度预检,解决离线预检大数据量落盘、传输困难痛点。 | 1、使用离线预检,数据量较大落盘困难或传输耗时长时,可通过在线预检进行精度排查 | 1、依赖GPU环境,NPU和GPU能够通信
2、重计算模式下,不支持反向aten算子预检 | -| [溢出检查
(overflow_checker)](./12.overflow_check_PyTorch.md) | 检测模型计算过程的输入输出,并在溢出时落盘数据,助力用户快速定位溢出位置。 | 1、当模型出现溢出时,用于快速定位最先溢出的API或Module
2、相比数据采集,性能更优,磁盘压力更小 | 1、局限性同数据采集 | -| [数据解析
(parse_tool)](./14.data_parse_PyTorch.md) | 互交式界面处理解析kernel层级dump数据,便于查看分析。 | 1、比对kernel层级dump数据的一致性 | 1、仅限于NPU | -| [无标杆比对
(free_benchmark)](./15.free_benchmarking_PyTorch.md) | 不依赖标杆数据,通过对算子输入增加微小扰动,计算扰动后输出与原始输出的相对误差,识别有精度风险算子。 | 1、无标杆数据场景下的算子精度排查
2、对个别算子进行升精度、“to cpu”等操作,以验证其对模型loss的影响 | 1、由于需要拷贝输入进行二次执行,所以在遇到大张量的输入时容易发生显存OOM的问题, 特别是反向比对过程。建议结合白名单使用
2、比对会延长训练时间,整网比对可能会造成严重的耗时膨胀,建议结合白名单使用 | -| [梯度状态监测
(grad_probe)](./17.grad_probe.md) | 可导出模型权重梯度数据并对比相似度,助力确认训练过程精度问题step和反向中的异常。 | 1、需要分析梯度数据时
2、需要定位发生问题的step时 | 暂无 | -| [在线精度比对
(online_dispatch)](./18.online_dispatch.md) | 训练过程中直接完成NPU和CPU的精度比对并输出比对结果。 | 1、执行一次就可获取NPU和CPU分别执行后的精度比对结果 | 暂无 | -| [训练状态监控
(monitor)](./19.monitor.md) | 收集模型训练过程中的激活值、梯度和优化器状态,助力分析计算、通信、优化器各部分异常情况。 | 1、通过监控模块级统计量指标,快速定位异常模块位置,如loss出现nan | 1、仅支持模块级别统计量指标分析
2、仅支持megatron、deepspeed框架
3、少量增加时间和显存膨胀 | -| [可视化比对
(visualization) ](./21.visualization_PyTorch.md) | 解析dump的精度数据,还原模型图结构,比对各层级精度数据,助力理解模型结构、分析精度问题。 | 1、整网精度比对定位可疑算子,通过浏览器展示比对结果,支持快速搜索到可疑算子
2、支持查看模型层级结果,比对模型层级结构差异 | 1、由于使用整网dump数据,定位的可疑算子受累计误差影响
2、当模型规模较大时,比对所需时间较长 | -| [单API自动生成脚本
(generate_operator) ](./23.generate_operator_PyTorch.md) | 解析dump的精度数据,提取可疑的API算子,自动生成单API复现脚本,并根据不同的API采用不同的比对算法,给定最终比对结果数据;帮助开发者分析算子精度问题。 | 1、该工具支持从整网dump下来的数据中提取可疑算子,并自动生成单API脚本
2、除了支持复现单API的前反向过程,同时会根据不同的API选择不同的比对方法,并给出比对结果 |1、不支持通信算子
2、融合算子需手动修改脚本进行适配
3、目前比对的标杆均为和CPU进行比对,暂不支持直接NPU和GPU比对 +| 功能名(英文) | 简介 | 适用场景/优势 | 当前版本局限性 | +| --------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| [数据采集`
`(dump)](./05.data_dump_PyTorch.md) | 采集模型训练过程中的API或Module层级的前反向输入输出数据,包括层次关系、统计值信息、真实数据和调用栈等。 | 1、将模型中训练的API或Module的前反向输入输出数据保存下来分析`
` 2、模型出现溢出时,可用于查看哪些API或Module出现了溢出 | 1、API级数据采集仅支持白名单列表上的API`
`2、工具会做一些同步操作,引入工具可能会导致一些同步问题消失`
`3、当前对inplace操作API或Module的支持度有限`
`4、暂不支持参数及参数梯度的采集 | +| [离线预检`
`(api_accuracy_checker)](./07.accuracy_checker_PyTorch.md) | 为网络中每个API创建用例,检验其精度,并根据不同比对算法综合判定API在NPU上的精度是否达标,快速找出精度差异API。 | 1、对模型中所有的API做精度初步排查`
`2、精度排查不受模型累计误差影响 | 1、依赖GPU环境`
`2、不支持通信算子`
`3、仅支持部分融合算子 | +| [整网比对`
`(compare)](./10.accuracy_compare_PyTorch.md) | 计算模型整网NPU和标杆设备的精度误差指标,标记精度异常API或Module,助力快速定位精度问题根因。 | 1、整网比对定位精度可疑算子 | 1、由于使用整网dump数据,定位的可疑算子受累计误差影响`
`2、当模型规模较大时,比对所需时间较长 | +| [在线预检`
`(online_api_accuracy_checker)](./08.accuracy_checker_online_PyTorch.md) | 通过TCP通信或共享存储空间的方式,进行在线精度预检,解决离线预检大数据量落盘、传输困难痛点。 | 1、使用离线预检,数据量较大落盘困难或传输耗时长时,可通过在线预检进行精度排查 | 1、依赖GPU环境,NPU和GPU能够通信`
`2、重计算模式下,不支持反向aten算子预检 | +| [溢出检查`
`(overflow_checker)](./12.overflow_check_PyTorch.md) | 检测模型计算过程的输入输出,并在溢出时落盘数据,助力用户快速定位溢出位置。 | 1、当模型出现溢出时,用于快速定位最先溢出的API或Module`
`2、相比数据采集,性能更优,磁盘压力更小 | 1、局限性同数据采集 | +| [数据解析`
`(parse_tool)](./14.data_parse_PyTorch.md) | 交互式界面处理解析kernel层级dump数据,便于查看分析。 | 1、比对kernel层级dump数据的一致性 | 1、仅限于NPU | +| [无标杆比对`
`(free_benchmark)](./15.free_benchmarking_PyTorch.md) | 不依赖标杆数据,通过对算子输入增加微小扰动,计算扰动后输出与原始输出的相对误差,识别有精度风险算子。 | 1、无标杆数据场景下的算子精度排查`
`2、对个别算子进行升精度、“to cpu”等操作,以验证其对模型loss的影响 | 1、由于需要拷贝输入进行二次执行,所以在遇到大张量的输入时容易发生显存OOM的问题, 特别是反向比对过程。建议结合白名单使用`
`2、比对会延长训练时间,整网比对可能会造成严重的耗时膨胀,建议结合白名单使用 | +| [梯度状态监测`
`(grad_probe)](./17.grad_probe.md) | 可导出模型权重梯度数据并对比相似度,助力确认训练过程精度问题step和反向中的异常。 | 1、需要分析梯度数据时`
`2、需要定位发生问题的step时 | 暂无 | +| [在线精度比对`
`(online_dispatch)](./18.online_dispatch.md) | 训练过程中直接完成NPU和CPU的精度比对并输出比对结果。 | 1、执行一次就可获取NPU和CPU分别执行后的精度比对结果 | 暂无 | +| [训练状态监控`
`(monitor)](./19.monitor.md) | 收集模型训练过程中的激活值、梯度和优化器状态,助力分析计算、通信、优化器各部分异常情况。 | 1、通过监控模块级统计量指标,快速定位异常模块位置,如loss出现nan | 1、仅支持模块级别统计量指标分析`
`2、仅支持megatron、deepspeed框架`
`3、少量增加时间和显存膨胀 | +| [可视化比对`
`(visualization) ](./21.visualization_PyTorch.md) | 解析dump的精度数据,还原模型图结构,比对各层级精度数据,助力理解模型结构、分析精度问题。 | 1、整网精度比对定位可疑算子,通过浏览器展示比对结果,支持快速搜索到可疑算子`
`2、支持查看模型层级结果,比对模型层级结构差异 | 1、由于使用整网dump数据,定位的可疑算子受累计误差影响`
`2、当模型规模较大时,比对所需时间较长 | +| [单API自动生成脚本`
`(generate_operator) ](./23.generate_operator_PyTorch.md) | 解析dump的精度数据,提取可疑的API算子,自动生成单API复现脚本,并根据不同的API采用不同的比对算法,给定最终比对结果数据;帮助开发者分析算子精度问题。 | 1、该工具支持从整网dump下来的数据中提取可疑算子,并自动生成单API脚本`
`2、除了支持复现单API的前反向过程,同时会根据不同的API选择不同的比对方法,并给出比对结果 | 1、不支持通信算子`
`2、融合算子需手动修改脚本进行适配`
`3、目前比对的标杆均为和CPU进行比对,暂不支持直接NPU和GPU比对 | ## 2 MindSpore框架 -| 功能名(英文) | 简介 | 适用场景/优势 | 当前版本局限性 | -|----------------------------------------------------------------------|-------------------------------------------------------------------|------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------| -| [数据采集
(dump)](./06.data_dump_MindSpore.md) | 采集模型训练过程中的API或Cell层级的前反向输入输出数据,包括层次关系、统计值信息、真实数据和调用栈等。 | 1、将模型中训练的API或Cell的前反向输入输出数据保存下来分析
2、模型出现溢出时,可用于查看哪些API或Cell出现了溢出 | 1、API级数据采集仅支持白名单列表上的API
2、当前对inplace操作API或Cell的支持度有限
3、暂不支持参数及参数梯度的采集 | -| [离线预检
(api_accuracy_checker)](./09.accuracy_checker_MindSpore.md) | 为网络中每个API创建用例,检验其精度,并根据不同比对算法综合判定API在NPU上的精度是否达标,快速找出精度差异API。 | 1、对模型中所有的API做精度初步排查
2、精度排查不受模型累计误差影响 | 1、仅针对MindSpore.mint API | -| [整网比对
(compare)](./11.accuracy_compare_MindSpore.md) | NPU精度数据与标杆数据的比对,支持MindSpore框架内和与PyTorch跨框架的比对,助力快速定位精度异常API或Cell。 | 1、MindSpore同框架静态图比对
2、MindSpore同框架动态图比对
3、MindSpore vs PyTorch跨框架动态图比对 | 1、部分PyTorch的API关联不到MindSpore,需要手动配置映射关系 | -| [溢出检查
(overflow_checker)](./13.overflow_check_MindSpore.md) | 检测模型计算过程的输入输出,并在溢出时落盘数据,助力用户快速定位溢出位置。 | 1、当模型出现溢出时,可用于定位最先溢出的API或Cell或kernel
2、相比数据采集,性能更优,磁盘压力更小 | 1、除具有与数据采集功能相同的局限性外,动态图场景下,不支持 Primitive 和 Jit 类 API 的检测
2、动态图场景下,仅支持检测API或Cell级别溢出
3、静态图场景下,仅支持检测kernel级别溢出 | -| [无标杆比对
(free_benchmark)](./16.free_benchmarking_MindSpore.md) | 不依赖标杆数据,通过对算子输入增加微小扰动,计算扰动后输出与原始输出的相对误差,识别有精度风险算子。 | 1、无标杆数据场景下的算子精度排查
2、对个别算子进行升精度修复,验证其对模型loss的影响 | 1、仅支持动态图场景
2、由于需要拷贝输入进行二次执行,所以在遇到大张量的输入时容易发生显存OOM的问题, 特别是反向比对过程。建议结合白名单使用
3、比对会延长训练时间,整网比对可能会造成严重的耗时膨胀,建议结合白名单使用
4、不支持“to cpu”操作,不支持预热功能 | -| [可视化比对
(visualization) ](./22.visualization_MindSpore.md) | 解析dump的精度数据,还原模型图结构,比对各层级精度数据,助力理解模型结构、分析精度问题。 | 1、整网精度比对定位可疑算子,通过浏览器展示比对结果,支持快速搜索到可疑算子
2、支持查看模型层级结果,比对模型层级结构差异 | 1、由于使用整网dump数据,定位的可疑算子受累计误差影响
2、当模型规模较大时,比对所需时间较长 | +| 功能名(英文) | 简介 | 适用场景/优势 | 当前版本局限性 | +| ---------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| [数据采集 `
`(dump)](./06.data_dump_MindSpore.md) | 采集模型训练过程中的API或Cell层级的前反向输入输出数据,包括层次关系、统计值信息、真实数据和调用栈等。 | 1、将模型中训练的API或Cell的前反向输入输出数据保存下来分析 `
` 2、模型出现溢出时,可用于查看哪些API或Cell出现了溢出 | 1、API级数据采集仅支持白名单列表上的API `
`2、当前对inplace操作API或Cell的支持度有限 `
`3、暂不支持参数及参数梯度的采集 | +| [离线预检 `
`(api_accuracy_checker)](./09.accuracy_checker_MindSpore.md) | 为网络中每个API创建用例,检验其精度,并根据不同比对算法综合判定API在NPU上的精度是否达标,快速找出精度差异API。 | 1、对模型中所有的API做精度初步排查 `
`2、精度排查不受模型累计误差影响 | 1、仅针对MindSpore.mint API | +| [整网比对 `
`(compare)](./11.accuracy_compare_MindSpore.md) | NPU精度数据与标杆数据的比对,支持MindSpore框架内和与PyTorch跨框架的比对,助力快速定位精度异常API或Cell。 | 1、MindSpore同框架静态图比对 `
`2、MindSpore同框架动态图比对 `
`3、MindSpore vs PyTorch跨框架动态图比对 | 1、部分PyTorch的API关联不到MindSpore,需要手动配置映射关系 | +| [溢出检查 `
`(overflow_checker)](./13.overflow_check_MindSpore.md) | 检测模型计算过程的输入输出,并在溢出时落盘数据,助力用户快速定位溢出位置。 | 1、当模型出现溢出时,可用于定位最先溢出的API或Cell或kernel `
`2、相比数据采集,性能更优,磁盘压力更小 | 1、除具有与数据采集功能相同的局限性外,动态图场景下,不支持 Primitive 和 Jit 类 API 的检测 `
`2、动态图场景下,仅支持检测API或Cell级别溢出 `
`3、静态图场景下,仅支持检测kernel级别溢出 | +| [无标杆比对 `
`(free_benchmark)](./16.free_benchmarking_MindSpore.md) | 不依赖标杆数据,通过对算子输入增加微小扰动,计算扰动后输出与原始输出的相对误差,识别有精度风险算子。 | 1、无标杆数据场景下的算子精度排查 `
`2、对个别算子进行升精度修复,验证其对模型loss的影响 | 1、仅支持动态图场景 `
`2、由于需要拷贝输入进行二次执行,所以在遇到大张量的输入时容易发生显存OOM的问题, 特别是反向比对过程。建议结合白名单使用 `
`3、比对会延长训练时间,整网比对可能会造成严重的耗时膨胀,建议结合白名单使用 `
`4、不支持“to cpu”操作,不支持预热功能 | +| [可视化比对 `
`(visualization) ](./22.visualization_MindSpore.md) | 解析dump的精度数据,还原模型图结构,比对各层级精度数据,助力理解模型结构、分析精度问题。 | 1、整网精度比对定位可疑算子,通过浏览器展示比对结果,支持快速搜索到可疑算子 `
`2、支持查看模型层级结果,比对模型层级结构差异 | 1、由于使用整网dump数据,定位的可疑算子受累计误差影响 `
`2、当模型规模较大时,比对所需时间较长 | +| [训练状态监控 `
`(monitor)](./19.monitor.md) | 收集模型训练过程中的激活值、梯度和优化器状态,助力分析计算、通信、优化器各部分异常情况。 | 1、通过监控模块级统计量指标,快速定位异常模块位置,如loss出现nan | 1、仅支持模块级别统计量指标分析 `
`2、仅支持megatron、deepspeed框架 `
`3、少量增加时间和显存膨胀 | diff --git a/debug/accuracy_tools/msprobe/docs/26.data_dump_PyTorch_baseline.md b/debug/accuracy_tools/msprobe/docs/26.data_dump_PyTorch_baseline.md index 5ca199ab6171a3634af0b26844d6ba8e7d04933f..8127d7d7410eb46ebdaf510c6002de9724e8b108 100644 --- a/debug/accuracy_tools/msprobe/docs/26.data_dump_PyTorch_baseline.md +++ b/debug/accuracy_tools/msprobe/docs/26.data_dump_PyTorch_baseline.md @@ -1,8 +1,19 @@ # PyTorch 场景的精度数据采集基线 +## "statistics"模式(未开启md5)采集时间膨胀参考基线 + +该基线为PyTorch框架下,使用"statistics"模式采集数据性能膨胀的参考基线。本基线测试了LLAMA2-7B语言大模型在不同采集模式8卡下的时间膨胀。 + +| 采集模式 | 无工具 (耗时) | 加工具但未使能 Dump (耗时) | 加工具并使能 Dump (耗时) | +|:--------:|:--------:|:--------------------:|:------------------:| +| L0 | ≈17.4 s | ≈17.4 s (无膨胀) | ≈78.4 s (膨胀4.5倍) | +| L1 | ≈17.4 s | ≈20.7 s (膨胀1.2倍) | ≈353 s (膨胀20倍) | +| mix | ≈17.4 s | ≈20.7 s (膨胀1.2倍) | ≈430 s (膨胀24.7 倍) | + + ## "tensor"模式采集数据量参考基线 -该基线为pytorch框架下,使用"tensor"模式采集数据量参考基线。本基线测试了两个模型,分别为LLAMA2-7B和LLAMA2-13B,测试了不同采集模式下,不同global_batch_size下,单卡和8卡下,数据量的变化。 +该基线为PyTorch框架下,使用"tensor"模式采集数据量参考基线。本基线测试了两个模型,分别为LLAMA2-7B和LLAMA2-13B,测试了不同采集模式下,不同global_batch_size下,单卡和8卡下,数据量的变化。 ### LLAMA2-7B diff --git a/debug/accuracy_tools/msprobe/docs/27.dump_json_instruction.md b/debug/accuracy_tools/msprobe/docs/27.dump_json_instruction.md index f994dc2301bcae6b23dc7a7503297aa4fe5b3724..bf992a02aba6c9b4c6c1d18077775c0a8f4325ea 100644 --- a/debug/accuracy_tools/msprobe/docs/27.dump_json_instruction.md +++ b/debug/accuracy_tools/msprobe/docs/27.dump_json_instruction.md @@ -1,8 +1,8 @@ # dump.json文件说明及示例 -## 1. dump.json文件示例(PyTorch) +## 1. PyTorch 场景下的 dump.json 文件 -### 1.1 L0级别 +### 1.1 L0 级别 L0级别的dump.json文件包括模块的前反向的输入输出,以及模块的参数和参数梯度。以PyTorch的Conv2d模块为例,网络中模块调用代码为: `output = self.conv2(input) # self.conv2 = torch.nn.Conv2d(64, 128, 5, padding=2, bias=True)` @@ -168,7 +168,7 @@ dump.json文件中包含以下数据名称: } ``` -### 1.2 L1级别 +### 1.2 L1 级别 L1级别的dump.json文件包括API的前反向的输入输出。以PyTorch的relu函数为例,网络中API调用代码为: `output = torch.nn.functional.relu(input)` @@ -264,13 +264,13 @@ dump.json文件中包含以下数据名称: } ``` -### 1.3 mix级别 +### 1.3 mix 级别 mix级别的dump.json文件同时包括L0和L1级别的dump数据,文件格式与上述示例相同。 -## 2. dump.json文件示例(MindSpore) +## 2. MindSpore 场景下的 dump.json 文件 -### 2.1 L0级别 +### 2.1 L0 级别 L0级别的dump.json文件包括模块的前反向的输入输出,以及模块的参数和参数梯度。 以MindSpore的Conv2d模块为例,dump.json文件中使用的模块调用代码为: @@ -429,7 +429,7 @@ dump.json文件中包含以下数据名称: } ``` -### 2.2 L1级别 +### 2.2 L1 级别 L1级别的dump.json文件包括API的前反向的输入输出,以MindSpore的relu函数为例,网络中API调用代码为: `output = mindspore.ops.relu(input)` @@ -521,5 +521,275 @@ L1级别的dump.json文件包括API的前反向的输入输出,以MindSpore的 } ``` -### 2.3 mix级别 +### 2.3 mix 级别 + mix级别的dump.json文件同时包括L0和L1级别的dump数据,文件格式与上述示例相同。 + +## 3. MSAdapter 场景下的 dump.json 文件 + +### 3.1 L0 级别 + +L0 级别的 dump.json 文件包括模块的前反向的输入输出,以及模块的参数和参数梯度。以 Conv2d 模块为例,网络中模块调用代码为: +`output = self.conv2(input) # self.conv2 = torch.nn.Conv2d(64, 128, 5, padding=2, bias=True)` + +dump.json文件中包含以下数据名称: + +- `Module.conv2.Conv2d.forward.0`:模块的前向数据,其中input_args为模块的输入数据(位置参数),input_kwargs为模块的输入数据(关键字参数),output为模块的输出数据,parameters为模块的参数数据,包括权重(weight)和偏置(bias)。 +- `Module.conv2.Conv2d.parameters_grad`:模块的参数梯度数据,包括权重(weight)和偏置(bias)的梯度。 +- `Module.conv2.Conv2d.backward.0`:模块的反向数据,其中input为模块反向的输入梯度(对应前向输出的梯度),output为模块的反向输出梯度(对应前向输入的梯度)。 + +**说明**:当dump时传入的model参数为List[torch.nn.Module]或Tuple[torch.nn.Module]时,模块级数据的命名中包含该模块在列表中的索引index,命名格式为`{Module}.{index}.*`,*表示以上三种模块级数据的命名格式,例如:`Module.0.conv1.Conv2d.forward.0`。 + +```json +{ + "task": "tensor", + "level": "L0", + "framework": "mindtorch", + "dump_data_dir": "/dump/path", + "data": { + "Module.conv2.Conv2d.forward.0": { + "input_args": [ + { + "type": "mindspore.Tensor", + "dtype": "Float32", + "shape": [ + 8, + 16, + 14, + 14 + ], + "Max": 1.638758659362793, + "Min": 0.0, + "Mean": 0.2544615864753723, + "Norm": 70.50277709960938, + "requires_grad": true, + "data_name": "Module.conv2.Conv2d.forward.0.input.0.npy" + } + ], + "input_kwargs": {}, + "output": [ + { + "type": "mindspore.Tensor", + "dtype": "Float32", + "shape": [ + 8, + 32, + 10, + 10 + ], + "Max": 1.6815717220306396, + "Min": -1.5120246410369873, + "Mean": -0.025344856083393097, + "Norm": 149.65576171875, + "requires_grad": true, + "data_name": "Module.conv2.Conv2d.forward.0.output.0.npy" + } + ], + "parameters": { + "weight": { + "type": "mindspore.Tensor", + "dtype": "Float32", + "shape": [ + 32, + 16, + 5, + 5 + ], + "Max": 0.05992485210299492, + "Min": -0.05999220535159111, + "Mean": -0.0006165213999338448, + "Norm": 3.421217441558838, + "requires_grad": true, + "data_name": "Module.conv2.Conv2d.forward.0.parameters.weight.npy" + }, + "bias": { + "type": "mindspore.Tensor", + "dtype": "Float32", + "shape": [ + 32 + ], + "Max": 0.05744686722755432, + "Min": -0.04894155263900757, + "Mean": 0.006410328671336174, + "Norm": 0.17263513803482056, + "requires_grad": true, + "data_name": "Module.conv2.Conv2d.forward.0.parameters.bias.npy" + } + } + }, + "Module.conv2.Conv2d.parameters_grad": { + "weight": [ + { + "type": "mindspore.Tensor", + "dtype": "Float32", + "shape": [ + 32, + 16, + 5, + 5 + ], + "Max": 0.018550323322415352, + "Min": -0.008627401664853096, + "Mean": 0.0006675920449197292, + "Norm": 0.26084786653518677, + "requires_grad": false, + "data_name": "Module.conv2.Conv2d.parameters_grad.weight.npy" + } + ], + "bias": [ + { + "type": "mindspore.Tensor", + "dtype": "Float32", + "shape": [ + 32 + ], + "Max": 0.014914230443537235, + "Min": -0.006656786892563105, + "Mean": 0.002657240955159068, + "Norm": 0.029451673850417137, + "requires_grad": false, + "data_name": "Module.conv2.Conv2d.parameters_grad.bias.npy" + } + ] + }, + "Module.conv2.Conv2d.backward.0": { + "input": [ + { + "type": "mindspore.Tensor", + "dtype": "Float32", + "shape": [ + 8, + 32, + 10, + 10 + ], + "Max": 0.0015069986693561077, + "Min": -0.001139344065450132, + "Mean": 3.3215508210560074e-06, + "Norm": 0.020567523315548897, + "requires_grad": false, + "data_name": "Module.conv2.Conv2d.backward.0.input.0.npy" + } + ], + "output": [ + { + "type": "mindspore.Tensor", + "dtype": "Float32", + "shape": [ + 8, + 16, + 14, + 14 + ], + "Max": 0.0007466732058674097, + "Min": -0.00044813455315306783, + "Mean": 6.814070275140693e-06, + "Norm": 0.01474067009985447, + "requires_grad": false, + "data_name": "Module.conv2.Conv2d.backward.0.output.0.npy" + } + ] + } + } +} +``` + +### 3.2 L1 级别 +L1级别的dump.json文件包括API的前反向的输入输出。以 relu API 为例,网络中 API 调用代码为: +`output = torch.nn.functional.relu(input)` + +dump.json文件中包含以下数据名称: +- `Functional.relu.0.forward`:API的前向数据,其中input_args为API的输入数据(位置参数),input_kwargs为API的输入数据(关键字参数),output为API的输出数据。 +- `Functional.relu.0.backward`:API的反向数据,其中input为API的反向输入梯度(对应前向输出的梯度),output为API的反向输出梯度(对应前向输入的梯度)。 + +```json +{ + "task": "tensor", + "level": "L1", + "framework": "mindtorch", + "dump_data_dir":"/dump/path", + "data": { + "Functional.relu.0.forward": { + "input_args": [ + { + "type": "mindspore.Tensor", + "dtype": "Float32", + "shape": [ + 32, + 16, + 28, + 28 + ], + "Max": 1.3864083290100098, + "Min": -1.3364859819412231, + "Mean": 0.03711778670549393, + "Norm": 236.20692443847656, + "requires_grad": true, + "data_name": "Functional.relu.0.forward.input.0.npy" + } + ], + "input_kwargs": {}, + "output": [ + { + "type": "mindspore.Tensor", + "dtype": "Float32", + "shape": [ + 32, + 16, + 28, + 28 + ], + "Max": 1.3864083290100098, + "Min": 0.0, + "Mean": 0.16849493980407715, + "Norm": 175.23345947265625, + "requires_grad": true, + "data_name": "Functional.relu.0.forward.output.0.npy" + } + ] + }, + "Functional.relu.0.backward": { + "input": [ + { + "type": "mindspore.Tensor", + "dtype": "Float32", + "shape": [ + 32, + 16, + 28, + 28 + ], + "Max": 0.0001815402356442064, + "Min": -0.00013352684618439525, + "Mean": 0.00011915402356442064, + "Norm": 0.007598237134516239, + "requires_grad": false, + "data_name": "Functional.relu.0.backward.input.0.npy" + } + ], + "output": [ + { + "type": "mindspore.Tensor", + "dtype": "Float32", + "shape": [ + 32, + 16, + 28, + 28 + ], + "Max": 0.0001815402356442064, + "Min": -0.00012117840378778055, + "Mean": 2.0098118724831693e-08, + "Norm": 0.006532244384288788, + "requires_grad": false, + "data_name": "Functional.relu.0.backward.output.0.npy" + } + ] + } + } +} +``` + +### 3.3 mix 级别 + +mix级别的dump.json文件同时包括L0和L1级别的dump数据,文件格式与上述示例相同。 \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/docs/28.kernel_dump_MindSpore.md b/debug/accuracy_tools/msprobe/docs/28.kernel_dump_MindSpore.md index 6b8cc558aa22526158033cfb35f31203d8b04278..4988586c0568b391739f7c14f1a9452461f1a6f1 100644 --- a/debug/accuracy_tools/msprobe/docs/28.kernel_dump_MindSpore.md +++ b/debug/accuracy_tools/msprobe/docs/28.kernel_dump_MindSpore.md @@ -1,4 +1,4 @@ -# MindSpore 场景的 kernel dump 说明 +# MindSpore 动态图场景的 kernel dump 说明 当使用 msprobe 数据采集功能时,level 配置为 "L2" 表示采集 kernel 层级的算子数据,仅支持昇腾 NPU 平台。 diff --git a/debug/accuracy_tools/msprobe/docs/29.data_dump_MSAdapter.md b/debug/accuracy_tools/msprobe/docs/29.data_dump_MSAdapter.md new file mode 100644 index 0000000000000000000000000000000000000000..6549b15e7adbb8b8bc66102820672c67b0b30437 --- /dev/null +++ b/debug/accuracy_tools/msprobe/docs/29.data_dump_MSAdapter.md @@ -0,0 +1,229 @@ +# MSAdapter 场景的精度数据采集 + +MSAdapter 是一款 MindSpore 生态适配工具,可以将 PyTorch 训练脚本高效迁移至 MindSpore 框架执行,以实现在不改变原有 PyTorch 用户开发习惯的情况下,使得 PyTorch 代码能在昇腾上获得高效性能。 + +msprobe 工具主要通过在训练脚本内添加 dump 接口、启动训练的方式采集精度数据。 + +本工具提供固定的 API 支持列表,若需要删除或增加 dump 的 API,可以在 msprobe/pytorch/hook_module/support_wrap_ops.yaml 文件内手动修改,如下示例: + +```yaml +functional: # functional为算子类别,找到对应的类别,在该类别下按照下列格式删除或添加API + - conv1d + - conv2d + - conv3d +``` + +删除 API 的场景:部分模型代码逻辑会存在 API 原生类型校验,工具执行dump操作时,对封装后的模型 API 可能与模型的原生 API 类型不一致,此时可能引发校验失败,详见《[FAQ](FAQ.md#33-异常情况)》中“异常情况”的第10和11条。 + +## 1. 工具安装 + +请参见[《msprobe 工具安装指南》](./01.installation.md)。 + +## 2 接口介绍 + +### 2.1 msprobe.mindspore.PrecisionDebugger + +**功能说明**:通过加载 dump 配置文件的方式来确定 dump 操作的详细配置。 + +**原型**: + +```Python +PrecisionDebugger(config_path=None, task=None, dump_path=None, level=None, step=None) +``` + +**参数说明**: + +1. config_path:指定 dump 配置文件路径,string 类型。参数示例:"./config.json"。未配置该路径时,默认使用 [config.json](../config.json) 文件的默认配置,配置选项含义可见 [config.json 介绍](./02.config_introduction.md)。 + +2. 其他参数与 [config.json](../config.json) 文件中的同名配置字段含义相同,具体可见 [config.json 介绍](./02.config_introduction.md)。当参数值非None时,优先级高于 [config.json](../config.json) 文件中的同名配置。 + +#### 2.1.1 start + +**功能说明**:启动精度数据采集。需要与 [**stop**](#212-stop) 接口一起添加在训练迭代的 for 循环内。 + +**原型**: + +```Python +start(model=None) +``` + +**参数说明**: + +1. model:指定需要采集 Module 级数据的模型,支持传入 torch.nn.Module、list[torch.nn.Module]或Tuple[torch.nn.Module] 类型,默认未配置。level 配置为 "L0" 或 "mix" 时,必须在该接口中配置该参数。API级别("L1" level)dump 时,传入 model 可以采集 model 内包含 primitive op 对象在内的所有 API 数据,若不传入 model 参数,则只采集非 primitive op 的 API 数据。 + +#### 2.1.2 stop + +**功能说明**:停止精度数据采集。在 **start** 接口调用之后的任意位置添加。若 **stop** 接口添加在反向计算代码之后,则会采集 **start** 和该接口之间的前反向数据。 +若 **stop** 接口添加在反向计算代码之前,则需要将 [**step**](#213-step) 接口添加到反向计算代码之后,才能采集 **start** 和该接口之间的前反向数据。 + +**注意**:**stop** 接口必须调用,否则可能导致精度数据落盘不全。 + +**原型**: + +```Python +stop() +``` + +#### 2.1.3 step + +**功能说明**:进行训练 step 数的自增,完成当前 step 所有数据的落盘并更新 dump 参数。在一个 step 训练结束的位置添加,且必须在 **stop** 接口之后的位置调用。该接口需要配合 **start** 和 **stop** 函数使用,尽量添加在反向计算代码之后,否则可能会导致反向数据丢失。 + +**原型**: + +```Python +step() +``` + +#### 2.1.4 forward_backward_dump_end + +**功能说明**:停止精度数据采集。与 **stop** 接口功能相同,该函数在将来会被移除,建议使用 **stop** 接口。 + +**原型**: + +```Python +forward_backward_dump_end() +``` + +#### 2.1.5 save + +**功能说明**:单点保存网络执行过程中正反向数值,并以统计值/张量文件落盘。 + +**原型**: +```python +save(variable, name, save_backward=True) +``` + +**参数说明**: +| 参数名称 | 参数含义 | 支持数据类型 | 是否必选| +| ---------- | ------------------| ------------------- | ------------------- | +| variable | 需要保存的变量 |dict, list, tuple, torch.tensor, int, float, str | 是 | +| name | 指定的名称 | str | 是 | +| save_backward | 是否保存反向数据 | boolean | 否 | + +### 2.2 msprobe.mindspore.seed_all + +**功能说明**:用于固定网络中的随机性和开启确定性计算。 + +**原型**: +```python +seed_all(seed=1234, mode=False, rm_dropout=True) +``` + +**参数说明**: + +1. seed: 随机性种子,默认值:1234,非必选。参数示例: seed=1000。该参数用于 random、numpy.random, mindspore.common.Initializer、mindspore.nn.probability.distribution的随机数生成以及 Python 中 str、bytes、datetime 对象的 hash 算法。 + +2. mode:确定性计算使能,可配置 True 或 False,默认值:False,非必选。参数示例:mode=True。该参数设置为 True 后,将会开启算子确定性运行模式与归约类通信算子(AllReduce、ReduceScatter、Reduce)的确定性计算。注意:确定性计算会导致 API 执行性能降低,建议在发现模型多次执行结果不同的情况下开启。 + +3. rm_dropout:控制 dropout 失效的开关。可配置 True 或 False,默认值:True,非必选。参数示例:rm_dropout=True。该参数设置为 True 后,将会使 mindspore.ops.Dropout,mindspore.ops.Dropout2D,mindspore.ops.Dropout3D,mindspore.mint.nn.Dropout和mindspore.mint.nn.functional.dropout 失效,以避免因随机 dropout 造成的网络随机性。建议在采集数据前调用。 + +**注意**:通过 rm_dropout 控制 dropout 失效或生效需要在初始化 Dropout 实例前调用才能生效。 + +## 3 示例代码 + +以下为添加了 msprobe 工具 dump 接口的示例训练脚本。 + +```python +import mindspore as ms +import torch +import torch.nn as nn +import torch.nn.functional as F + +# 导入工具的数据采集接口 +from msprobe.mindspore import PrecisionDebugger + +# 在模型训练开始前实例化PrecisionDebugger +debugger = PrecisionDebugger(config_path='./config.json') + + +# 定义网络 +class Net(nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear1 = nn.Linear(in_features=8, out_features=4) + self.linear2 = nn.Linear(in_features=4, out_features=2) + + def forward(self, x): + x1 = self.linear1(x) + x2 = self.linear2(x1) + logits = F.relu(x2) + return logits + + +net = Net() + + +def train_step(inputs): + return net(inputs) + + +if __name__ == "__main__": + data = (torch.randn(10, 8), torch.randn(10, 8), torch.randn(10, 8)) + grad_fn = ms.value_and_grad(train_step, grad_position=0) + + for inputs in data: + # 开启数据 dump + debugger.start(model=net) + + out, grad = grad_fn(inputs) + + # 停止数据 dump + debugger.stop() + # 更新 step 信息 + debugger.step() +``` + +## 4 dump 结果文件介绍 + +训练结束后,工具将 dump 的数据保存在 dump_path 参数指定的目录下。目录结构示例如下: + +```lua +├── dump_path +│ ├── step0 +│ | ├── rank0 +│ | │ ├── dump_tensor_data +| | | | ├── Tensor.permute.1.forward.npy +| | | | ├── Functional.linear.5.backward.output.npy # 命名格式为{api_type}.{api_name}.{API调用次数}.{forward/backward}.{input/output}.{参数序号}, 其中,“参数序号”表示该API的第n个输入或输出,例如1,则为第一个参数,若该参数为list格式,则根据list继续排序,例如1.1,表示该API的第1个参数的第1个元素。 +| | | | ... +| | | | ├── Module.conv1.Conv2d.forward.0.input.0.npy # 命名格式为{Module}.{module_name}.{class_name}.{forward/backward}.{调用次数}.{input/output}.{参数序号}, 其中,“参数序号”表示该Module的第n个参数,例如1,则为第一个参数,若该参数为list格式,则根据list继续排序,例如1.1,表示该Module的第1个参数的第1个元素。 +| | | | ├── Module.conv1.Conv2D.forward.0.parameters.bias.npy # 模块参数数据:命名格式为{Module}.{module_name}.{class_name}.forward.{调用次数}.parameters.{parameter_name}。 +| | | | └── Module.conv1.Conv2D.parameters_grad.weight.npy # 模块参数梯度数据:命名格式为{Module}.{module_name}.{class_name}.parameters_grad.{parameter_name}。因为同一模块的参数使用同一梯度进行更新,所以参数梯度文件名不包含调用次数。 +| | | | # 当dump时传入的model参数为List[torch.nn.Module]或Tuple[torch.nn.Module]时,模块级数据的命名中包含该模块在列表中的索引index,命名格式为{Module}.{index}.*,*表示以上三种模块级数据的命名格式,例如:Module.0.conv1.Conv2d.forward.0.input.0.npy。 +│ | | ├── dump.json +│ | | ├── stack.json +│ | | └── construct.json +│ | ├── rank1 +| | | ├── dump_tensor_data +| | | | └── ... +│ | | ├── dump.json +│ | | ├── stack.json +| | | └── construct.json +│ | ├── ... +│ | | +| | └── rank7 +│ ├── step1 +│ | ├── ... +│ ├── step2 +``` +* `rank`:设备 ID,每张卡的数据保存在对应的 `rank{ID}` 目录下。非分布式场景下没有 rank ID,目录名称为 rank。 +* `dump_tensor_data`:保存采集到的张量数据。 +* `dump.json`: 保存 API 或 Module 前反向数据的统计量信息。包含 dump 数据的 API 名称或 Module 名称,各数据的 dtype、 shape、max、min、mean、L2norm(L2范数,平方根)统计信息以及当配置 summary_mode="md5" 时的 CRC-32 数据。具体介绍可参考[dump.json文件说明](./27.dump_json_instruction.md#3-msadapter-场景下的-dumpjson-文件)。 +* `stack.json`:API/Module 的调用栈信息。 +* `construct.json`:分层分级结构,level 为 L1 时,construct.json 内容为空。 + + +当 task 为 tensor 时,dump 过程中,npy 文件在对应算子或者模块被执行后就会落盘,而 json 文件则需要在正常执行 PrecisionDebugger.stop() 后才会写入完整数据。因此如果程序异常终止,终止前被执行算子的相关 npy 文件得以保存,但 json 文件中的数据可能丢失。 + +其中 rank 为设备上各卡的 ID,每张卡上 dump 的数据会生成对应 dump 目录。非分布式场景下没有 rank ID,目录名称为 rank。 + +npy 文件名的前缀含义如下: + +| 前缀 | 含义 | +| ----------- | ---------------------------- | +| Tensor | torch.Tensor API数据 | +| Torch | torch API数据 | +| Functional | torch.nn.functional API数据 | +| NPU | NPU 亲和API数据 | +| Distributed | torch.distributed API数据 | +| Jit | 被 "jit" 装饰的模块或函数数据 | +| Module | torch.nn.Module 类(模块)数据 | \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/docs/30.overflow_check_MSAdapter.md b/debug/accuracy_tools/msprobe/docs/30.overflow_check_MSAdapter.md new file mode 100644 index 0000000000000000000000000000000000000000..e963a60e8361be2569e7f85ee0d97df9194d6d91 --- /dev/null +++ b/debug/accuracy_tools/msprobe/docs/30.overflow_check_MSAdapter.md @@ -0,0 +1,31 @@ +# MSAdapter 场景的溢出检测 + +msprobe 工具提供 MSAdapter 场景下的溢出检测功能。其检测对象为 **API** 级别(除 Primitive 和 Jit 类 API)或**模块**级别,分别对应 config.json 配置中的 **"L1"** 、**"L0"** level。 + +需要注意,本工具仅支持在 INF/NAN 模式a下进行溢出检测。INF/NAN 模式的使能方式如下: + +```Shell +# 使能 CANN 侧 INF/NAN 模式 +export INF_NAN_MODE_ENABLE=1 +# 使能 MindSpore 框架侧 INF/NAN 模式 +export MS_ASCEND_CHECK_OVERFLOW_MODE="INFNAN_MODE" +``` + +**a**:在处理浮点数计算溢出问题时,NPU 当前支持两种溢出模式:INF/NAN 模式与饱和模式。INF/NAN 模式遵循 IEEE 754 标准,根据定义输出 INF/NAN 的计算结果。与之对应的饱和模式在计算出现溢出时,饱和为浮点数极值(+-MAX)。对于 CANN 侧配置,Atlas 训练系列产品,默认为饱和模式,且不支持使用 INF/NAN 模式;Atlas A2训练系列产品,默认为 INF/NAN 模式,且不建议使用饱和模式。对于 MindSpore 框架侧配置,仅支持对 Atlas A2 训练系列产品进行设置,默认为 INF/NAN 模式。CANN 侧 与 MindSpore 框架侧配置须一致。 + +溢出检测任务的配置示例见["**MindSpore 动态图场景 task 配置为 overflow_check**"](./03.config_examples.md#33-task-配置为-overflow_check)小节。 + + +## 1 接口介绍 + +溢出检测功能提供的接口与数据采集任务一致,详见 MSAdapter 场景的精度数据采集中的["**2 接口介绍**"](./29.data_dump_MSAdapter.md#2-接口介绍)小节。 + +需要注意,目前暂不支持 "L1" level 下 primitive op 的溢出检测。 + +## 2 示例代码 + +溢出检测功能使用方式与数据采集任务一致,详见 MSAdapter 场景的精度数据采集中的["**3 示例代码**"](./29.data_dump_MSAdapter.md#3-示例代码)小节。 + +## 3 溢出检测结果文件介绍 + +溢出检测结果文件目录结构与含义与数据采集任务一致,但仅保存溢出 API 或 模块 的真实数据或统计信息。详见 MSAdapter 场景的精度数据采集中的["**4 dump 结果文件介绍**"](./29.data_dump_MSAdapter.md#4-dump-结果文件介绍)小节。 \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/docs/31.config_checking.md b/debug/accuracy_tools/msprobe/docs/31.config_checking.md new file mode 100644 index 0000000000000000000000000000000000000000..6696f9bce205e82ded6cf886fa2dcf69252728e6 --- /dev/null +++ b/debug/accuracy_tools/msprobe/docs/31.config_checking.md @@ -0,0 +1,95 @@ +# config check + +## 介绍 + +该工具主要适用于对比两个环境下可能影响训练精度的配置差异,支持mindspore和pytorch两个框架,包括: + +- 环境变量 +- 三方库版本 +- 训练超参 +- 权重 +- 数据集 +- 随机操作 + + +## 安装教程 + +参见 msprobe [安装教程](./01.installation.md) + +## 使用说明 + +用户需要在两个待比对的训练的环境上分别进行数据采集, 工具会采集两个环境下影响精度的配置,采集结果上传到同一机器进行比对。 + +### 数据采集 + +#### 静态数据采集 + +静态数据采集仅支持环境变量,三方库版本及训练超参采集,其中环境变量,三方库版本默认采集,训练超参采集需要用户传入启动训练的 shell 脚本路径或 yaml 配置文件, +支持多个输入,不传入表示不采集。 + +启动命令如下 +```shell +msprobe -f pytorch/mindspore config_check -d **.sh **.yaml -o output_path +``` +-f或--framework 代表训练框架,传入pytorch或mindspore,必选。 + +-d或--dump 代表数据采集模式,可传入启动训练的 shell 脚本路径或 yaml 配置文件路径,可选,不传入代表不采集。 + +-o或--output 代表输出路径,可选,默认为 config_check_pack.zip。 + +#### 动态数据采集 + + +在训练流程执行到的第一个python脚本开始处插入如下代码: +``` +from msprobe.core.config_check import ConfigChecker +ConfigChecker.apply_patches(fmk) +``` + +说明: + +- fmk:训练框架。可选 pytorch 和 mindspore ,不传默认为 pytorch。 + +在模型初始化好之后插入如下代码: +``` +from msprobe.core.config_check import ConfigChecker +ConfigChecker(model, shell_path, output_zip_path, fmk) +``` + +说明: + +- model:初始化好的模型。不传或缺省就不会采集权重和数据集。 +- shell_path:动态采集模式下支持 **megatron** 训练超参自动捕获,使用 **megatron** 时推荐不传入,其他情况下可传入训练脚本路径,类型为列表,传入一个或多个训练配置/启动脚本。不传或缺省就不会采集超参。 +- output_zip_path:输出zip包的路径,不传默认为"./config_check_pack.zip"。 +- fmk:当前是什么框架。可选 pytorch 和 mindspore ,不传默认为 pytorch。 + +采集完成后会得到一个zip包,里面包括各项[影响精度的配置](#介绍)。 + +在另一个环境上执行上述操作,得到另一个zip包 + +### 数据比对 + +将两个zip包传到同一个环境下,使用如下命令进行比对: + +```shell +msprobe -f pytorch config_check -c bench_zip_path cmp_zip_path -o output_path +``` + +-c或--compare 表示compare,数据对比,有两个参数。其中**bench_zip_path** 为标杆侧采集到的数据, **cmp_zip_path** 为待对比侧采集到的数据。 + +**output_path 会被删掉再新建**,不传默认为"./config_check_result", 在 **output_path** 里会生成2个目录和1个文件: +- bench:bench_zip_path里打包的数据。 +- cmp:cmp_zip_path里打包的数据。 +- result.xlsx:比对结果。里面会有多个sheet页,其中**summary**总览通过情况,其余页是具体检查项的详情。其中step为micro_step。 + +## 通过标准 + +以下五项检查通过: + +- 环境变量 +- 三方库版本 +- 训练超参 +- 权重 +- 数据集 + +这五项检查在**精度比对**前必须保证达成。 diff --git a/debug/accuracy_tools/msprobe/docs/32.checkpoint_compare.md b/debug/accuracy_tools/msprobe/docs/32.checkpoint_compare.md new file mode 100644 index 0000000000000000000000000000000000000000..c49b4bfc8ee079cfdf2583c0c84372fe74aec6a7 --- /dev/null +++ b/debug/accuracy_tools/msprobe/docs/32.checkpoint_compare.md @@ -0,0 +1,60 @@ +# 权重比对 + +msprobe 工具提供大模型权重比对功能。当前支持pytorch下megatron/mindspeed不同模型并行策略下的权重互相比对。 + +> **Attention:** Ensure megatron in the PYTHONPATH to load a megatron checkpoint. + +## 1. 工具安装 + +[msprobe工具安装](https://gitee.com/ascend/mstt/blob/master/debug/accuracy_tools/msprobe/docs/01.installation.md) + +## 2. 工具使用 +```shell +msprobe -f pytorch config_checking -c PATH/TO/A/CHECKPOINT PATH/TO/THE/OTHER/CHECKPOINT -s -o PATH/FOR/OUTPUT +``` + +**命令行参数说明**: + +| 参数名 | 说明 | 是否必选 | +|------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------| +| -c --compare | 需要比较的两个checkpoint路径 | 是 | +| -s --ckpt-sim | store_true。使能权重比对功能,否则为配置比对 | 是 | +| -o 或 --out | 权重比对结果文件存盘目录,默认为'ckpt_compare_out.json' | 否 | + + + + +Sample stdout: +```txt +Loaded checkpoint from iteration x +Found xxx total parameters across all ranks +Loaded checkpoint from iteration x +Found xxx total parameters across all ranks +2025-03-25 08:24:48 (552546) [WARNING] Parameters not in ckpt2: set() +2025-03-25 08:24:48 (552546) [WARNING] Parameters not in ckpt1: set() +... +[INFO] Comparison results written to ckpt_compare_out.json +``` + +Sample result: +```json +{ + "embedding.word_embeddings.weight": { + "l2": 0.0, + "cos": 1.0, + "numel": 25755648, + "shape": [ + 50304, + 512 + ] + }, + "decoder.layers.0.input_layernorm.bias": { + "l2": 0.0, + "cos": 0.9999999403953552, + "numel": 512, + "shape": [ + 512 + ] + } +} +``` \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/docs/33.nan_analyze.md b/debug/accuracy_tools/msprobe/docs/33.nan_analyze.md new file mode 100644 index 0000000000000000000000000000000000000000..37b0be8bb4fba80e30d93eee4de99fd19e9bb47b --- /dev/null +++ b/debug/accuracy_tools/msprobe/docs/33.nan_analyze.md @@ -0,0 +1,73 @@ +# 整网首个溢出节点分析 + +## 介绍 +在分析inf、nan的场景下,会采集多个rank下的多个step的dump数据,前面出现的异常会传播到同rank后续的节点,并通过通信算子传播到其他rank的后续节点中,因此如何分析首个nan出现的节点位置尤为重要。 + +通过nan_analyze工具可以对pytorch的dump数据进行分析。在多卡场景下,检测到每张卡中产生inf/nan的节点。若是经过通信导致的inf/nan,可以分析并找出首个产生inf/nan的rank和节点。 + +## 安装教程 + +参见 msprobe [安装教程](./01.installation.md)。 + +## 使用说明 + +当前仅支持分析pytorch的dump数据。 + +### 采集数据 + +参见 [PyTorch 场景的精度数据采集](./05.data_dump_PyTorch.md)。 + +### 执行命令 + +```commandline +msprobe -f pytorch nan_analyze -i dump_step_path -o output_dir_path +``` + +| 参数 | 说明 | +|--------------------|---------------------------------------------| +| -f 或 --framework | 指定训练框架。pytorch。必选。 | +| -i 或 --input_path | dump数据的目录。需指定到step层级,如`-i /xxx/dump/step0/` | +| -o 或 --output_path | 输出文件的目录,可选,不填时默认在当前目录下创建 \"./output/" 目录。 | + +### 输出文件介绍 + +当日志打印 +``` +Cannot find any anomaly node, no need to generate analyze file. +``` +时,分析认为不存在异常节点,不生成分析文件。 + +存在异常节点时,生成`anomaly_analyze_{timestamp}.json`文件,结构为: +```json +{ + "rank_0": [ // 卡号 + { + "op_name": "Tensor.op_name.0.forward", // 节点名 + "data_info": { + "input_args": [], // input_args数据 + "input_kwargs": {}, // input_kwargs数据 + "output": [] // output数据 + }, + "construct_info": [], // 节点层级数据 + "stack_info": {} // 堆栈数据 + } + ] +} +``` + +## 异常判定 + +### 异常计算节点判定 +当某个计算节点的输入值正常,即Max或Min中不存在inf或nan,而输出值存在异常时认为从此节点开始产生了溢出,并有可能向后传递。 + +### 异常通信节点判定 +通信节点按照功能分为有向节点,如`send`, `recv`, `scatter`, `gather`, `broadcast`, `reduce`等,以及无向节点,如`all_gather`, `all_reduce`, `reduce_scatter`, `all_to_all`等。 + +对于有向节点,当src节点的input存在异常时,通常认为传入的数据中本身就存在异常,因此考虑异常节点发生在src节点所在rank的上一个或多个计算节点中;当src节点的input正常而output存在异常值,或dst节点的output存在异常值时,考虑是通信节点本身的操作产生了异常数据。 + +对于无向节点,当节点input存在异常时,认为传入的数据中本身就存在异常,因此考虑异常节点发生在src节点所在rank的上一个或多个计算节点中;当input正常而output异常时,考虑是通信节点本身的操作产生了异常数据。 + +### 顺序判定 +对于相连接的有向通信算子,认为src节点的异常发生早于dst节点;对于无向通信算子,认为异常是同时发生的。 + +对于计算节点按照dump的顺序排序。 diff --git a/debug/accuracy_tools/msprobe/docs/FAQ.md b/debug/accuracy_tools/msprobe/docs/FAQ.md index 833ca07a236f33e69b102d4acb45d35cd6fe7e3a..bc34e9b36b798c6f9cf01c486565360764df4f66 100644 --- a/debug/accuracy_tools/msprobe/docs/FAQ.md +++ b/debug/accuracy_tools/msprobe/docs/FAQ.md @@ -36,6 +36,9 @@ 该信息说明 module 挂载了被 PyTorch 框架废弃的 register_backward_hook,这与工具使用的 register_full_backward_hook 接口会产生冲突,故工具会跳过该 module 的反向数据采集。 - 如果您希望所有 module 数据都能采集下来,可以将模型中使用的 register_backward_hook 接口改为 PyTorch 框架推荐的 register_full_backward_pre_hook 或 register_full_backward_hook 接口。 + +5. 在使用 msprobe 进行 Pytorch 框架的数据采集功能时,请注意确认环境变量 NPU_ASD_ENABLE=0 ,即关闭特征值检测功能。 由于工具冲突, 在该功能开启的情况下可能导致某些 api 数据采集的缺失。 + # 2 精度预检(PyTorch) 1. 预检工具在 dump 和 run_ut 的过程中,是否需要同时开启或关闭 jit 编译(jit_compile)? @@ -58,11 +61,7 @@ 答:对于 fp16 的数据,CPU 会上升一个精度 fp32 去计算,这是和算子那边对齐的精度结论,CPU 用更高精度去计算会更接近真实值。 -6. 添加预检工具后截取操作报错:`IndexError: too many indices for tensor of dimension x` 或 `TypeError: len() of a 0-d tensor`。 - - 答:注释工具目录 `mstt/debug/accuracy_tools/msprobe/pytorch/hook_module/support_wrap_ops.yaml` 文件中 Tensor: 下的 `- __getitem__`,工具会跳过采集该 API。如果是需要 dump 关键位置 API 也可以考虑根据报错堆栈信息注释引发报错的类型检查。 - -7. Tensor 魔法函数具体对应什么操作? +6. Tensor 魔法函数具体对应什么操作? 答: @@ -202,15 +201,11 @@ def npu_forward_fused_softmax(self, input_, mask): 答:正常现象,dataloader 通过 raise 结束程序,堆栈信息可忽略。 -10. 添加 msprobe 工具后截取操作报错:`IndexError: too many indices for tensor of dimension x` 或 `TypeError: len() of a 0-d tensor`。 - - 答:注释工具目录 `mstt/debug/accuracy_tools/msprobe/pytorch/hook_module/support_wrap_ops.yaml` 文件中 `Tensor: ` 下的 `- __getitem__`,工具会跳过采集该 API。如果是需要采集关键位置 API 也可以考虑根据报错堆栈信息注释引发报错的类型检查。 - -11. 使用 msprobe 工具数据采集功能后,模型出现报错,报错信息为:`activation_func must be F.gelu` 或 `ValueError(Only support fusion of gelu and swiglu)`。 +10. 使用 msprobe 工具数据采集功能后,模型出现报错,报错信息为:`activation_func must be F.gelu` 或 `ValueError(Only support fusion of gelu and swiglu)`。 答:这一类报错常见于 Megatron/MindSpeed/ModelLink 等加速库或模型仓中,原因是工具本身会封装 torch 的 API(API类型和地址会发生改变),而有些 API 在工具使能前类型和地址就已经确定,此时工具无法对这类 API 再进行封装,而加速库中会对某些 API 进行类型检查,即会把工具无法封装的原始的 API和工具封装之后的 API 进行判断,所以会报错。 规避方式有3种:①将PrecisionDebugger的实例化放在文件的开始位置,即导包后的位置,确保所有API都被封装;②注释 `mstt/debug/accuracy_tools/msprobe/pytorch/hook_module/support_wrap_ops.yaml` 文件中的 `-gelu` 或者 `-silu`,工具会跳过采集该 API。③ 可以考虑根据报错堆栈信息注释引发报错的类型检查。 -12. 添加 msprobe 工具后触发与 AsStrided 算子相关、或者编译相关的报错,如:`Failed to compile Op [AsStrided]`。 +11. 添加 msprobe 工具后触发与 AsStrided 算子相关、或者编译相关的报错,如:`Failed to compile Op [AsStrided]`。 答:注释工具目录 `mstt/debug/accuracy_tools/msprobe/pytorch/hook_module/support_wrap_ops.yaml` 文件中 `Tensor: `下的 `-t` 和 `- transpose`。 diff --git a/debug/accuracy_tools/msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md b/debug/accuracy_tools/msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md index 0a76c51d71d77c9cbc86d98600203e6faa71a0f6..275aa66e53f25587facb2034dba5706b71bab0bb 100644 --- a/debug/accuracy_tools/msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +++ b/debug/accuracy_tools/msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md @@ -1,6 +1,17 @@ # MindSpore 场景的精度数据采集基线 -## "tensor"模式采集数据量参考基线 +## "statistics"模式(未开启md5)采集**时间**膨胀参考基线 + +该基线为MindSpore框架下,使用"statistics"模式采集数据性能膨胀参考基线。测试了38B语言大模型在不同采集模式8卡下的性能膨胀。 + +| 采集模式 | 无工具 (耗时) | 加工具但未使能 Dump (耗时) | 加工具并使能 Dump (耗时) | +|:--------:|:-------------:|:--------------------:|:----------------:| +| L0 | ≈340 ms | ≈340 ms (无膨胀) | ≈1.2 s (膨胀3.5倍) | +| L1 | ≈340 ms | ≈0.7–1.2 s (膨胀2~4倍) | ≈3.8 s (膨胀11倍) | +| mix | ≈340 ms | ≈0.7–1.2 s (膨胀2~4倍) | ≈5.5 s (膨胀16倍) | + + +## "tensor"模式采集**数据量**参考基线 该基线为MindSpore框架下,使用"tensor"模式采集数据量参考基线。本基线测试了38B语言大模型在不同采集模式下,不同global_batch_size下,单卡和8卡下,数据量的变化。 diff --git a/debug/accuracy_tools/msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md b/debug/accuracy_tools/msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md index 543d260650361431ffb8b5142ae3df6b09d0db1d..14bb2cd2c54793b5a61af5e106bcfcd484e8ecef 100644 --- a/debug/accuracy_tools/msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +++ b/debug/accuracy_tools/msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md @@ -51,6 +51,7 @@ debugger = PrecisionDebugger(config_path=config_path) # 设置 MindSpore 设备上下文 context.set_context(mode=ms.PYNATIVE_MODE, device_target="Ascend", device_id=0) +print("Context set successfully. Please wait for the training task.") # 定义卷积层 def conv_layer(in_channels, out_channels, kernel_size, stride=1, padding=0, pad_mode="valid", has_bias=True): @@ -199,7 +200,7 @@ python alexnet_model.py ## 5. 数据分析 -在 `dump_path` 参数指定的路径下(本例中为 `./output`),会出现如下目录结构,后续精度数据分析操作可使用 msprobe 工具的精度预检和精度比对等功能,详细流程请参见[《msprobe使用手册》](../../README.md#2-精度预检)。: +在 `dump_path` 参数指定的路径下(本例中为 `./output`),会出现如下目录结构,后续精度数据分析操作可使用 msprobe 工具的精度预检和精度比对等功能,详细流程请参见[《msprobe使用手册》](../../README.md#2-精度预检)。 ```bash output/ @@ -208,4 +209,5 @@ output/ ├── construct.json # level为L0时,保存Cell的层级关系信息。当前场景为空 ├── dump.json # 保存API前反向输入输出数据的统计量信息 └── stack.json # 保存API的调用栈 + ...... ``` \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/docs/img/compare_result.png b/debug/accuracy_tools/msprobe/docs/img/compare_result.png index 07cdb51707fe43d07723ed976275d99f55b50571..b321ebed8c7ea04357b57da81cc31ee038d4b94f 100644 Binary files a/debug/accuracy_tools/msprobe/docs/img/compare_result.png and b/debug/accuracy_tools/msprobe/docs/img/compare_result.png differ diff --git a/debug/accuracy_tools/msprobe/docs/img/ms_layer.png b/debug/accuracy_tools/msprobe/docs/img/ms_layer.png index d64fc0bbc0c7fe6c7d99151ec9d1ab589436eb09..ddacdc97b2934aab3d8d68cec5445f3d09136019 100644 Binary files a/debug/accuracy_tools/msprobe/docs/img/ms_layer.png and b/debug/accuracy_tools/msprobe/docs/img/ms_layer.png differ diff --git a/debug/accuracy_tools/msprobe/docs/img/save_compare_result_sample.png b/debug/accuracy_tools/msprobe/docs/img/save_compare_result_sample.png new file mode 100644 index 0000000000000000000000000000000000000000..51f902e1b9acdc17255ae7745a77a2b9bc5117b6 Binary files /dev/null and b/debug/accuracy_tools/msprobe/docs/img/save_compare_result_sample.png differ diff --git a/debug/accuracy_tools/msprobe/docs/img/visualization/proxy.png b/debug/accuracy_tools/msprobe/docs/img/visualization/proxy.png new file mode 100644 index 0000000000000000000000000000000000000000..3033214904ca3a8a1f50f187a382c47c23f05786 Binary files /dev/null and b/debug/accuracy_tools/msprobe/docs/img/visualization/proxy.png differ diff --git a/debug/accuracy_tools/msprobe/docs/img/visualization/vis_browser_1.png b/debug/accuracy_tools/msprobe/docs/img/visualization/vis_browser_1.png index 96e8521fde4b776ba915a00b5d77851b8406c153..93ee108b0cbaa145d61b75beac024dc377ecba4a 100644 Binary files a/debug/accuracy_tools/msprobe/docs/img/visualization/vis_browser_1.png and b/debug/accuracy_tools/msprobe/docs/img/visualization/vis_browser_1.png differ diff --git a/debug/accuracy_tools/msprobe/docs/img/visualization/vis_match_info.png b/debug/accuracy_tools/msprobe/docs/img/visualization/vis_match_info.png new file mode 100644 index 0000000000000000000000000000000000000000..2d0c68cd12ab31c891be6f22de04f230472d4e2d Binary files /dev/null and b/debug/accuracy_tools/msprobe/docs/img/visualization/vis_match_info.png differ diff --git a/debug/accuracy_tools/msprobe/docs/img/visualization/vis_precision_info.png b/debug/accuracy_tools/msprobe/docs/img/visualization/vis_precision_info.png index ddd59b37f044fe64c02148b698b95296592e0399..5b625089d5c85b970089293ae754c3fb6488fd6d 100644 Binary files a/debug/accuracy_tools/msprobe/docs/img/visualization/vis_precision_info.png and b/debug/accuracy_tools/msprobe/docs/img/visualization/vis_precision_info.png differ diff --git a/debug/accuracy_tools/msprobe/docs/img/visualization/vis_search_info.png b/debug/accuracy_tools/msprobe/docs/img/visualization/vis_search_info.png index 7c55b33840163c388f8fde69f0bbc531b23f81f6..0db7f67f356700f55a7995b9e3c19df4de318939 100644 Binary files a/debug/accuracy_tools/msprobe/docs/img/visualization/vis_search_info.png and b/debug/accuracy_tools/msprobe/docs/img/visualization/vis_search_info.png differ diff --git a/debug/accuracy_tools/msprobe/docs/img/visualization/vis_show_info.png b/debug/accuracy_tools/msprobe/docs/img/visualization/vis_show_info.png index 9a6217e04848e671d784ed0b484d2fe10151bde7..75fb14cbdaca50d764b77696edef56d31c8cb0f9 100644 Binary files a/debug/accuracy_tools/msprobe/docs/img/visualization/vis_show_info.png and b/debug/accuracy_tools/msprobe/docs/img/visualization/vis_show_info.png differ diff --git a/debug/accuracy_tools/msprobe/docs/img/visualization/vis_showcase.png b/debug/accuracy_tools/msprobe/docs/img/visualization/vis_showcase.png index e95b5eeee663d91a67b1ace422c8681797ca96c1..f4f07dc1e7b429c862af074bf6d07ec560e788d6 100644 Binary files a/debug/accuracy_tools/msprobe/docs/img/visualization/vis_showcase.png and b/debug/accuracy_tools/msprobe/docs/img/visualization/vis_showcase.png differ diff --git a/debug/accuracy_tools/msprobe/docs/img/visualization/vis_unmatch_info.png b/debug/accuracy_tools/msprobe/docs/img/visualization/vis_unmatch_info.png index e4c9ed4306f9a7b20d031d32f18c815628030da6..4b123a4e7d06016cd76effd2cebcc30d6f4c2226 100644 Binary files a/debug/accuracy_tools/msprobe/docs/img/visualization/vis_unmatch_info.png and b/debug/accuracy_tools/msprobe/docs/img/visualization/vis_unmatch_info.png differ diff --git a/debug/accuracy_tools/msprobe/mindspore/__init__.py b/debug/accuracy_tools/msprobe/mindspore/__init__.py index 089c29eb098ad4305edcca1306462f8924dd9291..e9f2e09ac8dba07e57b37bf598d82c8330434d19 100644 --- a/debug/accuracy_tools/msprobe/mindspore/__init__.py +++ b/debug/accuracy_tools/msprobe/mindspore/__init__.py @@ -17,12 +17,12 @@ import os try: from msprobe.lib import _msprobe_c - os.environ["MS_HOOK_ENABLE"] = "on" os.environ["HOOK_TOOL_PATH"] = _msprobe_c.__file__ except ImportError: from .common.log import logger logger.info("Module _msprobe_c has not been installed. L2-Dump may not work normally.") from msprobe.mindspore.debugger.precision_debugger import PrecisionDebugger -from msprobe.mindspore.common.utils import seed_all -from msprobe.mindspore.monitor.module_hook import TrainerMon \ No newline at end of file +from msprobe.mindspore.common.utils import seed_all, MsprobeStep, MsprobeInitStep +from msprobe.mindspore.monitor.module_hook import TrainerMon +from msprobe.mindspore.dump.graph_tensor_dump import save, save_grad diff --git a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py index 98c6b4b98530ec447c2e239c11b5d4d7b927d874..557d731e042913da3a622035219ec8dea0409ab4 100644 --- a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +++ b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py @@ -16,7 +16,7 @@ import os from tqdm import tqdm -from msprobe.core.common.const import Const, CompareConst, MsCompareConst +from msprobe.core.common.const import Const, CompareConst from msprobe.core.common.file_utils import FileOpen, create_directory, write_csv, load_json, load_yaml from msprobe.core.common.utils import add_time_as_suffix from msprobe.mindspore.api_accuracy_checker.api_info import ApiInfo @@ -25,6 +25,7 @@ from msprobe.mindspore.api_accuracy_checker.base_compare_algorithm import compar from msprobe.mindspore.api_accuracy_checker.data_manager import DataManager from msprobe.mindspore.api_accuracy_checker.utils import (check_and_get_from_json_dict, global_context, trim_output_compute_element_list) +from msprobe.mindspore.common.const import MsCompareConst from msprobe.mindspore.common.log import logger from msprobe.mindspore.api_accuracy_checker import torch_mindtorch_importer @@ -156,6 +157,7 @@ class ApiAccuracyChecker: real_api_str = Const.SEP.join(api_name_str_list[1:-2]) api_list = load_yaml(yaml_path) supported_tensor_api_list = api_list.get(MsCompareConst.SUPPORTED_TENSOR_LIST_KEY) + supported_fusion_api_list = MsCompareConst.SUPPORTED_FUSION_LIST if api_type_str in (MsCompareConst.MINT, MsCompareConst.MINT_FUNCTIONAL) \ and global_context.get_framework() == Const.MS_FRAMEWORK: return True @@ -165,6 +167,9 @@ class ApiAccuracyChecker: if api_type_str == MsCompareConst.TENSOR_API and real_api_str in supported_tensor_api_list \ and global_context.get_framework() == Const.MS_FRAMEWORK: return True + if api_type_str == MsCompareConst.FUNCTIONAL_API and real_api_str in supported_fusion_api_list \ + and global_context.get_framework() == Const.MS_FRAMEWORK: + return True return False def parse(self, api_info_path): diff --git a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/api_runner.py b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/api_runner.py index f42702be0b114e40e5e31dc4326bd9ca21f82202..82c2790452776733f924eccb82da3dc6a339594f 100644 --- a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/api_runner.py +++ b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/api_runner.py @@ -15,11 +15,13 @@ import mindspore from mindspore import ops -from msprobe.core.common.const import Const, MsCompareConst +from msprobe.core.common.const import Const from msprobe.core.common.exceptions import ApiAccuracyCheckerException from msprobe.mindspore.api_accuracy_checker.compute_element import ComputeElement from msprobe.mindspore.api_accuracy_checker.type_mapping import float_dtype_str_list, torch_dtype_to_dtype_str from msprobe.mindspore.api_accuracy_checker.utils import convert_to_tuple +from msprobe.mindspore.api_accuracy_checker.bench_functions.fusion_operator import fusion +from msprobe.mindspore.common.const import MsCompareConst from msprobe.mindspore.common.log import logger @@ -64,7 +66,9 @@ api_parent_module_mapping = { (MsCompareConst.MINDTORCH_FUNC, Const.MT_FRAMEWORK): mindtorch_func, (MsCompareConst.MINDTORCH_FUNC, Const.PT_FRAMEWORK): torch.nn.functional, (MsCompareConst.MINDTORCH_DIST, Const.MT_FRAMEWORK): mindtorch_dist, - (MsCompareConst.MINDTORCH_DIST, Const.PT_FRAMEWORK): torch.distributed + (MsCompareConst.MINDTORCH_DIST, Const.PT_FRAMEWORK): torch.distributed, + (MsCompareConst.FUNCTIONAL_API, Const.MS_FRAMEWORK): mindspore.ops, + (MsCompareConst.FUSION_API, Const.PT_FRAMEWORK): fusion } @@ -83,7 +87,9 @@ api_parent_module_str_mapping = { (MsCompareConst.MINDTORCH_FUNC, Const.MT_FRAMEWORK): "mindtorch_func", (MsCompareConst.MINDTORCH_FUNC, Const.PT_FRAMEWORK): "torch.nn.functional", (MsCompareConst.MINDTORCH_DIST, Const.MT_FRAMEWORK): "mindtorch_dist", - (MsCompareConst.MINDTORCH_DIST, Const.PT_FRAMEWORK): "torch.distributed" + (MsCompareConst.MINDTORCH_DIST, Const.PT_FRAMEWORK): "torch.distributed", + (MsCompareConst.FUNCTIONAL_API, Const.MS_FRAMEWORK): "mindspore.ops", + (MsCompareConst.FUSION_API, Const.PT_FRAMEWORK): "fusion" } @@ -125,7 +131,8 @@ class ApiRunner: err_msg = f"ApiRunner.get_info_from_name failed: api_name_str: {api_name_str} is not in defined format" logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue)) api_type_str, api_sub_name = api_name_list[0], api_name_list[1] - if api_type_str not in [MsCompareConst.MINT, MsCompareConst.MINT_FUNCTIONAL, MsCompareConst.TENSOR_API] \ + if api_type_str not in [MsCompareConst.MINT, MsCompareConst.MINT_FUNCTIONAL, MsCompareConst.TENSOR_API, + MsCompareConst.FUNCTIONAL_API] \ and api_platform == Const.MS_FRAMEWORK: err_msg = f"ApiRunner.get_info_from_name failed: not mint, mint.nn.functional or Tensor api" logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue)) @@ -139,21 +146,24 @@ class ApiRunner: def get_api_instance(api_type_str, api_sub_name, api_platform): """ Args: - api_type_str: str, Union["MintFunctional", "Mint", "Tensor"] + api_type_str: str, Union["MintFunctional", "Mint", "Tensor", "Functional"] api_sub_name: str, e.g. "relu" - api_platform: str: Union["mindpore", "torch"] + api_platform: str: Union["mindspore", "pytorch"] Return: api_instance: function object Description: - get mindspore.mint/torch api fucntion + get mindspore.mint/torch api function mindspore.mint.{api_sub_name} <--> torch.{api_sub_name} mindspore.mint.nn.functional.{api_sub_name} <--> torch.nn.functional.{api_sub_name} """ - - api_parent_module = api_parent_module_mapping.get((api_type_str, api_platform)) - api_parent_module_str = api_parent_module_str_mapping.get((api_type_str, api_platform)) + if api_sub_name in MsCompareConst.SUPPORTED_FUSION_LIST and api_platform == "pytorch": + api_parent_module = api_parent_module_mapping.get((MsCompareConst.FUSION_API, api_platform)) + api_parent_module_str = api_parent_module_str_mapping.get((MsCompareConst.FUSION_API, api_platform)) + else: + api_parent_module = api_parent_module_mapping.get((api_type_str, api_platform)) + api_parent_module_str = api_parent_module_str_mapping.get((api_type_str, api_platform)) full_api_name = api_parent_module_str + Const.SEP + api_sub_name if not hasattr(api_parent_module, api_sub_name): diff --git a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py index ead03d25ea5c2e6bb0422486f1939c5b31ee589b..da2f8ad612fcf3a42083894ff1b8e56db757f919 100644 --- a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +++ b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py @@ -18,9 +18,10 @@ from abc import ABC, abstractmethod import mindspore import numpy as np import torch -from msprobe.core.common.const import CompareConst, MsCompareConst +from msprobe.core.common.const import CompareConst from msprobe.core.common.exceptions import ApiAccuracyCheckerException from msprobe.mindspore.common.log import logger +from msprobe.mindspore.common.const import MsCompareConst class CompareResult: diff --git a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py new file mode 100644 index 0000000000000000000000000000000000000000..c5adc9842414fd616816f30f3c0b66c22b0e86b7 --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py @@ -0,0 +1,541 @@ +# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import namedtuple +import torch +import torch.nn as nn +import numpy as np + +from einops import rearrange + + +from msprobe.pytorch.common.utils import logger + +GTYPE = torch.float64 # arm host必须选择float64,x86环境选择float32即可,64也行。arm计算很慢,s=8k的场景建议使用x86 +SOFTMAX_BUILD_MODE = "QKV" # "MAX_SUM" + +FaForwardParams = namedtuple("FaForwardParams", + ["q", "k", "v", "drop_mask", "attn_mask", "pse", "scalar_value", "keep_prob"]) +FaBackwardParams = namedtuple("FaBackwardParams", + ["dx", "q", "k", "v", "softmax_res", "drop_mask", "pse", "scalar_value", "keep_prob"]) +RebuildSoftmaxParams = namedtuple("RebuildSoftmaxParams", + ["q", "k", "attn_mask", "pse", "scalar_value", "softmax_max", "softmax_sum"]) + + +def softmax_forward(x): + x_max = torch.max(x, dim=-1, keepdims=True)[0] + x_sub = x.sub(x_max) + y = torch.exp(x_sub) + x_sum = y.sum(dim=-1, keepdims=True) + res = y.div(x_sum) + return res, x_max, x_sum + + +def softmax_grad(dp, softmax_res): + muls = dp * softmax_res + muls_r = muls.sum(dim=-1, keepdims=True) + sub_r = dp - muls_r + res = sub_r * softmax_res + return res + + +def broadcast_kv(num_heads, num_kv_heads, kv_tensor, dtype): + if num_kv_heads == 0 or num_kv_heads > num_heads: + raise ValueError(f"num_kv_heads must be non-zero and bigger than num_heads.") + + factor = num_heads // num_kv_heads + kv_shape = kv_tensor.shape + b = kv_shape[0] + s = kv_shape[2] + d = kv_shape[3] + kv_res = torch.zeros([b, num_heads, s, d]).to(dtype) + for i in range(num_heads): + j = i // factor + kv_res[:, i:i + 1, :, :] = kv_tensor[:, j:j + 1, :, :] + return kv_res + + +def calculate_qk(q, k, attn_mask, pse, scalar_value): + if k.dim() != 4: + raise ValueError(f"k tensor dimension must be 4, but got {k.dim()} dimensions (shape: {k.shape})") + + if k.dim() == 3: + k = k.unsqueeze(1) # 在head维度扩展 + + if pse is None or len(pse.shape) == 0: + qk = torch.matmul(q, k.permute(0, 1, 3, 2)).mul(scalar_value) + else: + qk = (torch.matmul(q, k.permute(0, 1, 3, 2)) + pse).mul(scalar_value) + if attn_mask is None or len(attn_mask.shape) == 0: + return qk + else: + qk = qk + attn_mask.bool() * (-40000.0) # -10000 + return qk + + +def fusion_attention_forward(forward_params): + q = forward_params.q + k = forward_params.k + v = forward_params.v + drop_mask = forward_params.drop_mask + attn_mask = forward_params.attn_mask + pse = forward_params.pse + scalar_value = forward_params.scalar_value + keep_prob = forward_params.keep_prob + + qk = calculate_qk(q, k, attn_mask, pse, scalar_value) + softmax_res, softmax_max, softmax_sum = softmax_forward(qk) + if drop_mask is None or len(drop_mask.shape) == 0: + drop_res = softmax_res + else: + drop_res = softmax_res * drop_mask * (1.0 / keep_prob) + y = torch.matmul(drop_res, v) + return y, softmax_max, softmax_sum + + +def fusion_attention_backward(backward_params): + dx = backward_params.dx + q = backward_params.q + k = backward_params.k + v = backward_params.v + softmax_res = backward_params.softmax_res + drop_mask = backward_params.drop_mask + pse = backward_params.pse + scalar_value = backward_params.scalar_value + keep_prob = backward_params.keep_prob + dp = torch.matmul(dx, v.permute(0, 1, 3, 2)) + if drop_mask is None or len(drop_mask.shape) == 0: + drop_res = softmax_res.permute(0, 1, 3, 2) + dp_drop = dp + else: + drop_res = softmax_res.mul(drop_mask).mul(1.0 / keep_prob).permute(0, 1, 3, 2) + dp_drop = dp * drop_mask * (1.0 / keep_prob) + dv = torch.matmul(drop_res, dx) + softmax_grad_res = (softmax_grad(dp_drop, softmax_res) * scalar_value) + dq = torch.matmul(softmax_grad_res, k) + dk = torch.matmul(softmax_grad_res.permute(0, 1, 3, 2), q) + return dq, dk, dv + + +def parse_bsnd_args(query, key, head_num, input_layout): + supported_input_layout = ["BSH", "SBH", "BSND", "BNSD", "TND"] + b, s1, s2, n1, n2, d, h1, h2 = None, None, None, head_num, None, None, None, None + + if not isinstance(input_layout, str) or input_layout not in supported_input_layout: + raise ValueError(f"Invalid input_layout arg which must be one of {supported_input_layout}.") + + if input_layout == "TND": + raise ValueError(f"input_layout {input_layout} does not supported for now.") + try: + if input_layout == "BSH": + b, s1, h1 = query.shape + _, s2, h2 = key.shape + d = h1 // n1 + n2 = h2 // d + elif input_layout == "SBH": + s1, b, h1 = query.shape + s2, _, h2 = key.shape + d = h1 // n1 + n2 = h2 // d + elif input_layout == "BSND": + b, s1, n1, d = query.shape + _, s2, n2, _ = key.shape + h1 = n1 * d + h2 = n2 * d + elif input_layout == "BNSD": + b, n1, s1, d = query.shape + _, n2, s2, _ = key.shape + h1 = n1 * d + h2 = n2 * d + except Exception as e: + raise ValueError(f"query.shape: {query.shape}, key.shape: {key.shape}, parse_bsnd_args error: {e}") from e + + if d == 0: + raise ValueError(f"Value d must be non-zero.") + _dtype = query.dtype + ret = (b, s1, s2, n1, n2, d, h1, h2, _dtype) + return ret + + +def convert_from_bnsd(_input, input_layout): + """ + transform qkv from bnsd to input_layout. + B: batch_size + S: sequence_length + N: num_heads + D: head_dim + Args: + _input (torch.Tensor): tensor of shape (B,N,S,D) + input_layout (str): "BSH" or "SBH" or "BSND" or "BNSD" or "TND" + Returns: + tensor of shape (B,N,S,D) or (B,S,N,D) or (S,B,H) or (B,S,H) + """ + if input_layout == "BSH": + # (B,N,S,D)=>(B,S,N*D) + out = rearrange(_input, 'b n s d -> b s (n d)').contiguous() + elif input_layout == "SBH": + # (B,N,S,D)=>(S,B,N*D) + out = rearrange(_input, 'b n s d -> s b (n d)').contiguous() + elif input_layout == "BSND": + # (B,N,S,D)=>(B,S,N,D) + out = rearrange(_input, 'b n s d -> b s n d').contiguous() + elif input_layout == "TND": + raise ValueError(f"input_layout {input_layout} does not supported for now.") + else: + out = _input + return out + + +def convert_to_bnsd(_input, n, input_layout): + """ + transform qkv from input_layout to bnsd. + B: batch_size + S: sequence_length + N: num_heads + D: head_dim + Args: + _input (torch.Tensor): tensor of shape (B,N,S,D) or (B,S,N,D) or (S,B,H) or (B,S,H) + n (int): num_heads + input_layout (str):"BSH" or "SBH" or "BSND" or "BNSD" or "TND" + Returns: + tensor of shape (B,N,S,D) + """ + if input_layout == "BSH": + # (B,S,N*D)=>(B,N,S,D) + out = rearrange(_input, 'b s (n d) -> b n s d', n=n) + elif input_layout == "SBH": + # (S,B,N*D)=>(B,N,S,D) + out = rearrange(_input, 's b (n d) -> b n s d', n=n) + elif input_layout == "BSND": + # (B,S,N,D)=>(B,N,S,D) + out = rearrange(_input, 'b s n d -> b n s d', n=n) + elif input_layout == "TND": + raise ValueError(f"input_layout {input_layout} does not supported for now.") + else: + out = _input + if out.dim() != 4: + raise ValueError(f"convert qkv format failed with input_layout {input_layout}.") + return out.to(GTYPE) + + +def generate_attn_mask(*args): + """ + # 当sparse_mode=2、3、4时小算子到融合算子会走这个优化,反过来看就要拆解回原来的基本实现 + ===> attn_mask = torch.from_numpy(np.triu(np.ones([2048, 2048]), k=1)).to(dtype) + """ + + sparse_mode, attn_mask, b, n1, s1, s2, pre_tocken, next_tocken, dtype = args + shape = [s1, s2] + + if attn_mask is not None: + # 当FA的输入已经包含attn_mask时,可以认为已经是转换之后的mask矩阵了,有三种特殊场景,即稀疏矩阵场景,需要进行逆向还原 + if sparse_mode == 2 or sparse_mode == 3 or sparse_mode == 4: + logger.info(f"s1: {s1}, s2:{s2}, attn_mask.shape:{attn_mask.shape}, attn_mask.dtype:{attn_mask.dtype}") + + if attn_mask.dim() == 2 and attn_mask.shape[0] == 2048 and attn_mask.shape[1] == 2048: + if attn_mask.equal(torch.from_numpy(np.triu(np.ones([2048, 2048]), k=1)).to(attn_mask.dtype)): + if sparse_mode == 2: + attn_mask = torch.from_numpy(np.triu(np.ones(shape), k=1)) + elif sparse_mode == 3: + attn_mask = torch.from_numpy(np.triu(np.ones(shape), k=s2 - s1 + 1)) + elif sparse_mode == 4: + attn_mask_u = torch.from_numpy(np.triu(np.ones(shape), k=next_tocken + 1)) + attn_mask_l = torch.from_numpy(np.tril(np.ones(shape), k=-pre_tocken - 1)) + attn_mask = attn_mask_u + attn_mask_l + logger.debug(f"反向转换attn_mask {attn_mask.shape}") + return attn_mask.to(dtype) + + return attn_mask.to(dtype) + + if attn_mask is not None: + if attn_mask.dim() == 2: + if attn_mask.shape[0] != s1 or attn_mask.shape[1] != s2: + raise ValueError(f"Invalid attn_mask shape `SS` {attn_mask.shape}") + shape = [s1, s2] + elif attn_mask.dim() == 4: + if attn_mask.shape[1] == 1: + shape = [b, 1, s1, s2] if b != 1 else [1, 1, s1, s2] + else: + shape = [b, n1, s1, s2] if b != 1 else [1, n1, s1, s2] + + if sparse_mode == 0: + attn_mask_u = torch.from_numpy(np.triu(np.ones(shape), k=next_tocken + 1)) + attn_mask_l = torch.from_numpy(np.tril(np.ones(shape), k=-pre_tocken - 1)) + attn_mask = attn_mask_u + attn_mask_l + elif sparse_mode == 1: # no sparse + attn_mask = torch.from_numpy(np.zeros(shape)) + elif sparse_mode == 2: + attn_mask = torch.from_numpy(np.triu(np.ones(shape), k=1)) + elif sparse_mode == 3: + attn_mask = torch.from_numpy(np.triu(np.ones(shape), k=s2 - s1 + 1)) + elif sparse_mode == 4: + attn_mask_u = torch.from_numpy(np.triu(np.ones(shape), k=next_tocken + 1)) + attn_mask_l = torch.from_numpy(np.tril(np.ones(shape), k=-pre_tocken - 1)) + attn_mask = attn_mask_u + attn_mask_l + # 注:不会出现sparse_mode=5的情况,该情况要求必须要传入attn_mask,且attn_mask矩阵数据格式须为BNSS或B1SS, + # 因此可以认为FA的输入已经是正确的attn_mask了 + return attn_mask.to(dtype) + + +def generate_kv(key, value, n1, n2): + # N不等长适配by cdy + if not (n1 == n2): + k_new = broadcast_kv(n1, n2, key, key.dtype) + v_new = broadcast_kv(n1, n2, value, value.dtype) + else: + k_new = key + v_new = value + return k_new, v_new + + +def rebuid_softmax_by_qkv(q, k, attn_mask, pse, scalar_value): + """ + attention = softmax(QK^T/sqrt(d))V + softmax(x_i) = e^(x_i - x_max) / sum(e^(x_i - x_max)) + """ + logger.info("Using QKV to rebuild original softmax") + qk = calculate_qk(q, k, attn_mask, pse, scalar_value) + softmax_res, _, _ = softmax_forward(qk) + return softmax_res + + +def rebuild_softmax_by_max_sum(softmax_params): + """ + attention = softmax(QK^T/sqrt(d))V + softmax(x_i) = e^(x_i - x_max_i) / x_sum_i) + """ + q = softmax_params.q + k = softmax_params.k + attn_mask = softmax_params.attn_mask + pse = softmax_params.pse + scalar_value = softmax_params.scalar_value + softmax_max = softmax_params.softmax_max + softmax_sum = softmax_params.softmax_sum + logger.info("Using softmax_max and softmax_sum to rebuild original softmax") + + qk = calculate_qk(q, k, attn_mask, pse, scalar_value) + if softmax_max.shape[-1] == 0: + raise ValueError(f"softmax_max.shape[-1] must be non-zero, softmax_max.shape: {softmax_max.shape}") + repeat_dim = qk.shape[-1] // softmax_max.shape[-1] + softmax_res = torch.exp(qk.sub(softmax_max.repeat(1, 1, 1, repeat_dim))).div( + softmax_sum.repeat(1, 1, 1, repeat_dim)) + return softmax_res + + +def get_head_num(*args, **kwargs): + if kwargs.get("head_num", None): + head_num = kwargs.get("head_num") + elif len(args) >= 4: + head_num = args[3] + else: + raise ValueError(f"Unsupported npu_fusion_attention args {args}.") + return head_num + + +def get_input_layout(*args, **kwargs): + if kwargs.get("input_layout", None): + input_layout = kwargs.get("input_layout") + elif len(args) >= 5: + input_layout = args[4] + else: + raise ValueError(f"Unsupported npu_fusion_attention args {args}.") + return input_layout + + +def npu_fusion_attention_forward_patch(*args, **kwargs): + if len(args) < 2: + raise RuntimeError("npu_fusion_attention_forward_patch: length of args should greater than or equal to 2.") + + # query, key, value, head_num, input_layout + head_num = get_head_num(*args, **kwargs) + input_layout = get_input_layout(*args, **kwargs) + + b, s1, s2, n1, n2, d, h1, h2, dtype = parse_bsnd_args(args[0], args[1], head_num, input_layout) + if n1 == n2 and s1 == s2: + logger.debug(f"running case : BNSD = {b}_{n1}_{s1}_{d}, sparse = {kwargs.get('sparse_mode', 0)}") + else: + logger.debug(f"running case: BNSD = {b}_{n1}({n2})_{s1}({s2})_{d}, sparse = {kwargs.get('sparse_mode', 0)}") + if not (n1 % n2 == 0 and n1 >= n2): + raise ValueError(f"N1与N2不匹配,请检查: n1 = {n1}, n2 = {n2}.") + + dims_kwargs = { + "b": b, "s1": s1, "s2": s2, "n1": n1, "n2": n2, + "d": d, "h1": h1, "h2": h2, "dtype": dtype + } + new_kwargs = { + "keep_prob": 1, + "scalar_value": kwargs.get("scalar_value", 1 / (d ** 0.5)), + "sparse_mode": kwargs.get("sparse_mode", 0), + "prefix": kwargs.get("prefix"), + "pre_tockens": kwargs.get("pre_tockens", 2147483647), + "next_tockens": kwargs.get("next_tockens", 2147483647), + "pse": kwargs.get("pse"), + "padding_mask": kwargs.get("padding_mask"), + "attn_mask": kwargs.get("attn_mask") + } + + return args, dims_kwargs, new_kwargs + + +def npu_fusion_attention_backward_patch(*args, **kwargs): + if len(args) != 6: + raise ValueError(f"Unsupported npu_fusion_attention_grad args {args}.") + + b, s1, s2, n1, n2, d, h1, h2, dtype = parse_bsnd_args(args[0], args[1], args[4], args[5]) + if n1 == n2 and s1 == s2: + logger.info(f"running case : bnsd = {b}_{n1}_{s1}_{d}, sparse = {kwargs.get('sparse_mode', 0)}") + else: + logger.info(f"running case: bnsd = {b}_{n1}({n2})_{s1}({s2})_{d}, sparse = {kwargs.get('sparse_mode', 0)}") + if not (n1 % n2 == 0 and n1 >= n2): + raise ValueError(f"N1与N2不匹配,请检查: n1 = {n1}, n2 = {n2}.") + + dims_kwargs = { + "b": b, "s1": s1, "s2": s2, "n1": n1, "n2": n2, + "d": d, "h1": h1, "h2": h2, "dtype": dtype + } + + new_kwargs = { + "keep_prob": 1, + "scalar_value_value": kwargs.get("scalar_value_value", 1 / (d ** 0.5)), + "sparse_mode": kwargs.get("sparse_mode", 0), + "prefix": kwargs.get("prefix"), + "pre_tockens": kwargs.get("pre_tockens", 2147483647), + "next_tockens": kwargs.get("next_tockens", 2147483647), + "pse": kwargs.get("pse"), + "padding_mask": kwargs.get("padding_mask"), + "softmax_max": kwargs.get("softmax_max"), + "softmax_sum": kwargs.get("softmax_sum"), + "softmax_in": kwargs.get("softmax_in"), + "attention_in": kwargs.get("attention_in"), + "seed": kwargs.get("seed", 0), + "offset": kwargs.get("offset", 0), + "numels": kwargs.get("numels", 0), + "attn_mask": kwargs.get("attn_mask") + } + + return args, dims_kwargs, new_kwargs + + +class FlashAttentionScore(nn.Module): + def __init__(self): + super(FlashAttentionScore, self).__init__() + # You can initialize any parameters here if necessary + + def forward(self, *inputs, **kwargs): + # Extract the inputs for the attention calculation + new_args, dims_kwargs, new_kwargs = npu_fusion_attention_forward_patch(*inputs, **kwargs) + query, key, value = new_args[0], new_args[1], new_args[2] + + input_layout = get_input_layout(*inputs, **kwargs) + + n1 = dims_kwargs.get("n1") + n2 = dims_kwargs.get("n2") + s1 = dims_kwargs.get("s1") + s2 = dims_kwargs.get("s2") + b = dims_kwargs.get("b") + dtype = dims_kwargs.get("dtype") + attn_mask = new_kwargs.get("attn_mask") + keep_prob = new_kwargs.get("keep_prob") + sparse_mode = new_kwargs.get("sparse_mode") + pre_tockens = new_kwargs.get("pre_tockens") + next_tockens = new_kwargs.get("next_tokens") + pse = new_kwargs.get("real_shift") + scalar_value = new_kwargs.get("scalar_value") + + args_temp = [sparse_mode, attn_mask, b, n1, s1, s2, pre_tockens, next_tockens, dtype] + + attn_mask = generate_attn_mask(*args_temp) + query = convert_to_bnsd(query, n1, input_layout) + key = convert_to_bnsd(key, n2, input_layout) + value = convert_to_bnsd(value, n2, input_layout) + + forward_params = FaForwardParams( + q=query, + k=key, + v=value, + drop_mask=None, + attn_mask=attn_mask, + pse=pse, + scalar_value=scalar_value, + keep_prob=keep_prob + ) + + out_golden, softmax_max, softmax_sum = fusion_attention_forward(forward_params) + + # If output dimension is 5, reshape accordingly + if out_golden.dim() == 5: + out_golden = out_golden.reshape(out_golden.size(0), + out_golden.size(1) * out_golden.size(2), + out_golden.size(3), out_golden.size(4)) + + out_golden = convert_from_bnsd(out_golden, input_layout) + + # Ensure the output matches the desired layout + out_golden = out_golden.cpu(), softmax_max.repeat(1, 1, 1, 8).cpu(), softmax_sum.repeat(1, 1, 1, 8).cpu() + + return out_golden + + def backward(self, *inputs, **kwargs): + # The backward pass will be similar to what was described for the gradient computation + new_args, dims_kwargs, new_kwargs = npu_fusion_attention_backward_patch(*inputs, **kwargs) + query, key, value, dx, input_layout = new_args[0], new_args[1], new_args[2], new_args[3], new_args[5] + n1 = dims_kwargs.get("n1") + n2 = dims_kwargs.get("n2") + s1 = dims_kwargs.get("s1") + s2 = dims_kwargs.get("s2") + b = dims_kwargs.get("b") + dtype = dims_kwargs.get("dtype") + attn_mask = new_kwargs.get("attn_mask") + keep_prob = new_kwargs.get("keep_prob") + sparse_mode = new_kwargs.get("sparse_mode") + pre_tockens = new_kwargs.get("pre_tockens") + next_tockens = new_kwargs.get("next_tockens") + pse = new_kwargs.get("pse") + softmax_max = new_kwargs.get("softmax_max") + softmax_sum = new_kwargs.get("softmax_sum") + scalar_value = new_kwargs.get("scalar_value") + + args_temp = [sparse_mode, attn_mask, b, n1, s1, s2, pre_tockens, next_tockens, dtype] + attn_mask = generate_attn_mask(*args_temp) + + query = convert_to_bnsd(query, n1, input_layout) + dx = convert_to_bnsd(dx, n1, input_layout) + key = convert_to_bnsd(key, n2, input_layout) + value = convert_to_bnsd(value, n2, input_layout) + + k_new, v_new = generate_kv(key, value, n1, n2) + + if SOFTMAX_BUILD_MODE == "QKV": + softmax_res = rebuid_softmax_by_qkv(query, k_new, attn_mask, pse, scalar_value) + else: + softmax_params = RebuildSoftmaxParams(query, k_new, attn_mask, pse, scalar_value, softmax_max, softmax_sum) + softmax_res = rebuild_softmax_by_max_sum(softmax_params) + + backward_params = FaBackwardParams(dx, query, k_new, v_new, softmax_res, None, pse, scalar_value, keep_prob) + dq, dk, dv = fusion_attention_backward(backward_params) + + # Reshape as needed + if dq.dim() == 5: + dq = dq.reshape(dq.size(0), dq.size(1) * dq.size(2), dq.size(3), dq.size(4)) + if dk.dim() == 5: + dk = dk.reshape(dk.size(0), dk.size(1) * dk.size(2), dk.size(3), dk.size(4)) + if dv.dim() == 5: + dv = dv.reshape(dv.size(0), dv.size(1) * dv.size(2), dv.size(3), dv.size(4)) + + dq = convert_from_bnsd(dq, input_layout) + dk = convert_from_bnsd(dk, input_layout) + dv = convert_from_bnsd(dv, input_layout) + + return dq.cpu(), dk.cpu(), dv.cpu() diff --git a/debug/accuracy_tools/msprobe/pytorch/compare/match.py b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py similarity index 35% rename from debug/accuracy_tools/msprobe/pytorch/compare/match.py rename to debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py index d676b85f20bbb1083f7f8f10bc3d9237a89f7b55..e1344541e89c4dafd9d49d63e3fdea117366bdd9 100644 --- a/debug/accuracy_tools/msprobe/pytorch/compare/match.py +++ b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py @@ -13,37 +13,29 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os -from msprobe.core.common.utils import CompareException -from msprobe.core.common.file_utils import load_yaml +from msprobe.mindspore.api_accuracy_checker.bench_functions.flash_attention_score import FlashAttentionScore -class AtenIrMapping(): +class FusionOperator: + """ + 所有融合算子的父类,定义了通用的接口和属性。 + """ + + # 初始化操作符字典 def __init__(self): - cur_path = os.path.dirname(os.path.realpath(__file__)) - yaml_path = os.path.join(cur_path, "mapping.yaml") - self.aten_mapping = load_yaml(yaml_path) - - def match(self, op1, op2): - if "Aten" in op1 and "Aten" not in op2: - return self.match_op(op1, op2) + self.flash_attention_score = None # 用于存放 FlashAttentionScore 操作符 + self._register_operators() + + def __getattr__(self, name): + """ 动态获取算子类 """ + if hasattr(self, name): + return getattr(self, name) else: - return self.match_op(op2, op1) - - def match_op(self, aten_op, torch_op): - try: - aten_op_raw_name_overload = '_'.join(aten_op.split("_")[1:-3]) - aten_op_raw_name = aten_op_raw_name_overload.split('.')[0] - torch_op_raw_name = '_'.join(torch_op.split("_")[1:-3]).lower() - except IndexError as e: - err_msg = f"Dump op name format error: {aten_op}, {torch_op}. Your dump data may be corrupted." - raise CompareException.INVALID_DATA_ERROR(err_msg) from e - matching_op = self.aten_mapping.get(aten_op_raw_name) - if matching_op is None: - return False - if matching_op.lower() == torch_op_raw_name: - return True - return False - - -graph_mapping = AtenIrMapping() + raise AttributeError(f"'FusionOperator' object has no attribute '{name}'") + + def _register_operators(self): + """ 注册操作符到父类,以便通过 fusion.xxx 调用 """ + self.flash_attention_score = FlashAttentionScore() + + +fusion = FusionOperator() diff --git a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/data_manager.py b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/data_manager.py index 748adf7d02cafe3983fe1990b40b1e77e993698b..24f6eb717e7ebf8fabb59d397d493831011e1161 100644 --- a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/data_manager.py +++ b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/data_manager.py @@ -16,12 +16,13 @@ import os import csv -from msprobe.core.common.const import Const, CompareConst, MsCompareConst +from msprobe.core.common.const import Const, CompareConst from msprobe.core.common.file_utils import FileOpen, create_directory, write_csv, read_csv from msprobe.core.common.utils import add_time_as_suffix, MsprobeBaseException from msprobe.mindspore.api_accuracy_checker.base_compare_algorithm import compare_algorithms from msprobe.core.common.file_utils import check_file_or_directory_path from msprobe.mindspore.common.log import logger +from msprobe.mindspore.common.const import MsCompareConst class ResultCsvEntry: @@ -187,7 +188,7 @@ class DataManager: def record_exception_skip(self, api_name, forward_or_backward, err_msg): ''' - record exception_skip infomation into self.record_exception_skip. + record exception_skip information into self.record_exception_skip. self.record_exception_skip: dict{str: dict{"forward": str/None, "backward": str/None}} string in key is api_name, string in value is err_msg ''' @@ -269,7 +270,7 @@ class DataManager: entry.backward_pass_status, overall_err_msg ] - # change row if this api has excption_skip infomation + # change row if this api has exception_skip information if api_name in self.results_exception_skip: if self.results_exception_skip[api_name][Const.FORWARD] is not None: row[1] = CompareConst.SKIP diff --git a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py index e764140badf4c107ea83044353aba19a1c412fe0..1913675ad162bf690fc0aed5fc84c245ae4f73ca 100644 --- a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +++ b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py @@ -27,10 +27,11 @@ import numpy as np from tqdm import tqdm # 本地应用/库特定导入 -from msprobe.core.common.const import Const, CompareConst, MsCompareConst +from msprobe.core.common.const import Const, CompareConst from msprobe.mindspore.api_accuracy_checker.api_accuracy_checker import ApiAccuracyChecker, BasicInfoAndStatus from msprobe.mindspore.api_accuracy_checker.multi_data_manager import MultiDataManager from msprobe.mindspore.common.log import logger +from msprobe.mindspore.common.const import MsCompareConst class MultiApiAccuracyChecker(ApiAccuracyChecker): diff --git a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py index 84f2706cc55fa3d0a1fba13d54ba8310371f1a43..13e2645ea14932afa3ac3e9ea131e443b2ee931e 100644 --- a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +++ b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py @@ -19,7 +19,8 @@ import sys from pathlib import Path import mindspore from msprobe.mindspore.common.log import logger -from msprobe.core.common.const import Const, CompareConst, MsCompareConst +from msprobe.core.common.const import Const, CompareConst +from msprobe.mindspore.common.const import MsCompareConst import torch as mindtorch from torch import Tensor as mindtorch_tensor import torch.nn.functional as mindtorch_func @@ -107,7 +108,8 @@ def delete_torch_paths(): if count_delete_env_path >= MsCompareConst.MAX_RECURSION_DEPTH - 1: raise Exception(f"Please check if you have a valid PyTorch and MindTorch environment, and ensure " - f"the PYTHONPATH environment variable depth does not exceed {Const.MAX_RECURSION_DEPTH}.") + f"the PYTHONPATH environment variable depth does not " + f"exceed {MsCompareConst.MAX_RECURSION_DEPTH}.") if not is_mindtorch(): diff --git a/debug/accuracy_tools/msprobe/mindspore/cell_processor.py b/debug/accuracy_tools/msprobe/mindspore/cell_processor.py index 6dc5d510ef51ab2a135a8bdf9f15ac670fba9e56..cc90cb03e0e0377c6ea58e9ba9be60439d004777 100644 --- a/debug/accuracy_tools/msprobe/mindspore/cell_processor.py +++ b/debug/accuracy_tools/msprobe/mindspore/cell_processor.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,21 +13,50 @@ # See the License for the specific language governing permissions and # limitations under the License. -from msprobe.core.data_dump.scope import ModuleRangeScope, MixRangeScope +from collections import OrderedDict + +from mindspore import Tensor +from mindspore.common.hook_handle import HookHandle +from mindspore.ops.operations import _inner_ops as inner + from msprobe.core.common.const import Const +from msprobe.core.common.exceptions import MsprobeException +from msprobe.core.common.runtime import Runtime +from msprobe.core.data_dump.scope import ModuleRangeScope, MixRangeScope, BaseScope +from msprobe.mindspore.common.const import Const as MsConst +from msprobe.mindspore.common.log import logger +from msprobe.mindspore.common.utils import ( + is_mindtorch, + get_cells_and_names_with_index, + has_kwargs_in_forward_hook, + is_graph_mode_cell_dump_allowed +) +from msprobe.mindspore.debugger.debugger_config import DebuggerConfig +from msprobe.mindspore.dump.graph_mode_cell_dump import GraphModeCellDump + + +def get_cell_construct(construct): + def _construct(self, *args, **kwargs): + if hasattr(self, 'msprobe_hook'): + setattr(self, 'msprobe_input_kwargs', kwargs) + return construct(self, *args, **kwargs) + return _construct class CellProcessor: cell_count = {} cell_stack = [] - api_parent_node = "" + api_parent_node = None module_node = {} + cell_bw_hook_kernels = {} + cell_backward_pre_hook = [] + cell_backward_hook = [] def __init__(self, scope): self.scope = scope if isinstance(scope, (ModuleRangeScope, MixRangeScope)) else None @staticmethod - def set_cell_count(cell_name): + def set_and_get_calls_number(cell_name): if cell_name not in CellProcessor.cell_count: CellProcessor.cell_count[cell_name] = 0 else: @@ -38,42 +67,184 @@ class CellProcessor: def reset_cell_stats(cls): cls.cell_count = {} cls.cell_stack = [] - cls.api_parent_node = "" + cls.api_parent_node = None cls.module_node = {} + cls.cell_bw_hook_kernels = {} + cls.cell_backward_pre_hook = [] + cls.cell_backward_hook = [] - def node_hook(self, name_prefix, start_or_stop, **kwargs): - def begin_hook(cell, input_data): - full_name = self.set_and_get_reserved_name(cell, name_prefix, is_called_by_pre_hook=True) - if CellProcessor.cell_stack: - CellProcessor.module_node[full_name] = CellProcessor.cell_stack[-1] - else: - CellProcessor.module_node[full_name] = None + def register_cell_hook(self, models, build_hook, config: DebuggerConfig): + if not models: + raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR, + 'The model cannot be None, when level is "L0" or "mix"') + + is_registered = False + model_type = Const.MODULE if is_mindtorch() else Const.CELL + cells_with_index_in_pynative_mode, cells_with_index_in_graph_mode = get_cells_and_names_with_index(models) + construct_name = '_call_impl' if is_mindtorch() else '_run_construct' + + for index, cells_and_names in cells_with_index_in_pynative_mode.items(): + model = models if index == "-1" else models[int(index)] + for name, cell in cells_and_names: + if cell == model: + continue + + if not has_kwargs_in_forward_hook(): + if not hasattr(cell.__class__, 'msprobe_construct'): + setattr(cell.__class__, 'msprobe_construct', True) + if hasattr(cell.__class__, construct_name): + setattr(cell.__class__, construct_name, + get_cell_construct(getattr(cell.__class__, construct_name))) + setattr(cell, 'msprobe_hook', True) + + cell_index = (index + Const.SEP) if index != "-1" else "" + prefix = f'{model_type}{Const.SEP}{cell_index}{name}{Const.SEP}{cell.__class__.__name__}{Const.SEP}' + + forward_pre_hook = self.build_cell_hook(prefix, build_hook) + cell.register_forward_pre_hook(forward_pre_hook) + + if not is_registered: + logger.info("The cell hook function is successfully mounted to the model.") + is_registered = True + + if is_graph_mode_cell_dump_allowed(config): + cells_and_names_in_graph_mode = [] + for index, cells_and_names in cells_with_index_in_graph_mode.items(): + model = models if index == "-1" else models[int(index)] + for name, cell in cells_and_names: + if cell == model: + continue + cell_index = (index + Const.SEP) if index != "-1" else "" + cells_and_names_in_graph_mode.append((f'{cell_index}{name}', cell)) + + if cells_and_names_in_graph_mode: + Runtime.run_mode = MsConst.PYNATIVE_GRAPH_MODE + GraphModeCellDump(config, cells_and_names_in_graph_mode, strict=False).handle() - CellProcessor.cell_stack.append(full_name) - CellProcessor.api_parent_node = full_name + def build_cell_hook(self, cell_name, build_data_hook): + def forward_pre_hook(cell, args): + index = CellProcessor.set_and_get_calls_number(cell_name) + full_forward_name = f'{cell_name}{Const.FORWARD}{Const.SEP}{index}' + full_backward_name = f'{cell_name}{Const.BACKWARD}{Const.SEP}{index}' - if self.scope: - self.scope.begin_module(full_name) + self.set_construct_info_in_pre_hook(full_forward_name) - def end_hook(cell, input_data, output_data): - if CellProcessor.cell_stack: - CellProcessor.cell_stack.pop() - if CellProcessor.cell_stack: - CellProcessor.api_parent_node = CellProcessor.cell_stack[-1] + if not hasattr(cell, 'msprobe_forward_hook'): + if is_mindtorch(): + cell.register_forward_hook(forward_hook, prepend=True, with_kwargs=True) + else: + forward_hook_dict = getattr(cell, '_forward_hook', OrderedDict()) + if has_kwargs_in_forward_hook(): + forward_hook_with_kwargs_dict = getattr(cell, '_forward_hook_with_kwargs', OrderedDict()) + handle = HookHandle(forward_hook_dict, extra_dict=forward_hook_with_kwargs_dict) + forward_hook_with_kwargs_dict[handle.handle_id] = True + else: + handle = HookHandle(forward_hook_dict) + forward_hook_dict[handle.handle_id] = forward_hook + forward_hook_dict.move_to_end(handle.handle_id, last=False) + + setattr(cell, 'msprobe_forward_hook', True) + + def get_backward_hook(backward_data_hook, full_backward_name): + def backward_hook_fn(cell, grad_input, grad_output): + new_output = backward_data_hook(cell, grad_input, grad_output) + self.set_construct_info_in_hook(full_backward_name) + cell.has_pre_hook_called = False + return new_output + return backward_hook_fn + + enable_hooked = sum( + [isinstance(ele, Tensor) and ele.dtype not in MsConst.NonDifferentiableType for ele in args] + ) + if enable_hooked: + backward_hook = OrderedDict() + hook_set = build_data_hook(BaseScope.Module_Type_Module, full_forward_name) + backward_hook[full_backward_name] = get_backward_hook(hook_set.backward_hook, full_backward_name) + CellProcessor.cell_backward_hook.append(backward_hook) + bw_hook = inner.CellBackwardHook(full_backward_name, cell, + self.cell_backward_hook[-1]) + bw_hook.register_backward_hook() + CellProcessor.cell_bw_hook_kernels[full_forward_name] = bw_hook + + args = bw_hook(*args) + + return args + + def forward_hook(cell, args, kwargs_or_output, output_or_kwargs=None): + index = CellProcessor.cell_count.get(cell_name, 0) + full_forward_name = f'{cell_name}{Const.FORWARD}{Const.SEP}{index}' + full_backward_name = f'{cell_name}{Const.BACKWARD}{Const.SEP}{index}' + + self.set_construct_info_in_hook(full_forward_name) + + hook_set = build_data_hook(BaseScope.Module_Type_Module, full_forward_name) + hook_result = hook_set.forward_hook(cell, args, kwargs_or_output, output_or_kwargs) + if hook_result is not None: + outputs = hook_result else: - CellProcessor.api_parent_node = None + outputs = output_or_kwargs if has_kwargs_in_forward_hook() else kwargs_or_output + + bw_hook = CellProcessor.cell_bw_hook_kernels.get(full_forward_name) + if bw_hook: + if not isinstance(outputs, (Tensor, tuple)): + logger.warning("For backward hooks to be called," + " cell output should be a Tensor or a tuple of Tensors" + f" but received {type(outputs)}") + if isinstance(outputs, tuple): + new_outputs = bw_hook(*outputs) + else: + new_outputs = bw_hook(outputs) + if isinstance(outputs, tuple) and len(outputs) == 1: + new_outputs = (new_outputs,) + outputs = new_outputs + + def get_backward_pre_hook(full_backward_name, backward_data_hook): + def backward_pre_hook_fn(cell, grad_output): + cell.has_pre_hook_called = True + self.set_construct_info_in_pre_hook(full_backward_name) + if backward_data_hook: + backward_data_hook(cell, (), grad_output) + self.set_construct_info_in_hook(full_backward_name) + cell.has_pre_hook_called = False + return backward_pre_hook_fn - if self.scope: - self.scope.end_module(cell.mindstudio_reserved_name) + backward_pre_hook = OrderedDict() + backward_data_hook = None if bw_hook else hook_set.backward_hook + backward_pre_hook[full_backward_name] = get_backward_pre_hook(full_backward_name, backward_data_hook) + CellProcessor.cell_backward_pre_hook.append(backward_pre_hook) + bw_pre_hook = inner.CellBackwardHook(full_backward_name, cell, + self.cell_backward_pre_hook[-1]) + bw_pre_hook.register_backward_pre_hook() - return begin_hook if Const.START == start_or_stop else end_hook + if isinstance(outputs, tuple): + result = bw_pre_hook(*outputs) + else: + result = bw_pre_hook(outputs) + if isinstance(outputs, tuple): + if len(outputs) == 1: + result = (result,) + if len(result) != len(outputs): + raise TypeError( + f"The backward pre hook return value size is {len(result)} " + f"not equal to output size {len(outputs)}" + ) + return result + + return forward_pre_hook - def set_and_get_reserved_name(self, cell, cell_name, is_called_by_pre_hook=False): - if not is_called_by_pre_hook and hasattr(cell, 'has_pre_hook_called') and cell.has_pre_hook_called: - cell.has_pre_hook_called = False + def set_construct_info_in_pre_hook(self, full_name): + if self.cell_stack: + CellProcessor.module_node[full_name] = self.cell_stack[-1] else: - if is_called_by_pre_hook: - cell.has_pre_hook_called = True - index = self.set_cell_count(cell_name) - cell.mindstudio_reserved_name = cell_name + Const.SEP + str(index) - return cell.mindstudio_reserved_name + CellProcessor.module_node[full_name] = None + CellProcessor.cell_stack.append(full_name) + CellProcessor.api_parent_node = full_name + if self.scope: + self.scope.begin_module(full_name) + + def set_construct_info_in_hook(self, full_name): + if self.cell_stack: + CellProcessor.cell_stack.pop() + CellProcessor.api_parent_node = CellProcessor.cell_stack[-1] if self.cell_stack else None + if self.scope: + self.scope.end_module(full_name) diff --git a/debug/accuracy_tools/msprobe/mindspore/code_mapping/graph_parser.py b/debug/accuracy_tools/msprobe/mindspore/code_mapping/graph_parser.py index ee35750fb35c100e2025b0dcbdd9e20ef998b2ee..e09178d6dce5da7adc382f7ee62e8e32fca4aac4 100644 --- a/debug/accuracy_tools/msprobe/mindspore/code_mapping/graph_parser.py +++ b/debug/accuracy_tools/msprobe/mindspore/code_mapping/graph_parser.py @@ -34,19 +34,6 @@ class Parser: if isinstance(subgraph_node.attrs, list): subgraph_node.attrs.extend(attrs) - @staticmethod - def parse_graph_attributes(text: str, graph_node: GraphNode) -> None: - attr_pattern = re.compile(r'# Attrs:\s*(.*)', re.DOTALL) - match = attr_pattern.search(text, graph_node.pos) - if match: - attrs = match.group(1).strip().split('\n') - for attr in attrs: - if not attr: - break - key, value = attr.split(':') - if isinstance(graph_node.attrs, dict): - graph_node.attrs[key.strip()] = value.strip() - @staticmethod def parse_code_info(text: str, start_pos: int, end_pos: int) -> List[str]: code_info = [] @@ -124,8 +111,9 @@ class Parser: scope_match = scope_pattern.search(text, end_pos) scope = scope_match.group(1) if scope_match else "" - id_pattern = re.compile(r'.*cnode_primal_attrs:' - r'\s*\{.*\b(?:forward_unique_id|unique_id):\s*\"(\d+)\".*', re.IGNORECASE) + id_pattern = re.compile( + r'cnode_primal_attrs:'r'\s*\{[\w+]{1, 10000}\b(?:forward_unique_id|unique_id):\s*\"(\d+)\"', + re.IGNORECASE) unique_id_match = id_pattern.search(text, end_pos, scope_match.start()) unique_id = unique_id_match.group(1) if unique_id_match else None @@ -186,7 +174,7 @@ class Parser: node_info.var_inputs.append(callee_name) def parse_subgraphs(self, text: str) -> None: - subgraph_pattern = re.compile(r'subgraph\s+@(\S+)(\([^\)]*\))?\s+.*\{') + subgraph_pattern = re.compile(r'/subgraph\s+@([\w+]{1,1000)(\([^\)]{1,100}\))?\s+\S[^\{]\{/+') matches = list(subgraph_pattern.finditer(text)) end_pos = 0 for match in matches: @@ -203,11 +191,6 @@ class Parser: subgraph_info.end = end_pos logging.info('Parsed subgraph: %s', subgraph_name) - def count_nodes(self) -> Tuple[int, int]: - total_nodes = len(self.nodes) - total_cnodes = sum(1 for node in self.nodes.values() if node.name.startswith('CNode')) - return total_nodes, total_cnodes - def create_backward_map(self): for node in self.nodes.values(): if node.scope and node.scope.startswith("Gradients"): diff --git a/debug/accuracy_tools/msprobe/mindspore/common/const.py b/debug/accuracy_tools/msprobe/mindspore/common/const.py index 9e8c79e51284b8e9696dde150481609f7da8b488..700c669e20dc18b3824126a09e5ceb20f67693a3 100644 --- a/debug/accuracy_tools/msprobe/mindspore/common/const.py +++ b/debug/accuracy_tools/msprobe/mindspore/common/const.py @@ -15,6 +15,7 @@ import numpy as np import mindspore as ms +from mindspore import dtype as mstype from msprobe.core.common.const import Const as CoreConst @@ -23,14 +24,20 @@ class Const: CELL = "cell" API = "api" KERNEL = "kernel" + CELL_AND_API = 'cell_and_api' TOOL_LEVEL_DICT = { CoreConst.LEVEL_L0: CELL, CoreConst.LEVEL_L1: API, - CoreConst.LEVEL_L2: KERNEL + CoreConst.LEVEL_L2: KERNEL, + CoreConst.LEVEL_MIX: CELL_AND_API } - PYNATIVE_MODE = "pynative" + + PYNATIVE_MODE = CoreConst.PYNATIVE_MODE + GRAPH_MODE = "graph" GRAPH_GE_MODE = "graph_ge" GRAPH_KBYK_MODE = "graph_kbyk" + PYNATIVE_GRAPH_MODE = CoreConst.PYNATIVE_GRAPH_MODE + JIT_LEVEL = "jit_level" JIT_LEVEL_O0 = "O0" JIT_LEVEL_O1 = "O1" @@ -61,6 +68,7 @@ class Const: DROPOUT_API_NAME_PREFIX = "dropout" GRAPH_DATA_MODE_LIST = [CoreConst.ALL, CoreConst.INPUT, CoreConst.OUTPUT] + GRAPH_CELL_DUMP_DATA_MODE_LIST = [CoreConst.ALL, CoreConst.FORWARD, CoreConst.BACKWARD] HOOK_MS_PREFIX_DICT = { OPS_DATA_PREFIX: OPS_PREFIX, @@ -69,6 +77,69 @@ class Const: MINT_NN_FUNC_DATA_PREFIX: MINT_NN_FUNC_PREFIX } + NonDifferentiableType = ( + mstype.bool_, mstype.int8, mstype.byte, mstype.uint8, mstype.ubyte, + mstype.int16, mstype.short, mstype.uint16, mstype.ushort, + mstype.int32, mstype.intc, mstype.uint32, mstype.uintc, + mstype.int64, mstype.intp, mstype.uint64, mstype.uintp + ) + + +class MsCompareConst: + # api_info field + MINT = "Mint" + MINT_FUNCTIONAL = "MintFunctional" + TENSOR_API = "Tensor" + FUNCTIONAL_API = "Functional" + FUSION_API = "FUSION" + + API_NAME_STR_LENGTH = 4 + MAX_RECURSION_DEPTH = 20 + + # Mindtorch api_info field + MINDTORCH_TENSOR = "Tensor" + MINDTORCH = "Torch" + MINDTORCH_FUNC = "Functional" + MINDTORCH_NPU = "NPU" + MINDTORCH_DIST = "Distributed" + + MT_VALID_API_TYPES = [ + MINDTORCH, MINDTORCH_FUNC, MINDTORCH_TENSOR + ] + SUPPORTED_FUSION_LIST = ["flash_attention_score"] + + TASK_FIELD = "task" + STATISTICS_TASK = "statistics" + FRAMEWORK = "framework" + TENSOR_TASK = "tensor" + DUMP_DATA_DIR_FIELD = "dump_data_dir" + DATA_FIELD = "data" + + # supported api yaml + SUPPORTED_API_LIST_FILE = "checker_support_api.yaml" + SUPPORTED_TENSOR_LIST_KEY = "tensor" + + # detail_csv + DETAIL_CSV_API_NAME = "API Name" + DETAIL_CSV_BENCH_DTYPE = "Bench Dtype" + DETAIL_CSV_TESTED_DTYPE = "Tested Dtype" + DETAIL_CSV_SHAPE = "Shape" + DETAIL_CSV_PASS_STATUS = "Status" + DETAIL_CSV_MESSAGE = "Message" + DETAIL_CSV_FILE_NAME = "accuracy_checking_details" + + # result_csv + RESULT_CSV_FORWARD_TEST_SUCCESS = "Forward Test Success" + RESULT_CSV_BACKWARD_TEST_SUCCESS = "Backward Test Success" + RESULT_CSV_FILE_NAME = "accuracy_checking_result" + + EPSILON = 1e-8 + + class ProcessStatus: + SUCCESS = "success" + API_NOT_FOUND = "api_not_found" + EXCEPTION_SKIP = "exception_skip" + class FreeBenchmarkConst: ADD_NOISE = "add_noise" diff --git a/debug/accuracy_tools/msprobe/mindspore/common/utils.py b/debug/accuracy_tools/msprobe/mindspore/common/utils.py index ded3faaa22b565ef35c17a7596782976ddf9125d..ce087aca726a7fd766e3e134c0ca5dda513a7df6 100644 --- a/debug/accuracy_tools/msprobe/mindspore/common/utils.py +++ b/debug/accuracy_tools/msprobe/mindspore/common/utils.py @@ -13,19 +13,63 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect import os import random +import types import mindspore as ms - from mindspore import ops +from mindspore.common.jit_config import JitConfig from mindspore.mint import nn +from msprobe.core.common.const import Const +from msprobe.core.common.decorator import recursion_depth_decorator from msprobe.core.common.exceptions import DistributedNotInitializedError from msprobe.core.common.file_utils import path_len_exceeds_limit, check_path_exists, save_npy from msprobe.core.common.log import logger -from msprobe.core.common.const import Const -from msprobe.core.common.utils import CompareException, check_seed_all +from msprobe.core.common.utils import CompareException, check_seed_all, is_save_variable_valid +from msprobe.mindspore.common.const import Const as MsConst + +try: + from mindspore._c_expression import _set_init_iter +except ImportError: + enable_dynamic_kbyk_dump = False +else: + enable_dynamic_kbyk_dump = True + +mindtorch_check_result = None +register_backward_hook_functions = {} +kwargs_exist_in_forward_hook = None + + +class MsprobeStep(ms.train.Callback): + def __init__(self, debugger): + super(MsprobeStep, self).__init__() + self.debugger = debugger + + def on_train_begin(self, run_context): + self.debugger.start() + if enable_dynamic_kbyk_dump: + _set_init_iter(0) + + def on_train_step_begin(self, run_context): + self.debugger.start() + + def on_train_step_end(self, run_context): + self.debugger.stop() + self.debugger.step() + + +class MsprobeInitStep(ms.train.Callback): + def on_train_begin(self, run_context): + try: + from ms._c_expression import _set_init_iter + except ImportError: + logger.warning('MsprobeInitStep does not work on this version of MindSpore.') + return + cb_params = run_context.original_args() + _set_init_iter(cb_params.cur_step_num) def get_rank_if_initialized(): @@ -58,8 +102,8 @@ def convert_to_int(value): def clean_input_kwargs(cell): - if hasattr(cell, 'input_kwargs'): - del cell.input_kwargs + if hasattr(cell, 'msprobe_input_kwargs'): + del cell.msprobe_input_kwargs def list_lowest_level_directories(root_dir): @@ -93,20 +137,6 @@ def seed_all(seed=1234, mode=False, rm_dropout=True): remove_dropout() -class MsprobeStep(ms.train.Callback): - - def __init__(self, debugger): - super(MsprobeStep, self).__init__() - self.debugger = debugger - - def on_train_step_begin(self, run_context): - self.debugger.start() - - def on_train_step_end(self, run_context): - self.debugger.stop() - self.debugger.step() - - class Dropout(ops.Dropout): def __init__(self, keep_prob=0.5, seed0=0, seed1=1): super().__init__(1., seed0, seed1) @@ -142,9 +172,6 @@ def remove_dropout(): nn.functional.dropout = dropout_ext -mindtorch_check_result = None - - def is_mindtorch(): global mindtorch_check_result if mindtorch_check_result is None: @@ -159,17 +186,17 @@ def is_mindtorch(): return mindtorch_check_result -register_backward_hook_functions = {} - - def set_register_backward_hook_functions(): global register_backward_hook_functions + if register_backward_hook_functions: + return + if is_mindtorch(): import torch from msprobe.mindspore.mindtorch import (_call_impl, register_full_backward_pre_hook, register_full_backward_hook) - if not hasattr(torch, "register_full_backward_hook"): + if not hasattr(torch.nn.Module, "register_full_backward_hook"): setattr(torch.nn.Module, "_call_impl", _call_impl) setattr(torch.nn.Module, "register_full_backward_pre_hook", register_full_backward_pre_hook) setattr(torch.nn.Module, "register_full_backward_hook", register_full_backward_hook) @@ -182,9 +209,11 @@ def set_register_backward_hook_functions(): def check_save_param(variable, name, save_backward): # try catch this api to skip invalid call - if not isinstance(variable, (list, dict, ms.Tensor, int, float, str)): + valid_data_types = (ms.Tensor, int, float, str) + if not is_save_variable_valid(variable, valid_data_types): + valid_data_types_with_nested_types = valid_data_types + (dict, tuple, list) logger.warning("PrecisionDebugger.save variable type not valid, " - "should be one of list, dict, ms.Tensor, int, float or string. " + f"should be one of {valid_data_types_with_nested_types}" "Skip current save process.") raise ValueError if not isinstance(name, str): @@ -196,4 +225,104 @@ def check_save_param(variable, name, save_backward): logger.warning("PrecisionDebugger.save_backward name not valid, " "should be bool. " "Skip current save process.") - raise ValueError \ No newline at end of file + raise ValueError + + +def is_graph_mode_cell_dump_allowed(config): + if config.task not in [Const.TENSOR] or is_mindtorch() or not hasattr(ops, 'TensorDump'): + return False + valid_mix_level = [MsConst.CELL_AND_API, Const.LEVEL_MIX] + if config.level in valid_mix_level and config.execution_mode == MsConst.PYNATIVE_MODE: + return True + return config.level == MsConst.CELL or config.level == Const.LEVEL_L0 + + +@recursion_depth_decorator('msprobe.mindspore.common.utils.is_decorated_by_jit') +def is_decorated_by_jit(func): + closure = getattr(func, '__closure__', []) + if closure: + for obj in closure: + if isinstance(obj.cell_contents, JitConfig): + return True + elif isinstance(obj.cell_contents, types.FunctionType) and hasattr(obj.cell_contents, '__closure__'): + if is_decorated_by_jit(obj.cell_contents): + return True + return False + + +@recursion_depth_decorator('msprobe.mindspore.common.utils.get_cells_and_names') +def get_cells_and_names(model, cells_set=None, name_prefix=''): + cells_set = cells_set if cells_set else set() + if model in cells_set: + return + + cells_set.add(model) + jit_decorated = is_decorated_by_jit(model.construct) + yield name_prefix, model, jit_decorated + if jit_decorated: + return + + children_cells = getattr(model, '_cells') + for name, cell in children_cells.items(): + if cell: + cells_name_prefix = f'{name_prefix}{Const.SEP}{name}' if name_prefix else name + jit_decorated = is_decorated_by_jit(model.construct) + if jit_decorated: + yield cells_name_prefix, cell, jit_decorated + else: + for ele in get_cells_and_names(cell, cells_set, cells_name_prefix): + yield ele + + +def get_cells_and_names_with_index(models): + cells_with_index_in_pynative_mode = {} + cells_with_index_in_graph_mode = {} + + def distinguish_cells(cells): + cells_in_pynative_mode = [] + cells_in_graph_mode = [] + for name, cell, jit_decorated in cells: + if jit_decorated: + cells_in_graph_mode.append((name, cell)) + else: + cells_in_pynative_mode.append((name, cell)) + return cells_in_pynative_mode, cells_in_graph_mode + + if is_mindtorch(): + if isinstance(models, (list, tuple)): + for index, model in enumerate(models): + cells_with_index_in_pynative_mode[str(index)] = model.named_modules() + else: + cells_with_index_in_pynative_mode["-1"] = models.named_modules() + else: + if isinstance(models, (list, tuple)): + for index, model in enumerate(models): + cells = get_cells_and_names(model) + cells_in_pynative_mode, cells_in_graph_mode = distinguish_cells(cells) + cells_with_index_in_pynative_mode[str(index)] = cells_in_pynative_mode + cells_with_index_in_graph_mode[str(index)] = cells_in_graph_mode + else: + cells = get_cells_and_names(models) + cells_in_pynative_mode, cells_in_graph_mode = distinguish_cells(cells) + cells_with_index_in_pynative_mode["-1"] = cells_in_pynative_mode + cells_with_index_in_graph_mode["-1"] = cells_in_graph_mode + + return cells_with_index_in_pynative_mode, cells_with_index_in_graph_mode + + +def has_kwargs_in_forward_hook(): + global kwargs_exist_in_forward_hook + + if kwargs_exist_in_forward_hook is None: + if is_mindtorch(): + kwargs_exist_in_forward_hook = True + return kwargs_exist_in_forward_hook + + try: + func_params = inspect.signature(nn.Cell.register_forward_hook).parameters + kwargs_exist_in_forward_hook = 'with_kwargs' in func_params + except Exception: + kwargs_exist_in_forward_hook = False + return kwargs_exist_in_forward_hook + + return kwargs_exist_in_forward_hook diff --git a/debug/accuracy_tools/msprobe/mindspore/compare/distributed_compare.py b/debug/accuracy_tools/msprobe/mindspore/compare/distributed_compare.py index 46f825330dbb8b7ff5ce9d42cef5c6b74e3846f2..fa8b68070945f08c0a18d2fc2c142b05de8707fe 100644 --- a/debug/accuracy_tools/msprobe/mindspore/compare/distributed_compare.py +++ b/debug/accuracy_tools/msprobe/mindspore/compare/distributed_compare.py @@ -47,7 +47,13 @@ def ms_compare_distributed(npu_dump_dir, bench_dump_dir, output_path, **kwargs): 'bench_json_path': bench_path, 'is_print_compare_log': is_print_compare_log } - ms_compare(input_param=dump_result_param, output_path=output_path, suffix=f'_{nr}-{br}', **kwargs) + try: + ms_compare(input_param=dump_result_param, output_path=output_path, suffix=f'_{nr}', **kwargs) + except CompareException as e: + if e.code == CompareException.INVALID_DATA_ERROR: + logger.error(f"Invalid or missing 'data' in dump.json. Skipping {nr} comparison.") + if e.code == CompareException.INVALID_TASK_ERROR: + logger.error(f"Invalid or missing 'task' in dump.json. Skipping {nr} comparison.") def ms_graph_compare(inputs, outputs): diff --git a/debug/accuracy_tools/msprobe/mindspore/compare/ms_compare.py b/debug/accuracy_tools/msprobe/mindspore/compare/ms_compare.py index 8509a7f38add0c2e8d3f3638f4c247895e07bd6d..dd2c6f8c103337498e037db00f65911329b2621d 100644 --- a/debug/accuracy_tools/msprobe/mindspore/compare/ms_compare.py +++ b/debug/accuracy_tools/msprobe/mindspore/compare/ms_compare.py @@ -13,410 +13,29 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os -import re -from collections import defaultdict - -import numpy as np -import pandas as pd - -from msprobe.core.common.const import CompareConst, Const -from msprobe.core.common.exceptions import FileCheckException -from msprobe.core.common.file_utils import FileOpen, create_directory, load_json, load_npy, load_yaml -from msprobe.core.common.log import logger -from msprobe.core.common.utils import CompareException, check_compare_param, check_configuration_param, \ - check_op_str_pattern_valid, get_dump_mode, set_dump_path -from msprobe.core.compare.acc_compare import Comparator, ModeConfig -from msprobe.core.compare.check import dtype_mapping +from msprobe.core.compare.acc_compare import Comparator, ModeConfig, MappingConfig, setup_comparison from msprobe.core.compare.layer_mapping import generate_data_mapping_by_layer_mapping -from msprobe.core.compare.utils import set_stack_json_path, reorder_op_x_list - - -class MappingConfig: - def __init__(self, cell_mapping=None, api_mapping=None, data_mapping=None): - self.cell_mapping = cell_mapping - self.api_mapping = api_mapping - self.data_mapping = data_mapping - - -class MSComparator(Comparator): - """ - 用于mindspore动态图同框架/跨框架精度比对,支持md5/summary/all模式。 - cell_mapping: mindspore在cell级别(L0)dump数据和pytorch的module之间的映射关系; - api_mapping: mindspore在api级别(L1)dump数据和pytorch的api之间的映射关系; - data_mapping: mindspore的cell或api的入参/出参和pytorch之间的映射关系; - is_cross_framework: 是否跨框架。 - """ - def __init__(self, mode_config, mapping_config=None, is_cross_framework=False): - super().__init__(mode_config) - self.frame_name = MSComparator.__name__ - - self.stack_mode = mode_config.stack_mode - self.auto_analyze = mode_config.auto_analyze - self.fuzzy_match = mode_config.fuzzy_match - self.dump_mode = mode_config.dump_mode - - if mapping_config: - self.cell_mapping = mapping_config.cell_mapping - self.api_mapping = mapping_config.api_mapping - self.data_mapping = mapping_config.data_mapping - - if self.data_mapping: - self.cross_frame = is_cross_framework - else: - self.cross_frame = self.cell_mapping is not None or self.api_mapping is not None - self.cell_mapping_dict = self.load_mapping_file(self.cell_mapping) - self.api_mapping_dict = self.load_mapping_file(self.api_mapping) - if self.api_mapping is not None: - self.ms_to_pt_mapping = self.load_internal_api() - - if isinstance(self.data_mapping, str) or self.data_mapping is None: - self.data_mapping_dict = self.load_mapping_file(self.data_mapping) - elif isinstance(self.data_mapping, dict): - self.data_mapping_dict = self.data_mapping - else: - raise TypeError(f"The type of parameter `data_mapping` must be dict, str or None, but got " - f"{type(self.data_mapping)}") - - def calc_accuracy(self, result_df, header): - condition_no_bench = result_df[CompareConst.BENCH_NAME] == CompareConst.N_A - result_df[condition_no_bench] = result_df[condition_no_bench].fillna(CompareConst.N_A) - result_df.loc[condition_no_bench, CompareConst.ERROR_MESSAGE] = CompareConst.NO_BENCH - - def calc_summary_diff(data_type: str): - def type_check(val): - check_series = pd.Series(False, index=val.index) - val_str = val.astype(str) - check_series[pd.to_numeric(val_str, errors='coerce').notna() | val_str.str.lower().eq('nan')] = True - return check_series - - def get_number(val): - return pd.to_numeric(val.astype(str), errors='coerce') - - ms_val = result_df['NPU ' + data_type] - pt_val = result_df['Bench ' + data_type] - diff_name = data_type.capitalize() + ' diff' - rel_err_name = ('norm' if data_type == 'l2norm' else data_type).capitalize() + 'RelativeErr' - condition_na = ~type_check(ms_val) | ~type_check(pt_val) - result_df.loc[condition_na, [diff_name, rel_err_name]] = CompareConst.N_A - result_df.loc[~(condition_no_bench | condition_na), diff_name] = get_number(ms_val) - get_number(pt_val) - condition_nan_diff = ~condition_no_bench & ~condition_na & result_df[diff_name].isna() - condition_not_nan_diff = ~condition_no_bench & ~condition_na & result_df[diff_name].notna() - result_df.loc[condition_nan_diff, [diff_name, rel_err_name]] = CompareConst.NAN - condition_pt_zero = pt_val == 0 - result_df.loc[condition_not_nan_diff & condition_pt_zero, rel_err_name] = CompareConst.NAN - condition_ref_err = condition_not_nan_diff & ~condition_pt_zero - result_df.loc[condition_ref_err, rel_err_name] = (result_df.loc[condition_ref_err, diff_name] / - pt_val[condition_ref_err] * 100) - result_df.loc[condition_ref_err, rel_err_name] = (result_df.loc[condition_ref_err, rel_err_name] - .abs().astype(str) + '%') - magnitude = get_number(result_df[diff_name]).abs() / ( - pd.Series(np.maximum(get_number(ms_val), get_number(pt_val))).abs() + CompareConst.EPSILON) - return magnitude > CompareConst.MAGNITUDE - - if self.dump_mode == Const.MD5: - condition_md5_equal = result_df[CompareConst.NPU_MD5] == result_df[CompareConst.BENCH_MD5] - result_df.loc[condition_md5_equal, CompareConst.RESULT] = CompareConst.PASS - result_df.loc[~condition_md5_equal & ~condition_no_bench, CompareConst.RESULT] = CompareConst.DIFF - elif self.dump_mode == Const.SUMMARY: - warning_list = [calc_summary_diff(data_type) for data_type in ['max', 'min', 'mean', 'l2norm']] - warning_flag = pd.DataFrame(warning_list).all() - result_df.loc[~condition_no_bench, [CompareConst.RESULT, CompareConst.ERROR_MESSAGE]] = '' - result_df.loc[warning_flag, CompareConst.RESULT] = CompareConst.WARNING - result_df.loc[warning_flag, CompareConst.ERROR_MESSAGE] = 'Need double check api accuracy.' - else: - fill_cols = [CompareConst.COSINE, CompareConst.MAX_ABS_ERR, CompareConst.MAX_RELATIVE_ERR, - CompareConst.ONE_THOUSANDTH_ERR_RATIO, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO, - CompareConst.ERROR_MESSAGE] - result_df.loc[~condition_no_bench, fill_cols] = '' - result_df.loc[~condition_no_bench, CompareConst.ACCURACY] = CompareConst.ACCURACY_CHECK_YES - return result_df[header] - - def make_result_df(self, result): - header = CompareConst.HEAD_OF_COMPARE_MODE[self.dump_mode][:] - - if self.stack_mode: - header.append(CompareConst.STACK) - if self.dump_mode == Const.ALL: - header.append(CompareConst.DATA_NAME) - result.rename(columns={'op_name_x': CompareConst.NPU_NAME, - 'op_name_y': CompareConst.BENCH_NAME, - 'dtype_x': CompareConst.NPU_DTYPE, - 'dtype_y': CompareConst.BENCH_DTYPE, - 'shape_x': CompareConst.NPU_SHAPE, - 'shape_y': CompareConst.BENCH_SHAPE, - 'md5_x': CompareConst.NPU_MD5, - 'md5_y': CompareConst.BENCH_MD5, - 'data_name_x': CompareConst.DATA_NAME, - 'stack_info_x': CompareConst.STACK}, inplace=True) - - npu_summary = [CompareConst.NPU_MAX, CompareConst.NPU_MIN, CompareConst.NPU_MEAN, CompareConst.NPU_NORM] - bench_summary = [CompareConst.BENCH_MAX, CompareConst.BENCH_MIN, CompareConst.BENCH_MEAN, - CompareConst.BENCH_NORM] - - def set_summary(summary): - if summary == CompareConst.N_A: - return [CompareConst.N_A] * 4 - summary_list = [] - for i in summary: - if i is None: - summary_list.append(CompareConst.N_A) - elif str(i).lower() == 'nan': - summary_list.append(CompareConst.NAN) - else: - summary_list.append(i) - return summary_list - - result[npu_summary] = result['summary_x'].apply(set_summary).tolist() - result[bench_summary] = result['summary_y'].apply(set_summary).tolist() - result_df = pd.DataFrame(columns=header) - for h in header: - if h in result.columns: - result_df[h] = result[h] - return self.calc_accuracy(result_df, header) - - def load_internal_api(self): - cur_path = os.path.dirname(os.path.realpath(__file__)) - yaml_path = os.path.abspath(os.path.join(cur_path, CompareConst.INTERNAL_API_MAPPING_FILE)) - return load_yaml(yaml_path) - - def load_mapping_file(self, mapping_file): - if isinstance(mapping_file, str): - mapping_dict = load_yaml(mapping_file) - else: - mapping_dict = {} - return mapping_dict - - def process_cell_mapping(self, npu_op_name): - if not npu_op_name: - return CompareConst.N_A - param_grad_flag = Const.PARAMS_GRAD in npu_op_name.split(Const.SEP) - if not param_grad_flag and not re.search(Const.REGEX_FORWARD_BACKWARD, npu_op_name): - return CompareConst.N_A - npu_op_name = npu_op_name.replace("Cell", "Module", 1) - if self.cell_mapping_dict: - # get cell name & class name from op_name - # Cell.fc1.Dense.forward.0.input.0 - cell_name = re.split(r'\.(?:forward|backward|parameters_grad)\.', npu_op_name.split(Const.SEP, 1)[-1])[0] - if cell_name in self.cell_mapping_dict: - npu_op_name = npu_op_name.replace(cell_name, self.cell_mapping_dict[cell_name], 1) - return npu_op_name - - def read_npy_data(self, dir_path, file_name, load_pt_file=False): - if not file_name: - return None - data_path = os.path.join(dir_path, file_name) - if load_pt_file: - import torch - from msprobe.pytorch.common.utils import load_pt - data_value = load_pt(data_path, True).detach() - if data_value.dtype == torch.bfloat16: - data_value = data_value.to(torch.float32) - data_value = data_value.numpy() - else: - data_value = load_npy(data_path) - return data_value - - def process_internal_api_mapping(self, npu_op_name): - # get api name & class name from op_name - # Functional.addcmul.0.forward.input.0 - ms_api_name = self.get_api_name(npu_op_name.split(Const.SEP)) - class_name = ms_api_name.split(Const.SEP)[0] - if class_name == "Mint": - return npu_op_name.replace("Mint", "Torch") - elif class_name == "MintFunctional": - return npu_op_name.replace("MintFunctional", "Functional") - elif self.ms_to_pt_mapping.get(ms_api_name): - return npu_op_name.replace(ms_api_name, self.ms_to_pt_mapping.get(ms_api_name)) - else: - return npu_op_name - - def get_api_name(self, api_list): - try: - api_name = api_list[0] + Const.SEP + api_list[1] - except IndexError as error: - logger.error(f'Failed to retrieve API name, please check if the dump data is reasonable') - raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from error - return api_name - - def compare_process(self, file_lists): - npu_json_path, bench_json_path, stack_json_path = file_lists - npu_json_data = load_json(npu_json_path) - bench_json_data = load_json(bench_json_path) - stack_json_data = load_json(stack_json_path) if self.stack_mode else None - - npu_df = self.gen_data_df(npu_json_data, stack_json_data) - bench_df = self.gen_data_df(bench_json_data, stack_json_data) - if self.cell_mapping: - npu_df[CompareConst.COMPARE_KEY] = npu_df[CompareConst.OP_NAME].apply(self.process_cell_mapping) - elif self.api_mapping: - npu_df[CompareConst.COMPARE_KEY] = npu_df[CompareConst.OP_NAME].apply(self.process_internal_api_mapping) - if isinstance(self.api_mapping, str): - self.modify_compare_data_with_user_mapping(npu_df, bench_df) - else: - npu_df[CompareConst.COMPARE_KEY] = npu_df[CompareConst.OP_NAME] - npu_df[[Const.DTYPE, Const.SHAPE]] = npu_df[[Const.DTYPE, Const.SHAPE]].astype(str) - bench_df[[Const.DTYPE, Const.SHAPE]] = bench_df[[Const.DTYPE, Const.SHAPE]].astype(str) - npu_df[CompareConst.COMPARE_SHAPE] = npu_df[Const.SHAPE] - bench_df[CompareConst.COMPARE_KEY] = bench_df[CompareConst.OP_NAME] - bench_df[CompareConst.COMPARE_SHAPE] = bench_df[Const.SHAPE] - match_result = pd.merge(npu_df, bench_df, on=[CompareConst.COMPARE_KEY, CompareConst.COMPARE_SHAPE], - how='outer') - match_result = match_result[match_result['op_name_x'].notna()].fillna(CompareConst.N_A) - - def gen_dtype_condition(): - npu_dtype = match_result['dtype_x'] - bench_dtype = match_result['dtype_y'] - if self.cross_frame: - npu_dtype = npu_dtype.map(dtype_mapping).fillna(npu_dtype) - return ((npu_dtype == bench_dtype) | - ((npu_dtype == Const.FLOAT16) & (bench_dtype == Const.FLOAT32)) | - ((npu_dtype == Const.FLOAT32) & (bench_dtype == Const.FLOAT16)) | - ((npu_dtype == Const.FLOAT16) & (bench_dtype == Const.BFLOAT16)) | - ((npu_dtype == Const.BFLOAT16) & (bench_dtype == Const.FLOAT16)) | - ((npu_dtype == Const.TORCH_FLOAT16) & (bench_dtype == Const.TORCH_FLOAT32)) | - ((npu_dtype == Const.TORCH_FLOAT32) & (bench_dtype == Const.TORCH_FLOAT16)) | - ((npu_dtype == Const.TORCH_FLOAT16) & (bench_dtype == Const.TORCH_BFLOAT16)) | - ((npu_dtype == Const.TORCH_BFLOAT16) & (bench_dtype == Const.TORCH_FLOAT16))) - - match_result.loc[~gen_dtype_condition(), [i + '_y' for i in bench_df.columns]] = CompareConst.N_A - return self.make_result_df(match_result) - - def modify_compare_data_with_user_mapping(self, npu_df, bench_df): - def get_api_indices_dict(op_name_df): - api_indices_dict = defaultdict(list) - for op_index, name in enumerate(op_name_df[CompareConst.OP_NAME]): - api = self.get_api_name(name.split(Const.SEP)) - api_indices_dict[api].append(op_index) - return api_indices_dict - - ms_api_indices_dict = get_api_indices_dict(npu_df) - pt_api_indices_dict = get_api_indices_dict(bench_df) - - def gen_input_compare_key(pattern, term): - flag = True - for i, prefix in enumerate(mapping_dict.get(f'ms_{term}')): - if op_name.split(pattern)[1].startswith(str(prefix)): - npu_df.loc[index, CompareConst.COMPARE_KEY] = ( - op_name.replace(pattern + str(prefix), - pattern + str(mapping_dict.get(f'pt_{term}')[i]))) - flag = False - return flag - - for mapping_dict in self.api_mapping_dict: - keys_to_compare = [ - ('ms_args', 'pt_args'), - ('ms_output', 'pt_output'), - ('ms_parameters', 'pt_parameters'), - ('ms_parameters_grad', 'pt_parameters_grad'), - ] - if not all(len(mapping_dict.get(k1, [])) == len(mapping_dict.get(k2, [])) for k1, k2 in keys_to_compare): - logger.warning('The user-defined mapping table is incorrect,\ - make sure that the number of parameters is equal') - continue - - ms_api, pt_api = mapping_dict.get('ms_api'), mapping_dict.get('pt_api') - if ms_api not in ms_api_indices_dict or pt_api not in pt_api_indices_dict: - continue - for index in ms_api_indices_dict.get(ms_api): - op_name = npu_df.loc[index, CompareConst.OP_NAME].replace(ms_api, pt_api, 1) - if CompareConst.INPUT_PATTERN in op_name: - is_abandoned = gen_input_compare_key(CompareConst.INPUT_PATTERN, 'args') - elif CompareConst.KWARGS_PATTERN in op_name: - is_abandoned = gen_input_compare_key(CompareConst.KWARGS_PATTERN, 'args') - elif CompareConst.OUTPUT_PATTERN in op_name: - is_abandoned = gen_input_compare_key(CompareConst.OUTPUT_PATTERN, 'output') - elif CompareConst.PARAMS_PATTERN in op_name: - is_abandoned = gen_input_compare_key(CompareConst.PARAMS_PATTERN, 'parameters') - elif CompareConst.PARAMS_GRAD_PATTERN in op_name: - is_abandoned = gen_input_compare_key(CompareConst.PARAMS_GRAD_PATTERN, 'parameters_grad') - else: - logger.error(f'Excepted op_name: {op_name}') - raise CompareException(CompareException.INVALID_DATA_ERROR) - if is_abandoned: - npu_df.loc[index, CompareConst.COMPARE_KEY] = op_name + 'abandoned' - - def gen_data_df(self, data_json, stack_json_data): - result = { - CompareConst.OP_NAME: [], - Const.DTYPE: [], - Const.SHAPE: [], - Const.SUMMARY: [], - 'stack_info': [] - } - if self.dump_mode == Const.ALL: - result['data_name'] = [] - elif self.dump_mode == Const.MD5: - result[Const.MD5] = [] - for data_name in data_json['data']: - check_op_str_pattern_valid(data_name) - merge_list = self.gen_merge_list(data_json, data_name, stack_json_data) - if not merge_list: - continue - - op_name_list = merge_list.get(CompareConst.OP_NAME) - summary_list = merge_list.get(Const.SUMMARY) - data_name_list = merge_list.get('data_name') - op_name_reorder, summary_reorder, data_name_reorder = reorder_op_x_list(op_name_list, - summary_list, - data_name_list) - for op_name in op_name_reorder: - result[CompareConst.OP_NAME].append(op_name) - if (CompareConst.INPUT_PATTERN in op_name) or (CompareConst.KWARGS_PATTERN in op_name): - struct = merge_list[CompareConst.INPUT_STRUCT].pop(0) - elif CompareConst.OUTPUT_PATTERN in op_name: - struct = merge_list[CompareConst.OUTPUT_STRUCT].pop(0) - elif CompareConst.PARAMS_PATTERN in op_name: - struct = merge_list[CompareConst.PARAMS_STRUCT].pop(0) - else: - struct = merge_list[CompareConst.PARAMS_GRAD_STRUCT].pop(0) - result[Const.DTYPE].append(struct[0]) - result[Const.SHAPE].append(struct[1]) - if self.dump_mode == Const.MD5: - result[Const.MD5].append(struct[2]) - result[Const.SUMMARY].append(summary_reorder.pop(0)) - result['stack_info'].append(merge_list['stack_info'][0] if self.stack_mode else None) - if self.dump_mode == Const.ALL: - result['data_name'].append(data_name_reorder.pop(0)) - return pd.DataFrame(result) +from msprobe.mindspore.compare.utils import read_npy_data, check_cross_framework -def check_cross_framework(bench_json_path): - pattern = r'"data_name":\s*"[^"]+\.pt"' - with FileOpen(bench_json_path, 'r') as file: - for line in file: - if re.search(pattern, line): - return True - return False +def read_real_data(npu_dir, npu_data_name, bench_dir, bench_data_name, cross_frame) -> tuple: + n_value = read_npy_data(npu_dir, npu_data_name) + if cross_frame: + from msprobe.pytorch.compare.utils import read_pt_data + b_value = read_pt_data(bench_dir, bench_data_name) + else: + b_value = read_npy_data(bench_dir, bench_data_name) + return n_value, b_value def ms_compare(input_param, output_path, **kwargs): - try: - auto_analyze = kwargs.get('auto_analyze', True) - fuzzy_match = kwargs.get('fuzzy_match', False) - cell_mapping = kwargs.get('cell_mapping', None) - api_mapping = kwargs.get('api_mapping', None) - data_mapping = kwargs.get('data_mapping', None) - layer_mapping = kwargs.get('layer_mapping', None) - suffix = kwargs.get('suffix', '') + config = setup_comparison(input_param, output_path, **kwargs) - set_dump_path(input_param) - dump_mode = get_dump_mode(input_param) - if 'stack_json_path' in input_param: - stack_mode = kwargs.get('stack_mode', False) - else: - stack_mode = set_stack_json_path(input_param) # set stack_mode and set "stack_json_path" in input_param - check_configuration_param(stack_mode, auto_analyze, fuzzy_match, input_param.get('is_print_compare_log', True)) - create_directory(output_path) - check_compare_param(input_param, output_path, dump_mode, stack_mode) - except (CompareException, FileCheckException) as error: - logger.error('Compare failed. Please check the arguments and do it again!') - raise CompareException(error.code) from error - if layer_mapping: - data_mapping = generate_data_mapping_by_layer_mapping(input_param, layer_mapping, output_path) + if config.layer_mapping: + config.data_mapping = generate_data_mapping_by_layer_mapping(input_param, config.layer_mapping, output_path) - mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) - mapping_config = MappingConfig(cell_mapping, api_mapping, data_mapping) is_cross_framework = check_cross_framework(input_param.get('bench_json_path')) - ms_comparator = MSComparator(mode_config, mapping_config, is_cross_framework) - ms_comparator.compare_core(input_param, output_path, suffix=suffix) + mode_config = ModeConfig(config.stack_mode, config.auto_analyze, config.fuzzy_match, config.dump_mode) + mapping_config = MappingConfig(config.cell_mapping, config.api_mapping, config.data_mapping) + ms_comparator = Comparator(read_real_data, mode_config, mapping_config, is_cross_framework) + ms_comparator.compare_core(input_param, output_path, suffix=config.suffix) diff --git a/debug/accuracy_tools/msprobe/mindspore/compare/ms_graph_compare.py b/debug/accuracy_tools/msprobe/mindspore/compare/ms_graph_compare.py index 701988ba483de4e13d85892dbb42d62c7cc805b8..ecf8e84d136fdfcbcab6372e45099b1c931900ea 100644 --- a/debug/accuracy_tools/msprobe/mindspore/compare/ms_graph_compare.py +++ b/debug/accuracy_tools/msprobe/mindspore/compare/ms_graph_compare.py @@ -85,11 +85,13 @@ def statistic_data_read(statistic_file_list, statistic_file_path): } for statistic_file in statistic_file_list: content = read_csv(statistic_file, as_pd=False) + if not content: + logger.error(f'Empty dump file: {statistic_file}') + raise CompareException(f'Empty dump file: {statistic_file}') header = content[0] - for key in header_index.keys(): - for index, value in enumerate(header): - if key == value: - header_index[key] = index + for index, value in enumerate(header): + if value in header_index: + header_index[value] = index statistic_data_list.extend(content[1:]) for key in header_index.keys(): @@ -97,7 +99,14 @@ def statistic_data_read(statistic_file_list, statistic_file_path): logger.warning(f"Data_path {statistic_file_path} has no key {key}.") for data in statistic_data_list: - compare_key = f"{data[1]}.{data[2]}.{data[3]}.{data[5]}" + ''' + 13列分别是OpType, OpName, TaskId, StreamId, TimeStamp, IO, Slot, DataSize, + DataType, Shape, MaxValue, MinValue, L2NormValue + ''' + if len(data) < 13: + logger.error(f'Dump file {statistic_file_path} has been modified into incorrect format!') + raise CompareException(f'Dump file {statistic_file_path} has been modified into incorrect format!') + compare_key = f"{data[1]}.{data[2]}.{data[5]}.{data[6]}" # OpName, TaskId, IO, Slot op_name = f"{compare_key} {statistic_file_path}" timestamp = int(data[4]) result_data = [op_name, compare_key, timestamp] @@ -195,11 +204,12 @@ class GraphMSComparator: if not error_flag: result_list, err_msg = compare_ops_apply(n_value, b_value, False, "") result_dict[CompareConst.COSINE] = result_list[0] - result_dict[CompareConst.MAX_ABS_ERR] = result_list[1] - result_dict[CompareConst.MAX_RELATIVE_ERR] = result_list[2] - result_dict[CompareConst.ONE_THOUSANDTH_ERR_RATIO] = result_list[3] - result_dict[CompareConst.FIVE_THOUSANDTHS_ERR_RATIO] = result_list[4] - result_dict[CompareConst.ACCURACY] = check_accuracy(result_list[0], result_list[1]) + result_dict[CompareConst.EUC_DIST] = result_list[1] + result_dict[CompareConst.MAX_ABS_ERR] = result_list[2] + result_dict[CompareConst.MAX_RELATIVE_ERR] = result_list[3] + result_dict[CompareConst.ONE_THOUSANDTH_ERR_RATIO] = result_list[4] + result_dict[CompareConst.FIVE_THOUSANDTHS_ERR_RATIO] = result_list[5] + result_dict[CompareConst.ACCURACY] = check_accuracy(result_list[0], result_list[2]) result_dict[CompareConst.ERROR_MESSAGE] = err_msg return pd.Series(result_dict) diff --git a/debug/accuracy_tools/msprobe/mindspore/compare/utils.py b/debug/accuracy_tools/msprobe/mindspore/compare/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7a9c78e8f74426c23982723fcf90f729fc9e694c --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/compare/utils.py @@ -0,0 +1,37 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from msprobe.core.common.const import Const +from msprobe.core.common.file_utils import load_npy, FileChecker, FileCheckConst +from msprobe.core.common.utils import detect_framework_by_dump_json + + +def read_npy_data(dir_path, file_name): + if not file_name: + return None + + data_path = os.path.join(dir_path, file_name) + path_checker = FileChecker(data_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE, + FileCheckConst.NUMPY_SUFFIX, False) + data_path = path_checker.common_check() + data_value = load_npy(data_path) + return data_value + + +def check_cross_framework(bench_json_path): + framework = detect_framework_by_dump_json(bench_json_path) + return framework == Const.PT_FRAMEWORK diff --git a/debug/accuracy_tools/msprobe/mindspore/debugger/debugger_config.py b/debug/accuracy_tools/msprobe/mindspore/debugger/debugger_config.py index 92155b4ec4ebd636477ef67f1c75b43e7a82b802..02e8c629708df3a679dc9ffe58a499f7cd02c01f 100644 --- a/debug/accuracy_tools/msprobe/mindspore/debugger/debugger_config.py +++ b/debug/accuracy_tools/msprobe/mindspore/debugger/debugger_config.py @@ -15,12 +15,18 @@ import os +from mindspore import nn + from msprobe.core.common.const import Const from msprobe.core.common.exceptions import MsprobeException from msprobe.core.common.file_utils import create_directory +from msprobe.core.common.log import logger from msprobe.mindspore.common.const import Const as MsConst from msprobe.mindspore.common.const import FreeBenchmarkConst -from msprobe.core.common.log import logger +from msprobe.mindspore.common.utils import is_mindtorch + +if is_mindtorch(): + import torch class DebuggerConfig: @@ -41,8 +47,16 @@ class DebuggerConfig: self.check_mode = task_config.check_mode self.framework = Const.MS_FRAMEWORK self.summary_mode = task_config.summary_mode + self.stat_cal_mode = task_config.stat_cal_mode if hasattr(task_config, 'stat_cal_mode') else None + self.device_stat_precision_mode = task_config.device_stat_precision_mode \ + if hasattr(task_config, 'device_stat_precision_mode') else None self.async_dump = common_config.async_dump if common_config.async_dump else False + if hasattr(task_config, 'td_config_path'): + self.td_config_path = "" if not task_config.td_config_path else task_config.td_config_path + else: + self.td_config_path = "" self.check() + self._check_statistics_config(task_config) create_directory(self.dump_path) if self.task == Const.FREE_BENCHMARK: @@ -53,13 +67,40 @@ class DebuggerConfig: self.stage = FreeBenchmarkConst.DEFAULT_STAGE if not task_config.fuzz_stage else task_config.fuzz_stage if self.handler_type == FreeBenchmarkConst.FIX and \ self.pert_type != FreeBenchmarkConst.DEFAULT_PERT_TYPE: - raise ValueError("pert_mode must be improve_precision or empty when handler_type is fix, " - f"but got {self.pert_type}.") + logger.error("pert_mode must be improve_precision or empty when handler_type is fix, " + f"but got {self.pert_type}.") + raise ValueError if self.stage == Const.BACKWARD and self.handler_type == FreeBenchmarkConst.FIX: - raise ValueError("handler_type must be check or empty when fuzz_stage is backward, " - f"but got {self.handler_type}.") + logger.error("handler_type must be check or empty when fuzz_stage is backward, " + f"but got {self.handler_type}.") + raise ValueError self.dump_level = FreeBenchmarkConst.DEFAULT_DUMP_LEVEL + @staticmethod + def check_model(models, token_range=None): + if token_range and not models: + error_info = "The 'model' parameter must be provided when token_range is not None" + raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR, error_info) + + target_module_type = (torch.nn.Module, "torch.nn.Module") if is_mindtorch() else (nn.Cell, "mindspore.nn.Cell") + if models is None or isinstance(models, target_module_type[0]): + return models + error_model = None + if isinstance(models, (list, tuple)): + for model in models: + if not isinstance(model, target_module_type[0]): + error_model = model + break + else: + error_model = models + + if error_model is not None: + error_info = (f"The 'model' parameter must be a {target_module_type[1]} or list[{target_module_type[1]}] " + f"type, currently there is a {type(error_model)} type.") + raise MsprobeException( + MsprobeException.INVALID_PARAM_ERROR, error_info) + return models + def check(self): if not self.dump_path: raise Exception("Dump path is empty.") @@ -74,8 +115,12 @@ class DebuggerConfig: self.check_mode = "all" if not isinstance(self.async_dump, bool): raise Exception("The parameters async_dump should be bool.") - if self.async_dump and self.task == Const.TENSOR and not self.list: - raise Exception("The parameters async_dump is true in tensor task, the parameters list cannot be empty.") + if self.async_dump and self.task == Const.TENSOR: + if self.level_ori == Const.LEVEL_DEBUG: + self.list = [] # async_dump + debug level case ignore list + if not self.list and self.level_ori != Const.LEVEL_DEBUG: + raise Exception("The parameters async_dump is true in tensor task," + " the parameters list cannot be empty.") if self.task == Const.STRUCTURE and self.level_ori not in [Const.LEVEL_L0, Const.LEVEL_MIX]: logger.warning_on_rank_0( f"When the task is set to structure, the level should be one of {[Const.LEVEL_L0, Const.LEVEL_MIX]}. " @@ -84,15 +129,24 @@ class DebuggerConfig: self.level_ori = Const.LEVEL_MIX return True - def check_config_with_l2(self): - if self.level_ori != Const.LEVEL_L2: - return - if self.task != Const.TENSOR: + def check_config_with_l2(self, is_graph_config): + if not is_graph_config and self.task != Const.TENSOR: raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR, f"When level is set to L2, the task must be set to tensor.") - if self.scope: + if not is_graph_config and self.scope: raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR, f"When level is set to L2, the scope cannot be configured.") - if not self.list or len(self.list) != 1: + if not is_graph_config and (not self.list or len(self.list) != 1): raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR, f"When level is set to L2, the list must be configured as a list with one api name.") + + def _check_statistics_config(self, task_config): + if self.task != Const.STATISTICS: + return + self.tensor_list = [] + if not hasattr(task_config, "tensor_list"): + return + if self.level_ori == Const.LEVEL_DEBUG and task_config.tensor_list: + logger.warning_on_rank_0("When level is set to debug, the tensor_list will be invalid.") + return + self.tensor_list = task_config.tensor_list diff --git a/debug/accuracy_tools/msprobe/mindspore/debugger/precision_debugger.py b/debug/accuracy_tools/msprobe/mindspore/debugger/precision_debugger.py index 7694d71dd98ae1c7c4611f9435a274ac018e5df6..182eaf43339cdde072ebc7163a24c00474c9454a 100644 --- a/debug/accuracy_tools/msprobe/mindspore/debugger/precision_debugger.py +++ b/debug/accuracy_tools/msprobe/mindspore/debugger/precision_debugger.py @@ -19,22 +19,33 @@ from collections import defaultdict, namedtuple import mindspore as ms from mindspore._c_expression import MSContext -from msprobe.core.common.const import Const, FileCheckConst, MsgConst -from msprobe.core.common.exceptions import MsprobeException -from msprobe.core.common.file_utils import FileChecker -from msprobe.core.common.utils import get_real_step_or_rank +from msprobe.core.common.const import Const, MsgConst +from msprobe.core.common.utils import check_token_range +from msprobe.core.common.runtime import Runtime +from msprobe.core.debugger.precision_debugger import BasePrecisionDebugger from msprobe.mindspore.cell_processor import CellProcessor from msprobe.mindspore.common.const import Const as MsConst -from msprobe.mindspore.common.utils import set_register_backward_hook_functions, check_save_param +from msprobe.mindspore.common.utils import ( + set_register_backward_hook_functions, + check_save_param, + is_graph_mode_cell_dump_allowed +) from msprobe.mindspore.debugger.debugger_config import DebuggerConfig -from msprobe.mindspore.dump.hook_cell.api_registry import api_register +from msprobe.mindspore.dump.graph_mode_cell_dump import GraphModeCellDump +from msprobe.mindspore.dump.hook_cell.api_register import get_api_register from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell from msprobe.mindspore.grad_probe.grad_monitor import GradientMonitor -from msprobe.mindspore.ms_config import parse_json_config -from msprobe.mindspore.runtime import Runtime -from msprobe.mindspore.service import Service +from msprobe.mindspore.ms_config import parse_task_config +from msprobe.mindspore.mindspore_service import MindsporeService from msprobe.mindspore.task_handler_factory import TaskHandlerFactory +try: + from mindspore._c_expression import _dump_start, _dump_stop, _dump_step, _set_init_iter, _dump_set_dynamic +except ImportError: + enable_dynamic_kbyk_dump = False +else: + enable_dynamic_kbyk_dump = True + try: from msprobe.lib import _msprobe_c except ImportError: @@ -44,9 +55,7 @@ except ImportError: ConfigParameters = namedtuple("ConfigParameters", ["config_path", "task", "dump_path", "level"]) -class PrecisionDebugger: - _instance = None - task_not_need_service = [Const.GRAD_PROBE] +class PrecisionDebugger(BasePrecisionDebugger): def __new__(cls, config_path=None, task=None, dump_path=None, level=None, step=None, opt=None): @@ -62,61 +71,32 @@ class PrecisionDebugger: level=None, step=None): if self.initialized: return - self.initialized = True - set_register_backward_hook_functions() + super().__init__(config_path, task, dump_path, level, step) - if not config_path: - config_path = os.path.join(os.path.dirname(__file__), "../../config.json") - - config_params = ConfigParameters(config_path, task, dump_path, level) - self.check_input_params(config_params) - - common_config, task_config = parse_json_config(config_path) - common_config.task = task if task else common_config.task - self.task = common_config.task if self.task == Const.GRAD_PROBE: - self.gm = GradientMonitor(common_config, task_config) + self.gm = GradientMonitor(self.common_config, self.task_config) return - common_config.step = get_real_step_or_rank( - step, Const.STEP) if step is not None else common_config.step - common_config.level = level if level else common_config.level - common_config.dump_path = dump_path if dump_path else common_config.dump_path - self.config = DebuggerConfig(common_config, task_config) + self.common_config.level = level if level else self.common_config.level + self.common_config.dump_path = dump_path if dump_path else self.common_config.dump_path + self.config = DebuggerConfig(self.common_config, self.task_config) - if _msprobe_c: + if self._is_kernel_dump() and _msprobe_c: + os.environ["MS_HOOK_ENABLE"] = "on" _msprobe_c._PrecisionDebugger(framework="MindSpore", config_path=config_path) self.config.execution_mode = self._get_execution_mode() if self._need_service(): - self.config.check_config_with_l2() - self.service = Service(self.config) + self.service = MindsporeService(self.config) Runtime.step_count = 0 Runtime.is_running = False + if enable_dynamic_kbyk_dump: + _dump_set_dynamic() @staticmethod - def check_input_params(args): - if args.config_path is not None: - if not isinstance(args.config_path, str): - raise MsprobeException( - MsprobeException.INVALID_PARAM_ERROR, f"config_path must be a string") - file_checker = FileChecker( - file_path=args.config_path, path_type=FileCheckConst.FILE, file_type=FileCheckConst.JSON_SUFFIX) - file_checker.common_check() - - if args.task is not None and args.task not in Const.TASK_LIST: - raise MsprobeException( - MsprobeException.INVALID_PARAM_ERROR, f"task must be one of {Const.TASK_LIST}") - - if args.dump_path is not None: - if not isinstance(args.dump_path, str): - raise MsprobeException( - MsprobeException.INVALID_PARAM_ERROR, f"dump_path must be a string") - - if args.level is not None and args.level not in Const.LEVEL_LIST: - raise MsprobeException( - MsprobeException.INVALID_PARAM_ERROR, f"level must be one of {Const.LEVEL_LIST}") + def _get_task_config(task, json_config): + return parse_task_config(task, json_config) @staticmethod def _get_execution_mode(): @@ -137,9 +117,7 @@ class PrecisionDebugger: return MsConst.PYNATIVE_MODE @staticmethod - def _is_graph_dump(config): - if config.level != MsConst.KERNEL: - return False + def _is_graph_dump(config: DebuggerConfig): if not config.list: return True is_graph = any(item.startswith("name-regex") for item in config.list) @@ -147,63 +125,62 @@ class PrecisionDebugger: return is_graph @classmethod - def start(cls, model=None): - instance = cls._instance - if not instance: - raise Exception(MsgConst.NOT_CREATED_INSTANCE) - if _msprobe_c: - _msprobe_c._PrecisionDebugger().start() - if instance.task in PrecisionDebugger.task_not_need_service: + def start(cls, model=None, token_range=None): + instance = cls._get_instance() + if instance is None: return - - instance.config.execution_mode = cls._get_execution_mode() - if cls._need_service(): - if not instance.service: - instance.service = Service(instance.config) - instance.service.start(model) + if cls._is_kernel_dump(): + cls._start_kernel_dump() else: - if not instance.first_start: - api_register.api_set_ori_func() - handler = TaskHandlerFactory.create(instance.config) - handler.handle() - + check_token_range(token_range) + instance.config.execution_mode = cls._get_execution_mode() + if cls._need_service(): + if not instance.service: + instance.service = MindsporeService(instance.config) + instance.config.check_model(model, token_range) + instance.service.start(model, token_range) + else: + if not instance.first_start: + get_api_register().restore_all_api() + handler = TaskHandlerFactory.create(instance.config, model) + handler.handle() + Runtime.is_running = True instance.first_start = True - Runtime.is_running = True - - @classmethod - def forward_backward_dump_end(cls): - instance = cls._instance - instance.stop() @classmethod def stop(cls): - instance = cls._instance - if not instance: - raise Exception(MsgConst.NOT_CREATED_INSTANCE) - if _msprobe_c: - _msprobe_c._PrecisionDebugger().stop() + instance = cls._get_instance() + if instance is None: + return + if instance.task == Const.GRAD_PROBE: instance.gm.stop() - if instance.task in PrecisionDebugger.task_not_need_service: - return if instance.service: instance.service.stop() - Runtime.is_running = False - + else: + Runtime.is_running = False + if enable_dynamic_kbyk_dump: + _dump_stop() + if cls._is_kernel_dump() and _msprobe_c: + _msprobe_c._PrecisionDebugger().stop() + @classmethod def step(cls): - instance = cls._instance - if not instance: - raise Exception(MsgConst.NOT_CREATED_INSTANCE) - if _msprobe_c: - _msprobe_c._PrecisionDebugger().step() - if instance.task in PrecisionDebugger.task_not_need_service: + instance = cls._get_instance() + if instance is None: return + if instance.service: instance.service.step() + if is_graph_mode_cell_dump_allowed(instance.config): + GraphModeCellDump.step() + if enable_dynamic_kbyk_dump: + _dump_step(1) + if cls._is_kernel_dump() and _msprobe_c: + _msprobe_c._PrecisionDebugger().step() + HOOKCell.cell_count = defaultdict(int) CellProcessor.reset_cell_stats() - Runtime.step_count += 1 @classmethod @@ -230,7 +207,7 @@ class PrecisionDebugger: instance.config.execution_mode = cls._get_execution_mode() if cls._need_service(): if not instance.service: - instance.service = Service(instance.config) + instance.service = MindsporeService(instance.config) instance.service.save(variable, name, save_backward) @classmethod @@ -238,7 +215,41 @@ class PrecisionDebugger: instance = cls._instance if not instance: raise Exception(MsgConst.NOT_CREATED_INSTANCE) + if instance.config.level_ori == Const.LEVEL_L2: + return False if instance.config.execution_mode != MsConst.PYNATIVE_MODE: return False else: - return instance.config.task != Const.FREE_BENCHMARK and not instance._is_graph_dump(instance.config) \ No newline at end of file + return instance.config.task != Const.FREE_BENCHMARK + + @classmethod + def _is_kernel_dump(cls): + instance = cls._instance + if not instance: + raise Exception(MsgConst.NOT_CREATED_INSTANCE) + return instance.config.level_ori == Const.LEVEL_L2 + + @classmethod + def _start_kernel_dump(cls): + instance = cls._get_instance() + is_graph_config = cls._is_graph_dump(instance.config) + instance.config.check_config_with_l2(is_graph_config) + if not is_graph_config: + if not instance.service: + instance.service = MindsporeService(instance.config) + instance.service.start() + else: + if _msprobe_c: + _msprobe_c._PrecisionDebugger().start() + if not instance.first_start: + get_api_register().restore_all_api() + handlers = TaskHandlerFactory.create(instance.config) + for handler in handlers: + handler.handle() + if enable_dynamic_kbyk_dump: + _set_init_iter(0) + if enable_dynamic_kbyk_dump: + is_valid_rank = (not instance.config.rank or Runtime.rank_id in instance.config.rank) + is_valid_step = (not instance.config.step or Runtime.step_count in instance.config.step) + if is_valid_rank and is_valid_step: + _dump_start() diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/cell_dump_process.py b/debug/accuracy_tools/msprobe/mindspore/dump/cell_dump_process.py new file mode 100644 index 0000000000000000000000000000000000000000..04d3d2c1fd728c444a167c547e32a1f7382ad13a --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/dump/cell_dump_process.py @@ -0,0 +1,592 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import atexit +from multiprocessing import Pool +import os +import re +import time + +import numpy as np +import mindspore as ms +from mindspore import nn, ops + +from msprobe.core.common.const import Const as CoreConst +from msprobe.core.common.const import FileCheckConst +from msprobe.core.common.file_utils import load_npy, save_json, remove_path, load_yaml, move_file +from msprobe.mindspore.common.log import logger + +CONSTRUCT_FILE_NAME = "construct.json" +DEFAULT_RANK_DIR = "rank0" +KEY_LAYERS = "layers" +construct = {} +cell_list = [] +KEY_SIDE_EFFECT = "side_effect_io" +KEY_TOPLAYER = "TopLayer" +KEY_FORWARD = CoreConst.FORWARD +KEY_BACKWARD = CoreConst.BACKWARD +KEY_INPUT = CoreConst.INPUT +KEY_OUTPUT = CoreConst.OUTPUT +td = ops.TensorDump() +if (ms.__version__ >= "2.5.0"): + td_in = ops.TensorDump("in") +else: + td_in = ops.TensorDump() +td.add_prim_attr(KEY_SIDE_EFFECT, False) +td_in.add_prim_attr(KEY_SIDE_EFFECT, False) +np_ms_dtype_dict = { + "bool": ms.bool_, + "int8": ms.int8, + "byte": ms.byte, + "int16": ms.int16, + "short": ms.short, + "int32": ms.int32, + "intc": ms.intc, + "int64": ms.int64, + "intp": ms.intp, + "uint8": ms.uint8, + "ubyte": ms.ubyte, + "uint16": ms.uint16, + "ushort": ms.ushort, + "uint32": ms.uint32, + "uintc": ms.uintc, + "uint64": ms.uint64, + "uintp": ms.uintp, + "float16": ms.float16, + "half": ms.half, + "float32": ms.float32, + "single": ms.single, + "float64": ms.float64, + "double": ms.double, + "bfloat16": ms.bfloat16, + "complex64": ms.complex64, + "complex128": ms.complex128 +} + + +def gen_file_path(dump_path, cell_prefix, suffix, io_type, index): + step_path = os.path.join(dump_path, "{step}") + rank_path = os.path.join(step_path, "{rank}") + data_path = os.path.join(rank_path, CoreConst.DUMP_TENSOR_DATA) + file_name = CoreConst.SEP.join([cell_prefix, suffix, io_type, str(index)]) + return os.path.join(data_path, file_name) + + +def partial_func(func, dump_path, cell_prefix, index, io_type): + def newfunc(*args, **kwargs): + return func(dump_path, cell_prefix, index, io_type, *args, **kwargs) + return newfunc + + +def clip_gradient(dump_path, cell_prefix, index, io_type, dx): + if io_type == KEY_OUTPUT: + temp = td(gen_file_path(dump_path, cell_prefix, KEY_BACKWARD, io_type, index), dx) + dx = ops.depend(dx, temp) + if io_type == KEY_INPUT: + temp = td_in(gen_file_path(dump_path, cell_prefix, KEY_BACKWARD, io_type, index), dx) + dx = ops.depend(dx, temp) + return dx + + +def need_tensordump_in(cell_obj, attr): + return hasattr(cell_obj, attr) and getattr(cell_obj, attr) == "in" + + +def cell_construct_wrapper(func, self): + def new_construct(self, *args, **kwargs): + new_args = [] + out_list = [] + + index = 0 + item = None + # The inputs of the cell. + for index, item in enumerate(args): + if self.data_mode == "backward" or self.data_mode == "all": + if ops.is_tensor(item): + item = self.output_clips[index](item) + if self.data_mode == "forward" or self.data_mode == "all": + if ops.is_tensor(item): + if need_tensordump_in(self, 'input_dump_mode'): + temp = td_in( + gen_file_path(self.dump_path, self.cell_prefix, KEY_FORWARD, KEY_INPUT, index), + item + ) + else: + temp = td( + gen_file_path(self.dump_path, self.cell_prefix, KEY_FORWARD, KEY_INPUT, index), + item + ) + item = ops.depend(item, temp) + new_args.append(item) + + out = func(*new_args, **kwargs) + + # The outputs of the cell. + if isinstance(out, tuple): + for index, item in enumerate(out): + if self.data_mode == "backward" or self.data_mode == "all": + if ops.is_tensor(item): + item = self.input_clips[index](item) + if self.data_mode == "forward" or self.data_mode == "all": + if ops.is_tensor(item): + if need_tensordump_in(self, 'output_dump_mode'): + temp = td_in( + gen_file_path(self.dump_path, self.cell_prefix, KEY_FORWARD, KEY_OUTPUT, index), + item + ) + else: + temp = td( + gen_file_path(self.dump_path, self.cell_prefix, KEY_FORWARD, KEY_OUTPUT, index), + item + ) + item = ops.depend(item, temp) + out_list.append(item) + else: + out_list.append(item) + out_list = tuple(out_list) + return out_list + else: + if self.data_mode == "backward" or self.data_mode == "all": + out = self.input_clips[0](out) + if self.data_mode == "forward" or self.data_mode == "all": + if ops.is_tensor(out): + if need_tensordump_in(self, 'output_dump_mode'): + temp = td_in( + gen_file_path(self.dump_path, self.cell_prefix, KEY_FORWARD, KEY_OUTPUT, 0), + out + ) + else: + temp = td( + gen_file_path(self.dump_path, self.cell_prefix, KEY_FORWARD, KEY_OUTPUT, 0), + out + ) + out = ops.depend(out, temp) + return out + + return new_construct.__get__(self, type(self)) + + +# 获取目录下所有文件名并根据TensorDump落盘自增id从小到大排序 +def sort_filenames(path): + filenames = os.listdir(path) + id_pattern = re.compile(rf'{CoreConst.REPLACEMENT_CHARACTER}(\d+){CoreConst.NUMPY_SUFFIX}$') + filenames.sort(key=lambda x: int(id_pattern.findall(x)[0])) + return filenames + + +# 删除重复dump的文件:自定义文件名相同,并且数据相同 +def del_same_file(path, filenames): + result_list = [] + seen_prefixes = {} + for current_filename in filenames: + parts = current_filename.rsplit(CoreConst.REPLACEMENT_CHARACTER, 1) + prefix = parts[0] + if prefix not in seen_prefixes: + result_list.append(current_filename) + seen_prefixes[prefix] = current_filename + else: + current_file_path = os.path.join(path, current_filename) + current_file = load_npy(current_file_path) + prev_filename = seen_prefixes[prefix] + prev_file_path = os.path.join(path, prev_filename) + prev_file = load_npy(prev_file_path) + if np.array_equal(current_file, prev_file): + remove_path(current_file_path) + logger.warning(f"{current_file_path} is deleted!") + else: + result_list.append(current_filename) + return result_list + + +def rename_filename(path): + filenames = sort_filenames(path) + filenames = del_same_file(path, filenames) + + filename_dict = {} + for filename in filenames: + name_field = filename.rsplit(CoreConst.REPLACEMENT_CHARACTER, 1)[0] + + if name_field in filename_dict: + filename_dict[name_field] += 1 + else: + filename_dict[name_field] = 0 + + cell_index = filename_dict[name_field] + + # 修改文件名,增加重复调用Cell的序号 + if CoreConst.FORWARD_PATTERN in filename: + # Format: Cell.{cell_name}.{class_name}.{forward/backward}.{number}.{input/output}.{index}_{dtype}_{id}.npy + new_file_name = filename.replace(CoreConst.FORWARD_PATTERN, + CoreConst.FORWARD_PATTERN + str(cell_index) + CoreConst.SEP) + if CoreConst.BACKWARD_PATTERN in filename: + new_file_name = filename.replace(CoreConst.BACKWARD_PATTERN, + CoreConst.BACKWARD_PATTERN + str(cell_index) + CoreConst.SEP) + move_file(os.path.join(path, filename), os.path.join(path, new_file_name)) + logger.info("==========The rename_filename phase is Finished!==========") + + +# Extract the field between the first "." and the third to last ".", i.e. {cell_name} +def get_cell_name(cell_str): + parts = cell_str.split(CoreConst.SEP) + if len(parts) < 4: + return None + start_index = 1 + end_index = len(parts) - 3 + return CoreConst.SEP.join(parts[start_index:end_index]) + + +# Extract the field between the last "." and the second to last ".", i.e. {data_made} +def get_data_mode(cell_str): + last_dot_index = cell_str.rfind(CoreConst.SEP) + second_last_dot_index = cell_str.rfind(CoreConst.SEP, 0, last_dot_index) + data_mode = cell_str[second_last_dot_index + 1:last_dot_index] + return data_mode + + +# 判断二者之间是否存在父子关系 +def check_relation(cell_name, parent_cell_name): + layers_pattern = rf"{CoreConst.SEP}{KEY_LAYERS}{CoreConst.SEP}\d+$" + last_dot_index = cell_name.rfind(CoreConst.SEP) + if last_dot_index != -1: + # 如果cell_name最后一个'.'之前的字段等于parent_cell_name,则判定存在父子关系 + sub_cell_name = cell_name[:last_dot_index] + if sub_cell_name == parent_cell_name: + return True + elif re.search(layers_pattern, cell_name): + # 如果cell_name以".layer.{layer_id}"结尾,且去掉该字段后等于parent_cell_name,则判定存在父子关系 + sub_cell_name = re.sub(layers_pattern, '', cell_name) + if sub_cell_name == parent_cell_name: + return True + return False + + +def get_construct(cell_list_input): + for cell in cell_list_input: + cell_name = get_cell_name(cell) + cell_data_mode = get_data_mode(cell) + found_flag = False + for parent_cell in cell_list_input: + parent_cell_name = get_cell_name(parent_cell) + parent_data_mode = get_data_mode(parent_cell) + has_relation = check_relation(cell_name, parent_cell_name) + if has_relation and parent_data_mode == cell_data_mode: + construct.update({cell: parent_cell}) + found_flag = True + break + if not found_flag: + construct.update({cell: None}) + + +def generate_construct(path): + global construct + filenames = sort_filenames(path) + + # 提取文件名中Cell.{cell_name}.{class_name}.{data_mode}.{重复调用此cell的序号}字段,并存入cell_list + for filename in filenames: + point_position = 3 + mid_field = filename.rsplit(CoreConst.SEP, point_position)[0] + if KEY_INPUT in filename: + if mid_field in cell_list: + cell_list.remove(mid_field) + cell_list.append(mid_field) + else: + if mid_field not in cell_list: + index = filenames.index(filename) + output_field = mid_field + KEY_OUTPUT + find_flag = False + for filename_other in cell_list[index + 1:]: + if output_field in filename_other: + find_flag = True + if find_flag is False: + cell_list.append(mid_field) + + get_construct(cell_list) + + # 生成JSON文件 + rank_dir = os.path.dirname(path) + json_path = os.path.join(rank_dir, CONSTRUCT_FILE_NAME) + save_json(json_path, construct, indent=1) + + # 清空'construct'继续处理下一个路径下的数据 + construct = {} + logger.info(f"Construct data saved to {json_path}") + + +def process_file(file_path): + try: + # 读取.npy文件内容 + npy_content = load_npy(file_path) + logger.info(f"Loaded {file_path}: shape is {npy_content.shape}, dtype is {npy_content.dtype}") + + # 文件名举例:Cell.network._backbone.loss.CrossEntropyLoss.forward.0.input.0_float32_165.npy + parts = os.path.basename(file_path).split(CoreConst.SEP) + data_dtype = "" + # 获取0_float32_165或者0_in_float32_165中的float32 + data_dtype_list = parts[-2].split('_') + if len(data_dtype_list) > 1: + data_dtype = data_dtype_list[-2] + # op_name是Cell.network._backbone.loss.CrossEntropyLoss.forward.0 + op_name = CoreConst.SEP.join(parts[:-3]) + ms_dtype = np_ms_dtype_dict.get(data_dtype) + if ms_dtype is None: + logger.warning(f"Get dtype None from file {file_path}") + + # 修改落盘文件名字,去掉TensorDump自带的数据类型和自增id字段 + data_file_name = os.path.basename(file_path) + data_file_dir = os.path.dirname(file_path) + parts = data_file_name.split(CoreConst.SEP) + if len(parts) >= 2: + param_index = parts[-2].split(CoreConst.REPLACEMENT_CHARACTER)[0] + pre_parts = CoreConst.SEP.join(parts[:-2]) + new_file_name = pre_parts + CoreConst.SEP + param_index + CoreConst.NUMPY_SUFFIX + move_file(os.path.join(data_file_dir, data_file_name), os.path.join(data_file_dir, new_file_name)) + logger.info(f"{data_file_name} is renamed to {new_file_name}") + else: + logger.warning(f"Failed to rename {data_file_name}.") + new_file_name = data_file_name + + tensor_json = { + CoreConst.TYPE: 'mindspore.Tensor', + CoreConst.DTYPE: str(ms_dtype), + CoreConst.SHAPE: list(npy_content.shape), + CoreConst.MAX: npy_content.max().item(), + CoreConst.MIN: npy_content.min().item(), + CoreConst.MEAN: npy_content.mean().item(), + CoreConst.NORM: np.linalg.norm(npy_content).item(), + CoreConst.DATA_NAME: new_file_name + } + + # 根据文件名的最后一个部分(输入或输出)确定是添加到input_args还是output + if parts[-3] == KEY_INPUT: + return op_name, CoreConst.INPUT_ARGS, tensor_json + elif parts[-3] == KEY_OUTPUT: + return op_name, KEY_OUTPUT, tensor_json + else: + return None, None, None + + except Exception as e: + logger.error(f"Error reading {file_path}: {e}") + return None, None, None + + +def custom_sort(item, key_to_index): + key = item[0] + return key_to_index.get(key, float('inf')) + + +def generate_dump_info(path): + if not os.path.exists(path): + logger.error("The provided path does not exist.") + return + + dump_data = {"task": "tensor", "level": "L0", "dump_data_dir": path, "data": {}} + + with Pool(processes=10) as pool: + file_paths = [] + for root, _, files in os.walk(path): + for file in files: + if file.endswith(FileCheckConst.NUMPY_SUFFIX): + file_paths.append((os.path.join(root, file),)) + file_paths.sort() + results = pool.starmap(process_file, file_paths) + + # 收集结果 + for op_name, key, tensor_json in results: + if op_name: + if op_name not in dump_data.get(CoreConst.DATA, {}): + dump_data.get(CoreConst.DATA, {})[op_name] = {CoreConst.INPUT_ARGS: [], + CoreConst.INPUT_KWARGS: {}, + KEY_OUTPUT: []} + if key not in dump_data.get(CoreConst.DATA, {}).get(op_name, {}): + dump_data.get(CoreConst.DATA, {}).get(op_name, {})[key] = [] + dump_data.get(CoreConst.DATA, {}).get(op_name, {}).get(key, []).append(tensor_json) + + # 根据cell_list排序 + data_dict = dump_data.get(CoreConst.DATA, {}) + key_to_index = {key: index for index, key in enumerate(cell_list)} + sorted_data_dict = dict(sorted(data_dict.items(), key=lambda item: custom_sort(item, key_to_index))) + dump_data[CoreConst.DATA] = sorted_data_dict + + # 将数据写入dump.json + json_path = os.path.join(os.path.dirname(path), 'dump.json') + save_json(json_path, dump_data, indent=1) + + logger.info(f"Dump data saved to {json_path}") + + +def generate_stack_info(path): + if not os.path.exists(path): + logger.error("The provided path does not exist.") + return + + stack_data = {} + file_paths = [] + # 传入的path为工具生成的./dump_tensor_data,内容为npy文件 + for root, _, files in os.walk(path): + for file in files: + if file.endswith(FileCheckConst.NUMPY_SUFFIX): + file_paths.append(os.path.join(root, file)) + file_paths.sort() + for file_path in file_paths: + # 文件名举例:Cell.network._backbone.loss.CrossEntropyLoss.forward.0.input.0_float32_165.npy + parts = os.path.basename(file_path).split(CoreConst.SEP) + # op_name是Cell.network._backbone.loss.CrossEntropyLoss.forward.0 + op_name = CoreConst.SEP.join(parts[:-3]) + stack_data.update({op_name: []}) + + # 将数据写入stack.json + json_path = os.path.join(os.path.dirname(path), 'stack.json') + save_json(json_path, stack_data, indent=1) + + logger.info(f"Stack data saved to {json_path}") + + +def is_download_finished(directory, interval=3): + """ + 判断指定目录在一段时间后是否有数据被下载完成 + :param directory: 指定目录的路径 + :param interval: 检查的时间间隔(秒),默认为 3 秒 + :return: 如有数据被下载完成返回 True,否则返回 False + """ + # 检查目录是否存在 + if not os.path.exists(directory): + logger.warning(f"The specified directory {directory} does not exist.") + return False + initial_modification_time = os.path.getmtime(directory) + time.sleep(interval) + current_modification_time = os.path.getmtime(directory) + # 比较初始和当前修改时间 + if current_modification_time > initial_modification_time: + return False + else: + return True + + +def process(dump_path): + rank_id = os.environ.get('RANK_ID') + rank_dir = DEFAULT_RANK_DIR + if rank_id is not None: + rank_dir = CoreConst.RANK + str(rank_id) + + step_dir_list = os.listdir(dump_path) + for step_dir in step_dir_list: + step_path = os.path.join(dump_path, step_dir) + rank_path = os.path.join(step_path, rank_dir) + npy_path = os.path.join(rank_path, CoreConst.DUMP_TENSOR_DATA) + while True: + is_finished = is_download_finished(npy_path) + if not is_finished: + logger.info("There is data being downloaded in the specified directory, continue checking...") + else: + logger.info("There is no data being downloaded in the specified directory, Stop checking.") + break + logger.info("==========Start processing data that has already been stored on the disk!==========") + rename_filename(npy_path) + generate_construct(npy_path) + generate_dump_info(npy_path) + generate_stack_info(npy_path) + logger.info("==========JSON file generation completed!==========") + + +def get_yaml_keys(yaml_data): + keys = [] + for key, _ in yaml_data.items(): + keys.append(key) + return keys + + +def get_tensordump_mode(input_str): + left_index = input_str.find('(') + right_index = input_str.find(')') + + # 提取括号内的字符串 + if left_index != -1 and right_index != -1: + inner_str = input_str[left_index + 1:right_index] + # 分割字符串得到元素列表 + elements = inner_str.split(',') + if len(elements) >= 2: + # 去除元素前后的空格 + first_element = elements[0].strip() + second_element = elements[1].strip() + return first_element, second_element + return None, None + + +def set_tensordump_mode(cell, input_str): + first_str, second_str = get_tensordump_mode(input_str) + if first_str and second_str: + cell.input_dump_mode = first_str + cell.output_dump_mode = second_str + + +def start(net=None, dump_path="./", data_mode=CoreConst.ALL, td_config_path=''): + if net is None: + return + + if isinstance(net, nn.Cell): + net = (('', net),) + + if td_config_path == "": + yaml_data = {} + else: + yaml_data = load_yaml(td_config_path) + first_layer_key = get_yaml_keys(yaml_data) + + black_list = ["grad_reducer", ""] + + for name_and_model in net: + for name, cell in name_and_model[1].cells_and_names(name_prefix=name_and_model[0]): + class_name = cell.__class__.__name__ + # 跳过黑名单cell + if name in black_list: + logger.info(f"Cell {name}.{class_name} is skipped!") + continue + # 跳过框架内部的cell + if class_name.startswith(CoreConst.REPLACEMENT_CHARACTER): + logger.info(f"Cell {name}.{class_name} is skipped!") + continue + else: + # Format: Cell.{cell_name}.{class_name} + cell.cell_prefix = CoreConst.SEP.join([CoreConst.CELL, name, cell.__class__.__name__]) + + # 根据yaml配置文件设置cell的TensorDump模式 + if class_name in first_layer_key: + layer_data = yaml_data.get(class_name) + if layer_data: + for child_name, child_cell in cell.cells_and_names(): + if child_name in layer_data: + set_tensordump_mode(child_cell, layer_data[child_name]) + top_layer_data = yaml_data.get(KEY_TOPLAYER) + if top_layer_data and name in top_layer_data: + set_tensordump_mode(cell, top_layer_data[name]) + + # 替换construct函数 + cell.construct = cell_construct_wrapper(cell.construct, cell) + logger.info(f"Cell {name}: construct function is wrapped!") + cell.dump_path = dump_path + cell.data_mode = data_mode + cell.input_clips = [] + cell.output_clips = [] + # It is assumed that each cell has a maximum of 50 outputs and 50 inputs. + for i in range(50): + cell.input_clips.append( + ops.InsertGradientOf(partial_func(clip_gradient, cell.dump_path, cell.cell_prefix, i, KEY_INPUT)) + ) + cell.output_clips.append( + ops.InsertGradientOf(partial_func(clip_gradient, cell.dump_path, cell.cell_prefix, i, KEY_OUTPUT)) + ) + + logger.info("==========The cell_dump_process_start phase is Finished!==========") + atexit.register(process, dump_path=dump_path) diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/dump_tool_factory.py b/debug/accuracy_tools/msprobe/mindspore/dump/dump_tool_factory.py index 0ca63b4a84aee00127bca37b7da36888e905a5aa..1f7fe0379db2cbaa59fa856a38ce431fb3655a51 100644 --- a/debug/accuracy_tools/msprobe/mindspore/dump/dump_tool_factory.py +++ b/debug/accuracy_tools/msprobe/mindspore/dump/dump_tool_factory.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,16 +14,18 @@ # limitations under the License. from msprobe.mindspore.common.const import Const +from msprobe.core.common.log import logger from msprobe.mindspore.debugger.debugger_config import DebuggerConfig from msprobe.mindspore.dump.kernel_graph_dump import KernelGraphDump from msprobe.mindspore.dump.kernel_kbyk_dump import KernelKbykDump +from msprobe.mindspore.dump.graph_mode_cell_dump import GraphModeCellDump class DumpToolFactory: tools = { Const.CELL: { - Const.GRAPH_KBYK_MODE: None, - Const.GRAPH_GE_MODE: None, + Const.GRAPH_KBYK_MODE: GraphModeCellDump, + Const.GRAPH_GE_MODE: GraphModeCellDump, Const.PYNATIVE_MODE: None }, Const.API: { @@ -39,14 +41,21 @@ class DumpToolFactory: } @staticmethod - def create(config: DebuggerConfig): - if len(config.data_mode) != 1 or config.data_mode[0] not in Const.GRAPH_DATA_MODE_LIST: - raise Exception("data_mode must be one of all, input, output.") + def create(config: DebuggerConfig, model=None): + if config.level == Const.CELL: + if len(config.data_mode) != 1 or config.data_mode[0] not in Const.GRAPH_CELL_DUMP_DATA_MODE_LIST: + raise Exception("data_mode must be one of all, forward, backward.") + else: + if len(config.data_mode) != 1 or config.data_mode[0] not in Const.GRAPH_DATA_MODE_LIST: + raise Exception("data_mode must be one of all, input, output.") + if config.level == Const.KERNEL: + return (KernelGraphDump(config), KernelKbykDump(config)) tool = DumpToolFactory.tools.get(config.level) if not tool: raise Exception("Valid level is needed.") tool = tool.get(config.execution_mode) if not tool: - raise Exception(f"Data dump is not supported in {config.execution_mode} mode " - f"when dump level is {config.level}.") - return tool(config) + logger.error(f"Data dump is not supported in {config.execution_mode} mode " + f"when dump level is {config.level}.") + raise ValueError + return tool(config, model) if tool == GraphModeCellDump else tool(config) diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/graph_mode_cell_dump.py b/debug/accuracy_tools/msprobe/mindspore/dump/graph_mode_cell_dump.py new file mode 100644 index 0000000000000000000000000000000000000000..7ab63f20fb7b50a2c5cbc9f4d49aecdb33452e03 --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/dump/graph_mode_cell_dump.py @@ -0,0 +1,108 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import mindspore as ms +from mindspore import hal, ops +from mindspore.ops.primitive import _run_op + +from msprobe.core.common.const import Const as CoreConst +from msprobe.core.common.runtime import Runtime +from msprobe.mindspore.common.const import Const +from msprobe.mindspore.common.log import logger +from msprobe.mindspore.debugger.debugger_config import DebuggerConfig +import msprobe.mindspore.dump.cell_dump_process as cellDumper + +tensordump_flag = True +try: + from mindspore._c_expression import _tensordump_set_step +except ImportError: + tensordump_flag = False + + +class GraphModeCellDump: + def __init__(self, config: DebuggerConfig, model, strict=True): + self.net = model + self.white_list = [] + self.black_list = [] + self.dump_path = config.dump_path if config.dump_path else "./" + self.rank = config.rank + self.step = config.step + self.scope = config.scope + self.list = config.list + self.data_mode = config.data_mode + self.file_format = config.file_format + self.td_config_path = config.td_config_path + self.check_config(strict) + self.set_step() + + @staticmethod + def step(): + hal.synchronize() + temp_tensor = ms.Tensor([1], dtype=ms.float32) + step_flag = "" + _run_op(ops.TensorDump(), "TensorDump", (step_flag, temp_tensor)) + ops.tensordump(step_flag, temp_tensor) + + def check_config(self, strict): + if not self.net: + raise Exception("The model is empty and cell dump is not enabled.") + + if strict: + if self.rank: + raise Exception("In graph mode, cell dump does not currently support specifying rank.") + if self.scope: + raise Exception("In graph mode, cell dump does not currently support specifying scope.") + if self.list: + raise Exception("In graph mode, cell dump does not currently support specifying list.") + if len(self.data_mode) != 1 or self.data_mode[0] not in Const.GRAPH_CELL_DUMP_DATA_MODE_LIST: + raise Exception("In graph mode and cell dump, data_mode must be one of all, forword, backword.") + if self.file_format != []: + logger.warning("In graph mode, cell dump does not currently support specifying file_format." + " The file will be stored in npy format.") + else: + self.rank = [] + self.scope = [] + self.list = [] + self.file_format = [] + if len(self.data_mode) != 1 or self.data_mode[0] not in Const.GRAPH_CELL_DUMP_DATA_MODE_LIST: + self.data_mode = [CoreConst.ALL] + + return True + + def set_step(self): + if tensordump_flag: + _tensordump_set_step(self.step) + else: + raise Exception( + "Importing _tensordump_set_step failed, " + "please use the latest version package of MindSpore." + ) + + def handle(self): + os.environ['MS_JIT_MODULES'] = 'msprobe' + + if Runtime.run_mode == Const.PYNATIVE_GRAPH_MODE: + dump_path = os.path.join(self.dump_path, Const.GRAPH_MODE) + else: + dump_path = self.dump_path + + cellDumper.start( + net=self.net, + dump_path=dump_path, + data_mode=self.data_mode[0], + td_config_path=self.td_config_path + ) diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/graph_tensor_dump.py b/debug/accuracy_tools/msprobe/mindspore/dump/graph_tensor_dump.py new file mode 100644 index 0000000000000000000000000000000000000000..7b3f249e7e7065d52046aa6991a9d8553bb230d6 --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/dump/graph_tensor_dump.py @@ -0,0 +1,123 @@ +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from collections import OrderedDict +import mindspore as ms + + +def _iterate_items(data): + if isinstance(data, (dict, OrderedDict)): + return data.items() + elif isinstance(data, (list, tuple)): + return enumerate(data) + else: + raise TypeError("Unsupported data type") + + +class _SaveBase: + def __init__(self, save_dir): + super(_SaveBase, self).__init__() + self.path = save_dir + self.save_func = _npy_save + + def get_save_func(self): + return self.save_func + + +@ms.jit_class +class _SaveCell(_SaveBase): + def __call__(self, name, data): + return self.get_save_func()(self.path, name, data) + + +class _SaveGradBase: + def __init__(self, save_dir, name): + super(_SaveGradBase, self).__init__() + self.file = save_dir + name + + +@ms.jit_class +class _SaveGradCell(_SaveGradBase): + def __init__(self, save_dir, name): + super(_SaveGradCell, self).__init__(save_dir, name) + self.ms_save_grad = ms.ops.InsertGradientOf( + _wrapper_save_grad_func(self.file)) + + def __call__(self, x): + if isinstance(x, ms.Tensor): + return self.ms_save_grad(x) + else: + raise TypeError(f"For 'save_grad', the type of argument 'data' must be mindspore.Tensor or torch.tensor, " + f"but got {type(x)}") + + +def _npy_save_ops(file, data): + if isinstance(data, ms.Tensor): + if data.dtype == ms.bfloat16: + data = data.float() + ms.ops.TensorDump()(file, data) + else: + raise TypeError(f"For 'save', the type of argument 'data' must be mindspore.Tensor or torch.tensor, " + f"but got {type(data)}") + + +def _wrapper_save_grad_func(file): + def _save_grad_func(grad): + data = grad + if data.dtype == ms.bfloat16: + data = data.float() + ms.ops.TensorDump()(file, data) + return grad + return _save_grad_func + + +def _npy_save(save_dir, item_name, data): + if isinstance(data, (list, tuple, dict, OrderedDict)): + for key, val in _iterate_items(data): + _npy_save(save_dir, f"{item_name}.{key}", val) + else: + if data is None: + return + _npy_save_ops(f"{save_dir}{item_name}", data) + + +def generate_dump_dir(save_dir, sep=os.sep): + """ + usage: generate dump directory path str in mindspore graph mode + """ + full_suffix = '{step}' + sep + '{rank}' + sep + if save_dir and save_dir[-1] != sep: + result_dir = save_dir + sep + full_suffix + else: + result_dir = save_dir + full_suffix + return result_dir + + +def save(save_dir, name, data): + """ + save tensor. + """ + dump_dir = generate_dump_dir(save_dir) + _SaveCell(dump_dir)(name, data) + + +def save_grad(save_dir, name, data): + """ + save grad. + """ + dump_dir = generate_dump_dir(save_dir) + suffix_name = name + '_grad' + return _SaveGradCell(dump_dir, suffix_name)(data) diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_register.py b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_register.py new file mode 100644 index 0000000000000000000000000000000000000000..59ae1214e7b13ae62f28af2caf218a3cd9e613fc --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_register.py @@ -0,0 +1,176 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import inspect + +from mindspore import Tensor, ops, mint +from mindspore.mint import distributed +from mindspore.mint.nn import functional +from mindspore.communication import comm_func + +from msprobe.core.common.file_utils import load_yaml +from msprobe.core.common.utils import Const +from msprobe.core.data_dump.api_registry import ApiRegistry +from msprobe.mindspore.common.log import logger +from msprobe.mindspore.common.const import Const as MsConst +from msprobe.mindspore.common.utils import is_mindtorch +from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell + + +stub_tensor_existed = True +try: + from mindspore.common._stub_tensor import StubTensor +except ImportError: + stub_tensor_existed = False + +cur_path = os.path.dirname(os.path.realpath(__file__)) +if not is_mindtorch(): + _api_types = { + Const.MS_FRAMEWORK: { + Const.MS_API_TYPE_OPS: (ops, (ops,)), + Const.MS_API_TYPE_TENSOR: (Tensor, (Tensor,)), + Const.MS_API_TYPE_MINT: (mint, (mint,)), + Const.MS_API_TYPE_MINT_FUNC: (functional, (functional,)), + Const.MS_API_TYPE_COM: (comm_func, (comm_func,)), + Const.MS_API_TYPE_MINT_DIST: (distributed, (distributed,)) + } + } + if stub_tensor_existed: + _api_types.get(Const.MS_FRAMEWORK).update( + {Const.MS_API_TYPE_STUB_TENSOR: (StubTensor, (StubTensor,))} + ) + + _supported_api_list_path = (os.path.join(cur_path, MsConst.SUPPORTED_API_LIST_FILE),) + _backlist = [] +else: + import torch + import torch_npu + _api_types = { + Const.MT_FRAMEWORK: { + Const.PT_API_TYPE_FUNCTIONAL: (torch.nn.functional, (torch.nn.functional,)), + Const.PT_API_TYPE_TENSOR: (torch.Tensor, (torch.Tensor,)), + Const.PT_API_TYPE_TORCH: (torch, (torch,)), + Const.PT_API_TYPE_NPU: (torch_npu, (torch_npu,)), + Const.PT_API_TYPE_DIST: (torch.distributed, (torch.distributed, torch.distributed.distributed_c10d)) + } + } + _supported_api_list_path = (os.path.join(cur_path, '../../../pytorch/hook_module', + MsConst.SUPPORTED_API_LIST_FILE),) + _backlist = [f'{Const.PT_API_TYPE_TENSOR}.__setitem__'] + +_inner_used_api = { + Const.MS_FRAMEWORK + Const.SEP + Const.MS_API_TYPE_OPS: ( + ops, "norm", "square", "sqrt", "is_complex", "stack", "is_floating_point" + ), + Const.MS_FRAMEWORK + Const.SEP + Const.MS_API_TYPE_TENSOR: ( + Tensor, "to", "numel", 'sum' + ), + Const.MS_FRAMEWORK + Const.SEP + Const.MS_API_TYPE_MINT: ( + mint, "max", "min", "mean", "norm" + ) +} + + +class ApiTemplate(HOOKCell): + def __init__(self, api_name, api_func, prefix, hook_build_func): + self.api_name = api_name + self.api_func = api_func + self.prefix_api_name = prefix + Const.SEP + str(api_name.split(Const.SEP)[-1]) + Const.SEP + super().__init__(hook_build_func) + distributed_prefix = Const.DIST_API_TYPE_PREFIX if is_mindtorch() else Const.MINT_DIST_API_TYPE_PREFIX + if prefix == distributed_prefix: + self.op_is_distributed = True + + @staticmethod + def async_to_sync(output): + # Fake handle, used to return after the CommHandle executes the wait method + fake_handle = type("FakeHandle", (), {"wait": lambda self: None})() + if isinstance(output, tuple) and len(output) == 2 and hasattr(output[1], "wait"): + output[1].wait() + output = (output[0], fake_handle) + elif hasattr(output, "wait"): + output.wait() + output = fake_handle + return output + + def construct(self, *args, **kwargs): + if self.api_name.startswith(MsConst.DROPOUT_API_NAME_PREFIX): + return args[0] if args else kwargs.get(Const.INPUT) + + output = self.api_func(*args, **kwargs) + + if self.prefix_api_name.startswith( + (MsConst.DISTRIBUTED_DATA_PREFIX, Const.MINT_DIST_API_TYPE_PREFIX) + ): + try: + bound = inspect.signature(self.api_func).bind(*args, **kwargs) + bound.apply_defaults() + use_async_op_flag = bound.arguments.get("async_op", False) + except Exception as e: + use_async_op_flag = False + logger.warning(f"fail to get dist api's func signature because {e}, no wait") + + if use_async_op_flag or self.api_name in ["isend", "irecv"]: + output = self.async_to_sync(output) + if self.api_name == "batch_isend_irecv" and isinstance(output, list): + output = [self.async_to_sync(handle) for handle in output] + + return output + + def forward(self, *args, **kwargs): + if self.api_name.startswith(MsConst.DROPOUT_API_NAME_PREFIX): + return args[0] if args else kwargs.get(Const.INPUT) + return self.api_func(*args, **kwargs) + + +api_register = None +stub_tensor_set = False + + +def get_api_register(return_new=False): + global stub_tensor_set + + def stub_method(method): + def wrapped_method(*args, **kwargs): + return method(*args, **kwargs) + return wrapped_method + if not is_mindtorch() and stub_tensor_existed and not stub_tensor_set: + api_names = load_yaml(_supported_api_list_path[0]).get(Const.MS_API_TYPE_TENSOR, []) + for attr_name in dir(StubTensor): + attr = getattr(StubTensor, attr_name) + if attr_name in api_names and callable(attr): + setattr(StubTensor, attr_name, stub_method(attr)) + stub_tensor_set = True + + if return_new: + return ApiRegistry( + _api_types, + _inner_used_api, + _supported_api_list_path, + ApiTemplate, + _backlist + ) + + global api_register + if api_register is None: + api_register = ApiRegistry( + _api_types, + _inner_used_api, + _supported_api_list_path, + ApiTemplate, + _backlist + ) + return api_register diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py deleted file mode 100644 index 7aee1deccd9689985c7a2e270648bd0877cd7cf3..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py +++ /dev/null @@ -1,207 +0,0 @@ -# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from mindspore import Tensor, ops, mint -from mindspore.mint.nn import functional -from mindspore.common._stub_tensor import StubTensor -from mindspore.communication import comm_func - -from msprobe.mindspore.dump.hook_cell.wrap_api import (HOOKTensor, HOOKStubTensor, HOOKFunctionalOP, - HOOKMintOP, HOOKMintNNFunctionalOP, HOOKDistributedOP, - HOOKTorchOP, HOOKTorchTensor, HOOKTorchFunctionalOP, - HOOKTorchDistributedOP, HOOKTorchNpuOP, - get_wrap_api_list, get_wrap_torch_api_list, setup_hooks) -from msprobe.core.common.utils import Const -from msprobe.mindspore.common.utils import is_mindtorch - -if is_mindtorch(): - import torch - import torch_npu - - -def stub_method(method): - def wrapped_method(*args, **kwargs): - return method(*args, **kwargs) - return wrapped_method - - -class ApiRegistry: - def __init__(self): - self.tensor_ori_attr = {} - self.stub_tensor_ori_attr = {} - self.functional_ori_attr = {} - self.mint_ops_ori_attr = {} - self.mint_func_ops_ori_attr = {} - self.distributed_ori_attr = {} - self.norm_inner_ops_ori_attr = {} - - self.torch_ori_attr = {} - self.torch_tensor_ori_attr = {} - self.torch_functional_ori_attr = {} - self.torch_distributed_ori_attr = {} - self.torch_npu_ori_attr = {} - - self.tensor_hook_attr = {} - self.stub_tensor_hook_attr = {} - self.functional_hook_attr = {} - self.mint_ops_hook_attr = {} - self.mint_func_ops_hook_attr = {} - self.distibuted_hook_attr = {} - self.norm_inner_ops_hook_attr = {} - - self.torch_hook_attr = {} - self.torch_tensor_hook_attr = {} - self.torch_functional_hook_attr = {} - self.torch_distributed_hook_attr = {} - self.torch_npu_hook_attr = {} - - self.norm_inner_ops = ["norm", "square", "sqrt", "is_complex"] - - @staticmethod - def store_ori_attr(ori_api_group, api_list, api_ori_attr): - for api in api_list: - if Const.SEP in api: - sub_module_name, sub_op = api.rsplit(Const.SEP, 1) - sub_module = getattr(ori_api_group, sub_module_name) - ori_api_func = getattr(sub_module, sub_op) - else: - ori_api_func = getattr(ori_api_group, api) - if ori_api_group == StubTensor: - api_ori_attr[api] = stub_method(ori_api_func) - continue - api_ori_attr[api] = ori_api_func - - @staticmethod - def set_api_attr(api_group, attr_dict): - for api, api_attr in attr_dict.items(): - if Const.SEP in api: - sub_module_name, sub_op = api.rsplit(Const.SEP, 1) - sub_module = getattr(api_group, sub_module_name, None) - if sub_module is not None: - setattr(sub_module, sub_op, api_attr) - else: - setattr(api_group, api, api_attr) - - def norm_inner_op_set_hook_func(self): - self.set_api_attr(ops, self.norm_inner_ops_hook_attr) - - def norm_inner_op_set_ori_func(self): - self.set_api_attr(ops, self.norm_inner_ops_ori_attr) - - def api_set_hook_func(self): - if is_mindtorch(): - self.set_api_attr(torch, self.torch_hook_attr) - self.set_api_attr(torch.Tensor, self.torch_tensor_hook_attr) - self.set_api_attr(torch.nn.functional, self.torch_functional_hook_attr) - self.set_api_attr(torch.distributed, self.torch_distributed_hook_attr) - self.set_api_attr(torch.distributed.distributed_c10d, self.torch_distributed_hook_attr) - self.set_api_attr(torch_npu, self.torch_npu_hook_attr) - else: - self.set_api_attr(Tensor, self.tensor_hook_attr) - self.set_api_attr(StubTensor, self.stub_tensor_hook_attr) - self.set_api_attr(ops, self.functional_hook_attr) - self.set_api_attr(mint, self.mint_ops_hook_attr) - self.set_api_attr(functional, self.mint_func_ops_hook_attr) - self.set_api_attr(comm_func, self.distibuted_hook_attr) - - def api_set_ori_func(self): - if is_mindtorch(): - self.set_api_attr(torch, self.torch_ori_attr) - self.set_api_attr(torch.Tensor, self.torch_tensor_ori_attr) - self.set_api_attr(torch.nn.functional, self.torch_functional_ori_attr) - self.set_api_attr(torch.distributed, self.torch_distributed_ori_attr) - self.set_api_attr(torch.distributed.distributed_c10d, self.torch_distributed_ori_attr) - self.set_api_attr(torch_npu, self.torch_npu_ori_attr) - else: - self.set_api_attr(Tensor, self.tensor_ori_attr) - self.set_api_attr(StubTensor, self.stub_tensor_ori_attr) - self.set_api_attr(ops, self.functional_ori_attr) - self.set_api_attr(mint, self.mint_ops_ori_attr) - self.set_api_attr(functional, self.mint_func_ops_ori_attr) - self.set_api_attr(comm_func, self.distributed_ori_attr) - - def initialize_hook(self, hook): - setup_hooks(hook) - if is_mindtorch(): - wrap_torch_api_name = get_wrap_torch_api_list() - self.store_ori_attr(torch, - wrap_torch_api_name.torch_api_names, self.torch_ori_attr) - self.store_ori_attr(torch.Tensor, - wrap_torch_api_name.tensor_api_names, self.torch_tensor_ori_attr) - self.store_ori_attr(torch.nn.functional, - wrap_torch_api_name.functional_api_names, self.torch_functional_ori_attr) - self.store_ori_attr(torch.distributed, - wrap_torch_api_name.distributed_api_names, self.torch_distributed_ori_attr) - self.store_ori_attr(torch_npu, - wrap_torch_api_name.npu_api_names, self.torch_npu_ori_attr) - for attr_name in dir(HOOKTorchOP): - if attr_name.startswith(Const.ATTR_NAME_PREFIX): - api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:] - self.torch_hook_attr[api_name] = getattr(HOOKTorchOP, attr_name) - for attr_name in dir(HOOKTorchTensor): - if attr_name.startswith(Const.ATTR_NAME_PREFIX): - api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:] - self.torch_tensor_hook_attr[api_name] = getattr(HOOKTorchTensor, attr_name) - for attr_name in dir(HOOKTorchFunctionalOP): - if attr_name.startswith(Const.ATTR_NAME_PREFIX): - api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:] - self.torch_functional_hook_attr[api_name] = getattr(HOOKTorchFunctionalOP, attr_name) - for attr_name in dir(HOOKTorchDistributedOP): - if attr_name.startswith(Const.ATTR_NAME_PREFIX): - api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:] - self.torch_distributed_hook_attr[api_name] = getattr(HOOKTorchDistributedOP, attr_name) - for attr_name in dir(HOOKTorchNpuOP): - if attr_name.startswith(Const.ATTR_NAME_PREFIX): - api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:] - self.torch_npu_hook_attr[api_name] = getattr(HOOKTorchNpuOP, attr_name) - return - - wrap_api_name = get_wrap_api_list() - self.store_ori_attr(Tensor, wrap_api_name.tensor_api_names, self.tensor_ori_attr) - self.store_ori_attr(StubTensor, wrap_api_name.stub_tensor_api_names, self.stub_tensor_ori_attr) - self.store_ori_attr(ops, wrap_api_name.ops_api_names, self.functional_ori_attr) - self.store_ori_attr(mint, wrap_api_name.mint_api_names, self.mint_ops_ori_attr) - self.store_ori_attr(functional, wrap_api_name.mint_nn_func_api_names, self.mint_func_ops_ori_attr) - self.store_ori_attr(comm_func, wrap_api_name.distributed_api_names, self.distributed_ori_attr) - self.store_ori_attr(ops, self.norm_inner_ops, self.norm_inner_ops_ori_attr) - for attr_name in dir(HOOKTensor): - if attr_name.startswith(Const.ATTR_NAME_PREFIX): - api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:] - self.tensor_hook_attr[api_name] = getattr(HOOKTensor, attr_name) - for attr_name in dir(HOOKStubTensor): - if attr_name.startswith(Const.ATTR_NAME_PREFIX): - api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:] - self.stub_tensor_hook_attr[api_name] = getattr(HOOKStubTensor, attr_name) - for attr_name in dir(HOOKFunctionalOP): - if attr_name.startswith(Const.ATTR_NAME_PREFIX): - api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:] - self.functional_hook_attr[api_name] = getattr(HOOKFunctionalOP, attr_name) - if api_name in self.norm_inner_ops: - self.norm_inner_ops_hook_attr[api_name] = getattr(HOOKFunctionalOP, attr_name) - for attr_name in dir(HOOKMintOP): - if attr_name.startswith(Const.ATTR_NAME_PREFIX): - api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:] - self.mint_ops_hook_attr[api_name] = getattr(HOOKMintOP, attr_name) - for attr_name in dir(HOOKMintNNFunctionalOP): - if attr_name.startswith(Const.ATTR_NAME_PREFIX): - api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:] - self.mint_func_ops_hook_attr[api_name] = getattr(HOOKMintNNFunctionalOP, attr_name) - for attr_name in dir(HOOKDistributedOP): - if attr_name.startswith(Const.ATTR_NAME_PREFIX): - api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:] - self.distibuted_hook_attr[api_name] = getattr(HOOKDistributedOP, attr_name) - - -api_register = ApiRegistry() diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/hook_cell.py b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/hook_cell.py index b68a7d995a56497a219281c5a43d692c46cfac4d..62e14e9f287f5c0301f24126c19209e00f62da40 100644 --- a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/hook_cell.py +++ b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/hook_cell.py @@ -15,11 +15,16 @@ from collections import defaultdict +import mindspore as ms from mindspore import nn +from msprobe.core.common.runtime import Runtime from msprobe.mindspore.common.utils import is_mindtorch, register_backward_hook_functions +ms_version = ms.__version__ + + def add_cell_count(name): HOOKCell.cell_count[name] += 1 @@ -28,29 +33,34 @@ def get_cell_count(name): return HOOKCell.cell_count[name] -def __init__(self, build_hook) -> None: +def __init__(self, hook_build_func) -> None: super(HOOKCell, self).__init__() self.changed_status = False - self.input_kwargs = {} - self.prefix = "" + self.msprobe_input_kwargs = {} if not HOOKCell.g_stop_hook: HOOKCell.g_stop_hook = True self.changed_status = True - if hasattr(self, "prefix_api_name"): - self.prefix = self.prefix_api_name - self.forward_data_collected = False - forward_pre_hook, forward_hook, backward_hook, backward_pre_hook = build_hook(self.prefix) - self.register_forward_pre_hook(forward_pre_hook) - self.register_forward_hook(forward_hook) - register_backward_hook_functions["full"](self, backward_hook) - register_backward_hook_functions["pre"](self, backward_pre_hook) + + if not Runtime.is_running: + return + prefix = self.prefix_api_name if hasattr(self, "prefix_api_name") else "" + if callable(hook_build_func): + hook_set = hook_build_func(prefix) + if ms_version < "2.6.0" and not is_mindtorch(): + getattr(self, "_forward_pre_hook", {})[id(self)] = hook_set.forward_pre_hook + getattr(self, "_forward_hook", {})[id(self)] = hook_set.forward_hook + else: + self.register_forward_pre_hook(hook_set.forward_pre_hook) + self.register_forward_hook(hook_set.forward_hook) + register_backward_hook_functions["full"](self, hook_set.backward_hook) + register_backward_hook_functions["pre"](self, hook_set.backward_pre_hook) # 重载call,加全局标志。 def __call__(self, *args, **kwargs): try: - self.input_kwargs = kwargs + self.msprobe_input_kwargs = kwargs out = super(HOOKCell, self).__call__(*args, **kwargs) except Exception as e: raise e diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/ms_hook_manager.py b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/ms_hook_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..5581a44ca5eb1f1f0ecab4b30255b5c4e09f8b5a --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/ms_hook_manager.py @@ -0,0 +1,88 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from mindspore.common.api import _no_grad +from msprobe.core.common.const import Const +from msprobe.core.common.utils import replace_last_occurrence +from msprobe.core.data_dump.data_processor.base import ModuleBackwardInputs +from msprobe.core.hook_manager import BaseHookManager, HookSet +from msprobe.mindspore.common.utils import has_kwargs_in_forward_hook +from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell + + +class MindsproeHookManager(BaseHookManager): + @property + def _is_recompute(self): + return None + + @staticmethod + def _no_grad_context(): + return _no_grad() + + @staticmethod + def _add_count(name): + HOOKCell.add_cell_count(name) + + @staticmethod + def _process_kwargs_and_output(module, hook_type, kwargs_or_output, output_or_kwargs): + if not has_kwargs_in_forward_hook() or hook_type == Const.API: + kwargs = module.msprobe_input_kwargs if hasattr(module, 'msprobe_input_kwargs') else {} + output = kwargs_or_output + else: + kwargs = kwargs_or_output + output = output_or_kwargs + return kwargs, output + + def build_hook(self, hook_type, name): + if hook_type == Const.API: + full_forward_name = name + str(HOOKCell.get_cell_count(name)) + Const.SEP + Const.FORWARD + else: + full_forward_name = name + full_backward_name = replace_last_occurrence(full_forward_name, Const.FORWARD, Const.BACKWARD) + hookset = HookSet( + forward_hook=self._build_forward_hook(hook_type, full_forward_name), + forward_pre_hook=self._build_forward_pre_hook(hook_type, full_forward_name, name), + backward_hook=self._build_backward_hook(hook_type, full_backward_name), + backward_pre_hook=self._build_backward_pre_hook(hook_type, full_backward_name) + ) + return hookset + + def _need_exchange(self, module): + if not hasattr(module, 'has_pre_hook_called') or not module.has_pre_hook_called: + return False + else: + return True + + def _get_params_dict(self, module): + params_dict = {} + if self.config.task != Const.STRUCTURE: + params_dict = { + key.split(Const.SEP)[-1]: value + for key, value in module.parameters_dict(recurse=False).items() + } + return params_dict + + def _build_backward_pre_hook(self, hook_type, name): + def backward_pre_hook(module, grad_input): + if self.config.level != Const.LEVEL_L2: + return + if not self._should_execute_hook(hook_type, module, False): + return + BaseHookManager.inner_switch = True + module_input = ModuleBackwardInputs(grad_input=grad_input) + self.data_collector.update_api_or_module_name(name) + self.data_collector.backward_input_data_collect(name, module, self._pid, module_input) + BaseHookManager.inner_switch = False + return backward_pre_hook diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/primitive_hooks.py b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/primitive_hooks.py index 656e48c678956563a6f2d1d5f5ab8a4d03f074e7..b8cf0078e46efd90fac863fd6f1b48604ca63ce1 100644 --- a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/primitive_hooks.py +++ b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/primitive_hooks.py @@ -21,6 +21,7 @@ from mindspore.common.tensor import Tensor from msprobe.core.common.utils import Const, DumpException from msprobe.core.data_dump.data_processor.base import (ModuleBackwardInputs, ModuleBackwardOutputs, ModuleForwardInputsOutputs) +from msprobe.core.hook_manager import BaseHookManager from msprobe.mindspore.common.log import logger @@ -179,7 +180,7 @@ class PrimitiveHookService: current_count = self.primitive_counters.get(primitive_name, 0) updated_primitive_name = f"{Const.PRIMITIVE_PREFIX}{Const.SEP}{primitive_name}{Const.SEP}{current_count}" - if not self.service_instance.primitive_switch: + if not self.service_instance.primitive_switch or BaseHookManager.inner_switch: return origin_func(*args, **kwargs) captured_grads_input, captured_grads_output = [], [] diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml index 723b0cbc93f78d50f703838eb488de6733008906..eae8f85a87fb2b0986cefb2e6faae7399a86f367 100644 --- a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +++ b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml @@ -564,15 +564,15 @@ tensor: - all - amax - amin + - angle - any - arccos - arccosh - - argmax - - angle - arcsin - arcsinh - arctan - arctanh + - argmax - argmin - argsort - asin @@ -582,19 +582,23 @@ tensor: - atanh - baddbmm - bernoulli + - bfloat16 - bincount - bitwise_and - bitwise_or - bitwise_xor - bmm - bool + - bool astype - broadcast_to + - byte - ceil - - cholesky_solve - cholesky + - cholesky_solve - clamp - clip - conj + - copy - copysign - cos - cosh @@ -606,11 +610,13 @@ tensor: - deg2rad - diag - diagflat + - diagonal - diff - digamma - div - div_ - divide + - double - equal - erf - erfc @@ -618,13 +624,16 @@ tensor: - exp - expand_as - expm1 + - flatten - flip - fliplr - flipud + - float - float_power - floor - fmod - frac + - from_numpy - gather_elements - ge - geqrf @@ -648,12 +657,12 @@ tensor: - inner - int - inverse + - is_complex + - is_signed - isclose - isfinite - isinf - isnan - - is_complex - - is_signed - isneginf - isposinf - isreal @@ -704,28 +713,27 @@ tensor: - new_ones - new_zeros - nextafter - - norm - nonzero + - norm - not_equal - ormqr - permute - pow - prod - qr + - rad2deg - ravel - real - reciprocal - remainder - renorm - - rad2deg - - tile - repeat_interleave - reshape - reshape - - round + - resize - rot90 + - round - rsqrt - - sum_to_size - scatter - sgn - short @@ -745,7 +753,8 @@ tensor: - sub - sub_ - subtract - - subtract + - sum + - sum_to_size - svd - swapaxes - swapdims @@ -753,13 +762,13 @@ tensor: - take - tan - tanh - - trace - - swapaxes + - tensor_split - tile + - to - topk - - tril - - tensor_split + - trace - transpose + - tril - true_divide - trunc - unbind @@ -769,17 +778,6 @@ tensor: - view - where - xlogy - - from_numpy - - std - - take - - var - - all - - any - - copy - - diagonal - - flatten - - resize - - sum mint.ops: - abs @@ -1027,3 +1025,21 @@ communication.comm_func: - recv - isend - irecv + +mint.distributed: + - send + - recv + - broadcast + - all_reduce + - reduce + - all_gather + - gather + - isend + - irecv + - scatter + - reduce_scatter + - all_to_all_single + - all_to_all + - all_gather_into_tensor + - reduce_scatter_tensor + - batch_isend_irecv diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/wrap_api.py b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/wrap_api.py deleted file mode 100644 index 0e97929ecd7f8444b19fd531efc49883d0df58de..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/wrap_api.py +++ /dev/null @@ -1,212 +0,0 @@ -# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os - -from mindspore import Tensor, mint, ops -from mindspore.common._stub_tensor import StubTensor -from mindspore.communication import comm_func -from mindspore.mint.nn import functional - -from msprobe.core.common.const import Const -from msprobe.core.common.file_utils import load_yaml -from msprobe.mindspore.common.const import Const as MsConst -from msprobe.mindspore.common.utils import is_mindtorch -from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell - -if is_mindtorch(): - import torch - import torch_npu - -cur_path = os.path.dirname(os.path.realpath(__file__)) -yaml_path = os.path.join(cur_path, MsConst.SUPPORTED_API_LIST_FILE) -torch_yaml_path = os.path.join(cur_path, "../../../pytorch/hook_module", MsConst.SUPPORTED_API_LIST_FILE) - - -class HOOKTensor(object): - pass - - -class HOOKStubTensor(object): - pass - - -class HOOKFunctionalOP(object): - pass - - -class HOOKMintOP(object): - pass - - -class HOOKMintNNFunctionalOP(object): - pass - - -class HOOKDistributedOP(object): - pass - - -class HOOKTorchOP(object): - pass - - -class HOOKTorchTensor(object): - pass - - -class HOOKTorchFunctionalOP(object): - pass - - -class HOOKTorchDistributedOP(object): - pass - - -class HOOKTorchNpuOP(object): - pass - - -class ApiTemplate(HOOKCell): - def __init__(self, api_name, api_dict, prefix, hook): - self.api_name = api_name - self.api_func = api_dict[api_name] - self.prefix_api_name = prefix + str(api_name.split(Const.SEP)[-1]) + Const.SEP - super().__init__(hook) - - @staticmethod - def async_to_sync(output): - # Fake handle, used to return after the CommHandle executes the wait method - fake_handle = type("FakeHandle", (), {"wait": lambda self: None})() - if isinstance(output, tuple) and len(output) == 2 and hasattr(output[1], "wait"): - output[1].wait() - output = (output[0], fake_handle) - elif hasattr(output, "wait"): - output.wait() - output = fake_handle - return output - - def construct(self, *args, **kwargs): - if self.api_name.startswith(MsConst.DROPOUT_API_NAME_PREFIX): - return args[0] if args else kwargs.get(Const.INPUT) - - output = self.api_func(*args, **kwargs) - - if self.prefix_api_name.startswith(MsConst.DISTRIBUTED_DATA_PREFIX): - if kwargs.get("async_op") or self.api_name in ["isend", "irecv"]: - output = self.async_to_sync(output) - return output - - def forward(self, *args, **kwargs): - if self.api_name.startswith(MsConst.DROPOUT_API_NAME_PREFIX): - return args[0] if args else kwargs.get(Const.INPUT) - return self.api_func(*args, **kwargs) - - -class WrapApiName: - def __init__(self, tensor_api_names, stub_tensor_api_names, ops_api_names, mint_api_names, mint_nn_func_api_names, - distributed_api_names): - self.tensor_api_names = tensor_api_names - self.stub_tensor_api_names = stub_tensor_api_names - self.ops_api_names = ops_api_names - self.mint_api_names = mint_api_names - self.mint_nn_func_api_names = mint_nn_func_api_names - self.distributed_api_names = distributed_api_names - - -class WrapTorchApiName: - def __init__(self, torch_api_names, tensor_api_names, functional_api_names, distributed_api_names, npu_api_names): - self.torch_api_names = torch_api_names - self.tensor_api_names = tensor_api_names - self.functional_api_names = functional_api_names - self.distributed_api_names = distributed_api_names - self.npu_api_names = npu_api_names - - -def get_wrap_api_list(): - api_list = load_yaml(yaml_path) - tensor_api = api_list.get(MsConst.SUPPORTED_TENSOR_LIST_KEY) - ops_api = api_list.get(MsConst.SUPPORTED_OPS_LIST_KEY) - mint_api = api_list.get(MsConst.SUPPORTED_MINT_LIST_KEY) - mint_nn_func_api = api_list.get(MsConst.SUPPORTED__MINT_NN_FUNC_LIST_KEY) - distributed_api = api_list.get(MsConst.SUPPORTED_COMM_LIST_KEY) - wrap_api_name = WrapApiName(set(tensor_api) & set(dir(Tensor)), - set(tensor_api) & set(dir(StubTensor)), - set(ops_api) & set(dir(ops)), - set(mint_api) & set(dir(mint)), - set(mint_nn_func_api) & set(dir(functional)), - set(distributed_api) & set(dir(comm_func))) - return wrap_api_name - - -def get_wrap_torch_api_list(): - api_list = load_yaml(torch_yaml_path) - torch_api = api_list.get("torch") - tensor_api = api_list.get("tensor") - functional_api = api_list.get("functional") - distributed_api = api_list.get("distributed") - npu_api = api_list.get("torch_npu") - wrap_api_name = WrapTorchApiName(set(torch_api) & set(dir(torch)), - set(tensor_api) & set(dir(torch.Tensor)), - set(functional_api) & set(dir(torch.nn.functional)), - set(distributed_api) & set(dir(torch.distributed)), - set(npu_api) & set(dir(torch_npu))) - return wrap_api_name - - -def wrap_api_func(api_name, api_dict, prefix, hook): - def api_function(*args, **kwargs): - return ApiTemplate(api_name, api_dict, prefix, hook)(*args, **kwargs) - return api_function - - -def wrap_api_func_and_bind(api_list, api_dict, prefix, hook, hook_class): - for api_name in api_list: - if callable(api_dict[api_name]): - setattr(hook_class, Const.ATTR_NAME_PREFIX + api_name, wrap_api_func(api_name, api_dict, prefix, hook)) - - -def setup_hooks(hook): - if is_mindtorch(): - torch_wrap_api_name = get_wrap_torch_api_list() - wrap_api_func_and_bind(torch_wrap_api_name.torch_api_names, - {f: getattr(torch, f) for f in dir(torch)}, - MsConst.TORCH_DATA_PREFIX, hook, HOOKTorchOP) - wrap_api_func_and_bind(torch_wrap_api_name.tensor_api_names, - {f: getattr(torch.Tensor, f) for f in dir(torch.Tensor)}, - MsConst.TENSOR_DATA_PREFIX, hook, HOOKTorchTensor) - wrap_api_func_and_bind(torch_wrap_api_name.functional_api_names, - {f: getattr(torch.nn.functional, f) for f in dir(torch.nn.functional)}, - MsConst.OPS_DATA_PREFIX, hook, HOOKTorchFunctionalOP) - wrap_api_func_and_bind(torch_wrap_api_name.distributed_api_names, - {f: getattr(torch.distributed, f) for f in dir(torch.distributed)}, - MsConst.DISTRIBUTED_DATA_PREFIX, hook, HOOKTorchDistributedOP) - wrap_api_func_and_bind(torch_wrap_api_name.npu_api_names, {f: getattr(torch_npu, f) for f in dir(torch_npu)}, - MsConst.TORCH_NPU_DATA_PREFIX, hook, HOOKTorchNpuOP) - return - - wrap_api_name = get_wrap_api_list() - wrap_api_func_and_bind(wrap_api_name.tensor_api_names, {f: getattr(Tensor, f) for f in dir(Tensor)}, - MsConst.TENSOR_DATA_PREFIX, hook, HOOKTensor) - wrap_api_func_and_bind(wrap_api_name.stub_tensor_api_names, {f: getattr(StubTensor, f) for f in dir(StubTensor)}, - MsConst.STUB_TENSOR_DATA_PREFIX, hook, HOOKStubTensor) - wrap_api_func_and_bind(wrap_api_name.ops_api_names, {f: getattr(ops, f) for f in dir(ops)}, - MsConst.OPS_DATA_PREFIX, hook, HOOKFunctionalOP) - wrap_api_func_and_bind(wrap_api_name.mint_api_names, {f: getattr(mint, f) for f in dir(mint)}, - MsConst.MINT_DATA_PREFIX, hook, HOOKMintOP) - wrap_api_func_and_bind(wrap_api_name.mint_nn_func_api_names, {f: getattr(functional, f) for f in dir(functional)}, - MsConst.MINT_NN_FUNC_DATA_PREFIX, hook, HOOKMintNNFunctionalOP) - wrap_api_func_and_bind(wrap_api_name.distributed_api_names, {f: getattr(comm_func, f) for f in dir(comm_func)}, - MsConst.DISTRIBUTED_DATA_PREFIX, hook, HOOKDistributedOP) diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/jit_dump.py b/debug/accuracy_tools/msprobe/mindspore/dump/jit_dump.py index 0a32200639a1f3805f815c37caaef5d3bb64c82f..90565ac6d404a1e35e20adc0e55ee81f7d2b4eb2 100644 --- a/debug/accuracy_tools/msprobe/mindspore/dump/jit_dump.py +++ b/debug/accuracy_tools/msprobe/mindspore/dump/jit_dump.py @@ -13,9 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os from collections import defaultdict +import os +import types +import mindspore +from mindspore import nn from mindspore._c_expression import PyNativeExecutor_ try: from mindspore.common.api import _MindsporeFunctionExecutor @@ -24,30 +27,31 @@ except ImportError: from msprobe.core.common.log import logger from msprobe.core.common.const import Const +from msprobe.core.common.runtime import Runtime from msprobe.core.data_dump.data_processor.base import ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs -from msprobe.mindspore.dump.hook_cell.api_registry import api_register +from msprobe.mindspore.common.const import Const as MsConst +from msprobe.mindspore.dump.hook_cell.api_register import get_api_register + + +_api_register = get_api_register() def dump_jit(name, in_feat, out_feat, is_forward): pid = os.getpid() - ori_args = str(name) - index = ori_args.find("<") - if index != 0 and index != -1: - result = ori_args[0:index] - elif name is not None and "<" not in str(name): - result = str(name) - else: - result = "JitFunction" + name = name if name else "JitFunction" if JitDump.need_dump(): if is_forward: - JitDump.jit_count[result] += 1 - name_template = (Const.JIT + Const.SEP + result + Const.SEP + - str(JitDump.jit_count[result]) + Const.SEP + Const.FORWARD) + if name in JitDump.jit_count: + JitDump.jit_count[name] += 1 + else: + JitDump.jit_count[name] = 0 + name_template = (Const.JIT + Const.SEP + name + Const.SEP + + str(JitDump.jit_count[name]) + Const.SEP + Const.FORWARD) JitDump.data_collector.update_api_or_module_name(name_template) module_input_output = ModuleForwardInputsOutputs(args=in_feat, kwargs={}, output=out_feat) JitDump.data_collector.forward_data_collect(name_template, None, pid, module_input_output) else: - name_template = Const.JIT + Const.SEP + result + Const.SEP + str(JitDump.jit_count[result]) + Const.SEP + \ + name_template = Const.JIT + Const.SEP + name + Const.SEP + str(JitDump.jit_count[name]) + Const.SEP + \ Const.BACKWARD JitDump.data_collector.update_api_or_module_name(name_template) module_input_output = ModuleBackwardInputsOutputs(grad_input=in_feat, grad_output=out_feat) @@ -57,7 +61,7 @@ def dump_jit(name, in_feat, out_feat, is_forward): class JitDump(_MindsporeFunctionExecutor): dump_config = None jit_enable = False - jit_dump_switch = True + jit_dump_switch = False jit_count = defaultdict(int) def __init__(self, *args, **kwargs): @@ -68,19 +72,17 @@ class JitDump(_MindsporeFunctionExecutor): self._executor = PyNativeExecutor_.get_instance() def __call__(self, *args, **kwargs): - if JitDump.jit_dump_switch: - api_register.api_set_ori_func() + _api_register.restore_all_api() out = super().__call__(*args, **kwargs) - if JitDump.jit_dump_switch and len(args) > 0: - if self.name and self.name != "construct": + if JitDump.jit_dump_switch and len(args) > 0 and self.name: + if self.name != "construct": dump_jit(self.name, args, out, True) - else: - dump_jit(args[0], args, out, True) + elif Runtime.run_mode != MsConst.PYNATIVE_GRAPH_MODE and isinstance(args[0], nn.Cell): + dump_jit(args[0].__class__.__name__, args, out, True) JitDump.jit_enable = True elif len(args) == 0: logger.warning(f"The jit function {self.name} has no input arguments, nothing will be dumped.") - if JitDump.jit_dump_switch: - api_register.api_set_hook_func() + _api_register.register_all_api() return out @classmethod @@ -101,9 +103,15 @@ class JitDump(_MindsporeFunctionExecutor): def grad(self, obj, grad, weights, grad_position, *args, **kwargs): if JitDump.jit_dump_switch and JitDump.jit_enable: - api_register.api_set_ori_func() - output = self._executor.grad(grad, obj, weights, grad_position, *args, *(kwargs.values())) + _api_register.restore_all_api() + if mindspore.__version__ >= "2.5": + output = self._executor.grad(grad, obj, weights, grad_position, False, *args, *(kwargs.values())) + else: + output = self._executor.grad(grad, obj, weights, grad_position, *args, *(kwargs.values())) if JitDump.jit_dump_switch and JitDump.jit_enable: - dump_jit(obj, args, None, False) - api_register.api_set_hook_func() + if isinstance(obj, types.FunctionType): + dump_jit(obj.__name__, args, None, False) + elif Runtime.run_mode != MsConst.PYNATIVE_GRAPH_MODE and isinstance(obj, nn.Cell): + dump_jit(obj.__class__.__name__, args, None, False) + _api_register.register_all_api() return output diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/kernel_kbyk_dump.py b/debug/accuracy_tools/msprobe/mindspore/dump/kernel_kbyk_dump.py index 2c46b0c73e7789ea41afb991bb985e089b2349cd..91a6ab93abad35c39810d2b3e9ec731605694aa5 100644 --- a/debug/accuracy_tools/msprobe/mindspore/dump/kernel_kbyk_dump.py +++ b/debug/accuracy_tools/msprobe/mindspore/dump/kernel_kbyk_dump.py @@ -39,9 +39,19 @@ class KernelKbykDump: common_set["input_output"] = 0 common_set["kernels"] = [] common_set["support_device"] = [0, 1, 2, 3, 4, 5, 6, 7] - e2e_set = dict() - e2e_set["enable"] = True - e2e_set["trans_flag"] = True + + if config.stat_cal_mode and config.device_stat_precision_mode: + e2e_set = { + "enable": not config.async_dump, + "trans_flag": True, + "stat_calc_mode": config.stat_cal_mode, + "device_stat_precision_mode": config.device_stat_precision_mode + } + else: + e2e_set = { + "enable": not config.async_dump, + "trans_flag": True + } if config.list: common_set["dump_mode"] = 1 diff --git a/debug/accuracy_tools/msprobe/mindspore/free_benchmark/api_pynative_self_check.py b/debug/accuracy_tools/msprobe/mindspore/free_benchmark/api_pynative_self_check.py index 57b7de4fa567d73a19178256d79f5e4cbeb38864..6a7c85b3c5874e310d71fefac33a93e5ef5e57ce 100644 --- a/debug/accuracy_tools/msprobe/mindspore/free_benchmark/api_pynative_self_check.py +++ b/debug/accuracy_tools/msprobe/mindspore/free_benchmark/api_pynative_self_check.py @@ -19,22 +19,27 @@ import os import traceback import mindspore as ms + from msprobe.core.common.const import Const from msprobe.core.common.exceptions import DistributedNotInitializedError from msprobe.core.common.file_utils import check_path_length, load_yaml +from msprobe.core.common.runtime import Runtime +from msprobe.core.hook_manager import HookSet from msprobe.mindspore.common.const import Const as MsConst from msprobe.mindspore.common.const import FreeBenchmarkConst from msprobe.mindspore.common.log import logger from msprobe.mindspore.common.utils import get_rank_if_initialized from msprobe.mindspore.debugger.debugger_config import DebuggerConfig -from msprobe.mindspore.dump.hook_cell.api_registry import api_register +from msprobe.mindspore.dump.hook_cell.api_register import get_api_register from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell from msprobe.mindspore.free_benchmark.common.config import Config from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams from msprobe.mindspore.free_benchmark.common.utils import Tools from msprobe.mindspore.free_benchmark.handler.handler_factory import HandlerFactory from msprobe.mindspore.free_benchmark.perturbation.perturbation_factory import PerturbationFactory -from msprobe.mindspore.runtime import Runtime + + +_api_register = get_api_register() class ApiPyNativeSelfCheck: @@ -60,8 +65,8 @@ class ApiPyNativeSelfCheck: self.store_original_func() def handle(self): - api_register.initialize_hook(self.build_hook) - api_register.api_set_hook_func() + _api_register.initialize_hook(self.build_hook) + _api_register.register_all_api() def build_hook(self, api_name): def pre_hook(cell, input_data): @@ -71,7 +76,7 @@ class ApiPyNativeSelfCheck: ret = None if not need_wrapper_func(): - del cell.input_kwargs + del cell.msprobe_input_kwargs return ret api_name_with_id = api_name_with_id[:-1] @@ -80,9 +85,9 @@ class ApiPyNativeSelfCheck: api_name_with_id[api_name_with_id.find(Const.SEP) + 1:api_name_with_id.rfind(Const.SEP)]) if api_name in self.api_list: ret = check_self(api_name_with_id, output_data, self.ori_func.get(api_name), - *input_data, **cell.input_kwargs) + *input_data, **cell.msprobe_input_kwargs) - del cell.input_kwargs + del cell.msprobe_input_kwargs return ret def backward_hook(cell, grad_input, grad_output): @@ -101,8 +106,13 @@ class ApiPyNativeSelfCheck: def pre_backward_hook(cell, grad_input): return None - - return pre_hook, wrap_forward_hook, wrap_backward_hook, pre_backward_hook + + return HookSet( + forward_hook=wrap_forward_hook, + forward_pre_hook=pre_hook, + backward_hook=wrap_backward_hook, + backward_pre_hook=pre_backward_hook + ) def store_original_func(self): for api_name in self.api_list: @@ -166,13 +176,13 @@ def check_self(api_name_with_id, output, ori_func, *args, **kwargs): return ret logger.info(f"[{api_name_with_id}] is {Config.handler_type}ing.") - api_register.api_set_ori_func() + _api_register.restore_all_api() try: perturbation = PerturbationFactory.create(api_name_with_id) params.fuzzed_result = perturbation.handle(params) if params.fuzzed_result is False: - api_register.api_set_hook_func() + _api_register.register_all_api() return ret if Config.stage == Const.BACKWARD: params.original_result = Tools.get_grad(params.original_func, *params.args, **params.kwargs) @@ -183,7 +193,7 @@ def check_self(api_name_with_id, output, ori_func, *args, **kwargs): logger.error(f"[{api_name_with_id}] Error: {str(e)}") logger.error(f"[{api_name_with_id}] Error detail: {traceback.format_exc()}") - api_register.api_set_hook_func() + _api_register.register_all_api() return ret diff --git a/debug/accuracy_tools/msprobe/mindspore/free_benchmark/common/utils.py b/debug/accuracy_tools/msprobe/mindspore/free_benchmark/common/utils.py index 14a72a5e6b6a6289595897a15c46a0e6397bcd1a..c3f9b27fe2b5792119f7105955cdecfa6bdc51d4 100644 --- a/debug/accuracy_tools/msprobe/mindspore/free_benchmark/common/utils.py +++ b/debug/accuracy_tools/msprobe/mindspore/free_benchmark/common/utils.py @@ -19,10 +19,10 @@ from typing import Any, Optional import mindspore as ms from mindspore import Tensor, ops +from msprobe.core.common.runtime import Runtime from msprobe.mindspore.common.const import FreeBenchmarkConst from msprobe.mindspore.free_benchmark.common.config import Config from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams -from msprobe.mindspore.runtime import Runtime class Tools: diff --git a/debug/accuracy_tools/msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py b/debug/accuracy_tools/msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py index 3fd1430bff792d5043429caac8fe477e457b8bee..39ca164f2043c5d8f6d2e05987edfffe5bca2bee 100644 --- a/debug/accuracy_tools/msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +++ b/debug/accuracy_tools/msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,6 +14,7 @@ # limitations under the License. from msprobe.mindspore.common.const import FreeBenchmarkConst +from msprobe.mindspore.common.log import logger from msprobe.mindspore.free_benchmark.common.config import Config from msprobe.mindspore.free_benchmark.perturbation.add_noise import AddNoisePerturbation from msprobe.mindspore.free_benchmark.perturbation.bit_noise import BitNoisePerturbation @@ -41,4 +42,5 @@ class PerturbationFactory: if perturbation: return perturbation(api_name_with_id) else: - raise Exception(f'{Config.pert_type} is a invalid perturbation type') + logger.error(f'{Config.pert_type} is a invalid perturbation type') + raise ValueError diff --git a/debug/accuracy_tools/msprobe/mindspore/free_benchmark/self_check_tool_factory.py b/debug/accuracy_tools/msprobe/mindspore/free_benchmark/self_check_tool_factory.py index 35b5eb2ab65511fa4320dc97702a60a9c8d07f62..b21b15d1758a90e62861c7edf2976d38ab43c5f0 100644 --- a/debug/accuracy_tools/msprobe/mindspore/free_benchmark/self_check_tool_factory.py +++ b/debug/accuracy_tools/msprobe/mindspore/free_benchmark/self_check_tool_factory.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,6 +14,7 @@ # limitations under the License. from msprobe.mindspore.common.const import Const +from msprobe.core.common.log import logger from msprobe.mindspore.debugger.debugger_config import DebuggerConfig from msprobe.mindspore.free_benchmark.api_pynative_self_check import ApiPyNativeSelfCheck @@ -41,8 +42,10 @@ class SelfCheckToolFactory: def create(config: DebuggerConfig): tool = SelfCheckToolFactory.tools.get(config.level) if not tool: - raise Exception(f"{config.level} is not supported.") + logger.error(f"{config.level} is not supported.") + raise ValueError tool = tool.get(config.execution_mode) if not tool: - raise Exception(f"Task free_benchmark is not supported in this mode: {config.execution_mode}.") + logger.error(f"Task free_benchmark is not supported in this mode: {config.execution_mode}.") + raise ValueError return tool(config) diff --git a/debug/accuracy_tools/msprobe/mindspore/grad_probe/global_context.py b/debug/accuracy_tools/msprobe/mindspore/grad_probe/global_context.py index 01e46e019a4d1634a4592970386d855637c34e8f..f55c254de3531e56525b49a12ba697bd4e9dfdf8 100644 --- a/debug/accuracy_tools/msprobe/mindspore/grad_probe/global_context.py +++ b/debug/accuracy_tools/msprobe/mindspore/grad_probe/global_context.py @@ -16,6 +16,7 @@ import os import threading from typing import Dict, Union, Tuple +import time from msprobe.core.common.utils import is_int from msprobe.core.common.file_utils import create_directory, check_path_before_create @@ -40,8 +41,12 @@ class GlobalContext: def __new__(cls, *args, **kwargs): if cls._instance is None: cls._instance_lock.acquire() - cls._instance = object.__new__(cls) - cls._instance_lock.release() + try: + cls._instance = object.__new__(cls) + except Exception as e: + raise RuntimeError("grad_probe global context init failed") from e + finally: + cls._instance_lock.release() return cls._instance def init_context(self, config_dict: Dict): @@ -68,6 +73,7 @@ class GlobalContext: create_directory(self._setting.get(GradConst.OUTPUT_PATH)) else: logger.warning("The output_path exists, the data will be covered.") + self._setting[GradConst.TIME_STAMP] = str(int(time.time())) def get_context(self, key: str): if key not in self._setting: diff --git a/debug/accuracy_tools/msprobe/mindspore/grad_probe/grad_analyzer.py b/debug/accuracy_tools/msprobe/mindspore/grad_probe/grad_analyzer.py index 8a154f4d65f63e55f6b0cf3165d3c905bcb68546..c46d55b7b481bc89a56f2eac997c1618fb2cdda2 100644 --- a/debug/accuracy_tools/msprobe/mindspore/grad_probe/grad_analyzer.py +++ b/debug/accuracy_tools/msprobe/mindspore/grad_probe/grad_analyzer.py @@ -111,7 +111,8 @@ class CSVGenerator(Process): output_path = context.get_context(GradConst.OUTPUT_PATH) self.level = context.get_context(GradConst.LEVEL) self.bounds = context.get_context(GradConst.BOUNDS) - self.dump_dir = f"{output_path}/rank{rank_id}/Dump/" + time_stamp = context.get_context(GradConst.TIME_STAMP) + self.dump_dir = f"{output_path}/rank{rank_id}/Dump{time_stamp}/" self.save_dir = f"{output_path}/rank{rank_id}/" self.current_step = None self.stop_event = multiprocessing.Event() diff --git a/debug/accuracy_tools/msprobe/mindspore/grad_probe/grad_stat_csv.py b/debug/accuracy_tools/msprobe/mindspore/grad_probe/grad_stat_csv.py index 9cc30ea1b9d6575bcd5af94c27f19cb93ed7246d..820f7f21d0cd5e6b1fd98f93f1515515407358c0 100644 --- a/debug/accuracy_tools/msprobe/mindspore/grad_probe/grad_stat_csv.py +++ b/debug/accuracy_tools/msprobe/mindspore/grad_probe/grad_stat_csv.py @@ -15,6 +15,7 @@ import hashlib from abc import ABC, abstractmethod +import zlib import mindspore from mindspore import ops @@ -76,8 +77,8 @@ class CsvMd5(CsvItem): def generate_csv_content(csv_input): grad = csv_input.grad tensor_bytes = grad.float().numpy().tobytes() - md5_hash = hashlib.md5(tensor_bytes) - return [md5_hash.hexdigest()] + md5_hash = f"{zlib.crc32(tensor_bytes):08x}" + return [md5_hash] @register_csv_item(GradConst.DISTRIBUTION) diff --git a/debug/accuracy_tools/msprobe/mindspore/grad_probe/hook.py b/debug/accuracy_tools/msprobe/mindspore/grad_probe/hook.py index 1aa9fcfad10815d5845de66ab0ea6d4d7211741f..36857636fa301db37ae4267f8e18d41d9f0328a5 100644 --- a/debug/accuracy_tools/msprobe/mindspore/grad_probe/hook.py +++ b/debug/accuracy_tools/msprobe/mindspore/grad_probe/hook.py @@ -49,12 +49,10 @@ class HookInput: self.param_list = grad_context.get_context(GradConst.PARAM_LIST) self.rank_id = get_rank_id() output_path = grad_context.get_context(GradConst.OUTPUT_PATH) - self.dump_dir = os.path.join(output_path, f"rank{self.rank_id}", "Dump") + time_stamp = grad_context.get_context(GradConst.TIME_STAMP) + self.dump_dir = os.path.join(output_path, f"rank{self.rank_id}", f"Dump{time_stamp}") self.save_dir = os.path.join(output_path, f"rank{self.rank_id}") self.step_finish_flag = os.path.join(self.dump_dir, GradConst.STEP_FINISH) - if os.path.exists(self.save_dir): - logger.warning(f"Delete existing path {self.save_dir}.") - remove_path(self.save_dir) self.level = grad_context.get_context(GradConst.LEVEL) self.bounds = grad_context.get_context(GradConst.BOUNDS) self.mode = mindspore.get_context("mode") diff --git a/debug/accuracy_tools/msprobe/mindspore/mindspore_service.py b/debug/accuracy_tools/msprobe/mindspore/mindspore_service.py new file mode 100644 index 0000000000000000000000000000000000000000..ad78be966f07ba5e59dfb055c1b4e27dfbacde86 --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/mindspore_service.py @@ -0,0 +1,111 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import defaultdict +import mindspore as ms +from mindspore.ops.primitive import Primitive + +from msprobe.core.common.utils import Const +from msprobe.core.service import BaseService +from msprobe.mindspore.cell_processor import CellProcessor +from msprobe.mindspore.common.log import logger +from msprobe.mindspore.common.utils import ( + get_rank_if_initialized, + is_mindtorch, + get_cells_and_names_with_index +) +from msprobe.mindspore.dump.hook_cell.api_register import get_api_register, ApiTemplate +from msprobe.mindspore.dump.hook_cell.ms_hook_manager import MindsproeHookManager +from msprobe.mindspore.dump.hook_cell.primitive_hooks import PrimitiveHookService +from msprobe.mindspore.dump.jit_dump import JitDump + +try: + from mindspore.common._pijit_context import PIJitCaptureContext +except ImportError: + pijit_label = False +else: + pijit_label = True + + +class MindsporeService(BaseService): + @property + def _get_framework_type(self): + return Const.MT_FRAMEWORK if is_mindtorch() else Const.MS_FRAMEWORK + + @staticmethod + def _get_current_rank(): + return get_rank_if_initialized() + + def empty(self, *args, **kwargs): + pass + + def _init_specific_components(self): + self.logger = logger + self.api_register = get_api_register() + self.primitive_hook_service = PrimitiveHookService(self) + self.cell_processor = CellProcessor(self.data_collector.scope) + self.hook_manager = MindsproeHookManager(self.data_collector, self.config) + self._setup_jit_context() + self.api_template = ApiTemplate + + def _setup_jit_context(self): + if self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L1]: + JitDump.set_config(self.config) + JitDump.set_data_collector(self.data_collector) + if hasattr(ms.common.api, "_MindsporeFunctionExecutor"): + ms.common.api._MindsporeFunctionExecutor = JitDump + else: + ms.common.api._JitExecutor = JitDump + ms.common.api._PyNativeExecutor.grad = JitDump.grad + if pijit_label: + PIJitCaptureContext.__enter__ = self.empty + PIJitCaptureContext.__exit__ = self.empty + + def _register_module_hook(self): + self.cell_processor.register_cell_hook(self.model, self.build_hook, self.config) + self.logger.info_on_rank_0(f"The module {self.config.task} hook function is successfully mounted to the model.") + + def _register_hook(self): + self._register_primitive_hook() + + def _register_primitive_hook(self): + if self.config.level not in [Const.LEVEL_MIX, Const.LEVEL_L1]: + return + if not self.model or self.config.task not in Const.DUMP_DATA_COLLECTION_LIST: + return + + primitive_set = set() + cells_and_names_with_index, _ = get_cells_and_names_with_index(self.model) + for cells_and_names in cells_and_names_with_index.values(): + for _, cell in cells_and_names: + for attribute, value in vars(cell).items(): + if isinstance(value, Primitive): + primitive_set.add((attribute, value)) + + for pname, primitive in primitive_set: + primitive_class_name = primitive.__class__.__name__ + primitive_combined_name = pname + Const.SEP + primitive_class_name + new_primitive = type('NewPrimitive', (primitive.__class__,), + {'__call__': self.primitive_hook_service.wrap_primitive(primitive.__call__, + primitive_combined_name)}) + primitive.__class__ = new_primitive + + def _reset_status(self): + super()._reset_status() + self.primitive_hook_service.primitive_counters.clear() + JitDump.jit_count = defaultdict(int) + + def _change_jit_switch(self, status): + JitDump.jit_dump_switch = status diff --git a/debug/accuracy_tools/msprobe/mindspore/mindtorch/__init__.py b/debug/accuracy_tools/msprobe/mindspore/mindtorch/__init__.py index fc695d05ccc010f824b61db39a8ea77714d2d73b..13427188c913bff230fd798c1374c9839d7dd092 100644 --- a/debug/accuracy_tools/msprobe/mindspore/mindtorch/__init__.py +++ b/debug/accuracy_tools/msprobe/mindspore/mindtorch/__init__.py @@ -1,18 +1,18 @@ -# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from .mindtorch_adaptor import (_call_impl, - register_full_backward_pre_hook, - register_full_backward_hook) +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .mindtorch_adaptor import (_call_impl, + register_full_backward_pre_hook, + register_full_backward_hook) diff --git a/debug/accuracy_tools/msprobe/mindspore/monitor/common_func.py b/debug/accuracy_tools/msprobe/mindspore/monitor/common_func.py new file mode 100644 index 0000000000000000000000000000000000000000..ef72a75ca246a8943bf580ba490465d2cca2c09b --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/monitor/common_func.py @@ -0,0 +1,91 @@ +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from mindspore import nn +from mindspore import communication +from msprobe.mindspore.monitor.utils import logger +from msprobe.mindspore.common.utils import is_mindtorch +if is_mindtorch(): + import torch + + +def is_valid_instance(model): + return isinstance(model, torch.nn.Module) if is_mindtorch() else isinstance(model, nn.Cell) + + +def get_submodules(model): + if not is_valid_instance(model): + logger.info("Counter invalid model, nothing to hook") + return {} + return model.named_modules() if is_mindtorch() else model.cells_and_names() + + +def get_parameters(model): + if not is_valid_instance(model): + return {} + if is_mindtorch(): + return model.named_parameters() + else: + return model.parameters_and_names() + + +def get_rank(): + if comm_is_initialized(): + return communication.get_rank() + return 0 + + +def comm_is_initialized(): + return communication.GlobalComm.INITED + + +def optimizer_pre_hook(optimizer, fn): + """ + fn should be fn(optimizer, args, **kwargs) + """ + if is_mindtorch(): + origin_api = optimizer.__class__.step + + def patch_step(func, optimizer): + def wrapper(*args, **kwargs): + fn(optimizer, args, kwargs) + out = func(*args, **kwargs) + return out + return wrapper + optimizer.__class__.step = patch_step(optimizer.__class__.step, optimizer) + return (optimizer.__class__.step, origin_api) + + else: + handle = optimizer.register_forward_pre_hook(fn) + return handle + + +def optimizer_post_hook(optimizer, fn): + if is_mindtorch(): + origin_api = optimizer.__class__.step + + def patch_step(func, optimizer): + def wrapper(*args, **kwargs): + out = func(*args, **kwargs) + fn(optimizer, args, kwargs) + return out + return wrapper + optimizer.__class__.step = patch_step(optimizer.__class__.step, optimizer) + return (optimizer.__class__.step, origin_api) + + else: + handle = optimizer.register_forward_hook(fn) + return handle diff --git a/debug/accuracy_tools/msprobe/mindspore/monitor/anomaly_detect.py b/debug/accuracy_tools/msprobe/mindspore/monitor/data_writers.py similarity index 42% rename from debug/accuracy_tools/msprobe/mindspore/monitor/anomaly_detect.py rename to debug/accuracy_tools/msprobe/mindspore/monitor/data_writers.py index 3544ebbd025614349585bc799b15e00a5c2c7956..85c1096123c337a123b16f18236655bfe6e49c5e 100644 --- a/debug/accuracy_tools/msprobe/mindspore/monitor/anomaly_detect.py +++ b/debug/accuracy_tools/msprobe/mindspore/monitor/data_writers.py @@ -15,91 +15,20 @@ import itertools import os -import sys -import statistics as st -from abc import ABC -from dataclasses import dataclass, field -from typing import List +from dataclasses import dataclass from collections import defaultdict import pandas as pd - from mindspore import ops +from mindspore import Tensor from mindspore import _no_grad + from msprobe.core.common.log import logger from msprobe.core.common.file_utils import change_mode, create_directory, write_df_to_csv +from msprobe.core.monitor.anomaly_processor import AnomalyDataFactory, AnomalyTurbulence, AnomalyScanner from msprobe.core.common.const import FileCheckConst, MonitorConst -class ScanRule(ABC): - name = "ScanRule" - - def apply(self, history, cur): - raise NotImplementedError("abstract method apply is not implemented") - - -class AnomalyTurbulence(ScanRule): - name = "AnomalyTurbulence" - - def __init__(self, threshold) -> None: - self.threshold = threshold - - def apply(self, history, cur): - baseline = st.mean(history) if isinstance(history, list) else history - - up_bound = baseline + baseline * self.threshold - if baseline > 0: - return cur > up_bound - else: - return cur < up_bound - - -class AnomalyScanner: - - @staticmethod - def load_rules(specs: List[dict]): - """ - specs: [{"rule_name": "AnomalyTurbulence", "args": {"threshold": 0.5}}] - """ - if specs is None: - return [] - alert_rules = [] - for spec in specs: - # 使用get方法获取键值,如果键不存在则返回None - rule_cls_name = spec.get("rule_name") - rule_args = spec.get("args") - - # 检查必要的键是否存在 - if rule_cls_name is None or rule_args is None: - logger.warning(f"Spec is missing required keys: {spec}") - continue - - cur_module = sys.modules.get(__name__) - try: - rule_cls = getattr(cur_module, rule_cls_name) - except AttributeError: - logger.error(f"Rule class '{rule_cls_name}' not found in the current module.") - continue - - try: - rule_instance = rule_cls(**rule_args) - alert_rules.append(rule_instance) - except Exception as e: - logger.error(f"Error creating instance of rule '{rule_cls_name}': {e}") - continue - - return alert_rules - - @staticmethod - def scan(scan_rules: List[ScanRule], history, cur): - anomaly = False - for rule in scan_rules: - anomaly = rule.apply(history, cur) - if anomaly: - return anomaly, rule.name - return anomaly, None - - class BCOLORS: HEADER = '\033[95m' OKBLUE = '\033[94m' @@ -112,130 +41,6 @@ class BCOLORS: UNDERLINE = '\033[4m' -class AnomalyDataFactory(ABC): - def __init__(self, rank, pp_stage, group_mates): - super().__init__() - self.rank = rank - self.pp_stage = pp_stage - self.group_mates = group_mates - self.micro_step = 0 - self.name2callid = {} - - def set_call_id(self, name2callid): - """根据当前GradContext信息更新call_id vpp_stage等信息 - """ - self.name2callid = name2callid - - def create(self, tag, message, step): - """如果检查出异常, 调用当前接口生成GradAnomalyData实例 - tag (tuple): metric tag ('0:1.post_attention_norm.weight/rank0/pre_grad', 'min') - message (str): anomaly detect message - step (int): training step - """ - if not isinstance(tag, tuple) or len(tag) != 2: - raise ValueError("tag must be a tuple with length 2") - tag_name = tag[0] - param_name = tag_name.split('/')[0] - call_id = self.name2callid.get(tag_name, -1) - if MonitorConst.NAME_SEP in param_name: - vpp_stage = int(param_name.split(MonitorConst.NAME_SEP)[0]) - else: - vpp_stage = 0 - - return GradAnomalyData( - self.rank, - step, - self.micro_step, - self.pp_stage, - vpp_stage, - call_id, - tag_name, - message, - self.group_mates - ) - - -class TrainStage: - DEFAULT_STAGE = -1 - FORWARD_STAGE = 0 - BACKWARD_STAGE = 1 - OPTIMIZER_STAGE = 2 - - -FORWARD_KEY = [MonitorConst.ACTV_IN, MonitorConst.ACTV_OUT] -BACKWARD_KEY = [MonitorConst.ACTVGRAD_IN, MonitorConst.ACTVGRAD_OUT, - MonitorConst.PRE_GRAD, MonitorConst.POST_GRAD, MonitorConst.ACC_GRAD] -OPTIMIZER_KEY = [MonitorConst.EXP_AVG, MonitorConst.EXP_AVG_SQ] -TRAIN_STAGE = { - **{key_: TrainStage.FORWARD_STAGE for key_ in FORWARD_KEY}, - **{key_: TrainStage.BACKWARD_STAGE for key_ in BACKWARD_KEY}, - **{key_: TrainStage.OPTIMIZER_STAGE for key_ in OPTIMIZER_KEY} -} - - -@dataclass(eq=True) -class GradAnomalyData: - rank: int = 0 - step: int = 0 - micro_step: int = 0 - pp_stage: int = 0 - vpp_stage: int = 0 - call_id: int = 0 - tag_name: str = field(default=None, compare=False) - message: str = field(default="", compare=False) - group_mates: list = field(default=None, compare=False) - - def __lt__(self, other): - """ - 自定义比较函数,用于确定 GradAnomalyData 实例之间的顺序。 - 比较规则为: - step 和 micro_step 值越小优先级越高; - vpp 和 pp 在前向阶段值越小优先级越高,在非前向阶段值越大优先级越高; - call_id 值越小优先级越高。 - """ - if not isinstance(other, GradAnomalyData): - return NotImplemented - - self_train_stage = self.get_train_stage(self.tag_name) - other_train_stage = self.get_train_stage(other.tag_name) - - def vpp_pp_comparator(anomaly): - """ - Determine the priority rule for vpp and pp based on train stage - Forward stage prefers smaller vpp and pp - Other stages prefer larger vpp and pp - """ - if self_train_stage == TrainStage.FORWARD_STAGE: - return anomaly.vpp_stage, anomaly.pp_stage - else: - return -anomaly.vpp_stage, -anomaly.pp_stage - - self_cmp = [self.step, self.micro_step, self_train_stage, *vpp_pp_comparator(self), self.call_id] - other_cmp = [other.step, other.micro_step, other_train_stage, *vpp_pp_comparator(other), other.call_id] - return self_cmp < other_cmp - - def __le__(self, other): - if not isinstance(other, GradAnomalyData): - return NotImplemented - return self == other or self < other - - @staticmethod - def get_train_stage(tag_name): - """ - :param tag_name: "0:fc2_0/rank0/input", "0:fc1.weight/rank0/post_grad", "0:fc2.weight/rank0/exp_avg_sq" - :return: int, if forward return 0; if backward return 1; if optimizer return 2 - """ - key_ = tag_name.split("/")[-1] - return TRAIN_STAGE.get(key_, TrainStage.DEFAULT_STAGE) - - def to_dict(self): - return self.__dict__ - - def get_key(self): - # 0:1.self_attention.core_attention_flash_0/rank0/input_grad - return ''.join([str(self.tag_name), "_step_", str(self.step), "_call_", str(self.call_id)]) - - @dataclass class WriterInput: path: str @@ -254,6 +59,41 @@ class BaseWriterWithAD: self.anomaly_factory = writer_input.anomaly_factory self.anomalies = [] self.ndigits = writer_input.ndigits + self.beta = 0.99 + + @staticmethod + def stack_tensors(tensor_list): + """ + Torch not support stack cpu and xpu tensors. Group the tensors into cpu_group and xpu_group, + stack them separately, migrate xpu_group to cpu, and then restore in the order of input. + + :param tensor_list: [tensor(-1.6165), tensor(-1.0985), tensor(-1.7777), tensor(-1.8408, device='npu:0')] + :return: result: list of float + """ + cpu_tensors = [] + xpu_tensors = [] + + for tensor in tensor_list: + if isinstance(tensor, Tensor): + # 将device上的tensor先stack后to cpu + xpu_tensors.append(tensor) + else: + cpu_tensors.append(tensor) + + xpu_stack = ops.stack(xpu_tensors).tolist() if xpu_tensors else ops.tensor([]) + + # 按照输入的顺序恢复 + result = [] + cpu_tensors_idx, xpu_tensors_idx = 0, 0 + for tensor in tensor_list: + if isinstance(tensor, Tensor): + result.append(xpu_stack[xpu_tensors_idx]) + xpu_tensors_idx += 1 + else: + result.append(cpu_tensors[cpu_tensors_idx]) + cpu_tensors_idx += 1 + + return result def get_anomalies(self): """返回已检测到的异常列表 @@ -272,12 +112,17 @@ class BaseWriterWithAD: Returns: None """ - detected = False - if self.ad_rules: - avg = self._update_tag2scalars(tag, scalar_value) - detected, rule_name = self._ad(scalar_value, history=avg) + if not self.ad_rules or tag[-1] in ["shape", "dtype"]: + return + if isinstance(scalar_value, Tensor): + scalar_value = scalar_value.item() + avg = self._update_tag2scalars(tag, scalar_value) + detected, rule_name = self._ad(scalar_value, history=avg) if detected: - exception_message = f"Rule {rule_name} reports anomaly signal in {tag} at step {global_step}." + if rule_name == AnomalyTurbulence.name and tag[-1] not in ["norm", "mean"]: + return + exception_message = (f"Rule {rule_name} reports anomaly signal in {tag} at step {global_step}, " + f"current value {scalar_value}, history mean {avg}.") logger.info(f"{BCOLORS.WARNING}> {exception_message}{BCOLORS.ENDC}") # append to self.anomalies for dump if self.anomaly_factory: @@ -290,8 +135,12 @@ class BaseWriterWithAD: tags = list(itertools.product(metric_value.keys(), op_list)) for op2tensor in metric_value.values(): tensors.extend(op2tensor.values()) + + if not tensors: + return + with _no_grad(): - metric_list = ops.stack(tensors).tolist() if tensors else [] + metric_list = self.stack_tensors(tensors) for tag, metric in zip(tags, metric_list): self.add_scalar(tag, metric, step, need_explain) @@ -311,11 +160,11 @@ class BaseWriterWithAD: Returns: float: The average value before update. """ + abs_scalar_value = abs(scalar_value) if tag not in self.tag2scalars: - self.tag2scalars[tag] = {'avg': scalar_value, 'count': 0} + self.tag2scalars[tag] = {'avg': abs_scalar_value, 'count': 0} avg = self.tag2scalars[tag]['avg'] - new_avg = (avg * self.tag2scalars[tag]['count'] + scalar_value) / (self.tag2scalars[tag]['count'] + 1) - self.tag2scalars[tag]['avg'] = new_avg + self.tag2scalars[tag]['avg'] = self.beta * avg + (1 - self.beta) * abs_scalar_value self.tag2scalars[tag]['count'] += 1 return avg @@ -353,11 +202,10 @@ class CSVWriterWithAD(BaseWriterWithAD): new_data = [] for name, metric_value in self.context_dict.items(): - if MonitorConst.NAME_SEP not in name: - new_data.append([name] + [step] + metric_value) - else: - new_data.append(name.split(MonitorConst.NAME_SEP) + [step] + metric_value) - new_data = pd.DataFrame(new_data).round(self.ndigits) + new_line = name.split(MonitorConst.NAME_SEP) + metric_value + new_line.insert(2, step) + new_data.append(new_line) + new_data = pd.DataFrame(new_data).round(self.ndigits).fillna("nan") write_df_to_csv(new_data, filepath, mode='a+', header=False) self.context_dict = defaultdict(list) @@ -375,30 +223,15 @@ class CSVWriterWithAD(BaseWriterWithAD): name += '.output' self.context_dict[name].append(scalar_value) - def write_metrics(self, op_list, metric_value, step, prefix='', need_explain=False): + def write_metrics(self, op_list, metric_value, step, prefix='', need_explain=False, **kwargs): need_explain = prefix == 'other' super().write_metrics(op_list, metric_value, step, prefix='', need_explain=need_explain) - # generate csv headers - # set hashmap to reduce the number of headers generated. - # 前向的norm用input.ops_和output.ops_,反向的用input_grad.ops_和output_grad.ops_ - if prefix in {"actv", "actv_grad"}: - if prefix == "actv": - input_and_output = [MonitorConst.ACTV_IN, MonitorConst.ACTV_OUT] - else: - input_and_output = [MonitorConst.ACTVGRAD_IN, MonitorConst.ACTVGRAD_OUT] - ops_ = [MonitorConst.DOT.join(i) for i in itertools.product(input_and_output, op_list)] - csv_header = ["module_name", "step", *ops_] + if prefix in [MonitorConst.ACTV, MonitorConst.ACTVGRAD] or kwargs.get("use_micro_step"): + self.header = MonitorConst.CSV_HEADER_MICRO_STEP + op_list else: - csv_header = ["param_name", "step", *op_list] - - keys = list(metric_value.keys()) - if keys and MonitorConst.NAME_SEP in keys[0]: - csv_header.insert(0, "vpp_stage") - - self.header = csv_header + self.header = MonitorConst.CSV_HEADER + op_list self.write_csv(prefix, step) - self.header = [] def close(self): pass diff --git a/debug/accuracy_tools/msprobe/mindspore/monitor/features.py b/debug/accuracy_tools/msprobe/mindspore/monitor/features.py index be958dadfe8fcc50f26f16c93b3a090269235d1e..4f8b78cf418e22375562a81d5a6a7aa48b2c9435 100644 --- a/debug/accuracy_tools/msprobe/mindspore/monitor/features.py +++ b/debug/accuracy_tools/msprobe/mindspore/monitor/features.py @@ -46,6 +46,8 @@ def get_max(x: Tensor): @_no_grad() def get_zeros(x: Tensor, eps: float): + if x.numel() == 0: + return Tensor(float('nan')) return mint.sum(mint.abs(x) < eps) / x.numel() @@ -54,10 +56,21 @@ def get_nans(t): return ops.isnan(t.astype(mstype.float32)).sum() -FUNC_MAP = {"min" : get_min, - "max" : get_max, - "mean" : get_mean, - "norm" : get_norm, - "nans" : get_nans, - "zeros": get_zeros - } \ No newline at end of file +def get_shape(t): + return t.shape + + +def get_dtype(t): + return t.dtype + + +FUNC_MAP = { + "min": get_min, + "max": get_max, + "mean": get_mean, + "norm": get_norm, + "nans": get_nans, + "zeros": get_zeros, + "shape": get_shape, + "dtype": get_dtype +} diff --git a/debug/accuracy_tools/msprobe/mindspore/monitor/module_hook.py b/debug/accuracy_tools/msprobe/mindspore/monitor/module_hook.py index 068be9ff6c782bec2bf637999ef5f0eabe0c2675..5e5ecbe18f578cb55cc79bb00818585d3c6912bb 100644 --- a/debug/accuracy_tools/msprobe/mindspore/monitor/module_hook.py +++ b/debug/accuracy_tools/msprobe/mindspore/monitor/module_hook.py @@ -20,21 +20,23 @@ from collections import defaultdict from datetime import datetime import pytz -import mindspore as ms +import pandas as pd from mindspore import Tensor, mint from mindspore import nn, _no_grad -from mindspore.communication import get_rank from msprobe.core.common.log import logger -from msprobe.core.common.const import MonitorConst +from msprobe.core.common.const import MonitorConst, Const from msprobe.core.common.file_utils import load_json, save_json +from msprobe.core.monitor.anomaly_processor import AnomalyScanner, AnomalyDataFactory, AnomalyDataWriter +from msprobe.mindspore.common.utils import is_mindtorch +from msprobe.mindspore.monitor.common_func import is_valid_instance, get_parameters, get_submodules, get_rank from msprobe.mindspore.monitor.utils import get_summary_writer_tag_name, validate_config, step_accumulates_one, \ - is_skip_step, get_metrics, get_single_metrics, get_target_output_dir -from msprobe.mindspore.monitor.module_spec_verifier import validate_config_spec -from msprobe.mindspore.monitor.anomaly_detect import AnomalyScanner, AnomalyDataFactory, \ - CSVWriterWithAD, BaseWriterWithAD, WriterInput -from msprobe.mindspore.monitor.distributed.wrap_distributed import api_register, create_hooks, op_aggregate, \ - get_process_group + is_skip_step, get_metrics, get_target_output_dir +from msprobe.mindspore.monitor.optimizer_collect import OptimizerMonFactory +from msprobe.mindspore.monitor.data_writers import CSVWriterWithAD, BaseWriterWithAD, WriterInput +from msprobe.mindspore.monitor.distributed.wrap_distributed import api_register, create_hooks, op_aggregate +from msprobe.core.common.file_utils import write_df_to_csv +from msprobe.core.common.utils import analyze_api_call_stack FORMAT_MAPPING = { MonitorConst.CSV: CSVWriterWithAD, @@ -88,24 +90,7 @@ class ModuleHookContext: self.actvgrad = [] self.module_name = module_name self.struct = {} - self.format_by_arg = {} - self.verified = False - self.focused_in_col = 0 - self.focused_out_col = 0 - self.ignore_in = False # no need to care when no key 'input' or 'input_grad' found - - def set_format_by_arg(self, key_name: str, target_config: dict): - cared = target_config.get(self.module_name, self.struct) - if key_name in cared: - if isinstance(cared[key_name], dict): - # current cared is self.struct - config = cared[key_name].get('config') - self.format_by_arg[key_name] = config - else: - # current cared is target_config[self.module_name] - self.format_by_arg[key_name] = cared[key_name] - elif key_name in ['input', 'input_grad']: - self.ignore_in = True + self.stack = "" def reset(self): self.actv.clear() @@ -186,6 +171,7 @@ class TrainerMon: self.config_file_path = config_file_path self.process_group = process_group self.params_have_main_grad = params_have_main_grad + self.is_mindtorch = is_mindtorch() self.config_timestamp = 0 # 后面有校验时间戳, 首次监控无需为了更新config文件时间戳而去改, 可通过dynamic_on开关直接打开 self.config = load_json(config_file_path) validate_config(self.config) @@ -218,6 +204,7 @@ class TrainerMon: self.dp_group = None self.tp_group = None self.micro_batch_number = 1 + self.optimizer_mon = None # TYPE3: 会随着训练中途config配置更新或监控状态改变而重置的变量 self.module_fwd_hook_context_by_module = defaultdict(ModuleHookContext) @@ -240,6 +227,8 @@ class TrainerMon: self.optimizer_hooked = False self.param_registered = False self.struct_printed = False + self.pre_step_hooks = [] + self.post_step_hooks = [] # 动静态区分 self.dynamic_enable = os.getenv("DYNAMIC_MONITOR", 'False').lower() == 'true' @@ -276,6 +265,9 @@ class TrainerMon: self.param_distribution = self.config.get("param_distribution", False) self.mg_direction = self.config.get('mg_direction', False) # main grad direction self.cc_distribution = self.config.get("cc_distribution", {}) # communication ops + self.stack_info = self.config.get('stack_info', False) + self.monitor_mbs_grad = self.config.get('monitor_mbs_grad', False) + if not self.cc_distribution.get('enable', False): self.cc_log_only = False else: @@ -283,8 +275,6 @@ class TrainerMon: self.cc_log_only = self.cc_distribution.get('cc_log_only', False) self.cc_logged_stack = defaultdict(set) self.cc_pre_hook = self.cc_distribution.get('cc_pre_hook', False) - self.handles['cc'] = api_register.initialize_hook(*create_hooks(context=self.cc_context, monitor=self)) - api_register.redirect_api() self.common_info() # 初始化AnomalyData工厂 @@ -298,18 +288,25 @@ class TrainerMon: if self.format not in FORMAT_MAPPING: logger.error(f"Unsupported format: {self.format}, use default format: {MonitorConst.CSV}") self.format = MonitorConst.CSV - writer = FORMAT_MAPPING[self.format] self.step_count_per_record = self.config.get('step_count_per_record', 1) - self.summary_writer = writer( - WriterInput( - self.tensorboard_dir, - self.alert_rules, - self.unique_id, - self.anomaly_data_factory, - self.ndigits, - self.step_count_per_record + if not self.module_rank_list or (self.rank in self.module_rank_list): + writer = FORMAT_MAPPING[self.format] + self.summary_writer = writer( + WriterInput( + self.tensorboard_dir, + self.alert_rules, + self.unique_id, + self.anomaly_data_factory, + self.ndigits, + self.step_count_per_record + ) ) - ) + + # 初始化anomaly detected文件目录 + if self.anomaly_data_factory: + self.anomaly_data_writer = AnomalyDataWriter(os.path.join(self.output_base_dir, "anomaly_detected"), + self.rank) + self.anomaly_data_writer.init_detected_json() def common_info(self): if not self.xy_distribution: @@ -341,6 +338,7 @@ class TrainerMon: self.micro_batch_number = grad_acc_steps self.dp_group = dp_group self.tp_group = tp_group + self.optimizer_mon = OptimizerMonFactory.create_optimizer_mon(optimizer) self.hook_step_final(optimizer) if not isinstance(model, list): model = [model] @@ -358,19 +356,31 @@ class TrainerMon: if self.monitoring: module_rank_valid = self.is_target_rank() step_condition = (context.step >= self.start_step and ( - context.step - self.start_step) % self.step_interval == 0) + context.step - self.start_step) % self.step_interval == 0) if module_rank_valid and step_condition: self.has_collect_times += 1 + + if self.anomaly_data_factory: + self.anomaly_data_factory.set_call_id(self.param_name_call_id) self.write_xy_tb(context.step) self.write_grad_tb(context.step) self.write_mv_tb(context) self.write_param_tb(context) + if self.stack_info: + self.write_stack_info() + self.stack_info = False + for handle in self.handles["stack"]: + handle.remove() + self.handles["stack"].clear() if context.metric_dict: self.summary_writer.write_metrics(self.ops, context.metric_dict, context.step, 'other') context.metric_dict.clear() + if self.anomaly_data_factory: + self.anomaly_data_writer.write_detected_json(self.summary_writer.get_anomalies()) self.summary_writer.clear_anomalies() + self.call_id = 0 self.param_name_call_id.clear() @@ -380,7 +390,23 @@ class TrainerMon: context.step += 1 self.dynamic_monitor(optimizer) - optimizer.register_forward_hook(step_final_hook) + def patch_step(func, optimizer): + def wrapper(*args, **kwargs): + for hook in self.pre_step_hooks: + hook(optimizer, args, kwargs) + out = func(*args, **kwargs) + for hook in self.post_step_hooks: + hook(optimizer, args, kwargs) + step_final_hook(optimizer, args, kwargs) + return out + + return wrapper + + if self.is_mindtorch: + optimizer.__class__.step = patch_step(optimizer.__class__.step, optimizer) + else: + optimizer.__class__.construct = patch_step(optimizer.__class__.construct, optimizer) + return def dynamic_monitor(self, optimizer): @@ -408,13 +434,14 @@ class TrainerMon: validate_config(config) self.config = config self.set_config() + self.start_step = context.step # 动态启停时不受原start_step影响,永远从下一步开始 logger.warning(f"config is updated at step{context.step - 1}, " f"will start new hook at step{context.step}.") except Exception as e: logger.error(f"set config wrong because {e}, not updated, please check!!!") return - self._remove_all_hooks() + self._remove_all_hooks(optimizer) self.register_hooks(optimizer) def register_hooks(self, optimizer): @@ -422,6 +449,9 @@ class TrainerMon: self.hook_modules() self.hook_optimizer(optimizer) self._patch_grad_sync() + if self.cc_distribution.get('enable', False): + self.handles['cc'] = api_register.initialize_hook(*create_hooks(context=self.cc_context, monitor=self)) + api_register.redirect_api() self.monitoring = True def hook_modules(self): @@ -436,45 +466,36 @@ class TrainerMon: hooked_count = 0 for vpp_stage, model_chunk in enumerate(self.model): - if not isinstance(model_chunk, nn.Cell): + if not is_valid_instance(model_chunk): logger.info("Target Model is not Cell") continue vpp_stage = f'{vpp_stage}{MonitorConst.NAME_SEP}' - targets = [x for x, _ in model_chunk.cells_and_names()] if self.print_struct else self.targets.keys() + targets = [x for x, _ in get_submodules(model_chunk)] if self.print_struct else self.targets.keys() hooked_count += self._hook_module(targets, model_chunk, vpp_stage) logger.info(f"> {hooked_count} modules are monitored.") def hook_optimizer(self, optimizer): - def optimizer_pre_hook_function(opt, grad_names, gradients): + def optimizer_pre_step_hook(opt, *args, **kwargs): context = self.optimizer_context[opt] if is_skip_step(context.step, self.start_step, self.step_interval, self.has_collect_times, self.collect_times): return - gradient_list = gradients[0] if isinstance(gradients, tuple) else gradients - is_select = self.is_select - for idx, grad in enumerate(gradient_list): - grad_name = grad_names[idx] - if is_select and grad_name not in self.targets: - continue - get_single_metrics(self.ops, grad_name, grad, context.param_weight_grad) - - if self.mv_distribution: - # fetch mean - for param in m_list: - name = param.name - if is_select and name not in self.targets: - continue - get_single_metrics(self.ops, name, param, context.exp_avg_metric) - # fetch variance - for param in v_list: - name = param.name - if is_select and name not in self.targets: - continue - get_single_metrics(self.ops, name, param, context.exp_avg_sq_metric) - if self.param_distribution: - for param in param_list: - get_single_metrics(self.ops, param.name, param, context.param_metric) - self.generate_wgrad_metrics() + + grad_dict = {} + if self.wg_distribution: + grad_dict = self.optimizer_mon.fetch_grad(self, self.param2name) + + if self.mv_distribution or self.ur_distribution or self.mg_direction: + if self.is_mindtorch: + context.param_exp_avg, context.param_exp_avg_sq, context.param_adam_update, \ + context.param_adam_ratio = self.optimizer_mon.fetch_mv(self, self.param2name) + else: + context.param_exp_avg, context.param_exp_avg_sq = self.get_mv_for_ms(optimizer) + + self.generate_wgrad_metrics(grad_dict) + self.generate_mv_metrics(context) + self.generate_param_metrics(context, MonitorConst.PRE_PARAM) + metric_dict = {} for cc in self.cc_context.values(): cc.aggregate() @@ -486,63 +507,86 @@ class TrainerMon: context.metric_dict = metric_dict return - def optimizer_pre_hook_wrapper(func, grad_names): - def wrapper(opt, gradients): - return func(opt, grad_names, gradients) - return wrapper + def optimizer_post_step_hook(optimizer, args, kwargs): + context = self.optimizer_context[optimizer] + self.generate_param_metrics(context, MonitorConst.POST_PARAM) + if self.optimizer_hooked or not self.is_target_rank(): return - m_list = [] - v_list = [] - param_list = [] - grad_names = [] - for param in optimizer.get_parameters(): - if MonitorConst.EXP_AVG_SQ in param.name: - v_list.append(param) - elif MonitorConst.EXP_AVG in param.name: - m_list.append(param) - elif param.name in ['global_step', 'learning_rate']: - pass - else: - param_list.append(param) - grad_names.append(param.name) - - handle = optimizer.register_forward_pre_hook( - optimizer_pre_hook_wrapper(optimizer_pre_hook_function, grad_names)) - self.handles['optimizer'].append(handle) + self.pre_step_hooks.append(optimizer_pre_step_hook) + self.post_step_hooks.append(optimizer_post_step_hook) self.optimizer_hooked = True return - def generate_wgrad_metrics(self): + def generate_wgrad_metrics(self, grad_dict): if not self.wg_distribution: - return {}, {} + return - if self.weight_hooked: - try: - get_metrics(self.ops, self.grad_context.acc, self.eps, self.grad_context.acc_metric) - except Exception as e: - logger.warning(f"An error occurred while generating wgrad pre metrics") - return {}, {} + get_metrics(self.ops, self.grad_context.acc, self.eps, self.grad_context.acc_metric) + get_metrics(self.ops, grad_dict, self.eps, self.grad_context.post) - grad_dict = {} - for param, name in self.param2name.items(): - if self.duplicate_param.get(name, False): - continue - grad = param.main_grad if self.params_have_main_grad else param.grad - if grad is None: - logger.warning(f"grad is None: {name}, maybe something wrong happened.") + def generate_param_map(self, tag, param_tensor): + metrics = {} + if not self.is_mindtorch: + return param_tensor + for name in self.param2name.values(): + key = get_summary_writer_tag_name(name, tag, self.rank) + self.register_param_call_id("optimizer_pre_step_hook", key) + if name not in param_tensor or param_tensor[name] is None: continue - tag = self.name2tag.get(name, {}).get(MonitorConst.POST_GRAD) - self._register_param_call_id("hook_optimizer", tag) - grad_dict[tag] = grad - try: - get_metrics(self.ops, grad_dict, self.eps, self.grad_context.post) - except Exception as e: - logger.warning(f"An error occurred while generating wgrad post metrics") + metrics[key] = param_tensor[name] + return metrics + + def generate_param_metrics(self, opt_context, stage=MonitorConst.PRE_PARAM): + if not self.param_distribution: + return + tag2param = { + self.name2tag.get(name, {}).get(stage): param + for name, param in self.name2param.items() + if param.numel() != 0 + } + get_metrics(self.ops, tag2param, self.eps, opt_context.param_metric) + + def get_mv_for_ms(self, opt): + if not self.mv_distribution: return {}, {} - return self.grad_context.post, self.grad_context.pre + common_opt = opt + if not is_valid_instance(opt): + common_opt = getattr(opt, 'optimizer') + if not is_valid_instance(common_opt): + logger.warning("Optimizer is not valid, please check usage") + return {}, {} + m_dict = {} + v_dict = {} + for name, param in get_parameters(common_opt): + if MonitorConst.EXP_AVG_SQ in name: + m_dict[name] = param + elif MonitorConst.EXP_AVG in name: + v_dict[name] = param + return m_dict, v_dict + + def generate_mv_metrics(self, opt_context): + if not self.mv_distribution: + return + opt_context.exp_avg_metric = {} + opt_context.exp_avg_sq_metric = {} + m_tag_tensor_map = self.generate_param_map(MonitorConst.EXP_AVG, opt_context.param_exp_avg) + v_tag_tensor_map = self.generate_param_map(MonitorConst.EXP_AVG_SQ, opt_context.param_exp_avg_sq) + get_metrics(self.ops, m_tag_tensor_map, self.eps, opt_context.exp_avg_metric) + get_metrics(self.ops, v_tag_tensor_map, self.eps, opt_context.exp_avg_sq_metric) + + def write_stack_info(self): + stack_data = [] + header = ["module_name", "stack_info"] + stack_data.append(header) + for _, fwd_context in self.module_fwd_hook_context_by_module.items(): + stack_data.append([fwd_context.module_name, fwd_context.stack]) + filepath = os.path.join(self.tensorboard_dir, f'stack_info.csv') + if not os.path.exists(filepath): + data_frame = pd.DataFrame(columns=stack_data) + write_df_to_csv(data_frame, filepath) def write_xy_tb(self, step): if not self.xy_distribution: @@ -550,27 +594,32 @@ class TrainerMon: for _, fwd_context in self.module_fwd_hook_context_by_module.items(): if len(fwd_context.actv) == 0: continue - self.summary_writer.write_metrics(self.ops, fwd_context.actv, step, 'actv') + self.summary_writer.write_metrics(self.ops, fwd_context.actv, step, MonitorConst.ACTV) fwd_context.actv.clear() if self.grad_context.actv: - self.summary_writer.write_metrics(self.ops, self.grad_context.actv, step, 'actv_grad') + self.summary_writer.write_metrics(self.ops, self.grad_context.actv, step, MonitorConst.ACTVGRAD) def write_param_tb(self, opt_context): if not self.param_distribution: return - self.summary_writer.write_metrics(self.ops, opt_context.param_metric, opt_context.step, 'param') + param_metrics = {k: v for k, v in opt_context.param_metric.items() if MonitorConst.PRE_PARAM in k} + updated_param_metrics = {k: v for k, v in opt_context.param_metric.items() if MonitorConst.POST_PARAM in k} + self.summary_writer.write_metrics(self.ops, param_metrics, opt_context.step, MonitorConst.PRE_PARAM) + self.summary_writer.write_metrics(self.ops, updated_param_metrics, opt_context.step, MonitorConst.POST_PARAM) def write_mv_tb(self, opt_context): if not self.mv_distribution: return - self.summary_writer.write_metrics(self.ops, opt_context.exp_avg_metric, opt_context.step, 'exp_avg') - self.summary_writer.write_metrics(self.ops, opt_context.exp_avg_sq_metric, opt_context.step, 'exp_avg_sq') + self.summary_writer.write_metrics(self.ops, opt_context.exp_avg_metric, opt_context.step, MonitorConst.EXP_AVG) + self.summary_writer.write_metrics(self.ops, opt_context.exp_avg_sq_metric, opt_context.step, + MonitorConst.EXP_AVG_SQ) def write_grad_tb(self, step): if not self.wg_distribution: return - self.summary_writer.write_metrics(self.ops, self.grad_context.acc_metric, step, 'grad_unreduced') + self.summary_writer.write_metrics(self.ops, self.grad_context.acc_metric, step, 'grad_unreduced', + use_micro_step=self.monitor_mbs_grad) self.summary_writer.write_metrics(self.ops, self.grad_context.post, step, 'grad_reduced') def is_target_rank(self): @@ -578,13 +627,38 @@ class TrainerMon: return False return True - def build_tbtag_tensor_map(self, module_name, tag, tensor): - metrics = {} - key = get_summary_writer_tag_name(module_name, tag, str(self.rank)) + def build_tbtag_tensor_map(self, module_name, suffix, tag, tensor): + """ + :param module_name: str of module name + :param suffix: + :param tag: + :param tensor: torch.tensor or tuple/list of torch.tensor + :return: tensor_map + """ + tensor_map = {} if isinstance(tensor, Tensor): - self._register_param_call_id("_hook_module", key) - metrics[key] = tensor - return metrics + tensor = [tensor] + if isinstance(tensor, tuple) or isinstance(tensor, list): + if len(tensor) == 1: + key = get_summary_writer_tag_name(module_name + suffix, tag, self.rank) + self.register_param_call_id("_hook_module", key) + tensor_map[key] = tensor[0] + else: + for i, tensor_i in enumerate(tensor): + key = get_summary_writer_tag_name(module_name + f"_{i}" + suffix, tag, self.rank) + self.register_param_call_id("_hook_module", key) + tensor_map[key] = tensor_i + return tensor_map + + def register_param_call_id(self, hook_name: str, key: str): + """ + :param hook_name: + :param key: str, '0:relu_0/output_grad' + :return: + """ + logger.debug(f"{hook_name} {key}: {self.call_id}") + self.param_name_call_id[key] = self.call_id + self.call_id += 1 def _register_param_name(self): for vpp_stage, model_chunk in enumerate(self.model): @@ -593,8 +667,7 @@ class TrainerMon: def _register_chunk(self, model_chunk, prefix): index = 0 - for param in model_chunk.get_parameters(): - param_name = param.name + for param_name, param in get_parameters(model_chunk): if not param.requires_grad: continue if self._is_target_param(param_name, param, prefix): @@ -609,25 +682,37 @@ class TrainerMon: self.duplicate_param[name] = True if self.dp_group and param_is_data_parallel_duplicate(self.dp_group): self.duplicate_param[name] = True + keywords = [ + MonitorConst.PRE_GRAD, + MonitorConst.POST_GRAD, + MonitorConst.PRE_PARAM, + MonitorConst.POST_PARAM + ] self.name2tag[name] = { - MonitorConst.PRE_GRAD: get_summary_writer_tag_name(name, MonitorConst.PRE_GRAD, self.rank), - MonitorConst.POST_GRAD: get_summary_writer_tag_name(name, MonitorConst.POST_GRAD, self.rank) + k: get_summary_writer_tag_name(name, k, self.rank) + for k in keywords } index += 1 def _hook_module(self, target_names, module, vpp_stage=''): - if not isinstance(module, nn.Cell): + if not is_valid_instance(module): # nothing to hook return 0 - def fwd_hook_fun(module, module_input, module_output, name): + def fwd_hook_fun(module, args, kwargs, module_output, name): + + module_input = [tensor for tensor in args if isinstance(tensor, Tensor)] + if kwargs: + kwargs_tensors = [tensor for tensor in kwargs.values() if isinstance(tensor, Tensor)] + module_input.extend(kwargs_tensors) + if module not in self.module_fwd_hook_context_by_module: self.module_fwd_hook_context_by_module[module] = ModuleHookContext(name) context: ModuleHookContext = self.module_fwd_hook_context_by_module[module] if not context.struct: context.struct = { - MonitorConst.ACTV_IN: get_param_struct(module_input), - MonitorConst.ACTV_OUT: get_param_struct(module_output) + Const.INPUT: get_param_struct(module_input), + Const.OUTPUT: get_param_struct(module_output) } if self.print_struct: self.module_struct[context.module_name].update(context.struct) @@ -638,31 +723,18 @@ class TrainerMon: self.collect_times): step_accumulates_one(context, self.micro_batch_number) return - if not context.format_by_arg: - context.set_format_by_arg(MonitorConst.ACTV_IN, self.targets) - context.set_format_by_arg(MonitorConst.ACTV_OUT, self.targets) - if not context.format_by_arg: - return - if not context.verified: - if not context.ignore_in: - context.focused_in_col = validate_config_spec(context.format_by_arg[MonitorConst.ACTV_IN], - module_input, context.module_name, - MonitorConst.ACTV_IN) - context.focused_out_col = validate_config_spec(context.format_by_arg[MonitorConst.ACTV_OUT], - module_output, context.module_name, - MonitorConst.ACTV_OUT) - context.verified = True tbtag_tensor_map = {} - if not context.ignore_in: - cared_input = module_input if context.focused_in_col is None else module_input[context.focused_in_col] - tbtag_tensor_map.update( - self.build_tbtag_tensor_map(f'{context.module_name}_{context.micro_step}', MonitorConst.ACTV_IN, - cared_input)) - cared_output = module_output if context.focused_out_col is None else module_output[context.focused_out_col] tbtag_tensor_map.update( - self.build_tbtag_tensor_map(f'{context.module_name}_{context.micro_step}', MonitorConst.ACTV_OUT, - cared_output)) + self.build_tbtag_tensor_map( + f'{context.module_name}.{Const.INPUT}', f'{MonitorConst.NAME_SEP}{context.micro_step}', + MonitorConst.ACTV, module_input)) + module_output = [tensor for tensor in module_output if isinstance(tensor, Tensor)] \ + if isinstance(module_output, tuple) else module_output + tbtag_tensor_map.update( + self.build_tbtag_tensor_map( + f'{context.module_name}.{Const.OUTPUT}', f'{MonitorConst.NAME_SEP}{context.micro_step}', + MonitorConst.ACTV, module_output)) try: get_metrics(self.ops, tbtag_tensor_map, self.eps, context.actv) except Exception as e: @@ -687,31 +759,17 @@ class TrainerMon: step_accumulates_one(context, self.micro_batch_number) return - if not context.format_by_arg: - context.set_format_by_arg(MonitorConst.ACTVGRAD_IN, self.targets) - context.set_format_by_arg(MonitorConst.ACTVGRAD_OUT, self.targets) - if not context.format_by_arg: - return - if not context.verified: - if not context.ignore_in: - context.focused_in_col = validate_config_spec(context.format_by_arg[MonitorConst.ACTVGRAD_IN], - input_grad, context.module_name, - MonitorConst.ACTVGRAD_IN) - context.focused_out_col = validate_config_spec(context.format_by_arg[MonitorConst.ACTVGRAD_OUT], - output_grad, context.module_name, - MonitorConst.ACTVGRAD_OUT) - context.verified = True - + valid_input_grad = [tensor for tensor in input_grad if isinstance(tensor, Tensor)] tbtag_tensor_map = {} - if not context.ignore_in: - cared_input_grad = input_grad if context.focused_in_col is None else input_grad[context.focused_in_col] - tbtag_tensor_map.update( - self.build_tbtag_tensor_map( - f'{context.module_name}_{context.micro_step}', MonitorConst.ACTVGRAD_IN, cared_input_grad)) - cared_output_grad = output_grad if context.focused_out_col is None else output_grad[context.focused_out_col] tbtag_tensor_map.update( - self.build_tbtag_tensor_map(f'{context.module_name}_{context.micro_step}', MonitorConst.ACTVGRAD_OUT, - cared_output_grad)) + self.build_tbtag_tensor_map( + f'{context.module_name}.{Const.INPUT}', f'{MonitorConst.NAME_SEP}{context.micro_step}', + MonitorConst.ACTVGRAD, valid_input_grad)) + + tbtag_tensor_map.update( + self.build_tbtag_tensor_map( + f'{context.module_name}.{Const.OUTPUT}', f'{MonitorConst.NAME_SEP}{context.micro_step}', + MonitorConst.ACTVGRAD, output_grad)) if context.micro_step == 0 and context.actvgrad: logger.warning(f"actvgrad context of {context.module_name} is not empty when first micro_step, " @@ -726,20 +784,34 @@ class TrainerMon: return def fwd_hook_fun_wrapper(fwd_hook_fun, name): - def wrapper(module, module_input, module_output): - return fwd_hook_fun(module, module_input, module_output, name) + def wrapper(module, args, kwargs, module_output): + return fwd_hook_fun(module, args, kwargs, module_output, name) + return wrapper + def stack_hook(module, args, kwargs, module_output, name): + if module not in self.module_fwd_hook_context_by_module: + self.module_fwd_hook_context_by_module[module] = ModuleHookContext(name) + context: ModuleHookContext = self.module_fwd_hook_context_by_module[module] + context.stack = analyze_api_call_stack(name) + return + if self.backward_only and self.forward_only: logger.warning('not enable backward_only and forward_only simultaneously') hooked_count = 0 - if self.xy_distribution or self.print_struct: - for module_name, submodule in module.cells_and_names(): - name = self._is_target_module(module_name, target_names, vpp_stage) - if not name: - continue + + for module_name, submodule in get_submodules(module): + if self.stack_info: + name = vpp_stage + squash_param_name(module_name) + handle = submodule.register_forward_hook(fwd_hook_fun_wrapper(stack_hook, name=name), with_kwargs=True) + self.handles["stack"].append(handle) + name = self._is_target_module(module_name, target_names, vpp_stage) + if not name: + continue + if self.xy_distribution or self.print_struct: if not self.backward_only: - handle = submodule.register_forward_hook(fwd_hook_fun_wrapper(fwd_hook_fun, name=name)) + handle = submodule.register_forward_hook(fwd_hook_fun_wrapper(fwd_hook_fun, name=name), + with_kwargs=True) self.handles['xy'].append(handle) if not self.forward_only: handle = submodule.register_backward_hook(bwd_hook_fun) @@ -758,22 +830,30 @@ class TrainerMon: context = self.grad_context @_no_grad() - def param_hook(grad, context_dict, param, key): + def param_hook(grad, context_dict, param, name): + key = name + if self.monitor_mbs_grad: + key += f'{MonitorConst.NAME_SEP}{param.micro_step}' + key = get_summary_writer_tag_name(key, 'acc_grad', self.rank) + self.register_param_call_id("param_hook", key) param.micro_step += 1 - self._register_param_call_id("param_hook", key) + + if self.monitor_mbs_grad or (param.micro_step == self.micro_batch_number): + context_dict[key] = grad if param.micro_step == self.micro_batch_number: param.micro_step = 0 - context_dict[key] = grad - def param_hook_wrapper(param_hook, context_dict, param, key): + def param_hook_wrapper(param_hook, context_dict, param, name): def wrapper(grad): - return param_hook(grad, context_dict, param, key) + return param_hook(grad, context_dict, param, name) + return wrapper + logger.info("hooking weights.") for param, name in self.param2name.items(): - key = get_summary_writer_tag_name(name, 'acc_grad', self.rank) setattr(param, 'micro_step', 0) - handle = param.register_hook(param_hook_wrapper(param_hook, context_dict=context.acc, param=param, key=key)) + handle = param.register_hook( + param_hook_wrapper(param_hook, context_dict=context.acc, param=param, name=name)) self.handles['wgrads'].append(handle) self.weight_hooked = True @@ -799,17 +879,7 @@ class TrainerMon: return pattern return "" - def _register_param_call_id(self, hook_name: str, key: str): - """ - :param hook_name: - :param key: str, '0:relu_0/output_grad' - :return: - """ - logger.debug(f"{hook_name} {key}: {self.call_id}") - self.param_name_call_id[key] = self.call_id - self.call_id += 1 - - def _remove_all_hooks(self): + def _remove_all_hooks(self, optimizer): # 清空hook handle for handle in self.handles['xy']: handle.remove() @@ -827,9 +897,8 @@ class TrainerMon: self.weight_hooked = False if self.optimizer_hooked: - for handle in self.handles['optimizer']: - handle.remove() - self.handles['optimizer'].clear() + self.pre_step_hooks.clear() + self.post_step_hooks.clear() for _, context in self.optimizer_context.items(): context.reset() self.optimizer_hooked = False @@ -837,6 +906,7 @@ class TrainerMon: for handle in self.handles['cc']: handle.remove() self.handles['cc'].clear() + api_register.restore_api() for _, context in self.cc_context.items(): context.reset() @@ -867,4 +937,4 @@ class TrainerMon: except Exception as e: logger.warning(f"Finish monitor, set config'dynamic_on=False fail because {e}, please check!!!") logger.info("Finish monitor") - self._remove_all_hooks() + self._remove_all_hooks(optimizer) diff --git a/debug/accuracy_tools/msprobe/mindspore/monitor/module_spec_verifier.py b/debug/accuracy_tools/msprobe/mindspore/monitor/module_spec_verifier.py deleted file mode 100644 index c06e8ea10f6a2178c3670e596ad64e333db44cab..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/msprobe/mindspore/monitor/module_spec_verifier.py +++ /dev/null @@ -1,94 +0,0 @@ -# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import re -import abc -from mindspore import Tensor - -from msprobe.core.common.log import logger - - -# 用于存储所有validator实现类的注册表 -config_validator_registry = {} - - -def register_config_validator(cls): - """装饰器 用于注册ConfigValidator的实现类""" - config_validator_registry[cls.__name__] = cls - return cls - - -class ConfigValidator(metaclass=abc.ABCMeta): - @abc.abstractmethod - def check_pattern_match(self, config_spec: str): - pass - - @abc.abstractmethod - def validate(self, actual_data, module_name: str, data_type: str, pattern_match): - pass - - -@register_config_validator -class TensorValidator(ConfigValidator): - def check_pattern_match(self, config_spec: str): - pattern = re.compile(r"tensor") - return pattern.match(config_spec) - - def validate(self, actual_data, module_name: str, data_type: str, pattern_match): - if not isinstance(actual_data, Tensor): - raise ValueError( - f"Format of {module_name} {data_type} does not match the required format 'tensor' in config.") - - -@register_config_validator -class TupleValidator(ConfigValidator): - def check_pattern_match(self, config_spec: str): - pattern = re.compile(r"tuple\[(\d+)\]:?(\d+)?") - return pattern.match(config_spec) - - def validate(self, actual_data, module_name: str, data_type: str, pattern_match): - length, index = pattern_match.groups() - if index is None: - index = 0 - length, index = int(length), int(index) - - if not (0 <= index < length): - raise ValueError( - f"Format of {module_name} {data_type} in config.json does not match the required format 'tuple[x]:y'." - f"y must be greater than or equal to 0 and less than x.") - if not isinstance(actual_data, tuple): - raise ValueError( - f"Type of {module_name} {data_type} does not match spec of config.json, should be tuple, please check.") - if len(actual_data) != length: - raise ValueError( - f"Length of {module_name} {data_type} does not match spec of config.json, should be {length}, " - f"actual is {len(actual_data)} please check.") - return index - - -def validate_config_spec(config_spec: str, actual_data, module_name: str, data_type: str): - focused_col = None - for _, validator_cls in config_validator_registry.items(): - config_validator = validator_cls() - pattern_match = config_validator.check_pattern_match(config_spec) - if pattern_match: - try: - focused_col = config_validator.validate(actual_data, module_name, data_type, pattern_match) - except ValueError as e: - logger.warning(f"config spec validate failed: {str(e)}") - return focused_col - logger.warning(f"config spec in {module_name} {data_type} not supported, " - f"expected spec:'tuple\[(\d+)\]:(\d+)' or 'tensor', actual spec: {config_spec}.") - return focused_col \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/mindspore/monitor/optimizer_collect.py b/debug/accuracy_tools/msprobe/mindspore/monitor/optimizer_collect.py new file mode 100644 index 0000000000000000000000000000000000000000..7e0b633345077503280641aed76d58be9e7a3328 --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/monitor/optimizer_collect.py @@ -0,0 +1,331 @@ +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from abc import abstractmethod + +from mindspore import mint, ops + +from msprobe.mindspore.common.log import logger +from msprobe.core.common.const import MonitorConst + + +class OptimizerMon(object): + def __init__(self, optim) -> None: + self.fp16_to_fp32_param = {} + self.optim = optim + self.state = {} + + def narrow_from_flatten(self, param, flatten_state): + return flatten_state + + def get_state(self, optim): + if hasattr(optim, 'chained_optimizers'): + for opt in optim.chained_optimizers: + self._get_single_state(opt) + else: + self._get_single_state(optim) + + def fetch_grad(self, monitor, params2name): + if not self.fp16_to_fp32_param: + self.map_fp16_to_fp32_param(self.optim) + + grad_dict = {} + first_param = True + for param, name in params2name.items(): + if monitor.duplicate_param.get(name, False): + continue + if self.fp16_to_fp32_param and param not in self.fp16_to_fp32_param: + continue + grad = param.main_grad if monitor.params_have_main_grad else param.grad + element_in_cur_partition = self.fp16_to_fp32_param.get(param, param).numel() + if param.numel() != element_in_cur_partition: + if first_param: + grad = grad.flatten()[-element_in_cur_partition:] + else: # supposed to be the last one + grad = grad.flatten()[:element_in_cur_partition] + first_param = False + if grad is None: + continue + tag = monitor.name2tag.get(name, {}).get(MonitorConst.POST_GRAD) + monitor.register_param_call_id("hook_optimizer", tag) + grad_dict[tag] = grad + return grad_dict + + def map_fp16_to_fp32_param(self, optim): + pass + + def fetch_mv(self, monitor, params2name): + if not self.fp16_to_fp32_param: + self.map_fp16_to_fp32_param(self.optim) + if not self.state: + self.get_state(self.optim) + + exp_avg_dict = {} + exp_avg_sq_dict = {} + update_dict = {} + ratio_dict = {} + + if not self.state: + logger.warning('optimizer state can not accessed') + return exp_avg_dict, exp_avg_sq_dict, update_dict, ratio_dict + + for lp_param, name in params2name.items(): + if lp_param in self.fp16_to_fp32_param: + hp_param = self.fp16_to_fp32_param[lp_param] + else: + hp_param = lp_param + + if hp_param in self.state: + state_param = self.state.get(hp_param, {}) + exp_avg = self.narrow_from_flatten(lp_param, state_param.get("exp_avg", None)) + exp_avg_sq = self.narrow_from_flatten(lp_param, state_param.get("exp_avg_sq", None)) + if monitor.mv_distribution: + exp_avg_dict[name] = exp_avg + exp_avg_sq_dict[name] = exp_avg_sq + if monitor.mg_direction: + exp_avg_dict[name] = exp_avg + if monitor.ur_distribution: + if len(self.optim.param_groups) > 1: + logger.info(f"the length of optim.param_groups is {len(self.optim.param_groups)}.") + if 'step' in state_param: + step = state_param['step'] # Optimizer from pytorch or FusedAdam from apex(used by megatron) + elif 'step' in self.optim.param_groups[0]: + step = self.optim.param_groups[0]['step'] # AdamW from mindspeed + else: + logger.warning(f"step of {name} is None, maybe something wrong happened.") + continue + exp_avg_hat = exp_avg / (1 - self.optim.defaults['betas'][0] ** step) + exp_avg_sq_hat = exp_avg_sq / (1 - self.optim.defaults['betas'][1] ** step) + update_dict[name] = exp_avg_hat / (mint.sqrt(exp_avg_sq_hat) + self.optim.defaults['eps']) + ratio_dict[name] = exp_avg_hat / mint.sqrt(exp_avg_sq_hat) + monitor.update_heatmap_visualizer[name].pre_cal(update_dict[name]) + monitor.ratio_heatmap_visualizer[name].pre_cal(ratio_dict[name]) + return exp_avg_dict, exp_avg_sq_dict, update_dict, ratio_dict + + def _get_single_state(self, optim): + state = {} + if hasattr(optim, 'param_to_cpu_states_map'): + state = optim.param_to_cpu_states_map + elif hasattr(optim, 'state'): + state = optim.state + elif hasattr(optim, 'optimizer') and hasattr(optim.optimizer, 'state'): + state = optim.optimizer.state + self.state.update(state) + + +class MixPrecisionOptimizerMon(OptimizerMon): + """ + 混合精度优化器监控类。在混合精度训练中监控和管理优化器。 + 混合精度训练通过适当降低某些计算的精度来加速训练过程并减少内存消耗。 + """ + def map_fp16_to_fp32_param(self, optim): + for fp16_group, fp32_group in zip(optim.float16_groups, optim.fp32_from_float16_groups): + for fp16_param, fp32_param in zip(fp16_group, fp32_group): + self.fp16_to_fp32_param[fp16_param] = fp32_param + + +class MegatronDistributedOptimizerMon(OptimizerMon): + def map_fp16_to_fp32_param(self, optim): + if not (hasattr(optim, "model_float16_groups") and + hasattr(optim, "shard_fp32_from_float16_groups")): + raise Exception( + "megatron distributed optimizer should have model_float16_groups and shard_fp32_from_float16_groups, " + "if not, please check megatron-lm version") + for fp16_group, shard_fp32_group in zip(optim.model_float16_groups, + optim.shard_fp32_from_float16_groups): + for fp16_param, shard_fp32_param in zip(fp16_group, shard_fp32_group): + self.fp16_to_fp32_param[fp16_param] = shard_fp32_param + + +class MegatronChainedDistributedOptimizerMon(MegatronDistributedOptimizerMon): + def map_fp16_to_fp32_param(self, optim): + for opt in optim.chained_optimizers: + super().map_fp16_to_fp32_param(opt) + + +class MegatronChainedMixPrecisionOptimizerMon(MixPrecisionOptimizerMon): + def map_fp16_to_fp32_param(self, optim): + for opt in optim.chained_optimizers: + super().map_fp16_to_fp32_param(opt) + + +class DeepSpeedZeroOptimizerMon(OptimizerMon): + """ + Base monitor class for DeepSpeed ZeRO optimizer. + ZeRO stage 0 no partition + ZeRO stage 1 partitions optimizer states across data parallel processes. + ZeRO stage 2 additionally partitions gradients. + ZeRO stage 3 additionally partitions parameters. + + This class provides monitoring capabilities for ZeRO optimizers by: + - Handling gradient collection for different ZeRO stages + - Managing optimizer state access for monitoring + """ + def __init__(self, optim): + super().__init__(optim) + self.stage = '' + self.bit16_groups = [] + self.fp32_flat_groups = [] + self.param2group = () + self.param2index = [] + self.group_offset = {} + + @abstractmethod + def get_grad_for_param(self, lp_param, group_idx, param_id): + raise NotImplementedError + + def param_not_in_partition(self, lp_param, group_idx): + param_slice_mapping = self.optim.state_dict()['param_slice_mappings'][group_idx] + hp_address = param_slice_mapping.get(self.optim.param_names.get(lp_param)) + return hp_address is None + + def get_position(self, lp_param, group_idx): + param_slice_mapping = self.optim.state_dict()['param_slice_mappings'][group_idx] + hp_address = param_slice_mapping.get(self.optim.param_names.get(lp_param)) + return hp_address.start, hp_address.numel + + def get_group_index(self): + param2group = {} + for group_idx, bit16_group in enumerate(self.bit16_groups): + for param in bit16_group: + param2group[param] = group_idx + return param2group + + def get_param_index(self, lp_param, group_idx): + if not self.param2index: + for group in self.bit16_groups: + param2index = {} + for index, param in enumerate(group): + param2index[param] = index + self.param2index.append(param2index) + + return self.param2index[group_idx][lp_param] + + def narrow_from_flatten(self, param, flatten_state): + if flatten_state is None: + return flatten_state + group_idx = self.param2group[param] + if self.param_not_in_partition(param, group_idx): + return None + start, numel = self.get_position(param, group_idx) + return flatten_state.narrow(0, start, numel) + + def map_fp16_to_fp32_param(self, optim): + for group_idx, group in enumerate(self.bit16_groups): + for param in group: + self.fp16_to_fp32_param[param] = self.fp32_flat_groups[group_idx] + + def fetch_grad(self, monitor, params2name): + grad_dict = {} + for lp_param, name in params2name.items(): + group_idx = self.param2group[lp_param] + param_id = self.get_param_index(lp_param, group_idx) + if self.param_not_in_partition(lp_param, group_idx): + continue + if self.stage == '1or2': + param_id = param_id - self.group_offset[group_idx] - 1 + grad = self.get_grad_for_param(lp_param, group_idx, param_id) + tag = monitor.name2tag.get(name, {}).get(MonitorConst.POST_GRAD) + monitor.register_param_call_id("hook_optimizer", tag) + grad_dict[tag] = grad + + return grad_dict + + +class DeepSpeedZeroOptimizerStage0Mon(DeepSpeedZeroOptimizerMon): + def __init__(self, optim): + super().__init__(optim) + self.stage = '0' + self.bit16_groups = optim.bf16_groups + self.fp32_flat_groups = optim.fp32_groups_flat_partition + self.param2group = self.get_group_index() + + def get_grad_for_param(self, lp_param, group_idx, param_id): + return self.optim.fp32_groups_gradient_dict[group_idx][param_id] + + +class DeepSpeedZeroOptimizerStage1or2Mon(DeepSpeedZeroOptimizerMon): + def __init__(self, optim): + super().__init__(optim) + self.stage = '1or2' + self.bit16_groups = optim.bit16_groups + self.fp32_flat_groups = optim.single_partition_of_fp32_groups + self.param2group = self.get_group_index() + self.group_offset = {} + self.get_group_offset() + + def get_grad_for_param(self, lp_param, group_idx, param_id): + if getattr(self.optim, "cpu_offload", False): + grads = self.optim.single_partition_of_fp32_groups[group_idx].grad + start, numel = self.get_position(lp_param, group_idx) + grad = grads.narrow(0, start, numel) + else: + grad = self.optim.averaged_gradients[group_idx][param_id] + return grad + + def get_group_offset(self): + for group_idx, group in enumerate(self.bit16_groups): + self.group_offset[group_idx] = -1 + for lp_param in group: + if self.param_not_in_partition(lp_param, group_idx): + self.group_offset[group_idx] = self.get_param_index(lp_param, group_idx) + else: + break + + +class DeepSpeedZeroOptimizerStage3Mon(DeepSpeedZeroOptimizerMon): + def __init__(self, optim): + super().__init__(optim) + self.stage = '3' + self.bit16_groups = optim.fp16_groups + self.fp32_flat_groups = optim.fp32_partitioned_groups_flat + self.param2group = self.get_group_index() + + def param_not_in_partition(self, lp_param, group_idx): + """Each param partioned across all zero ranks""" + return False + + def get_position(self, lp_param, group_idx): + param_id = self.optim.get_param_id(lp_param) + return self.optim.grad_position[param_id][1:] + + def get_grad_for_param(self, lp_param, group_idx, param_id): + return self.optim.averaged_gradients[group_idx][param_id] + + +class OptimizerMonFactory: + _optimizer_mon_map = { + "FP32Optimizer": OptimizerMon, + "Float16OptimizerWithFloat16Params": MixPrecisionOptimizerMon, + "DistributedOptimizer": MegatronDistributedOptimizerMon, + "SwapDistributedOptimizer": MegatronDistributedOptimizerMon, + "ChainedDistributedOptimizer": MegatronChainedDistributedOptimizerMon, + "ChainedSwapDistributedOptimizer": MegatronChainedDistributedOptimizerMon, + "ChainedFloat16OptimizerWithFloat16Params": MegatronChainedMixPrecisionOptimizerMon, + "BF16_Optimizer": DeepSpeedZeroOptimizerStage0Mon, + "DeepSpeedZeroOptimizer": DeepSpeedZeroOptimizerStage1or2Mon, + "DeepSpeedZeroOptimizer_Stage3": DeepSpeedZeroOptimizerStage3Mon, + "Adam": OptimizerMon + } + + @staticmethod + def create_optimizer_mon(optimizer): + # auto replace opt_ty + optimizer_class = optimizer.__class__.__name__ + if optimizer_class == "ChainedOptimizer": + optimizer_class = "Chained" + optimizer.chained_optimizers[0].__class__.__name__ + logger.info(f'The optimizer type is {optimizer_class}') + + optimizer_mon_class = OptimizerMonFactory._optimizer_mon_map.get(optimizer_class, OptimizerMon) + return optimizer_mon_class(optimizer) diff --git a/debug/accuracy_tools/msprobe/mindspore/monitor/utils.py b/debug/accuracy_tools/msprobe/mindspore/monitor/utils.py index a27172f19ead537f276c5ce0820b405d7abb6e25..e0817eb2a4efc11000a96c6b328f6fbd07145060 100644 --- a/debug/accuracy_tools/msprobe/mindspore/monitor/utils.py +++ b/debug/accuracy_tools/msprobe/mindspore/monitor/utils.py @@ -24,18 +24,24 @@ from msprobe.core.common.log import logger from msprobe.core.common.file_utils import check_file_or_directory_path -def get_single_metrics(op_list, tag, tensor, output=None): +def get_single_metrics(op_list, tag, tensor, eps=1e-8, output=None): if output is None: output = {} if tag not in output: output[tag] = {} for op in op_list: func = FUNC_MAP.get(op) - statistic = func(tensor) + if op == "zeros": + statistic = func(tensor, eps) + else: + statistic = func(tensor) if hasattr(statistic, "dtype") and statistic.dtype == mstype.bfloat16: statistic = float(statistic) statistic = Tensor(statistic) - output[tag][op] = statistic.astype(mstype.float32) + if isinstance(statistic, Tensor): + output[tag][op] = statistic.astype(mstype.float32) + else: + output[tag][op] = statistic def get_metrics(op_list, tag2tensor, eps, output=None): @@ -44,7 +50,7 @@ def get_metrics(op_list, tag2tensor, eps, output=None): for tag, tensor in tag2tensor.items(): if tag not in output: output[tag] = {} - get_single_metrics(op_list, tag, tensor, output) + get_single_metrics(op_list, tag, tensor, eps, output) return output @@ -91,6 +97,11 @@ def validate_ops(ops): default_op = MonitorConst.OP_LIST[0] valid_ops.append(default_op) logger.info(f"There is no valid ops, default op {default_op} is used") + # 增加默认shape和dtype参数 + if "shape" not in valid_ops: + valid_ops.append("shape") + if "dtype" not in valid_ops: + valid_ops.append("dtype") return valid_ops @@ -171,7 +182,7 @@ def validate_alert(alert): args = rule.get("args") if args and isinstance(args, dict): threshold = args.get("threshold") - if not isinstance(threshold, float) or threshold < 0: + if not isinstance(threshold, (float, int)) or threshold < 0: raise TypeError('threshold must be float and not less than 0') dump = alert.get('dump') if dump and not isinstance(dump, bool): @@ -212,6 +223,18 @@ def validate_collect_times(collect_times): raise ValueError("collect_times must greater than 1") +def validate_dynamic_on(dynamic_on): + if not isinstance(dynamic_on, bool): + raise TypeError('dynamic_on should be a bool') + + +def validate_monitor_mbs_grad(monitor_mbs_grad): + if not isinstance(monitor_mbs_grad, bool): + logger.warning(f'monitor_mbs_grad should be a bool, actual value is {monitor_mbs_grad}.') + return False + return monitor_mbs_grad + + def validate_config(config): config['ops'] = validate_ops(config.get('ops', [])) @@ -258,9 +281,14 @@ def validate_config(config): step_interval = config.get('step_interval', 1) validate_step_interval(step_interval) - collect_times = config.get('collect_times', 1e8) + collect_times = config.get('collect_times', int(1e8)) validate_collect_times(collect_times) + config["monitor_mbs_grad"] = validate_monitor_mbs_grad(config.get('monitor_mbs_grad', False)) + + dynamic_on = config.get('dynamic_on', False) + validate_dynamic_on(dynamic_on) + if not targets: if xy_distribution: config["all_xy"] = True diff --git a/debug/accuracy_tools/msprobe/mindspore/ms_config.py b/debug/accuracy_tools/msprobe/mindspore/ms_config.py index f20ed804c5bb8d8fbe4dba3e208060e8f52a3120..4b73ad5bdebbc3b6b2bdceb2d34b89264aa4f013 100644 --- a/debug/accuracy_tools/msprobe/mindspore/ms_config.py +++ b/debug/accuracy_tools/msprobe/mindspore/ms_config.py @@ -28,7 +28,9 @@ class TensorConfig(BaseConfig): super().__init__(json_config) self.check_mode = None self.file_format = json_config.get("file_format") + self.td_config_path = json_config.get("td_config_path") self.check_config() + self._check_summary_mode() self._check_config() def _check_config(self): @@ -42,12 +44,23 @@ class StatisticsConfig(BaseConfig): self.file_format = None self.check_mode = None self.check_config() - self._check_config() + self._check_summary_mode() - def _check_config(self): - single_opt = ["statistics", "md5"] + self.tensor_list = json_config.get("tensor_list", []) + self._check_str_list_config(self.tensor_list, "tensor_list") + self.stat_cal_mode = json_config.get("device", "host") + self.device_stat_precision_mode = json_config.get("precision", "high") + self._check_stat_params() + + def _check_stat_params(self): + if self.stat_cal_mode not in ["device", "host"]: + raise Exception("Config param [device] is invalid, expected from [\"device\", \"host\"]") + if self.device_stat_precision_mode not in ["high", "low"]: + raise Exception("Config param [precision] is invalid, expected from [\"high\", \"low\"]") + + def _check_summary_mode(self): muti_opt = ["md5", "max", "min", "mean", "l2norm"] - if isinstance(self.summary_mode, str) and self.summary_mode not in single_opt: + if isinstance(self.summary_mode, str) and self.summary_mode not in Const.SUMMARY_MODE: raise Exception("summary_mode is invalid") if isinstance(self.summary_mode, list) and not all(opt in muti_opt for opt in self.summary_mode): raise Exception("summary_mode is invalid") @@ -132,14 +145,3 @@ def parse_task_config(task, json_config): if task not in TaskDict: raise Exception("task is invalid.") return TaskDict.get(task)(task_map) - - -def parse_json_config(json_file_path): - if not json_file_path: - raise Exception("json file path is None") - json_config = load_json(json_file_path) - common_config = parse_common_config(json_config) - if not common_config.task: - common_config.task = Const.STATISTICS - task_config = parse_task_config(common_config.task, json_config) - return common_config, task_config diff --git a/debug/accuracy_tools/msprobe/mindspore/overflow_check/overflow_check_tool_factory.py b/debug/accuracy_tools/msprobe/mindspore/overflow_check/overflow_check_tool_factory.py index a2d3e290bd6b16b3deeb7f22a5e7d327ebaa2bc4..0b3ed6221c2c4fc2d380072a480f35d5815cb89e 100644 --- a/debug/accuracy_tools/msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +++ b/debug/accuracy_tools/msprobe/mindspore/overflow_check/overflow_check_tool_factory.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from msprobe.core.common.log import logger from msprobe.mindspore.common.const import Const from msprobe.mindspore.debugger.debugger_config import DebuggerConfig from msprobe.mindspore.overflow_check.kernel_graph_overflow_check import KernelGraphOverflowCheck @@ -44,6 +45,7 @@ class OverflowCheckToolFactory: raise Exception("Valid level is needed.") tool = tool.get(config.execution_mode) if not tool: - raise Exception(f"Overflow check is not supported in {config.execution_mode} mode " - f"when level is {config.level}.") - return tool(config) + logger.error(f"Overflow check is not supported in {config.execution_mode} mode " + f"when level is {config.level}.") + raise ValueError + return (tool(config),) diff --git a/debug/accuracy_tools/msprobe/mindspore/runtime.py b/debug/accuracy_tools/msprobe/mindspore/runtime.py index 0191a484cbc096b2e211b22b5abce147eac23b97..9ea2e5d32f9db0fe4cc13a26eca52026dae9e599 100644 --- a/debug/accuracy_tools/msprobe/mindspore/runtime.py +++ b/debug/accuracy_tools/msprobe/mindspore/runtime.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,7 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +from msprobe.mindspore.common.const import Const + + class Runtime: step_count: int = 0 rank_id: int = -1 is_running: bool = False + run_mode: str = Const.PYNATIVE_MODE diff --git a/debug/accuracy_tools/msprobe/mindspore/service.py b/debug/accuracy_tools/msprobe/mindspore/service.py deleted file mode 100644 index 5afbd046be4caf29c4b247a0f8fdd655c5208fd0..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/msprobe/mindspore/service.py +++ /dev/null @@ -1,543 +0,0 @@ -# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import copy -import functools -import os -from collections import defaultdict - -import mindspore as ms -from mindspore import nn -from mindspore.common.api import _no_grad -from mindspore.ops.primitive import Primitive - -try: - from mindspore.common._pijit_context import PIJitCaptureContext -except ImportError: - pijit_label = False -else: - pijit_label = True - -from msprobe.core.common.exceptions import DistributedNotInitializedError, MsprobeException -from msprobe.core.common.file_utils import create_directory -from msprobe.core.common.utils import Const, print_tools_ends_info, DumpPathAggregation -from msprobe.core.data_dump.data_collector import build_data_collector -from msprobe.core.data_dump.data_processor.base import (ModuleBackwardInputsOutputs, ModuleForwardInputsOutputs, - ModuleBackwardInputs) -from msprobe.core.data_dump.scope import BaseScope -from msprobe.mindspore.cell_processor import CellProcessor -from msprobe.mindspore.common.log import logger -from msprobe.mindspore.common.utils import (get_rank_if_initialized, clean_input_kwargs, - is_mindtorch, register_backward_hook_functions) -from msprobe.mindspore.dump.hook_cell.api_registry import api_register -from msprobe.mindspore.dump.hook_cell.primitive_hooks import PrimitiveHookService -from msprobe.mindspore.dump.jit_dump import JitDump -from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell -from msprobe.mindspore.dump.kernel_dump.kernel_config import create_kernel_config_json - -if is_mindtorch(): - import torch - - -class Service: - def __init__(self, config): - self.model = None - self.config = copy.deepcopy(config) - self.config.level = self.config.level_ori - self.data_collector = build_data_collector(self.config) - self.cell_processor = CellProcessor(self.data_collector.scope) - self.primitive_hook_service = PrimitiveHookService(self) - self.switch = False - self.inner_switch = False - self.primitive_switch = False - self.current_iter = 0 - self.first_start = True - self.current_rank = None - self.dump_iter_dir = None - self.start_call = False - self.should_stop_service = False - self.params_grad_info = {} - self.hook_handle_dict = {} - # 提前注册,确保注册尽可能多的API hook - self.register_api_hook() - self.init_for_debug_level() - - @staticmethod - def check_model_valid(models): - target_module_type = (torch.nn.Module, "torch.nn.Module") if is_mindtorch() else (nn.Cell, "mindspore.nn.Cell") - if models is None or isinstance(models, target_module_type[0]): - return models - error_model = None - if isinstance(models, (list, tuple)): - for model in models: - if not isinstance(model, target_module_type[0]): - error_model = model - break - else: - error_model = models - - if error_model is not None: - error_info = (f"The 'model' parameter must be a {target_module_type[1]} or list[{target_module_type[1]}] " - f"type, currently there is a {type(error_model)} type.") - raise MsprobeException( - MsprobeException.INVALID_PARAM_ERROR, error_info) - return models - - @staticmethod - def prepare_module_input_output(target_type, cell, input_data, output): - if target_type == BaseScope.Module_Type_Module: - module_input_output = ModuleForwardInputsOutputs(args=input_data, kwargs={}, output=output) - else: - module_input_output = ModuleForwardInputsOutputs(args=input_data, kwargs=cell.input_kwargs, output=output) - return module_input_output - - def build_hook(self, target_type, name): - def pre_hook(api_or_cell_name, cell, input_data): - if not self.should_execute_hook(target_type, cell, True): - clean_input_kwargs(cell) - return None - - with _no_grad(): - self.inner_switch = True - if target_type == BaseScope.Module_Type_Module: - api_or_cell_name = self.cell_processor.set_and_get_reserved_name(cell, api_or_cell_name) - else: - cell.forward_data_collected = True - HOOKCell.add_cell_count(name) - module_input_output = self.prepare_module_input_output(target_type, cell, input_data, None) - self.data_collector.update_api_or_module_name(api_or_cell_name) - self.data_collector.forward_input_data_collect(api_or_cell_name, cell, pid, module_input_output) - self.inner_switch = False - return input_data - - def grad_hook(cell, ori_name, param_name): - def hook_fn(grad): - if not self.should_execute_hook(target_type, cell, False): - return None - self.inner_switch = True - self.data_collector.params_data_collect(ori_name, param_name, pid, grad) - self.inner_switch = False - return None - - return hook_fn - - def register_param_hook(ori_name, cell, params_dict): - ''' - 注册参数hook - ''' - # data_mode为forward时,不注册参数hook - if not (Const.FORWARD in self.config.data_mode and Const.BACKWARD not in self.config.data_mode): - for param_name, param in params_dict.items(): - if param.requires_grad: - name = ori_name + Const.SEP + param_name - old_handle = self.hook_handle_dict.get(name) - if old_handle and hasattr(old_handle, "remove"): - old_handle.remove() - handle = param.register_hook(grad_hook(cell, ori_name, param_name)) - self.hook_handle_dict[name] = handle - - def init_params_grad_info(cell, params_dict): - ''' - 初始化参数梯度信息, 在前向hook结束后, 将参数梯度信息写入cache_data中用于占位 - ''' - if not params_dict: - return - if not (Const.FORWARD in self.config.data_mode and Const.BACKWARD not in self.config.data_mode): - grad_name = cell.params_grad_name if hasattr(cell, 'params_grad_name') else None - # 判断是否已经在cache_data中进行了占位, 若没有则先写入cache_data中 - if not self.params_grad_info.get(grad_name): - data_info = {grad_name: {key: [None] for key, value in params_dict.items() if value.requires_grad}} - # 当模块中的参数有requires_grad属性为True时,才会进行梯度计算,此时才需要占位 - if data_info.get(grad_name): - # 将grad_name的data_info先写入cache_data中, 梯度计算后再更新 - self.data_collector.handle_data(grad_name, data_info, - flush=self.data_collector.data_processor.is_terminated) - # 记录当前模块的参数梯度信息已占位 - self.params_grad_info[grad_name] = True - - def forward_hook(api_or_cell_name, cell, input_data, output): - if not self.should_execute_hook(target_type, cell, True): - clean_input_kwargs(cell) - return None - with _no_grad(): - self.inner_switch = True - module_input_output = self.prepare_module_input_output(target_type, cell, input_data, output) - if target_type == BaseScope.Module_Type_Module: - api_or_cell_name = self.cell_processor.set_and_get_reserved_name(cell, api_or_cell_name) - params_dict = {} - if self.config.task != Const.STRUCTURE: - params_dict = { - key.split(Const.SEP)[-1]: value - for key, value in cell.parameters_dict(recurse=False).items() - } - setattr(module_input_output, Const.PARAMS, params_dict) - # 判断是否需要注册参数hook - if params_dict: - ori_name = api_or_cell_name.rsplit(Const.SEP, 2)[0] - grad_name = ori_name + Const.SEP + Const.PARAMS_GRAD - # 首次执行前向hook时,添加params_grad_name属性,并注册参数hook - setattr(cell, 'params_grad_name', grad_name) - register_param_hook(ori_name, cell, params_dict) - self.data_collector.update_api_or_module_name(api_or_cell_name) - self.data_collector.forward_data_collect(api_or_cell_name, cell, pid, module_input_output) - init_params_grad_info(cell, params_dict) - else: - self.data_collector.update_api_or_module_name(api_or_cell_name) - self.data_collector.forward_output_data_collect(api_or_cell_name, cell, pid, module_input_output) - - if self.data_collector.if_return_forward_new_output(): - forward_new_output = self.data_collector.get_forward_new_output() - self.inner_switch = False - return forward_new_output - clean_input_kwargs(cell) - self.inner_switch = False - return output - - def backward_hook(api_or_cell_name, cell, grad_input, grad_output): - if not self.should_execute_hook(target_type, cell, False): - return - self.inner_switch = True - - need_exchange = True - if target_type == BaseScope.Module_Type_Module: - if not hasattr(cell, 'has_pre_hook_called') or not cell.has_pre_hook_called: - need_exchange = False - api_or_cell_name = self.cell_processor.set_and_get_reserved_name(cell, api_or_cell_name) - - self.data_collector.update_api_or_module_name(api_or_cell_name) - if self.data_collector: - # 框架最新接口变更,grad_input和grad_output的含义发生了变化,与torch含义保持一致,因此此处调换顺序传入 - if need_exchange: - module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_output, grad_output=grad_input) - else: - module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_input, grad_output=grad_output) - self.data_collector.backward_data_collect(api_or_cell_name, cell, pid, module_input_output) - self.inner_switch = False - - def pre_backward_hook(api_or_cell_name, cell, grad_input): - if not self.should_execute_hook(target_type, cell, False): - return - self.inner_switch = True - module_input = ModuleBackwardInputs(grad_input=grad_input) - self.data_collector.update_api_or_module_name(api_or_cell_name) - self.data_collector.backward_input_data_collect(api_or_cell_name, cell, pid, module_input) - - self.inner_switch = False - - pid = os.getpid() - if target_type == BaseScope.Module_Type_Module: - full_forward_name = name + Const.FORWARD - full_backward_name = name + Const.BACKWARD - else: - full_forward_name = name + str(HOOKCell.get_cell_count(name)) + Const.SEP + Const.FORWARD - full_backward_name = name + str(HOOKCell.get_cell_count(name)) + Const.SEP + Const.BACKWARD - pre_forward_hook = functools.partial(pre_hook, full_forward_name) - forward_hook = functools.partial(forward_hook, full_forward_name) - backward_hook = functools.partial(backward_hook, full_backward_name) - pre_backward_hook = functools.partial(pre_backward_hook, full_backward_name) - - def wrap_pre_forward_hook(cell, input_data): - return pre_forward_hook(cell, input_data) - - def wrap_forward_hook(cell, input_data, output_data): - return forward_hook(cell, input_data, output_data) - - def wrap_backward_hook(cell, grad_input, grad_output): - return backward_hook(cell, grad_input, grad_output) - - def wrap_pre_backward_hook(cell, grad_input): - return pre_backward_hook(cell, grad_input) - - return wrap_pre_forward_hook, wrap_forward_hook, wrap_backward_hook, wrap_pre_backward_hook - - def update_primitive_counters(self, primitive_name): - if primitive_name not in self.primitive_counters: - self.primitive_counters[primitive_name] = 0 - else: - self.primitive_counters[primitive_name] += 1 - - def step(self): - if self.config.level == Const.LEVEL_DEBUG: - return - if self.config.async_dump: - self.data_collector.fill_stack_tensor_data() - if self.config.task == Const.TENSOR: - self.data_collector.data_processor.dump_async_data() - self.data_collector.write_json() - self.current_iter += 1 - self.data_collector.update_iter(self.current_iter) - self.reset_status() - - def start(self, model=None): - if self.config.level == Const.LEVEL_DEBUG: - return - self.start_call = True - if self.should_stop_service: - return - if self.need_end_service(): - self.should_stop_service = True - self.switch = False - self.primitive_switch = False - print_tools_ends_info() - return - if self.config.step and self.current_iter not in self.config.step: - return - self.model = self.check_model_valid(model) - - logger.info(f"{Const.TOOL_NAME}: debugger.start() is set successfully") - - if self.first_start: - try: - self.current_rank = get_rank_if_initialized() - except DistributedNotInitializedError: - self.current_rank = None - - if self.config.rank and self.current_rank not in self.config.rank: - return - self.register_primitive_hook() - self.register_cell_hook() - if self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L1]: - JitDump.set_config(self.config) - JitDump.set_data_collector(self.data_collector) - if hasattr(ms.common.api, "_MindsporeFunctionExecutor"): - ms.common.api._MindsporeFunctionExecutor = JitDump - else: - ms.common.api._JitExecutor = JitDump - ms.common.api._PyNativeExecutor.grad = JitDump.grad - if pijit_label: - PIJitCaptureContext.__enter__ = self.empty - PIJitCaptureContext.__exit__ = self.empty - self.first_start = False - - api_register.api_set_hook_func() - self.switch = True - self.primitive_switch = True - logger.info(f"Dump switch is turned on at step {self.current_iter}. ") - self.create_dirs() - logger.info(f"Dump data will be saved in {self.dump_iter_dir}.") - JitDump.jit_dump_switch = True - - def stop(self): - if self.config.level == Const.LEVEL_DEBUG: - return - if self.should_stop_service: - return - logger.info(f"{Const.TOOL_NAME}: debugger.stop() is set successfully. " - "Please set debugger.start() to turn on the dump switch again. ") - if not self.start_call: - logger.error(f"{Const.TOOL_NAME}: debugger.start() is not set in the current scope.") - raise Exception("debugger.start() is not set in the current scope.") - if self.config.step and self.current_iter not in self.config.step: - return - if self.config.rank and self.current_rank not in self.config.rank: - return - self.switch = False - self.primitive_switch = False - self.start_call = False - if self.config.async_dump: - self.data_collector.fill_stack_tensor_data() - if self.config.task == Const.TENSOR: - self.data_collector.data_processor.dump_async_data() - self.data_collector.write_json() - JitDump.jit_dump_switch = False - - def need_end_service(self): - if self.config.step and self.current_iter > max(self.config.step): - return True - if self.data_collector and self.data_collector.data_processor.is_terminated: - return True - return False - - def should_execute_hook(self, hook_type, cell, is_forward): - is_cell_hook = hook_type == BaseScope.Module_Type_Module - if is_cell_hook and not self.switch: - return False - elif not is_cell_hook and is_forward and not self.switch: - return False - elif not is_cell_hook and not is_forward and not cell.forward_data_collected: - return False - - if self.inner_switch: - return False - if not self.data_collector or self.data_collector.data_processor.is_terminated: - return False - return True - - def create_dirs(self): - create_directory(self.config.dump_path) - self.dump_iter_dir = os.path.join(self.config.dump_path, f"step{self.current_iter}") - cur_rank = self.current_rank if self.current_rank is not None else '' - if self.config.level == Const.LEVEL_L2: - create_directory(self.dump_iter_dir) - kernel_config_path = create_kernel_config_json(self.dump_iter_dir, cur_rank) - self.config.kernel_config_path = kernel_config_path - return - - dump_dir = os.path.join(self.dump_iter_dir, f"rank{cur_rank}") - create_directory(dump_dir) - if self.config.task in self.data_collector.tasks_need_tensor_data: - dump_data_dir = os.path.join(dump_dir, "dump_tensor_data") - create_directory(dump_data_dir) - else: - dump_data_dir = None - - dump_path_aggregation = DumpPathAggregation() - dump_path_aggregation.dump_file_path = os.path.join(dump_dir, "dump.json") - dump_path_aggregation.stack_file_path = os.path.join(dump_dir, "stack.json") - dump_path_aggregation.construct_file_path = os.path.join(dump_dir, "construct.json") - dump_path_aggregation.dump_tensor_data_dir = dump_data_dir - self.data_collector.update_dump_paths(dump_path_aggregation) - - self.data_collector.initialize_json_file( - framework=Const.MT_FRAMEWORK if is_mindtorch() else Const.MS_FRAMEWORK - ) - - def empty(self, *args, **kwargs): - pass - - def register_api_hook(self): - if self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L1, Const.LEVEL_L2]: - logger.info(f"The api {self.config.task} hook function is successfully mounted to the model.") - api_register.initialize_hook(functools.partial(self.build_hook, BaseScope.Module_Type_API)) - api_register.api_set_hook_func() - - def get_cells_and_names(self): - cells_and_names_with_index = {} - - def get_cell_or_module(model): - return model.named_modules() if is_mindtorch() else model.cells_and_names() - - if isinstance(self.model, (list, tuple)): - for index, model in enumerate(self.model): - cells_and_names_with_index[str(index)] = get_cell_or_module(model) - else: - cells_and_names_with_index["-1"] = get_cell_or_module(self.model) - return cells_and_names_with_index - - def register_primitive_hook(self): - if self.config.level not in [Const.LEVEL_MIX, Const.LEVEL_L1]: - return - if not self.model or self.config.task not in Const.DUMP_DATA_COLLECTION_LIST: - return - - primitive_set = set() - cells_and_names_with_index = self.get_cells_and_names() - for cells_and_names in cells_and_names_with_index.values(): - for _, cell in cells_and_names: - for attribute, value in vars(cell).items(): - if isinstance(value, Primitive): - primitive_set.add((attribute, value)) - - for pname, primitive in primitive_set: - primitive_class_name = primitive.__class__.__name__ - primitive_combined_name = pname + Const.SEP + primitive_class_name - new_primitive = type('NewPrimitive', (primitive.__class__,), - {'__call__': self.primitive_hook_service.wrap_primitive(primitive.__call__, - primitive_combined_name)}) - primitive.__class__ = new_primitive - - def register_cell_hook(self): - if self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L0]: - logger.info(f"The cell {self.config.task} hook function is successfully mounted to the model.") - if not self.model: - raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR, - f"The current level is {self.config.level}, the model cannot be None") - model_type = Const.MODULE if is_mindtorch() else Const.CELL - cells_and_names_with_index = self.get_cells_and_names() - - for index, cells_and_names in cells_and_names_with_index.items(): - model = self.model if index == "-1" else self.model[int(index)] - for name, cell in cells_and_names: - if cell == model: - continue - cell_index = (index + Const.SEP) if index != "-1" else "" - prefix = (model_type + Const.SEP + cell_index + name + - Const.SEP + cell.__class__.__name__ + Const.SEP) - _, forward_hook, backward_hook, _ = self.build_hook(BaseScope.Module_Type_Module, prefix) - cell.register_forward_hook(forward_hook) - cell.register_forward_pre_hook( - self.cell_processor.node_hook(prefix + Const.FORWARD, Const.START)) - cell.register_forward_hook( - self.cell_processor.node_hook(prefix + Const.FORWARD, Const.STOP)) - - register_backward_hook_functions["full"](cell, backward_hook) - register_backward_hook_functions["pre"]( - cell, self.cell_processor.node_hook(prefix + Const.BACKWARD, Const.START)) - register_backward_hook_functions["full"]( - cell, self.cell_processor.node_hook(prefix + Const.BACKWARD, Const.STOP)) - - def reset_status(self): - self.primitive_hook_service.primitive_counters.clear() - self.data_collector.reset_status() - JitDump.jit_count = defaultdict(int) - self.params_grad_info.clear() - if self.config.level == Const.LEVEL_L2: - self.data_collector.data_processor.reset_status() - return - if self.config.step and self.current_iter not in self.config.step: - return - if self.config.rank and self.current_rank not in self.config.rank: - return - - def init_for_debug_level(self): - if not (self.config.level == Const.LEVEL_DEBUG and self.config.task in [Const.TENSOR, Const.STATISTICS]): - return - try: - self.current_rank = get_rank_if_initialized() - except DistributedNotInitializedError: - self.current_rank = None - # dir: dump_path -- rank{} -- debug.json - self.dump_iter_dir = self.config.dump_path - cur_rank = self.current_rank if self.current_rank is not None else '' - dump_dir = os.path.join(self.dump_iter_dir, f"rank{cur_rank}") - create_directory(dump_dir) - if self.config.task in self.data_collector.tasks_need_tensor_data: - dump_data_dir = os.path.join(dump_dir, "dump_tensor_data") - create_directory(dump_data_dir) - else: - dump_data_dir = None - - dump_path_aggregation = DumpPathAggregation() - dump_path_aggregation.dump_tensor_data_dir = dump_data_dir - dump_path_aggregation.debug_file_path = os.path.join(dump_dir, "debug.json") - self.data_collector.update_dump_paths(dump_path_aggregation) - self.data_collector.initialize_json_file( - framework=Const.MT_FRAMEWORK if is_mindtorch() else Const.MS_FRAMEWORK - ) - self.debug_variable_counter = defaultdict(int) - - def save(self, variable, name, save_backward): - ''' - Args: - variable: Union[List[variable], dict{str: variable}, mindspore.tensor, str, float, int] - name: str - save_backward: boolean - Return: - void - ''' - if self.config.level != Const.LEVEL_DEBUG: - return - count = self.debug_variable_counter[name] - self.debug_variable_counter[name] += 1 - - name_with_count = f"{name}.{count}" - grad_name_with_count = f"{name}_grad.{count}" - - # forward save - self.data_collector.debug_data_collect_forward(variable, name_with_count) - - # backward save - if save_backward: - self.data_collector.debug_data_collect_backward(variable, grad_name_with_count) diff --git a/debug/accuracy_tools/msprobe/mindspore/task_handler_factory.py b/debug/accuracy_tools/msprobe/mindspore/task_handler_factory.py index a9cb5e6dd4037dcdeffe3c4d9584ad93c42022d6..10b74ea22b02d0668d0b3b17a569c5e1a67c1dd8 100644 --- a/debug/accuracy_tools/msprobe/mindspore/task_handler_factory.py +++ b/debug/accuracy_tools/msprobe/mindspore/task_handler_factory.py @@ -29,11 +29,14 @@ class TaskHandlerFactory: } @staticmethod - def create(config: DebuggerConfig): + def create(config: DebuggerConfig, model=None): task = TaskHandlerFactory.tasks.get(config.task) if not task: raise Exception("Valid task is needed.") - handler = task.create(config) + if task == DumpToolFactory: + handler = task.create(config, model) + else: + handler = task.create(config) if not handler: raise Exception("Can not find task handler") return handler diff --git a/debug/accuracy_tools/msprobe/msprobe.py b/debug/accuracy_tools/msprobe/msprobe.py index 8e0386fde6dccc071c3d9d8e1a86729a2c483c7c..eb857621d66331039b5bd50d5f99ba74d345d6fb 100644 --- a/debug/accuracy_tools/msprobe/msprobe.py +++ b/debug/accuracy_tools/msprobe/msprobe.py @@ -51,11 +51,12 @@ def main(): graph_service_cmd_parser = subparsers.add_parser('graph') op_generate_cmd_parser = subparsers.add_parser('op_generate') merge_result_parser = subparsers.add_parser('merge_result') + config_checking_parser = subparsers.add_parser('config_checking') + nan_analyze_parser = subparsers.add_parser('nan_analyze') _compare_parser(compare_cmd_parser) _merge_result_parser(merge_result_parser) is_torch_available = is_module_available("torch") - if len(sys.argv) < 4: parser.print_help() sys.exit(0) @@ -71,6 +72,9 @@ def main(): from msprobe.visualization.graph_service import _pt_graph_service_parser, _pt_graph_service_command from msprobe.pytorch.api_accuracy_checker.generate_op_script.op_generator import _op_generator_parser, \ _run_operator_generate_commond + from msprobe.pytorch.config_checking.config_checking import _config_checking_parser, \ + _run_config_checking_command + from msprobe.nan_analyze.analyzer import _nan_analyze_parser, _run_nan_analyze _run_ut_parser(run_ut_cmd_parser) _run_ut_parser(multi_run_ut_cmd_parser) @@ -80,6 +84,8 @@ def main(): _run_overflow_check_parser(run_overflow_check_cmd_parser) _pt_graph_service_parser(graph_service_cmd_parser) _op_generator_parser(op_generate_cmd_parser) + _config_checking_parser(config_checking_parser) + _nan_analyze_parser(nan_analyze_parser) elif framework_args.framework == Const.MS_FRAMEWORK: from msprobe.mindspore.api_accuracy_checker.cmd_parser import add_api_accuracy_checker_argument from msprobe.visualization.graph_service import _ms_graph_service_parser, _ms_graph_service_command @@ -118,6 +124,10 @@ def main(): compare_cli(args) elif sys.argv[3] == "merge_result": merge_result_cli(args) + elif sys.argv[3] == "config_checking": + _run_config_checking_command(args) + elif sys.argv[3] == "nan_analyze": + _run_nan_analyze(args) else: if not is_module_available(Const.MS_FRAMEWORK): logger.error("MindSpore does not exist, please install MindSpore library") diff --git a/debug/accuracy_tools/msprobe/nan_analyze/__init__.py b/debug/accuracy_tools/msprobe/nan_analyze/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b14094e3f9a77a0970342980ed8de1017f58ce19 --- /dev/null +++ b/debug/accuracy_tools/msprobe/nan_analyze/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/nan_analyze/analyzer.py b/debug/accuracy_tools/msprobe/nan_analyze/analyzer.py new file mode 100644 index 0000000000000000000000000000000000000000..e147f23b7c7bd514a13251830e0365928876bc75 --- /dev/null +++ b/debug/accuracy_tools/msprobe/nan_analyze/analyzer.py @@ -0,0 +1,255 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +from collections import defaultdict +import os +from itertools import dropwhile, chain + +from msprobe.core.common.file_utils import check_file_or_directory_path, save_json, make_dir +from msprobe.core.common.log import logger +from msprobe.core.common.const import Const +from msprobe.nan_analyze.utils import (RankPath, FileCache, is_communication_op, is_ignore_op, NanAnalyseConst, + analyze_anomaly_in_group) +from msprobe.nan_analyze.graph import DataNode, CommunicationNode + + +class NanAnalyzer: + def __init__(self, input_path, output_path): + self._input_path = input_path + self._output_path = output_path + self._paths = {} + self._resolve_input_path() + self._anomaly_nodes = [] # 记录所有异常节点 + self._cache = FileCache() + self._first_comm_nodes = {} # 记录各rank下首个通信节点的node_id + self._after_comm_anomalies = {} # 记录各rank下发生在通信节点之后的异常计算节点 + self._rank_comm_nodes_dict = {} # 记录各rank的通信节点 + + def analyze(self): + for analyze_func in [self._pre_analyze, self._analyze, self._post_analyze]: + analyze_func() + if self._anomaly_nodes: + self._gen_analyze_info() + return + logger.info('Cannot find any anomaly node, no need to generate analyze file.') + + def _resolve_input_path(self): + contents = os.listdir(self._input_path) + for path in contents: + if not path.startswith('rank'): + continue + rank_str = path.strip('rank') + if not rank_str: + rank = 0 + elif not rank_str.isdigit(): + continue + else: + rank = int(rank_str) + dump_path = os.path.join(self._input_path, path, NanAnalyseConst.DUMP_FILE) + construct_path = os.path.join(self._input_path, path, NanAnalyseConst.CONSTRUCT_FILE) + stack_path = os.path.join(self._input_path, path, NanAnalyseConst.STACK_FILE) + self._paths[rank] = RankPath(rank, dump_path, construct_path, stack_path) + + def _pre_analyze(self): + logger.info('Start searching anomaly node before communication.') + for path in self._paths.values(): + dump_data = self._cache.load_json(path.dump_path).get('data') + if not dump_data: + logger.warning(f'Rank {path.rank} has no dump data!') + continue + for op_name, op_data in dump_data.items(): + if is_communication_op(op_name): + self._first_comm_nodes[path.rank] = op_name + break + data_node = DataNode(op_name, path.rank, op_data) + if data_node.is_anomaly(): + self._anomaly_nodes.append(data_node) + break + + def _analyze(self): + logger.info('Start searching anomaly node during communication.') + self._rank_comm_nodes_dict = {rank: self._analyze_comm_nodes(rank) for rank in self._paths} + self._connect_comm_nodes() + self._pruning() + self._search_first_anomaly() + + def _post_analyze(self): + logger.info('Start searching anomaly node after communication.') + for nodes in self._after_comm_anomalies.values(): + if nodes: + self._anomaly_nodes.append(nodes[0]) + + def _gen_analyze_info(self): + if not os.path.exists(self._output_path): + make_dir(self._output_path) + file_name = f'anomaly_analyze_{time.time_ns()}.json' + result_file = os.path.join(self._output_path, file_name) + result_content = defaultdict(list) + for node in self._anomaly_nodes: + result_content[f'rank_{node.rank}'].append(node.gen_node_info(self._paths[node.rank])) + save_json(result_file, result_content, 2) + logger.info(f"The analyze result is saved in: {result_file}") + + def _analyze_comm_nodes(self, rank): + path = self._paths[rank] + data = self._cache.load_json(path.dump_path).get('data') + communication_nodes = {} + if rank not in self._first_comm_nodes: # 此rank没有通信节点 + return communication_nodes + last_node_id = None # 记录上一个通信节点的node_id + compute_ops = [] # 记录两个通信节点之间的计算节点 + sub_layer = 0 # 记录两个通信算子之间异常计算节点的调用序数 + for op_name in dropwhile(lambda k: k != self._first_comm_nodes[rank], data): + node_id = f'{rank}.{op_name}' + op_data = data[op_name] + if is_communication_op(op_name): + comm_node = CommunicationNode(node_id, rank, DataNode(op_name, rank, op_data, sub_layer=sub_layer), + compute_ops=compute_ops) + if last_node_id: + communication_nodes.get(last_node_id).add_next(comm_node) + communication_nodes[node_id] = comm_node + last_node_id = node_id + compute_ops = [] + sub_layer = 0 + elif not is_ignore_op(op_name): + data_node = DataNode(op_name, rank, op_data, sub_layer=sub_layer) + if data_node.is_anomaly(): + compute_ops.append(data_node) + sub_layer += 1 + if compute_ops: + self._after_comm_anomalies[rank] = compute_ops + return communication_nodes + + def _connect_comm_nodes(self): + searched_ranks = set() + for rank, nodes in list(self._rank_comm_nodes_dict.items())[:-1]: + searched_ranks.add(rank) + seen_nodes = set() + for cur_node in nodes.values(): + conn_info = cur_node.find_connected_nodes() + if not conn_info.get('ranks'): + conn_info['ranks'] = self._rank_comm_nodes_dict.keys() + if not self._find_connection(conn_info, cur_node, searched_ranks, seen_nodes): + logger.info(f'Cannot find connected communication node for "{cur_node.node_id}".') + + def _find_connection(self, conn_info, cur_node, searched_ranks, seen_nodes): + def connect(): + seen_nodes.add(search_node.node_id) + if search_node.type == NanAnalyseConst.DST: + cur_node.add_dst(search_node) + elif search_node.type == NanAnalyseConst.SRC: + search_node.layer = cur_node.layer + search_node.add_dst(cur_node) + else: + cur_node.add_link(search_node) + + found = cur_node.connected + for connected_rank in conn_info['ranks']: + if connected_rank in searched_ranks: + continue + tar_id_prefix = f'{connected_rank}.{conn_info["api"]}' + for search_id, search_node in self._rank_comm_nodes_dict[connected_rank].items(): + if search_id in seen_nodes: + continue + if not (search_id.startswith(tar_id_prefix) and search_node.type == conn_info.get('type')): + continue + search_conn_ranks = search_node.find_connected_nodes().get('ranks') + if ((not search_conn_ranks and search_node.api not in NanAnalyseConst.DIRECTED_API) or + cur_node.rank in search_conn_ranks): # 有些无向通信算子没有填ProcessGroup,默认连接所有rank + connect() + found = True + break + return found + + def _pruning(self): + deleted_node_id = [] + for nodes in self._rank_comm_nodes_dict.values(): + for node_id in list(nodes.keys()): + node = nodes[node_id] + if node.has_nan_inf() or node.compute_ops: + continue + deleted_node_id.append(node_id) + node.delete() + del nodes[node_id] + logger.debug(f'After pruning, following nodes are removed: [{", ".join(deleted_node_id)}]') + + def _search_first_anomaly(self): + nodes_queues = [] + for comm_nodes in self._rank_comm_nodes_dict.values(): + nodes_queues.append(sorted(list(comm_nodes.values()), key=lambda x: x.layer)) + seen_nodes = set() + + def get_next_node(node_list): + while node_list: + next_node = node_list.pop(0) + if next_node.node_id not in seen_nodes: + return next_node + return None + + def find_all_members(ori_node): + ids = get_relative_ids(ori_node) + id_queue = list(chain(*[get_relative_ids(self._get_node_by_id(n_id)).difference(ids) for n_id in ids])) + while id_queue: + new_id = id_queue.pop(0) + ids.add(new_id) + id_queue.extend(get_relative_ids(self._get_node_by_id(new_id)).difference(ids)) + return ids + + def get_relative_ids(ori_node): + if not ori_node: + return set() + return ({ori_node.node_id} | ori_node.link_nodes.keys() | ori_node.src_nodes.keys() | + ori_node.dst_nodes.keys()) + + while any(nodes_queues): + groups = [] + all_ids_in_groups = set() + for nodes in nodes_queues: + node = get_next_node(nodes) + if not node: + continue + if not groups or node.node_id in all_ids_in_groups: + new_group = find_all_members(node) + groups.append(new_group) + all_ids_in_groups.update(new_group) + for group in groups: + seen_nodes.update(group) + self._anomaly_nodes.extend(analyze_anomaly_in_group([self._get_node_by_id(n_id) for n_id in group])) + if self._anomaly_nodes: + self._anomaly_nodes = [min(self._anomaly_nodes, key=lambda x: (x.layer, x.sub_layer))] + return + + def _get_node_by_id(self, node_id): + splits = node_id.split(Const.SEP, 1) + if len(splits) < 2 or not splits[0].isdigit(): + logger.error(f'invalid node_id {node_id}') + raise RuntimeError(f'invalid node_id {node_id}') + rank = int(splits[0]) + return self._rank_comm_nodes_dict.get(rank, {}).get(node_id) + + +def _nan_analyze_parser(parser): + parser.add_argument("-i", "--input_path", dest="input_path", default="", type=str, + help=" The dump file path, over step level. eg: \"xxx/step_0/\".", + required=True) + parser.add_argument("-o", "--output_path", dest="output_path", default="./output", type=str, + help=" The nan inf analyze result output file path.", + required=False) + + +def _run_nan_analyze(args): + check_file_or_directory_path(args.input_path, True) + NanAnalyzer(args.input_path, args.output_path).analyze() \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/nan_analyze/graph.py b/debug/accuracy_tools/msprobe/nan_analyze/graph.py new file mode 100644 index 0000000000000000000000000000000000000000..5a4f8fb87296a39796b5124854ba7060be71d53a --- /dev/null +++ b/debug/accuracy_tools/msprobe/nan_analyze/graph.py @@ -0,0 +1,193 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from msprobe.core.common.const import Const +from msprobe.core.common.log import logger +from msprobe.core.common.exceptions import MsprobeException +from msprobe.nan_analyze.utils import FileCache, RankPath, is_ignore_op, check_item_anomaly, NanAnalyseConst + + +@dataclass +class DataNode: + op_name: str + rank: int + inputs: list + input_args: list + input_kwargs: dict + outputs: dict + layer: int = 0 # 和communication_node的layer保持一致 + sub_layer: int = 0 # 调用顺序,越小表示越先调用 + + def __init__(self, op_name, rank, op_data, **kwargs): + self.op_name = op_name + self.rank = rank + self.inputs = op_data.get(Const.INPUT, []) + self.input_args = op_data.get(Const.INPUT_ARGS, []) + self.input_kwargs = op_data.get(Const.INPUT_KWARGS, {}) + self.outputs = op_data.get(Const.OUTPUT, {}) + self.sub_layer = kwargs.get('sub_layer', 0) + + @staticmethod + def find_complete_construct(construct_info, op_name): + construct = [op_name] + seen = set(op_name) + while True: + op_name = construct_info.get(op_name) + if not op_name or op_name in seen: + return construct + construct.insert(0, op_name) + seen.add(op_name) + + def find_stack(self, stack_info): + for item in stack_info.values(): + if not isinstance(item, list): + raise MsprobeException(MsprobeException.UNSUPPORTED_TYPE_ERROR, + f'The value\'s type in stack.json should be a list, not {type(item)}!') + if len(item) >= 2 and self.op_name in item[0]: + return item[1] + return {} + + def is_anomaly(self) -> bool: + if is_ignore_op(self.op_name): + return False + is_input_anomaly = (check_item_anomaly(self.inputs) or check_item_anomaly(self.input_args) or + check_item_anomaly(self.input_kwargs)) + is_output_anomaly = check_item_anomaly(self.outputs) + return (not is_input_anomaly) and is_output_anomaly + + def gen_node_info(self, path: RankPath): + cache = FileCache() + construct = cache.load_json(path.construct_path) + stack = cache.load_json(path.stack_path) + if Const.FORWARD in self.op_name: + data_info_list = {Const.INPUT_ARGS: self.input_args, Const.INPUT_KWARGS: self.input_kwargs, + Const.OUTPUT: self.outputs} + else: + data_info_list = {Const.INPUT: self.inputs, Const.OUTPUT: self.outputs} + return {'op_name': self.op_name, + 'data_info': data_info_list, + 'construct_info': self.find_complete_construct(construct, self.op_name), + 'stack_info': self.find_stack(stack)} + + +class CommunicationNode: + def __init__(self, node_id, rank, data: DataNode, layer=0, **kwargs): + self.node_id = node_id + self.rank = rank + self.data = data + self.layer = layer + op_name_split = self.data.op_name.split(Const.SEP) + if len(op_name_split) < 4: + logger.error(f'invalid op_name: {self.data.op_name}') + raise RuntimeError(f'invalid op_name: {self.data.op_name}') + self.api = op_name_split[1] + self.call_cnt = op_name_split[2] + self.pre_node = kwargs.get('pre_node') + self.link_nodes = kwargs.get('link_nodes', {}) + self.dst_nodes = kwargs.get('dst_nodes', {}) + self.src_nodes = kwargs.get('src_nodes', {}) + self.next_nodes = kwargs.get('next_nodes', {}) + self.compute_ops = kwargs.get('compute_ops', []) + self.type = self._resolve_type() + self.connected = False + + def add_next(self, node): + self.next_nodes[node.node_id] = node + node.pre_node = self + node.layer = self.layer + 1 + node.data.layer = node.layer + + def add_link(self, node): + self.link_nodes[node.node_id] = node + node.link_nodes[self.node_id] = self + node.layer = self.layer + node.data.layer = node.layer + self.connected = True + node.connected = True + + def add_dst(self, node): + self.dst_nodes[node.node_id] = node + node.src_nodes[self.node_id] = self + node.layer = self.layer + node.data.layer = node.layer + self.connected = True + node.connected = True + + def delete(self): + for node in self.next_nodes.values(): + node.pre_node = None + for node in self.dst_nodes.values(): + node.src_nodes.pop(self.node_id) + for node in self.src_nodes.values(): + node.dst_nodes.pop(self.node_id) + for node in self.link_nodes.values(): + node.link_nodes.pop(self.node_id) + if self.pre_node: + self.pre_node.next_nodes.pop(self.node_id) + + def has_nan_inf(self): + return self.input_has_nan_inf() or check_item_anomaly(self.data.outputs) + + def input_has_nan_inf(self): + return check_item_anomaly(self.data.input_args) or check_item_anomaly(self.data.input_kwargs) + + def find_connected_nodes(self): + """ + 根据 api/类型/入参/调用次数 确定相连接的node的op_name + """ + tar_api = NanAnalyseConst.P2P_API_MAPPING.get(self.api, self.api) + ranks = set() + for dst in [NanAnalyseConst.DST, NanAnalyseConst.DST_GROUP]: + if dst in self.data.input_kwargs: + dst_value = self.data.input_kwargs.get(dst) + if dst_value: + ranks.add(dst_value.get('value')) + break + for src in [NanAnalyseConst.SRC, NanAnalyseConst.SRC_GROUP]: + if src in self.data.input_kwargs: + src_value = self.data.input_kwargs.get(src) + if src_value: + ranks.add(src_value.get('value')) + break + if not ranks: + for item in self.data.input_args: + if isinstance(item, dict) and item.get(Const.TYPE) == 'int': + ranks.add(item.get('value')) + group = self.data.input_kwargs.get('group') + if group: + ranks.update(group.get('group_ranks')) + return {'ranks': ranks, 'api': f'Distributed.{tar_api}', + 'type': NanAnalyseConst.OPPOSITE_DIR.get(self.type, NanAnalyseConst.LINK)} + + def _resolve_type(self): + for src in [NanAnalyseConst.SRC, NanAnalyseConst.SRC_GROUP]: + if src in self.data.input_kwargs and self.data.input_kwargs[src]: + if self.data.input_kwargs[src].get('value') == self.rank: + return NanAnalyseConst.SRC + else: + return NanAnalyseConst.DST + for dst in [NanAnalyseConst.DST, NanAnalyseConst.DST_GROUP]: + if dst in self.data.input_kwargs and self.data.input_kwargs[dst]: + if self.data.input_kwargs[dst].get('value') == self.rank: + return NanAnalyseConst.DST + else: + return NanAnalyseConst.SRC + if self.api in NanAnalyseConst.DIRECTED_API: + for item in self.data.input_args: + if item.get(Const.TYPE) == 'int': + node_type = NanAnalyseConst.DIRECTED_API[self.api] + return node_type if item.get('value') == self.rank else NanAnalyseConst.OPPOSITE_DIR[node_type] + return NanAnalyseConst.LINK diff --git a/debug/accuracy_tools/msprobe/nan_analyze/utils.py b/debug/accuracy_tools/msprobe/nan_analyze/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..aaed65106f32f6c6cfc11911596bc11accd6a0df --- /dev/null +++ b/debug/accuracy_tools/msprobe/nan_analyze/utils.py @@ -0,0 +1,211 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import OrderedDict +from dataclasses import dataclass +import sys +import time +import psutil + +from msprobe.core.common.const import CompareConst +from msprobe.core.common.file_utils import check_file_or_directory_path, load_json + + +@dataclass +class RankPath: + rank: int + dump_path: str + construct_path: str + stack_path: str + + def __init__(self, rank, dump_path, construct_path, stack_path): + self.rank = rank + check_file_or_directory_path(dump_path) + self.dump_path = dump_path + check_file_or_directory_path(construct_path) + self.construct_path = construct_path + check_file_or_directory_path(stack_path) + self.stack_path = stack_path + + +class FileCache: + """ + lazy load file + """ + _instance = None + + def __new__(cls, *args, **kwargs): + if not cls._instance: + cls._instance = super().__new__(cls, *args, **kwargs) + return cls._instance + + def __init__(self): + self._max_memory_usage = psutil.virtual_memory().available / 4 # 最大占用当前可用内存空间的1/4 + self._cache = OrderedDict() + self._access_cnt = {} + self._access_time = {} + self._size = {} + + @staticmethod + def _sizeof(obj): + seen = set() + objs = [obj] + size = 0 + while objs: + obj = objs.pop() + obj_id = id(obj) + if obj_id in seen: + continue + seen.add(obj_id) + size += sys.getsizeof(obj) + if isinstance(obj, dict): + objs.extend(obj.keys()) + objs.extend(obj.values()) + elif isinstance(obj, (list, tuple, set, frozenset)): + objs.extend(obj) + return size + + def load_json(self, json_path): + if json_path in self._cache: + self._access_cnt[json_path] += 1 + self._access_time[json_path] = time.monotonic() + self._cache.move_to_end(json_path) + return self._cache[json_path] + self._cleanup() + return self._load(json_path) + + def _load(self, json_path): + data = load_json(json_path) + self._add_to_cache(json_path, data) + return data + + def _add_to_cache(self, key, data): + if key in self._cache: + self._cache.move_to_end(key) + else: + self._cache[key] = data + self._access_cnt[key] = 0 + self._access_time[key] = time.monotonic() + self._size[key] = self._sizeof(data) + + def _calc_cache_size(self): + return sys.getsizeof(self._cache) + sum(self._size.values()) + + def _cleanup(self): + while self._calc_cache_size() > self._max_memory_usage and self._cache: + least_frequent_key = min(self._access_cnt.keys(), key=lambda k: self._access_cnt[k]) + least_recent_key = min(self._access_time.keys(), key=lambda k: self._access_time[k]) + largest_key = max(self._cache.keys(), key=lambda k: self._size[k]) + key_to_rm = min([least_frequent_key, least_recent_key, largest_key], + key=lambda k: (self._access_cnt[k], self._access_time[k], -self._size[k])) + del self._cache[key_to_rm] + del self._access_cnt[key_to_rm] + del self._access_time[key_to_rm] + del self._size[key_to_rm] + + +def is_communication_op(op_name): + # 定义通信算子的关键字,覆盖各种通信操作,如all_reduce, send, broadcast等 + # 从wrap文件中读取,先硬编码在文件中 + return (op_name.startswith('Distributed.') and + any(keyword in op_name for keyword in NanAnalyseConst.COMMUNICATION_KEYWORDS)) + + +def is_ignore_op(op_name): + ignore_keywords = [ + 'Torch.empty', + 'Torch.fill' + ] + return any(keyword in op_name for keyword in ignore_keywords) + + +def check_item_anomaly(param): + def has_nan_inf(dict_obj, key): + return str(dict_obj.get(key)).lower() in CompareConst.OVERFLOW_LIST + + items = [] + if isinstance(param, list): + items = param + elif isinstance(param, dict): + items = param.values() + for item in items: + if not isinstance(item, dict): + continue + if has_nan_inf(item, 'Max') or has_nan_inf(item, 'Min'): + return True + return False + + +def analyze_anomaly_in_group(nodes_group): + anomaly_nodes = [] + + def get_compute_ops_from_comm_nodes(comm_nodes): + for comm_node in comm_nodes: + for op_node in comm_node.compute_ops: + op_node.layer = comm_node.layer + anomaly_nodes.append(op_node) + + def get_comm_ops(comm_nodes): + for node in comm_nodes: + node.data.layer = node.layer + anomaly_nodes.append(node.data) + + # 先看src或link中input是否有异常 + src_list = list(filter(lambda node: node.type in [NanAnalyseConst.SRC, NanAnalyseConst.LINK], nodes_group)) + input_anomaly_nodes = list(filter(lambda node: node.input_has_nan_inf(), src_list)) + # 如果有异常回溯计算节点找到异常来源 + # 使用cpu模拟节点进行计算,查看结果是否有问题。需要对所有计算节点录入/映射,暂不实现。 + get_compute_ops_from_comm_nodes(input_anomaly_nodes) + # 筛选入参没问题但出参有问题的通信节点 + output_anomaly_nodes = list(filter(lambda node: node.data.is_anomaly(), nodes_group)) + get_comm_ops(output_anomaly_nodes) + return anomaly_nodes + + +class NanAnalyseConst: + COMMUNICATION_KEYWORDS = { + 'send', # send 算子 + 'recv', # recv 算子 + 'broadcast', # broadcast 算子 + 'all_reduce', # all_reduce 算子 + 'reduce', # reduce 算子 + 'all_gather', # all_gather 算子 + 'gather', # gather 算子 + 'isend', # isend 算子 + 'irecv', # irecv 算子 + 'scatter', # scatter 算子 + 'reduce_scatter', # reduce_scatter 算子 + '_reduce_scatter_base', # _reduce_scatter_base 算子 + '_all_gather_base', # _all_gather_base 算子 + 'all_to_all_single', # all_to_all_single 算子 + 'all_to_all', # all_to_all 算子 + 'all_gather_into_tensor', # all_gather_into_tensor 算子 + 'reduce_scatter_tensor', # reduce_scatter_tensor 算子 + 'send_object_list', # send_object_list 算子 + 'recv_object_list' # recv_object_list 算子 + } + P2P_API_MAPPING = {'send': 'recv', 'recv': 'send', 'isend': 'irecv', 'irecv': 'isend', + 'send_object_list': 'recv_object_list', 'recv_object_list': 'send_object_list'} + SRC = 'src' + DST = 'dst' + SRC_GROUP = 'src_group' + DST_GROUP = 'dst_group' + LINK = 'link' + DIRECTED_API = {'send': DST, 'recv': SRC, 'isend': DST, 'irecv': SRC, 'broadcast': SRC, 'scatter': SRC, + 'gather': DST, 'send_object_list': DST, 'recv_object_list': SRC} + OPPOSITE_DIR = {SRC: DST, DST: SRC} + DUMP_FILE = "dump.json" + CONSTRUCT_FILE = "construct.json" + STACK_FILE = "stack.json" diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/common/config.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/common/config.py index f2b2d6a30463c62846bcc02e147c9c319f55d1b8..f9827b52b9737f97b80d17a28263598622109a0c 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/common/config.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/common/config.py @@ -52,7 +52,9 @@ class Config: 'host': str, 'port': int, 'rank_list': list, - 'tls_path': str + 'tls_path': str, + 'master_ip': str, + 'master_port': str } if key not in validators: raise ValueError(f"{key} must be one of {validators.keys()}") @@ -72,6 +74,10 @@ class Config: RunUTConfig.check_nfs_path_config(value) if key == 'tls_path': RunUTConfig.check_tls_path_config(value) + if key == 'master_ip': + RunUTConfig.check_master_ip_config(value) + if key == 'master_port': + RunUTConfig.check_master_port_config(value) return value @@ -91,6 +97,8 @@ class CheckerConfig: self.port = msCheckerConfig.port self.rank_list = msCheckerConfig.rank_list self.tls_path = msCheckerConfig.tls_path + self.master_ip = msCheckerConfig.master_ip + self.master_port = msCheckerConfig.master_port if task_config: self.load_config(task_config) @@ -105,6 +113,8 @@ class CheckerConfig: self.port = task_config.port self.rank_list = task_config.rank_list self.tls_path = task_config.tls_path + self.master_ip = task_config.master_ip + self.master_port = task_config.master_port def get_online_config(self): return OnlineConfig( @@ -125,8 +135,8 @@ class CheckerConfig: save_error_data=config_params.get('save_error_data'), is_continue_run_ut=config_params.get('is_continue_run_ut'), real_data_path=config_params.get('real_data_path'), - white_list=self.white_list, - black_list=self.black_list, + white_list=self.white_list.copy() if self.white_list else [], + black_list=self.black_list.copy() if self.black_list else [], error_data_path=config_params.get('error_data_path'), online_config=self.get_online_config() ) diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/algorithm.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/algorithm.py index ddee254c2b1085f9af96fe2774c53fb88c5821f4..abe8f2b4b3cd1cf8195fc86ed5c6a07e1daddf15 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/algorithm.py @@ -261,3 +261,54 @@ def compare_bool_tensor(bench_output, device_output): error_rate = float(error_nums / bench_output.size) result = CompareConst.PASS if error_rate == 0 else CompareConst.ERROR return error_rate, result, "" + + +def maximize_kahan_loss(cumsum, addend, negative=False): + """ + Calculate the precision loss in Kahan summation and select the maximum or minimum loss. + + Parameters: + cumsum (torch.Tensor): The current cumulative sum. + addend (torch.Tensor): The value to be added in the current step. + negative (bool): Whether to select the negative direction of loss. + Default is False (select positive direction which minimizes the sum). + + Returns: + loss_res (torch.Tensor): The selected maximum or minimum loss value. + mask (torch.Tensor): + A boolean mask indicating whether the loss value should be compensated. + """ + loss_all = (cumsum + addend) - cumsum - addend + if negative: + loss_res = torch.min(loss_all, dim=0)[0] + mask = loss_res <= 0 + else: + loss_res = torch.max(loss_all, dim=0)[0] + mask = loss_res >= 0 + return loss_res, mask + + +def kahan_range(tensors, negative=False): + """ + Perform Kahan summation on a list of tensors and track precision loss. + + Parameters: + tensors (list of torch.Tensor): The list of tensors to be summed. + negative (bool): Whether to select the negative direction of loss. + Default is False (select positive direction which minimizes the sum). + Returns: + sum_max: The summation results. + """ + if len(tensors) < 1: + raise ValueError("tensors should have at least 1 element") + cumsum_temp = torch.clone(tensors[0]).unsqueeze(dim=0) + sum_max = torch.clone(tensors[0]) + loss_max = torch.tensor(0) + + for tensor in tensors[1:]: + addend = tensor - loss_max + loss_max, mask = maximize_kahan_loss(cumsum_temp, addend, negative) + sum_max = sum_max + (addend - torch.where(mask, loss_max, 0)) + loss_max = torch.where(mask, 0, loss_max) + cumsum_temp = torch.cat((cumsum_temp, sum_max.unsqueeze(dim=0))) + return sum_max diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py index 8f7db73b58f42a4a64728bb0f12d25cf6f9f9ebe..55e93d271cec67334fe21c1f6466df2d0254a36b 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py @@ -40,7 +40,7 @@ from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import get_validat from msprobe.pytorch.api_accuracy_checker.common.utils import extract_detailed_api_segments, extract_basic_api_segments from msprobe.core.common.file_utils import FileChecker, change_mode, create_directory from msprobe.pytorch.common.log import logger -from msprobe.core.common.utils import CompareException +from msprobe.core.common.utils import CompareException, check_op_str_pattern_valid from msprobe.core.common.const import Const, CompareConst, FileCheckConst CompareConfig = namedtuple('CompareConfig', ['npu_csv_path', 'gpu_csv_path', 'result_csv_path', 'details_csv_path']) @@ -151,6 +151,7 @@ def analyse_csv(npu_data, gpu_data, config): message = '' compare_column = ApiPrecisionOutputColumn() full_api_name_with_direction_status = row_npu[ApiPrecisionCompareColumn.API_NAME] + check_op_str_pattern_valid(full_api_name_with_direction_status) row_gpu = gpu_data[gpu_data[ApiPrecisionCompareColumn.API_NAME] == full_api_name_with_direction_status] api_name, api_full_name, direction_status = extract_detailed_api_segments(full_api_name_with_direction_status) if not api_full_name: @@ -430,6 +431,7 @@ def _api_precision_compare(parser=None): _api_precision_compare_parser(parser) args = parser.parse_args(sys.argv[1:]) _api_precision_compare_command(args) + logger.info("Compare task completed.") def _api_precision_compare_command(args): @@ -457,8 +459,3 @@ def _api_precision_compare_parser(parser): parser.add_argument("-o", "--out_path", dest="out_path", default="", type=str, help=" The api precision compare task result out path.", required=False) - - -if __name__ == '__main__': - _api_precision_compare() - logger.info("Compare task completed.") diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare.py index cf5928e509e3138ea762cd9d7af6fc26a5d2c5c9..de80bf6f59ced66acf3589cc217b5479c6b4f175 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare.py @@ -40,6 +40,7 @@ from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import check_dty DETAIL_TEST_ROWS, BENCHMARK_COMPARE_SUPPORT_LIST from msprobe.pytorch.api_accuracy_checker.common.utils import extract_basic_api_segments from msprobe.pytorch.common.log import logger +from msprobe.core.common.decorator import recursion_depth_decorator ResultInfo = namedtuple('ResultInfo', ['full_api_name', 'fwd_success_status', 'bwd_success_status', @@ -165,6 +166,41 @@ class Comparator: accumulative_error_compare = AccumulativeErrorCompare(input_data) accumulative_error_compare.compare() + @recursion_depth_decorator("compare_core") + def _compare_core(self, api_name, bench_output, device_output): + compare_column = CompareColumn() + if not isinstance(bench_output, type(device_output)): + status = CompareConst.ERROR + message = "bench and npu output type is different." + elif isinstance(bench_output, dict): + b_keys, n_keys = set(bench_output.keys()), set(device_output.keys()) + if b_keys != n_keys: + status = CompareConst.ERROR + message = "bench and npu output dict keys are different." + else: + status, compare_column, message = self._compare_core(api_name, list(bench_output.values()), + list(device_output.values())) + elif isinstance(bench_output, torch.Tensor): + copy_bench_out = bench_output.detach().clone() + copy_device_output = device_output.detach().clone() + compare_column.bench_type = str(copy_bench_out.dtype) + compare_column.npu_type = str(copy_device_output.dtype) + compare_column.shape = tuple(device_output.shape) + status, compare_column, message = self._compare_torch_tensor(api_name, copy_bench_out, copy_device_output, + compare_column) + elif isinstance(bench_output, (bool, int, float, str)): + compare_column.bench_type = str(type(bench_output)) + compare_column.npu_type = str(type(device_output)) + status, compare_column, message = self._compare_builtin_type(bench_output, device_output, compare_column) + elif bench_output is None: + status = CompareConst.SKIP + message = "Bench output is None, skip this test." + else: + status = CompareConst.ERROR + message = "Unexpected output type in compare_core: {}".format(type(bench_output)) + + return status, compare_column, message + def write_csv_title(self): summary_test_rows = [ [self.COLUMN_API_NAME, @@ -293,40 +329,6 @@ class Comparator: test_final_success = CompareConst.WARNING return test_final_success, detailed_result_total - def _compare_core(self, api_name, bench_output, device_output): - compare_column = CompareColumn() - if not isinstance(bench_output, type(device_output)): - status = CompareConst.ERROR - message = "bench and npu output type is different." - elif isinstance(bench_output, dict): - b_keys, n_keys = set(bench_output.keys()), set(device_output.keys()) - if b_keys != n_keys: - status = CompareConst.ERROR - message = "bench and npu output dict keys are different." - else: - status, compare_column, message = self._compare_core(api_name, list(bench_output.values()), - list(device_output.values())) - elif isinstance(bench_output, torch.Tensor): - copy_bench_out = bench_output.detach().clone() - copy_device_output = device_output.detach().clone() - compare_column.bench_type = str(copy_bench_out.dtype) - compare_column.npu_type = str(copy_device_output.dtype) - compare_column.shape = tuple(device_output.shape) - status, compare_column, message = self._compare_torch_tensor(api_name, copy_bench_out, copy_device_output, - compare_column) - elif isinstance(bench_output, (bool, int, float, str)): - compare_column.bench_type = str(type(bench_output)) - compare_column.npu_type = str(type(device_output)) - status, compare_column, message = self._compare_builtin_type(bench_output, device_output, compare_column) - elif bench_output is None: - status = CompareConst.SKIP - message = "Bench output is None, skip this test." - else: - status = CompareConst.ERROR - message = "Unexpected output type in compare_core: {}".format(type(bench_output)) - - return status, compare_column, message - def _compare_torch_tensor(self, api_name, bench_output, device_output, compare_column): cpu_shape = bench_output.shape npu_shape = device_output.shape diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py index 549230d0a9e200283f545eed608a8da5df6a53a8..89c4401b2cac863bc609cce14a9f4c3ca03951b7 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py @@ -73,27 +73,27 @@ DETAIL_TEST_ROWS = [ precision_configs = { - torch.float16 : { - 'small_value' : [ + torch.float16: { + 'small_value': [ 1e-3 ], - 'small_value_atol' : [ + 'small_value_atol': [ 1e-5 ] }, torch.bfloat16: { - 'small_value' : [ + 'small_value': [ 1e-3 ], - 'small_value_atol' : [ + 'small_value_atol': [ 1e-5 ] }, - torch.float32:{ - 'small_value' : [ + torch.float32: { + 'small_value': [ 1e-6 ], - 'small_value_atol' : [ + 'small_value_atol': [ 1e-9 ] } @@ -101,33 +101,33 @@ precision_configs = { ULP_PARAMETERS = { - torch.float16 : { - 'min_eb' : [ + torch.float16: { + 'min_eb': [ -14 ], - 'exponent_num' : [ + 'exponent_num': [ 10 ] }, - torch.bfloat16 : { - 'min_eb' : [ + torch.bfloat16: { + 'min_eb': [ -126 ], - 'exponent_num' : [ + 'exponent_num': [ 7 ] }, - torch.float32 : { - 'min_eb' : [ + torch.float32: { + 'min_eb': [ -126 ], - 'exponent_num' : [ + 'exponent_num': [ 23 ] } } - - + + class ApiPrecisionCompareColumn: API_NAME = 'API Name' DEVICE_DTYPE = 'DEVICE Dtype' @@ -202,7 +202,7 @@ class ApiPrecisionCompareColumn: CompareMessage = { - "topk" : "在npu上,topk的入参sorted=False时不生效,会返回有序tensor,而cpu上会返回无序tensor。 如果topk精度不达标,请检查是否是该原因导致的。" + "topk": "在npu上,topk的入参sorted=False时不生效,会返回有序tensor,而cpu上会返回无序tensor。 如果topk精度不达标,请检查是否是该原因导致的。" } diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/config.yaml b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/config.yaml index 2ec9251009e61ef68dbfed987abe457d47b91e9a..30cea3b8e01f1c1a8a3a3d25620ba4bb2c9e709a 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/config.yaml +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/config.yaml @@ -8,3 +8,5 @@ host: "" port: -1 rank_list: [0] tls_path: "./" +master_ip: '127.0.0.1' +master_port: '2688' diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py index 797210f09c3b55a64002a4aa84a3d39770ae803c..c58c058674f31d8acb24a008104cdd32b1969726 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py @@ -28,10 +28,10 @@ from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import binary_st ulp_standard_api, thousandth_standard_api from msprobe.core.common.file_utils import FileOpen, load_json, save_json from msprobe.core.common.utils import check_file_or_directory_path, check_op_str_pattern_valid, is_int -from msprobe.core.common.const import Const, MonitorConst, MsgConst +from msprobe.core.common.const import Const, MonitorConst, MsgConst, FileCheckConst from msprobe.core.common.log import logger -from msprobe.core.common.file_utils import make_dir -from msprobe.core.common.utils import recursion_depth_decorator +from msprobe.core.common.file_utils import make_dir, change_mode +from msprobe.core.common.decorator import recursion_depth_decorator TENSOR_DATA_LIST = ["torch.Tensor", "torch.nn.parameter.Parameter"] TORCH_BOOL_TYPE = ["torch.bool"] @@ -50,6 +50,7 @@ DATA_NAME = "data_name" API_MAX_LENGTH = 30 PROPAGATION_LIST = [Const.FORWARD, Const.BACKWARD] DATAMODE_LIST = ["random_data", "real_data"] +ITER_MAX_TIMES = 1000 class APIInfo: @@ -97,6 +98,8 @@ class CommonConfig: iter_t = self.iter_times if iter_t <= 0: raise ValueError("iter_times should be an integer bigger than zero!") + if iter_t > ITER_MAX_TIMES: + raise ValueError("iter_times should not be greater than 1000!") json_file = self.extract_api_path propagation = self.propagation @@ -117,7 +120,7 @@ class CommonConfig: # Retrieve the first API name and dictionary forward_item = next(iter(json_content.items()), None) - if not forward_item or not isinstance(forward_item[1], dict): + if not forward_item or not isinstance(forward_item[1], dict) or not forward_item[1]: raise ValueError(f'Invalid forward API data in json_content!') # if propagation is backward, ensure json file contains forward and backward info @@ -127,7 +130,7 @@ class CommonConfig: # if propagation is backward, ensure it has valid data if propagation == Const.BACKWARD: backward_item = list(json_content.items())[1] - if not isinstance(backward_item[1], dict): + if not isinstance(backward_item[1], dict) or not backward_item[1]: raise ValueError(f'Invalid backward API data in json_content!') return json_content @@ -169,7 +172,7 @@ class APIExtractor: value = self.load_real_data_path(value, real_data_path) new_data[key] = value if not new_data: - logger.error(f"Error: The api '{self.api_name}' does not exist in the file.") + logger.warning(f"Warning: The api '{self.api_name}' does not exist in the file.") else: save_json(self.output_file, new_data, indent=4) logger.info( @@ -183,6 +186,7 @@ class APIExtractor: self.update_data_name(v, dump_data_dir) return value + @recursion_depth_decorator("OpGenerator: APIExtractor.update_data_name") def update_data_name(self, data, dump_data_dir): if isinstance(data, list): for item in data: @@ -407,19 +411,16 @@ class OperatorScriptGenerator: return kwargs_dict_generator - def _op_generator_parser(parser): - parser.add_argument("-i", "--config_input", dest="config_input", default='', type=str, - help=" Path of config json file", required=True) + parser.add_argument("-i", "--config_input", dest="config_input", type=str, + help=" Path of config json file", required=True) parser.add_argument("-o", "--api_output_path", dest="api_output_path", type=str, - help=" Path of extract api_name.json.", - required=True) + help=" Path of extract api_name.json.", required=True) def parse_json_config(json_file_path): if not json_file_path: - config_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) - json_file_path = os.path.join(config_dir, "config.json") + raise Exception("config_input path can not be empty, please check.") json_config = load_json(json_file_path) common_config = CommonConfig(json_config) return common_config @@ -467,6 +468,7 @@ def _run_operator_generate_commond(cmd_args): fout.write(code_template.format(**internal_settings)) except OSError: logger.error(f"Failed to open file. Please check file {template_path} or {operator_script_path}.") + change_mode(operator_script_path, FileCheckConst.DATA_FILE_AUTHORITY) logger.info(f"Generate operator script successfully and the name is {operator_script_path}.") diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template index 131fd211ad82dad8256c48e59195fc335efa936b..c60d84994745e94bef6d05a78d83fae81df7ed1e 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template @@ -1,6 +1,6 @@ -import json import os -import math +import re +import stat from enum import Enum, auto import torch try: @@ -25,6 +25,31 @@ RAISE_PRECISION = {{ }} THOUSANDTH_THRESHOLDING = 0.001 BACKWARD = 'backward' +DIR = "dir" +FILE = "file" +READ_ABLE = "read" +WRITE_ABLE = "write" +READ_WRITE_ABLE = "read and write" +DIRECTORY_LENGTH = 4096 +FILE_NAME_LENGTH = 255 +SOFT_LINK_ERROR = "检测到软链接" +FILE_PERMISSION_ERROR = "文件权限错误" +INVALID_FILE_ERROR = "无效文件" +ILLEGAL_PATH_ERROR = "非法文件路径" +ILLEGAL_PARAM_ERROR = "非法打开方式" +FILE_TOO_LARGE_ERROR = "文件过大" +FILE_VALID_PATTERN = r"^[a-zA-Z0-9_.:/-]+$" +FILE_SIZE_DICT = {{ + ".pkl": 1073741824, # 1 * 1024 * 1024 * 1024 + ".npy": 10737418240, # 10 * 1024 * 1024 * 1024 + ".json": 1073741824, # 1 * 1024 * 1024 * 1024 + ".pt": 10737418240, # 10 * 1024 * 1024 * 1024 + ".csv": 1073741824, # 1 * 1024 * 1024 * 1024 + ".xlsx": 1073741824, # 1 * 1024 * 1024 * 1024 + ".yaml": 1073741824, # 1 * 1024 * 1024 * 1024 + ".ir": 1073741824 # 1 * 1024 * 1024 * 1024 +}} +COMMOM_FILE_SIZE = 1048576 # 1 * 1024 * 1024 class CompareStandard(Enum): BINARY_EQUALITY_STANDARD = auto() @@ -33,13 +58,189 @@ class CompareStandard(Enum): BENCHMARK_STANDARD = auto() THOUSANDTH_STANDARD = auto() +class FileChecker: + """ + The class for check file. + + Attributes: + file_path: The file or dictionary path to be verified. + path_type: file or dictionary + ability(str): FileCheckConst.WRITE_ABLE or FileCheckConst.READ_ABLE to set file has writability or readability + file_type(str): The correct file type for file + """ + + def __init__(self, file_path, path_type, ability=None, file_type=None, is_script=True): + self.file_path = file_path + self.path_type = self._check_path_type(path_type) + self.ability = ability + self.file_type = file_type + self.is_script = is_script + + @staticmethod + def _check_path_type(path_type): + if path_type not in [DIR, FILE]: + print(f'ERROR: The path_type must be {{DIR}} or {{FILE}}.') + raise Exception(ILLEGAL_PARAM_ERROR) + return path_type + + def common_check(self): + """ + 功能:用户校验基本文件权限:软连接、文件长度、是否存在、读写权限、文件属组、文件特殊字符 + 注意:文件后缀的合法性,非通用操作,可使用其他独立接口实现 + """ + FileChecker.check_path_exists(self.file_path) + FileChecker.check_link(self.file_path) + self.file_path = os.path.realpath(self.file_path) + FileChecker.check_path_length(self.file_path) + FileChecker.check_path_type(self.file_path, self.path_type) + self.check_path_ability() + if self.is_script: + FileChecker.check_path_owner_consistent(self.file_path) + FileChecker.check_path_pattern_valid(self.file_path) + FileChecker.check_common_file_size(self.file_path) + FileChecker.check_file_suffix(self.file_path, self.file_type) + if self.path_type == FILE: + FileChecker.check_dirpath_before_read(self.file_path) + return self.file_path + + def check_path_ability(self): + if self.ability == WRITE_ABLE: + FileChecker.check_path_writability(self.file_path) + if self.ability == READ_ABLE: + FileChecker.check_path_readability(self.file_path) + if self.ability == READ_WRITE_ABLE: + FileChecker.check_path_readability(self.file_path) + FileChecker.check_path_writability(self.file_path) + + @staticmethod + def check_path_exists(path): + if not os.path.exists(path): + print(f'ERROR: The file path %s does not exist.' % path) + raise Exception() + + @staticmethod + def check_link(path): + abs_path = os.path.abspath(path) + if os.path.islink(abs_path): + print('ERROR: The file path {{}} is a soft link.'.format(path)) + raise Exception(SOFT_LINK_ERROR) + + @staticmethod + def check_path_length(path, name_length=None): + file_max_name_length = name_length if name_length else FILE_NAME_LENGTH + if len(path) > DIRECTORY_LENGTH or \ + len(os.path.basename(path)) > file_max_name_length: + print(f'ERROR: The file path length exceeds limit.') + raise Exception(ILLEGAL_PATH_ERROR) + + @staticmethod + def check_path_type(file_path, file_type): + if file_type == FILE: + if not os.path.isfile(file_path): + print(f"ERROR: The {{file_path}} should be a file!") + raise Exception(INVALID_FILE_ERROR) + if file_type == DIR: + if not os.path.isdir(file_path): + print(f"ERROR: The {{file_path}} should be a dictionary!") + raise Exception(INVALID_FILE_ERROR) + + @staticmethod + def check_path_owner_consistent(path): + file_owner = os.stat(path).st_uid + if file_owner != os.getuid() and os.getuid() != 0: + print('ERROR: The file path %s may be insecure because is does not belong to you.' % path) + raise Exception(FILE_PERMISSION_ERROR) + + @staticmethod + def check_path_pattern_valid(path): + if not re.match(FILE_VALID_PATTERN, path): + print('ERROR: The file path %s contains special characters.' % (path)) + raise Exception(ILLEGAL_PATH_ERROR) + + @staticmethod + def check_common_file_size(file_path): + if os.path.isfile(file_path): + for suffix, max_size in FILE_SIZE_DICT.items(): + if file_path.endswith(suffix): + FileChecker.check_file_size(file_path, max_size) + return + FileChecker.check_file_size(file_path, COMMOM_FILE_SIZE) + + @staticmethod + def check_file_size(file_path, max_size): + try: + file_size = os.path.getsize(file_path) + except OSError as os_error: + print(f'ERROR: Failed to open "{{file_path}}". {{str(os_error)}}') + raise Exception(INVALID_FILE_ERROR) from os_error + if file_size >= max_size: + print(f'ERROR: The size ({{file_size}}) of {{file_path}} exceeds ({{max_size}}) bytes, tools not support.') + raise Exception(FILE_TOO_LARGE_ERROR) + + @staticmethod + def check_file_suffix(file_path, file_suffix): + if file_suffix: + if not file_path.endswith(file_suffix): + print(f"The {{file_path}} should be a {{file_suffix}} file!") + raise Exception(INVALID_FILE_ERROR) + + @staticmethod + def check_dirpath_before_read(path): + path = os.path.realpath(path) + dirpath = os.path.dirname(path) + if FileChecker.check_others_writable(dirpath): + print(f"WARNING: The directory is writable by others: {{dirpath}}.") + try: + FileChecker.check_path_owner_consistent(dirpath) + except Exception: + print(f"WARNING: The directory {{dirpath}} is not yours.") + + @staticmethod + def check_others_writable(directory): + dir_stat = os.stat(directory) + is_writable = ( + bool(dir_stat.st_mode & stat.S_IWGRP) or # 组可写 + bool(dir_stat.st_mode & stat.S_IWOTH) # 其他用户可写 + ) + return is_writable + + @staticmethod + def check_path_readability(path): + if not os.access(path, os.R_OK): + print('ERROR: The file path %s is not readable.' % path) + raise Exception(FILE_PERMISSION_ERROR) + + @staticmethod + def check_path_writability(path): + if not os.access(path, os.W_OK): + print('ERROR: The file path %s is not writable.' % path) + raise Exception(FILE_PERMISSION_ERROR) + + +def check_file_or_directory_path(path, isdir=False): + """ + Function Description: + check whether the path is valid + Parameter: + path: the path to check + isdir: the path is dir or file + Exception Description: + when invalid data throw exception + """ + if isdir: + path_checker = FileChecker(path, DIR, WRITE_ABLE) + else: + path_checker = FileChecker(path, FILE, READ_ABLE) + path_checker.common_check() + def load_pt(pt_path, to_cpu=False): pt_path = os.path.realpath(pt_path) + check_file_or_directory_path(pt_path) try: if to_cpu: - pt = torch.load(pt_path, map_location=torch.device("cpu")) + pt = torch.load(pt_path, map_location=torch.device("cpu"), weights_only=True) else: - pt = torch.load(pt_path) + pt = torch.load(pt_path, weights_only=True) except Exception as e: raise RuntimeError(f"load pt file {{pt_path}} failed") from e return pt @@ -202,6 +403,7 @@ def compare_tensor(out_device, out_bench, api_name): else: abs_err = torch.abs(out_device - out_bench) abs_bench = torch.abs(out_bench) + eps = 2 ** -23 if dtype_bench == torch.float32: eps = 2 ** -23 if dtype_bench == torch.float64: diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py index 9d89b2de32f70c6fa7abf38add49b58a13531d7a..15e14b68c7da4f2c7fadd4e0285c79fec5fa78f1 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py @@ -1,9 +1,7 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # -# Licensed under the Apache License, Version 2.0 (the "License"); +# Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # @@ -15,20 +13,27 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os import math -import torch +import os + import numpy +import torch -from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import hf_32_standard_api +from msprobe.core.common.const import Const, FileCheckConst, CompareConst, DistributedCheckConst +from msprobe.core.common.file_utils import FileChecker, load_npy from msprobe.pytorch.api_accuracy_checker.common.utils import check_object_type, get_full_data_path, \ CompareException, get_module_and_atttribute_name, get_attribute -from msprobe.core.common.file_utils import FileChecker, load_npy +from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import hf_32_standard_api from msprobe.pytorch.common.log import logger from msprobe.pytorch.common.utils import load_pt -from msprobe.core.common.const import Const, FileCheckConst, CompareConst +from msprobe.pytorch.hook_module.api_register import get_api_register +api_register = get_api_register(return_new=True) +api_register.initialize_hook(None) +distribute_api_key = Const.PT_FRAMEWORK + Const.SEP + Const.PT_API_TYPE_DIST +distribute_api_list = list(api_register.ori_api_attr.get(distribute_api_key, {}).keys()) + TORCH_TYPE = ["torch.device", "torch.dtype"] TENSOR_DATA_LIST = ["torch.Tensor", "torch.nn.parameter.Parameter"] FLOAT_TYPE = [ @@ -68,7 +73,7 @@ def gen_data(info, api_name, need_grad, convert_type, real_data_path=None): data = gen_random_tensor(info, convert_type) if api_name in hf_32_standard_api and data.dtype == torch.float32: data = fp32_to_hf32_to_fp32(data) - if info.get('requires_grad') and need_grad: + if info.get('requires_grad') and need_grad and api_name not in distribute_api_list: data.requires_grad_(True) temp_data = data * 1 data = temp_data.type_as(data) @@ -261,11 +266,14 @@ def gen_args(args_info, api_name, func_options): Function Description: Based on API basic information, generate input parameters: args, for API forward running Parameter: - api_info: API basic information. List + args_info: API basic information. DICT api_name: API name - need_grad: set Tensor grad for backward - convert_type: convert ori_type to dist_type flag. - real_data_path: the root directory for storing real data. + func_options: the options for generating args. Dict + need_grad: set Tensor grad for backward + convert_type: convert ori_type to dist_type flag. + real_data_path: the root directory for storing real data. + depth: the depth of recursion. + kwargs_params: the input kwargs parameters. """ check_object_type(args_info, list) args_result = [] @@ -274,6 +282,7 @@ def gen_args(args_info, api_name, func_options): convert_type = func_options.get('convert_type', None) real_data_path = func_options.get('real_data_path', None) depth = func_options.get('depth', 0) + kwargs_params = func_options.get('input_kwargs', {}) if depth > Const.MAX_DEPTH: logger.error("The depth of args is too large, please check the input args.") @@ -284,7 +293,11 @@ def gen_args(args_info, api_name, func_options): func_options['depth'] = depth + 1 data = gen_args(arg, api_name, func_options) elif isinstance(arg, dict): - data = gen_data(arg, api_name, need_grad, convert_type, real_data_path) + if arg.get('type') == DistributedCheckConst.TORCH_PROCESS_GROUP: + data = None + kwargs_params[DistributedCheckConst.GROUP] = arg + else: + data = gen_data(arg, api_name, need_grad, convert_type, real_data_path) elif arg is None: data = None else: @@ -311,6 +324,8 @@ def gen_kwargs(api_info, api_name, convert_type=None, real_data_path=None): kwargs_params[key] = gen_list_kwargs(value, api_name, convert_type, real_data_path) elif value is None: kwargs_params[key] = None + elif key == DistributedCheckConst.GROUP and value.get('type') == DistributedCheckConst.TORCH_PROCESS_GROUP: + kwargs_params[key] = value elif key == 'atten_mask' and api_name == 'npu_fusion_attention': sparse_mode = kwargs_params.get('sparse_mode', {}) if isinstance(sparse_mode, dict): @@ -415,17 +430,19 @@ def gen_api_params(api_info, api_name, need_grad=True, convert_type=None, real_d if convert_type and convert_type not in Const.CONVERT: error_info = f"convert_type params not support {convert_type}." raise CompareException(CompareException.INVALID_PARAM_ERROR, error_info) - kwargs_params = gen_kwargs(api_info, api_name, convert_type, real_data_path) + func_options = { 'need_grad': need_grad, 'convert_type': convert_type, 'real_data_path': real_data_path, - 'depth': 0 + 'depth': 0, + 'input_kwargs': api_info.get("input_kwargs", {}) } if api_info.get("input_args"): args_params = gen_args(api_info.get("input_args"), api_name, func_options) else: logger.warning(f'Warning: No args in {api_info} ') args_params = [] + kwargs_params = gen_kwargs(api_info, api_name, convert_type, real_data_path) output_dtype = get_output_dtype(api_info) return args_params, kwargs_params, output_dtype diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/distributed_bench_function.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/distributed_bench_function.py new file mode 100644 index 0000000000000000000000000000000000000000..18ff05bc00c2c5271e965dbd91fd54be1d410876 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/distributed_bench_function.py @@ -0,0 +1,204 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from msprobe.core.common.const import DistributedCheckConst +from msprobe.pytorch.api_accuracy_checker.common.utils import check_object_type +from msprobe.pytorch.api_accuracy_checker.compare.algorithm import kahan_range +from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import get_distributed_args + + +def sort_all_input(inputs): + ranks = len(inputs) + if ranks <= 1: + return inputs + combined_tensor = torch.stack(inputs) + sorted_indices = torch.argsort(combined_tensor, descending=True, dim=0) + combined_tensor = torch.gather(combined_tensor, 0, sorted_indices) + sorted_inputs = [combined_tensor[i] for i in range(ranks)] + return sorted_inputs + + +def reduce_sum(tensors): + min_bound = torch.min( + kahan_range(tensors, negative=False), + kahan_range(tensors[::-1], negative=False), + ) + max_bound = torch.max( + kahan_range(tensors, negative=True), kahan_range(tensors[::-1], negative=True) + ) + tensors_sorted = sort_all_input(tensors) + min_sorted_bound = torch.min( + kahan_range(tensors_sorted, negative=False), + kahan_range(tensors_sorted[::-1], negative=False), + ) + max_sorted_bound = torch.max( + kahan_range(tensors_sorted, negative=True), + kahan_range(tensors_sorted[::-1], negative=True), + ) + return torch.min(min_bound, min_sorted_bound), torch.max( + max_bound, max_sorted_bound + ) + + +def reduce_product(tensors): + return torch.stack(tensors).prod(dim=0) + + +def reduce_min(tensors): + return torch.stack(tensors).min(dim=0).values + + +def reduce_max(tensors): + return torch.stack(tensors).max(dim=0).values + + +def reduce_band(tensors): + reduce_tensor = tensors[0].clone() + if len(tensors) > 1: + for t in tensors[1:]: + reduce_tensor &= t + return reduce_tensor + + +def reduce_bor(tensors): + reduce_tensor = tensors[0].clone() + if len(tensors) > 1: + for t in tensors[1:]: + reduce_tensor |= t + return reduce_tensor + + +def reduce_bxor(tensors): + reduce_tensor = tensors[0].clone() + if len(tensors) > 1: + for t in tensors[1:]: + reduce_tensor ^= t + return reduce_tensor + + +def mock_broadcast(api_name, input_args, input_kwargs): + check_object_type(input_args, list) + check_object_type(input_kwargs, list) + if len(input_args) < 1 or len(input_kwargs) < 1: + raise ValueError("input_args and input_kwargs should have at least 1 element") + + src = get_distributed_args(api_name, input_args[0], input_kwargs[0], DistributedCheckConst.SRC) + + group = get_distributed_args(api_name, input_args[0], input_kwargs[0], DistributedCheckConst.GROUP) + group_ranks = group.get(DistributedCheckConst.GROUP_RANKS, []) + if not group_ranks: + raise ValueError("group_ranks should not be empty") + real_src = src - min(group_ranks) + if len(input_args) <= real_src: + raise ValueError("input_args should have at least {} element".format(real_src + 1)) + + return input_args[real_src][0] + + +def mock_reduce(api_name, input_args, input_kwargs): + check_object_type(input_args, list) + check_object_type(input_kwargs, list) + if len(input_args) < 1 or len(input_kwargs) < 1: + raise ValueError("input_args and input_kwargs should have at least 1 element") + + reduce_op = get_distributed_args(api_name, input_args[0], input_kwargs[0], DistributedCheckConst.OP) + tensors = [] + for arg in input_args: + if len(arg) > 0: + tensors.append(arg[0]) + reduce_tensor = None + if not tensors: + return reduce_tensor + reduce_ops = { + DistributedCheckConst.REDOPTYPE_SUM: reduce_sum, + DistributedCheckConst.REDOPTYPE_PRODUCT: reduce_product, + DistributedCheckConst.REDOPTYPE_MIN: reduce_min, + DistributedCheckConst.REDOPTYPE_MAX: reduce_max, + DistributedCheckConst.REDOPTYPE_BAND: reduce_band, + DistributedCheckConst.REDOPTYPE_BOR: reduce_bor, + DistributedCheckConst.REDOPTYPE_BXOR: reduce_bxor, + } + if reduce_op not in reduce_ops: + raise ValueError(f"Unsupported reduce operation: {reduce_op}") + reduce_tensor = reduce_ops[reduce_op](tensors) + + return reduce_tensor + + +def mock_scatter(api_name, input_args, input_kwargs): + check_object_type(input_args, list) + check_object_type(input_kwargs, list) + if len(input_args) < 1 or len(input_kwargs) < 1: + raise ValueError("input_args and input_kwargs should have at least 1 element") + + src = get_distributed_args(api_name, input_args[0], input_kwargs[0], DistributedCheckConst.SRC) + group = get_distributed_args(api_name, input_args[0], input_kwargs[0], DistributedCheckConst.GROUP) + group_ranks = group.get(DistributedCheckConst.GROUP_RANKS, []) + if not group_ranks: + raise ValueError("group_ranks should not be empty") + real_src = src - min(group_ranks) + if len(input_args) <= real_src: + raise ValueError("input_args should have at least {} element".format(real_src + 1)) + scatter_list = get_distributed_args(api_name, input_args[real_src], input_kwargs[real_src], + DistributedCheckConst.SCATTER_LIST) + return scatter_list + + +def mock_all_gather(api_name, input_args, input_kwargs): + check_object_type(input_args, list) + check_object_type(input_kwargs, list) + gather_tensor = [] + for data in input_args: + if len(data) > 1: + gather_tensor.append(data[1]) + return gather_tensor + + +def mock_all_to_all(api_name, input_args, input_kwargs): + check_object_type(input_args, list) + check_object_type(input_kwargs, list) + input_tensor_list = [] + for data in input_args: + if len(data) >= 2: + input_tensor_list.append(data[1]) + world_size = len(input_tensor_list) + output_tensor_list = [] + for rank in range(world_size): + output_chunk = [] + for data in input_tensor_list: + if len(data) <= rank: + raise ValueError("input_tensor_list should have at least {} element".format(rank + 1)) + output_chunk.append(data[rank]) + output_tensor_list.append(output_chunk) + return output_tensor_list + + +def mock_all_to_all_single(api_name, input_args, input_kwargs): + check_object_type(input_args, list) + check_object_type(input_kwargs, list) + input_tensor_list = [] + for data in input_args: + if len(data) >= 2: + input_tensor_list.append(data[1]) + if not input_tensor_list: + return [] + input_tensor = torch.stack(input_tensor_list) + output_tensor = input_tensor.t() + output_tensor_list = [tensor.clone() for tensor in output_tensor] + return output_tensor_list diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/distributed_compare_function.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/distributed_compare_function.py new file mode 100644 index 0000000000000000000000000000000000000000..f7cf95a1d0d9060b75a45a360e6a4d5d8b087637 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/distributed_compare_function.py @@ -0,0 +1,116 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +import torch +import tqdm + +from msprobe.core.common.const import CompareConst, DistributedCheckConst + + +def cumulative_check(rank, inputs, output, min_bound, max_bound): + # 检查每个元素是否在最小值和最大值之间 + res = CompareConst.PASS + out_of_bounds = torch.nonzero((output < min_bound) | (output > max_bound)) + if out_of_bounds.shape[0] == 0: + return res + # 对超出范围的值进行累加序遍历检查 + perms = list(itertools.permutations(list(range(len(inputs))))) + if len(out_of_bounds) > DistributedCheckConst.MAX_CUMSUM_CHECK_NUM: + res = CompareConst.WARNING + out_of_bounds = out_of_bounds[: DistributedCheckConst.MAX_CUMSUM_CHECK_NUM] + pbar = tqdm.tqdm( + out_of_bounds, + position=rank + 1, + desc=f"Suspicious cumulative result check for rank{rank}", + ) + for indice in pbar: + indice_tuple = tuple(indice) + input_values = torch.stack([input_[indice_tuple] for input_ in inputs])[perms] + for i in range(1, len(inputs)): + input_values[:, 0] += input_values[:, i] + if output[indice_tuple] not in input_values[:, 0]: + res = CompareConst.ERROR + break + pbar.close() + return res + + +def compare_broadcast(device_out, bench_out, **kwargs): + if len(device_out) < 1: + raise ValueError("device_out should not be empty") + compare_result = torch.equal(device_out[0].cpu(), bench_out) + + return CompareConst.PASS if compare_result else CompareConst.ERROR + + +def compare_all_reduce(device_out, bench_out, **kwargs): + if len(device_out) < 1: + raise ValueError("device_out should not be empty") + if isinstance(bench_out, tuple): + rank = kwargs.get("local_rank", 0) + input_args = kwargs.get("input_args", []) + tensors = [] + for arg in input_args: + if len(arg) > 0: + tensors.append(arg[0]) + if len(tensors) < 1: + raise ValueError("input_args should have at least 1 element") + result = cumulative_check(rank, tensors, device_out[0].cpu(), *bench_out) + else: + compare_result = torch.equal(device_out[0].cpu(), bench_out) + result = CompareConst.PASS if compare_result else CompareConst.ERROR + return result + + +def compare_scatter(device_out, bench_out, **kwargs): + rank = kwargs.get("local_rank", 0) + if len(device_out) < 1: + raise ValueError("device_out should not be empty") + if len(bench_out) <= rank: + raise ValueError("bench_out should have at least rank+1 outputs") + compare_result = torch.equal(device_out[0].cpu(), bench_out[rank]) + + return CompareConst.PASS if compare_result else CompareConst.ERROR + + +def compare_all_gather(device_out, bench_out, **kwargs): + if len(device_out) < 1: + raise ValueError("device_out should not be empty") + device_out_cpu = [tensor.cpu() for tensor in device_out[0]] + compare_result = all(torch.equal(a, b) for a, b in zip(device_out_cpu, bench_out)) + + return CompareConst.PASS if compare_result else CompareConst.ERROR + + +def compare_all_to_all(device_out, bench_out, **kwargs): + rank = kwargs.get("local_rank", 0) + if len(device_out) < 1: + raise ValueError("device_out should not be empty") + device_out_cpu = [tensor.cpu() for tensor in device_out[0]] + compare_result = all(torch.equal(a, b) for a, b in zip(device_out_cpu, bench_out[rank])) + + return CompareConst.PASS if compare_result else CompareConst.ERROR + + +def compare_all_to_all_single(device_out, bench_out, **kwargs): + rank = kwargs.get("local_rank", 0) + if len(device_out) < 1: + raise ValueError("device_out should not be empty") + compare_result = torch.equal(device_out[0].cpu(), bench_out[rank]) + + return CompareConst.PASS if compare_result else CompareConst.ERROR diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/distributed_function_registry.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/distributed_function_registry.py new file mode 100644 index 0000000000000000000000000000000000000000..6758b4ff4f8b286477880f74cd34e3516060c3fb --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/distributed_function_registry.py @@ -0,0 +1,68 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable + +from msprobe.pytorch.api_accuracy_checker.run_ut.distributed_bench_function import \ + mock_broadcast, mock_reduce, mock_scatter, mock_all_gather, mock_all_to_all, \ + mock_all_to_all_single +from msprobe.pytorch.api_accuracy_checker.run_ut.distributed_compare_function import \ + compare_broadcast, compare_all_reduce, compare_scatter, \ + compare_all_gather, compare_all_to_all, compare_all_to_all_single +from msprobe.core.common.const import DistributedCheckConst + + +class DistributedFunctionRegistry: + def __init__(self): + self.compare_functions = {} + self.bench_functions = {} + self.support_api_list = [DistributedCheckConst.BROADCAST, DistributedCheckConst.ALL_REDUCE, + DistributedCheckConst.SCATTER, DistributedCheckConst.ALL_GATHER, + DistributedCheckConst.ALL_TO_ALL, DistributedCheckConst.ALL_TO_ALL_SINGLE] + + def register_compare_function(self, api_name: str, function: Callable): + self.compare_functions[api_name] = function + + def register_bench_function(self, api_name: str, function: Callable): + self.bench_functions[api_name] = function + + def register_functions(self, functions_dict): + for api_name, (bench_function, compare_function) in functions_dict.items(): + self.register_bench_function(api_name, bench_function) + self.register_compare_function(api_name, compare_function) + + def get_compare_function(self, api_name: str) -> Callable: + if not self.compare_functions.get(api_name): + raise Exception("No compare function registered for api: {}".format(api_name)) + return self.compare_functions.get(api_name) + + def get_bench_function(self, api_name: str) -> Callable: + if not self.bench_functions.get(api_name): + raise Exception("No benchmark function registered for api: {}".format(api_name)) + return self.bench_functions.get(api_name) + + +functions_map = { + DistributedCheckConst.BROADCAST: (mock_broadcast, compare_broadcast), + DistributedCheckConst.ALL_REDUCE: (mock_reduce, compare_all_reduce), + DistributedCheckConst.SCATTER: (mock_scatter, compare_scatter), + DistributedCheckConst.ALL_GATHER: (mock_all_gather, compare_all_gather), + DistributedCheckConst.ALL_TO_ALL: (mock_all_to_all, compare_all_to_all), + DistributedCheckConst.ALL_TO_ALL_SINGLE: (mock_all_to_all_single, compare_all_to_all_single) +} +distributed_func_registry = DistributedFunctionRegistry() +distributed_func_registry.register_functions(functions_map) diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py index 498102b475f564564d6039a81e305fba3bceec17..110685e9900b9557f78da537bf23dd9ac1c14b11 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py @@ -50,6 +50,9 @@ def split_json_file(input_file, num_splits, filter_api): backward_data[f"{data_name}.backward"] = backward_data.pop(data_name) input_data = load_json(input_file) + if "dump_data_dir" not in input_data.keys(): + logger.error("Invalid input file, 'dump_data_dir' field is missing") + raise CompareException("Invalid input file, 'dump_data_dir' field is missing") if input_data.get("data") is None: logger.error("Invalid input file, 'data' field is missing") raise CompareException("Invalid input file, 'data' field is missing") @@ -67,7 +70,7 @@ def split_json_file(input_file, num_splits, filter_api): split_forward_data = dict(items[start:end]) temp_data = { **input_data, - "data":{ + "data": { **split_forward_data, **backward_data } @@ -84,10 +87,6 @@ def signal_handler(signum, frame): raise KeyboardInterrupt() -signal.signal(signal.SIGINT, signal_handler) -signal.signal(signal.SIGTERM, signal_handler) - - ParallelUTConfig = namedtuple('ParallelUTConfig', ['api_files', 'out_path', 'num_splits', 'save_error_data_flag', 'jit_compile_flag', 'device_id', 'result_csv_path', 'total_items', 'config_path']) @@ -97,7 +96,7 @@ def run_parallel_ut(config): processes = [] device_id_cycle = cycle(config.device_id) if config.save_error_data_flag: - logger.info("UT task error datas will be saved") + logger.info("UT task error data will be saved") logger.info(f"Starting parallel UT with {config.num_splits} processes") progress_bar = tqdm(total=config.total_items, desc="Total items", unit="items") @@ -129,6 +128,9 @@ def run_parallel_ut(config): sys.stdout.flush() except ValueError as e: logger.warning(f"An error occurred while reading subprocess output: {e}") + finally: + if process.poll() is None: + process.stdout.close() def update_progress_bar(progress_bar, result_csv_path): while any(process.poll() is None for process in processes): @@ -139,7 +141,7 @@ def run_parallel_ut(config): for api_info in config.api_files: cmd = create_cmd(api_info, next(device_id_cycle)) - process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL, + process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL, text=True, bufsize=1, shell=False) processes.append(process) threading.Thread(target=read_process_output, args=(process,), daemon=True).start() @@ -185,8 +187,8 @@ def run_parallel_ut(config): def prepare_config(args): - api_info_file_checker = FileChecker(file_path=args.api_info_file, path_type=FileCheckConst.FILE, - ability=FileCheckConst.READ_ABLE, file_type=FileCheckConst.JSON_SUFFIX) + api_info_file_checker = FileChecker(file_path=args.api_info_file, path_type=FileCheckConst.FILE, + ability=FileCheckConst.READ_ABLE, file_type=FileCheckConst.JSON_SUFFIX) api_info = api_info_file_checker.common_check() out_path = args.out_path if args.out_path else Const.DEFAULT_PATH create_directory(out_path) @@ -195,11 +197,11 @@ def prepare_config(args): split_files, total_items = split_json_file(api_info, args.num_splits, args.filter_api) config_path = args.config_path if args.config_path else None if config_path: - config_path_checker = FileChecker(config_path, FileCheckConst.FILE, + config_path_checker = FileChecker(config_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE, FileCheckConst.JSON_SUFFIX) config_path = config_path_checker.common_check() result_csv_path = args.result_csv_path or os.path.join( - out_path, f"accuracy_checking_result_{time.strftime('%Y%m%d%H%M%S')}.csv") + out_path, f"accuracy_checking_result_{time.strftime('%Y%m%d%H%M%S')}.csv") if not args.result_csv_path: details_csv_path = os.path.join(out_path, f"accuracy_checking_details_{time.strftime('%Y%m%d%H%M%S')}.csv") comparator = Comparator(result_csv_path, details_csv_path, False) @@ -214,14 +216,12 @@ def prepare_config(args): def main(): + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTERM, signal_handler) parser = argparse.ArgumentParser(description='Run UT in parallel') _run_ut_parser(parser) - parser.add_argument('-n', '--num_splits', type=int, choices=range(1, 65), default=8, + parser.add_argument('-n', '--num_splits', type=int, choices=range(1, 65), default=8, help='Number of splits for parallel processing. Range: 1-64') args = parser.parse_args() config = prepare_config(args) run_parallel_ut(config) - - -if __name__ == '__main__': - main() diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_distributed_check.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_distributed_check.py new file mode 100644 index 0000000000000000000000000000000000000000..54f3790bbc048a9265419a52e18519b77ab25de8 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_distributed_check.py @@ -0,0 +1,254 @@ +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import sys +import time +from collections import namedtuple +import copy + +import torch_npu +import torch.distributed as dist +import torch.multiprocessing as mp + +from msprobe.core.common.const import Const, FileCheckConst, DistributedCheckConst, CompareConst +from msprobe.core.common.file_utils import FileChecker, write_csv, create_directory +from msprobe.core.compare.utils import check_and_return_dir_contents +from msprobe.pytorch.api_accuracy_checker.common.config import CheckerConfig +from msprobe.pytorch.api_accuracy_checker.common.utils import extract_basic_api_segments +from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import generate_device_params, get_group_info, \ + is_port_in_use +from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut import get_api_info +from msprobe.pytorch.api_accuracy_checker.run_ut.distributed_function_registry import distributed_func_registry +from msprobe.pytorch.common.log import logger +from msprobe.pytorch.common.parse_json import parse_json_info_forward_backward +from msprobe.pytorch.hook_module.api_register import get_api_register +from msprobe.pytorch.pt_config import parse_json_config + + +api_register = get_api_register(return_new=True) +api_register.initialize_hook(None) +distribute_api_key = Const.PT_FRAMEWORK + Const.SEP + Const.PT_API_TYPE_DIST +distributed_func = api_register.ori_api_attr.get(distribute_api_key, {}) + +current_time = time.strftime("%Y%m%d%H%M%S") +RESULT_FILE_NAME = "accuracy_checking_result_" + current_time + ".csv" +RESULT_CSV_HEADER = [['API_NAME', 'RANK', 'COMPARE_RESULT', 'MESSAGE']] +DistributedCheckParams = namedtuple("DistributedCheckParams", ["api_full_name", "all_args", "all_kwargs", + "group_ranks", "result_file_path", "checker_config"]) +special_rank_api_list = [DistributedCheckConst.SCATTER, + DistributedCheckConst.ALL_TO_ALL, + DistributedCheckConst.ALL_TO_ALL_SINGLE] + + +def cleanup(): + dist.destroy_process_group() + + +def distributed_setup(rank, world_size, master_ip, master_port): + init_method = DistributedCheckConst.TCP + Const.COLON + Const.DOUBLE_SLASH + master_ip + Const.COLON + master_port + dist.init_process_group(backend=DistributedCheckConst.HCCL, init_method=init_method, + world_size=world_size, rank=rank) + + +def parse_distributed_api(forward_content): + distributed_api = {} + for api_full_name, api_info_dict in forward_content.items(): + split_name = api_full_name.split(Const.SEP)[0] + if split_name == Const.DISTRIBUTED: + distributed_api.update({api_full_name: api_info_dict}) + return distributed_api + + +def _run_distributed_parser(parser): + parser.add_argument("-api_info", "--api_info_dir", dest="api_info_dir", default="", type=str, + help=" The api param tool result dir: generate from api param tool. ", + required=True) + parser.add_argument("-o", "--out_path", dest="out_path", default="", type=str, + help=" The ut task result out path.", + required=False) + parser.add_argument("-config", "--config_path", dest="config_path", default="", type=str, + help=" The path of config.json", required=False) + + +def _run_distributed(parser=None): + if parser is None: + parser = argparse.ArgumentParser() + _run_distributed_parser(parser) + args = parser.parse_args(sys.argv[1:]) + run_distributed_command(args) + + +def run_distributed_command(args): + input_checker = FileChecker(args.api_info_dir, FileCheckConst.DIR, ability=FileCheckConst.READ_ABLE) + api_info_dir = input_checker.common_check() + ranks = sorted(check_and_return_dir_contents(api_info_dir, Const.RANK)) + file_paths = [os.path.join(api_info_dir, rank, 'dump.json') for rank in ranks] + forward_contents = [] + real_data_paths = [] + for file_path in file_paths: + forward_content, _, real_data_path = parse_json_info_forward_backward(file_path) + if real_data_path: + dump_path = os.path.dirname(file_path) + real_data_path = os.path.join(dump_path, Const.DUMP_TENSOR_DATA) + distributed_api = parse_distributed_api(forward_content) + forward_contents.append(distributed_api) + real_data_paths.append(real_data_path) + + out_path = args.out_path if args.out_path else Const.DEFAULT_PATH + create_directory(out_path) + out_path_checker = FileChecker(out_path, FileCheckConst.DIR, ability=FileCheckConst.WRITE_ABLE) + out_path = out_path_checker.common_check() + result_file_path = os.path.join(out_path, RESULT_FILE_NAME) + write_csv(RESULT_CSV_HEADER, result_file_path) + if args.config_path: + config_path_checker = FileChecker(args.config_path, FileCheckConst.FILE, + FileCheckConst.READ_ABLE, FileCheckConst.JSON_SUFFIX) + checked_config_path = config_path_checker.common_check() + _, task_config = parse_json_config(checked_config_path, Const.RUN_UT) + checker_config = CheckerConfig(task_config) + else: + checker_config = CheckerConfig() + run_distributed_check(forward_contents, real_data_paths, result_file_path, checker_config) + + +def run_distributed_check(forward_contents, real_data_paths, result_file_path, checker_config): + for rank, forward_content in enumerate(forward_contents): + logger.info("Start to check distributed api in rank {}.".format(rank)) + + for api_full_name, api_info_dict in forward_content.items(): + _, api_name = extract_basic_api_segments(api_full_name) + + if api_name not in distributed_func_registry.support_api_list: + message = "The api {} doesn't support distributed check.".format(api_full_name) + logger.warning(message) + result_rows = [] + df_row = list([api_full_name, rank, CompareConst.SKIP, message]) + result_rows.append(df_row) + write_csv(result_rows, result_file_path) + continue + + if api_info_dict.get('used'): + continue + + group_ranks, group_id = get_group_info(api_full_name, api_name, api_info_dict) + if not group_ranks or not group_id: + logger.warning("The api {} doesn't support distributed check.".format(api_full_name)) + continue + all_args, all_kwargs = get_distributed_args_kwargs(forward_contents, api_full_name, + real_data_paths, group_ranks) + try: + distributed_check_params = DistributedCheckParams(api_full_name, all_args, all_kwargs, group_ranks, + result_file_path, checker_config) + distributed_check(distributed_check_params) + except Exception as e: + logger.error("The api {} in rank {} distributed check failed.".format(api_full_name, rank)) + result_rows = [] + df_row = list([api_full_name, rank, CompareConst.ERROR, str(e)]) + result_rows.append(df_row) + write_csv(result_rows, result_file_path) + + +def distributed_check(distributed_check_params): + api_full_name = distributed_check_params.api_full_name + all_args = distributed_check_params.all_args + all_kwargs = distributed_check_params.all_kwargs + group_ranks = distributed_check_params.group_ranks + result_file_path = distributed_check_params.result_file_path + checker_config = distributed_check_params.checker_config + + _, api_name = extract_basic_api_segments(api_full_name) + nprocs = len(group_ranks) + distributed_config = {} + distributed_config[DistributedCheckConst.API_FULL_NAME] = api_full_name + distributed_config[DistributedCheckConst.API_NAME] = api_name + distributed_config[DistributedCheckConst.GROUP_RANKS] = group_ranks + distributed_config[DistributedCheckConst.ALL_ARGS] = all_args + distributed_config[DistributedCheckConst.ALL_KWARGS] = all_kwargs + distributed_config[DistributedCheckConst.RESULT_FILE_PATH] = result_file_path + benchmark_function = distributed_func_registry.get_bench_function(api_name) + distributed_config[DistributedCheckConst.BENCHMARK_RESULT] = benchmark_function(api_name, all_args, all_kwargs) + distributed_config[DistributedCheckConst.MASTER_IP] = checker_config.master_ip + distributed_config[DistributedCheckConst.MASTER_PORT] = checker_config.master_port + distributed_config[DistributedCheckConst.WORLD_SIZE] = nprocs + + if is_port_in_use(checker_config.master_port, checker_config.master_ip): + raise ValueError( + f"Warning: Port {checker_config.master_port} on host " + f"{checker_config.master_ip} is already in use." + ) + logger.info(f"Port {checker_config.master_port} on host {checker_config.master_ip} is available.") + + mp.spawn(run_hccl, + args=(distributed_config,), + nprocs=nprocs) + + +def run_hccl(rank, distributed_config): + local_rank = distributed_config[DistributedCheckConst.GROUP_RANKS][rank] + torch_npu.npu.set_device(local_rank) + world_size = distributed_config[DistributedCheckConst.WORLD_SIZE] + master_ip = distributed_config[DistributedCheckConst.MASTER_IP] + master_port = distributed_config[DistributedCheckConst.MASTER_PORT] + distributed_setup(rank, world_size, master_ip, master_port) + api_full_name = distributed_config[DistributedCheckConst.API_FULL_NAME] + api_name = distributed_config[DistributedCheckConst.API_NAME] + input_args = distributed_config[DistributedCheckConst.ALL_ARGS] + rank_args = input_args[rank] + rank_kwargs = distributed_config[DistributedCheckConst.ALL_KWARGS][rank] + result_file_path = distributed_config[DistributedCheckConst.RESULT_FILE_PATH] + benchmark_result = distributed_config[DistributedCheckConst.BENCHMARK_RESULT] + device_args, _ = generate_device_params(rank_args, rank_kwargs, False, api_name) + logger.info("Start to check distributed api {} in rank {}.".format(api_full_name, local_rank)) + distributed_func.get(api_name)(*device_args) + dist.barrier() + if api_name in special_rank_api_list: + local_rank = rank + kwargs = { + "local_rank": local_rank, + "input_args": input_args + } + compare_function = distributed_func_registry.get_compare_function(api_name) + status = compare_function(device_args, benchmark_result, **kwargs) + message = '' + result_rows = [] + df_row = list([api_full_name, local_rank, status, message]) + result_rows.append(df_row) + write_csv(result_rows, result_file_path) + cleanup() + + +def get_distributed_args_kwargs(forward_contents, api_full_name, real_data_paths, group_ranks): + all_args, all_kwargs = [], [] + _, api_name = extract_basic_api_segments(api_full_name) + for group_rank in group_ranks: + target_api_info = forward_contents[group_rank].get(api_full_name) + if not target_api_info: + logger.warning("The api {} doesn't exist in rank {}.".format(api_full_name, group_rank)) + continue + if target_api_info.get('used'): + continue + target_api_info['used'] = True + args, kwargs, _ = get_api_info(target_api_info, api_name, real_data_paths[group_rank]) + all_args.append(args) + all_kwargs.append(kwargs) + return all_args, all_kwargs + + +if __name__ == '__main__': + logger.info("Start to run distributed ut task.") + _run_distributed() + logger.info("End to run distributed ut task.") diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py index 6214d892906bef44d94474c6415674f39099357b..0f184d14b66d84607a6767ba9ef5210ff4fc5b69 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py @@ -34,8 +34,10 @@ from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import exec_api, i from msprobe.core.common.file_utils import check_link, FileChecker from msprobe.pytorch.api_accuracy_checker.common.utils import extract_basic_api_segments from msprobe.core.common.const import FileCheckConst, Const +from msprobe.core.common.utils import check_op_str_pattern_valid from msprobe.pytorch.common.log import logger from msprobe.pytorch.common.parse_json import parse_json_info_forward_backward +from msprobe.core.common.decorator import recursion_depth_decorator def check_tensor_overflow(x): @@ -63,6 +65,7 @@ def check_tensor_overflow(x): return False +@recursion_depth_decorator("check_data_overflow") def check_data_overflow(x, device): if isinstance(x, (tuple, list)): if not x: @@ -75,6 +78,7 @@ def check_data_overflow(x, device): return torch_npu.npu.utils.npu_check_overflow(x) +@recursion_depth_decorator("is_bool_output") def is_bool_output(x): if isinstance(x, (tuple, list)): if not x: @@ -91,6 +95,7 @@ def run_overflow_check(forward_file): dump_path = os.path.dirname(forward_file) real_data_path = os.path.join(dump_path, Const.DUMP_TENSOR_DATA) for api_full_name, api_info_dict in tqdm(forward_content.items()): + check_op_str_pattern_valid(api_full_name) if is_unsupported_api(api_full_name, is_overflow_check=True): continue try: @@ -161,6 +166,7 @@ def _run_overflow_check(parser=None): _run_overflow_check_parser(parser) args = parser.parse_args(sys.argv[1:]) _run_overflow_check_command(args) + logger.info("UT task completed.") def _run_overflow_check_command(args): @@ -175,8 +181,3 @@ def _run_overflow_check_command(args): logger.error(f"Set NPU device id failed. device id is: {args.device_id}") raise NotImplementedError from error run_overflow_check(api_info) - - -if __name__ == '__main__': - _run_overflow_check() - logger.info("UT task completed.") diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py index 905687c1bfc932883396481410c333a7566fd342..fc7814e6d4cf533410fcc26a7cc94b94808b9b44 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py @@ -45,11 +45,11 @@ from msprobe.pytorch.api_accuracy_checker.compare.compare_column import CompareC from msprobe.pytorch.api_accuracy_checker.common.config import CheckerConfig from msprobe.pytorch.common.parse_json import parse_json_info_forward_backward from msprobe.core.common.file_utils import FileChecker, change_mode, \ - create_directory, get_json_contents, read_csv, check_file_or_directory_path, check_crt_valid + create_directory, get_json_contents, read_csv, check_file_or_directory_path from msprobe.pytorch.common.log import logger from msprobe.pytorch.pt_config import parse_json_config from msprobe.core.common.const import Const, FileCheckConst, CompareConst -from msprobe.core.common.utils import safe_get_value, CompareException +from msprobe.core.common.utils import safe_get_value, CompareException, is_int, check_op_str_pattern_valid from msprobe.pytorch.common.utils import seed_all from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.attl import ATTL, ATTLConfig, move2device_exec from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.device_dispatch import ConsumerDispatcher @@ -65,6 +65,8 @@ DETAILS_FILE_NAME = "accuracy_checking_details_" + current_time + ".csv" not_backward_list = ['repeat_interleave'] unsupported_backward_list = ['masked_select'] +unsupported_api_list = ["to", "empty", "empty_like", "empty_strided", "new_empty", "new_empty_strided", + "empty_with_format"] tqdm_params = { @@ -83,6 +85,9 @@ tqdm_params = { } +seed_all() + + def run_ut(config): logger.info("start UT test") if config.online_config.is_online: @@ -93,7 +98,7 @@ def run_ut(config): logger.info(f"UT task details will be saved in {config.details_csv_path}") if config.save_error_data: - logger.info(f"UT task error_datas will be saved in {config.error_data_path}") + logger.info(f"UT task error_data will be saved in {config.error_data_path}") compare = Comparator(config.result_csv_path, config.details_csv_path, config.is_continue_run_ut, config=config) if config.online_config.is_online: @@ -117,6 +122,7 @@ def run_ut(config): def run_api_offline(config, compare, api_name_set): err_column = CompareColumn() for _, (api_full_name, api_info_dict) in enumerate(tqdm(config.forward_content.items(), **tqdm_params)): + check_op_str_pattern_valid(api_full_name) if api_full_name in api_name_set: continue if is_unsupported_api(api_full_name): @@ -218,6 +224,7 @@ def blacklist_and_whitelist_filter(api_name, black_list, white_list): If api is both in black_list and black_list, black_list first. return: False for exec api, True for not exec """ + black_list.extend(unsupported_api_list) if black_list and api_name in black_list: return True if white_list and api_name not in white_list: @@ -317,7 +324,8 @@ def run_torch_api_online(api_full_name, api_data, backward_content): if kwargs.get("device"): del kwargs["device"] - device_out = exec_api(api_type, api_name, Const.CUDA_LOWERCASE, args, kwargs) + device_exec_params = ExecParams(api_type, api_name, current_device, args, kwargs, False, None) + device_out = exec_api(device_exec_params) device_out = move2device_exec(device_out, "cpu") return UtDataInfo(None, None, out, device_out, None, in_fwd_data_list, None, rank=api_data.rank) @@ -344,6 +352,9 @@ def need_to_backward(grad_index, out): def run_backward(args, grad, grad_index, out): if grad_index is not None: + if not is_int(grad_index): + logger.error(f"{grad_index} dtype is not int") + raise TypeError(f"{grad_index} dtype is not int") if grad_index >= len(out): logger.error(f"Run backward error when grad_index is {grad_index}") raise IndexError(f"Run backward error when grad_index is {grad_index}") @@ -430,6 +441,7 @@ def preprocess_forward_content(forward_content): arg_cache = {} for key, value in forward_content.items(): + check_op_str_pattern_valid(key) base_key = key.rsplit(Const.SEP, 1)[0] if key not in arg_cache: @@ -491,7 +503,10 @@ def checked_online_config(online_config): check_file_or_directory_path(online_config.tls_path, isdir=True) check_file_or_directory_path(os.path.join(online_config.tls_path, "server.key")) check_file_or_directory_path(os.path.join(online_config.tls_path, "server.crt")) - check_crt_valid(os.path.join(online_config.tls_path, "server.crt")) + check_file_or_directory_path(os.path.join(online_config.tls_path, "ca.crt")) + crl_path = os.path.join(online_config.tls_path, "crl.pem") + if os.path.exists(crl_path): + check_file_or_directory_path(crl_path) # host and port if not isinstance(online_config.host, str) or not re.match(Const.ipv4_pattern, online_config.host): @@ -561,7 +576,14 @@ def run_ut_command(args): error_data_path = checker_config.error_data_path if save_error_data: if args.result_csv_path: - time_info = result_csv_path.split('.')[0].split('_')[-1] + parts_by_dot = result_csv_path.split(Const.SEP) + if len(parts_by_dot) < 2 or not parts_by_dot[0]: + raise ValueError("result_csv_path does not contain a valid file name with an extension.") + file_name_part = parts_by_dot[0] + parts_by_underscore = file_name_part.split(Const.REPLACEMENT_CHARACTER) + if len(parts_by_underscore) < 2: + raise ValueError("File name part does not contain enough '_' separated segments.") + time_info = parts_by_underscore[-1] global UT_ERROR_DATA_DIR UT_ERROR_DATA_DIR = 'ut_error_data' + time_info error_data_path = initialize_save_error_data(error_data_path) @@ -579,9 +601,8 @@ def run_ut_command(args): } run_ut_config = checker_config.get_run_ut_config(**config_params) run_ut(run_ut_config) + logger.info("UT task completed.") if __name__ == '__main__': - seed_all() _run_ut() - logger.info("UT task completed.") diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py index dc0174212e3f8f8cf70fa1701aadc664138dbcdf..63e347d971e1cc222e72e82c59a7fc0f168bf8fa 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py @@ -1,9 +1,7 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # -# Licensed under the Apache License, Version 2.0 (the "License"); +# Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # @@ -16,10 +14,11 @@ # limitations under the License. import os +import socket from collections import namedtuple import re -import torch +import torch try: import torch_npu except ImportError: @@ -29,15 +28,13 @@ else: current_device = "npu" from torch_npu.npu.amp import autocast -from msprobe.core.common.const import FileCheckConst, Const, CompareConst +from msprobe.core.common.const import FileCheckConst, Const, CompareConst, DistributedCheckConst from msprobe.core.common.file_utils import FileChecker from msprobe.core.common.log import logger from msprobe.core.common.utils import CompareException +from msprobe.pytorch.hook_module.api_register import ApiTemplate, get_api_register from msprobe.pytorch.hook_module.wrap_aten import AtenOPTemplate -from msprobe.pytorch.hook_module.wrap_functional import FunctionalOPTemplate -from msprobe.pytorch.hook_module.wrap_npu_custom import NpuOPTemplate -from msprobe.pytorch.hook_module.wrap_tensor import TensorOPTemplate -from msprobe.pytorch.hook_module.wrap_torch import TorchOPTemplate + hf_32_standard_api = ["conv1d", "conv2d"] not_detach_set = {'resize_', 'resize_as_', 'set_', 'transpose_', 't_', 'squeeze_', 'unsqueeze_'} @@ -108,17 +105,28 @@ def exec_api(exec_params): kwargs = exec_params.kwargs is_autocast = exec_params.is_autocast autocast_dtype = exec_params.autocast_dtype - - if api_type == "Functional": - torch_api = FunctionalOPTemplate(api_name, str, False) - if api_type == "Tensor": - torch_api = TensorOPTemplate(api_name, str, False) - if api_type == "Torch": - torch_api = TorchOPTemplate(api_name, str, False) - if api_type == "Aten": + out = None + + prefix_map = Const.API_DATA_PREFIX.get(Const.PT_FRAMEWORK, {}) + if not prefix_map or api_type not in prefix_map.values() or \ + api_type not in ( + Const.FUNCTIONAL_API_TYPE_PREFIX, + Const.TENSOR_API_TYPE_PREFIX, + Const.TORCH_API_TYPE_PREFIX, + Const.ATEN_API_TYPE_PREFIX, + Const.NPU_API_TYPE_PREFIX + ): + return out + + if api_type == Const.ATEN_API_TYPE_PREFIX: torch_api = AtenOPTemplate(api_name, None, False) - if api_type == "NPU": - torch_api = NpuOPTemplate(api_name, None, False, device) + else: + api_register = get_api_register() + api_register.initialize_hook(None) + api_func_type = list(prefix_map.keys())[list(prefix_map.values()).index(api_type)] + api_func = api_register.ori_api_attr.get(Const.PT_FRAMEWORK + Const.SEP + api_func_type, {}).get(api_name) + + torch_api = ApiTemplate(api_name, api_func, api_type, None, need_hook=False, device=device) if is_autocast: with autocast(dtype=autocast_dtype): out = torch_api.forward(*args, **kwargs) @@ -225,7 +233,7 @@ def generate_cpu_params(input_args, input_kwargs, need_backward, api_name): origin_dtype = need_raise_dtypes.pop() raise_dtype = PRECISION_MAPPING.get(origin_dtype, torch.float32) autocast_dtype = origin_dtype - + elif len(need_raise_dtypes) >= 2: raise_dtype = torch.float32 need_raise_dtypes.discard(torch.float32) @@ -252,3 +260,65 @@ def is_unsupported_api(api_name, is_overflow_check=False): if flag: logger.info(f"{split_name} api is not supported for run ut. SKIP.") return flag + + +def get_args_index(api_name, args_name): + """ + 根据 API 名字和参数名获取参数索引。获取 group_index 或者 src_index。 + :param api_name: API 名字,如 "broadcast" 或 "all_reduce" + :param args_name: 参数名,如 "group" 或 "src" + :return: 参数索引 或 None(如果 API 名字或参数名不存在) + """ + api_info = DistributedCheckConst.API_ARGS_INDEX.get(api_name) + if api_info: + return api_info.get(args_name) + return None + + +def get_distributed_args(api_name, input_args, input_kwargs, args_name): + res = None + res = input_kwargs.get(args_name) + if res: + return res + res_index = get_args_index(api_name, args_name) + if not res_index or len(input_args) <= res_index: + return None + res = input_args[res_index] + return res + + +def get_group_info(api_full_name, api_name, api_info_dict): + input_args = api_info_dict.get('input_args', []) + input_kwargs = api_info_dict.get('input_kwargs', {}) + group = get_distributed_args(api_name, input_args, input_kwargs, DistributedCheckConst.GROUP) + + if not group: + logger.warning("The api {} doesn't have group info.".format(api_full_name)) + return None, None + group_ranks = group.get('group_ranks') + if not group_ranks: + logger.warning("The group of api {} doesn't have group_ranks info.".format(api_full_name)) + return None, None + group_id = group.get('group_id') + if not group_id: + logger.warning("The group of api {} doesn't have group_id info.".format(api_full_name)) + return None, None + return group_ranks, group_id + + +def is_port_in_use(port, host): + """ + 检测指定端口是否被占用。 + :param port: 要检测的端口号 + :param host: 主机地址 + :return: 如果端口被占用返回 True,否则返回 False + """ + if not isinstance(port, str) or not port.isdigit(): + raise Exception(f"port: {port} is invalid. Port must be a numeric string.") + port = int(port) + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + try: + s.bind((host, port)) + return False # 端口未被占用 + except socket.error: + return True # 端口已被占用 diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py index f31c29c6bb6fa8a863b83bf09d15aba09645436f..2cfc355ec035d245261ca9c817e02687c684d471 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py @@ -27,6 +27,7 @@ from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.client import T from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.server import TCPServer from msprobe.core.common.file_utils import remove_path from msprobe.pytorch.common.utils import logger, save_api_data, load_api_data, save_pkl, load_pkl +from msprobe.core.common.decorator import recursion_depth_decorator BufferType = Union[ApiData, Dict[str, Any], str] # Union[Tensor, Tuple[Optional[Tensor]]] @@ -168,11 +169,12 @@ class ATTL: return buffer +@recursion_depth_decorator("move2device_exec") def move2device_exec(obj, device): if isinstance(obj, (tuple, list)): data_list = [move2device_exec(val, device) for val in obj] return data_list if isinstance(obj, list) else tuple(data_list) - if isinstance(obj, dict): + if isinstance(obj, dict): return {key: move2device_exec(val, device) for key, val in obj.items()} elif isinstance(obj, torch.Tensor): obj = obj.detach() diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py index fbb087deec73bb6e77c0d7581128c74e2d9be9fa..a55ecae283105ed3d3127b862fc817ca371732db 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py @@ -12,23 +12,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -import hashlib +from functools import partial +import zlib import io import struct import time import os -import signal from queue import Queue from threading import Thread from typing import Union -from twisted.internet import reactor, protocol, endpoints +from twisted.internet import reactor, protocol, endpoints, ssl from twisted.protocols.basic import FileSender from msprobe.pytorch.common.utils import logger from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.utils import STRUCT_UNPACK_MODE as unpack_mode, \ - STR_TO_BYTES_ORDER as bytes_order + STR_TO_BYTES_ORDER as bytes_order, cipher_list, verify_callback, load_ssl_pem MAX_SENDING_QUEUE_SIZE = 20 @@ -104,11 +103,28 @@ class TCPClient: self.factory = MessageClientFactory() self.factory.protocol = cur_protocol if self.tls_path: - from twisted.internet import ssl - client_key = os.path.join(self.tls_path, "client.key") - client_crt = os.path.join(self.tls_path, "client.crt") - client_context_factory = ssl.DefaultOpenSSLContextFactory(client_key, client_crt) - endpoint = endpoints.SSL4ClientEndpoint(reactor, self.host, self.port, client_context_factory) + client_key, client_crt, ca_crt, crl_pem = load_ssl_pem( + key_file=os.path.join(self.tls_path, "client.key"), + cert_file=os.path.join(self.tls_path, "client.crt"), + ca_file=os.path.join(self.tls_path, "ca.crt"), + crl_file=os.path.join(self.tls_path, "crl.pem") + ) + + ssl_options = ssl.CertificateOptions( + privateKey=client_key, + certificate=client_crt, + method=ssl.SSL.TLSv1_2_METHOD, + verify=True, + requireCertificate=True, + caCerts=[ca_crt], # 信任的CA证书列表 + ) + ssl_context = ssl_options.getContext() + ssl_context.set_cipher_list(cipher_list) + ssl_context.set_options(ssl.SSL.OP_NO_RENEGOTIATION) + ssl_context.set_verify(ssl.SSL.VERIFY_PEER | ssl.SSL.VERIFY_FAIL_IF_NO_PEER_CERT, + partial(verify_callback, crl=crl_pem)) + + endpoint = endpoints.SSL4ClientEndpoint(reactor, self.host, self.port, ssl_options) else: endpoint = endpoints.TCP4ClientEndpoint(reactor, self.host, self.port) d = endpoint.connect(self.factory) @@ -299,12 +315,12 @@ class ClientProtocol(protocol.Protocol): def send_wrapped_data(self, data, sequence_number: int = 0, rank: int = 0, step: int = 0): length = len(data) - md5_hash = hashlib.md5(data).hexdigest() if self.check_sum else "" + data_crc = f"{zlib.crc32(data):08x}" if self.check_sum else "" data_meaasge = length.to_bytes(8, byteorder=bytes_order) + \ sequence_number.to_bytes(8, byteorder=bytes_order) + \ rank.to_bytes(8, byteorder=bytes_order) + \ step.to_bytes(8, byteorder=bytes_order) + \ - md5_hash.encode() + \ + data_crc.encode() + \ data logger.debug(f"send 流水号: {sequence_number}; RANK: {rank}; STEP: {step}; LENGTH: {length}") @@ -346,7 +362,7 @@ class ClientProtocol(protocol.Protocol): def connectionLost(self, reason): self.signal_exit = True self.factory.num_connections -= 1 - logger.info(f"Lost connection with server, reason is : {reason}") + logger.info(f"Lost connection with server, reason is : {reason.value}") class MessageClientFactory(protocol.ClientFactory): diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py index 8777af9cc37ad03dacfa82bf29854fb1c1babe95..6fc36bcdecac81ae302ec9fd64079758f74e4071 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py @@ -29,7 +29,6 @@ from msprobe.pytorch.common.log import logger from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.attl import move2target_device from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import generate_cpu_params - # NPU vs GPU api list CompareApi = set(absolute_standard_api) | set(binary_standard_api) | set(thousandth_standard_api) @@ -43,6 +42,15 @@ OnlineApiPrecisionCompareConfig = namedtuple('OnlineApiPrecisionCompareConfig', CommonCompareConfig = namedtuple('CommonCompareConfig', ['compare', 'handle_func', 'config']) +def get_gpu_device(): + is_gpu = False + try: + import torch_npu + except ImportError: + is_gpu = True + return is_gpu + + def run_ut_process(xpu_id, consumer_queue, common_config, api_precision_csv_file): """ When consumer_queue(shared with ConsumerDispatcher) is not empty, consume api data from consumer_queue. :param xpu_id: int @@ -51,7 +59,9 @@ def run_ut_process(xpu_id, consumer_queue, common_config, api_precision_csv_file :param api_precision_csv_file: list, length is 2, result file name and details file name :return: """ - gpu_device = torch.device(f'cuda:{xpu_id}') + device_info = "cuda" if get_gpu_device() else "npu" + logger.info(f"Start run_ut_process for {device_info} device, rank: {xpu_id}.") + gpu_device = torch.device(f'{device_info}:{xpu_id}') while True: if consumer_queue.empty(): diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py index 411e36d4cb3014b75a46d58ebec99b7e8b7c7c44..d51138941c7711e404d561e0c92389de581c3b3c 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py @@ -12,19 +12,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -import os.path +from functools import partial +import os import struct -import hashlib +import zlib import time import io from threading import Thread -from twisted.internet import reactor, protocol, endpoints +from twisted.internet import reactor, protocol, endpoints, ssl from msprobe.pytorch.common.utils import logger from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.utils import cipher_list, \ - STRUCT_UNPACK_MODE as unpack_mode, STR_TO_BYTES_ORDER as bytes_order + STRUCT_UNPACK_MODE as unpack_mode, STR_TO_BYTES_ORDER as bytes_order, verify_callback, load_ssl_pem class TCPServer: @@ -44,15 +44,28 @@ class TCPServer: self.factory.protocol = self.build_protocol if self.tls_path: - from OpenSSL import SSL - from twisted.internet import ssl - server_key = os.path.join(self.tls_path, "server.key") - server_crt = os.path.join(self.tls_path, "server.crt") - server_context_factory = ssl.DefaultOpenSSLContextFactory(server_key, server_crt, SSL.TLSv1_2_METHOD) - server_context_ = server_context_factory.getContext() - server_context_.set_cipher_list(cipher_list) - server_context_.set_options(SSL.OP_NO_RENEGOTIATION) - endpoint = endpoints.SSL4ServerEndpoint(reactor, self.port, server_context_factory) + server_key, server_crt, ca_crt, crl_pem = load_ssl_pem( + key_file=os.path.join(self.tls_path, "server.key"), + cert_file=os.path.join(self.tls_path, "server.crt"), + ca_file=os.path.join(self.tls_path, "ca.crt"), + crl_file=os.path.join(self.tls_path, "crl.pem") + ) + + ssl_options = ssl.CertificateOptions( + privateKey=server_key, + certificate=server_crt, + method=ssl.SSL.TLSv1_2_METHOD, + verify=True, + requireCertificate=True, + caCerts=[ca_crt], # 信任的CA证书列表 + ) + ssl_context = ssl_options.getContext() + ssl_context.set_cipher_list(cipher_list) + ssl_context.set_options(ssl.SSL.OP_NO_RENEGOTIATION) + ssl_context.set_verify(ssl.SSL.VERIFY_PEER | ssl.SSL.VERIFY_FAIL_IF_NO_PEER_CERT, + partial(verify_callback, crl=crl_pem)) + + endpoint = endpoints.SSL4ServerEndpoint(reactor, self.port, ssl_options) else: endpoint = endpoints.TCP4ServerEndpoint(reactor, self.port) endpoint.listen(self.factory) @@ -85,10 +98,10 @@ class ServerProtocol(protocol.Protocol): self.consumer_queue = shared_queue self.check_sum = check_sum self.length_width = 8 - self.md5_width = 32 + self.crc_width = 8 self.obj_length = None self.tell = 0 - self.obj_md5 = None + self.obj_crc = None self.obj_body = None self.sequence_number = -1 self.rank = -1 @@ -99,7 +112,7 @@ class ServerProtocol(protocol.Protocol): self.buffer = io.BytesIO() self.obj_length = None self.tell = 0 - self.obj_md5 = None + self.obj_crc = None self.obj_body = None self.factory.transport_dict[self.transport] = 1 self.factory.transport_list.append(self.transport) @@ -132,11 +145,12 @@ class ServerProtocol(protocol.Protocol): time.sleep(0.1) obj_key = str(self.sequence_number) + "_" + str(self.rank) + "_" + str(self.step) + # get the crc value of a 16-bit string with a length of 8 + recv_crc = f"{zlib.crc32(self.obj_body):08x}" - recv_md5 = hashlib.md5(self.obj_body).hexdigest() - if self.check_sum and recv_md5 != self.obj_md5: - # when needs check md5 and check no pass, indicates received data error, send b"ERROR" to client. - logger.debug(f"Error:接收数据有问题,流水号{self.sequence_number}, expected {self.obj_md5}, but get {recv_md5}") + if self.check_sum and recv_crc != self.obj_crc: + # when needs check hash value and check no pass, indicates received data error, send b"ERROR" to client. + logger.debug(f"Error:接收数据有问题,流水号{self.sequence_number}, expected {self.obj_crc}, but get {recv_crc}") self.send_ack(self.ACK_ERROR) else: if self.obj_body == self.ACK_STOP: @@ -146,7 +160,7 @@ class ServerProtocol(protocol.Protocol): if obj_key in self.sequence_number_dict: logger.debug(f"这是一次异常的重传,可以忽略。 {obj_key}, {self.sequence_number_dict}") else: - self.sequence_number_dict[obj_key] = self.obj_md5 + self.sequence_number_dict[obj_key] = self.obj_crc self.consumer_queue.put(self.obj_body, block=True) self.reset_env() @@ -173,7 +187,7 @@ class ServerProtocol(protocol.Protocol): self.sequence_number = -1 self.rank = -1 self.step = -1 - self.obj_md5 = None + self.obj_crc = None self.obj_body = None def dataReceived(self, data): @@ -192,15 +206,15 @@ class ServerProtocol(protocol.Protocol): logger.debug( f"流水号: {self.sequence_number}; RANK: {self.rank}; STEP: {self.step}; Length: {self.obj_length}") - # If needs check md5 but not parse md5 yet, read 32b md5 values - check_sum_and_md5 = (self.check_sum + # If needs check hash but not parse crc yet, read 8b crc values + check_sum_and_crc = (self.check_sum and self.obj_length is not None - and self.obj_md5 is None - and len(self.buffer.getvalue()) - self.tell >= self.md5_width) - if check_sum_and_md5: - self.obj_md5 = self.buffer.read(self.md5_width).decode() - self.tell += self.md5_width - logger.debug(f"MD5: {self.obj_md5}") + and self.obj_crc is None + and len(self.buffer.getvalue()) - self.tell >= self.crc_width) + if check_sum_and_crc: + self.obj_crc = self.buffer.read(self.crc_width).decode() + self.tell += self.crc_width + logger.debug(f"Hash value: {self.obj_crc}") current_length = len(self.buffer.getvalue()) - self.tell if self.obj_length is not None and 0 < self.obj_length <= current_length: diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py index aace2f13cc0eeb34a51c03907c9a87a6479617c4..05dd50a3f2bbc6637926c45f7c96f7d90e01edbf 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py @@ -12,6 +12,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import gc +import os +from datetime import datetime, timezone + +from OpenSSL import crypto +from cryptography import x509 +from cryptography.hazmat.backends import default_backend +from dateutil import parser + +from msprobe.core.common.file_utils import FileOpen +from msprobe.core.common.log import logger cipher_list = ":".join( ["TLS_DHE_RSA_WITH_AES_128_GCM_SHA256", @@ -42,3 +53,146 @@ cipher_list = ":".join( STRUCT_UNPACK_MODE = "!Q" STR_TO_BYTES_ORDER = "big" + + +def is_certificate_revoked(cert, crl): + # 获取证书的序列号 + cert_serial_number = cert.get_serial_number() + + # 检查证书是否在CRL中 + revoked_serials = [revoked_cert.serial_number for revoked_cert in crl] + if cert_serial_number in revoked_serials: + logger.error(f"证书已吊销:{cert_serial_number:020x}") + return True + + return False + + +def verify_callback(conn, cert, errno, depth, preverify_ok, crl=None): + """ + 验证对端证书的有效性 + :param conn: OpenSSL.SSL.Connection, SSL 连接对象 + :param cert: OpenSSL.crypto.X509, 当前证书 + :param errno: int, OpenSSL错误代码, 0:无错误 | 9:证书过期 | 18: 自签名证书 + :param depth: int, 当前证书在证书链中的深度 (0=叶子节点), 1:中间CA证书 -1:根CA证书 2+:更高级别CA证书 + :param preverify_ok: int, 验证结果 (1=通过, 0=失败) + :param crl: _CRLInternal, CRL证书对象 + :return: bool, True表示接受证书, False表示拒绝 + """ + + if not preverify_ok: + from OpenSSL import SSL + error_str = SSL._ffi.string(SSL._lib.X509_verify_cert_error_string(errno)).decode() + logger.error(f"证书验证失败 (depth={depth}, err={errno}): {error_str}") + return False + + if crl and is_certificate_revoked(cert, crl): + return False + + return preverify_ok + + +def load_ssl_pem(key_file, cert_file, ca_file, crl_file): + """ + Load SSL PEM files. + + Args: + key_file (str): The path to the private key file. + cert_file (str): The path to the certificate file. + ca_file (str): The path to the CA certificate file. + crl_file (str): The path to the CRL file. + + Returns: + tuple: (key, crt, ca_crt, crl) + + Raises: + Exception: If the file paths are invalid or the file contents are incorrect, exceptions may be thrown. + """ + + try: + # your_private_key_password + import pwinput + passphrase = pwinput.pwinput("Enter your password: ") + with FileOpen(key_file, "rb") as f: + key = crypto.load_privatekey(crypto.FILETYPE_PEM, f.read(), passphrase.encode()) + del passphrase + gc.collect() + with FileOpen(cert_file, "rb") as f: + crt = crypto.load_certificate(crypto.FILETYPE_PEM, f.read()) + check_crt_valid(crt) + + crt_serial_number = hex(crt.get_serial_number())[2:] + logger.info(f"crt_serial_number: {crt_serial_number}") + + check_certificate_match(crt, key) + + with FileOpen(ca_file, "rb") as f: + ca_crt = crypto.load_certificate(crypto.FILETYPE_PEM, f.read()) + check_crt_valid(ca_crt) + + ca_serial_number = hex(ca_crt.get_serial_number())[2:] + logger.info(f"ca_serial_number: {ca_serial_number}") + crl = None + if os.path.exists(crl_file): + with FileOpen(crl_file, "rb") as f: + crl = x509.load_pem_x509_crl(f.read(), default_backend()) + check_crl_valid(crl, ca_crt) + for revoked_cert in crl: + logger.info(f"Serial Number: {revoked_cert.serial_number}, " + f"Revocation Date: {revoked_cert.revocation_date_utc}") + + except Exception as e: + raise RuntimeError(f"The SSL certificate is invalid") from e + + return key, crt, ca_crt, crl + + +def check_crt_valid(pem): + """ + Check the validity of the SSL certificate. + + Raises: + RuntimeError: If the SSL certificate is invalid or expired. + """ + try: + pem_start = parser.parse(pem.get_notBefore().decode("UTF-8")) + pem_end = parser.parse(pem.get_notAfter().decode("UTF-8")) + logger.info(f"The SSL certificate passes the verification and the validity period " + f"starts from {pem_start} ends at {pem_end}.") + except Exception as e: + raise RuntimeError(f"The SSL certificate is invalid") from e + + now_utc = datetime.now(tz=timezone.utc) + if pem.has_expired() or not (pem_start <= now_utc <= pem_end): + raise RuntimeError(f"The SSL certificate has expired.") + + +def check_certificate_match(certificate, private_key): + """ + Check certificate and private_key is match or not. if mismatched, an exception is thrown. + :param certificate: + :param private_key: + :return: + """ + test_data = os.urandom(256) + try: + signature = crypto.sign(private_key, test_data, "sha256") + crypto.verify( + certificate, # 包含公钥的证书 + signature, # 生成的签名 + test_data, # 原始数据 + "sha256", # 哈希算法 + ) + logger.info("公钥和私钥匹配") + except Exception as e: + raise RuntimeError("公钥和私钥不匹配") from e + + +def check_crl_valid(crl, ca_crt): + # 验证CRL签名(确保CRL未被篡改) + if not crl.is_signature_valid(ca_crt.get_pubkey().to_cryptography_key()): + raise RuntimeError("CRL签名无效!") + + # 检查CRL有效期 + if not (crl.last_update <= datetime.utcnow() <= crl.next_update): + raise RuntimeError("CRL已过期或尚未生效!") diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/triton_adapter/__init__.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/triton_adapter/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/triton_adapter/common/__init__.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/triton_adapter/common/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/triton_adapter/common/compare_input.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/triton_adapter/common/compare_input.py new file mode 100644 index 0000000000000000000000000000000000000000..7c9b41008925fd99e2e71e18c67af51555eebdc0 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/triton_adapter/common/compare_input.py @@ -0,0 +1,22 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class PrecisionCompareInput: + def __init__(self, compare_row, dtype, compare_column): + self.row_npu = compare_row + self.row_gpu = None # 由于复用了msprobe中的BasePrecisionCompare,需要补上row_gpu属性,无作用 + self.dtype = dtype + self.compare_column = compare_column diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/triton_adapter/common/compare_utils.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/triton_adapter/common/compare_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2e9908b95c8bd4ba028e53f797f7ce6a9d528406 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/triton_adapter/common/compare_utils.py @@ -0,0 +1,80 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pandas as pd + +from msprobe.pytorch.common.log import logger +from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import DETAIL_TEST_ROWS +from msprobe.core.common.const import CompareConst + +from msprobe.pytorch.api_accuracy_checker.triton_adapter.precision_standard.triton_standard_register import \ + absolute_standard_api_list, binary_standard_api_list, ulp_standard_api_list, thousandth_standard_api_list + +accumulative_error_eb_threshold = { + 'torch.float16': 2 ** -20, + 'torch.bfloat16': 2 ** -7, + 'torch.float32': 2 ** -14, + 'default': 2 ** -14 +} + +ulp_err_threshold = { + 'torch.float32': { + 'mean_ulp_error': 64, + 'ulp_err_proportion': 0.05 + }, + 'torch.float16': { + 'ulp_err_proportion': 0.001 + } +} + + +def convert_compare_column_to_row(compare_column, api_name): + compare_column_list = compare_column.to_column_value("pass", " ") + compare_column_list.insert(0, api_name) + compare_row = pd.Series(compare_column_list, DETAIL_TEST_ROWS[0]) + return compare_row + + +def print_check_details(compare_column, api_name): + if api_name in absolute_standard_api_list: + standard = CompareConst.ABSOLUTE_THRESHOLD + metrics = ['inf_nan_error_ratio', 'rel_err_ratio', 'abs_err_ratio'] + values = [compare_column.inf_nan_error_ratio, compare_column.rel_err_ratio, compare_column.abs_err_ratio] + thresholds = [0, 0, 0] + elif api_name in binary_standard_api_list: + standard = CompareConst.BINARY_CONSISTENCY + metrics = ['error_rate'] + values = [compare_column.error_rate] + thresholds = [0] + elif api_name in ulp_standard_api_list: + standard = CompareConst.ULP_COMPARE + metrics = ['mean_ulp_error', 'ulp_error_proportion'] + values = [compare_column.mean_ulp_error, compare_column.ulp_error_proportion] + thresholds = [ulp_err_threshold] + elif api_name in thousandth_standard_api_list: + standard = CompareConst.THOUSANDTH_STANDARD + metrics = ['rel_err_thousandth'] + values = [compare_column.rel_err_thousandth] + thresholds = [CompareConst.THOUSANDTH_PASS_VALUE] + else: + standard = CompareConst.ACCUMULATIVE_ERROR_COMPARE + metrics = ['inf_nan_error_ratio', 'rel_err_ratio', 'abs_err_ratio', 'eb'] + values = [compare_column.inf_nan_error_ratio, compare_column.rel_err_ratio, compare_column.abs_err_ratio, + compare_column.eb] + thresholds = [0, 0, 0, accumulative_error_eb_threshold] + + logger.info( + f"Checked by precision standard:{standard}, metrics:{metrics}, values:{values}, thresholds:{thresholds}") diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/triton_adapter/get_compare_result.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/triton_adapter/get_compare_result.py new file mode 100644 index 0000000000000000000000000000000000000000..7503b2a6b1423efb00a1fe2026a20208675d1655 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/triton_adapter/get_compare_result.py @@ -0,0 +1,61 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from msprobe.core.common.const import CompareConst +from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import API_PRECISION_COMPARE_UNSUPPORT_LIST, \ + ApiPrecisionCompareColumn +from msprobe.pytorch.api_accuracy_checker.compare.api_precision_compare import record_absolute_threshold_result, \ + record_binary_consistency_result, record_thousandth_threshold_result, record_accumulative_error_compare_result +from msprobe.pytorch.api_accuracy_checker.compare.compare_column import ApiPrecisionOutputColumn + +from msprobe.pytorch.api_accuracy_checker.triton_adapter.precision_standard.ulp_compare import record_ulp_compare_result +from msprobe.pytorch.api_accuracy_checker.triton_adapter.precision_standard.triton_standard_register import \ + TritonStandardRegister +from msprobe.pytorch.api_accuracy_checker.triton_adapter.common.compare_input import PrecisionCompareInput + + +def register_compare_func(): + registry = TritonStandardRegister() + registry.register(CompareConst.ABSOLUTE_THRESHOLD, record_absolute_threshold_result) + registry.register(CompareConst.BINARY_CONSISTENCY, record_binary_consistency_result) + registry.register(CompareConst.ULP_COMPARE, record_ulp_compare_result) + registry.register(CompareConst.THOUSANDTH_STANDARD, record_thousandth_threshold_result) + registry.register(CompareConst.ACCUMULATIVE_ERROR_COMPARE, record_accumulative_error_compare_result) + return registry + + +def get_api_status(compare_row, api_name, compare_column, registry): + # compare_row is CompareColumn by run_ut + # 当前API的输出为空(例如反向过程中requires_grad=False),跳过比对 + if (compare_row[ApiPrecisionCompareColumn.DEVICE_DTYPE].isspace() or + compare_row[ApiPrecisionCompareColumn.DEVICE_DTYPE] in API_PRECISION_COMPARE_UNSUPPORT_LIST or + compare_row[ApiPrecisionCompareColumn.SHAPE] == CompareConst.ZERO_SHAPE): + compare_column.compare_result = CompareConst.SKIP + new_status = CompareConst.SKIP + else: + compare_column.api_name = api_name + dtype = compare_row[ApiPrecisionCompareColumn.DEVICE_DTYPE] + input_data = PrecisionCompareInput(compare_row, dtype, compare_column) + comparison_func = registry.get_comparison_function(api_name, dtype) + new_status = comparison_func(input_data) + return new_status + + +def get_compare_result(run_ut_column, api_name): + compare_column = ApiPrecisionOutputColumn() + registry = register_compare_func() + status = get_api_status(run_ut_column, api_name, compare_column, registry) + return status diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/triton_adapter/precision_compare.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/triton_adapter/precision_compare.py new file mode 100644 index 0000000000000000000000000000000000000000..1a7af24500e0b57bc042e364e58740c7f3d0b1ab --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/triton_adapter/precision_compare.py @@ -0,0 +1,105 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch + +from msprobe.pytorch.api_accuracy_checker.precision_standard.absolute_threshold import AbsolutethdCompare +from msprobe.pytorch.api_accuracy_checker.precision_standard.binary_consistency import BinaryCompare +from msprobe.pytorch.api_accuracy_checker.precision_standard.ulp_compare import UlpCompare +from msprobe.pytorch.api_accuracy_checker.precision_standard.thousandth_standard import ThousandthStdCompare +from msprobe.pytorch.api_accuracy_checker.precision_standard.benchmark_compare import BenchmarkCompare +from msprobe.pytorch.api_accuracy_checker.precision_standard.accumulative_error_compare import AccumulativeErrorCompare +from msprobe.pytorch.api_accuracy_checker.compare.algorithm import get_abs_bench_with_eps, get_abs_err, \ + get_rel_err_origin +from msprobe.pytorch.api_accuracy_checker.compare.compare_column import CompareColumn +from msprobe.pytorch.api_accuracy_checker.compare.compare_input import CompareInput +from msprobe.core.common.const import CompareConst + +from msprobe.pytorch.api_accuracy_checker.triton_adapter.precision_standard.triton_standard_register import \ + TritonStandardRegister + + +class Comparator: + def __init__(self): + self.registry = self._register_compare_func() + + @staticmethod + def _absolute_standard_compare(input_data): + absolute_compare = AbsolutethdCompare(input_data) + absolute_compare.compare() + + @staticmethod + def _binary_standard_compare(input_data): + binary_compare = BinaryCompare(input_data) + binary_compare.compare() + + @staticmethod + def _ulp_compare(input_data): + ulp_compare = UlpCompare(input_data) + ulp_compare.compare() + + @staticmethod + def _thousandth_standard_compare(input_data): + thousandth_compare = ThousandthStdCompare(input_data) + thousandth_compare.compare() + + @staticmethod + def _benchmark_compare(input_data): + benchmark_compare = BenchmarkCompare(input_data) + benchmark_compare.compare() + + @staticmethod + def _accumulative_error_compare(input_data): + accumulative_error_compare = AccumulativeErrorCompare(input_data) + accumulative_error_compare.compare() + + def perform_comparison(self, api_name, input_data): + comparison_func = self.registry.get_comparison_function(api_name, None) + comparison_func(input_data) + + def _register_compare_func(self): + registry = TritonStandardRegister() + registry.register(CompareConst.ABSOLUTE_THRESHOLD, self._absolute_standard_compare) + registry.register(CompareConst.BINARY_CONSISTENCY, self._binary_standard_compare) + registry.register(CompareConst.ULP_COMPARE, self._ulp_compare) + registry.register(CompareConst.THOUSANDTH_STANDARD, self._thousandth_standard_compare) + registry.register(CompareConst.BENCHMARK, self._benchmark_compare) + registry.register(CompareConst.ACCUMULATIVE_ERROR_COMPARE, self._accumulative_error_compare) + return registry + + +def precision_compare(api_name, expected, actual, dtype): + compare_column = CompareColumn() + compare_column.bench_type = str(expected.dtype) + compare_column.npu_type = str(actual.dtype) + compare_column.shape = tuple(actual.shape) + + # to float32 for numpy without bfloat16 + if dtype == torch.bfloat16: + expected = expected.to(torch.float32) + actual = actual.to(torch.float32) + + fx_output = expected.cpu().numpy() # fx_output and triton_output need to be numpy data + triton_output = actual.cpu().numpy() + + _, abs_bench_with_eps = get_abs_bench_with_eps(fx_output, dtype) + abs_err = get_abs_err(fx_output, triton_output) + rel_err_origin = get_rel_err_origin(abs_err, abs_bench_with_eps) + + input_data = CompareInput(fx_output, triton_output, compare_column, dtype, rel_err_origin) + comparator = Comparator() + comparator.perform_comparison(api_name, input_data) + return compare_column diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/triton_adapter/precision_standard/__init__.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/triton_adapter/precision_standard/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/triton_adapter/precision_standard/triton_op_precision_standard.yaml b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/triton_adapter/precision_standard/triton_op_precision_standard.yaml new file mode 100644 index 0000000000000000000000000000000000000000..48de05086259f408bcac44d6adc51cb4db43ff4d --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/triton_adapter/precision_standard/triton_op_precision_standard.yaml @@ -0,0 +1,14 @@ +AbsoluteThreshStandard: + - triton_unk_fused_repeat_5 + +BinaryCompareStandard: + - test_triton + +ULPStandard: + - test_triton + +ThousandthStandard: + - test_triton + +AccumulativeErrorStandard: + - test_triton diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/triton_adapter/precision_standard/triton_standard_register.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/triton_adapter/precision_standard/triton_standard_register.py new file mode 100644 index 0000000000000000000000000000000000000000..3959af46a46489cdbac9c13ec2b2c9f2cb77fa75 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/triton_adapter/precision_standard/triton_standard_register.py @@ -0,0 +1,47 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os + +from msprobe.core.common.file_utils import load_yaml +from msprobe.core.common.const import CompareConst +from msprobe.pytorch.api_accuracy_checker.precision_standard.standard_register import StandardRegistry + +cur_dir = os.path.dirname(os.path.realpath(__file__)) +standard_yaml_path = os.path.join(cur_dir, 'triton_op_precision_standard.yaml') +apis = load_yaml(standard_yaml_path) +absolute_standard_api_list = apis.get('AbsoluteThreshStandard') +binary_standard_api_list = apis.get('BinaryCompareStandard') +ulp_standard_api_list = apis.get('ULPStandard') +thousandth_standard_api_list = apis.get('ThousandthStandard') +accumulative_error_standard_api_list = apis.get('AccumulativeErrorStandard') + + +class TritonStandardRegister(StandardRegistry): + def __init__(self): + super().__init__() + self.api_standard_function_map[CompareConst.ABSOLUTE_THRESHOLD] = absolute_standard_api_list + self.api_standard_function_map[CompareConst.BINARY_CONSISTENCY] = binary_standard_api_list + self.api_standard_function_map[CompareConst.ULP_COMPARE] = ulp_standard_api_list + self.api_standard_function_map[CompareConst.THOUSANDTH_STANDARD] = thousandth_standard_api_list + self.api_standard_function_map[CompareConst.ACCUMULATIVE_ERROR_COMPARE] = accumulative_error_standard_api_list + + +def exist_in_precision_standard(kernel_name): + for api_list in apis.values(): + if kernel_name in api_list: + return True + return False diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/triton_adapter/precision_standard/ulp_compare.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/triton_adapter/precision_standard/ulp_compare.py new file mode 100644 index 0000000000000000000000000000000000000000..64e228b64d6e0da6b7603ed7e5ee5b5b0c68869b --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/triton_adapter/precision_standard/ulp_compare.py @@ -0,0 +1,115 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from collections import namedtuple + +import torch + +from msprobe.pytorch.api_accuracy_checker.precision_standard.standard_config import StandardConfig +from msprobe.pytorch.api_accuracy_checker.precision_standard.base_standard import BasePrecisionCompare +from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import ApiPrecisionCompareColumn, convert_str_to_float +from msprobe.core.common.const import Const, CompareConst + +ULP_ERR_MSG = "ERROR: ULP误差不满足标准\n" +UlpInfNanConsistency = namedtuple('UlpInfNanConsistency', + ['mean_ulp_err_inf_nan_consistency', + 'ulp_err_proportion_ratio_inf_nan_consistency']) + + +class UlpPrecisionCompare(BasePrecisionCompare): + def __init__(self, input_data): + super().__init__(input_data) + self.compare_algorithm = CompareConst.ULP_COMPARE_ALGORITHM_NAME + + @staticmethod + def _get_fp32_ulp_err_status(mean_ulp_err, ulp_err_proportion): + mean_ulp_err_threshold, ulp_err_proportion_threshold, _ = StandardConfig.get_ulp_threshold(torch.float32) + if mean_ulp_err < mean_ulp_err_threshold: + return CompareConst.PASS, "" + elif ulp_err_proportion < ulp_err_proportion_threshold: + return CompareConst.PASS, "" + return CompareConst.ERROR, ULP_ERR_MSG + + @staticmethod + def _get_fp16_ulp_err_status(ulp_err_proportion): + _, ulp_err_proportion_threshold, _ = StandardConfig.get_ulp_threshold(torch.float16) + if ulp_err_proportion < ulp_err_proportion_threshold: + return CompareConst.PASS, "" + return CompareConst.ERROR, ULP_ERR_MSG + + def _compute_mean_ulp_err(self): + column_name = ApiPrecisionCompareColumn.MEAN_ULP_ERR + npu_value = self._get_and_convert_values(column_name) + return npu_value, "" + + def _compute_ulp_err_proportion(self): + column_name = ApiPrecisionCompareColumn.ULP_ERR_PROPORTION + npu_value = self._get_and_convert_values(column_name) + return npu_value + + def _get_status(self, metrics, inf_nan_consistency): + ulp_inf_nan_consistency = inf_nan_consistency.mean_ulp_err_inf_nan_consistency and \ + inf_nan_consistency.ulp_err_proportion_ratio_inf_nan_consistency + + if not ulp_inf_nan_consistency: + compare_result = CompareConst.ERROR + metrics[CompareConst.COMPARE_MESSAGE] = metrics.get(CompareConst.COMPARE_MESSAGE, "") + ULP_ERR_MSG + metrics.update({CompareConst.COMPARE_RESULT: compare_result}) + return metrics + + dtype = self.row_npu.get(ApiPrecisionCompareColumn.DEVICE_DTYPE) + mean_ulp_err = metrics.get(CompareConst.MEAN_ULP_ERR) + ulp_err_proportion = metrics.get(CompareConst.ULP_ERR_PROPORTION) + + if dtype == Const.TORCH_FLOAT32: + status, final_message = self._get_fp32_ulp_err_status(mean_ulp_err, ulp_err_proportion) + else: + status, final_message = self._get_fp16_ulp_err_status(ulp_err_proportion) + metrics[CompareConst.COMPARE_MESSAGE] = metrics.get(CompareConst.COMPARE_MESSAGE, "") + final_message + + status_dict = { + CompareConst.ULP_ERR_STATUS: status + } + compare_result = status + metrics.update(status_dict) + metrics.update({CompareConst.COMPARE_RESULT: compare_result}) + return metrics + + def _compute_ratio(self): + compare_message = "" + mean_ulp_err, mean_ulp_err_message = self._compute_mean_ulp_err() + compare_message += mean_ulp_err_message + npu_ulp_err_proportion = self._compute_ulp_err_proportion() + + metrics = { + CompareConst.MEAN_ULP_ERR: mean_ulp_err, + CompareConst.ULP_ERR_PROPORTION: npu_ulp_err_proportion, + CompareConst.COMPARE_MESSAGE: compare_message + } + return metrics, UlpInfNanConsistency(True, True) + + def _get_and_convert_values(self, column_name): + npu_value = self.row_npu.get(column_name) + if npu_value is None: + raise ValueError(f"value for column '{column_name}' is None.") + npu_value = convert_str_to_float(npu_value) + return npu_value + + +def record_ulp_compare_result(input_data): + us = UlpPrecisionCompare(input_data) + compare_result = us.compare() + return compare_result diff --git a/debug/accuracy_tools/msprobe/pytorch/attl_manager.py b/debug/accuracy_tools/msprobe/pytorch/attl_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..d05f546ae1523a27e78e364a2cdc9c9db9dcb091 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/attl_manager.py @@ -0,0 +1,66 @@ +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from msprobe.core.common.runtime import Runtime +from msprobe.core.common.utils import Const +from msprobe.pytorch.api_accuracy_checker.common.utils import ApiData +from msprobe.pytorch.common.log import logger + + +class ATTLManager: + def __init__(self, config): + self.config = config + self.attl = None + + def attl_init(self): + if self.config.online_run_ut: + from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.attl import ATTLConfig, ATTL + attl_config = ATTLConfig(is_benchmark_device=False, + connect_ip=self.config.host, + connect_port=self.config.port, + nfs_path=self.config.nfs_path, + tls_path=self.config.tls_path) + need_dump = len(self.config.rank) == 0 or Runtime.current_rank in self.config.rank + self.attl = ATTL('npu', attl_config, need_dump=need_dump) + if self.config.nfs_path: + self.attl.upload("start") + + def attl_send(self, name, args, kwargs, output): + api_data = ApiData( + name[:-len(Const.FORWARD_NAME_SUFFIX)], + args, + kwargs, + output, + Runtime.current_iter, + Runtime.current_rank + ) + logger.info(f"tools is dumping api: {api_data.name}, rank: {Runtime.current_rank}") + api_type, _, _ = api_data.name.split(Const.SEP) + if api_type in [Const.DISTRIBUTED]: + logger.info(f"api {api_data.name} is not supported, skip") + return + if self.config.nfs_path: + self.attl.upload(api_data) + else: + self.attl.send(api_data) + + def attl_stop(self): + if self.config.nfs_path: + self.attl.upload("end") + elif self.attl.socket_manager is not None: + logger.info(f"pid: {os.getpid()} finished, start sends STOP signal.") + self.attl.socket_manager.send_stop_signal() \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py b/debug/accuracy_tools/msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py index be15935ce9c9f77bc0a8447902f7f4a7b536a7fb..07655ba841120a80f64a9975a74abd7556569a41 100644 --- a/debug/accuracy_tools/msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +++ b/debug/accuracy_tools/msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py @@ -29,6 +29,8 @@ def softmax_func(x, axis=None): def npu_moe_gating_top_k_softmax(x, finished_optional, k): input_dtype = x.dtype + if x.dim() < 1: + raise ValueError("Input x must have at least 1 dimensions.") num_expert = x.shape[-1] softmax = softmax_func(x, -1) softmax = softmax.to(input_dtype) @@ -36,9 +38,13 @@ def npu_moe_gating_top_k_softmax(x, finished_optional, k): expert_idx = expert_idx[:, :k] y = torch.gather(softmax, index=expert_idx, dim=-1) if finished_optional is not None: + if finished_optional.dim() < 1: + raise ValueError("Finished_optional must have at least 1 dimensions.") finished_optional = finished_optional.view(finished_optional.shape[0], 1) finished_optional = finished_optional.expand(-1, k) expert_idx = torch.where(finished_optional, num_expert, expert_idx) + if y.dim() < 2: + raise ValueError("Variable y must have at least 2 dimensions.") row_idx = torch.arange(y.shape[0] * y.shape[1]).reshape(y.shape[1], y.shape[0]).t() return y, expert_idx, row_idx diff --git a/debug/accuracy_tools/msprobe/pytorch/common/utils.py b/debug/accuracy_tools/msprobe/pytorch/common/utils.py index 16067f6d2bee70645bcc337d1809a14f41ae5b96..8f10660c713228e61c17d43ef1aa297060670f4d 100644 --- a/debug/accuracy_tools/msprobe/pytorch/common/utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/common/utils.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -24,11 +24,12 @@ from functools import wraps import numpy as np import torch import torch.distributed as dist + from msprobe.core.common.exceptions import DistributedNotInitializedError from msprobe.core.common.file_utils import (FileCheckConst, change_mode, check_file_or_directory_path, check_path_before_create, FileOpen) from msprobe.core.common.log import logger -from msprobe.core.common.utils import check_seed_all +from msprobe.core.common.utils import check_seed_all, is_save_variable_valid from packaging import version try: @@ -38,7 +39,9 @@ except ImportError: else: is_gpu = False + torch_without_guard_version = torch.__version__ >= '2.1' +torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0' if not is_gpu and not torch_without_guard_version: from torch_npu.utils.device_guard import torch_device_guard as torch_npu_device_guard @@ -57,7 +60,7 @@ def parameter_adapter(func): @wraps(func) def inner(self, *args, **kwargs): - if self.op_name_ == "__getitem__" and len(args) > 1 and isinstance(args[1], torch.Tensor): + if self.api_name == "__getitem__" and len(args) > 1 and isinstance(args[1], torch.Tensor): input_tensor = args[0] indices = args[1] if indices.dtype == torch.uint8: @@ -77,7 +80,7 @@ def parameter_adapter(func): else: res = [input_tensor[tensor_index] for tensor_index in indices] return getattr(torch._C._VariableFunctionsClass, "stack")(res, 0) - if self.op_name_ == "__eq__" and len(args) > 1 and args[1] is None: + if self.api_name == "__eq__" and len(args) > 1 and args[1] is None: return False return func(self, *args, **kwargs) @@ -261,6 +264,10 @@ class Const: NPU = 'NPU' DISTRIBUTED = 'Distributed' + HIFLOAT8_TYPE = "torch_npu.HiFloat8Tensor" + FLOAT8_E5M2_TYPE = "torch.float8_e5m2" + FLOAT8_E4M3FN_TYPE = "torch.float8_e4m3fn" + RAISE_PRECISION = { torch.float16: torch.float32, torch.bfloat16: torch.float32, @@ -309,14 +316,14 @@ def print_rank_0(message): logger.info(message) -def load_pt(pt_path, to_cpu=False): +def load_pt(pt_path, to_cpu=False, weights_only=True): pt_path = os.path.realpath(pt_path) check_file_or_directory_path(pt_path) try: if to_cpu: - pt = torch.load(pt_path, map_location=torch.device("cpu"), weights_only=True) + pt = torch.load(pt_path, map_location=torch.device("cpu"), weights_only=weights_only) else: - pt = torch.load(pt_path, weights_only=True) + pt = torch.load(pt_path, weights_only=weights_only) except Exception as e: raise RuntimeError(f"load pt file {pt_path} failed") from e return pt @@ -391,7 +398,7 @@ def save_api_data(api_data): io_buff = io.BytesIO() torch.save(api_data, io_buff) except Exception as e: - raise RuntimeError(f"save api_data to io_buff failed") from e + raise RuntimeError("save api_data to io_buff failed") from e return io_buff @@ -399,9 +406,9 @@ def load_api_data(api_data_bytes): """Load data from bytes stream""" try: buffer = io.BytesIO(api_data_bytes) - buffer = torch.load(buffer, map_location="cpu") + buffer = torch.load(buffer, map_location="cpu", weights_only=False) except Exception as e: - raise RuntimeError(f"load api_data from bytes failed") from e + raise RuntimeError("load api_data from bytes failed") from e return buffer @@ -419,7 +426,11 @@ def is_recomputation(): bool: True if in the re-computation phase, False otherwise. """ backward_function_indices = [] - call_stack = inspect.stack() + try: + call_stack = inspect.stack() + except Exception as e: + logger.warning(f"Failed to capture stack trace, recomputation validation may be incorrect, error info: {e}.") + return False # Identify the function 'backward' is being executed within the 'torch/_tensor.py' file. for frame_info in call_stack: @@ -449,9 +460,11 @@ def is_recomputation(): def check_save_param(variable, name, save_backward): # try catch this api to skip invalid call - if not isinstance(variable, (list, dict, torch.Tensor, int, float, str)): + valid_data_types = (torch.Tensor, int, float, str) + if not is_save_variable_valid(variable, valid_data_types): + valid_data_types_with_nested_types = valid_data_types + (dict, tuple, list) logger.warning("PrecisionDebugger.save variable type not valid, " - "should be one of list, dict, torch.Tensor, int, float or string. " + f"should be one of {valid_data_types_with_nested_types}" "Skip current save process.") raise ValueError if not isinstance(name, str): @@ -466,10 +479,31 @@ def check_save_param(variable, name, save_backward): raise ValueError -def replace_last_occurrence(text, old, new): - if text is None: - return text - index = text.rfind(old) - if index != -1: - return text[:index] + text[index:].replace(old, new, 1) - return text +def is_torch_nn_module(variable): + return isinstance(variable, torch.nn.Module) and not isinstance(variable, torch.jit.ScriptModule) + + +def is_hifloat8_tensor(tensor): + if not is_gpu and hasattr(torch_npu, "HiFloat8Tensor") and isinstance(tensor, torch_npu.HiFloat8Tensor): + return True + return False + + +def is_float8_tensor(tensor): + if str(tensor.dtype) in [Const.FLOAT8_E5M2_TYPE, Const.FLOAT8_E4M3FN_TYPE]: + return True + return is_hifloat8_tensor(tensor) + + +def register_forward_pre_hook(module, forward_pre_hook): + if torch_version_above_or_equal_2: + module.register_forward_pre_hook(forward_pre_hook, with_kwargs=True) + else: + module.register_forward_pre_hook(forward_pre_hook) + + +def register_forward_hook(module, forward_hook): + if torch_version_above_or_equal_2: + module.register_forward_hook(forward_hook, with_kwargs=True) + else: + module.register_forward_hook(forward_hook) diff --git a/debug/accuracy_tools/msprobe/pytorch/compare/distributed_compare.py b/debug/accuracy_tools/msprobe/pytorch/compare/distributed_compare.py index de62af421b5a37e39140a9836fb16853443740d7..b706a7544506b723a0c366866cec490eb5a4ff5f 100644 --- a/debug/accuracy_tools/msprobe/pytorch/compare/distributed_compare.py +++ b/debug/accuracy_tools/msprobe/pytorch/compare/distributed_compare.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019-2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -15,14 +15,10 @@ import os -from msprobe.core.common.exceptions import FileCheckException -from msprobe.core.common.file_utils import create_directory -from msprobe.core.common.utils import CompareException, check_compare_param, check_configuration_param, get_dump_mode, \ - set_dump_path -from msprobe.core.compare.acc_compare import ModeConfig -from msprobe.core.compare.utils import check_and_return_dir_contents, extract_json, set_stack_json_path +from msprobe.core.common.utils import CompareException +from msprobe.core.compare.utils import check_and_return_dir_contents, extract_json from msprobe.pytorch.common.log import logger -from msprobe.pytorch.compare.pt_compare import PTComparator, compare +from msprobe.pytorch.compare.pt_compare import compare def compare_distributed(npu_dump_dir, bench_dump_dir, output_path, **kwargs): @@ -50,4 +46,10 @@ def compare_distributed(npu_dump_dir, bench_dump_dir, output_path, **kwargs): "bench_json_path": bench_path, "is_print_compare_log": is_print_compare_log } - compare(input_param=dump_result_param, output_path=output_path, suffix=f'_{nr}-{br}', **kwargs) + try: + compare(input_param=dump_result_param, output_path=output_path, suffix=f'_{nr}', **kwargs) + except CompareException as e: + if e.code == CompareException.INVALID_DATA_ERROR: + logger.error(f"Invalid or missing 'data' in dump.json. Skipping {nr} comparison.") + if e.code == CompareException.INVALID_TASK_ERROR: + logger.error(f"Invalid or missing 'task' in dump.json. Skipping {nr} comparison.") diff --git a/debug/accuracy_tools/msprobe/pytorch/compare/pt_compare.py b/debug/accuracy_tools/msprobe/pytorch/compare/pt_compare.py index 308a82b3d6e9beb67a669ea05b83d7b8a6eddc90..8acaf70c3e078c0c259cf64fe97dca63704cacb5 100644 --- a/debug/accuracy_tools/msprobe/pytorch/compare/pt_compare.py +++ b/debug/accuracy_tools/msprobe/pytorch/compare/pt_compare.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,92 +13,21 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os.path +from msprobe.core.compare.acc_compare import Comparator, ModeConfig, MappingConfig, setup_comparison +from msprobe.pytorch.compare.utils import read_pt_data -import torch -from msprobe.core.common.const import FileCheckConst -from msprobe.core.common.exceptions import FileCheckException -from msprobe.core.common.file_utils import FileChecker, create_directory, load_yaml -from msprobe.core.common.utils import CompareException, check_compare_param, check_configuration_param, get_dump_mode, \ - set_dump_path -from msprobe.core.compare.acc_compare import Comparator, ModeConfig -from msprobe.core.compare.utils import set_stack_json_path -from msprobe.pytorch.common.log import logger -from msprobe.pytorch.common.utils import load_pt - - -class PTComparator(Comparator): - def __init__(self, mode_config, data_mapping=None): - super().__init__(mode_config) - - self.stack_mode = mode_config.stack_mode - self.auto_analyze = mode_config.auto_analyze - self.fuzzy_match = mode_config.fuzzy_match - self.dump_mode = mode_config.dump_mode - - self.frame_name = PTComparator.__name__ - self.data_mapping = data_mapping - if isinstance(self.data_mapping, str) or self.data_mapping is None: - self.data_mapping_dict = self.load_mapping_file(self.data_mapping) - elif isinstance(self.data_mapping, dict): - self.data_mapping_dict = self.data_mapping - else: - raise TypeError(f"The type of parameter `data_mapping` must be dict, str or None, but got " - f"{type(self.data_mapping)}") - - @staticmethod - def load_mapping_file(mapping_file): - if isinstance(mapping_file, str): - mapping_dict = load_yaml(mapping_file) - else: - mapping_dict = {} - return mapping_dict - - def read_npy_data(self, dir_path, file_name): - if not file_name: - return None - data_path = os.path.join(dir_path, file_name) - path_checker = FileChecker(data_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE, - FileCheckConst.PT_SUFFIX, False) - data_path = path_checker.common_check() - try: - # detach because numpy can not process gradient information - data_value = load_pt(data_path, to_cpu=True).detach() - except RuntimeError as e: - # 这里捕获 load_pt 中抛出的异常 - logger.error(f"Failed to load the .pt file at {data_path}.") - raise CompareException(CompareException.INVALID_FILE_ERROR) from e - except AttributeError as e: - # 这里捕获 detach 方法抛出的异常 - logger.error(f"Failed to detach the loaded tensor.") - raise CompareException(CompareException.DETACH_ERROR) from e - if data_value.dtype == torch.bfloat16: - data_value = data_value.to(torch.float32) - data_value = data_value.numpy() - return data_value +def read_real_data(npu_dir, npu_data_name, bench_dir, bench_data_name, _) -> tuple: + n_value = read_pt_data(npu_dir, npu_data_name) + b_value = read_pt_data(bench_dir, bench_data_name) + return n_value, b_value def compare(input_param, output_path, **kwargs): - try: - auto_analyze = kwargs.get('auto_analyze', True) - fuzzy_match = kwargs.get('fuzzy_match', False) - data_mapping = kwargs.get('data_mapping', None) - suffix = kwargs.get('suffix', '') - - set_dump_path(input_param) - dump_mode = get_dump_mode(input_param) - if "stack_json_path" in input_param: - stack_mode = kwargs.get('stack_mode', False) - else: - stack_mode = set_stack_json_path(input_param) # set stack_mode and set "stack_json_path" in input_param - check_configuration_param(stack_mode, auto_analyze, fuzzy_match, input_param.get('is_print_compare_log', True)) - create_directory(output_path) - check_compare_param(input_param, output_path, dump_mode, stack_mode) - except (CompareException, FileCheckException) as error: - logger.error('Compare failed. Please check the arguments and do it again!') - raise CompareException(error.code) from error + config = setup_comparison(input_param, output_path, **kwargs) - mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) - pt_comparator = PTComparator(mode_config, data_mapping) - pt_comparator.compare_core(input_param, output_path, suffix=suffix) + mode_config = ModeConfig(config.stack_mode, config.auto_analyze, config.fuzzy_match, config.dump_mode, + config.first_diff_analyze) + mapping_config = MappingConfig(data_mapping=config.data_mapping) + pt_comparator = Comparator(read_real_data, mode_config, mapping_config) + pt_comparator.compare_core(input_param, output_path, suffix=config.suffix) diff --git a/debug/accuracy_tools/msprobe/pytorch/compare/pt_diff_analyze.py b/debug/accuracy_tools/msprobe/pytorch/compare/pt_diff_analyze.py new file mode 100644 index 0000000000000000000000000000000000000000..b558a20b6f592ac9ebd758a0041155beee413caa --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/compare/pt_diff_analyze.py @@ -0,0 +1,21 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from msprobe.pytorch.compare.distributed_compare import compare_distributed + + +def pt_diff_analyze(npu_dump_dir, bench_dump_dir, output_path, first_diff_analyze): + compare_distributed(npu_dump_dir, bench_dump_dir, output_path, first_diff_analyze=first_diff_analyze) diff --git a/debug/accuracy_tools/msprobe/pytorch/compare/utils.py b/debug/accuracy_tools/msprobe/pytorch/compare/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..16473ff386d89de5f3bbb269e69837c07a950ea5 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/compare/utils.py @@ -0,0 +1,47 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import torch + +from msprobe.core.common.utils import logger, CompareException +from msprobe.core.common.file_utils import FileChecker, FileCheckConst +from msprobe.pytorch.common.utils import load_pt + + +def read_pt_data(dir_path, file_name): + if not file_name: + return None + + data_path = os.path.join(dir_path, file_name) + path_checker = FileChecker(data_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE, + FileCheckConst.PT_SUFFIX, False) + data_path = path_checker.common_check() + try: + # detach because numpy can not process gradient information + data_value = load_pt(data_path, to_cpu=True).detach() + except RuntimeError as e: + # 这里捕获 load_pt 中抛出的异常 + logger.error(f"Failed to load the .pt file at {data_path}.") + raise CompareException(CompareException.INVALID_FILE_ERROR) from e + except AttributeError as e: + # 这里捕获 detach 方法抛出的异常 + logger.error(f"Failed to detach the loaded tensor.") + raise CompareException(CompareException.DETACH_ERROR) from e + if data_value.dtype == torch.bfloat16: + data_value = data_value.to(torch.float32) + data_value = data_value.numpy() + return data_value diff --git a/debug/accuracy_tools/msprobe/pytorch/parse.py b/debug/accuracy_tools/msprobe/pytorch/config_checking/__init__.py similarity index 87% rename from debug/accuracy_tools/msprobe/pytorch/parse.py rename to debug/accuracy_tools/msprobe/pytorch/config_checking/__init__.py index 3dfd88f03d1b944f6943a58ce860c7de9c4a3424..7d60f07881d378bb0a7a9c6faf6147af07a915b2 100644 --- a/debug/accuracy_tools/msprobe/pytorch/parse.py +++ b/debug/accuracy_tools/msprobe/pytorch/config_checking/__init__.py @@ -13,7 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from msprobe.pytorch.parse_tool import cli - -if __name__ == '__main__': - cli.parse() +import msprobe.pytorch.config_checking.checkers diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_vf.py b/debug/accuracy_tools/msprobe/pytorch/config_checking/checkers/__init__.py similarity index 34% rename from debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_vf.py rename to debug/accuracy_tools/msprobe/pytorch/config_checking/checkers/__init__.py index 05ee3bc92257be9882c20cf825ebb7561f41ddb1..c1bed99b3a091f7163431e4cc3e6c02cad69530d 100644 --- a/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_vf.py +++ b/debug/accuracy_tools/msprobe/pytorch/config_checking/checkers/__init__.py @@ -13,48 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os -import torch +__all__ = ['BaseChecker', 'apply_patches'] -from msprobe.core.common.const import Const -from msprobe.core.common.file_utils import load_yaml -from msprobe.pytorch.hook_module.hook_module import HOOKModule -from msprobe.pytorch.common.utils import torch_device_guard +import msprobe.pytorch.config_checking.checkers.env_args_checker +import msprobe.pytorch.config_checking.checkers.pip_checker +import msprobe.pytorch.config_checking.checkers.dataset_checker +import msprobe.pytorch.config_checking.checkers.weights_checker +import msprobe.pytorch.config_checking.checkers.hyperparameter_checker +import msprobe.pytorch.config_checking.checkers.random_checker +from msprobe.pytorch.config_checking.checkers.random_checker import apply_patches -cur_path = os.path.dirname(os.path.realpath(__file__)) -yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml") - - -def get_vf_ops(): - yaml_data = load_yaml(yaml_path) - wrap_vf_ops = yaml_data.get('_VF') - return wrap_vf_ops - - -class HOOKVfOP(object): - pass - - -class VfOPTemplate(HOOKModule): - def __init__(self, op_name, hook): - self.op_name_ = op_name - self.prefix_op_name_ = "VF" + Const.SEP + str(op_name) + Const.SEP - super().__init__(hook) - - @torch_device_guard - def forward(self, *args, **kwargs): - return getattr(torch._C._VariableFunctionsClass, str(self.op_name_))(*args, **kwargs) - - -def wrap_vf_op(op_name, hook): - def vf_op_template(*args, **kwargs): - return VfOPTemplate(op_name, hook)(*args, **kwargs) - - return vf_op_template - - -def wrap_vf_ops_and_bind(hook): - _vf_ops = get_vf_ops() - for op_name in _vf_ops: - setattr(HOOKVfOP, "wrap_" + op_name, wrap_vf_op(op_name, hook)) +from msprobe.pytorch.config_checking.checkers.base_checker import BaseChecker diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_tensor.py b/debug/accuracy_tools/msprobe/pytorch/config_checking/checkers/base_checker.py similarity index 31% rename from debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_tensor.py rename to debug/accuracy_tools/msprobe/pytorch/config_checking/checkers/base_checker.py index f93c09a12415f22d96306ebc9de919520c025236..e61a4ec56344cfc199df7d084264de83a98e51c0 100644 --- a/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_tensor.py +++ b/debug/accuracy_tools/msprobe/pytorch/config_checking/checkers/base_checker.py @@ -14,56 +14,47 @@ # limitations under the License. import os +from abc import ABC, abstractmethod import torch -from msprobe.pytorch.hook_module.hook_module import HOOKModule -from msprobe.pytorch.common.utils import torch_device_guard, parameter_adapter -from msprobe.core.common.const import Const -from msprobe.core.common.file_utils import load_yaml +from msprobe.core.common.const import FileCheckConst -cur_path = os.path.dirname(os.path.realpath(__file__)) -yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml") +class PackInput: + def __init__(self, output_zip_path, model, shell_path): + self.output_zip_path = output_zip_path + self.shell_path = shell_path + self.model = model[0] if isinstance(model, list) and len(model) > 0 else model + self.check_input_params() -def get_tensor_ops(): - _tensor_ops = dir(torch.Tensor) - yaml_data = load_yaml(yaml_path) - wrap_tensor_ops = yaml_data.get('tensor') - return set(wrap_tensor_ops) & set(_tensor_ops) + def check_input_params(self): + if self.model and not isinstance(self.model, torch.nn.Module): + raise Exception(f"model is not torch.nn.Module or module list.") + if not isinstance(self.output_zip_path, str) or not self.output_zip_path.endswith(FileCheckConst.ZIP_SUFFIX): + raise Exception(f"output zip path must be a string and ends with '.zip'") -TensorOps = {op: getattr(torch.Tensor, op) for op in get_tensor_ops()} +class BaseChecker(ABC): + input_needed = None + target_name_in_zip = None + multi_rank = False + @staticmethod + @abstractmethod + def pack(pack_input): + pass -class HOOKTensor(object): - pass + @staticmethod + @abstractmethod + def compare(bench_dir, cmp_dir, output_path): + pass - -class TensorOPTemplate(HOOKModule): - - def __init__(self, op_name, hook, need_hook=True): - self.op_name_ = op_name - self.prefix_op_name_ = "Tensor" + Const.SEP + str(op_name) + Const.SEP - if need_hook: - super().__init__(hook) - - @torch_device_guard - @parameter_adapter - def forward(self, *args, **kwargs): - return TensorOps[str(self.op_name_)](*args, **kwargs) - - -def wrap_tensor_op(op_name, hook): - - def tensor_op_template(*args, **kwargs): - return TensorOPTemplate(op_name, hook)(*args, **kwargs) - - return tensor_op_template - - -def wrap_tensor_ops_and_bind(hook): - _tensor_ops = get_tensor_ops() - for op_name in _tensor_ops: - setattr(HOOKTensor, "wrap_" + str(op_name), wrap_tensor_op(op_name, hook)) + @classmethod + def compare_ex(cls, bench_dir, cmp_dir, output_path): + bench_filepath = os.path.join(bench_dir, cls.target_name_in_zip) + cmp_filepath = os.path.join(cmp_dir, cls.target_name_in_zip) + if not os.path.exists(bench_filepath) or not os.path.exists(cmp_filepath): + return None, None, None + return cls.compare(bench_dir, cmp_dir, output_path) diff --git a/debug/accuracy_tools/msprobe/pytorch/config_checking/checkers/dataset_checker.py b/debug/accuracy_tools/msprobe/pytorch/config_checking/checkers/dataset_checker.py new file mode 100644 index 0000000000000000000000000000000000000000..89217ac18ba4fb2fffe0eaaf325a634ee5d32eb3 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/config_checking/checkers/dataset_checker.py @@ -0,0 +1,138 @@ +# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import json +import torch +import pandas as pd +from msprobe.core.common.file_utils import create_file_in_zip, load_json +from msprobe.pytorch.common.utils import get_rank_id +from msprobe.pytorch.config_checking.checkers.base_checker import BaseChecker +from msprobe.pytorch.config_checking.config_checker import register_checker_item, register_pre_forward_fun_list +from msprobe.pytorch.config_checking.utils.utils import config_checking_print +from msprobe.core.common.decorator import recursion_depth_decorator + + +def process_tensor(tensor): + return { + 'max': float(tensor.max().item()), + 'min': float(tensor.min().item()), + 'mean': float(tensor.mean().item()), + 'norm': float(torch.norm(tensor).item()) + } + + +@recursion_depth_decorator("config_check: process_obj") +def process_obj(obj): + if isinstance(obj, torch.Tensor): + return process_tensor(obj) + elif isinstance(obj, (tuple, list)): + return {i: process_obj(x) for i, x in enumerate(obj)} + elif isinstance(obj, dict): + return {k: process_obj(v) for k, v in obj.items()} + else: + return "" + + +def parse_args_and_kargs(args, kwargs): + processed_args = process_obj(args) + processed_kargs = process_obj(kwargs) + + return { + 'args': processed_args, + 'kwargs': processed_kargs + } + + +@recursion_depth_decorator("config_check: compare_dataset_dicts") +def compare_dataset_dicts(dict1, dict2, tag=''): + results = [] + for key, value1 in dict1.items(): + new_tag = f"{tag}.{key}" if tag else key + value2 = dict2[key] + # 若为包含四个指定键的字典,不再递归 + if not isinstance(value1, dict): + continue + if set(value1.keys()) == {'max', 'min', 'mean', 'norm'}: + equal = value1 == value2 + relative_diffs = { + f"{k}_relative_diff": (abs(value1[k] - value2[k]) / value1[k]) \ + if value1[k] != 0 else None \ + for k in ['max', 'min', 'mean', 'norm'] + } + result = {'tag': new_tag, 'equal': equal} + result.update(relative_diffs) + results.append(result) + else: + results.extend(compare_dataset_dicts(value1, value2, new_tag)) + return results + + +def compare_dataset(bench_dir, cmp_dir): + all_results = [] + for step in os.listdir(bench_dir): + step_path_bench = os.path.join(bench_dir, step) + if not os.path.isdir(step_path_bench): + continue + step_path_cmp = os.path.join(cmp_dir, step) + for rank in os.listdir(step_path_bench): + rank_path_bench = os.path.join(step_path_bench, rank, 'dataset.json') + rank_path_cmp = os.path.join(step_path_cmp, rank, 'dataset.json') + if not os.path.isfile(rank_path_bench) or not os.path.isfile(rank_path_cmp): + continue + + dict1 = load_json(rank_path_bench) + dict2 = load_json(rank_path_cmp) + results = compare_dataset_dicts(dict1, dict2) + for result in results: + result['step'] = step + result['rank'] = rank + all_results.extend(results) + + df = pd.DataFrame(all_results, columns=DatasetChecker.result_header) + df = df.sort_values(by=['step', 'rank'], ascending=[True, True]) + return df + + +@register_checker_item("dataset") +class DatasetChecker(BaseChecker): + input_needed = "model" + multi_rank = True + + target_name_in_zip = "dataset" + result_header = ['step', 'rank', 'tag', 'equal', 'max_relative_diff', + 'min_relative_diff', 'mean_relative_diff', 'norm_relative_diff'] + + @staticmethod + def pack(pack_input): + output_zip_path = pack_input.output_zip_path + + def collect_input(model, args, kwargs, step): + features = parse_args_and_kargs(args, kwargs) + dataset_filepath = os.path.join(DatasetChecker.target_name_in_zip, + f"step{step}", f"rank{get_rank_id()}", "dataset.json") + create_file_in_zip(output_zip_path, dataset_filepath, json.dumps(features, indent=4)) + config_checking_print(f"add first dataset input features to zip") + + register_pre_forward_fun_list(collect_input) + + @staticmethod + def compare(bench_dir, cmp_dir, output_path): + bench_dataset_pack_path = os.path.join(bench_dir, DatasetChecker.target_name_in_zip) + cmp_dataset_pack_path = os.path.join(cmp_dir, DatasetChecker.target_name_in_zip) + + df = compare_dataset(bench_dataset_pack_path, cmp_dataset_pack_path) + pass_check = False not in df['equal'].values + return DatasetChecker.target_name_in_zip, pass_check, df diff --git a/debug/accuracy_tools/msprobe/pytorch/config_checking/checkers/env_args_checker.py b/debug/accuracy_tools/msprobe/pytorch/config_checking/checkers/env_args_checker.py new file mode 100644 index 0000000000000000000000000000000000000000..9eaeb1a05729f7e9ae4ff8c9727d8f3589b3a166 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/config_checking/checkers/env_args_checker.py @@ -0,0 +1,77 @@ +# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import json + +import pandas as pd + +from msprobe.core.common.file_utils import load_json, load_yaml, create_file_with_content, create_file_in_zip +from msprobe.pytorch.config_checking.checkers.base_checker import BaseChecker +from msprobe.pytorch.config_checking.config_checker import register_checker_item +from msprobe.pytorch.config_checking.utils.utils import config_checking_print +from msprobe.core.common.file_utils import save_excel + + +dirpath = os.path.dirname(__file__) +env_yaml_path = os.path.join(dirpath, "../resource/env.yaml") + + +def collect_env_data(): + result = {} + for key, value in os.environ.items(): + result[key] = value + return result + + +def compare_env_data(npu_path, bench_path): + necessary_env = load_yaml(env_yaml_path) + npu_data = load_json(npu_path) + bench_data = load_json(bench_path) + data = [] + for _, value in necessary_env.items(): + npu_env_name = value[0]["name"] + npu_value = npu_data.get(npu_env_name) if npu_data.get(npu_env_name) else value[0]["default_value"] + if len(value) == 1: + data.append([npu_env_name, "only npu has this env", npu_value, "", "warning"]) + continue + bench_env_name = value[1]["name"] + bench_value = bench_data.get(bench_env_name) if bench_data.get(bench_env_name) else value[1]["default_value"] + if npu_value != bench_value: + data.append([npu_env_name, bench_env_name, npu_value, bench_value, "error"]) + df = pd.DataFrame(data, columns=EnvArgsChecker.result_header) + return df + + +@register_checker_item("env") +class EnvArgsChecker(BaseChecker): + + target_name_in_zip = "env" + result_header = ["bench_env_name", "cmp_env_name", "bench_value", "cmp_value", "level"] + + @staticmethod + def pack(pack_input): + output_zip_path = pack_input.output_zip_path + env_args_dict = collect_env_data() + create_file_in_zip(output_zip_path, EnvArgsChecker.target_name_in_zip, json.dumps(env_args_dict, indent=4)) + config_checking_print(f"add env args to zip") + + @staticmethod + def compare(bench_dir, cmp_dir, output_path): + bench_env_data = os.path.join(bench_dir, EnvArgsChecker.target_name_in_zip) + cmp_env_data = os.path.join(cmp_dir, EnvArgsChecker.target_name_in_zip) + df = compare_env_data(bench_env_data, cmp_env_data) + pass_check = "error" not in df['level'].values + return EnvArgsChecker.target_name_in_zip, pass_check, df diff --git a/debug/accuracy_tools/msprobe/pytorch/config_checking/checkers/hyperparameter_checker.py b/debug/accuracy_tools/msprobe/pytorch/config_checking/checkers/hyperparameter_checker.py new file mode 100644 index 0000000000000000000000000000000000000000..9ac1cd61fc5483c1a002bf0109d56a341aeab120 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/config_checking/checkers/hyperparameter_checker.py @@ -0,0 +1,216 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import json +import re +import tempfile +from difflib import SequenceMatcher + +from typing import Union, List, Dict, Any + +from msprobe.pytorch.config_checking.checkers.base_checker import BaseChecker +from msprobe.pytorch.config_checking.config_checker import register_checker_item +from msprobe.pytorch.config_checking.utils.utils import compare_dict, config_checking_print +from msprobe.core.common.file_utils import (os_walk_for_files, create_file_in_zip, load_json, create_file_with_list, + FileOpen) +from msprobe.core.common.const import FileCheckConst, Const + + +@register_checker_item("hyperparameter") +class HyperparameterChecker(BaseChecker): + input_needed = "shell_path" + target_name_in_zip = "hyperparameters" + + PARAMETER_NAME_MAPPING = { + "learning_rate": ["lr", "learningrate"], + "batch_size": ["batch", "bs", "batch_size_per_gpu"], + "epochs": ["num_epochs", "max_epochs", "epoch"], + "weight_decay": ["wd", "weightdecay"], + "dropout_rate": ["dropout", "drop_rate"], + } + + @staticmethod + def pack(pack_input): + shell_path = pack_input.shell_path + output_zip_path = pack_input.output_zip_path + + if not isinstance(shell_path, list): + raise TypeError("shell_path should be a list of file paths.") + + for index, script_path in enumerate(shell_path): + if os.path.isfile(script_path): + hyperparameters = HyperparameterChecker._extract_hyperparameters_from_script(script_path) + if hyperparameters: + create_file_in_zip(output_zip_path, os.path.join(HyperparameterChecker.target_name_in_zip, + HyperparameterChecker.target_name_in_zip + + Const.REPLACEMENT_CHARACTER + str(index) + + FileCheckConst.JSON_SUFFIX), + json.dumps(hyperparameters, indent=4)) + config_checking_print(f"add hyperparameters args to zip") + else: + config_checking_print(f"Warning: Failed to extract hyperparameters from script {script_path}") + else: + config_checking_print(f"Warning: Script path {script_path} is not a file.") + + @staticmethod + def compare(bench_dir, cmp_dir, output_path): + bench_model_dir = os.path.join(bench_dir, HyperparameterChecker.target_name_in_zip) + cmp_model_dir = os.path.join(cmp_dir, HyperparameterChecker.target_name_in_zip) + bench_hyperparameters = HyperparameterChecker.load_hyperparameters(bench_model_dir) + cmp_hyperparameters = HyperparameterChecker.load_hyperparameters(cmp_model_dir) + + if len(bench_hyperparameters) != len(cmp_hyperparameters): + config_checking_print("The shell path length dose not match!") + raise Exception("The shell path length dose not match!") + + all_diffs = [] + all_files = set(bench_hyperparameters.keys()) | set(cmp_hyperparameters.keys()) + + for filename in all_files: + bench_params = bench_hyperparameters.get(filename, {}) + cmp_params = cmp_hyperparameters.get(filename, {}) + + if bench_params and cmp_params: + all_diffs.extend(HyperparameterChecker.compare_param(bench_params, cmp_params, filename)) + + elif bench_params is not None: + all_diffs.append(f"[Only in benchmark] File: {filename}") + else: + all_diffs.append(f"[Only in compare] File: {filename}") + return HyperparameterChecker.target_name_in_zip, True, None + + @staticmethod + def compare_param(bench_params, cmp_params, filename): + all_diffs = [] + file_diffs = [] + bench_param_names = bench_params.keys() + for bench_param_name in bench_param_names: + matched_cmp_param_name = HyperparameterChecker._fuzzy_match_parameter(bench_param_name, cmp_params) + if matched_cmp_param_name: + bench_param_value = bench_params[bench_param_name] + cmp_param_value = cmp_params[matched_cmp_param_name] + if bench_param_value != cmp_param_value: + diff = compare_dict({bench_param_name: bench_param_value}, + {matched_cmp_param_name: cmp_param_value}) + if diff: + file_diffs.extend( + [f" Parameter '{bench_param_name}' (matched with '{matched_cmp_param_name}'): {d}" + for d in diff]) + del cmp_params[matched_cmp_param_name] + else: + file_diffs.append( + f" [Only in benchmark] Parameter: '{bench_param_name}': {bench_params[bench_param_name]}") + for cmp_param_name, cmp_param_value in cmp_params.items(): + file_diffs.append(f" [Only in compare] Parameter: '{cmp_param_name}': {cmp_param_value}") + if file_diffs: + file_diffs.sort() + all_diffs.append(f"File: {filename}") + all_diffs.extend(file_diffs) + return all_diffs + + @staticmethod + def load_hyperparameters(model_dir): + hyperparameters = {} + if not os.path.exists(model_dir): + return hyperparameters + subfiles = os_walk_for_files(model_dir, Const.MAX_TRAVERSAL_DEPTH) + for files in subfiles: + if files["file"].endswith(FileCheckConst.JSON_SUFFIX): + filepath = os.path.join(files["root"], files["file"]) + relative_filepath = os.path.relpath(filepath, model_dir) + params = load_json(filepath) + if params: + hyperparameters[relative_filepath] = params + return hyperparameters + + @staticmethod + def _extract_hyperparameters_from_script(script_path: str) -> Dict[str, Any]: + """ + Extracts arguments from bash script used to run a model training. + """ + hyperparameters = {} + script_content_list = [] + with FileOpen(script_path, 'r') as file: + for line in file: + stripped_line = line.lstrip() + if not stripped_line.startswith('#'): + line = line.split('#')[0].rstrip() + '\n' + if line.strip(): + script_content_list.append(line) + script_content = ''.join(script_content_list) + + command_line = re.search(r'torchrun\s[^|]*|python -m torch.distributed.launch\s[^|]*', script_content, + re.DOTALL) + if command_line: + command_line = command_line.group() + + blocks = re.findall(r'([a-zA-Z0-9_]{1,20}_ARGS)="(.*?)"', script_content, re.DOTALL) + block_contents = {} + for block_name, block_content in blocks: + block_content = block_content.replace('\n', ' ') + block_contents[block_name] = block_content + command_line = command_line.replace(f"${block_name}", block_content) + + matches = re.findall(r'--([\w-]+)(?:\s+([^\s\\]+))?', command_line) + for match in matches: + key, value = match + args_key = re.match(r'\$\{?(\w+)}?', value) + if args_key: + env_vars = re.findall(rf'{args_key.group(1)}=\s*(.+)', script_content) + if env_vars: + value = env_vars[-1] + hyperparameters[key] = value if value else True + + return hyperparameters + + @staticmethod + def _fuzzy_match_parameter(param_name: str, available_params: Dict[str, Any]) -> Union[str, None]: + """ + Fuzzy matches a parameter name against available parameter names using predefined + mappings and string similarity. + """ + if param_name in available_params: + return param_name + + canonical_name = None + for standard_name, aliases in HyperparameterChecker.PARAMETER_NAME_MAPPING.items(): + if param_name == standard_name or param_name in aliases: + canonical_name = standard_name + break + + if canonical_name: + if canonical_name in available_params: + return canonical_name + for alias in HyperparameterChecker.PARAMETER_NAME_MAPPING[canonical_name]: + if alias in available_params: + config_checking_print( + f"Matched '{param_name}' to alias '{alias}' via canonical name '{canonical_name}'") + return alias + + best_match_name = None + best_match_ratio = 0.8 + for available_param_name in available_params: + ratio = SequenceMatcher(None, param_name.lower(), available_param_name.lower()).ratio() + if ratio > best_match_ratio: + best_match_ratio = ratio + best_match_name = available_param_name + + if best_match_name: + config_checking_print( + f"Fuzzy matched parameter '{param_name}' to '{best_match_name}' (similarity: {best_match_ratio:.2f})") + return best_match_name + + return None diff --git a/debug/accuracy_tools/msprobe/pytorch/config_checking/checkers/pip_checker.py b/debug/accuracy_tools/msprobe/pytorch/config_checking/checkers/pip_checker.py new file mode 100644 index 0000000000000000000000000000000000000000..15c02d16843d72103293834a10d6952c59cf73f3 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/config_checking/checkers/pip_checker.py @@ -0,0 +1,89 @@ +# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import pandas as pd +try: + import importlib.metadata as metadata +except ImportError: + import importlib_metadata as metadata + +from msprobe.core.common.file_utils import load_yaml, create_file_in_zip +from msprobe.pytorch.config_checking.checkers.base_checker import BaseChecker +from msprobe.pytorch.config_checking.config_checker import register_checker_item +from msprobe.pytorch.config_checking.utils.utils import config_checking_print +from msprobe.core.common.file_utils import FileOpen, save_excel + +dirpath = os.path.dirname(__file__) +depend_path = os.path.join(dirpath, "../resource/dependency.yaml") + + +def load_pip_txt(file_path): + output_dir = {} + with FileOpen(file_path, 'r', encoding='utf-8') as file: + lines = file.readlines() + for line in lines: + info_list = line.strip().split("=") + output_dir[info_list[0]] = "" if len(info_list) != 2 else info_list[1] + return output_dir + + +def collect_pip_data(): + result = "" + packages = metadata.distributions() + for pkg in packages: + if pkg.metadata: + result += f"{pkg.metadata.get('Name')}={pkg.version}\n" + return result + + +def compare_pip_data(bench_pip_path, cmp_pip_path): + necessary_dependency = load_yaml(depend_path)["dependency"] + bench_data = load_pip_txt(bench_pip_path) + cmp_data = load_pip_txt(cmp_pip_path) + data = [] + for package in necessary_dependency: + bench_version = bench_data.get(package) + cmp_version = cmp_data.get(package) + + if bench_version != cmp_version: + data.append([package, bench_version if bench_version else 'None', + cmp_version if cmp_version else 'None', + "error"]) + + df = pd.DataFrame(data, columns=PipPackageChecker.result_header) + return df + + +@register_checker_item("pip") +class PipPackageChecker(BaseChecker): + + target_name_in_zip = "pip" + result_header = ['package', 'bench version', 'cmp version', 'level'] + + @staticmethod + def pack(pack_input): + output_zip_path = pack_input.output_zip_path + pip_data = collect_pip_data() + create_file_in_zip(output_zip_path, PipPackageChecker.target_name_in_zip, pip_data) + config_checking_print(f"add pip info to zip") + + @staticmethod + def compare(bench_dir, cmp_dir, output_path): + bench_pip_path = os.path.join(bench_dir, PipPackageChecker.target_name_in_zip) + cmp_pip_path = os.path.join(cmp_dir, PipPackageChecker.target_name_in_zip) + df = compare_pip_data(bench_pip_path, cmp_pip_path) + pass_check = "error" not in df['level'].values + return PipPackageChecker.target_name_in_zip, pass_check, df diff --git a/debug/accuracy_tools/msprobe/pytorch/config_checking/checkers/random_checker.py b/debug/accuracy_tools/msprobe/pytorch/config_checking/checkers/random_checker.py new file mode 100644 index 0000000000000000000000000000000000000000..883144d617d8ea33afbedf8d510f7181450fae6c --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/config_checking/checkers/random_checker.py @@ -0,0 +1,182 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import random +from functools import wraps +from typing import Callable +import inspect +import os +import json +from collections import defaultdict + +import numpy as np +import torch +import pandas as pd +from msprobe.pytorch.config_checking.config_checker import register_checker_item, register_pre_forward_fun_list +from msprobe.pytorch.common.utils import get_rank_id +from msprobe.core.common.file_utils import create_file_in_zip, load_json, save_excel +from msprobe.pytorch.config_checking.checkers.base_checker import BaseChecker +from msprobe.pytorch.config_checking.utils.utils import config_checking_print + + +random_log_dict = defaultdict(dict) + + +def load_json_files(directory): + json_data = {} + for file in os.listdir(directory): + file_path = os.path.join(directory, file) + if file.startswith('rank') and file.endswith('.json'): + json_data.update(load_json(file_path)) + return json_data + + +def get_file_and_line(position): + parts = position.rsplit(':', 1) + if len(parts) == 2: + file_name = os.path.basename(parts[0]) + line_num = parts[1] + return f"{file_name}:{line_num}" + return position + + +def compare_json_files(bench_data, cmp_data): + results = [] + for op in set(bench_data) | set(cmp_data): + bench_records = bench_data.get(op, {}) + cmp_records = cmp_data.get(op, {}) + all_positions = set() + for position in set(bench_records) | set(cmp_records): + all_positions.add(get_file_and_line(position)) + + for position in all_positions: + bench_count = 0 + cmp_count = 0 + for original_position, count in bench_records.items(): + if get_file_and_line(original_position) == position: + bench_count += count + for original_position, count in cmp_records.items(): + if get_file_and_line(original_position) == position: + cmp_count += count + results.append([op, position, bench_count == cmp_count, bench_count, cmp_count]) + return results + + +def compare_random(bench_dir='bench', cmp_dir='cmp'): + bench_data = load_json_files(bench_dir) + cmp_data = load_json_files(cmp_dir) + results = compare_json_files(bench_data, cmp_data) + df = pd.DataFrame(results, columns=RandomChecker.result_header) + return df + + +def track_random_call(func: Callable, name: str): + @wraps(func) + def wrapper(*args, **kwargs): + frame = inspect.currentframe() + caller_frame = frame.f_back + caller_info = inspect.getframeinfo(caller_frame) + location = f"{os.path.abspath(caller_info.filename)}:{caller_info.lineno}" + + global random_log_dict + random_log_dict.setdefault(name, {}) + random_log_dict[name][location] = random_log_dict[name].get(location, 0) + 1 + + try: + result = func(*args, **kwargs) + return result + except Exception as e: + raise e + finally: + del frame, caller_frame + + return wrapper + + +def apply_patches(): + random_patches = { + 'random': random.random, + 'randint': random.randint, + 'uniform': random.uniform, + 'choice': random.choice + } + for name, func in random_patches.items(): + setattr(random, name, track_random_call(func, f"random.{name}")) + + np_random_patches = { + 'rand': np.random.rand, + 'randint': np.random.randint, + 'choice': np.random.choice, + 'normal': np.random.normal + } + for name, func in np_random_patches.items(): + setattr(np.random, name, track_random_call(func, f"np.random.{name}")) + + torch_patches = { + 'rand': torch.rand, + 'randint': torch.randint, + 'randn': torch.randn, + 'rand_like': torch.rand_like, + 'randint_like': torch.randint_like, + 'randn_like': torch.randn_like, + 'manual_seed': torch.manual_seed + } + for name, func in torch_patches.items(): + setattr(torch, name, track_random_call(func, f"torch.{name}")) + + tensor_patches = { + 'exponential_': torch.Tensor.exponential_, + 'geometric_': torch.Tensor.geometric_, + 'log_normal_': torch.Tensor.log_normal_, + 'cauchy_': torch.Tensor.cauchy_ + } + for name, func in tensor_patches.items(): + setattr(torch.Tensor, name, track_random_call(func, f"torch.Tensor.{name}")) + + + +@register_checker_item("random") +class RandomChecker(BaseChecker): + input_needed = None + + target_name_in_zip = "random" + result_header = ['op', 'position', 'equal', 'bench_count', 'cmp_count'] + write_once = False + + @staticmethod + def pack(pack_input): + output_zip_path = pack_input.output_zip_path + + def collect_input(model, args, kwargs, step): + if RandomChecker.write_once: + return + + random_log_filepath = os.path.join(RandomChecker.target_name_in_zip, f"rank{get_rank_id()}.json") + create_file_in_zip(output_zip_path, random_log_filepath, json.dumps(random_log_dict, indent=4)) + config_checking_print(f"add first random_log input features to zip") + RandomChecker.write_once = True + + register_pre_forward_fun_list(collect_input) + + @staticmethod + def compare(bench_dir, cmp_dir, output_path): + bench_random_log_pack_path = os.path.join(bench_dir, RandomChecker.target_name_in_zip) + cmp_random_log_pack_path = os.path.join(cmp_dir, RandomChecker.target_name_in_zip) + + df = compare_random(bench_random_log_pack_path, cmp_random_log_pack_path) + pass_check = False not in df['equal'].values + return RandomChecker.target_name_in_zip, pass_check, df + diff --git a/debug/accuracy_tools/msprobe/pytorch/config_checking/checkers/weights_checker.py b/debug/accuracy_tools/msprobe/pytorch/config_checking/checkers/weights_checker.py new file mode 100644 index 0000000000000000000000000000000000000000..6cfd758f713861241ba55e36574e7e75e2041636 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/config_checking/checkers/weights_checker.py @@ -0,0 +1,148 @@ +# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import json +import torch +import pandas as pd + +from msprobe.core.common.file_utils import create_file_in_zip, save_excel, os_walk_for_files, load_json +from msprobe.pytorch.common.utils import get_rank_id +from msprobe.pytorch.config_checking.checkers.base_checker import BaseChecker +from msprobe.pytorch.config_checking.config_checker import register_checker_item, register_pre_forward_fun_list +from msprobe.pytorch.config_checking.utils.utils import config_checking_print, get_tensor_features + + +def collect_weights_data(model): + weights_data = {} + for name, param in model.named_parameters(): + if param.dtype == torch.bfloat16: + param = param.float() + weights_data[name] = get_tensor_features(param) + return weights_data + + +def compare_weight_file(bench_file, cmp_file): + bench_data = load_json(bench_file) + cmp_data = load_json(cmp_file) + + results = [] + for weight_name in set(bench_data.keys()) | set(cmp_data.keys()): + result = { + "weight_name": weight_name, + "equal": None, + "max_relative_diff": None, + "min_relative_diff": None, + "mean_relative_diff": None, + "norm_relative_diff": None + } + + if weight_name not in bench_data: + result["equal"] = "only cmp have" + results.append(result) + continue + + if weight_name not in cmp_data: + result["equal"] = "only bench have" + results.append(result) + continue + + bench_vals = bench_data[weight_name] + cmp_vals = cmp_data[weight_name] + keys = ["max", "min", "mean", "norm"] + equal = all([bench_vals[k] == cmp_vals[k] for k in keys]) + result["equal"] = equal + + for key in keys: + diff_key = f"{key}_relative_diff" + result[diff_key] = (abs(bench_vals[key] - cmp_vals[key]) / bench_vals[key]) \ + if bench_vals[key] != 0 else None + + results.append(result) + + return results + + +def compare_weight(bench_dir, cmp_dir): + all_results = [] + bench_files_info = os_walk_for_files(bench_dir, 10) + for info in bench_files_info: + if not info["file"].endswith('.json'): + continue + bench_file = os.path.join(info["root"], info["file"]) + relative_path = os.path.relpath(info["root"], bench_dir) + cmp_root = os.path.join(cmp_dir, relative_path) + cmp_file = os.path.join(cmp_root, info["file"]) + + path_list = relative_path.split(os.sep) + if len(path_list) < 2: + raise Exception("Can not compare weights because the extracted file has been corrupted!") + step = int(path_list[0].replace("step", "")) + rank = int(path_list[1].replace("rank", "")) + + if not os.path.exists(cmp_file): + bench_data = load_json(bench_file) + for weight_name in bench_data.keys(): + result = { + "step": step, + "rank": rank, + "weight_name": weight_name, + "equal": "only bench have", + "max_relative_diff": None, + "min_relative_diff": None, + "mean_relative_diff": None, + "norm_relative_diff": None + } + all_results.append(result) + else: + results = compare_weight_file(bench_file, cmp_file) + for res in results: + res["step"] = step + res["rank"] = rank + all_results.append(res) + + df = pd.DataFrame(all_results, columns=WeightsChecker.result_header) + df = df.sort_values(by=['step', 'rank'], ascending=[True, True]) + return df + + +@register_checker_item("weights") +class WeightsChecker(BaseChecker): + input_needed = "model" + multi_rank = True + + target_name_in_zip = "weights" + result_header = ["step", "rank", "weight_name", "equal", "max_relative_diff", + "min_relative_diff", "mean_relative_diff", "norm_relative_diff"] + + @staticmethod + def pack(pack_input): + output_zip_path = pack_input.output_zip_path + + def collect_weights(model, args, kwargs, step): + weights_data_dict = collect_weights_data(model) + weights_data_filepath = os.path.join(WeightsChecker.target_name_in_zip, + f"step{step}", f"rank{get_rank_id()}", "weight.json") + create_file_in_zip(output_zip_path, weights_data_filepath, json.dumps(weights_data_dict, indent=4)) + config_checking_print(f"add weights info to zip") + register_pre_forward_fun_list(collect_weights) + + @staticmethod + def compare(bench_dir, cmp_dir, output_path): + bench_weight_pack_path = os.path.join(bench_dir, WeightsChecker.target_name_in_zip) + cmp_weight_pack_path = os.path.join(cmp_dir, WeightsChecker.target_name_in_zip) + df = compare_weight(bench_weight_pack_path, cmp_weight_pack_path) + pass_check = False not in df['equal'].values + return WeightsChecker.target_name_in_zip, pass_check, df diff --git a/debug/accuracy_tools/msprobe/pytorch/config_checking/ckpt_compare/compare_weight.py b/debug/accuracy_tools/msprobe/pytorch/config_checking/ckpt_compare/compare_weight.py new file mode 100644 index 0000000000000000000000000000000000000000..b4c49fc3a8e0ed7838de451f9e8dcfbcf4363388 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/config_checking/ckpt_compare/compare_weight.py @@ -0,0 +1,71 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict +from tqdm import tqdm + +from msprobe.core.common.file_utils import save_json, check_file_or_directory_path +from msprobe.pytorch.common.log import logger +from msprobe.pytorch.config_checking.ckpt_compare.megatron_loader import load_megatron_weights +from msprobe.pytorch.config_checking.ckpt_compare.metrics import METRIC_FUNC + + +def compare_checkpoints(ckpt_path1, ckpt_path2, output_path) -> Dict: + """Compare weights between two checkpoints using cosine similarity and L2 distance. + + Args: + ckpt_path1 (str): Path to first checkpoint directory + ckpt_path2 (str): Path to second checkpoint directory + output_path (str): Path to save comparison results JSON file + + Returns: + Dict: Dictionary containing comparison metrics for each parameter. The dictionary has the following structure: + { + "param_name": { + "cosine_similarity": float, # Cosine similarity between parameter tensors + "l2_distance": float, # L2 distance between parameter tensors + "shape": List[int] # Shape of the parameter tensors + }, + ... + } + """ + + # Load both checkpoints + check_file_or_directory_path(output_path) + weights1 = load_megatron_weights(ckpt_path1) + weights2 = load_megatron_weights(ckpt_path2) + + # Initialize results dictionary + results = {} + + # Compare weights with matching keys + common = set(weights1) & set(weights2) + logger.warning(f'Parameters not in ckpt2: {set(weights1) - set(weights2)}') + logger.warning(f'Parameters not in ckpt1: {set(weights2) - set(weights1)}') + for key in tqdm(common): + tensor1 = weights1[key].float() + tensor2 = weights2[key].float() + + results[key] = {} + for metric, func in METRIC_FUNC.items(): + try: + results[key][metric] = func(tensor1, tensor2) + except Exception as e: + logger.warning(f'Error when calculate {metric} for reason: {e}') + + # Write results to JSON file + save_json(output_path, results, indent=4) + logger.info(f"Comparison results written to {output_path}") + return results diff --git a/debug/accuracy_tools/msprobe/pytorch/config_checking/ckpt_compare/megatron_loader.py b/debug/accuracy_tools/msprobe/pytorch/config_checking/ckpt_compare/megatron_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..3dea9792360a3253e6e917eb8427a53ca44481e1 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/config_checking/ckpt_compare/megatron_loader.py @@ -0,0 +1,273 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import re +from collections import defaultdict +from typing import Dict +import torch +from msprobe.pytorch.common.log import logger +from msprobe.core.common.decorator import recursion_depth_decorator +from msprobe.core.common.const import Const +from msprobe.core.common.file_utils import FileOpen, load_yaml +from msprobe.pytorch.common.utils import load_pt + +try: + import megatron +except ModuleNotFoundError as e: + raise ModuleNotFoundError("No module named 'megatron', which is required to load a megatron ckpt") from e + + +COLUMN_PARALLEL_PARAMS = ['linear_qkv', 'linear_fc1', 'word_embeddings.weight'] +ARGS = 'args' +LAYER_IDX_PATTERN = re.compile('layers\.(\d+)\.') +EXPERT_IDX_PATTERN = re.compile('experts\.(\d+)\.') + + +@recursion_depth_decorator('') +def _get_parameter(weights, prefix=''): + for k, v in weights.items(): + name = Const.SEP.join([prefix, k]).strip(Const.SEP) + if isinstance(v, dict): + yield from _get_parameter(v, prefix=name) + elif isinstance(v, torch.Tensor): + yield name, v + + +def _map_to_mcore_local_names(param_name: str) -> str: + """Map parameter names to mcore + local transformer implementation names.""" + mcore_local_map = load_yaml(os.path.join(os.path.dirname(__file__), 'name_mapping.yaml')) + for other_name, mcore_local_name in mcore_local_map.items(): + param_name = param_name.replace(other_name, mcore_local_name) + + return param_name + + +def _parse_real_layer_idx(param_name, num_layers_per_stage, pp_size, pp_rank): + """Map local (virtual) pipeline stage layer index to global layer index. + + For virtual pipeline parallel, each pipeline stage is further divided into virtual stages. + The global layer index needs to account for both pipeline stage and virtual stage. + + Args: + param_name (str): Parameter name containing layer index + num_layers_per_stage (int): Number of layers per pipeline stage + pp_size (int): Pipeline parallel size + + Returns: + int: Global layer index accounting for both pipeline and virtual pipeline stages + """ + # Extract local layer index from parameter name + layer_match = re.search(LAYER_IDX_PATTERN, param_name) + param_name, vpp_stage = param_name.split(Const.SCOPE_SEPARATOR) + if not layer_match: + return param_name + + local_layer_idx = int(layer_match.group(1)) + vpp_stage = int(vpp_stage) + + # Calculate global layer index based on pipeline stage and virtual stage + real_layer_idx = local_layer_idx + (pp_size * vpp_stage + pp_rank) * num_layers_per_stage + + return param_name.replace(f'layers.{local_layer_idx}', f'layers.{real_layer_idx}') + + +def _parse_real_expert_idx(param_name, num_experts_per_rank, exp_rank): + """Map local expert index to global expert index. TODO: shared expert + + For expert parallel, experts are distributed across ranks. This function maps + the local expert index on a rank to its global index across all ranks. + + Args: + param_name (str): Parameter name containing local expert index + num_experts_per_rank (int): Number of experts on each rank + exp_rank (int): Expert parallel rank + + Returns: + str: Parameter name with local expert index replaced by global expert index + """ + # Extract local layer index from parameter name + expert_match = re.search(EXPERT_IDX_PATTERN, param_name) + if not expert_match: + return param_name + + local_expert_idx = int(expert_match.group(1)) + # Calculate global layer index based on pipeline stage and virtual stage + real_experts_idx = local_expert_idx + exp_rank * num_experts_per_rank + + return param_name.replace(f'experts.{local_expert_idx}', f'experts.{real_experts_idx}') + + +def _consolidate_tp_weights(weights: Dict) -> Dict: + """Consolidate weights from different tensor parallel ranks into combined tensors. + + Args: + weights: Dictionary of weights with rank information in keys + + Returns: + Dict: Consolidated weights without rank information + """ + consolidated = {} + for key, tensors in weights.items(): + if any([name in key for name in COLUMN_PARALLEL_PARAMS]): + # Column parallel - concatenate along input dimension (dim 0) + combined = torch.cat(tensors, dim=0) + elif "linear_proj.weight" in key or "linear_fc2.weight" in key: + # Row parallel - concatenate along output dimension (dim 1) + combined = torch.cat(tensors, dim=1) + else: + # For other params, verify identical and use first + if not all(torch.allclose(tensors[0], t) for t in tensors[1:]): + logger.warning(f"Inconsistent values for {key} across TP ranks") + combined = tensors[0] + + consolidated[key] = combined + return consolidated + + +def _parse_num_layers_per_stage(tp_partition): + match = [re.findall(LAYER_IDX_PATTERN, key) for key in tp_partition.keys()] + layer_idx = [int(i[0]) for i in match if i] + num_layers_per_pipeline_stage = max(layer_idx) + 1 + + return num_layers_per_pipeline_stage + + +def parse_parallel_size(checkpoint_dir: str): + """Parse tensor, pipeline and expert parallel sizes from checkpoint filenames. + + Args: + checkpoint_dir (str): Directory containing checkpoint files + + Returns: + Namespace + """ + # Find all rank directories + rank_dirs = [d for d in os.listdir(checkpoint_dir) if d.startswith('mp_rank_')] + + if not rank_dirs: + raise ValueError(f"No checkpoint rank directories found in {checkpoint_dir}") + + ckpt = load_pt(os.path.join(checkpoint_dir, rank_dirs[0], 'model_optim_rng.pt'), to_cpu=True, weights_only=False) + args = ckpt[ARGS] + return ( + args.tensor_model_parallel_size, + args.pipeline_model_parallel_size, + args.expert_model_parallel_size, + args.num_experts + ) + + +def parse_iteration(checkpoint_path: str) -> Dict: + iteration = None + latest_iteration = None + tracker_file = os.path.join(checkpoint_path, "latest_checkpointed_iteration.txt") + if os.path.exists(tracker_file): + with FileOpen(tracker_file, 'r') as f: + iteration = latest_iteration = int(f.read().strip()) + else: + match = re.findall('iter_([\d]{7})', checkpoint_path) + if match: + iteration = int(match[0]) + + # Checkpoint directory for this iteration + logger.info(f"Loaded checkpoint from iteration {iteration}") + if latest_iteration: + checkpoint_path = os.path.join(checkpoint_path, f'iter_{iteration:07d}') + if not os.path.exists(checkpoint_path): + raise ValueError(f"Checkpoint directory not found: {checkpoint_path}") + + return checkpoint_path + + +def get_weights_from_state_dict(state_dict): + weights = {} + if 'model' in state_dict: + model_weights = state_dict['model'] + vpp_stage = 0 + + for key, value in _get_parameter(model_weights): + key = _map_to_mcore_local_names(key) + weights[f"{key}{Const.SCOPE_SEPARATOR}{vpp_stage}"] = value + + elif 'model0' in state_dict: + #vpp enabled + vpp_size = 0 + while f'model{vpp_size}' in state_dict: + model_weights = state_dict[f'model{vpp_stage}'] + for key, value in _get_parameter(model_weights): + key = _map_to_mcore_local_names(key) + weights[f"{key}{Const.SCOPE_SEPARATOR}{vpp_stage}"] = value + vpp_size += 1 + return weights + + +def load_megatron_weights(checkpoint_path: str) -> Dict: + """Load Megatron parallel checkpoint weights into a single dictionary. + + Args: + checkpoint_path (str): Base checkpoint directory path + + Returns: + combined_weights: Dict with weights from all ranks, keys include rank info + """ + # Find latest iteration if not specified + checkpoint_path = parse_iteration(checkpoint_path) + + # Parse parallel sizes from checkpoint directory structure + tp_size, pp_size, exp_size, num_experts = parse_parallel_size(checkpoint_path) + combined_weights = {} + + # Load checkpoints from all ranks + for exp_rank in range(exp_size): + num_layers_per_pipeline_stage = 0 + for pp_rank in range(pp_size): + tp_partition = defaultdict(list) + for tp_rank in range(tp_size): + # Construct checkpoint path based on parallel ranks + if pp_size > 1: + rank_dir = f'mp_rank_{tp_rank:02d}_{pp_rank:03d}' + else: + rank_dir = f'mp_rank_{tp_rank:02d}' + + if exp_size > 1: + rank_dir = f'{rank_dir}_{exp_rank:03d}' + + ckpt_file = os.path.join(checkpoint_path, rank_dir, 'model_optim_rng.pt') + try: + state_dict = load_pt(ckpt_file, to_cpu=True, weights_only=False) + partition = get_weights_from_state_dict(state_dict) + for key, weight in partition.items(): + tp_partition[key].append(weight) + + except Exception as load_error: + logger.warning(f"Error loading {ckpt_file}: {load_error}") + + if not tp_partition: + raise ValueError('No state loaded.') + + if not num_layers_per_pipeline_stage: + num_layers_per_pipeline_stage = _parse_num_layers_per_stage(tp_partition) + + consolidated_weight = _consolidate_tp_weights(tp_partition) + for key, value in consolidated_weight.items(): + key = _parse_real_layer_idx(key, num_layers_per_pipeline_stage, pp_size, pp_rank) + if num_experts: + key = _parse_real_expert_idx(key, num_experts // exp_size, exp_rank) + combined_weights[key] = value + + logger.info(f"Found {len(combined_weights)} total parameters across all ranks") + + return combined_weights diff --git a/debug/accuracy_tools/msprobe/pytorch/config_checking/ckpt_compare/metrics.py b/debug/accuracy_tools/msprobe/pytorch/config_checking/ckpt_compare/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..65b5feb659f2fc515d5f2f57faf107d65937d16c --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/config_checking/ckpt_compare/metrics.py @@ -0,0 +1,95 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from torch.nn import functional as F + +from msprobe.pytorch.common.log import logger + +MAX_SLICE = 1000000 + + +def in_different_shape(a, b): + if a.shape != b.shape: + logger.warning(f"a, b are in different shape. a: {a.shape}, b: {b.shape}") + return True + return False + + +def l2_distance(a, b): + if a is None or b is None: + return None + if in_different_shape(a, b): + return None + return (a - b).square().sum().sqrt().item() + + +def cos_sim(a, b, eps=1e-8): + if a is None or b is None: + return None + if a.dtype not in [torch.float64, torch.float32, torch.float16, torch.bfloat16]: + return None + + if in_different_shape(a, b): + return None + if a.dim() > 0: + a = a.flatten().squeeze() + b = b.flatten().squeeze() + + num_element = a.numel() + if num_element > MAX_SLICE: + logger.info(f'num parameters: {num_element}. Calculate cos by chunks') + n_batch = num_element // MAX_SLICE + 1 + sim = 0 + total_norm_a = eps + total_norm_b = eps + for i in range(n_batch): + slice_a = a[i * MAX_SLICE: min((i + 1) * MAX_SLICE, num_element)] + slice_b = b[i * MAX_SLICE: min((i + 1) * MAX_SLICE, num_element)] + slice_sim = (slice_a * slice_b).sum().item() + total_norm_a += (slice_a ** 2).sum().item() + total_norm_b += (slice_a ** 2).sum().item() + sim += slice_sim + sim = sim / total_norm_a ** 0.5 / total_norm_b ** 0.5 + + else: + sim = F.cosine_similarity(a, b, dim=0, eps=eps).item() + + return sim + + +def numel(a, b): + n1 = a.numel() + n2 = b.numel() + if n1 != n2: + logger.warning('parameters have different number of element') + return (n1, n2) + return n1 + + +def shape(a, b): + s1 = a.shape + s2 = b.shape + if in_different_shape(a, b): + return [list(s1), list(s2)] + return list(s1) + + +METRIC_FUNC = { + 'l2': l2_distance, + 'cos': cos_sim, + 'numel': numel, + 'shape': shape + } \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/pytorch/config_checking/ckpt_compare/name_mapping.yaml b/debug/accuracy_tools/msprobe/pytorch/config_checking/ckpt_compare/name_mapping.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0caecc53a73b108939435867fe1b6e614bd91812 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/config_checking/ckpt_compare/name_mapping.yaml @@ -0,0 +1,12 @@ +self_attention.linear_qkv.layer_norm_: input_layernorm. +language_model.: '' +encoder: decoder +.input_norm.: .input_layernorm. +query_key_value: linear_qkv +.dense.: .linear_proj. +post_attention_norm: pre_mlp_layernorm +dense_h_to_4h: linear_fc1 +dense_4h_to_h: linear_fc2 +mlp.local_experts: mlp.experts.local_experts +final_norm: final_layernorm +word_embeddings_for_head: output_layer diff --git a/debug/accuracy_tools/msprobe/pytorch/config_checking/config_checker.py b/debug/accuracy_tools/msprobe/pytorch/config_checking/config_checker.py new file mode 100644 index 0000000000000000000000000000000000000000..fa5d2ff3b036f0f150b19345ccad51e4b58e95f4 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/config_checking/config_checker.py @@ -0,0 +1,105 @@ +# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import shutil + +import torch +import torch.distributed as dist +import pandas as pd + +from msprobe.core.common.file_utils import save_excel, split_zip_file_path, \ + create_directory, extract_zip, make_dir +from msprobe.pytorch.config_checking.checkers.base_checker import PackInput +from msprobe.pytorch.config_checking.utils.utils import config_checking_print + + + +class ConfigChecker: + checkers = {} + pre_forward_fun_list = [] + result_filename = "result.xlsx" + result_header = ["filename", "pass_check"] + step = 0 + + def __init__(self, model=None, shell_path=None, output_zip_path="./config_check_pack.zip"): + self.pack_input = PackInput(output_zip_path, model, shell_path) + file_path, file_name = split_zip_file_path(self.pack_input.output_zip_path) + if not os.path.exists(file_path): + create_directory(file_path) + self.pack() + else: + if os.path.exists(self.pack_input.output_zip_path): + raise Exception("The output file path already exist!") + self.pack() + + + @staticmethod + def compare(bench_zip_path, cmp_zip_path, outpath): + if os.path.exists(outpath): + shutil.rmtree(outpath) + bench_dir = os.path.join(outpath, "bench") + cmp_dir = os.path.join(outpath, "cmp") + extract_zip(bench_zip_path, bench_dir) + config_checking_print(f"extract zip file {bench_zip_path} to {bench_dir}") + extract_zip(cmp_zip_path, cmp_dir) + config_checking_print(f"extract zip file {cmp_zip_path} to {cmp_dir}") + + result = [] + summary_result = [] + for checker in ConfigChecker.checkers.values(): + checker_name, pass_check, df = checker.compare_ex(bench_dir, cmp_dir, outpath) + if checker_name: + summary_result.append([checker_name, pass_check]) + if df is not None: + result.append((df, checker_name)) + summary_result_df = pd.DataFrame(summary_result, columns=ConfigChecker.result_header) + result.insert(0, (summary_result_df, "summary")) + save_excel(os.path.join(outpath, ConfigChecker.result_filename), result) + config_checking_print(f"config checking result save to {os.path.realpath(outpath)}") + + + def pack(self): + config_checking_print(f"pack result zip path {os.path.realpath(self.pack_input.output_zip_path)}") + if dist.is_initialized() and dist.get_rank() == 0: + config_checking_print(f"pack result zip path {self.pack_input.output_zip_path}") + if os.path.exists(self.pack_input.output_zip_path): + os.remove(self.pack_input.output_zip_path) + + def hook(model, args, kwargs): + for collect_func in self.pre_forward_fun_list: + collect_func(model, args, kwargs, ConfigChecker.step) + ConfigChecker.step += 1 + + if self.pack_input.model: + self.pack_input.model.register_forward_pre_hook(hook, with_kwargs=True) + for checker in ConfigChecker.checkers.values(): + if checker.input_needed and not getattr(self.pack_input, checker.input_needed): + continue + if dist.is_initialized() and dist.get_rank() != 0 and not checker.multi_rank: + continue + checker.pack(self.pack_input) + + +def register_checker_item(key, cls=None): + if cls is None: + # 无参数时,返回装饰器函数 + return lambda cls: register_checker_item(key, cls) + ConfigChecker.checkers[key] = cls + return cls + + +def register_pre_forward_fun_list(func): + ConfigChecker.pre_forward_fun_list.append(func) diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_functional.py b/debug/accuracy_tools/msprobe/pytorch/config_checking/config_checking.py similarity index 30% rename from debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_functional.py rename to debug/accuracy_tools/msprobe/pytorch/config_checking/config_checking.py index 6164169476dab66ac2bdb8d0cbc41a04ddce6713..a8cc15ab6ee36907bdc7a061cd04359b4b83ebf8 100644 --- a/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_functional.py +++ b/debug/accuracy_tools/msprobe/pytorch/config_checking/config_checking.py @@ -13,54 +13,37 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os -import torch - -from msprobe.pytorch.hook_module.hook_module import HOOKModule -from msprobe.pytorch.common.utils import torch_device_guard -from msprobe.core.common.const import Const from msprobe.pytorch.common.log import logger -from msprobe.core.common.file_utils import load_yaml - - -cur_path = os.path.dirname(os.path.realpath(__file__)) -yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml") - - -def get_functional_ops(): - yaml_data = load_yaml(yaml_path) - wrap_functional_ops = yaml_data.get('functional') - _all_functional_ops = dir(torch.nn.functional) - return set(wrap_functional_ops) & set(_all_functional_ops) - - -TorchFunctions = {func: getattr(torch.nn.functional, func) for func in get_functional_ops()} - - -class HOOKFunctionalOP(object): - pass +from msprobe.pytorch.config_checking.config_checker import ConfigChecker +from msprobe.pytorch.config_checking.ckpt_compare.compare_weight import compare_checkpoints -class FunctionalOPTemplate(HOOKModule): - def __init__(self, op_name, hook, need_hook=True): - self.op_name_ = op_name - self.prefix_op_name_ = "Functional" + Const.SEP + str(op_name) + Const.SEP - if need_hook: - super().__init__(hook) +def pack(config_filepath): + ConfigChecker(config_filepath) - @torch_device_guard - def forward(self, *args, **kwargs): - return TorchFunctions[str(self.op_name_)](*args, **kwargs) +def compare(bench_zip_path, cmp_zip_path, outpath): + ConfigChecker.compare(bench_zip_path, cmp_zip_path, outpath) -def wrap_functional_op(op_name, hook): - def functional_op_template(*args, **kwargs): - return FunctionalOPTemplate(op_name, hook)(*args, **kwargs) - return functional_op_template +def _config_checking_parser(parser): + parser.add_argument('-pack', '--pack', help='Pack a directory into a zip file') + parser.add_argument('-c', '--compare', nargs=2, help='Compare two zip files or ckpt dir') + parser.add_argument('-s', '--ckpt-sim', default=False, action='store_true', + help='Calculate the similarity of two ckpt') + parser.add_argument('-o', '--output', help='output path, default is current directory') -def wrap_functional_ops_and_bind(hook): - _functional_ops = get_functional_ops() - for op_name in _functional_ops: - setattr(HOOKFunctionalOP, "wrap_" + op_name, wrap_functional_op(op_name, hook)) +def _run_config_checking_command(args): + if args.pack: + pack(args.pack) + elif args.compare: + if args.ckpt_sim: + output_path = args.output if args.output else "./ckpt_compare_out.json" + compare_checkpoints(args.compare[0], args.compare[1], output_path) + else: + output_dirpath = args.output if args.output else "./config_check_result" + compare(args.compare[0], args.compare[1], output_dirpath) + else: + logger.error("The param is not correct, you need to give '-pack' for pack or '-c' for compare.") + raise Exception("The param is not correct, you need to give '-pack' for pack or '-c' for compare.") diff --git a/debug/accuracy_tools/msprobe/pytorch/config_checking/resource/dependency.yaml b/debug/accuracy_tools/msprobe/pytorch/config_checking/resource/dependency.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f4f73a5fce97f20608a3c9bacb92e53f1747f092 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/config_checking/resource/dependency.yaml @@ -0,0 +1,24 @@ +# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +dependency: + - transformers + - deepspeed + - megatron + - numpy + - datasets + - torch + - torchversion + - peft \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/pytorch/dump/kernel_dump/kernel_config.py b/debug/accuracy_tools/msprobe/pytorch/config_checking/resource/env.yaml similarity index 52% rename from debug/accuracy_tools/msprobe/pytorch/dump/kernel_dump/kernel_config.py rename to debug/accuracy_tools/msprobe/pytorch/config_checking/resource/env.yaml index 48d0918ca68d7f429cc97fc64c5ba7d7f884960b..13ea0e39f89b4807b72a6322ddc865145d9fde9d 100644 --- a/debug/accuracy_tools/msprobe/pytorch/dump/kernel_dump/kernel_config.py +++ b/debug/accuracy_tools/msprobe/pytorch/config_checking/resource/env.yaml @@ -13,21 +13,26 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os +HCCL_DETERMINISTIC: + - name: HCCL_DETERMINISTIC + default_value: False -from msprobe.core.common.file_utils import save_json +HCCL_ALGO: + - name: HCCL_ALGO + default_value: None +HCCL_INTRA_ROCE_ENABLE: + - name: HCCL_INTRA_ROCE_ENABLE + default_value: 0 -def create_kernel_config_json(dump_path, cur_rank): - kernel_config_name = "kernel_config.json" if cur_rank == '' else f"kernel_config_{cur_rank}.json" - kernel_config_path = os.path.join(dump_path, kernel_config_name) - config_info = { - "dump": { - "dump_list": [], - "dump_path": dump_path, - "dump_mode": "all", - "dump_op_switch": "on" - } - } - save_json(kernel_config_path, config_info, indent=4) - return kernel_config_path +HCCL_INTRA_PICE_ENABLE: + - name: HCCL_INTRA_PICE_ENABLE + default_value: 1 + +ASCEND_LAUNCH_BLOCKING: + - name: ASCEND_LAUNCH_BLOCKING + default_value: False + +ASCEND_RT_VISIBLE_DEVICE: + - name: ASCEND_RT_VISIBLE_DEVICE + default_value: None \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/pytorch/config_checking/utils/utils.py b/debug/accuracy_tools/msprobe/pytorch/config_checking/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3f8cef378ef3479aa0892786e835a66861eb6637 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/config_checking/utils/utils.py @@ -0,0 +1,99 @@ +# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import re +import hashlib + +import torch + +from msprobe.pytorch.common.log import logger + + +def merge_keys(dir_0, dir_1): + output_list = list(dir_0.keys()) + output_list.extend(list(dir_1.keys())) + return set(output_list) + + +def compare_dict(bench_dict, cmp_dict): + result = [] + for key in set(bench_dict.keys()) | set(cmp_dict.keys()): + if key in bench_dict and key in cmp_dict: + if bench_dict[key] != cmp_dict[key]: + result.append(f"{key}: {bench_dict[key]} -> {cmp_dict[key]}") + elif key in bench_dict: + result.append(f"{key}: [deleted] -> {bench_dict[key]}") + else: + result.append(f"{key}: [added] -> {cmp_dict[key]}") + return result + + +def config_checking_print(msg): + logger.info(f"[config checking log] {msg}") + + +def tensor_to_hash(tensor): + """Compute the hash value of a tensor""" + tensor_bytes = tensor.clone().detach().cpu().numpy().tobytes() + return bytes_hash(tensor_bytes) + + +def get_tensor_features(tensor): + features = { + "max": lambda x: torch.max(x).item(), + "min": lambda x: torch.min(x).item(), + "mean": lambda x: torch.mean(x).item(), + "norm": lambda x: torch.norm(x).item(), + } + + if not tensor.is_floating_point() or tensor.dtype == torch.float64: + tensor = tensor.float() + return {key: features.get(key)(tensor) for key in features} + + +def compare_dicts(dict1, dict2, path=''): + deleted = [] + added = [] + changed = [] + result = {} + + for key in dict1: + if key not in dict2: + deleted.append(f"[Deleted]: {path + key}") + result[key] = "[deleted]" + else: + if isinstance(dict1[key], dict) and isinstance(dict2[key], dict): + sub_deleted, sub_added, sub_changed, sub_result = compare_dicts( + dict1[key], dict2[key], path + key + '/') + deleted.extend(sub_deleted) + added.extend(sub_added) + changed.extend(sub_changed) + if sub_result: + result[key] = sub_result + elif dict1[key] != dict2[key]: + changed.append(f"[Changed]: {path + key} : {dict1[key]} -> {dict2[key]}") + result[key] = f"[changed]: {dict1[key]} -> {dict2[key]}" + for key in dict2: + if key not in dict1: + added.append(f"[Added]: {path + key}") + result[key] = "[added]" + return deleted, added, changed, result + + +def bytes_hash(obj: bytes): + hex_dig = hashlib.sha256(obj).hexdigest() + short_hash = int(hex_dig, 16) % (2 ** 16) + return short_hash diff --git a/debug/accuracy_tools/msprobe/pytorch/debugger/debugger_config.py b/debug/accuracy_tools/msprobe/pytorch/debugger/debugger_config.py index 77e78bc38063602e64b533291d60b9b12fd2ae00..2ed2b3a08b9021f40075b40304855bcc809142e8 100644 --- a/debug/accuracy_tools/msprobe/pytorch/debugger/debugger_config.py +++ b/debug/accuracy_tools/msprobe/pytorch/debugger/debugger_config.py @@ -13,11 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import torch - from msprobe.core.common.const import Const from msprobe.core.common.exceptions import MsprobeException from msprobe.pytorch.common.log import logger +from msprobe.pytorch.common.utils import is_torch_nn_module class DebuggerConfig: @@ -60,6 +59,7 @@ class DebuggerConfig: if isinstance(task_config.online_run_ut_recompute, bool) else False self.check() + self._check_statistics_config(task_config) if self.level == Const.LEVEL_L2: self.is_backward_kernel_dump = False @@ -78,10 +78,13 @@ class DebuggerConfig: if not isinstance(self.async_dump, bool): raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR, f"The parameters async_dump should be bool.") - if self.async_dump and self.task == Const.TENSOR and not self.list: - raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR, - f"The parameters async_dump is true in tensor task, the parameters list cannot be " - f"empty.") + if self.async_dump and self.task == Const.TENSOR: + if self.level == Const.LEVEL_DEBUG: + self.list = [] # async_dump + debug level case ignore list + if not self.list and self.level != Const.LEVEL_DEBUG: + raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR, + f"The parameters async_dump is true in tensor task, the parameters list cannot be " + f"empty.") if self.task == Const.STRUCTURE and self.level not in [Const.LEVEL_L0, Const.LEVEL_MIX]: logger.warning_on_rank_0( f"When the task is set to structure, the level should be one of {[Const.LEVEL_L0, Const.LEVEL_MIX]}. " @@ -93,25 +96,24 @@ class DebuggerConfig: self.check_kwargs() return True - def check_model(self, instance, start_model): - if self.level not in [Const.LEVEL_L0, Const.LEVEL_MIX]: - if instance.model is not None or start_model is not None: - logger.info_on_rank_0( - f"The current level is not L0 or mix level, so the model parameters will not be used.") + def check_model(self, instance, start_model, token_range=None): + instance.model = start_model if start_model is not None else instance.model + if self.level not in [Const.LEVEL_L0, Const.LEVEL_MIX] and token_range is None: return - if start_model is None and instance.model is None: + + if instance.model is None: logger.error_on_rank_0( - f"For level {self.level}, PrecisionDebugger or start interface must receive a 'model' parameter.") + f"For level {self.level} or non-empty token_range, " + f"PrecisionDebugger or start interface must receive a 'model' parameter.") raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR, f"missing the parameter 'model'") - instance.model = start_model if start_model is not None else instance.model - if isinstance(instance.model, torch.nn.Module): + if is_torch_nn_module(instance.model): return error_model = None if isinstance(instance.model, (list, tuple)): for model in instance.model: - if not isinstance(model, torch.nn.Module): + if not is_torch_nn_module(model): error_model = model break else: @@ -119,7 +121,7 @@ class DebuggerConfig: if error_model is not None: error_info = (f"The 'model' parameter must be a torch.nn.Module or list[torch.nn.Module] " - f"type, currently there is a {type(error_model)} type.") + f"type, currently there is an unsupported {type(error_model)} type.") raise MsprobeException( MsprobeException.INVALID_PARAM_ERROR, error_info) @@ -130,8 +132,23 @@ class DebuggerConfig: if not self.list or len(self.list) != 1: raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR, f"When level is set to L2, the list must be configured as a list with one api name.") + if self.task != Const.TENSOR: + raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR, + f"When level is set to L2, the task must be set to tensor.") + api_name = self.list[0] if api_name.endswith(Const.BACKWARD): self.is_backward_kernel_dump = True api_forward_name = api_name[:-len(Const.BACKWARD)] + Const.FORWARD self.list.append(api_forward_name) + + def _check_statistics_config(self, task_config): + if self.task != Const.STATISTICS: + return + self.tensor_list = [] + if not hasattr(task_config, "tensor_list"): + return + if self.level == Const.LEVEL_DEBUG and task_config.tensor_list: + logger.warning_on_rank_0("When level is set to debug, the tensor_list will be invalid.") + return + self.tensor_list = task_config.tensor_list diff --git a/debug/accuracy_tools/msprobe/pytorch/debugger/precision_debugger.py b/debug/accuracy_tools/msprobe/pytorch/debugger/precision_debugger.py index 5bb1d3a14e82d7b4bce9d7da8921a1d701e82222..be3ad5bf27dae51f0b26204b24f260cccb4132da 100644 --- a/debug/accuracy_tools/msprobe/pytorch/debugger/precision_debugger.py +++ b/debug/accuracy_tools/msprobe/pytorch/debugger/precision_debugger.py @@ -13,36 +13,22 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections import namedtuple +from torch.utils.data import dataloader -import torch -from msprobe.core.common.const import Const, FileCheckConst, MsgConst +from msprobe.core.common.const import Const, MsgConst from msprobe.core.common.exceptions import MsprobeException -from msprobe.core.common.file_utils import FileChecker -from msprobe.core.common.utils import get_real_step_or_rank +from msprobe.core.common.utils import check_token_range +from msprobe.core.debugger.precision_debugger import BasePrecisionDebugger from msprobe.pytorch.common.log import logger -from msprobe.pytorch.common.utils import check_save_param +from msprobe.pytorch.common.utils import check_save_param, is_torch_nn_module from msprobe.pytorch.debugger.debugger_config import DebuggerConfig from msprobe.pytorch.dump.module_dump.module_dump import ModuleDumper from msprobe.pytorch.grad_probe.grad_monitor import GradientMonitor -from msprobe.pytorch.pt_config import parse_json_config -from msprobe.pytorch.service import Service -from torch.utils.data import dataloader - -ConfigParameters = namedtuple("ConfigParameters", ["config_path", "task", - "dump_path", "level", "model"]) +from msprobe.pytorch.pt_config import parse_task_config +from msprobe.pytorch.pytorch_service import PytorchService -class PrecisionDebugger: - _instance = None - tasks_not_need_debugger = [Const.GRAD_PROBE] - - def __new__(cls, *args, **kwargs): - if cls._instance is None: - cls._instance = super(PrecisionDebugger, cls).__new__(cls) - cls._instance.config = None - cls._instance.enable_dataloader = False - return cls._instance +class PrecisionDebugger(BasePrecisionDebugger): def __init__( self, @@ -53,90 +39,65 @@ class PrecisionDebugger: model=None, step=None ): - if not hasattr(self, "initialized"): - config_params = ConfigParameters(config_path, - task, - dump_path, - level, - model) - self.check_input_params(config_params) - - self.initialized = True - self.model = model - common_config, task_config = parse_json_config(config_path, task) - self.task = task if task else common_config.task - if self.task == Const.GRAD_PROBE: - self.gm = GradientMonitor(common_config, task_config) - return - if step is not None: - common_config.step = get_real_step_or_rank(step, Const.STEP) - self.config = DebuggerConfig( - common_config, task_config, task, dump_path, level - ) - self.service = Service(self.config) - self.module_dumper = ModuleDumper(self.service) - self.enable_dataloader = self.config.enable_dataloader - if self.enable_dataloader: - logger.warning_on_rank_0("The enable_dataloader feature will be deprecated in the future.") - dataloader._BaseDataLoaderIter.__next__ = iter_tracer(dataloader._BaseDataLoaderIter.__next__) - - @property - def instance(self): - return self._instance + if self.initialized: + return + super().__init__(config_path, task, dump_path, level, step) + self.model = model + if self.task == Const.GRAD_PROBE: + self.gm = GradientMonitor(self.common_config, self.task_config) + return + self.config = DebuggerConfig( + self.common_config, self.task_config, task, dump_path, level + ) + self.service = PytorchService(self.config) + self.module_dumper = ModuleDumper(self.service) + self.ori_customer_func = {} + self.enable_dataloader = self.config.enable_dataloader + self._param_warning() @staticmethod - def check_input_params(args): - if args.config_path is not None: - if not isinstance(args.config_path, str): - raise MsprobeException( - MsprobeException.INVALID_PARAM_ERROR, f"config_path must be a string") - file_checker = FileChecker( - file_path=args.config_path, path_type=FileCheckConst.FILE, file_type=FileCheckConst.JSON_SUFFIX) - file_checker.common_check() - - if args.task is not None and args.task not in Const.TASK_LIST: - raise MsprobeException( - MsprobeException.INVALID_PARAM_ERROR, f"task must be one of {Const.TASK_LIST}") + def _get_task_config(task, json_config): + return parse_task_config(task, json_config) - if args.dump_path is not None: - if not isinstance(args.dump_path, str): + @staticmethod + def _iter_tracer(func): + def func_wrapper(*args, **kwargs): + debugger_instance = PrecisionDebugger._instance + if not debugger_instance: raise MsprobeException( - MsprobeException.INVALID_PARAM_ERROR, f"dump_path must be a string") + MsprobeException.INTERFACE_USAGE_ERROR, + f"PrecisionDebugger must be instantiated before executing the dataloader iteration" + ) - if args.level is not None and args.level not in Const.LEVEL_LIST: - raise MsprobeException( - MsprobeException.INVALID_PARAM_ERROR, f"level must be one of {Const.LEVEL_LIST}") + debugger_instance.enable_dataloader = False + if not debugger_instance.service.first_start: + debugger_instance.stop() + debugger_instance.step() + result = func(*args, **kwargs) + debugger_instance.start() + debugger_instance.enable_dataloader = True + return result - if args.model is not None: - logger.warning_on_rank_0( - "The 'model' parameter in the PrecisionDebugger will be deprecated in the future." - "It is recommended to pass the 'model' parameter in the start interface instead." - ) + return func_wrapper @classmethod - def start(cls, model=None): - instance = cls._instance - if not instance: - raise Exception(MsgConst.NOT_CREATED_INSTANCE) - if instance.task in PrecisionDebugger.tasks_not_need_debugger: + def start(cls, model=None, token_range=None): + instance = cls._get_instance() + if instance is None: return - instance.config.check_model(instance, model) + + check_token_range(token_range) + instance.config.check_model(instance, model, token_range) + if instance.enable_dataloader: logger.warning_on_rank_0("DataLoader is enabled, start() skipped.") else: - instance.service.start(instance.model) - - @classmethod - def forward_backward_dump_end(cls): - instance = cls._instance - instance.stop() + instance.service.start(instance.model, token_range) @classmethod def stop(cls): - instance = cls._instance - if not instance: - raise Exception(MsgConst.NOT_CREATED_INSTANCE) - if instance.task in PrecisionDebugger.tasks_not_need_debugger: + instance = cls._get_instance() + if instance is None: return if instance.enable_dataloader: logger.warning_on_rank_0("DataLoader is enabled, stop() skipped.") @@ -145,9 +106,8 @@ class PrecisionDebugger: @classmethod def step(cls): - if not cls._instance: - raise Exception(MsgConst.NOT_CREATED_INSTANCE) - if cls._instance.task in PrecisionDebugger.tasks_not_need_debugger: + instance = cls._get_instance() + if instance is None: return cls._instance.service.step() @@ -172,12 +132,23 @@ class PrecisionDebugger: return instance.service.save(variable, name, save_backward) + def _param_warning(self): + if self.model is not None: + logger.warning_on_rank_0( + "The 'model' parameter in the PrecisionDebugger will be deprecated in the future." + "It is recommended to pass the 'model' parameter in the start interface instead." + ) + if self.enable_dataloader: + logger.warning_on_rank_0("The enable_dataloader feature will be deprecated in the future.") + dataloader._BaseDataLoaderIter.__next__ = self._iter_tracer(dataloader._BaseDataLoaderIter.__next__) + def module_dump(module, dump_name): - if not isinstance(module, torch.nn.Module): + if not is_torch_nn_module(module): raise MsprobeException( MsprobeException.INVALID_PARAM_ERROR, - f"the module argument in module_dump must be a torch.nn.Module subclass" + f"the module argument in module_dump must be a torch.nn.Module type, " + f"but currently there is an unsupported {type(module)} type." ) if not isinstance(dump_name, str): raise MsprobeException( @@ -201,17 +172,3 @@ def module_dump_end(): f"PrecisionDebugger must be instantiated before using module_dump_end interface" ) instance.module_dumper.stop_module_dump() - - -def iter_tracer(func): - def func_wrapper(*args, **kwargs): - debugger_instance = PrecisionDebugger.instance - debugger_instance.enable_dataloader = False - if not debugger_instance.service.first_start: - debugger_instance.stop() - debugger_instance.step() - result = func(*args, **kwargs) - debugger_instance.start() - debugger_instance.enable_dataloader = True - return result - return func_wrapper diff --git a/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/hook_wrapper.py b/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/hook_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..0434e3e62686ac0f8011ea8e58daadd9da81c3c0 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/hook_wrapper.py @@ -0,0 +1,93 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import wraps + +import torch +from torch.utils.hooks import BackwardHook + +from msprobe.core.common.const import Const +from msprobe.core.common.decorator import recursion_depth_decorator +from msprobe.pytorch.common.log import logger +from msprobe.pytorch.common.utils import is_float8_tensor + + +def wrap_setup_backward_hook(func): + def requires_clone(tensor): + return isinstance(tensor, torch.Tensor) and not is_float8_tensor(tensor) and \ + tensor.requires_grad and torch.is_grad_enabled() + + @recursion_depth_decorator("Dump: wrap_setup_backward_hook.parse_tensor", max_depth=Const.DUMP_MAX_DEPTH) + def parse_tensor(item, tensor_list): + if requires_clone(item): + tensor_list.append(item) + elif isinstance(item, (list, tuple)): + for value in item: + parse_tensor(value, tensor_list) + elif isinstance(item, dict): + for value in item.values(): + parse_tensor(value, tensor_list) + + @recursion_depth_decorator("Dump: wrap_setup_backward_hook.rebuild_args", max_depth=Const.DUMP_MAX_DEPTH) + def rebuild_args(item, tensor_iter): + if requires_clone(item): + result = next(tensor_iter) + if hasattr(result, "_base") and result._base is not None: + if torch._C._autograd._get_creation_meta(result) != torch._C._autograd.CreationMeta(0): + torch._C._autograd._set_creation_meta(result, torch._C._autograd.CreationMeta(0)) + return result + if isinstance(item, list): + for index, value in enumerate(item): + item[index] = rebuild_args(value, tensor_iter) + return item + if isinstance(item, dict): + for key, value in item.items(): + item[key] = rebuild_args(value, tensor_iter) + return item + if isinstance(item, tuple): + if hasattr(item, '_fields'): + return type(item)(*[rebuild_args(i, tensor_iter) for i in item]) + return type(item)([rebuild_args(i, tensor_iter) for i in item]) + return item + + @wraps(func) + def wrap_setup_hook_func(*args, **kwargs): + if len(args) < 2: + return func(*args, **kwargs) + + actual_args = args[1] + + tensor_list = [] + + parse_tensor(actual_args, tensor_list) + + new_args = args[0], tuple(tensor_list) + hooked_tensors = func(*new_args, **kwargs) + + tensor_iter = iter(hooked_tensors) + try: + new_data = rebuild_args(actual_args, tensor_iter) + except Exception as e: + logger.debug(f"Unsupported data in setup input/output hook. The detail info: {e}") + new_data = actual_args + + return new_data + + return wrap_setup_hook_func + + +def wrap_setup_input_output_hook(): + BackwardHook.setup_input_hook = wrap_setup_backward_hook(BackwardHook.setup_input_hook) + BackwardHook.setup_output_hook = wrap_setup_backward_hook(BackwardHook.setup_output_hook) diff --git a/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/module_dump.py b/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/module_dump.py index 4700de6f1f9f3b5ddfb9507decb6f8739b5eda9b..5bf26f7ac0d91cce630a3b9c8e648453ae4ab65c 100644 --- a/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/module_dump.py +++ b/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/module_dump.py @@ -13,74 +13,28 @@ # See the License for the specific language governing permissions and # limitations under the License. -import torch -from msprobe.core.common.const import Const -from msprobe.core.data_dump.scope import BaseScope from msprobe.pytorch.common.log import logger -from msprobe.pytorch.hook_module.api_registry import api_register - -torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0' +from msprobe.pytorch.dump.module_dump.module_processer import ModuleProcesser +from msprobe.pytorch.hook_module.api_register import get_api_register class ModuleDumper: def __init__(self, service): self.service = service - self.hook_handle_list = [] + self.api_register = get_api_register() def start_module_dump(self, module, dump_name): - api_register.api_originality() - self.register_hook(module, dump_name) - - def stop_module_dump(self): - api_register.api_modularity() - for hook_handle in self.hook_handle_list: - if isinstance(hook_handle, torch.utils.hooks.RemovableHandle): - hook_handle.remove() - self.hook_handle_list.clear() + if hasattr(module, 'msprobe_hook') and not hasattr(module, 'msprobe_module_dump'): + logger.info_on_rank_0("The init dump is enabled, and the module dump function will not be available.") + return - def register_hook(self, module, dump_name): - prefix_name = ( - BaseScope.Module_Type_Module + Const.SEP + - dump_name + Const.SEP + - module.__class__.__name__ + Const.SEP - ) - module_processor = self.service.module_processor - _, forward_hook, backward_hook, forward_hook_torch_version_below_2 = self.service.build_hook( - BaseScope.Module_Type_Module, - prefix_name - ) + ModuleProcesser.enable_module_dump = True + self.api_register.restore_all_api() + if not hasattr(module, 'msprobe_module_dump'): + self.service.module_processor.register_module_hook(module, self.service.build_hook, + recursive=False, module_names=[dump_name]) + setattr(module, 'msprobe_module_dump', True) - if module_processor.has_register_backward_hook(module): - logger.warning( - f"The {dump_name} module has registered deprecated register_backward_hook," - f"which may cause abnormal data dump. The backward data dump for this module will be skipped." - ) - if torch_version_above_or_equal_2: - forward_hook_handle = module.register_forward_hook(forward_hook, with_kwargs=True) - else: - if not module_processor.has_register_backward_hook(module): - backward_hook_handle = module.register_full_backward_hook( - module_processor.node_hook(prefix_name + Const.BACKWARD, Const.STOP) - ) - self.hook_handle_list.append(backward_hook_handle) - forward_hook_handle = module.register_forward_hook(forward_hook_torch_version_below_2) - self.hook_handle_list.append(forward_hook_handle) - if not module_processor.has_register_backward_hook(module): - backward_hook_handle = module.register_full_backward_hook(backward_hook) - self.hook_handle_list.append(backward_hook_handle) - - forward_pre_hook_handle = module.register_forward_pre_hook( - module_processor.node_hook(prefix_name + Const.FORWARD, Const.START) - ) - forward_hook_handle = module.register_forward_hook( - module_processor.node_hook(prefix_name + Const.FORWARD, Const.STOP) - ) - self.hook_handle_list.extend([forward_pre_hook_handle, forward_hook_handle]) - if torch_version_above_or_equal_2 and not module_processor.has_register_backward_hook(module): - backward_pre_hook_handle = module.register_full_backward_pre_hook( - module_processor.node_hook(prefix_name + Const.BACKWARD, Const.START) - ) - backward_hook_handle = module.register_full_backward_hook( - module_processor.node_hook(prefix_name + Const.BACKWARD, Const.STOP) - ) - self.hook_handle_list.extend([backward_pre_hook_handle, backward_hook_handle]) + def stop_module_dump(self): + ModuleProcesser.enable_module_dump = False + self.api_register.register_all_api() diff --git a/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/module_processer.py b/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/module_processer.py index b5ca1da461fd4235a09172de4b9dcea34a624e58..c770f99f338d05b708c48a02245c88e2e5e0d291 100644 --- a/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/module_processer.py +++ b/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/module_processer.py @@ -13,18 +13,21 @@ # See the License for the specific language governing permissions and # limitations under the License. -from functools import wraps +import sys +from collections import OrderedDict import torch +from torch.utils.hooks import BackwardHook, RemovableHandle + from msprobe.core.common.const import Const from msprobe.core.data_dump.scope import BaseScope, ModuleRangeScope, MixRangeScope from msprobe.pytorch.common.log import logger -from msprobe.pytorch.common.utils import replace_last_occurrence -from torch.utils.checkpoint import checkpoint as origin_checkpoint -from torch.utils.checkpoint import set_checkpoint_early_stop -from torch.utils.hooks import BackwardHook +from msprobe.pytorch.common.utils import is_torch_nn_module, register_forward_pre_hook +from msprobe.pytorch.dump.module_dump.hook_wrapper import wrap_setup_input_output_hook torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0' +if torch_version_above_or_equal_2: + from torch.utils.checkpoint import checkpoint as origin_checkpoint, set_checkpoint_early_stop def checkpoint_without_early_stop(*args, **kwargs): @@ -33,7 +36,18 @@ def checkpoint_without_early_stop(*args, **kwargs): def replace_checkpoint(): - torch.utils.checkpoint.checkpoint = checkpoint_without_early_stop + if torch_version_above_or_equal_2: + torch.utils.checkpoint.checkpoint = checkpoint_without_early_stop + + +def wrap_megatron_deallocate(func): + def wrapper_func(out, deallocate_pipeline_outputs=False): + if deallocate_pipeline_outputs and isinstance(out, torch.Tensor) and getattr(out, "_base") is not None: + out_clone = out.clone() + out.data = torch.empty((1,), device=out.device, dtype=out.dtype, ) + return func(out_clone, deallocate_pipeline_outputs) + return func(out, deallocate_pipeline_outputs) + return wrapper_func class ModuleProcesser: @@ -41,37 +55,33 @@ class ModuleProcesser: module_stack = [] api_parent_node = "" module_node = {} + module_bw_hook_kernels = {} + module_with_backward_hook = {} + enable_module_dump = False def __init__(self, scope): self.scope = scope if isinstance(scope, (ModuleRangeScope, MixRangeScope)) else None - BackwardHook.setup_input_hook = ModuleProcesser.clone_return_value(BackwardHook.setup_input_hook) - BackwardHook.setup_output_hook = ModuleProcesser.clone_return_value(BackwardHook.setup_output_hook) + wrap_setup_input_output_hook() replace_checkpoint() + try: + from megatron.core.pipeline_parallel import schedules + origin_func_id = id(schedules.deallocate_output_tensor) + schedules.deallocate_output_tensor = wrap_megatron_deallocate(schedules.deallocate_output_tensor) + for module in list(sys.modules.values()): + if module.__name__ == 'schedules': + continue + for func in module.__dict__: + if id(module.__dict__[func]) == origin_func_id: + module.__setattr__(func, schedules.deallocate_output_tensor) + logger.debug(f'patch {module.__name__}.{func}.') + logger.info_on_rank_0("Patch megatron method success.") + except ImportError: + logger.info_on_rank_0("No megatron find.") + except Exception as e: + logger.info_on_rank_0(f"Patch megatron method failed, detail:{str(e)}") @staticmethod - def clone_return_value(func): - @wraps(func) - def clone_return_value_func(*args, **kwargs): - result = func(*args, **kwargs) - return ModuleProcesser.clone_if_tensor(result) - - return clone_return_value_func - - @staticmethod - def clone_if_tensor(result): - if isinstance(result, torch.Tensor): - return result.clone() - elif type(result) is tuple: - return tuple(ModuleProcesser.clone_if_tensor(x) for x in result) - elif type(result) is list: - return list(ModuleProcesser.clone_if_tensor(x) for x in result) - elif type(result) is dict: - return {k: ModuleProcesser.clone_if_tensor(v) for k, v in result.items()} - else: - return result - - @staticmethod - def module_count_func(module_name): + def set_and_get_calls_number(module_name): if module_name not in ModuleProcesser.module_count: ModuleProcesser.module_count[module_name] = 0 else: @@ -85,13 +95,19 @@ class ModuleProcesser: module._is_full_backward_hook is False @staticmethod - def get_modules_and_names(models): + def get_modules_and_names(models, recursive, module_names): modules_and_names_with_index = {} if isinstance(models, (list, tuple)): + if not recursive and len(module_names) != len(models): + return modules_and_names_with_index for index, model in enumerate(models): - modules_and_names_with_index[str(index)] = model.named_modules() + modules_and_names_with_index[str(index)] = model.named_modules() if recursive else \ + [(module_names[index], model)] else: - modules_and_names_with_index["-1"] = models.named_modules() + if not recursive and len(module_names) != 1: + return modules_and_names_with_index + modules_and_names_with_index["-1"] = models.named_modules() if recursive else \ + [(module_names[0], models)] return modules_and_names_with_index @classmethod @@ -100,105 +116,134 @@ class ModuleProcesser: cls.module_stack = [] cls.api_parent_node = "" cls.module_node = {} + cls.module_bw_hook_kernels = {} + cls.enable_module_dump = False + + def register_module_hook(self, models, build_hook, recursive=True, module_names=None): + if module_names is None: + module_names = [] - def register_module_hook(self, models, build_hook): - logger.info_on_rank_0("The init dump is enabled, and the module dump function will not be available.") - modules_and_names_with_index = self.get_modules_and_names(models) + modules_and_names_with_index = self.get_modules_and_names(models, recursive, module_names) for index, modules_and_names in modules_and_names_with_index.items(): model = models if index == "-1" else models[int(index)] for name, module in modules_and_names: - if module == model: + if recursive and module == model: continue + if not is_torch_nn_module(module): + logger.warning( + f"The module dump does not support {type(module)} type. " + f"The data dump for this module will be skipped." + ) + continue + if module.__class__.__name__ == "FullyShardedDataParallel": + continue + setattr(module, 'msprobe_hook', True) module_index = (index + Const.SEP) if index != "-1" else "" - prefix_name = (BaseScope.Module_Type_Module + Const.SEP + module_index + - name + Const.SEP + module.__class__.__name__ + Const.SEP) - pre_forward_hook, forward_hook, backward_hook, forward_hook_torch_version_below_2 = build_hook( - BaseScope.Module_Type_Module, - prefix_name - ) + prefix_name = f'{BaseScope.Module_Type_Module}{Const.SEP}{module_index}{name}{Const.SEP}' + \ + f'{module.__class__.__name__}{Const.SEP}' + + forward_pre_hook = self.build_module_hook(prefix_name, build_hook) if self.has_register_backward_hook(module): logger.warning( f"The {prefix_name[:-1]} has registered deprecated register_backward_hook," f"which may cause abnormal data dump. The backward data dump for this module will be skipped." ) + ModuleProcesser.module_with_backward_hook[prefix_name] = True + register_forward_pre_hook(module, forward_pre_hook) + + def build_module_hook(self, module_name, build_data_hook): + def forward_pre_hook(module, args, kwargs=None): + if kwargs is None: + kwargs = {} + + if hasattr(module, 'msprobe_module_dump') and not self.enable_module_dump: + return (args, kwargs) if torch_version_above_or_equal_2 else args + + index = ModuleProcesser.set_and_get_calls_number(module_name) + full_forward_name = f'{module_name}{Const.FORWARD}{Const.SEP}{index}' + full_backward_name = f'{module_name}{Const.BACKWARD}{Const.SEP}{index}' + + self.set_construct_info_in_pre_hook(full_forward_name) + + if not hasattr(module, 'msprobe_forward_hook'): + forward_hooks_dict = getattr(module, '_forward_hooks', OrderedDict()) + handle = RemovableHandle(forward_hooks_dict) + forward_hooks_dict[handle.id] = forward_hook + forward_hooks_dict.move_to_end(handle.id, last=False) if torch_version_above_or_equal_2: - module.register_forward_hook(forward_hook, with_kwargs=True) + forward_hooks_with_kwargs_dict = getattr(module, '_forward_hooks_with_kwargs', OrderedDict()) + forward_hooks_with_kwargs_dict[handle.id] = True + + setattr(module, 'msprobe_forward_hook', True) + + hook_set = build_data_hook(BaseScope.Module_Type_Module, full_forward_name) + + def get_backward_pre_hook(full_backward_name): + def backward_pre_hook_fn(module, grad_output): + self.set_construct_info_in_pre_hook(full_backward_name) + return backward_pre_hook_fn + + def get_backward_hook(backward_data_hook, full_backward_name): + def backward_hook_fn(module, grad_input, grad_output): + new_output = backward_data_hook(module, grad_input, grad_output) + self.set_construct_info_in_hook(full_backward_name, is_forward=False) + return new_output + return backward_hook_fn + + if not ModuleProcesser.module_with_backward_hook.get(module_name): + backward_pre_hook = get_backward_pre_hook(full_backward_name) + backward_hook = get_backward_hook(hook_set.backward_hook, full_backward_name) + if torch_version_above_or_equal_2: + bw_hook = BackwardHook(module, [backward_hook], [backward_pre_hook]) else: - if not self.has_register_backward_hook(module): - module.register_full_backward_hook(self.node_hook(prefix_name + Const.BACKWARD, Const.STOP)) - module.register_forward_hook(forward_hook_torch_version_below_2) - if not self.has_register_backward_hook(module): - module.register_full_backward_hook(backward_hook) - - module.register_forward_pre_hook(self.node_hook(prefix_name + Const.FORWARD, Const.START)) - module.register_forward_hook(self.node_hook(prefix_name + Const.FORWARD, Const.STOP)) - if torch_version_above_or_equal_2 and not self.has_register_backward_hook(module): - module.register_full_backward_pre_hook(self.node_hook(prefix_name + Const.BACKWARD, Const.START)) - module.register_full_backward_hook(self.node_hook(prefix_name + Const.BACKWARD, Const.STOP)) - - def node_hook(self, name_prefix, start_or_stop, **kwargs): - - def pre_hook(module, input, output=None): - try: - index = ModuleProcesser.module_count_func(name_prefix) - except IndexError as e: - index = None - pass - full_name = name_prefix + Const.SEP + str(index) - if not hasattr(module, "mindstudio_reserved_name") or not module.mindstudio_reserved_name: - module.mindstudio_reserved_name = [] - module.mindstudio_reserved_name.append(full_name) - if self.module_stack: - ModuleProcesser.module_node[full_name] = self.module_stack[-1] + bw_hook = BackwardHook(module, [backward_hook]) + ModuleProcesser.module_bw_hook_kernels[full_forward_name] = bw_hook + args = bw_hook.setup_input_hook(args) + return (args, kwargs) if torch_version_above_or_equal_2 else args + + def forward_hook(module, args, kwargs_or_output, output_or_kwargs=None): + if hasattr(module, 'msprobe_module_dump') and not self.enable_module_dump: + return output_or_kwargs if torch_version_above_or_equal_2 else kwargs_or_output + + index = ModuleProcesser.module_count.get(module_name) + full_name = f'{module_name}{Const.FORWARD}{Const.SEP}{index}' + + hook_set = build_data_hook(BaseScope.Module_Type_Module, full_name) + hook_result = hook_set.forward_hook(module, args, kwargs_or_output, output_or_kwargs) + self.set_construct_info_in_hook(full_name) + + if hook_result is not None: + result = hook_result else: - ModuleProcesser.module_node[full_name] = None + result = output_or_kwargs if torch_version_above_or_equal_2 else kwargs_or_output - ModuleProcesser.module_stack.append(full_name) - if self.module_stack: - ModuleProcesser.api_parent_node = self.module_stack[-1] - if self.scope: - self.scope.begin_module(full_name) + bw_hook = ModuleProcesser.module_bw_hook_kernels.get(full_name) + if bw_hook: + result = bw_hook.setup_output_hook(result) - def end_hook(module, input, output=None): + return result + + return forward_pre_hook + + def set_construct_info_in_pre_hook(self, full_name): + if self.module_stack: + ModuleProcesser.module_node[full_name] = self.module_stack[-1] + else: + ModuleProcesser.module_node[full_name] = None + ModuleProcesser.module_stack.append(full_name) + ModuleProcesser.api_parent_node = full_name + if self.scope: + self.scope.begin_module(full_name) + + def set_construct_info_in_hook(self, full_name, is_forward=True): + if torch_version_above_or_equal_2 or is_forward: if self.module_stack: ModuleProcesser.module_stack.pop() - if self.module_stack: - ModuleProcesser.api_parent_node = self.module_stack[-1] - else: - ModuleProcesser.api_parent_node = None - if not hasattr(module, "mindstudio_reserved_name") or not module.mindstudio_reserved_name: - raise RuntimeError(f"module reserve name is None when pop") - current_name = module.mindstudio_reserved_name.pop() + ModuleProcesser.api_parent_node = ModuleProcesser.module_stack[-1] if self.module_stack else None if self.scope: - self.scope.end_module(current_name) - - def backward_hook(module, input, output=None): - try: - index = ModuleProcesser.module_count_func(name_prefix) - except IndexError as e: - index = None - pass - full_name = name_prefix + Const.SEP + str(index) - if not hasattr(module, "mindstudio_reserved_name") or not module.mindstudio_reserved_name: - module.mindstudio_reserved_name = [] - module.mindstudio_reserved_name.append(full_name) - forward_full_name = replace_last_occurrence(full_name, Const.BACKWARD, Const.FORWARD) - ModuleProcesser.module_node[full_name] = replace_last_occurrence( - ModuleProcesser.module_node.get(forward_full_name), Const.FORWARD, Const.BACKWARD) - ModuleProcesser.api_parent_node = None + self.scope.end_module(full_name) + else: if self.scope: self.scope.begin_module(full_name) - - if torch_version_above_or_equal_2: - if Const.START in start_or_stop: - return pre_hook - else: - return end_hook - else: - if Const.FORWARD in name_prefix and Const.START in start_or_stop: - return pre_hook - elif Const.BACKWARD in name_prefix: - return backward_hook - else: - return end_hook + ModuleProcesser.api_parent_node = full_name diff --git a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/common/utils.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/common/utils.py index e3fd2b69fef2772354401a22344376258e77a008..6baa684cbff27001ac489eddf11269ba2c71dfae 100644 --- a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/common/utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/common/utils.py @@ -16,7 +16,7 @@ import torch from msprobe.core.common.exceptions import FreeBenchmarkException -from msprobe.core.common.utils import recursion_depth_decorator +from msprobe.core.common.decorator import recursion_depth_decorator from msprobe.pytorch.free_benchmark.common.enums import DeviceType diff --git a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/compare/single_benchmark.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/compare/single_benchmark.py index 49e845da4011565f1b6ccf0c0e1193fb3fcffcbf..a5f18946c44c09bf1670173d45cc99ace3b0e79d 100644 --- a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/compare/single_benchmark.py +++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/compare/single_benchmark.py @@ -16,7 +16,7 @@ import math import torch -from msprobe.core.common.utils import recursion_depth_decorator +from msprobe.core.common.decorator import recursion_depth_decorator from msprobe.pytorch.free_benchmark import logger from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig from msprobe.pytorch.free_benchmark.common.utils import TorchC diff --git a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py index 41ec39e3a3b6233720c047d5d2b736d91bba989e..754e3b06e9670a04fcf7c20d5af3d7e1733b7af1 100644 --- a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py @@ -14,7 +14,7 @@ # limitations under the License. import torch -from msprobe.core.common.utils import recursion_depth_decorator +from msprobe.core.common.decorator import recursion_depth_decorator from msprobe.pytorch.free_benchmark import logger from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode @@ -95,13 +95,13 @@ class AddNoiseLayer(NpuBaseLayer): except Exception: logger.warning_on_rank_0( f"[msprobe] Free Benchmark: For {self.api_name}, " - f"when calculate maximun value, tensor is changed to float32." + f"when calculating the maximum value, the tensor is changed to float32." ) max_val = TorchC.max(TorchC.abs(tensor_obj.to(torch.float32))).item() if max_val < abs_tol: logger.warning_on_rank_0( f"[msprobe] Free Benchmark: For {self.api_name}, " - f"Maximun value is less than the minimun threshold. Cancel add noise." + f"maximum value is less than the minimum threshold. Cancel adding noise." ) return False return True diff --git a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py index df1a73127aa0b69e42254cce1d3334810319f7cf..aec0c3ca96e39958316f6835261618c148c7ad4e 100644 --- a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py @@ -14,7 +14,7 @@ # limitations under the License. import torch -from msprobe.core.common.utils import recursion_depth_decorator +from msprobe.core.common.decorator import recursion_depth_decorator from msprobe.pytorch.free_benchmark import logger from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode @@ -100,13 +100,13 @@ class BitNoiseLayer(NpuBaseLayer): except Exception: logger.warning_on_rank_0( f"[msprobe] Free Benchmark: For {self.api_name}, " - f"when calculate maximun value, tensor is changed to float32." + f"when calculate the maximum value, the tensor is changed to float32." ) max_val = TorchC.max(TorchC.abs(tensor_obj.to(torch.float32))).item() if max_val < abs_tol: logger.warning_on_rank_0( f"[msprobe] Free Benchmark: For {self.api_name}, " - f"Maximun value is less than the minimun threshold. Cancel add noise." + f"maximum value is less than the minimum threshold. Cancel adding noise." ) return False return True diff --git a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py index c4fbeaf82f8fcafba235a7faa6dd9073d4d556d8..521637a1d8b3bca226a6eacfc5f6f5a0d4bc1921 100644 --- a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py @@ -14,7 +14,7 @@ # limitations under the License. import torch -from msprobe.core.common.utils import recursion_depth_decorator +from msprobe.core.common.decorator import recursion_depth_decorator from msprobe.pytorch.free_benchmark import logger from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode from msprobe.pytorch.free_benchmark.common.params import DataParams diff --git a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py index 095e77ffaff39a795cb1418c1695608d91d7427b..daa271976f3b05f81b9997bd1775ee2809b776c9 100644 --- a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py @@ -15,7 +15,7 @@ import torch from msprobe.core.common.const import Const -from msprobe.core.common.utils import recursion_depth_decorator +from msprobe.core.common.decorator import recursion_depth_decorator from msprobe.pytorch.free_benchmark import logger from msprobe.pytorch.free_benchmark.common.constant import CommonField from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode diff --git a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/result_handlers/base_handler.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/result_handlers/base_handler.py index 47f93ab7b89f44bdd4f92ceafc6e9dbe503d0374..e0d583dd012364f3bb75eb4d030dca21cfea2bc6 100644 --- a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/result_handlers/base_handler.py @@ -186,6 +186,8 @@ class FuzzHandler(ABC): ratio = self.ratio_calculate( origin_output, perturbed_output, norm_type=NormType.ENDLESS_NORM ) + if threshold == 0: + raise ValueError("Threshold cannot be zero. Check `get_threshold` implementation.") if ratio == ThresholdConfig.SYMBOL_FLIPPING: is_consistent = False else: diff --git a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/result_handlers/check_handler.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/result_handlers/check_handler.py index 9feec1531b16ff8ba63910f3f7c40aa275d0104e..d088cd1d1647a59c167f705702d9ad6afcf6e21b 100644 --- a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/result_handlers/check_handler.py @@ -49,6 +49,6 @@ class CheckerHandler(FuzzHandler): except Exception as e: logger.warning_on_rank_0( f"[msprobe] Free Benchmark: For {self.params.api_name}, " - f"when campare the result exception raise {e}" + f"when comparing the results, an exception is raised: {e}" ) return data_params.original_result diff --git a/debug/accuracy_tools/msprobe/pytorch/function_factory.py b/debug/accuracy_tools/msprobe/pytorch/function_factory.py index 247e2cd0ed5ea11047cc0d75954dbc1e92b889f4..f515b5d4783c0e20a2303579f6954d42a7b9deac 100644 --- a/debug/accuracy_tools/msprobe/pytorch/function_factory.py +++ b/debug/accuracy_tools/msprobe/pytorch/function_factory.py @@ -70,7 +70,7 @@ class Register(dict): def add_register_item(key, value): if key in self._dict: - logger.warning(f"{value.__name__} has been registered before, so we will overriden it.") + logger.warning(f"{value.__name__} has been registered before, so we will override it.") self[key] = value return value diff --git a/debug/accuracy_tools/msprobe/pytorch/grad_probe/grad_monitor.py b/debug/accuracy_tools/msprobe/pytorch/grad_probe/grad_monitor.py index 926476b8fb353531e54a485ccb47c4c59860c5d0..81d7575fc251c0b90703b13c537f61f778cf5136 100644 --- a/debug/accuracy_tools/msprobe/pytorch/grad_probe/grad_monitor.py +++ b/debug/accuracy_tools/msprobe/pytorch/grad_probe/grad_monitor.py @@ -46,7 +46,7 @@ class GradientMonitor: if not os.path.exists(self._output_path): create_directory(self._output_path) else: - logger.warning(f"the file in {self._output_path} will be recoverd") + logger.warning(f"the file in {self._output_path} will be deleted") self._step = -1 self._param2name = defaultdict(str) @@ -97,7 +97,7 @@ class GradientMonitor: create_directory(output_dirpath) output_path = os.path.join(output_dirpath, f"grad_summary_{self._step}.csv") if os.path.exists(output_path): - logger.warning(f"{output_path} will be recoverd") + logger.warning(f"{output_path} will be deleted") remove_path(output_path) header_result = GradStatCsv.generate_csv_header(self._level_adp, self._bounds) output_lines.insert(0, header_result) diff --git a/debug/accuracy_tools/msprobe/pytorch/grad_probe/grad_stat_csv.py b/debug/accuracy_tools/msprobe/pytorch/grad_probe/grad_stat_csv.py index 6391f8f5e1d00f62240c002c35757bba623c3929..bf72a7fb0ff31fddfd7fe6582582b733428eee2b 100644 --- a/debug/accuracy_tools/msprobe/pytorch/grad_probe/grad_stat_csv.py +++ b/debug/accuracy_tools/msprobe/pytorch/grad_probe/grad_stat_csv.py @@ -17,6 +17,7 @@ from abc import ABC, abstractmethod from collections import namedtuple import hashlib from functools import wraps +import zlib import torch from msprobe.core.grad_probe.constant import GradConst @@ -74,8 +75,8 @@ class CsvMd5(CsvItem): def generate_csv_content(csv_content_input): grad = csv_content_input.grad tensor_bytes = grad.cpu().detach().float().numpy().tobytes() - md5_hash = hashlib.md5(tensor_bytes) - return [md5_hash.hexdigest()] + md5_hash = f"{zlib.crc32(tensor_bytes):08x}" + return [md5_hash] @register_csv_item(GradConst.DISTRIBUTION) diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/api_register.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/api_register.py new file mode 100644 index 0000000000000000000000000000000000000000..552a62a7a57562bb2fc3ca65a3748d3d59eb2a79 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/api_register.py @@ -0,0 +1,177 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +import os +import inspect + +import torch +import torch.distributed as dist + +from msprobe.core.common.const import Const +from msprobe.core.data_dump.api_registry import ApiRegistry +from msprobe.pytorch.common.log import logger +from msprobe.pytorch.common.utils import ( + torch_without_guard_version, is_gpu, torch_device_guard, parameter_adapter +) +from msprobe.pytorch.function_factory import npu_custom_functions +from msprobe.pytorch.hook_module.hook_module import HOOKModule +from msprobe.pytorch.hook_module.utils import dynamic_import_op +from msprobe.core.common.file_utils import load_yaml + +try: + import mindspeed.ops +except ImportError: + mindspeed_enable = False +else: + mindspeed_enable = True + + +torch_version_above_2 = torch.__version__.split('+')[0] > '2.0' + +_inner_used_api = {} +_supported_api_list_path = (os.path.join(os.path.dirname(os.path.realpath(__file__)), Const.SUPPORT_API_FILE_NAME),) +_cuda_func_mapping = {"npu_fusion_attention": "gpu_fusion_attention"} +dist_data_collect_func = {} +origin_wait = getattr(dist.Work, 'wait') + +_api_types = { + Const.PT_FRAMEWORK: { + Const.PT_API_TYPE_FUNCTIONAL: (torch.nn.functional, (torch.nn.functional,)), + Const.PT_API_TYPE_TENSOR: (torch.Tensor, (torch.Tensor,)), + Const.PT_API_TYPE_TORCH: (torch, (torch,)), + Const.PT_API_TYPE_VF: (torch._C._VariableFunctionsClass, (torch._VF,)), + Const.PT_API_TYPE_DIST: (dist, (dist, dist.distributed_c10d)) + } +} +if not is_gpu: + import torch_npu + if torch_without_guard_version: + _api_types.get(Const.PT_FRAMEWORK).update( + { + Const.PT_API_TYPE_NPU: (torch.ops.npu, (torch_npu, torch.ops.npu)) + } + ) + else: + _api_types.get(Const.PT_FRAMEWORK).update( + {Const.PT_API_TYPE_NPU: (torch_npu._C._VariableFunctionsClass, (torch_npu,))} + ) + _api_types.get(Const.PT_FRAMEWORK).update( + { + Const.PT_API_TYPE_NPU_DIST: (torch_npu.distributed, (torch_npu.distributed, + torch_npu.distributed.distributed_c10d)) + } + ) + if mindspeed_enable: + _api_types.get(Const.PT_FRAMEWORK).update({Const.PT_API_TYPE_MINDSPEED: (mindspeed.ops, (mindspeed.ops,))}) + mindspeed_op_list = load_yaml(_supported_api_list_path[0]).get(Const.PT_API_TYPE_MINDSPEED) + mindspeed_op_file_list = [op.split(Const.SEP)[0] + Const.PY_SUFFIX for op in mindspeed_op_list] + dynamic_import_op(mindspeed.ops, mindspeed_op_file_list) + + +@parameter_adapter +def tensor_module_forward(module, *args, **kwargs): + return module.api_func(*args, **kwargs) + + +def dist_module_forward(module, *args, **kwargs): + handle = module.api_func(*args, **kwargs) + try: + bound = inspect.signature(module.api_func).bind(*args, **kwargs) + bound.apply_defaults() + use_async_op_flag = bound.arguments.get("async_op", False) + except Exception as e: + use_async_op_flag = False + logger.warning(f"fail to get dist api's func signature because {e}, no wait") + + def create_async_callback_func(catch_func): + def store_data(): + module.async_op_dump_flag = False + catch_func(module, args, kwargs, handle) + return store_data + + if len(module._forward_hooks.values()) == 0: + return handle + if use_async_op_flag or module.api_name in ['isend', 'irecv']: + module.async_op_dump_flag = True + dist_data_collect_func[handle] = create_async_callback_func(list(module._forward_hooks.values())[0]) + if module.api_name == 'batch_isend_irecv': + if isinstance(handle, list): + for req in handle: + dist_data_collect_func[req] = create_async_callback_func(list(module._forward_hooks.values())[0]) + return handle + + +def redirect_wait(): + def wrapped_wait(work): + def wrapped_wait(*args, **kwargs): + origin_wait(*args, **kwargs) + if args[0] in dist_data_collect_func: + store_func = dist_data_collect_func.pop(args[0]) + store_func() + return wrapped_wait + dist.Work.wait = wrapped_wait(dist.Work) + + +def npu_module_forward(module, *args, **kwargs): + if not module.need_hook: + if module.api_name not in npu_custom_functions: + raise Exception(f'There is not bench function {module.api_name}') + if module.device == Const.CUDA_LOWERCASE: + module.api_name = _cuda_func_mapping.get(module.api_name, module.api_name) + if module.device in [Const.CUDA_LOWERCASE, Const.CPU_LOWERCASE]: + return npu_custom_functions[module.api_name](*args, **kwargs) + return module.api_func(*args, **kwargs) + + +forward_methods = { + "Tensor": tensor_module_forward, + "Distributed": dist_module_forward, + "NPU": npu_module_forward +} + + +class ApiTemplate(HOOKModule): + def __init__(self, api_name, api_func, prefix, hook_build_func, need_hook=True, device=Const.CPU_LOWERCASE): + self.api_name = api_name + self.api_func = api_func + self.prefix = prefix + self.prefix_api_name = prefix + Const.SEP + str(api_name.split(Const.SEP)[-1]) + Const.SEP + self.need_hook = need_hook + self.device = device + self.async_op_dump_flag = False + if self.need_hook: + super().__init__(hook_build_func) + if prefix == Const.DIST_API_TYPE_PREFIX: + self.op_is_distributed = True + + @torch_device_guard + def forward(self, *args, **kwargs): + exec_func = forward_methods.get(self.prefix) + exec_func = functools.partial(exec_func, self) if exec_func else self.api_func + return exec_func(*args, **kwargs) + + +api_register = None + + +def get_api_register(return_new=False): + if return_new: + return ApiRegistry(_api_types, _inner_used_api, _supported_api_list_path, ApiTemplate) + + global api_register + if api_register is None: + api_register = ApiRegistry(_api_types, _inner_used_api, _supported_api_list_path, ApiTemplate) + return api_register diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/api_registry.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/api_registry.py deleted file mode 100644 index 1aad89bd6e89ae839513001b1d51572b50d8280b..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/msprobe/pytorch/hook_module/api_registry.py +++ /dev/null @@ -1,166 +0,0 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -import torch.distributed as dist - -from msprobe.pytorch.hook_module import wrap_torch, wrap_functional, wrap_tensor, wrap_vf, wrap_distributed, wrap_aten -from msprobe.pytorch.hook_module.wrap_aten import get_aten_ops -from msprobe.pytorch.hook_module.wrap_distributed import get_distributed_ops -from msprobe.pytorch.hook_module.wrap_functional import get_functional_ops -from msprobe.pytorch.hook_module.wrap_tensor import get_tensor_ops -from msprobe.pytorch.hook_module.wrap_torch import get_torch_ops -from msprobe.pytorch.hook_module.wrap_vf import get_vf_ops -from msprobe.pytorch.common.utils import torch_without_guard_version, npu_distributed_api, is_gpu -from msprobe.core.common.const import Const - -torch_version_above_2 = torch.__version__.split('+')[0] > '2.0' - -if not is_gpu: - import torch_npu - from . import wrap_npu_custom - from .wrap_npu_custom import get_npu_ops - - -class ApiRegistry: - def __init__(self): - self.tensor_ori_attr = {} - self.torch_ori_attr = {} - self.functional_ori_attr = {} - self.distributed_ori_attr = {} - self.npu_distributed_ori_attr = {} - self.vf_ori_attr = {} - self.aten_ori_attr = {} - self.torch_npu_ori_attr = {} - - self.tensor_hook_attr = {} - self.torch_hook_attr = {} - self.functional_hook_attr = {} - self.distributed_hook_attr = {} - self.npu_distributed_hook_attr = {} - self.vf_hook_attr = {} - self.aten_hook_attr = {} - self.torch_npu_hook_attr = {} - - @staticmethod - def store_ori_attr(ori_api_group, api_list, api_ori_attr): - for api in api_list: - if '.' in api: - sub_module_name, sub_op = api.rsplit('.', 1) - sub_module = getattr(ori_api_group, sub_module_name) - api_ori_attr[api] = getattr(sub_module, sub_op) - else: - api_ori_attr[api] = getattr(ori_api_group, api) - - @staticmethod - def set_api_attr(api_group, attr_dict): - for api, api_attr in attr_dict.items(): - if '.' in api: - sub_module_name, sub_op = api.rsplit('.', 1) - sub_module = getattr(api_group, sub_module_name, None) - if sub_module is not None: - setattr(sub_module, sub_op, api_attr) - else: - setattr(api_group, api, api_attr) - - def api_modularity(self): - self.set_api_attr(torch.Tensor, self.tensor_hook_attr) - self.set_api_attr(torch, self.torch_hook_attr) - self.set_api_attr(torch.nn.functional, self.functional_hook_attr) - self.set_api_attr(dist, self.distributed_hook_attr) - self.set_api_attr(dist.distributed_c10d, self.distributed_hook_attr) - if not is_gpu and not torch_without_guard_version: - self.set_api_attr(torch_npu.distributed, self.npu_distributed_hook_attr) - self.set_api_attr(torch_npu.distributed.distributed_c10d, self.npu_distributed_hook_attr) - if torch_version_above_2: - self.set_api_attr(torch.ops.aten, self.aten_hook_attr) - self.set_api_attr(torch._VF, self.vf_hook_attr) - if not is_gpu: - self.set_api_attr(torch_npu, self.torch_npu_hook_attr) - - def api_originality(self): - self.set_api_attr(torch.Tensor, self.tensor_ori_attr) - self.set_api_attr(torch, self.torch_ori_attr) - self.set_api_attr(torch.nn.functional, self.functional_ori_attr) - self.set_api_attr(dist, self.distributed_ori_attr) - self.set_api_attr(dist.distributed_c10d, self.distributed_ori_attr) - if not is_gpu and not torch_without_guard_version: - self.set_api_attr(torch_npu.distributed, self.npu_distributed_ori_attr) - self.set_api_attr(torch_npu.distributed.distributed_c10d, self.npu_distributed_ori_attr) - if torch_version_above_2: - self.set_api_attr(torch.ops.aten, self.aten_ori_attr) - self.set_api_attr(torch._VF, self.vf_ori_attr) - if not is_gpu: - self.set_api_attr(torch_npu, self.torch_npu_ori_attr) - - def initialize_hook(self, hook, online_run_ut=False): - """ - initialize_hook - Args: - hook (_type_): initialize_hook - online_run_ut (bool): default False, whether online run_ut or not. - If online_run_ut is True, the hook will not wrap the aten ops. - """ - self.store_ori_attr(torch.Tensor, get_tensor_ops(), self.tensor_ori_attr) - wrap_tensor.wrap_tensor_ops_and_bind(hook) - for attr_name in dir(wrap_tensor.HOOKTensor): - if attr_name.startswith(Const.ATTR_NAME_PREFIX): - self.tensor_hook_attr[attr_name[5:]] = getattr(wrap_tensor.HOOKTensor, attr_name) - - self.store_ori_attr(torch, get_torch_ops(), self.torch_ori_attr) - wrap_torch.wrap_torch_ops_and_bind(hook) - for attr_name in dir(wrap_torch.HOOKTorchOP): - if attr_name.startswith(Const.ATTR_NAME_PREFIX): - self.torch_hook_attr[attr_name[5:]] = getattr(wrap_torch.HOOKTorchOP, attr_name) - - self.store_ori_attr(torch.nn.functional, get_functional_ops(), self.functional_ori_attr) - wrap_functional.wrap_functional_ops_and_bind(hook) - for attr_name in dir(wrap_functional.HOOKFunctionalOP): - if attr_name.startswith(Const.ATTR_NAME_PREFIX): - self.functional_hook_attr[attr_name[5:]] = getattr(wrap_functional.HOOKFunctionalOP, attr_name) - - self.store_ori_attr(dist, get_distributed_ops(), self.distributed_ori_attr) - wrap_distributed.wrap_distributed_ops_and_bind(hook) - if not is_gpu and not torch_without_guard_version: - self.store_ori_attr(torch_npu.distributed, npu_distributed_api, self.npu_distributed_ori_attr) - for attr_name in dir(wrap_distributed.HOOKDistributedOP): - if attr_name.startswith(Const.ATTR_NAME_PREFIX): - self.distributed_hook_attr[attr_name[5:]] = getattr(wrap_distributed.HOOKDistributedOP, attr_name) - if not is_gpu and not torch_without_guard_version and attr_name[5:] in npu_distributed_api: - self.npu_distributed_hook_attr[attr_name[5:]] = getattr(wrap_distributed.HOOKDistributedOP, - attr_name) - - if torch_version_above_2 and not online_run_ut: - self.store_ori_attr(torch.ops.aten, get_aten_ops(), self.aten_ori_attr) - wrap_aten.wrap_aten_ops_and_bind(hook) - for attr_name in dir(wrap_aten.HOOKAtenOP): - if attr_name.startswith(Const.ATTR_NAME_PREFIX): - self.aten_hook_attr[attr_name[5:]] = getattr(wrap_aten.HOOKAtenOP, attr_name) - - self.store_ori_attr(torch._VF, get_vf_ops(), self.vf_ori_attr) - wrap_vf.wrap_vf_ops_and_bind(hook) - for attr_name in dir(wrap_vf.HOOKVfOP): - if attr_name.startswith(Const.ATTR_NAME_PREFIX): - self.vf_hook_attr[attr_name[5:]] = getattr(wrap_vf.HOOKVfOP, attr_name) - - if not is_gpu: - self.store_ori_attr(torch_npu, get_npu_ops(), self.torch_npu_ori_attr) - wrap_npu_custom.wrap_npu_ops_and_bind(hook) - for attr_name in dir(wrap_npu_custom.HOOKNpuOP): - if attr_name.startswith(Const.ATTR_NAME_PREFIX): - self.torch_npu_hook_attr[attr_name[5:]] = getattr(wrap_npu_custom.HOOKNpuOP, attr_name) - - -api_register = ApiRegistry() diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/hook_module.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/hook_module.py index b59d4be82f2b55326c2a1d6a8a9e127a8470bff6..0a55f6a9deb8fddf89af112ab7e49a6b74af750e 100644 --- a/debug/accuracy_tools/msprobe/pytorch/hook_module/hook_module.py +++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/hook_module.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -21,35 +21,33 @@ import torch import torch.nn as nn import torch.utils.hooks as full_hooks -torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0' +from msprobe.core.common.runtime import Runtime +from msprobe.pytorch.common.utils import is_float8_tensor, register_forward_pre_hook, register_forward_hook class HOOKModule(nn.Module): module_count = defaultdict(int) inner_stop_hook = {} - def __init__(self, build_hook) -> None: + def __init__(self, hook_build_func) -> None: super(HOOKModule, self).__init__() self.has_overflow = False - self.prefix = "" self.current_thread = threading.current_thread().ident if self.current_thread not in HOOKModule.inner_stop_hook: HOOKModule.inner_stop_hook[self.current_thread] = False self.stop_hook = HOOKModule.inner_stop_hook.get(self.current_thread, False) if not self.stop_hook: - if hasattr(self, "prefix_op_name_"): - self.prefix = self.prefix_op_name_ - self.forward_data_collected = False - forward_pre_hook, forward_hook, backward_hook, _ = build_hook(self.prefix) - if torch_version_above_or_equal_2: - self.register_forward_pre_hook(forward_pre_hook, with_kwargs=True) - self.register_forward_hook(forward_hook, with_kwargs=True) - else: - self.register_forward_pre_hook(forward_pre_hook) - self.register_forward_hook(forward_hook) - self.register_backward_hook(backward_hook) + + if not Runtime.is_running: + return + prefix = self.prefix_api_name if hasattr(self, "prefix_api_name") else "" + if callable(hook_build_func): + hook_set = hook_build_func(prefix) + register_forward_pre_hook(self, hook_set.forward_pre_hook) + register_forward_hook(self, hook_set.forward_hook) + self.register_backward_hook(hook_set.backward_hook) def __call__(self, *args, **kwargs): changed = False @@ -78,13 +76,7 @@ class HOOKModule(nn.Module): if len(self._backward_hooks) > 0: full_backward_hooks, non_full_backward_hooks = self._get_backward_hooks() for hook in self._forward_pre_hooks.values(): - result_args, result_kwargs = hook(self, args, kwargs) - if result_args is not None: - if not isinstance(result_args, tuple): - result_args = (result_args,) - args = result_args - if result_kwargs is not None: - kwargs = result_kwargs + hook(self, args, kwargs) bw_hook = None if len(full_backward_hooks) > 0: bw_hook = full_hooks.BackwardHook(self, full_backward_hooks) @@ -111,6 +103,10 @@ class HOOKModule(nn.Module): return result else: return result + + if is_float8_tensor(var) or not (var.requires_grad and torch.is_grad_enabled()): + return result + grad_fn = var.grad_fn if grad_fn is not None: for hook in non_full_backward_hooks: diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/jit_script_wrapper.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/jit_script_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..ea2ee39ae79544b5a699800cb1e7dc9e0fc9066b --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/jit_script_wrapper.py @@ -0,0 +1,33 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from msprobe.pytorch.hook_module.api_register import get_api_register + + +def wrap_jit_script_func(): + def patched_script(*args, **kwargs): + all_api_registered = api_register.all_api_registered + if all_api_registered: + api_register.restore_all_api() + result = original_script(*args, **kwargs) + if all_api_registered: + api_register.register_all_api() + return result + + original_script = torch.jit.script + api_register = get_api_register() + torch.jit.script = patched_script diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/pt_hook_manager.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/pt_hook_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..413ad3da00adfaebfa3b2652e079dd71267457c2 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/pt_hook_manager.py @@ -0,0 +1,68 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from contextlib import nullcontext + +from msprobe.core.common.const import Const +from msprobe.core.common.utils import replace_last_occurrence +from msprobe.core.hook_manager import BaseHookManager, HookSet +from msprobe.pytorch.common.utils import is_recomputation, torch_version_above_or_equal_2 +from msprobe.pytorch.hook_module.hook_module import HOOKModule + + +class PytorchHookManager(BaseHookManager): + @property + def _is_recompute(self): + return is_recomputation() + + @staticmethod + def _no_grad_context(): + return nullcontext() + + @staticmethod + def _add_count(name): + HOOKModule.add_module_count(name) + + @staticmethod + def _process_kwargs_and_output(module, hook_type, kwargs_or_output, output_or_kwargs): + kwargs = kwargs_or_output if torch_version_above_or_equal_2 else {} + output = output_or_kwargs if torch_version_above_or_equal_2 else kwargs_or_output + return kwargs, output + + def build_hook(self, hook_type, name): + if hook_type == Const.API: + full_forward_name = name + str(HOOKModule.get_module_count(name)) + Const.SEP + Const.FORWARD + else: + full_forward_name = name + full_backward_name = replace_last_occurrence(full_forward_name, Const.FORWARD, Const.BACKWARD) + hookset = HookSet( + forward_hook=self._build_forward_hook(hook_type, full_forward_name), + forward_pre_hook=self._build_forward_pre_hook(hook_type, full_forward_name, name), + backward_hook=self._build_backward_hook(hook_type, full_backward_name) + ) + return hookset + + def _need_exchange(self, module): + return True + + def _get_params_dict(self, module): + params_dict = {} + if self.config.task != Const.STRUCTURE: + params_dict = { + key.split(Const.SEP)[-1]: value + for key, value in module.named_parameters(recurse=False) + } + return params_dict diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/register_optimizer_hook.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/register_optimizer_hook.py index 75be9fc4532ea5863ed3daad569c062c4ccb91ba..b4f9a5f50639752e8094b38961ef600cc6d7b101 100644 --- a/debug/accuracy_tools/msprobe/pytorch/hook_module/register_optimizer_hook.py +++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/register_optimizer_hook.py @@ -32,8 +32,9 @@ def register_optimizer_hook(data_collector): def patch_clip_grad(func): def wrapper(*args, **kwargs): data_collector.optimizer_status = Const.CLIP_GRAD - func(*args, **kwargs) + result = func(*args, **kwargs) data_collector.optimizer_status = Const.END_PREFIX + Const.CLIP_GRAD + return result return wrapper diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/support_wrap_ops.yaml b/debug/accuracy_tools/msprobe/pytorch/hook_module/support_wrap_ops.yaml index 4bc22f51ceb5497f307fb4ac3226c8c590ea459a..f0dedc0dd81d109d9f9883a6208f7fdd01369b95 100644 --- a/debug/accuracy_tools/msprobe/pytorch/hook_module/support_wrap_ops.yaml +++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/support_wrap_ops.yaml @@ -149,9 +149,9 @@ tensor: - __bool__ - __div__ - __eq__ + - __floordiv__ - __ge__ - __gt__ - - __getitem__ - __iadd__ - __iand__ - __idiv__ @@ -160,23 +160,33 @@ tensor: - __imod__ - __imul__ - __ior__ + - __ipow__ - __irshift__ - __isub__ - __ixor__ + - __le__ - __lshift__ + - __lt__ - __matmul__ - __mod__ - __mul__ + - __ne__ - __nonzero__ - __or__ + - __pow__ - __radd__ + - __rdiv__ + - __rmod__ - __rmul__ + - __ror__ + - __rpow__ - __rshift__ + - __rsub__ + - __rxor__ - __setitem__ - __sub__ - __truediv__ - __xor__ - - __pow__ - abs - abs_ - absolute @@ -199,12 +209,14 @@ tensor: - addmv_ - addr - addr_ + - adjoint - align_as - align_to - all - allclose - amax - amin + - aminmax - angle - any - arccos @@ -216,12 +228,15 @@ tensor: - arcsinh - arcsinh_ - arctan + - arctan2 + - arctan2_ - arctan_ - arctanh - arctanh_ - argmax - argmin - argsort + - argwhere - asin - asin_ - asinh @@ -236,39 +251,51 @@ tensor: - baddbmm_ - bernoulli - bernoulli_ + - bfloat16 - bincount - bitwise_and - bitwise_and_ + - bitwise_left_shift + - bitwise_left_shift_ - bitwise_not - bitwise_not_ - bitwise_or - bitwise_or_ + - bitwise_right_shift + - bitwise_right_shift_ - bitwise_xor - bitwise_xor_ - bmm + - bool - broadcast_to + - byte - cauchy_ - ceil - ceil_ + - cfloat + - char - cholesky + - cholesky_inverse + - cholesky_solve - chunk - clamp - - cholesky_solve - - cholesky_inverse - clamp_ - clamp_max - clamp_max_ - - clip - clamp_min - clamp_min_ + - clip - clip_ + - conj_physical - copysign - copysign_ + - corrcoef - cos - cos_ - cosh - cosh_ - count_nonzero + - cov - cummax - cummin - cumprod @@ -282,20 +309,23 @@ tensor: - diag_embed - diagflat - diagonal + - diagonal_scatter - diff - - dist - digamma - digamma_ + - dist - div - div_ - divide - divide_ - dot + - double + - dsplit - eig - eq - eq_ - - erf - equal + - erf - erf_ - erfc - erfc_ @@ -304,18 +334,21 @@ tensor: - exp - exp2 - exp2_ - - expm1 - exp_ + - expand + - expand_as + - expm1 - expm1_ - exponential_ - fill_ - - fix - fill_diagonal_ + - fix - fix_ + - flatten - flip - fliplr - - flatten - flipud + - float - float_power - float_power_ - floor @@ -328,6 +361,7 @@ tensor: - fmod_ - frac - frac_ + - frexp - gather - gcd - gcd_ @@ -338,31 +372,37 @@ tensor: - ger - greater - greater_ - - gt - - gt_ - greater_equal - greater_equal_ + - gt + - gt_ + - half - hardshrink - heaviside - heaviside_ - histc + - histogram + - hsplit - hypot - hypot_ + - i0 + - i0_ - igamma - igamma_ - igammac - igammac_ - index_add - index_add_ - - inverse - index_copy - index_copy_ - index_fill - index_fill_ - index_put - index_put_ - - inner - index_select + - inner + - int + - inverse - isclose - isfinite - isinf @@ -380,7 +420,6 @@ tensor: - le_ - lerp - lerp_ - - where - less - less_ - less_equal @@ -397,43 +436,47 @@ tensor: - log_ - log_normal_ - log_softmax - - logcumsumexp - - logdet - logaddexp - logaddexp2 + - logcumsumexp + - logdet - logical_and - logical_and_ - logical_not - - logit - logical_not_ - logical_or - logical_or_ - logical_xor - logical_xor_ + - logit - logit_ - logsumexp + - long - lstsq - lt - lt_ + - lu - lu_solve - map2_ - map_ - masked_fill - - matmul - masked_fill_ - masked_scatter - masked_scatter_ - masked_select + - matmul - matrix_exp + - matrix_power - max - maximum - mean - - matrix_power - median - min - minimum - mm - mode + - moveaxis + - movedim - msort - mul - mul_ @@ -443,6 +486,11 @@ tensor: - mv - mvlgamma - mvlgamma_ + - nan_to_num + - nan_to_num_ + - nanmean + - nanmedian + - nanquantile - nansum - narrow - narrow_copy @@ -452,20 +500,29 @@ tensor: - neg_ - negative - negative_ + - nextafter + - nextafter_ - nonzero - norm - normal_ - not_equal - not_equal_ + - numpy + - orgqr + - ormqr + - outer - permute - pinverse - polygamma + - polygamma_ - pow - pow_ - - polygamma_ - prelu - prod - put_ + - q_zero_point + - qr + - quantile - rad2deg - rad2deg_ - ravel @@ -474,15 +531,16 @@ tensor: - relu - relu_ - remainder - - repeat_interleave - - reshape - remainder_ - renorm - renorm_ - repeat + - repeat_interleave + - reshape - reshape_as - resize_ - resize_as_ + - resolve_neg - roll - rot90 - round @@ -496,6 +554,7 @@ tensor: - select - sgn - sgn_ + - short - sigmoid - sigmoid_ - sign @@ -507,11 +566,13 @@ tensor: - sinc_ - sinh - sinh_ + - slice_scatter - slogdet - smm - softmax - solve - sort + - split - split_with_sizes - sqrt - sqrt_ @@ -521,21 +582,29 @@ tensor: - squeeze_ - sspaddmm - std + - stft + - stride - sub - sub_ + - subtract - sum - sum_to_size - svd + - swapaxes + - swapdims + - swapdims_ - symeig - t - t_ - take + - take_along_dim - tan - tan_ - tanh - tanh_ - tensor_split - tile + - to - topk - transpose - transpose_ @@ -543,8 +612,8 @@ tensor: - tril - tril_ - triu - - true_divide - triu_ + - true_divide - true_divide_ - trunc - trunc_ @@ -552,37 +621,20 @@ tensor: - unbind - unflatten - unfold + - unique + - unique_consecutive - unsafe_chunk - - unsqueeze - unsafe_split - unsafe_split_with_sizes + - unsqueeze + - unsqueeze_ - var - vdot - - unsqueeze_ - view_as + - vsplit + - where - xlogy - xlogy_ - - split - - stft - - nan_to_num - - dsplit - - orgqr - - bitwise_left_shift_ - - arctan2 - - histogram - - q_zero_point - - adjoint - - ormqr - - bitwise_right_shift_ - - nanquantile - - lu - - quantile - - arctan2_ - - qr - - diagonal_scatter - - corrcoef - - vsplit - - aminmax torch: - linalg.norm @@ -624,6 +676,7 @@ torch: - _batch_norm_impl_index - _convolution - _foreach_norm + - _fused_adamw_ - _softmax_backward_data - abs - abs_ @@ -642,13 +695,14 @@ torch: - addmv - addmv_ - addr - - amax - affine_grid_generator - align_tensors - all - alpha_dropout - - amin - alpha_dropout_ + - amax + - amin + - aminmax - angle - any - arange @@ -661,12 +715,14 @@ torch: - arcsinh - arcsinh_ - arctan + - arctan2 - arctan_ - arctanh - arctanh_ - argmax - argmin - argsort + - argwhere - asin - asin_ - asinh @@ -687,13 +743,13 @@ torch: - batch_norm_elemt - batch_norm_gather_stats - batch_norm_gather_stats_with_counts - - bernoulli - batch_norm_stats - batch_norm_update_stats + - bernoulli - bilinear + - binary_cross_entropy_with_logits - bincount - binomial - - binary_cross_entropy_with_logits - bitwise_and - bitwise_not - bitwise_or @@ -739,9 +795,9 @@ torch: - conv_transpose1d - conv_transpose2d - conv_transpose3d - - cos - convolution - copysign + - cos - cos_ - cosh - cosh_ @@ -755,14 +811,16 @@ torch: - cummin - cumprod - cumsum + - cumulative_trapezoid - deg2rad - deg2rad_ - det - diag - diag_embed - - diff - diagflat - diagonal + - diagonal_scatter + - diff - digamma - dist - div @@ -771,12 +829,15 @@ torch: - dropout - dropout_ - dsmm + - dsplit - dstack - eig - einsum - embedding - embedding_bag - embedding_renorm_ + - empty + - empty_like - eq - equal - erf @@ -791,12 +852,12 @@ torch: - expm1 - expm1_ - eye - - feature_dropout - feature_alpha_dropout - feature_alpha_dropout_ + - feature_dropout - feature_dropout_ - - fix - fill_ + - fix - fix_ - flatten - flip @@ -811,8 +872,9 @@ torch: - fmod - frac - frac_ - - full + - frexp - frobenius_norm + - full - full_like - gather - gcd @@ -824,8 +886,8 @@ torch: - greater_equal - grid_sampler - grid_sampler_2d - - group_norm - grid_sampler_3d + - group_norm - gru - gru_cell - gt @@ -835,23 +897,29 @@ torch: - heaviside - hinge_embedding_loss - histc + - histogram + - histogramdd - hsmm + - hsplit - hspmm - hstack - hypot + - i0 + - i0_ - igamma - igammac - index_add - index_copy - - inner - index_fill - index_put - index_put_ - index_select + - inner - instance_norm - inverse - isclose - isfinite + - isin - isinf - isnan - isneginf @@ -879,8 +947,8 @@ torch: - log1p_ - log2 - log2_ - - log_softmax - log_ + - log_softmax - logaddexp - logaddexp2 - logcumsumexp @@ -899,18 +967,18 @@ torch: - lt - lu_solve - lu_unpack - - masked_fill - margin_ranking_loss + - masked_fill - masked_scatter - masked_select - - matrix_exp - matmul + - matrix_exp - matrix_power - matrix_rank - max - max_pool1d - - max_pool2d - max_pool1d_with_indices + - max_pool2d - max_pool3d - maximum - mean @@ -929,18 +997,20 @@ torch: - mvlgamma - nan_to_num - nan_to_num_ + - nanmean - nanmedian + - nanquantile - nansum - narrow + - narrow_copy - native_batch_norm - native_group_norm - - narrow_copy - native_layer_norm - native_norm - ne - neg - - negative - neg_ + - negative - negative_ - nextafter - nonzero @@ -972,30 +1042,31 @@ torch: - ravel - real - reciprocal - - relu - reciprocal_ + - relu - relu_ - remainder - renorm - repeat_interleave - reshape - resize_as_ + - resolve_neg - roll - rot90 - round - round_ + - row_stack - rrelu - rrelu_ - rsqrt - - row_stack - rsqrt_ - rsub - saddmm - scalar_tensor - scatter - - select - scatter_add - searchsorted + - select - selu - selu_ - sgn @@ -1015,12 +1086,12 @@ torch: - solve - sort - sparse_coo_tensor - - square - split - split_with_sizes - spmm - sqrt - sqrt_ + - square - square_ - squeeze - sspaddmm @@ -1042,8 +1113,8 @@ torch: - tan_ - tanh - tanh_ - - tensordot - tensor_split + - tensordot - threshold - threshold_ - tile @@ -1059,19 +1130,21 @@ torch: - true_divide - trunc - trunc_ - - unique_consecutive - - xlogy - unbind + - unflatten + - unique_consecutive - unsafe_chunk - unsafe_split - - vander - - var - - vdot - unsafe_split_with_sizes - unsqueeze + - vander + - var - var_mean + - vdot + - vsplit - vstack - where + - xlogy - xlogy_ _VF: @@ -1165,6 +1238,28 @@ torch_npu: - npu_moe_finalize_routing - npu_moe_gating_top_k_softmax - npu_trans_quant_param + - npu_gelu + - npu_ffn + - npu_quant_matmul + - npu_format_cast_ + - npu_dynamic_quant + - npu_moe_compute_expert_tokens + - npu_weight_quant_batchmatmul + - npu_dynamic_quant_asymmetric + - npu_grouped_matmul + - npu_quant_scatter_ + - npu_group_quant + - npu_fused_infer_attention_score + - npu_quantize + - npu_fast_gelu + - npu_weight_quant_batchmatmul + - scatter_update + - scatter_update_ + - npu_moe_init_routing + - npu_scatter_nd_update_ + - npu_scatter_nd_update + - npu_prefetch + - npu_dynamic_block_quant aten: - signbit @@ -1912,4 +2007,27 @@ distributed: - all_to_all - all_gather_into_tensor - reduce_scatter_tensor - - batch_isend_irecv \ No newline at end of file + - batch_isend_irecv + +npu_distributed: + - isend + - irecv + +mindspeed: + - dropout_add_layer_norm.npu_dropout_add_layer_norm + - npu_rotary_position_embedding.npu_rotary_position_embedding + - fusion_attention_v2.npu_fusion_attention + - npu_mm_all_reduce_add_rms_norm.npu_mm_all_reduce_add_rms_norm + - npu_mm_all_reduce_add_rms_norm_.npu_mm_all_reduce_add_rms_norm_ + - gmm.npu_gmm + - gmm.npu_gmm_v2 + - npu_grouped_mat_mul_all_reduce.npu_grouped_mat_mul_all_reduce + - ffn.npu_ffn + - npu_moe_token_permute.npu_moe_token_permute + - npu_moe_token_unpermute.npu_moe_token_unpermute + - npu_ring_attention_update.npu_ring_attention_update + - npu_matmul_add.npu_matmul_add_fp32 + - npu_groupmatmul_add.npu_groupmatmul_add_fp32 + - quant_gmm.npu_quant_gmm + - quant_gmm.npu_quant_gmm_v2 + - npu_apply_fused_ema_adamw.npu_apply_fused_ema_adamw \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/utils.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/utils.py index 41869403a547fc526ec422ecbb123af18ff81a39..68e434d0ad151fc70d2a7bbb333b195d4bbe0e2f 100644 --- a/debug/accuracy_tools/msprobe/pytorch/hook_module/utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/utils.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,7 +14,11 @@ # limitations under the License. import os -from msprobe.core.common.file_utils import load_yaml +import importlib +import inspect + +from msprobe.core.common.file_utils import load_yaml, check_link +from msprobe.core.common.log import logger def get_ops(): @@ -26,3 +30,25 @@ def get_ops(): wrap_torch = ops.get('torch') wrap_npu_ops = ops.get('torch_npu') return set(wrap_functional) | set(wrap_tensor) | set(wrap_torch) | set(wrap_npu_ops) + + +def dynamic_import_op(package, white_list): + package_name = package.__name__ + ops = {} + ops_dir, _ = os.path.split(package.__file__) + check_link(ops_dir) + for file_name in os.listdir(ops_dir): + if file_name in white_list: + sub_module_name = file_name[:-3] + module_name = f"{package_name}.{sub_module_name}" + try: + module = importlib.import_module(module_name) + except Exception as e: + logger.warning(f"import {module_name} failed!") + continue + + func_members = inspect.getmembers(module, inspect.isfunction) + for func_member in func_members: + func_name, func = func_member[0], func_member[1] + ops[f"{sub_module_name}.{func_name}"] = func + return ops diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_distributed.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_distributed.py deleted file mode 100644 index 1cd11842c31bacdad7c1bb90f98ac81c3415a40e..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_distributed.py +++ /dev/null @@ -1,79 +0,0 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -from functools import wraps -import torch.distributed as dist - -from msprobe.pytorch.hook_module.hook_module import HOOKModule -from msprobe.pytorch.common.utils import torch_device_guard -from msprobe.core.common.const import Const -from msprobe.core.common.file_utils import load_yaml - - -cur_path = os.path.dirname(os.path.realpath(__file__)) -yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml") - - -distributed_func = {} -for f in dir(dist): - distributed_func[f] = getattr(dist, f) - - -def get_distributed_ops(): - _all_distributed_ops = dir(dist) - yaml_data = load_yaml(yaml_path) - wrap_distributed_ops = yaml_data.get('distributed') - return set(wrap_distributed_ops) & set(_all_distributed_ops) - - -class HOOKDistributedOP(object): - pass - - -class DistributedOPTemplate(HOOKModule): - def __init__(self, op_name, build_hook): - self.op_name_ = op_name - self.prefix_op_name_ = "Distributed" + Const.SEP + str(op_name) + Const.SEP - super().__init__(build_hook) - if not self.stop_hook: - self.op_is_distributed = True - - @torch_device_guard - def forward(self, *args, **kwargs): - handle = distributed_func.get(self.op_name_)(*args, **kwargs) - if kwargs.get("async_op") or self.op_name_ in ["isend", "irecv"]: - if handle and hasattr(handle, 'wait'): - handle.wait() - if self.op_name_ == "batch_isend_irecv": - if isinstance(handle, list): - for req in handle: - req.wait() - return handle - - -def wrap_distributed_op(op_name, hook): - @wraps(DistributedOPTemplate) - def distributed_op_template(*args, **kwargs): - return DistributedOPTemplate(op_name, hook)(*args, **kwargs) - - distributed_op_template.__name__ = op_name - return distributed_op_template - - -def wrap_distributed_ops_and_bind(hook): - _distributed_ops = get_distributed_ops() - for op_name in _distributed_ops: - setattr(HOOKDistributedOP, "wrap_" + str(op_name), wrap_distributed_op(op_name, hook)) diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_npu_custom.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_npu_custom.py deleted file mode 100644 index 1c0afc59f50c069fbcd7e9a546c5b57c467400a9..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_npu_custom.py +++ /dev/null @@ -1,85 +0,0 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import torch - -from msprobe.pytorch.hook_module.hook_module import HOOKModule -from msprobe.pytorch.common.utils import torch_device_guard, torch_without_guard_version -from msprobe.core.common.const import Const -from msprobe.core.common.log import logger -from msprobe.core.common.file_utils import load_yaml -from msprobe.pytorch.function_factory import npu_custom_functions - -try: - import torch_npu -except ImportError: - logger.info("Failing to import torch_npu.") - - -cur_path = os.path.dirname(os.path.realpath(__file__)) -yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml") -cuda_func_mapping = {"npu_fusion_attention" : "gpu_fusion_attention"} - - -def get_npu_ops(): - if torch_without_guard_version: - _npu_ops = dir(torch.ops.npu) - else: - _npu_ops = dir(torch_npu._C._VariableFunctionsClass) - yaml_data = load_yaml(yaml_path) - wrap_npu_ops = yaml_data.get('torch_npu') - return set(wrap_npu_ops) & set(_npu_ops) - - -class HOOKNpuOP(object): - pass - - -class NpuOPTemplate(HOOKModule): - - def __init__(self, op_name, hook, need_hook=True, device=Const.CPU_LOWERCASE): - self.op_name_ = op_name - self.prefix_op_name_ = "NPU" + Const.SEP + str(op_name) + Const.SEP - self.need_hook = need_hook - self.device = device - if need_hook: - super().__init__(hook) - - @torch_device_guard - def forward(self, *args, **kwargs): - if not self.need_hook: - if self.op_name_ not in npu_custom_functions: - raise Exception(f'There is not bench function {self.op_name_}') - if self.device == Const.CUDA_LOWERCASE: - self.op_name_ = cuda_func_mapping.get(self.op_name_, self.op_name_) - if self.device in [Const.CUDA_LOWERCASE, Const.CPU_LOWERCASE]: - return npu_custom_functions[self.op_name_](*args, **kwargs) - if torch_without_guard_version: - return getattr(torch.ops.npu, str(self.op_name_))(*args, **kwargs) - else: - return getattr(torch_npu._C._VariableFunctionsClass, str(self.op_name_))(*args, **kwargs) - - -def wrap_npu_op(op_name, hook): - def npu_op_template(*args, **kwargs): - return NpuOPTemplate(op_name, hook)(*args, **kwargs) - return npu_op_template - - -def wrap_npu_ops_and_bind(hook): - _npu_ops = get_npu_ops() - for op_name in _npu_ops: - setattr(HOOKNpuOP, "wrap_" + str(op_name), wrap_npu_op(op_name, hook)) diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_torch.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_torch.py deleted file mode 100644 index fc9d61c206bcfaeda7fefb5cb8b90fda2d67cb16..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_torch.py +++ /dev/null @@ -1,84 +0,0 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import torch - -from msprobe.pytorch.hook_module.hook_module import HOOKModule -from msprobe.pytorch.common.utils import torch_device_guard -from msprobe.core.common.const import Const -from msprobe.core.common.file_utils import load_yaml - - -cur_path = os.path.dirname(os.path.realpath(__file__)) -yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml") - - -def get_torch_ops(): - _torch_ops = [] - yaml_data = load_yaml(yaml_path) - wrap_torch_ops = yaml_data.get('torch') - for operation in wrap_torch_ops: - if '.' in operation: - operation_sub_module_name, operation_sub_op = operation.rsplit('.', 1) - operation_sub_module = getattr(torch, operation_sub_module_name) - if operation_sub_op in dir(operation_sub_module): - _torch_ops.append(operation) - else: - if hasattr(torch, operation): - _torch_ops.append(operation) - return set(_torch_ops) - - -TorchOps = {} -for op in get_torch_ops(): - if '.' in op: - sub_module_name, sub_op = op.rsplit('.', 1) - sub_module = getattr(torch, sub_module_name) - TorchOps[op] = getattr(sub_module, sub_op) - else: - TorchOps[op] = getattr(torch, op) - - - -class HOOKTorchOP(object): - pass - - -class TorchOPTemplate(HOOKModule): - - def __init__(self, op_name, hook, need_hook=True): - self.op_name_ = op_name - self.prefix_op_name_ = "Torch" + Const.SEP + str(op_name) + Const.SEP - if need_hook: - super().__init__(hook) - - @torch_device_guard - def forward(self, *args, **kwargs): - return TorchOps[str(self.op_name_)](*args, **kwargs) - - -def wrap_torch_op(op_name, hook): - - def torch_op_template(*args, **kwargs): - return TorchOPTemplate(op_name, hook)(*args, **kwargs) - - return torch_op_template - - -def wrap_torch_ops_and_bind(hook): - _torch_ops = get_torch_ops() - for op_name in _torch_ops: - setattr(HOOKTorchOP, "wrap_" + op_name, wrap_torch_op(op_name, hook)) diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/csv2db.py b/debug/accuracy_tools/msprobe/pytorch/monitor/csv2db.py new file mode 100644 index 0000000000000000000000000000000000000000..92334ec7a9d6dcfbd892013e90a447ab340a5c3b --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/csv2db.py @@ -0,0 +1,497 @@ +# Copyright (c) 2025-2026, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import datetime +import os +import re +import sqlite3 +from collections import OrderedDict, defaultdict +from collections.abc import Iterable +from concurrent.futures import ProcessPoolExecutor, as_completed +from dataclasses import dataclass +from typing import List, Optional + +import pytz +from msprobe.core.common.const import MonitorConst +from msprobe.core.common.file_utils import ( + create_directory, + read_csv, + recursive_chmod, + remove_path, +) +from msprobe.core.common.utils import is_int +from msprobe.pytorch.common.log import logger +from msprobe.pytorch.monitor.utils import get_target_output_dir +from tqdm import tqdm + +all_data_type_list = [ + "actv", "actv_grad", "exp_avg", "exp_avg_sq", + "grad_unreduced", "grad_reduced", "param_origin", "param_updated" +] +DEFAULT_INT_VALUE = 0 +MAX_PROCESS_NUM = 128 +CSV_FILE_PATTERN = r"(\w+)_(\d+)-(\d+)\.csv" +BATCH_SIZE = 50000 + + +def check_process_num(process_num): + if not is_int(process_num) or process_num <= 0: + raise ValueError( + f"process_num is not a positive integer") + if process_num > MAX_PROCESS_NUM: + raise ValueError( + f"The maximum supported process_num is {MAX_PROCESS_NUM}, current value: {process_num}.") + + +def check_step_partition(step_partition): + if not is_int(step_partition) or step_partition <= 0: + raise ValueError( + f"step_partition is not a positive integer") + + +def check_data_type_list(data_type_list): + if data_type_list is None or data_type_list == []: + logger.info( + f"data_type_list is None, use default all_data_type_list: {all_data_type_list}") + return + if not isinstance(data_type_list, list): + raise ValueError(f"data_type_list({data_type_list}) is not a list") + for data_type in data_type_list: + if data_type not in all_data_type_list: + raise ValueError( + f"data type({data_type}) is not supported, supported data type: {all_data_type_list}") + + +def update_with_order_dict(main_dict, new_list): + for item in new_list: + if item not in main_dict: + main_dict[item] = None + return main_dict + + +def get_ordered_stats(stats): + if not isinstance(stats, Iterable): + return [] + return [stat for stat in MonitorConst.OP_LIST if stat in stats] + + +def pre_scan_single_rank(rank, files): + metrics = set() + min_step = None + max_step = 0 + metric_stats = defaultdict(set) + targets = OrderedDict() + + for file_path in files: + file_name = os.path.basename(file_path) + match = re.match(CSV_FILE_PATTERN, file_name) + if not match: + continue + metric_name, step_start, step_end = match.groups() + + step_start, step_end = int(step_start), int(step_end) + metrics.add(metric_name) + min_step = step_start if min_step is None or min_step > step_start else min_step + max_step = step_end if max_step < step_end else max_step + + data = read_csv(file_path) + stats = [k for k in data.keys() if k in MonitorConst.OP_LIST] + metric_stats[metric_name].update(stats) + + for _, row in data.iterrows(): + name = row[MonitorConst.HEADER_NAME] + vpp_stage = int(row['vpp_stage']) + micro_step = int(row.get('micro_step', DEFAULT_INT_VALUE)) + target = (vpp_stage, name, micro_step) + if target not in targets: + targets[target] = None + + return { + 'max_rank': int(rank), + 'metrics': metrics, + 'min_step': min_step, + 'max_step': max_step, + 'metric_stats': metric_stats, + 'targets': list(targets.keys()) + } + + +def process_single_rank(task, metric_id_dict, target_dict, step_partition_size, db_path): + rank, files = task + total_inserted = 0 # 跟踪插入行数 + + # 优化连接配置 + conn = sqlite3.connect(db_path, timeout=300) + conn.execute("PRAGMA journal_mode = WAL") # WAL模式提高并发性 + conn.execute("PRAGMA synchronous = OFF") # 关闭同步提升速度 + conn.execute("PRAGMA cache_size = -200000") # 200MB缓存 + conn.execute("PRAGMA temp_store = MEMORY") + + # 使用字典跟踪每个表的批处理数据 + table_batches = defaultdict(list) + + try: + for file in files: + filename = os.path.basename(file) + match = re.match(CSV_FILE_PATTERN, filename) + if not match: + continue + metric_name, _, _ = match.groups() + + metric_info = metric_id_dict.get(metric_name) + if not metric_info: + continue + metric_id, stats = metric_info + + for _, row in read_csv(file).iterrows(): + name = row.get(MonitorConst.HEADER_NAME) + vpp_stage = int(row['vpp_stage']) + micro_step = int(row.get('micro_step', DEFAULT_INT_VALUE)) + target_id = target_dict.get((name, vpp_stage, micro_step)) + if not target_id: + continue + + step = int(row['step']) + try: + data = [rank, step, target_id] + [ + float(row[stat]) if stat in row else None for stat in stats + ] + except ValueError as e: + logger.error( + "CSV float conversion failed | " + f"file:{file} | " + f"error={str(e)} | " + f"row={row} | " + ) + + # 计算表名 + partition_start = ( + step // step_partition_size) * step_partition_size + table_name = f"metric_{metric_id}_step_{partition_start}_{partition_start + step_partition_size - 1}" + + # 添加到批处理 + table_batches[table_name].append(tuple(data)) + + # 当表批处理达到阈值时执行插入 + if len(table_batches[table_name]) >= BATCH_SIZE: + with conn: # 显式事务控制 + placeholders = ', '.join( + ['?'] * len(table_batches[table_name][0])) + conn.executemany( + f"INSERT INTO {table_name} VALUES ({placeholders})", + table_batches[table_name] + ) + total_inserted += len(table_batches[table_name]) + table_batches[table_name] = [] # 清空批处理 + + # 插入剩余数据 + with conn: + for table_name, rows in table_batches.items(): + if rows: + placeholders = ', '.join(['?'] * len(rows[0])) + conn.executemany( + f"INSERT INTO {table_name} VALUES ({placeholders})", + rows + ) + total_inserted += len(rows) + + logger.info(f"Rank {rank} inserted {total_inserted} rows") + except Exception as e: + logger.error(f"Error processing {rank}: {e}") + if conn: + conn.rollback() # 错误时回滚事务 + finally: + conn.close() + + +class MonitorDB: + def __init__(self, db_path, step_partition_size): + self.db_path = db_path + self.step_partition_size = step_partition_size + self.conn = self._get_connection() + self._init_schema() + + @staticmethod + def _get_metric_table_name(metric_id, step_start, step_end): + return f"metric_{metric_id}_step_{step_start}_{step_end}" + + def import_data(self, data_dirs, data_type_list, workers=4): + # 在主进程中查询 metric_id_dict 和 target_dict + rank_tasks = self._pre_scan(data_dirs, data_type_list, workers) + # 获得扫描的Metric信息 + try: + metric_id_dict = self.conn.execute( + "SELECT m.metric_id, m.metric_name, GROUP_CONCAT(ms.stat_name) " + "FROM monitoring_metrics m " + "LEFT JOIN metric_stats ms ON m.metric_id = ms.metric_id " + "GROUP BY m.metric_id" + ).fetchall() + except sqlite3.OperationalError as e: + logger.error(f"Failed to execute metric_id_dict query: {str(e)}") + return + + metric_id_dict = { + row[1]: [row[0], get_ordered_stats(row[2].split(','))] + for row in metric_id_dict + } + try: + target_dict = self.conn.execute( + "SELECT target_id, target_name, vpp_stage, micro_step FROM monitoring_targets" + ).fetchall() + except sqlite3.OperationalError as e: + logger.error(f"Failed to execute target_dict query: {str(e)}") + return + + target_dict = {(row[1], row[2], row[3]): row[0] for row in target_dict} + + with ProcessPoolExecutor(max_workers=workers) as executor: + futures = [] + for rank, files in rank_tasks.items(): + future = executor.submit( + process_single_rank, + (rank, files), + metric_id_dict, + target_dict, + self.step_partition_size, + self.db_path + ) + futures.append(future) + for _ in tqdm(as_completed(futures), total=len(futures), desc="Processing ranks"): + pass + + def _get_connection(self): + """SQLite性能优化配置""" + conn = sqlite3.connect(self.db_path, timeout=300) + # 关键性能参数 + conn.execute("PRAGMA journal_mode = WAL") + conn.execute("PRAGMA synchronous = OFF") + conn.execute("PRAGMA cache_size = -2000") + conn.execute("PRAGMA temp_store = MEMORY") + conn.execute("PRAGMA auto_vacuum = FULL") + return conn + + def _init_schema(self): + self.conn.execute(""" + CREATE TABLE IF NOT EXISTS monitoring_targets ( + target_id INTEGER PRIMARY KEY AUTOINCREMENT, + target_name TEXT NOT NULL, + vpp_stage INTEGER NOT NULL, + micro_step INTEGER NOT NULL DEFAULT 0, + UNIQUE(target_name, vpp_stage, micro_step) + )""") + self.conn.execute(""" + CREATE TABLE IF NOT EXISTS monitoring_metrics ( + metric_id INTEGER PRIMARY KEY AUTOINCREMENT, + metric_name TEXT UNIQUE NOT NULL + )""") + self.conn.execute(""" + CREATE TABLE IF NOT EXISTS metric_stats ( + metric_id INTEGER NOT NULL, + stat_name TEXT NOT NULL, + PRIMARY KEY (metric_id, stat_name), + FOREIGN KEY (metric_id) REFERENCES monitoring_metrics(metric_id) + ) WITHOUT ROWID""") + self.conn.execute(""" + CREATE TABLE IF NOT EXISTS global_stats ( + stat_name TEXT PRIMARY KEY, + stat_value INTEGER NOT NULL + ) WITHOUT ROWID""") + self.conn.executemany( + "INSERT OR IGNORE INTO global_stats VALUES (?, ?)", + [('max_rank', 0), ('min_step', 0), ('max_step', 0), + ('step_partition_size', self.step_partition_size)] + ) + self.conn.commit() + + def _pre_scan(self, data_dirs, data_type_list, workers=1): + """Pre-scan all targets, metrics, and statistics""" + logger.info("Scanning dimensions...") + rank_files = defaultdict(list) + + for rank, dir_path in data_dirs.items(): + files = os.listdir(dir_path) + for file in files: + match = re.match(CSV_FILE_PATTERN, file) + if not match: + continue + metric_name, _, _ = match.groups() + if metric_name not in data_type_list: + continue + rank_files[rank].append(os.path.join(dir_path, file)) + + with ProcessPoolExecutor(max_workers=workers) as executor: + futures = { + executor.submit(pre_scan_single_rank, rank, files): rank + for rank, files in rank_files.items() + } + results = [] + with tqdm(total=len(futures), desc="Pre-scanning ranks") as pbar: + for future in as_completed(futures): + rank = futures[future] + try: + result = future.result() + results.append(result) + except Exception as e: + logger.error(f"Error pre-scanning rank {rank}: {e}") + pbar.update(1) + + targets = OrderedDict() + metrics = set() + min_step = None + max_step = 0 + max_rank = 0 + metric_stats = defaultdict(set) + + for rank_result in results: + max_rank = max(max_rank, rank_result['max_rank']) + metrics.update(rank_result['metrics']) + min_step = rank_result['min_step'] if min_step is None or min_step > rank_result['min_step'] else min_step + max_step = rank_result['max_step'] if max_step < rank_result['max_step'] else max_step + for metric, stats in rank_result['metric_stats'].items(): + metric_stats[metric].update(stats) + targets = update_with_order_dict(targets, rank_result['targets']) + + # Batch insert dimensions + self.conn.executemany( + "INSERT OR IGNORE INTO monitoring_targets (vpp_stage, target_name, micro_step) VALUES (?, ?, ?)", + [m for m in targets] + ) + self.conn.executemany( + "INSERT OR IGNORE INTO monitoring_metrics (metric_name) VALUES (?)", + [(m,) for m in metrics] + ) + + # Insert metric-stat relationships + for metric, stats in metric_stats.items(): + # Get metric_id + metric_id = self._get_metric_id(metric) + ordered_stats = get_ordered_stats(stats) + + # Insert statistics + self.conn.executemany( + "INSERT OR IGNORE INTO metric_stats (metric_id, stat_name) VALUES (?, ?)", + [(metric_id, stat) for stat in ordered_stats] + ) + + # 计算需要创建的分区范围 + first_partition = min_step // self.step_partition_size + last_partition = max_step // self.step_partition_size + + # 为每个分区创建一次表 + for partition in range(first_partition, last_partition + 1): + step_start = partition * self.step_partition_size + self._create_metric_table(metric_id, step_start, ordered_stats) + + self.conn.commit() + self._update_global_stats(max_rank, min_step, max_step) + return rank_files + + def _create_metric_table(self, metric_id, partition_start_step, stats): + """创建按分区划分的指标表""" + table_name = self._get_metric_table_name( + metric_id, + partition_start_step, + partition_start_step + self.step_partition_size - 1 + ) + if not self._table_exists(table_name): + stat_columns = [f"{stat} REAL DEFAULT NULL" for stat in stats] + create_sql = f""" + CREATE TABLE IF NOT EXISTS {table_name} ( + rank INTEGER NOT NULL, + step INTEGER NOT NULL CHECK(step BETWEEN {partition_start_step} + AND {partition_start_step + self.step_partition_size - 1}), + target_id INTEGER NOT NULL, + {', '.join(stat_columns)}, + PRIMARY KEY (rank, step, target_id), + FOREIGN KEY (target_id) REFERENCES monitoring_targets(target_id) + ) WITHOUT ROWID + """ + self.conn.execute(create_sql) + self.conn.commit() + return table_name + + def _table_exists(self, table_name: str) -> bool: + cursor = self.conn.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name=?", + (table_name,) + ) + return cursor.fetchone() is not None + + def _update_global_stats(self, max_rank, min_step, max_step): + min_step = 0 if min_step is None else min_step + with self.conn: + self.conn.execute( + "UPDATE global_stats SET stat_value = ? WHERE stat_name = 'max_rank'", + (max_rank,) + ) + self.conn.execute( + "UPDATE global_stats SET stat_value = ? WHERE stat_name = 'max_step'", + (max_step,) + ) + self.conn.execute( + "UPDATE global_stats SET stat_value = ? WHERE stat_name = 'min_step'", + (min_step,) + ) + + def _get_metric_id(self, metric_name): + cursor = self.conn.execute( + "SELECT metric_id FROM monitoring_metrics WHERE metric_name = ?", + (metric_name,) + ) + result = cursor.fetchone() + return result[0] if result else None + + +@dataclass +class CSV2DBConfig: + monitor_path: str + time_start: Optional[str] = None + time_end: Optional[str] = None + process_num: int = 1 + data_type_list: Optional[List[str]] = None + output_dirpath: Optional[str] = None + step_partition: int = 500 + + +def csv2db(config: CSV2DBConfig): + check_process_num(config.process_num) + check_step_partition(config.step_partition) + check_data_type_list(config.data_type_list) + + target_output_dirs = get_target_output_dir( + config.monitor_path, config.time_start, config.time_end) + + if config.output_dirpath is None: + local_tz = pytz.timezone("Asia/Shanghai") + cur_time = datetime.datetime.now(local_tz).strftime("%b%d_%H-%M-%S") + config.output_dirpath = os.path.join( + config.monitor_path, f"{cur_time}-csv2db") + + create_directory(config.output_dirpath) + db_path = os.path.join(config.output_dirpath, "monitor_metrics.db") + + if os.path.exists(db_path): + remove_path(db_path) + logger.warning(f"Existing path {db_path} will be recovered") + + db = MonitorDB(db_path, config.step_partition) + try: + db.import_data(target_output_dirs, + config.data_type_list if config.data_type_list else all_data_type_list, + workers=config.process_num) + finally: + db.conn.close() + + recursive_chmod(config.output_dirpath) + logger.info(f"Output has been saved to: {config.output_dirpath}") diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/csv2tb.py b/debug/accuracy_tools/msprobe/pytorch/monitor/csv2tb.py index 6ffd1ffabe7b113ff4e61786d4d9f0709b8b605b..ebedae69d0a3c139ac8524807bb972bfeb737df0 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/csv2tb.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/csv2tb.py @@ -22,13 +22,18 @@ from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm from msprobe.core.common.const import MonitorConst -from msprobe.core.common.file_utils import read_csv, create_directory, remove_path +from msprobe.core.common.file_utils import read_csv, create_directory, remove_path, recursive_chmod from msprobe.core.common.utils import is_int +from msprobe.core.common.decorator import recursion_depth_decorator from msprobe.pytorch.common.log import logger from msprobe.pytorch.monitor.utils import get_target_output_dir -all_data_type_list = ["actv", "actv_grad", "exp_avg", "exp_avg_sq", "grad_unreduced", "grad_reduced", "param"] +all_data_type_list = [ + "actv", "actv_grad", "exp_avg", "exp_avg_sq", + "grad_unreduced", "grad_reduced", "param_origin", "param_updated" +] CSV_FILE_SUFFIX = r"_\d+-\d+\.csv" +MAX_PROCESS_NUM = 128 def parse_step_line(line, ops): @@ -46,7 +51,7 @@ def parse_step_line(line, ops): def parse_step_fn(filepath): data = read_csv(filepath) - ops = [k for k in data.keys() if k in MonitorConst.OP_LIST] + ops = [k for k in data.keys() if k in MonitorConst.OP_LIST[:-2]] parse_step_result = {} for _, line in data.iterrows(): @@ -74,8 +79,10 @@ def write_step(output_dirpath, parse_step_result, rank, data_type): for op, value in ops.items(): tag = f"{vpp_name}/{op}" writer.add_scalar(tag, value, step) + writer.flush() +@recursion_depth_decorator("update_dict", max_depth=50) def update_dict(dict1, dict2): for key, value in dict2.items(): if key in dict1: @@ -115,11 +122,13 @@ def csv2tb_by_step_work(target_output_dirs, output_dirpath, data_type_list): def check_process_num(process_num): if not is_int(process_num) or process_num <= 0: raise ValueError(f"process_num({process_num}) is not a positive integer") + if process_num > MAX_PROCESS_NUM: + raise ValueError(f"The maximum supported process_num is {MAX_PROCESS_NUM}, current value: {process_num}.") def check_data_type_list(data_type_list): if data_type_list is None: - logger.info(f"data_type_list is None, use defualt all_data_type_list: {all_data_type_list}") + logger.info(f"data_type_list is None, use default all_data_type_list: {all_data_type_list}") return if not isinstance(data_type_list, list): raise ValueError(f"data_type_list({data_type_list}) is not a list") @@ -161,4 +170,5 @@ def csv2tensorboard_by_step( p.start() for p in processes: p.join() + recursive_chmod(output_dirpath) logger.info(f"output has been saved to: {output_dirpath}") diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/anomaly_detect.py b/debug/accuracy_tools/msprobe/pytorch/monitor/data_writers.py similarity index 48% rename from debug/accuracy_tools/msprobe/pytorch/monitor/anomaly_detect.py rename to debug/accuracy_tools/msprobe/pytorch/monitor/data_writers.py index 63f20b1928c80e1e29d7cb8224f267c246fcaa8b..bd6bde7e9f6ede789f520acc2138492e99bac509 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/anomaly_detect.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/data_writers.py @@ -14,12 +14,8 @@ # limitations under the License. import itertools import os -import statistics as st -import sys -from abc import ABC from collections import defaultdict -from dataclasses import dataclass, field -from typing import List +from dataclasses import dataclass import pandas as pd import torch @@ -27,78 +23,10 @@ from torch.utils.tensorboard import SummaryWriter from msprobe.core.common.const import FileCheckConst, MonitorConst from msprobe.core.common.file_utils import change_mode, create_directory, write_df_to_csv +from msprobe.core.monitor.anomaly_processor import AnomalyDataFactory, AnomalyTurbulence, AnomalyScanner from msprobe.pytorch.common.log import logger -class ScanRule(ABC): - name = "ScanRule" - - def apply(self, history, cur): - raise NotImplementedError("abstract method apply is not implemented") - - -class AnomalyTurbulence(ScanRule): - name = "AnomalyTurbulence" - - def __init__(self, threshold) -> None: - self.threshold = threshold - - def apply(self, history, cur): - baseline = st.mean(history) if isinstance(history, list) else history - - up_bound = baseline + baseline * self.threshold - if baseline > 0: - return cur > up_bound - else: - return cur < up_bound - - -class AnomalyScanner: - - @staticmethod - def load_rules(specs: List[dict]): - """ - specs: [{"rule_name": "AnomalyTurbulence", "args": {"threshold": 0.5}}] - """ - if specs is None: - return [] - alert_rules = [] - for spec in specs: - # 使用get方法获取键值,如果键不存在则返回None - rule_cls_name = spec.get("rule_name") - rule_args = spec.get("args") - - # 检查必要的键是否存在 - if rule_cls_name is None or rule_args is None: - logger.warning(f"Spec is missing required keys: {spec}") - continue - - cur_module = sys.modules.get(__name__) - try: - rule_cls = getattr(cur_module, rule_cls_name) - except AttributeError: - logger.error(f"Rule class '{rule_cls_name}' not found in the current module.") - continue - - try: - rule_instance = rule_cls(**rule_args) - alert_rules.append(rule_instance) - except Exception as e: - logger.error(f"Error creating instance of rule '{rule_cls_name}': {e}") - continue - - return alert_rules - - @staticmethod - def scan(scan_rules: List[ScanRule], history, cur): - anomaly = False - for rule in scan_rules: - anomaly = rule.apply(history, cur) - if anomaly: - return anomaly, rule.name - return anomaly, None - - class BCOLORS: HEADER = '\033[95m' OKBLUE = '\033[94m' @@ -111,130 +39,6 @@ class BCOLORS: UNDERLINE = '\033[4m' -class AnomalyDataFactory(ABC): - def __init__(self, rank, pp_stage, group_mates): - super().__init__() - self.rank = rank - self.pp_stage = pp_stage - self.group_mates = group_mates - self.micro_step = 0 - self.name2callid = {} - - def set_call_id(self, name2callid): - """根据当前GradContext信息更新call_id vpp_stage等信息 - """ - self.name2callid = name2callid - - def create(self, tag, message, step): - """如果检查出异常, 调用当前接口生成GradAnomalyData实例 - tag (tuple): metric tag ('0:1.post_attention_norm.weight/rank0/pre_grad', 'min') - message (str): anomaly detect message - step (int): training step - """ - if not isinstance(tag, tuple) or len(tag) != 2: - raise ValueError("tag must be a tuple with length 2") - tag_name = tag[0] - param_name = tag_name.split('/')[0] - call_id = self.name2callid.get(tag_name, -1) - if MonitorConst.NAME_SEP in param_name: - vpp_stage = int(param_name.split(MonitorConst.NAME_SEP)[0]) - else: - vpp_stage = 0 - - return GradAnomalyData( - self.rank, - step, - self.micro_step, - self.pp_stage, - vpp_stage, - call_id, - tag_name, - message, - self.group_mates - ) - - -class TrainStage: - DEFAULT_STAGE = -1 - FORWARD_STAGE = 0 - BACKWARD_STAGE = 1 - OPTIMIZER_STAGE = 2 - - -FORWARD_KEY = [MonitorConst.ACTV] -BACKWARD_KEY = [MonitorConst.ACTVGRAD, MonitorConst.PRE_GRAD, - MonitorConst.POST_GRAD, MonitorConst.ACC_GRAD] -OPTIMIZER_KEY = [MonitorConst.EXP_AVG, MonitorConst.EXP_AVG_SQ] -TRAIN_STAGE = { - **{key_: TrainStage.FORWARD_STAGE for key_ in FORWARD_KEY}, - **{key_: TrainStage.BACKWARD_STAGE for key_ in BACKWARD_KEY}, - **{key_: TrainStage.OPTIMIZER_STAGE for key_ in OPTIMIZER_KEY} -} - - -@dataclass(eq=True) -class GradAnomalyData: - rank: int = 0 - step: int = 0 - micro_step: int = 0 - pp_stage: int = 0 - vpp_stage: int = 0 - call_id: int = 0 - tag_name: str = field(default=None, compare=False) - message: str = field(default="", compare=False) - group_mates: list = field(default=None, compare=False) - - def __lt__(self, other): - """ - 自定义比较函数,用于确定 GradAnomalyData 实例之间的顺序。 - 比较规则为: - step 和 micro_step 值越小优先级越高; - vpp 和 pp 在前向阶段值越小优先级越高,在非前向阶段值越大优先级越高; - call_id 值越小优先级越高。 - """ - if not isinstance(other, GradAnomalyData): - return NotImplemented - - self_train_stage = self.get_train_stage(self.tag_name) - other_train_stage = self.get_train_stage(other.tag_name) - - def vpp_pp_comparator(anomaly): - """ - Determine the priority rule for vpp and pp based on train stage - Forward stage prefers smaller vpp and pp - Other stages prefer larger vpp and pp - """ - if self_train_stage == TrainStage.FORWARD_STAGE: - return anomaly.vpp_stage, anomaly.pp_stage - else: - return -anomaly.vpp_stage, -anomaly.pp_stage - - self_cmp = [self.step, self.micro_step, self_train_stage, *vpp_pp_comparator(self), self.call_id] - other_cmp = [other.step, other.micro_step, other_train_stage, *vpp_pp_comparator(other), other.call_id] - return self_cmp < other_cmp - - def __le__(self, other): - if not isinstance(other, GradAnomalyData): - return NotImplemented - return self == other or self < other - - @staticmethod - def get_train_stage(tag_name): - """ - :param tag_name: "0:fc2.input:0/rank0/actv", "0:fc1.weight/rank0/post_grad", "0:fc2.weight/rank0/exp_avg_sq" - :return: int, if forward return 0; if backward return 1; if optimizer return 2 - """ - key_ = tag_name.split("/")[-1] - return TRAIN_STAGE.get(key_, TrainStage.DEFAULT_STAGE) - - def to_dict(self): - return self.__dict__ - - def get_key(self): - # 0:1.self_attention.core_attention_flash_0/rank0/input_grad - return ''.join([str(self.tag_name), "_step_", str(self.step), "_call_", str(self.call_id)]) - - @dataclass class WriterInput: path: str @@ -253,6 +57,41 @@ class BaseWriterWithAD: self.anomaly_factory = writer_input.anomaly_factory self.anomalies = [] self.ndigits = writer_input.ndigits + self.beta = 0.99 + + @staticmethod + def stack_tensors(tensor_list): + """ + Torch not support stack cpu and xpu tensors. Group the tensors into cpu_group and xpu_group, + stack them separately, migrate xpu_group to cpu, and then restore in the order of input. + + :param tensor_list: [tensor(-1.6165), tensor(-1.0985), tensor(-1.7777), tensor(-1.8408, device='npu:0')] + :return: result: list of float + """ + cpu_tensors = [] + xpu_tensors = [] + + for tensor in tensor_list: + if isinstance(tensor, torch.Tensor) and tensor.device.type != 'cpu': + # 将device上的tensor先stack后to cpu + xpu_tensors.append(tensor) + else: + cpu_tensors.append(tensor) + + xpu_stack = torch.stack(xpu_tensors).cpu() if xpu_tensors else torch.tensor([]) + + # 按照输入的顺序恢复 + result = [] + cpu_tensors_idx, xpu_tensors_idx = 0, 0 + for tensor in tensor_list: + if isinstance(tensor, torch.Tensor) and tensor.device.type != 'cpu': + result.append(xpu_stack[xpu_tensors_idx]) + xpu_tensors_idx += 1 + else: + result.append(cpu_tensors[cpu_tensors_idx]) + cpu_tensors_idx += 1 + + return result def get_anomalies(self): """返回已检测到的异常列表 @@ -271,12 +110,17 @@ class BaseWriterWithAD: Returns: None """ - detected = False - if self.ad_rules: - avg = self._update_tag2scalars(tag, scalar_value) - detected, rule_name = self._ad(scalar_value, history=avg) + if not self.ad_rules or tag[-1] in ["shape", "dtype"]: + return + if isinstance(scalar_value, torch.Tensor): + scalar_value = scalar_value.item() + avg = self._update_tag2scalars(tag, scalar_value) + detected, rule_name = self._ad(scalar_value, history=avg) if detected: - exception_message = f"Rule {rule_name} reports anomaly signal in {tag} at step {global_step}." + if rule_name == AnomalyTurbulence.name and tag[-1] not in ["norm", "mean"]: + return + exception_message = (f"Rule {rule_name} reports anomaly signal in {tag} at step {global_step}, " + f"current value {scalar_value}, history mean {avg}.") logger.info(f"{BCOLORS.WARNING}> {exception_message}{BCOLORS.ENDC}") # append to self.anomalies for dump if self.anomaly_factory: @@ -291,15 +135,15 @@ class BaseWriterWithAD: tensors.extend(op2tensor.values()) if not tensors: return - + n_slices = len(tensors) // MonitorConst.SLICE_SIZE with torch.no_grad(): for i in range(n_slices + 1): begin = i * MonitorConst.SLICE_SIZE - end = (i+1) * MonitorConst.SLICE_SIZE + end = (i + 1) * MonitorConst.SLICE_SIZE if begin == len(tensors): continue - metric_list = torch.stack(tensors[begin:end]).cpu() + metric_list = self.stack_tensors(tensors[begin:end]) for tag, metric in zip(tags[begin:end], metric_list): self.add_scalar(tag, metric, step) @@ -319,11 +163,11 @@ class BaseWriterWithAD: Returns: float: The average value before update. """ + abs_scalar_value = abs(scalar_value) if tag not in self.tag2scalars: - self.tag2scalars[tag] = {'avg': scalar_value, 'count': 0} + self.tag2scalars[tag] = {'avg': abs_scalar_value, 'count': 0} avg = self.tag2scalars[tag]['avg'] - new_avg = (avg * self.tag2scalars[tag]['count'] + scalar_value) / (self.tag2scalars[tag]['count'] + 1) - self.tag2scalars[tag]['avg'] = new_avg + self.tag2scalars[tag]['avg'] = self.beta * avg + (1 - self.beta) * abs_scalar_value self.tag2scalars[tag]['count'] += 1 return avg @@ -364,7 +208,6 @@ class CSVWriterWithAD(BaseWriterWithAD): new_line = name.split(MonitorConst.NAME_SEP) + metric_value new_line.insert(2, step) new_data.append(new_line) - new_data = pd.DataFrame(new_data).round(self.ndigits).fillna("nan") write_df_to_csv(new_data, filepath, mode='a+', header=False) self.context_dict = defaultdict(list) @@ -376,13 +219,19 @@ class CSVWriterWithAD(BaseWriterWithAD): super().add_scalar(tag, scalar_value, global_step) name = tag[0].split('/')[0] - self.context_dict[name].append(scalar_value.item()) + if isinstance(scalar_value, torch.Tensor): + value = scalar_value.item() + elif isinstance(scalar_value, torch.Size): + value = list(scalar_value) + else: + value = scalar_value + self.context_dict[name].append(value) - def write_metrics(self, ops, metric_value, step, prefix=''): + def write_metrics(self, ops, metric_value, step, prefix='', **kwargs): super().write_metrics(ops, metric_value, step, prefix='') - if prefix in [MonitorConst.ACTV, MonitorConst.ACTVGRAD]: - self.header = MonitorConst.CSV_HEADER_XY + ops + if prefix in [MonitorConst.ACTV, MonitorConst.ACTVGRAD] or kwargs.get("use_micro_step"): + self.header = MonitorConst.CSV_HEADER_MICRO_STEP + ops else: self.header = MonitorConst.CSV_HEADER + ops self.write_csv(prefix, step) diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/distributed/wrap_distributed.py b/debug/accuracy_tools/msprobe/pytorch/monitor/distributed/wrap_distributed.py index b2fa26a58e702120fcabd5d82f8e1e0ed27f3bc4..d819911b910ae23970acfb5e89430bbfcad03763 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/distributed/wrap_distributed.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/distributed/wrap_distributed.py @@ -24,6 +24,7 @@ import torch.nn as nn from msprobe.core.common.const import MonitorConst from msprobe.core.common.file_utils import load_yaml from msprobe.pytorch.monitor.module_metric import get_metrics, get_summary_writer_tag_name +from msprobe.pytorch.common.log import logger try: import torch_npu @@ -37,6 +38,7 @@ WrapDistributedOps = load_yaml(OpsPath).get("distributed", []) StackBlackListPath = os.path.join(os.path.dirname(__file__), "stack_blacklist.yaml") StackBlackList = load_yaml(StackBlackListPath).get("stack", []) +MAX_STRING_LENGTH = 1000 distributed_func = {} for f in dir(dist): @@ -106,7 +108,6 @@ class ApiRegistry: if args[0] in PENDING_ASYNC_CC_BY_HANDLE: store_func = PENDING_ASYNC_CC_BY_HANDLE.pop(args[0]) store_func() - return wrapped_wait dist.Work.wait = wrapped_wait(dist.Work) @@ -139,6 +140,8 @@ def get_process_group(process_group): def stack_filter(stack): + if len(stack) > MAX_STRING_LENGTH: + logger.warning(f'The character string contains more than {MAX_STRING_LENGTH}. re match is skipped.') for pattern in StackBlackList: if re.search(pattern, stack): return False @@ -188,10 +191,12 @@ def update_data(old, new): def is_target_line(codeline): - stack = get_callstack() - whole_stack = ';'.join(stack) if codeline == []: return True + stack = get_callstack() + whole_stack = ';'.join(stack) + if len(whole_stack) > MAX_STRING_LENGTH: + logger.warning(f'The character string contains more than {MAX_STRING_LENGTH}. re match is skipped.') for pattern in codeline: if re.search(pattern, whole_stack): return True @@ -267,7 +272,7 @@ def create_hooks(context, monitor): RANK = dist.get_rank() if dist.is_initialized() and RANK not in monitor.module_rank_list and monitor.module_rank_list != []: return [pre_hooks, hooks] - + if monitor.cc_log_only: pre_hooks.append(cc_log_hook) return [pre_hooks, hooks] diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/features.py b/debug/accuracy_tools/msprobe/pytorch/monitor/features.py index 81c029d401f9194688d332ac711d6065f126ce6a..8e2012afe668556b1cb7fbc6171fe4e028aa66a0 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/features.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/features.py @@ -33,6 +33,11 @@ def get_mean(x: torch.tensor): return torch.mean(x.to(torch.float64)) +@torch.no_grad() +def get_mean(x: torch.tensor): + return torch.mean(x) + + @torch.no_grad() def get_norm(x: torch.tensor): return torch.norm(x.to(torch.float64), p=2) @@ -45,13 +50,18 @@ def get_max(x: torch.tensor): @torch.no_grad() def get_zeros(x: torch.tensor, eps: float): + if x.numel() == 0: + return torch.tensor(float('nan')) return torch.sum(torch.abs(x) < eps) / x.numel() @torch.no_grad() def get_sign_matches(x: torch.tensor, y: torch.tensor): + if y.numel() == 0: + return torch.tensor(1.) xs = x.sign() ys = y.sign() + try: same_direction_ratio = ((xs * ys).sum() / ys.numel() + 1) / 2 except RuntimeError as e: @@ -106,3 +116,258 @@ def cal_histc(tensor_cal, bins_total, min_val, max_val): @torch.no_grad() def get_nans(t): return torch.isnan(t).sum() + + +def check_tensor_dim(tensor, n): + """检查张量维度是否大于n + """ + if not isinstance(tensor, torch.Tensor): + raise TypeError( + f"Input must be a PyTorch tensor. Got {type(tensor)} instead. " + f"Consider using torch.tensor() for conversion." + ) + + if tensor.dim() < n: + raise ValueError( + f"Tensor must have at least {n} dimensions. " + f"Got shape: {tuple(tensor.shape)} with {tensor.dim()} dims." + ) + + +@torch.no_grad() +def max_eigenvalue(input_tensor: torch.Tensor, num_iterations=3): + input_tensor = input_tensor.float() + try: + check_tensor_dim(input_tensor, 2) + except (TypeError, ValueError) as e: + logger.warning(f"Calculate max eigenvalue failed: {e}") + return torch.tensor(0) + in_features = input_tensor.shape[1] + u_tensor = torch.randn(in_features).to(input_tensor.device) + u_norm = u_tensor.norm() + if u_norm.item() == 0: + return torch.tensor(0) + u_tensor = u_tensor / u_tensor.norm() + input_seq = torch.matmul(input_tensor.T, input_tensor) + for _ in range(num_iterations): + v_tensor = torch.matmul(input_seq, u_tensor) + spectral_norm = torch.matmul(v_tensor.T, u_tensor) + v_norm = v_tensor.norm() + if v_norm > 0: + u_tensor = v_tensor / v_norm + else: + spectral_norm = torch.tensor(0) + break + return spectral_norm.sqrt() + + +@torch.no_grad() +def cal_entropy(qk_tensor, mask=None): + try: + check_tensor_dim(qk_tensor, 2) + except (TypeError, ValueError) as e: + logger.warning(f"Calculate max eigenvalue failed: {e}") + return torch.tensor(0), torch.tensor(0) + if mask is None: + mask = torch.tril(torch.ones(qk_tensor.shape[1], qk_tensor.shape[1])).to( + qk_tensor.device) + qk_tensor = qk_tensor - torch.amax(qk_tensor, dim=1, keepdim=True) + qk_tensor = qk_tensor.masked_fill(mask == 0, float('-inf')) + softmax_qkt = torch.nn.functional.softmax(qk_tensor.float(), dim=1) + softmax_max = torch.mean(torch.amax(softmax_qkt, dim=1)) + entropy = torch.mean(-torch.nansum(softmax_qkt * + torch.log(softmax_qkt), dim=1)) + return entropy, softmax_max + + +@torch.no_grad() +def cal_qkt(q_h, k_h, order="s,b,h,d"): + # q_h shape is [s, b, h, d] + try: + check_tensor_dim(q_h, 4) + check_tensor_dim(k_h, 4) + except (TypeError, ValueError) as e: + logger.warning(f"Calculate qk tensor failed: {e}") + return torch.tensor(0) + + if order == "s,b,h,d": + qkt = torch.matmul( + q_h[:, 0, 0, :], k_h[:, 0, 0, :].t()) / q_h.shape[-1] ** 0.5 + elif order == "b,s,h,d": + qkt = torch.matmul( + q_h[0, :, 0, :], k_h[0, :, 0, :].t()) / q_h.shape[-1] ** 0.5 + else: + logger.warning("Calculate qk tensor failed: Order unsupported.") + qkt = torch.tensor(0) + return qkt + + +@torch.no_grad() +def cal_stable_rank(weight: torch.Tensor): + eig = max_eigenvalue(weight) + if eig == torch.tensor(0): + return torch.tensor(0), torch.tensor(0) + f_norm = torch.norm(weight, p="fro") + return f_norm / eig, eig + + +@torch.no_grad() +def cal_svd_entropy(weight: torch.Tensor, k=50): + epsilon = 1e-10 + if isinstance(weight, torch.Tensor): + _, s, _ = torch.svd_lowrank(weight.float(), q=k) + s_sum = torch.sum(s) + if s_sum.item() == 0: + return torch.tensor(0) + p = s / torch.sum(s) + entropy = -torch.sum(p * torch.log2(p + epsilon)) + else: + logger.warning("Calculate SVD entropy failed: Weight is not a tensor") + entropy = torch.tensor(0) + return entropy + + +@torch.no_grad() +def cal_avg_token_similarity(input_sequence: torch.Tensor): + try: + check_tensor_dim(input_sequence, 2) + except (TypeError, ValueError) as e: + logger.warning(f"Calculate avg token similarity failed: {e}") + return torch.tensor(0) + cos_sim_matrix = torch.nn.functional.cosine_similarity( + input_sequence.unsqueeze(0), input_sequence.unsqueeze(1), dim=-1) + abs_cos_sim_matrix = torch.abs(cos_sim_matrix) + avg_abs_cos_sim = abs_cos_sim_matrix.mean() + return avg_abs_cos_sim + + +@torch.no_grad() +def layer_norm_jacobian(input_tensor, weight, eps=1e-10): + """ + :param input: input tensor to layer-norm operator, and should be as shape [1, D]. + :param weight: 1-d weight tensor. + :param eps: default is 1e-5 + """ + try: + check_tensor_dim(input_tensor, 2) + except (TypeError, ValueError) as e: + logger.warning(f"Calculate layer norm jacobian failed: {e}") + return torch.tensor(0), torch.tensor(0) + x = input_tensor.clone() + x = x.detach() + x = x.to(dtype=torch.float32) + device = x.device + dim = x.shape[-1] + diag = torch.eye(dim).to(device=device) + ones = torch.ones(dim, 1).to(device=device) + matrix_a = (1 / dim) * torch.mm(ones, ones.T) + y = torch.mm(x.float(), diag - matrix_a) + std_y = torch.std(y) + jacobian_matrix = (1 / (std_y + eps)) * ((diag - matrix_a) @ + (diag - (torch.mm(y.T, y) / (std_y * std_y * dim + eps)))) + return std_y, max_eigenvalue(weight * jacobian_matrix) + + +@torch.no_grad +def rms_norm_jacobian(input_tensor, weight, eps=1e-10): + try: + check_tensor_dim(input_tensor, 2) + except (TypeError, ValueError) as e: + logger.warning(f"Calculate rms norm jacobian failed: {e}") + return torch.tensor(0), torch.tensor(0) + x = input_tensor.clone() + x = x.detach() + x = x.to(dtype=torch.float32) + device = x.device + std_x = torch.std(x) + dim = x.shape[-1] + diag = torch.eye(dim).to(device=device) + jacobian_matrix = (1 / (std_x + eps)) * \ + (diag - (torch.mm(x.T, x) / (std_x * std_x * dim + eps))) + return std_x, max_eigenvalue(weight * jacobian_matrix) + + +@torch.no_grad() +def cal_avg_token_similarity_chunk(input_sequence: torch.Tensor, chunk_size=256): + try: + check_tensor_dim(input_sequence, 2) + if input_sequence.dim() != 2: + raise ValueError( + f"Tensor must be with 2 dimensions. " + f"Got shape: {tuple(input_sequence.shape)} " + f"with {input_sequence.dim()} dims." + ) + except (TypeError, ValueError) as e: + logger.warning(f"Calculate rms norm jacobian failed: {e}") + return torch.tensor(0) + + device = input_sequence.device + n_dim, _ = input_sequence.shape + + epsilon = 1e-10 + norms = torch.norm(input_sequence, dim=1, keepdim=True) + normalized = input_sequence / (norms + epsilon) + + sum_abs = 0.0 + count = 0 + + for start_i in range(0, n_dim, chunk_size): + end_i = min(start_i + chunk_size, n_dim) + + sub_a = normalized[start_i:end_i, :] + sub_sim = sub_a.mm(normalized.t()) + + sum_abs += torch.sum(torch.abs(sub_sim)).item() + count += sub_sim.numel() + if count != 0: + avg_abs_cos_sim = sum_abs / count + return torch.tensor(avg_abs_cos_sim, device=device) + return torch.tensor(0, device=device) + + +@torch.no_grad() +def cal_kl_divergence(input_tensor, output_tensor): + device = input_tensor.device + data1 = input_tensor.flatten().to(dtype=torch.float32, device=device) + data2 = output_tensor.flatten().to(dtype=torch.float32, device=device) + + min_val = torch.min(torch.min(data1), torch.min(data2)) + max_val = torch.max(torch.max(data1), torch.max(data2)) + if min_val == max_val: + logger.warning( + f"Calculate kl divergence failed: min_val equal to max_val") + return torch.tensor(0), torch.tensor(0) + + bins = 100 + epsilon = 1e-10 + bin_edges = torch.linspace(min_val, max_val, steps=bins + 1, device=device) + bin_width = bin_edges[-1] - bin_edges[0] + + def compute_hist(data): + idx = ((data - bin_edges[0]) / bin_width).floor().long() + idx = torch.clamp(idx, 0, bins - 1) + hist = torch.bincount(idx, minlength=bins).float() + return hist / (hist.sum() + epsilon) + + hist1 = compute_hist(data1) + hist2 = compute_hist(data2) + + kl_div = torch.sum( + hist1 * torch.log((hist1 + epsilon) / (hist2 + epsilon))) + return kl_div + + +@torch.no_grad() +def cal_dist_diff(input_tensor, output_tensor): + device = input_tensor.device + x = input_tensor.flatten().to(dtype=torch.float32, device=device) + y = output_tensor.flatten().to(dtype=torch.float32, device=device) + + mean_diff = torch.abs(x.mean() - y.mean()) + std_diff = torch.abs(x.std(unbiased=False) - y.std(unbiased=False)) + + n_dim = min(x.numel(), y.numel()) + x_sorted = torch.sort(x[:n_dim])[0] + y_sorted = torch.sort(y[:n_dim])[0] + wasserstein_1 = torch.mean(torch.abs(x_sorted - y_sorted)) + return wasserstein_1, mean_diff, std_diff diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py b/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py index d0285564d3cb5c00b69933db3259b7c3339c443d..58c242e2cfd39c17c5b3c528d18780f1b914827a 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py @@ -22,31 +22,40 @@ from functools import partial import pytz import torch import torch.distributed as dist +import pandas as pd from torch.utils.hooks import BackwardHook from msprobe.core.common.const import MonitorConst, Const from msprobe.core.common.file_utils import load_json, save_json +from msprobe.core.common.decorator import recursion_depth_decorator +from msprobe.core.monitor.anomaly_processor import AnomalyScanner, AnomalyDataFactory, AnomalyDataWriter +from msprobe.core.common.file_utils import write_df_to_csv +from msprobe.core.common.utils import analyze_api_call_stack from msprobe.pytorch.common.log import logger -from msprobe.pytorch.common.utils import is_recomputation -from msprobe.pytorch.monitor.anomaly_analyse import AnomalyDataWriter -from msprobe.pytorch.monitor.anomaly_detect import AnomalyScanner, SummaryWriterWithAD, AnomalyDataFactory, \ - CSVWriterWithAD, BaseWriterWithAD, WriterInput +from msprobe.pytorch.common.utils import is_recomputation, is_float8_tensor +from msprobe.pytorch.monitor.data_writers import SummaryWriterWithAD, CSVWriterWithAD, BaseWriterWithAD, WriterInput from msprobe.pytorch.monitor.distributed.wrap_distributed import api_register, create_hooks, op_aggregate, \ get_process_group -from msprobe.pytorch.monitor.features import get_sign_matches +from msprobe.pytorch.monitor.features import get_sign_matches, cal_qkt from msprobe.pytorch.monitor.module_metric import get_metrics, get_summary_writer_tag_name, \ - TensorMetrics, squash_param_name -from msprobe.pytorch.monitor.module_spec_verifier import validate_config_spec + TensorMetrics, squash_param_name, get_sr_metric, get_entropy_metric, \ + get_avg_token_similarity_metric, get_norm_stability_metric from msprobe.pytorch.monitor.optimizer_collect import OptimizerMonFactory from msprobe.pytorch.monitor.utils import get_param_struct, validate_config, validate_ops, \ - get_output_base_dir, get_target_output_dir + get_output_base_dir, get_target_output_dir, chmod_tensorboard_dir, validate_set_monitor from msprobe.pytorch.monitor.visualizer import HeatmapVisualizer +try: + from megatron.core import mpu +except ImportError: + MPU_IMPORT = False +else: + MPU_IMPORT = True + torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0' if not torch_version_above_or_equal_2: raise ValueError("monitor require torch>=2.0") - FORMAT_MAPPING = { MonitorConst.TENSORBOARD: SummaryWriterWithAD, MonitorConst.CSV: CSVWriterWithAD, @@ -71,42 +80,30 @@ class ModuleHookContext: self.actvgrad = [] self.module_name = module_name self.struct = {} - self.format_by_arg = {} - self.verified = False - self.focused_in_col = 0 - self.focused_out_col = 0 - - def set_format_by_arg(self, key_name: str, target_config: dict): - """ 按照监控对象配置format_by_arg - 1) module_name 在 target 中配置监控对象 - 2) module_name 未在 targets 中配置,且 all_xy 全量监控 - 3) module_name 未在 targets 中配置,且 all_xy 未全量监控 - - :param key_name: str, one of [input, output, input_grad, output_grad] - :param target_config: target obj in config json. - :return: - """ - cared = target_config.get(self.module_name, self.struct) - if key_name in cared: - target_module_config = cared[key_name] - if isinstance(target_module_config, dict): - # current cared is self.struct, monitor all data for module_name - self.format_by_arg[key_name] = target_module_config.get('config') - elif isinstance(target_module_config, str): - # current cared is target_config[self.module_name] - self.format_by_arg[key_name] = target_module_config - else: - logger.warning_on_rank_0(f"target module config error, result maybe empty." - f"module_name: {self.module_name}, key_name: {key_name}") - self.format_by_arg[key_name] = None - else: - self.format_by_arg[key_name] = self.struct.get(key_name).get('config') + self.stack = "" def reset(self): self.actv.clear() self.actvgrad.clear() +class FeatureHookContext: + def __init__(self, module_name): + self.step = 0 + self.micro_step = 0 + self.attention_feature = {} + self.linear_feature = {} + self.token_feature = {} + self.norm_feature = {} + self.module_name = module_name + + def reset(self): + self.attention_feature.clear() + self.linear_feature.clear() + self.token_feature.clear() + self.norm_feature.clear() + + start_step = 0 @@ -173,6 +170,20 @@ class GradContext: self.actv.clear() +class ProxyContext: + def __init__(self): + self.gradient_per_sample = defaultdict(dict) + self.gradient_per_sample_metric = {} + self.global_gradient = {} + self.proxy = {} + + def reset(self): + self.gradient_per_sample.clear() + self.gradient_per_sample_metric.clear() + self.global_gradient.clear() + self.proxy.clear() + + class TrainerMon: tensor_metrics = TensorMetrics() @@ -184,8 +195,8 @@ class TrainerMon: self.params_have_main_grad = params_have_main_grad self.update_heatmap_visualizer = defaultdict(HeatmapVisualizer) self.ratio_heatmap_visualizer = defaultdict(HeatmapVisualizer) - self.origin_step_func = None self.origin_start_grad_sync = None + self.fsdp_post_backward_hook = None self.config_timestamp = 0 # 后面有校验时间戳, 首次监控无需为了更新config文件时间戳而去改, 可通过dynamic_on开关直接打开 self.config = load_json(config_file_path) validate_config(self.config) @@ -196,6 +207,7 @@ class TrainerMon: self.unique_id = str(uuid.uuid4())[:8] self.output_base_dir = get_output_base_dir() time_tags = self.config.get("append_output", []) + self.grad_per_sample_hooked = False if dist.is_initialized(): self.rank = dist.get_rank() if time_tags: @@ -220,20 +232,21 @@ class TrainerMon: self.dp_group = None self.tp_group = None self.enable_megatron = False + self.fsdp_wrapped_module = False self.micro_batch_number = 1 - self.optimizer_class = None self.optimizer_mon = None self.optimizer_trans = None # TYPE3: 会随着训练中途config配置更新或监控状态改变而重置的变量 self.module_fwd_hook_context_by_module = defaultdict(ModuleHookContext) self.module_bwd_hook_context_by_module = defaultdict(ModuleHookContext) + self.feature_hook_context_by_module = defaultdict(FeatureHookContext) self.optimizer_context = defaultdict(OptimizerContext) self.cc_context = defaultdict(CommunicationContext) self.grad_context = GradContext() + self.proxy_context = ProxyContext() self.handles = defaultdict(list) self.param2name = defaultdict(str) - self.name2index = defaultdict() self.name2indices = defaultdict() self.name2param = {} self.duplicate_param = {} @@ -243,9 +256,12 @@ class TrainerMon: self.module_struct = defaultdict(dict) self.grad_accs = [] self.weight_hooked = False + self.proxy_hooked = False self.optimizer_hooked = False self.param_registered = False self.struct_printed = False + self.pre_step_hooks = [] + self.post_step_hooks = [] # 动静态区分 self.dynamic_enable = os.getenv("DYNAMIC_MONITOR", 'False').lower() == 'true' @@ -313,9 +329,14 @@ class TrainerMon: self.ur_distribution = self.config.get('ur_distribution', False) self.mv_distribution = self.config.get("mv_distribution", False) self.wg_distribution = self.config.get("wg_distribution", False) + self.proxy_model = self.config.get("proxy_model", False) self.param_distribution = self.config.get("param_distribution", False) self.mg_direction = self.config.get('mg_direction', False) self.cc_distribution = self.config.get("cc_distribution", {}) + self.stack_info = self.config.get('stack_info', False) + self.monitor_mbs_grad = self.config.get('monitor_mbs_grad', False) + self.recording_l2_features = self.config.get("recording_l2_features", False) + self.sa_order = self.config.get("sa_order", "s,b,h,d") if not self.cc_distribution.get('enable', False): self.cc_log_only = False @@ -324,8 +345,6 @@ class TrainerMon: self.cc_log_only = self.cc_distribution.get('cc_log_only', False) self.cc_logged_stack = defaultdict(set) self.cc_pre_hook = self.cc_distribution.get('cc_pre_hook', False) - self.handles['cc'] = api_register.initialize_hook(*create_hooks(context=self.cc_context, monitor=self)) - api_register.redirect_api() self.common_info() @@ -338,11 +357,11 @@ class TrainerMon: # 初始化writer, 创建输出目录 if self.format not in FORMAT_MAPPING: - logger.error(f"Unsupported format: {self.format}, use default format: {MonitorConst.CSV}") + logger.warning(f"Unsupported format: {self.format}, use default format: {MonitorConst.CSV}") self.format = MonitorConst.CSV if self.ur_distribution and self.format != 'tensorboard': - logger.error("can only set ur_distribution when format is 'tensorboard', cancel ur_distribution") + logger.warning("can only set ur_distribution when format is 'tensorboard', cancel ur_distribution") self.ur_distribution = False writer = FORMAT_MAPPING[self.format] @@ -376,6 +395,10 @@ class TrainerMon: logger.info_on_rank_0("> momentum and variance of adam is not monitored. ") if not self.wg_distribution: logger.info_on_rank_0("> weight grad of specified module is not monitored. ") + if not self.proxy_model: + logger.info_on_rank_0("> proxy model of specified module is not monitored. ") + if not self.recording_l2_features: + logger.info_on_rank_0("> l2 features of specified module is not monitored. ") if not self.mg_direction: logger.info_on_rank_0('> grad and momentum direction will not be compared.') if not self.cc_distribution.get('enable', False): @@ -405,13 +428,14 @@ class TrainerMon: start_iteration=0 ): """External interface""" + grad_acc_steps, start_iteration = validate_set_monitor(grad_acc_steps, start_iteration) global start_step start_step = start_iteration logger.info(f'grad acc steps {grad_acc_steps}') self.micro_batch_number = grad_acc_steps self.dp_group = dp_group self.tp_group = tp_group - self.optimizer_mon, self.optimizer_class = OptimizerMonFactory.create_optimizer_mon(optimizer) + self.optimizer_mon = OptimizerMonFactory.create_optimizer_mon(optimizer) self.hook_step_final(optimizer) if not isinstance(model, list): model = [model] @@ -426,7 +450,11 @@ class TrainerMon: self._register_param_name() self.hook_optimizer(optimizer) self._patch_grad_sync() + self._register_proxy_model_content() self.hook_modules() + if self.cc_distribution.get('enable', False): + self.handles['cc'] = api_register.initialize_hook(*create_hooks(context=self.cc_context, monitor=self)) + api_register.redirect_api() self.monitoring = True def adhoc_check(self, target_tensor: torch.tensor, module_name: str, tensor_name: str, rank_list, ops_list): @@ -437,25 +465,48 @@ class TrainerMon: return self.tensor_metrics.stat_insert(target_tensor, ops_list, module_name, tensor_name, rank) - def build_tbtag_tensor_map(self, module_name, tag, tensor): - key = get_summary_writer_tag_name(module_name, tag, self.rank) - self._register_param_call_id("_hook_module", key) - return {key: tensor} + def build_tbtag_tensor_map(self, module_name, suffix, tag, tensor): + """ + :param module_name: str of module name + :param suffix: + :param tag: + :param tensor: torch.tensor or tuple/list of torch.tensor + :return: tensor_map + """ + tensor_map = {} + if isinstance(tensor, torch.Tensor): + tensor = [tensor] + if isinstance(tensor, tuple) or isinstance(tensor, list): + if len(tensor) == 1: + key = get_summary_writer_tag_name(module_name + suffix, tag, self.rank) + self.register_param_call_id("_hook_module", key) + tensor_map[key] = tensor[0] + else: + for i, tensor_i in enumerate(tensor): + key = get_summary_writer_tag_name(module_name + f"_{i}" + suffix, tag, self.rank) + self.register_param_call_id("_hook_module", key) + tensor_map[key] = tensor_i + return tensor_map def generate_param_map(self, tag, param_tensor): metrics = {} for name in self.param2name.values(): key = get_summary_writer_tag_name(name, tag, self.rank) - self._register_param_call_id("optimizer_pre_step_hook", key) + self.register_param_call_id("optimizer_pre_step_hook", key) if name not in param_tensor or param_tensor[name] is None: continue metrics[key] = param_tensor[name] return metrics - def generate_param_metrics(self, opt_context): + def generate_param_metrics(self, opt_context, stage=MonitorConst.PRE_PARAM): if not self.param_distribution: return - get_metrics(self.ops, self.name2param, self.eps, opt_context.param_metric) + tag2param = { + self.name2tag.get(name, {}).get(stage): param + for name, param in self.name2param.items() + if param.numel() != 0 + } + get_metrics(self.ops, tag2param, self.eps, opt_context.param_metric) def generate_mv_metrics(self, opt_context): if not self.mv_distribution: @@ -467,13 +518,27 @@ class TrainerMon: get_metrics(self.ops, m_tag_tensor_map, self.eps, opt_context.exp_avg_metric) get_metrics(self.ops, v_tag_tensor_map, self.eps, opt_context.exp_avg_sq_metric) - def generate_wgrad_metrics(self): + def generate_wgrad_metrics(self, post_grad_dict): if not self.wg_distribution: return {}, {} if self.weight_hooked: get_metrics(self.ops, self.grad_context.acc, self.eps, self.grad_context.acc_metric) + get_metrics(self.ops, post_grad_dict, self.eps, self.grad_context.post) + reduced_grad = self.grad_context.post + + if self.weight_hooked: + unreduced_grad = self.grad_context.acc_metric + else: + unreduced_grad = self.grad_context.pre + + return reduced_grad, unreduced_grad + + def generate_proxy_metrics(self): + if not self.proxy_model: + return + grad_dict = {} for param, name in self.param2name.items(): if self.duplicate_param.get(name, False): @@ -482,13 +547,34 @@ class TrainerMon: if grad is None: logger.warning(f"grad is None: {name}, maybe something wrong happened.") continue - tag = self.name2tag.get(name, {}).get(MonitorConst.POST_GRAD) - self._register_param_call_id("hook_optimizer", tag) - grad_dict[tag] = grad - - get_metrics(self.ops, grad_dict, self.eps, self.grad_context.post) - unreduced_grad = self.grad_context.acc_metric if self.weight_hooked else self.grad_context.pre - return self.grad_context.post, unreduced_grad + key = get_summary_writer_tag_name(name, 'global_gradient', self.rank) + self.register_param_call_id("global_gradient_in_proxy_hook", key) + grad_dict[name] = grad + + get_metrics(['norm'], grad_dict, self.eps, self.proxy_context.global_gradient) + + if self.grad_per_sample_hooked: + for key, value in self.proxy_context.gradient_per_sample.items(): + temp_dict = get_metrics(['norm'], value, self.eps) + self.proxy_context.gradient_per_sample_metric[key] = torch.stack( + [value['norm'] for value in temp_dict.values()]).sum() + if self.proxy_context.gradient_per_sample_metric: + unreduced_norm = torch.stack([value for value in self.proxy_context.gradient_per_sample_metric.values()]) + keys = [key for key in self.proxy_context.gradient_per_sample_metric.keys()] + rank = torch.distributed.get_rank(self.dp_group) + dp_list = [torch.empty_like(unreduced_norm) + for _ in range(torch.distributed.get_world_size(group=self.dp_group))] + dp_list[rank] = unreduced_norm + torch.distributed.all_gather(dp_list, unreduced_norm, group=self.dp_group) + reduced_norm = torch.stack(dp_list).sum(dim=0) + reduced_gradient = dict(zip(keys, reduced_norm)) + for key, item in self.proxy_context.global_gradient.items(): + temp_reduced_norm = reduced_gradient.get(key, None) + if temp_reduced_norm is not None: + self.proxy_context.proxy[key] = {} + self.proxy_context.proxy[key]['proxy'] = temp_reduced_norm / item['norm'] + else: + logger.warning(f"reduced_norm is None: {key}, maybe something wrong happened.") def generate_xy_metrics(self): actv = {} @@ -508,16 +594,27 @@ class TrainerMon: handle.remove() self.handles['xy'].clear() self.hook_modules() - for _, fwd_context in self.module_fwd_hook_context_by_module.items(): + for fwd_context in self.module_fwd_hook_context_by_module.values(): fwd_context.actv.clear() def write_adhoc_check(self, step): self.tensor_metrics.flush(self.summary_writer) + def write_stack_info(self): + stack_data = [] + header = ["module_name", "stack_info"] + stack_data.append(header) + for fwd_context in self.module_fwd_hook_context_by_module.values(): + stack_data.append([fwd_context.module_name, fwd_context.stack]) + filepath = os.path.join(self.tensorboard_dir, f'stack_info.csv') + if not os.path.exists(filepath): + data_frame = pd.DataFrame(columns=stack_data) + write_df_to_csv(data_frame, filepath) + def write_xy_tb(self, step): if not self.xy_distribution: return - for _, fwd_context in self.module_fwd_hook_context_by_module.items(): + for fwd_context in self.module_fwd_hook_context_by_module.values(): if len(fwd_context.actv) == 0: continue self.summary_writer.write_metrics(self.ops, fwd_context.actv, step, MonitorConst.ACTV) @@ -525,10 +622,37 @@ class TrainerMon: if self.grad_context.actv: self.summary_writer.write_metrics(self.ops, self.grad_context.actv, step, MonitorConst.ACTVGRAD) + def write_metrics_if_not_empty(self, features, metrics, step, hook_name): + if len(features) != 0: + self.summary_writer.write_metrics(metrics, features, step, hook_name, use_micro_step=True) + features.clear() + + def write_features_tb(self, step): + if not self.recording_l2_features: + return + for context in self.feature_hook_context_by_module.values(): + num_features = len(context.attention_feature) + len(context.linear_feature) + len( + context.token_feature) + len(context.norm_feature) + if num_features == 0: + continue + self.write_metrics_if_not_empty(context.attention_feature, ["entropy", "softmax_max"], + step, "attention_hook") + self.write_metrics_if_not_empty(context.linear_feature, ["sr", "kernel_norm"], step, "linear_hook") + self.write_metrics_if_not_empty(context.token_feature, ["token_similarity"], step, "token_hook") + self.write_metrics_if_not_empty(context.norm_feature, ["std_x", "jacobian"], step, "norm_hook") + def write_param_tb(self, opt_context): if not self.param_distribution: return - self.summary_writer.write_metrics(self.ops, opt_context.param_metric, opt_context.step, MonitorConst.PARAM) + param_metrics = {k: v for k, v in opt_context.param_metric.items() if MonitorConst.PRE_PARAM in k} + updated_param_metrics = {k: v for k, v in opt_context.param_metric.items() if MonitorConst.POST_PARAM in k} + self.summary_writer.write_metrics(self.ops, param_metrics, opt_context.step, MonitorConst.PRE_PARAM) + self.summary_writer.write_metrics(self.ops, updated_param_metrics, opt_context.step, MonitorConst.POST_PARAM) + + def write_proxy_tb(self, step): + if not self.proxy_model: + return + self.summary_writer.write_metrics(['proxy'], self.proxy_context.proxy, step, 'proxy_model') def write_mv_tb(self, opt_context): if not self.mv_distribution: @@ -542,10 +666,11 @@ class TrainerMon: if not self.wg_distribution: return - if self.enable_megatron: - self.summary_writer.write_metrics(self.ops, self.grad_context.pre, step, 'grad_unreduced') + if self.weight_hooked: + self.summary_writer.write_metrics(self.ops, self.grad_context.acc_metric, step, 'grad_unreduced', + use_micro_step=self.monitor_mbs_grad) else: - self.summary_writer.write_metrics(self.ops, self.grad_context.acc_metric, step, 'grad_unreduced') + self.summary_writer.write_metrics(self.ops, self.grad_context.pre, step, 'grad_unreduced') self.summary_writer.write_metrics(self.ops, self.grad_context.post, step, 'grad_reduced') def hook_optimizer(self, optimizer): @@ -567,21 +692,24 @@ class TrainerMon: # skip generate metrics if context.step < self.start_step or (context.step - self.start_step) % self.step_interval != 0: return - if MonitorConst.DEEPSPEED_ZERO_OPT_FILTER in self.optimizer_class: # use deepspeed with zero1/2/3 - if not self.name2indices: - self.name2indices = self.optimizer_mon.get_param_index(self.param2name, self.name2index, optimizer) - mv_result = self.optimizer_mon.fetch_mv(self, optimizer, self.param2name, self.name2indices) - self.param2name = mv_result.grad - else: - mv_result = self.optimizer_mon.fetch_mv(self, optimizer, self.param2name) - context.param_exp_avg = mv_result.exp_avg - context.param_exp_avg_sq = mv_result.exp_avg_sq - context.param_adam_update = mv_result.update - context.param_adam_ratio = mv_result.ratio - self.generate_wgrad_metrics() + grad_dict = {} + if self.wg_distribution: + grad_dict = self.optimizer_mon.fetch_grad(self, self.param2name) + + mv_result = None + if self.mv_distribution or self.ur_distribution or self.mg_direction: + mv_result = self.optimizer_mon.fetch_mv(self, self.param2name) + if mv_result: + context.param_exp_avg = mv_result.exp_avg + context.param_exp_avg_sq = mv_result.exp_avg_sq + context.param_adam_update = mv_result.update + context.param_adam_ratio = mv_result.ratio + + _, _ = self.generate_wgrad_metrics(grad_dict) self.generate_mv_metrics(context) - self.generate_param_metrics(context) + self.generate_proxy_metrics() + self.generate_param_metrics(context, MonitorConst.PRE_PARAM) tbtag_tensor_map = {} if self.mg_direction: @@ -609,17 +737,15 @@ class TrainerMon: context.metric_dict = metric_dict return - def patch_step(func, optimizer): - def wrapper(*args, **kwargs): - optimizer_pre_step_hook(optimizer, args, kwargs) - out = func(*args, **kwargs) - return out - return wrapper + def optimizer_post_step_hook(optimizer, args, kwargs): + context = self.optimizer_context[optimizer] + self.generate_param_metrics(context, MonitorConst.POST_PARAM) if self.optimizer_hooked: return - optimizer.__class__.step = patch_step(optimizer.__class__.step, optimizer) + self.pre_step_hooks.append(optimizer_pre_step_hook) + self.post_step_hooks.append(optimizer_post_step_hook) self.optimizer_hooked = True return @@ -649,6 +775,7 @@ class TrainerMon: validate_config(config) self.config = config self.set_config() + self.start_step = context.step # 动态启停时不受原start_step影响,永远从下一步开始 logger.warning(f"config is updated at step{context.step - 1}, " f"will start new hook at step{context.step}.") except Exception as e: @@ -665,19 +792,27 @@ class TrainerMon: # 静态在第0步就可以保存, 动态在第0步不可以, 因为动态设计的就是重置后下一步开启, 第0步的self.monitoring还是False if self.monitoring: module_rank_valid = not self.module_rank_list or ( - dist.is_initialized() and dist.get_rank() in self.module_rank_list) + dist.is_initialized() and dist.get_rank() in self.module_rank_list) step_condition = (context.step >= self.start_step and ( - context.step - self.start_step) % self.step_interval == 0) + context.step - self.start_step) % self.step_interval == 0) if module_rank_valid and step_condition: self.has_collect_times += 1 if self.anomaly_data_factory: self.anomaly_data_factory.set_call_id(self.param_name_call_id) self.write_xy_tb(context.step) + self.write_features_tb(context.step) self.write_grad_tb(context.step) + self.write_proxy_tb(context.step) self.write_mv_tb(context) self.write_param_tb(context) self.write_adhoc_check(context.step) + if self.stack_info: + self.write_stack_info() + self.stack_info = False + for handle in self.handles["stack"]: + handle.remove() + self.handles["stack"].clear() if self.ur_distribution: for param_name, _ in context.param_adam_update.items(): @@ -696,6 +831,9 @@ class TrainerMon: if self.anomaly_data_factory: self.anomaly_data_writer.write_detected_json(self.summary_writer.get_anomalies()) self.summary_writer.clear_anomalies() + + if self.format == MonitorConst.TENSORBOARD: + chmod_tensorboard_dir(self.tensorboard_dir) self.call_id = 0 self.param_name_call_id.clear() @@ -707,13 +845,17 @@ class TrainerMon: def patch_step(func, optimizer): def wrapper(*args, **kwargs): + for hook in self.pre_step_hooks: + hook(optimizer, args, kwargs) out = func(*args, **kwargs) + for hook in self.post_step_hooks: + hook(optimizer, args, kwargs) step_final_hook(optimizer, args, kwargs) return out + return wrapper optimizer.__class__.step = patch_step(optimizer.__class__.step, optimizer) - self.origin_step_func = optimizer.__class__.step return def hook_modules(self): @@ -731,14 +873,16 @@ class TrainerMon: vpp_stage = f'{vpp_stage}{MonitorConst.NAME_SEP}' targets = [x for x, _ in model_chunk.named_modules()] if self.print_struct else self.config[ 'targets'].keys() - hooked_count += self._hook_module(targets, model_chunk, vpp_stage) + l2_target_names = self.config.get('l2_targets', '') + hooked_count += self._hook_module(targets, l2_target_names, model_chunk, vpp_stage) logger.info_on_rank_0(f"> {hooked_count} modules are monitored.") + @recursion_depth_decorator('msprobe.pytorch.monitor.clone_if_tensor') def clone_if_tensor(args): if isinstance(args, tuple): return tuple([clone_if_tensor(arg) for arg in args]) - elif isinstance(args, torch.Tensor): + elif isinstance(args, torch.Tensor) and not is_float8_tensor(args): return args.clone() else: return args @@ -756,17 +900,33 @@ class TrainerMon: BackwardHook.setup_output_hook = wrap_hook_setup(BackwardHook.setup_output_hook) return + def register_param_call_id(self, hook_name: str, key: str): + """ + :param hook_name: + :param key: str, '0:relu_0/output_grad' + :return: + """ + logger.debug(f"{hook_name} {key}: {self.call_id}") + self.param_name_call_id[key] = self.call_id + self.call_id += 1 + def _remove_all_hooks(self, optimizer): # 清空hook handle for handle in self.handles['xy']: handle.remove() self.handles['xy'].clear() + for handle in self.handles['L2_features']: + handle.remove() + self.handles['L2_features'].clear() # 清空对应context缓存 - for _, fwd_context in self.module_fwd_hook_context_by_module.items(): + for fwd_context in self.module_fwd_hook_context_by_module.values(): fwd_context.reset() - for _, bwd_context in self.module_bwd_hook_context_by_module.items(): + for bwd_context in self.module_bwd_hook_context_by_module.values(): bwd_context.reset() + for feature_context in self.feature_hook_context_by_module.values(): + feature_context.reset() self.grad_context.reset() # 权重梯度和激活值梯度都在这 + self.proxy_context.reset() if self.origin_start_grad_sync: # megatron try: @@ -781,14 +941,18 @@ class TrainerMon: logger.info("remove _ParamAndGradBucketGroup start_grad_sync") except ImportError: pass - else: # not megatron + elif self.fsdp_post_backward_hook: # fsdp + torch.distributed.fsdp._runtime_utils._post_backward_hook = self.fsdp_post_backward_hook + logger.info("remove patch_post_backward_hook in fsdp.") + else: # not megatron and not fsdp for handle in self.handles['wgrads']: handle.remove() self.handles['wgrads'].clear() self.weight_hooked = False if self.optimizer_hooked: - optimizer.__class__.step = self.origin_step_func + self.pre_step_hooks.clear() + self.post_step_hooks.clear() for _, context in self.optimizer_context.items(): context.reset() @@ -797,12 +961,12 @@ class TrainerMon: for handle in self.handles['cc']: handle.remove() self.handles['cc'].clear() + api_register.restore_api() for _, context in self.cc_context.items(): context.reset() # 清空节点缓存 self.param2name.clear() - self.name2index.clear() self.name2indices.clear() self.name2param.clear() self.duplicate_param.clear() @@ -841,12 +1005,17 @@ class TrainerMon: logger.info(msg) def _save_module_struct(self): - save_module_struct = (not dist.is_initialized() - or (self.module_rank_list and dist.get_rank() == min(self.module_rank_list)) - or (not self.module_rank_list and dist.get_rank() == 0)) - + if MPU_IMPORT: + pp_group = mpu.get_pipeline_model_parallel_group() + pp_group_list = torch.distributed.get_process_group_ranks(pp_group) + save_module_struct = not dist.is_initialized() or dist.get_rank() in pp_group_list + else: + save_module_struct = (not dist.is_initialized() + or (self.module_rank_list and dist.get_rank() == min(self.module_rank_list)) + or (not self.module_rank_list and dist.get_rank() == 0)) if save_module_struct: - module_struct_file = os.path.realpath(os.path.join(get_output_base_dir(), 'module_struct.json')) + module_struct_file = os.path.realpath( + os.path.join(get_output_base_dir(), f'{dist.get_rank()}_module_struct.json')) save_json(module_struct_file, self.module_struct, indent=2) logger.info(f"> save module struct to {module_struct_file}") self.struct_printed = True @@ -862,27 +1031,33 @@ class TrainerMon: return False def _register_chunk(self, model_chunk, prefix): - index = 0 for (param_name, param) in model_chunk.named_parameters(): if not param.requires_grad: continue + if not self.fsdp_wrapped_module and param_name.startswith("_fsdp_wrapped_module"): + self.fsdp_wrapped_module = True if self._is_target_param(param_name, param, prefix): name = prefix + squash_param_name(param_name, self.squash_name) if name in self.param2name.values(): name = prefix + param_name self.param2name[param] = name self.name2param[name] = param - self.name2index[name] = index if self.tp_group and not param_is_not_tensor_parallel_duplicate(param, self.tp_group): self.duplicate_param[name] = True if self.dp_group and param_is_data_parallel_duplicate(self.dp_group): self.duplicate_param[name] = True + + keywords = [ + MonitorConst.PRE_GRAD, + MonitorConst.POST_GRAD, + MonitorConst.PRE_PARAM, + MonitorConst.POST_PARAM + ] self.name2tag[name] = { - MonitorConst.PRE_GRAD: get_summary_writer_tag_name(name, MonitorConst.PRE_GRAD, self.rank), - MonitorConst.POST_GRAD: get_summary_writer_tag_name(name, MonitorConst.POST_GRAD, self.rank) + k: get_summary_writer_tag_name(name, k, self.rank) + for k in keywords } - index += 1 def _register_param_name(self): for vpp_stage, model_chunk in enumerate(self.model): @@ -897,19 +1072,38 @@ class TrainerMon: vpp_stage + module_name, ]: if pattern in targets: - return pattern + return vpp_stage + squash_param_name(module_name, self.squash_name) return "" - def _hook_module(self, target_names, module: torch.nn.Module, vpp_stage=''): + def _is_recording_module(self, module_name, l2_targets, vpp_stage): + + if len(l2_targets) > 0: + for pattern in [ + vpp_stage + squash_param_name(module_name, self.squash_name), + vpp_stage + module_name, + ]: + if pattern in l2_targets: + return vpp_stage + squash_param_name(module_name, self.squash_name) + return "" + else: + raise NotImplementedError("If monitoring L2 features, the targets should be set specifically.") + + def _hook_module(self, target_names, l2_target_names, module: torch.nn.Module, vpp_stage=''): if '_modules' not in module.__dict__: # nothing to hook return 0 - def fwd_hook_fun(module, module_input, module_output, name): + def fwd_hook_fun(module, args, kwargs, module_output, name): if not module.training or is_recomputation(): # 1 only monitor training stage. # 2 when open recompute, skip recomputed forward stage. return + + module_input = [tensor for tensor in args if torch.is_tensor(tensor)] + if kwargs: + kwargs_tensors = [tensor for tensor in kwargs.values() if torch.is_tensor(tensor)] + module_input.extend(kwargs_tensors) + if module not in self.module_fwd_hook_context_by_module: self.module_fwd_hook_context_by_module[module] = ModuleHookContext(name) context: ModuleHookContext = self.module_fwd_hook_context_by_module[module] @@ -918,34 +1112,20 @@ class TrainerMon: Const.INPUT: get_param_struct(module_input), Const.OUTPUT: get_param_struct(module_output) } + if self.print_struct: self.module_struct[context.module_name].update(context.struct) return - if not context.format_by_arg: - context.set_format_by_arg(Const.INPUT, self.config['targets']) - context.set_format_by_arg(Const.OUTPUT, self.config['targets']) - if not context.format_by_arg: - return - if not context.verified: - context.focused_in_col = validate_config_spec(context.format_by_arg[Const.INPUT], - module_input, context.module_name, - Const.INPUT) - context.focused_out_col = validate_config_spec(context.format_by_arg[Const.OUTPUT], - module_output, context.module_name, - Const.OUTPUT) - context.verified = True - # expect output be tensor type + tbtag_tensor_map = {} - cared_input = module_input if context.focused_in_col is None else module_input[context.focused_in_col] tbtag_tensor_map.update( self.build_tbtag_tensor_map( - f'{context.module_name}.{Const.INPUT}{MonitorConst.NAME_SEP}{context.micro_step}', - MonitorConst.ACTV, cared_input)) - cared_output = module_output if context.focused_out_col is None else module_output[context.focused_out_col] + f'{context.module_name}.{Const.INPUT}', f'{MonitorConst.NAME_SEP}{context.micro_step}', + MonitorConst.ACTV, module_input)) tbtag_tensor_map.update( self.build_tbtag_tensor_map( - f'{context.module_name}.{Const.OUTPUT}{MonitorConst.NAME_SEP}{context.micro_step}', - MonitorConst.ACTV, cared_output)) + f'{context.module_name}.{Const.OUTPUT}', f'{MonitorConst.NAME_SEP}{context.micro_step}', + MonitorConst.ACTV, module_output)) get_metrics(self.ops, tbtag_tensor_map, self.eps, context.actv) context.micro_step += 1 @@ -963,31 +1143,17 @@ class TrainerMon: if self.print_struct: self.module_struct[context.module_name].update(context.struct) return - if not context.format_by_arg: - context.set_format_by_arg(MonitorConst.INPUT_GRAD, self.config['targets']) - context.set_format_by_arg(MonitorConst.OUTPUT_GRAD, self.config['targets']) - if not context.format_by_arg: - return - if not context.verified: - context.focused_in_col = validate_config_spec( - context.format_by_arg[MonitorConst.INPUT_GRAD], - input_grad, context.module_name, MonitorConst.INPUT_GRAD) - context.focused_out_col = validate_config_spec( - context.format_by_arg[MonitorConst.OUTPUT_GRAD], - output_grad, context.module_name, MonitorConst.OUTPUT_GRAD) - context.verified = True tbtag_tensor_map = {} - cared_input_grad = input_grad if context.focused_in_col is None else input_grad[context.focused_in_col] tbtag_tensor_map.update( self.build_tbtag_tensor_map( - f'{context.module_name}.{Const.INPUT}{MonitorConst.NAME_SEP}{context.micro_step}', - MonitorConst.ACTV, cared_input_grad)) - cared_output_grad = output_grad if context.focused_out_col is None else output_grad[context.focused_out_col] + f'{context.module_name}.{Const.INPUT}', f'{MonitorConst.NAME_SEP}{context.micro_step}', + MonitorConst.ACTVGRAD, input_grad)) + tbtag_tensor_map.update( self.build_tbtag_tensor_map( - f'{context.module_name}.{Const.OUTPUT}{MonitorConst.NAME_SEP}{context.micro_step}', - MonitorConst.ACTV, cared_output_grad)) + f'{context.module_name}.{Const.OUTPUT}', f'{MonitorConst.NAME_SEP}{context.micro_step}', + MonitorConst.ACTVGRAD, output_grad)) if context.micro_step == 0 and context.actvgrad: logger.warning(f"actvgrad context of {context.module_name} is not empty when first micro_step, " @@ -1001,17 +1167,127 @@ class TrainerMon: context.micro_step = 0 return + def extract_attention_feature_hook(module, module_input, module_output, name): + if is_recomputation() or not module.training: + return + + if module not in self.feature_hook_context_by_module: + self.feature_hook_context_by_module[module] = FeatureHookContext(name) + context: FeatureHookContext = self.feature_hook_context_by_module[module] + + tbtag_tensor_map = {} + if len(module_input) < 2: + raise ValueError("the length of module_input in attention hook's module " + "should be greater than or equal to 2.") + q_h = module_input[0] + k_h = module_input[1] + qkt = cal_qkt(q_h, k_h, order=self.sa_order) + tbtag_tensor_map.update( + self.build_tbtag_tensor_map(f'{context.module_name}.attention', + f'{MonitorConst.NAME_SEP}{context.micro_step}', 'qkt', qkt) + ) + get_entropy_metric(tbtag_tensor_map, self.eps, context.attention_feature) + + context.micro_step += 1 + if context.micro_step == self.micro_batch_number: + context.micro_step = 0 + context.step += 1 + return + + def extract_linear_sr_hook(module, module_input, module_output, name): + if is_recomputation() or not module.training: + return + + if module not in self.feature_hook_context_by_module: + self.feature_hook_context_by_module[module] = FeatureHookContext(name) + context: FeatureHookContext = self.feature_hook_context_by_module[module] + + tbtag_tensor_map = {} + + value = module.weight.data + tbtag_tensor_map.update( + self.build_tbtag_tensor_map(f'{context.module_name}.linear', + f'{MonitorConst.NAME_SEP}{context.micro_step}', 'sr', value) + ) + get_sr_metric(tbtag_tensor_map, self.eps, context.linear_feature) + + context.micro_step += 1 + if context.micro_step == self.micro_batch_number: + context.micro_step = 0 + context.step += 1 + return + + def extract_token_similarity(module, module_input, module_output, name): + if is_recomputation() or not module.training: + return + + if module not in self.feature_hook_context_by_module: + self.feature_hook_context_by_module[module] = FeatureHookContext(name) + context: FeatureHookContext = self.feature_hook_context_by_module[module] + + tbtag_tensor_map = {} + + tbtag_tensor_map.update( + self.build_tbtag_tensor_map(f'{context.module_name}.token', + f'{MonitorConst.NAME_SEP}{context.micro_step}', 'avg_token_similarity', + module_output[0]) + ) + get_avg_token_similarity_metric(tbtag_tensor_map, self.eps, context.token_feature) + + context.micro_step += 1 + if context.micro_step == self.micro_batch_number: + context.micro_step = 0 + context.step += 1 + return + + def detect_norm_stability(module, module_input, module_output, name): + if is_recomputation() or not module.training: + return + + if module not in self.feature_hook_context_by_module: + self.feature_hook_context_by_module[module] = FeatureHookContext(name) + context: FeatureHookContext = self.feature_hook_context_by_module[module] + + tbtag_tensor_map = {} + weight = module.weight + + tbtag_tensor_map.update( + self.build_tbtag_tensor_map(f'{context.module_name}.layernorm', + f'{MonitorConst.NAME_SEP}{context.micro_step}', 'norm_similarity', + [(module_input[0], weight)]) + ) + get_norm_stability_metric(tbtag_tensor_map, self.eps, context.norm_feature) + + context.micro_step += 1 + if context.micro_step == self.micro_batch_number: + context.micro_step = 0 + context.step += 1 + return + + def stack_hook(module, args, kwargs, module_output, name): + if module not in self.module_fwd_hook_context_by_module: + self.module_fwd_hook_context_by_module[module] = ModuleHookContext(name) + context: ModuleHookContext = self.module_fwd_hook_context_by_module[module] + context.stack = analyze_api_call_stack(name) + return + if self.backward_only and self.forward_only: logger.warning('not enable backward_only and forward_only simultaneously') hooked_count = 0 - if self.xy_distribution or self.print_struct: - for module_name, submodule in module.named_modules(): - name = self._is_target_module(module_name, target_names, vpp_stage) - if not name: - continue + for module_name, submodule in module.named_modules(): + if self.stack_info: + name = vpp_stage + squash_param_name(module_name, self.squash_name) + handle = submodule.register_forward_hook(partial(stack_hook, name=name), with_kwargs=True) + self.handles['stack'].append(handle) + name = self._is_target_module(module_name, target_names, vpp_stage) + if not name: + continue + if submodule.__class__.__name__ == "FullyShardedDataParallel": + continue + if self.xy_distribution or self.print_struct: if not self.backward_only: - handle = submodule.register_forward_hook(partial(fwd_hook_fun, name=name)) + handle = submodule.register_forward_hook(partial(fwd_hook_fun, name=name), with_kwargs=True) self.handles['xy'].append(handle) if not self.forward_only and not self.has_register_backward_hook(name, submodule): handle = submodule.register_full_backward_hook(bwd_hook_fun) @@ -1019,6 +1295,28 @@ class TrainerMon: self.module_bwd_hook_context_by_module[submodule] = ModuleHookContext(name) logger.info_on_rank_0(f"> {name} is monitored successfully") hooked_count += 1 + if not self.print_struct and self.recording_l2_features: + for module_name, submodule in module.named_modules(): + func_map = { + "attention_hook": extract_attention_feature_hook, + "linear_hook": extract_linear_sr_hook, + "token_hook": extract_token_similarity, + "norm_hook": detect_norm_stability + } + hooks = ["attention_hook", "linear_hook", "token_hook", "norm_hook"] + for hook in hooks: + if hook in l2_target_names: + temp_names = l2_target_names[hook] + name = self._is_recording_module(module_name, temp_names, vpp_stage) + if name: + handle = submodule.register_forward_hook(partial(func_map[hook], name=name)) + print_feature_name = hook.split('_')[0] + logger.info_on_rank_0( + f'> {print_feature_name} features of {name} is monitored successfully') + self.handles["L2_features"].append(handle) + hooked_count += 1 + continue + return hooked_count def _patch_grad_sync(self): @@ -1040,7 +1338,7 @@ class TrainerMon: if tag is None: continue grad_dict[tag] = grad - self._register_param_call_id("sync_grad_func", tag) + self.register_param_call_id("sync_grad_func", tag) get_metrics(self.ops, grad_dict, self.eps, self.grad_context.pre) out = sync_grad_func(bucket) return out @@ -1049,7 +1347,14 @@ class TrainerMon: if not self.wg_distribution: return + if self.fsdp_wrapped_module: + # patch fsdp _runtime_utils._post_backward_hook + self._patch_fsdp_post_backward_hook() + return + if self.monitor_mbs_grad: + self._hook_weights() + return try: from megatron.core.distributed.param_and_grad_buffer import Bucket self.origin_start_grad_sync = Bucket.start_grad_sync @@ -1066,44 +1371,113 @@ class TrainerMon: self.enable_megatron = True logger.info("megatron version is > core_r0.8.0 <= core_r0.9.0") except ImportError: - self.enable_megatron = False + self.enable_megatron = False | self.enable_megatron + if self.enable_megatron: + return - if not self.enable_megatron: - self._hook_weights() + # default hook weights + self._hook_weights() + + def _patch_fsdp_post_backward_hook(self): + """ + FSDP runtime 需要处理整个forward和backward计算和通信的流程,通过override nn.Module的forward,定义相应的逻辑。 + 对AccumulateGrad对象注册hook,可以在backward计算grad后立刻执行,在reduce_scatter操作前采集梯度累计后,通信聚合前的梯度。 + 每个forward阶段,fsdp对AccumulateGrad重复注册hook方法,monitor工具内注册hook无法生效, + 因此对_post_backward_hook进行patch,在backward后,reduce_scatter前采集梯度。 + """ + + def patch_post_backward_hook(_post_backward_hook): + def wrapper(state, handle, *unused): + grad_dict = {} + offset = 0 + for param, name in self.param2name.items(): + limit = param.numel() + if not limit: + continue + grad = handle.flat_param.grad[offset:offset + limit] + offset += limit + tag = self.name2tag.get(name, {}).get(MonitorConst.PRE_GRAD) + if tag is None: + continue + grad_dict[tag] = grad + self.register_param_call_id("_post_backward_hook", tag) + get_metrics(self.ops, grad_dict, self.eps, self.grad_context.pre) + out = _post_backward_hook(state, handle, *unused) + return out + + return wrapper + + logger.info("Patch fsdp _post_backward_hook, collect pre_grad metrics.") + self.fsdp_post_backward_hook = torch.distributed.fsdp._runtime_utils._post_backward_hook + torch.distributed.fsdp._runtime_utils._post_backward_hook = \ + patch_post_backward_hook(torch.distributed.fsdp._runtime_utils._post_backward_hook) def _hook_weights(self): + """ + 遍历参数的梯度生成函数(grad_acc),并挂载hook,以便在该参数所有梯度计算后,采集通信聚合前梯度数据。 + """ context = self.grad_context @torch.no_grad - def param_hook(*args, context_dict, param, key, name): + def param_hook(*args, context_dict, param, name): + key = name + if self.monitor_mbs_grad: + key += f'{MonitorConst.NAME_SEP}{param.micro_step}' + + key = get_summary_writer_tag_name(key, 'acc_grad', self.rank) + self.register_param_call_id("param_hook", key) param.micro_step += 1 - self._register_param_call_id("param_hook", key) - if param.micro_step == self.micro_batch_number: - param.micro_step = 0 + + if self.monitor_mbs_grad or (param.micro_step == self.micro_batch_number): if self.params_have_main_grad: - context_dict[key] = param.main_grad.clone() + grad = param.main_grad else: - context_dict[key] = param.grad.clone() + grad = param.grad + if is_float8_tensor(grad): + grad = grad.float() + context_dict[key] = grad.clone() + + if param.micro_step == self.micro_batch_number: + param.micro_step = 0 logger.info("hooking weights.") for param, name in self.param2name.items(): - key = get_summary_writer_tag_name(name, 'acc_grad', self.rank) setattr(param, 'micro_step', 0) param_tmp = param.expand_as(param) grad_acc = param_tmp.grad_fn.next_functions[0][0] handle = grad_acc.register_hook( - partial(param_hook, context_dict=context.acc, param=param, key=key, name=name)) + partial(param_hook, context_dict=context.acc, param=param, name=name)) self.grad_accs.append(grad_acc) self.handles['wgrads'].append(handle) self.weight_hooked = True - def _register_param_call_id(self, hook_name: str, key: str): - """ - :param hook_name: - :param key: str, '0:relu_0/output_grad' - :return: - """ - logger.debug(f"{hook_name} {key}: {self.call_id}") - self.param_name_call_id[key] = self.call_id - self.call_id += 1 + def _register_proxy_model_content(self): + if not self.proxy_model: + return + context = self.proxy_context + + @torch.no_grad + def param_hook(*args, context_dict, param, key, name): + param.micro_step_count += 1 + self.register_param_call_id("proxy_hook", key) + if self.params_have_main_grad: + context_dict[name][param.micro_step_count] = param.main_grad.clone() + else: + context_dict[name][param.micro_step_count] = param.grad.clone() + + if param.micro_step_count == self.micro_batch_number: + param.micro_step_count = 0 + + for param, name in self.param2name.items(): + logger.info(f"grad_per_sample hook of {name} is registered successfully! ") + key = get_summary_writer_tag_name(name, "grad_per_sample", self.rank) + setattr(param, 'micro_step_count', 0) + param_tmp = param.expand_as(param) + grad_per_sample = param_tmp.grad_fn.next_functions[0][0] + handle = grad_per_sample.register_hook( + partial(param_hook, context_dict=context.gradient_per_sample, param=param, key=key, name=name)) + self.grad_accs.append(grad_per_sample) + self.handles['wgrads'].append(handle) + + self.grad_per_sample_hooked = True diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/module_metric.py b/debug/accuracy_tools/msprobe/pytorch/monitor/module_metric.py index 87963812006413a90fd33bc70d6172a7c73c3f10..319afbd4288d205384c38da47326338a878094be 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/module_metric.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/module_metric.py @@ -16,7 +16,11 @@ import re import torch -from msprobe.pytorch.monitor.features import get_max, get_min, get_zeros, get_nans, get_norm, get_mean +from msprobe.core.common.const import MonitorConst +from msprobe.pytorch.common.utils import is_float8_tensor +from msprobe.pytorch.monitor.features import get_max, get_min, get_zeros, get_nans, get_norm, get_mean, cal_entropy, \ + cal_stable_rank, rms_norm_jacobian, layer_norm_jacobian, cal_kl_divergence, cal_avg_token_similarity_chunk, \ + cal_dist_diff from msprobe.pytorch.monitor.utils import get_nan_tensor @@ -143,6 +147,20 @@ class IdentMetric(Metric): return tensor +@register_config_metric("shape") +class ShapeMetric(Metric): + @staticmethod + def get_metric_value(tensor, eps): + return tensor.shape + + +@register_config_metric("dtype") +class DtypeMetric(Metric): + @staticmethod + def get_metric_value(tensor, eps): + return tensor.dtype + + def get_metrics(ops, tag2tensor, eps, out_dict=None): """ :param ops: ["op1", "op2"] @@ -166,7 +184,108 @@ def get_metrics(ops, tag2tensor, eps, out_dict=None): # Non-tensor in/output filled with nan. out_dict[tag].update({metric_name: get_nan_tensor() for metric_name in ops}) continue + if is_float8_tensor(tensor): + tensor = tensor.float() for metric_name in ops: fun_metric = config_metric_registry.get(metric_name) out_dict[tag][metric_name] = fun_metric.get_metric(tensor, eps) return out_dict + + +def get_sr_metric(tag2tensor, eps, out_dict=None): + if out_dict is None: + out_dict = {} + for tag, tensor in tag2tensor.items(): + if "sr" not in tag: + continue + if tag not in out_dict: + out_dict[tag] = {} + sr, eig = cal_stable_rank(tensor) + out_dict[tag]['sr'] = sr + out_dict[tag]['kernel_norm'] = eig + + +def get_entropy_metric(tag2tensor, eps, out_dict=None): + if out_dict is None: + out_dict = {} + for tag, tensor in tag2tensor.items(): + if tag not in out_dict: + out_dict[tag] = {} + entropy, softmax_max = cal_entropy(tensor) + out_dict[tag]['entropy'] = entropy + out_dict[tag]['softmax_max'] = softmax_max + + +def get_avg_token_similarity_metric(tag2tensor, eps, out_dict=None): + if out_dict is None: + out_dict = {} + for tag, tensor in tag2tensor.items(): + if "avg" not in tag: + continue + if tag not in out_dict: + out_dict[tag] = {} + + if tensor.dim() == 2: + cal_tensor = tensor + elif tensor.dim() == 3: + cal_tensor = tensor[:MonitorConst.CAL_SIM_SEQ_LENGTH, 0, + :tensor.shape[-1] // MonitorConst.CAL_SIM_H_CHUNK_RATIO] + avg_token_similarity = cal_avg_token_similarity_chunk(cal_tensor) + out_dict[tag]['avg_token_similarity'] = avg_token_similarity + + +def get_norm_stability_metric(tag2tensor, eps=1e-8, out_dict=None): + if out_dict is None: + out_dict = {} + for tag, tensor in tag2tensor.items(): + if tag not in out_dict: + out_dict[tag] = {} + input_tensor, weight = tensor + if input_tensor.dim() == 2: + cal_tensor = input_tensor[:, :].mean(dim=0, keepdim=True) + elif input_tensor.dim() == 3: + cal_tensor = input_tensor[:, 0, :].mean(dim=0, keepdim=True) + std_x, jacobian = rms_norm_jacobian(cal_tensor, weight, eps) + + out_dict[tag]['std_x'] = std_x + out_dict[tag]['jacobian'] = torch.log(jacobian) + + +def get_kl_divergence_metric(tag2tensor, eps, out_dict=None): + if out_dict is None: + out_dict = {} + + for tag, tensor in tag2tensor.items(): + if "kl_divergence" not in tag: + continue + if tag not in out_dict: + out_dict[tag] = {} + + input_tensor, output_tensor = tensor + kl_divergence = cal_kl_divergence(input_tensor[:MonitorConst.CAL_SIM_SEQ_LENGTH, 0, + :input_tensor.shape[-1] // MonitorConst.CAL_SIM_H_CHUNK_RATIO], + output_tensor[:MonitorConst.CAL_SIM_SEQ_LENGTH, 0, + :output_tensor.shape[-1] // MonitorConst.CAL_SIM_H_CHUNK_RATIO]) + out_dict[tag]['kl_divergence'] = kl_divergence + + +def get_dict_diff_metric(tag2tensor, eps, out_dict=None): + if out_dict is None: + out_dict = {} + + for tag, tensor in tag2tensor.items(): + if "dict" not in tag: + continue + if tag not in out_dict: + out_dict[tag] = {} + + input_tensor, output_tensor = tensor + wasserstein_1, mean_diff, std_diff = cal_dist_diff( + input_tensor[:MonitorConst.CAL_SIM_SEQ_LENGTH, 0, + :input_tensor.shape[-1] // MonitorConst.CAL_SIM_H_CHUNK_RATIO], + output_tensor[:MonitorConst.CAL_SIM_SEQ_LENGTH, 0, + :output_tensor.shape[-1] // MonitorConst.CAL_SIM_H_CHUNK_RATIO] + ) + out_dict[tag]['wasserstein_1'] = wasserstein_1 + out_dict[tag]['mean_diff'] = mean_diff + out_dict[tag]['std_diff'] = std_diff diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/module_spec_verifier.py b/debug/accuracy_tools/msprobe/pytorch/monitor/module_spec_verifier.py deleted file mode 100644 index 72c35c90bf9540a31cfa1176274a3d2c66bc8946..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/module_spec_verifier.py +++ /dev/null @@ -1,95 +0,0 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import re -import abc -import torch - -from msprobe.pytorch.common.log import logger - -# 用于存储所有validator实现类的注册表 -config_validator_registry = {} - - -def register_config_validator(cls): - """装饰器 用于注册ConfigValidator的实现类""" - config_validator_registry[cls.__name__] = cls - return cls - - -class ConfigValidator(metaclass=abc.ABCMeta): - @abc.abstractmethod - def check_pattern_match(self, config_spec: str): - pass - - @abc.abstractmethod - def validate(self, actual_data, module_name: str, data_type: str, pattern_match): - pass - - -@register_config_validator -class TensorValidator(ConfigValidator): - def check_pattern_match(self, config_spec: str): - pattern = re.compile(r"tensor") - return pattern.match(config_spec) - - def validate(self, actual_data, module_name: str, data_type: str, pattern_match): - if not torch.is_tensor(actual_data): - raise ValueError( - f"Format of {module_name} {data_type} does not match the required format 'tensor' in config.") - - -@register_config_validator -class TupleValidator(ConfigValidator): - def check_pattern_match(self, config_spec: str): - pattern = re.compile(r"tuple\[(\d+)\]:?(\d+)?") - return pattern.match(config_spec) - - def validate(self, actual_data, module_name: str, data_type: str, pattern_match): - length, index = pattern_match.groups() - if index is None: - index = 0 - length, index = int(length), int(index) - - if not (0 <= index < length): - raise ValueError( - f"Format of {module_name} {data_type} in config.json does not match the required format 'tuple[x]:y'." - f"y must be greater than or equal to 0 and less than x.") - if not isinstance(actual_data, tuple): - raise ValueError( - f"Type of {module_name} {data_type} does not match spec of config.json, should be tuple, please check.") - if len(actual_data) != length: - raise ValueError( - f"Length of {module_name} {data_type} does not match spec of config.json, should be {length}, " - f"actual is {len(actual_data)} please check.") - return index - - -def validate_config_spec(config_spec: str, actual_data, module_name: str, data_type: str): - focused_col = None - if not config_spec or not isinstance(config_spec, str): - return focused_col - for _, validator_cls in config_validator_registry.items(): - config_validator = validator_cls() - pattern_match = config_validator.check_pattern_match(config_spec) - if pattern_match: - try: - focused_col = config_validator.validate(actual_data, module_name, data_type, pattern_match) - except ValueError as e: - logger.warning(f"config spec validate failed: {str(e)}") - return focused_col - logger.warning(f"config spec in {module_name} {data_type} not supported, " - f"expected spec:'tuple\[(\d+)\]:(\d+)' or 'tensor', actual spec: {config_spec}.") - return focused_col diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/optimizer_collect.py b/debug/accuracy_tools/msprobe/pytorch/monitor/optimizer_collect.py index 602514836d2531ad4a6be3a23f56bc3b942ba199..3f8140cb7dc8a73b179af589d034308aeede2350 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/optimizer_collect.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/optimizer_collect.py @@ -12,129 +12,120 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from collections import defaultdict +from abc import abstractmethod import torch -import torch.distributed as dist from msprobe.pytorch.common.log import logger -from msprobe.pytorch.monitor.utils import MVResult, MVGradResult +from msprobe.pytorch.monitor.utils import MVResult +from msprobe.core.common.const import MonitorConst class OptimizerMon(object): - def __init__(self) -> None: + def __init__(self, torch_opt) -> None: self.fp16_to_fp32_param = {} - self.is_stage3 = False + self.torch_opt = torch_opt + self.state = {} + + def narrow_from_flatten(self, param, flatten_state): + return flatten_state + + def get_state(self, torch_opt): + if hasattr(torch_opt, 'chained_optimizers'): + for opt in torch_opt.chained_optimizers: + self._get_single_state(opt) + else: + self._get_single_state(torch_opt) - def fetch_mv(self, monitor, torch_opt, params2name): - pass + def fetch_grad(self, monitor, params2name): + if not self.fp16_to_fp32_param: + self.map_fp16_to_fp32_param(self.torch_opt) - def _fetch_mv_in_adam(self, monitor, torch_opt, params2name): - exp_avg_dict = defaultdict(float) - exp_avg_sq_dict = defaultdict(float) - update_dict = defaultdict() - ratio_dict = defaultdict() + grad_dict = {} + first_param = True for param, name in params2name.items(): - if param in self.fp16_to_fp32_param: - param = self.fp16_to_fp32_param[param] - - if param in torch_opt.state: - state_param = torch_opt.state.get(param, None) - exp_avg = state_param.get("exp_avg", None) - exp_avg_sq = state_param.get("exp_avg_sq", None) - if exp_avg is None or exp_avg_sq is None: - logger.warning(f"exp_avg or exp_avg_sq of {name} is None, maybe something wrong happened.") - continue + if monitor.duplicate_param.get(name, False): + continue + if self.fp16_to_fp32_param and param not in self.fp16_to_fp32_param: + continue + grad = param.main_grad if monitor.params_have_main_grad else param.grad + element_in_cur_partition = self.fp16_to_fp32_param.get(param, param).numel() + if param.numel() != element_in_cur_partition: + if first_param: + grad = grad.flatten()[-element_in_cur_partition:] + else: # supposed to be the last one + grad = grad.flatten()[:element_in_cur_partition] + first_param = False + + if grad is None: + if not monitor.fsdp_wrapped_module: + logger.warning(f"grad is None: {name}, maybe something wrong happened.") + continue + tag = monitor.name2tag.get(name, {}).get(MonitorConst.POST_GRAD) + monitor.register_param_call_id("hook_optimizer", tag) + grad_dict[tag] = grad + return grad_dict + + def map_fp16_to_fp32_param(self, torch_opt): + pass + + def fetch_mv(self, monitor, params2name): + if not self.fp16_to_fp32_param: + self.map_fp16_to_fp32_param(self.torch_opt) + if not self.state: + self.get_state(self.torch_opt) + + exp_avg_dict = {} + exp_avg_sq_dict = {} + update_dict = {} + ratio_dict = {} + + if not self.state: + logger.warning('optimizer state can not accessed') + return MVResult(exp_avg=exp_avg_dict, exp_avg_sq=exp_avg_sq_dict, update=update_dict, ratio=ratio_dict) + + for lp_param, name in params2name.items(): + if lp_param in self.fp16_to_fp32_param: + hp_param = self.fp16_to_fp32_param[lp_param] + else: + hp_param = lp_param + + if hp_param in self.state: + state_param = self.state.get(hp_param, {}) + exp_avg = self.narrow_from_flatten(lp_param, state_param.get("exp_avg", None)) + exp_avg_sq = self.narrow_from_flatten(lp_param, state_param.get("exp_avg_sq", None)) if monitor.mv_distribution: exp_avg_dict[name] = exp_avg exp_avg_sq_dict[name] = exp_avg_sq if monitor.mg_direction: exp_avg_dict[name] = exp_avg if monitor.ur_distribution: - if len(torch_opt.param_groups) > 1: - logger.info(f"the length of torch_opt.param_groups is {len(torch_opt.param_groups)}.") + if len(self.torch_opt.param_groups) > 1: + logger.info(f"the length of torch_opt.param_groups is {len(self.torch_opt.param_groups)}.") if 'step' in state_param: step = state_param['step'] # Optimizer from pytorch or FusedAdam from apex(used by megatron) - elif 'step' in torch_opt.param_groups[0]: - step = torch_opt.param_groups[0]['step'] # AdamW from mindspeed + elif 'step' in self.torch_opt.param_groups[0]: + step = self.torch_opt.param_groups[0]['step'] # AdamW from mindspeed else: logger.warning(f"step of {name} is None, maybe something wrong happened.") continue - exp_avg_hat = exp_avg / (1 - torch_opt.defaults['betas'][0] ** step) - exp_avg_sq_hat = exp_avg_sq / (1 - torch_opt.defaults['betas'][1] ** step) - update_dict[name] = exp_avg_hat / (torch.sqrt(exp_avg_sq_hat) + torch_opt.defaults['eps']) + exp_avg_hat = exp_avg / (1 - self.torch_opt.defaults['betas'][0] ** step) + exp_avg_sq_hat = exp_avg_sq / (1 - self.torch_opt.defaults['betas'][1] ** step) + update_dict[name] = exp_avg_hat / (torch.sqrt(exp_avg_sq_hat) + self.torch_opt.defaults['eps']) ratio_dict[name] = exp_avg_hat / torch.sqrt(exp_avg_sq_hat) monitor.update_heatmap_visualizer[name].pre_cal(update_dict[name]) monitor.ratio_heatmap_visualizer[name].pre_cal(ratio_dict[name]) return MVResult(exp_avg=exp_avg_dict, exp_avg_sq=exp_avg_sq_dict, update=update_dict, ratio=ratio_dict) - - def _fetch_mv_grad_in_adam(self, monitor, torch_opt, params2name, name2indices, fp32_partitioned_groups_flat): - exp_avg_dict = defaultdict(float) - exp_avg_sq_dict = defaultdict(float) - update_dict = defaultdict() - ratio_dict = defaultdict() - param2name = defaultdict() - fp32_partitioned_groups_flat_grad = defaultdict() - partition_id = dist.get_rank() - - def get_flatten_grad(self, optimizer, group_idx): - if fp32_partitioned_groups_flat[group_idx].grad is None: - if partition_id == dist.get_world_size() - 1 and not self.is_stage3: - fp32_partitioned_groups_flat_grad = optimizer.flatten_dense_tensors_aligned( - optimizer.averaged_gradients[group_idx], - int(optimizer.partition_size[group_idx]) - ).to(fp32_partitioned_groups_flat[group_idx].dtype) - else: - fp32_partitioned_groups_flat_grad = optimizer.flatten( - optimizer.averaged_gradients[group_idx] - ).to(fp32_partitioned_groups_flat[group_idx].dtype) - return fp32_partitioned_groups_flat_grad - else: - return fp32_partitioned_groups_flat[group_idx].grad - - for group_idx in range(len(fp32_partitioned_groups_flat)): - fp32_partitioned_groups_flat_grad[group_idx] = get_flatten_grad(self, torch_opt, group_idx) - - for name in params2name.values(): - start_idx, end_idx, group_idx, group_with_rank = name2indices[name] - if group_with_rank != partition_id and isinstance(group_with_rank, int): - continue - fp32_param = fp32_partitioned_groups_flat[group_idx][start_idx: end_idx] - fp32_param.grad = fp32_partitioned_groups_flat_grad[group_idx][start_idx: end_idx] - param2name[fp32_param] = name - if not torch_opt.state: - continue - state_param = list(torch_opt.state.values())[group_idx] - exp_avg = state_param.get("exp_avg", None) - exp_avg_sq = state_param.get("exp_avg_sq", None) - if exp_avg is None or exp_avg_sq is None: - logger.warning(f"exp_avg or exp_avg_sq of {name} is None, maybe something wrong happened.") - continue - exp_avg = exp_avg[start_idx: end_idx] - exp_avg_sq = exp_avg_sq[start_idx: end_idx] - if monitor.mv_distribution: - exp_avg_dict[name] = exp_avg - exp_avg_sq_dict[name] = exp_avg_sq - if monitor.mg_direction: - exp_avg_dict[name] = exp_avg - if monitor.ur_distribution: - if 'step' in state_param: - step = state_param['step'] # Optimizer from pytorch or FusedAdam from apex(used by megatron) - elif 'step' in torch_opt.param_groups[group_idx]: - step = torch_opt.param_groups[group_idx]['step'] # AdamW from mindspeed - else: - logger.warning(f"step of {name} is None, maybe something wrong happened.") - continue - exp_avg_hat = exp_avg / (1 - torch_opt.defaults['betas'][0] ** step) - exp_avg_sq_hat = exp_avg_sq / (1 - torch_opt.defaults['betas'][1] ** step) - update_dict[name] = exp_avg_hat / (torch.sqrt(exp_avg_sq_hat) + torch_opt.defaults['eps']) - ratio_dict[name] = exp_avg_hat / torch.sqrt(exp_avg_sq_hat) - monitor.update_heatmap_visualizer[name].pre_cal(update_dict[name]) - monitor.ratio_heatmap_visualizer[name].pre_cal(ratio_dict[name]) - del fp32_partitioned_groups_flat_grad - return MVGradResult(exp_avg=exp_avg_dict, exp_avg_sq=exp_avg_sq_dict, update=update_dict, ratio=ratio_dict, - grad=param2name) + + def _get_single_state(self, torch_opt): + state = {} + if hasattr(torch_opt, 'param_to_cpu_states_map'): + state = torch_opt.param_to_cpu_states_map + elif hasattr(torch_opt, 'state'): + state = torch_opt.state + elif hasattr(torch_opt, 'optimizer') and hasattr(torch_opt.optimizer, 'state'): + state = torch_opt.optimizer.state + self.state.update(state) class MixPrecisionOptimizerMon(OptimizerMon): @@ -142,21 +133,14 @@ class MixPrecisionOptimizerMon(OptimizerMon): 混合精度优化器监控类。在混合精度训练中监控和管理优化器。 混合精度训练通过适当降低某些计算的精度来加速训练过程并减少内存消耗。 """ - - def map_fp16_tp_fp32_param(self, torch_opt): + def map_fp16_to_fp32_param(self, torch_opt): for fp16_group, fp32_group in zip(torch_opt.float16_groups, torch_opt.fp32_from_float16_groups): for fp16_param, fp32_param in zip(fp16_group, fp32_group): self.fp16_to_fp32_param[fp16_param] = fp32_param - def fetch_mv(self, monitor, torch_opt, params2name): - if not self.fp16_to_fp32_param and torch_opt is not None: - self.map_fp16_tp_fp32_param(torch_opt) - - return self._fetch_mv_in_adam(monitor, torch_opt, params2name) - class MegatronDistributedOptimizerMon(OptimizerMon): - def map_fp16_tp_fp32_param(self, torch_opt): + def map_fp16_to_fp32_param(self, torch_opt): if not (hasattr(torch_opt, "model_float16_groups") and hasattr(torch_opt, "shard_fp32_from_float16_groups")): raise Exception( @@ -167,141 +151,176 @@ class MegatronDistributedOptimizerMon(OptimizerMon): for fp16_param, shard_fp32_param in zip(fp16_group, shard_fp32_group): self.fp16_to_fp32_param[fp16_param] = shard_fp32_param - def fetch_mv(self, monitor, torch_opt, params2name): - if not self.fp16_to_fp32_param and torch_opt is not None: - self.map_fp16_tp_fp32_param(torch_opt) - - return self._fetch_mv_in_adam(monitor, torch_opt, params2name) +class MegatronChainedDistributedOptimizerMon(MegatronDistributedOptimizerMon): + def map_fp16_to_fp32_param(self, torch_opt): + for opt in torch_opt.chained_optimizers: + super().map_fp16_to_fp32_param(opt) -class MegatronFP32OptimizerMon(OptimizerMon): - def fetch_mv(self, monitor, torch_opt, params2name): - return self._fetch_mv_in_adam(monitor, torch_opt, params2name) +class MegatronChainedMixPrecisionOptimizerMon(MixPrecisionOptimizerMon): + def map_fp16_to_fp32_param(self, torch_opt): + for opt in torch_opt.chained_optimizers: + super().map_fp16_to_fp32_param(opt) -class MegatronChainedDistributedOptimizerMon(MegatronDistributedOptimizerMon): - def fetch_mv(self, monitor, torch_opt, params2name): - if not self.fp16_to_fp32_param and torch_opt is not None: - for opt in torch_opt.chained_optimizers: - self.map_fp16_tp_fp32_param(opt) - if not isinstance(torch_opt, torch.optim.Optimizer): - torch_opt.state = {} - for opt in torch_opt.chained_optimizers: - torch_opt.state.update(opt.optimizer.state) - return self._fetch_mv_in_adam(monitor, torch_opt, params2name) +class DeepSpeedZeroOptimizerMon(OptimizerMon): + """ + Base monitor class for DeepSpeed ZeRO optimizer. + ZeRO stage 0 no partition + ZeRO stage 1 partitions optimizer states across data parallel processes. + ZeRO stage 2 additionally partitions gradients. + ZeRO stage 3 additionally partitions parameters. + + This class provides monitoring capabilities for ZeRO optimizers by: + - Handling gradient collection for different ZeRO stages + - Managing optimizer state access for monitoring + """ + def __init__(self, torch_opt): + super().__init__(torch_opt) + self.stage = '' + self.bit16_groups = [] + self.fp32_flat_groups = [] + self.param2group = () + self.param2index = [] + self.group_offset = {} + + @abstractmethod + def get_grad_for_param(self, lp_param, group_idx, param_id): + raise NotImplementedError + + def param_not_in_partition(self, lp_param, group_idx): + param_slice_mapping = self.torch_opt.state_dict()['param_slice_mappings'][group_idx] + hp_address = param_slice_mapping.get(self.torch_opt.param_names.get(lp_param)) + return hp_address is None + + def get_position(self, lp_param, group_idx): + param_slice_mapping = self.torch_opt.state_dict()['param_slice_mappings'][group_idx] + hp_address = param_slice_mapping.get(self.torch_opt.param_names.get(lp_param)) + return hp_address.start, hp_address.numel + + def get_group_index(self): + param2group = {} + for group_idx, bit16_group in enumerate(self.bit16_groups): + for param in bit16_group: + param2group[param] = group_idx + return param2group + + def get_param_index(self, lp_param, group_idx): + if not self.param2index: + for group in self.bit16_groups: + param2index = {} + for index, param in enumerate(group): + param2index[param] = index + self.param2index.append(param2index) + + return self.param2index[group_idx][lp_param] + + def narrow_from_flatten(self, param, flatten_state): + if flatten_state is None: + return flatten_state + group_idx = self.param2group[param] + if self.param_not_in_partition(param, group_idx): + return None + start, numel = self.get_position(param, group_idx) + return flatten_state.narrow(0, start, numel) + + def map_fp16_to_fp32_param(self, torch_opt): + for group_idx, group in enumerate(self.bit16_groups): + for param in group: + self.fp16_to_fp32_param[param] = self.fp32_flat_groups[group_idx] + + def fetch_grad(self, monitor, params2name): + grad_dict = {} + for lp_param, name in params2name.items(): + group_idx = self.param2group[lp_param] + param_id = self.get_param_index(lp_param, group_idx) + if self.param_not_in_partition(lp_param, group_idx): + continue + if self.stage == '1or2': + param_id = param_id - self.group_offset[group_idx] - 1 + grad = self.get_grad_for_param(lp_param, group_idx, param_id) + tag = monitor.name2tag.get(name, {}).get(MonitorConst.POST_GRAD) + monitor.register_param_call_id("hook_optimizer", tag) + grad_dict[tag] = grad + + return grad_dict + + +class DeepSpeedZeroOptimizerStage0Mon(DeepSpeedZeroOptimizerMon): + def __init__(self, torch_opt): + super().__init__(torch_opt) + self.stage = '0' + self.bit16_groups = torch_opt.bf16_groups + self.fp32_flat_groups = torch_opt.fp32_groups_flat_partition + self.param2group = self.get_group_index() + + def get_grad_for_param(self, lp_param, group_idx, param_id): + return self.torch_opt.fp32_groups_gradient_dict[group_idx][param_id] + + +class DeepSpeedZeroOptimizerStage1or2Mon(DeepSpeedZeroOptimizerMon): + def __init__(self, torch_opt): + super().__init__(torch_opt) + self.stage = '1or2' + self.bit16_groups = torch_opt.bit16_groups + self.fp32_flat_groups = torch_opt.single_partition_of_fp32_groups + self.param2group = self.get_group_index() + self.group_offset = {} + self.get_group_offset() + + def get_grad_for_param(self, lp_param, group_idx, param_id): + if getattr(self.torch_opt, "cpu_offload", False): + grads = self.torch_opt.single_partition_of_fp32_groups[group_idx].grad + start, numel = self.get_position(lp_param, group_idx) + grad = grads.narrow(0, start, numel) + else: + grad = self.torch_opt.averaged_gradients[group_idx][param_id] + return grad + + def get_group_offset(self): + for group_idx, group in enumerate(self.bit16_groups): + self.group_offset[group_idx] = -1 + for lp_param in group: + if self.param_not_in_partition(lp_param, group_idx): + self.group_offset[group_idx] = self.get_param_index(lp_param, group_idx) + else: + break -class MegatronChainedMixPrecisionOptimizerMon(MixPrecisionOptimizerMon): - def fetch_mv(self, monitor, torch_opt, params2name): - if not self.fp16_to_fp32_param and torch_opt is not None: - for opt in torch_opt.chained_optimizers: - self.map_fp16_tp_fp32_param(opt) +class DeepSpeedZeroOptimizerStage3Mon(DeepSpeedZeroOptimizerMon): + def __init__(self, torch_opt): + super().__init__(torch_opt) + self.stage = '3' + self.bit16_groups = torch_opt.fp16_groups + self.fp32_flat_groups = torch_opt.fp32_partitioned_groups_flat + self.param2group = self.get_group_index() - if not isinstance(torch_opt, torch.optim.Optimizer): - torch_opt.state = {} - for opt in torch_opt.chained_optimizers: - torch_opt.state.update(opt.optimizer.state) - return self._fetch_mv_in_adam(monitor, torch_opt, params2name) - - -class DeepSpeedZeroOptimizerStage0Mon(OptimizerMon): - def fetch_mv(self, monitor, torch_opt, params2name): - return self._fetch_mv_in_adam(monitor, torch_opt, params2name) - - -class DeepSpeedZeroOptimizerStage3Mon(OptimizerMon): - def get_param_index(self, params2name, name2index, torch_opt): - fp16_groups = torch_opt.fp16_partitioned_groups - name2indices = defaultdict() - index_length = defaultdict() - index = 0 - idx = 0 - for group_idx, fp16_group in enumerate(fp16_groups): - for param in fp16_group: - param_length = len(param.flatten()) - index_length[idx] = (index, index + param_length, group_idx) - index += param_length - idx += 1 - for _, name in params2name.items(): - idx = name2index[name] - start_idx, end_idx, group_idx = index_length[idx] - name2indices[name] = (start_idx, end_idx, group_idx, None) - return name2indices - - def fetch_mv(self, monitor, torch_opt, params2name, name2indices=None): - self.is_stage3 = True - fp32_partitioned_groups_flat = torch_opt.fp32_partitioned_groups_flat - return self._fetch_mv_grad_in_adam(monitor, torch_opt, params2name, name2indices, fp32_partitioned_groups_flat) - - -class DeepSpeedZeroOptimizerStage1or2Mon(OptimizerMon): - @staticmethod - def get_group_index(fp32_length, world_size, index): - for i in range(len(fp32_length) - 1): - if fp32_length[i] <= index < fp32_length[i + 1]: - interval_start = fp32_length[i] - interval_length = fp32_length[i + 1] - fp32_length[i] - sub_interval_length = interval_length // world_size - sub_index = (index - interval_start) // sub_interval_length - sub_interval_start = interval_start + sub_index * sub_interval_length - return sub_interval_start, min(sub_index, world_size - 1) - return fp32_length[-1], 0 - - def get_param_index(self, params2name, name2index, torch_opt): - padding = torch_opt.groups_padding - world_size = dist.get_world_size() - fp32_length = [0] - for fp32_group_index, single_partition_of_fp32_group in enumerate(torch_opt.single_partition_of_fp32_groups): - fp32_length.append(len(single_partition_of_fp32_group) * world_size + fp32_length[fp32_group_index]) - - bf16_groups = [] - name2indices = defaultdict() - index_length = defaultdict() - index = 0 - idx = 0 - for group_idx, bf16_group in enumerate(torch_opt.bit16_groups): - bf16_groups.extend(bf16_group) - for param in bf16_group: - param_length = len(param.flatten()) - group_index, group_with_rank = self.get_group_index(fp32_length, world_size, index) - index_length[idx] = (index, index + param_length, group_idx, group_index, group_with_rank) - index += param_length - idx += 1 - group_length = len(bf16_groups) / len(torch_opt.bit16_groups) - for _, name in params2name.items(): - name_index = name2index[name] - start_idx, end_idx, group_idx, group_index, group_with_rank = index_length[name_index] - need_padding = True if group_with_rank == world_size - 1 else False - new_start_idx = start_idx - group_index - new_end_idx = end_idx - group_index - if need_padding and group_length - 1 <= name_index <= len(bf16_groups) - 1 and name_index % ( - group_length - 1) == 0: - new_end_idx -= padding[int(name_index // (group_length - 1) - 1)] - name2indices[name] = (new_start_idx, new_end_idx, group_idx, group_with_rank) - return name2indices - - def fetch_mv(self, monitor, torch_opt, params2name, name2indices=None): - fp32_partitioned_groups_flat = torch_opt.single_partition_of_fp32_groups - return self._fetch_mv_grad_in_adam(monitor, torch_opt, params2name, name2indices, fp32_partitioned_groups_flat) - - -class DummyOptimizerMon(OptimizerMon): - def fetch_mv(self, monitor, torch_opt, params2name): - return self._fetch_mv_in_adam(monitor, torch_opt, params2name) + def param_not_in_partition(self, lp_param, group_idx): + """Each param partioned across all zero ranks""" + return False + + def get_position(self, lp_param, group_idx): + param_id = self.torch_opt.get_param_id(lp_param) + return self.torch_opt.grad_position[param_id][1:] + + def get_grad_for_param(self, lp_param, group_idx, param_id): + return self.torch_opt.averaged_gradients[group_idx][param_id] class OptimizerMonFactory: _optimizer_mon_map = { - "FP32Optimizer": MegatronFP32OptimizerMon, + "FP32Optimizer": OptimizerMon, "Float16OptimizerWithFloat16Params": MixPrecisionOptimizerMon, "DistributedOptimizer": MegatronDistributedOptimizerMon, + "SwapDistributedOptimizer": MegatronDistributedOptimizerMon, "ChainedDistributedOptimizer": MegatronChainedDistributedOptimizerMon, + "ChainedSwapDistributedOptimizer": MegatronChainedDistributedOptimizerMon, "ChainedFloat16OptimizerWithFloat16Params": MegatronChainedMixPrecisionOptimizerMon, "BF16_Optimizer": DeepSpeedZeroOptimizerStage0Mon, "DeepSpeedZeroOptimizer": DeepSpeedZeroOptimizerStage1or2Mon, "DeepSpeedZeroOptimizer_Stage3": DeepSpeedZeroOptimizerStage3Mon, - "Adam": DummyOptimizerMon + "Adam": OptimizerMon } @staticmethod @@ -310,6 +329,7 @@ class OptimizerMonFactory: optimizer_class = optimizer.__class__.__name__ if optimizer_class == "ChainedOptimizer": optimizer_class = "Chained" + optimizer.chained_optimizers[0].__class__.__name__ + logger.info(f'The optimizer type is {optimizer_class}') - optimizer_mon_class = OptimizerMonFactory._optimizer_mon_map.get(optimizer_class, DummyOptimizerMon) - return optimizer_mon_class(), optimizer_class + optimizer_mon_class = OptimizerMonFactory._optimizer_mon_map.get(optimizer_class, OptimizerMon) + return optimizer_mon_class(optimizer) diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/unittest/test_monitor.py b/debug/accuracy_tools/msprobe/pytorch/monitor/unittest/test_monitor.py deleted file mode 100644 index 4d5c1a717d80ee30414f25b44a93ddc7257ef2c7..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/unittest/test_monitor.py +++ /dev/null @@ -1,160 +0,0 @@ -# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import os -import re -from glob import glob - -import pandas as pd - -from msprobe.pytorch.common.log import logger - - -def parse_logfile(logfile): - grad_norm = [] - step = [] - with open(logfile) as f: - for line in f.readlines(): - if 'consumed samples' in line: - grad_norm.append(float(re.findall('(?<=grad norm\: )[\d\.]*', line)[0])) - return grad_norm - - -def parse_monitor_output(output_dir): - reduced = {} - unreduced = {} - for directory in glob(output_dir + '*'): - rank = int(re.findall('(?<=rank)[\d]*', directory)[0]) - unreduced[rank] = [] - reduced[rank] = [] - for file in os.listdir(directory): - df = pd.read_csv(os.path.join(directory, file)) - if '_unreduced_' in file: - unreduced[rank].append(df) - pass - elif '_reduced_' in file: - reduced[rank].append(df) - else: - logger.info(f'unexpected file {file} in {directory}') - return reduced, unreduced - - -def valid_reduce(reduced, unreduced, tp_size, dp_size, sequence_parallel): - steps = len(reduced[0]) - world_size = len(reduced) - errors = [] - for _, row in unreduced[0][0].iterrows(): - param = row['param_name'] - is_tp_duplicate = False - for step in range(2): - # sum reduced - reduced_mean = 0. - for rank in range(world_size): - if len(reduced[rank]) == 0: - continue - df = reduced[rank][step] - value = list(df[df['param_name'] == param]['mean']) - if not value: - if step == 0: - is_tp_duplicate = True - continue - reduced_mean += value[0] - - # sum unreduced - unreduced_mean = 0. - for rank in range(world_size): - df = unreduced[rank][step] - value = list(df[df['param_name'] == param]['mean']) - if not value: - continue - unreduced_mean += list(df[df['param_name'] == param]['mean'])[0] - - unreduced_mean /= dp_size - if is_tp_duplicate and (not sequence_parallel or 'embedding' in param): - unreduced_mean /= tp_size - try: - assert_equal(unreduced_mean, reduced_mean) - except AssertionError as e: - errors.append([param, step, e, is_tp_duplicate]) - if errors: - logger.info(errors) - else: - logger.info(f'grad mean is in consist between unreduced grad and reduced grad monitord.') - - -def assert_equal(a, b): - if b == 0 or a == 0: - return - if b == 0: - rel_diff = a - elif a == 0: - rel_diff = b - else: - rel_diff = abs(a / b - 1) - assert rel_diff < 0.01, f'{a}, {b}, {rel_diff}' - - -def valid_total_norm(total_norm, reduced, duplicate_embedding): - steps = len(total_norm) - world_size = len(reduced) - errors = [] - for step in range(steps): - calculated_norm = 0. - for rank in range(world_size): - if len(reduced[rank]) == 0: - if step == 0: - logger.info(f'rank {rank} is duplicated in dp group') - continue - for _, row in reduced[rank][step].iterrows(): - if duplicate_embedding and 'word_embedding' in row['param_name']: - continue - calculated_norm += row['norm'] ** 2 - try: - assert_equal(calculated_norm ** 0.5, total_norm[step]) - except AssertionError as e: - errors.append([step, e]) - if errors: - logger.info('total norm errors: ', errors) - else: - logger.info('grad norm in consist between training log and reduced gradients monitored') - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument('--monitor_output', '-m', type=str, required=True, - help='path prefix to the output of monitor e.g. monitor_output/Aug12_07-16') - parser.add_argument('--logfile', '-l', type=str, required=True, help='path to the training log file') - parser.add_argument('--tp_size', '-t', type=int, required=True, help='tp parallel size') - parser.add_argument('--dp_size', '-d', type=int, required=True, help='dp parallel size') - parser.add_argument('--pp_size', '-p', type=int, required=True, help='pp parallel size') - parser.add_argument('--untie_embeddings_and_output_weights', '-u', action="store_true", default=False, - help='whether untie_embeddings_and_output_weights in pp parallel') - parser.add_argument('--sequence_parallel', '-s', action="store_true", default=False, - help='whether sequence parallel is enabled. Add -s to store true') - - args = parser.parse_args() - - assert args.tp_size > 0, 'if tp not enabled, set tp_size = 1' - assert args.dp_size > 0, 'if tp not enabled, set dp_size = 1' - assert args.pp_size > 0, 'if tp not enabled, set pp_size = 1' - - total_norm = parse_logfile(args.logfile) - reduced, unreduced = parse_monitor_output(args.monitor_output) - - duplicate_embedding = not args.untie_embeddings_and_output_weights and args.pp_size > 1 - - valid_total_norm(total_norm, reduced, duplicate_embedding) - valid_reduce(reduced, unreduced, args.tp_size, args.dp_size, args.sequence_parallel) diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/utils.py b/debug/accuracy_tools/msprobe/pytorch/monitor/utils.py index 94afe56ffcfe7571a189c5f6959b2eb9a2779d81..767707479719159ae2806b54f1f706ed6faa5a20 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/utils.py @@ -22,10 +22,10 @@ import re import torch -from msprobe.core.common.const import MonitorConst, Const +from msprobe.core.common.const import MonitorConst from msprobe.pytorch.common.log import logger from msprobe.core.common.utils import is_int -from msprobe.core.common.file_utils import check_file_or_directory_path +from msprobe.core.common.file_utils import check_file_or_directory_path, recursive_chmod device = "cpu" @@ -43,7 +43,6 @@ DIRECTORY_MAX_LENGTH = 4096 beijing_tz = timezone(timedelta(hours=8)) MVResult = namedtuple('MVResult', ("exp_avg", "exp_avg_sq", "update", "ratio")) -MVGradResult = namedtuple('MVGradResult', ("exp_avg", "exp_avg_sq", "update", "ratio", "grad")) class MsgConst: @@ -102,9 +101,23 @@ def validate_ops(ops): default_op = MonitorConst.OP_LIST[0] valid_ops.append(default_op) logger.info_on_rank_0(f"There is no valid ops, default op {default_op} is used") + # 增加默认shape和dtype参数 + if "shape" not in valid_ops: + valid_ops.append("shape") + if "dtype" not in valid_ops: + valid_ops.append("dtype") return valid_ops +def validate_ndigits(ndigits): + if not ndigits: + return + if not is_int(ndigits) or ndigits <= 0: + raise ValueError(f"ndigits({ndigits}) is not a positive integer, current is: {ndigits}.") + if ndigits > MonitorConst.MAX_NDIGITS: + raise ValueError(f"The maximum supported ndigits is {MonitorConst.MAX_NDIGITS}, current value: {ndigits}.") + + def validate_ranks(ranks): if not isinstance(ranks, list): raise TypeError("module_ranks should be a list") @@ -190,7 +203,7 @@ def validate_alert(alert): args = rule.get("args") if args and isinstance(args, dict): threshold = args.get("threshold") - if not isinstance(threshold, float) or threshold < 0: + if not isinstance(threshold, (float, int)) or threshold < 0: raise TypeError('threshold must be float and not less than 0') dump = alert.get('dump') if dump and not isinstance(dump, bool): @@ -206,9 +219,24 @@ def validate_step_count_per_record(step_count_per_record): raise ValueError("step_count_per_record must smaller than 1e6") +def validate_dynamic_on(dynamic_on): + if not isinstance(dynamic_on, bool): + raise TypeError('dynamic_on should be a bool') + + +def validate_monitor_mbs_grad(monitor_mbs_grad): + if not isinstance(monitor_mbs_grad, bool): + logger.warning(f'monitor_mbs_grad should be a bool, actual value is {monitor_mbs_grad}.') + return False + return monitor_mbs_grad + + def validate_config(config): config['ops'] = validate_ops(config.get('ops', [])) + ndigits = config.get('ndigits') + validate_ndigits(ndigits) + eps = config.get('eps', 1e-8) if not isinstance(eps, float): raise TypeError("eps should be a float") @@ -246,9 +274,22 @@ def validate_config(config): step_count_per_record = config.get('step_count_per_record', 1) validate_step_count_per_record(step_count_per_record) + config["start_step"] = validate_int_arg(config.get("start_step"), "start_step", + MonitorConst.DEFAULT_START_STEP, MonitorConst.DEFAULT_START_STEP) + config["collect_times"] = validate_int_arg(config.get("collect_times"), "collect_times", + MonitorConst.DEFAULT_MIN_COLLECT_TIMES, + MonitorConst.DEFAULT_MAX_COLLECT_TIMES) + config["step_interval"] = validate_int_arg(config.get("step_interval"), "step_interval", + MonitorConst.DEFAULT_STEP_INTERVAL, MonitorConst.DEFAULT_STEP_INTERVAL) + squash_name = config.get('squash_name', True) validate_squash_name(squash_name) + config["monitor_mbs_grad"] = validate_monitor_mbs_grad(config.get('monitor_mbs_grad', False)) + + dynamic_on = config.get('dynamic_on', False) + validate_dynamic_on(dynamic_on) + if not targets: if xy_distribution: config["all_xy"] = True @@ -257,6 +298,8 @@ def validate_config(config): def time_str2time_digit(time_str): time_format = '%b%d_%H-%M-%S' + if not isinstance(time_str, str): + raise TypeError(f"time_str:{time_str} should be a str") try: time_digit = datetime.strptime(time_str, time_format) except Exception as e: @@ -284,3 +327,40 @@ def get_target_output_dir(monitor_path, time_start, time_end): if start_ok and end_ok: result[rank] = os.path.join(monitor_path, dirname) return result + + +def chmod_tensorboard_dir(path): + """ + format配置为tensorboard时,需要补充文件权限设置 + """ + try: + recursive_chmod(path) + except Exception as e: + logger.warning(f"chmod tensorboard dir wrong because {e}, not updated, please check!!!") + + +def validate_set_monitor(grad_acc_steps, start_iteration): + """ + validate parameters of set_monitor. + """ + grad_acc_steps = validate_int_arg(grad_acc_steps, "grad_acc_steps", + MonitorConst.DEFAULT_GRAD_ACC_STEPS, MonitorConst.DEFAULT_GRAD_ACC_STEPS) + + start_iteration = validate_int_arg(start_iteration, "start_iteration", + MonitorConst.DEFAULT_START_ITERATION, MonitorConst.DEFAULT_START_ITERATION) + return grad_acc_steps, start_iteration + + +def validate_int_arg(value, name, minimum, default_value): + """Validate int args, if any exception occurs, use the default value.""" + if value is None: + return default_value + try: + if not is_int(value): + raise TypeError(f"{name} must be int") + if value < minimum: + raise ValueError(f"{name} must greater than {minimum}") + except Exception as e: + value = default_value + logger.warning(f"Validate {name} failed, {e}, replaced with default value {value}.") + return value diff --git a/debug/accuracy_tools/msprobe/pytorch/nan_analyse/acs_msprobe/__init__.py b/debug/accuracy_tools/msprobe/pytorch/nan_analyse/acs_msprobe/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b25e9bbda64095ac21efcc17b577ac7a5d6b0e9f --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/nan_analyse/acs_msprobe/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/debug/accuracy_tools/msprobe/pytorch/nan_analyse/acs_msprobe/analyzer/__init__.py b/debug/accuracy_tools/msprobe/pytorch/nan_analyse/acs_msprobe/analyzer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b25e9bbda64095ac21efcc17b577ac7a5d6b0e9f --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/nan_analyse/acs_msprobe/analyzer/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/debug/accuracy_tools/msprobe/pytorch/nan_analyse/acs_msprobe/analyzer/anomaly_analyzer/anomaly_api_analyzer.py b/debug/accuracy_tools/msprobe/pytorch/nan_analyse/acs_msprobe/analyzer/anomaly_analyzer/anomaly_api_analyzer.py new file mode 100644 index 0000000000000000000000000000000000000000..da6dd2bfa76fa78c5f38ba1bb11402a8f8e29689 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/nan_analyse/acs_msprobe/analyzer/anomaly_analyzer/anomaly_api_analyzer.py @@ -0,0 +1,58 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from rich.live import Live + +from acs_msprobe.analyzer.base_analyzer import Analyzer +from acs_msprobe.tools.utils import get_progress_table_and_progress +from acs_msprobe.dataset.msprobe_summary_dataset import SummaryStackDataset, SummaryItem +from acs_msprobe.result.analyzer_result import ResultCollector +from acs_msprobe.checker.factory import StaticApiHandlerFactory + + +class AnomalyAPIAnalyzer(Analyzer): + def __init__(self, dump_dataset, stack_dataset=None): + super().__init__(dump_dataset) + self.stack_dataset = stack_dataset + self.result_collector = ResultCollector() + + @property + def analyzer_result(self): + return self.result_collector + + def analyze(self): + process_table, overall_progress, overall_task = (get_progress_table_and_progress(total_len=len(self.dataset), + title="总体分析进度")) + with Live(process_table, refresh_per_second=5): + for rank_id, api_list in self.dataset.dataset.items(): + if api_list is None: + continue + unique_api_list = [] + + for index, item in enumerate(api_list): + api_item = item.get("api_name") + result = StaticApiHandlerFactory.dispatch(api_name, api_list[index], dataset=self.dataset.dataset) + if not result: + continue + if self.stack_dataset is not None: + api_name = api_name.replace("backward", "forward") + api_name_valid = ".".join(api_name.split(".")[:2]) + result.stack_info = self.stack_dataset.query(rank_id, api_name) + if (result.stack_info, api_name_valid) in unique_api_list: + continue + unique_api_list.append((result.stack_info, api_name_valid)) + self.result_collector.add(result) + overall_progress.advance(overall_task, advance=1) \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/pytorch/nan_analyse/acs_msprobe/analyzer/anomaly_analyzer/summary_analyzer_factory.py b/debug/accuracy_tools/msprobe/pytorch/nan_analyse/acs_msprobe/analyzer/anomaly_analyzer/summary_analyzer_factory.py new file mode 100644 index 0000000000000000000000000000000000000000..df3f14aefedcc81a9460552b0317ae1b5ae6cdab --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/nan_analyse/acs_msprobe/analyzer/anomaly_analyzer/summary_analyzer_factory.py @@ -0,0 +1,44 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from acs_msprobe.analyzer.anomaly_analyzer.anomaly_api_analyzer import AnomalyAPIAnalyzer + + +class SummaryDumpAnalyzerFactory: + """ + SummaryDumpAnalyzerFactory 用于创建不同类型的SummaryAnalyzer实例 + """ + supported_analyzer = { + "anomaly_api": AnomalyAPIAnalyzer + } + default_analyzer = "anomaly_api" + + @classmethod + def create_analyzer(cls, **kwargs): + """ + 创建指定类型的SummaryAnalyzer实例 + """ + analyze_mothod = kwargs.pop("analyze_method", cls.default_analyzer) + if analyze_mothod not in cls.supported_analyzer.keys(): + raise ValueError(f"Unsupported analyzer method: {analyze_mothod}") + + return cls.get_analyzer(analyze_mothod)(**kwargs) + + @classmethod + def get_analyzer(cls, analyze_method): + """ + 获取指定类型的SummaryAnalyzer类 + """ + return cls.supported_analyzer.get(analyze_method) diff --git a/debug/accuracy_tools/msprobe/pytorch/nan_analyse/acs_msprobe/analyzer/base_analyzer.py b/debug/accuracy_tools/msprobe/pytorch/nan_analyse/acs_msprobe/analyzer/base_analyzer.py new file mode 100644 index 0000000000000000000000000000000000000000..f63f2d582d2d5a79e37bd1aa8bdb05d0c02c61ce --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/nan_analyse/acs_msprobe/analyzer/base_analyzer.py @@ -0,0 +1,28 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from acs_msprobe.dataset.base_dataset import Dataset + + +class Analyzer(ABC): + def __init__(self, dataset: Dataset, *args, **kwargs): + if not isinstance(dataset, Dataset): + raise TypeError("dataset must be an instance of Dataset.") + self.dataset = dataset + + @abstractmethod + def analyze(self, *args, **kwargs): + raise NotImplementedError("analyze method not implemented.") \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/pytorch/nan_analyse/acs_msprobe/checker/__init__.py b/debug/accuracy_tools/msprobe/pytorch/nan_analyse/acs_msprobe/checker/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b25e9bbda64095ac21efcc17b577ac7a5d6b0e9f --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/nan_analyse/acs_msprobe/checker/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/debug/accuracy_tools/msprobe/pytorch/nan_analyse/acs_msprobe/checker/factory.py b/debug/accuracy_tools/msprobe/pytorch/nan_analyse/acs_msprobe/checker/factory.py new file mode 100644 index 0000000000000000000000000000000000000000..d547a6d179c80ec491fcb94a97c0b3733d51da84 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/nan_analyse/acs_msprobe/checker/factory.py @@ -0,0 +1,55 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from acs_msprobe.tools.const import COMPUTEENUM +from acs_msprobe.tools.utils import is_distributed_api + + +class StaticApiHandlerFactory: + handlers = {} + + @classmethod + def register(cls, handler, alias=None): + if alias is None: + cls.handlers[handler.__name__] = handler + else: + if not isinstance(alias, list): + alias = [alias] + for name in alias: + cls.handlers[name] = handler + + @classmethod + def create(cls, hander_name, *args, **kwargs): + if hander_name not in cls.handlers: + return cls.handlers[hander_name].handle(*args, **kwargs) + else: + raise ValueError(f"Handler '{hander_name}' not found.") + + @classmethod + def dispatch(cls, api_name, *args, **kwargs): + if not is_distributed_api(api_name): + return cls.create(COMPUTEENUM.NORMAL_INPUT_ABNORMAL_OUTPUT, *args, **kwargs) + + result = None + for key in cls.handlers.keys(): + if key in api_name: + result = cls.create(key, *args, **kwargs) + if result is not None: + return result + return result + + @classmethod + def reset(cls): + cls.handlers = {} diff --git a/debug/accuracy_tools/msprobe/pytorch/nan_analyse/acs_msprobe/dataset/__init__.py b/debug/accuracy_tools/msprobe/pytorch/nan_analyse/acs_msprobe/dataset/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b25e9bbda64095ac21efcc17b577ac7a5d6b0e9f --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/nan_analyse/acs_msprobe/dataset/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/debug/accuracy_tools/msprobe/pytorch/nan_analyse/acs_msprobe/dataset/base_dataset.py b/debug/accuracy_tools/msprobe/pytorch/nan_analyse/acs_msprobe/dataset/base_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..30695e8cea77a3f9a6ad888a14fd8ee9ce1b2c12 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/nan_analyse/acs_msprobe/dataset/base_dataset.py @@ -0,0 +1,36 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from abc import ABC, abstractmethod +from typing import Dict + +from acs_msprobe.tools.log import logger + + +class Dataset(ABC): + def __init__(self, input_path, *args, **kwargs): + if input_path is None or not os.path.exists(input_path): + raise FileNotFoundError(f"Dataset file {input_path} not found.") + self.input_path = os.path.abspath(os.path.join(input_path)) + self.dataset: Dict = {} + logger.debug("Init %s with input_path: %s", self.__class__.__name__, self.input_path) + + def __len__(self): + return len(self.dataset) + + @abstractmethod + def parse(self, *args, **kwargs): + raise NotImplementedError("parse method not implemented.") \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/pytorch/nan_analyse/acs_msprobe/dataset/msprobe_summary_dataset.py b/debug/accuracy_tools/msprobe/pytorch/nan_analyse/acs_msprobe/dataset/msprobe_summary_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..8b9d47a967b1ac35c613053f1e038a85e89ba88e --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/nan_analyse/acs_msprobe/dataset/msprobe_summary_dataset.py @@ -0,0 +1,229 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import json +import functools +from typing import Optional, Dict, List, Tuple + +from acs_msprobe.tools.utils import get_current_rank +from acs_msprobe.dataset.process_group_dataset import SummaryItem, ProcessGroup, ProcessGroupCounter +from acs_msprobe.tools.log import logger +from acs_msprobe.dataset.base_dataset import Dataset +from acs_msprobe.tools import const as constant + + +class ProcessGroupMonitor: + def __init__(self, process_group): + self.process_group_counter = ProcessGroupCounter() + self.process_group = process_group + + def collect(self, meta_info: Dict, api_info: Dict): + """ + meta_info数据格式: + {"type": "torch.distributed.ProcessGroup", + "global_ranks": [ 0, 1, 2, 3], + "group_id": 281471037097008} + """ + process_group = ProcessGroup(group_id=meta_info.get(constant.TORCH_GROUP_ID, -1), + global_ranks=meta_info.get(constant.TORCH_GROUP_RANKS, []) or + meta_info.get(constant.TORCH_GLOBAL_RANKS, []) or + meta_info.get("value", [])) + self.process_group_counter.update(process_group, api_info.get("rank"), api_info.get("api_name")) + api_info['group'] = process_group + + +class SummaryDumpDataset(Dataset): + + def __init__(self, input_path, framework): + super().__init__(input_path) + self.framework: str = framework + # 区分PyTorch和MindSpore的ProcessGroup字段名称 + self.process_group = constant.VALID_TORCH_GROUP if framework == constant.PYTORCH else None + self.process_group_monitor = ProcessGroupMonitor(self.process_group) + self.parse_all_ranks() + + @classmethod + def from_args(cls, **kwargs) -> "SummaryDumpDataset": + input_path: str = kwargs.get("input_path") + framework: str = kwargs.get("framework", "pytorch") + if input_path is None or not os.path.exists(input_path): + raise TypeError("Missing input_path args.") + if framework == "pytorch": + logger.debug("Using msprobe pytorch json-format dataset.") + cls.PROCESS_GROUP = "torch.distributed.ProcessGroup" + else: + raise ValueError(f"Only pytorch framework is supported in advisor mode currently, but got {framework}") + return cls(input_path=input_path, framework=framework) + + def load_data_by_rank(self, rank_id: str): + if rank_id not in self.dataset.keys(): + self.dataset[rank_id] = self.parse(os.path.join(self.input_path, rank_id), rank_id) + + def parse_all_ranks(self): + def _cmp_rank(rank1: str, rank2: str): + int_rank1 = get_current_rank(rank1) + int_rank2 = get_current_rank(rank2) + return 1 if int_rank1 > int_rank2 else -1 + + if not os.path.isdir(os.path.join(self.input_path)): + logger.error("Only directory is supported, please check your data path: %s.", self.input_path) + return None + + dir_list = sorted([item for item in os.listdir(self.input_path) if item.startswith('rank')], + key=functools.cmp_to_key(_cmp_rank)) + for rank in dir_list: + rank_int = get_current_rank(rank) + self.dataset[rank_int] = self.parse(os.path.join(self.input_path, rank), rank_int) + + def _parse_dict_item(self, api_info, meta_info: Dict) -> Tuple[Optional[List], Optional[List]]: + api_static_info = None + api_meta_info = None + dtype = meta_info.get("type") + if "Max" not in meta_info.keys(): + if dtype in ["float", "int"]: + # item类型 + value = meta_info.get("value") + dtype = constant.DTYPE_CLASS_INT if dtype == "int" \ + else constant.DTYPE_CLASS_FLOAT + shape = '[]' + api_meta_info = [dtype, shape] + api_static_info = [value, value, value, value] + elif dtype in self.process_group: + self.process_group_monitor.collect(meta_info, api_info) + self.GLOBAL_RANK_FLAG = True + else: + # Tensor类型 + api_meta_info = [dtype, meta_info.get("shape")] + api_static_info = [meta_info.get("Max"), meta_info.get("Min"), meta_info.get("Mean"), meta_info.get("Norm")] + return api_meta_info, api_static_info + + def _parse_list_object(self, api_info, api_name, meta_info_list: List) -> List: + api_info_list = [] + for index, item in enumerate(meta_info_list): + if item is None: + continue + api_name_with_index = api_name + f".{index}" if len(meta_info_list) > 1 else api_name + api_meta_info, api_static_info = self._parse_dict_item(api_info, item) + if api_static_info is not None: + api_info_list.append((api_name_with_index, *api_meta_info, api_static_info)) + return api_info_list + + def _parse_args_object(self, api_info, api_name, meta_info_list: List) -> List: + api_info_list = [] + for index, inp in enumerate(meta_info_list): + if inp is None: + continue + if isinstance(inp, dict): + inp = [inp] + api_info_parsed_list = self._parse_list_object(api_info, api_name + f".{index}", inp) + if len(api_info_parsed_list) > 0: + api_info_list.append(api_info_parsed_list) + return api_info_list + + def _parse_kwargs_object(self, api_info, api_name, meta_info_dict: Dict) -> List: + api_info_list = [] + for key, value in meta_info_dict.items(): + if value is None: + continue + api_name_with_index = api_name + f".{key}" + if not isinstance(value, dict) and not isinstance(value, list): + raise ValueError + + if isinstance(value, dict): + value = [value] + api_info_parsed_list = self._parse_list_object(api_info, api_name_with_index, value) + if len(api_info_parsed_list) > 0: + api_info_list.append(api_info_parsed_list) + return api_info_list + + def parse(self, rank_dir, rank: str) -> Optional[List]: + api_list = [] + api_dump_pkl_path = os.path.join(rank_dir, constant.API_DUMP_JSON_FILE_NAME) + if not os.path.exists(api_dump_pkl_path): + return api_list + # 优化:如果json文件过大,可以尝试使用流式读取 + with open(api_dump_pkl_path, "r") as f: + content = json.load(f) + index = 0 + for api_name, api_meta in content.get('data', {}).items(): + api_info = {} + index += 1 + fwd_bwd_flag = constant.FORWARD_FLAG if constant.FORWARD_FLAG in api_name else constant.BACKWARD_FLAG + api_info["api_name"] = api_name + api_info["line"] = index + api_info["rank"] = rank + api_info[constant.INPUT_FLAG] = [] + api_info[constant.OUTPUT_FLAG] = [] + if fwd_bwd_flag == constant.FORWARD_FLAG: + api_info_list = self._parse_args_object(api_info, api_name, + api_meta.get("input_args", {})) + api_info[constant.INPUT_FLAG].extend(api_info_list) + api_info_list = self._parse_kwargs_object(api_info, api_name, + api_meta.get("input_kwargs", {})) + api_info[constant.INPUT_FLAG].extend(api_info_list) + api_info_list = self._parse_args_object(api_info, api_name, + api_meta.get("output", [])) + api_info[constant.OUTPUT_FLAG].extend(api_info_list) + elif fwd_bwd_flag == constant.BACKWARD_FLAG: + # 兼容不同msprobe版本 + input_key = "grad_output" if "grad_output" in api_meta.keys() else "input" + output_key = "grad_input" if "grad_input" in api_meta.keys() else "output" + api_info_list = self._parse_args_object(api_info, api_name, + api_meta.get(output_key, [])) + api_info[constant.OUTPUT_FLAG].extend(api_info_list) + api_info_list = self._parse_args_object(api_info, api_name, + api_meta.get(input_key, [])) + api_info[constant.INPUT_FLAG].extend(api_info_list) + api_list.append(self.post_process(SummaryItem(**api_info))) + return api_list + + @staticmethod + def post_process(api_info: SummaryItem): + if "npu_fusion_attention" in api_info.api_name: + api_info.output = api_info.output[:3] + + if "__setitem__" in api_info.api_name and not api_info.output: + api_info.output = api_info.input + + return api_info + + +class SummaryStackDataset(Dataset): + def __init__(self, input_path, framework=None): + super().__init__(input_path) + self.framework = framework + self.ranks = [] + self.stack_json_file = None + self._raw_data = None + + def parse(self) -> Optional[Dict]: + if not os.path.exists(self.stack_json_file): + logger.warning("Fail to parse stack json file, because %s is not exist.", self.stack_json_file) + return {} + with open(self.stack_json_file, "r") as f: + return json.load(f) + + def query(self, rank: int, api_name: str) -> Optional[str]: + if rank not in self.ranks: + self.ranks.append(rank) + rank_path = "rank" + str(rank) + if rank == 0 and os.path.exists(os.path.join(self.input_path, "rank", constant.API_STACK_JSON_FILE_NAME)): + rank_path = "rank" + logger.debug("Start to parse stack info in %s.", rank_path) + + self.stack_json_file = os.path.join(self.input_path, rank_path, constant.API_STACK_JSON_FILE_NAME) + self._raw_data = self.parse() + return self._raw_data.get(api_name) if self._raw_data is not None else None diff --git a/debug/accuracy_tools/msprobe/pytorch/nan_analyse/acs_msprobe/dataset/process_group_dataset.py b/debug/accuracy_tools/msprobe/pytorch/nan_analyse/acs_msprobe/dataset/process_group_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..b55904548b8c7be9b8b90a0319cab348539e7f11 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/nan_analyse/acs_msprobe/dataset/process_group_dataset.py @@ -0,0 +1,232 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +from functools import partial +from typing import List, Callable, Any + +from acs_msprobe.tools import const as constant +from acs_msprobe.tools.utils import get_current_rank, is_int_number, is_all_close, load_variable + + +class ProcessGroup: + """ + 用于存储ProcessGroup类数据格式, + group_id表示通信组id, + global_ranks表示该通信组内包含哪些rank节点 + local_order表示当前通信算子在当前rank相同的相同类型通信算子的出现顺序 + """ + __slots__ = ['group_id', 'global_ranks', "local_order"] + + def __init__(self, group_id: int, global_ranks: List): + self.group_id = group_id + self.global_ranks = global_ranks + self.local_order = 0 + + def __eq__(self, other): + return self.global_ranks == other.global_ranks and self.local_order == other.local_order + + def update_local_order(self, current_rank: int): + self.local_order = current_rank + + +class ProcessGroupCounter: + def __init__(self): + self.rank = -1 + self.group_mapping = {} + + def update(self, group: ProcessGroup, rank, dist_func_name): + if self.rank == -1 or self.rank != rank: + self.rank = rank + self.group_mapping = {} + dist_func = dist_func_name.split(constant.DELIMITER)[:2] + key_name = constant.DELIMITER.join(dist_func) + "-" + \ + "-".join([str(i) for i in group.global_ranks]) + "-" + \ + str(group.group_id) + if key_name not in self.group_mapping.keys(): + self.group_mapping[key_name] = 1 + else: + self.group_mapping[key_name] += 1 + group.update_local_order(self.group_mapping[key_name]) + + +class SummaryItem: + __slots__ = ['api_name', 'line', 'stack_info', 'rank', 'input', 'output', 'group'] + REL_TOL = load_variable("CMP_REL_TOL", constant.CMP_REL_TOL, float) + ABS_TOL = load_variable("CMP_ABS_TOL", constant.CMP_ABS_TOL, float) + + def __init__(self, api_name: str, line: int, + group=None, + rank: int = -1, + stack_info: str = None, + input_list: List = None, + output: List = None): + self.api_name = api_name + self.line = line + self.group = group + self.rank = rank + self.stack_info = stack_info + self.input = input_list if input_list is not None else [] + self.output = output if output is not None else [] + + def is_empty(self): + return self.api_name is None and self.line is None + + def is_distributed_api(self): + return constant.DISTRIBUTED in self.api_name + + def get(self, key, default=None): + return getattr(self, key) if hasattr(self, key) else default + + def __eq__(self, other): + if self.api_name != other.api_name: + return False + + if len(self.input) != len(other.input) or len(self.output) != len(other.output): + return False + + cmp_res = True + for _bench, _candidate in zip(self.input + self.output, other.input + other.output[:len(self.output)]): + if isinstance(_bench, list) and len(_bench) > 0 and isinstance(_bench[0], tuple): + _bench = _bench[0] + if isinstance(_candidate, list) and len(_candidate) > 0 and isinstance(_candidate[0], tuple): + _candidate = _candidate[0] + cmp_res &= all([is_all_close(_bench[-1], _candidate[-1], self.REL_TOL, self.ABS_TOL)]) + if not cmp_res: + break + return cmp_res + + def is_equivalent(self, other, rel_tol=None, abs_tol=None): + rel_tol = self.REL_TOL if rel_tol is None else rel_tol + abs_tol = self.ABS_TOL if abs_tol is None else abs_tol + + if self.is_distributed_api(): + self.inplace_convert(self) + self.inplace_convert(other) + output_target_index = self.check_different_item(self.output, other.output, + rel_tol=rel_tol, abs_tol=abs_tol) + return output_target_index == -1 + + def check_different_item(self, bench_list: List[Any], candidate_list: List[Any], + rel_tol: float = None, abs_tol: float = None) -> int: + # 从列表中找到第一个不相似的元素 + index = 0 + for _bench, _candidate in zip(bench_list, candidate_list): + if isinstance(_bench, list) and len(_bench) > 0 and isinstance(_bench[0], tuple): + _bench = _bench[0] + if isinstance(_candidate, list) and len(_candidate) > 0 and isinstance(_candidate[0], tuple): + _candidate = _candidate[0] + try: + if not all([is_all_close(_bench[-1], _candidate[-1], rel_tol, abs_tol)]): + return index + except IndexError as e: + raise IndexError(f"bench_list: {_bench}, candidate_list: {_candidate}") from e + index += 1 + return -1 + + def partial_compare(self) -> Callable: + return partial(is_all_close, rel_tol=self.REL_TOL, abs_tol=self.ABS_TOL) + + def is_match(self, other): + if self.api_name != other.api_name: + return False + if len(self.input) != len(other.input) or len(self.output) != len(other.output): + return False + cmp_res = True + for _bench, _candidate in zip(self.input + self.output, other.input + other.output): + if isinstance(_bench, list) and len(_bench) > 0 and isinstance(_bench[0], tuple): + _bench = _bench[0] + if isinstance(_candidate, list) and len(_candidate) > 0 and isinstance(_candidate[0], tuple): + _candidate = _candidate[0] + cmp_res &= _bench[1:3] == _candidate[1:3] + if not cmp_res: + break + return cmp_res + + @staticmethod + def inplace_convert(api_info): + api_name = api_info.api_name + for handler, dist_name_list in constant.DIST_MAPPING.items(): + for dist_name in dist_name_list: + if dist_name not in api_name or hasattr(SummaryItem, handler): + continue + getattr(SummaryItem, handler)(api_info) + return + + @staticmethod + def deepcopy_convert(api_info): + api_info_copy = copy.deepcopy(api_info) + api_name = api_info_copy.api_name + for handler, dist_name_list in constant.DIST_MAPPING.items(): + for dist_name in dist_name_list: + if dist_name not in api_name or not hasattr(SummaryItem, handler): + continue + getattr(SummaryItem, handler)(api_info_copy) + return api_info_copy + + @staticmethod + def base_distributed_handler(api_info): + # 大部分通信算子的第一个输入是输出的初始值,在compare中不需要关注 + if not isinstance(api_info.input, list) or len(api_info.input) <= 0: + return + api_info.input.pop(0) + + @staticmethod + def broadcast_handler(api_info): + src = api_info.input[1] + current_rank: int = get_current_rank(api_info.rank) + is_int, number = is_int_number(src) + if is_int and number != current_rank: + api_info.input = [] + + return api_info + + @staticmethod + def scatter_handler(api_info): + src = api_info.input[2] + current_rank: int = get_current_rank(api_info.rank) + is_int, number = is_int_number(src) + if is_int and number != current_rank: + api_info.input = [] + + return api_info + + @staticmethod + def all_reduce_handler(api_info): + return api_info + + @staticmethod + def gather_handler(api_info): + dst = api_info.input[2] + current_rank: int = get_current_rank(api_info.rank) + is_int, number = is_int_number(dst) + if is_int and number != current_rank: + api_info.output = [] + return api_info + + @staticmethod + def reduce_handler(api_info): + dst = api_info.input[1] + current_rank: int = get_current_rank(api_info.rank) + is_int, number = is_int_number(dst) + if is_int and number != current_rank: + api_info.output = [] + return api_info + + def __repr__(self): + return f"\t【名称】:\t<{self.api_name}>\n\t" \ + f"【位置】:\t{self.line}\n\t" \ + f"【输入】:\t{self.input}\n\t" \ + f"【输出】:\t{self.output}" diff --git a/debug/accuracy_tools/msprobe/pytorch/nan_analyse/acs_msprobe/dataset/summary_dataset_factory.py b/debug/accuracy_tools/msprobe/pytorch/nan_analyse/acs_msprobe/dataset/summary_dataset_factory.py new file mode 100644 index 0000000000000000000000000000000000000000..8de58e853c01a076de534e22140caf112a4ee788 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/nan_analyse/acs_msprobe/dataset/summary_dataset_factory.py @@ -0,0 +1,39 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from acs_msprobe.dataset.msprobe_summary_dataset import SummaryDumpDataset, SummaryStackDataset +from acs_msprobe.tools import const as constant + + +class SummaryDumpDatasetFactory: + + @staticmethod + def create_dataset(**kwargs): + input_path = kwargs.get("input_path") + if input_path is None or not os.path.exists(input_path): + raise ValueError(f"Invalid input path: {input_path}") + if SummaryDumpDatasetFactory._is_json_dataset(input_path): + return SummaryDumpDataset(**kwargs), SummaryStackDataset(**kwargs) + else: + return TypeError(f"Unsupported summary dump dataset. Please check your input path: {input_path}") + + @staticmethod + def _is_json_dataset(data_dir): + for rank in os.listdir(data_dir): + if constant.API_DUMP_JSON_FILE_NAME in os.listdir(os.path.join(data_dir, rank)): + return True + return False diff --git a/debug/accuracy_tools/msprobe/pytorch/nan_analyse/acs_msprobe/result/__init__.py b/debug/accuracy_tools/msprobe/pytorch/nan_analyse/acs_msprobe/result/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c75b481f5a24ac56ca2bc0eb705704f0d14efb28 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/nan_analyse/acs_msprobe/result/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/debug/accuracy_tools/msprobe/pytorch/nan_analyse/acs_msprobe/result/issues.py b/debug/accuracy_tools/msprobe/pytorch/nan_analyse/acs_msprobe/result/issues.py new file mode 100644 index 0000000000000000000000000000000000000000..5c7889944c670c16bfb85e28647a0a77fae2f365 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/nan_analyse/acs_msprobe/result/issues.py @@ -0,0 +1,39 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum + + +class IssueType(Enum): + HARDWARE_ISSUE = "Hardware issue: silent data correction (SDC) happens." + DIST_API_ISSUE = "Software issue: unexpected communication api behavior." + COMPUTE_API_ISSUE = "Software issue: unexpected compute api behavior." + VERSION_ISSUE = "Version issue: software version mismatch." + EMPTY_ISSUE = "N/A" + COMPARE_ISSUE = "Match issue: unexpected difference from the benchmark." + UNEXPECTED_INPUT_ISSUE = ("Unexpected input issue: current api input is abnormal and need to double " + "check whether it occur in the first time.") + + +class ExceptionInfo: + UNCOVERED_INIT_ISSUE = "Uncovered abnormal initial input " + + +class ConfidenceEnum(Enum): + VERY_HIGH = 1.0 + HIGH = 0.8 + MEDIUM = 0.6 + LOW = 0.4 + ZERO = 0.0 diff --git a/debug/accuracy_tools/msprobe/pytorch/nan_analyse/acs_msprobe/test/analyzer/test_anomaly_api_analyzer.py b/debug/accuracy_tools/msprobe/pytorch/nan_analyse/acs_msprobe/test/analyzer/test_anomaly_api_analyzer.py new file mode 100644 index 0000000000000000000000000000000000000000..3650f4acef6c051843c429f825c8954618f541d4 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/nan_analyse/acs_msprobe/test/analyzer/test_anomaly_api_analyzer.py @@ -0,0 +1,92 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from unittest.mock import MagicMock, patch +from acs_msprobe.analyzer.anomaly_analyzer.anomaly_api_analyzer import AnomalyAPIAnalyzer +from acs_msprobe.dataset.msprobe_summary_dataset import SummaryStackDataset, SummaryItem +from acs_msprobe.result.analyzer_result import ResultCollector + + +class TestAnomalyAPIAnalyzer: + @pytest.fixture + def mock_dataset(self): + dataset = MagicMock() + dataset.dataset = { + 0: [{"api_name": "test.api1"}, {"api_name": "test.api2"}], + 1: None, # Test empty rank + 2: [{"api_name": "test.api3.backward"}] # Test backward api + } + return dataset + + @pytest.fixture + def mock_stack_dataset(self): + stack_dataset = MagicMock(spec=SummaryStackDataset) + stack_dataset.query.return_value = "stack_info" + return stack_dataset + + @pytest.fixture + def mock_result_collector(self): + return MagicMock(spec=ResultCollector) + + def test_analyze_without_stack_dataset(self, mock_dataset): + analyzer = AnomalyAPIAnalyzer(mock_dataset) + + with patch('acs_msprobe.checker.factory.StaticApiHandlerFactory.dispatch') as mock_dispatch: + mock_dispatch.return_value = "test_result" + analyzer.analyze() + + assert mock_dispatch.call_count == 3 # Called for each api except None + assert analyzer.analyzer_result.add.call_count == 3 + + def test_analyze_with_stack_dataset(self, mock_dataset, mock_stack_dataset): + analyzer = AnomalyAPIAnalyzer(mock_dataset, mock_stack_dataset) + + with patch('acs_msprobe.checker.factory.StaticApiHandlerFactory.dispatch') as mock_dispatch: + mock_dispatch.return_value = MagicMock() + analyzer.analyze() + + # Verify backward api is converted to forward + mock_stack_dataset.query.assert_called_with(2, "test.api3.forward") + assert analyzer.analyzer_result.add.call_count == 2 # Only 2 unique apis + + def test_analyze_with_empty_dataset(self): + empty_dataset = MagicMock() + empty_dataset.dataset = {} + analyzer = AnomalyAPIAnalyzer(empty_dataset) + + analyzer.analyze() + assert analyzer.analyzer_result.add.call_count == 0 + + def test_analyze_with_duplicate_apis(self, mock_dataset, mock_stack_dataset): + mock_dataset.dataset = { + 0: [{"api_name": "test.api"}, {"api_name": "test.api"}] + } + analyzer = AnomalyAPIAnalyzer(mock_dataset, mock_stack_dataset) + + with patch('acs_msprobe.checker.factory.StaticApiHandlerFactory.dispatch') as mock_dispatch: + mock_dispatch.return_value = MagicMock() + analyzer.analyze() + + assert analyzer.analyzer_result.add.call_count == 1 # Only one unique api + + def test_analyze_with_no_result(self, mock_dataset): + analyzer = AnomalyAPIAnalyzer(mock_dataset) + + with patch('acs_msprobe.checker.factory.StaticApiHandlerFactory.dispatch') as mock_dispatch: + mock_dispatch.return_value = None + analyzer.analyze() + + assert analyzer.analyzer_result.add.call_count == 0 \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/pytorch/nan_analyse/acs_msprobe/test/analyzer/test_base_analyzer.py b/debug/accuracy_tools/msprobe/pytorch/nan_analyse/acs_msprobe/test/analyzer/test_base_analyzer.py new file mode 100644 index 0000000000000000000000000000000000000000..75a2b671cfa75ea84a2e7a4f8dc106c414d2077e --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/nan_analyse/acs_msprobe/test/analyzer/test_base_analyzer.py @@ -0,0 +1,44 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest.mock import Mock +from acs_msprobe.analyzer.base_analyzer import Analyzer +from torch.utils.data import Dataset + + +class TestAnalyzer(unittest.TestCase): + def setUp(self): + # Create a mock Dataset for testing + self.mock_dataset = Mock(spec=Dataset) + + def test_init_with_valid_dataset(self): + """Test initialization with a valid Dataset instance""" + analyzer = Analyzer(self.mock_dataset) + self.assertEqual(analyzer.dataset, self.mock_dataset) + + def test_init_with_invalid_dataset(self): + """Test initialization raises TypeError with invalid dataset""" + with self.assertRaises(TypeError) as context: + Analyzer("not_a_dataset") + self.assertEqual(str(context.exception), "dataset must be an instance of Dataset.") + + def test_analyze_not_implemented(self): + """Test analyze method raises NotImplementedError""" + analyzer = Analyzer(self.mock_dataset) + with self.assertRaises(NotImplementedError) as context: + analyzer.analyze() + self.assertEqual(str(context.exception), "analyze method not implemented.") + diff --git a/debug/accuracy_tools/msprobe/pytorch/nan_analyse/acs_msprobe/test/analyzer/test_summary_analyzer_factory.py b/debug/accuracy_tools/msprobe/pytorch/nan_analyse/acs_msprobe/test/analyzer/test_summary_analyzer_factory.py new file mode 100644 index 0000000000000000000000000000000000000000..003b395358936f92490447465baa0ad50458aa49 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/nan_analyse/acs_msprobe/test/analyzer/test_summary_analyzer_factory.py @@ -0,0 +1,62 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from acs_msprobe.analyzer.anomaly_analyzer.summary_analyzer_factory import SummaryDumpAnalyzerFactory +from acs_msprobe.analyzer.anomaly_analyzer.anomaly_api_analyzer import AnomalyAPIAnalyzer + +class TestSummaryDumpAnalyzerFactory: + def test_create_analyzer_default(self): + """ + 测试默认创建analyzer的情况 + """ + analyzer = SummaryDumpAnalyzerFactory.create_analyzer() + assert isinstance(analyzer, AnomalyAPIAnalyzer) + + def test_create_analyzer_specified(self): + """ + 测试指定创建analyzer的情况 + """ + analyzer = SummaryDumpAnalyzerFactory.create_analyzer(analyze_method="anomaly_api") + assert isinstance(analyzer, AnomalyAPIAnalyzer) + + def test_create_analyzer_with_kwargs(self): + """ + 测试创建analyzer时传递额外参数的情况 + """ + analyzer = SummaryDumpAnalyzerFactory.create_analyzer(param1="value1", param2="value2") + assert isinstance(analyzer, AnomalyAPIAnalyzer) + + def test_create_analyzer_unsupported_method(self): + """ + 测试创建不支持的analyzer类型的情况 + """ + with pytest.raises(ValueError) as excinfo: + SummaryDumpAnalyzerFactory.create_analyzer(analyze_method="unsupported_method") + assert "Unsupported analyzer method: unsupported_method" in str(excinfo.value) + + def test_get_analyzer_supported(self): + """ + 测试获取支持的analyzer类 + """ + analyzer_class = SummaryDumpAnalyzerFactory.get_analyzer("anomaly_api") + assert analyzer_class == AnomalyAPIAnalyzer + + def test_get_analyzer_unsupported(self): + """ + 测试获取不支持的analyzer类 + """ + analyzer_class = SummaryDumpAnalyzerFactory.get_analyzer("unsupported_method") + assert analyzer_class is None diff --git a/debug/accuracy_tools/msprobe/pytorch/nan_analyse/acs_msprobe/test/checker/test_factory.py b/debug/accuracy_tools/msprobe/pytorch/nan_analyse/acs_msprobe/test/checker/test_factory.py new file mode 100644 index 0000000000000000000000000000000000000000..dbafe713496299f7f46c8a64f2e6377a2ba0b0a5 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/nan_analyse/acs_msprobe/test/checker/test_factory.py @@ -0,0 +1,113 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from acs_msprobe.tools.const import COMPUTEENUM +from acs_msprobe.tools.utils import is_distributed_api +from acs_msprobe.pytorch.nan_analyse.acs_msprobe.checker.factory import StaticApiHandlerFactory + + +# Mock handler classes for testing +class MockHandler1: + @classmethod + def handle(cls, *args, **kwargs): + return "MockHandler1_result" + + +class MockHandler2: + @classmethod + def handle(cls, *args, **kwargs): + return "MockHandler2_result" + + +class MockDistributedHandler: + @classmethod + def handle(cls, *args, **kwargs): + return "Distributed_result" + + +@pytest.fixture(autouse=True) +def reset_factory(): + """Fixture to reset the handler registry before each test""" + StaticApiHandlerFactory.reset() + yield + StaticApiHandlerFactory.reset() + + +def test_register_single_name(): + """Test registering a handler with default name""" + StaticApiHandlerFactory.register(MockHandler1) + assert MockHandler1.__name__ in StaticApiHandlerFactory.handlers + assert StaticApiHandlerFactory.handlers[MockHandler1.__name__] == MockHandler1 + + +def test_register_with_alias(): + """Test registering a handler with custom alias names""" + StaticApiHandlerFactory.register(MockHandler2, alias=["alias1", "alias2"]) + assert "alias1" in StaticApiHandlerFactory.handlers + assert "alias2" in StaticApiHandlerFactory.handlers + assert StaticApiHandlerFactory.handlers["alias1"] == MockHandler2 + assert StaticApiHandlerFactory.handlers["alias2"] == MockHandler2 + + +def test_create_existing_handler(): + """Test creating an existing handler""" + StaticApiHandlerFactory.register(MockHandler1) + result = StaticApiHandlerFactory.create(MockHandler1.__name__, "test_arg") + assert result == "MockHandler1_result" + + +def test_create_nonexistent_handler(): + """Test creating a non-existent handler""" + with pytest.raises(ValueError) as excinfo: + StaticApiHandlerFactory.create("nonexistent_handler") + assert "Handler 'nonexistent_handler' not found" in str(excinfo.value) + + +def test_dispatch_non_distributed_api(monkeypatch): + """Test dispatch with non-distributed API""" + monkeypatch.setattr('acs_msprobe.tools.utils.is_distributed_api', lambda x: False) + StaticApiHandlerFactory.register(MockHandler1, alias=[COMPUTEENUM.NORMAL_INPUT_ABNORMAL_OUTPUT]) + + result = StaticApiHandlerFactory.dispatch("some_api") + assert result == "MockHandler1_result" + + +def test_dispatch_distributed_api_matching(monkeypatch): + """Test dispatch with distributed API that matches a handler""" + monkeypatch.setattr('acs_msprobe.tools.utils.is_distributed_api', lambda x: True) + StaticApiHandlerFactory.register(MockDistributedHandler, alias=["distributed"]) + + result = StaticApiHandlerFactory.dispatch("distributed_api_name") + assert result == "Distributed_result" + + +def test_dispatch_distributed_api_no_match(monkeypatch): + """Test dispatch with distributed API that doesn't match any handler""" + monkeypatch.setattr('acs_msprobe.tools.utils.is_distributed_api', lambda x: True) + StaticApiHandlerFactory.register(MockHandler1, alias=["other_key"]) + + result = StaticApiHandlerFactory.dispatch("distributed_api_name") + assert result is None + + +def test_reset(): + """Test resetting the handler registry""" + StaticApiHandlerFactory.register(MockHandler1) + StaticApiHandlerFactory.register(MockHandler2, alias=["alias1"]) + assert len(StaticApiHandlerFactory.handlers) == 2 + + StaticApiHandlerFactory.reset() + assert len(StaticApiHandlerFactory.handlers) == 0 diff --git a/debug/accuracy_tools/msprobe/pytorch/nan_analyse/acs_msprobe/test/dataset/test_base_dataset.py b/debug/accuracy_tools/msprobe/pytorch/nan_analyse/acs_msprobe/test/dataset/test_base_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..888a1de9e6d4c86c47248b989c924a23d29d0401 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/nan_analyse/acs_msprobe/test/dataset/test_base_dataset.py @@ -0,0 +1,72 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import pytest +from unittest.mock import patch +from acs_msprobe.dataset.base_dataset import Dataset + + +class TestDataset: + @pytest.fixture + def mock_logger(self): + with patch('acs_msprobe.dataset.base_dataset.logger') as mock: + yield mock + + def test_init_with_valid_path(self, tmp_path, mock_logger): + """Test initialization with valid file path""" + test_file = tmp_path / "test.txt" + test_file.write_text("test content") + + dataset = Dataset(str(test_file)) + assert dataset.input_path == os.path.abspath(str(test_file)) + assert isinstance(dataset.dataset, dict) + mock_logger.debug.assert_called_once_with( + "Init %s with input_path: %s", "Dataset", os.path.abspath(str(test_file)) + ) + + def test_init_with_none_path(self): + """Test initialization with None path""" + with pytest.raises(FileNotFoundError) as excinfo: + Dataset(None) + assert "Dataset file None not found." in str(excinfo.value) + + def test_init_with_invalid_path(self): + """Test initialization with non-existent path""" + invalid_path = "/path/that/does/not/exist" + with pytest.raises(FileNotFoundError) as excinfo: + Dataset(invalid_path) + assert f"Dataset file {invalid_path} not found." in str(excinfo.value) + + def test_len_method(self, tmp_path): + """Test __len__ method""" + test_file = tmp_path / "test.txt" + test_file.write_text("test content") + + dataset = Dataset(str(test_file)) + assert len(dataset) == 0 # dataset is empty dict by default + + dataset.dataset = {"key1": "value1", "key2": "value2"} + assert len(dataset) == 2 + + def test_parse_method_not_implemented(self, tmp_path): + """Test abstract parse method raises NotImplementedError""" + test_file = tmp_path / "test.txt" + test_file.write_text("test content") + + dataset = Dataset(str(test_file)) + with pytest.raises(NotImplementedError) as excinfo: + dataset.parse() + assert "parse method not implemented." in str(excinfo.value) \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/pytorch/nan_analyse/acs_msprobe/test/dataset/test_msprobe_summary_dataset.py b/debug/accuracy_tools/msprobe/pytorch/nan_analyse/acs_msprobe/test/dataset/test_msprobe_summary_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..315e3503db4321c631bc7fa5ece05ae7da3db38c --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/nan_analyse/acs_msprobe/test/dataset/test_msprobe_summary_dataset.py @@ -0,0 +1,119 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import os +import json +from unittest.mock import patch +from acs_msprobe.dataset.msprobe_summary_dataset import ( + SummaryDumpDataset, + SummaryStackDataset, + ProcessGroupMonitor +) +from acs_msprobe.dataset.process_group_dataset import ProcessGroup, SummaryItem +from acs_msprobe.tools import const as constant + + +# Test ProcessGroupMonitor +def test_process_group_monitor_collect(): + monitor = ProcessGroupMonitor("torch.distributed.ProcessGroup") + meta_info = { + "type": "torch.distributed.ProcessGroup", + "global_ranks": [0, 1, 2, 3], + "group_id": 281471037097008 + } + api_info = {"rank": 0, "api_name": "test_api"} + monitor.collect(meta_info, api_info) + assert "group" in api_info + assert isinstance(api_info["group"], ProcessGroup) + + +# Test SummaryDumpDataset +def test_summary_dump_dataset_from_args_valid(): + with patch('os.path.exists', return_value=True): + dataset = SummaryDumpDataset.from_args(input_path="/valid/path", framework="pytorch") + assert isinstance(dataset, SummaryDumpDataset) + assert dataset.framework == "pytorch" + + +def test_summary_dump_dataset_from_args_invalid_path(): + with patch('os.path.exists', return_value=False): + with pytest.raises(TypeError): + SummaryDumpDataset.from_args(input_path="/invalid/path", framework="pytorch") + + +def test_summary_dump_dataset_from_args_unsupported_framework(): + with patch('os.path.exists', return_value=True): + with pytest.raises(ValueError): + SummaryDumpDataset.from_args(input_path="/valid/path", framework="tensorflow") + + +def test_summary_dump_dataset_parse_all_ranks(tmp_path): + # Setup test directory structure + rank_dirs = ["rank0", "rank1"] + for rank_dir in rank_dirs: + os.makedirs(tmp_path / rank_dir) + + dataset = SummaryDumpDataset(str(tmp_path), "pytorch") + dataset.parse_all_ranks() + assert len(dataset.dataset) == 2 + + +def test_summary_dump_dataset_parse_dict_item(): + dataset = SummaryDumpDataset("/dummy", "pytorch") + + # Test float type + api_info = {} + meta_info = {"type": "float", "value": 1.23} + meta, static = dataset._parse_dict_item(api_info, meta_info) + assert meta == [constant.DTYPE_CLASS_FLOAT, '[]'] + assert static == [1.23, 1.23, 1.23, 1.23] + + # Test tensor type + meta_info = {"type": "Tensor", "shape": "[2,2]", "Max": 2, "Min": 1, "Mean": 1.5, "Norm": 1.8} + meta, static = dataset._parse_dict_item(api_info, meta_info) + assert meta == ["Tensor", "[2,2]"] + assert static == [2, 1, 1.5, 1.8] + + +def test_summary_dump_dataset_parse_list_object(): + dataset = SummaryDumpDataset("/dummy", "pytorch") + api_info = {"rank": 0} + api_name = "test_api" + meta_list = [{"type": "float", "value": 1.23}, None] + result = dataset._parse_list_object(api_info, api_name, meta_list) + assert len(result) == 1 + assert result[0][0] == "test_api" + assert result[0][3] == [1.23, 1.23, 1.23, 1.23] + + +# Test SummaryStackDataset +def test_summary_stack_dataset_query(tmp_path): + # Setup test data + stack_data = {"api1": "stack1", "api2": "stack2"} + os.makedirs(tmp_path / "rank0") + with open(tmp_path / "rank0" / constant.API_STACK_JSON_FILE_NAME, 'w') as f: + json.dump(stack_data, f) + + dataset = SummaryStackDataset(str(tmp_path)) + result = dataset.query(0, "api1") + assert result == "stack1" + assert 0 in dataset.ranks + + +def test_summary_stack_dataset_query_missing_file(tmp_path): + dataset = SummaryStackDataset(str(tmp_path)) + result = dataset.query(0, "api1") + assert result is None diff --git a/debug/accuracy_tools/msprobe/pytorch/nan_analyse/acs_msprobe/test/dataset/test_process_group_dataset.py b/debug/accuracy_tools/msprobe/pytorch/nan_analyse/acs_msprobe/test/dataset/test_process_group_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..103d52743ff9076d0cb5650ce1e990b88c02589c --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/nan_analyse/acs_msprobe/test/dataset/test_process_group_dataset.py @@ -0,0 +1,194 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import copy +from unittest.mock import patch +from acs_msprobe.dataset.process_group_dataset import ( + ProcessGroup, ProcessGroupCounter, SummaryItem +) + + +# Test ProcessGroup class +class TestProcessGroup: + def test_init(self): + group = ProcessGroup(1, [0, 1, 2]) + assert group.group_id == 1 + assert group.global_ranks == [0, 1, 2] + assert group.local_order == 0 + + def test_eq(self): + group1 = ProcessGroup(1, [0, 1, 2]) + group2 = ProcessGroup(1, [0, 1, 2]) + group3 = ProcessGroup(2, [0, 1, 2]) + group4 = ProcessGroup(1, [0, 1]) + + assert group1 == group2 + assert not (group1 == group3) + assert not (group1 == group4) + + def test_update_local_order(self): + group = ProcessGroup(1, [0, 1, 2]) + group.update_local_order(5) + assert group.local_order == 5 + + +# Test ProcessGroupCounter class +class TestProcessGroupCounter: + def test_init(self): + counter = ProcessGroupCounter() + assert counter.rank == -1 + assert counter.group_mapping == {} + + def test_update_new_rank(self): + counter = ProcessGroupCounter() + group = ProcessGroup(1, [0, 1, 2]) + counter.update(group, 0, "dist_func") + assert counter.rank == 0 + assert len(counter.group_mapping) == 1 + + def test_update_existing_rank(self): + counter = ProcessGroupCounter() + group = ProcessGroup(1, [0, 1, 2]) + counter.update(group, 0, "dist_func") + counter.update(group, 0, "dist_func") + assert counter.group_mapping["dist_func-0-1-2-1"] == 2 + + +# Test SummaryItem class +class TestSummaryItem: + @pytest.fixture + def sample_item(self): + return SummaryItem( + api_name="test_api", + line=10, + group=ProcessGroup(1, [0, 1, 2]), + rank=0, + stack_info="stack", + input_list=[("input1", 1, 2, 3.0)], + output=[("output1", 1, 2, 3.0)] + ) + + def test_init(self, sample_item): + assert sample_item.api_name == "test_api" + assert sample_item.line == 10 + assert sample_item.group.group_id == 1 + assert sample_item.rank == 0 + assert sample_item.stack_info == "stack" + assert sample_item.input == [("input1", 1, 2, 3.0)] + assert sample_item.output == [("output1", 1, 2, 3.0)] + + def test_is_empty(self): + empty_item = SummaryItem(None, None) + assert empty_item.is_empty() + assert not sample_item.is_empty() + + def test_is_distributed_api(self): + dist_item = SummaryItem("distributed_api", 1) + non_dist_item = SummaryItem("regular_api", 1) + assert dist_item.is_distributed_api() + assert not non_dist_item.is_distributed_api() + + def test_get(self, sample_item): + assert sample_item.get("api_name") == "test_api" + assert sample_item.get("nonexistent", "default") == "default" + + def test_eq(self, sample_item): + same_item = SummaryItem( + api_name="test_api", + line=10, + input_list=[("input1", 1, 2, 3.0)], + output=[("output1", 1, 2, 3.0)] + ) + diff_item1 = SummaryItem( + api_name="different_api", + line=10, + input_list=[("input1", 1, 2, 3.0)], + output=[("output1", 1, 2, 3.0)] + ) + diff_item2 = SummaryItem( + api_name="test_api", + line=10, + input_list=[("input1", 1, 2, 4.0)], + output=[("output1", 1, 2, 3.0)] + ) + + assert sample_item == same_item + assert not (sample_item == diff_item1) + assert not (sample_item == diff_item2) + + @patch('acs_msprobe.tools.utils.is_all_close', return_value=True) + def test_is_equivalent(self, mock_is_all_close, sample_item): + other_item = copy.deepcopy(sample_item) + assert sample_item.is_equivalent(other_item) + mock_is_all_close.assert_called() + + def test_check_different_item(self, sample_item): + bench = [("item1", 1, 2, 3.0)] + candidate1 = [("item1", 1, 2, 3.0)] + candidate2 = [("item1", 1, 2, 4.0)] + + assert sample_item.check_different_item(bench, candidate1) == -1 + assert sample_item.check_different_item(bench, candidate2) == 0 + + def test_partial_compare(self, sample_item): + compare_func = sample_item.partial_compare() + assert callable(compare_func) + + def test_is_match(self, sample_item): + matching_item = SummaryItem( + api_name="test_api", + line=10, + input_list=[("input1", 1, 2, 3.0)], + output=[("output1", 1, 2, 3.0)] + ) + non_matching_item = SummaryItem( + api_name="test_api", + line=10, + input_list=[("input1", 1, 3, 3.0)], + output=[("output1", 1, 2, 3.0)] + ) + + assert sample_item.is_match(matching_item) + assert not sample_item.is_match(non_matching_item) + + @patch('acs_msprobe.tools.utils.get_current_rank', return_value=0) + def test_broadcast_handler(self, mock_get_rank): + item = SummaryItem( + api_name="broadcast_api", + line=1, + rank=0, + input_list=[("input1", 1, 2, 3.0), 0] + ) + processed_item = SummaryItem.broadcast_handler(item) + assert processed_item.input == [("input1", 1, 2, 3.0), 0] + + @patch('acs_msprobe.tools.utils.get_current_rank', return_value=1) + def test_scatter_handler(self, mock_get_rank): + item = SummaryItem( + api_name="scatter_api", + line=1, + rank=1, + input_list=[("input1", 1, 2, 3.0), ("input2", 1, 2, 3.0), 0] + ) + processed_item = SummaryItem.scatter_handler(item) + assert processed_item.input == [] + + def test_repr(self, sample_item): + repr_str = repr(sample_item) + assert "test_api" in repr_str + assert "10" in repr_str + assert "input1" in repr_str + assert "output1" in repr_str diff --git a/debug/accuracy_tools/msprobe/pytorch/nan_analyse/acs_msprobe/test/dataset/test_summary_dataset_factory.py b/debug/accuracy_tools/msprobe/pytorch/nan_analyse/acs_msprobe/test/dataset/test_summary_dataset_factory.py new file mode 100644 index 0000000000000000000000000000000000000000..2c703b87166ae89ed80615d010b2f688f15a67fe --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/nan_analyse/acs_msprobe/test/dataset/test_summary_dataset_factory.py @@ -0,0 +1,98 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from unittest.mock import patch +from acs_msprobe.dataset.summary_dataset_factory import SummaryDumpDatasetFactory +from acs_msprobe.dataset.msprobe_summary_dataset import SummaryDumpDataset, SummaryStackDataset +from acs_msprobe.tools import const as constant + + +class TestSummaryDumpDatasetFactory: + + @patch('os.path.exists') + @patch('os.listdir') + def test_create_dataset_valid_json(self, mock_listdir, mock_exists): + """Test creating dataset with valid json input path""" + # Setup mock + mock_exists.return_value = True + mock_listdir.side_effect = [ + ['rank0', 'rank1'], # First call for data_dir + [constant.API_DUMP_JSON_FILE_NAME, 'other_file'], # Second call for rank0 + ] + + # Test + kwargs = {"input_path": "/valid/path"} + dump_dataset, stack_dataset = SummaryDumpDatasetFactory.create_dataset(**kwargs) + + # Verify + assert isinstance(dump_dataset, SummaryDumpDataset) + assert isinstance(stack_dataset, SummaryStackDataset) + mock_exists.assert_called_once_with("/valid/path") + + @patch('os.path.exists') + def test_create_dataset_invalid_path(self, mock_exists): + """Test creating dataset with invalid input path""" + # Setup mock + mock_exists.return_value = False + + # Test and verify + kwargs = {"input_path": "/invalid/path"} + with pytest.raises(ValueError, match="Invalid input path: /invalid/path"): + SummaryDumpDatasetFactory.create_dataset(**kwargs) + + @patch('os.path.exists') + @patch('os.listdir') + def test_create_dataset_unsupported_format(self, mock_listdir, mock_exists): + """Test creating dataset with unsupported format""" + # Setup mock + mock_exists.return_value = True + mock_listdir.side_effect = [ + ['rank0', 'rank1'], # First call for data_dir + ['file1', 'file2'], # Second call for rank0 (no json file) + ] + + # Test and verify + kwargs = {"input_path": "/unsupported/path"} + with pytest.raises(TypeError, match="Unsupported summary dump dataset"): + SummaryDumpDatasetFactory.create_dataset(**kwargs) + + @patch('os.path.exists') + @patch('os.listdir') + def test_is_json_dataset_true(self, mock_listdir, mock_exists): + """Test _is_json_dataset returns True when json file exists""" + # Setup mock + mock_exists.return_value = True + mock_listdir.side_effect = [ + ['rank0', 'rank1'], + [constant.API_DUMP_JSON_FILE_NAME, 'other_file'], + ] + + # Test and verify + assert SummaryDumpDatasetFactory._is_json_dataset("/test/path") is True + + @patch('os.path.exists') + @patch('os.listdir') + def test_is_json_dataset_false(self, mock_listdir, mock_exists): + """Test _is_json_dataset returns False when no json file exists""" + # Setup mock + mock_exists.return_value = True + mock_listdir.side_effect = [ + ['rank0', 'rank1'], + ['file1', 'file2'], + ] + + # Test and verify + assert SummaryDumpDatasetFactory._is_json_dataset("/test/path") is False \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/pytorch/nan_analyse/acs_msprobe/test/tools/test_log.py b/debug/accuracy_tools/msprobe/pytorch/nan_analyse/acs_msprobe/test/tools/test_log.py new file mode 100644 index 0000000000000000000000000000000000000000..76b67531887df93b3373a2e373cabe8c255566cb --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/nan_analyse/acs_msprobe/test/tools/test_log.py @@ -0,0 +1,62 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import logging +import pytest +from unittest.mock import patch +from acs_msprobe.tools.log import get_log_level + + +# Mock constants for testing +class MockConstants: + ADVISOR_LOG_LEVEL = "ADVISOR_LOG_LEVEL" + DEFAULT_LOG_LEVEL = "INFO" + + +@pytest.fixture +def mock_constants(): + with patch('debug.accuracy_tools.msprobe.pytorch.nan_analyse.acs_msprobe.tools.log.const', MockConstants()): + yield + + +def test_get_log_level_default(mock_constants): + with patch.dict(os.environ, clear=True): + assert get_log_level() == "INFO" + + +def test_get_log_level_valid_env(mock_constants): + test_cases = [ + ("DEBUG", "DEBUG"), + ("INFO", "INFO"), + ("WARNING", "WARNING"), + ("ERROR", "ERROR"), + ("CRITICAL", "CRITICAL"), + ] + for env_value, expected in test_cases: + with patch.dict(os.environ, {MockConstants.ADVISOR_LOG_LEVEL: env_value}): + assert get_log_level() == expected + + +def test_get_log_level_invalid_env(mock_constants): + with patch.dict(os.environ, {MockConstants.ADVISOR_LOG_LEVEL: "INVALID"}): + with pytest.raises(AttributeError) as excinfo: + get_log_level() + assert "Invalid log level" in str(excinfo.value) + + +def test_get_log_level_case_insensitive(mock_constants): + with patch.dict(os.environ, {MockConstants.ADVISOR_LOG_LEVEL: "debug"}): + assert get_log_level() == "DEBUG" diff --git a/debug/accuracy_tools/msprobe/pytorch/nan_analyse/acs_msprobe/test/tools/test_utils.py b/debug/accuracy_tools/msprobe/pytorch/nan_analyse/acs_msprobe/test/tools/test_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f0cc9b4e3e8575e8cd3f39005e5cd07900b03809 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/nan_analyse/acs_msprobe/test/tools/test_utils.py @@ -0,0 +1,230 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import math +import pytest + +from rich.progress import Progress, BarColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn +from rich.table import Table + +from unittest.mock import patch +from acs_msprobe.tools.utils import ( + get_current_rank, + is_int_number, + safe_divide, + get_rel_err, + get_abs_err, + is_close, + is_all_close, + load_variable, + get_progress_table_and_progress +) + + +# Test get_current_rank +def test_get_current_rank(): + # Test string inputs + assert get_current_rank("rank") == 0 + assert get_current_rank("rank5") == 5 + assert get_current_rank("42") == 42 + + # Test integer input + assert get_current_rank(7) == 7 + + # Test invalid input + with pytest.raises(ValueError): + get_current_rank(3.14) + with pytest.raises(ValueError): + get_current_rank(None) + + +# Test is_int_number +def test_is_int_number(): + # Valid integer case + assert is_int_number([[None, "", "[]", [5, 5, 5, 5]]]) == (True, 5) + + # Invalid cases + assert is_int_number([[None, "", "[]", [5, 5, 5, 5]]]) == (False, -1) + assert is_int_number([[None, "", "[1]", [5, 5, 5, 5]]]) == (False, -1) + assert is_int_number([[None, "", "[]", [1, 2, 3, 4]]]) == (False, -1) + + # Invalid length + with pytest.raises(ValueError): + is_int_number([[None, "", "[]"]]) + with pytest.raises(ValueError): + is_int_number([[None, "", "[]", [1, 2, 3, 4, 5]]]) + + +# Test safe_divide +def test_safe_divide(): + assert safe_divide(10, 2) == 5 + assert safe_divide(10, 0) == 0 + assert safe_divide(0, 5) == 0 + assert safe_divide(-10, 2) == -5 + + +# Test get_rel_err +def test_get_rel_err(): + assert get_rel_err(100, 110) == pytest.approx(0.1) + assert get_rel_err(0, 0) == 0 + assert get_rel_err(1e-10, 2e-10) == pytest.approx(1.0) + assert get_rel_err(-100, -110) == pytest.approx(0.1) + + +# Test get_abs_err +def test_get_abs_err(): + assert get_abs_err(100, 110) == 10 + assert get_abs_err(0, 0) == 0 + assert get_abs_err(-100, -110) == 10 + assert get_abs_err(1e-10, 2e-10) == 1e-10 + + +# Test is_close +def test_is_close(): + # Equal values + assert is_close(100, 100) + + # Close values + assert is_close(100, 100.00001) + + # Different values + assert not is_close(100, 200) + + # Special cases + assert is_close(float('nan'), float('nan')) + assert is_close(float('inf'), float('inf')) + assert not is_close(float('inf'), float('-inf')) + assert not is_close(0, float('nan')) + + # Edge cases + assert is_close(1e-10, 1.1e-10, rel_tol=0.2) + assert not is_close(1e-10, 2e-10, rel_tol=0.1) + + +# Test is_all_close +def test_is_all_close(): + # Equal lists + assert is_all_close([1, 2, 3], [1, 2, 3]) + + # Close lists + assert is_all_close([1.0, 2.0], [1.00001, 2.00001]) + + # Different lists + assert not is_all_close([1, 2, 3], [1, 2, 4]) + + # Special cases + assert is_all_close([float('nan')], [float('nan')]) + assert not is_all_close([float('nan')], [1.0]) + assert is_all_close([None], [None]) + assert not is_all_close([None], [1]) + + # Mixed types + assert is_all_close([1, "str", True], [1, "str", True]) + assert not is_all_close([1, "str", True], [2, "str", True]) + + # Nested lists + assert is_all_close([[1, 2], [3, 4]], [[1, 2], [3, 4]]) + assert not is_all_close([[1, 2], [3, 4]], [[1, 2], [3, 5]]) + + +# Test load_variable +def test_load_variable(): + # Test with environment variable set + with patch.dict(os.environ, {"TEST_VAR": "42"}): + assert load_variable("TEST_VAR", 10, int) == 42 + + # Test with default value + with patch.dict(os.environ, {}, clear=True): + assert load_variable("TEST_VAR", 10, int) == 10 + assert load_variable("TEST_VAR", "default", str) == "default" + assert load_variable("TEST_VAR", None, None) is None + + # Test type conversion + with patch.dict(os.environ, {"TEST_VAR": "3.14"}): + assert load_variable("TEST_VAR", 0, float) == 3.14 + assert load_variable("TEST_VAR", 0, str) == "3.14" + +def test_get_progress_table_and_progress_basic(): + """Test basic functionality with positive integer input""" + total_len = 10 + title = "Test Progress" + + progress_table, overall_progress, overall_task = get_progress_table_and_progress(total_len, title) + + # Verify progress table structure + assert isinstance(progress_table, Table) + assert len(progress_table.rows) == 1 + + # Verify progress bar configuration + assert isinstance(overall_progress, Progress) + assert len(overall_progress.columns) == 6 # 6 columns in the progress bar + assert any(isinstance(col, BarColumn) for col in overall_progress.columns) + assert any(isinstance(col, TimeElapsedColumn) for col in overall_progress.columns) + assert any(isinstance(col, TimeRemainingColumn) for col in overall_progress.columns) + + # Verify task configuration + assert overall_task in overall_progress.task_ids + task = overall_progress.tasks[overall_task] + assert task.description == f"Total Rank {total_len}" + assert task.total == total_len + + +def test_get_progress_table_and_progress_zero_length(): + """Test with zero length input""" + total_len = 0 + title = "Zero Progress" + + progress_table, overall_progress, overall_task = get_progress_table_and_progress(total_len, title) + + assert overall_progress.tasks[overall_task].total == 0 + assert overall_progress.tasks[overall_task].description == f"Total Rank {total_len}" + + +def test_get_progress_table_and_progress_large_number(): + """Test with very large number""" + total_len = 1000000 + title = "Large Progress" + + progress_table, overall_progress, overall_task = get_progress_table_and_progress(total_len, title) + + assert overall_progress.tasks[overall_task].total == total_len + assert overall_progress.tasks[overall_task].description == f"Total Rank {total_len}" + + +def test_get_progress_table_and_progress_empty_title(): + """Test with empty title string""" + total_len = 5 + title = "" + + progress_table, overall_progress, overall_task = get_progress_table_and_progress(total_len, title) + + assert isinstance(progress_table, Table) + assert len(progress_table.rows) == 1 + + +def test_get_progress_table_and_progress_invalid_input(): + """Test with invalid input types""" + with pytest.raises(TypeError): + get_progress_table_and_progress("10", "Test") # string length + + with pytest.raises(TypeError): + get_progress_table_and_progress(10, 123) # numeric title + + +def test_get_progress_table_and_progress_negative_length(): + """Test with negative length""" + with pytest.raises(ValueError): + get_progress_table_and_progress(-5, "Negative Progress") diff --git a/debug/accuracy_tools/msprobe/pytorch/nan_analyse/acs_msprobe/tools/const.py b/debug/accuracy_tools/msprobe/pytorch/nan_analyse/acs_msprobe/tools/const.py new file mode 100644 index 0000000000000000000000000000000000000000..a55352e7d931005574dcd1e1fffb8a5e98810b9c --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/nan_analyse/acs_msprobe/tools/const.py @@ -0,0 +1,101 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import enum +from collections import OrderedDict + +# log +DEFAULT_LOG_LEVEL = "WARNING" +ADVISOR_LOG_LEVEL = "ADVISOR_LOG_LEVEL" + +# file +API_DUMP_JSON_FILE_NAME = "dump.json" +API_STACK_JSON_FILE_NAME = "stack.json" + +FORWARD_FLAG = "forward" +BACKWARD_FLAG = "backward" +INPUT_FLAG = "input" +OUTPUT_FLAG = "output" +DISTRIBUTED = "Distributed" + + +# torch +TORCH_PROCESS_GROUP = "torch.ProcessGroup" +TORCH_DISTRIBUTED_PROCESS_GROUP = "torch.distributed.ProcessGroup" +VALID_TORCH_GROUP = [TORCH_PROCESS_GROUP, TORCH_DISTRIBUTED_PROCESS_GROUP] +TORCH_GROUP_RANKS = "group_ranks" +TORCH_GLOBAL_RANKS = "global_ranks" +TORCH_GROUP_ID = "group_id" + +# threshold +MINIMUM_ABS_THRESHOLD = 1e-3 +MAXIMUM_ABS_THRESHOLD = 1e3 + +DELIMITER = "." +CMP_REL_TOL = 0.5 +CMP_ABS_TOL = 0.1 + +DTYPE_CLASS_INT = "" +DTYPE_CLASS_FLOAT = "" +PYTORCH = "pytorch" + + +class DISTRIBUTEDENUM(enum.Enum): + ALL_TO_ALL = "all_to_all_single" + ALL_GATHER_BASE = "_all_gather_base" + ALL_GATHER = "all_gather" + ALL_GATHER_INTO_TENSOR = "all_gather_into_tensor" + ALL_REDUCE = "all_reduce" + REDUCE_SCATTER_BASE = "_reduce_scatter_base" + REDUCE_SCATTER_TENSOR = "reduce_scatter_tensor" + REDUCE = "reduce" + GATHER = "gather" + SCATTER = "scatter" + BROADCAST = "broadcast" + IRECV = "irecv" + ISEND = "isend" + RECV = "recv" + SEND = "send" + + +class COMPUTEENUM(enum.Enum): + NORMAL_INPUT_ABNORMAL_OUTPUT = "normal_input_abnormal_output" + + +DIST_MAPPING = OrderedDict( + base_distributed_handler=[ + DISTRIBUTEDENUM.ALL_TO_ALL.value, + DISTRIBUTEDENUM.ALL_GATHER.value, + DISTRIBUTEDENUM.ALL_GATHER_BASE.value, + DISTRIBUTEDENUM.ALL_GATHER_INTO_TENSOR.value, + DISTRIBUTEDENUM.REDUCE_SCATTER_BASE.value, + DISTRIBUTEDENUM.REDUCE_SCATTER_TENSOR.value + ], + broadcast_handler=[ + DISTRIBUTEDENUM.BROADCAST.value + ], + all_reduce_handler=[ + DISTRIBUTEDENUM.ALL_REDUCE.value + ], + scatter_handler=[ + DISTRIBUTEDENUM.SCATTER.value + ], + gather_handler=[ + DISTRIBUTEDENUM.GATHER.value + ], + reduce_handler=[ + DISTRIBUTEDENUM.REDUCE.value + ] +) \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/pytorch/nan_analyse/acs_msprobe/tools/log.py b/debug/accuracy_tools/msprobe/pytorch/nan_analyse/acs_msprobe/tools/log.py new file mode 100644 index 0000000000000000000000000000000000000000..f6bd6cbf41844ff31e9282a6584042b795a07ce5 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/nan_analyse/acs_msprobe/tools/log.py @@ -0,0 +1,48 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import logging +from rich.logging import RichHandler +from acs_msprobe.tools import const + + +def get_log_level(): + valid_log_levels = [ + logging.CRITICAL, + logging.ERROR, + logging.WARNING, + logging.INFO, + logging.DEBUG, + ] + log_level = os.getenv(const.ADVISOR_LOG_LEVEL, const.DEFAULT_LOG_LEVEL).upper() + if not hasattr(logging, log_level): + raise AttributeError(f"Invalid log level: {log_level}, supported log levels: {valid_log_levels}") + return log_level + + +def initialize_logger(): + log_level = get_log_level() + logging.basicConfig( + level=log_level, + format="%(message)s", + handlers=[RichHandler(show_time=False, show_level=False, show_path=False)] + ) + _logger = logging.getLogger("rich") + _logger.setLevel(log_level) + return _logger + + +logger = initialize_logger() \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/pytorch/nan_analyse/acs_msprobe/tools/utils.py b/debug/accuracy_tools/msprobe/pytorch/nan_analyse/acs_msprobe/tools/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..841cd38407efb76066df8699b13b5e72d093a789 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/nan_analyse/acs_msprobe/tools/utils.py @@ -0,0 +1,149 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +import os + +from rich.logging import RichHandler +from rich.panel import Panel +from rich.table import Table +from rich.progress import Progress, BarColum, TextColum, TimeElapsedColumn, TimeRemainingColumn + +from acs_msprobe.tools import const + + +def get_progress_table_and_progress(total_len, title): + overall_progress = Progress( + TextColum("[progress.description]{task.description}"), + BarColum(), + TextColum("{task.completed}/{task.total}"), + TextColum("Elapsed time:"), + TimeElapsedColumn(), + TextColum("Remaining time:"), + TimeRemainingColumn(compact=True), + expand=True + ) + + overall_task = overall_progress.add_task(f"Total Rank {total_len}", total=total_len) + progress_table = Table.grid() + progress_table.add_row(Panel.fit(overall_progress, title=title, border_style="green")) + return progress_table, overall_progress, overall_task + + +def get_current_rank(rank_info): + """ + Get the current rank of the process. + Returns: + int: The current rank of the process. + """ + if isinstance(rank_info, str): + if rank_info == "rank": + return 0 + if rank_info.startswith("rank"): + return int(rank_info.replace("rank", "")) + elif rank_info.isdigit(): + return int(rank_info) + elif isinstance(rank_info, int): + return rank_info + else: + raise ValueError("Invalid rank_info type.") + + +def is_int_number(api_static_list): + """ + Check if the given list contains only integers. + Args: + api_static_list (list): The list to check. + Returns: + bool: True if the list contains only integers, False otherwise. + """ + if len(api_static_list) == 1: + api_static_list = api_static_list[0] + if len(api_static_list) != 4: + raise ValueError("Invalid api_static_list length.") + _, api_dtype, api_shape, api_static_info = api_static_list + if api_dtype != "" or api_shape!= "[]": + return False, -1 + if max(api_static_info) == min(api_static_info): + return True, max(api_static_info) + return False, -1 + + +def safe_divide(numerator, denominator): + if denominator == 0: + return 0 + return numerator / denominator + + +def get_rel_err(bench, candidate): + return safe_divide(abs(bench - candidate), max(abs(bench), abs(candidate))) + + +def get_abs_err(bench, candidate): + return abs(bench - candidate) + + +def is_close(bench, candidate, rel_tol=1e-5, abs_tol=1e-5): + if math.isnan(bench) and math.isnan(candidate): + return True + if math.isinf(bench) and math.isinf(candidate): + return True + + rel_err = get_rel_err(bench, candidate) + abs_err = get_abs_err(bench, candidate) + if abs_err < const.MINIMUM_ABS_THRESHOLD or abs_err > const.MAXIMUM_ABS_THRESHOLD: + if rel_err > rel_tol or abs_err > abs_tol: + return False + return math.isclose(bench, candidate, rel_tol=rel_tol, abs_tol=abs_tol) + + +def is_all_close(bench, candidate, rel_tol=1e-5, abs_tol=1e-5): + if len(bench) != len(candidate): + return False + for _banch, _candidate in zip(bench, candidate): + if isinstance(_banch, list) and isinstance(_candidate, list): + continue + if _banch is None: + return _candidate is None + if isinstance(_banch, str) or isinstance(_candidate, str): + continue + if isinstance(_banch, bool) or isinstance(_candidate, bool): + return _banch == _candidate + + if _banch is None and _candidate is None: + continue + if _banch is None or _candidate is None: + return False + + if math.isnan(_banch) and not math.isnan(_candidate): + return False + if not math.isnan(_banch) and math.isnan(_candidate): + return False + if math.isinf(_banch) and math.isinf(_candidate): + return False + if not math.isinf(_banch) and math.isinf(_candidate): + return False + if not is_close(_banch, _candidate, rel_tol=rel_tol, abs_tol=abs_tol): + return False + return True + + +def is_distributed_api(api_name): + return "Distributed" in api_name + + +def load_variable(parameter, default, dtype): + if not os.environ.get(parameter, None): + return dtype(default) if dtype is not None else default + return dtype(os.environ.get(parameter)) if dtype is not None else os.environ.get(parameter) \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/pytorch/nan_analyse/analyze_dump_graph.py b/debug/accuracy_tools/msprobe/pytorch/nan_analyse/analyze_dump_graph.py new file mode 100644 index 0000000000000000000000000000000000000000..9a5f80205371599f7e43ac5dae8880eea56b233e --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/nan_analyse/analyze_dump_graph.py @@ -0,0 +1,337 @@ +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict, List, Set, Optional, Tuple, Callable +from enum import Enum +from dataclasses import dataclass +from collections import defaultdict, deque + +from msprobe.core.common.log import logger +from msprobe.pytorch.nan_analyse.api_info import APIInfo +from msprobe.pytorch.nan_analyse.pre_process_dump_data import process_on_all_ranks + + +class NodeType(Enum): + COMPUTE = "compute" + SEND = "send" + RECV = "recv" + COLLECTIVE = "collective" + + +class EdgeType(Enum): + SEQUENTIAL = "sequential" + COMMUNICATION = "communication" + + +@dataclass +class Node: + node_id: str # (rank_id:api_name) + rank: int + api_info: APIInfo + node_type: NodeType + + def __hash__(self): + return hash(self.node_id) + + def __eq__(self, other): + return isinstance(other, Node) and self.node_id == other.node_id + + def __str__(self): + return self.node_id + + +class Edge: + def __init__(self, src: Node, dst: Node, edge_type: EdgeType = EdgeType.SEQUENTIAL): + self.src = src + self.dst = dst + self.edge_type = edge_type + self.edge_id = self.__generate_edge_name() + + def __generate_edge_name(self): + return f'{self.src.node_id}_{self.dst.node_id}' + + +class DistributedComputeGraph: + def __init__(self): + self.nodes: Dict[str, Node] = {} + self.edges: Dict[str, Edge] = {} + self.adj_list: Dict[Node, List[Node]] = defaultdict(list) + self.rank_to_nodes: Dict[int, List[Node]] = {} + # 添加入度统计 + self.in_degrees: Dict[Node, int] = defaultdict(int) + + def add_node(self, node: Node): + self.nodes[node.node_id] = node + if not self.rank_to_nodes.get(node.rank): + self.rank_to_nodes[node.rank] = [] + self.rank_to_nodes[node.rank].append(node) + + def add_edge(self, src: Node, dst: Node, edge_type: EdgeType = EdgeType.SEQUENTIAL): + edge = Edge(src, dst, edge_type) + # 边去重 + if self.edges.get(edge.edge_id): + return + self.edges[edge.edge_id] = edge + self.adj_list[src].append(dst) + # 更新入度 + self.in_degrees[dst] += 1 + + def get_node(self, node_id: str) -> Optional[Node]: + return self.nodes.get(node_id) + + def get_nodes_by_rank(self, rank_id: int) -> List[Node]: + return self.rank_to_nodes.get(rank_id, []) + + def get_start_nodes(self) -> List[Node]: + """获取所有入度为0的节点或者每个rank上首个节点""" + start_nodes = [node for node in self.nodes.values() if self.in_degrees[node] == 0] + if not start_nodes: + return self._get_first_nodes() + return start_nodes + + def _get_first_nodes(self): + first_nodes = [] + for rank in self.rank_to_nodes.keys(): + first_nodes.extend(self.__get_first_node_by_rank(rank)) + return first_nodes + + def __get_first_node_by_rank(self, rank): + nodes = self.rank_to_nodes.get(rank, []) + if not nodes: + return [] + return nodes[:1] + + +class GraphBuilder: + @staticmethod + def create_node(rank: int, api_info: APIInfo) -> Node: + node_id = f"{rank}:{api_info.api_name}" + + if api_info.is_communication_op: + if "send" in api_info.api_name.lower(): + node_type = NodeType.SEND + elif "recv" in api_info.api_name.lower(): + node_type = NodeType.RECV + else: + node_type = NodeType.COLLECTIVE + else: + node_type = NodeType.COMPUTE + + return Node(node_id, rank, api_info, node_type) + + @staticmethod + def build_graph(rank_ops_data: Dict[int, Dict]) -> DistributedComputeGraph: + graph = DistributedComputeGraph() + + # Step 1: Create all nodes + rank_nodes: Dict[int, List[Node]] = {} + for rank, ops in rank_ops_data.items(): + rank_nodes[rank] = [] + for _, api_info in ops.items(): + node = GraphBuilder.create_node(rank, api_info) + graph.add_node(node) + rank_nodes[rank].append(node) + + # Step 2: Connect sequential operations within each rank + for _, nodes in rank_nodes.items(): + for i in range(len(nodes) - 1): + graph.add_edge(nodes[i], nodes[i + 1], EdgeType.SEQUENTIAL) + + # Step 3: Connect communication operations between ranks + GraphBuilder._connect_p2p_operations(graph, rank_nodes) + GraphBuilder._connect_collective_operations(graph, rank_nodes) + + return graph + + @staticmethod + def _connect_p2p_operations(graph: DistributedComputeGraph, rank_nodes: Dict[int, List[Node]]): + match_list = [] + + for nodes in rank_nodes.values(): + match_list.extend(node for node in nodes if node.node_type in (NodeType.SEND, NodeType.RECV)) + + for node in match_list: + if not node.api_info.pg: + continue + + for rank in node.api_info.pg: + if rank == node.api_info.cur_rank: + continue + + for candi_node in graph.get_nodes_by_rank(rank): + if GraphBuilder._match_comm_ops(node, candi_node): + graph.add_edge(node, candi_node, EdgeType.COMMUNICATION) + break + + @staticmethod + def _connect_collective_operations(graph: DistributedComputeGraph, rank_nodes: Dict[int, List[Node]]): + collective_groups: Dict[str, List[Node]] = defaultdict(list) + + # Group collective operations by their process group + for nodes in rank_nodes.values(): + for node in nodes: + if node.node_type == NodeType.COLLECTIVE: + group_key = GraphBuilder._get_process_group_key(node.api_info) + collective_groups[group_key].append(node) + + # Connect nodes in the same collective operation + for group in collective_groups.values(): + for i, node_i in enumerate(group): + for j, node_j in enumerate(group): + if i >= j: + continue + graph.add_edge(node_i, node_j, EdgeType.COMMUNICATION) + graph.add_edge(node_j, node_i, EdgeType.COMMUNICATION) # Bidirectional for collectives + + @staticmethod + def _match_comm_ops(no1: Node, no2: Node) -> bool: + return no1.api_info == no2.api_info + + @staticmethod + def _get_process_group_key(api_info: APIInfo) -> str: + return api_info.process_group_id + + +class SortStrategy(Enum): + CALL_INDEX = "call_index" + RANK = "rank" + API_NAME = "api_name" + + +class GraphTraversal: + + @staticmethod + def sort_levels(levels: List[List[Node]], strategy: SortStrategy = SortStrategy.CALL_INDEX) -> List[List[Node]]: + """ + 对每一层的节点进行排序 + Args: + levels: 层次遍历的结果 + strategy: 排序策略 + Returns: + sorted_levels: 排序后的层次结果 + """ + sort_key = GraphTraversal._get_sort_key(strategy) + return [sorted(level, key=sort_key) for level in levels] + + @staticmethod + def bfs_by_level(graph: DistributedComputeGraph) -> List[List[Node]]: + """ + 使用BFS进行层次遍历,返回每一层的节点列表 + Args: + graph: 分布式计算图 + Returns: + levels: 每一层节点的列表的列表 + """ + start_nodes = graph.get_start_nodes() + if not start_nodes: + return [[]] + + # 记录已访问的节点和它们所在的层级 + visited = {} # 节点 -> 层级的映射 + current_level = 0 + levels = [[]] # 初始层包含起始节点 + queue = deque() # (节点, 层级)的队列 + + for n in start_nodes: + visited[n] = 0 + levels[0].append(n) + queue.append((n, 0)) + + while queue: + node, level = queue.popleft() + + # 如果遇到新的层级,创建新的层级列表 + if level > current_level: + current_level = level + levels.append([]) + + # 遍历邻接节点 + for neighbor in graph.adj_list[node]: + # 如果邻接节点未访问过,或者在更深的层级遇到了它 + if neighbor not in visited or visited[neighbor] > level + 1: + visited[neighbor] = level + 1 + queue.append((neighbor, level + 1)) + # 将节点添加到对应层级的列表中 + if len(levels) <= level + 1: + levels.append([]) + if neighbor not in levels[level + 1]: + levels[level + 1].append(neighbor) + + return levels + + @staticmethod + def get_node_info(node: Node) -> str: + """ + 获取节点的详细信息,用于调试和打印 + """ + return (f"Node(id={node.node_id}, rank={node.rank}, call_index={node.api_info.call_index}, " + f"type={node.node_type.value})") + + @staticmethod + def print_levels_info(levels: List[List[Node]]): + """ + 打印每一层的节点信息 + """ + logger.info("Level visit results:") + for i, level in enumerate(levels): + logger.info(f"level {i}:") + for node in level: + logger.info(f"node: {GraphTraversal.get_node_info(node)}") + + @staticmethod + def print_cycles_info(cycles: Set[Tuple[Node, Node]]): + """ + 打印检测到的环信息 + """ + logger.info("\n检测到的环:") + for source, target in cycles: + logger.info(f"环: {GraphTraversal.get_node_info(source)} -> {GraphTraversal.get_node_info(target)}") + + @staticmethod + def _get_sort_key(strategy: SortStrategy) -> Callable[[Node], any]: + """Get the sort key function based on the sorting strategy""" + if strategy == SortStrategy.CALL_INDEX: + return lambda node: (node.api_info.call_index, node.rank) + elif strategy == SortStrategy.RANK: + return lambda node: node.rank + elif strategy == SortStrategy.API_NAME: + return lambda node: node.api_info.api_name + else: + return lambda node: node.api_info.call_index # Default to call_index + + +def traverse_graph(graph: DistributedComputeGraph, sort_strategy: SortStrategy = SortStrategy.CALL_INDEX): + levels, cycles = GraphTraversal.bfs_by_level(graph), set() + sorted_levels = GraphTraversal.sort_levels(levels, sort_strategy) + + GraphTraversal.print_levels_info(sorted_levels) + GraphTraversal.print_cycles_info(cycles) + + return levels, cycles + + +def main(): + file_path = 'test_data/all_reduce_data' + # Load your data as before + data = process_on_all_ranks(file_path) + + # Build the graph + graph = GraphBuilder.build_graph(data) + + # Traverse the graph + _, _ = traverse_graph(graph) + + +if __name__ == '__main__': + main() diff --git a/debug/accuracy_tools/msprobe/pytorch/nan_analyse/analyze_pp_partition.py b/debug/accuracy_tools/msprobe/pytorch/nan_analyse/analyze_pp_partition.py new file mode 100644 index 0000000000000000000000000000000000000000..59e6952ce6a16260f035289f2af42d0746912436 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/nan_analyse/analyze_pp_partition.py @@ -0,0 +1,172 @@ +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import defaultdict +from typing import Dict, List, Set, Optional + +from msprobe.core.common.log import logger +from msprobe.pytorch.nan_analyse.api_info import APIInfo +from msprobe.pytorch.nan_analyse.pre_process_dump_data import process_on_all_ranks +from msprobe.pytorch.nan_analyse.utils import singleton + + +MAX_RECURSIVE_DEPTH = 100 + + +def __is_send_op(op_name: str) -> bool: + if op_name.startswith('Distributed.') and 'send.' in op_name: + return True + return False + + +def __is_recv_op(op_name: str) -> bool: + if op_name.startswith('Distributed.') and 'recv.' in op_name: + return True + return False + + +def _is_first_send_op(op_name: str) -> bool: + if __is_send_op(op_name) and 'send.0' in op_name: + return True + return False + + +def _is_first_recv_op(op_name: str) -> bool: + if __is_recv_op(op_name) and 'recv.0' in op_name: + return True + return False + + +@singleton +class PPAnalyzer: + def __init__(self, rank_data: Dict[int, dict]): + # 初始化rank_to_data字典,rank_id --> dump_data + self.rank_to_data = rank_data + self.rank_to_stage = {} # 存储rank对应的pipeline stage + self.send_recv_pairs = defaultdict(list) # 存储send-recv配对信息 + + @staticmethod + def _find_start_ranks(rank_graph: Dict[int, Set[int]]) -> List[int]: + """找到没有入边的rank(pipeline的起始rank)""" + all_ranks = set(rank_graph.keys()) + target_ranks = set() + for ranks in rank_graph.values(): + target_ranks.update(ranks) + return list(all_ranks - target_ranks) + + @staticmethod + def _get_target_rank(op_info: APIInfo) -> Optional[int]: + """从send操作中提取目标rank""" + kwargs = op_info.input_kwargs + if 'dst' in kwargs: + return int(kwargs['dst'].get('value')) + return None + + @staticmethod + def _get_source_rank(op_info: APIInfo) -> Optional[int]: + """从recv操作中提取源rank""" + kwargs = op_info.input_kwargs + if 'src' in kwargs: + return kwargs['src'].get('value') + return None + + def get_pp_stages(self) -> Dict[int, List[int]]: + """获取每个stage包含的ranks""" + stages = defaultdict(list) + for rank, stage in self.rank_to_stage.items(): + stages[stage].append(rank) + return dict(stages) + + def analyze(self): + self.analyze_send_recv() + self.determine_pp_stages() + + def analyze_send_recv(self): + """分析所有rank的send和recv操作""" + rank_data = self.rank_to_data + for cur_rank, data in rank_data.items(): + self._analyze_cur_rank(cur_rank, data) + + def determine_pp_stages(self): + """确定每个rank属于哪个pipeline stage""" + # 构建rank之间的依赖关系图 + rank_graph = defaultdict(set) + for rank, pairs in self.send_recv_pairs.items(): + for op_type, other_rank in pairs: + if op_type == 'send': + rank_graph[rank].add(other_rank) + + # 没有send、recv操作,所有的rank属于同一个stage + if not rank_graph: + all_ranks = set(self.rank_to_data.keys()) + for rank in all_ranks: + self.rank_to_stage[rank] = 0 + return + + # 使用拓扑排序确定stage + visited = set() + + def dfs(rank_id: int, stage: int): + if stage >= MAX_RECURSIVE_DEPTH: + raise ValueError("Recursive depth exceeds the limit") + + if rank_id in visited: + return + visited.add(rank_id) + self.rank_to_stage[rank_id] = stage + + # 遍历所有下一个rank + for next_rank in rank_graph[rank_id]: + dfs(next_rank, stage + 1) + + # 找到起始rank(入度为0的节点)为首个PP stage + start_ranks = self._find_start_ranks(rank_graph) + for start_rank in start_ranks: + dfs(start_rank, 0) + + def _analyze_cur_rank(self, cur_rank: int, data: Dict[str, APIInfo]): + if not data: + return + + for op_name, op_info in data.items(): + if _is_first_send_op(op_name): + target_rank = self._get_target_rank(op_info) + if target_rank is None or target_rank < cur_rank: # 仅添加大于cur_rank的send操作,保证所有都是前向 + continue + self.send_recv_pairs[cur_rank].append(('send', target_rank)) + + # 不采集rcv的通信算子,仅仅从send数据分析,rcv算子用于做validation + elif _is_first_recv_op(op_name): + source_rank = self._get_source_rank(op_info) + if source_rank is None: + continue + + +def main(): + file_path = 'test_data/send_recv' + data = process_on_all_ranks(file_path) + + # 分析pp stage + analyzer = PPAnalyzer(data) + analyzer.analyze() + + pp_stages = analyzer.get_pp_stages() + + logger.info("Pipeline Parallel Stages:") + for stage, ranks in sorted(pp_stages.items()): + logger.info(f"Stage {stage}: Ranks {sorted(ranks)}") + + +if __name__ == "__main__": + main() diff --git a/debug/accuracy_tools/msprobe/pytorch/nan_analyse/analyzer.py b/debug/accuracy_tools/msprobe/pytorch/nan_analyse/analyzer.py new file mode 100644 index 0000000000000000000000000000000000000000..9afe1dc16a674817c7d5fc066c4d2be81c1cc3d2 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/nan_analyse/analyzer.py @@ -0,0 +1,65 @@ +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from msprobe.pytorch.nan_analyse.analyze_dump_graph import GraphBuilder, GraphTraversal +from msprobe.pytorch.nan_analyse.pre_process_dump_data import process_on_all_ranks + + +class HeapDumpAnalyzer: + def __init__(self, dump_file_path): + """初始化分析器 + Args: + dump_file_path (str): 堆转储文件的路径 + """ + self.dump_file_path = dump_file_path + self.processed_data = None + self.analysis_results = None + self.graph = None + self.visited_levels = None + + def pre_process(self): + """预处理dump文件 + Returns: + 处理后的数据结构 + """ + self.processed_data = process_on_all_ranks(self.dump_file_path) + self.graph = GraphBuilder.build_graph(self.processed_data) + + def analyze_graph(self): + """分析预处理后的数据 + Returns: + 分析结果 + """ + if self.processed_data is None or self.graph is None: + raise ValueError("Data or graph is not processed yet") + self.visited_levels = GraphTraversal.bfs_by_level(self.graph) + + def post_process(self): + """获取分析结果""" + self.analysis_results = GraphTraversal.sort_levels(self.visited_levels) + + def apply(self): + """执行完整的分析流程""" + self.pre_process() + + self.analyze_graph() + + self.post_process() + return self.analysis_results + + +if __name__ == "__main__": + analyzer = HeapDumpAnalyzer("test_data/send_recv") + results = analyzer.apply() + GraphTraversal.print_levels_info(results) diff --git a/debug/accuracy_tools/msprobe/pytorch/nan_analyse/api_info.py b/debug/accuracy_tools/msprobe/pytorch/nan_analyse/api_info.py new file mode 100644 index 0000000000000000000000000000000000000000..17fdc88a0abadc3a8e9e3cb12012eb0a6d3d9213 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/nan_analyse/api_info.py @@ -0,0 +1,164 @@ +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass + +from typing import Dict, List, Union, Any + +from msprobe.core.common.const import Const +from msprobe.core.overflow_check.filter import IgnoreFilter +from msprobe.pytorch.nan_analyse.utils import singleton, has_nan_inf, generate_hash + + +def is_comm_api_name_match(bench_api_name, cmp_api_name): + if 'send' in bench_api_name and 'recv' in cmp_api_name: + return True + if 'recv' in bench_api_name and 'send' in cmp_api_name: + return True + return bench_api_name == cmp_api_name + + +@dataclass +class APIInfo: + input_kwargs: Dict + output_data: List[Dict] + api_name: str + torch_api_name: str + input_args: List[Dict] + call_index: int + is_communication_op: bool + + cur_rank: int + process_group_id = str + + def __init__(self, api_name, input_args=None, input_kwargs=None, output_data=None, call_index=0, cur_rank=None): + self.input_kwargs = input_kwargs + self.output_data = output_data + self.api_name = api_name + self.input_args = input_args + self.call_index = call_index + self.cur_rank = cur_rank + self.torch_api_name = self.__extract_torch_api(self.api_name) + self.is_communication_op = self.__is_communication_operators() + self.pg, self.process_group_id = self.__generate_pg_id() + + def __eq__(self, other): + if not self.is_communication_op or not other.is_communication_op: + return False + + if not is_comm_api_name_match(self.torch_api_name, other.torch_api_name): + return False + + if self.torch_api_name != other.torch_api_name: + return False + if self.process_group_id != other.process_group_id: + return False + return True + + @staticmethod + def __extract_torch_api(api_name) -> str: + """ + Process tensor api name to extract first two fields in lowercase. + """ + # Empty string checking + if not api_name.strip(): + return "" + + parts = api_name.split(Const.SEP) + + # Handle different cases based on number of parts + if len(parts) == 0: + return "" + elif len(parts) == 1: + return parts[0].lower() + else: + return Const.SEP.join(parts[:2]).lower() + + def __is_communication_operators(self) -> bool: + # 定义通信算子的关键字,覆盖各种通信操作,如all_reduce, send, broadcast等 + # 从wrap文件中读取,先硬编码在文件中 + communication_keywords = [ + 'send', # send 算子 + 'recv', # recv 算子 + 'broadcast', # broadcast 算子 + 'all_reduce', # all_reduce 算子 + 'reduce', # reduce 算子 + 'all_gather', # all_gather 算子 + 'gather', # gather 算子 + 'isend', # isend 算子 + 'irecv', # irecv 算子 + 'scatter', # scatter 算子 + 'reduce_scatter', # reduce_scatter 算子 + '_reduce_scatter_base', # _reduce_scatter_base 算子 + '_all_gather_base', # _all_gather_base 算子 + 'all_to_all_single', # all_to_all_single 算子 + 'all_to_all', # all_to_all 算子 + 'all_gather_into_tensor', # all_gather_into_tensor 算子 + 'reduce_scatter_tensor' # reduce_scatter_tensor 算子 + ] + + # 是否以Distributed开头,并且算子名包含上述通信算子 + return (any(keyword in self.api_name for keyword in communication_keywords) or + self.api_name.startswith('Distributed.')) + + def __generate_pg_id(self): + if not self.is_communication_op: + return None, None + + process_group: List[int] = [] + if 'send' in self.api_name: + dst = int(self.input_kwargs.get('dst', {}).get('value')) + process_group.extend([self.cur_rank, dst]) + elif 'recv' in self.api_name: + src = int(self.input_kwargs.get('src', {}).get('value')) + process_group.extend([src, self.cur_rank]) + else: + process_group.extend(self.input_kwargs.get('group_ranks', [])) + + # 暂时直接使用调用的次数,而忽略pg的匹配 + call_cnt = self.api_name.split('.')[-2] + fmt = f'{call_cnt}_{str(process_group)}' + + return process_group, generate_hash(fmt) + + +@singleton +class AnomalyDetector: + def __init__(self): + self._filter = IgnoreFilter() + + @staticmethod + def _has_anomaly(data: Union[Dict, Any]) -> bool: + return has_nan_inf(data) + + def has_input_anomaly(self, api_data) -> bool: + """检查输入是否有异常(包括args和kwargs)""" + # args + args_anomaly = any(self._has_anomaly(x) for x in api_data.input_args if isinstance(x, dict)) + # kwargs + kwargs_anomaly = any(self._has_anomaly(x) for x in api_data.input_kwargs.values() if isinstance(x, dict)) + return args_anomaly or kwargs_anomaly + + def has_output_anomaly(self, api_data) -> bool: + """检查输出是否有异常""" + return any(self._has_anomaly(x) for x in api_data.output_data if isinstance(x, dict)) + + def has_overflow(self, data: APIInfo) -> bool: + # 输入输出不存在nan、inf,不存在溢出 + if not (self.has_input_anomaly(data) or self.has_output_anomaly(data)): + return False + # 是否真的溢出,并且对计算结果造成影响 + if self._filter.apply_filter(data): + return False + return True diff --git a/debug/accuracy_tools/msprobe/pytorch/nan_analyse/pre_process_dump_data.py b/debug/accuracy_tools/msprobe/pytorch/nan_analyse/pre_process_dump_data.py new file mode 100644 index 0000000000000000000000000000000000000000..73815f8b922d3c6e7ff5d7a8505524ad61a4ee15 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/nan_analyse/pre_process_dump_data.py @@ -0,0 +1,100 @@ +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import re +from typing import Any, Dict +from collections import OrderedDict + +from msprobe.core.common.const import Const +from msprobe.core.common.log import logger +from msprobe.core.common.file_utils import load_json +from msprobe.pytorch.nan_analyse.api_info import APIInfo, AnomalyDetector + + +def _create_api_info(api_name: str, _data: Dict, call_index: int = 0, cur_rank: int = 0) -> APIInfo: + """从原始数据创建APIInfo实例""" + return APIInfo( + api_name=api_name, + input_args=_data.get(Const.INPUT_ARGS, []), + input_kwargs=_data.get(Const.INPUT_KWARGS, {}), + output_data=_data.get(Const.OUTPUT, []), + call_index=call_index, + cur_rank=cur_rank + ) + + +def extract_essential_operators(dump_data: Any, cur_rank: int, common_overflow_num=5): + """ + 减少内存占用,仅筛选出溢出、通信算子等,用于下一步构图 + """ + # 从数据中提取通信算子和nan等溢出问题算子,使用顺序dict保存结果 + # order dict性能与list+dict性能比较,是否对这里进行改造 + extract_opts = OrderedDict() + detector = AnomalyDetector() # 单例,无额外内存占用 + cnt = 0 + index = 0 + for api_name, value in dump_data.get('data', {}).items(): + api_info = _create_api_info(api_name, value, call_index=index, cur_rank=cur_rank) + index += 1 + + is_overflow, is_comm_op = detector.has_overflow(api_info), api_info.is_communication_op + if cnt < common_overflow_num and is_overflow: + extract_opts[api_name] = api_info + cnt += 1 + continue + + return extract_opts + + +def process_on_all_ranks(base_path: str): + all_rank_ops_data = {} + + # 获取所有rank目录 + for rank_dir in os.listdir(base_path): + rank_path = os.path.join(base_path, rank_dir) + if not os.path.isdir(rank_path) or not rank_dir.startswith('rank'): + logger.warning(f"{rank_dir} is not a valid rank directory.") + continue + + dump_file = os.path.join(rank_path, 'dump.json') + if not os.path.exists(dump_file): + logger.warning(f"{dump_file} does not exist for {rank_dir}") + continue + + rank_id = get_rank_id(rank_dir) + dump_data = load_json(dump_file) + op_list = extract_essential_operators(dump_data, rank_id) + + if op_list: + all_rank_ops_data[rank_id] = op_list + else: + logger.warning(f"No essential operators found for {rank_id}") + + return all_rank_ops_data + + +def get_rank_id(rank_dir: str) -> int: + match = re.search(r'rank(\d+)', rank_dir) + + if not match: + raise ValueError(f"Invalid rank directory: {rank_dir}") + return int(match.group(1)) + + +if __name__ == '__main__': + file_path = 'test_data/all_reduce_data' + + data = process_on_all_ranks(file_path) + logger.info(data) diff --git a/debug/accuracy_tools/msprobe/pytorch/nan_analyse/utils.py b/debug/accuracy_tools/msprobe/pytorch/nan_analyse/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2eb54dc488ccbd98ddba132a285a99655deba815 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/nan_analyse/utils.py @@ -0,0 +1,50 @@ +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import hashlib +from typing import Any + + +CHECK_FIELDS = ['Max', 'Min', 'Mean'] +OVERFLOW_VALUES = ['inf', '-inf', 'nan'] + + +def singleton(cls): + """ + :param cls: any class + :return: singleton handle + """ + _instance = {} + + def _singleton(*args: any, **kw: any) -> any: + if cls not in _instance: + _instance[cls] = cls(*args, **kw) + return _instance.get(cls) + + return _singleton + + +def has_nan_inf(value: Any) -> bool: + """检查值是否包含NaN或Inf""" + if isinstance(value, dict): + for k, v in value.items(): + if k in CHECK_FIELDS and str(v).lower() in OVERFLOW_VALUES: + return True + return False + + +def generate_hash(input_string): + sha256_hash = hashlib.sha256() + sha256_hash.update(input_string.encode('utf-8')) + return sha256_hash.hexdigest() diff --git a/debug/accuracy_tools/msprobe/pytorch/online_dispatch/compare.py b/debug/accuracy_tools/msprobe/pytorch/online_dispatch/compare.py index 7a265e70fa4cbe95c897c35d68e4afa8ebd77249..18d8e0f1d0ab00fb723eafa9d0dc17d92bd164a6 100644 --- a/debug/accuracy_tools/msprobe/pytorch/online_dispatch/compare.py +++ b/debug/accuracy_tools/msprobe/pytorch/online_dispatch/compare.py @@ -125,8 +125,6 @@ class Saver: def write_summary_csv(self, test_result): test_rows = [] - if self.stack_info: - test_rows[0].append(self.COLUMN_STACK_INFO) check_op_str_pattern_valid(test_result.api_name) df_row = [test_result.api_name, test_result.is_fwd_success, test_result.is_bwd_success] diff --git a/debug/accuracy_tools/msprobe/pytorch/online_dispatch/dispatch.py b/debug/accuracy_tools/msprobe/pytorch/online_dispatch/dispatch.py index b9201cfaac74e38bbbaee468b6c452895f8b38f9..916a68aece20ba620877004d25b15bbbcc01c41e 100644 --- a/debug/accuracy_tools/msprobe/pytorch/online_dispatch/dispatch.py +++ b/debug/accuracy_tools/msprobe/pytorch/online_dispatch/dispatch.py @@ -16,6 +16,7 @@ import json import os import time +import multiprocessing from multiprocessing import Pool import torch @@ -52,6 +53,7 @@ class PtdbgDispatch(TorchDispatchMode): return if dump_path is None: logger.error("Please set dump_path when dump_mode is config!") + raise DispatchException("Please set dump_path when dump_mode is config!") check_file_or_directory_path(dump_path, True) self.device_id = torch_npu._C._npu_getDevice() @@ -85,6 +87,11 @@ class PtdbgDispatch(TorchDispatchMode): self.get_ops(yaml_path) self.lock = None + max_process_num = max(int((multiprocessing.cpu_count() + 1) // Const.CPU_QUARTER), 1) + if process_num > max_process_num: + logger.error(f"process_num should be less than or equal to {max_process_num}, but got {process_num}!") + raise DispatchException(f'process_num should be less than or equal to {max_process_num}, ' + f'but got {process_num}!') if process_num > 0: self.pool = Pool(process_num) if debug: @@ -115,6 +122,8 @@ class PtdbgDispatch(TorchDispatchMode): if len(json_line_data) == 0: break msg = json.loads(json_line_data) + if len(msg) < 2: + raise ValueError("JSON data does not contain enough elements. Expected at least 2 elements.") self.all_summary[msg[0]] = msg[1] fp_handle.close() @@ -199,8 +208,10 @@ class PtdbgDispatch(TorchDispatchMode): dispatch_workflow(run_param, data_info) else: self.lock.acquire() - self.all_summary.append([]) - self.lock.release() + try: + self.all_summary.append([]) + finally: + self.lock.release() run_param.process_flag = True if self.check_fun(func, run_param): data_info = DisPatchDataInfo(cpu_args, cpu_kwargs, self.all_summary, None, npu_out_cpu, cpu_out, diff --git a/debug/accuracy_tools/msprobe/pytorch/online_dispatch/dump_compare.py b/debug/accuracy_tools/msprobe/pytorch/online_dispatch/dump_compare.py index b185bc1110d4062d8a31b9cc94dc946d8fb8456c..dbf7626a2710a3f10ddc8d45795988b89081d0d5 100644 --- a/debug/accuracy_tools/msprobe/pytorch/online_dispatch/dump_compare.py +++ b/debug/accuracy_tools/msprobe/pytorch/online_dispatch/dump_compare.py @@ -19,6 +19,8 @@ import os from datetime import datetime, timezone import torch +from msprobe.core.common.const import Const +from msprobe.core.common.decorator import recursion_depth_decorator from msprobe.core.common.file_utils import FileOpen, save_npy, save_json from msprobe.pytorch.common.log import logger @@ -91,6 +93,7 @@ def support_basic_type(data): return False +@recursion_depth_decorator("dump_data") def dump_data(data, prefix, dump_path): if isinstance(data, (tuple, list)) and data: for i, item in enumerate(data): @@ -107,8 +110,11 @@ def dump_data(data, prefix, dump_path): def save_temp_summary(api_index, single_api_summary, path, lock): summary_path = os.path.join(path, f'summary.json') lock.acquire() - data = [api_index, single_api_summary] - save_json(summary_path, data, mode='a') + try: + data = [api_index, single_api_summary] + save_json(summary_path, data, mode='a') + finally: + lock.release() def dispatch_workflow(run_param: DispatchRunParam, data_info: DisPatchDataInfo): diff --git a/debug/accuracy_tools/msprobe/pytorch/online_dispatch/utils.py b/debug/accuracy_tools/msprobe/pytorch/online_dispatch/utils.py index ae8b9435a34ced607d4e70fab615b2b017083fe9..37105551a3bccca548fe2b6594f4848324746b49 100644 --- a/debug/accuracy_tools/msprobe/pytorch/online_dispatch/utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/online_dispatch/utils.py @@ -27,8 +27,10 @@ else: pta_cpu_device = torch.device("cpu") from msprobe.core.common.const import CompareConst +from msprobe.core.common.decorator import recursion_depth_decorator from msprobe.pytorch.common.log import logger + cpu_device = torch._C.device("cpu") COLOR_RED = '\033[31m' COLOR_GREEN = '\033[32m' @@ -85,6 +87,7 @@ def get_callstack(): return callstack +@recursion_depth_decorator("data_to_cpu") def data_to_cpu(data, deep, data_cpu): global cpu_device list_cpu = [] diff --git a/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/interactive_cli.py b/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/interactive_cli.py index ac6f3d234e3a6681a580f16e56d94204223102f1..7f08b7929cd46961cb5850f16aa6ad7d7eace533 100644 --- a/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/interactive_cli.py +++ b/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/interactive_cli.py @@ -45,12 +45,7 @@ class InteractiveCli(cmd.Cmd): @catch_exception def default(self, line=""): - self.util.execute_command(line) - return False - - @catch_exception - def do_run(self, line=""): - self.util.execute_command(line) + self.stdout.write("Command invalid, Only support command start with cad/vc/dc/pk/cn/pt\n") @catch_exception def do_vc(self, line=""): diff --git a/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/utils.py b/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/utils.py index 66229d36b8d0b532eea48f1aa5d96e178ed80cdc..2cdfe6f5106b46cbc8f69e492d9977c3f905f3b2 100644 --- a/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/utils.py @@ -13,12 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -import hashlib import os import re import subprocess import sys import time +import zlib from collections import namedtuple import numpy as np @@ -37,8 +37,6 @@ try: from rich.table import Table from rich import print as rich_print from rich.columns import Columns - - install() except ImportError as err: install = None Panel = None @@ -114,11 +112,12 @@ class Util: @staticmethod def get_md5_for_numpy(obj): np_bytes = obj.tobytes() - md5_hash = hashlib.md5(np_bytes) - return md5_hash.hexdigest() + md5_crc = zlib.crc32(np_bytes) + return f"{md5_crc:08x}" @staticmethod def deal_with_dir_or_file_inconsistency(output_path): + logger.warning(f"Trying to delete {output_path}") remove_path(output_path) raise ParseException("Inconsistent directory structure or file.") @@ -264,7 +263,7 @@ class Util: match = re_pattern.match(name) if not match: continue - if extern_pattern != '' and re_pattern.match(extern_pattern) and not re.match(extern_pattern, name): + if extern_pattern != '' and re_pattern.match(extern_pattern) and not name.startswith(extern_pattern): continue file_list[name] = gen_info_func(name, match, file["root"]) return file_list diff --git a/debug/accuracy_tools/msprobe/pytorch/pt_config.py b/debug/accuracy_tools/msprobe/pytorch/pt_config.py index 8293ac969490b103eef630081b6001234ca8bb07..b44564d7c7dd401a62ee91869bec95a4ba621b47 100644 --- a/debug/accuracy_tools/msprobe/pytorch/pt_config.py +++ b/debug/accuracy_tools/msprobe/pytorch/pt_config.py @@ -16,9 +16,9 @@ import os import re -from msprobe.core.common.const import Const +from msprobe.core.common.const import Const, FileCheckConst from msprobe.core.common.exceptions import MsprobeException -from msprobe.core.common.file_utils import FileOpen, load_json, check_file_or_directory_path, check_crt_valid +from msprobe.core.common.file_utils import FileOpen, load_json, check_file_or_directory_path, FileChecker from msprobe.core.common.log import logger from msprobe.core.common.utils import is_int from msprobe.core.common_config import BaseConfig, CommonConfig @@ -42,6 +42,7 @@ class TensorConfig(BaseConfig): self.tls_path = json_config.get("tls_path", "./") self.online_run_ut_recompute = json_config.get("online_run_ut_recompute", False) self.check_config() + self._check_summary_mode() self._check_file_format() if self.online_run_ut: self._check_online_run_ut() @@ -65,7 +66,10 @@ class TensorConfig(BaseConfig): check_file_or_directory_path(self.tls_path, isdir=True) check_file_or_directory_path(os.path.join(self.tls_path, "client.key")) check_file_or_directory_path(os.path.join(self.tls_path, "client.crt")) - check_crt_valid(os.path.join(self.tls_path, "client.crt")) + check_file_or_directory_path(os.path.join(self.tls_path, "ca.crt")) + crl_path = os.path.join(self.tls_path, "crl.pem") + if os.path.exists(crl_path): + check_file_or_directory_path(crl_path) if not isinstance(self.host, str) or not re.match(Const.ipv4_pattern, self.host): raise Exception(f"host: {self.host} is invalid.") @@ -80,9 +84,8 @@ class StatisticsConfig(BaseConfig): self.check_config() self._check_summary_mode() - def _check_summary_mode(self): - if self.summary_mode and self.summary_mode not in ["statistics", "md5"]: - raise Exception("summary_mode is invalid") + self.tensor_list = json_config.get("tensor_list", []) + self._check_str_list_config(self.tensor_list, "tensor_list") class OverflowCheckConfig(BaseConfig): @@ -95,6 +98,8 @@ class OverflowCheckConfig(BaseConfig): def check_overflow_config(self): if self.overflow_nums is not None and not is_int(self.overflow_nums): raise Exception("overflow_num is invalid") + if self.overflow_nums is not None and self.overflow_nums != -1 and self.overflow_nums <= 0: + raise Exception("overflow_nums should be -1 or positive integer") if self.check_mode is not None and self.check_mode not in ["all", "aicore", "atomic"]: raise Exception("check_mode is invalid") @@ -148,7 +153,7 @@ class FreeBenchmarkCheckConfig(BaseConfig): self.pert_mode in PytorchFreeBenchmarkConst.CPU_MODE_LIST ): msg = ( - f"You neet to and can only set fuzz_device as {DeviceType.CPU} " + f"You need to and can only set fuzz_device as {DeviceType.CPU} " f"when pert_mode in {PytorchFreeBenchmarkConst.CPU_MODE_LIST}" ) logger.error_log_with_exp( @@ -252,6 +257,8 @@ class RunUTConfig(BaseConfig): self.port = json_config.get("port", -1) self.rank_list = json_config.get("rank_list", Const.DEFAULT_LIST) self.tls_path = json_config.get("tls_path", "./") + self.master_ip = json_config.get("master_ip", "127.0.0.1") + self.master_port = json_config.get("master_port", "8888") self.check_run_ut_config() @classmethod @@ -271,13 +278,26 @@ class RunUTConfig(BaseConfig): @classmethod def check_nfs_path_config(cls, nfs_path): - if nfs_path and not os.path.exists(nfs_path): - raise Exception("nfs_path: %s does not exist" % nfs_path) + if nfs_path: + FileChecker(nfs_path, FileCheckConst.DIR, FileCheckConst.READ_ABLE).common_check() @classmethod def check_tls_path_config(cls, tls_path): - if tls_path and not os.path.exists(tls_path): - raise Exception("tls_path: %s does not exist" % tls_path) + if tls_path: + FileChecker(tls_path, FileCheckConst.DIR, FileCheckConst.READ_ABLE).common_check() + + @classmethod + def check_master_ip_config(cls, master_ip): + if not re.match(Const.ipv4_pattern, master_ip): + raise Exception("master_ip: %s is invalid" % master_ip) + + @classmethod + def check_master_port_config(cls, master_port): + if not isinstance(master_port, str) or not master_port.isdigit(): + raise Exception(f"port: {master_port} is invalid. Port must be a numeric string.") + port_number = int(master_port) + if not (0 < port_number <= 65535): + raise Exception(f"port: {master_port} is invalid. Port range must be between 1 and 65535.") def check_run_ut_config(self): RunUTConfig.check_filter_list_config(Const.WHITE_LIST, self.white_list) @@ -285,6 +305,8 @@ class RunUTConfig(BaseConfig): RunUTConfig.check_error_data_path_config(self.error_data_path) RunUTConfig.check_nfs_path_config(self.nfs_path) RunUTConfig.check_tls_path_config(self.tls_path) + RunUTConfig.check_master_ip_config(self.master_ip) + RunUTConfig.check_master_port_config(self.master_port) class GradToolConfig(BaseConfig): diff --git a/debug/accuracy_tools/msprobe/pytorch/pytorch_service.py b/debug/accuracy_tools/msprobe/pytorch/pytorch_service.py new file mode 100644 index 0000000000000000000000000000000000000000..ae8b306eb69c8376baeda91b4ad711a028119273 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/pytorch_service.py @@ -0,0 +1,71 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from msprobe.core.common.utils import Const +from msprobe.core.service import BaseService +from msprobe.pytorch.attl_manager import ATTLManager +from msprobe.pytorch.common.log import logger +from msprobe.pytorch.common.utils import get_rank_if_initialized, torch_version_above_or_equal_2 +from msprobe.pytorch.dump.module_dump.module_processer import ModuleProcesser +from msprobe.pytorch.hook_module.api_register import get_api_register, ApiTemplate, redirect_wait +from msprobe.pytorch.hook_module.hook_module import HOOKModule +from msprobe.pytorch.hook_module.jit_script_wrapper import wrap_jit_script_func +from msprobe.pytorch.hook_module.pt_hook_manager import PytorchHookManager +from msprobe.pytorch.hook_module.register_optimizer_hook import register_optimizer_hook + +if torch_version_above_or_equal_2: + from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.dump_dispatch import run_ut_dispatch + + +class PytorchService(BaseService): + @property + def _get_framework_type(self): + return Const.PT_FRAMEWORK + + @staticmethod + def _get_current_rank(): + return get_rank_if_initialized() + + def _init_specific_components(self): + self.logger = logger + self.api_register = get_api_register() + self.module_processor = ModuleProcesser(self.data_collector.scope) + self.attl_manager = ATTLManager(self.config) + self.hook_manager = PytorchHookManager(self.data_collector, self.config, self.attl_manager) + self.api_template = ApiTemplate + + def _register_hook(self): + self.attl_manager.attl_init() + if self._is_mix_level: + register_optimizer_hook(self.data_collector) + + def _register_api_hook(self): + super()._register_api_hook() + wrap_jit_script_func() + redirect_wait() + + def _register_module_hook(self): + ModuleProcesser.enable_module_dump = True + self.module_processor.register_module_hook(self.model, self.build_hook) + self.logger.info_on_rank_0(f"The module {self.config.task} hook function is successfully mounted to the model.") + + def _run_ut_dispatch(self, status): + if torch_version_above_or_equal_2: + run_ut_dispatch(self.attl_manager.attl, status, self.config.online_run_ut_recompute) + + def _reset_status(self): + super()._reset_status() + ModuleProcesser.reset_module_stats() + HOOKModule.reset_module_stats() diff --git a/debug/accuracy_tools/msprobe/pytorch/service.py b/debug/accuracy_tools/msprobe/pytorch/service.py deleted file mode 100644 index fd81a7f1cf064506a4fb91481429828c97113509..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/msprobe/pytorch/service.py +++ /dev/null @@ -1,470 +0,0 @@ -# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import functools -import os -from collections import namedtuple, defaultdict - -import torch -from msprobe.core.common.const import Const -from msprobe.core.common.exceptions import DistributedNotInitializedError -from msprobe.core.common.file_utils import create_directory -from msprobe.core.common.utils import print_tools_ends_info, DumpPathAggregation -from msprobe.core.data_dump.data_collector import build_data_collector -from msprobe.core.data_dump.data_processor.base import ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs -from msprobe.core.data_dump.scope import BaseScope -from msprobe.pytorch.api_accuracy_checker.common.utils import ApiData -from msprobe.pytorch.common.log import logger -from msprobe.pytorch.common.utils import get_rank_if_initialized, is_recomputation -from msprobe.pytorch.dump.kernel_dump.kernel_config import create_kernel_config_json -from msprobe.pytorch.dump.module_dump.module_processer import ModuleProcesser -from msprobe.pytorch.hook_module.api_registry import api_register -from msprobe.pytorch.hook_module.hook_module import HOOKModule -from msprobe.pytorch.hook_module.register_optimizer_hook import register_optimizer_hook - -torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0' -if torch_version_above_or_equal_2: - from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.dump_dispatch import run_ut_dispatch - -HookFn = namedtuple('hookFn', ['pre_hook', 'forward_hook', 'backward_hook', 'forward_hook_torch_version_below_2']) - - -class Service: - def __init__(self, config): - self.model = None - self.config = config - self.data_collector = build_data_collector(config) - self.module_processor = ModuleProcesser(self.data_collector.scope) - self.switch = False - self.inner_switch = False - self.current_iter = 0 - self.first_start = True - self.current_rank = None - self.dump_iter_dir = None - self.should_stop_service = False - self.attl = None - self.params_grad_info = {} - self.hook_handle_dict = {} - # 提前注册,确保注册尽可能多的API hook - self.register_api_hook() - self.init_for_debug_level() - - def build_hook(self, module_type, name): - def pre_hook(api_or_module_name, module, args, kwargs): - if not self.should_execute_hook(module_type, module, True): - return args, kwargs - is_recompute = is_recomputation() - - self.inner_switch = True - if module_type == BaseScope.Module_Type_Module: - api_or_module_name = module.mindstudio_reserved_name[-1] - else: - module.forward_data_collected = True - HOOKModule.add_module_count(name) - self.data_collector.update_api_or_module_name(api_or_module_name) - - if self.config.online_run_ut: - self.inner_switch = False - return None, None - if self.data_collector: - module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=None) - self.data_collector.forward_input_data_collect( - api_or_module_name, - module, - pid, - module_input_output, - is_recompute - ) - - self.inner_switch = False - return args, kwargs - - def grad_hook(module, ori_name, param_name): - def hook_fn(grad): - if not self.should_execute_hook(module_type, module, False): - return grad - self.inner_switch = True - self.data_collector.params_data_collect(ori_name, param_name, pid, grad) - self.inner_switch = False - return grad - - return hook_fn - - def register_param_hook(ori_name, module, params_dict): - ''' - 注册参数hook - ''' - # data_mode为forward时,不注册参数hook - if not (Const.FORWARD in self.config.data_mode and Const.BACKWARD not in self.config.data_mode): - for param_name, param in params_dict.items(): - if param.requires_grad: - name = ori_name + Const.SEP + param_name - old_handle = self.hook_handle_dict.get(name) - if old_handle and hasattr(old_handle, "remove"): - old_handle.remove() - handle = param.register_hook(grad_hook(module, ori_name, param_name)) - self.hook_handle_dict[name] = handle - - def init_params_grad_info(module, params_dict): - ''' - 初始化参数梯度信息, 在前向hook结束后, 将参数梯度信息写入cache_data中用于占位 - ''' - if not params_dict: - return - if not (Const.FORWARD in self.config.data_mode and Const.BACKWARD not in self.config.data_mode): - grad_name = module.params_grad_name if hasattr(module, 'params_grad_name') else None - # 判断是否已经在cache_data中进行了占位, 若没有则先写入cache_data中 - if not self.params_grad_info.get(grad_name): - data_info = {grad_name: {key: [None] for key, value in params_dict.items() if value.requires_grad}} - # 当模块中的参数有requires_grad属性为True时,才会进行梯度计算,此时才需要占位 - if data_info.get(grad_name): - # 将grad_name的data_info先写入cache_data中, 梯度计算后再更新 - self.data_collector.handle_data(grad_name, data_info, - flush=self.data_collector.data_processor.is_terminated) - # 记录当前模块的参数梯度信息已占位 - self.params_grad_info[grad_name] = True - - def forward_hook(api_or_module_name, module, args, kwargs, output): - if not self.should_execute_hook(module_type, module, True): - return None - is_recompute = is_recomputation() - - self.inner_switch = True - if self.config.online_run_ut: - self.data_collector.update_api_or_module_name(api_or_module_name) - if self.data_collector.scope and not self.data_collector.scope.check(api_or_module_name): - return None - api_data = ApiData( - api_or_module_name[:-len(Const.FORWARD_NAME_SUFFIX)], - args, - kwargs, - output, - self.current_iter, - self.current_rank - ) - self.attl_send(api_data) - self.inner_switch = False - return None - - module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=output) - if module_type == BaseScope.Module_Type_Module: - api_or_module_name = module.mindstudio_reserved_name[-1] - self.data_collector.update_api_or_module_name(api_or_module_name) - params_dict = {} - if self.config.task != Const.STRUCTURE: - params_dict = { - key.split(Const.SEP)[-1]: value - for key, value in module.named_parameters(recurse=False) - } - setattr(module_input_output, Const.PARAMS, params_dict) - # 判断是否需要注册参数hook - if params_dict: - ori_name = api_or_module_name.rsplit(Const.SEP, 2)[0] - grad_name = ori_name + Const.SEP + Const.PARAMS_GRAD - # 首次执行前向hook时,添加params_grad_name属性,并注册参数hook - setattr(module, 'params_grad_name', grad_name) - register_param_hook(ori_name, module, params_dict) - self.data_collector.forward_data_collect( - api_or_module_name, - module, - pid, - module_input_output, - is_recompute - ) - init_params_grad_info(module, params_dict) - else: - self.data_collector.update_api_or_module_name(api_or_module_name) - self.data_collector.forward_output_data_collect( - api_or_module_name, - module, - pid, - module_input_output, - is_recompute - ) - - if self.data_collector.if_return_forward_new_output(): - forward_new_output = self.data_collector.get_forward_new_output() - self.inner_switch = False - return forward_new_output - self.inner_switch = False - return output - - def forward_hook_torch_version_below_2(api_or_module_name, module, args, output): - return forward_hook(api_or_module_name, module, args, {}, output) - - def backward_hook(api_or_module_name, module, grad_input, grad_output): - if not self.should_execute_hook(module_type, module, False): - return - is_recompute = is_recomputation() - - self.inner_switch = True - if module_type == BaseScope.Module_Type_Module: - api_or_module_name = module.mindstudio_reserved_name[-1] - self.data_collector.update_api_or_module_name(api_or_module_name) - - if self.config.online_run_ut: - self.inner_switch = False - return - - if self.data_collector: - # 此处获取到的grad_input实际为反向过程的输出数据,grad_output为反向过程的输入数据,因此传入时调换顺序 - module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_output, grad_output=grad_input) - self.data_collector.backward_data_collect( - api_or_module_name, - module, - pid, - module_input_output, - is_recompute - ) - self.inner_switch = False - - pid = os.getpid() - full_forward_name = None - full_backward_name = None - if module_type == BaseScope.Module_Type_API: - full_forward_name = name + str(HOOKModule.get_module_count(name)) + Const.SEP + Const.FORWARD - full_backward_name = name + str(HOOKModule.get_module_count(name)) + Const.SEP + Const.BACKWARD - pre_forward_hook_fn = functools.partial(pre_hook, full_forward_name) - forward_hook_fn = functools.partial(forward_hook, full_forward_name) - backward_hook_fn = functools.partial(backward_hook, full_backward_name) - forward_hook_torch_version_below_2_fn = functools.partial( - forward_hook_torch_version_below_2, - full_forward_name - ) - return HookFn(pre_forward_hook_fn, forward_hook_fn, backward_hook_fn, forward_hook_torch_version_below_2_fn) - - def start(self, model): - if self.config.level == Const.LEVEL_DEBUG: - return - if self.need_stop_service(): - return - - self.model = model - if self.first_start: - try: - self.current_rank = get_rank_if_initialized() - except DistributedNotInitializedError: - self.current_rank = None - self.attl_init() - - if self.config.rank and self.current_rank not in self.config.rank: - return - self.register_module_hook() - if self.config.level == Const.LEVEL_MIX: - register_optimizer_hook(self.data_collector) - self.first_start = False - if self.config.online_run_ut and torch_version_above_or_equal_2: - run_ut_dispatch(self.attl, True, self.config.online_run_ut_recompute) - self.switch = True - logger.info_on_rank_0(f"Dump switch is turned on at step {self.current_iter}. ") - if not self.config.online_run_ut: - self.create_dirs() - logger.info_on_rank_0(f"Dump data will be saved in {self.dump_iter_dir}.") - - def stop(self): - if self.config.level == Const.LEVEL_DEBUG: - return - if self.should_stop_service: - return - if self.config.step and self.current_iter not in self.config.step: - return - if self.config.rank and self.current_rank not in self.config.rank: - return - self.switch = False - if self.config.level == Const.LEVEL_L2: - return - if self.config.online_run_ut and torch_version_above_or_equal_2: - run_ut_dispatch(self.attl, False, self.config.online_run_ut_recompute) - return - if self.config.async_dump: - self.data_collector.fill_stack_tensor_data() - if self.config.task == Const.TENSOR: - self.data_collector.data_processor.dump_async_data() - self.data_collector.write_json() - - def step(self): - if self.config.level == Const.LEVEL_DEBUG: - return - if self.should_stop_service: - return - if self.config.async_dump: - self.data_collector.fill_stack_tensor_data() - if self.config.task == Const.TENSOR: - self.data_collector.data_processor.dump_async_data() - self.data_collector.write_json() - self.current_iter += 1 - self.data_collector.update_iter(self.current_iter) - self.reset_status() - - def need_stop_service(self): - if self.should_stop_service: - return True - end_service = self.config.step and self.current_iter > max(self.config.step) or \ - self.data_collector and self.data_collector.data_processor.is_terminated - if end_service: - if self.config.online_run_ut: - # send stop signal if online_run_ut - self.attl_stop() - self.switch = False - self.should_stop_service = True - print_tools_ends_info() - return True - if self.config.step and self.current_iter not in self.config.step: - return True - return False - - def should_execute_hook(self, hook_type, module, is_forward): - is_module_hook = hook_type == BaseScope.Module_Type_Module - if is_module_hook and not self.switch: - return False - elif not is_module_hook and is_forward and not self.switch: - return False - elif not is_module_hook and not is_forward and not module.forward_data_collected: - return False - - if self.inner_switch: - return False - if not self.data_collector or self.data_collector.data_processor.is_terminated: - return False - return True - - def create_dirs(self): - create_directory(self.config.dump_path) - self.dump_iter_dir = os.path.join(self.config.dump_path, f"step{self.current_iter}") - cur_rank = self.current_rank if self.current_rank is not None else '' - if self.config.level == Const.LEVEL_L2: - create_directory(self.dump_iter_dir) - kernel_config_path = create_kernel_config_json(self.dump_iter_dir, cur_rank) - self.config.kernel_config_path = kernel_config_path - return - - dump_dir = os.path.join(self.dump_iter_dir, f"rank{cur_rank}") - create_directory(dump_dir) - if self.config.task in self.data_collector.tasks_need_tensor_data: - dump_data_dir = os.path.join(dump_dir, "dump_tensor_data") - create_directory(dump_data_dir) - else: - dump_data_dir = None - - dump_path_aggregation = DumpPathAggregation() - dump_path_aggregation.dump_file_path = os.path.join(dump_dir, "dump.json") - dump_path_aggregation.stack_file_path = os.path.join(dump_dir, "stack.json") - dump_path_aggregation.construct_file_path = os.path.join(dump_dir, "construct.json") - dump_path_aggregation.dump_tensor_data_dir = dump_data_dir - dump_path_aggregation.free_benchmark_file_path = os.path.join(dump_dir, "free_benchmark.csv") - self.data_collector.update_dump_paths(dump_path_aggregation) - self.data_collector.initialize_json_file(framework=Const.PT_FRAMEWORK) - - def register_api_hook(self): - if self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L1, Const.LEVEL_L2]: - logger.info_on_rank_0(f"The api {self.config.task} hook function is successfully mounted to the model.") - api_register.initialize_hook( - functools.partial(self.build_hook, BaseScope.Module_Type_API), - self.config.online_run_ut - ) - api_register.api_modularity() - - def register_module_hook(self): - if self.config.level in [Const.LEVEL_L0, Const.LEVEL_MIX]: - logger.info_on_rank_0(f"The module {self.config.task} hook function is successfully mounted to the model.") - self.module_processor.register_module_hook(self.model, self.build_hook) - - def attl_init(self): - if self.config.online_run_ut: - from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.attl import ATTLConfig, ATTL - attl_config = ATTLConfig(is_benchmark_device=False, - connect_ip=self.config.host, - connect_port=self.config.port, - nfs_path=self.config.nfs_path, - tls_path=self.config.tls_path) - need_dump = len(self.config.rank) == 0 or self.current_rank in self.config.rank - self.attl = ATTL('npu', attl_config, need_dump=need_dump) - if self.config.nfs_path: - self.attl.upload("start") - - def attl_send(self, api_data): - logger.info(f"tools is dumping api: {api_data.name}, rank: {self.current_rank}") - api_type, _, _ = api_data.name.split(Const.SEP) - if api_type in [Const.DISTRIBUTED]: - logger.info(f"api {api_data.name} is not supported, skip") - return - if self.config.nfs_path: - self.attl.upload(api_data) - else: - self.attl.send(api_data) - - def attl_stop(self): - if self.config.nfs_path: - self.attl.upload("end") - elif self.attl.socket_manager is not None: - logger.info(f"pid: {os.getpid()} finished, start send STOP signal.") - self.attl.socket_manager.send_stop_signal() - - def reset_status(self): - ModuleProcesser.reset_module_stats() - HOOKModule.reset_module_stats() - self.data_collector.reset_status() - self.params_grad_info.clear() - - if self.config.level == Const.LEVEL_L2: - self.data_collector.data_processor.reset_status() - return - if self.config.step and self.current_iter not in self.config.step: - return - if self.config.rank and self.current_rank not in self.config.rank: - return - - def init_for_debug_level(self): - if not (self.config.level == Const.LEVEL_DEBUG and self.config.task in [Const.TENSOR, Const.STATISTICS]): - return - try: - self.current_rank = get_rank_if_initialized() - except DistributedNotInitializedError: - self.current_rank = None - - # dir: dump_path -- rank{} -- debug.json - self.dump_iter_dir = self.config.dump_path - cur_rank = self.current_rank if self.current_rank is not None else '' - dump_dir = os.path.join(self.dump_iter_dir, f"rank{cur_rank}") - create_directory(dump_dir) - if self.config.task in self.data_collector.tasks_need_tensor_data: - dump_data_dir = os.path.join(dump_dir, "dump_tensor_data") - create_directory(dump_data_dir) - else: - dump_data_dir = None - - dump_path_aggregation = DumpPathAggregation() - dump_path_aggregation.dump_tensor_data_dir = dump_data_dir - dump_path_aggregation.debug_file_path = os.path.join(dump_dir, "debug.json") - self.data_collector.update_dump_paths(dump_path_aggregation) - self.data_collector.initialize_json_file(framework=Const.PT_FRAMEWORK) - - self.debug_variable_counter = defaultdict(int) - - def save(self, variable, name, save_backward): - if self.config.level != Const.LEVEL_DEBUG: - return - count = self.debug_variable_counter[name] - self.debug_variable_counter[name] += 1 - - name_with_count = f"{name}.{count}" - grad_name_with_count = f"{name}_grad.{count}" - - # forward save - self.data_collector.debug_data_collect_forward(variable, name_with_count) - - # backward save - if save_backward: - self.data_collector.debug_data_collect_backward(variable, grad_name_with_count) diff --git a/debug/accuracy_tools/msprobe/test/common_set_up/__init__.py b/debug/accuracy_tools/msprobe/test/common_set_up/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/debug/accuracy_tools/msprobe/test/common_set_up/mindtorch.py b/debug/accuracy_tools/msprobe/test/common_set_up/mindtorch.py new file mode 100644 index 0000000000000000000000000000000000000000..665d17c21e743fb5ffe6a0d9e014fe0a2da4af99 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/common_set_up/mindtorch.py @@ -0,0 +1,29 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from mindspore import Tensor +import torch + + +def create_msa_tensor(data, dtype=None): + return Tensor(data, dtype) + + +tensor_tensor = torch.tensor +setattr(torch, 'tensor', create_msa_tensor) + + +def reset_torch_tensor(): + setattr(torch, 'tensor', tensor_tensor) diff --git a/debug/accuracy_tools/msprobe/test/common_set_up/test_set_up.py b/debug/accuracy_tools/msprobe/test/common_set_up/test_set_up.py new file mode 100644 index 0000000000000000000000000000000000000000..9fa072dde7644a670f193d7c61d2c2ed5fa5b748 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/common_set_up/test_set_up.py @@ -0,0 +1,55 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +from unittest import TestCase +from unittest.mock import MagicMock + +import mindspore as ms +from mindspore import mint + +try: + from mint import distributed +except ImportError: + distributed = MagicMock() + setattr(mint, 'distributed', distributed) + +# ensure not to import torch_npu +from msprobe.mindspore import mindspore_service +from msprobe.mindspore.monitor import common_func + +from .mindtorch import reset_torch_tensor +from msprobe.mindspore.common import utils +from msprobe.mindspore.common.utils import is_mindtorch, register_backward_hook_functions + +utils.mindtorch_check_result = None +importlib.reload(mindspore_service) +importlib.reload(common_func) +reset_torch_tensor() + + +def register_backward_pre_hook(*args, **kwargs): + pass + + +register_backward_hook_functions['full'] = ms.nn.Cell.register_backward_hook +register_backward_hook_functions["pre"] = register_backward_pre_hook + + +class SetUp(TestCase): + def test_case(self): + self.assertTrue(hasattr(mint, 'distributed')) + self.assertTrue(is_mindtorch()) + utils.mindtorch_check_result = None diff --git a/debug/accuracy_tools/msprobe/test/core_ut/common/test_dump_file/dump_no_pt_no_ms.json b/debug/accuracy_tools/msprobe/test/core_ut/common/test_dump_file/dump_no_pt_no_ms.json new file mode 100644 index 0000000000000000000000000000000000000000..63a062d8ffa264a0254fc2bab0208dcf951ae094 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/core_ut/common/test_dump_file/dump_no_pt_no_ms.json @@ -0,0 +1,3 @@ +{ + "task": "tensor" +} \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/test/core_ut/common/test_dump_file/ms_dump_no_framework.json b/debug/accuracy_tools/msprobe/test/core_ut/common/test_dump_file/ms_dump_no_framework.json new file mode 100644 index 0000000000000000000000000000000000000000..b223c74b2315af1b9454e5f1e70c29502d449c56 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/core_ut/common/test_dump_file/ms_dump_no_framework.json @@ -0,0 +1,4 @@ +{ + "task": "tensor", + "type": "mindspore.float16" +} \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/test/core_ut/common/test_dump_file/pt_dump_no_framework.json b/debug/accuracy_tools/msprobe/test/core_ut/common/test_dump_file/pt_dump_no_framework.json new file mode 100644 index 0000000000000000000000000000000000000000..2444ae1fd4096b083a9e8a0e51c9166bb990f51f --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/core_ut/common/test_dump_file/pt_dump_no_framework.json @@ -0,0 +1,4 @@ +{ + "task": "tensor", + "type": "torch.float16" +} \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/test/core_ut/common/test_file_utils.py b/debug/accuracy_tools/msprobe/test/core_ut/common/test_file_utils.py index 9ed13f78aed57fd4d8153e2f005ea14d4fb33643..318313bba82197be77c36a9b4f9620f36c409b03 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/common/test_file_utils.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/common/test_file_utils.py @@ -1,7 +1,8 @@ +import unittest from unittest.mock import patch, mock_open, MagicMock +from zipfile import ZipFile, ZipInfo +import tempfile -import numpy as np -import pandas as pd import pytest from msprobe.core.common.file_utils import * @@ -248,12 +249,21 @@ class TestFileOperations: assert mock_flock.call_count == 2 mock_dump.assert_called_once_with(test_data, mock_file(), sort_keys=False) - def test_save_excel(self): + def test_save_excel_tiny(self): df = pd.DataFrame({'col1': [1, 2], 'col2': [3, 4]}) with patch('pandas.DataFrame.to_excel') as mock_to_excel, \ + patch('pandas.ExcelWriter') as mock_writer, \ patch('os.chmod') as mock_chmod: save_excel(self.excel_file, df) - mock_to_excel.assert_called_once_with(str(self.excel_file), index=False) + mock_to_excel.assert_called_once_with(mock_writer().__enter__(), sheet_name='Sheet1', index=False) + + def test_save_excel_large(self): + df = pd.DataFrame({'col1': list(range(1500000)), 'col2': list(range(1500000, 0, -1))}) + with patch('pandas.DataFrame.to_excel') as mock_to_excel, \ + patch('pandas.ExcelWriter') as mock_writer, \ + patch('os.chmod') as mock_chmod: + save_excel(self.excel_file, df) + mock_to_excel.assert_called_with(mock_writer().__enter__(), sheet_name='part_1', index=False) def test_move_file(self): dst_file = self.test_dir / "moved_file" @@ -439,18 +449,19 @@ class TestUtilityOperations: def test_remove_path(self): # Test remove file with patch('os.path.exists', return_value=True), \ - patch('os.path.islink', return_value=True), \ + patch('os.path.islink', return_value=False), \ + patch('os.path.isfile', return_value=True), \ patch('os.remove') as mock_remove: - remove_path(str(self.test_file)) - mock_remove.assert_called_once_with(str(self.test_file)) + remove_path("/test_remove_path/test/test.txt") + mock_remove.assert_called_once_with("/test_remove_path/test/test.txt") # Test remove directory with patch('os.path.exists', return_value=True), \ patch('os.path.islink', return_value=False), \ patch('os.path.isfile', return_value=False), \ patch('shutil.rmtree') as mock_rmtree: - remove_path(str(self.test_dir)) - mock_rmtree.assert_called_once_with(str(self.test_dir)) + remove_path("/test_remove_path/test") + mock_rmtree.assert_called_once_with("/test_remove_path/test") def test_get_json_contents(self): json_content = '{"key": "value"}' @@ -495,24 +506,6 @@ class TestUtilityOperations: assert result[0]['file'] == 'file1.txt' -class TestCertificateOperations: - @pytest.fixture(autouse=True) - def setup(self, tmp_path): - self.cert_file = tmp_path / "test.pem" - self.mock_cert = MagicMock() - self.mock_cert.get_notBefore.return_value = b'20230101000000Z' - self.mock_cert.get_notAfter.return_value = b'20250101000000Z' - self.mock_cert.has_expired.return_value = False - - def test_check_crt_valid(self): - # Test expired certificate - self.mock_cert.has_expired.return_value = True - with patch('OpenSSL.crypto.load_certificate', return_value=self.mock_cert), \ - patch('builtins.open', mock_open(read_data='cert data')), \ - pytest.raises(RuntimeError): - check_crt_valid(self.cert_file) - - class TestDirectoryChecks: @pytest.fixture(autouse=True) def setup(self, tmp_path): @@ -533,4 +526,59 @@ class TestDirectoryChecks: # Test file path check_file_or_directory_path(self.test_file, isdir=False) # Test directory path - check_file_or_directory_path(self.test_dir, isdir=True) \ No newline at end of file + check_file_or_directory_path(self.test_dir, isdir=True) + + +cur_dir = os.path.dirname(os.path.realpath(__file__)) +zip_dir = os.path.join(cur_dir, 'test_temp_zip_file') + + +class TestCheckZipFile(unittest.TestCase): + def setUp(self): + os.makedirs(zip_dir, mode=0o750, exist_ok=True) + + def tearDown(self): + if os.path.exists(zip_dir): + shutil.rmtree(zip_dir) + + @staticmethod + def create_fake_zip_with_sizes(file_sizes): + """创建临时 zip 文件,file_sizes 为每个文件的大小列表,伪造一个具有 file_size=size 的 ZIP 条目""" + tmp_fd, tmp_path = tempfile.mkstemp(suffix=".zip", dir=zip_dir) + os.close(tmp_fd) + with ZipFile(tmp_path, 'w', allowZip64=True) as zipf: + for i, size in enumerate(file_sizes): + info = ZipInfo(f"file_{i}.bin") + zipf.writestr(info, b'') # 实际内容为空,但声明文件大小为 size + info.file_size = size + return tmp_path + + def test_valid_zip(self): + file_sizes = [100, 200, 300] + zip_path = self.create_fake_zip_with_sizes(file_sizes) + try: + check_zip_file(zip_path) + finally: + os.remove(zip_path) + + def test_single_file_too_large(self): + file_sizes = [FileCheckConst.MAX_FILE_IN_ZIP_SIZE + 1] + zip_path = self.create_fake_zip_with_sizes(file_sizes) + try: + with self.assertRaises(ValueError) as cm: + check_zip_file(zip_path) + self.assertIn("is too large to extract", str(cm.exception)) + finally: + os.remove(zip_path) + + def test_total_size_too_large(self): + count = 20 + size_each = (FileCheckConst.MAX_ZIP_SIZE // count) + 1 + file_sizes = [size_each] * count + zip_path = self.create_fake_zip_with_sizes(file_sizes) + try: + with self.assertRaises(ValueError) as cm: + check_zip_file(zip_path) + self.assertIn("Total extracted size exceeds the limit", str(cm.exception)) + finally: + os.remove(zip_path) diff --git a/debug/accuracy_tools/msprobe/test/core_ut/common/test_utils.py b/debug/accuracy_tools/msprobe/test/core_ut/common/test_utils.py index 3472ca9018e189ffb48e4d26cfeb79e1ba1ff16d..59f7aa5c589e558e7c43db9e50dac008eac55eb4 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/common/test_utils.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/common/test_utils.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- """ -# Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. +# Copyright (C) 2024-2025. Huawei Technologies Co., Ltd. All rights reserved. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -17,12 +17,12 @@ import json import os import tempfile -from datetime import datetime, timezone +import unittest from unittest import TestCase from unittest.mock import MagicMock, mock_open, patch -import OpenSSL import numpy as np +from pathlib import Path from msprobe.core.common.const import Const from msprobe.core.common.file_utils import ( @@ -30,7 +30,6 @@ from msprobe.core.common.file_utils import ( FileCheckException, check_file_or_directory_path, check_file_size, - check_crt_valid, get_file_content_bytes, get_json_contents, save_json, @@ -45,15 +44,18 @@ from msprobe.core.common.utils import (CompareException, check_regex_prefix_format_valid, set_dump_path, get_dump_mode, - get_real_step_or_rank, - get_step_or_rank_from_string, + get_real_step_or_rank, + get_step_or_rank_from_string, get_stack_construct_by_dump_json_path, check_seed_all, safe_get_value, - recursion_depth_decorator, MsprobeBaseException, check_str_param, - is_json_file) + is_json_file, + detect_framework_by_dump_json, + is_save_variable_valid, + check_dump_json_key) +from msprobe.core.common.decorator import recursion_depth_decorator class TestUtils(TestCase): @@ -203,7 +205,7 @@ class TestUtils(TestCase): with self.assertRaises(CompareException) as context: set_dump_path(input_param) self.assertEqual(context.exception.code, CompareException.INVALID_PATH_ERROR) - mock_error.assert_called_with("Please check the json path is valid. npu_path: None, bench_path: bench_path") + mock_error.assert_called_with("Please check the json path is valid and ensure that neither npu_path nor bench_path is None.") @patch.object(logger, "error") def test_get_dump_mode(self, mock_error): @@ -214,7 +216,7 @@ class TestUtils(TestCase): npu_json = { "task": Const.TENSOR, "dump_data_dir": "dump_data_dir", - "data": "data" + "data": {"api": "value"} } input_param["npu_json_path"] = "npu_path" @@ -334,7 +336,7 @@ class TestUtils(TestCase): def test_recursion_depth_decorator(self, mock_error): # 测试递归深度限制函数 recursion_list = [[]] - temp_list = recursion_list[0] + temp_list = recursion_list[0] for _ in range(Const.MAX_DEPTH): temp_list.append([]) temp_list = temp_list[0] @@ -436,55 +438,125 @@ class TestUtils(TestCase): self.assertFalse(is_json_file(file_path_false)) -class TestCheckCrtValid(TestCase): - """ - Test the check_crt_valid function. - """ +class TestDetectFrameworkByDumpJson(unittest.TestCase): + @patch('msprobe.core.common.utils.load_json') + def test_valid_pytorch_framework(self, mock_load_json): + mock_load_json.return_value = {"framework": Const.PT_FRAMEWORK} + + result = detect_framework_by_dump_json("dummy_path") + + self.assertEqual(result, Const.PT_FRAMEWORK) + + @patch('msprobe.core.common.utils.load_json') + def test_valid_mindspore_framework(self, mock_load_json): + mock_load_json.return_value = {"framework": Const.MS_FRAMEWORK} + + result = detect_framework_by_dump_json("dummy_path") + + self.assertEqual(result, Const.MS_FRAMEWORK) + + def test_detect_framework_in_file(self): + self.current_dir = Path(__file__).parent + file_path = self.current_dir / "test_dump_file/pt_dump_no_framework.json" + result = detect_framework_by_dump_json(file_path) + self.assertEqual(result, Const.PT_FRAMEWORK) + + self.current_dir = Path(__file__).parent + file_path = self.current_dir / "test_dump_file/ms_dump_no_framework.json" + result = detect_framework_by_dump_json(file_path) + self.assertEqual(result, Const.MS_FRAMEWORK) + + @patch("msprobe.core.common.utils.logger") + def test_detect_framework_exception(self, mock_logger): + self.current_dir = Path(__file__).parent + file_path = self.current_dir / "test_dump_file/dump_no_pt_no_ms.json" + with self.assertRaises(CompareException) as context: + result = detect_framework_by_dump_json(file_path) + self.assertEqual(context.exception.code, CompareException.INVALID_PARAM_ERROR) + mock_logger.error.assert_called_once_with(f"{file_path} must be based on the MindSpore or PyTorch framework.") + + +class TestIsSaveVariableValid(unittest.TestCase): def setUp(self): - self.cert_file_path = "cert_file_path.pem" - if not os.path.exists(self.cert_file_path): - with open(self.cert_file_path, 'w') as f: - f.write("This is a test certificate.") - - def tearDown(self): - if os.path.exists(self.cert_file_path): - os.remove(self.cert_file_path) - - @patch('msprobe.core.common.file_utils.datetime') - @patch('OpenSSL.crypto.load_certificate') - @patch('builtins.open', new_callable=mock_open, read_data="cert_data") - def test_check_crt_valid_success(self, mock_open_, mock_load_certificate, mock_datetime): - mock_cert = MagicMock() - mock_cert.get_notBefore.return_value = b'20220101' - mock_cert.get_notAfter.return_value = b'20230101' - mock_cert.has_expired.return_value = False - mock_load_certificate.return_value = mock_cert - mock_datetime.now.return_value = datetime(2022, 10, 1) - - check_crt_valid(self.cert_file_path) - mock_load_certificate.assert_called_once_with(OpenSSL.crypto.FILETYPE_PEM, 'cert_data') - - @patch('datetime.datetime') - @patch('OpenSSL.crypto.load_certificate') - @patch('builtins.open', new_callable=mock_open, read_data="cert_data") - def test_check_crt_valid_expired(self, mock_open_, mock_load_certificate, mock_datetime): - mock_cert = MagicMock() - mock_cert.get_notBefore.return_value = b'20220101' - mock_cert.get_notAfter.return_value = b'20230101' - mock_cert.has_expired.return_value = True - mock_load_certificate.return_value = mock_cert - mock_datetime.now.return_value = datetime(2022, 10, 1, tzinfo=timezone.utc) - - with self.assertRaises(RuntimeError) as context: - check_crt_valid(self.cert_file_path) - self.assertIn('The SSL certificate has expired and needs to be replaced', str(context.exception)) - - @patch('OpenSSL.crypto.load_certificate') - @patch('builtins.open', new_callable=mock_open, read_data="cert_data") - def test_check_crt_valid_exception(self, mock_open_, mock_load_certificate): - mock_load_certificate.side_effect = Exception('Test Exception') - - with self.assertRaises(RuntimeError) as context: - check_crt_valid(self.cert_file_path) - self.assertIn('The SSL certificate is invalid', str(context.exception)) + self.valid_special_types = (int, float, str, bool) + + def test_is_save_variable_valid_DepthExceeded_ReturnsFalse(self): + # 创建一个深度超过 Const.DUMP_MAX_DEPTH 的嵌套结构 + nested_structure = [0] * Const.DUMP_MAX_DEPTH + for _ in range(Const.DUMP_MAX_DEPTH): + nested_structure = [nested_structure] + self.assertFalse(is_save_variable_valid(nested_structure, self.valid_special_types)) + + def test_is_save_variable_valid_ValidSpecialTypes_ReturnsTrue(self): + for valid_type in self.valid_special_types: + self.assertTrue(is_save_variable_valid(valid_type(0), self.valid_special_types)) + + def test_is_save_variable_valid_ListWithValidElements_ReturnsTrue(self): + self.assertTrue(is_save_variable_valid([1, 2, 3], self.valid_special_types)) + + def test_is_save_variable_valid_ListWithInvalidElement_ReturnsFalse(self): + self.assertFalse(is_save_variable_valid([1, "test", [1, slice(1)]], self.valid_special_types)) + + def test_is_save_variable_valid_TupleWithValidElements_ReturnsTrue(self): + self.assertTrue(is_save_variable_valid((1, 2, 3), self.valid_special_types)) + + def test_is_save_variable_valid_TupleWithInvalidElement_ReturnsFalse(self): + self.assertFalse(is_save_variable_valid((1, "test", [1, slice(1)]), self.valid_special_types)) + + def test_is_save_variable_valid_DictWithValidElements_ReturnsTrue(self): + self.assertTrue(is_save_variable_valid({"a": 1, "b": "test"}, self.valid_special_types)) + + def test_is_save_variable_valid_DictWithInvalidKey_ReturnsFalse(self): + self.assertFalse(is_save_variable_valid({1: "test"}, self.valid_special_types)) + + def test_is_save_variable_valid_DictWithInvalidValue_ReturnsFalse(self): + self.assertFalse(is_save_variable_valid({"a": [1, slice(1)]}, self.valid_special_types)) + + +class TestCheckDumpJsonKey(unittest.TestCase): + def test_valid_input(self): + json_data = { + "task": "tensor", + "data": {"api1": "value1"} + } + task, api_data = check_dump_json_key(json_data, "NPU") + self.assertEqual(task, "tensor") + self.assertEqual(api_data, {"api1": "value1"}) + + @patch("msprobe.core.common.utils.logger") + def test_missing_task(self, mock_logger): + json_data = { + "data": {"api1": "value1"} + } + with self.assertRaises(CompareException) as context: + check_dump_json_key(json_data, "bench") + self.assertEqual(context.exception.code, CompareException.INVALID_TASK_ERROR) + mock_logger.error.assert_called_once_with( + "Task for bench is empty, please check." + ) + + @patch("msprobe.core.common.utils.logger") + def test_missing_data(self, mock_logger): + json_data = { + "task": "tensor" + } + with self.assertRaises(CompareException) as context: + check_dump_json_key(json_data, "npu") + self.assertEqual(context.exception.code, CompareException.INVALID_DATA_ERROR) + mock_logger.error.assert_called_once_with( + "Missing 'data' in dump.json, please check dump.json of npu." + ) + + @patch("msprobe.core.common.utils.logger") + def test_wrong_data_type(self, mock_logger): + json_data = { + "task": "tensor", + "data": [1] + } + with self.assertRaises(CompareException) as context: + check_dump_json_key(json_data, "npu") + self.assertEqual(context.exception.code, CompareException.INVALID_DATA_ERROR) + mock_logger.error.assert_called_once_with( + "Invalid type for 'data': expected a dict. Please check dump.json of npu." + ) diff --git a/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare.py b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare.py index b4566fcfe6f48d9040feb4dc22f3a96cd08719a7..ee15d9b06e530f32c5759492a9de40a2ab9cbf46 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare.py @@ -6,15 +6,49 @@ import threading import unittest from unittest.mock import patch +import numpy as np import pandas as pd import torch +from msprobe.core.common.file_utils import load_json from msprobe.core.common.const import CompareConst, Const from msprobe.core.common.utils import CompareException -from msprobe.core.compare.acc_compare import Comparator, ModeConfig, get_bench_data_name -from msprobe.core.compare.highlight import find_error_rows, find_compare_result_error_rows, ApiBatch -from msprobe.core.compare.utils import get_accuracy -from msprobe.pytorch.compare.pt_compare import PTComparator +from msprobe.core.compare.acc_compare import ModeConfig, MappingConfig, MappingDict, Comparator, ParseData, ProcessDf, \ + Match, CreateTable, CalcStatsDiff + +npu_op_item_data_fuzzy = { + 'op_name': 'Functional.conv2d.0.forward.input.0', + 'dtype': 'torch.float32', + 'shape': [1, 1, 28, 28], + 'summary': [3.029174327850342, -2.926689624786377, -0.06619918346405029], + 'stack_info': [], + 'data_name': 'Functional.conv2d.0.forward.input.0.pt', + 'compare_key': 'Functional.conv2d.0.forward.input.0', + 'compare_shape': [1, 1, 28, 28], +} +npu_op_item_fuzzy = pd.Series(npu_op_item_data_fuzzy) +npu_op_item_data_fuzzy_2 = { + 'op_name': 'Functional.conv2d.0.forward.input.1', + 'dtype': 'torch.float32', + 'shape': [1, 1, 28, 28], + 'summary': [3.029174327850342, -2.926689624786377, -0.06619918346405029], + 'stack_info': [], + 'data_name': 'Functional.conv2d.0.forward.input.1.pt', + 'compare_key': 'Functional.conv2d.0.forward.input.1', + 'compare_shape': [1, 1, 28, 28], +} +npu_op_item_fuzzy_2 = pd.Series(npu_op_item_data_fuzzy_2) +bench_op_item_data_fuzzy = { + 'op_name': 'Functional.conv2d.1.forward.input.0', + 'dtype': 'torch.float32', + 'shape': [1, 1, 28, 28], + 'summary': [3.029174327850342, -2.926689624786377, -0.06619918346405029], + 'stack_info': [], + 'data_name': 'Functional.conv2d.1.forward.input.0.pt', + 'compare_key': 'Functional.conv2d.1.forward.input.0', + 'compare_shape': [1, 1, 28, 28], +} +bench_op_item_fuzzy = pd.Series(bench_op_item_data_fuzzy) npu_dict = {'op_name': ['Functional.conv2d.0.forward.input.0', 'Functional.conv2d.0.forward.input.1', 'Functional.conv2d.0.forward.input.2', 'Functional.conv2d.0.forward.output'], @@ -159,50 +193,21 @@ aten_result = [ -10.640625, -0.008758544921875, 5.397906303405762, -5.796811580657959, 2.5283952709287405e-10, 'Warning', 'Need double check api accuracy.', 'None'], ['Aten__native_batch_norm_legit_functional.default_0_forward.output.1', 'Nan', 'torch.float32', 'Nan', [256], 'Nan', - ' ', ' ', ' ', ' ', ' ', 0.30550330877304077, -0.24485322833061218, -0.010361209511756897, 'Nan', 'Nan', 'Nan', + ' ', ' ', ' ', ' ', ' ', ' ', 0.30550330877304077, -0.24485322833061218, -0.010361209511756897, 'Nan', 'Nan', + 'Nan', 'Yes', '', 'None'], ['Aten__native_batch_norm_legit_functional.default_0_forward.output.2', 'Nan', 'torch.float32', 'Nan', [256], 'Nan', - ' ', ' ', ' ', ' ', ' ', 623.9192504882812, 432.96826171875, 520.2276611328125, 'Nan', 'Nan', 'Nan', + ' ', ' ', ' ', ' ', ' ', ' ', 623.9192504882812, 432.96826171875, 520.2276611328125, 'Nan', 'Nan', 'Nan', 'Yes', '', 'None'], ['Aten__native_batch_norm_legit_functional.default_0_forward.output.3', 'Nan', 'torch.float32', 'Nan', [256], 'Nan', - ' ', ' ', ' ', ' ', ' ', 2.4797861576080322, -3.055997371673584, -0.04795549064874649, 'Nan', 'Nan', 'Nan', + ' ', ' ', ' ', ' ', ' ', ' ', 2.4797861576080322, -3.055997371673584, -0.04795549064874649, 'Nan', 'Nan', 'Nan', 'Yes', '', 'None'], ['Aten__native_batch_norm_legit_functional.default_0_forward.output.4', 'Nan', 'torch.float32', 'Nan', [256], 'Nan', - ' ', ' ', ' ', ' ', ' ', 61.7945556640625, 42.59713363647461, 52.03831481933594, 'Nan', 'Nan', 'Nan', + ' ', ' ', ' ', ' ', ' ', ' ', 61.7945556640625, 42.59713363647461, 52.03831481933594, 'Nan', 'Nan', 'Nan', 'Yes', '', 'None']] highlight_dict = {'red_rows': [], 'yellow_rows': []} -num_0, num_1, num_2, num_3 = 0, 1, 2, 3 -summary_line_input = ['Functional_batch_norm_0_forward.input.0', 'Functional_batch_norm_0_forward.input.0', - 'torch.float16', - 'torch.float32', [256, 256, 14, 14], [256, 256, 14, 14], 0.01, 0, 0, 0, 1, 1, 1, 1, 1.01, 1, 1, 1, - 'Yes', ''] -summary_line_1 = ['Functional_batch_norm_0_forward.output.0', 'Functional_batch_norm_0_forward.output.0', - 'torch.float16', - 'torch.float32', [256, 256, 14, 14], [256, 256, 14, 14], 10, 0, 0, 0, 2, 0, 1, 1, 1, 1, 1, 1, - 'Warning', ''] -summary_line_2 = ['Functional_batch_norm_0_forward.output.1', 'Functional_batch_norm_0_forward.output.1', - 'torch.float16', - 'torch.float32', [256, 256, 14, 14], [256, 256, 14, 14], 0.02, 0, 0, 0, 0.12, 0, 1, 1, 0.1, 1, 1, 1, - 'Warning', ''] -summary_line_3 = ['Functional_batch_norm_0_forward.output.2', 'Functional_batch_norm_0_forward.output.2', - 'torch.float16', - 'torch.float32', [256, 256, 14, 14], [256, 256, 14, 14], 0, 0, 0, 0, 2, 0, 1, 1, 1, 1, 1, 1, - 'Warning', ''] -line_input = ['Functional.batch.norm.0.forward.input.0', 'Functional.batch.norm.0.forward.input.0', 'torch.float16', - 'torch.float32', [256, 256, 14, 14], [256, 256, 14, 14], 1, 1, 1, 0.95, 1, 1, 1, 1, 1, 1.01, 1, 1, 1, - 'Yes', ''] -line_1 = ['Functional.batch.norm.0.forward.output.0', 'Functional.batch.norm.0.forward.output.0', 'torch.float16', - 'torch.float32', [256, 256, 14, 14], [256, 256, 14, 14], 0.8, 1, 1, 0.59, 1, 'nan', 0, 1, 1, 19, 1, 1, 1, - 'Warning', ''] -line_2 = ['Functional.batch.norm.0.forward.output.1', 'Functional.batch.norm.0.forward.output.1', 'torch.float16', - 'torch.float32', [256, 256, 14, 14], [256, 256, 14, 14], 0.9, 1, 1, 0.8, 1, 0, 0.12, 0, 1, 1, 0.1, 1, 1, 1, - 'Warning', ''] -line_3 = ['Functional.batch.norm.0.forward.output.2', 'Functional.batch.norm.0.forward.output.2', 'torch.float16', - 'torch.float32', [256, 256, 14, 14], [256, 256, 14, 14], 0.8, 1.1e+10, 1, 0.85, 1, 9, 0.12, 0, 1, 1, 0.1, 1, - 1, 1, 'Warning', ''] - op_data = { 'input_args': [{'type': 'torch.Tensor', 'dtype': 'torch.float32', 'shape': [16, 1, 3, 3], 'Max': 0.33033010363578796, 'Min': -0.331031858921051, 'Mean': -0.030964046716690063, @@ -263,6 +268,33 @@ def generate_dump_json(base_dir): json.dump(data, json_file) +def generate_dump_json_md5(base_dir): + data_path = os.path.join(base_dir, 'dump_md5.json') + data = { + 'task': 'statistics', + 'level': 'L1', + 'dump_data_dir': '', + 'data': { + 'Functional.linear.0.forward': { + 'input_args': [ + {'type': 'torch.Tensor', + 'dtype': 'torch.float32', + 'shape': [2, 2], + 'Max': 2, + 'Min': 0, + 'Mean': 1, + 'Norm': 1, + 'requires_grad': False, + 'md5': 123456 + } + ] + } + } + } + with open(data_path, 'w') as json_file: + json.dump(data, json_file) + + def generate_stack_json(base_dir): data_path = os.path.join(base_dir, 'stack.json') data = {'Functional.linear.0.forward': ['File']} @@ -296,145 +328,6 @@ class TestUtilsMethods(unittest.TestCase): if os.path.exists(base_dir3): shutil.rmtree(base_dir3) - def test_get_accuracy_graph_mode(self): - result = [] - get_accuracy(result, npu_dict_aten, bench_dict_functional, dump_mode=Const.SUMMARY) - self.assertEqual(result, aten_result) - - def test_find_error_rows(self): - api_batch = ApiBatch("Functional_batch_norm_0_forward", 0) - api_batch.input_len = 1 - api_batch.output_end_index = 4 - api_batch.params_end_index = 4 - summary_result = [summary_line_input, summary_line_1, summary_line_2, summary_line_3] - highlight_dict_test = {"red_rows": set(), "yellow_rows": set(), "red_lines": [], "yellow_lines": []} - find_error_rows(summary_result, api_batch, highlight_dict_test, dump_mode=Const.SUMMARY) - self.assertEqual(highlight_dict_test, - {"red_rows": set(), "yellow_rows": set(), "red_lines": [], "yellow_lines": []}) - - def test_find_compare_result_error_rows(self): - result = [line_input, line_1, line_2, line_3] - result_df = pd.DataFrame(result) - highlight_dict_test = {"red_rows": set(), "yellow_rows": set(), "red_lines": [], "yellow_lines": []} - find_compare_result_error_rows(result_df, highlight_dict_test, dump_mode=Const.ALL) - self.assertEqual(highlight_dict_test, { - "red_rows": {1, 3}, - "yellow_rows": {2}, - "red_lines": [ - (1, ["maximum or minimum is nan, -inf, or inf"]), - (3, ["maximum absolute error exceeds 1e+10"]) - ], - "yellow_lines": [ - (2, ["The output's one thousandth err ratio decreases by more than 0.1 compared to the input/parameters's"]), - (3, [ - "maximum absolute error of both input/parameters and output exceed 1, " - "with the output larger by an order of magnitude", - "The output's cosine decreases by more than 0.1 compared to the input/parameters's"]) - ] - }) - - def test_calculate_summary_data(self): - npu_summary_data = [1, 1, 1, 1] - bench_summary_data = [2, 2, 2, 2] - result_item = ['', '', '', '', '', '', '', '', '', '', '', '', '', ''] - - stack_mode = True - auto_analyze = True - fuzzy_match = False - dump_mode = Const.SUMMARY - mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) - - comparator = Comparator(mode_config) - comparator.calculate_summary_data(npu_summary_data, bench_summary_data, result_item) - self.assertEqual(result_item, - ['', '', '', '', '', '', -1, -1, -1, -1, '50.0%', '50.0%', '50.0%', '50.0%', '', '']) - - bench_summary_data = [0, 0, 0, 0] - result_item = ['', '', '', '', '', '', '', '', '', '', '', '', '', ''] - - comparator.calculate_summary_data(npu_summary_data, bench_summary_data, result_item) - self.assertEqual(result_item, ['', '', '', '', '', '', 1, 1, 1, 1, 'N/A', 'N/A', 'N/A', 'N/A', 'Warning', - 'Need double check api accuracy.']) - - def test_make_result_table_stack_mode_True(self): - result_md5 = [['Functional.linear.0.forward.input.0', 'Functional.linear.0.forward.input.0', - 'torch.float32', 'torch.float32', [2, 2], [2, 2], '', '', '', 'File']] - result_summary = [['Functional.linear.0.forward.input.0', 'Functional.linear.0.forward.input.0', - 'torch.float32', 'torch.float32', [2, 2], [2, 2], '', '', '', '', '', '', '', '', - 1, 1, 1, 1, 1, 1, 1, 1, 'Yes', '', 'File']] - result_all = [['Functional.linear.0.forward.input.0', 'Functional.linear.0.forward.input.0', - 'torch.float32', 'torch.float32', [2, 2], [2, 2], '', '', '', '', '', - 1, 1, 1, 1, 1, 1, 1, 1, 'Yes', '', 'File', '-1']] - columns_md5_stack_mode_true = CompareConst.MD5_COMPARE_RESULT_HEADER + ['NPU_Stack_Info'] - result_table_md5_true = pd.DataFrame(result_md5, columns=columns_md5_stack_mode_true, dtype=object) - columns_summary_stack_mode_true = CompareConst.SUMMARY_COMPARE_RESULT_HEADER + ['NPU_Stack_Info'] - result_table_summary_true = pd.DataFrame(result_summary, columns=columns_summary_stack_mode_true, dtype=object) - columns_all_stack_mode_true = CompareConst.COMPARE_RESULT_HEADER + ['NPU_Stack_Info'] + ['Data_name'] - result_table_all_true = pd.DataFrame(result_all, columns=columns_all_stack_mode_true, dtype=object) - - stack_mode = True - auto_analyze = True - fuzzy_match = False - - dump_mode = Const.MD5 - mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) - result_df = Comparator(mode_config).make_result_table(result_md5) - self.assertTrue(result_df.equals(result_table_md5_true)) - - dump_mode = Const.SUMMARY - mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) - result_df = Comparator(mode_config).make_result_table(result_summary) - self.assertTrue(result_df.equals(result_table_summary_true)) - - dump_mode = Const.ALL - mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) - result_df = Comparator(mode_config).make_result_table(result_all) - self.assertTrue(result_df.equals(result_table_all_true)) - - def test_make_result_table_stack_mode_False(self): - result_md5_test = [['Functional.linear.0.forward.input.0', 'Functional.linear.0.forward.input.0', - 'torch.float32', 'torch.float32', [2, 2], [2, 2], '', '', '', '']] - result_md5 = [['Functional.linear.0.forward.input.0', 'Functional.linear.0.forward.input.0', - 'torch.float32', 'torch.float32', [2, 2], [2, 2], '', '', '']] - result_summary_test = [['Functional.linear.0.forward.input.0', 'Functional.linear.0.forward.input.0', - 'torch.float32', 'torch.float32', [2, 2], [2, 2], '', '', '', '', '', '', '', '', - 1, 1, 1, 1, 1, 1, 1, 1, 'Yes', '', '']] - result_summary = [['Functional.linear.0.forward.input.0', 'Functional.linear.0.forward.input.0', - 'torch.float32', 'torch.float32', [2, 2], [2, 2], '', '', '', '', '', '', '', '', - 1, 1, 1, 1, 1, 1, 1, 1, 'Yes', '']] - result_all_test = [['Functional.linear.0.forward.input.0', 'Functional.linear.0.forward.input.0', - 'torch.float32', 'torch.float32', [2, 2], [2, 2], '', '', '', '', '', - 1, 1, 1, 1, 1, 1, 1, 1, 'Yes', '', '', '-1']] - result_all = [['Functional.linear.0.forward.input.0', 'Functional.linear.0.forward.input.0', - 'torch.float32', 'torch.float32', [2, 2], [2, 2], '', '', '', '', '', - 1, 1, 1, 1, 1, 1, 1, 1, 'Yes', '', '-1']] - columns_md5_stack_mode_true = CompareConst.MD5_COMPARE_RESULT_HEADER - result_table_md5_true = pd.DataFrame(result_md5, columns=columns_md5_stack_mode_true, dtype='object') - columns_summary_stack_mode_true = CompareConst.SUMMARY_COMPARE_RESULT_HEADER - result_table_summary_true = pd.DataFrame(result_summary, columns=columns_summary_stack_mode_true, - dtype='object') - columns_all_stack_mode_true = CompareConst.COMPARE_RESULT_HEADER + ['Data_name'] - result_table_all_true = pd.DataFrame(result_all, columns=columns_all_stack_mode_true, dtype='object') - - stack_mode = False - auto_analyze = True - fuzzy_match = False - - dump_mode = Const.MD5 - mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) - result_df = Comparator(mode_config).make_result_table(result_md5_test) - self.assertTrue(result_df.equals(result_table_md5_true)) - - dump_mode = Const.SUMMARY - mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) - result_df = Comparator(mode_config).make_result_table(result_summary_test) - self.assertTrue(result_df.equals(result_table_summary_true)) - - dump_mode = Const.ALL - mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) - result_df = Comparator(mode_config).make_result_table(result_all_test) - self.assertTrue(result_df.equals(result_table_all_true)) - def test_gen_merge_list(self): op_data = { 'input_args': [ @@ -465,294 +358,533 @@ class TestUtilsMethods(unittest.TestCase): dump_mode = Const.SUMMARY mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) - result = Comparator(mode_config).gen_merge_list(json_data, op_name, stack_json_data) + result = ParseData(mode_config).gen_merge_list(json_data, op_name, stack_json_data) self.assertEqual(result, merge_list) - def test_check_op_fuzzy_false(self): - stack_mode = False - auto_analyze = True - dump_mode = Const.SUMMARY - - fuzzy_match = False - mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) - - pt_comparator = PTComparator(mode_config) - result = pt_comparator.check_op(npu_dict, bench_dict) - self.assertEqual(result, True) - - def test_check_op_fuzzy_true(self): + def test_check_op_item_fuzzy(self): stack_mode = False auto_analyze = True dump_mode = Const.SUMMARY fuzzy_match = True mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) + mapping_config = MappingConfig() - pt_comparator = PTComparator(mode_config) - result = pt_comparator.check_op(npu_dict2, bench_dict) + match = Match(mode_config, mapping_config, cross_frame=False) + result = match.check_op_item(npu_op_item_fuzzy, bench_op_item_fuzzy) self.assertEqual(result, True) - def test_match_op_both_last_element(self): - stack_mode = False - auto_analyze = True - fuzzy_match = False - dump_mode = Const.SUMMARY - mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) - - pt_comparator = PTComparator(mode_config) - a, b = pt_comparator.match_op([npu_dict], [bench_dict]) - self.assertEqual(a, 0) - self.assertEqual(b, 0) - - def test_match_op_only_npu_last_element(self): - stack_mode = False - auto_analyze = True - fuzzy_match = False - dump_mode = Const.SUMMARY - mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) - - pt_comparator = PTComparator(mode_config) - a, b = pt_comparator.match_op([npu_dict], [bench_dict, 1]) - self.assertEqual(a, 0) - self.assertEqual(b, 0) - - def test_match_op_only_bench_last_element(self): - stack_mode = False - auto_analyze = True - fuzzy_match = False - dump_mode = Const.SUMMARY - mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) - - pt_comparator = PTComparator(mode_config) - a, b = pt_comparator.match_op([npu_dict, npu_dict2], [bench_dict]) - self.assertEqual(a, 0) - self.assertEqual(b, 0) - - def test_compare_process(self): + def test_compare_statistics(self): generate_dump_json(base_dir) generate_stack_json(base_dir) - file_lists = [os.path.join(base_dir, 'dump.json'), os.path.join(base_dir, 'dump.json'), - os.path.join(base_dir, 'stack.json')] + file_list = [os.path.join(base_dir, 'dump.json'), os.path.join(base_dir, 'dump.json'), + os.path.join(base_dir, 'stack.json')] stack_mode = True auto_analyze = True fuzzy_match = False dump_mode = Const.SUMMARY mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) + mapping_config = MappingConfig() - result = PTComparator(mode_config).compare_process(file_lists) + from msprobe.pytorch.compare.pt_compare import read_real_data + comparator = Comparator(read_real_data, mode_config, mapping_config) + result = comparator.compare_statistics(file_list) o_data = [ ['Functional.linear.0.forward.input.0', 'Functional.linear.0.forward.input.0', - 'torch.float32', 'torch.float32', [2, 2], [2, 2], 0, 0, 0, 0, '0.0%', 'N/A', '0.0%', '0.0%', + 'torch.float32', 'torch.float32', '[2, 2]', '[2, 2]', 0, 0, 0, 0, '0.0%', 'N/A', '0.0%', '0.0%', 2, 0, 1, 1, 2, 0, 1, 1, '', '', ['File'] ] ] columns = CompareConst.SUMMARY_COMPARE_RESULT_HEADER + ['NPU_Stack_Info'] o_result = pd.DataFrame(o_data, columns=columns, dtype=object) - self.assertTrue(result.equals(o_result)) + self.assertTrue(np.array_equal(result.to_numpy(), o_result.to_numpy())) - def test_merge_data(self): - op_data = { - 'input_args': [ - { - 'type': 'torch.Tensor', 'dtype': 'torch.float32', 'shape': [2, 2], - 'Max': 1, 'Min': 1, 'Mean': 1, 'Norm': 1, 'requires_grad': False, - 'data_name': 'Functional.linear.0.forward.input.0.pt', - 'full_op_name': 'Functional.linear.0.forward.input.0' - } - ] - } - json_data = {'data': {'Functional.linear.0.forward': op_data}} - stack_json_data = {'Functional.linear.0.forward': ['File']} + +class TestParseData(unittest.TestCase): + + def setUp(self): + os.makedirs(base_dir, mode=0o750, exist_ok=True) + generate_dump_json(base_dir) + generate_dump_json_md5(base_dir) + generate_stack_json(base_dir) + + self.lock = threading.Lock() + + def tearDown(self): + if os.path.exists(base_dir): + shutil.rmtree(base_dir) + + def test_parse(self): + file_list = [os.path.join(base_dir, 'dump.json'), os.path.join(base_dir, 'dump.json'), + os.path.join(base_dir, 'stack.json')] stack_mode = True - auto_analyze = True - fuzzy_match = False - dump_mode = Const.SUMMARY - mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) + mode_config = ModeConfig(stack_mode=stack_mode) + parse_data = ParseData(mode_config) + npu_df, bench_df = parse_data.parse(file_list) + + target_df = pd.DataFrame( + [['Functional.linear.0.forward.input.0', 'torch.float32', [2, 2], [2, 0, 1, 1], ['File']]], + columns=['op_name', 'dtype', 'shape', 'summary', 'stack_info'] + ) + self.assertTrue(npu_df.equals(target_df)) + self.assertTrue(bench_df.equals(target_df)) + + def test_gen_data_df_summary(self): + npu_json_path = os.path.join(base_dir, 'dump.json') + stack_json_path = os.path.join(base_dir, 'stack.json') + npu_json_data = load_json(npu_json_path) + stack_json_data = load_json(stack_json_path) - result = Comparator(mode_config).merge_data(json_data, stack_json_data) - ops_all = { - 'Functional.linear.0.forward.input.0': { - 'data_name': None, 'stack_info': [['File']], - 'struct': ('torch.float32', [2, 2]), 'summary': [1, 1, 1, 1] - } + stack_mode = True + mode_config = ModeConfig(stack_mode=stack_mode) + parse_data = ParseData(mode_config) + npu_df = parse_data.gen_data_df(npu_json_data, stack_json_data) + + target_df = pd.DataFrame( + [['Functional.linear.0.forward.input.0', 'torch.float32', [2, 2], [2, 0, 1, 1], ['File']]], + columns=['op_name', 'dtype', 'shape', 'summary', 'stack_info'] + ) + self.assertTrue(npu_df.equals(target_df)) + + def test_gen_data_df_all(self): + npu_json_path = os.path.join(base_dir, 'dump.json') + stack_json_path = os.path.join(base_dir, 'stack.json') + npu_json_data = load_json(npu_json_path) + stack_json_data = load_json(stack_json_path) + + stack_mode = True + mode_config = ModeConfig(stack_mode=stack_mode, dump_mode=Const.ALL) + parse_data = ParseData(mode_config) + npu_df = parse_data.gen_data_df(npu_json_data, stack_json_data) + + target_df = pd.DataFrame( + [['Functional.linear.0.forward.input.0', 'torch.float32', [2, 2], [2, 0, 1, 1], ['File'], 'Functional.linear.0.forward.input.0.pt']], + columns=['op_name', 'dtype', 'shape', 'summary', 'stack_info', 'data_name'] + ) + self.assertTrue(npu_df.equals(target_df)) + + def test_gen_data_df_md5(self): + npu_json_path = os.path.join(base_dir, 'dump_md5.json') + stack_json_path = os.path.join(base_dir, 'stack.json') + npu_json_data = load_json(npu_json_path) + stack_json_data = load_json(stack_json_path) + + stack_mode = True + mode_config = ModeConfig(stack_mode=stack_mode, dump_mode=Const.MD5) + parse_data = ParseData(mode_config) + npu_df = parse_data.gen_data_df(npu_json_data, stack_json_data) + + target_df = pd.DataFrame( + [['Functional.linear.0.forward.input.0', 'torch.float32', [2, 2], [2, 0, 1, 1], ['File'], 123456]], + columns=['op_name', 'dtype', 'shape', 'summary', 'stack_info', 'md5'] + ) + self.assertTrue(npu_df.equals(target_df)) + + def test_gen_merge_list(self): + npu_json_path = os.path.join(base_dir, 'dump.json') + stack_json_path = os.path.join(base_dir, 'stack.json') + npu_json_data = load_json(npu_json_path) + stack_json_data = load_json(stack_json_path) + + stack_mode = True + mode_config = ModeConfig(stack_mode=stack_mode) + parse_data = ParseData(mode_config) + merge_list = parse_data.gen_merge_list(npu_json_data, 'Functional.linear.0.forward', stack_json_data) + + target_dict = { + 'input_struct': [('torch.float32', [2, 2])], + 'op_name': ['Functional.linear.0.forward.input.0'], + 'output_struct': [], + 'params_grad_struct': [], + 'params_struct': [], + 'stack_info': [['File']], + 'summary': [[2, 0, 1, 1]] } - self.assertEqual(result, ops_all) - - def test_compare_core_basic(self): - generate_dump_json(base_dir2) - generate_stack_json(base_dir2) - input_params = { - "npu_json_path": os.path.join(base_dir2, "dump.json"), - "bench_json_path": os.path.join(base_dir2, "dump.json"), - "stack_json_path": os.path.join(base_dir2, "stack.json"), + self.assertEqual(merge_list, target_dict) + + +class TestProcessDf(unittest.TestCase): + + def test_get_api_name_success(self): + api_list = ['Functional', 'linear', '0', 'forward', 'input', '0'] + + mode_config = ModeConfig() + mapping_config = MappingConfig() + mapping_dict = MappingDict(mapping_config) + process_df = ProcessDf(mode_config, mapping_config, mapping_dict) + api_name = process_df.get_api_name(api_list) + + target_api_name = 'Functional.linear' + self.assertEqual(api_name, target_api_name) + + @patch('msprobe.core.compare.acc_compare.logger') + def test_get_api_name_index_error(self, mock_logger): + api_list = ['Functional'] + with self.assertRaises(CompareException) as context: + mode_config = ModeConfig() + mapping_config = MappingConfig() + mapping_dict = MappingDict(mapping_config) + process_df = ProcessDf(mode_config, mapping_config, mapping_dict) + api_name = process_df.get_api_name(api_list) + self.assertEqual(context.exception.code, CompareException.INDEX_OUT_OF_BOUNDS_ERROR) + mock_logger.error.assert_called_once_with('Failed to retrieve API name, please check if the dump data is reasonable') + + def test_process_compare_key_and_shape(self): + npu_df_o = bench_df_o = pd.DataFrame( + [['Functional.linear.0.forward.input.0', 'torch.float32', [2, 2], [2, 0, 1, 1], ['File']]], + columns=['op_name', 'dtype', 'shape', 'summary', 'stack_info'] + ) + + mode_config = ModeConfig() + mapping_config = MappingConfig() + mapping_dict = MappingDict(mapping_config) + process_df = ProcessDf(mode_config, mapping_config, mapping_dict) + npu_df, bench_df = process_df.process_compare_key_and_shape(npu_df_o, bench_df_o) + + target_df = pd.DataFrame( + [['Functional.linear.0.forward.input.0', 'torch.float32', [2, 2], [2, 0, 1, 1], ['File'], 'Functional.linear.0.forward.input.0', [2, 2]]], + columns=['op_name', 'dtype', 'shape', 'summary', 'stack_info', 'compare_key', 'compare_shape'] + ) + self.assertTrue(npu_df.equals(target_df)) + self.assertTrue(bench_df.equals(target_df)) + + def test_process_internal_api_mapping(self): + mode_config = ModeConfig() + mapping_config = MappingConfig() + mapping_dict = MappingDict(mapping_config) + process_df = ProcessDf(mode_config, mapping_config, mapping_dict) + + # mint to torch + npu_op_name = 'Mint.mean.0.input.0' + target_name = 'Torch.mean.0.input.0' + name = process_df.process_internal_api_mapping(npu_op_name) + self.assertEqual(name, target_name) + + # mintfunctional to functional + npu_op_name = 'MintFunctional.mean.0.input.0' + target_name = 'Functional.mean.0.input.0' + name = process_df.process_internal_api_mapping(npu_op_name) + self.assertEqual(name, target_name) + + # inner mapping exists + npu_op_name = 'Functional.abs.0.input.0' + mapping_dict.ms_to_pt_mapping = {'Functional.abs': 'Torch.abs'} + target_name = 'Torch.abs.0.input.0' + name = process_df.process_internal_api_mapping(npu_op_name) + self.assertEqual(name, target_name) + + # inner mapping not found + npu_op_name = 'Functional.abs.0.input.0' + mapping_dict.ms_to_pt_mapping = {} + target_name = 'Functional.abs.0.input.0' + name = process_df.process_internal_api_mapping(npu_op_name) + self.assertEqual(name, target_name) + + def test_modify_compare_data_with_user_mapping(self): + mode_config = ModeConfig() + mapping_config = MappingConfig() + mapping_dict = MappingDict(mapping_config) + process_df = ProcessDf(mode_config, mapping_config, mapping_dict) + mapping_dict.api_mapping_dict = [{ + 'ms_api': 'Functional.conv2d', + 'pt_api': 'Torch.conv2d', + 'ms_args': [0], + 'pt_args': [0] + }] + + npu_df = pd.DataFrame([ + ['Functional.conv2d.0.forward.input.0', 'float32', [1, 2], 'summary', 'stack_info', 'Functional.conv2d.0.forward.input.0'], + ['Functional.amax.0.forward.input.0', 'float32', [1, 2], 'summary', 'stack_info', 'Functional.amax.0.forward.input.0'] + ], columns=['op_name', 'dtype', 'shape', 'summary', 'stack_info', 'compare_key']) + bench_df = pd.DataFrame([ + ['Torch.conv2d.0.forward.input.0', 'float32', [1, 2], 'summary', 'stack_info', 'Torch.conv2d.0.forward.input.0'], + ['Torch.amax.0.forward.input.0', 'float32', [1, 2], 'summary', 'stack_info', 'Torch.amax.0.forward.input.0'] + ], columns=['op_name', 'dtype', 'shape', 'summary', 'stack_info', 'compare_key']) + + process_df.modify_compare_data_with_user_mapping(npu_df, bench_df) + + def test_get_api_indices_dict(self): + mode_config = ModeConfig() + mapping_config = MappingConfig() + mapping_dict = MappingDict(mapping_config) + process_df = ProcessDf(mode_config, mapping_config, mapping_dict) + + op_name_df = pd.DataFrame([ + ['Functional.conv2d.0.forward.input.0', 'float32', [1, 2], 'summary', 'stack_info', 'Functional.conv2d.0.forward.input.0'], + ['Functional.amax.0.forward.input.0', 'float32', [1, 2], 'summary', 'stack_info', 'Functional.amax.0.forward.input.0'] + ], columns=['op_name', 'dtype', 'shape', 'summary', 'stack_info', 'compare_key']) + + api_indices_dict = process_df.get_api_indices_dict(op_name_df) + expected = { + 'Functional.conv2d': [0], + 'Functional.amax': [1] } - output_path = base_dir2 + self.assertEqual(api_indices_dict, expected) + + def test_process_cell_mapping(self): + mode_config = ModeConfig() + mapping_config = MappingConfig() + mapping_dict = MappingDict(mapping_config) + process_df = ProcessDf(mode_config, mapping_config, mapping_dict) + + # not name + npu_op_name = None + name = process_df.process_cell_mapping(npu_op_name) + self.assertEqual(name, CompareConst.N_A) + + # not params_grad + npu_op_name = 'MintFunctional.embedding.0.input.0' + name = process_df.process_cell_mapping(npu_op_name) + self.assertEqual(name, CompareConst.N_A) + + # default replace + npu_op_name = 'Cell.network_with_loss.module.GPTModel.forward.1.input.0' + name = process_df.process_cell_mapping(npu_op_name) + self.assertEqual(name, 'Module.network_with_loss.module.GPTModel.forward.1.input.0') + + # mapping_dict + npu_op_name = 'Cell.fc1.Dense.forward.0.input.0' + mapping_dict.cell_mapping_dict = {'fc1.Dense': 'module.name'} + name = process_df.process_cell_mapping(npu_op_name) + self.assertEqual(name, 'Module.module.name.forward.0.input.0') + + def test_process_data_mapping(self): + mode_config = ModeConfig() + mapping_config = MappingConfig() + mapping_dict = MappingDict(mapping_config) + process_df = ProcessDf(mode_config, mapping_config, mapping_dict) + + npu_op_name = 'Functional.flash_attention_score.4.forward.input.0' + mapping_dict.data_mapping_dict = {'Functional.flash_attention_score.4.forward.input.0': 'NPU.npu_fusion_attention.4.forward.input.0'} + name = process_df.process_data_mapping(npu_op_name) + self.assertEqual(name, 'NPU.npu_fusion_attention.4.forward.input.0') + + +class TestMatch(unittest.TestCase): + + def test_put_unmatched_in_table(self): + mode_config = ModeConfig() + mapping_config = MappingConfig() + match = Match(mode_config, mapping_config, cross_frame=False) + + match_result = pd.DataFrame(columns=CompareConst.MATCH_RESULT_COLUMNS) + npu_op_item = pd.Series(['op', 'float32', [1, 2], 'summary', 'stack_info', 'data_name', 'op', [1, 2]], + index=['op_name_x', 'dtype_x', 'shape_x', 'summary_x', 'stack_info_x', 'data_name_x', + 'compare_key', 'compare_shape'] + ) + match_result = match.put_unmatched_in_table(match_result, npu_op_item) + target_match_result = pd.DataFrame([['op', 'float32', [1, 2], 'summary', 'stack_info', 'data_name', 'op', [1, 2], + 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A']], + columns=CompareConst.MATCH_RESULT_COLUMNS) + self.assertTrue(match_result.equals(target_match_result)) + + def test_put_matched_in_table(self): + mode_config = ModeConfig() + mapping_config = MappingConfig() + match = Match(mode_config, mapping_config, cross_frame=False) + + match_result = pd.DataFrame(columns=CompareConst.MATCH_RESULT_COLUMNS) + npu_op_item = pd.Series(['op', 'float32', [1, 2], 'summary', 'stack_info', 'data_name', 'op', [1, 2]], + index=['op_name_x', 'dtype_x', 'shape_x', 'summary_x', 'stack_info_x', 'data_name_x', + 'compare_key', 'compare_shape'] + ) + bench_op_item = pd.Series(['op', 'float32', [1, 2], 'summary', 'stack_info', 'data_name', 'op', [1, 2]], + index=['op_name_y', 'dtype_y', 'shape_y', 'summary_y', 'stack_info_y', 'data_name_y', + 'compare_key', 'compare_shape'] + ) + match_result = match.put_matched_in_table(match_result, npu_op_item, bench_op_item) + target_match_result = pd.DataFrame([['op', 'float32', [1, 2], 'summary', 'stack_info', 'data_name', 'op', [1, 2], + 'op', 'float32', [1, 2], 'summary', 'stack_info', 'data_name']], + columns=CompareConst.MATCH_RESULT_COLUMNS) + self.assertTrue(match_result.equals(target_match_result)) + + def test_rename_api(self): + mode_config = ModeConfig() + mapping_config = MappingConfig() + match = Match(mode_config, mapping_config, cross_frame=False) + + op_name_1 = 'Functional.linear.0.forward.input.0' + result_1 = match.rename_api(op_name_1) + self.assertTrue(result_1, 'Functional.linear.input.0') + + op_name_2 = 'Functional.linear.0.backward.input.0' + result_2 = match.rename_api(op_name_2) + self.assertTrue(result_2, 'Functional.linear.input.0') + + op_name_3 = 'Functional.linear.0.x.input.0' + result_3 = match.rename_api(op_name_3) + self.assertTrue(result_3, 'Functional.linear.0.x.input.0') + + def test_check_op_item(self): + mode_config = ModeConfig() + mapping_config = MappingConfig() + match = Match(mode_config, mapping_config, cross_frame=False) + + npu_op_item = pd.Series(['op', 'float32', [1, 2], 'summary', 'stack_info', 'data_name', 'Functional.linear.0.forward.input.0', [1, 2]], + index=['op_name_x', 'dtype_x', 'shape_x', 'summary_x', 'stack_info_x', 'data_name_x', + 'compare_key', 'compare_shape'] + ) + bench_op_item = pd.Series(['op', 'float32', [1, 2], 'summary', 'stack_info', 'data_name', 'Functional.linear.1.forward.input.0', [1, 2]], + index=['op_name_y', 'dtype_y', 'shape_y', 'summary_y', 'stack_info_y', 'data_name_y', + 'compare_key', 'compare_shape'] + ) + result = match.check_op_item(npu_op_item, bench_op_item) + self.assertTrue(result) + + def test_process_fuzzy_match(self): + mode_config = ModeConfig() + mapping_config = MappingConfig() + match = Match(mode_config, mapping_config, cross_frame=False) + + npu_df = pd.DataFrame([ + ['Functional.conv2d.3.forward.input.0', 'float32', [1, 2], 'summary', 'stack_info', 'Functional.conv2d.3.forward.input.0.pt', 'Functional.conv2d.3.forward.input.0', [1, 2]], + ['Functional.amax.1.forward.input.0', 'float32', [1, 2], 'summary', 'stack_info', 'Functional.amax.0.forward.input.0.pt', 'Functional.amax.1.forward.input.0', [1, 2]] + ], columns=['op_name', 'dtype', 'shape', 'summary', 'stack_info', 'data_name', 'compare_key', 'compare_shape']) + bench_df = pd.DataFrame([ + ['Functional.conv2d.0.forward.input.0', 'float32', [1, 2], 'summary', 'stack_info', 'Functional.conv2d.0.forward.input.0.pt', 'Functional.conv2d.0.forward.input.0', [1, 2]], + ['Functional.amax.0.forward.input.0', 'float32', [1, 2], 'summary', 'stack_info', 'Functional.amax.0.forward.input.0.pt', 'Functional.amax.0.forward.input.0', [1, 2]] + ], columns=['op_name', 'dtype', 'shape', 'summary', 'stack_info', 'data_name', 'compare_key', 'compare_shape']) + + match_result = match.process_fuzzy_match(npu_df, bench_df) + expected = pd.DataFrame( + [ + ['Functional.conv2d.3.forward.input.0', 'float32', [1, 2], 'summary', 'stack_info', 'Functional.conv2d.3.forward.input.0.pt', 'Functional.conv2d.3.forward.input.0', [1, 2], 'Functional.conv2d.0.forward.input.0', 'float32', [1, 2], 'summary', 'stack_info', 'Functional.conv2d.0.forward.input.0.pt'], + ['Functional.amax.1.forward.input.0', 'float32', [1, 2], 'summary', 'stack_info', 'Functional.amax.0.forward.input.0.pt', 'Functional.amax.1.forward.input.0', [1, 2], 'Functional.amax.0.forward.input.0', 'float32', [1, 2], 'summary', 'stack_info', 'Functional.amax.0.forward.input.0.pt'] + ] + , columns=CompareConst.MATCH_RESULT_COLUMNS) - stack_mode = True + self.assertTrue(match_result.equals(expected)) + + def test_match_op_both_last_element(self): + stack_mode = False auto_analyze = True fuzzy_match = False dump_mode = Const.SUMMARY mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) + mapping_config = MappingConfig() - PTComparator(mode_config).compare_core(input_params, output_path) - - output_files = os.listdir(output_path) - self.assertTrue(any(f.endswith(".xlsx") for f in output_files)) - - def test_compare_ops(self): - generate_dump_json(base_dir3) - generate_stack_json(base_dir3) - generate_pt(pt_dir) - dump_path = os.path.join(base_dir3, 'dump.json') - stack_path = os.path.join(base_dir3, 'stack.json') - input_param = {'npu_json_path': dump_path, 'bench_json_path': dump_path, 'stack_json_path': stack_path, - 'is_print_compare_log': True, 'npu_dump_data_dir': pt_dir, 'bench_dump_data_dir': pt_dir} - dump_path_dict = {'Functional.linear.0.forward.input.0': ['Functional.linear.0.forward.input.0.pt', - 'Functional.linear.0.forward.input.0.pt']} - result_df = pd.DataFrame({ - 'NPU Name': ['Functional.linear.0.forward.input.0'], - 'Bench Name': ['Functional.linear.0.forward.input.0'] - }) + match = Match(mode_config, mapping_config, cross_frame=False) + a, b = match.match_op([npu_op_item_fuzzy], [bench_op_item_fuzzy]) + self.assertEqual(a, 0) + self.assertEqual(b, 0) - stack_mode = True + def test_match_op_only_npu_last_element(self): + stack_mode = False auto_analyze = True fuzzy_match = False - dump_mode = Const.ALL + dump_mode = Const.SUMMARY mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) + mapping_config = MappingConfig() - pt_comparator = PTComparator(mode_config) - updated_df = pt_comparator.compare_ops(idx=0, dump_path_dict=dump_path_dict, result_df=result_df, - lock=self.lock, input_param=input_param) - - self.assertEqual(updated_df.loc[0, CompareConst.COSINE], 1.0) - self.assertEqual(updated_df.loc[0, CompareConst.MAX_ABS_ERR], 0) - - def test_do_multi_process(self): - data = [['Functional.linear.0.forward.input.0', 'Functional.linear.0.forward.input.0', - 'torch.float32', 'torch.float32', [2, 2], [2, 2], - '', '', '', '', '', 1, 1, 1, 1, 1, 1, 1, 1, 'Yes', '', '-1']] - o_data = [['Functional.linear.0.forward.input.0', 'Functional.linear.0.forward.input.0', - 'torch.float32', 'torch.float32', [2, 2], [2, 2], 'unsupported', 'unsupported', 'unsupported', - 'unsupported', 'unsupported', - 1, 1, 1, 1, 1, 1, 1, 1, 'None', 'No bench data matched.', '-1']] - columns = CompareConst.COMPARE_RESULT_HEADER + ['Data_name'] - result_df = pd.DataFrame(data, columns=columns) - o_result = pd.DataFrame(o_data, columns=columns) - generate_dump_json(base_dir) - input_param = {'bench_json_path': os.path.join(base_dir, 'dump.json')} + match = Match(mode_config, mapping_config, cross_frame=False) + a, b = match.match_op([npu_op_item_fuzzy], [bench_op_item_fuzzy, 1]) + self.assertEqual(a, 0) + self.assertEqual(b, 0) - stack_mode = True + def test_match_op_only_bench_last_element(self): + stack_mode = False auto_analyze = True fuzzy_match = False - dump_mode = Const.ALL + dump_mode = Const.SUMMARY mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) + mapping_config = MappingConfig() - comparator = Comparator(mode_config) - result = comparator.do_multi_process(input_param, result_df) - self.assertTrue(result.equals(o_result)) + match = Match(mode_config, mapping_config, cross_frame=False) + a, b = match.match_op([npu_op_item_fuzzy, npu_op_item_data_fuzzy_2], [bench_op_item_fuzzy]) + self.assertEqual(a, 0) + self.assertEqual(b, 0) - def test_compare_by_op_1(self): - npu_op_name = 'Functional.linear.0.forward.input.0' - bench_op_name = 'N/A' - op_name_mapping_dict = {'Functional.linear.0.forward.input.0': [-1, -1]} - input_param = {} + def test_gen_dtype_condition(self): + mode_config = ModeConfig() + mapping_config = MappingConfig() + match = Match(mode_config, mapping_config, cross_frame=True) - stack_mode = True - auto_analyze = True - fuzzy_match = False - dump_mode = Const.ALL - mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) + # data mapping + mapping_config.data_mapping = True + match_result = pd.DataFrame([1, 2, 3]) + result = match.gen_dtype_condition(match_result) + expected = pd.Series([True, True, True]) + self.assertTrue(result.equals(expected)) - pt_comparator = PTComparator(mode_config) - result = pt_comparator.compare_by_op(npu_op_name, bench_op_name, op_name_mapping_dict, input_param, {}) + # normal + mapping_config.data_mapping = None + match_result = pd.DataFrame([['Float16', 'Float32'], ['torch.float32', 'torch.bfloat16']], columns=['dtype_x', 'dtype_y']) + result = match.gen_dtype_condition(match_result) + expected = pd.Series([True, True]) + self.assertTrue(result.equals(expected)) - self.assertEqual(result, ['unsupported', 'unsupported', 'unsupported', 'unsupported', 'unsupported', - 'No bench data matched.']) + def test_process_cross_frame_dtype(self): + mode_config = ModeConfig() + mapping_config = MappingConfig() + match = Match(mode_config, mapping_config, cross_frame=True) - def test_compare_by_op_2(self): - npu_op_name = 'Functional.linear.0.forward.input.0' - bench_op_name = 'Functional.linear.0.forward.input.0' + dtype_o = pd.Series(['Int8', 'Float16', 'torch.bool', 'Complex64', 'unknown']) + dtype = match.process_cross_frame_dtype(dtype_o) + self.assertTrue(dtype.equals(pd.Series(['int', 'float', 'bool', 'complex', 'unknown']))) - stack_mode = True - auto_analyze = True - fuzzy_match = False - dump_mode = Const.ALL - mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) - pt_comparator = PTComparator(mode_config) +class TestCreateTable(unittest.TestCase): - pt_name = '-1' - pt_path = os.path.join(base_dir, pt_name) - op_name_mapping_dict = {'Functional.linear.0.forward.input.0': [pt_path, pt_path]} - input_param = {'npu_dump_data_dir': base_dir, 'bench_dump_data_dir': base_dir} - result = pt_comparator.compare_by_op(npu_op_name, bench_op_name, op_name_mapping_dict, input_param, - {'Functional.linear.0.forward': {'input_args': [ - {'data_name': 'Functional.linear.0.forward.input.0.pt'}]}}) - self.assertEqual(result, ['unsupported', 'unsupported', 'unsupported', 'unsupported', 'unsupported', - f'Dump file: {pt_path} not found.']) + def test_process_data_name(self): + mode_config = ModeConfig() + create_table = CreateTable(mode_config) - pt_name = 'Functional.linear.0.forward.input.0.pt' - pt_path = os.path.join(base_dir, pt_name) - op_name_mapping_dict = {'Functional.linear.0.forward.input.0': [pt_path, pt_path]} - input_param = {'npu_dump_data_dir': base_dir, 'bench_dump_data_dir': base_dir} - result = pt_comparator.compare_by_op(npu_op_name, bench_op_name, op_name_mapping_dict, input_param, {}) - self.assertEqual(result, ['unsupported', 'unsupported', 'unsupported', 'unsupported', 'unsupported', - 'Bench does not have data file.']) + data = { + 'data_name_x': ['A', 'B', 'C'], + 'data_name_y': ['X', 'Y', 'Z'] + } + result_o = pd.DataFrame(data) + result = create_table.process_data_name(result_o) + target_data = { + 'data_name_x': [['A', 'X'], ['B', 'Y'], ['C', 'Z']], + 'data_name_y': ['X', 'Y', 'Z'] + } + target_result = pd.DataFrame(target_data) + self.assertTrue(result.equals(target_result)) - generate_pt(base_dir) - result = pt_comparator.compare_by_op(npu_op_name, bench_op_name, op_name_mapping_dict, input_param, - {'Functional.linear.0.forward': {'input_args': [ - {'data_name': 'Functional.linear.0.forward.input.0.pt'}]}}) - self.assertEqual(result, [1.0, 0.0, 0.0, 1.0, 1.0, '']) + def test_set_summary(self): + mode_config = ModeConfig() + create_table = CreateTable(mode_config) - def test_get_bench_data_name_input(self): - bench_op_name = "Functional.linear.0.forward.input.0" - bench_data = {"Functional.linear.0.forward": {"input_args": [{"data_name": "Functional.linear.0.forward.input.0.pt"}], "input_kwargs": {}, "output": []}} - result = get_bench_data_name(bench_op_name, bench_data) + # all nan + result = create_table.set_summary(['nan', 'NaN', 'nAn']) + expected = [CompareConst.NAN, CompareConst.NAN, CompareConst.NAN] + self.assertEqual(result, expected) - self.assertEqual(result, "Functional.linear.0.forward.input.0.pt") + # mixed values + result = create_table.set_summary([1, 'nan', 2.0, 'NaN']) + expected = [1, CompareConst.NAN, 2.0, CompareConst.NAN] + self.assertEqual(result, expected) - def test_get_bench_data_name_output(self): - bench_op_name = "Functional.linear.0.forward.output.0" - bench_data = {"Functional.linear.0.forward": {"input_args": [], "input_kwargs": {}, "output": [{"data_name": "Functional.linear.0.forward.output.0.pt"}]}} - result = get_bench_data_name(bench_op_name, bench_data) + # NA case + result = create_table.set_summary(CompareConst.N_A) + expected = [CompareConst.N_A, CompareConst.N_A, CompareConst.N_A, CompareConst.N_A] + self.assertEqual(result, expected) - self.assertEqual(result, "Functional.linear.0.forward.output.0.pt") + # empty input + result = create_table.set_summary([]) + expected = [] + self.assertEqual(result, expected) -class TestComparator(unittest.TestCase): - def setUp(self): - mode_config = ModeConfig(dump_mode=Const.MD5) - self.comparator = Comparator(mode_config=mode_config) - self.npu_ops_all = { - 'op1': {'struct': ['float32', [1, 96, 2], '83dcefb7']}, - } - self.bench_ops_all = { - 'op1': {'struct': ['float32', [1, 96, 2], '83dcefb7']}, - } +class TestCalcStatsDiff(unittest.TestCase): - def test_normal(self): - expected_result = ['op1', 'op1', 'float32', 'float32', [1, 96, 2], [1, 96, 2], '83dcefb7', '83dcefb7', - CompareConst.PASS, CompareConst.NONE] - result = self.comparator.get_result_md5_compare('op1', 'op1', - self.npu_ops_all, self.bench_ops_all) - self.assertEqual(result, expected_result) + def test_type_check(self): + mode_config = ModeConfig() + calc_stats_diff = CalcStatsDiff(mode_config) - @patch('msprobe.core.compare.acc_compare.logger') - def test_length_exception(self, mock_logger): - self.npu_ops_all['op1']['struct'] = ['npu_val1', 'npu_val2'] - with self.assertRaises(CompareException) as context: - self.comparator.get_result_md5_compare('op1', 'op1', - self.npu_ops_all, self.bench_ops_all) - self.assertEqual(context.exception.code, CompareException.INDEX_OUT_OF_BOUNDS_ERROR) - mock_logger.error.assert_called_once_with("The length of npu_struct and bench_struct must be >= 3, " - "but got npu_struct=2 and bench_struct=3. Please check!") - - def test_with_extra_args(self): - expected_result = ['op1', 'op1', 'float32', 'float32', [1, 96, 2], [1, 96, 2], '83dcefb7', '83dcefb7', - CompareConst.PASS, 'extra_data'] - result = self.comparator.get_result_md5_compare('op1', 'op1', - self.npu_ops_all, self.bench_ops_all, True, ['extra_data']) - self.assertEqual(result, expected_result) + series = pd.Series([float('nan'), 5, 'nan', 10, 'abc', None]) + result = calc_stats_diff.type_check(series) + expected = pd.Series([True, True, True, True, False, False]) + self.assertTrue(result.equals(expected)) + + def test_get_number(self): + mode_config = ModeConfig() + calc_stats_diff = CalcStatsDiff(mode_config) + + series = pd.Series([1, '2', 3.5, 'text', None]) + result = calc_stats_diff.get_number(series) + expected = pd.Series([1, 2, 3.5, float('nan'), float('nan')]) + self.assertTrue(result.equals(expected)) diff --git a/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare_check.py b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare_check.py index a1e5f8eee1bce9b170e6f4f7fdfeda65d47252c9..fdfd124222f03599f914c77eb16c42c8d3578a7b 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare_check.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare_check.py @@ -1,7 +1,6 @@ # coding=utf-8 import unittest -from msprobe.core.compare.check import check_struct_match, check_type_shape_match, check_graph_mode, fuzzy_check_op, \ - fuzzy_check_name, check_dump_json_str, check_json_key_value, valid_key_value, check_stack_json_str +from msprobe.core.compare.check import check_dump_json_str, check_json_key_value, valid_key_value, check_stack_json_str from msprobe.core.common.utils import CompareException @@ -66,86 +65,6 @@ op_name = 'Functional.conv2d.0.backward.input.0' class TestUtilsMethods(unittest.TestCase): - def test_check_struct_match_success(self): - result = check_struct_match(npu_dict, bench_dict) - self.assertTrue(result) - - def test_check_struct_match_fail(self): - npu_dict2 = {'input_struct': [('torch.float32', [1, 1, 28, 28]), ('torch.float32', [16, 1, 5, 5]), - ('torch.float32', [16])], - 'output_struct': [('torch.float32', [1, 16, 28, 28])] - } - - bench_dict2 = {'input_struct': [('torch.float32', [2, 1, 28, 28]), ('torch.float32', [16, 1, 5, 5]), - ('torch.float32', [16])], - 'output_struct': [('torch.float32', [1, 16, 28, 28])] - } - result = check_struct_match(npu_dict2, bench_dict2) - self.assertFalse(result) - - def test_check_struct_index_error(self): - npu_dict3 = {'input_struct': [('a'), ('torch.float32'), - ('torch.float32')], - 'output_struct': [('torch.float32')] - } - - bench_dict3 = {'input_struct': [('torch.float32'), ('torch.float32'), - ('torch.float32')], - 'output_struct': [('torch.float32')] - } - with self.assertRaises(CompareException) as context: - result = check_struct_match(npu_dict3, bench_dict3) - self.assertEqual(context.exception.code, CompareException.INDEX_OUT_OF_BOUNDS_ERROR) - - def test_check_type_shape_match_success(self): - result = check_type_shape_match(npu_struct, bench_struct) - self.assertTrue(result) - - def test_check_type_shape_match_index_error(self): - npu_struct2 = [('a'), ('torch.float32'), ('torch.float32')] - bench_struct2 = [('torch.float32'), ('torch.float32'), ('torch.float32')] - with self.assertRaises(CompareException) as context: - result = check_type_shape_match(npu_struct2, bench_struct2) - self.assertEqual(context.exception.code, CompareException.INDEX_OUT_OF_BOUNDS_ERROR) - - def test_check_graph_mode(self): - op1 = "Aten" - op2 = "torch" - self.assertTrue(check_graph_mode(op1, op2)) - self.assertTrue(check_graph_mode(op2, op1)) - self.assertFalse(check_graph_mode(op1, op1)) - self.assertFalse(check_graph_mode(op2, op2)) - - def test_fuzzy_check_op_1(self): - npu_name_list = [] - bench_name_list = [] - result = fuzzy_check_op(npu_name_list, bench_name_list) - self.assertFalse(result) - - def test_fuzzy_check_op_2(self): - npu_name_list = [] - bench_name_list = ['Functional.conv2d.0.forward.input.0'] - result = fuzzy_check_op(npu_name_list, bench_name_list) - self.assertFalse(result) - - def test_fuzzy_check_op_3(self): - npu_name_list = ['Functional.conv2d.0.forward.input.0'] - bench_name_list = ['Functional.conv2d.1.forward.input.0'] - result = fuzzy_check_op(npu_name_list, bench_name_list) - self.assertTrue(result) - - def test_fuzzy_check_name_1(self): - npu_name = 'Functional.conv2d.0.backward.input.0' - bench_name = 'Functional.conv2d.1.backward.input.0' - result = fuzzy_check_name(npu_name, bench_name) - self.assertTrue(result) - - def test_fuzzy_check_name_2(self): - npu_name = 'Functional.conv2d.0.backward.input.0' - bench_name = 'Functional.conv2d.1.backward.input.1' - result = fuzzy_check_name(npu_name, bench_name) - self.assertFalse(result) - def test_check_dump_json_str(self): with self.assertRaises(CompareException) as context: check_dump_json_str(op_data, op_name) diff --git a/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare_npy_compare.py b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare_npy_compare.py index aec6cdc51173ae817f32dd76455bec645659b45c..a30d693f7b32a806dee8667e42794259e7785545 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare_npy_compare.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare_npy_compare.py @@ -20,7 +20,7 @@ from unittest.mock import patch from msprobe.core.common.const import CompareConst from msprobe.core.compare.npy_compare import handle_inf_nan, reshape_value, get_error_flag_and_msg, \ npy_data_check, statistics_data_check, get_relative_err, GetCosineSimilarity, GetMaxAbsErr, GetMaxRelativeErr, \ - GetErrRatio, error_value_process, compare_ops_apply + GetErrRatio, error_value_process, compare_ops_apply, GetEuclideanDistance op_name = 'Functional.conv2d.0.backward.input.0' @@ -113,7 +113,7 @@ class TestUtilsMethods(unittest.TestCase): n_value, b_value, error_flag, err_msg = get_error_flag_and_msg(n_value, b_value, error_flag=error_flag) self.assertFalse(error_flag) - self.assertEqual(err_msg, "This is type of 0-d tensor, can not calculate 'Cosine', " + self.assertEqual(err_msg, "This is type of 0-d tensor, can not calculate 'Cosine', 'EucDist', " "'One Thousandth Err Ratio' and 'Five Thousandths Err Ratio'. ") def test_get_error_flag_and_msg_shape_unmatch(self): @@ -239,15 +239,17 @@ class TestUtilsMethods(unittest.TestCase): b_value_1 = np.array(1) relative_err = get_relative_err(n_value_1, b_value_1) n_value_1, b_value_1 = reshape_value(n_value_1, b_value_1) - result, err_msg = op.apply(n_value_1, b_value_1, relative_err) + err_msg = "This is type of 0-d tensor, can not calculate 'Cosine', 'EucDist', 'One Thousandth Err Ratio' and 'Five Thousandths Err Ratio'. " + result, err_msg = op.apply(n_value_1, b_value_1, relative_err, err_msg) self.assertEqual(result, CompareConst.UNSUPPORTED) - self.assertEqual(err_msg, "") + self.assertEqual(err_msg, "This is type of 0-d tensor, can not calculate 'Cosine', 'EucDist', 'One Thousandth Err Ratio' and 'Five Thousandths Err Ratio'. ") n_value_2 = np.array([1, 2]) b_value_2 = np.array([1, 2]) relative_err = get_relative_err(n_value_2, b_value_2) n_value_2, b_value_2 = reshape_value(n_value_2, b_value_2) - result, err_msg = op.apply(n_value_2, b_value_2, relative_err) + err_msg = "" + result, err_msg = op.apply(n_value_2, b_value_2, relative_err, err_msg) self.assertEqual(result, 1.0) self.assertEqual(err_msg, "") @@ -255,7 +257,8 @@ class TestUtilsMethods(unittest.TestCase): b_value_3 = np.array([0, 0]) relative_err = get_relative_err(n_value_3, b_value_3) n_value_3, b_value_3 = reshape_value(n_value_3, b_value_3) - result, err_msg = op.apply(n_value_3, b_value_3, relative_err) + err_msg = "" + result, err_msg = op.apply(n_value_3, b_value_3, relative_err, err_msg) self.assertEqual(result, 1.0) self.assertEqual(err_msg, "") @@ -263,7 +266,8 @@ class TestUtilsMethods(unittest.TestCase): b_value_4 = np.array([1, 2]) relative_err = get_relative_err(n_value_4, b_value_4) n_value_4, b_value_4 = reshape_value(n_value_4, b_value_4) - result, err_msg = op.apply(n_value_4, b_value_4, relative_err) + err_msg = "" + result, err_msg = op.apply(n_value_4, b_value_4, relative_err, err_msg) self.assertEqual(result, CompareConst.NAN) self.assertEqual(err_msg, 'Cannot compare by Cosine Similarity, All the data is Zero in npu dump data.') @@ -271,7 +275,8 @@ class TestUtilsMethods(unittest.TestCase): b_value_5 = np.array([0, 0]) relative_err = get_relative_err(n_value_5, b_value_5) n_value_5, b_value_5 = reshape_value(n_value_5, b_value_5) - result, err_msg = op.apply(n_value_5, b_value_5, relative_err) + err_msg = "" + result, err_msg = op.apply(n_value_5, b_value_5, relative_err, err_msg) self.assertEqual(result, CompareConst.NAN) self.assertEqual(err_msg, 'Cannot compare by Cosine Similarity, All the data is Zero in Bench dump data.') @@ -282,7 +287,9 @@ class TestUtilsMethods(unittest.TestCase): b_value_1 = np.array([1]) relative_err = get_relative_err(n_value_1, b_value_1) n_value_1, b_value_1 = reshape_value(n_value_1, b_value_1) - result, err_msg = op.apply(n_value_1, b_value_1, relative_err) + err_msg = "" + + result, err_msg = op.apply(n_value_1, b_value_1, relative_err, err_msg) self.assertEqual(result, CompareConst.UNSUPPORTED) self.assertEqual(err_msg, "This is a 1-d tensor of length 1.") @@ -294,8 +301,9 @@ class TestUtilsMethods(unittest.TestCase): b_value = np.array([1, 1]) relative_err = get_relative_err(n_value, b_value) n_value, b_value = reshape_value(n_value, b_value) + err_msg = "" - result, err_msg = op.apply(n_value, b_value, relative_err) + result, err_msg = op.apply(n_value, b_value, relative_err, err_msg) self.assertEqual(result, CompareConst.NAN) self.assertEqual(err_msg, "Cannot compare by Cosine Similarity, the dump data has NaN.") @@ -319,8 +327,9 @@ class TestUtilsMethods(unittest.TestCase): b_value = np.array([0, 0]) relative_err = get_relative_err(n_value, b_value) n_value, b_value = reshape_value(n_value, b_value) + err_msg = "" - result, err_msg = op.apply(n_value, b_value, relative_err) + result, err_msg = op.apply(n_value, b_value, relative_err, err_msg) self.assertEqual(result, 2.0) self.assertEqual(err_msg, "") @@ -333,8 +342,9 @@ class TestUtilsMethods(unittest.TestCase): b_value = np.array([1, 1]) relative_err = get_relative_err(n_value, b_value) n_value, b_value = reshape_value(n_value, b_value) + err_msg = "" - result, err_msg = op.apply(n_value, b_value, relative_err) + result, err_msg = op.apply(n_value, b_value, relative_err, err_msg) self.assertEqual(result, CompareConst.NAN) self.assertEqual(err_msg, "Cannot compare by MaxAbsError, the data contains nan/inf/-inf in dump data.") @@ -347,8 +357,9 @@ class TestUtilsMethods(unittest.TestCase): b_value = np.array([1, 1]) relative_err = get_relative_err(n_value, b_value) n_value, b_value = reshape_value(n_value, b_value) + err_msg = "" - result, err_msg = op.apply(n_value, b_value, relative_err) + result, err_msg = op.apply(n_value, b_value, relative_err, err_msg) self.assertEqual(result, 1.0) self.assertEqual(err_msg, "") @@ -361,8 +372,9 @@ class TestUtilsMethods(unittest.TestCase): b_value = np.array([1, 1]) relative_err = get_relative_err(n_value, b_value) n_value, b_value = reshape_value(n_value, b_value) + err_msg = "" - result, err_msg = op.apply(n_value, b_value, relative_err) + result, err_msg = op.apply(n_value, b_value, relative_err, err_msg) self.assertEqual(result, CompareConst.NAN) self.assertEqual(err_msg, "Cannot compare by MaxRelativeError, the data contains nan/inf/-inf in dump data.") @@ -375,8 +387,9 @@ class TestUtilsMethods(unittest.TestCase): b_value = np.array([1, 1]) relative_err = get_relative_err(n_value, b_value) n_value, b_value = reshape_value(n_value, b_value) + err_msg = "" - result, err_msg = op.apply(n_value, b_value, relative_err) + result, err_msg = op.apply(n_value, b_value, relative_err, err_msg) self.assertEqual(result, 0.5) self.assertEqual(err_msg, "") @@ -387,11 +400,12 @@ class TestUtilsMethods(unittest.TestCase): n_value = np.array(1) # 标量 b_value = np.array(1) relative_err = np.array(0) + err_msg = "This is type of 0-d tensor, can not calculate 'Cosine', 'EucDist', 'One Thousandth Err Ratio' and 'Five Thousandths Err Ratio'. " - result, err_msg = op.apply(n_value, b_value, relative_err) + result, err_msg = op.apply(n_value, b_value, relative_err, err_msg) self.assertEqual(result, CompareConst.UNSUPPORTED) - self.assertEqual(err_msg, "") + self.assertEqual(err_msg, "This is type of 0-d tensor, can not calculate 'Cosine', 'EucDist', 'One Thousandth Err Ratio' and 'Five Thousandths Err Ratio'. ") def test_GetThousandErrRatio_not_size(self): op = GetErrRatio(CompareConst.THOUSAND_RATIO_THRESHOLD) @@ -399,8 +413,9 @@ class TestUtilsMethods(unittest.TestCase): n_value = np.array([1, 2]) b_value = np.array([1, 2]) relative_err = np.array([]) # 空数组 + err_msg = "" - result, err_msg = op.apply(n_value, b_value, relative_err) + result, err_msg = op.apply(n_value, b_value, relative_err, err_msg) self.assertEqual(result, CompareConst.NAN) self.assertEqual(err_msg, "") @@ -412,8 +427,9 @@ class TestUtilsMethods(unittest.TestCase): b_value = np.array([1, 1]) relative_err = get_relative_err(n_value, b_value) n_value, b_value = reshape_value(n_value, b_value) + err_msg = "" - result, err_msg = op.apply(n_value, b_value, relative_err) + result, err_msg = op.apply(n_value, b_value, relative_err, err_msg) self.assertEqual(result, 0.5) self.assertEqual(err_msg, "") @@ -438,7 +454,7 @@ class TestUtilsMethods(unittest.TestCase): result, err_msg = error_value_process(n_value) - self.assertEqual(result, 0) + self.assertEqual(result, CompareConst.UNSUPPORTED) self.assertEqual(err_msg, "") def test_error_value_process_shape_unmatch(self): @@ -471,5 +487,34 @@ class TestUtilsMethods(unittest.TestCase): error_flag = False err_msg = '' a, b = compare_ops_apply(n_value, b_value, error_flag, err_msg) - self.assertEqual(a, [1.0, 0.0, 0.0, 1.0, 1.0]) + self.assertEqual(a, [1.0, 0.0, 0.0, 0.0, 1.0, 1.0]) self.assertEqual(b, '') + + +class TestGetEuclideanDistance(unittest.TestCase): + + def setUp(self): + self.euc_distance = GetEuclideanDistance() + + def test_euclidean_distance_normal(self): + # 测试计算两个张量之间的欧式距离 + n_value = np.array([1, 2, 3]) + b_value = np.array([4, 5, 6]) + relative_err = None + err_msg = "" + + result, msg = self.euc_distance.apply(n_value, b_value, relative_err, err_msg) + expected_distance = np.linalg.norm(n_value - b_value) + self.assertEqual(result, expected_distance) + self.assertEqual(msg, '') + + def test_euclidean_distance_0d_tensor(self): + # 测试计算两个张量之间的欧式距离 + n_value = np.array(1) + b_value = np.array(1) + relative_err = None + err_msg = "This is type of 0-d tensor, can not calculate 'Cosine', 'EucDist', 'One Thousandth Err Ratio' and 'Five Thousandths Err Ratio'. " + + result, msg = self.euc_distance.apply(n_value, b_value, relative_err, err_msg) + self.assertEqual(result, CompareConst.UNSUPPORTED) + self.assertEqual(msg, "This is type of 0-d tensor, can not calculate 'Cosine', 'EucDist', 'One Thousandth Err Ratio' and 'Five Thousandths Err Ratio'. ") diff --git a/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare_utils.py b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare_utils.py index ab8703dcd353ff32dc0722fc314ade6042d6f567..6265e31cfccbd8a741435250de9438c02374e721 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare_utils.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare_utils.py @@ -12,9 +12,9 @@ import numpy as np from msprobe.core.common.const import CompareConst, Const from msprobe.core.common.utils import CompareException from msprobe.core.compare.utils import ApiItemInfo, _compare_parser, check_and_return_dir_contents, extract_json, \ - count_struct, get_accuracy, append_stack_info, get_rela_diff_summary_mode, get_un_match_accuracy, merge_tensor, \ - op_item_parse, read_op, rename_api, resolve_api_special_parameters, result_item_init, stack_column_process, \ - table_value_is_valid, get_name_and_state, reorder_op_name_list, reorder_op_x_list, gen_op_item + count_struct, get_accuracy, get_rela_diff_summary_mode, merge_tensor, op_item_parse, read_op, result_item_init, \ + stack_column_process, table_value_is_valid, get_name_and_state, reorder_op_name_list, reorder_op_x_list, \ + gen_op_item, ApiBatch # test_read_op_1 op_data = { @@ -221,28 +221,34 @@ o_result_unmatch_2 = [ 'N/A', 'N/A', 'N/A', 'N/A', 'No bench data matched.', 'None'] ] o_result_unmatch_3 = [ - ['Functional.conv2d.0.forward.input.0', 'N/A', 'torch.float32', 'N/A', [1, 1, 28, 28], 'N/A', 'N/A', 'N/A', 'N/A', - 'N/A', 'N/A', 3.029174327850342, -2.926689624786377, -0.06619918346405029, 1.0, 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', - 'No bench data matched.', 'None', '-1'], - ['Functional.conv2d.0.forward.input.1', 'N/A', 'torch.float32', 'N/A', [16, 1, 5, 5], 'N/A', 'N/A', 'N/A', 'N/A', - 'N/A', 'N/A', 0.19919930398464203, -0.19974489510059357, 0.006269412115216255, 1.0, 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', - 'No bench data matched.', 'None', '-1'], - ['Functional.conv2d.0.forward.input.2', 'N/A', 'torch.float32', 'N/A', [16], 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', - 'N/A', 0.19734230637550354, -0.18177609145641327, 0.007903944700956345, 1.0, 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', - 'No bench data matched.', 'None', '-1'], - ['Functional.conv2d.0.forward.parameters.weight', 'N/A', 'torch.float32', 'N/A', [1, 16, 28, 28], 'N/A', 'N/A', - 'N/A', 'N/A', - 'N/A', 'N/A', 1.0, 1.0, 1.0, 1.0, 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'No bench data matched.', 'None', '-1'], - ['Functional.conv2d.0.forward.parameters.bias', 'N/A', 'torch.float32', 'N/A', [1, 16, 28, 28], 'N/A', 'N/A', 'N/A', - 'N/A', - 'N/A', 'N/A', 1.0, 1.0, 1.0, 1.0, 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'No bench data matched.', 'None', '-1'], - ['Functional.conv2d.0.forward.output.0', 'N/A', 'torch.float32', 'N/A', [1, 16, 28, 28], 'N/A', 'N/A', 'N/A', 'N/A', - 'N/A', 'N/A', 2.1166646480560303, -2.190781354904175, -0.003579073818400502, 1.0, 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', - 'No bench data matched.', 'None', '-1'], - ['Functional.conv2d.0.parameters_grad.weight', 'N/A', 'torch.float32', 'N/A', [1, 16, 28, 28], 'N/A', 'N/A', 'N/A', 'N/A', - 'N/A', 'N/A', 1.0, 1.0, 1.0, 1.0, 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'No bench data matched.', 'None', '-1'], - ['Functional.conv2d.0.parameters_grad.bias', 'N/A', 'torch.float32', 'N/A', [1, 16, 28, 28], 'N/A', 'N/A', 'N/A', 'N/A', - 'N/A', 'N/A', 1.0, 1.0, 1.0, 1.0, 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'No bench data matched.', 'None', '-1'] + ['Functional.conv2d.0.forward.input.0', 'N/A', 'torch.float32', 'N/A', [1, 1, 28, 28], 'N/A', + 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', + 3.029174327850342, -2.926689624786377, -0.06619918346405029, 1.0, 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', + 'No bench data matched.', 'None', ['-1', '-1']], + ['Functional.conv2d.0.forward.input.1', 'N/A', 'torch.float32', 'N/A', [16, 1, 5, 5], 'N/A', + 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', + 0.19919930398464203, -0.19974489510059357, 0.006269412115216255, 1.0, 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', + 'No bench data matched.', 'None', ['-1', '-1']], + ['Functional.conv2d.0.forward.input.2', 'N/A', 'torch.float32', 'N/A', [16], 'N/A', + 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', + 0.19734230637550354, -0.18177609145641327, 0.007903944700956345, 1.0, 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', + 'No bench data matched.', 'None', ['-1', '-1']], + ['Functional.conv2d.0.forward.parameters.weight', 'N/A', 'torch.float32', 'N/A', [1, 16, 28, 28], 'N/A', + 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', + 1.0, 1.0, 1.0, 1.0, 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'No bench data matched.', 'None', ['-1', '-1']], + ['Functional.conv2d.0.forward.parameters.bias', 'N/A', 'torch.float32', 'N/A', [1, 16, 28, 28], 'N/A', + 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', + 1.0, 1.0, 1.0, 1.0, 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'No bench data matched.', 'None', ['-1', '-1']], + ['Functional.conv2d.0.forward.output.0', 'N/A', 'torch.float32', 'N/A', [1, 16, 28, 28], 'N/A', + 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', + 2.1166646480560303, -2.190781354904175, -0.003579073818400502, 1.0, 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', + 'No bench data matched.', 'None', ['-1', '-1']], + ['Functional.conv2d.0.parameters_grad.weight', 'N/A', 'torch.float32', 'N/A', [1, 16, 28, 28], 'N/A', + 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', + 1.0, 1.0, 1.0, 1.0, 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'No bench data matched.', 'None', ['-1', '-1']], + ['Functional.conv2d.0.parameters_grad.bias', 'N/A', 'torch.float32', 'N/A', [1, 16, 28, 28], 'N/A', + 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', + 1.0, 1.0, 1.0, 1.0, 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'No bench data matched.', 'None', ['-1', '-1']] ] # test_merge_tensor @@ -344,18 +350,6 @@ class TestUtilsMethods(unittest.TestCase): result = check_and_return_dir_contents(base_dir2, 'rank') self.assertEqual(set(result), set(['rank0', 'rank1'])) - def test_rename_api_1(self): - test_name_1 = "Distributed.broadcast.0.forward.input.0" - expect_name_1 = "Distributed.broadcast.input.0" - actual_name_1 = rename_api(test_name_1, "forward") - self.assertEqual(actual_name_1, expect_name_1) - - def test_rename_api_2(self): - test_name_2 = "Torch.sum.0.backward.output.0" - expect_name_2 = "Torch.sum.output.0" - actual_name_2 = rename_api(test_name_2, "backward") - self.assertEqual(actual_name_2, expect_name_2) - def test_read_op(self): result = read_op(op_data, op_name) self.assertEqual(result, op_result) @@ -373,11 +367,6 @@ class TestUtilsMethods(unittest.TestCase): op_item_parse(parse_item, parse_op_name, depth=11) self.assertEqual(context.exception.code, CompareException.RECURSION_LIMIT_ERROR) - def test_resolve_api_special_parameters(self): - item_list = [] - resolve_api_special_parameters(data_dict, full_op_name, item_list) - self.assertEqual(item_list, o_result_api_special) - def test_get_rela_diff_summary_mode_float_or_int(self): result_item = [0] * 14 err_msg = '' @@ -443,57 +432,6 @@ class TestUtilsMethods(unittest.TestCase): get_accuracy(result, npu_dict, bench_dict, dump_mode=Const.SUMMARY) self.assertEqual(result, o_result) - def test_append_stack_info_stack_exist_index_0(self): - result_item = ['item1'] - npu_stack_info = ['stack_info1'] - index = 0 - - append_stack_info(result_item, npu_stack_info, index) - - self.assertEqual(result_item, ['item1', 'stack_info1']) - - def test_append_stack_info_stack_exist_index_not_0(self): - result_item = ['item1'] - npu_stack_info = ['stack_info1'] - index = 1 - - append_stack_info(result_item, npu_stack_info, index) - - self.assertEqual(result_item, ['item1', CompareConst.NONE]) - - def test_append_stack_info_stack_empty_index_0(self): - result_item = ['item1'] - npu_stack_info = [] - index = 0 - - append_stack_info(result_item, npu_stack_info, index) - - self.assertEqual(result_item, ['item1', CompareConst.NONE]) - - def test_append_stack_info_stack_empty_index_not_0(self): - result_item = ['item1'] - npu_stack_info = [] - index = 1 - - append_stack_info(result_item, npu_stack_info, index) - - self.assertEqual(result_item, ['item1', CompareConst.NONE]) - - def test_get_un_match_accuracy_md5(self): - result = [] - get_un_match_accuracy(result, npu_dict, dump_mode=Const.MD5) - self.assertEqual(result, o_result_unmatch_1) - - def test_get_un_match_accuracy_summary(self): - result = [] - get_un_match_accuracy(result, npu_dict, dump_mode=Const.SUMMARY) - self.assertEqual(result, o_result_unmatch_2) - - def test_get_un_match_accuracy_all(self): - result = [] - get_un_match_accuracy(result, npu_dict, dump_mode=Const.ALL) - self.assertEqual(result, o_result_unmatch_3) - def test_merge_tensor_summary(self): op_dict = merge_tensor(tensor_list, dump_mode=Const.SUMMARY) self.assertEqual(op_dict, result_op_dict) @@ -558,7 +496,7 @@ class TestUtilsMethods(unittest.TestCase): dump_mode = Const.ALL result_item = result_item_init(n_info, b_info, dump_mode) self.assertEqual(result_item, ['Tensor.add.0.forward.input.0', 'Tensor.add.0.forward.input.0', - 'torch.float32', 'torch.float32', [96], [96], ' ', ' ', ' ', ' ', ' ']) + 'torch.float32', 'torch.float32', [96], [96], ' ', ' ', ' ', ' ', ' ', ' ']) dump_mode = Const.SUMMARY result_item = result_item_init(n_info, b_info, dump_mode) @@ -848,3 +786,85 @@ class TestGenOpItem(unittest.TestCase): expected_md5 = f"{zlib.crc32(str(op_data['value']).encode()):08x}" self.assertEqual(result['md5'], expected_md5) + + +class TestApiBatch(unittest.TestCase): + def test_ApiBatch_increment_input(self): + api_name = "functional.conv2d" + start = 2 + api_batch = ApiBatch(api_name, start) + + api_batch.increment(Const.INPUT) + + self.assertEqual(api_batch._state, Const.INPUT) + self.assertEqual(api_batch.input_len, 2) + self.assertEqual(api_batch.params_end_index, 4) + self.assertEqual(api_batch.output_end_index, 4) + self.assertEqual(api_batch.params_grad_end_index, 4) + + def test_ApiBatch_increment_output(self): + api_name = "functional.conv2d" + start = 2 + api_batch = ApiBatch(api_name, start) + + api_batch.increment(Const.OUTPUT) + + self.assertEqual(api_batch._state, Const.OUTPUT) + self.assertEqual(api_batch.input_len, 1) + self.assertEqual(api_batch.params_end_index, 3) + self.assertEqual(api_batch.output_end_index, 4) + self.assertEqual(api_batch.params_grad_end_index, 4) + + def test_ApiBatch_increment_kwargs(self): + api_name = "functional.conv2d" + start = 2 + api_batch = ApiBatch(api_name, start) + + api_batch.increment(Const.KWARGS) + + self.assertEqual(api_batch._state, Const.KWARGS) + self.assertEqual(api_batch.input_len, 2) + self.assertEqual(api_batch.params_end_index, 4) + self.assertEqual(api_batch.output_end_index, 4) + self.assertEqual(api_batch.params_grad_end_index, 4) + + def test_ApiBatch_increment_params(self): + api_name = "functional.conv2d" + start = 2 + api_batch = ApiBatch(api_name, start) + + api_batch.increment(Const.PARAMS) + + self.assertEqual(api_batch._state, Const.PARAMS) + self.assertEqual(api_batch.input_len, 1) + self.assertEqual(api_batch.params_end_index, 4) + self.assertEqual(api_batch.output_end_index, 4) + self.assertEqual(api_batch.params_grad_end_index, 4) + + def test_ApiBatch_increment_multiple_input(self): + api_name = "functional.conv2d" + start = 2 + api_batch = ApiBatch(api_name, start) + + api_batch.increment(Const.INPUT) + api_batch.increment(Const.INPUT) + + self.assertEqual(api_batch._state, Const.INPUT) + self.assertEqual(api_batch.input_len, 3) + self.assertEqual(api_batch.params_end_index, 5) + self.assertEqual(api_batch.output_end_index, 5) + self.assertEqual(api_batch.params_grad_end_index, 5) + + def test_ApiBatch_increment_multiple_output(self): + api_name = "functional.conv2d" + start = 2 + api_batch = ApiBatch(api_name, start) + + api_batch.increment(Const.OUTPUT) + api_batch.increment(Const.OUTPUT) + + self.assertEqual(api_batch._state, Const.OUTPUT) + self.assertEqual(api_batch.input_len, 1) + self.assertEqual(api_batch.params_end_index, 3) + self.assertEqual(api_batch.output_end_index, 5) + self.assertEqual(api_batch.params_grad_end_index, 5) diff --git a/debug/accuracy_tools/msprobe/test/core_ut/compare/test_cmp_first_diff_analyze.py b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_cmp_first_diff_analyze.py new file mode 100644 index 0000000000000000000000000000000000000000..5c5a2690780bd0c1ada134f48746430da337d48d --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_cmp_first_diff_analyze.py @@ -0,0 +1,106 @@ +import unittest +from unittest.mock import patch + +import pandas as pd + +from msprobe.core.common.utils import CompareException +from msprobe.core.compare.diff_analyze.first_diff_analyze import FirstDiffAnalyze + + +class TestFirstDiffAnalyze(unittest.TestCase): + def setUp(self): + self.header = ['NPU name', 'L2norm diff', + 'MaxRelativeErr', 'MinRelativeErr', 'MeanRelativeErr', 'NormRelativeErr'] + self.data = [ + ['Functional.conv2d.0.forward.input.0', 1, '0.0%', '0.0%', '0.0%', '0.0%'], + ['Functional.conv2d.0.forward.input.1', 1, '99.0%', '99.0%', '99.0%', '99.0%'] + ] + self.result_df = pd.DataFrame(self.data, columns=self.header) + + @patch('msprobe.core.compare.diff_analyze.first_diff_analyze.thresholds', + {'compare_metrics': ['MaxRelativeErr', 'NormRelativeErr'], 'MaxRelativeErr': [0.5]}) + def test_single_metric_diff_check_true(self): + first_diff_analyze = FirstDiffAnalyze() + result = first_diff_analyze.single_metric_diff_check('MaxRelativeErr', '60.0%') + self.assertTrue(result) + + @patch('msprobe.core.compare.diff_analyze.first_diff_analyze.thresholds', + {'compare_metrics': ['MaxRelativeErr', 'NormRelativeErr'], 'MaxRelativeErr': [0.5]}) + def test_single_metric_diff_check_false(self): + first_diff_analyze = FirstDiffAnalyze() + result = first_diff_analyze.single_metric_diff_check('MaxRelativeErr', '30.0%') + self.assertFalse(result) + + @patch('msprobe.core.compare.diff_analyze.first_diff_analyze.thresholds', + {'compare_metrics': ['MaxRelativeErr', 'NormRelativeErr'], 'NormRelativeErr': [0.5]}) + def test_single_metric_diff_check_miss_threshold(self): + first_diff_analyze = FirstDiffAnalyze() + with self.assertRaises(CompareException) as context: + result = first_diff_analyze.single_metric_diff_check('MaxRelativeErr', '30.0%') + self.assertEqual(context.exception.code, CompareException.MISSING_THRESHOLD_ERROR) + + @patch('msprobe.core.compare.diff_analyze.first_diff_analyze.thresholds', + {'compare_metrics': ['MaxRelativeErr', 'NormRelativeErr'], 'MaxRelativeErr': [0.5, 1.0]}) + def test_single_metric_diff_check_wrong_threshold(self): + first_diff_analyze = FirstDiffAnalyze() + with self.assertRaises(CompareException) as context: + result = first_diff_analyze.single_metric_diff_check('MaxRelativeErr', '30.0%') + self.assertEqual(context.exception.code, CompareException.WRONG_THRESHOLD_ERROR) + + def test_single_api_check_within_threshold(self): + result_slice = [ + ['Functional.conv2d.0.forward.input.0', 1, '0.0%', '0.0%', '0.0%', '0.0%'], + ['Functional.conv2d.0.forward.input.1', 1, '0.1%', '0.1%', '0.1%', '0.1%'] + ] + expected_result = { + 'is_same': True, + 'op_items': [ + {'NPU name': 'Functional.conv2d.0.forward.input.0', 'L2norm diff': 1, + 'MaxRelativeErr': '0.0%', 'MinRelativeErr': '0.0%', + 'MeanRelativeErr': '0.0%', 'NormRelativeErr': '0.0%'}, + {'NPU name': 'Functional.conv2d.0.forward.input.1', 'L2norm diff': 1, + 'MaxRelativeErr': '0.1%', 'MinRelativeErr': '0.1%', + 'MeanRelativeErr': '0.1%', 'NormRelativeErr': '0.1%'} + ] + } + first_diff_analyze = FirstDiffAnalyze() + result = first_diff_analyze.single_api_check(result_slice, self.header) + self.assertEqual(result, expected_result) + + def test_single_api_check_exceed_threshold(self): + result_slice = [ + ['Functional.conv2d.0.forward.input.0', 1, '88.0%', '88.0%', '88.0%', '88.0%'], + ['Functional.conv2d.0.forward.input.1', 1, '99.0%', '99.0%', '99.0%', '99.0%'] + ] + expected_result = { + 'is_same': False, + 'op_items': [ + {'NPU name': 'Functional.conv2d.0.forward.input.0', 'L2norm diff': 1, + 'MaxRelativeErr': '88.0%', 'MinRelativeErr': '88.0%', + 'MeanRelativeErr': '88.0%', 'NormRelativeErr': '88.0%'}, + {'NPU name': 'Functional.conv2d.0.forward.input.1', 'L2norm diff': 1, + 'MaxRelativeErr': '99.0%', 'MinRelativeErr': '99.0%', + 'MeanRelativeErr': '99.0%', 'NormRelativeErr': '99.0%'}, + ] + } + first_diff_analyze = FirstDiffAnalyze() + result = first_diff_analyze.single_api_check(result_slice, self.header) + self.assertEqual(result, expected_result) + + def test_check(self): + expected_result = { + 'Functional.conv2d.0.forward': { + 'is_same': False, + 'op_items': [ + {'NPU name': 'Functional.conv2d.0.forward.input.0', 'L2norm diff': 1, + 'MaxRelativeErr': '0.0%', 'MinRelativeErr': '0.0%', + 'MeanRelativeErr': '0.0%', 'NormRelativeErr': '0.0%'}, + {'NPU name': 'Functional.conv2d.0.forward.input.1', 'L2norm diff': 1, + 'MaxRelativeErr': '99.0%', 'MinRelativeErr': '99.0%', + 'MeanRelativeErr': '99.0%', 'NormRelativeErr': '99.0%'}, + ] + } + } + first_diff_analyze = FirstDiffAnalyze() + result = first_diff_analyze.check(self.result_df) + self.assertEqual(result, expected_result) diff --git a/debug/accuracy_tools/msprobe/test/core_ut/compare/test_cmp_highlight.py b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_cmp_highlight.py index f561a3e05ec84c3ee75dac50ed5aec2a2af7f7b5..5d01c3fdcbee48bad403c4749921b2936cdfc83d 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/compare/test_cmp_highlight.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_cmp_highlight.py @@ -12,12 +12,45 @@ import openpyxl from openpyxl import load_workbook from openpyxl.styles import PatternFill - from msprobe.core.common.const import CompareConst, Const -from msprobe.core.compare.highlight import ApiBatch, CheckMaxRelativeDiff, CheckOrderMagnitude, \ - CheckOneThousandErrorRatio, CheckCosineSimilarity, add_highlight_row_info, compare_result_df_convert, \ - df_malicious_value_check, find_error_rows, highlight_rows_xlsx, update_highlight_err_msg, value_check - +from msprobe.core.compare.highlight import CheckMaxRelativeDiff, CheckOrderMagnitude, \ + CheckOneThousandErrorRatio, CheckCosineSimilarity, add_highlight_row_info, HighLight +from msprobe.core.compare.config import ModeConfig +from msprobe.core.compare.utils import ApiBatch + + +summary_line_input = ['Functional_batch_norm_0_forward.input.0', 'Functional_batch_norm_0_forward.input.0', + 'torch.float16', + 'torch.float32', [256, 256, 14, 14], [256, 256, 14, 14], 0.01, 0, 0, 0, 1, 1, 1, 1, 1.01, 1, 1, 1, + 'Yes', ''] +summary_line_1 = ['Functional_batch_norm_0_forward.output.0', 'Functional_batch_norm_0_forward.output.0', + 'torch.float16', + 'torch.float32', [256, 256, 14, 14], [256, 256, 14, 14], 10, 0, 0, 0, 2, 0, 1, 1, 1, 1, 1, 1, + 'Warning', ''] +summary_line_2 = ['Functional_batch_norm_0_forward.output.1', 'Functional_batch_norm_0_forward.output.1', + 'torch.float16', + 'torch.float32', [256, 256, 14, 14], [256, 256, 14, 14], 0.02, 0, 0, 0, 0.12, 0, 1, 1, 0.1, 1, 1, 1, + 'Warning', ''] +summary_line_3 = ['Functional_batch_norm_0_forward.output.2', 'Functional_batch_norm_0_forward.output.2', + 'torch.float16', + 'torch.float32', [256, 256, 14, 14], [256, 256, 14, 14], 0, 0, 0, 0, 2, 0, 1, 1, 1, 1, 1, 1, + 'Warning', ''] +line_input = ['Functional.batch.norm.0.forward.input.0', 'Functional.batch.norm.0.forward.input.0', 'torch.float16', + 'torch.float32', [256, 256, 14, 14], [256, 256, 14, 14], 1, 0.5, 1, 1, 0.95, 1, + 1, 1, 1, 1, 1.01, 1, 1, 1, + 'Yes', ''] +line_1 = ['Functional.batch.norm.0.forward.output.0', 'Functional.batch.norm.0.forward.output.0', 'torch.float16', + 'torch.float32', [256, 256, 14, 14], [256, 256, 14, 14], 0.8, 0.5, 1, 1, 0.59, 1, + 'nan', 0, 1, 1, 19, 1, 1, 1, + 'Yes', ''] +line_2 = ['Functional.batch.norm.0.forward.output.1', 'Functional.batch.norm.0.forward.output.1', 'torch.float16', + 'torch.float32', [256, 256, 14, 14], [256, 256, 14, 14], 0.9, 0.5, 1, 1, 0.8, 1, + 0, 0.12, 0, 1, 1, 0.1, 1, 1, + 'Yes', ''] +line_3 = ['Functional.batch.norm.0.forward.output.2', 'Functional.batch.norm.0.forward.output.2', 'torch.float16', + 'torch.float32', [256, 256, 14, 14], [256, 256, 14, 14], 0.8, 0.5, 1.1e+10, 1, 0.85, 1, + 9, 0.12, 0, 1, 1, 0.1, 1, 1, + 'Yes', ''] base_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), f'test_highlight') @@ -26,7 +59,7 @@ def generate_result_xlsx(base_dir): data_path = os.path.join(base_dir, 'target_result.xlsx') data = [['Functional.linear.0.forward.input.0', 'Functional.linear.0.forward.input.0', 'torch.float32', 'torch.float32', [2, 2], [2, 2], - '', '', '', '', '', 1, 1, 1, 1, 1, 1, 1, 1, 'Yes', '', '-1'] + '', '', '', '', '', '', 1, 1, 1, 1, 1, 1, 1, 1, 'Yes', '', '-1'] ] columns = CompareConst.COMPARE_RESULT_HEADER + ['Data_name'] result_df = pd.DataFrame(data, columns=columns) @@ -101,8 +134,8 @@ class TestUtilsMethods(unittest.TestCase): self.assertEqual(result, None) def test_CheckOneThousandErrorRatio_str(self): - api_in = [1, 1, 1, 1, 1, 1, 1, 1, 1, "unsupported"] - api_out = [1, 1, 1, 1, 1, 1, 1, 1, 1, "unsupported"] + api_in = [1, 1, 1, 1, 1, 1, 0.9, 0.5, 1, 1, "unsupported"] + api_out = [1, 1, 1, 1, 1, 1, 0.9, 0.5, 1, 1, "unsupported"] info = (api_in, api_out, 1) color_columns = () dump_mode = Const.ALL @@ -113,8 +146,8 @@ class TestUtilsMethods(unittest.TestCase): @patch("msprobe.core.compare.highlight.add_highlight_row_info") def test_CheckOneThousandErrorRatio_red(self, mock_add_highlight_row_info): - api_in = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1] - api_out = [1, 1, 1, 1, 1, 1, 1, 1, 1, 0.5] + api_in = [1, 1, 1, 1, 1, 1, 0.9, 0.5, 1, 1, 1] + api_out = [1, 1, 1, 1, 1, 1, 0.9, 0.5, 1, 1, 0.5] info = (api_in, api_out, 1) ColorColumns = namedtuple('ColorColumns', ['red', 'yellow']) color_columns = ColorColumns(red=[], yellow=[]) @@ -161,7 +194,7 @@ class TestUtilsMethods(unittest.TestCase): num = 1 info = (api_in, api_out, num) CheckMaxRelativeDiff().apply(info, color_columns, dump_mode=Const.SUMMARY) - red_lines, yellow_lines = [], [(1, ["The output's maximum relative error exceeds 0.1, while the input/parameters's is below 0.01"])] + red_lines, yellow_lines = [], [(1, ["The output's maximum relative error exceeds 0.1, while the input/parameter's is below 0.01"])] target_color_columns = ColorColumns(red=red_lines, yellow=yellow_lines) self.assertEqual(color_columns, target_color_columns) @@ -198,9 +231,10 @@ class TestUtilsMethods(unittest.TestCase): api_batch.output_end_index = 4 api_batch.params_end_index = 4 highlight_dict = {"red_lines": [], "red_rows": set(), "yellow_lines": [], "yellow_rows": set()} - dump_mode = Const.ALL - find_error_rows(compare_result, api_batch, highlight_dict, dump_mode) + mode_config = ModeConfig(dump_mode=Const.ALL) + highlight = HighLight(mode_config) + highlight.find_error_rows(compare_result, api_batch, highlight_dict) self.assertEqual(highlight_dict, {"red_lines": [], "red_rows": set(), "yellow_lines": [], "yellow_rows": set()}) @@ -211,92 +245,13 @@ class TestUtilsMethods(unittest.TestCase): api_batch.output_end_index = 1 api_batch.params_end_index = 1 highlight_dict = {} - dump_mode = Const.MD5 - result = find_error_rows(compare_result, api_batch, highlight_dict, dump_mode) + mode_config = ModeConfig(dump_mode=Const.MD5) + highlight = HighLight(mode_config) + result = highlight.find_error_rows(compare_result, api_batch, highlight_dict) self.assertEqual(result, None) - def test_ApiBatch_increment_input(self): - api_name = "functional.conv2d" - start = 2 - api_batch = ApiBatch(api_name, start) - - api_batch.increment(Const.INPUT) - - self.assertEqual(api_batch._state, Const.INPUT) - self.assertEqual(api_batch.input_len, 2) - self.assertEqual(api_batch.params_end_index, 4) - self.assertEqual(api_batch.output_end_index, 4) - self.assertEqual(api_batch.params_grad_end_index, 4) - - def test_ApiBatch_increment_output(self): - api_name = "functional.conv2d" - start = 2 - api_batch = ApiBatch(api_name, start) - - api_batch.increment(Const.OUTPUT) - - self.assertEqual(api_batch._state, Const.OUTPUT) - self.assertEqual(api_batch.input_len, 1) - self.assertEqual(api_batch.params_end_index, 3) - self.assertEqual(api_batch.output_end_index, 4) - self.assertEqual(api_batch.params_grad_end_index, 4) - - def test_ApiBatch_increment_kwargs(self): - api_name = "functional.conv2d" - start = 2 - api_batch = ApiBatch(api_name, start) - - api_batch.increment(Const.KWARGS) - - self.assertEqual(api_batch._state, Const.KWARGS) - self.assertEqual(api_batch.input_len, 2) - self.assertEqual(api_batch.params_end_index, 4) - self.assertEqual(api_batch.output_end_index, 4) - self.assertEqual(api_batch.params_grad_end_index, 4) - - def test_ApiBatch_increment_params(self): - api_name = "functional.conv2d" - start = 2 - api_batch = ApiBatch(api_name, start) - - api_batch.increment(Const.PARAMS) - - self.assertEqual(api_batch._state, Const.PARAMS) - self.assertEqual(api_batch.input_len, 1) - self.assertEqual(api_batch.params_end_index, 4) - self.assertEqual(api_batch.output_end_index, 4) - self.assertEqual(api_batch.params_grad_end_index, 4) - - def test_ApiBatch_increment_multiple_input(self): - api_name = "functional.conv2d" - start = 2 - api_batch = ApiBatch(api_name, start) - - api_batch.increment(Const.INPUT) - api_batch.increment(Const.INPUT) - - self.assertEqual(api_batch._state, Const.INPUT) - self.assertEqual(api_batch.input_len, 3) - self.assertEqual(api_batch.params_end_index, 5) - self.assertEqual(api_batch.output_end_index, 5) - self.assertEqual(api_batch.params_grad_end_index, 5) - - def test_ApiBatch_increment_multiple_output(self): - api_name = "functional.conv2d" - start = 2 - api_batch = ApiBatch(api_name, start) - - api_batch.increment(Const.OUTPUT) - api_batch.increment(Const.OUTPUT) - - self.assertEqual(api_batch._state, Const.OUTPUT) - self.assertEqual(api_batch.input_len, 1) - self.assertEqual(api_batch.params_end_index, 3) - self.assertEqual(api_batch.output_end_index, 5) - self.assertEqual(api_batch.params_grad_end_index, 5) - @patch("msprobe.core.compare.highlight.logger") def test_value_check(self, mock_logger): value = "=functional.conv2d" @@ -304,7 +259,9 @@ class TestUtilsMethods(unittest.TestCase): i = 1 result_df_columns = CompareConst.COMPARE_RESULT_HEADER - value_check(value, api_name, i, result_df_columns) + mode_config = ModeConfig() + highlight = HighLight(mode_config) + highlight.value_check(value, api_name, i, result_df_columns) mock_logger.error.assert_called_once_with( "Malicious value [=functional.conv2d] at api_name [=functional.conv2d], column [Bench Name], " @@ -315,40 +272,52 @@ class TestUtilsMethods(unittest.TestCase): columns = CompareConst.COMPARE_RESULT_HEADER data = [['Functional.linear.0.forward.input.0', 'Functional.linear.0.forward.input.0', 'torch.float32', 'torch.float32', [2, 2], [2, 2], - '', '', '', '', '', 1, 1, 1, 1, 1, 1, 1, 1, 'Yes', ''] + '', '', '', '', '', '', 1, 1, 1, 1, 1, 1, 1, 1, 'Yes', ''] ] result_df = pd.DataFrame(data, columns=columns) - df_malicious_value_check(result_df, columns) + mode_config = ModeConfig(dump_mode=Const.ALL) + highlight = HighLight(mode_config) + highlight.df_malicious_value_check(result_df, columns) def test_compare_result_df_convert(self): value = float("nan") - result = compare_result_df_convert(value) + mode_config = ModeConfig() + highlight = HighLight(mode_config) + result = highlight.compare_result_df_convert(value) self.assertEqual(result, "nan\t") def test_highlight_rows_xlsx_red(self): data = [['Functional.linear.0.forward.input.0', 'Functional.linear.0.forward.input.0', 'torch.float32', 'torch.float32', [2, 2], [2, 2], - '', '', '', '', '', 1, 1, 1, 1, 1, 1, 1, 1, 'Yes', '', '-1'] + '', '', '', '', '', '', 1, 1, 1, 1, 1, 1, 1, 1, 'Yes', '', '-1'] ] columns = CompareConst.COMPARE_RESULT_HEADER + ['Data_name'] result_df = pd.DataFrame(data, columns=columns) highlight_dict = {'red_rows': [0]} file_path = os.path.join(base_dir, 'result.xlsx') - highlight_rows_xlsx(result_df, highlight_dict, file_path) + + mode_config = ModeConfig(dump_mode=Const.ALL) + highlight = HighLight(mode_config) + highlight.highlight_rows_xlsx(result_df, highlight_dict, file_path) + generate_result_xlsx(base_dir) self.assertTrue(compare_excel_files_with_highlight(file_path, os.path.join(base_dir, 'target_result.xlsx'))) def test_highlight_rows_xlsx_yellow(self): data = [['Functional.linear.0.forward.input.0', 'Functional.linear.0.forward.input.0', 'torch.float32', 'torch.float32', [2, 2], [2, 2], - '', '', '', '', '', 1, 1, 1, 1, 1, 1, 1, 1, 'Yes', '', '-1'] + '', '', '', '', '', '', 1, 1, 1, 1, 1, 1, 1, 1, 'Yes', '', '-1'] ] columns = CompareConst.COMPARE_RESULT_HEADER + ['Data_name'] result_df = pd.DataFrame(data, columns=columns) highlight_dict = {'yellow_rows': [0]} file_path = os.path.join(base_dir, 'result.xlsx') - highlight_rows_xlsx(result_df, highlight_dict, file_path) + + mode_config = ModeConfig(dump_mode=Const.ALL) + highlight = HighLight(mode_config) + highlight.highlight_rows_xlsx(result_df, highlight_dict, file_path) + generate_result_xlsx(base_dir) self.assertTrue(compare_excel_files_with_highlight(file_path, os.path.join(base_dir, 'target_result_yellow.xlsx'))) @@ -356,7 +325,7 @@ class TestUtilsMethods(unittest.TestCase): def test_highlight_rows_xlsx_malicious_columns(self, mock_save_book): data = [['Functional.linear.0.forward.input.0', 'Functional.linear.0.forward.input.0', 'torch.float32', 'torch.float32', [2, 2], [2, 2], - '', '', '', '', '', 1, 1, 1, 1, 1, 1, 1, 1, 'Yes', '', '-1'] + '', '', '', '', '', '', 1, 1, 1, 1, 1, 1, 1, 1, 'Yes', '', '-1'] ] columns = CompareConst.COMPARE_RESULT_HEADER + ['=Data_name'] result_df = pd.DataFrame(data, columns=columns) @@ -366,7 +335,9 @@ class TestUtilsMethods(unittest.TestCase): temp_output_file = 'temp_output.txt' sys.stdout = open(temp_output_file, 'w') - highlight_rows_xlsx(result_df, highlight_dict, file_path) + mode_config = ModeConfig(dump_mode=Const.ALL) + highlight = HighLight(mode_config) + highlight.highlight_rows_xlsx(result_df, highlight_dict, file_path) with open(temp_output_file, 'r') as f: output = f.read() @@ -378,10 +349,10 @@ class TestUtilsMethods(unittest.TestCase): def test_highlight_rows_xlsx_malicious_type(self, mock_save_book): data = [['Functional.linear.0.forward.input.0', 'Functional.linear.0.forward.input.0', '=torch.float32', 'torch.float32', [2, 2], [2, 2], - '', '', '', '', '', 1, 1, 1, 1, 1, 1, 1, 1, 'Yes', '', '-1'], + '', '', '', '', '', '', 1, 1, 1, 1, 1, 1, 1, 1, 'Yes', '', '-1'], ['Functional.linear.0.forward.input.0', 'Functional.linear.0.forward.input.0', '=torch.float32', 'torch.float32', [2, 2], [2, 2], - '', '', '', '', '', 1, 1, 1, 1, 1, 1, 1, 1, 'Yes', '', '-1'] + '', '', '', '', '', '', 1, 1, 1, 1, 1, 1, 1, 1, 'Yes', '', '-1'] ] columns = CompareConst.COMPARE_RESULT_HEADER + ['Data_name'] result_df = pd.DataFrame(data, columns=columns) @@ -391,7 +362,9 @@ class TestUtilsMethods(unittest.TestCase): temp_output_file = 'temp_output.txt' sys.stdout = open(temp_output_file, 'w') - highlight_rows_xlsx(result_df, highlight_dict, file_path) + mode_config = ModeConfig(dump_mode=Const.ALL) + highlight = HighLight(mode_config) + highlight.highlight_rows_xlsx(result_df, highlight_dict, file_path) with open(temp_output_file, 'r') as f: output = f.read() @@ -416,10 +389,10 @@ class TestUtilsMethods(unittest.TestCase): def test_update_highlight_err_msg(self): data = [['Functional.linear.0.forward.input.0', 'Functional.linear.0.forward.input.0', 'torch.float32', 'torch.float32', [2, 2], [2, 2], - '', '', '', '', '', 1, 1, 1, 1, 1, 1, 1, 1, 'Yes', '', '-1'], + '', '', '', '', '', '', 1, 1, 1, 1, 1, 1, 1, 1, 'Yes', '', '-1'], ['Functional.linear.0.forward.input.0', 'Functional.linear.0.forward.input.0', 'torch.float32', 'torch.float32', [2, 2], [2, 2], - '', '', '', '', '', 1, 1, 1, 1, 1, 1, 1, 1, 'Yes', '', '-1'] + '', '', '', '', '', '', 1, 1, 1, 1, 1, 1, 1, 1, 'Yes', '', '-1'] ] columns = CompareConst.COMPARE_RESULT_HEADER + ['Data_name'] result_df = pd.DataFrame(data, columns=columns) @@ -429,14 +402,17 @@ class TestUtilsMethods(unittest.TestCase): 'red_lines': [(0, ['a', 'b'])], 'yellow_lines': [(0, ['c']), (1, ['d'])] } - update_highlight_err_msg(result_df, highlight_dict) + + mode_config = ModeConfig(dump_mode=Const.ALL) + highlight = HighLight(mode_config) + highlight.update_highlight_err_msg(result_df, highlight_dict) t_data = [['Functional.linear.0.forward.input.0', 'Functional.linear.0.forward.input.0', 'torch.float32', 'torch.float32', [2, 2], [2, 2], - '', '', '', '', '', 1, 1, 1, 1, 1, 1, 1, 1, 'Yes', 'a\nb', '-1'], + '', '', '', '', '', '', 1, 1, 1, 1, 1, 1, 1, 1, 'Yes', 'a\nb', '-1'], ['Functional.linear.0.forward.input.0', 'Functional.linear.0.forward.input.0', 'torch.float32', 'torch.float32', [2, 2], [2, 2], - '', '', '', '', '', 1, 1, 1, 1, 1, 1, 1, 1, 'Yes', 'd', '-1'] + '', '', '', '', '', '', 1, 1, 1, 1, 1, 1, 1, 1, 'Yes', 'd', '-1'] ] target_result_df = pd.DataFrame(t_data, columns=columns) self.assertTrue(result_df.equals(target_result_df)) @@ -449,7 +425,9 @@ class TestUtilsMethods(unittest.TestCase): result_df = pd.DataFrame(data, columns=columns) highlight_dict = {} - result = update_highlight_err_msg(result_df, highlight_dict) + mode_config = ModeConfig(dump_mode=Const.MD5) + highlight = HighLight(mode_config) + result = highlight.update_highlight_err_msg(result_df, highlight_dict) self.assertEqual(result, None) @@ -466,5 +444,43 @@ class TestUtilsMethods(unittest.TestCase): 'red_lines': [(0, ['a', 'b'])], 'yellow_lines': [(0, ['c']), (1, ['d'])] } - result = update_highlight_err_msg(result_df, highlight_dict) + mode_config = ModeConfig() + highlight = HighLight(mode_config) + result = highlight.update_highlight_err_msg(result_df, highlight_dict) self.assertEqual(result, None) + + def test_find_error_rows(self): + api_batch = ApiBatch("Functional_batch_norm_0_forward", 0) + api_batch.input_len = 1 + api_batch.output_end_index = 4 + api_batch.params_end_index = 4 + summary_result = [summary_line_input, summary_line_1, summary_line_2, summary_line_3] + highlight_dict_test = {"red_rows": set(), "yellow_rows": set(), "red_lines": [], "yellow_lines": []} + mode_config = ModeConfig() + highlight = HighLight(mode_config) + highlight.find_error_rows(summary_result, api_batch, highlight_dict_test) + self.assertEqual(highlight_dict_test, + {"red_rows": set(), "yellow_rows": set(), "red_lines": [], "yellow_lines": []}) + + def test_find_compare_result_error_rows(self): + result = [line_input, line_1, line_2, line_3] + result_df = pd.DataFrame(result) + highlight_dict_test = {"red_rows": set(), "yellow_rows": set(), "red_lines": [], "yellow_lines": []} + mode_config = ModeConfig(dump_mode=Const.ALL) + highlight = HighLight(mode_config) + highlight.find_compare_result_error_rows(result_df, highlight_dict_test) + self.assertEqual(highlight_dict_test, { + "red_rows": {1, 3}, + "yellow_rows": {2}, + "red_lines": [ + (1, ["maximum or minimum is nan, -inf, or inf"]), + (3, ["maximum absolute error exceeds 1e+10"]) + ], + "yellow_lines": [ + (2, ["The output's one thousandth err ratio decreases by more than 0.1 compared to the input/parameter's"]), + (3, [ + "maximum absolute error of both input/parameters and output exceed 1, " + "with the output larger by an order of magnitude", + "The output's cosine decreases by more than 0.1 compared to the input/parameter's"]) + ] + }) diff --git a/debug/accuracy_tools/msprobe/test/core_ut/compare/test_cmp_multiprocessing_compute.py b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_cmp_multiprocessing_compute.py index 9c2dea835fea13af7902bf796d9ab06c9eb6a61b..dcf2d5621e3445583e83948e5b167d82cfcb6cbc 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/compare/test_cmp_multiprocessing_compute.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_cmp_multiprocessing_compute.py @@ -7,125 +7,231 @@ import unittest import pandas as pd -from msprobe.core.common.const import CompareConst, Const +from msprobe.core.common.const import Const, CompareConst from msprobe.core.common.utils import CompareException -from msprobe.core.compare.acc_compare import Comparator, ModeConfig -from msprobe.core.compare.multiprocessing_compute import ComparisonResult, _handle_multi_process, _save_cmp_result, \ - check_accuracy, read_dump_data -from test_acc_compare import generate_dump_json +from msprobe.core.compare.acc_compare import ModeConfig +from msprobe.core.compare.multiprocessing_compute import check_accuracy, CompareRealData, ComparisonResult +from msprobe.pytorch.compare.pt_compare import read_real_data +from test_acc_compare import generate_dump_json, generate_pt, generate_stack_json data = [['Functional.linear.0.forward.input.0', 'Functional.linear.0.forward.input.0', 'torch.float32', 'torch.float32', [2, 2], [2, 2], - '', '', '', '', '', + '', '', '', '', '', '', 1, 1, 1, 1, 1, 1, 1, 1, - 'Yes', '', '-1']] + 'Yes', '', ['-1', '-1']]] o_data = [['Functional.linear.0.forward.input.0', 'Functional.linear.0.forward.input.0', 'torch.float32', 'torch.float32', [2, 2], [2, 2], - 'unsupported', 'unsupported', 'unsupported', 'unsupported', 'unsupported', + 'unsupported', 'unsupported', 'unsupported', 'unsupported', 'unsupported', 'unsupported', 1, 1, 1, 1, 1, 1, 1, 1, - 'None', 'No bench data matched.', '-1']] + 'None', 'No bench data matched.', ['-1', '-1']]] columns = CompareConst.COMPARE_RESULT_HEADER + ['Data_name'] result_df = pd.DataFrame(data, columns=columns) o_result = pd.DataFrame(o_data, columns=columns) base_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), f'test_cmp_multiprocessing_compute') +base_dir3 = os.path.join(os.path.dirname(os.path.abspath(__file__)), f'test_acc_compare_data3') +pt_dir = os.path.join(base_dir3, f'dump_data_dir') class TestUtilsMethods(unittest.TestCase): + def test_check_accuracy(self): + max_abs_err = '' + + cos_1 = CompareConst.SHAPE_UNMATCH + result_1 = check_accuracy(cos_1, max_abs_err) + self.assertEqual(result_1, CompareConst.ACCURACY_CHECK_UNMATCH) + + cos_2 = CompareConst.NONE + result_2 = check_accuracy(cos_2, max_abs_err) + self.assertEqual(result_2, CompareConst.NONE) + + cos_3 = 'N/A' + result_3 = check_accuracy(cos_3, max_abs_err) + self.assertEqual(result_3, CompareConst.ACCURACY_CHECK_NO) + + cos_4 = '' + result_4 = check_accuracy(cos_4, max_abs_err) + self.assertEqual(result_4, CompareConst.NONE) + + cos_5 = 0.95 + max_abs_err = 0.002 + result_5 = check_accuracy(cos_5, max_abs_err) + self.assertEqual(result_5, CompareConst.ACCURACY_CHECK_NO) + + cos_6 = 0.85 + max_abs_err = 2 + result_6 = check_accuracy(cos_6, max_abs_err) + self.assertEqual(result_6, CompareConst.ACCURACY_CHECK_NO) + + cos_7 = 0.95 + max_abs_err = 0.001 + result_7 = check_accuracy(cos_7, max_abs_err) + self.assertEqual(result_7, CompareConst.ACCURACY_CHECK_YES) + + +class TestCompareRealData(unittest.TestCase): + def setUp(self): self.result_df = pd.DataFrame(columns=[ - CompareConst.COSINE, CompareConst.MAX_ABS_ERR, CompareConst.MAX_RELATIVE_ERR, - CompareConst.ERROR_MESSAGE, CompareConst.ACCURACY, - CompareConst.ONE_THOUSANDTH_ERR_RATIO, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO + CompareConst.COSINE, CompareConst.EUC_DIST, CompareConst.MAX_ABS_ERR, CompareConst.MAX_RELATIVE_ERR, + CompareConst.ONE_THOUSANDTH_ERR_RATIO, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO, + CompareConst.ACCURACY, CompareConst.ERROR_MESSAGE ]) os.makedirs(base_dir, mode=0o750, exist_ok=True) + os.makedirs(base_dir3, mode=0o750, exist_ok=True) + os.makedirs(pt_dir, mode=0o750, exist_ok=True) self.lock = threading.Lock() def tearDown(self): if os.path.exists(base_dir): shutil.rmtree(base_dir) - - def test_handle_multi_process(self): - stack_mode = False - auto_analyze = True - fuzzy_match = False - dump_mode = Const.ALL - mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) - - func = Comparator(mode_config).compare_ops - generate_dump_json(base_dir) - input_parma = {'bench_json_path': os.path.join(base_dir, 'dump.json')} - lock = multiprocessing.Manager().RLock() - result = _handle_multi_process(func, input_parma, result_df, lock) - self.assertTrue(result.equals(o_result)) + if os.path.exists(pt_dir): + shutil.rmtree(pt_dir) + if os.path.exists(base_dir3): + shutil.rmtree(base_dir3) def test_read_dump_data(self): - result = read_dump_data(result_df) + file_reader = read_real_data + mode_config = ModeConfig(dump_mode=Const.ALL) + cross_frame = False + compare_real_data = CompareRealData(file_reader, mode_config, cross_frame) + + # normal + result = compare_real_data.read_dump_data(result_df) self.assertEqual(result, {'Functional.linear.0.forward.input.0': ['-1', '-1']}) + # index error with self.assertRaises(CompareException) as context: - result = read_dump_data(pd.DataFrame()) + result = compare_real_data.read_dump_data(pd.DataFrame()) self.assertEqual(context.exception.code, CompareException.INDEX_OUT_OF_BOUNDS_ERROR) def test_save_cmp_result_success(self): + file_reader = read_real_data + mode_config = ModeConfig(dump_mode=Const.ALL) + cross_frame = False + compare_real_data = CompareRealData(file_reader, mode_config, cross_frame) + comparison_result = ComparisonResult( cos_result=[0.99, 0.98], max_err_result=[0.01, 0.02], max_relative_err_result=[0.001, 0.002], - err_msgs=['', 'Error in comparison'], + euc_dist_result=[0.5, 0.49], one_thousand_err_ratio_result=[0.1, 0.2], - five_thousand_err_ratio_result=[0.05, 0.1] + five_thousand_err_ratio_result=[0.05, 0.1], + err_msgs=['', 'Error in comparison'] ) offset = 0 - updated_df = _save_cmp_result(offset, comparison_result, self.result_df, self.lock) + updated_df = compare_real_data._save_cmp_result(offset, comparison_result, self.result_df, self.lock) self.assertEqual(updated_df.loc[0, CompareConst.COSINE], 0.99) self.assertEqual(updated_df.loc[1, CompareConst.COSINE], 0.98) self.assertEqual(updated_df.loc[1, CompareConst.ERROR_MESSAGE], 'Error in comparison') def test_save_cmp_result_index_error(self): + file_reader = read_real_data + mode_config = ModeConfig(dump_mode=Const.ALL) + cross_frame = False + compare_real_data = CompareRealData(file_reader, mode_config, cross_frame) + comparison_result = ComparisonResult( cos_result=[0.99], max_err_result=[], max_relative_err_result=[0.001], - err_msgs=[''], + euc_dist_result=[0.5], one_thousand_err_ratio_result=[0.1], - five_thousand_err_ratio_result=[0.05] + five_thousand_err_ratio_result=[0.05], + err_msgs=[''] ) with self.assertRaises(CompareException) as context: - _save_cmp_result(0, comparison_result, self.result_df, self.lock) + compare_real_data._save_cmp_result(0, comparison_result, self.result_df, self.lock) self.assertEqual(context.exception.code, CompareException.INDEX_OUT_OF_BOUNDS_ERROR) - def test_check_accuracy(self): - max_abs_err = '' - - cos_1 = CompareConst.SHAPE_UNMATCH - result_1 = check_accuracy(cos_1, max_abs_err) - self.assertEqual(result_1, CompareConst.ACCURACY_CHECK_UNMATCH) - - cos_2 = CompareConst.NONE - result_2 = check_accuracy(cos_2, max_abs_err) - self.assertEqual(result_2, CompareConst.NONE) - - cos_3 = 'N/A' - result_3 = check_accuracy(cos_3, max_abs_err) - self.assertEqual(result_3, CompareConst.ACCURACY_CHECK_NO) - - cos_4 = '' - result_4 = check_accuracy(cos_4, max_abs_err) - self.assertEqual(result_4, CompareConst.NONE) + def test_compare_by_op_bench_normal(self): + npu_op_name = 'Functional.linear.0.forward.input.0' + bench_op_name = 'Functional.linear.0.forward.input.0' + + file_reader = read_real_data + mode_config = ModeConfig(dump_mode=Const.ALL) + cross_frame = False + compare_real_data = CompareRealData(file_reader, mode_config, cross_frame) + + pt_name = '-1' + op_name_mapping_dict = {'Functional.linear.0.forward.input.0': [pt_name, pt_name]} + input_param = {'npu_dump_data_dir': base_dir, 'bench_dump_data_dir': base_dir} + result = compare_real_data.compare_by_op(npu_op_name, bench_op_name, op_name_mapping_dict, input_param) + self.assertEqual(result, ['unsupported', 'unsupported', 'unsupported', 'unsupported', 'unsupported', + 'unsupported', 'No bench data matched.']) + + pt_name = 'Functional.linear.0.forward.input.0.pt' + op_name_mapping_dict = {'Functional.linear.0.forward.input.0': [pt_name, pt_name]} + input_param = {'npu_dump_data_dir': base_dir, 'bench_dump_data_dir': base_dir} + result = compare_real_data.compare_by_op(npu_op_name, bench_op_name, op_name_mapping_dict, input_param) + self.assertEqual(result, ['unsupported', 'unsupported', 'unsupported', 'unsupported', 'unsupported', + 'unsupported', "Dump file: ['Functional.linear.0.forward.input.0.pt', 'Functional.linear.0.forward.input.0.pt'] not found or read failed."]) + + generate_pt(base_dir) + result = compare_real_data.compare_by_op(npu_op_name, bench_op_name, op_name_mapping_dict, input_param) + self.assertEqual(result, [1.0, 0.0, 0.0, 0.0, 1.0, 1.0, '']) + + def test_compare_by_op_bench_na(self): + npu_op_name = 'Functional.linear.0.forward.input.0' + bench_op_name = 'N/A' + op_name_mapping_dict = {'Functional.linear.0.forward.input.0': [-1, -1]} + input_param = {} + + file_reader = read_real_data + mode_config = ModeConfig(dump_mode=Const.ALL) + cross_frame = False + compare_real_data = CompareRealData(file_reader, mode_config, cross_frame) + + result = compare_real_data.compare_by_op(npu_op_name, bench_op_name, op_name_mapping_dict, input_param) + self.assertEqual(result, ['unsupported', 'unsupported', 'unsupported', 'unsupported', 'unsupported', + 'unsupported', 'No bench data matched.']) + + def test_compare_ops(self): + generate_dump_json(base_dir3) + generate_stack_json(base_dir3) + generate_pt(pt_dir) + dump_path = os.path.join(base_dir3, 'dump.json') + stack_path = os.path.join(base_dir3, 'stack.json') + input_param = {'npu_json_path': dump_path, 'bench_json_path': dump_path, 'stack_json_path': stack_path, + 'is_print_compare_log': True, 'npu_dump_data_dir': pt_dir, 'bench_dump_data_dir': pt_dir} + dump_path_dict = {'Functional.linear.0.forward.input.0': ['Functional.linear.0.forward.input.0.pt', + 'Functional.linear.0.forward.input.0.pt']} + result_df = pd.DataFrame({ + 'NPU Name': ['Functional.linear.0.forward.input.0'], + 'Bench Name': ['Functional.linear.0.forward.input.0'] + }) + + file_reader = read_real_data + mode_config = ModeConfig(dump_mode=Const.ALL) + cross_frame = False + compare_real_data = CompareRealData(file_reader, mode_config, cross_frame) + + updated_df = compare_real_data.compare_ops(idx=0, dump_path_dict=dump_path_dict, result_df=result_df, + lock=self.lock, input_param=input_param) + + self.assertEqual(updated_df.loc[0, CompareConst.COSINE], 1.0) + self.assertEqual(updated_df.loc[0, CompareConst.MAX_ABS_ERR], 0) + + def test_do_multi_process(self): + data = [['Functional.linear.0.forward.input.0', 'Functional.linear.0.forward.input.0', + 'torch.float32', 'torch.float32', [2, 2], [2, 2], + '', '', '', '', '', '', 1, 1, 1, 1, 1, 1, 1, 1, 'Yes', '', ['-1', '-1']]] + o_data = [['Functional.linear.0.forward.input.0', 'Functional.linear.0.forward.input.0', + 'torch.float32', 'torch.float32', [2, 2], [2, 2], + 'unsupported', 'unsupported', 'unsupported', 'unsupported', 'unsupported', 'unsupported', + 1, 1, 1, 1, 1, 1, 1, 1, 'None', 'No bench data matched.', ['-1', '-1']]] + columns = CompareConst.COMPARE_RESULT_HEADER + ['Data_name'] + result_df = pd.DataFrame(data, columns=columns) + o_result = pd.DataFrame(o_data, columns=columns) + generate_dump_json(base_dir) + input_param = {'bench_json_path': os.path.join(base_dir, 'dump.json')} - cos_5 = 0.95 - max_abs_err = 0.002 - result_5 = check_accuracy(cos_5, max_abs_err) - self.assertEqual(result_5, CompareConst.ACCURACY_CHECK_NO) + file_reader = read_real_data + mode_config = ModeConfig(dump_mode=Const.ALL) + cross_frame = False + compare_real_data = CompareRealData(file_reader, mode_config, cross_frame) - cos_6 = 0.85 - max_abs_err = 2 - result_6 = check_accuracy(cos_6, max_abs_err) - self.assertEqual(result_6, CompareConst.ACCURACY_CHECK_NO) - - cos_7 = 0.95 - max_abs_err = 0.001 - result_7 = check_accuracy(cos_7, max_abs_err) - self.assertEqual(result_7, CompareConst.ACCURACY_CHECK_YES) + result = compare_real_data.do_multi_process(input_param, result_df) + self.assertTrue(result.equals(o_result)) diff --git a/debug/accuracy_tools/msprobe/test/core_ut/compare/test_postprocess_pass.py b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_postprocess_pass.py index 9cb33eb277848fa96bdf5b7456867d8579359723..f3623da772d1cbde684aa53119639faa93e4f068 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/compare/test_postprocess_pass.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_postprocess_pass.py @@ -14,9 +14,18 @@ # See the License for the specific language governing permissions and # limitations under the License. """ +from dataclasses import dataclass from unittest import TestCase -from msprobe.core.compare.layer_mapping.postprocess_pass import extract_next_item_last_number -from msprobe.core.compare.layer_mapping.postprocess_pass import replace_next_item_index +from msprobe.core.compare.layer_mapping.postprocess_pass import extract_next_item_last_number, \ + replace_next_item_index, renumber_index_pass + + +@dataclass +class DataItem: + """Class for keeping track of an item in inventory""" + type_name: str + full_scope: str + layer_scope: str class TestPostProcessPass(TestCase): @@ -46,3 +55,12 @@ class TestPostProcessPass(TestCase): replace_result = replace_next_item_index(input_data, prefix, inf_value) self.assertEqual(replace_result, input_data) + def test_renumber_index_pass(self): + a = DataItem("ParallelTransformer", "fake_data.layers.10", "fake_data.layers") + b = DataItem("ParallelTransformer", "fake_data.layers.12", "fake_data.layers") + c = DataItem("FakeLayer", "fake_data.layers.10.a.b.c", "fake_data.layers.a.b") + data_items = [a, b, c] + renumber_index_pass(data_items, "ParallelTransformer") + self.assertEqual(a.full_scope, "fake_data.layers.0") + self.assertEqual(b.full_scope, "fake_data.layers.2") + self.assertEqual(c.full_scope, "fake_data.layers.0.a.b.c") diff --git a/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_base.py b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_base.py index 8ff89437646ee203aaa4a3fac5bbfea1538e9409..f9b6bd4d8a2266e0f449239a7df87d5caf9d1b10 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_base.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_base.py @@ -70,31 +70,23 @@ class TestBaseDataProcessor(unittest.TestCase): @patch('inspect.stack') def test_analyze_api_call_stack(self, mock_stack): mock_stack.return_value = [ - (None, 'file0.py', 0, 'function0', ['code line 0'], None), - (None, 'file1.py', 10, 'function1', ['code line 1'], None), - (None, 'file2.py', 20, 'function2', ['code line 2'], None), (None, 'file3.py', 30, 'function3', ['code line 3'], None), - (None, 'file4.py', 40, 'function4', ['code line 4'], None), - (None, 'file5.py', 50, 'function5', ['code line 5'], None), - (None, 'file6.py', 60, 'function6', ['code line 6'], None), - (None, 'file7.py', 70, 'function7', ['code line 7'], None), + (None, 'file1.py', 40, 'function1', ['code line 1'], None), + (None, 'file2.py', 50, 'function2', ['code line 2'], None), + (None, 'file3.py', 60, 'function3', ['code line 3'], None), + (None, 'file1.py', 70, 'function1', ['code line 1'], None), + (None, 'file1.py', 80, 'function1', ['code line 1'], None), + (None, 'file2.py', 90, 'function2', ['code line 2'], None), + (None, 'file3.py', 100, 'function3', ['code line 3'], None) ] result = BaseDataProcessor.analyze_api_call_stack('test_stack') - expected_output = { - 'test_stack': [ - 'File file5.py, line 50, in function5, \n code line 5', - 'File file6.py, line 60, in function6, \n code line 6', - 'File file7.py, line 70, in function7, \n code line 7', - ] - } - self.assertEqual(result, expected_output) + expected_output = ( + 'File file1.py, line 80, in function1, \n code line 1', + 'File file2.py, line 90, in function2, \n code line 2', + 'File file3.py, line 100, in function3, \n code line 3', + ) - def test_convert_numpy_to_builtin(self): - self.assertEqual(BaseDataProcessor._convert_numpy_to_builtin(np.int32(5)), (5, 'int32')) - self.assertEqual(BaseDataProcessor._convert_numpy_to_builtin(np.float64(3.14)), (3.14, 'float64')) - self.assertEqual(BaseDataProcessor._convert_numpy_to_builtin(np.bool_(True)), (True, 'bool_')) - self.assertEqual(BaseDataProcessor._convert_numpy_to_builtin(np.str_('test')), ('test', 'str_')) - self.assertEqual(BaseDataProcessor._convert_numpy_to_builtin(5), (5, '')) + self.assertEqual(result, expected_output) def test_analyze_builtin(self): result = self.processor._analyze_builtin(slice(1, 10, 2)) @@ -113,12 +105,37 @@ class TestBaseDataProcessor(unittest.TestCase): expected = {'type': 'int', 'value': 1} self.assertEqual(result, expected) + def test_analyze_numpy(self): + result = BaseDataProcessor._analyze_numpy(np.int32(5)) + expected = {"type": 'int32', "value": 5} + self.assertEqual(result, expected) + + result = BaseDataProcessor._analyze_numpy(np.float32(3.14)) + expected = {"type": 'float32', "value": 3.140000104904175} + self.assertEqual(result, expected) + + result = BaseDataProcessor._analyze_numpy(np.bool_(True)) + expected = {"type": 'bool_', "value": True} + self.assertEqual(result, expected) + + result = BaseDataProcessor._analyze_numpy(np.str_("abc")) + expected = {"type": 'str_', "value": "abc"} + self.assertEqual(result, expected) + + result = BaseDataProcessor._analyze_numpy(np.byte(1)) + expected = {"type": 'int8', "value": 1} + self.assertEqual(result, expected) + + result = BaseDataProcessor._analyze_numpy(np.complex128(1 + 2j)) + expected = {"type": 'complex128', "value": (1 + 2j)} + self.assertEqual(result, expected) + def test_get_special_types(self): self.assertIn(int, BaseDataProcessor.get_special_types()) - def test_analyze_numpy(self): + def test_analyze_ndarray(self): ndarray = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32) - result = BaseDataProcessor._analyze_numpy(ndarray, 'numpy.ndarray') + result = BaseDataProcessor._analyze_ndarray(ndarray, 'numpy.ndarray') expected_result = { 'type': 'numpy.ndarray', 'dtype': 'int32', @@ -126,7 +143,20 @@ class TestBaseDataProcessor(unittest.TestCase): 'Max': 6, 'Min': 1, 'Mean': 3.5, - 'Norm':9.539392014169456 + 'Norm': 9.539392014169456 + } + self.assertEqual(result, expected_result) + + ndarray = np.array([], dtype=np.int32) + result = BaseDataProcessor._analyze_ndarray(ndarray, 'numpy.ndarray') + expected_result = { + 'type': 'numpy.ndarray', + 'dtype': 'int32', + 'shape': (0,), + 'Max': None, + 'Min': None, + 'Mean': None, + 'Norm': None } self.assertEqual(result, expected_result) @@ -134,6 +164,7 @@ class TestBaseDataProcessor(unittest.TestCase): transform = lambda x, _: x * 2 Test = namedtuple("Test", ['a']) myNamedTuple = Test(1) + @dataclass class MyDataClass: last_hidden_state: int = None @@ -145,7 +176,7 @@ class TestBaseDataProcessor(unittest.TestCase): hidden_states=(2, 3), attentions=(4, 5) ) - expected_dataclass_res = {'last_hidden_state': 2, 'hidden_states': [4, 6], 'attentions': [8,10]} + expected_dataclass_res = {'last_hidden_state': 2, 'hidden_states': [4, 6], 'attentions': [8, 10]} self.assertEqual(BaseDataProcessor.recursive_apply_transform(2, transform), 4) self.assertEqual(BaseDataProcessor.recursive_apply_transform(myData, transform), expected_dataclass_res) self.assertEqual(BaseDataProcessor.recursive_apply_transform(myNamedTuple, transform), {'a': 2}) @@ -280,9 +311,9 @@ class TestBaseDataProcessor(unittest.TestCase): self.assertEqual(dst_data_structure, excepted_result) def test_analyze_element_to_all_none(self): - element = {"key1": [12, 3, {"key2": 10, "key3":["12"]}]} + element = {"key1": [12, 3, {"key2": 10, "key3": ["12"]}]} result = self.processor.analyze_element_to_all_none(element) - excepted_result = {"key1": [None, None, {"key2": None, "key3":[None]}]} + excepted_result = {"key1": [None, None, {"key2": None, "key3": [None]}]} self.assertEqual(result, excepted_result) @patch.object(MindsporeDataProcessor, "is_hookable_element", return_value=True) @@ -327,4 +358,4 @@ class TestBaseDataProcessor(unittest.TestCase): nested_data_structure, ["grad_name_1", "layer1", "layer2"], "grad_data_info" ) self.assertIsNone(self.processor.save_name) - self.assertEqual(result, grad) \ No newline at end of file + self.assertEqual(result, grad) diff --git a/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_mindspore_processor.py b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_mindspore_processor.py index b593d34c5d86c7fb3b4a0e8a3ff548c55555e09d..46cc3b44747a548bee655927c7b5cef69b84d586 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_mindspore_processor.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_mindspore_processor.py @@ -19,7 +19,7 @@ from unittest.mock import patch, MagicMock import zlib import mindspore as ms -from mindspore import Tensor +from mindspore import Tensor, ops, mint import numpy as np from msprobe.core.data_dump.data_processor.base import BaseDataProcessor @@ -32,6 +32,13 @@ from msprobe.core.data_dump.data_processor.mindspore_processor import ( from msprobe.mindspore.common.log import logger +def patch_norm(value): + return ops.norm(value) + + +setattr(mint, "norm", patch_norm) + + class TestMindsporeDataProcessor(unittest.TestCase): def setUp(self): self.config = MagicMock() @@ -66,15 +73,6 @@ class TestMindsporeDataProcessor(unittest.TestCase): self.assertEqual(result.mean, 2.0) self.assertEqual(result.norm, ms.ops.norm(tensor).item()) - def test_get_stat_info_float_async(self): - self.config.async_dump = True - tensor = ms.tensor([1.0, 2.0, 3.0]) - result = self.processor.get_stat_info(tensor).stack_tensor_stat[1] - self.assertEqual(result[0].item(), 3.0) - self.assertEqual(result[1].item(), 1.0) - self.assertEqual(result[2].item(), 2.0) - self.assertEqual(result[3].item(), ms.ops.norm(tensor).item()) - def test_get_stat_info_int(self): self.config.async_dump = False tensor = ms.Tensor([1, 2, 3], dtype=ms.int32) @@ -84,13 +82,6 @@ class TestMindsporeDataProcessor(unittest.TestCase): self.assertEqual(result.mean, 2) self.assertEqual(result.norm, ms.ops.norm(tensor).item()) - def test_get_stat_info_int_async(self): - self.config.async_dump = True - tensor = ms.tensor([1, 2, 3]) - result = self.processor.get_stat_info(tensor).stack_tensor_stat[1] - self.assertEqual(result[0].item(), 3.0) - self.assertEqual(result[1].item(), 1.0) - def test_get_stat_info_bool(self): self.config.async_dump = False tensor = ms.Tensor([True, False, True]) @@ -100,64 +91,6 @@ class TestMindsporeDataProcessor(unittest.TestCase): self.assertIsNone(result.mean) self.assertIsNone(result.norm) - def test_get_stat_info_bool_async(self): - self.config.async_dump = True - tensor = ms.Tensor([True, False, True]) - result = self.processor.get_stat_info(tensor).stack_tensor_stat[1] - self.assertEqual(result[0].item(), True) - self.assertEqual(result[1].item(), False) - - @patch.object(MindsporeDataProcessor, 'get_md5_for_tensor') - def test__analyze_tensor(self, get_md5_for_tensor): - get_md5_for_tensor.return_value = "test_md5" - tensor = ms.Tensor(np.array([1, 2, 3], dtype=np.int32)) - self.config.summary_mode = 'md5' - self.config.async_dump = False - suffix = "test_tensor" - expected_result = { - 'type': 'mindspore.Tensor', - 'dtype': 'Int32', - 'shape': (3,), - 'Max': 3, - 'Min': 1, - 'Mean': 2, - 'Norm': ms.ops.norm(tensor).item(), - 'md5': 'test_md5', - } - result = self.processor._analyze_tensor(tensor, suffix) - self.assertEqual(result, expected_result) - - -class TestTensorDataProcessor(unittest.TestCase): - - def setUp(self): - self.config = MagicMock() - self.data_writer = MagicMock() - self.processor = TensorDataProcessor(self.config, self.data_writer) - self.data_writer.dump_tensor_data_dir = "./dump_data" - self.processor.current_api_or_module_name = "test_api" - self.processor.api_data_category = "input" - - @patch('msprobe.core.data_dump.data_processor.mindspore_processor.save_tensor_as_npy') - def test_analyze_tensor(self, mock_save): - self.config.framework = "mindspore" - self.config.async_dump = False - tensor = ms.Tensor([1.0, 2.0, 3.0]) - suffix = 'suffix' - result = self.processor._analyze_tensor(tensor, suffix) - mock_save.assert_called_once() - expected = { - 'type': 'mindspore.Tensor', - 'dtype': str(tensor.dtype), - 'shape': tensor.shape, - 'Max': 3.0, - 'Min': 1.0, - 'Mean': 2.0, - 'Norm': ms.ops.norm(tensor).item(), - 'data_name': 'test_api.input.suffix.npy' - } - self.assertEqual(expected, result) - class TestOverflowCheckDataProcessor(unittest.TestCase): def setUp(self): @@ -218,57 +151,6 @@ class TestOverflowCheckDataProcessor(unittest.TestCase): self.data_processor.overflow_nums = 3 self.assertFalse(self.data_processor.is_terminated) - def test__analyze_maybe_overflow_tensor(self): - self.data_processor.has_overflow = False - tensor_json = {"Max": None, "Min": 0} - self.data_processor._analyze_maybe_overflow_tensor(tensor_json) - self.assertFalse(self.data_processor.has_overflow) - tensor_json.update({"Max": -np.inf}) - self.data_processor._analyze_maybe_overflow_tensor(tensor_json) - self.assertTrue(self.data_processor.has_overflow) - self.data_processor.has_overflow = False - tensor_json.update({"Max": np.inf}) - self.data_processor._analyze_maybe_overflow_tensor(tensor_json) - self.assertTrue(self.data_processor.has_overflow) - self.data_processor.has_overflow = False - tensor_json.update({"Max": np.nan}) - self.data_processor._analyze_maybe_overflow_tensor(tensor_json) - self.assertTrue(self.data_processor.has_overflow) - tensor_json.update({"Max": 0}) - self.data_processor.has_overflow = False - tensor_json.update({"Min": -np.inf}) - self.data_processor._analyze_maybe_overflow_tensor(tensor_json) - self.assertTrue(self.data_processor.has_overflow) - self.data_processor.has_overflow = False - tensor_json.update({"Min": np.inf}) - self.data_processor._analyze_maybe_overflow_tensor(tensor_json) - self.assertTrue(self.data_processor.has_overflow) - self.data_processor.has_overflow = False - tensor_json.update({"Min": np.nan}) - self.data_processor._analyze_maybe_overflow_tensor(tensor_json) - self.assertTrue(self.data_processor.has_overflow) - - @patch("msprobe.core.data_dump.data_processor.mindspore_processor.logger.warning") - @patch.object(OverflowCheckDataProcessor, "get_save_file_path") - @patch.object(MindsporeDataProcessor, "_analyze_tensor") - def test__analyze_tensor(self, mock_super, mock_get_file_path, mock_warning): - mock_get_file_path.return_value = ("dump_data_name", "file_path") - single_arg = {"Max": None} - mock_super.return_value = single_arg - - with patch("msprobe.core.data_dump.data_processor.mindspore_processor.path_len_exceeds_limit", - return_value=False): - ret = self.data_processor._analyze_tensor("tensor", "suffix") - self.assertEqual(self.data_processor.cached_tensors_and_file_paths, {"file_path": "tensor"}) - mock_warning.assert_not_called() - mock_super.assert_called_with("tensor", "suffix") - self.assertEqual(ret.get("Max"), None) - self.assertEqual(ret.get("data_name"), "dump_data_name") - - with patch("msprobe.core.data_dump.data_processor.mindspore_processor.path_len_exceeds_limit", - return_value=True): - self.data_processor._analyze_tensor("tensor", "suffix") - mock_warning.assert_called_with("The file path file_path length exceeds limit.") class TestKernelDumpDataProcessor(unittest.TestCase): def setUp(self): diff --git a/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_pytorch_processor.py b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_pytorch_processor.py index 34064e7cc2b9d0aa5c0c2e98806b8993137a589c..43847ddc751143990a1588d1234b3788738c3aad 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_pytorch_processor.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_pytorch_processor.py @@ -1,3 +1,19 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2024-2025. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" import hashlib import os import sys @@ -19,6 +35,7 @@ from msprobe.core.data_dump.data_processor.pytorch_processor import ( KernelDumpDataProcessor ) from torch import distributed as dist +from torch._subclasses import FakeTensorMode class TestPytorchDataProcessor(unittest.TestCase): @@ -62,6 +79,15 @@ class TestPytorchDataProcessor(unittest.TestCase): result = PytorchDataProcessor.get_stat_info(mock_data) self.assertIsInstance(result, TensorStatInfo) + def test_get_stat_info_with_fake_tensor(self): + with FakeTensorMode() as fake_tensor_mode: + fake_tensor = fake_tensor_mode.from_tensor(torch.randn(1, 2, 3)) + result = PytorchDataProcessor.get_stat_info(fake_tensor) + self.assertIsNone(result.max) + self.assertIsNone(result.min) + self.assertIsNone(result.mean) + self.assertIsNone(result.norm) + def test_get_stat_info_float(self): tensor = torch.tensor([1.0, 2.0, 3.0]) result = self.processor.get_stat_info(tensor) @@ -70,14 +96,6 @@ class TestPytorchDataProcessor(unittest.TestCase): self.assertEqual(result.mean, 2.0) self.assertEqual(result.norm, torch.norm(tensor).item()) - def test_get_stat_info_float_async(self): - tensor = torch.tensor([1.0, 2.0, 3.0]) - result = self.processor.get_stat_info_async(tensor).stack_tensor_stat[1] - self.assertEqual(result[0].item(), 3.0) - self.assertEqual(result[1].item(), 1.0) - self.assertEqual(result[2].item(), 2.0) - self.assertEqual(result[3].item(), torch.norm(tensor).item()) - def test_get_stat_info_int(self): tensor = torch.tensor([1, 2, 3], dtype=torch.int32) result = self.processor.get_stat_info(tensor) @@ -86,13 +104,6 @@ class TestPytorchDataProcessor(unittest.TestCase): self.assertEqual(result.mean, 2) self.assertEqual(result.norm, torch.norm(tensor.float()).item()) - def test_get_stat_info_int_async(self): - tensor = torch.tensor([1, 2, 3]) - result = self.processor.get_stat_info_async(tensor).stack_tensor_stat[1] - self.assertEqual(result[0].item(), 3.0) - self.assertEqual(result[1].item(), 1.0) - self.assertEqual(result[2].item(), 2.0) - self.assertEqual(result[3].item(), torch.norm(tensor.float()).item()) def test_get_stat_info_empty(self): tensor = torch.tensor([]) @@ -110,12 +121,6 @@ class TestPytorchDataProcessor(unittest.TestCase): self.assertIsNone(result.mean) self.assertIsNone(result.norm) - def test_get_stat_info_bool_async(self): - tensor = torch.tensor([True, False, True]) - result = self.processor.get_stat_info_async(tensor).stack_tensor_stat[1] - self.assertEqual(result[0].item(), True) - self.assertEqual(result[1].item(), False) - def test_get_stat_info_with_scalar_tensor(self): scalar_tensor = torch.tensor(42.0) result = PytorchDataProcessor.get_stat_info(scalar_tensor) @@ -196,7 +201,7 @@ class TestPytorchDataProcessor(unittest.TestCase): dist.init_process_group(backend='gloo', world_size=1, rank=0) process_group_element = dist.group.WORLD result = self.processor.process_group_hash(process_group_element) - expected = hashlib.md5('[0]'.encode('utf-8')).hexdigest() + expected = f"{zlib.crc32(str([0]).encode('utf-8')):08x}" self.assertEqual(result, expected) def test_analyze_torch_size(self): @@ -222,7 +227,7 @@ class TestPytorchDataProcessor(unittest.TestCase): expected = { 'type': 'torch.ProcessGroup', 'group_ranks': [0], - 'group_id': hashlib.md5('[0]'.encode('utf-8')).hexdigest() + 'group_id': f"{zlib.crc32(str([0]).encode('utf-8')):08x}" } self.assertEqual(result, expected) @@ -268,11 +273,35 @@ class TestPytorchDataProcessor(unittest.TestCase): self.assertEqual(result, self.processor._analyze_process_group(process_group_element)) def test_analyze_single_element_numpy_conversion(self): - numpy_element = np.int64(1) - converted_numpy, numpy_type = self.processor._convert_numpy_to_builtin(numpy_element) + numpy_element = np.int32(5) result = self.processor.analyze_single_element(numpy_element, []) - expected_result = {"type": numpy_type, "value": converted_numpy} - self.assertEqual(result, expected_result) + expected = {"type": 'int32', "value": 5} + self.assertEqual(result, expected) + + numpy_element = np.float32(3.14) + result = self.processor.analyze_single_element(numpy_element, []) + expected = {"type": 'float32', "value": 3.140000104904175} + self.assertEqual(result, expected) + + numpy_element = np.bool_(True) + result = self.processor.analyze_single_element(numpy_element, []) + expected = {"type": 'bool_', "value": True} + self.assertEqual(result, expected) + + numpy_element = np.str_("abc") + result = self.processor.analyze_single_element(numpy_element, []) + expected = {"type": 'str_', "value": "abc"} + self.assertEqual(result, expected) + + numpy_element = np.byte(1) + result = self.processor.analyze_single_element(numpy_element, []) + expected = {"type": 'int8', "value": 1} + self.assertEqual(result, expected) + + numpy_element = np.complex128(1+2j) + result = self.processor.analyze_single_element(numpy_element, []) + expected = {"type": 'complex128', "value": (1+2j)} + self.assertEqual(result, expected) def test_analyze_single_element_tensor(self): tensor_element = torch.tensor([1, 2, 3]) @@ -291,39 +320,15 @@ class TestPytorchDataProcessor(unittest.TestCase): expected_result = self.processor._analyze_builtin(Ellipsis) self.assertEqual(result, expected_result) - @patch.object(PytorchDataProcessor, 'get_md5_for_tensor') - def test_analyze_tensor(self, get_md5_for_tensor): - get_md5_for_tensor.return_value = 'mocked_md5' - tensor = torch.tensor([1.0, 2.0, 3.0]) - self.config.summary_mode = 'md5' - self.config.async_dump = False - result = self.processor._analyze_tensor(tensor, 'suffix') - expected = { - 'type': 'torch.Tensor', - 'dtype': str(tensor.dtype), - 'shape': tensor.shape, - 'Max': 3.0, - 'Min': 1.0, - 'Mean': 2.0, - 'Norm': torch.norm(tensor).item(), - 'requires_grad': tensor.requires_grad, - 'md5': 'mocked_md5' - } - self.assertDictEqual(expected, result) - - def test_analyze_tensor_with_empty_tensor(self): - tensor = torch.tensor([]) - result = self.processor._analyze_tensor(tensor, 'suffix') - self.assertEqual(result['Max'], None) - self.assertEqual(result['Min'], None) - self.assertEqual(result['Mean'], None) - self.assertEqual(result['Norm'], None) + def test_cast_to_float_if_fp8(self): + tensor = MagicMock() + tensor.dtype = "torch.float8_e5m2" + _, dtype = self.processor._cast_to_float_if_fp8(tensor) + self.assertEqual(dtype, "torch.float8_e5m2") - def test_analyze_tensor_with_inf_and_nan(self): - tensor = torch.tensor([1.0, float('inf'), float('nan'), -float('inf')]) - result = self.processor._analyze_tensor(tensor, 'suffix') - self.assertEqual(result['Max_except_inf_nan'], 1.0) - self.assertEqual(result['Min_except_inf_nan'], 1.0) + tensor.dtype = "torch.float8_e4m3fn" + _, dtype = self.processor._cast_to_float_if_fp8(tensor) + self.assertEqual(dtype, "torch.float8_e4m3fn") class TestTensorDataProcessor(unittest.TestCase): @@ -336,27 +341,6 @@ class TestTensorDataProcessor(unittest.TestCase): self.processor.current_api_or_module_name = "test_api" self.processor.api_data_category = "input" - @patch('torch.save') - def test_analyze_tensor(self, mock_save): - self.config.framework = "pytorch" - self.config.async_dump = False - tensor = torch.tensor([1.0, 2.0, 3.0]) - suffix = 'suffix' - result = self.processor._analyze_tensor(tensor, suffix) - mock_save.assert_called_once() - expected = { - 'type': 'torch.Tensor', - 'dtype': 'torch.float32', - 'shape': tensor.shape, - 'Max': 3.0, - 'Min': 1.0, - 'Mean': 2.0, - 'Norm': torch.norm(tensor).item(), - 'requires_grad': False, - 'data_name': 'test_api.input.suffix.pt' - } - self.assertEqual(expected, result) - class TestOverflowCheckDataProcessor(unittest.TestCase): @@ -448,33 +432,6 @@ class TestOverflowCheckDataProcessor(unittest.TestCase): self.processor._is_support_inf_nan() self.assertTrue(self.processor.support_inf_nan) - def test_analyze_maybe_overflow_tensor(self): - tensor_json = {'Max': None, 'Min': None} - self.processor._analyze_maybe_overflow_tensor(tensor_json) - self.assertFalse(self.processor.has_overflow) - - tensor_json = {'Max': float('inf'), 'Min': 1.0} - self.processor._analyze_maybe_overflow_tensor(tensor_json) - self.assertTrue(self.processor.has_overflow) - - tensor_json = {'Max': 1.0, 'Min': float('inf')} - self.processor._analyze_maybe_overflow_tensor(tensor_json) - self.assertTrue(self.processor.has_overflow) - - @patch('msprobe.core.common.file_utils.path_len_exceeds_limit', return_value=False) - @patch.object(BaseDataProcessor, 'get_save_file_path', - return_value=['test_api_name', 'test_api_name.0.forward.input.pt']) - def test_analyze_tensor(self, mock_path_len_exceeds_limit, _): - tensor = torch.tensor([1.0, 2.0, 3.0]) - suffix = 'suffix' - expected = {'Max': 3.0, 'Min': 1.0, 'data_name': 'test_api_name'} - with patch.object(PytorchDataProcessor, '_analyze_tensor', - return_value={'Max': 3.0, 'Min': 1.0}) as mock_super_analyze_tensor: - result = self.processor._analyze_tensor(tensor, suffix) - mock_super_analyze_tensor.assert_called_once_with(tensor, suffix) - mock_path_len_exceeds_limit.assert_called_once() - self.assertEqual(expected, result) - class TestFreeBenchmarkDataProcessor(unittest.TestCase): diff --git a/debug/accuracy_tools/msprobe/test/core_ut/data_dump/test_api_registry.py b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/test_api_registry.py new file mode 100644 index 0000000000000000000000000000000000000000..c67c5d8ee9efd201cdcf09bc82471cac1f6607c3 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/test_api_registry.py @@ -0,0 +1,73 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from unittest import TestCase +from unittest.mock import patch + +import torch + +from msprobe.core.common.const import Const +from msprobe.core.data_dump.api_registry import _get_attr, ApiWrapper + + +class TestFunctions(TestCase): + def test__get_attr(self): + module = torch + + attr_name = 'linalg.norm' + target_value = torch.linalg.norm + actual_value = _get_attr(module, attr_name) + self.assertEqual(target_value, actual_value) + + attr_name = 'norm' + target_value = torch.norm + actual_value = _get_attr(module, attr_name) + self.assertEqual(target_value, actual_value) + + +class TestApiWrapper(TestCase): + api_types = { + Const.PT_FRAMEWORK: { + Const.PT_API_TYPE_TORCH: (torch, torch), + } + } + supported_api_list_path = (Const.SUPPORT_API_FILE_NAME,) + yaml_value = {'torch': ['linalg.norm', 'norm']} + api_names = {Const.PT_FRAMEWORK: {'torch': {'linalg.norm', 'norm'}}} + + def test___init__(self): + with patch('msprobe.core.data_dump.api_registry.load_yaml', return_value=self.yaml_value): + api_wrapper = ApiWrapper(self.api_types, self.supported_api_list_path) + self.assertEqual(api_wrapper.api_types, self.api_types) + self.assertEqual(api_wrapper.api_list_paths, self.supported_api_list_path) + self.assertEqual(api_wrapper.api_names, self.api_names) + self.assertEqual(api_wrapper.wrapped_api_functions, {}) + + api_wrapper = ApiWrapper(self.api_types, Const.SUPPORT_API_FILE_NAME) + self.assertEqual(api_wrapper.api_list_paths, list(self.supported_api_list_path)) + + with self.assertRaises(Exception) as context: + api_wrapper = ApiWrapper(self.api_types, (Const.SUPPORT_API_FILE_NAME, Const.SUPPORT_API_FILE_NAME)) + self.assertEqual(str(context.exception), + "The number of api_list_paths must be equal to the number of frameworks in 'api_types', " + "when api_list_paths is a list or tuple.") + + def test__get_api_names(self): + target_value = self.api_names + with patch('msprobe.core.data_dump.api_registry.load_yaml', return_value=self.yaml_value): + api_wrapper = ApiWrapper(self.api_types, self.supported_api_list_path) + actual_value = api_wrapper._get_api_names() + self.assertEqual(target_value, actual_value) diff --git a/debug/accuracy_tools/msprobe/test/core_ut/data_dump/test_data_collector.py b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/test_data_collector.py index b9d2e7abef7244fc12dc71e3113c26af52529ce9..b2de545c649ecf54e59e933f2c552ee0c0725883 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/data_dump/test_data_collector.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/test_data_collector.py @@ -1,8 +1,7 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" -# Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. -# Licensed under the Apache License, Version 2.0 (the "License"); +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # @@ -13,8 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" +import os import unittest from unittest.mock import patch, mock_open, MagicMock @@ -22,9 +21,6 @@ from msprobe.core.common.utils import Const from msprobe.core.data_dump.data_collector import DataCollector from msprobe.pytorch.debugger.debugger_config import DebuggerConfig from msprobe.pytorch.pt_config import parse_json_config -from msprobe.core.data_dump.json_writer import DataWriter -from msprobe.core.data_dump.data_processor.base import BaseDataProcessor -from msprobe.core.data_dump.data_processor.pytorch_processor import StatisticsDataProcessor class TestDataCollector(unittest.TestCase): @@ -38,6 +34,110 @@ class TestDataCollector(unittest.TestCase): config = DebuggerConfig(common_config, task_config, Const.STATISTICS, "./ut_dump", "L1") self.data_collector = DataCollector(config) + def test_dump_data_dir(self): + self.assertEqual(self.data_collector.dump_data_dir, None) + + self.data_collector.data_writer.dump_tensor_data_dir = "./test_dump" + self.assertEqual(self.data_collector.dump_data_dir, "./test_dump") + + def test_dump_file_path(self): + self.assertEqual(self.data_collector.dump_file_path, None) + + self.data_collector.data_writer.dump_file_path = "./test_dump/dump.json" + self.assertEqual(self.data_collector.dump_file_path, "./test_dump/dump.json") + + def test_scope_none_and_pid_match(self): + mock_name = "test_module" + current_pid = os.getpid() + result = self.data_collector.check_scope_and_pid(None, mock_name, current_pid) + self.assertTrue(result) + + def test_scope_valid_and_pid_match(self): + mock_scope = MagicMock() + mock_scope.check.return_value = True + mock_name = "valid_module" + current_pid = os.getpid() + result = self.data_collector.check_scope_and_pid(mock_scope, mock_name, current_pid) + self.assertTrue(result) + mock_scope.check.assert_called_once_with(mock_name) + + def test_scope_invalid_and_pid_match(self): + mock_scope = MagicMock() + mock_scope.check.return_value = False + mock_name = "invalid_module" + current_pid = os.getpid() + result = self.data_collector.check_scope_and_pid(mock_scope, mock_name, current_pid) + self.assertFalse(result) + + def test_scope_valid_but_pid_mismatch(self): + mock_scope = MagicMock() + mock_scope.check.return_value = True + mock_name = "valid_module" + fake_pid = os.getpid() + 1 + result = self.data_collector.check_scope_and_pid(mock_scope, mock_name, fake_pid) + self.assertFalse(result) + + def test_scope_none_but_pid_mismatch(self): + mock_name = "test_module" + fake_pid = os.getpid() + 1 + result = self.data_collector.check_scope_and_pid(None, mock_name, fake_pid) + self.assertFalse(result) + + def test_normal_case(self): + data_info = {"key1": {"other_field": "value"}} + self.data_collector.set_is_recomputable(data_info, True) + self.assertTrue(data_info["key1"]["is_recompute"]) + + self.data_collector.set_is_recomputable(data_info, False) + self.assertFalse(data_info["key1"]["is_recompute"]) + + def test_empty_data_info(self): + data_info = {} + original_data = data_info.copy() + self.data_collector.set_is_recomputable(data_info, True) + self.assertEqual(data_info, original_data) + + def test_data_info_length_not_one(self): + data_info = {"key1": {}, "key2": {}} + original_data = data_info.copy() + self.data_collector.set_is_recomputable(data_info, True) + self.assertEqual(data_info, original_data) + + def test_is_recompute_none(self): + data_info = {"key1": {}} + original_data = data_info.copy() + self.data_collector.set_is_recomputable(data_info, None) + self.assertEqual(data_info, original_data) + + def test_nested_structure(self): + data_info = {"layer1": {"sub_layer": {"value": 1}}} + self.data_collector.set_is_recomputable(data_info, True) + self.assertTrue(data_info["layer1"]["is_recompute"]) + self.assertEqual(data_info["layer1"]["sub_layer"]["value"], 1) + + def test_reset_status(self): + self.data_collector.optimizer_status = "test_optimizer_status" + self.data_collector.reset_status() + + self.assertEqual(self.data_collector.optimizer_status, "") + self.assertEqual( + self.data_collector.optimizer_status_first_start, + {Const.OPTIMIZER: True, Const.CLIP_GRAD: True} + ) + self.assertEqual(self.data_collector.backward_module_names, {}) + + def test_update_api_or_module_name(self): + self.assertEqual(self.data_collector.data_processor.current_api_or_module_name, None) + + self.data_collector.update_api_or_module_name("test_api_name") + self.assertEqual(self.data_collector.data_processor.current_api_or_module_name, "test_api_name") + + def test_write_json(self): + self.data_collector.data_writer = MagicMock() + + self.data_collector.write_json() + self.data_collector.data_writer.write_json.assert_called_once() + def test_update_data(self): self.data_collector.config.task = Const.OVERFLOW_CHECK self.data_collector.data_processor.has_overflow = True @@ -59,6 +159,82 @@ class TestDataCollector(unittest.TestCase): mock_warning.assert_not_called() mock_debug.assert_called_once_with("msprobe is collecting data on Tensor.add.") + def test_call_stack_collect(self): + self.data_collector.data_processor = MagicMock() + self.data_collector.data_writer = MagicMock() + + test_name = "test_api" + mock_stack = ["func1", "func2", "func3"] + self.data_collector.data_processor.analyze_api_call_stack.return_value = mock_stack + + self.data_collector.call_stack_collect(test_name) + + self.data_collector.data_processor.analyze_api_call_stack.assert_called_once_with(test_name) + self.data_collector.data_writer.update_stack.assert_called_once_with(test_name, mock_stack) + + def test_update_construct_without_construct(self): + self.data_collector.data_writer = MagicMock() + + self.data_collector.config.level = "L1" + self.data_collector.update_construct("test") + self.data_collector.data_writer.update_construct.assert_not_called() + + def test_update_construct_with_first_start(self): + self.data_collector.module_processor = MagicMock() + self.data_collector.data_writer = MagicMock() + self.data_collector.config.level = "L0" + self.data_collector.optimizer_status = "optimizer" + self.data_collector.optimizer_status_first_start = {"optimizer": True} + + self.data_collector.update_construct("test_name") + calls = [ + unittest.mock.call({"optimizer": None}), + unittest.mock.call({"test_name": "optimizer"}), + unittest.mock.call(self.data_collector.module_processor.module_node) + ] + self.data_collector.data_writer.update_construct.assert_has_calls(calls) + + def test_update_construct_with_not_first_start(self): + self.data_collector.module_processor = MagicMock() + self.data_collector.data_writer = MagicMock() + self.data_collector.config.level = "L0" + self.data_collector.optimizer_status = "clip_grad" + self.data_collector.optimizer_status_first_start = {"clip_grad": False} + + self.data_collector.update_construct("test_name") + calls = [ + unittest.mock.call({"test_name": "clip_grad"}), + unittest.mock.call(self.data_collector.module_processor.module_node) + ] + self.data_collector.data_writer.update_construct.assert_has_calls(calls) + + def test_update_construct_with_module_prefix(self): + self.data_collector.module_processor = MagicMock() + self.data_collector.data_writer = MagicMock() + self.data_collector.config.level = "mix" + self.data_collector.optimizer_status = "other_status" + test_name = "Module_test_name" + + self.data_collector.update_construct(test_name) + self.data_collector.data_writer.update_construct.assert_called_with( + self.data_collector.module_processor.module_node + ) + + def test_update_construct_without_module_prefix(self): + self.data_collector.module_processor = MagicMock() + self.data_collector.data_writer = MagicMock() + self.data_collector.config.level = "mix" + self.data_collector.optimizer_status = "other_status" + self.data_collector.module_processor.api_parent_node = "parent_node" + test_name = "api_name" + + self.data_collector.update_construct(test_name) + calls = [ + unittest.mock.call({"api_name": "parent_node"}), + unittest.mock.call(self.data_collector.module_processor.module_node) + ] + self.data_collector.data_writer.update_construct.assert_has_calls(calls) + def test_handle_data(self): with patch.object(DataCollector, "update_data") as mock_update_data, \ patch.object(DataCollector, "write_json") as mock_write_json, \ @@ -76,44 +252,212 @@ class TestDataCollector(unittest.TestCase): mock_flush.assert_not_called() mock_write_json.assert_called() - @patch.object(DataCollector, "update_construct") - @patch.object(DataWriter, "update_stack") - @patch.object(BaseDataProcessor, "analyze_api_call_stack") - @patch.object(DataCollector, "handle_data") - def test_forward_data_collect(self, mock_handle_data, _, __, ___): - with patch.object(DataCollector, "check_scope_and_pid", return_value=True), \ - patch.object(StatisticsDataProcessor, "analyze_forward", return_value={}): - with patch.object(StatisticsDataProcessor, "is_terminated", new=True): - self.data_collector.forward_data_collect("name", "module", "pid", "module_input_output") - mock_handle_data.assert_called_with("name", {}, flush=True) - - self.data_collector.forward_data_collect("name", "module", "pid", "module_input_output") - mock_handle_data.assert_called_with("name", {}, flush=False) - - @patch.object(DataCollector, "update_construct") - @patch.object(DataCollector, "handle_data") - def test_backward_data_collect(self, mock_handle_data, _): - with patch.object(DataCollector, "check_scope_and_pid", return_value=True), \ - patch.object(StatisticsDataProcessor, "analyze_backward", return_value={}): - with patch.object(StatisticsDataProcessor, "is_terminated", new=True): - self.data_collector.backward_data_collect("name", "module", "pid", "module_input_output") - mock_handle_data.assert_called_with("name", {}, flush=True) - - self.data_collector.backward_data_collect("name", "module", "pid", "module_input_output") - mock_handle_data.assert_called_with("name", {}, flush=False) - - @patch.object(DataWriter, "update_debug") - @patch.object(BaseDataProcessor, "analyze_debug_forward", return_value="data_info") - def test_debug_data_collect_forward(self, _, mock_update_debug): - self.data_collector.debug_data_collect_forward("variable", "name_with_count") - mock_update_debug.assert_called_with({"name_with_count": "data_info"}) - - @patch.object(DataWriter, "update_debug") - @patch.object(BaseDataProcessor, "analyze_debug_backward") - @patch.object(BaseDataProcessor, "analyze_element_to_all_none", return_value = "all_none_data_info") - def test_debug_data_collect_backward(self, _, mock_analyze_debug_backward, mock_update_debug): - self.data_collector.data_writer.cache_debug = {"data": None} - self.data_collector.debug_data_collect_backward("variable", "name_with_count") - mock_update_debug.assert_called_with({"name_with_count": "all_none_data_info"}) - mock_analyze_debug_backward.assert_called_with("variable", "name_with_count", self.data_collector.data_writer.cache_debug['data']) - self.data_collector.data_writer.cache_debug = None + +class TestForwardDataCollect(unittest.TestCase): + def setUp(self): + mock_json_data = { + "dump_path": "./test_fwd_dump", + } + with patch("msprobe.pytorch.pt_config.FileOpen", mock_open(read_data='')), \ + patch("msprobe.pytorch.pt_config.load_json", return_value=mock_json_data): + common_config, task_config = parse_json_config("./config.json", Const.STATISTICS) + config = DebuggerConfig(common_config, task_config, Const.STATISTICS, "./test_fwd_dump", "L1") + self.data_collector = DataCollector(config) + + self.data_collector.update_construct = MagicMock() + self.data_collector.config = MagicMock() + self.data_collector.data_processor = MagicMock() + self.data_collector.scope = "test_scope" + self.data_collector.check_scope_and_pid = MagicMock() + self.data_collector.set_is_recomputable = MagicMock() + self.data_collector.handle_data = MagicMock() + self.data_collector.call_stack_collect = MagicMock() + + self.Const = MagicMock() + self.Const.FREE_BENCHMARK = "free_benchmark" + self.Const.TENSOR = "tensor" + self.Const.FORWARD = "forward" + self.Const.BACKWARD = "backward" + self.Const.STRUCTURE = "structure" + self.Const.LEVEL_L2 = "L2" + + def test_forward_input_with_free_benchmark_task(self): + self.data_collector.config.task = self.Const.FREE_BENCHMARK + self.data_collector.check_scope_and_pid.return_value = True + + self.data_collector.forward_input_data_collect( + "forward_test", + "module1", + 123, + "input_output" + ) + + self.data_collector.data_processor.analyze_forward_input.assert_called_once_with( + "backward_test", + "module1", + "input_output" + ) + + def test_forward_input_with_scope_pid_check_fail(self): + self.data_collector.config.task = self.Const.TENSOR + self.data_collector.check_scope_and_pid.return_value = False + + self.data_collector.forward_input_data_collect( + "test", "module1", 123, "input_output" + ) + + self.data_collector.data_processor.analyze_forward_input.assert_not_called() + + def test_forward_input_with_structure_task(self): + self.data_collector.config.task = self.Const.STRUCTURE + self.data_collector.check_scope_and_pid.return_value = True + + self.data_collector.forward_input_data_collect( + "test", "module1", 123, "input_output" + ) + + self.data_collector.data_processor.analyze_forward_input.assert_not_called() + self.data_collector.set_is_recomputable.assert_called_once_with({}, None) + + def test_forward_input_with_level_l2(self): + self.data_collector.config.task = self.Const.TENSOR + self.data_collector.config.level = self.Const.LEVEL_L2 + self.data_collector.check_scope_and_pid.return_value = True + + self.data_collector.forward_input_data_collect( + "test", "module1", 123, "input_output" + ) + + self.data_collector.handle_data.assert_not_called() + + def test_forward_input_with_recompute(self): + self.data_collector.config.task = self.Const.TENSOR + self.data_collector.config.level = "L1" + self.data_collector.check_scope_and_pid.return_value = True + mock_data = {"key": "value"} + self.data_collector.data_processor.analyze_forward_input.return_value = mock_data + + self.data_collector.forward_input_data_collect( + "test", "module1", 123, "input_output", is_recompute=True + ) + + self.data_collector.set_is_recomputable.assert_called_once_with(mock_data, True) + self.data_collector.handle_data.assert_called_once_with( + "test", mock_data, flush=self.data_collector.data_processor.is_terminated + ) + + def test_forward_output_with_scope_check_fail(self): + self.data_collector.check_scope_and_pid.return_value = False + self.data_collector.forward_output_data_collect("test", "module", 123, "data") + self.data_collector.data_processor.analyze_forward_output.assert_not_called() + + def test_forward_output_with_structure_task(self): + self.data_collector.config.task = self.Const.STRUCTURE + self.data_collector.forward_output_data_collect("test", "module", 123, "data") + self.data_collector.data_processor.analyze_forward_output.assert_not_called() + + def test_forward_output_with_level_l2(self): + self.data_collector.config.level = self.Const.LEVEL_L2 + self.data_collector.forward_output_data_collect("test", "module", 123, "data") + self.data_collector.handle_data.assert_not_called() + + def test_forward_output_normal(self): + mock_data = {"key": "value"} + self.data_collector.data_processor.analyze_forward_output.return_value = mock_data + self.data_collector.forward_output_data_collect("test", "module", 123, "data", True) + self.data_collector.handle_data.assert_called_once_with( + "test", + mock_data, + flush=self.data_collector.data_processor.is_terminated + ) + + def test_forward_with_scope_check_fail(self): + self.data_collector.check_scope_and_pid.return_value = False + self.data_collector.forward_data_collect("test", "module", 123, "data") + self.data_collector.data_processor.analyze_forward.assert_not_called() + + def test_forward_with_structure_task(self): + self.data_collector.config.task = self.Const.STRUCTURE + self.data_collector.forward_data_collect("test", "module", 123, "data") + self.data_collector.data_processor.analyze_forward.assert_not_called() + + def test_forward_normal(self): + mock_data = {"key": "value"} + self.data_collector.data_processor.analyze_forward.return_value = mock_data + self.data_collector.forward_data_collect("test", "module", 123, "data", False) + self.data_collector.call_stack_collect.assert_called_once_with("test") + self.data_collector.handle_data.assert_called_once_with( + "test", + mock_data, + flush=self.data_collector.data_processor.is_terminated + ) + + +class TestBackwardDataCollector(unittest.TestCase): + def setUp(self): + mock_json_data = { + "dump_path": "./test_bwd_dump", + } + with patch("msprobe.pytorch.pt_config.FileOpen", mock_open(read_data='')), \ + patch("msprobe.pytorch.pt_config.load_json", return_value=mock_json_data): + common_config, task_config = parse_json_config("./config.json", Const.STATISTICS) + config = DebuggerConfig(common_config, task_config, Const.STATISTICS, "./test_bwd_dump", "L1") + self.data_collector = DataCollector(config) + + self.data_collector.config = MagicMock() + self.data_collector.data_processor = MagicMock() + self.data_collector.scope = "test_scope" + self.data_collector.check_scope_and_pid = MagicMock(return_value=True) + self.data_collector.set_is_recomputable = MagicMock() + self.data_collector.handle_data = MagicMock() + self.data_collector.update_construct = MagicMock() + self.data_collector.backward_module_names = {} + + self.Const = MagicMock() + self.Const.STRUCTURE = "structure" + self.Const.TENSOR = "tensor" + self.Const.LEVEL_L2 = "L2" + self.Const.SEP = "." + self.Const.MODULE_PREFIX = ["module"] + + def test_backward_with_scope_check_fail(self): + self.data_collector.check_scope_and_pid.return_value = False + self.data_collector.backward_data_collect("test", "module", 123, "data") + self.data_collector.data_processor.analyze_backward.assert_not_called() + + def test_backward_with_level_l2(self): + self.data_collector.config.level = self.Const.LEVEL_L2 + self.data_collector.backward_data_collect("test", "module", 123, "data") + self.data_collector.handle_data.assert_not_called() + + def test_backward_data_module_prefix_match(self): + self.data_collector.check_scope_and_pid.return_value = True + self.data_collector.config.task = self.Const.TENSOR + self.data_collector.config.level = "L1" + mock_data = {"key": "value"} + self.data_collector.data_processor.analyze_backward.return_value = mock_data + test_name = "Module.layer1.backward" + self.data_collector.backward_data_collect(test_name, "module", 123, "data") + self.assertEqual(self.data_collector.backward_module_names, {"Module": True}) + + def test_backward_input_with_structure_task(self): + self.data_collector.config.task = self.Const.STRUCTURE + self.data_collector.backward_input_data_collect("test", "module", 123, "data") + self.data_collector.data_processor.analyze_backward_input.assert_not_called() + + def test_backward_input_with_normal(self): + mock_data = {"key": "value"} + self.data_collector.data_processor.analyze_backward_input.return_value = mock_data + self.data_collector.backward_input_data_collect("test", "module", 123, "data", True) + self.data_collector.set_is_recomputable.assert_called_once_with(mock_data, True) + + def test_backward_output_with_scope_check_fail(self): + self.data_collector.check_scope_and_pid.return_value = False + self.data_collector.backward_output_data_collect("test", "module", 123, "data") + self.data_collector.data_processor.analyze_backward_output.assert_not_called() + + def test_backward_output_with_recompute(self): + mock_data = {"key": "value"} + self.data_collector.data_processor.analyze_backward_output.return_value = mock_data + self.data_collector.backward_output_data_collect("test", "module", 123, "data", False) + self.data_collector.set_is_recomputable.assert_called_once_with(mock_data, False) diff --git a/debug/accuracy_tools/msprobe/test/core_ut/data_dump/test_json_writer.py b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/test_json_writer.py index 9b20ffb2197882e16c1550cf013d1ba132096063..5fa3f1c254dd357b4f7a8e0c875480dbb6330563 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/data_dump/test_json_writer.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/test_json_writer.py @@ -117,8 +117,9 @@ class TestDataWriter(unittest.TestCase): self.assertEqual(self.data_writer.cache_data, expected) def test_update_stack(self): - self.data_writer.update_stack(self.data_content) - self.assertEqual(self.data_writer.cache_stack, self.data_content) + self.data_writer.cache_stack = {"stack1": ["test1"]} + self.data_writer.update_stack("test2", "stack1") + self.assertEqual(self.data_writer.cache_stack, {"stack1": ["test1", "test2"]}) def test_update_construct(self): self.data_writer.update_construct(self.data_content) @@ -136,13 +137,13 @@ class TestDataWriter(unittest.TestCase): os.remove(file_path) def test_write_stack_info_json(self): - self.data_writer.cache_stack = self.data_content + self.data_writer.cache_stack = {("api1", "api2"): ["stack1"]} file_path = os.path.join(self.cur_path, "stack.json") self.data_writer.write_stack_info_json(file_path) load_result = load_json(file_path) try: - self.assertEqual(load_result, self.data_content) + self.assertEqual(load_result, {"0": [["stack1"], ["api1", "api2"]]}) finally: os.remove(file_path) diff --git a/debug/accuracy_tools/msprobe/test/core_ut/monitor/__init__.py b/debug/accuracy_tools/msprobe/test/core_ut/monitor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_anomaly_analyse.py b/debug/accuracy_tools/msprobe/test/core_ut/monitor/test_anomaly_processor.py similarity index 46% rename from debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_anomaly_analyse.py rename to debug/accuracy_tools/msprobe/test/core_ut/monitor/test_anomaly_processor.py index 904be210a3771f1757e4410b5e0fa0f2ad6152f2..2511d60caa823366c778761ccb8fb9bca747d2f5 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_anomaly_analyse.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/monitor/test_anomaly_processor.py @@ -1,14 +1,282 @@ import os import unittest +from unittest import TestCase from unittest.mock import patch, MagicMock -from msprobe.pytorch.monitor.anomaly_detect import GradAnomalyData +from msprobe.core.monitor.anomaly_processor import ScanRule, AnomalyTurbulence, AnomalyNan, AnomalyScanner, \ + AnomalyDataFactory, GradAnomalyData, AnomalyDataWriter, AnomalyDataLoader, AnomalyAnalyse, \ + _get_step_and_stop, _anomaly_analyse, _get_parse_args -from msprobe.pytorch.monitor.anomaly_analyse import AnomalyDataWriter, AnomalyDataLoader, AnomalyAnalyse, \ - _get_parse_args, _get_step_and_stop, _anomaly_analyse +class TestScanRule(TestCase): + def test_apply_not_implemented(self): + scan_rule = ScanRule() + with self.assertRaises(Exception) as context: + scan_rule.apply(None, None) + + self.assertEqual(str(context.exception), "abstract method apply is not implemented") + + +class TestAnomalyTurbulence(TestCase): + + def setUp(self) -> None: + self.threshold = 0.2 + self.rule = AnomalyTurbulence(self.threshold) + + def test_apply_with_positive_baseline(self): + history = 12 + cur = 16 + result = self.rule.apply(cur, history=history) + self.assertTrue(result) + + def test_apply_with_non_positive_baseline(self): + history = 0 + cur = -1 + result = self.rule.apply(cur, history=history) + self.assertTrue(result) + + def test_apply_with_valid_value(self): + history = 0 + cur = 0 + result = self.rule.apply(cur, history=history) + self.assertFalse(result) + + +class TestAnomalyNan(TestCase): + + def setUp(self) -> None: + self.threshold = 1e10 + self.rule = AnomalyNan(self.threshold) + + def test_apply_with_nan(self): + cur = float("nan") + result = self.rule.apply(cur) + self.assertTrue(result) + + def test_apply_with_big_value(self): + cur = float("1e30") + result = self.rule.apply(cur) + self.assertTrue(result) + + def test_apply_with_valid_value(self): + cur = 0.5 + result = self.rule.apply(cur) + self.assertFalse(result) + + +class TestAnomalyScanner(TestCase): + + def test_load_rules_with_valied_spec(self): + specs = [ + {"rule_name": "AnomalyTurbulence", "args": {"threshold": 0.2}} + ] + rules = AnomalyScanner.load_rules(specs) + + self.assertEqual(len(rules), 1) + self.assertIsInstance(rules[0], AnomalyTurbulence) + self.assertEqual(rules[0].threshold, 0.2) + + rules = AnomalyScanner.load_rules(None) + self.assertEqual(len(rules), 0) + + @patch("msprobe.core.monitor.anomaly_processor.logger") + def test_load_rules_with_missing_keys(self, mock_logger): + specs = [ + {"rule_name": "AnomalyTurbulence"} + ] + rules = AnomalyScanner.load_rules(specs) -class TestAnomalyDataWriter(unittest.TestCase): + self.assertEqual(len(rules), 0) + mock_logger.warning.assert_called_once_with(f"Spec is missing required keys: {specs[0]}") + + def test_load_rules_with_invalid_rule(self): + # test invalid rule_name + specs = [{"rule_name": "InvalidRule", "args": {"threshold": 0.2}}] + rules = AnomalyScanner.load_rules(specs) + self.assertEqual(len(rules), 0) + + # test invalid args + specs = [{"rule_name": "AnomalyTurbulence", "args": "invalid args"}] + rules = AnomalyScanner.load_rules(specs) + self.assertEqual(len(rules), 0) + + def test_scan(self): + ad_rules = [AnomalyTurbulence(0.2)] + # test scan with anomaly + expected = True, "AnomalyTurbulence" + self.assertEqual(AnomalyScanner.scan(ad_rules, 1.0, 2.0), expected) + # test scan with no anomaly + expected = False, None + self.assertEqual(AnomalyScanner.scan(ad_rules, 1.0, 1.0), expected) + + +class TestAnomalyDataFactory(TestCase): + + def setUp(self) -> None: + rank = 0 + pp_stage = 0 + group_mates = [0] + self.AnomalyDataFactory = AnomalyDataFactory(rank, pp_stage, group_mates) + + def test_set_call_id(self): + name2callid = {'param_name': 0} + self.AnomalyDataFactory.set_call_id(name2callid) + + self.assertEqual(self.AnomalyDataFactory.name2callid, {'param_name': 0}) + + def test_create_success(self): + tag = ('0:1.self_attention.core_attention_flash_0/rank0/output', 'min') + message = "Rule AnomalyTurbulence reports anomaly signal in ('0:1.self_attention.core_attention_flash_0/rank0/output', 'min') at step 2." + step = 2 + result = self.AnomalyDataFactory.create(tag, message, step) + + self.assertEqual(result.step, step) + self.assertEqual(result.tag_name, tag[0]) + self.assertEqual(result.message, message) + self.assertEqual(result.vpp_stage, 0) + + # test no vpp_stage + tag = ('1.self_attention.core_attention_flash_0/rank0/output', 'min') + result = self.AnomalyDataFactory.create(tag, message, step) + self.assertEqual(result.vpp_stage, 0) + + def test_create_failed(self): + error_tag = '0:1.self_attention.core_attention_flash_0/rank0/output' + message = "Rule AnomalyTurbulence reports anomaly signal in ('0:1.self_attention.core_attention_flash_0/rank0/output', 'min') at step 2." + step = 2 + with self.assertRaises(Exception) as context: + self.AnomalyDataFactory.create(error_tag, message, step) + self.assertEqual(str(context.exception), "tag must be a tuple with length 2") + + +class TestGradAnomalyData(TestCase): + + def setUp(self) -> None: + tag_name = "0:1.self_attention.core_attention_flash.output:0/rank0/actv" + message = "Rule AnomalyTurbulence reports anomaly signal in ('0:1.self_attention.core_attention_flash.output:0/rank0/actv', 'min') at step 2." + group_mates = [0] + self.GradAnomalyData = GradAnomalyData(tag_name=tag_name, message=message, group_mates=group_mates) + + def test_get_train_stage(self): + tag_name_list = ["0:fc2.input:0/rank0/actv", "0:fc1.weight/rank0/post_grad", "0:fc2.weight/rank0/exp_avg_sq", ""] + expected_train_stage_list = [0, 1, 2, -1] + for tag_name, expected_train_stage in zip(tag_name_list, expected_train_stage_list): + train_stage = GradAnomalyData.get_train_stage(tag_name) + self.assertEqual(train_stage, expected_train_stage) + + def test_to_dict(self): + expected = { + 'rank': 0, + 'step': 0, + 'micro_step': 0, + 'pp_stage': 0, + 'vpp_stage': 0, + 'call_id': 0, + 'tag_name': "0:1.self_attention.core_attention_flash.output:0/rank0/actv", + 'message': "Rule AnomalyTurbulence reports anomaly signal in ('0:1.self_attention.core_attention_flash.output:0/rank0/actv', 'min') at step 2.", + 'group_mates': [0] + } + + self.assertEqual(self.GradAnomalyData.to_dict(), expected) + + def test_get_key(self): + expected = "0:1.self_attention.core_attention_flash.output:0/rank0/actv_step_0_call_0" + + self.assertEqual(self.GradAnomalyData.get_key(), expected) + + def test_lt_different_step(self): + data1 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=0, call_id=0, tag_name="") + data2 = GradAnomalyData(step=2, micro_step=0, vpp_stage=0, pp_stage=0, call_id=0, tag_name="") + self.assertLess(data1, data2) + self.assertGreater(data2, data1) + + def test_lt_same_step_different_micro_step(self): + data1 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=0, call_id=0, tag_name="") + data2 = GradAnomalyData(step=1, micro_step=1, vpp_stage=0, pp_stage=0, call_id=0, tag_name="") + self.assertLess(data1, data2) + self.assertGreater(data2, data1) + + def test_lt_same_step_same_micro_step_different_vpp_stage(self): + # same forward + data1 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=0, call_id=0, tag_name="xxx/actv") + data2 = GradAnomalyData(step=1, micro_step=0, vpp_stage=1, pp_stage=0, call_id=0, tag_name="xxx/actv") + self.assertLess(data1, data2) + self.assertGreater(data2, data1) + + # same backward + data1 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=0, call_id=0, tag_name="xxx/post_grad") + data2 = GradAnomalyData(step=1, micro_step=0, vpp_stage=1, pp_stage=0, call_id=0, tag_name="xxx/post_grad") + self.assertLess(data2, data1) + self.assertGreater(data1, data2) + + # diff train stage + data1 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=0, call_id=0, tag_name="xxx/actv") + data2 = GradAnomalyData(step=1, micro_step=0, vpp_stage=1, pp_stage=0, call_id=0, tag_name="xxx/post_grad") + self.assertLess(data1, data2) + self.assertGreater(data2, data1) + + def test_lt_same_step_same_micro_step_same_vpp_stage_different_pp_stage(self): + # same forward + data1 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=0, call_id=0, tag_name="xxx/actv") + data2 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=1, call_id=0, tag_name="xxx/actv") + self.assertLess(data1, data2) + self.assertGreater(data2, data1) + + # same backward + data1 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=0, call_id=0, tag_name="xxx/post_grad") + data2 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=1, call_id=0, tag_name="xxx/post_grad") + self.assertLess(data2, data1) + self.assertGreater(data1, data2) + + # diff train stage + data1 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=0, call_id=0, tag_name="xxx/input") + data2 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=1, call_id=0, tag_name="xxx/post_grad") + self.assertLess(data1, data2) + self.assertGreater(data2, data1) + + def test_lt_same_step_same_micro_step_same_vpp_stage_same_pp_stage_different_call_id(self): + data1 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=0, call_id=0, tag_name="") + data2 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=0, call_id=1, tag_name="") + self.assertLess(data1, data2) + self.assertGreater(data2, data1) + + def test_lt_same_data(self): + data1 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=0, call_id=0, tag_name="") + data2 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=0, call_id=0, tag_name="") + self.assertGreaterEqual(data1, data2) + self.assertLessEqual(data1, data2) + + def test_lt_not_instance(self): + data = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=0, call_id=0) + not_instance = "not an instance of GradAnomalyData" + self.assertEqual(data.__lt__(not_instance), NotImplemented) + + def test_le_same_instance(self): + # 测试相同实例的情况 + data1 = GradAnomalyData() + self.assertTrue(data1 <= data1) + + def test_le_different_instance(self): + # 测试不同实例的情况 + data1 = GradAnomalyData() + data2 = GradAnomalyData() + self.assertTrue(data1 <= data2) + + def test_le_not_instance(self): + # 测试非GradAnomalyData实例的情况 + data = GradAnomalyData() + not_instance = "Not an instance of GradAnomalyData" + self.assertEqual(data.__le__(not_instance), NotImplemented) + + def test_le_different_instance_not_equal(self): + # 测试不同实例且不相等的情况 + data1 = GradAnomalyData() + data2 = GradAnomalyData() + data2.some_attribute = "some value" + self.assertTrue(data1 <= data2) + + +class TestAnomalyDataWriter(TestCase): def test_get_anomaly_dict(self): # 测试 get_anomaly_dict 方法 @@ -29,9 +297,9 @@ class TestAnomalyDataWriter(unittest.TestCase): } self.assertEqual(result, expected) - @patch('msprobe.pytorch.monitor.anomaly_analyse.os.path.exists') - @patch('msprobe.pytorch.monitor.anomaly_analyse.create_directory') - @patch('msprobe.pytorch.monitor.anomaly_analyse.save_json') + @patch('msprobe.core.monitor.anomaly_processor.os.path.exists') + @patch('msprobe.core.monitor.anomaly_processor.create_directory') + @patch('msprobe.core.monitor.anomaly_processor.save_json') def test_init_detected_json(self, mock_save_json, mock_create_directory, mock_exists): # 模拟路径检查 mock_exists.side_effect = [False, False, False] # dump_path, dump_rank_dir, json_path @@ -42,16 +310,15 @@ class TestAnomalyDataWriter(unittest.TestCase): writer.init_detected_json() # 检查是否创建了目录 - mock_create_directory.assert_any_call('/tmp/dump') mock_create_directory.assert_any_call('/tmp/dump/rank0') # 检查是否初始化了 JSON 文件 mock_save_json.assert_called_once_with(writer.json_path, {}, indent=1) - @patch('msprobe.pytorch.monitor.anomaly_analyse.check_file_or_directory_path') - @patch('msprobe.pytorch.monitor.anomaly_analyse.remove_path') - @patch('msprobe.pytorch.monitor.anomaly_analyse.save_json') - @patch('msprobe.pytorch.monitor.anomaly_analyse.logger') + @patch('msprobe.core.monitor.anomaly_processor.check_file_or_directory_path') + @patch('msprobe.core.monitor.anomaly_processor.remove_path') + @patch('msprobe.core.monitor.anomaly_processor.save_json') + @patch('msprobe.core.monitor.anomaly_processor.logger') def test_init_detected_json_existing_file(self, mock_logger, mock_save_json, mock_remove_path, mock_check_path): # 设置测试参数 dump_path = 'test/dump_path' @@ -72,9 +339,9 @@ class TestAnomalyDataWriter(unittest.TestCase): mock_logger.warning.assert_called_once_with(f"The existing file will be deleted: {writer.json_path}.") mock_save_json.assert_called_once_with(writer.json_path, {}, indent=1) - @patch('msprobe.pytorch.monitor.anomaly_analyse.os.path.exists') - @patch('msprobe.pytorch.monitor.anomaly_analyse.load_json') - @patch('msprobe.pytorch.monitor.anomaly_analyse.save_json') + @patch('msprobe.core.monitor.anomaly_processor.os.path.exists') + @patch('msprobe.core.monitor.anomaly_processor.load_json') + @patch('msprobe.core.monitor.anomaly_processor.save_json') def test_write_detected_json(self, mock_save_json, mock_load_json, mock_exists): mock_exists.side_effect = [True, True] # json_path 存在 @@ -101,9 +368,9 @@ class TestAnomalyDataWriter(unittest.TestCase): mock_save_json.assert_called_once_with(writer.json_path, expected_data, indent=1) -class TestAnomalyDataLoader(unittest.TestCase): +class TestAnomalyDataLoader(TestCase): - @patch('msprobe.pytorch.monitor.anomaly_analyse.GradAnomalyData') # 替换为 GradAnomalyData 的实际导入路径 + @patch('msprobe.core.monitor.anomaly_processor.GradAnomalyData') # 替换为 GradAnomalyData 的实际导入路径 def test_create_instances_from_dict(self, mock_GradAnomalyData): # 模拟 GradAnomalyData 的构造函数 def mock_constructor(**kwargs): @@ -122,11 +389,11 @@ class TestAnomalyDataLoader(unittest.TestCase): # 确保创建了两个实例,第三个因缺少 key2 被捕获 self.assertEqual(len(instances), 2) - @patch('msprobe.pytorch.monitor.anomaly_analyse.os.listdir') - @patch('msprobe.pytorch.monitor.anomaly_analyse.os.path.exists') - @patch('msprobe.pytorch.monitor.anomaly_analyse.load_json') - @patch('msprobe.pytorch.monitor.anomaly_analyse.check_file_or_directory_path') - @patch('msprobe.pytorch.monitor.anomaly_analyse.GradAnomalyData') + @patch('msprobe.core.monitor.anomaly_processor.os.listdir') + @patch('msprobe.core.monitor.anomaly_processor.os.path.exists') + @patch('msprobe.core.monitor.anomaly_processor.load_json') + @patch('msprobe.core.monitor.anomaly_processor.check_file_or_directory_path') + @patch('msprobe.core.monitor.anomaly_processor.GradAnomalyData') def test_get_anomalies_from_jsons(self, mock_GradAnomalyData, mock_check_path, mock_load_json, mock_exists, mock_listdir): mock_check_path.return_value = None @@ -146,7 +413,7 @@ class TestAnomalyDataLoader(unittest.TestCase): mock_GradAnomalyData.side_effect = mock_constructor # 假设构造成功 loader = AnomalyDataLoader('/tmp/data') - with patch('msprobe.pytorch.monitor.anomaly_analyse.os.path.isdir', return_value=True): + with patch('msprobe.core.monitor.anomaly_processor.os.path.isdir', return_value=True): anomalies = loader.get_anomalies_from_jsons() # 确保从 rank0 读取了异常数据 @@ -155,7 +422,7 @@ class TestAnomalyDataLoader(unittest.TestCase): mock_load_json.assert_called_once_with('/tmp/data/rank0/anomaly.json') -class TestAnomalyAnalyse(unittest.TestCase): +class TestAnomalyAnalyse(TestCase): def setUp(self): self.anomaly_analyse = AnomalyAnalyse() @@ -189,10 +456,10 @@ class TestAnomalyAnalyse(unittest.TestCase): self.assertEqual(len(result), 3) self.assertEqual(result, [anomalies[1], anomalies[0], anomalies[2]]) - @patch('msprobe.pytorch.monitor.anomaly_analyse.os.path.exists') - @patch('msprobe.pytorch.monitor.anomaly_analyse.AnomalyDataWriter.get_anomaly_dict') - @patch('msprobe.pytorch.monitor.anomaly_analyse.save_json') - @patch('msprobe.pytorch.monitor.anomaly_analyse.logger') + @patch('msprobe.core.monitor.anomaly_processor.os.path.exists') + @patch('msprobe.core.monitor.anomaly_processor.AnomalyDataWriter.get_anomaly_dict') + @patch('msprobe.core.monitor.anomaly_processor.save_json') + @patch('msprobe.core.monitor.anomaly_processor.logger') def test_rewrite_sorted_anomalies(self, mock_logger, mock_save_json, mock_get_anomaly_dict, mock_exists): # 设置 mock mock_exists.return_value = False @@ -202,7 +469,7 @@ class TestAnomalyAnalyse(unittest.TestCase): # 调用方法 self.anomaly_analyse.sorted_anomalies = self.anomalies - with patch("msprobe.pytorch.monitor.anomaly_analyse.check_file_or_directory_path", return_value=None): + with patch("msprobe.core.monitor.anomaly_processor.check_file_or_directory_path", return_value=None): self.anomaly_analyse.rewrite_sorted_anomalies(output_path) # 验证调用 @@ -214,17 +481,17 @@ class TestAnomalyAnalyse(unittest.TestCase): ) mock_logger.info.assert_called_once_with("anomaly_analyse.json is at output_path.") - @patch('msprobe.pytorch.monitor.anomaly_analyse.os.path.exists') - @patch('msprobe.pytorch.monitor.anomaly_analyse.logger') + @patch('msprobe.core.monitor.anomaly_processor.os.path.exists') + @patch('msprobe.core.monitor.anomaly_processor.logger') def test_rewrite_sorted_anomalies_file_exists(self, mock_logger, mock_exists): # 模拟文件已经存在的情况 mock_exists.return_value = True output_path = 'output_path' # 调用方法 - with patch("msprobe.pytorch.monitor.anomaly_analyse.check_file_or_directory_path", return_value=None), \ - patch("msprobe.pytorch.monitor.anomaly_analyse.remove_path", return_value=None), \ - patch("msprobe.pytorch.monitor.anomaly_analyse.save_json", return_value=None): + with patch("msprobe.core.monitor.anomaly_processor.check_file_or_directory_path", return_value=None), \ + patch("msprobe.core.monitor.anomaly_processor.remove_path", return_value=None), \ + patch("msprobe.core.monitor.anomaly_processor.save_json", return_value=None): self.anomaly_analyse.rewrite_sorted_anomalies(output_path) # 验证日志警告 @@ -232,35 +499,7 @@ class TestAnomalyAnalyse(unittest.TestCase): f"The existing file will be deleted: output_path/anomaly_analyse.json.") -class TestParseArgs(unittest.TestCase): - - @patch('msprobe.pytorch.monitor.anomaly_analyse.sys.argv', - new=['script_name', '-d', 'path/to/data', '-o', 'path/to/output', '-k', '5', '-s', '[1,2,3]']) - def test_parse_args_with_all_arguments(self): - args = _get_parse_args() - self.assertEqual(args.data_path_dir, 'path/to/data') - self.assertEqual(args.out_path, 'path/to/output') - self.assertEqual(args.top_k_number, 5) - self.assertEqual(args.step_list, '[1,2,3]') - - @patch('msprobe.pytorch.monitor.anomaly_analyse.sys.argv', new=['script_name', '-d', 'path/to/data']) - def test_parse_args_with_required_argument_only(self): - args = _get_parse_args() - self.assertEqual(args.data_path_dir, 'path/to/data') - self.assertEqual(args.out_path, '') - self.assertEqual(args.top_k_number, 8) # 默认值 - self.assertEqual(args.step_list, '[]') # 默认值 - - @patch('msprobe.pytorch.monitor.anomaly_analyse.sys.argv', new=['script_name', '-d', 'path/to/data', '-k', '10']) - def test_parse_args_with_topk_only(self): - args = _get_parse_args() - self.assertEqual(args.data_path_dir, 'path/to/data') - self.assertEqual(args.out_path, '') - self.assertEqual(args.top_k_number, 10) # 提供的值 - self.assertEqual(args.step_list, '[]') # 默认值 - - -class TestGetStepAndStop(unittest.TestCase): +class TestGetStepAndStop(TestCase): def test_valid_step_list_and_top_k(self): # 构造有效的 args 对象 @@ -318,13 +557,13 @@ class TestGetStepAndStop(unittest.TestCase): self.assertEqual(str(context.exception), "The top k number must be greater than 0.") -class TestAnomalyAnalyseFunction(unittest.TestCase): +class TestAnomalyAnalyseFunction(TestCase): - @patch('msprobe.pytorch.monitor.anomaly_analyse._get_parse_args') # 模拟命令行参数解析 - @patch('msprobe.pytorch.monitor.anomaly_analyse._get_step_and_stop') # 模拟步骤和顶级数字解析 - @patch('msprobe.pytorch.monitor.anomaly_analyse.AnomalyDataLoader') # 模拟数据加载器 - @patch('msprobe.pytorch.monitor.anomaly_analyse.AnomalyAnalyse') # 模拟异常分析器 - @patch('msprobe.pytorch.monitor.anomaly_analyse.logger') # 模拟日志记录 + @patch('msprobe.core.monitor.anomaly_processor._get_parse_args') # 模拟命令行参数解析 + @patch('msprobe.core.monitor.anomaly_processor._get_step_and_stop') # 模拟步骤和顶级数字解析 + @patch('msprobe.core.monitor.anomaly_processor.AnomalyDataLoader') # 模拟数据加载器 + @patch('msprobe.core.monitor.anomaly_processor.AnomalyAnalyse') # 模拟异常分析器 + @patch('msprobe.core.monitor.anomaly_processor.logger') # 模拟日志记录 def test_anomaly_analyse(self, mock_logger, mock_anomaly_analyse, mock_anomaly_data_loader, mock_get_step_and_stop, mock_get_parse_args): # 模拟命令行参数 @@ -376,5 +615,33 @@ class TestAnomalyAnalyseFunction(unittest.TestCase): mock_logger.info.assert_any_call("1: Top Anomaly 2") +class TestParseArgs(TestCase): + + @patch('msprobe.core.monitor.anomaly_processor.sys.argv', + new=['script_name', '-d', 'path/to/data', '-o', 'path/to/output', '-k', '5', '-s', '[1,2,3]']) + def test_parse_args_with_all_arguments(self): + args = _get_parse_args() + self.assertEqual(args.data_path_dir, 'path/to/data') + self.assertEqual(args.out_path, 'path/to/output') + self.assertEqual(args.top_k_number, 5) + self.assertEqual(args.step_list, '[1,2,3]') + + @patch('msprobe.core.monitor.anomaly_processor.sys.argv', new=['script_name', '-d', 'path/to/data']) + def test_parse_args_with_required_argument_only(self): + args = _get_parse_args() + self.assertEqual(args.data_path_dir, 'path/to/data') + self.assertEqual(args.out_path, '') + self.assertEqual(args.top_k_number, 8) # 默认值 + self.assertEqual(args.step_list, '[]') # 默认值 + + @patch('msprobe.core.monitor.anomaly_processor.sys.argv', new=['script_name', '-d', 'path/to/data', '-k', '10']) + def test_parse_args_with_topk_only(self): + args = _get_parse_args() + self.assertEqual(args.data_path_dir, 'path/to/data') + self.assertEqual(args.out_path, '') + self.assertEqual(args.top_k_number, 10) # 提供的值 + self.assertEqual(args.step_list, '[]') # 默认值 + + if __name__ == '__main__': unittest.main() diff --git a/debug/accuracy_tools/msprobe/test/core_ut/test_hook_manager.py b/debug/accuracy_tools/msprobe/test/core_ut/test_hook_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..d0c49a58f9838d355e14db67536f61f11fa43af6 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/core_ut/test_hook_manager.py @@ -0,0 +1,188 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest.mock import MagicMock, patch +from msprobe.core.common.const import Const +from msprobe.core.common.runtime import Runtime +from msprobe.core.hook_manager import BaseHookManager + + +class TestBaseHookManager(unittest.TestCase): + class MockBaseHookManager(BaseHookManager): + @property + def _is_recompute(self): + return False + + @staticmethod + def _no_grad_context(): + return MagicMock() + + @staticmethod + def _add_count(name): + pass + + @staticmethod + def _process_kwargs_and_output(module, hook_type, kwargs_or_output, output_or_kwargs): + return {"kwargs": kwargs_or_output}, output_or_kwargs + + def build_hook(self): + pass + + def _get_params_dict(self, module): + return {} + + def _need_exchange(self, module): + return False + + def setUp(self): + self.mock_data_collector = MagicMock() + self.mock_config = MagicMock() + self.mock_config.data_mode = ["all"] + self.mock_attl_manager = MagicMock() + self.manager = self.MockBaseHookManager( + self.mock_data_collector, + self.mock_config, + self.mock_attl_manager + ) + BaseHookManager.inner_switch = False + BaseHookManager.hook_handle_dict = {} + BaseHookManager.params_grad_info = {} + + def test_init(self): + self.assertEqual(self.manager.data_collector, self.mock_data_collector) + self.assertEqual(self.manager.config, self.mock_config) + self.assertEqual(self.manager.attl_manager, self.mock_attl_manager) + + def test_should_execute_hook_conditions(self): + module = MagicMock() + module.forward_data_collected = True + module.async_op_dump_flag = False + Runtime.is_running = True + self.mock_data_collector.data_processor.is_terminated = False + self.assertTrue(self.manager._should_execute_hook(Const.MODULE, module, True)) + self.assertTrue(self.manager._should_execute_hook(Const.API, module, False)) + + Runtime.is_running = False + self.assertFalse(self.manager._should_execute_hook(Const.MODULE, module, True)) + + Runtime.is_running = True + module.forward_data_collected = False + self.assertFalse(self.manager._should_execute_hook(Const.API, module, False)) + + BaseHookManager.inner_switch = True + self.assertFalse(self.manager._should_execute_hook(Const.MODULE, module, True)) + + self.mock_data_collector.data_processor.is_terminated = True + BaseHookManager.inner_switch = False + self.assertFalse(self.manager._should_execute_hook(Const.MODULE, module, True)) + self.assertFalse(self.manager._should_execute_hook(Const.API, module, True)) + + def test_clear_input_kwargs(self): + module = MagicMock() + module.msprobe_input_kwargs = {"key": "value"} + self.manager._clear_input_kwargs(module) + self.assertFalse(hasattr(module, 'msprobe_input_kwargs')) + + def test_register_param_hook(self): + module = MagicMock() + params = {"param1": MagicMock(requires_grad=True)} + full_name = "module.forward" + + with patch.object(self.manager, '_build_grad_hook') as mock_build: + self.manager._register_param_hook(full_name, module, params) + + self.assertEqual(len(BaseHookManager.hook_handle_dict), 1) + self.assertTrue("module.param1" in BaseHookManager.hook_handle_dict) + + self.assertEqual(module.params_grad_name, "module.parameters_grad") + + def test_init_params_grad_info(self): + module = MagicMock() + module.params_grad_name = "grad_name" + params = {"param1": MagicMock(requires_grad=True)} + + self.manager._init_params_grad_info(module, params) + self.mock_data_collector.handle_data.assert_called() + self.assertTrue(BaseHookManager.params_grad_info.get("grad_name")) + + self.manager._init_params_grad_info(module, params) + self.mock_data_collector.handle_data.assert_called_once() + + @patch.object(BaseHookManager, "_should_execute_hook") + def test_forward_pre_hook_behavior(self, mock_should_execute_hook): + mock_should_execute_hook.return_value = True + self.manager.config.online_run_ut = None + hook = self.manager._build_forward_pre_hook(Const.API, "api_name", "func_name") + module = MagicMock() + module.msprobe_input_kwargs = {"kwarg": "value"} + args = (1, 2) + + Runtime.is_running = True + module.forward_data_collected = True + self.mock_data_collector.data_processor.is_terminated = False + + with patch.object(self.manager, '_no_grad_context') as mock_ctx: + hook(module, args) + self.mock_data_collector.forward_input_data_collect.assert_called_once() + self.assertEqual(module.forward_data_collected, True) + + @patch.object(BaseHookManager, "_should_execute_hook") + def test_forward_hook_behavior(self, mock_should_execute_hook): + mock_should_execute_hook.return_value = True + hook = self.manager._build_forward_hook(Const.MODULE, "module_name") + module = MagicMock() + args = (1, 2) + kwargs = {"kwargs": []} + output = MagicMock() + + self.manager.config.online_run_ut = True + hook(module, args, output) + self.mock_attl_manager.attl_send.assert_called_once() + + self.manager.config.online_run_ut = None + self.mock_data_collector.if_return_forward_new_output.return_value = False + with patch.object(self.manager, '_get_params_dict', return_value={}): + result = hook(module, args, kwargs, output) + self.assertEqual(result, output) + self.mock_data_collector.forward_data_collect.assert_called_once() + self.mock_data_collector.get_forward_new_output.assert_not_called() + + self.mock_data_collector.if_return_forward_new_output.return_value = True + self.mock_data_collector.get_forward_new_output.return_value = "new_output" + with patch.object(self.manager, '_get_params_dict', return_value={}): + result = hook(module, args, output) + self.assertEqual(result, "new_output") + + @patch.object(BaseHookManager, "_should_execute_hook") + def test_backward_hook_behavior(self, mock_should_execute_hook): + mock_should_execute_hook.return_value = True + self.manager.config.online_run_ut = None + hook = self.manager._build_backward_hook(Const.API, "api_name") + module = MagicMock() + grad_input = (MagicMock(),) + grad_output = (MagicMock(),) + + module.forward_data_collected = True + Runtime.is_running = True + hook(module, grad_input, grad_output) + + self.mock_data_collector.backward_data_collect.assert_called_once() + + with patch.object(self.manager, '_need_exchange', return_value=True): + hook(module, grad_input, grad_output) + args, _ = self.mock_data_collector.backward_data_collect.call_args_list[1] + self.assertEqual(args[3].grad_input, grad_output) + self.assertEqual(args[3].grad_output, grad_input) diff --git a/debug/accuracy_tools/msprobe/test/core_ut/test_service.py b/debug/accuracy_tools/msprobe/test/core_ut/test_service.py new file mode 100644 index 0000000000000000000000000000000000000000..0a119e296adc6cdec73b016c46a47114e51bf28e --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/core_ut/test_service.py @@ -0,0 +1,398 @@ +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. language governing permissions and +# limitations under the License. + +import unittest +from unittest.mock import MagicMock, patch +import os +import tempfile + +from msprobe.core.service import BaseService +from msprobe.core.common.utils import Const +from msprobe.core.common.runtime import Runtime +from msprobe.core.data_dump.api_registry import ApiRegistry +from msprobe.core.hook_manager import BaseHookManager + + +class ConcreteBaseService(BaseService): + def _init_specific_components(self): + self.logger = MagicMock() + self.api_register = MagicMock() + self.hook_manager = MagicMock() + self.api_template = MagicMock() + + def _register_hook(self): + pass + + def _register_module_hook(self): + pass + + def _get_framework_type(self): + return "TestFramework" + + @staticmethod + def _get_current_rank(): + return 0 + + def _change_jit_switch(self, status): + pass + +class TestBaseService(unittest.TestCase): + def setUp(self): + self.temp_dir = tempfile.TemporaryDirectory() + self.config = MagicMock() + self.config.level = Const.LEVEL_DEBUG + self.config.level_ori = self.config.level + self.config.step = [1, 3] + self.config.rank = [0, 2] + self.config.dump_path = self.temp_dir.name + self.config.task = Const.STATISTICS + self.config.async_dump = True + self.config.tensor_list = [] + self.config.online_run_ut = False + self.config.framework = "test_framwork" + + with patch('msprobe.core.service.build_data_collector'): + self.service = ConcreteBaseService(self.config) + + def tearDown(self): + self.temp_dir.cleanup() + + def test_initialization(self): + self.assertEqual(self.service.config.level, Const.LEVEL_DEBUG) + self.assertIsNone(self.service.model) + self.assertIsNotNone(self.service.data_collector) + self.assertEqual(self.service.current_iter, 0) + self.assertEqual(self.service.loop, 0) + self.assertTrue(self.service.first_start) + self.assertFalse(self.service.primitive_switch) + self.assertIsNone(self.service.current_rank) + self.assertIsNone(self.service.dump_iter_dir) + self.assertFalse(self.service.should_stop_service) + self.assertTrue(self.service.currrent_step_first_debug_save) + self.assertEqual(self.service.ori_customer_func, {}) + + def test_properties(self): + self.service.config.level = Const.LEVEL_DEBUG + self.assertTrue(self.service._is_debug_level) + + self.service.config.level = Const.LEVEL_L2 + self.assertTrue(self.service._is_l2_level) + + self.service.config.level = Const.LEVEL_MIX + self.assertTrue(self.service._is_mix_level) + + self.service.config.level = Const.LEVEL_MIX + self.assertTrue(self.service._is_need_module_hook) + + self.service.config.level = Const.LEVEL_MIX + self.assertTrue(self.service._is_need_api_hook) + + self.assertFalse(self.service._need_tensor_data) + + self.service.current_iter = 2 + self.assertTrue(self.service._is_no_dump_step) + + self.service.current_rank = 1 + self.assertTrue(self.service._is_no_dump_rank) + + @patch.object(BaseService, '_get_current_rank') + @patch.object(BaseService, '_process_iteration') + def test_start_debug_level(self, mock_process_iter, mock_get_rank): + self.service.config.level = Const.LEVEL_DEBUG + model_mock = MagicMock() + + self.service.start(model=model_mock) + + mock_get_rank.assert_not_called() + mock_process_iter.assert_called_once() + self.service.logger.info.assert_not_called() + self.assertFalse(Runtime.is_running) + + + @patch.object(ConcreteBaseService, '_register_hook') + @patch.object(ConcreteBaseService, '_register_module_hook') + def test_start_normal_level_first_time(self, mock_register_module_hook, mock_register_hook): + self.service.config.level = Const.LEVEL_MIX + self.service.config.step = [] + self.service.config.rank = [] + model_mock = MagicMock() + self.service.data_collector.data_processor.is_terminated = False + self.service.start(model=model_mock) + + self.assertEqual(self.service.current_rank, 0) + self.assertEqual(Runtime.current_rank, 0) + + mock_register_hook.assert_called_once() + mock_register_module_hook.assert_called_once() + + self.service.logger.info.assert_called_with(f"Dump data will be saved in {self.service.dump_iter_dir}.") + self.assertTrue(Runtime.is_running) + self.assertTrue(self.service.primitive_switch) + self.assertFalse(self.service.first_start) + + @patch.object(ConcreteBaseService, '_register_hook') + @patch.object(ConcreteBaseService, '_register_module_hook') + @patch.object(ConcreteBaseService, 'create_dirs') + def test_start_not_first_calls(self, mock_dirs, mock_register_module_hook, mock_register_hook): + self.service.config.level = Const.LEVEL_L1 + self.service.config.step = [] + self.service.config.rank = [] + self.service.data_collector.data_processor.is_terminated = False + self.service.first_start = False + model_mock = MagicMock() + + self.service.start(model=model_mock) + mock_register_hook.assert_not_called() + mock_register_module_hook.assert_not_called() + self.assertTrue(Runtime.is_running) + self.assertTrue(self.service.primitive_switch) + mock_dirs.assert_called_once() + + def test_start_with_infer_hook(self): + self.service.config.level = Const.LEVEL_L1 + self.service.config.step = [] + self.service.config.rank = [] + self.service.data_collector.data_processor.is_terminated = False + model_mock = MagicMock() + token_range = [10, 20] + + self.service.start(model=model_mock, token_range=token_range) + model_mock.register_forward_pre_hook.assert_called_once() + self.assertEqual(self.service.cur_token_id, 0) + + def test_stop_debug_level(self): + self.config.level = Const.LEVEL_DEBUG + self.service.stop() + self.service.logger.info.assert_not_called() + + @patch.object(BaseService, '_process_async_dump') + def test_stop_normal_level(self, mock_process_async_dump): + self.service.config.level = Const.LEVEL_L1 + self.service.current_iter = 1 + self.service.current_rank = 0 + + self.service.stop() + self.assertFalse(Runtime.is_running) + self.assertFalse(self.service.primitive_switch) + + self.service.logger.info.assert_called_with( + f"{Const.TOOL_NAME}: debugger.stop() is set successfully. " + "Please set debugger.start() to turn on the dump switch again. " + ) + mock_process_async_dump.assert_called_once() + self.service.data_collector.write_json.assert_called_once() + + def test_stop_no_dump_step(self): + self.config.level = Const.LEVEL_L1 + self.service.current_iter = 2 + self.service.stop() + self.service.logger.info.assert_not_called() + + def test_stop_no_dump_rank(self): + self.config.level = Const.LEVEL_L1 + self.service.current_iter = 1 + self.service.current_rank = 1 + self.service.stop() + self.service.logger.info.assert_not_called() + + @patch.object(BaseService, '_process_async_dump') + def test_step(self, mock_process_async_dump): + self.service.step() + self.assertEqual(self.service.loop, 1) + self.assertTrue(self.service.currrent_step_first_debug_save) + mock_process_async_dump.assert_called_once() + self.service.data_collector.write_json.assert_called_once() + self.service.data_collector.reset_status.assert_called_once() + + @patch.object(BaseService, '_process_async_dump') + def test_step_should_stop_service(self, mock_process_async_dump): + self.service.should_stop_service = True + self.service.step() + self.assertEqual(self.service.loop, 0) + mock_process_async_dump.assert_not_called() + + def test_save_debug_level(self): + self.service.loop = 1 + self.service.init_step = 0 + self.service.save("test_var", "test_name", True) + self.service.data_collector.debug_data_collect_forward.assert_called_with("test_var", "test_name.0") + self.service.data_collector.debug_data_collect_backward.assert_called_with("test_var", "test_name_grad.0") + + def test_save_not_debug_level(self): + self.service.config.level = Const.LEVEL_L0 + self.service.loop = 1 + self.service.init_step = 0 + self.service.save("test_var", "test_name", True) + self.service.data_collector.debug_data_collect_forward.assert_not_called() + + def test_save_no_dump_step(self): + self.config.level = Const.LEVEL_DEBUG + self.service.current_iter = 2 + self.service.save("test_var", "test_name", True) + self.service.data_collector.debug_data_collect_forward.assert_not_called() + + def test_save_first_time_in_step(self): + self.service.config.level = Const.LEVEL_DEBUG + self.service.loop = 1 + self.service.init_step = 0 + + self.service.save("test_var", "test_name", True) + + self.assertEqual(self.service.current_rank, 0) + self.assertFalse(self.service.currrent_step_first_debug_save) + self.assertEqual(self.service.debug_variable_counter, {"test_name": 1}) + + self.assertIsNotNone(self.service.dump_iter_dir) + self.assertTrue(os.path.exists(self.service.dump_iter_dir)) + + @patch.object(ApiRegistry, 'register_custom_api') + def test_register_and_restore_custom_api(self, mock_register_custom_api): + module_mock = MagicMock() + api_name = "test_api" + api_prefix = "test_prefix" + self.service.register_custom_api(module_mock, api_name, api_prefix) + key = f"{str(module_mock)}{Const.SEP}{api_name}" + self.assertIn(key, self.service.ori_customer_func) + mock_register_custom_api.assert_called_once() + self.service.restore_custom_api(module_mock, api_name) + self.assertEqual(module_mock.test_api, self.service.ori_customer_func.get(key)) + + def test_build_hook(self): + hook = self.service.build_hook("test_type", "test_name") + self.service.hook_manager.build_hook.assert_called_with("test_type", "test_name") + + def test_create_dirs_pynative_graph(self): + Runtime.run_mode = Const.PYNATIVE_GRAPH_MODE + self.service.current_iter = 1 + self.service.current_rank = 0 + + self.service.create_dirs() + + expected_dir = os.path.join(self.config.dump_path, Const.PYNATIVE_MODE, "step1", "rank0") + self.assertEqual( + self.service.dump_iter_dir, os.path.join(self.config.dump_path, Const.PYNATIVE_MODE, "step1")) + self.assertTrue(os.path.exists(expected_dir)) + + self.service.data_collector.update_dump_paths.assert_called() + self.service.data_collector.initialize_json_file.assert_called() + + def test_create_dirs_pynative_mode(self): + Runtime.run_mode = Const.PYNATIVE_MODE + self.service.current_iter = 1 + self.service.current_rank = 0 + self.service.create_dirs() + expected_dir = os.path.join(self.config.dump_path, "step1", "rank0") + self.assertEqual(self.service.dump_iter_dir, os.path.join(self.config.dump_path, "step1")) + self.assertTrue(os.path.exists(expected_dir)) + + def test_create_dirs_l2_level(self): + self.service.config.level = Const.LEVEL_L2 + self.service.current_iter = 1 + self.service.current_rank = 0 + self.service.create_dirs() + expected_dir = os.path.join(self.config.dump_path, "step1") + self.assertEqual(self.service.dump_iter_dir, expected_dir) + self.assertTrue(os.path.exists(expected_dir)) + + kernel_config_path = os.path.join(expected_dir, "kernel_config_0.json") + self.assertTrue(os.path.exists(kernel_config_path)) + self.assertEqual(self.service.config.kernel_config_path, kernel_config_path) + + def test_need_stop_service_conditions(self): + self.service.current_iter = 4 + self.service.config.step = [1, 2, 3] + self.service.config.online_run_ut = True + self.service.attl_manager = MagicMock() + self.assertTrue(self.service._need_stop_service()) + self.assertFalse(Runtime.is_running) + self.assertFalse(self.service.primitive_switch) + self.service.attl_manager.attl_stop.assert_called() + + self.service.current_iter = 1 + self.service.data_collector.data_processor.is_terminated = True + self.assertTrue(self.service._need_stop_service()) + + self.service.data_collector.data_processor.is_terminated = False + self.service.should_stop_service = False + self.service.current_iter = 1 + self.service.config.step = [1, 2, 3] + self.assertFalse(self.service._need_stop_service()) + + def test_register_api_hook(self): + self.service.config.level = Const.LEVEL_MIX + self.service._register_api_hook() + self.service.api_register.initialize_hook.assert_called() + self.service.api_register.register_all_api.assert_called() + self.service.logger.info.assert_called_with( + f"The api {self.config.task} hook function is successfully mounted to the model." + ) + + def test_register_infer_count_hook(self): + model_mock = MagicMock() + token_range = [5, 10] + + self.service._register_infer_count_hook(model_mock, token_range) + + model_mock.register_forward_pre_hook.assert_called_once() + + hook = model_mock.register_forward_pre_hook.call_args[0][0] + + self.service.cur_token_id = 4 + hook(model_mock, None) + self.assertFalse(Runtime.is_running) + + self.service.cur_token_id = 5 + hook(model_mock, None) + self.assertTrue(Runtime.is_running) + + self.service.cur_token_id = 7 + hook(model_mock, None) + self.assertTrue(Runtime.is_running) + + self.service.cur_token_id = 11 + hook(model_mock, None) + self.assertFalse(Runtime.is_running) + + def test_process_iteration(self): + self.service.loop = 5 + self.service.init_step = 10 + self.service._process_iteration() + + self.assertEqual(self.service.current_iter, 15) + self.assertEqual(Runtime.current_iter, 15) + self.service.data_collector.update_iter.assert_called_with(15) + + def test_process_async_dump(self): + self.service.config.async_dump = True + self.service.config.task = Const.STATISTICS + self.service._process_async_dump() + + self.service.data_collector.data_processor.dump_async_data.assert_called_once() + + def test_process_async_dump_not_needed(self): + self.service.config.async_dump = False + self.service._process_async_dump() + self.service.data_collector.data_processor.dump_async_data.assert_not_called() + + self.service.config.task = Const.OVERFLOW_CHECK + self.service._process_async_dump() + self.service.data_collector.data_processor.dump_async_data.assert_not_called() + + def test_reset_status(self): + self.service._reset_status() + self.service.data_collector.reset_status.assert_called_once() + self.assertEqual(BaseHookManager.params_grad_info, {}) + \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/api_accuracy_checker/test_data_manager.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/api_accuracy_checker/test_data_manager.py index bb4c8b197ef8362921858839ca3790224715a39a..9cfad00d8ff13e91eb84fff5f46ab434f9ed1d4d 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/api_accuracy_checker/test_data_manager.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/api_accuracy_checker/test_data_manager.py @@ -2,7 +2,8 @@ import unittest from unittest.mock import patch, mock_open, MagicMock import os from msprobe.mindspore.api_accuracy_checker.api_accuracy_checker import DataManager -from msprobe.core.common.const import MsCompareConst, CompareConst +from msprobe.core.common.const import CompareConst +from msprobe.mindspore.common.const import MsCompareConst class TestDataManager(unittest.TestCase): diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/compare/dump_file/mindspore_data/dump.json b/debug/accuracy_tools/msprobe/test/mindspore_ut/compare/dump_file/mindspore_data/dump.json index 5b954f6d6443c92e6321e5f55e373e99f428653d..48800c0455c6651b146600e61e636d4dc25fac31 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/compare/dump_file/mindspore_data/dump.json +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/compare/dump_file/mindspore_data/dump.json @@ -1,6 +1,7 @@ { "task": "statistics", "level": "mix", + "framework": "mindspore", "dump_data_dir": null, "data": { "Tensor.__add__.0.forward": { diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/compare/dump_file/pytorch_data/dump.json b/debug/accuracy_tools/msprobe/test/mindspore_ut/compare/dump_file/pytorch_data/dump.json index 150cbd43b169573e48542aa0c46c26e7df69843e..b2704185ff19b961b43453f81247236d77677d83 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/compare/dump_file/pytorch_data/dump.json +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/compare/dump_file/pytorch_data/dump.json @@ -1,6 +1,7 @@ { "task": "statistics", "level": "mix", + "framework": "pytorch", "dump_data_dir": null, "data": { "Tensor.__add__.0.forward": { diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/compare/test_ms_compare.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/compare/test_ms_compare.py index b5cbff9784a837ea4d64ac9eccdf30175564f712..eafe9384618502390b41adedd7d32db172ca8188 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/compare/test_ms_compare.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/compare/test_ms_compare.py @@ -1,19 +1,18 @@ # coding=utf-8 -import json import os -import random import shutil -import tempfile +import random import unittest +from unittest.mock import patch -import numpy as np import torch -import yaml +import numpy as np -from msprobe.core.common.utils import CompareException -from msprobe.core.compare.acc_compare import ModeConfig -from msprobe.mindspore.compare.ms_compare import MappingConfig, MSComparator, check_cross_framework +from msprobe.mindspore.compare.ms_compare import check_cross_framework, read_real_data, ms_compare from msprobe.core.common.const import Const +from msprobe.test.core_ut.compare.test_acc_compare import generate_dump_json, generate_stack_json +from msprobe.core.common.utils import CompareException + npu_dict = {'op_name': ['Functional.conv2d.0.forward.input.0', 'Functional.conv2d.0.forward.input.1', 'Functional.conv2d.0.forward.input.2', 'Functional.conv2d.0.forward.output'], @@ -173,6 +172,8 @@ json_data_template = { 'data': {} } +base_dir1 = os.path.join(os.path.dirname(os.path.abspath(__file__)), f'test_ms_compare1') + def gen_data(is_ms=True): type_value = 'mindspore.Tensor' if is_ms else 'torch.Tensor' @@ -188,349 +189,65 @@ def gen_data(is_ms=True): } -def gen_api_mapping_test_data(need_user_mapping=False): - result_npu = json_data_template.copy() - result_bench = json_data_template.copy() - - stack_mode = True - auto_analyze = True - fuzzy_match = False - dump_mode = Const.SUMMARY - - mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) - mapping_config = MappingConfig() - ms_comparator = MSComparator(mode_config, mapping_config) - - api_mapping = ms_comparator.load_internal_api() - ms_api_list = np.random.choice(list(api_mapping.keys()), size=5, replace=False).astype(str).tolist() - ms_api_data = {} - pt_api_data = {} - user_mapping = [] - for api in ms_api_list: - call_num = random.randint(1, 10) - direction = random.choice(['forward', 'backward']) - data_name_ms = api + '.' + str(call_num) + '.' + direction - data_name_pt = api_mapping.get(api) + '.' + str(call_num) + '.' + direction - input_num = random.randint(1, 5) - output_num = random.randint(1, 5) - ms_data = {'input_args': [gen_data(True) for _ in range(input_num)], - 'output': [gen_data(True) for _ in range(output_num)]} - pt_data = {'input_args': [gen_data(False) for _ in range(input_num)], - 'output': [gen_data(False) for _ in range(output_num)]} - ms_api_data[data_name_ms] = ms_data - pt_api_data[data_name_pt] = pt_data - if need_user_mapping: - compare_num_input = random.randint(1, input_num) - compare_num_output = random.randint(1, output_num) - user_mapping_item = {'ms_api': api, - 'pt_api': api_mapping.get(api), - 'ms_args': sorted(np.random.choice(list(range(input_num)), size=compare_num_input, - replace=False).astype(int).tolist()), - 'pt_args': sorted(np.random.choice(list(range(input_num)), size=compare_num_input, - replace=False).astype(int).tolist()), - 'ms_output': sorted(np.random.choice(list(range(output_num)), size=compare_num_output, - replace=False).astype(int).tolist()), - 'pt_output': sorted(np.random.choice(list(range(output_num)), size=compare_num_output, - replace=False).astype(int).tolist())} - user_mapping.append(user_mapping_item) - ms_api_key_list = list(ms_api_data.keys()) - random.shuffle(ms_api_key_list) - result_npu['data'] = {k: ms_api_data.get(k) for k in ms_api_key_list} - pt_api_key_list = list(pt_api_data.keys()) - random.shuffle(pt_api_key_list) - result_bench['data'] = {k: pt_api_data.get(k) for k in pt_api_key_list} - return result_npu, result_bench, user_mapping +class TestUtilsMethods(unittest.TestCase): + def setUp(self): + os.makedirs(base_dir1, mode=0o750, exist_ok=True) + np.save(os.path.join(base_dir1, 'numpy_data.npy'), np.array([1, 2, 3])) + torch.save(torch.tensor([2, 3, 4]), os.path.join(base_dir1, 'torch_data.pt')) -class TestUtilsMethods(unittest.TestCase): + def tearDown(self): + if os.path.exists(base_dir1): + shutil.rmtree(base_dir1) - def test_check_op_ms(self): - stack_mode = True - auto_analyze = True - fuzzy_match = False - dump_mode = Const.ALL + @patch('msprobe.mindspore.compare.utils.detect_framework_by_dump_json') + def test_check_cross_framework_valid_pytorch(self, mock_detect_framework): + mock_detect_framework.return_value = Const.PT_FRAMEWORK - mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) - mapping_config = MappingConfig() + result = check_cross_framework("dummy_path") - ms_comparator = MSComparator(mode_config, mapping_config) - result = ms_comparator.check_op(npu_dict, bench_dict) self.assertTrue(result) - def test_data_mapping(self): - stack_json_data = {} - - stack_mode = True - auto_analyze = True - fuzzy_match = False - dump_mode = Const.SUMMARY - - mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) - mapping_config = MappingConfig(data_mapping=data_mapping) - ms_comparator = MSComparator(mode_config, mapping_config) - - npu_ops_all = ms_comparator.merge_data(npu_json_data, stack_json_data) - npu_ops_all_correct = { - 'Functional.flash_attention_score.4.forward.input.0': { - 'struct': ('BFloat16', [4096, 1, 2048]), - 'summary': [4.1875, -4.4375, -4.550282028503716e-05, 2316.379150390625], - 'data_name': None, - 'stack_info': [None] - }, - 'Functional.flash_attention_score.4.forward.output.0': { - 'struct': ('BFloat16', [4096, 1, 2048]), - 'summary': [4.1875, -4.4375, -4.550282028503716e-05, 2316.379150390625], - 'data_name': None, - 'stack_info': [None] - } - } - self.assertDictEqual(npu_ops_all, npu_ops_all_correct) - - bench_ops_all = ms_comparator.merge_data(bench_json_data, stack_json_data) - bench_ops_all_correct = { - 'NPU.npu_fusion_attention.4.forward.input.0': { - 'struct': ('torch.bfloat16', [4096, 1, 2048]), - 'summary': [4.1875, -4.4375, -4.553794860839844e-05, 2320.0], - 'data_name': None, - 'stack_info': [None] - }, - 'NPU.npu_fusion_attention.4.forward.output.0': { - 'struct': ('torch.bfloat16', [4096, 1, 2048]), - 'summary': [4.1875, -4.4375, -4.553794860839844e-05, 2320.0], - 'data_name': None, - 'stack_info': [None] - } - } - self.assertDictEqual(bench_ops_all, bench_ops_all_correct) - - result = ms_comparator.get_accuracy(npu_ops_all, bench_ops_all) - result_correct = [['Functional.flash_attention_score.4.forward.input.0', - 'NPU.npu_fusion_attention.4.forward.input.0', - 'BFloat16', 'torch.bfloat16', [4096, 1, 2048], [4096, 1, 2048], 0.0, 0.0, - 3.512832336127758e-08, -3.620849609375, '0.0%', '0.0%', '0.07714076816099476%', - '0.1560711038523707%', 4.1875, -4.4375, -4.550282028503716e-05, 2316.379150390625, - 4.1875, -4.4375, -4.553794860839844e-05, 2320.0, '', '', None], - ['Functional.flash_attention_score.4.forward.output.0', - 'NPU.npu_fusion_attention.4.forward.output.0', - 'BFloat16', 'torch.bfloat16', [4096, 1, 2048], [4096, 1, 2048], 0.0, 0.0, - 3.512832336127758e-08, -3.620849609375, '0.0%', '0.0%', '0.07714076816099476%', - '0.1560711038523707%', 4.1875, -4.4375, -4.550282028503716e-05, 2316.379150390625, - 4.1875, -4.4375, -4.553794860839844e-05, 2320.0, '', '', None] - ] - self.assertListEqual(result, result_correct) - - def test_dm_tensor_task(self): - self.compare_process_custom(dump_mode=Const.ALL) - - def compare_process_custom(self, dump_mode): - data_path = tempfile.mkdtemp(prefix='dump_data', dir='/tmp') - try: - npu_dump_path = os.path.join(data_path, 'npu_dump.json') - bench_dump_path = os.path.join(data_path, 'bench_dump.json') - npu_stack_path = os.path.join(data_path, 'npu_stack.json') - - with open(npu_dump_path, 'w') as n_d_f: - json.dump(npu_json_data, n_d_f) - with open(bench_dump_path, 'w') as b_d_f: - json.dump(bench_json_data, b_d_f) - with open(npu_stack_path, 'w') as n_s_f: - json.dump({}, n_s_f) - - stack_mode = True - auto_analyze = True - fuzzy_match = False - dump_mode = Const.SUMMARY - - mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) - mapping_config = MappingConfig() - - ms_comparator = MSComparator(mode_config, mapping_config) - result_df = ms_comparator.compare_process_custom((npu_dump_path, bench_dump_path, npu_stack_path)) - self.assertListEqual(result_df.values.tolist(), []) - finally: - shutil.rmtree(data_path) - - def test_check_cross_framework(self): - ms_data = { - "data_name": "Cell.model.language_model.encoder.layers.5.input_norm.FusedRMSNorm.forward.0.input.0.npy", - } - pt_data = { - "data_name": "Module.module.module.language_model.encoder.layers.0.input_norm.RMSNorm.forward.0.input.0.pt", + @patch('msprobe.mindspore.compare.utils.detect_framework_by_dump_json') + def test_check_cross_framework_invalid_framework(self, mock_detect_framework): + mock_detect_framework.return_value = Const.MS_FRAMEWORK + + result = check_cross_framework("dummy_path") + + self.assertFalse(result) + + def test_read_real_data_ms(self): + n_value, b_value = read_real_data(base_dir1, 'numpy_data.npy', base_dir1, 'numpy_data.npy', False) + self.assertTrue(np.array_equal(n_value, np.array([1, 2, 3]))) + self.assertTrue(np.array_equal(b_value, np.array([1, 2, 3]))) + + def test_read_real_data_cross_frame(self): + n_value, b_value = read_real_data(base_dir1, 'numpy_data.npy', base_dir1, 'torch_data.pt', True) + self.assertTrue(np.array_equal(n_value, np.array([1, 2, 3]))) + self.assertTrue(np.array_equal(b_value, np.array([2, 3, 4]))) + + def test_ms_compare(self): + generate_dump_json(base_dir1) + generate_stack_json(base_dir1) + + dump_path = os.path.join(base_dir1, 'dump.json') + + input_param = { + 'npu_json_path': dump_path, + 'bench_json_path': dump_path, + 'is_print_compare_log': True } + output_path = base_dir1 - def check_data(data): - with tempfile.NamedTemporaryFile(mode='w+', suffix='.json', encoding='utf-8', delete=True) as temp_file: - json.dump(data, temp_file, ensure_ascii=False, indent=4) - temp_file.flush() - return check_cross_framework(temp_file.name) - self.assertFalse(check_data(ms_data)) - self.assertTrue(check_data(pt_data)) - - def test_comapre_process(self): - data_path = tempfile.mkdtemp(prefix='dump_data', dir='/tmp') - try: - npu_dump_path = os.path.join(data_path, 'npu_dump.json') - bench_dump_path = os.path.join(data_path, 'bench_dump.json') - npu_stack_path = os.path.join(data_path, 'npu_stack.json') - - npu_data, bench_data, _ = gen_api_mapping_test_data() - with open(npu_dump_path, 'w', encoding='utf8') as n_d_f: - json.dump(npu_data, n_d_f) - with open(bench_dump_path, 'w', encoding='utf8') as b_d_f: - json.dump(bench_data, b_d_f) - with open(npu_stack_path, 'w', encoding='utf8') as n_s_f: - json.dump({}, n_s_f) - - stack_mode = True - auto_analyze = True - fuzzy_match = False - dump_mode = Const.SUMMARY - mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) - mapping_config = MappingConfig(api_mapping=True) - - ms_comparator = MSComparator(mode_config, mapping_config) - result_df = ms_comparator.compare_process((npu_dump_path, bench_dump_path, npu_stack_path)) - self.assertTrue((result_df['Bench Name'] != 'N/A').all()) - finally: - shutil.rmtree(data_path) - - def test_compare_process_with_customize_api_mapping(self): - data_path = tempfile.mkdtemp(prefix='dump_data', dir='/tmp') - try: - npu_dump_path = os.path.join(data_path, 'npu_dump.json') - bench_dump_path = os.path.join(data_path, 'bench_dump.json') - npu_stack_path = os.path.join(data_path, 'npu_stack.json') - user_mapping_path = os.path.join(data_path, 'user_mapping.yaml') - - npu_data, bench_data, user_mapping = gen_api_mapping_test_data(True) - with open(npu_dump_path, 'w', encoding='utf8') as n_d_f: - json.dump(npu_data, n_d_f) - with open(bench_dump_path, 'w', encoding='utf8') as b_d_f: - json.dump(bench_data, b_d_f) - with open(npu_stack_path, 'w', encoding='utf8') as n_s_f: - json.dump({}, n_s_f) - with open(user_mapping_path, 'w', encoding='utf8') as u_m_f: - yaml.safe_dump(user_mapping, u_m_f) - - stack_mode = True - auto_analyze = True - fuzzy_match = False - dump_mode = Const.SUMMARY - mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) - mapping_config = MappingConfig(api_mapping=user_mapping_path) - - ms_comparator = MSComparator(mode_config, mapping_config) - result_df = ms_comparator.compare_process((npu_dump_path, bench_dump_path, npu_stack_path)) - - user_mapping_dict = {} - for i in user_mapping: - user_mapping_dict[i.get('ms_api')] = {'input': i.get('ms_args'), 'output': i.get('ms_output')} - match_set = set() - for key in npu_data.get('data').keys(): - matched_dict = user_mapping_dict.get(key.rsplit('.', 2)[0]) - match_set.update({key + '.input.' + str(i) for i in matched_dict.get('input')}) - match_set.update({key + '.output.' + str(i) for i in matched_dict.get('output')}) - - self.assertTrue((result_df.loc[result_df['NPU Name'].isin(match_set), 'Bench Name'] != 'N/A').all()) - self.assertTrue((result_df.loc[~result_df['NPU Name'].isin(match_set), 'Bench Name'] == 'N/A').all()) - finally: - shutil.rmtree(data_path) - - def test_load_internal_api(self): - stack_mode = True - auto_analyze = True - fuzzy_match = False - dump_mode = Const.SUMMARY - - mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) - mapping_config = MappingConfig() - - ms_comparator = MSComparator(mode_config, mapping_config) - api_dict = ms_comparator.load_internal_api() - self.assertEqual(api_dict['Functional.abs'], 'Torch.abs') - - def test_process_cell_mapping(self): - self.base_test_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) - self.input_dir = os.path.join(self.base_test_dir, 'resources') - cell_mapping_path = os.path.join(self.input_dir, 'common', 'cell_mapping.yaml') - - stack_mode = True - auto_analyze = True - fuzzy_match = False - dump_mode = Const.SUMMARY - - mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) - mapping_config = MappingConfig(cell_mapping=cell_mapping_path) - - ms_comparator = MSComparator(mode_config, mapping_config) - npu_op_name = ms_comparator.process_cell_mapping(npu_cell_dict.get('op_name')[0]) - self.assertEqual(npu_op_name, 'Module.fc1.Linear.forward.0.input.0') - - def test_read_npy_data(self): - stack_mode = True - auto_analyze = True - fuzzy_match = False - dump_mode = Const.ALL - - mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) - mapping_config = MappingConfig() - - ms_comparator = MSComparator(mode_config, mapping_config) - - self.temp_file = tempfile.NamedTemporaryFile(suffix='.pt') - tensor = torch.Tensor([1, 2, 3]) - filename = self.temp_file.name.split('/')[-1] - torch.save(tensor, self.temp_file.name) - result = ms_comparator.read_npy_data('/tmp', filename, load_pt_file=True) - self.assertTrue(np.array_equal(result, np.array([1, 2, 3]))) - self.temp_file.close() - - self.temp_file = tempfile.NamedTemporaryFile(suffix='.npy') - tensor = np.array([1, 2, 3]) - filename = self.temp_file.name.split('/')[-1] - np.save(self.temp_file.name, tensor) - result = ms_comparator.read_npy_data('/tmp', filename, load_pt_file=False) - self.assertTrue(np.array_equal(result, np.array([1, 2, 3]))) - self.temp_file.close() - - def test_process_internal_api_mapping(self): - stack_mode = True - auto_analyze = True - fuzzy_match = False - dump_mode = Const.ALL - - mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) - mapping_config = MappingConfig(api_mapping=1) - - ms_comparator = MSComparator(mode_config, mapping_config) - - npu_op_name = "Mint.addcmul.0.forward.input.0" - result = ms_comparator.process_internal_api_mapping(npu_op_name) - self.assertEqual(result, "Torch.addcmul.0.forward.input.0") - - npu_op_name = "MintFunctional.addcmul.0.forward.input.0" - result = ms_comparator.process_internal_api_mapping(npu_op_name) - self.assertEqual(result, "Functional.addcmul.0.forward.input.0") - - npu_op_name = "Functional.abs" - result = ms_comparator.process_internal_api_mapping(npu_op_name) - self.assertEqual(result, "Torch.abs") - - def test_get_api_name(self): - stack_mode = True - auto_analyze = True - fuzzy_match = False - dump_mode = Const.ALL - - mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) - mapping_config = MappingConfig() - - ms_comparator = MSComparator(mode_config, mapping_config) - - api_list = ["Functional", "absolute", "0", "forward", "input", "0"] - result = ms_comparator.get_api_name(api_list) - self.assertEqual(result, "Functional.absolute") - - api_list = ["Mint"] - with self.assertRaises(CompareException): - ms_comparator.get_api_name(api_list) \ No newline at end of file + ms_compare(input_param, output_path) + output_files = os.listdir(output_path) + self.assertTrue(any(f.endswith(".xlsx") for f in output_files)) + + input_param2 = { + 'npu_json_path': '', + 'bench_json_path': dump_path, + 'is_print_compare_log': True + } + with self.assertRaises(CompareException) as context: + ms_compare(input_param2, output_path) + self.assertEqual(context.exception.code, 1) diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/compare/test_ms_compare_utils.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/compare/test_ms_compare_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d7fb5e38fb82b309caf3ab2a1b621655d7babc86 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/compare/test_ms_compare_utils.py @@ -0,0 +1,24 @@ +import unittest +from unittest.mock import patch + +import numpy as np + +from msprobe.core.common.file_utils import FileCheckConst +from msprobe.mindspore.compare.utils import read_npy_data + + +class TestReadNpyData(unittest.TestCase): + + @patch('msprobe.mindspore.compare.utils.load_npy') + @patch('msprobe.mindspore.compare.utils.FileChecker') + @patch('os.path.join', return_value='/fake/path/to/file.npy') + def test_read_real_data_ms(self, mock_os, mock_file_checker, mock_load_npy): + mock_file_checker.return_value.common_check.return_value = '/fake/path/to/file.npy' + + mock_load_npy.return_value = np.array([1.0, 2.0, 3.0]) + + result = read_npy_data('/fake/dir', 'file_name.npy') + + mock_file_checker.assert_called_once_with('/fake/path/to/file.npy', FileCheckConst.FILE, FileCheckConst.READ_ABLE, FileCheckConst.NUMPY_SUFFIX, False) + mock_load_npy.assert_called_once_with('/fake/path/to/file.npy') + self.assertTrue(np.array_equal(result, np.array([1.0, 2.0, 3.0]))) diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/compare/test_ms_graph_compare.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/compare/test_ms_graph_compare.py index e3fd9348efe7dd4df0a6db2cd52a45f4757dae01..df22eabfa83d7a3de587125c2374938a600389fa 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/compare/test_ms_graph_compare.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/compare/test_ms_graph_compare.py @@ -78,7 +78,7 @@ class TestMsGraphCompare(unittest.TestCase): result_correct = ( f"[['{npu_file_path}', '{bench_file_path}', dtype('float16'), dtype('float16'), (10, 10), (10, 10), " - f"44.0, 44.0, 44.0, inf, 44.0, 44.0, 44.0, inf, 'Yes', '', 1.0, 0.0, 0.0, 1.0, 1.0]]") + f"44.0, 44.0, 44.0, inf, 44.0, 44.0, 44.0, inf, 'Yes', '', 1.0, 0.0, 0.0, 0.0, 1.0, 1.0]]") self.assertNotEqual(len(files), 0) self.assertEqual(result, result_correct) @@ -93,7 +93,7 @@ class TestMsGraphCompare(unittest.TestCase): compare_result_db = ms_graph_comparator.compare_ops(compare_result_db, mode) result = compare_result_db.values.tolist() - op_name = 'Default_Switch-op1_kernel_graph1_Data_86.185.41.output' + op_name = 'Default_Switch-op1_kernel_graph1_Data_86.185.output.0' npu_file_path = os.path.join(self.npu_data_path, 'rank_0/mnist/0/0/statistic.csv') bench_file_path = os.path.join(self.bench_data_path, 'rank_0/mnist/0/0/statistic.csv') npu_name = f'{op_name} {npu_file_path}' diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/debugger/test_ms_debugger_config.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/debugger/test_ms_debugger_config.py index 033b0c1ea5769c3f1f8e19dd8b45c48918e15814..8a7195eac824485e75d8c1ba0752715c7c6a5600 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/debugger/test_ms_debugger_config.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/debugger/test_ms_debugger_config.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -17,14 +17,17 @@ import unittest from unittest.mock import patch from msprobe.core.common.const import Const -from msprobe.core.common_config import CommonConfig, BaseConfig +from msprobe.core.common.log import logger +from msprobe.core.common_config import CommonConfig from msprobe.mindspore.common.const import FreeBenchmarkConst from msprobe.mindspore.debugger.debugger_config import DebuggerConfig +from msprobe.mindspore.ms_config import StatisticsConfig class TestDebuggerConfig(unittest.TestCase): + @patch.object(logger, "error") @patch("msprobe.mindspore.debugger.debugger_config.create_directory") - def test_init(self, _): + def test_init(self, _, mock_logger_error): json_config = { "dump_path": "/absolute_path", "rank": [], @@ -32,12 +35,13 @@ class TestDebuggerConfig(unittest.TestCase): "level": "L2" } common_config = CommonConfig(json_config) - task_config = BaseConfig(json_config) + task_config = StatisticsConfig(json_config) debugger_config = DebuggerConfig(common_config, task_config) self.assertEqual(debugger_config.task, Const.STATISTICS) self.assertEqual(debugger_config.file_format, "npy") self.assertEqual(debugger_config.check_mode, "all") self.assertEqual(debugger_config.overflow_nums, 1) + self.assertEqual(debugger_config.tensor_list, []) common_config.level = "L1" common_config.task = Const.FREE_BENCHMARK @@ -49,17 +53,16 @@ class TestDebuggerConfig(unittest.TestCase): task_config.handler_type = FreeBenchmarkConst.FIX task_config.pert_mode = FreeBenchmarkConst.ADD_NOISE - with self.assertRaises(Exception) as context: + with self.assertRaises(ValueError): DebuggerConfig(common_config, task_config) - self.assertEqual(str(context.exception), - "pert_mode must be improve_precision or empty when handler_type is fix, " - f"but got {FreeBenchmarkConst.ADD_NOISE}.") + mock_logger_error.assert_called_with("pert_mode must be improve_precision or empty when handler_type is fix, " + f"but got {FreeBenchmarkConst.ADD_NOISE}.") + mock_logger_error.reset_mock() task_config.handler_type = FreeBenchmarkConst.FIX task_config.pert_mode = FreeBenchmarkConst.DEFAULT_PERT_TYPE task_config.fuzz_stage = Const.BACKWARD - with self.assertRaises(Exception) as context: + with self.assertRaises(ValueError): DebuggerConfig(common_config, task_config) - self.assertEqual(str(context.exception), - "handler_type must be check or empty when fuzz_stage is backward, " - f"but got {task_config.handler_type}.") + mock_logger_error.assert_called_with("handler_type must be check or empty when fuzz_stage is backward, " + f"but got {task_config.handler_type}.") diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/debugger/test_ms_precision_debugger.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/debugger/test_ms_precision_debugger.py index 066ff537ce6fba12f712ae3d4681115499be35a6..5f2547775f49cde44d062632e025025fad7e643d 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/debugger/test_ms_precision_debugger.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/debugger/test_ms_precision_debugger.py @@ -16,14 +16,16 @@ import unittest from unittest.mock import patch, MagicMock -from msprobe.core.common_config import CommonConfig, BaseConfig from msprobe.core.common.const import Const, MsgConst +from msprobe.core.common_config import CommonConfig +from msprobe.core.debugger.precision_debugger import BasePrecisionDebugger from msprobe.mindspore.cell_processor import CellProcessor from msprobe.mindspore.common.const import Const as MsConst from msprobe.mindspore.debugger.debugger_config import DebuggerConfig from msprobe.mindspore.debugger.precision_debugger import PrecisionDebugger from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell -from msprobe.mindspore.runtime import Runtime +from msprobe.mindspore.ms_config import StatisticsConfig +from msprobe.core.common.runtime import Runtime class TestPrecisionDebugger(unittest.TestCase): @@ -48,12 +50,12 @@ class TestPrecisionDebugger(unittest.TestCase): } common_config = CommonConfig(json_config) - task_config = BaseConfig(json_config) + task_config = StatisticsConfig(json_config) handler = Handler() mock_get_mode = MagicMock() mock_parse_json_config = MagicMock() - with patch("msprobe.mindspore.debugger.precision_debugger.parse_json_config", new=mock_parse_json_config), \ + with patch.object(BasePrecisionDebugger, "_parse_config_path", new=mock_parse_json_config), \ patch.object(PrecisionDebugger, "_get_execution_mode", new=mock_get_mode), \ patch("msprobe.mindspore.debugger.precision_debugger.TaskHandlerFactory.create", return_value=handler), \ patch("msprobe.mindspore.debugger.precision_debugger.set_register_backward_hook_functions"): @@ -68,20 +70,20 @@ class TestPrecisionDebugger(unittest.TestCase): self.assertTrue(Handler.called) mock_get_mode.return_value = MsConst.PYNATIVE_MODE - with patch("msprobe.mindspore.debugger.precision_debugger.Service") as mock_Service, \ + with patch("msprobe.mindspore.debugger.precision_debugger.MindsporeService") as mock_Service, \ patch("msprobe.mindspore.debugger.precision_debugger.set_register_backward_hook_functions"): debugger = PrecisionDebugger() debugger.start() service = mock_Service.return_value mock_Service.assert_called_with(debugger.config) - service.start.assert_called_with(None) + service.start.assert_called_with(None, None) PrecisionDebugger._instance = None with self.assertRaises(Exception) as context: debugger.start() self.assertEqual(str(context.exception), MsgConst.NOT_CREATED_INSTANCE) - with patch("msprobe.mindspore.debugger.precision_debugger.parse_json_config", new=mock_parse_json_config), \ + with patch.object(BasePrecisionDebugger, "_parse_config_path", new=mock_parse_json_config), \ patch.object(PrecisionDebugger, "_get_execution_mode", new=mock_get_mode), \ patch("msprobe.mindspore.debugger.precision_debugger.TaskHandlerFactory.create", return_value=handler), \ patch("msprobe.mindspore.debugger.precision_debugger.set_register_backward_hook_functions"): @@ -94,10 +96,19 @@ class TestPrecisionDebugger(unittest.TestCase): self.assertTrue(Handler.called) def test_stop_step(self): + class MockConfig: + def __init__(self): + self.task = Const.TENSOR + self.execution_mode = None + self.level = None + self.level_ori = Const.LEVEL_L1 + class MockPrecisionDebugger: def __init__(self): self.task = Const.TENSOR self.service = None + self.config = MockConfig() + PrecisionDebugger._instance = None with self.assertRaises(Exception) as context: PrecisionDebugger.stop() @@ -123,20 +134,25 @@ class TestPrecisionDebugger(unittest.TestCase): mock_reset_cell.assert_called_once() def test_forward_backward_dump_end(self): - with patch("msprobe.mindspore.debugger.precision_debugger.set_register_backward_hook_functions"): + json_config = { + "task": "statistics", + "dump_path": "/absolute_path", + "rank": [], + "step": [], + "level": "L1", + "async_dump": False + } + + common_config = CommonConfig(json_config) + task_config = StatisticsConfig(json_config) + with patch.object(BasePrecisionDebugger, "_parse_config_path", return_value=(common_config, task_config)), \ + patch("msprobe.mindspore.debugger.precision_debugger.set_register_backward_hook_functions"): debugger = PrecisionDebugger() debugger.task = "statistics" debugger.service = MagicMock() debugger.forward_backward_dump_end() debugger.service.stop.assert_called_once() - def test_is_graph_dump_level_not_kernel(self): - config = MagicMock() - config.level = "NOT_KERNEL" - config.list = ["some_value"] - result = PrecisionDebugger._is_graph_dump(config) - self.assertFalse(result) - def test_is_graph_dump_empty_list(self): config = MagicMock() config.level = MsConst.KERNEL diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/dump/test_ms_kernel_config.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/dump/test_ms_kernel_config.py index 54c59b6409cb546384dcb50f47c7c27975fa1cb7..e760faefd31b2e2b60c24091c6eebed087f4268f 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/dump/test_ms_kernel_config.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/dump/test_ms_kernel_config.py @@ -16,11 +16,11 @@ import unittest from unittest.mock import patch -from msprobe.mindspore.dump.kernel_dump.kernel_config import create_kernel_config_json +from msprobe.core.kernel_dump.kernel_config import create_kernel_config_json class TestPtKernelConfig(unittest.TestCase): - @patch("msprobe.mindspore.dump.kernel_dump.kernel_config.save_json") + @patch("msprobe.core.kernel_dump.kernel_config.save_json") def test_create_kernel_config_json_with_rank(self, mock_save_json): dump_path = "./step0" cur_rank = 0 @@ -36,7 +36,7 @@ class TestPtKernelConfig(unittest.TestCase): } mock_save_json.assert_called_once_with(kernel_config_path, config_info, indent=4) - @patch("msprobe.mindspore.dump.kernel_dump.kernel_config.save_json") + @patch("msprobe.core.kernel_dump.kernel_config.save_json") def test_create_kernel_config_json_without_rank(self, mock_save_json): dump_path = "./step0" cur_rank = '' diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/free_benchmark/common/test_ms_free_benchmark_utils.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/free_benchmark/common/test_ms_free_benchmark_utils.py index 1f37e18c6ef8d3facb526f3c54169a17f4616189..d1f1b48dfb9a622b10fcf39c32394339f0a8df9c 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/free_benchmark/common/test_ms_free_benchmark_utils.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/free_benchmark/common/test_ms_free_benchmark_utils.py @@ -22,7 +22,7 @@ from msprobe.mindspore.common.const import FreeBenchmarkConst from msprobe.mindspore.free_benchmark.common.config import Config from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams from msprobe.mindspore.free_benchmark.common.utils import Tools, UnequalRow, make_unequal_row -from msprobe.mindspore.runtime import Runtime +from msprobe.core.common.runtime import Runtime class TestUtils(unittest.TestCase): diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/free_benchmark/handler/test_ms_check_handler.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/free_benchmark/handler/test_ms_check_handler.py index 58c0a7b46ad7ca5c05a157733d57dd8828ced24d..d983073794b31800c0deb56123635a8e3fb785c7 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/free_benchmark/handler/test_ms_check_handler.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/free_benchmark/handler/test_ms_check_handler.py @@ -24,7 +24,7 @@ from msprobe.mindspore.common.log import logger from msprobe.mindspore.free_benchmark.common.config import Config from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams from msprobe.mindspore.free_benchmark.handler.check_handler import CheckHandler -from msprobe.mindspore.runtime import Runtime +from msprobe.core.common.runtime import Runtime def where(*args): diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/free_benchmark/perturbation/test_ms_base_perturbation.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/free_benchmark/perturbation/test_ms_base_perturbation.py index 3469e809d3fb27f9e366d128cfa10d68c776e391..41f7ea6db9e55b4886fbf2d0b21b5c2abb2e1551 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/free_benchmark/perturbation/test_ms_base_perturbation.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/free_benchmark/perturbation/test_ms_base_perturbation.py @@ -33,7 +33,7 @@ class TestBasePerturbation(unittest.TestCase): self.assertFalse(TestBasePerturbation.base_pert.is_fuzzed) self.assertIsNone(TestBasePerturbation.base_pert.perturbation_value) - @patch("msprobe.mindspore.service.Service.should_execute_hook", return_value=False) + @patch("msprobe.core.hook_manager.BaseHookManager._should_execute_hook", return_value=False) def test_get_fuzzed_result(self, _): params = HandlerParams() params.args = [Tensor([1.0], dtype=ms.float32), Tensor([5.0], dtype=ms.float32)] diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/free_benchmark/perturbation/test_ms_improve_precision.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/free_benchmark/perturbation/test_ms_improve_precision.py index e200bb40868fab8a9618047244830aa8a74cec27..84f9766a053c9f0b1bed69f49a1a044300d9e215 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/free_benchmark/perturbation/test_ms_improve_precision.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/free_benchmark/perturbation/test_ms_improve_precision.py @@ -68,7 +68,7 @@ class TestImprovePrecisionPerturbation(unittest.TestCase): self.assertEqual(ret.dtype, target.dtype) self.assertFalse(self.improve_precision_pert.is_fuzzed) - @patch("msprobe.mindspore.service.Service.should_execute_hook", return_value=False) + @patch("msprobe.core.hook_manager.BaseHookManager._should_execute_hook", return_value=False) @patch.object(logger, "warning") def test_handle(self, mock_warning, _): self.improve_precision_pert.is_fuzzed = False diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/free_benchmark/perturbation/test_ms_perturbation_factory.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/free_benchmark/perturbation/test_ms_perturbation_factory.py index 858e664bbaddb3506bf53ea067eeca1c9706b43b..a4458912149fc8600d32a542c398a335be5d636d 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/free_benchmark/perturbation/test_ms_perturbation_factory.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/free_benchmark/perturbation/test_ms_perturbation_factory.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,7 +14,9 @@ # limitations under the License. import unittest +from unittest.mock import patch +from msprobe.mindspore.common.log import logger from msprobe.mindspore.free_benchmark.perturbation.perturbation_factory import PerturbationFactory from msprobe.mindspore.free_benchmark.common.config import Config from msprobe.mindspore.common.const import FreeBenchmarkConst @@ -27,14 +29,14 @@ from msprobe.mindspore.free_benchmark.perturbation.exchange_value import Exchang class TestPerturbationFactory(unittest.TestCase): - def test_create(self): + @patch.object(logger, "error") + def test_create(self, mock_logger_error): api_name = "Functional.add.0" Config.pert_type = "UNKNOWN" - with self.assertRaises(Exception) as context: + with self.assertRaises(ValueError): PerturbationFactory.create(api_name) - self.assertEqual(str(context.exception), - "UNKNOWN is a invalid perturbation type") + mock_logger_error.assert_called_with("UNKNOWN is a invalid perturbation type") Config.pert_type = FreeBenchmarkConst.EXCHANGE_VALUE pert = PerturbationFactory.create(api_name) diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/free_benchmark/test_ms_api_pynative_self_check.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/free_benchmark/test_ms_api_pynative_self_check.py index e589dd4d58715d74644047f8c7e7a6ce79ccf225..ec9c66fc498603b59458468f44678463e40c0691 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/free_benchmark/test_ms_api_pynative_self_check.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/free_benchmark/test_ms_api_pynative_self_check.py @@ -23,18 +23,24 @@ from mindspore import Tensor, mint, ops from msprobe.core.common.const import Const from msprobe.mindspore.common.const import FreeBenchmarkConst from msprobe.mindspore.common.log import logger -from msprobe.mindspore.dump.hook_cell.api_registry import api_register -from msprobe.mindspore.free_benchmark.api_pynative_self_check import (ApiPyNativeSelfCheck, check_all_tensor, - check_self, data_pre_deal, - deal_fuzzed_and_original_result, - get_module, get_supported_ops, - get_target_arg_index, need_wrapper_func) +from msprobe.mindspore.free_benchmark.api_pynative_self_check import ( + ApiPyNativeSelfCheck, + check_all_tensor, + check_self, + data_pre_deal, + deal_fuzzed_and_original_result, + get_module, + get_supported_ops, + get_target_arg_index, + need_wrapper_func, + _api_register +) from msprobe.mindspore.free_benchmark.common.config import Config from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams from msprobe.mindspore.free_benchmark.common.utils import Tools from msprobe.mindspore.free_benchmark.handler.check_handler import CheckHandler from msprobe.mindspore.free_benchmark.handler.fix_handler import FixHandler -from msprobe.mindspore.runtime import Runtime +from msprobe.core.common.runtime import Runtime class DebuggerConfig: @@ -83,31 +89,33 @@ class TestApiPyNativeSelfCheck(TestCase): self.assertEqual(self_checker.ori_func, target_ori_func) def test_handle(self): - with patch.object(api_register, "initialize_hook") as mock_init_hook, \ - patch.object(api_register, "api_set_hook_func") as mock_set_hook: + with patch.object(_api_register, "initialize_hook") as mock_init_hook, \ + patch.object(_api_register, "register_all_api") as mock_set_hook: self.checker.handle() mock_init_hook.assert_called_with(self.checker.build_hook) mock_set_hook.assert_called_once() def test_build_hook(self): - _, forward_hook, backward_hook, _ = self.checker.build_hook("Functional.add.") + hook_set = self.checker.build_hook("Functional.add.") cell = Cell() + cell.msprobe_input_kwargs = {} with patch("msprobe.mindspore.free_benchmark.api_pynative_self_check.need_wrapper_func", return_value=False): - self.assertIsNone(forward_hook(cell, "input", "output")) + self.assertIsNone(hook_set.forward_hook(cell, "input", "output")) cell = Cell() + cell.msprobe_input_kwargs = {} self.checker.api_list = ["mindspore.ops.add"] self.checker.ori_func["mindspore.ops.add"] = "add" with patch("msprobe.mindspore.free_benchmark.api_pynative_self_check.need_wrapper_func", return_value=True), \ patch("msprobe.mindspore.free_benchmark.api_pynative_self_check.check_self", return_value="ret") as mock_check: - ret = forward_hook(cell, ("input",), ("output",)) + ret = hook_set.forward_hook(cell, ("input",), ("output",)) self.assertEqual(ret, "ret") mock_check.assert_called_with("Functional.add.0", ("output",), "add", "input") - self.assertIsNone(backward_hook("cell", "grad_input", "grad_output")) + self.assertIsNone(hook_set.backward_hook("cell", "grad_input", "grad_output")) def test_store_original_func(self): self.checker.api_list = ["mindspore.ops.add"] @@ -156,8 +164,8 @@ class TestApiPyNativeSelfCheck(TestCase): mock_warning.reset_mock() Config.stage = Const.FORWARD with patch.object(logger, "info") as mock_info, \ - patch.object(api_register, "api_set_ori_func") as mock_set_ori, \ - patch.object(api_register, "api_set_hook_func") as mock_set_hook, \ + patch.object(_api_register, "restore_all_api") as mock_set_ori, \ + patch.object(_api_register, "register_all_api") as mock_set_hook, \ patch("msprobe.mindspore.free_benchmark.api_pynative_self_check.deal_fuzzed_and_original_result", return_value="ret"): args = (1.0, 1.0) diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/free_benchmark/test_ms_self_check_tool_factory.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/free_benchmark/test_ms_self_check_tool_factory.py index fa68b8896c26d4156833c54d2b2bf5b443164e8f..4f3ddd45b5a05162c60abe831967dd449f3f5ae6 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/free_benchmark/test_ms_self_check_tool_factory.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/free_benchmark/test_ms_self_check_tool_factory.py @@ -1,7 +1,6 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" -# Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -13,11 +12,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" import os import unittest +from unittest.mock import patch +from msprobe.core.common.log import logger from msprobe.mindspore.free_benchmark.self_check_tool_factory import SelfCheckToolFactory from msprobe.mindspore.free_benchmark.api_pynative_self_check import ApiPyNativeSelfCheck from msprobe.mindspore.debugger.debugger_config import DebuggerConfig @@ -28,7 +28,8 @@ from msprobe.core.common.const import Const class TestSelfCheckToolFactory(unittest.TestCase): - def test_create(self): + @patch.object(logger, "error") + def test_create(self, mock_logger_error): common_config = CommonConfig({}) common_config.task = Const.FREE_BENCHMARK common_config.dump_path = os.path.dirname(os.path.realpath(__file__)) @@ -36,16 +37,16 @@ class TestSelfCheckToolFactory(unittest.TestCase): config = DebuggerConfig(common_config, task_config) config.level = "UNKNOWN" - with self.assertRaises(Exception) as context: + with self.assertRaises(ValueError): SelfCheckToolFactory.create(config) - self.assertEqual(str(context.exception), "UNKNOWN is not supported.") + mock_logger_error.assert_called_with("UNKNOWN is not supported.") + mock_logger_error.reset_mock() config.level = MsConst.API config.execution_mode = MsConst.GRAPH_KBYK_MODE - with self.assertRaises(Exception) as context: + with self.assertRaises(ValueError): SelfCheckToolFactory.create(config) - self.assertEqual(str(context.exception), - f"Task free_benchmark is not supported in this mode: {MsConst.GRAPH_KBYK_MODE}.") + mock_logger_error.assert_called_with(f"Task free_benchmark is not supported in this mode: {MsConst.GRAPH_KBYK_MODE}.") config.execution_mode = MsConst.PYNATIVE_MODE tool = SelfCheckToolFactory.create(config) diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/grad_probe/test_grad_analyzer.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/grad_probe/test_grad_analyzer.py index 802769d9005916c8723d436349d13ca7f557a00a..fefdaffec0798e09a5fc5787d11a0bf89167ecef 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/grad_probe/test_grad_analyzer.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/grad_probe/test_grad_analyzer.py @@ -1,6 +1,7 @@ import os import shutil import json +import time import numpy as np import mindspore as ms from unittest import TestCase, mock @@ -15,7 +16,8 @@ class TestGradAnalyzer(TestCase): @classmethod def setUpClass(cls): cls.output_path = "./test_output" - cls.dump_dir = f"{cls.output_path}/rank0/Dump" + cls.time_stamp = str(int(time.time())) + cls.dump_dir = f"{cls.output_path}/rank0/Dump{cls.time_stamp}" cls.save_dir = f"{cls.output_path}/rank0" os.makedirs(cls.dump_dir, exist_ok=True) @@ -31,7 +33,8 @@ class TestGradAnalyzer(TestCase): 'get_context.side_effect': lambda x: { GradConst.OUTPUT_PATH: self.output_path, GradConst.LEVEL: GradConst.LEVEL2, - GradConst.BOUNDS: [-0.1, 0.0, 0.1] + GradConst.BOUNDS: [-0.1, 0.0, 0.1], + GradConst.TIME_STAMP: self.time_stamp }[x] })) # Clear dump directory before each test diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/grad_probe/test_ms_grad_monitor.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/grad_probe/test_ms_grad_monitor.py deleted file mode 100644 index ae24457a444bfdddc796802126150577280d7e62..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/grad_probe/test_ms_grad_monitor.py +++ /dev/null @@ -1,182 +0,0 @@ -# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import hashlib -import json -import os -import shutil -from unittest import TestCase -from unittest.mock import patch - -import numpy as np -import mindspore -from mindspore import nn, Tensor -from mindspore.nn import SGD - -from msprobe.core.common.file_utils import FileOpen -from msprobe.core.grad_probe.constant import GradConst -from msprobe.mindspore import PrecisionDebugger -from msprobe.mindspore.grad_probe.global_context import grad_context - - -file_path = os.path.abspath(__file__) -directory = os.path.dirname(file_path) -config_json_path = os.path.join(directory, "config.json") - - -def main(): - PrecisionDebugger._instance = None - PrecisionDebugger.initialized = False - grad_context._setting[GradConst.CURRENT_STEP] = 0 - with patch("msprobe.mindspore.debugger.precision_debugger.set_register_backward_hook_functions"): - debugger = PrecisionDebugger(config_json_path) - - class SimpleNet(nn.Cell): - def __init__(self): - super().__init__() - self.my_dense = nn.Dense(16, 5) - - def construct(self, x): - x = self.flatten(x) - logits = self.my_dense(x) - return logits - model = SimpleNet() - optimizer = SGD(model.trainable_params(), learning_rate=0.001) - - debugger.monitor(optimizer) - - fix_gradient = tuple([Tensor(np.arange(5*16).reshape((5, 16)), dtype=mindspore.float32), - Tensor(np.arange(5).reshape(5), dtype=mindspore.float32)]) - - steps = 10 - - for _ in range(steps): - optimizer(fix_gradient) - - -def save_dict_as_json(data, json_file_path): - with FileOpen(json_file_path, 'w') as f: - json.dump(data, f, ensure_ascii=False, indent=4) - print(f"字典已保存为json文件: {json_file_path}") - - -def get_hash(file_path): - with FileOpen(file_path, 'rb') as file: - hash_object = hashlib.md5() - for chunk in iter(lambda: file.read(4096), b""): - hash_object.update(chunk) - return hash_object.hexdigest() - - -class TestMsGradientMonitor(TestCase): - def test_gradient_monitor_L2(self): - gradient_output_path = os.path.join(directory, "gradient_output") - if os.path.isfile(config_json_path): - os.remove(config_json_path) - if os.path.isdir(gradient_output_path): - shutil.rmtree(gradient_output_path) - config_dict = { - "task": "grad_probe", - "dump_path": gradient_output_path, - "rank": [], - "step": [1], - "grad_probe": { - "grad_level": "L2", - "param_list": [] - } - } - save_dict_as_json(config_dict, config_json_path) - - main() - - my_dense_bias_path = os.path.join(gradient_output_path, "rank0", "step1", "my_dense.bias.npy") - self.assertTrue(os.path.isfile(my_dense_bias_path), "bias npy file not found") - my_dense_bias_real = np.load(my_dense_bias_path) - my_dense_bias_target = np.arange(5).reshape(5) > 0 - - self.assertTrue((my_dense_bias_real == my_dense_bias_target).all(), "bias ndarray not same as target") - - my_dense_weight_path = os.path.join(gradient_output_path, "rank0", "step1", "my_dense.weight.npy") - self.assertTrue(os.path.isfile(my_dense_weight_path), "weight npy file not found") - my_dense_weight_real = np.load(my_dense_weight_path) - my_dense_weight_target = np.arange(5*16).reshape((5, 16)) > 0 - - self.assertTrue((my_dense_weight_real == my_dense_weight_target).all(), "weight ndarray not same as target") - - real_md5_value = get_hash(os.path.join(gradient_output_path, "rank0", "grad_summary_1.csv")) - target_md5_value = "d5e71f1aa37d48ef0ca0a75932597a29" - self.assertEqual(real_md5_value, target_md5_value, "hash value of grad_summary_1.csv is not same as target") - - def test_gradient_monitor_L1(self): - gradient_output_path = os.path.join(directory, "gradient_output") - if os.path.isfile(config_json_path): - os.remove(config_json_path) - if os.path.isdir(gradient_output_path): - shutil.rmtree(gradient_output_path) - config_dict = { - "task": "grad_probe", - "dump_path": gradient_output_path, - "rank": [], - "step": [1], - "grad_probe": { - "grad_level": "L1", - "param_list": [] - } - } - save_dict_as_json(config_dict, config_json_path) - - main() - - my_dense_bias_path = os.path.join(gradient_output_path, "rank0", "step1", "my_dense.bias.npy") - self.assertTrue(os.path.isfile(my_dense_bias_path), "bias npy file not found") - my_dense_bias_real = np.load(my_dense_bias_path) - my_dense_bias_target = np.arange(5).reshape(5) > 0 - - self.assertTrue((my_dense_bias_real == my_dense_bias_target).all(), "bias ndarray not same as target") - - my_dense_weight_path = os.path.join(gradient_output_path, "rank0", "step1", "my_dense.weight.npy") - self.assertTrue(os.path.isfile(my_dense_weight_path), "weight npy file not found") - my_dense_weight_real = np.load(my_dense_weight_path) - my_dense_weight_target = np.arange(5*16).reshape((5, 16)) > 0 - - self.assertTrue((my_dense_weight_real == my_dense_weight_target).all(), "weight ndarray not same as target") - - real_md5_value = get_hash(os.path.join(gradient_output_path, "rank0", "grad_summary_1.csv")) - target_md5_value = "a4ad300992cb10965fbc12c2ee19dd37" - self.assertEqual(real_md5_value, target_md5_value, "hash value of grad_summary_1.csv is not same as target") - - def test_gradient_monitor_L0(self): - gradient_output_path = os.path.join(directory, "gradient_output") - if os.path.isfile(config_json_path): - os.remove(config_json_path) - if os.path.isdir(gradient_output_path): - shutil.rmtree(gradient_output_path) - config_dict = { - "task": "grad_probe", - "dump_path": gradient_output_path, - "rank": [], - "step": [1], - "grad_probe": { - "grad_level": "L0", - "param_list": [] - } - } - save_dict_as_json(config_dict, config_json_path) - - main() - - real_md5_value = get_hash(os.path.join(gradient_output_path, "rank0", "grad_summary_1.csv")) - target_md5_value = "62e137a119c0d1a44623f10049c3f80d" - self.assertEqual(real_md5_value, target_md5_value, "hash value of grad_summary_1.csv is not same as target") diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/hook_module/test_ms_hook_manager.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/hook_module/test_ms_hook_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..5c6ebe9b923dc543ed5fd675f667ea5a6310894b --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/hook_module/test_ms_hook_manager.py @@ -0,0 +1,129 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest.mock import MagicMock, patch +from msprobe.mindspore.dump.hook_cell.ms_hook_manager import MindsproeHookManager +from msprobe.core.common.const import Const +from msprobe.core.hook_manager import HookSet, BaseHookManager + + +class TestMindsproeHookManager(unittest.TestCase): + def setUp(self): + self.mock_data_collector = MagicMock() + self.mock_config = MagicMock() + self.mock_config.data_mode = ["all"] + self.mock_config.task = "statistics" + self.mock_config.level = Const.LEVEL_L1 + self.manager = MindsproeHookManager( + self.mock_data_collector, + self.mock_config + ) + BaseHookManager.inner_switch = False + + def test_properties(self): + self.assertIsNone(self.manager._is_recompute) + + with patch('msprobe.mindspore.dump.hook_cell.ms_hook_manager._no_grad') as mock_no_grad: + ctx = self.manager._no_grad_context() + mock_no_grad.assert_called_once() + + def test_add_count(self): + with patch('msprobe.mindspore.dump.hook_cell.ms_hook_manager.HOOKCell.add_cell_count') as mock_add: + self.manager._add_count("test_module") + mock_add.assert_called_once_with("test_module") + + def test_process_kwargs_and_output(self): + mock_module = MagicMock() + mock_module.msprobe_input_kwargs = {"kw1": "v1"} + + kwargs, output = self.manager._process_kwargs_and_output( + mock_module, Const.API, "output_value", "ignored" + ) + self.assertEqual(kwargs, {"kw1": "v1"}) + self.assertEqual(output, "output_value") + + with patch('msprobe.mindspore.dump.hook_cell.ms_hook_manager.has_kwargs_in_forward_hook', return_value=True): + kwargs, output = self.manager._process_kwargs_and_output( + mock_module, Const.MODULE, "kwargs_value", "output_value" + ) + self.assertEqual(kwargs, "kwargs_value") + self.assertEqual(output, "output_value") + + def test_build_hook(self): + hookset = self.manager.build_hook(Const.API, "test_api") + self.assertIsInstance(hookset, HookSet) + self.assertEqual(hookset.forward_hook.__name__, "forward_hook") + self.assertEqual(hookset.forward_pre_hook.__name__, "forward_pre_hook") + self.assertEqual(hookset.backward_hook.__name__, "backward_hook") + self.assertEqual(hookset.backward_pre_hook.__name__, "backward_pre_hook") + hookset = self.manager.build_hook(Const.MODULE, "test_module") + self.assertEqual(hookset.forward_pre_hook.__name__, "forward_pre_hook") + + def test_need_exchange(self): + mock_module = MagicMock() + del mock_module.has_pre_hook_called + self.assertFalse(self.manager._need_exchange(mock_module)) + + mock_module.has_pre_hook_called = False + self.assertFalse(self.manager._need_exchange(mock_module)) + + mock_module.has_pre_hook_called = True + self.assertTrue(self.manager._need_exchange(mock_module)) + + def test_get_params_dict(self): + mock_module = MagicMock() + + self.mock_config.task = Const.STRUCTURE + params_dict = self.manager._get_params_dict(mock_module) + self.assertEqual(params_dict, {}) + + self.mock_config.task = "statistics" + mock_params = { + "test_module.weight": "w1", + "test_module.bias": "b1" + } + mock_module.parameters_dict.return_value = mock_params + params_dict = self.manager._get_params_dict(mock_module) + mock_module.parameters_dict.assert_called_once_with(recurse=False) + self.assertEqual(params_dict, {"weight": "w1", "bias": "b1"}) + + def test_build_backward_pre_hook(self): + hook_fn = self.manager._build_backward_pre_hook(Const.MODULE, "test_module_backward") + + mock_module = MagicMock() + mock_grad_input = ("grad1", "grad2") + + with patch.object(self.manager, '_should_execute_hook', return_value=False): + hook_fn(mock_module, mock_grad_input) + self.mock_data_collector.backward_input_data_collect.assert_not_called() + + self.mock_config.level = Const.LEVEL_L2 + with patch.object(self.manager, '_should_execute_hook', return_value=True): + hook_fn(mock_module, mock_grad_input) + + self.mock_data_collector.update_api_or_module_name.assert_called_with("test_module_backward") + self.mock_data_collector.backward_input_data_collect.assert_called_once() + + call_args = self.mock_data_collector.backward_input_data_collect.call_args[0] + module_input = call_args[3] + self.assertEqual(module_input.grad_input, mock_grad_input) + + self.assertFalse(BaseHookManager.inner_switch) + + self.mock_config.level = Const.LEVEL_L1 + with patch.object(self.manager, '_should_execute_hook', return_value=True): + hook_fn(mock_module, mock_grad_input) + self.mock_data_collector.backward_input_data_collect.assert_called_once() diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/ms_monitor/test_common_func.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/ms_monitor/test_common_func.py new file mode 100644 index 0000000000000000000000000000000000000000..d0753c5c2300b58814504f1e2a0e1bfe7cb56e12 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/ms_monitor/test_common_func.py @@ -0,0 +1,120 @@ +import pytest +from unittest.mock import patch, MagicMock +from mindspore import nn, context +from mindspore.common.initializer import Normal +import mindspore as ms + +from msprobe.mindspore.monitor.common_func import ( + is_valid_instance, + get_submodules, + get_parameters, + get_rank, + comm_is_initialized, + optimizer_pre_hook, + optimizer_post_hook +) + +TORCH_AVAILABLE = False +try: + import torch + import torch.nn as torch_nn + TORCH_AVAILABLE = True +except ImportError: + TORCH_AVAILABLE = False + + +class TestModelUtils: + @classmethod + def setup_class(cls): + """Setup once for all tests in this class""" + cls.ms_model = MSModel() + if TORCH_AVAILABLE: + cls.torch_model = TorchModel() + + @classmethod + def teardown_class(cls): + """Cleanup after all tests in this class""" + pass + + + def test_is_valid_instance_if_model_is_cell_or_module_then_return_true(self): + with patch('msprobe.mindspore.monitor.common_func.is_mindtorch') as mock_is_mindtorch: + if TORCH_AVAILABLE: + mock_is_mindtorch.return_value = True + assert is_valid_instance(self.torch_model) + mock_is_mindtorch.return_value = False + assert is_valid_instance(self.ms_model) + + def test_is_valid_instance_if_input_is_string_then_return_false(self): + assert not is_valid_instance("not a model") + + def test_is_valid_instance_if_input_is_number_then_return_false(self): + assert not is_valid_instance(123) + + def test_get_submodules_if_model_is_valid_then_return_non_empty_dict(self): + with patch('msprobe.mindspore.monitor.common_func.is_mindtorch') as mock_is_mindtorch: + mock_is_mindtorch.return_value = True + if TORCH_AVAILABLE: + submodules = dict(get_submodules(self.torch_model)) + assert len(submodules) > 0 + assert any(name == 'conv1' for name in submodules) + + mock_is_mindtorch.return_value = False + submodules = dict(get_submodules(self.ms_model)) + assert len(submodules) > 0 + assert any(name.endswith('conv1') for name in submodules) + + + def test_get_submodules_if_model_is_invalid_then_return_empty_dict(self): + assert get_submodules("invalid") == {} + + def test_get_parameters_if_model_is_valid_then_return_non_empty_dict(self): + with patch('msprobe.mindspore.monitor.common_func.is_mindtorch') as mock_is_mindtorch: + mock_is_mindtorch.return_value = True + if TORCH_AVAILABLE: + params = dict(get_parameters(self.torch_model)) + assert any(name == 'conv1.weight' for name in params) + mock_is_mindtorch.return_value = False + params = dict(get_parameters(self.ms_model)) + assert any('conv1.weight' in name for name in params) + + + def test_get_parameters_if_model_is_invalid_then_return_empty_dict(self): + assert get_parameters(123) == {} + + def test_get_rank_if_comm_initialized_then_return_integer(self): + rank = get_rank() + assert isinstance(rank, int) + assert rank >= 0 + + def test_comm_is_initialized_when_called_then_return_boolean(self): + assert isinstance(comm_is_initialized(), bool) + + +# Test models +class MSModel(nn.Cell): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(3, 64, 3, has_bias=True, weight_init=Normal(0.02)) + self.bn1 = nn.BatchNorm2d(64) + self.relu = nn.ReLU() + + def construct(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + return x + +if TORCH_AVAILABLE: + class TorchModel(torch_nn.Module): + def __init__(self): + super().__init__() + self.conv1 = torch_nn.Conv2d(3, 64, 3) + self.bn1 = torch_nn.BatchNorm2d(64) + self.relu = torch_nn.ReLU() + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + return x \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/ms_monitor/test_opt_collect.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/ms_monitor/test_opt_collect.py new file mode 100644 index 0000000000000000000000000000000000000000..df3e54f1a173a943ec04d73957f80619547ce977 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/ms_monitor/test_opt_collect.py @@ -0,0 +1,225 @@ +import pytest +import numpy as np +from mindspore import Tensor, nn, ops +from unittest.mock import MagicMock, patch + +from msprobe.core.common.const import MonitorConst +# Import the classes to test +from msprobe.core.common.log import logger +from msprobe.mindspore.monitor.optimizer_collect import ( + OptimizerMon, + MixPrecisionOptimizerMon, + MegatronDistributedOptimizerMon, + MegatronChainedDistributedOptimizerMon, + MegatronChainedMixPrecisionOptimizerMon, + DeepSpeedZeroOptimizerMon, + DeepSpeedZeroOptimizerStage0Mon, + DeepSpeedZeroOptimizerStage1or2Mon, + DeepSpeedZeroOptimizerStage3Mon, + OptimizerMonFactory +) + +class TestOptimizerMon: + @classmethod + def setup_class(cls): + """Setup once for all tests in this class""" + cls.mock_monitor = MagicMock() + cls.mock_monitor.name2tag = {"test_param": {MonitorConst.POST_GRAD: "test_tag"}} + cls.mock_monitor.duplicate_param = {} + cls.mock_monitor.params_have_main_grad = False + cls.mock_monitor.fsdp_wrapped_module = False + cls.mock_monitor.mv_distribution = True + cls.mock_monitor.mg_direction = True + cls.mock_monitor.ur_distribution = True + cls.mock_monitor.update_heatmap_visualizer = {"test_param": MagicMock()} + cls.mock_monitor.ratio_heatmap_visualizer = {"test_param": MagicMock()} + + def test_fetch_grad_if_param_has_valid_grad_then_return_correct_grad_values(self): + # Setup + param = MagicMock() + expected_grad = Tensor([1.0, 2.0, 3.0]) + param.grad = expected_grad + params2name = {param: "test_param"} + optimizer = MagicMock() + mon = OptimizerMon(optimizer) + + # Execute + result = mon.fetch_grad(self.mock_monitor, params2name) + + # Verify + assert len(result) == 1 + assert (result["test_tag"] == expected_grad).all() + self.mock_monitor.register_param_call_id.assert_called_once_with("hook_optimizer", "test_tag") + + def test_fetch_grad_if_param_has_main_grad_then_return_main_grad_values(self): + # Setup + param = MagicMock() + expected_grad = Tensor(np.array([1.5, 2.5])) + param.main_grad = expected_grad + param.grad = None + params2name = {param: "test_param"} + optimizer = MagicMock() + self.mock_monitor.params_have_main_grad = True + mon = OptimizerMon(optimizer) + + # Execute + result = mon.fetch_grad(self.mock_monitor, params2name) + + # Verify + assert len(result) == 1 + assert (result["test_tag"] == expected_grad).all() + + def test_fetch_mv_if_state_complete_then_return_correct_momentum_values(self): + # Setup + param = MagicMock() + params2name = {param: "test_param"} + optimizer = MagicMock() + optimizer.state = { + param: { + "exp_avg": Tensor([0.1]), + "exp_avg_sq": Tensor([0.2]), + "step": 10 + } + } + del optimizer.chained_optimizers + del optimizer.param_to_cpu_states_map + optimizer.defaults = {'betas': (0.9, 0.999), 'eps': 1e-8} + optimizer.param_groups = [{}] + + mon = OptimizerMon(optimizer) + mon.fp16_to_fp32_param = {} + + # Execute + exp_avg, exp_avg_sq, update, ratio = mon.fetch_mv(self.mock_monitor, params2name) + + # Verify + beta1, beta2 = optimizer.defaults['betas'] + step = optimizer.state[param]['step'] + + expected_exp_avg_hat = 0.1 / (1 - beta1**step) + expected_exp_avg_sq_hat = 0.2 / (1 - beta2**step) + expected_update = expected_exp_avg_hat / (np.sqrt(expected_exp_avg_sq_hat) + optimizer.defaults['eps']) + expected_ratio = expected_exp_avg_hat / np.sqrt(expected_exp_avg_sq_hat) + + assert exp_avg["test_param"] == Tensor([0.1]) + assert exp_avg_sq["test_param"] == Tensor([0.2]) + assert update["test_param"] == Tensor([expected_update]) + assert ratio["test_param"] == Tensor([expected_ratio]) + + def test_narrow_from_flatten_if_state_not_partitioned_then_return_original_state(self): + # Setup + param = MagicMock() + flatten_state = Tensor([1.0, 2.0, 3.0]) + mon = OptimizerMon(MagicMock()) + + # Execute + result = mon.narrow_from_flatten(param, flatten_state) + + # Verify + assert (result == flatten_state).all() + +class TestMixPrecisionOptimizerMon: + @classmethod + def setup_class(cls): + cls.mock_monitor = MagicMock() + cls.mock_monitor.mv_distribution = True + cls.mock_monitor.mg_direction = True + cls.mock_monitor.ur_distribution = True + cls.mock_monitor.update_heatmap_visualizer = {'param1': MagicMock(), 'param2': MagicMock()} + cls.mock_monitor.ratio_heatmap_visualizer = {'param1': MagicMock(), 'param2': MagicMock()} + + def test_map_fp16_to_fp32_param_if_multiple_groups_then_create_correct_mappings(self): + # Setup + optimizer = MagicMock() + fp16_params = [MagicMock(), MagicMock(), MagicMock()] + fp32_params = [MagicMock(), MagicMock(), MagicMock()] + optimizer.float16_groups = [fp16_params[:2], [fp16_params[2]]] + optimizer.fp32_from_float16_groups = [fp32_params[:2], [fp32_params[2]]] + + mon = MixPrecisionOptimizerMon(optimizer) + + # Execute + mon.map_fp16_to_fp32_param(optimizer) + + # Verify + assert len(mon.fp16_to_fp32_param) == 3 + for fp16, fp32 in zip(fp16_params, fp32_params): + assert mon.fp16_to_fp32_param[fp16] == fp32 + +class TestDeepSpeedZeroOptimizerStage1or2Mon: + @classmethod + def setup_class(cls): + """Setup once for all tests in this class""" + cls.mock_monitor = MagicMock() + cls.mock_monitor.name2tag = {"test_param": {MonitorConst.POST_GRAD: "test_tag"}} + cls.mock_monitor.duplicate_param = {} + cls.mock_monitor.params_have_main_grad = False + cls.mock_monitor.mg_direction = True + cls.mock_monitor.ur_distribution = True + + def test_fetch_grad_if_param_in_partition_then_return_correct_grad_slice(self): + # Setup + optimizer = MagicMock() + param = MagicMock() + params2name = {param: "test_param"} + expected_grad = Tensor(np.array([1.0, 2.0, 3.0])) + param.main_grad = expected_grad + param.grad = None + optimizer.bit16_groups = [[param]] + optimizer.cpu_offload = False + mon = DeepSpeedZeroOptimizerStage1or2Mon(optimizer) + mon.param2group = {param: 0} + mon.get_param_index = MagicMock(return_value=1) + mon.param_not_in_partition = MagicMock(return_value=False) + mon.get_position = MagicMock(return_value=(3, 3)) # start at index 3, length 3 + + # MagicMock the averaged_gradients structure + optimizer.averaged_gradients = { + 0: [ + None, # index 0 + Tensor(np.array([1.0, 2.0, 3.0])) # index 1 + ] + } + + # Execute + result = mon.fetch_grad(self.mock_monitor, params2name) + + # Verify + assert len(result) == 1 + assert (result["test_tag"] == expected_grad).all() + +class TestOptimizerMonFactory: + @classmethod + def setup_class(cls): + cls.mock_monitor = MagicMock() + cls.mock_monitor.mv_distribution = True + cls.mock_monitor.mg_direction = True + cls.mock_monitor.ur_distribution = True + cls.mock_monitor.update_heatmap_visualizer = {'param1': MagicMock(), 'param2': MagicMock()} + cls.mock_monitor.ratio_heatmap_visualizer = {'param1': MagicMock(), 'param2': MagicMock()} + + def test_create_optimizer_mon_if_chained_optimizer_then_return_correct_monitor_type(self): + # Setup + base_optimizer = MagicMock() + base_optimizer.__class__.__name__ = "DistributedOptimizer" + optimizer = MagicMock() + optimizer.__class__.__name__ = "ChainedOptimizer" + optimizer.chained_optimizers = [base_optimizer] + + # Execute + result = OptimizerMonFactory.create_optimizer_mon(optimizer) + + # Verify + assert isinstance(result, MegatronChainedDistributedOptimizerMon) + + def test_create_optimizer_mon_if_deepspeed_stage3_then_return_stage3_monitor(self): + # Setup + optimizer = MagicMock() + optimizer.__class__.__name__ = "DeepSpeedZeroOptimizer_Stage3" + + # Execute + result = OptimizerMonFactory.create_optimizer_mon(optimizer) + + # Verify + assert isinstance(result, DeepSpeedZeroOptimizerStage3Mon) + assert result.stage == '3' diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/save/test_debugger_save_mindspore.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/save/test_debugger_save_mindspore.py new file mode 100644 index 0000000000000000000000000000000000000000..fcefbb8c339ad6de1d14eaae7f75e6947efc5196 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/save/test_debugger_save_mindspore.py @@ -0,0 +1,364 @@ +import unittest +import os +import json +import mindspore +import numpy as np +import shutil +from unittest.mock import patch + +from msprobe.mindspore import PrecisionDebugger +from msprobe.core.data_dump.data_processor.mindspore_processor import MindsporeDataProcessor +from msprobe.mindspore.dump.hook_cell.api_register import get_api_register + +current_file = __file__ +parent_dir = os.path.abspath(os.path.dirname(current_file)) +test_dir = os.path.join(parent_dir, "test_dir") + +def deep_compare(obj1, obj2, float_tolerance=1e-5): + """ + Recursively compare two objects to check if they are the same. + Supports nested dictionaries and lists. + """ + if type(obj1) != type(obj2): + return False + if isinstance(obj1, dict): + if obj1.keys() != obj2.keys(): + return False + return all(deep_compare(obj1[key], obj2[key]) for key in obj1) + if isinstance(obj1, (tuple, list)): + if len(obj1) != len(obj2): + return False + return all(deep_compare(item1, item2) for item1, item2 in zip(obj1, obj2)) + if isinstance(obj1, (int, float)): + return abs(obj1 - obj2) < float_tolerance + return obj1 == obj2 + +class TestDebuggerSave(unittest.TestCase): + @staticmethod + def write_config_json(step, async_dump, mode, dump_path, config_file_path): + task = "tensor" if mode == "tensor" else "statistics" + statistics_summary_mode = "statistics" if mode == "statistics" else "md5" + config = { + "task": task, + "dump_path": dump_path, + "rank": [], + "step": step, + "level": "debug", + "enable_dataloader": False, + "async_dump": async_dump, + "statistics": { + "summary_mode": statistics_summary_mode, + } + } + with open(config_file_path, "w", encoding="utf-8") as f: + json.dump(config, f, indent=4, ensure_ascii=False) + + @staticmethod + def read_debug_json_into_dict(debug_json_path): + with open(debug_json_path, "r", encoding="utf-8") as f: + debug_json = json.load(f) + return debug_json + + @staticmethod + def check_real_npy(npy_path, target_ms_tensor, check_values=True, rtol=1e-5, atol=1e-8): + """ + Enhanced version with optional value comparison. + + Args: + npy_path (str): Path to the .npy file + target_ms_tensor: Target mindspore tensor to compare + check_values (bool): If True, also compare array values + rtol, atol: Relative and absolute tolerances for value comparison + + Returns: + bool: True if all checks pass + """ + # Convert mindspore tensor to numpy if needed + if hasattr(target_ms_tensor, 'numpy'): + target_ms_tensor = target_ms_tensor.numpy() + # Load the npy file + try: + npy_data = np.load(npy_path) + except FileNotFoundError: + print(f"Error: The file {npy_path} does not exist.") + return False + except Exception as e: + print(f"Error loading npy file: {e}") + return False + # Check shapes + if npy_data.shape != target_ms_tensor.shape: + print(f"Shape mismatch: npy data shape is {npy_data.shape}, target tensor shape is {target_ms_tensor.shape}") + return False + # Check dtypes + if npy_data.dtype != target_ms_tensor.dtype: + print(f"Shape mismatch: npy data dtype is {npy_data.dtype}, target tensor dtype is {target_ms_tensor.dtype}") + return False + # Optionally check values + if check_values: + if not np.allclose(npy_data, target_ms_tensor, rtol=rtol, atol=atol): + print("Value mismatch: npy data and target tensor values do not match within the specified tolerances.") + return False + + return True + + def setUp(self): + if not os.path.exists(test_dir): + os.makedirs(test_dir) + PrecisionDebugger._instance = None + self.original_mindspore_special_type = MindsporeDataProcessor.mindspore_special_type + MindsporeDataProcessor.mindspore_special_type = tuple([mindspore.Tensor]) + + def tearDown(self): + if os.path.exists(test_dir): + shutil.rmtree(test_dir) + PrecisionDebugger._instance = None + MindsporeDataProcessor.mindspore_special_type = self.original_mindspore_special_type + get_api_register(True).restore_all_api() + + @patch("msprobe.mindspore.debugger.precision_debugger.set_register_backward_hook_functions") + def test_save_real_tensor(self, _): + data = {"a": mindspore.Tensor([1., 2.])} + step = [] + async_dump = False + mode = "tensor" + dump_path = os.path.join(test_dir, "debug_save") + config_file_path = os.path.join(test_dir, "config.json") + + self.write_config_json(step, async_dump, mode, dump_path, config_file_path) + debugger = PrecisionDebugger(config_file_path) + PrecisionDebugger.save(data, "data_dict", save_backward=False) + PrecisionDebugger.step() + + # check npy file + npy_path = os.path.join(dump_path, "step0", "rank", "dump_tensor_data", "data_dict.0.debug.a.npy") + assert self.check_real_npy(npy_path, data["a"]) + + # check debug json + target_debug_info = { + "a": { + "type": "mindspore.Tensor", + "dtype": "Float32", + "shape": [ + 2 + ], + "Max": 2.0, + "Min": 1.0, + "Mean": 1.5, + "Norm": 2.2360680103302, + "data_name": "data_dict.0.debug.a.npy" + } + } + debug_json_path = os.path.join(dump_path, "step0", "rank", "debug.json") + debug_json_dict = self.read_debug_json_into_dict(debug_json_path) + assert deep_compare(debug_json_dict["data"]["data_dict.0.debug"], target_debug_info) + + @patch("msprobe.mindspore.debugger.precision_debugger.set_register_backward_hook_functions") + def test_save_md5(self, _): + data = {"a": mindspore.Tensor([1., 2.])} + step = [] + async_dump = False + mode = "md5" + dump_path = os.path.join(test_dir, "debug_save") + config_file_path = os.path.join(test_dir, "config.json") + self.write_config_json(step, async_dump, mode, dump_path, config_file_path) + debugger = PrecisionDebugger(config_file_path) + PrecisionDebugger.save(data, "data_dict", save_backward=False) + PrecisionDebugger.step() + # check debug json + target_debug_info = { + "a": { + "type": "mindspore.Tensor", + "dtype": "Float32", + "shape": [ + 2 + ], + "Max": 2.0, + "Min": 1.0, + "Mean": 1.5, + "Norm": 2.2360680103302, + "md5": "2e3fa576" + } + } + debug_json_path = os.path.join(dump_path, "step0", "rank", "debug.json") + debug_json_dict = self.read_debug_json_into_dict(debug_json_path) + assert deep_compare(debug_json_dict["data"]["data_dict.0.debug"], target_debug_info) + + @patch("msprobe.mindspore.debugger.precision_debugger.set_register_backward_hook_functions") + def test_save_multiple_steps(self, _): + data = {"a": mindspore.Tensor([1., 2.])} + step = [0, 1, 2] + async_dump = False + mode = "tensor" + dump_path = os.path.join(test_dir, "debug_save") + config_file_path = os.path.join(test_dir, "config.json") + self.write_config_json(step, async_dump, mode, dump_path, config_file_path) + debugger = PrecisionDebugger(config_file_path) + for _ in step: + PrecisionDebugger.save(data, "data_dict", save_backward=False) + PrecisionDebugger.step() + # check npy file + for i in step: + npy_path = os.path.join(dump_path, f"step{i}", "rank", "dump_tensor_data", "data_dict.0.debug.a.npy") + assert self.check_real_npy(npy_path, data["a"]) + # check debug json + target_debug_info = { + "a": { + "type": "mindspore.Tensor", + "dtype": "Float32", + "shape": [ + 2 + ], + "Max": 2.0, + "Min": 1.0, + "Mean": 1.5, + "Norm": 2.2360680103302, + "data_name": "data_dict.0.debug.a.npy" + } + } + for i in step: + debug_json_path = os.path.join(dump_path, f"step{i}", "rank", "debug.json") + debug_json_dict = self.read_debug_json_into_dict(debug_json_path) + assert deep_compare(debug_json_dict["data"]["data_dict.0.debug"], target_debug_info) + + @patch("msprobe.mindspore.debugger.precision_debugger.set_register_backward_hook_functions") + def test_async_save_tensor(self, _): + data = {"a": mindspore.Tensor([1., 2.])} + step = [] + async_dump = True + mode = "tensor" + dump_path = os.path.join(test_dir, "debug_save") + config_file_path = os.path.join(test_dir, "config.json") + self.write_config_json(step, async_dump, mode, dump_path, config_file_path) + debugger = PrecisionDebugger(config_file_path) + PrecisionDebugger.save(data, "data_dict", save_backward=False) + PrecisionDebugger.step() + # check npy file + npy_path = os.path.join(dump_path, "step0", "rank", "dump_tensor_data", "data_dict.0.debug.a.npy") + assert self.check_real_npy(npy_path, data["a"]) + # check debug json + target_debug_info = { + "a": { + "type": "mindspore.Tensor", + "dtype": "Float32", + "shape": [ + 2 + ], + "data_name": "data_dict.0.debug.a.npy", + "Max": 2.0, + "Min": 1.0, + "Mean": 1.5, + "Norm": 2.2360680103302 + } + } + debug_json_path = os.path.join(dump_path, "step0", "rank", "debug.json") + debug_json_dict = self.read_debug_json_into_dict(debug_json_path) + assert deep_compare(debug_json_dict["data"]["data_dict.0.debug"], target_debug_info) + + @patch("msprobe.mindspore.debugger.precision_debugger.set_register_backward_hook_functions") + def test_async_save_md5(self, _): + # async_dump case, md5 configuration not working,only save statistics + data = {"a": mindspore.Tensor([1., 2.])} + step = [] + async_dump = True + mode = "md5" + dump_path = os.path.join(test_dir, "debug_save") + config_file_path = os.path.join(test_dir, "config.json") + self.write_config_json(step, async_dump, mode, dump_path, config_file_path) + debugger = PrecisionDebugger(config_file_path) + PrecisionDebugger.save(data, "data_dict", save_backward=False) + PrecisionDebugger.step() + # check debug json + target_debug_info = { + "a": { + "type": "mindspore.Tensor", + "dtype": "Float32", + "shape": [ + 2 + ], + "Max": 2.0, + "Min": 1.0, + "Mean": 1.5, + "Norm": 2.2360680103302 + } + } + debug_json_path = os.path.join(dump_path, "step0", "rank", "debug.json") + debug_json_dict = self.read_debug_json_into_dict(debug_json_path) + assert deep_compare(debug_json_dict["data"]["data_dict.0.debug"], target_debug_info) + + @patch("msprobe.mindspore.debugger.precision_debugger.set_register_backward_hook_functions") + def test_save_multiple_times(self, _): + data = {"a": mindspore.Tensor([1., 2.])} + step = [] + call_times = 3 + async_dump = False + mode = "tensor" + dump_path = os.path.join(test_dir, "debug_save") + config_file_path = os.path.join(test_dir, "config.json") + self.write_config_json(step, async_dump, mode, dump_path, config_file_path) + debugger = PrecisionDebugger(config_file_path) + for _ in range(call_times): + PrecisionDebugger.save(data, "data_dict", save_backward=False) + PrecisionDebugger.step() + # check npy file + for i in range(call_times): + npy_path = os.path.join(dump_path, "step0", "rank", "dump_tensor_data", f"data_dict.{i}.debug.a.npy") + assert self.check_real_npy(npy_path, data["a"]) + # check debug json + for i in range(call_times): + target_debug_info = { + "a": { + "type": "mindspore.Tensor", + "dtype": "Float32", + "shape": [ + 2 + ], + "Max": 2.0, + "Min": 1.0, + "Mean": 1.5, + "Norm": 2.2360680103302, + "data_name": f"data_dict.{i}.debug.a.npy" + } + } + debug_json_path = os.path.join(dump_path, "step0", "rank", "debug.json") + debug_json_dict = self.read_debug_json_into_dict(debug_json_path) + assert deep_compare(debug_json_dict["data"][f"data_dict.{i}.debug"], target_debug_info) + + @patch("msprobe.mindspore.debugger.precision_debugger.set_register_backward_hook_functions") + def test_save_compilcated_data_structure(self, _): + x = mindspore.Tensor([1., 2.]) + complicated_structure = [{"a_key": x}] + step = [] + async_dump = False + mode = "tensor" + dump_path = os.path.join(test_dir, "debug_save") + config_file_path = os.path.join(test_dir, "config.json") + self.write_config_json(step, async_dump, mode, dump_path, config_file_path) + debugger = PrecisionDebugger(config_file_path) + PrecisionDebugger.save(complicated_structure, "complicated_structure") + PrecisionDebugger.step() + complicated_structure_info_list = [ + x, + os.path.join(dump_path, "step0", "rank", "dump_tensor_data", "complicated_structure.0.debug.0.a_key.npy"), + "complicated_structure.0.debug", + [ + { + "a_key": { + "type": "mindspore.Tensor", + "dtype": "Float32", + "shape": [ + 2 + ], + "Max": 2.0, + "Min": 1.0, + "Mean": 1.5, + "Norm": 2.2360680103302, + "data_name": "complicated_structure.0.debug.0.a_key.npy" + } + } + ], + ] + debug_json_path = os.path.join(dump_path, "step0", "rank", "debug.json") + debug_json_dict = self.read_debug_json_into_dict(debug_json_path) + target_tensor, target_tensor_path, target_tensor_key, target_tensor_info = complicated_structure_info_list + assert self.check_real_npy(target_tensor_path, target_tensor) + assert deep_compare(debug_json_dict["data"][target_tensor_key], target_tensor_info) \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_cell_processor.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_cell_processor.py index 40f5c0164115e18cdd49c046ce29967e7a3f63eb..a3687705ea47b2cf3e031d9d71328374908bccfa 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_cell_processor.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_cell_processor.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -16,132 +16,349 @@ import unittest from unittest.mock import MagicMock, patch +import mindspore as ms +from mindspore import Tensor +from mindspore.ops.operations import _inner_ops + from msprobe.core.common.const import Const +from msprobe.core.common.exceptions import MsprobeException from msprobe.core.data_dump.scope import ModuleRangeScope -from msprobe.mindspore.cell_processor import CellProcessor - - -class MockCell: - def __init__(self): - self.mindstudio_reserved_name = None +from msprobe.core.hook_manager import HookSet +from msprobe.mindspore.cell_processor import CellProcessor, get_cell_construct +from msprobe.mindspore.common.log import logger +from msprobe.mindspore.dump.hook_cell.api_register import get_api_register class TestCellProcessor(unittest.TestCase): + @classmethod + def setUpClass(cls): + CellProcessor.reset_cell_stats() + cls.scope = MagicMock(spec=ModuleRangeScope) + cls.processor = CellProcessor(cls.scope) + get_api_register().restore_all_api() + - def setUp(self): - # 重置静态变量 + @classmethod + def tearDownClass(cls): CellProcessor.reset_cell_stats() - self.scope = MagicMock(spec=ModuleRangeScope) - self.processor = CellProcessor(self.scope) - def test_init_with_module_range_scope(self): - self.assertIsInstance(self.processor.scope, ModuleRangeScope) + def test_class_attribute(self): + self.assertTrue(hasattr(CellProcessor, 'cell_count')) + self.assertTrue(hasattr(CellProcessor, 'cell_stack')) + self.assertTrue(hasattr(CellProcessor, 'api_parent_node')) + self.assertTrue(hasattr(CellProcessor, 'module_node')) + self.assertTrue(hasattr(CellProcessor, 'cell_bw_hook_kernels')) + self.assertTrue(hasattr(CellProcessor, 'cell_backward_pre_hook')) + self.assertTrue(hasattr(CellProcessor, 'cell_backward_hook')) - def test_init_with_none_scope(self): + def test__init(self): + self.assertIsInstance(self.processor.scope, ModuleRangeScope) processor = CellProcessor(None) self.assertIsNone(processor.scope) - def test_set_cell_count_new_cell(self): - count = self.processor.set_cell_count("cell1") + def test_get_cell_construct(self): + def construct(self, *args, **kwargs): + return len(args) + + _constrct = get_cell_construct(construct) + ret = _constrct(self, 'argument') + self.assertFalse(hasattr(self, 'msprobe_input_kwargs')) + self.assertEqual(ret, 1) + + setattr(self, 'msprobe_hook', True) + _constrct = get_cell_construct(construct) + ret = _constrct(self, 'argument') + self.assertEqual(self.msprobe_input_kwargs, {}) + self.assertEqual(ret, 1) + + del self.msprobe_hook + del self.msprobe_input_kwargs + + def test_set_and_get_calls_number(self): + CellProcessor.cell_count = {} + count = self.processor.set_and_get_calls_number("cell") self.assertEqual(count, 0) - self.assertEqual(CellProcessor.cell_count["cell1"], 0) + self.assertEqual(CellProcessor.cell_count["cell"], 0) - def test_set_cell_count_existing_cell(self): - self.processor.set_cell_count("cell1") - count = self.processor.set_cell_count("cell1") + count = self.processor.set_and_get_calls_number("cell") self.assertEqual(count, 1) - self.assertEqual(CellProcessor.cell_count["cell1"], 1) + self.assertEqual(CellProcessor.cell_count["cell"], 1) + + CellProcessor.cell_count = {} def test_reset_cell_stats(self): - self.processor.set_cell_count("cell1") + CellProcessor.cell_count['cell'] = 0 + CellProcessor.cell_stack.append('cell') + CellProcessor.api_parent_node = 'cell' + CellProcessor.module_node['cell'] = 'null' + CellProcessor.cell_bw_hook_kernels['cell'] = 'bw' + CellProcessor.cell_backward_pre_hook.append('backward_pre_hook') + CellProcessor.cell_backward_hook.append('backward_hook') + CellProcessor.reset_cell_stats() self.assertEqual(CellProcessor.cell_count, {}) self.assertEqual(CellProcessor.cell_stack, []) - self.assertEqual(CellProcessor.api_parent_node, "") + self.assertIsNone(CellProcessor.api_parent_node) self.assertEqual(CellProcessor.module_node, {}) + self.assertEqual(CellProcessor.cell_bw_hook_kernels, {}) + self.assertEqual(CellProcessor.cell_backward_pre_hook, []) + self.assertEqual(CellProcessor.cell_backward_hook, []) - @patch('msprobe.core.common.const.Const') - def test_node_hook_begin(self, mock_const): - mock_const.SEP = "." # 确保 SEPARATOR 设置为字符串 - mock_const.START = "start" - cell = MockCell() - self.processor.node_hook("prefix", "start")(cell, "input") - - expected_name = "prefix" + mock_const.SEP + "0" - self.assertEqual(cell.mindstudio_reserved_name, expected_name) - self.assertIn(expected_name, CellProcessor.cell_stack) - self.assertEqual(CellProcessor.api_parent_node, expected_name) - self.scope.begin_module.assert_called_once_with(expected_name) - - @patch('msprobe.core.common.const.Const') - def test_node_hook_end(self, mock_const): - mock_const.START = "start" - cell = MockCell() - self.processor.node_hook("prefix", "start")(cell, "input") - self.processor.node_hook("prefix", "stop")(cell, "input", "output") - - self.assertEqual(len(CellProcessor.cell_stack), 0) - self.assertIsNone(CellProcessor.api_parent_node) - self.scope.end_module.assert_called_once_with(cell.mindstudio_reserved_name) + def test_register_cell_hook(self): + with self.assertRaises(MsprobeException) as context: + self.processor.register_cell_hook([], None, 'config') + self.assertEqual(str(context.exception), '[msprobe] 无效参数:The model cannot be None, when level is "L0" or "mix"') - @patch('msprobe.core.common.const.Const') - def test_multiple_node_hook_calls(self, mock_const): - mock_const.SEP = "." # 确保 SEPARATOR 设置为字符串 - mock_const.START = "start" - cell = MockCell() + with patch('msprobe.mindspore.cell_processor.is_mindtorch') as mock_is_mindtorch, \ + patch('msprobe.mindspore.cell_processor.get_cells_and_names_with_index') as mock_get_cells_and_names, \ + patch('msprobe.mindspore.cell_processor.CellProcessor.build_cell_hook') as mock_build_cell_hook, \ + patch('msprobe.mindspore.cell_processor.get_cell_construct') as mock_get_cell_construct, \ + patch('msprobe.mindspore.cell_processor.is_graph_mode_cell_dump_allowed') \ + as mock_is_graph_mode_cell_dump_allowed, \ + patch.object(logger, 'info') as mock_logger_info: + mock_cell = MagicMock() + mock_sub_cell = MagicMock() + mock_get_cells_and_names.return_value = ({'-1': [('cell', mock_cell), ('sub_cell', mock_sub_cell)]}, {}) + mock_build_cell_hook.return_value = 'forward_pre_hook' + mock_get_cell_construct.return_value = '_construct' + mock_is_graph_mode_cell_dump_allowed.return_value = False - # First call - self.processor.node_hook("prefix", "start")(cell, "input") - expected_name1 = "prefix" + mock_const.SEP + "0" + mock_is_mindtorch.return_value = False + setattr(MagicMock, '_run_construct', '_run_construct') + self.processor.register_cell_hook(mock_cell, None, 'config') + self.assertTrue(mock_sub_cell.__class__.msprobe_construct) + mock_get_cell_construct.assert_called_with('_run_construct') + self.assertEqual(mock_sub_cell.__class__._run_construct, '_construct') + self.assertTrue(mock_sub_cell.msprobe_hook) + mock_build_cell_hook.assert_called_with('Cell.sub_cell.MagicMock.', None) + mock_cell.assert_not_called() + mock_sub_cell.register_forward_pre_hook.assert_called_with('forward_pre_hook') + mock_sub_cell.register_forward_hook.assert_not_called() + mock_logger_info.assert_called_with('The cell hook function is successfully mounted to the model.') - # Second call - self.processor.node_hook("prefix", "start")(cell, "input") - expected_name2 = "prefix" + mock_const.SEP + "1" + del MagicMock._run_construct + del mock_sub_cell.__class__._run_construct + del mock_sub_cell.__class__.msprobe_construct - self.assertEqual(cell.mindstudio_reserved_name, expected_name2) - self.assertEqual(CellProcessor.api_parent_node, expected_name2) + mock_get_cell_construct.reset_mock() + mock_another_sub_cell = MagicMock() + setattr(mock_another_sub_cell.__class__, 'msprobe_construct', True) + mock_get_cells_and_names.return_value = ( + {'-1': [('cell', mock_cell), ('another_sub_cell', mock_another_sub_cell)]}, + {} + ) + self.processor.register_cell_hook(mock_cell, None, 'config') + mock_get_cell_construct.assert_not_called() + mock_another_sub_cell.register_forward_pre_hook.assert_called_with('forward_pre_hook') + mock_another_sub_cell.register_forward_hook.assert_not_called() - # End first call - self.processor.node_hook("prefix", "stop")(cell, "input", "output") - self.assertEqual(len(CellProcessor.cell_stack), 1) # Still one item in stack - self.assertEqual(CellProcessor.api_parent_node, expected_name1) + del mock_another_sub_cell.__class__.msprobe_construct - # End second call - self.processor.node_hook("prefix", "stop")(cell, "input", "output") - self.assertEqual(len(CellProcessor.cell_stack), 0) # Stack should be empty now - self.assertIsNone(CellProcessor.api_parent_node) + mock_build_cell_hook.reset_mock() + mock_get_cell_construct.reset_mock() + mock_another_sub_cell.reset_mock() + setattr(MagicMock, '_call_impl', '_call_impl') + mock_is_mindtorch.return_value = True + self.processor.register_cell_hook(mock_cell, None, 'config') + self.assertTrue(mock_another_sub_cell.__class__.msprobe_construct) + mock_get_cell_construct.assert_called_with('_call_impl') + mock_build_cell_hook.assert_called_with('Module.another_sub_cell.MagicMock.', None) + mock_cell.assert_not_called() + mock_another_sub_cell.register_forward_pre_hook.assert_called_with('forward_pre_hook') + mock_another_sub_cell.register_forward_hook.assert_not_called() + + del MagicMock._call_impl + del mock_another_sub_cell.__class__._call_impl + del mock_another_sub_cell.__class__.msprobe_construct - def test_set_and_get_reserved_name(self): - cell = MockCell() - cell.mindstudio_reserved_name = "mindstudio_reserved_name" + def test_build_cell_hook(self): CellProcessor.reset_cell_stats() - cell_name = "Cell.net.Net.forward" - ret = self.processor.set_and_get_reserved_name(cell, cell_name) - self.assertEqual(ret, cell_name + Const.SEP + "0") - self.assertEqual(cell.mindstudio_reserved_name, ret) - self.assertEqual(CellProcessor.cell_count[cell_name], 0) - self.assertFalse(hasattr(cell, "has_pre_hook_called")) - - cell.has_pre_hook_called = False - ret = self.processor.set_and_get_reserved_name(cell, cell_name) - self.assertEqual(ret, cell_name + Const.SEP + "1") - self.assertEqual(cell.mindstudio_reserved_name, ret) - self.assertEqual(CellProcessor.cell_count[cell_name], 1) - self.assertFalse(cell.has_pre_hook_called) - - cell.has_pre_hook_called = True - cell.mindstudio_reserved_name = "mindstudio_reserved_name" + cell_name = 'Cell.cell.Cell.' + mock_build_data_hook = MagicMock() + mock_backward_data_hook = MagicMock() + target_grad_output = (Tensor([0.5]),) + mock_backward_data_hook.return_value = target_grad_output + mock_hook_set = HookSet(backward_hook=mock_backward_data_hook) + mock_build_data_hook.return_value = mock_hook_set + mock_cell = MagicMock() + + with patch.object(_inner_ops, 'CellBackwardHook') as mock_CellBackwardHook: + forward_pre_hook = self.processor.build_cell_hook(cell_name, mock_build_data_hook) + forward_hook = forward_pre_hook.__closure__[2].cell_contents + + mock_bw = mock_CellBackwardHook.return_value + mock_bw.return_value = (Tensor([0.0]),) + args = (Tensor([1.0]),) + target_args = (Tensor([0.0]),) + full_forward_name = f'{cell_name}{Const.FORWARD}.0' + full_backward_name = f'{cell_name}{Const.BACKWARD}.0' + # call testing function - forward_pre_hook + ret = forward_pre_hook(mock_cell, args) + self.assertIsNone(CellProcessor.module_node[full_forward_name]) + self.assertEqual(CellProcessor.cell_stack, [full_forward_name]) + self.assertEqual(CellProcessor.api_parent_node, full_forward_name) + self.scope.begin_module.assert_called_with(full_forward_name) + mock_build_data_hook.assert_called_with('Module', full_forward_name) + self.assertEqual(len(CellProcessor.cell_backward_hook), 1) + mock_CellBackwardHook.assert_called_with(full_backward_name, mock_cell, + CellProcessor.cell_backward_hook[-1]) + mock_bw.register_backward_hook.assert_called_once() + mock_bw.assert_called_with(*args) + self.assertTrue((ret[0] == target_args[0]).all()) + + backward_hook = CellProcessor.cell_backward_hook[-1][full_backward_name] + grad_input = (Tensor([1.0]),) + grad_output = (Tensor([2.0]),) + # call testing function - backward_hook + ret = backward_hook(mock_cell, grad_input, grad_output) + mock_backward_data_hook.assert_called_with(mock_cell, grad_input, grad_output) + self.assertFalse(mock_cell.has_pre_hook_called) + self.assertEqual(CellProcessor.cell_stack, []) + self.assertIsNone(CellProcessor.api_parent_node) + self.scope.end_module.assert_called_with(full_backward_name) + self.assertTrue((ret[0] == target_grad_output[0]).all()) + + mock_build_data_hook.reset_mock() + args = (Tensor([1], dtype=ms.int32),) + full_forward_name = f'{cell_name}{Const.FORWARD}.1' + # call testing function - forward_pre_hook + ret = forward_pre_hook(mock_cell, args) + self.assertIsNone(CellProcessor.module_node[full_forward_name]) + self.assertEqual(CellProcessor.cell_stack, [full_forward_name]) + self.assertEqual(CellProcessor.api_parent_node, full_forward_name) + self.scope.begin_module.assert_called_with(full_forward_name) + self.assertEqual(len(CellProcessor.cell_backward_hook), 1) + mock_build_data_hook.assert_not_called() + + full_forward_name = f'{cell_name}{Const.FORWARD}.0' + CellProcessor.cell_count = {cell_name: 0} + CellProcessor.cell_stack = [full_forward_name] + CellProcessor.api_parent_node = full_forward_name + CellProcessor.module_node = {full_forward_name: None} + self.scope.reset_mock() + mock_CellBackwardHook.reset_mock() + mock_bw.reset_mock() + target_output = Tensor([0.5]) + args = (Tensor([1.0]),) + output = Tensor([2.0]) + mock_bw.return_value = target_output + mock_backward_data_hook.reset_mock() + mock_forward_data_hook = MagicMock() + mock_forward_data_hook.return_value = output + mock_build_data_hook.return_value = HookSet( + forward_hook=mock_forward_data_hook, backward_hook=mock_backward_data_hook + ) + # call testing function - forward_hook + ret = forward_hook(mock_cell, args, output) + self.assertEqual(CellProcessor.cell_count.get(cell_name), 0) + self.assertEqual(CellProcessor.cell_stack, []) + self.assertIsNone(CellProcessor.api_parent_node) + self.scope.end_module.assert_called_with(full_forward_name) + self.assertEqual(mock_bw.call_count, 2) + self.assertEqual(mock_bw.call_args_list[0][0][0], output) + self.assertEqual(mock_bw.call_args_list[1][0][0], target_output) + self.assertEqual(mock_CellBackwardHook.call_count, 1) + self.assertEqual(len(CellProcessor.cell_backward_pre_hook), 1) + self.assertTrue((ret == target_output).all()) + + backward_pre_hook = CellProcessor.cell_backward_pre_hook[-1][full_backward_name] + mock_backward_data_hook.reset_mock() + grad_output = (Tensor([2.0]),) + # call testing function - backward_pre_hook + ret = backward_pre_hook(mock_cell, grad_output) + self.assertTrue(mock_cell.has_pre_hook_called) + self.scope.begin_module.assert_called_with(full_backward_name) + self.assertEqual(CellProcessor.cell_stack, [full_backward_name]) + self.assertEqual(CellProcessor.api_parent_node, full_backward_name) + self.assertEqual(CellProcessor.module_node, {full_forward_name: None, full_backward_name: None}) + self.scope.begin_module.assert_called_with(full_backward_name) + mock_backward_data_hook.assert_not_called() + self.assertIsNone(ret) + + CellProcessor.cell_count = {cell_name: 0} + CellProcessor.cell_stack = [full_forward_name] + CellProcessor.api_parent_node = full_forward_name + CellProcessor.module_node = {full_forward_name: None} + mock_bw.reset_mock() + args = (Tensor([1.0]),) + output = (Tensor([2.0]),) + mock_forward_data_hook.return_value = output + target_output = (Tensor([0.5]),) + # call testing function - forward_hook + ret = forward_hook(mock_cell, args, output) + self.assertEqual(mock_bw.call_count, 2) + self.assertEqual(mock_bw.call_args_list[0][0][0], *output) + self.assertEqual(mock_bw.call_args_list[1][0][0], mock_bw.return_value) + self.assertTrue((ret[0] == target_output[0]).all()) + + CellProcessor.cell_count = {cell_name: 0} + CellProcessor.cell_stack = [full_forward_name] + CellProcessor.api_parent_node = full_forward_name + CellProcessor.module_node = {full_forward_name: None} + CellProcessor.cell_bw_hook_kernels.clear() + CellProcessor.cell_backward_pre_hook.clear() + mock_bw.reset_mock() + mock_bw.return_value = (Tensor([0.5]),) + output = (Tensor([1.0]), Tensor([2.0])) + mock_forward_data_hook.return_value = output + with self.assertRaises(TypeError) as context: + # call testing function - forward_hook + forward_hook(mock_cell, args, output) + self.assertEqual(str(context.exception), + 'The backward pre hook return value size is 1 not equal to output size 2') + mock_bw.assert_called_with(*output) + + self.scope.reset_mock() + backward_pre_hook = CellProcessor.cell_backward_pre_hook[-1][full_backward_name] + # call testing function - backward_pre_hook + ret = backward_pre_hook(mock_cell, grad_output) + self.assertFalse(mock_cell.has_pre_hook_called) + self.scope.begin_module.assert_called_with(full_backward_name) + mock_backward_data_hook.assert_called_with(mock_cell, (), grad_output) + self.assertEqual(CellProcessor.cell_stack, []) + self.assertIsNone(CellProcessor.api_parent_node) + self.assertEqual(CellProcessor.module_node, {full_forward_name: None, full_backward_name: None}) + self.scope.end_module.assert_called_with(full_backward_name) + self.assertIsNone(ret) + + CellProcessor.reset_cell_stats() + + def test_set_construct_info_in_pre_hook(self): CellProcessor.reset_cell_stats() - ret = self.processor.set_and_get_reserved_name(cell, cell_name) - self.assertEqual(ret, "mindstudio_reserved_name") - self.assertEqual(cell.mindstudio_reserved_name, ret) - self.assertEqual(CellProcessor.cell_count, {}) - self.assertFalse(cell.has_pre_hook_called) + self.processor.set_construct_info_in_pre_hook('full_name') + self.assertEqual(CellProcessor.module_node['full_name'], None) + self.assertEqual(CellProcessor.cell_stack, ['full_name']) + self.assertEqual(CellProcessor.api_parent_node, 'full_name') + self.scope.begin_module.assert_called_with('full_name') + + self.scope.begin_module.reset_mock() + self.processor.set_construct_info_in_pre_hook('sub_cell_name') + self.assertEqual(CellProcessor.module_node, {'full_name': None, 'sub_cell_name': 'full_name'}) + self.assertEqual(CellProcessor.cell_stack, ['full_name', 'sub_cell_name']) + self.assertEqual(CellProcessor.api_parent_node, 'sub_cell_name') + self.scope.begin_module.assert_called_with('sub_cell_name') + + CellProcessor.reset_cell_stats() + + def test_set_construct_info_in_hook(self): + CellProcessor.reset_cell_stats() + self.processor.set_construct_info_in_hook('full_name') + self.assertIsNone(CellProcessor.api_parent_node) + self.scope.end_module.assert_called_with('full_name') + + self.scope.end_module.reset_mock() + CellProcessor.cell_stack = ['full_name'] + self.processor.set_construct_info_in_hook('full_name') + self.assertEqual(CellProcessor.cell_stack, []) + self.assertIsNone(CellProcessor.api_parent_node) + self.scope.end_module.assert_called_with('full_name') + + self.scope.end_module.reset_mock() + CellProcessor.cell_stack = ['Cell.0', 'Cell.1'] + self.processor.set_construct_info_in_hook('full_name') + self.assertEqual(CellProcessor.cell_stack, ['Cell.0']) + self.assertEqual(CellProcessor.api_parent_node, 'Cell.0') + self.scope.end_module.assert_called_with('full_name') - ret = self.processor.set_and_get_reserved_name(cell, cell_name, is_called_by_pre_hook=True) - self.assertEqual(ret, cell_name + Const.SEP + "0") - self.assertEqual(cell.mindstudio_reserved_name, ret) - self.assertEqual(CellProcessor.cell_count[cell_name], 0) - self.assertTrue(cell.has_pre_hook_called) CellProcessor.reset_cell_stats() diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_dump_tool_factory.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_dump_tool_factory.py index 8f5d207c41923175b6efe4f9dc313896f879fd89..ce733487b77caf157784d03fe470ffb84929975c 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_dump_tool_factory.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_dump_tool_factory.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -16,16 +16,19 @@ from unittest import TestCase from unittest.mock import patch +from msprobe.core.common.log import logger from msprobe.core.common_config import CommonConfig, BaseConfig from msprobe.core.common.const import Const as CoreConst from msprobe.mindspore.common.const import Const from msprobe.mindspore.debugger.debugger_config import DebuggerConfig from msprobe.mindspore.dump.dump_tool_factory import DumpToolFactory +from msprobe.mindspore.ms_config import StatisticsConfig class TestDumpToolFactory(TestCase): + @patch.object(logger, "error") @patch("msprobe.mindspore.debugger.debugger_config.create_directory") - def test_create(self, _): + def test_create(self, _, mock_logger_error): json_config = { "task": "statistics", "dump_path": "/absolute_path", @@ -35,7 +38,7 @@ class TestDumpToolFactory(TestCase): } common_config = CommonConfig(json_config) - task_config = BaseConfig(json_config) + task_config = StatisticsConfig(json_config) config = DebuggerConfig(common_config, task_config) config.data_mode = [CoreConst.INPUT, CoreConst.OUTPUT] @@ -54,18 +57,12 @@ class TestDumpToolFactory(TestCase): DumpToolFactory.create(config) self.assertEqual(str(context.exception), "Valid level is needed.") - config.level = Const.KERNEL - with self.assertRaises(Exception) as context: - DumpToolFactory.create(config) - self.assertEqual(str(context.exception), "Data dump is not supported in None mode when dump level is kernel.") - config.execution_mode = Const.GRAPH_GE_MODE config.level = Const.CELL with self.assertRaises(Exception) as context: DumpToolFactory.create(config) - self.assertEqual(str(context.exception), "Data dump is not supported in graph_ge mode when dump level is cell.") + self.assertEqual(str(context.exception), "The model is empty and cell dump is not enabled.") - config.execution_mode = Const.GRAPH_KBYK_MODE config.level = Const.KERNEL dumper = DumpToolFactory.create(config) - self.assertEqual(dumper.dump_json["common_dump_settings"]["net_name"], "Net") + self.assertIsInstance(dumper, tuple) diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_kernel_graph_dump.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_kernel_graph_dump.py index 329274b19d862c8c0e50af0fdbd051909e6a60d6..e607f2c2a8a701417ae28a6c353349d1430d5e98 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_kernel_graph_dump.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_kernel_graph_dump.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,6 +14,7 @@ # limitations under the License. import os +import sys from unittest import TestCase from unittest.mock import patch @@ -21,6 +22,7 @@ from unittest.mock import patch from msprobe.core.common_config import CommonConfig, BaseConfig from msprobe.mindspore.debugger.debugger_config import DebuggerConfig from msprobe.mindspore.dump.kernel_graph_dump import KernelGraphDump +from msprobe.core.common.file_utils import move_file class TestKernelGraphDump(TestCase): @@ -44,10 +46,26 @@ class TestKernelGraphDump(TestCase): self.assertEqual(dumper.dump_json["common_dump_settings"]["file_format"], "bin") self.assertEqual(dumper.dump_json["common_dump_settings"]["input_output"], 2) + _msprobe_c_existed = True + try: + from msprobe.lib import _msprobe_c + except ImportError: + _msprobe_c_existed = False + with patch("msprobe.mindspore.dump.kernel_graph_dump.create_directory"), \ patch("msprobe.mindspore.dump.kernel_graph_dump.logger.info"), \ patch("msprobe.mindspore.dump.kernel_graph_dump.save_json") as mock_save_json: + if _msprobe_c_existed: + dumper.handle() + mock_save_json.assert_not_called() + + _msprobe_c_path = _msprobe_c.__file__ + _msprobe_c_test_path = _msprobe_c_path.replace('_msprobe_c.so', '_msprobe_c_test.so') + move_file(_msprobe_c_path, _msprobe_c_test_path) + sys.modules.pop('msprobe.lib') + sys.modules.pop('msprobe.lib._msprobe_c') + os.environ["GRAPH_OP_RUN"] = "1" with self.assertRaises(Exception) as context: dumper.handle() @@ -63,3 +81,5 @@ class TestKernelGraphDump(TestCase): del os.environ["MINDSPORE_DUMP_CONFIG"] if "MS_ACL_DUMP_CFG_PATH" in os.environ: del os.environ["MS_ACL_DUMP_CFG_PATH"] + if _msprobe_c_existed: + move_file(_msprobe_c_test_path, _msprobe_c_path) diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py index b484bc9b7cdceec3b8906600b16b2d4fdc6b1b5e..b3b2466e8104de3735cd1f1d1de17a79e52bd38e 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,6 +14,7 @@ # limitations under the License. import os +import sys from unittest import TestCase from unittest.mock import patch @@ -21,6 +22,7 @@ from unittest.mock import patch from msprobe.core.common_config import CommonConfig, BaseConfig from msprobe.mindspore.debugger.debugger_config import DebuggerConfig from msprobe.mindspore.overflow_check.kernel_graph_overflow_check import KernelGraphOverflowCheck +from msprobe.core.common.file_utils import move_file class TestKernelGraphOverflowCheck(TestCase): @@ -41,11 +43,27 @@ class TestKernelGraphOverflowCheck(TestCase): checker = KernelGraphOverflowCheck(config) self.assertEqual(checker.dump_json["common_dump_settings"]["op_debug_mode"], 2) + _msprobe_c_existed = True + try: + from msprobe.lib import _msprobe_c + except ImportError: + _msprobe_c_existed = False + os.environ["MS_ACL_DUMP_CFG_PATH"] = "path" with patch("msprobe.mindspore.overflow_check.kernel_graph_overflow_check.create_directory"), \ patch("msprobe.mindspore.overflow_check.kernel_graph_overflow_check.logger.info"), \ patch("msprobe.mindspore.overflow_check.kernel_graph_overflow_check.save_json") as mock_save_json: + if _msprobe_c_existed: + checker.handle() + mock_save_json.assert_not_called() + + _msprobe_c_path = _msprobe_c.__file__ + _msprobe_c_test_path = _msprobe_c_path.replace('_msprobe_c.so', '_msprobe_c_test.so') + move_file(_msprobe_c_path, _msprobe_c_test_path) + sys.modules.pop('msprobe.lib') + sys.modules.pop('msprobe.lib._msprobe_c') + os.environ["GRAPH_OP_RUN"] = "1" with self.assertRaises(Exception) as context: checker.handle() @@ -60,3 +78,5 @@ class TestKernelGraphOverflowCheck(TestCase): if "MINDSPORE_DUMP_CONFIG" in os.environ: del os.environ["MINDSPORE_DUMP_CONFIG"] + if _msprobe_c_existed: + move_file(_msprobe_c_test_path, _msprobe_c_path) diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_kernel_kbyk_dump.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_kernel_kbyk_dump.py index c52ea4de2adef5a3c579c3deceece9d84b89309c..9be887eb4be1be00d68b12ca5181f7371dcf075e 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_kernel_kbyk_dump.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_kernel_kbyk_dump.py @@ -21,6 +21,7 @@ from unittest.mock import patch from msprobe.core.common_config import CommonConfig, BaseConfig from msprobe.mindspore.debugger.debugger_config import DebuggerConfig +from msprobe.mindspore.ms_config import StatisticsConfig from msprobe.mindspore.dump.kernel_kbyk_dump import KernelKbykDump @@ -36,7 +37,7 @@ class TestKernelKbykDump(TestCase): } common_config = CommonConfig(json_config) - task_config = BaseConfig(json_config) + task_config = StatisticsConfig(json_config) config = DebuggerConfig(common_config, task_config) dumper = KernelKbykDump(config) self.assertEqual(dumper.dump_json["common_dump_settings"]["iteration"], "0|2") @@ -53,6 +54,138 @@ class TestKernelKbykDump(TestCase): if "MINDSPORE_DUMP_CONFIG" in os.environ: del os.environ["MINDSPORE_DUMP_CONFIG"] + @patch("msprobe.mindspore.debugger.debugger_config.create_directory") + def test_handle_when_async_dump_then_pass(self, _): + json_config = { + "task": "statistics", + "dump_path": "/absolute_path", + "rank": [], + "step": [0, 2], + "level": "L2", + "async_dump": True + } + + common_config = CommonConfig(json_config) + task_config = StatisticsConfig(json_config) + config = DebuggerConfig(common_config, task_config) + dumper = KernelKbykDump(config) + self.assertEqual(dumper.dump_json["e2e_dump_settings"]["enable"], False) + + os.environ["MS_ACL_DUMP_CFG_PATH"] = "path" + with patch("msprobe.mindspore.dump.kernel_kbyk_dump.create_directory"), \ + patch("msprobe.mindspore.dump.kernel_kbyk_dump.logger.info") as mock_info, \ + patch("msprobe.mindspore.dump.kernel_kbyk_dump.save_json") as mock_save_json: + dumper.handle() + self.assertIn("kernel_kbyk_dump.json", mock_save_json.call_args_list[0][0][0]) + mock_info.assert_called_with("/absolute_path/kernel_kbyk_dump.json has been created.") + + self.assertEqual(os.environ.get("MS_ACL_DUMP_CFG_PATH"), None) + if "MINDSPORE_DUMP_CONFIG" in os.environ: + del os.environ["MINDSPORE_DUMP_CONFIG"] + + @patch("msprobe.mindspore.debugger.debugger_config.create_directory") + def test_handle_when_device_then_pass(self, _): + json_config = { + "task": "statistics", + "dump_path": "/absolute_path", + "rank": [], + "step": [0, 2], + "level": "L2", + "statistics": { + "list": [], + "data_mode": ["all"], + "device": "device", + "summary_mode": "statistics" + } + } + + common_config = CommonConfig(json_config) + task_config = StatisticsConfig(json_config["statistics"]) + config = DebuggerConfig(common_config, task_config) + dumper = KernelKbykDump(config) + self.assertEqual(dumper.dump_json["e2e_dump_settings"]["stat_calc_mode"], "device") + + os.environ["MS_ACL_DUMP_CFG_PATH"] = "path" + with patch("msprobe.mindspore.dump.kernel_kbyk_dump.create_directory"), \ + patch("msprobe.mindspore.dump.kernel_kbyk_dump.logger.info") as mock_info, \ + patch("msprobe.mindspore.dump.kernel_kbyk_dump.save_json") as mock_save_json: + dumper.handle() + self.assertIn("kernel_kbyk_dump.json", mock_save_json.call_args_list[0][0][0]) + mock_info.assert_called_with("/absolute_path/kernel_kbyk_dump.json has been created.") + + self.assertEqual(os.environ.get("MS_ACL_DUMP_CFG_PATH"), None) + if "MINDSPORE_DUMP_CONFIG" in os.environ: + del os.environ["MINDSPORE_DUMP_CONFIG"] + + @patch("msprobe.mindspore.debugger.debugger_config.create_directory") + def test_handle_when_precision_then_pass(self, _): + json_config = { + "task": "statistics", + "dump_path": "/absolute_path", + "rank": [], + "step": [0, 2], + "level": "L2", + "statistics": { + "list": [], + "data_mode": ["all"], + "precision": "low", + "summary_mode": "statistics" + } + } + + common_config = CommonConfig(json_config) + task_config = StatisticsConfig(json_config["statistics"]) + config = DebuggerConfig(common_config, task_config) + dumper = KernelKbykDump(config) + self.assertEqual(dumper.dump_json["e2e_dump_settings"]["device_stat_precision_mode"], "low") + + os.environ["MS_ACL_DUMP_CFG_PATH"] = "path" + with patch("msprobe.mindspore.dump.kernel_kbyk_dump.create_directory"), \ + patch("msprobe.mindspore.dump.kernel_kbyk_dump.logger.info") as mock_info, \ + patch("msprobe.mindspore.dump.kernel_kbyk_dump.save_json") as mock_save_json: + dumper.handle() + self.assertIn("kernel_kbyk_dump.json", mock_save_json.call_args_list[0][0][0]) + mock_info.assert_called_with("/absolute_path/kernel_kbyk_dump.json has been created.") + + self.assertEqual(os.environ.get("MS_ACL_DUMP_CFG_PATH"), None) + if "MINDSPORE_DUMP_CONFIG" in os.environ: + del os.environ["MINDSPORE_DUMP_CONFIG"] + + @patch("msprobe.mindspore.debugger.debugger_config.create_directory") + def test_handle_when_default_then_pass(self, _): + json_config = { + "task": "statistics", + "dump_path": "/absolute_path", + "rank": [], + "step": [0, 2], + "level": "L2", + "statistics": { + "list": [], + "data_mode": ["all"], + "summary_mode": "statistics" + } + } + + common_config = CommonConfig(json_config) + task_config = StatisticsConfig(json_config) + config = DebuggerConfig(common_config, task_config) + dumper = KernelKbykDump(config) + self.assertEqual(dumper.dump_json["e2e_dump_settings"]["device_stat_precision_mode"], "high") + self.assertEqual(dumper.dump_json["e2e_dump_settings"]["stat_calc_mode"], "host") + self.assertEqual(dumper.dump_json["e2e_dump_settings"]["enable"], True) + + os.environ["MS_ACL_DUMP_CFG_PATH"] = "path" + with patch("msprobe.mindspore.dump.kernel_kbyk_dump.create_directory"), \ + patch("msprobe.mindspore.dump.kernel_kbyk_dump.logger.info") as mock_info, \ + patch("msprobe.mindspore.dump.kernel_kbyk_dump.save_json") as mock_save_json: + dumper.handle() + self.assertIn("kernel_kbyk_dump.json", mock_save_json.call_args_list[0][0][0]) + mock_info.assert_called_with("/absolute_path/kernel_kbyk_dump.json has been created.") + + self.assertEqual(os.environ.get("MS_ACL_DUMP_CFG_PATH"), None) + if "MINDSPORE_DUMP_CONFIG" in os.environ: + del os.environ["MINDSPORE_DUMP_CONFIG"] + @patch("msprobe.mindspore.debugger.debugger_config.create_directory") def test_handle_tensor(self, _): json_config = { diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_config.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_config.py index 7717f9c336202d67ee524f59c3c5f328e70a045f..9320f49ad7e1719288f675d699c3086a3ac53671 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_config.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_config.py @@ -17,32 +17,11 @@ import unittest from unittest.mock import patch from msprobe.core.common.const import Const -from msprobe.mindspore.ms_config import (parse_json_config, parse_task_config, +from msprobe.mindspore.ms_config import (parse_task_config, TensorConfig, StatisticsConfig, OverflowCheckConfig, FreeBenchmarkConfig) class TestMsConfig(unittest.TestCase): - def test_parse_json_config(self): - mock_json_data = { - "dump_path": "./dump/", - "rank": [], - "step": [], - "level": "L1", - "statistics": { - "scope": [], - "list": [], - "data_mode": ["all"], - "summary_mode": "statistics" - } - } - with patch("msprobe.mindspore.ms_config.load_json", return_value=mock_json_data): - common_config, task_config = parse_json_config("./config.json") - self.assertEqual(common_config.task, Const.STATISTICS) - self.assertEqual(task_config.data_mode, ["all"]) - - with self.assertRaises(Exception) as context: - parse_json_config(None) - self.assertEqual(str(context.exception), "json file path is None") def test_parse_task_config(self): mock_json_config = { diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_debug_save.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_debug_save.py deleted file mode 100644 index 495eedbf41384f820c2ca054fd73192d1966a8bd..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_debug_save.py +++ /dev/null @@ -1,77 +0,0 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from unittest import TestCase -from unittest.mock import patch -import mindspore - -from msprobe.mindspore import PrecisionDebugger -from msprobe.core.common_config import CommonConfig, BaseConfig - -class TestMindsporeDebuggerSave(TestCase): - def setUp(self): - PrecisionDebugger._instance = None - mindspore.set_context(mode=mindspore.PYNATIVE_MODE) - statistics_task_json = { - "task": "statistics", - "dump_path": "./dump_path", - "rank": [], - "step": [], - "level": "debug", - "enable_dataloader": False, - "statistics": { - "summary_mode": "statistics" - } - } - common_config = CommonConfig(statistics_task_json) - task_config = BaseConfig(statistics_task_json) - with patch("msprobe.mindspore.debugger.precision_debugger.parse_json_config", return_value=(common_config, task_config)), \ - patch("msprobe.mindspore.debugger.precision_debugger.set_register_backward_hook_functions"): - self.debugger = PrecisionDebugger() - - def test_forward_and_backward(self): - def forward_func(x, y): - PrecisionDebugger.save(x, "x_tensor") - return x * y - x = mindspore.Tensor([1.]) - y = mindspore.Tensor([2.]) - result_json = { - "task": "statistics", - "level": "debug", - "framework": "mindspore", - "dump_data_dir": None, - "data": { - "x_tensor.0": { - "type": "mindspore.Tensor", - "dtype": "Float32", - "shape": (1,), - "Max": 1.0, - "Min": 1.0, - "Mean": 1.0, - "Norm": 1.0 - }, - "x_tensor_grad.0": { - "type": "mindspore.Tensor", - "dtype": "Float32", - "shape": (1,), - "Max": 2.0, - "Min": 2.0, - "Mean": 2.0, - "Norm": 2.0 - } - } - } - grad_fn = mindspore.value_and_grad(forward_func, (0, 1)) - grad_fn(x, y) - self.assertEqual(self.debugger.service.data_collector.data_writer.cache_debug, result_json) \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_service.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_service.py index 912830ea1ab705aae63c69f5c240887d4b4ce5b7..1d777a52752b9d8273f8e3604269af0383677e18 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_service.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_service.py @@ -1,7 +1,7 @@ -# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# Copyright (c) 2025, Huawei Technologies Co., Ltd. # All rights reserved. # -# Licensed under the Apache License, Version 2.0 (the "License"); +# Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # @@ -11,291 +11,148 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -# limitations under the License. language governing permissions and # limitations under the License. -import unittest from collections import defaultdict +import unittest from unittest.mock import MagicMock, patch - -from mindspore import nn, ops - -from msprobe.core.common.exceptions import MsprobeException -from msprobe.core.common.utils import Const, DumpPathAggregation -from msprobe.core.data_dump.scope import BaseScope -from msprobe.mindspore.cell_processor import CellProcessor -from msprobe.mindspore.common.log import logger -from msprobe.mindspore.common.utils import register_backward_hook_functions -from msprobe.mindspore.dump.hook_cell.api_registry import ApiRegistry, api_register -from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell from msprobe.mindspore.dump.jit_dump import JitDump -from msprobe.mindspore.service import Service +from msprobe.mindspore.mindspore_service import MindsporeService +from msprobe.core.common.utils import Const +from mindspore import ops +try: + from mindspore.common._pijit_context import PIJitCaptureContext +except ImportError: + pijit_label = False +else: + pijit_label = True -class TestService(unittest.TestCase): + +class TestMindsporeService(unittest.TestCase): def setUp(self): - self.config_mock = MagicMock() - self.config_mock.level_ori = Const.LEVEL_L0 - self.config_mock.dump_path = "/tmp/dump" - self.config_mock.step = [] - self.config_mock.rank = [] - self.config_mock.task = Const.TENSOR - self.config_mock.framework = Const.MS_FRAMEWORK - self.config_mock.list = [] - self.config_mock.scope = [] - self.service = Service(self.config_mock) - self.service.model = MagicMock(spec=nn.Cell) + + self.config = MagicMock() + self.config.step = [] + self.config.rank = [] + self.config.level_ori = Const.LEVEL_MIX + self.config.task = Const.STATISTICS + + with patch('msprobe.core.service.build_data_collector'): + self.service = MindsporeService(self.config) + + self.service.logger = MagicMock() self.service.data_collector = MagicMock() self.service.primitive_hook_service = MagicMock() - - def tearDown(self) -> None: - api_register.api_set_ori_func() - - def test_init(self): - self.assertEqual(self.service.config.level, "L0") - self.assertFalse(self.service.switch) - self.assertFalse(self.service.should_stop_service) - self.assertFalse(self.service.start_call) - self.assertTrue(self.service.first_start) - - def test_check_model_valid_with_valid_cell(self): - model = nn.Cell() - model_list = [model] - self.assertEqual(self.service.check_model_valid(model), model) - self.assertEqual(self.service.check_model_valid(model_list), model_list) - - def test_check_model_valid_with_invalid_type(self): - model = nn.Cell() - with self.assertRaises(MsprobeException): - self.service.check_model_valid("not a cell") - with self.assertRaises(MsprobeException): - self.service.check_model_valid(["not a cell", model]) - - def test_update_primitive_counters(self): - self.service.primitive_counters = {} - self.service.update_primitive_counters("conv2d") - self.assertEqual(self.service.primitive_counters["conv2d"], 0) - self.service.update_primitive_counters("conv2d") - self.assertEqual(self.service.primitive_counters["conv2d"], 1) - - @patch('msprobe.mindspore.service.create_directory') - def test_create_dirs(self, mock_create_directory): - self.service.current_iter = 1 - self.service.current_rank = 0 - self.service.data_collector.tasks_need_tensor_data = [Const.TENSOR] - self.service.data_collector.update_dump_paths = MagicMock() - self.service.create_dirs() - expected_calls = [ - ("/tmp/dump"), - ("/tmp/dump/step1/rank0"), - "/tmp/dump/step1/rank0/dump_tensor_data" - ] - mock_create_directory.assert_has_calls( - [unittest.mock.call(path) for path in expected_calls], any_order=True) - - args, _ = self.service.data_collector.update_dump_paths.call_args - self.assertEqual(args[0].dump_file_path, "/tmp/dump/step1/rank0/dump.json") - self.assertEqual(args[0].stack_file_path, "/tmp/dump/step1/rank0/stack.json") - self.assertEqual(args[0].construct_file_path, "/tmp/dump/step1/rank0/construct.json") - self.assertEqual(args[0].dump_tensor_data_dir, "/tmp/dump/step1/rank0/dump_tensor_data") - self.service.data_collector.initialize_json_file.assert_called_once_with( - framework=Const.MS_FRAMEWORK + self.service.cell_processor = MagicMock() + self.service.api_register = MagicMock() + + @patch('msprobe.mindspore.mindspore_service.is_mindtorch') + def test_framework_type(self, mock_is_mindtorch): + mock_is_mindtorch.return_value = True + self.assertEqual(self.service._get_framework_type, Const.MT_FRAMEWORK) + mock_is_mindtorch.return_value = False + self.assertEqual(self.service._get_framework_type, Const.MS_FRAMEWORK) + + @patch('msprobe.mindspore.mindspore_service.get_rank_if_initialized') + def test_get_current_rank(self, mock_get_rank): + mock_get_rank.return_value = 3 + self.assertEqual(MindsporeService._get_current_rank(), 3) + + def test_init_specific_components(self): + with patch('msprobe.core.service.build_data_collector'): + service = MindsporeService(self.config) + + self.assertIsNotNone(service.logger) + self.assertIsNotNone(service.api_register) + self.assertIsNotNone(service.primitive_hook_service) + self.assertIsNotNone(service.cell_processor) + self.assertIsNotNone(service.hook_manager) + + @patch.object(JitDump, "set_data_collector") + @patch.object(JitDump, "set_config") + @patch('msprobe.mindspore.mindspore_service.ms.common.api') + def test_setup_jit_context_with_pijit(self, mock_ms_api, mock_jit_set_config, mock_set_data_collector): + mock_ms_api.__dict__['_MindsporeFunctionExecutor'] = MagicMock() + self.service._setup_jit_context() + + mock_jit_set_config.assert_called_once_with(self.config) + mock_set_data_collector.assert_called_once_with(self.service.data_collector) + self.assertEqual(mock_ms_api._MindsporeFunctionExecutor, JitDump) + self.assertEqual(mock_ms_api._PyNativeExecutor.grad, JitDump.grad) + if pijit_label: + self.assertEqual(PIJitCaptureContext.__enter__, self.service.empty) + self.assertEqual(PIJitCaptureContext.__exit__, self.service.empty) + + @patch('msprobe.mindspore.mindspore_service.JitDump') + def test_change_jit_switch(self, mock_jit_dump): + self.service._change_jit_switch(True) + self.assertTrue(mock_jit_dump.jit_dump_switch) + + self.service._change_jit_switch(False) + self.assertFalse(mock_jit_dump.jit_dump_switch) + + def test_register_module_hook(self): + model_mock = MagicMock() + self.service.model = model_mock + self.service._register_module_hook() + + self.service.cell_processor.register_cell_hook.assert_called_once_with( + model_mock, self.service.build_hook, self.config ) - - @patch.object(Service, 'need_end_service', return_value=False) - def test_start_stop_cycle(self, mock_need_end_service): - self.service.model = nn.Cell() - with patch.object(self.service, 'register_cell_hook') as mock_register_hook: - self.should_stop_service = False - self.service.start(self.service.model) - self.assertTrue(self.service.switch) - self.service.stop() - self.assertFalse(self.service.switch) - mock_register_hook.assert_called_once() - mock_need_end_service.assert_called_once() - - def test_should_execute_hook_return_false(self): - cell = MagicMock() - self.service.switch = False - self.assertFalse(self.service.should_execute_hook("Module", cell, True)) - self.assertFalse(self.service.should_execute_hook("api", cell, True)) - - self.service.switch = True - cell.forward_data_collected = False - self.assertFalse(self.service.should_execute_hook("api", cell, False)) - - self.service.inner_switch = True - self.assertFalse(self.service.should_execute_hook("Module", cell, True)) - - self.service.inner_switch = False - self.service.data_collector = None - self.assertFalse(self.service.should_execute_hook("Module", cell, True)) - - def test_should_execute_hook_return_true(self): - cell = MagicMock() - self.service.switch = True - self.service.inner_switch = False - self.service.data_collector = MagicMock() - self.service.data_collector.data_processor = MagicMock() - self.service.data_collector.data_processor.is_terminated = False - self.assertTrue(self.service.should_execute_hook("Module", cell, True)) - - cell.forward_data_collected = True - self.assertTrue(self.service.should_execute_hook("api", cell, False)) - - def test_need_end_service_with_high_step(self): - self.service.config.step = [1, 2, 3] - self.service.current_iter = 4 - self.assertTrue(self.service.need_end_service()) - - def test_need_end_service_with_low_step(self): - self.service.config.step = [1, 2, 3] - self.service.current_iter = 2 - self.service.data_collector.data_processor.is_terminated = False - self.assertFalse(self.service.need_end_service()) - - def test_start_with_termination_condition(self): - self.service.config.step = [1, 2, 3] - self.service.current_iter = 4 - self.service.start() - self.assertFalse(self.service.switch) - self.assertTrue(self.service.should_stop_service) - self.assertFalse(self.service.primitive_switch) - - @patch('msprobe.mindspore.service.print_tools_ends_info') - @patch.object(Service, 'need_end_service', return_value=True) - def test_start_with_end_service(self, mock_need_end_service, mock_print_tools_ends_info): - self.service.start(self.service.model) - mock_need_end_service.assert_called_once() - mock_print_tools_ends_info.assert_called_once() - self.assertFalse(self.service.switch) - self.assertTrue(self.service.should_stop_service) - - @patch.object(Service, 'need_end_service', return_value=False) - @patch.object(logger, 'info') - @patch.object(Service, 'register_cell_hook') - @patch.object(Service, 'register_primitive_hook') - @patch.object(Service, 'create_dirs') - @patch('msprobe.mindspore.service.get_rank_if_initialized', return_value=0) - def test_start_first_time(self, mock_get_rank, mock_create_dirs, mock_register_primitive_hook, - mock_register_cell_hook, mock_logger, mock_need_end_service): - self.service.first_start = True - self.service.should_stop_service = False - self.service.start(self.service.model) - mock_get_rank.assert_called_once() - mock_register_cell_hook.assert_called_once() - mock_register_primitive_hook.assert_called_once() - mock_need_end_service.assert_called_once() - mock_create_dirs.assert_called_once() - self.assertFalse(self.service.first_start) - self.assertTrue(self.service.switch) - self.assertTrue(self.service.primitive_switch) - mock_logger.assert_called_with(f"Dump data will be saved in {self.service.dump_iter_dir}.") - - @patch.object(Service, 'register_primitive_hook') - @patch.object(Service, 'register_cell_hook') - @patch.object(Service, 'need_end_service', return_value=False) - @patch.object(JitDump, 'set_config') - @patch.object(JitDump, 'set_data_collector') - @patch.object(ApiRegistry, 'api_set_hook_func') - def test_start_with_jit_dump_enabled(self, mock_api_set_hook_func, mock_set_data_collector, - mock_set_config, mock_need_end_service, mock_register_cell_hook, - mock_register_primitive_hook): - self.service.config.level = Const.LEVEL_MIX - self.service.first_start = True - self.service.should_stop_service = False - self.service.start(self.service.model) - mock_set_config.assert_called_with(self.service.config) - mock_set_data_collector.assert_called_with(self.service.data_collector) - mock_api_set_hook_func.assert_called_once() - mock_need_end_service.assert_called_once() - mock_register_cell_hook.assert_called_once() - mock_register_primitive_hook.assert_called_once() - self.assertTrue(JitDump.jit_dump_switch) - - def test_step_updates(self): - CellProcessor.cell_count = {"test_api": 1} - HOOKCell.cell_count = {"test_api": 1} - JitDump.jit_count = {"test_api": 1} - self.service.primitive_hook_service.primitive_counters = {"test_api": 1} - self.service.current_iter = 0 - self.service.step() - self.assertEqual(self.service.current_iter, 1) - self.service.data_collector.update_iter.assert_called_once_with(1) - self.service.data_collector.reset_status.assert_called_once() - self.assertEqual(JitDump.jit_count, defaultdict(int)) - self.assertEqual((self.service.primitive_hook_service.primitive_counters), {}) - - @patch.object(Service, 'should_execute_hook') - def test_build_forward_and_backward_hooks(self, mock_should_execute_hook): - mock_should_execute_hook.return_value = True - self.service.data_collector = MagicMock() - self.service.data_collector.update_api_or_module_name = MagicMock() - self.service.data_collector.forward_data_collect = MagicMock() - self.service.data_collector.if_return_forward_new_output = MagicMock(return_value=False) - self.service.data_collector.backward_data_collect = MagicMock() - - mock_cell = MagicMock() - mock_cell.mindstudio_reserved_name = "TestCell" - mock_input = (MagicMock(),) - mock_output = MagicMock() - - _, forward_hook, backward_hook, _ = self.service.build_hook(BaseScope.Module_Type_Module, "TestHook") - - forward_hook(mock_cell, mock_input, mock_output) - self.service.data_collector.update_api_or_module_name.assert_called_with('TestCell') - self.service.data_collector.forward_data_collect.assert_called() - - self.service.data_collector.reset_mock() - - mock_grad_input = (MagicMock(),) - mock_grad_output = MagicMock() - - backward_hook(mock_cell, mock_grad_input, mock_grad_output) - self.service.data_collector.update_api_or_module_name.assert_called_with('TestHookbackward.0') - self.service.data_collector.backward_data_collect.assert_called() - + def test_register_primitive_hook(self): self.service.config.level = Const.LEVEL_MIX primitive_attr = ops.Add() primitive_name = "primitive_api" + mock_model = MagicMock() cell_mock = MagicMock() cell_mock.primitive_api = primitive_attr primitive_combined_name = primitive_name + Const.SEP + primitive_attr.__class__.__name__ - self.service.model.cells_and_names.return_value = [("cell_name", cell_mock)] - self.service.register_primitive_hook() + self.service.model = mock_model + with patch('msprobe.mindspore.mindspore_service.get_cells_and_names_with_index') as mock_get_cells_and_names: + mock_get_cells_and_names.return_value = ({'-1': [("cell_name", cell_mock)]}, {}) + self.service._register_primitive_hook() self.assertTrue(hasattr(primitive_attr.__class__, '__call__')) self.assertEqual(self.service.primitive_hook_service.wrap_primitive.call_args[0][1], primitive_combined_name) - - @patch.object(ApiRegistry, 'initialize_hook') - @patch.object(ApiRegistry, 'api_set_hook_func') - @patch("msprobe.mindspore.service.logger.info") - def test_register_hook_new_with_level_mix(self, mock_logger, mock_api_set_hook_func, mock_initialize_hook): - self.service.config.level = Const.LEVEL_MIX - self.service.register_api_hook() - self.service.register_cell_hook() - mock_logger.assert_called_with(f"The cell {self.service.config.task} hook function " - "is successfully mounted to the model.") - mock_api_set_hook_func.assert_called() - mock_initialize_hook.assert_called() - - @patch.object(CellProcessor, 'node_hook') - def test_register_hook_new_with_level_l0(self, mock_node_hook): - global register_backward_hook_functions - self.service.config.level = Const.LEVEL_L0 - cell_mock = MagicMock() - self.service.model.cells_and_names.return_value = [("cell_name", cell_mock)] - register_backward_hook_functions["pre"] = cell_mock.register_backward_pre_hook - register_backward_hook_functions["full"] = cell_mock.register_backward_hook - self.service.register_cell_hook() - cell_mock.register_forward_hook.assert_called() - cell_mock.register_backward_hook.assert_called() - mock_node_hook.assert_called() - register_backward_hook_functions = {} - - def test_register_hook_new_without_model_raises_exception(self): - self.service.config.level = Const.LEVEL_L0 - self.service.model = None - with self.assertRaises(MsprobeException): - self.service.register_cell_hook() + + def test_reset_status(self): + self.service.primitive_hook_service.primitive_counters = defaultdict(int) + self.service.primitive_hook_service.primitive_counters['test_prim'] = 5 + self.service._reset_status() + self.assertEqual(self.service.primitive_hook_service.primitive_counters, {}) + with patch('msprobe.mindspore.mindspore_service.JitDump') as mock_jit_dump: + mock_jit_dump.jit_count = defaultdict(int) + mock_jit_dump.jit_count['test_jit'] = 3 + self.service._reset_status() + self.assertEqual(mock_jit_dump.jit_count, {}) + + @patch('msprobe.mindspore.mindspore_service.JitDump') + def test_start_jit_enabled(self, mock_jit_dump): + self.service.data_collector.data_processor.is_terminated = False + self.service.config.online_run_ut = None + model_mock = MagicMock() + self.service.start(model=model_mock) + self.assertTrue(mock_jit_dump.jit_dump_switch) + + @patch('msprobe.mindspore.mindspore_service.JitDump') + def test_stop_jit_disabled(self, mock_jit_dump): + self.service.config.online_run_ut = None + self.config.level = Const.LEVEL_MIX + self.service.current_iter = 1 + self.service.current_rank = 0 + + self.service.stop() + + self.assertFalse(mock_jit_dump.jit_dump_switch) + + @patch('msprobe.mindspore.mindspore_service.JitDump') + @patch('msprobe.mindspore.mindspore_service.ms.common.api') + def test_setup_jit_context_level_not_supported(self, mock_ms_api, mock_jit_dump): + self.service.config.level = Const.LEVEL_DEBUG + + self.service._setup_jit_context() + + mock_jit_dump.set_config.assert_not_called() + mock_jit_dump.set_data_collector.assert_not_called() diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py index f46f171aa38585ea801f1fd3a9716bd3876a63a5..520a688dcf475fd3e6831ab0b25c7dda9faeb31b 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py @@ -1,7 +1,6 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" -# Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -13,19 +12,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" + from unittest import TestCase from unittest.mock import patch -from msprobe.mindspore.common.const import Const +from msprobe.core.common.log import logger from msprobe.core.common_config import CommonConfig, BaseConfig +from msprobe.mindspore.common.const import Const from msprobe.mindspore.debugger.debugger_config import DebuggerConfig from msprobe.mindspore.overflow_check.overflow_check_tool_factory import OverflowCheckToolFactory class TestOverflowCheckToolFactory(TestCase): + @patch.object(logger, "error") @patch("msprobe.mindspore.debugger.debugger_config.create_directory") - def test_create(self, _): + def test_create(self, _, mock_logger_error): json_config = { "task": "overflow_check", "dump_path": "/absolute_path", @@ -45,12 +46,11 @@ class TestOverflowCheckToolFactory(TestCase): config.execution_mode = Const.GRAPH_GE_MODE config.level = "cell" - with self.assertRaises(Exception) as context: + with self.assertRaises(ValueError): OverflowCheckToolFactory.create(config) - self.assertEqual(str(context.exception), - f"Overflow check is not supported in {config.execution_mode} mode " - f"when level is {config.level}.") + mock_logger_error.assert_called_with(f"Overflow check is not supported in {config.execution_mode} mode " + f"when level is {config.level}.") config.level = "kernel" - dumper = OverflowCheckToolFactory.create(config) + dumper = OverflowCheckToolFactory.create(config)[0] self.assertEqual(dumper.dump_json["common_dump_settings"]["file_format"], "npy") diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_primitive_dump.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_primitive_dump.py index 3cafd49f2c101c45dbb65a08803dd77c6bca485d..a69caed2569cd875224cf7b87ea16ced69ce3ae5 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_primitive_dump.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_primitive_dump.py @@ -1,8 +1,7 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" -# Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. -# Licensed under the Apache License, Version 2.0 (the "License"); +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # @@ -13,95 +12,24 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" + +from collections import defaultdict +import tempfile import unittest -import mindspore as ms -import numpy as np -import os from unittest.mock import Mock, patch -from mindspore import nn +import numpy as np +import mindspore as ms +from mindspore import Tensor, ops -import tempfile from msprobe.core.common.utils import Const -from msprobe.mindspore.service import Service -from msprobe.core.common.exceptions import MsprobeException +from msprobe.mindspore.mindspore_service import MindsporeService +from msprobe.core.common.runtime import Runtime from msprobe.core.common_config import CommonConfig, BaseConfig from msprobe.mindspore.debugger.debugger_config import DebuggerConfig from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell -from collections import defaultdict from msprobe.mindspore.dump.hook_cell.primitive_hooks import PrimitiveHookService -from mindspore.common.tensor import Tensor - - -class DummyModel(nn.Cell): - def __init__(self): - super(DummyModel, self).__init__() - self.dense = nn.Dense(2, 2) - - def construct(self, x): - return self.dense(x) - - -class TestService(unittest.TestCase): - @patch("msprobe.mindspore.debugger.debugger_config.create_directory") - def setUp(self, _): - json_config = { - "task": "statistics", - "dump_path": "/absolute_path", - "rank": [], - "step": [0, 2], - "level": "L1" - } - - common_config = CommonConfig(json_config) - task_config = BaseConfig(json_config) - config = DebuggerConfig(common_config, task_config) - self.service = Service(config) - self.service.model = Mock() - self.service.data_collector = Mock() - self.service.switch = True # Make sure the switch is on for testing - self.service.primitive_switch = True # Make sure the switch is on for testing - - def test_check_model_valid_none(self): - model = None - self.assertIsNone(self.service.check_model_valid(model)) - - def test_check_model_valid_valid_model(self): - model = DummyModel() - self.assertEqual(self.service.check_model_valid(model), model) - - def test_check_model_valid_invalid_model(self): - model = "invalid_model" - with self.assertRaises(MsprobeException) as context: - self.service.check_model_valid(model) - - def test_update_primitive_counters(self): - primitive_name = "test_primitive" - self.service.primitive_hook_service.update_primitive_counters(primitive_name) - self.assertEqual(self.service.primitive_hook_service.primitive_counters[primitive_name], 0) - self.service.primitive_hook_service.update_primitive_counters(primitive_name) - self.assertEqual(self.service.primitive_hook_service.primitive_counters[primitive_name], 1) - - def test_step_updates_iteration(self): - initial_iter = self.service.current_iter - self.service.step() - self.assertEqual(self.service.current_iter, initial_iter + 1) - - @patch.object(HOOKCell, 'cell_count', new_callable=lambda: defaultdict(int)) - def test_step_resets_counters(self, _): - # 假设在 step 调用之前已经有一些 primitive_counters - self.service.primitive_hook_service.primitive_counters["test_primitive"] = 5 - self.service.step() - self.assertEqual(self.service.primitive_hook_service.primitive_counters, {}) - self.assertEqual(HOOKCell.cell_count, defaultdict(int)) - - def test_step_calls_update_iter(self): - # 检查是否在调用 step 时调用了 update_iter - with patch.object(self.service.data_collector, 'update_iter') as mock_update_iter: - initial_iter = self.service.current_iter - self.service.step() - mock_update_iter.assert_called_once_with(initial_iter + 1) +from msprobe.mindspore.ms_config import StatisticsConfig class TestPrimitiveHookService(unittest.TestCase): @@ -118,21 +46,16 @@ class TestPrimitiveHookService(unittest.TestCase): } common_config = CommonConfig(json_config) - task_config = BaseConfig(json_config) + task_config = StatisticsConfig(json_config) config = DebuggerConfig(common_config, task_config) - self.service = Service(config) - self.service.model = Mock() - self.service.data_collector = Mock() - self.service.switch = True # Make sure the switch is on for testing - - # 模拟一个 service_instance 和 data_collector - self.mock_service_instance = Service(config) - self.mock_service_instance.switch = True - self.mock_service_instance.data_collector = Mock() - self.mock_service_instance.data_collector.dump_file_path = json_config["dump_path"] - # 初始化 PrimitiveHookService - self.primitive_hook_service = PrimitiveHookService(self.mock_service_instance) + with patch('msprobe.core.service.build_data_collector'), \ + patch('msprobe.mindspore.mindspore_service.CellProcessor'), \ + patch('msprobe.mindspore.mindspore_service.PrimitiveHookService'), \ + patch('msprobe.mindspore.mindspore_service.get_api_register'): + self.mock_service_instance = MindsporeService(config) + Runtime.is_running = True + self.primitive_hook_service = PrimitiveHookService(self.mock_service_instance) def tearDown(self): # 测试结束时删除临时目录 @@ -147,7 +70,6 @@ class TestPrimitiveHookService(unittest.TestCase): # 调用 wrap_primitive 获取包装函数通过闭包显式调用backward_hook hook_primitive_inputs = self.primitive_hook_service.wrap_primitive(None, "example").__closure__[0].cell_contents - wrapped_primitive_call = self.primitive_hook_service.wrap_primitive(None, "example") create_backward_hook = hook_primitive_inputs.__closure__[0].cell_contents @@ -162,7 +84,6 @@ class TestPrimitiveHookService(unittest.TestCase): backward_hook(grad_2) self.assertEqual(len(captured_grads), 6) # 捕获到两个梯度 - print(f"1After first backward_hook call, len(captured_grads): {len(captured_grads)}") # 调用到达阈值,验证数据收集 self.assertTrue(self.mock_service_instance.data_collector.backward_output_data_collect.called) @@ -176,7 +97,6 @@ class TestPrimitiveHookService(unittest.TestCase): # 调用 wrap_primitive 获取包装函数通过闭包显式调用backward_hook hook_primitive_inputs = self.primitive_hook_service.wrap_primitive(None, "example").__closure__[0].cell_contents - wrapped_primitive_call = self.primitive_hook_service.wrap_primitive(None, "example") create_backward_hook = hook_primitive_inputs.__closure__[0].cell_contents @@ -213,14 +133,7 @@ class TestPrimitiveHookService(unittest.TestCase): # 调用 wrap_primitive 获取包装函数通过闭包显式调用backward_hook hook_primitive_inputs = self.primitive_hook_service.wrap_primitive(None, "example").__closure__[0].cell_contents - wrapped_primitive_call = self.primitive_hook_service.wrap_primitive(None, "example") - if wrapped_primitive_call.__closure__: - for i, closure in enumerate(wrapped_primitive_call.__closure__): - print(f"Closure[{i}]:", closure.cell_contents) - - if hook_primitive_inputs.__closure__: - for i, closure in enumerate(hook_primitive_inputs.__closure__): - print(f"2Closure[{i}]:", closure.cell_contents) + create_backward_hook = hook_primitive_inputs.__closure__[0].cell_contents backward_hook = create_backward_hook(captured_grads, num_tensors, updated_primitive_name, hook_type) @@ -234,7 +147,6 @@ class TestPrimitiveHookService(unittest.TestCase): backward_hook(grad_2) self.assertEqual(len(captured_grads), 6) # 捕获到两个梯度 - print(f"After first backward_hook call, len(captured_grads): {len(captured_grads)}") # 调用到达阈值,验证数据收集 self.assertTrue(self.mock_service_instance.data_collector.backward_input_data_collect.called) @@ -281,18 +193,15 @@ class TestPrimitiveHookService(unittest.TestCase): updated_primitive_name = "test_primitive_input" # 调用 hook_primitive_inputs - hooked_inputs = self.primitive_hook_service.wrap_primitive(None, "example").__closure__[0].cell_contents(args, - captured_grads_input, - updated_primitive_name) - - # 验证 hooked_inputs 是否正确添加了 hook - for arg, hooked_arg in zip(args, hooked_inputs): - if isinstance(arg, Tensor): - print(f"Captured hooked_arg after hook: {hooked_arg}") - self.assertTrue(hasattr(hooked_arg, 'grad_fn')) - - # 打印调试信息 - print(f"Captured gradients after hook: {captured_grads_input}") + hook_primitive_inputs = self.primitive_hook_service.wrap_primitive(None, "example").__closure__[0].cell_contents + with patch.object(ops, 'HookBackward') as mock_HookBackward: + target_value = Tensor([1.0]) + mock_hbw = mock_HookBackward.return_value + mock_hbw.return_value = target_value + hooked_inputs = hook_primitive_inputs(args, captured_grads_input, updated_primitive_name) + self.assertEqual(mock_HookBackward.call_count, len(args)) + for hooked_input in hooked_inputs: + self.assertTrue((hooked_input == target_value).all()) def test_hook_primitive_outputs(self): # 模拟前向输出 @@ -301,17 +210,16 @@ class TestPrimitiveHookService(unittest.TestCase): updated_primitive_name = "test_primitive_output" # 调用 hook_primitive_outputs - hook_primitive_outputs = self.primitive_hook_service.wrap_primitive(None, "example").__closure__[ - 1].cell_contents - hooked_outputs = hook_primitive_outputs(out, captured_grads_output, updated_primitive_name) - - # 验证 hooked_outputs 是否正确添加了 hook - for tensor, hooked_tensor in zip(out, hooked_outputs): - if isinstance(tensor, Tensor): - self.assertTrue(hasattr(hooked_tensor, 'grad_fn')) - - # 打印调试信息 - print(f"Captured gradients after output hook: {captured_grads_output}") + hook_primitive_outputs = self.primitive_hook_service.wrap_primitive(None, + "example").__closure__[1].cell_contents + with patch.object(ops, 'HookBackward') as mock_HookBackward: + target_value = Tensor([1.0]) + mock_hbw = mock_HookBackward.return_value + mock_hbw.return_value = target_value + hooked_outputs = hook_primitive_outputs(out, captured_grads_output, updated_primitive_name) + self.assertEqual(mock_HookBackward.call_count, len(out)) + for hooked_output in hooked_outputs: + self.assertTrue((hooked_output == target_value).all()) def test_wrapped_primitive_call_args(self): # 模拟前向输入 @@ -324,19 +232,18 @@ class TestPrimitiveHookService(unittest.TestCase): # 调用 wrapped_primitive_call 并检查 hooked_inputs 是否与原始 args 相同 try: - hooked_inputs = wrapped_primitive_call.__closure__[0].cell_contents(args, captured_grads_input, - updated_primitive_name) - for arg, hooked_arg in zip(args, hooked_inputs): - if isinstance(arg, Tensor): - self.assertTrue(hasattr(hooked_arg, 'grad_fn')) - self.assertTrue(np.array_equal(arg.asnumpy(), hooked_arg.asnumpy())) - print(f"Arg type: {type(arg)}, Hooked input type: {type(hooked_arg)}") - else: - self.assertEqual(arg, hooked_arg) + with patch.object(ops, 'HookBackward') as mock_HookBackward: + target_value = Tensor([1.0]) + mock_hbw = mock_HookBackward.return_value + mock_hbw.return_value = target_value + hooked_inputs = wrapped_primitive_call.__closure__[0].cell_contents(args, captured_grads_input, + updated_primitive_name) + self.assertEqual(mock_HookBackward.call_count, len(args)) + for hooked_input in hooked_inputs: + self.assertTrue((hooked_input == target_value).all()) except Exception as e: self.fail(f"wrapped_primitive_call raised an exception: {e}") - def test_update_primitive_counters_multiple(self): # 测试更新 primitive 计数器的功能,增加多个不同名称的测试 primitive_names = ["MatMul", "Conv2D", "ReLU", "Softmax"] @@ -366,7 +273,7 @@ class TestPrimitiveHookService(unittest.TestCase): def test_wrap_primitive_no_hook_with_invalid_input(self): # 测试在 switch 关闭时传入无效输入时的行为 - self.mock_service_instance.switch = False + Runtime.is_running = False invalid_inputs = [None, "invalid_tensor", 123] @@ -415,13 +322,11 @@ class TestPrimitiveHookService(unittest.TestCase): for captured_grads in captured_grads_sets: updated_primitive_name = "MatMul.Backward" - num_tensors = len(captured_grads) hook = self.primitive_hook_service.wrap_primitive(Mock(), "MatMul") backward_hook = hook(Mock(), captured_grads, updated_primitive_name, Const.INPUT) self.assertIsNotNone(backward_hook) - @patch('msprobe.mindspore.dump.hook_cell.primitive_hooks.ops.HookBackward') def test_wrap_primitive_forward_and_backward_hooks(self, mock_hook_backward): # 模拟前向和后向钩子在同一个 primitive 中的行为 @@ -446,9 +351,6 @@ class TestPrimitiveHookService(unittest.TestCase): self.primitive_hook_service.update_primitive_counters(name) self.assertEqual(self.primitive_hook_service.primitive_counters[name], i) - - - def test_update_primitive_counters(self): primitive_name = "MatMul" self.primitive_hook_service.update_primitive_counters(primitive_name) @@ -495,7 +397,7 @@ class TestPrimitiveHookService(unittest.TestCase): wrapped_func = self.primitive_hook_service.wrap_primitive(mock_origin_func, "MatMul") # 模拟反向传播过程,调用包装的 primitive - with patch.object(self.mock_service_instance.data_collector, 'backward_data_collect') as mock_backward_collect: + with patch.object(self.mock_service_instance.data_collector, 'backward_data_collect'): result = wrapped_func(Mock(), input_tensor) # 验证结果是 Tensor 实例 @@ -503,7 +405,7 @@ class TestPrimitiveHookService(unittest.TestCase): def test_wrap_primitive_no_hook_when_switch_off(self): # 模拟 switch 关闭的情况 - self.mock_service_instance.switch = False + Runtime.is_running = False # 模拟 Tensor 输入 input_tensor = Tensor(np.random.randn(2, 2).astype(np.float32)) @@ -543,7 +445,6 @@ class TestPrimitiveHookService(unittest.TestCase): # 测试 create_backward_hook 的功能 captured_grads = [] updated_primitive_name = "MatMul.Backward" - num_tensors = 2 # 创建 backward hook backward_hook = self.primitive_hook_service.wrap_primitive(Mock(), "MatMul") diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_task_handler_factory.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_task_handler_factory.py index 752b5f916d50083dba842707feeeda2edcbe14f0..bce627650093f0bd19e3d6d831bd9e5b065dd286 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_task_handler_factory.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_task_handler_factory.py @@ -20,6 +20,7 @@ from unittest.mock import patch from msprobe.core.common_config import CommonConfig, BaseConfig from msprobe.mindspore.debugger.debugger_config import DebuggerConfig from msprobe.mindspore.dump.kernel_graph_dump import KernelGraphDump +from msprobe.mindspore.dump.kernel_kbyk_dump import KernelKbykDump from msprobe.mindspore.task_handler_factory import TaskHandlerFactory from msprobe.mindspore.common.const import Const @@ -47,7 +48,9 @@ class TestTaskHandlerFactory(TestCase): config.execution_mode = Const.GRAPH_GE_MODE handler = TaskHandlerFactory.create(config) - self.assertTrue(isinstance(handler, KernelGraphDump)) + self.assertTrue(isinstance(handler, tuple)) + self.assertTrue(isinstance(handler[1], KernelKbykDump)) + self.assertTrue(isinstance(handler[0], KernelGraphDump)) with patch("msprobe.mindspore.task_handler_factory.TaskHandlerFactory.tasks", new=tasks): with self.assertRaises(Exception) as context: diff --git a/debug/accuracy_tools/msprobe/test/nan_analyze_ut/test_nan_analyzer.py b/debug/accuracy_tools/msprobe/test/nan_analyze_ut/test_nan_analyzer.py new file mode 100644 index 0000000000000000000000000000000000000000..833844e5beb3c0b762bf34369ea977b5ce79dee8 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/nan_analyze_ut/test_nan_analyzer.py @@ -0,0 +1,358 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2025. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import os.path +import unittest +from unittest.mock import patch +import argparse + +from msprobe.nan_analyze.analyzer import _nan_analyze_parser, NanAnalyzer + + +class DumpDataBuilder: + def __init__(self): + self.nodes = {} + self.layer = {} + + @staticmethod + def gen_data(is_normal, **kwargs): + def gen_single_data(normal): + return { + 'type': 'torch.Tensor', + 'dtype': 'torch.float32', + 'shape': [ + 2, + 1024 + ], + 'Max': 2.0 if normal else 'inf', + 'Min': 1.0 if normal else '-inf', + 'Mean': 1.5 if normal else 'nan', + 'Norm': 2.236 if normal else 'nan', + 'requires_grad': False + } + + def gen_int(value): + return { + 'type': 'int', + 'value': value + } + + def gen_process_group(ranks): + return { + 'type': 'ProcessGroup', + 'group_ranks': ranks + } + + data_type = kwargs.get('type') + if data_type == 'compute': + return { + 'input_args': [gen_single_data(True)], + 'input_kwargs': {}, + 'output': [gen_single_data(is_normal)] + } + if data_type == 'p2p_src': + return { + 'input_args': [gen_single_data(kwargs.get('is_input_normal'))], + 'input_kwargs': {'dst': gen_int(kwargs.get('dst'))}, + 'output': [gen_single_data(kwargs.get('is_output_normal'))] + } + if data_type == 'p2p_dst': + return { + 'input_args': [gen_single_data(kwargs.get('is_input_normal'))], + 'input_kwargs': {'src': gen_int(kwargs.get('src'))}, + 'output': [gen_single_data(kwargs.get('is_output_normal'))] + } + if data_type == 'p2g_src': + return { + 'input_args': [gen_single_data(kwargs.get('is_input_normal')), gen_int(kwargs.get('src'))], + 'input_kwargs': {'group': gen_process_group(kwargs.get('ranks'))}, + 'output': [gen_single_data(kwargs.get('is_output_normal'))] + } + if data_type == 'p2g_dst': + return { + 'input_args': [gen_single_data(kwargs.get('is_input_normal')), gen_int(kwargs.get('dst'))], + 'input_kwargs': {'group': gen_process_group(kwargs.get('ranks'))}, + 'output': [gen_single_data(kwargs.get('is_output_normal'))] + } + if data_type == 'link': + return { + 'input_args': [gen_single_data(kwargs.get('is_input_normal'))], + 'input_kwargs': {}, + 'output': [gen_single_data(kwargs.get('is_output_normal'))] + } + + def add_node(self, is_normal, **kwargs): + name = kwargs.get("name", 'operator') + layer = self.layer.get(name, 0) + if kwargs.get('type') == 'compute': + node_name = f'Torch.operator.{layer}.forward' + else: + node_name = f'Distributed.{name}.{layer}.forward' + self.nodes[node_name] = self.gen_data(is_normal, **kwargs) + self.layer[name] = layer + 1 + return self + + def build(self): + return self.nodes + + +rank_order_dict = { + # (name, type, src, dst, ranks) + 0: [(0, 'compute', 0, 0, []), + (0, 'compute', 0, 0, []), + (0, 'compute', 0, 0, []), + (0, 'compute', 0, 0, []), + (0, 'compute', 0, 0, []), + ('send', 'p2p_src', 0, 1, []), + ('recv', 'p2p_dst', 1, 0, []), + (0, 'compute', 0, 0, []), + ('broadcast', 'p2g_src', 0, 0, [0, 1, 2, 3]), + (0, 'compute', 0, 0, []), + ('all_gather', 'link', 0, 0, []), + (0, 'compute', 0, 0, []), + ('gather', 'p2g_dst', 0, 0, [0, 1, 2, 3]), + (0, 'compute', 0, 0, [])], + 1: [('recv', 'p2p_dst', 0, 0, []), + (0, 'compute', 0, 0, []), + (0, 'compute', 0, 0, []), + ('send', 'p2p_src', 0, 2, []), + ('recv', 'p2p_dst', 2, 0, []), + (0, 'compute', 0, 0, []), + (0, 'compute', 0, 0, []), + ('send', 'p2p_src', 0, 0, []), + ('broadcast', 'p2g_src', 0, 0, [0, 1, 2, 3]), + (0, 'compute', 0, 0, []), + ('all_gather', 'link', 0, 0, []), + (0, 'compute', 0, 0, []), + ('gather', 'p2g_dst', 0, 0, [0, 1, 2, 3]), + (0, 'compute', 0, 0, [])], + 2: [('recv', 'p2p_dst', 1, 0, []), + (0, 'compute', 0, 0, []), + (0, 'compute', 0, 0, []), + ('send', 'p2p_src', 0, 3, []), + ('recv', 'p2p_dst', 3, 0, []), + (0, 'compute', 0, 0, []), + (0, 'compute', 0, 0, []), + ('send', 'p2p_src', 0, 1, []), + ('broadcast', 'p2g_src', 0, 0, [0, 1, 2, 3]), + (0, 'compute', 0, 0, []), + ('all_gather', 'link', 0, 0, []), + (0, 'compute', 0, 0, []), + ('gather', 'p2g_dst', 0, 0, [0, 1, 2, 3]), + (0, 'compute', 0, 0, [])], + 3: [('recv', 'p2p_dst', 2, 0, []), + (0, 'compute', 0, 0, []), + (0, 'compute', 0, 0, []), + (0, 'compute', 0, 0, []), + ('send', 'p2p_src', 0, 2, []), + ('broadcast', 'p2g_src', 0, 0, [0, 1, 2, 3]), + (0, 'compute', 0, 0, []), + ('all_gather', 'link', 0, 0, []), + (0, 'compute', 0, 0, []), + ('gather', 'p2g_dst', 0, 0, [0, 1, 2, 3]), + (0, 'compute', 0, 0, []), + (0, 'compute', 0, 0, []), + (0, 'compute', 0, 0, []), + (0, 'compute', 0, 0, [])] +} + + +def do_nothing(*args, **kwargs): + return + + +def gen_normal_dump_json(rank): + builder = DumpDataBuilder() + for name, data_type, src, dst, ranks in rank_order_dict[rank]: + builder = builder.add_node(True, name=name, type=data_type, src=src, dst=dst, ranks=ranks, + is_input_normal=True, is_output_normal=True) + return {'task': 'statistics', + 'level': 'mix', + 'dump_data_dir': None, + 'data': builder.build() + } + + +def gen_pre_anomaly_dump_json(rank): + builder = DumpDataBuilder() + for i, (name, data_type, src, dst, ranks) in enumerate(rank_order_dict[rank]): + is_normal = True + if i == rank and i in [0, 1]: + is_normal = False + builder = builder.add_node(is_normal, name=name, type=data_type, src=src, dst=dst, ranks=ranks, + is_input_normal=True, is_output_normal=True) + return {'task': 'statistics', + 'level': 'mix', + 'dump_data_dir': None, + 'data': builder.build() + } + + +def gen_anomaly_dump_json(rank): + builder = DumpDataBuilder() + start = 999 + for i, (name, data_type, src, dst, ranks) in enumerate(rank_order_dict[rank]): + is_normal = True + is_input_normal = True + is_output_normal = True + if rank == 0: + if i == 7: + is_normal = False + elif i == 8: + is_input_normal = False + is_output_normal = False + else: + if name == 'broadcast': + start = i + is_output_normal = False + elif i > start: + is_normal = False + is_input_normal = False + is_output_normal = False + builder = builder.add_node(is_normal, name=name, type=data_type, src=src, dst=dst, ranks=ranks, + is_input_normal=is_input_normal, is_output_normal=is_output_normal) + return {'task': 'statistics', + 'level': 'mix', + 'dump_data_dir': None, + 'data': builder.build() + } + + +def gen_after_anomaly_dump_json(rank): + builder = DumpDataBuilder() + for i, (name, data_type, src, dst, ranks) in enumerate(rank_order_dict[rank]): + is_normal = (rank != 2 or i != 13) and (rank != 3 or i != 11) + builder = builder.add_node(is_normal, name=name, type=data_type, src=src, dst=dst, ranks=ranks, + is_input_normal=True, is_output_normal=True) + return {'task': 'statistics', + 'level': 'mix', + 'dump_data_dir': None, + 'data': builder.build() + } + + +json_dict = {os.path.join('./step0', f'rank{i if i > 0 else ""}', 'construct.json'): {} for i in range(4)} + + +def gen_stack_json(rank): + return {f'0': [list(json_dict[os.path.join('./step0', f'rank{rank if rank > 0 else ""}', 'dump.json')]['data'].keys()), + ['File /root/example.py, line 10, in test_fcn, \\n test(tensor)']]} + + +class mock_time: + _uni_value = 1 + + @staticmethod + def set_uni_value(var): + mock_time._uni_value = var + + @staticmethod + def time_ns(): + return mock_time._uni_value + + + +class MockedFileCache: + def load_json(self, file_path): + return json_dict[file_path] + + +class TestAnalyzer(unittest.TestCase): + def setUp(self): + self.output = {} + self.input_path = './step0' + self.output_path = './output' + with patch('os.listdir', return_value=['rank', 'rank1', 'rank2', 'rank3', 'rank_others']), \ + patch('msprobe.nan_analyze.utils.check_file_or_directory_path', do_nothing), \ + patch('msprobe.nan_analyze.analyzer.FileCache', MockedFileCache): + self.analyzer = NanAnalyzer(self.input_path, self.output_path) + + def mocked_save_json(self, file, content, indent): + self.output[file] = content + + def test_nan_analyze_parser(self): + args = [ + '-i', '/path/to/input', + '-o', '/path/to/output', + ] + + parser = argparse.ArgumentParser() + _nan_analyze_parser(parser) + parsed_args = parser.parse_args(args) + self.assertEqual(parsed_args.input_path, '/path/to/input') + self.assertEqual(parsed_args.output_path, '/path/to/output') + + def test_normal(self): + json_dict.update({os.path.join('./step0', f'rank{i if i > 0 else ""}', 'dump.json'): gen_normal_dump_json(i) for i in range(4)}) + json_dict.update({os.path.join('./step0', f'rank{i if i > 0 else ""}', 'stack.json'): gen_stack_json(i) for i in range(4)}) + with patch('os.path.exists', return_value=True), \ + patch('msprobe.nan_analyze.analyzer.logger.info', print), \ + patch('msprobe.nan_analyze.analyzer.logger.warning', print), \ + patch('msprobe.nan_analyze.analyzer.FileCache', MockedFileCache): + self.analyzer.analyze() + self.assertFalse(bool(self.output)) + + def test_pre_anomaly(self): + json_dict.update({os.path.join('./step0', f'rank{i if i > 0 else ""}', 'dump.json'): gen_pre_anomaly_dump_json(i) for i in range(4)}) + json_dict.update({os.path.join('./step0', f'rank{i if i > 0 else ""}', 'stack.json'): gen_stack_json(i) for i in range(4)}) + with patch('os.path.exists', return_value=True), \ + patch('msprobe.nan_analyze.analyzer.save_json', self.mocked_save_json), \ + patch('msprobe.nan_analyze.analyzer.logger.info', print), \ + patch('msprobe.nan_analyze.analyzer.logger.warning', print), \ + patch('msprobe.nan_analyze.analyzer.FileCache', MockedFileCache), \ + patch('msprobe.nan_analyze.graph.FileCache', MockedFileCache), \ + patch('msprobe.nan_analyze.analyzer.time', mock_time): + mock_time.set_uni_value(1) + self.analyzer.analyze() + res_json = self.output.get(os.path.join('./output', 'anomaly_analyze_1.json')) + self.assertTrue(bool(res_json)) + self.assertEqual('Torch.operator.0.forward', res_json['rank_0'][0]['op_name']) + + def test_anomaly(self): + json_dict.update({os.path.join('./step0', f'rank{i if i > 0 else ""}', 'dump.json'): gen_anomaly_dump_json(i) for i in range(4)}) + json_dict.update({os.path.join('./step0', f'rank{i if i > 0 else ""}', 'stack.json'): gen_stack_json(i) for i in range(4)}) + with patch('os.path.exists', return_value=True), \ + patch('msprobe.nan_analyze.analyzer.save_json', self.mocked_save_json), \ + patch('msprobe.nan_analyze.analyzer.logger.info', print), \ + patch('msprobe.nan_analyze.analyzer.logger.warning', print), \ + patch('msprobe.nan_analyze.analyzer.FileCache', MockedFileCache), \ + patch('msprobe.nan_analyze.graph.FileCache', MockedFileCache), \ + patch('msprobe.nan_analyze.analyzer.time', mock_time): + mock_time.set_uni_value(2) + self.analyzer.analyze() + res_json = self.output.get(os.path.join('./output', 'anomaly_analyze_2.json')) + self.assertTrue(bool(res_json)) + self.assertEqual('Torch.operator.5.forward', res_json['rank_0'][0]['op_name']) + + def test_after_anomaly(self): + json_dict.update({os.path.join('./step0', f'rank{i if i > 0 else ""}', 'dump.json'): gen_after_anomaly_dump_json(i) for i in range(4)}) + json_dict.update({os.path.join('./step0', f'rank{i if i > 0 else ""}', 'stack.json'): gen_stack_json(i) for i in range(4)}) + with patch('os.path.exists', return_value=True), \ + patch('msprobe.nan_analyze.analyzer.save_json', self.mocked_save_json), \ + patch('msprobe.nan_analyze.analyzer.logger.info', print), \ + patch('msprobe.nan_analyze.analyzer.logger.warning', print), \ + patch('msprobe.nan_analyze.analyzer.FileCache', MockedFileCache), \ + patch('msprobe.nan_analyze.graph.FileCache', MockedFileCache), \ + patch('msprobe.nan_analyze.analyzer.time', mock_time): + mock_time.set_uni_value(3) + self.analyzer.analyze() + res_json = self.output.get(os.path.join('./output', 'anomaly_analyze_3.json')) + self.assertTrue(bool(res_json)) + self.assertEqual(res_json['rank_2'][0]['op_name'], 'Torch.operator.6.forward') + self.assertEqual(res_json['rank_3'][0]['op_name'], 'Torch.operator.6.forward') + + diff --git a/debug/accuracy_tools/msprobe/test/nan_analyze_ut/test_nan_analyzer_graph.py b/debug/accuracy_tools/msprobe/test/nan_analyze_ut/test_nan_analyzer_graph.py new file mode 100644 index 0000000000000000000000000000000000000000..3c1c43b3b109d4d232622cb581356e29f4eb83b8 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/nan_analyze_ut/test_nan_analyzer_graph.py @@ -0,0 +1,154 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2025. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import unittest +import os +from unittest.mock import patch + +from msprobe.nan_analyze.graph import CommunicationNode, DataNode +from msprobe.nan_analyze.utils import RankPath +from msprobe.core.common.exceptions import MsprobeException +from test_nan_analyzer import DumpDataBuilder, gen_normal_dump_json, MockedFileCache, json_dict, gen_stack_json, do_nothing + + +dump_json = {i: gen_normal_dump_json(i) for i in range(4)} + + +class TestCommunicationNode(unittest.TestCase): + def test_add_next(self): + op_name_0 = 'Distributed.send.0.forward' + op_name_1 = 'Distributed.recv.0.forward' + comm_node_0 = CommunicationNode(f'0.{op_name_0}', 0, DataNode(op_name_0, 0, dump_json[0]['data'][op_name_0])) + comm_node_1 = CommunicationNode(f'0.{op_name_1}', 0, DataNode(op_name_1, 0, dump_json[0]['data'][op_name_1])) + comm_node_0.add_next(comm_node_1) + self.assertEqual(comm_node_0.layer + 1, comm_node_1.layer) + self.assertTrue(comm_node_0 is comm_node_1.pre_node) + self.assertTrue(comm_node_1.node_id in comm_node_0.next_nodes) + + def test_add_link(self): + op_name = 'Distributed.all_gather.0.forward' + comm_node_0 = CommunicationNode(f'0.{op_name}', 0, DataNode(op_name, 0, dump_json[0]['data'][op_name])) + comm_node_1 = CommunicationNode(f'1.{op_name}', 1, DataNode(op_name, 1, dump_json[1]['data'][op_name])) + comm_node_0.add_link(comm_node_1) + self.assertEqual(comm_node_0.layer, comm_node_1.layer) + self.assertTrue(comm_node_0.node_id in comm_node_1.link_nodes) + self.assertTrue(comm_node_1.node_id in comm_node_0.link_nodes) + + def test_add_dst(self): + op_name = 'Distributed.broadcast.0.forward' + comm_node_0 = CommunicationNode(f'0.{op_name}', 0, DataNode(op_name, 0, dump_json[0]['data'][op_name])) + comm_node_1 = CommunicationNode(f'2.{op_name}', 2, DataNode(op_name, 2, dump_json[2]['data'][op_name])) + comm_node_0.add_dst(comm_node_1) + self.assertEqual(comm_node_0.layer, comm_node_1.layer) + self.assertTrue(comm_node_0.node_id in comm_node_1.src_nodes) + self.assertTrue(comm_node_1.node_id in comm_node_0.dst_nodes) + + def test_delete(self): + op_name = 'Distributed.broadcast.0.forward' + comm_node_0 = CommunicationNode(f'0.{op_name}', 0, DataNode(op_name, 0, dump_json[0]['data'][op_name])) + comm_node_1 = CommunicationNode(f'2.{op_name}', 2, DataNode(op_name, 2, dump_json[2]['data'][op_name])) + op_name = 'Distributed.recv.0.forward' + comm_node_2 = CommunicationNode(f'0.{op_name}', 0, DataNode(op_name, 0, dump_json[0]['data'][op_name])) + comm_node_2.add_next(comm_node_0) + comm_node_0.add_dst(comm_node_1) + comm_node_0.delete() + self.assertFalse(comm_node_1.src_nodes) + self.assertFalse(comm_node_2.next_nodes) + + def test_has_nan_inf(self): + op_name = 'Distributed.broadcast.0.forward' + comm_node_0 = CommunicationNode(f'0.{op_name}', 0, DataNode(op_name, 0, dump_json[0]['data'][op_name])) + self.assertFalse(comm_node_0.has_nan_inf()) + + def test_input_has_nan_inf(self): + op_name = 'Distributed.broadcast.0.forward' + comm_node_0 = CommunicationNode(f'0.{op_name}', 0, DataNode(op_name, 0, dump_json[0]['data'][op_name])) + self.assertFalse(comm_node_0.input_has_nan_inf()) + + def test_find_connected_nodes(self): + op_name = 'Distributed.broadcast.0.forward' + comm_node_0 = CommunicationNode(f'0.{op_name}', 0, DataNode(op_name, 0, dump_json[0]['data'][op_name])) + comm_node_1 = CommunicationNode(f'1.{op_name}', 1, DataNode(op_name, 1, dump_json[1]['data'][op_name])) + comm_node_2 = CommunicationNode(f'2.{op_name}', 2, DataNode(op_name, 2, dump_json[2]['data'][op_name])) + comm_node_3 = CommunicationNode(f'3.{op_name}', 3, DataNode(op_name, 3, dump_json[3]['data'][op_name])) + comm_node_0.add_dst(comm_node_1) + comm_node_0.add_dst(comm_node_2) + comm_node_0.add_dst(comm_node_3) + conn_info = comm_node_0.find_connected_nodes() + self.assertEqual(conn_info['ranks'], {0, 1, 2, 3}) + self.assertEqual(conn_info['api'], 'Distributed.broadcast') + self.assertEqual(conn_info['type'], 'dst') + + def test_resolve_type(self): + op_name = 'Distributed.broadcast.0.forward' + comm_node_0 = CommunicationNode(f'0.{op_name}', 0, DataNode(op_name, 0, dump_json[0]['data'][op_name])) + comm_node_1 = CommunicationNode(f'1.{op_name}', 1, DataNode(op_name, 1, dump_json[1]['data'][op_name])) + self.assertEqual(comm_node_0.type, 'src') + self.assertEqual(comm_node_1.type, 'dst') + + op_name = 'Distributed.all_gather.0.forward' + comm_node_2 = CommunicationNode(f'0.{op_name}', 0, DataNode(op_name, 0, dump_json[0]['data'][op_name])) + self.assertEqual(comm_node_2.type, 'link') + + +class TestDataNode(unittest.TestCase): + def setUp(self): + json_dict.update({os.path.join('./step0', f'rank{i if i > 0 else ""}', 'dump.json'): gen_normal_dump_json(i) for i in range(4)}) + json_dict.update({os.path.join('./step0', f'rank{i if i > 0 else ""}', 'stack.json'): gen_stack_json(i) for i in range(4)}) + json_dict[os.path.join('./step0', 'rank', 'construct.json')] = { + 'Torch.operator.1.forward': 'Module.module.test_model.forward.0', + 'Module.module.test_model.forward.0': 'Module.module.parent_model.forward.0', + 'Module.module.parent_model.forward.0': 'Module.module.root_model.forward.0', + 'Module.module.root_model.forward.0': None + } + + def test_find_stack(self): + with patch('msprobe.nan_analyze.graph.FileCache', MockedFileCache): + op_name = 'Torch.operator.1.forward' + data_node = DataNode(op_name, 0, dump_json[0]['data'][op_name]) + stack_info = data_node.find_stack(json_dict[os.path.join('./step0', 'rank', 'stack.json')]) + self.assertEqual(stack_info[0], 'File /root/example.py, line 10, in test_fcn, \\n test(tensor)') + with self.assertRaises(MsprobeException) as context: + data_node.find_stack({op_name: 'blabla'}) + self.assertEqual(context.exception.code, 4) + + def test_find_complete_construct(self): + with patch('msprobe.nan_analyze.graph.FileCache', MockedFileCache): + op_name = 'Torch.operator.1.forward' + construct = DataNode.find_complete_construct(json_dict[os.path.join('./step0', 'rank', 'construct.json')], + op_name) + self.assertEqual(len(construct), 4) + self.assertEqual(construct[0], 'Module.module.root_model.forward.0') + + def test_is_anomaly(self): + data_node_0 = DataNode('Torch.operator.1.forward', 0, DumpDataBuilder.gen_data(False, type='compute')) + data_node_1 = DataNode('Torch.operator.1.forward', 0, DumpDataBuilder.gen_data(True, type='compute')) + self.assertTrue(data_node_0.is_anomaly()) + self.assertFalse(data_node_1.is_anomaly()) + + def test_gen_node_info(self): + with patch('msprobe.nan_analyze.graph.FileCache', MockedFileCache), \ + patch('msprobe.nan_analyze.utils.check_file_or_directory_path', do_nothing): + op_name = 'Torch.operator.1.forward' + data_node = DataNode(op_name, 0, dump_json[0]['data'][op_name]) + node_info = data_node.gen_node_info(RankPath(0, os.path.join('./step0', 'rank', 'dump.json'), + os.path.join('./step0', 'rank', 'construct.json'), + os.path.join('./step0', 'rank', 'stack.json'))) + data_info = node_info['data_info'] + self.assertEqual(data_info['input_args'][0]['Max'], 2.0) + stack_info = node_info['stack_info'] + self.assertEqual(stack_info[0], 'File /root/example.py, line 10, in test_fcn, \\n test(tensor)') \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/test/nan_analyze_ut/test_nan_analyzer_utils.py b/debug/accuracy_tools/msprobe/test/nan_analyze_ut/test_nan_analyzer_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a1f37e2e38b0f68045826c062d7bb7bcda261fa5 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/nan_analyze_ut/test_nan_analyzer_utils.py @@ -0,0 +1,89 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2025. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import unittest +from unittest.mock import patch + +from msprobe.nan_analyze.utils import (FileCache, is_communication_op, is_ignore_op, check_item_anomaly, + analyze_anomaly_in_group) +from msprobe.nan_analyze.graph import CommunicationNode, DataNode +from test_nan_analyzer import DumpDataBuilder + + +json_dict = {chr(no): {f'test_{chr(no)}_{i}': [f'content_{j}' for j in range(10)] for i in range(10)} for no in range(ord('a'), ord('z') + 1)} + + +def mocked_load_json(json_path): + return json_dict.get(json_path) + + +class MockedMemory: + def __init__(self): + self.available = 100000 + + +def mocked_virtual_memory(): + return MockedMemory() + + +class TestFileCache(unittest.TestCase): + def test_load_json(self): + with patch('msprobe.nan_analyze.utils.load_json', mocked_load_json), \ + patch('psutil.virtual_memory', mocked_virtual_memory): + cache = FileCache() + self.assertFalse('a' in cache._cache) + a = cache.load_json('a') + self.assertTrue('a' in cache._cache) + self.assertTrue('test_a_5' in a) + + def test_clean_up(self): + with patch('msprobe.nan_analyze.utils.load_json', mocked_load_json), \ + patch('psutil.virtual_memory', mocked_virtual_memory): + cache = FileCache() + for _ in range(100): + cache.load_json('a') + for i, no in enumerate(range(ord('a'), ord('g'))): + cache.load_json(chr(no)) + self.assertEqual('b' in cache._cache, 0 < i < 3) + self.assertTrue('a' in cache._cache) + +class TestUtils(unittest.TestCase): + def test_is_communication_op(self): + self.assertTrue(is_communication_op('Distributed.broadcast.0.forward')) + self.assertFalse(is_communication_op('Torch.operator.1.forward')) + + def test_is_ignore_op(self): + self.assertTrue(is_ignore_op('Torch.empty.1.forward')) + self.assertFalse(is_ignore_op('Torch.operator.1.forward')) + + def test_check_item_anomaly(self): + self.assertTrue(check_item_anomaly(DumpDataBuilder.gen_data(False, type='compute')['output'])) + self.assertFalse(check_item_anomaly(DumpDataBuilder.gen_data(True, type='compute')['output'])) + + def test_analyze_anomaly_in_group(self): + name = 'broadcast' + data_type = 'p2g_src' + src = 0 + dst = 0 + ranks = [0, 1, 2, 3] + op_name = f'Distributed.{name}.0.forward' + data = DumpDataBuilder.gen_data(False, name=name, type=data_type, src=src, dst=dst, ranks=ranks, + is_input_normal=True, is_output_normal=False) + node_id = f'0.{op_name}' + node = CommunicationNode(node_id, 0, DataNode(op_name, 0, data)) + anomalies = analyze_anomaly_in_group([node]) + self.assertEqual(anomalies[0].op_name, op_name) \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py index df03485dc6c77371750fd0b67ca2c37ff7e2ed7b..30fa11d94de0dd4fec483502a51d0474e8b7646a 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py @@ -16,6 +16,8 @@ class TestUtConfig(): self.port = 8080 self.rank_list = [0, 1, 2] self.tls_path = '/path/to/tls' + self.master_ip = '127.0.0.1' + self.master_port = 8888 class TestConfig(unittest.TestCase): diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py index 377a29f2237e2b3172e6fc35a712ff36cc69972d..f1cc0d31363c326c3412824f4a5a176b70da1a90 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py @@ -208,3 +208,45 @@ class TestAlgorithmMethods(unittest.TestCase): ulp_err = alg.calc_ulp_err(self.bench_data, self.device_data, eb, exponent_num, data_type) expected_ulp_err = (self.device_data.astype(data_type) - self.bench_data).astype(data_type) * np.exp2(-eb + exponent_num) self.assertTrue(np.allclose(ulp_err, expected_ulp_err)) + + +class TestKahanLossRange(unittest.TestCase): + + def setUp(self): + self.cumsum = torch.tensor( + [[1000, 30], [1, 20], [10, 10]], dtype=torch.bfloat16) + self.addend = torch.tensor([[3, 0.2]], dtype=torch.bfloat16) + self.tensors = [ + torch.tensor([1000], dtype=torch.bfloat16), + torch.tensor([1004], dtype=torch.bfloat16), + torch.tensor([103], dtype=torch.bfloat16), + torch.tensor([4], dtype=torch.bfloat16)] + + def test_kahan_loss_positive(self): + # 测试最大化需要补偿的正损失, loss_res为历史损失中最大值,且mask会遮蔽小于0的部分 + loss_res, mask = alg.maximize_kahan_loss(self.cumsum, self.addend, negative=False) + expected_loss = torch.tensor([1, 0.0498], dtype=torch.bfloat16) + expected_mask = expected_loss >= 0 + self.assertTrue(torch.allclose(loss_res, expected_loss)) + self.assertTrue(torch.allclose(mask, expected_mask)) + + def test_kahan_loss_negative(self): + # 测试最大化需要补偿的负损失, loss_res为历史损失中最小值,且mask会遮蔽大于0的部分 + loss_res, mask = alg.maximize_kahan_loss(self.cumsum, self.addend, negative=True) + expected_loss = torch.tensor([0, -0.0127], dtype=torch.bfloat16) + expected_mask = expected_loss <= 0 + self.assertTrue(torch.allclose(loss_res, expected_loss)) + self.assertTrue(torch.allclose(mask, expected_mask)) + + def test_kahan_range_empty_list(self): + # 测试输入为空列表的情况 + with self.assertRaises(ValueError): + alg.kahan_range([]) + + def test_kahan_range_min_max(self): + max_ = alg.kahan_range(self.tensors, negative=True) + min_ = alg.kahan_range(self.tensors, negative=False) + expected_min = torch.tensor(2096, dtype=torch.bfloat16) + expected_max = torch.tensor(2112, dtype=torch.bfloat16) + self.assertTrue(torch.allclose(min_, expected_min)) + self.assertTrue(torch.allclose(max_, expected_max)) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py index 0a88476d600958b26eaf6ca20a9a70d35b4221cc..952a6dffbc85eea9dd2db87fa081bdf4bb3cae2a 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py @@ -322,7 +322,7 @@ class TestDataGenerateMethods(unittest.TestCase): low_info = [1, float('-inf')] high_info = [2, float('-inf')] tensor = gen_common_tensor(low_info, high_info, shape, data_dtype, None) - self.assertTrue(torch.allclose(tensor.max(), torch.tensor(2.0), atol = 0.3)) + self.assertTrue(torch.allclose(tensor.max(), torch.tensor(2.0), atol = 0.5)) self.assertTrue(tensor.min() == float('-inf')) low_info = [1, float('nan')] diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_distributed_bench_function.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_distributed_bench_function.py new file mode 100644 index 0000000000000000000000000000000000000000..0b21a9559e90acec80c9cb4726d8ce039ddb6a71 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_distributed_bench_function.py @@ -0,0 +1,29 @@ +import torch +import unittest + +from msprobe.pytorch.api_accuracy_checker.run_ut.distributed_bench_function import sort_all_input + +class TestSortAllInput(unittest.TestCase): + def setUp(self): + self.inputs = [ + torch.tensor([3.0, 2.0, 1.0]), + torch.tensor([6.0, 5.0, 4.0]), + torch.tensor([9.0, 8.0, 7.0]) + ] + + def test_normal_case(self): + # 测试正常情况 + sorted_inputs = sort_all_input(self.inputs) + expected_sorted_inputs = [ + torch.tensor([9.0, 8.0, 7.0]), + torch.tensor([6.0, 5.0, 4.0]), + torch.tensor([3.0, 2.0, 1.0]) + ] + for result, expected in zip(sorted_inputs, expected_sorted_inputs): + self.assertTrue(torch.equal(result, expected)) + + def test_single_tensor(self): + # 测试只有一个张量的情况 + single_input = [torch.tensor([2.0])] + sorted_inputs = sort_all_input(single_input) + self.assertTrue(torch.equal(sorted_inputs[0], single_input[0])) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py index 1ad191a0d4e85715e6199367d1d305c10a728630..8eb8fde4fdca88c97a4165f541f6dd6e7133303f 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py @@ -136,7 +136,7 @@ class TestMultiRunUT(unittest.TestCase): def setUp(self): self.test_json_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), "dump.json") - self.test_data = {'data': {'key1': 'TRUE', 'key2': 'TRUE', 'key3': 'TRUE'}} + self.test_data = {'dump_data_dir': '/test', 'data': {'key1': 'TRUE', 'key2': 'TRUE', 'key3': 'TRUE'}} self.test_json_content = json.dumps(self.test_data) self.forward_split_files_content = [ {'key1': 'TRUE', 'key2': 'TRUE'}, diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py index cb54b4ccfef5c1aa19c4a3527b6b5cfdac7dcc77..2ea3d849096d91c59a75b072ed91d4370673b6d5 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py @@ -58,80 +58,80 @@ class TestFileCheck(unittest.TestCase): def test_config_path_soft_link_check(self): args = Args(config_path=self.soft_json_path, api_info_path=self.hard_json_path, out_path=self.hard_path) - + with self.assertRaises(Exception) as context: run_ut_command(args) self.assertEqual(context.exception.code, FileCheckException.SOFT_LINK_ERROR) def test_api_info_path_soft_link_check(self): args = Args(config_path=self.hard_json_path, api_info_path=self.soft_json_path, out_path=self.hard_path) - + with self.assertRaises(Exception) as context: run_ut_command(args) self.assertEqual(context.exception.code, FileCheckException.SOFT_LINK_ERROR) def test_out_path_soft_link_check(self): args = Args(config_path=self.hard_json_path, api_info_path=self.hard_json_path, out_path=self.soft_path) - + with self.assertRaises(Exception) as context: run_ut_command(args) self.assertEqual(context.exception.code, FileCheckException.SOFT_LINK_ERROR) - + def test_result_csv_path_soft_link_check(self): - args = Args(config_path=self.hard_json_path, api_info_path=self.hard_json_path, out_path=self.hard_path, + args = Args(config_path=self.hard_json_path, api_info_path=self.hard_json_path, out_path=self.hard_path, result_csv_path=self.csv_path) - + with self.assertRaises(Exception) as context: run_ut_command(args) self.assertEqual(context.exception.code, FileCheckException.SOFT_LINK_ERROR) def test_config_path_empty_check(self): args = Args(config_path=self.empty_path, api_info_path=self.hard_json_path, out_path=self.hard_path) - + with self.assertRaises(Exception) as context: run_ut_command(args) self.assertEqual(context.exception.code, FileCheckException.ILLEGAL_PATH_ERROR) - + def test_api_info_path_empty_check(self): args = Args(config_path=self.hard_json_path, api_info_path=self.empty_path, out_path=self.hard_path) - + with self.assertRaises(Exception) as context: run_ut_command(args) self.assertEqual(context.exception.code, FileCheckException.ILLEGAL_PATH_ERROR) - + def test_out_path_empty_check(self): args = Args(config_path=self.hard_json_path, api_info_path=self.hard_json_path, out_path=self.empty_path) with self.assertRaises(Exception) as context: run_ut_command(args) self.assertEqual(context.exception.code, FileCheckException.ILLEGAL_PATH_ERROR) - + def test_result_csv_path_empty_check(self): - args = Args(config_path=self.hard_json_path, api_info_path=self.hard_json_path, out_path=self.hard_path, + args = Args(config_path=self.hard_json_path, api_info_path=self.hard_json_path, out_path=self.hard_path, result_csv_path=self.empty_path) with self.assertRaises(Exception) as context: run_ut_command(args) self.assertEqual(context.exception.code, FileCheckException.ILLEGAL_PATH_ERROR) - + def test_config_path_invalid_check(self): args = Args(config_path=123, api_info_path=self.hard_json_path, out_path=self.hard_path) with self.assertRaises(Exception) as context: run_ut_command(args) self.assertEqual(context.exception.code, FileCheckException.ILLEGAL_PATH_ERROR) - + def test_api_info_path_invalid_check(self): args = Args(config_path=self.hard_json_path, api_info_path="123", out_path=self.hard_path) with self.assertRaises(Exception) as context: run_ut_command(args) self.assertEqual(context.exception.code, FileCheckException.ILLEGAL_PATH_ERROR) - + def test_out_path_invalid_check(self): args = Args(config_path=self.hard_json_path, api_info_path=self.hard_json_path, out_path=123) with self.assertRaises(Exception) as context: run_ut_command(args) self.assertEqual(context.exception.code, FileCheckException.ILLEGAL_PATH_ERROR) - + def test_result_csv_path_invalid_check(self): - args = Args(config_path=self.hard_json_path, api_info_path=self.hard_json_path, out_path=self.hard_path, + args = Args(config_path=self.hard_json_path, api_info_path=self.hard_json_path, out_path=self.hard_path, result_csv_path=123) with self.assertRaises(Exception) as context: run_ut_command(args) @@ -196,26 +196,26 @@ class TestRunUtMethods(unittest.TestCase): self.assertIsNone(data_info.bench_output) self.assertIsNone(data_info.grad_in) self.assertIsNone(data_info.in_fwd_data_list) - + def test_blacklist_and_whitelist_filter(self): api_name = "test_api" black_list = ["test_api"] white_list = [] result = blacklist_and_whitelist_filter(api_name, black_list, white_list) self.assertTrue(result) - + api_name = "test_api" black_list = [] white_list = ["another_api"] result = blacklist_and_whitelist_filter(api_name, black_list, white_list) self.assertTrue(result) - + api_name = "test_api" black_list = ["test_api"] white_list = ["test_api"] result = blacklist_and_whitelist_filter(api_name, black_list, white_list) self.assertTrue(result) - + api_name = "test_api" black_list = [] white_list = ["test_api"] @@ -230,23 +230,23 @@ class TestRunUtMethods(unittest.TestCase): api_name = "Distributed.all_reduce" result = is_unsupported_api(api_name) self.assertTrue(result) - + def test_no_backward(self): grad_index = None - out = (1, 2, 3) + out = (1, 2, 3) result = need_to_backward(grad_index, out) self.assertFalse(result) grad_index = 0 - out = 42 + out = 42 result = need_to_backward(grad_index, out) self.assertTrue(result) + class TestRunUtOnlineConfig(unittest.TestCase): - @patch('msprobe.pytorch.api_accuracy_checker.run_ut.run_ut.check_crt_valid') - def test_checked_online_config(self, mock_check_crt_valid): + def test_checked_online_config(self): class OnlineConfigClass: is_online = True rank_list = [0, 1] @@ -255,8 +255,6 @@ class TestRunUtOnlineConfig(unittest.TestCase): host = "127.0.0.1" port = 12345 - mock_check_crt_valid.return_value = None - online_config = OnlineConfigClass() res = checked_online_config(online_config) self.assertIsNone(res) @@ -297,6 +295,8 @@ class TestRunUtOnlineConfig(unittest.TestCase): file.write("1") with open(os.path.join(online_config.tls_path, "server.crt"), 'w') as file: file.write("1") + with open(os.path.join(online_config.tls_path, "ca.crt"), 'w') as file: + file.write("1") checked_online_config(online_config) shutil.rmtree(online_config.tls_path) online_config.tls_path = "" @@ -314,3 +314,7 @@ class TestRunUtOnlineConfig(unittest.TestCase): checked_online_config(online_config) self.assertIn(str(context.exception), f"port: {online_config.port} is invalid, port range 0-65535.") online_config.port = 6123 + + +if __name__ == '__main__': + unittest.main() diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut_utils.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut_utils.py index 0cf30461aec70b85577c38ebed011bf9f818874d..751d3f6affd10c82f9aeee941bed8cf5453daad8 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut_utils.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut_utils.py @@ -1,13 +1,28 @@ -# coding=utf-8 +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import unittest -from unittest.mock import patch, MagicMock + import torch + from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import * from msprobe.core.common.file_utils import create_directory, write_csv class TestRunUtUtils(unittest.TestCase): - + def setUp(self): save_path = "temp_save_path" create_directory(save_path) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/tensor_transport_layer/test_pt_accuracy_server.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/tensor_transport_layer/test_pt_accuracy_server.py index b60cfdc323bed57e1cda1fc2d9db3197638cee4c..726714b7993081044a2ca6909db357d3995ad296 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/tensor_transport_layer/test_pt_accuracy_server.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/tensor_transport_layer/test_pt_accuracy_server.py @@ -86,8 +86,8 @@ class TestServerProtocol(unittest.TestCase): ]) self.server_protocol.transport.write.called_once_with(expected_value) - @patch("msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.server.hashlib.md5") - def test_post_process_error(self, mock_hashlib_md5): + @patch("msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.server.zlib.crc32") + def test_post_process_error(self, mock_zlib_crc32): self.shared_queue.maxsize = 1 self.server_protocol.send_ack = MagicMock() @@ -99,17 +99,18 @@ class TestServerProtocol(unittest.TestCase): self.server_protocol.send_ack.side_effect = [mock_send_ack_method1, mock_send_ack_method2] self.server_protocol.check_sum = True - mock_hashlib_md5.hexdiges.return_value = "123" + mock_zlib_crc32.return_value = 123 self.server_protocol.rank = 0 self.server_protocol.step = 0 self.server_protocol.post_process() - mock_hashlib_md5.assert_called() + mock_zlib_crc32.assert_called() self.server_protocol.send_ack.assert_any_call(self.server_protocol.ACK_ERROR) self.assertEqual(self.server_protocol.rank, -1) self.assertEqual(self.server_protocol.step, -1) - @patch("msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.server.hashlib.md5") - def test_post_process_success(self, _): + @patch("msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.server.zlib.crc32") + def test_post_process_success(self, mock_zlib_crc32): + mock_zlib_crc32.return_value = 123 self.shared_queue.maxsize = 1 self.server_protocol.send_ack = MagicMock() diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/tensor_transport_layer/test_pt_device_dispatch.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/tensor_transport_layer/test_pt_device_dispatch.py index 79f569cdcaeaa662f403b73fd4047caf7c2f0311..5df5ee879287931512dcca5f7de2daca5bcef284 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/tensor_transport_layer/test_pt_device_dispatch.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/tensor_transport_layer/test_pt_device_dispatch.py @@ -33,7 +33,6 @@ class TestDeviceDispatchFunc(unittest.TestCase): mock_consumer_queue.get.side_effect = [mock_api_data, "KILL_"] run_ut_process(xpu_id, mock_consumer_queue, None, None) - mock_torch.device.assert_called_once_with('cuda:1') mock_online_compare.assert_called_with(mock_api_data, mock_torch.device(), None) @patch("msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.device_dispatch.UtDataInfo") diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/tensor_transport_layer/test_ttl_utils.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/tensor_transport_layer/test_ttl_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b363c5f0316ad7f66d660b5644b5b25a104be82f --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/tensor_transport_layer/test_ttl_utils.py @@ -0,0 +1,44 @@ +import unittest +from unittest.mock import Mock, patch + +from OpenSSL import crypto, SSL + +from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.utils import verify_callback, is_certificate_revoked + + +class TestVerifyCallback(unittest.TestCase): + """ + Test for verify_callback and is_certificate_revoked. + """ + + def setUp(self): + self.conn = Mock(spec=SSL.Connection) + self.cert = Mock(spec=crypto.X509) + self.crl = [Mock()] + self.crl[0].serial_number = 89981275109692867917699502952114227065605526936 + + @patch('msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.utils.is_certificate_revoked') + def test_preverify_ok(self, mock_is_certificate_revoked): + mock_is_certificate_revoked.return_value = False + self.assertTrue(verify_callback(self.conn, self.cert, 0, 0, 1, self.crl)) + + @patch('msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.utils.is_certificate_revoked') + def test_preverify_not_ok(self, mock_is_certificate_revoked): + self.assertFalse(verify_callback(self.conn, self.cert, 0, 0, 0, None)) + + mock_is_certificate_revoked.return_value = False + self.assertEqual(verify_callback(self.conn, self.cert, 0, 0, 1, self.crl), 1) + + def test_is_certificate_revoked_true(self): + self.cert.get_serial_number.return_value = 89981275109692867917699502952114227065605526936 + result = is_certificate_revoked(self.cert, self.crl) + self.assertTrue(result) + + def test_is_certificate_revoked_false(self): + self.cert.get_serial_number.return_value = 89981275109692867917699502952114227065605526937 + result = is_certificate_revoked(self.cert, self.crl) + self.assertFalse(result) + + +if __name__ == '__main__': + unittest.main() diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/common/test_pt_utils.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/common/test_pt_utils.py index cdc922cc98d59b59ec0be85833d2000cd38913c8..0a25e6edf5983df968cd788e55348643e8098438 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/common/test_pt_utils.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/common/test_pt_utils.py @@ -1,17 +1,44 @@ -import os +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import io +import os +import tempfile import unittest from unittest.mock import MagicMock, patch -import tempfile import torch import torch.distributed as dist - -from msprobe.core.common.file_utils import FileCheckConst from msprobe.core.common.exceptions import DistributedNotInitializedError +from msprobe.core.common.file_utils import FileCheckConst from msprobe.pytorch.api_accuracy_checker.common.utils import ApiData -from msprobe.pytorch.common.utils import parameter_adapter, get_rank_if_initialized, \ - get_tensor_rank, get_rank_id, print_rank_0, load_pt, save_pt, save_api_data, load_api_data, save_pkl, load_pkl +from msprobe.pytorch.common.utils import ( + parameter_adapter, + get_rank_if_initialized, + get_tensor_rank, + get_rank_id, + print_rank_0, + load_pt, + save_pt, + save_api_data, + load_api_data, + save_pkl, + load_pkl, + is_float8_tensor, + is_hifloat8_tensor +) class TestParameterAdapter(unittest.TestCase): @@ -19,7 +46,7 @@ class TestParameterAdapter(unittest.TestCase): def setUp(self): self.func_mock = MagicMock() self.decorated_func = parameter_adapter(self.func_mock) - self.op_name_ = "__getitem__" + self.api_name = "__getitem__" def test_handle_masked_select_bfloat16(self): input_tensor = torch.tensor([1.0, 2.0], dtype=torch.bfloat16) @@ -45,7 +72,7 @@ class TestParameterAdapter(unittest.TestCase): self.assertTrue(torch.equal(result, torch.tensor([20.0, 30.0]))) def test_op_name_eq_with_none(self): - self.op_name_ = "__eq__" + self.api_name = "__eq__" args = (torch.tensor([1]), None) result = self.decorated_func(self, *args) self.assertFalse(result) @@ -186,6 +213,12 @@ class TestSavePT(unittest.TestCase): self.tensor = torch.tensor([1, 2, 3]) self.filepath = 'temp_tensor.pt' + def tearDown(self): + try: + os.remove(self.filepath) + except FileNotFoundError: + pass + @patch('msprobe.pytorch.common.utils.save_pt') @patch('os.path.realpath', return_value='temp_tensor.pt') @patch('msprobe.core.common.file_utils.check_path_before_create') @@ -193,21 +226,6 @@ class TestSavePT(unittest.TestCase): def test_save_pt_success(self, mock_change_mode, mock_check_path, mock_realpath, mock_torch_save): mock_torch_save(self.tensor, self.filepath) mock_torch_save.assert_called_once_with(self.tensor, self.filepath) - mock_change_mode.assert_called_once_with(self.filepath, FileCheckConst.DATA_FILE_AUTHORITY) - -class TestSavePT(unittest.TestCase): - - def setUp(self): - self.tensor = torch.tensor([1, 2, 3]) - self.filepath = 'temp_tensor.pt' - - @patch('torch.save') - @patch('os.path.realpath', return_value='temp_tensor.pt') - @patch('msprobe.core.common.file_utils.check_path_before_create') - @patch('msprobe.core.common.file_utils.change_mode') - def test_save_pt_success(self, mock_change_mode, mock_check_path, mock_realpath, mock_torch_save): - save_pt(self.tensor, self.filepath) - mock_torch_save.assert_called_once_with(self.tensor, self.filepath) @patch('torch.save', side_effect=Exception("Save failed")) @patch('os.path.realpath', return_value='temp_tensor.pt') @@ -218,12 +236,6 @@ class TestSavePT(unittest.TestCase): save_pt(self.tensor, self.filepath) self.assertIn("save pt file temp_tensor.pt failed", str(context.exception)) - def tearDown(self): - try: - os.remove(self.filepath) - except FileNotFoundError: - pass - class TestSaveApiData(unittest.TestCase): @@ -299,3 +311,24 @@ class TestSavePkl(unittest.TestCase): load_pkl(self.filepath) self.assertIn("Unsupported object type: os.system", str(context.exception)) os.remove(self.filepath) + +class TestFloat8Tensor(unittest.TestCase): + def setUp(self): + self.tensor = MagicMock() + + def test_is_float8_tensor(self): + self.tensor.dtype = "torch.float8_e5m2" + res = is_float8_tensor(self.tensor) + self.assertTrue(res) + + self.tensor.dtype = "torch.float8_e4m3fn" + res = is_float8_tensor(self.tensor) + self.assertTrue(res) + + def test_is_not_float8_tensor(self): + self.tensor.dtype = 123 + res = is_float8_tensor(self.tensor) + self.assertFalse(res) + + res = is_hifloat8_tensor(self.tensor) + self.assertFalse(res) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/compare/test_match.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/compare/test_match.py deleted file mode 100644 index ac28e994e9c8e77f8ae675fec3322eaf64a64321..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/compare/test_match.py +++ /dev/null @@ -1,20 +0,0 @@ -# coding=utf-8 -import unittest -from msprobe.pytorch.compare import match - - -class TestMatch(unittest.TestCase): - def test_graph_mapping(self): - op1 = "Aten_convolution_1_forward_0.input.0" - op2 = "Torch_conv2d_0_forward_0.input.0" - op3 = "Torch_batch_norm_0_forward_0.input.0" - op4 = "Aten_convolution.default_1_forward_0.input.0" - op5 = "Aten_foo_1_forward_0.input.0" - self.assertTrue(match.graph_mapping.match(op1, op2)) - self.assertTrue(match.graph_mapping.match(op2, op1)) - self.assertTrue(match.graph_mapping.match(op4, op2)) - self.assertTrue(match.graph_mapping.match(op2, op4)) - self.assertFalse(match.graph_mapping.match(op1, op3)) - self.assertFalse(match.graph_mapping.match(op3, op1)) - self.assertFalse(match.graph_mapping.match(op5, op2)) - self.assertFalse(match.graph_mapping.match(op2, op5)) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/compare/test_pt_compare.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/compare/test_pt_compare.py index b079e646c4a8f4098bb233e3e6259ef3ebea9c94..e4c8b722b182b8c0a4e82ba1b0eeb1a6ed847ee2 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/compare/test_pt_compare.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/compare/test_pt_compare.py @@ -3,16 +3,12 @@ import os import shutil import unittest -import numpy as np import torch -from msprobe.core.common.const import Const from msprobe.core.common.utils import CompareException -from msprobe.core.compare.acc_compare import ModeConfig -from msprobe.pytorch.compare.pt_compare import PTComparator, compare +from msprobe.pytorch.compare.pt_compare import compare from msprobe.test.core_ut.compare.test_acc_compare import generate_dump_json, generate_stack_json - base_dir1 = os.path.join(os.path.dirname(os.path.abspath(__file__)), f'test_pt_compare1') base_dir2 = os.path.join(os.path.dirname(os.path.abspath(__file__)), f'test_pt_compare2') @@ -40,36 +36,6 @@ class TestUtilsMethods(unittest.TestCase): if os.path.exists(base_dir2): shutil.rmtree(base_dir2) - def test_read_npy_data_bf16(self): - generate_bf16_pt(base_dir1) - - stack_mode = True - auto_analyze = True - fuzzy_match = False - dump_mode = Const.ALL - mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) - - pt_comparator = PTComparator(mode_config) - result = pt_comparator.read_npy_data(base_dir1, 'bf16.pt') - - target_result = torch.tensor([1, 2, 3, 4], dtype=torch.float32).numpy() - self.assertTrue(np.array_equal(result, target_result)) - - def test_read_npy_data_dict(self): - generate_dict_pt(base_dir1) - - stack_mode = True - auto_analyze = True - fuzzy_match = False - dump_mode = Const.ALL - mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) - - pt_comparator = PTComparator(mode_config) - - with self.assertRaises(CompareException) as context: - result = pt_comparator.read_npy_data(base_dir1, 'dict.pt') - self.assertEqual(context.exception.code, CompareException.DETACH_ERROR) - def test_compare(self): generate_dump_json(base_dir2) generate_stack_json(base_dir2) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/compare/test_pt_compare_utils.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/compare/test_pt_compare_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..558df47a108f27858cc571f6854ca3f403fc6fee --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/compare/test_pt_compare_utils.py @@ -0,0 +1,72 @@ +import os +import shutil +import threading +import unittest +from unittest import mock +from unittest.mock import patch + +import numpy as np + +from msprobe.pytorch.compare import utils +from msprobe.pytorch.compare.utils import read_pt_data +from msprobe.test.core_ut.compare.test_acc_compare import generate_pt +from msprobe.core.common.utils import CompareException + + +base_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), f'test_pt_compare_utils_data') +pt_dir = os.path.join(base_dir, f'dump_data_dir') + + +class TestReadPtData(unittest.TestCase): + + def setUp(self): + os.makedirs(base_dir, mode=0o750, exist_ok=True) + os.makedirs(pt_dir, mode=0o750, exist_ok=True) + + self.lock = threading.Lock() + + def tearDown(self): + if os.path.exists(pt_dir): + shutil.rmtree(pt_dir) + if os.path.exists(base_dir): + shutil.rmtree(base_dir) + + def test_read_pt_data_normal(self): + generate_pt(pt_dir) + result = read_pt_data(pt_dir, 'Functional.linear.0.forward.input.0.pt') + expected = np.array([1.0, 2.0, 3.0, 4.0]) + self.assertTrue(np.array_equal(result, expected)) + + def test_read_pt_data_no_file_name(self): + result = read_pt_data(pt_dir, None) + self.assertEqual(result, None) + + @patch.object(utils, 'load_pt') + @patch.object(utils, 'FileChecker') + def test_read_pt_data_runtime_error(self, mock_file_checker_class, mock_load_pt): + mock_file_checker = mock.Mock() + mock_file_checker.common_check.return_value = 'fake/path/file.pt' + mock_file_checker_class.return_value = mock_file_checker + + mock_load_pt.side_effect = RuntimeError('failed to load') + + with self.assertRaises(CompareException) as context: + read_pt_data('fake/path', 'file.pt') + self.assertEqual(context.exception.code, CompareException.INVALID_FILE_ERROR) + + @patch.object(utils, 'load_pt') + @patch.object(utils, 'FileChecker') + def test_read_pt_data_attribute_error(self, mock_file_checker_class, mock_load_pt): + mock_file_checker = mock.Mock() + mock_file_checker.common_check.return_value = 'fake/path/file.pt' + mock_file_checker_class.return_value = mock_file_checker + + class FakeTensor: + def detach(self): + raise AttributeError('no detach') + + mock_load_pt.return_value = FakeTensor() + + with self.assertRaises(CompareException) as context: + read_pt_data('fake/path', 'file.pt') + self.assertEqual(context.exception.code, CompareException.DETACH_ERROR) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/config_checking/bench.sh b/debug/accuracy_tools/msprobe/test/pytorch_ut/config_checking/bench.sh new file mode 100644 index 0000000000000000000000000000000000000000..217676ef0f451b6b8f2d2cecb14545d9a7f8dd8b --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/config_checking/bench.sh @@ -0,0 +1,25 @@ +MASTER_PORT=6000 +NNODES=1 +NODE_RANK=0 +CKPT_SAVE_DIR="your model save ckpt path" +DATA_PATH="your data path" +TOKENIZER_MODEL="your tokenizer path" +CKPT_LOAD_DIR="your model ckpt path" +TP=1 + +DISTRIBUTED_ARGS=" + --master_port $MASTER_PORT +" + +GPT_ARGS=" + --tensor-model-parallel-size ${TP} \ + --sequence-parallel \ + --tokenizer-model ${TOKENIZER_MODEL} \ +" + +torchrun $DISTRIBUTED_ARGS pretrain_gpt.py \ + $GPT_ARGS \ + --distributed-backend nccl \ + --load $CKPT_LOAD_DIR \ + --save $CKPT_SAVE_DIR \ + | tee logs/train_llama2_7b.log \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/config_checking/cmp.sh b/debug/accuracy_tools/msprobe/test/pytorch_ut/config_checking/cmp.sh new file mode 100644 index 0000000000000000000000000000000000000000..8df9e6507975c7edbcfee105d838563171c720e4 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/config_checking/cmp.sh @@ -0,0 +1,25 @@ +MASTER_PORT=6001 +NNODES=1 +NODE_RANK=0 +CKPT_SAVE_DIR="./aaa" +DATA_PATH="./aaa" +TOKENIZER_MODEL="./aaa" +CKPT_LOAD_DIR="./aaa" +TP=2 + +DISTRIBUTED_ARGS=" + --master_port $MASTER_PORT +" + +GPT_ARGS=" + --tensor-model-parallel-size ${TP} \ + --sequence-parallel \ + --tokenizer-model ${TOKENIZER_MODEL} \ +" + +torchrun $DISTRIBUTED_ARGS pretrain_gpt.py \ + $GPT_ARGS \ + --distributed-backend nccl \ + --load $CKPT_LOAD_DIR \ + --save $CKPT_SAVE_DIR \ + | tee logs/train_llama2_7b.log \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/config_checking/test_config_checking.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/config_checking/test_config_checking.py new file mode 100644 index 0000000000000000000000000000000000000000..27b6b6e4364ff440a74d9619d1439e349a696efe --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/config_checking/test_config_checking.py @@ -0,0 +1,131 @@ +import os +import random +import shutil +import unittest +import torch +import json +import numpy as np +import torch.nn as nn +from msprobe.pytorch.config_checking.config_checker import ConfigChecker +from msprobe.pytorch.config_checking.checkers.pip_checker import PipPackageChecker +from msprobe.pytorch.config_checking.checkers.random_checker import RandomChecker +from msprobe.pytorch.config_checking.checkers.dataset_checker import DatasetChecker +from msprobe.pytorch.config_checking.checkers.weights_checker import WeightsChecker +from msprobe.pytorch.config_checking.checkers.random_checker import apply_patches +from msprobe.core.common.file_utils import read_xlsx + +testdir = os.path.dirname(__file__) +config_checking_dir = os.path.dirname(testdir) +temp_dir = os.path.join(config_checking_dir, "temp") +os.makedirs(temp_dir, exist_ok=True) + + +def seed_all(seed=1234, mode=False): + random.seed(seed) + os.environ['PYTHONHASHSEED'] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.use_deterministic_algorithms(mode) + + +class MockModule(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(10, 5) + self.relu = nn.ReLU() + + def forward(self, x, y): + x1 = self.linear(x) + x2 = self.relu(x1) + return x2 + + +def get_test_dataset(): + inputs = [torch.rand(10, 10) for _ in range(10)] + labels = [torch.randint(0, 5, (10,)) for _ in range(10)] + return zip(inputs, labels) + + +def get_test_model(): + test_module = MockModule() + nn.init.constant_(test_module.linear.weight, 1.0) + nn.init.constant_(test_module.linear.bias, 1.0) + return test_module + + +@unittest.mock.patch("msprobe.pytorch.config_checking.checkers.pip_checker.collect_pip_data") +@unittest.mock.patch("msprobe.pytorch.config_checking.checkers.env_args_checker.collect_env_data") +def train_test(seed, output_zip_path, shell_path, mock_env, mock_pip): + mock_env.return_value = {"HCCL_DETERMINISTIC": False} + if seed == 1234: + mock_pip.return_value = "transformers=0.0.1" + else: + mock_pip.return_value = "transformers=0.0.2" + seed_all(seed) + + loss_fun = nn.CrossEntropyLoss() + test_module = get_test_model() + optimizer = torch.optim.SGD(test_module.parameters(), lr=1e-2) + + ConfigChecker(test_module, shell_path, output_zip_path) + + try: + for input_data, label in get_test_dataset(): + output = test_module(input_data, y=input_data) + loss = loss_fun(output, label) + optimizer.zero_grad() + loss.backward() + optimizer.step() + except Exception: + pass + + +class TestConfigChecker(unittest.TestCase): + def tearDown(self): + shutil.rmtree(temp_dir) + + def test_all(self): + train_test(1234, os.path.join(temp_dir, "config_check_pack1.zip"), [os.path.join(testdir, "cmp.sh")]) + + ConfigChecker.pre_forward_fun_list = [] + ConfigChecker.step = 0 + RandomChecker.write_once = False + apply_patches() + + train_test(1233, os.path.join(temp_dir, "config_check_pack2.zip"), [os.path.join(testdir, "bench.sh")]) + + ConfigChecker.compare(os.path.join(temp_dir, "config_check_pack1.zip"), + os.path.join(temp_dir, "config_check_pack2.zip"), + os.path.join(temp_dir, "compare_output")) + + compare_output_dir = os.path.join(temp_dir, "compare_output") + + + total_check_result = read_xlsx(os.path.join(compare_output_dir, ConfigChecker.result_filename)) + self.assertEqual(total_check_result.columns.tolist(), ConfigChecker.result_header) + target_total_check_result = [ + ['env', True], + ['pip', False], + ['dataset', False], + ['weights', False], + ['hyperparameters', True], + ['random', False] + ] + self.assertEqual(total_check_result.values.tolist(), target_total_check_result) + + pip_data_check_result = read_xlsx(os.path.join(compare_output_dir, ConfigChecker.result_filename), sheet_name=PipPackageChecker.target_name_in_zip) + self.assertEqual(pip_data_check_result.columns.tolist(), PipPackageChecker.result_header) + self.assertEqual(pip_data_check_result.iloc[0].tolist(), ['transformers', '0.0.1', '0.0.2', 'error']) + + random_check_result = read_xlsx(os.path.join(compare_output_dir, ConfigChecker.result_filename), sheet_name=RandomChecker.target_name_in_zip) + self.assertEqual(random_check_result.columns.tolist(), RandomChecker.result_header) + self.assertEqual(len(random_check_result), 3) + + dataset_check_result = read_xlsx(os.path.join(compare_output_dir, ConfigChecker.result_filename), sheet_name=DatasetChecker.target_name_in_zip) + self.assertEqual(dataset_check_result.columns.tolist(), DatasetChecker.result_header) + self.assertEqual(len(dataset_check_result), 20) + + weight_check_result = read_xlsx(os.path.join(compare_output_dir, ConfigChecker.result_filename), sheet_name=WeightsChecker.target_name_in_zip) + self.assertEqual(weight_check_result.columns.tolist(), WeightsChecker.result_header) + self.assertEqual(len(weight_check_result), 20) + diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/config_checking/test_dataset_checker.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/config_checking/test_dataset_checker.py new file mode 100644 index 0000000000000000000000000000000000000000..0898a4f8cf22847b87b66061e38a50c80bac9127 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/config_checking/test_dataset_checker.py @@ -0,0 +1,82 @@ +import unittest +import torch +import pandas as pd +from unittest.mock import patch, MagicMock + +from msprobe.pytorch.config_checking.checkers.dataset_checker import compare_dataset, \ + compare_dataset_dicts, parse_args_and_kargs, process_obj, process_tensor + + +class TestTensorProcessing(unittest.TestCase): + + def test_process_tensor(self): + tensor = torch.tensor([1.0, 2.0, 3.0]) + result = process_tensor(tensor) + self.assertEqual(isinstance(result, dict), True) + self.assertEqual(set(result.keys()), {'max', 'min', 'mean', 'norm'}) + + def test_process_obj_tensor(self): + tensor = torch.tensor([1.0, 2.0, 3.0]) + result = process_obj(tensor) + self.assertEqual(isinstance(result, dict), True) + self.assertEqual(set(result.keys()), {'max', 'min', 'mean', 'norm'}) + + def test_process_obj_list(self): + obj = [torch.tensor([1.0]), torch.tensor([2.0])] + result = process_obj(obj) + self.assertEqual(isinstance(result, dict), True) + self.assertEqual(set(result.keys()), {0, 1}) + + def test_process_obj_dict(self): + obj = {'a': torch.tensor([1.0]), 'b': torch.tensor([2.0])} + result = process_obj(obj) + self.assertEqual(isinstance(result, dict), True) + self.assertEqual(set(result.keys()), {'a', 'b'}) + + def test_process_obj_other(self): + obj = "test" + result = process_obj(obj) + self.assertEqual(result, "") + + def test_parse_args_and_kargs(self): + args = (torch.tensor([1.0]),) + kwargs = {'a': torch.tensor([2.0])} + result = parse_args_and_kargs(args, kwargs) + self.assertEqual(isinstance(result, dict), True) + self.assertEqual(set(result.keys()), {'args', 'kwargs'}) + + def test_compare_dataset_dicts_equal(self): + dict1 = {'a': {'max': 1.0, 'min': 0.0, 'mean': 0.5, 'norm': 0.7}} + dict2 = {'a': {'max': 1.0, 'min': 0.0, 'mean': 0.5, 'norm': 0.7}} + results = compare_dataset_dicts(dict1, dict2) + self.assertEqual(len(results), 1) + self.assertEqual(results[0]['equal'], True) + + def test_compare_dataset_dicts_not_equal(self): + dict1 = {'a': {'max': 1.0, 'min': 0.0, 'mean': 0.5, 'norm': 0.7}} + dict2 = {'a': {'max': 2.0, 'min': 0.0, 'mean': 0.5, 'norm': 0.7}} + results = compare_dataset_dicts(dict1, dict2) + self.assertEqual(len(results), 1) + self.assertEqual(results[0]['equal'], False) + + def test_compare_dataset_dicts_nested(self): + dict1 = {'a': {'b': {'max': 1.0, 'min': 0.0, 'mean': 0.5, 'norm': 0.7}}} + dict2 = {'a': {'b': {'max': 1.0, 'min': 0.0, 'mean': 0.5, 'norm': 0.7}}} + results = compare_dataset_dicts(dict1, dict2) + self.assertEqual(len(results), 1) + self.assertEqual(results[0]['tag'], 'a.b') + + @patch('os.listdir', return_value=['step1']) + @patch('os.path.isdir', return_value=True) + @patch('os.path.isfile', return_value=True) + @patch('msprobe.pytorch.config_checking.checkers.dataset_checker.load_json') + def test_compare_dataset(self, mock_load_json, mock_isfile, mock_isdir, mock_listdir): + mock_load_json.return_value = {'a': {'max': 1.0, 'min': 0.0, 'mean': 0.5, 'norm': 0.7}} + bench_dir = 'bench' + cmp_dir = 'cmp' + result = compare_dataset(bench_dir, cmp_dir) + self.assertEqual(isinstance(result, pd.DataFrame), True) + + + + \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/config_checking/test_random_checker.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/config_checking/test_random_checker.py new file mode 100644 index 0000000000000000000000000000000000000000..4b04351862796ccd433d7136e318f47dca856349 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/config_checking/test_random_checker.py @@ -0,0 +1,77 @@ +import unittest +import pandas as pd +from unittest.mock import patch, MagicMock + +from msprobe.pytorch.config_checking.checkers.random_checker import compare_json_files, compare_random, get_file_and_line + + +class TestCompareRandom(unittest.TestCase): + + @patch('os.listdir', return_value=['rank1.json', 'rank2.json']) + @patch('os.path.join', return_value='test_path') + @patch("msprobe.pytorch.config_checking.checkers.random_checker.load_json") + def test_compare_random_with_files(self, mock_load_json, mock_path, mock_listdir): + mock_load_json.return_value = {"op1": {"position1": 1}} + bench_dir = 'test_bench' + cmp_dir = 'test_cmp' + result = compare_random(bench_dir, cmp_dir) + self.assertEqual(isinstance(result, pd.DataFrame), True) + + @patch('os.listdir', return_value=[]) + @patch('os.path.join', return_value='test_path') + def test_compare_random_no_files(self, mock_path, mock_listdir): + bench_dir = 'test_bench' + cmp_dir = 'test_cmp' + result = compare_random(bench_dir, cmp_dir) + self.assertEqual(isinstance(result, pd.DataFrame), True) + self.assertEqual(len(result), 0) + + def test_get_file_and_line_with_valid_input(self): + position = '/path/to/file.py:10' + result = get_file_and_line(position) + self.assertEqual(isinstance(result, str), True) + self.assertEqual(result, 'file.py:10') + + def test_get_file_and_line_with_invalid_input(self): + position = 'invalid_position' + result = get_file_and_line(position) + self.assertEqual(isinstance(result, str), True) + self.assertEqual(result, 'invalid_position') + + @patch('os.listdir', return_value=['rank1.json', 'rank2.json']) + @patch('os.path.join', return_value='test_path') + def test_compare_json_files_same_data(self, mock_path, mock_listdir): + bench_data = {"op1": {"position1:10": 1}} + cmp_data = {"op1": {"position1:10": 1}} + result = compare_json_files(bench_data, cmp_data) + self.assertEqual(isinstance(result, list), True) + self.assertEqual(len(result), 1) + self.assertEqual(result[0][2], True) + + @patch('os.listdir', return_value=['rank1.json', 'rank2.json']) + @patch('os.path.join', return_value='test_path') + def test_compare_json_files_different_data(self, mock_path, mock_listdir): + bench_data = {"op1": {"position1:10": 1}} + cmp_data = {"op1": {"position1:10": 2}} + result = compare_json_files(bench_data, cmp_data) + self.assertEqual(isinstance(result, list), True) + self.assertEqual(len(result), 1) + self.assertEqual(result[0][2], False) + + @patch('os.listdir', return_value=['rank1.json', 'rank2.json']) + @patch('os.path.join', return_value='test_path') + def test_compare_json_files_missing_op_in_bench(self, mock_path, mock_listdir): + bench_data = {} + cmp_data = {"op1": {"position1:10": 1}} + result = compare_json_files(bench_data, cmp_data) + self.assertEqual(isinstance(result, list), True) + self.assertEqual(len(result), 1) + + @patch('os.listdir', return_value=['rank1.json', 'rank2.json']) + @patch('os.path.join', return_value='test_path') + def test_compare_json_files_missing_op_in_cmp(self, mock_path, mock_listdir): + bench_data = {"op1": {"position1:10": 1}} + cmp_data = {} + result = compare_json_files(bench_data, cmp_data) + self.assertEqual(isinstance(result, list), True) + self.assertEqual(len(result), 1) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/config_checking/test_weight_checker.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/config_checking/test_weight_checker.py new file mode 100644 index 0000000000000000000000000000000000000000..7b76268d23799d31b7fc90fd091f83eb6cbba577 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/config_checking/test_weight_checker.py @@ -0,0 +1,79 @@ +import unittest +from unittest.mock import patch +import pandas as pd +import os +import torch + +from msprobe.pytorch.config_checking.checkers.weights_checker import collect_weights_data, compare_weight, compare_weight_file + + +class TestWeightComparison(unittest.TestCase): + @patch('msprobe.pytorch.config_checking.utils.utils.get_tensor_features') + @patch('torch.nn.Module.named_parameters') + def test_collect_weights_data(self, mock_named_parameters, mock_get_tensor_features): + mock_model = unittest.mock.create_autospec(torch.nn.Module) + mock_named_parameters.return_value = [('param1', object())] + mock_get_tensor_features.return_value = {'max': 1, 'min': 0, 'mean': 0.5, 'norm': 1} + result = collect_weights_data(mock_model) + self.assertEqual(isinstance(result, dict), True) + + @patch('msprobe.pytorch.config_checking.checkers.weights_checker.load_json') + def test_compare_weight_file(self, mock_load_json): + mock_load_json.side_effect = [ + {'weight1': {'max': 1, 'min': 0, 'mean': 0.5, 'norm': 1}}, + {'weight1': {'max': 1, 'min': 0, 'mean': 0.5, 'norm': 1}} + ] + result = compare_weight_file('bench.json', 'cmp.json') + self.assertEqual(isinstance(result, list), True) + + @patch('msprobe.pytorch.config_checking.checkers.weights_checker.os_walk_for_files') + @patch('msprobe.pytorch.config_checking.checkers.weights_checker.load_json') + @patch('os.path.exists') + def test_compare_weight(self, mock_exists, mock_load_json, mock_os_walk_for_files): + mock_os_walk_for_files.return_value = [ + {"root": "bench/step1/rank0", "file": "weights.json"} + ] + mock_load_json.return_value = {'weight1': {'max': 1, 'min': 0, 'mean': 0.5, 'norm': 1}} + mock_exists.return_value = True + result = compare_weight('bench', 'cmp') + self.assertEqual(isinstance(result, pd.DataFrame), True) + + @patch('msprobe.pytorch.config_checking.checkers.weights_checker.load_json') + def test_compare_weight_file_different_weights(self, mock_load_json): + mock_load_json.side_effect = [ + {'weight1': {'max': 1, 'min': 0, 'mean': 0.5, 'norm': 1}}, + {'weight1': {'max': 2, 'min': 1, 'mean': 1.5, 'norm': 2}} + ] + result = compare_weight_file('bench.json', 'cmp.json') + self.assertEqual(isinstance(result, list), True) + for res in result: + if res["weight_name"] == "weight1": + self.assertEqual(res["equal"], False) + + @patch('msprobe.pytorch.config_checking.checkers.weights_checker.os_walk_for_files') + @patch('msprobe.pytorch.config_checking.checkers.weights_checker.load_json') + @patch('os.path.exists') + def test_compare_weight_cmp_file_missing(self, mock_exists, mock_load_json, mock_os_walk_for_files): + mock_os_walk_for_files.return_value = [ + {"root": "bench/step1/rank0", "file": "weights.json"} + ] + mock_load_json.return_value = {'weight1': {'max': 1, 'min': 0, 'mean': 0.5, 'norm': 1}} + mock_exists.return_value = False + result = compare_weight('bench', 'cmp') + self.assertEqual(isinstance(result, pd.DataFrame), True) + self.assertEqual(len(result[result["equal"] == "only bench have"]), 1) + + @patch('msprobe.pytorch.config_checking.checkers.weights_checker.os_walk_for_files') + @patch('msprobe.pytorch.config_checking.checkers.weights_checker.load_json') + @patch('os.path.exists') + def test_compare_weight_multiple_files(self, mock_exists, mock_load_json, mock_os_walk_for_files): + mock_os_walk_for_files.return_value = [ + {"root": "bench/step1/rank0", "file": "weights1.json"}, + {"root": "bench/step1/rank0", "file": "weights2.json"} + ] + mock_load_json.return_value = {'weight1': {'max': 1, 'min': 0, 'mean': 0.5, 'norm': 1}} + mock_exists.return_value = True + result = compare_weight('bench', 'cmp') + self.assertEqual(isinstance(result, pd.DataFrame), True) + self.assertEqual(len(result), 2) + diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/debugger/test_pt_debugger_config.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/debugger/test_pt_debugger_config.py index 4fc27c267ebe65ea46ecf0f17bc47ff702eb241d..22e2de34aadeb53001140dfe4b870cedf1a2c564 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/debugger/test_pt_debugger_config.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/debugger/test_pt_debugger_config.py @@ -1,6 +1,7 @@ import unittest -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch +import torch from msprobe.core.common.const import Const from msprobe.core.common.exceptions import MsprobeException from msprobe.pytorch.debugger.debugger_config import DebuggerConfig @@ -46,28 +47,77 @@ class TestDebuggerConfig(unittest.TestCase): self.assertEqual(debugger.nfs_path, "./nfs_path") self.assertEqual(debugger.port, 8080) - def test_valid_task_and_level(self): - config = DebuggerConfig(self.common_config, self.task_config, "tensor", None, "L1") - config.check_kwargs() + def test_check_kwargs_with_invalid_task(self): + self.common_config.task = "invalid_task" + with self.assertRaises(MsprobeException) as context: + DebuggerConfig(self.common_config, self.task_config, None, None, None) + self.assertIn(f"The task is not in the {Const.TASK_LIST}", str(context.exception)) + + def test_check_kwargs_with_invalid_level(self): + self.common_config.level = "invalid_level" + with self.assertRaises(MsprobeException) as context: + DebuggerConfig(self.common_config, self.task_config, None, None, None) + self.assertIn(f"The level is not in the {Const.LEVEL_LIST}.", str(context.exception)) - def test_invalid_task(self): + def test_check_kwargs_with_invalid_dump_path(self): + self.common_config.dump_path = None with self.assertRaises(MsprobeException) as context: - config = DebuggerConfig(self.common_config, self.task_config, "invalid_task", None, "L1") - config.check_kwargs() - self.assertIn("not in the", str(context.exception)) + DebuggerConfig(self.common_config, self.task_config, None, None, None) + self.assertIn(f"The dump_path not found.", str(context.exception)) - def test_invalid_level(self): + def test_check_kwargs_with_invalid_async_dump(self): + self.common_config.async_dump = 1 with self.assertRaises(MsprobeException) as context: - config = DebuggerConfig(self.common_config, self.task_config, "tensor", None, "invalid_level") - config.check_kwargs() - self.assertIn("not in the", str(context.exception)) + DebuggerConfig(self.common_config, self.task_config, None, None, None) + self.assertIn(f"The parameters async_dump should be bool.", str(context.exception)) + + def test_check_kwargs_with_async_dump_and_not_debug(self): + self.common_config.async_dump = True + self.common_config.task = Const.TENSOR + self.common_config.level = Const.LEVEL_MIX + self.task_config.list = [] + with self.assertRaises(MsprobeException) as context: + DebuggerConfig(self.common_config, self.task_config, None, None, None) + self.assertIn(f"the parameters list cannot be empty.", str(context.exception)) + + def test_check_kwargs_with_structure_task(self): + self.common_config.task = Const.STRUCTURE + self.common_config.level = Const.LEVEL_L1 + config = DebuggerConfig(self.common_config, self.task_config, None, None, None) + self.assertEqual(config.level, Const.LEVEL_MIX) + + def test_check_model_with_model_is_none(self): + self.common_config.level = Const.LEVEL_L0 + instance = MagicMock() + instance.model = None + config = DebuggerConfig(self.common_config, self.task_config, None, None, None) + with self.assertRaises(MsprobeException) as context: + config.check_model(instance, None, None) + self.assertIn("missing the parameter 'model'", str(context.exception)) + + def test_check_model_with_single_model(self): + self.common_config.level = Const.LEVEL_MIX + model1 = torch.nn.ReLU() + model2 = torch.nn.Linear(2, 2) - def test_missing_dump_path(self): + instance = MagicMock() + instance.model = model1 + config = DebuggerConfig(self.common_config, self.task_config, None, None, None) + config.check_model(instance, model2, None) + + self.assertEqual(instance.model, model2) + + def test_check_model_with_incorrect_model(self): + self.common_config.level = Const.LEVEL_L0 + model1 = torch.nn.ReLU() + model2 = [torch.nn.Linear(2, 2), torch.nn.ReLU(), "test_model"] + + instance = MagicMock() + instance.model = model1 + config = DebuggerConfig(self.common_config, self.task_config, None, None, None) with self.assertRaises(MsprobeException) as context: - self.common_config.dump_path = None - config = DebuggerConfig(self.common_config, self.task_config, "tensor", None, "L1") - config.check_kwargs() - self.assertIn("dump_path not found", str(context.exception)) + config.check_model(instance, model2, None) + self.assertIn("must be a torch.nn.Module or list[torch.nn.Module]", str(context.exception)) def test_check_and_adjust_config_with_l2_scope_not_empty(self): self.common_config.dump_path = "./dump_path" @@ -100,3 +150,40 @@ class TestDebuggerConfig(unittest.TestCase): debugger = DebuggerConfig(self.common_config, self.task_config, None, None, None) debugger._check_and_adjust_config_with_l2() self.assertIn("Functional.conv2d.0.forward", self.task_config.list) + + def test_check_and_adjust_config_with_l2_task_not_tensor(self): + self.common_config.dump_path = "./dump_path" + self.common_config.task = Const.STATISTICS + + self.task_config.scope = [] + self.task_config.list = ["Functional.conv2d.0.forward"] + debugger = DebuggerConfig(self.common_config, self.task_config, None, None, None) + with self.assertRaises(MsprobeException) as context: + debugger._check_and_adjust_config_with_l2() + self.assertIn("the task must be set to tensor", str(context.exception)) + + def test_check_statistics_config_task_not_statistics(self): + self.common_config.dump_path = "./dump_path" + self.common_config.task = Const.TENSOR + + debugger = DebuggerConfig(self.common_config, self.task_config, None, None, None) + debugger._check_statistics_config(self.task_config) + self.assertFalse(hasattr(debugger, "tensor_list")) + + def test_check_statistics_config_not_tensor_list(self): + self.common_config.dump_path = "./dump_path" + self.common_config.task = Const.STATISTICS + delattr(self.task_config, "tensor_list") + + debugger = DebuggerConfig(self.common_config, self.task_config, None, None, None) + debugger._check_statistics_config(self.task_config) + self.assertEqual(debugger.tensor_list, []) + + def test_check_statistics_config_success(self): + self.common_config.dump_path = "./dump_path" + self.common_config.task = Const.STATISTICS + + self.task_config.tensor_list = ["Functional.conv2d"] + debugger = DebuggerConfig(self.common_config, self.task_config, None, None, None) + debugger._check_statistics_config(self.task_config) + self.assertEqual(debugger.tensor_list, self.task_config.tensor_list) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/debugger/test_pt_debugger_start.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/debugger/test_pt_debugger_start.py new file mode 100644 index 0000000000000000000000000000000000000000..feab969830d670532c0d0708770240f5a2f58d30 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/debugger/test_pt_debugger_start.py @@ -0,0 +1,105 @@ +import os +import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import TensorDataset, DataLoader +import unittest +from unittest.mock import patch +from msprobe.core.common_config import CommonConfig +from msprobe.core.debugger.precision_debugger import BasePrecisionDebugger +from msprobe.pytorch.pt_config import StatisticsConfig +from msprobe.pytorch.debugger.precision_debugger import PrecisionDebugger +from msprobe.core.common.file_utils import load_json +import shutil + +# 生成随机分类数据 +X = torch.randn(100, 2) +y = ((X[:, 0] + X[:, 1]) > 0).float().reshape(-1, 1) + +# 创建数据加载器 +dataset = TensorDataset(X, y) +dataloader = DataLoader(dataset, batch_size=10) + +# 定义单层神经网络 +class SingleLayerNet(nn.Module): + def __init__(self): + super().__init__() + self.layer = nn.Linear(2, 1) + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + return self.sigmoid(x) + + +class MultiStartDebugger: + debugger = None + dump_path = None + hooked_model = [] + + @classmethod + def init(cls, dump_path): + cls.dump_path = dump_path + json_config = { + "task": "statistics", + "dump_path": "/absolute_path", + "rank": [], + "step": [], + "level": "L1", + "async_dump": False + } + + common_config = CommonConfig(json_config) + task_config = StatisticsConfig(json_config) + with patch.object(BasePrecisionDebugger, "_parse_config_path", return_value=(common_config, task_config)): + cls.debugger = PrecisionDebugger(task="statistics", level="L0", dump_path=dump_path) + + @classmethod + def debugger_start(cls, model, tag): + cls.debugger.service.first_start = True if model not in cls.hooked_model else False + cls.debugger.service.config.dump_path = os.path.join(cls.dump_path, tag) + cls.debugger.start(model=model) + if model not in cls.hooked_model: + cls.hooked_model.append(model) + + @classmethod + def debugger_stop(cls): + cls.debugger.stop() + cls.debugger.service._reset_status() + + @classmethod + def debugger_step(cls): + cls.debugger.step() + + +class TestPTDebuggerStart(unittest.TestCase): + def test_debugger_multiple_start(self): + dump_path = "./test_debugger_multiple_start_dump" + + model1 = SingleLayerNet() + model2 = SingleLayerNet() + MultiStartDebugger.init(dump_path) + + for batch_X, batch_y in dataloader: + MultiStartDebugger.debugger_start(model=model1, tag="model1") + output1 = model1(batch_X) + MultiStartDebugger.debugger_stop() + + MultiStartDebugger.debugger_start(model=model2, tag="model2") + output2 = model2(batch_X) + MultiStartDebugger.debugger_stop() + MultiStartDebugger.debugger_step() + + model1_dump_path = os.path.join(dump_path, "model1") + self.assertTrue(os.path.exists(model1_dump_path)) + self.assertEqual(len(os.listdir(model1_dump_path)), 10) + model1_construct_json = load_json(os.path.join(model1_dump_path, "step0", "rank", "construct.json")) + self.assertEqual(len(model1_construct_json), 1) + + model2_dump_path = os.path.join(dump_path, "model2") + self.assertTrue(os.path.exists(model2_dump_path)) + self.assertEqual(len(os.listdir(model2_dump_path)), 10) + model2_construct_json = load_json(os.path.join(model2_dump_path, "step0", "rank", "construct.json")) + self.assertEqual(len(model2_construct_json), 1) + + shutil.rmtree(dump_path) + diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/debugger/test_pt_precision_debugger.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/debugger/test_pt_precision_debugger.py index a2f3e8a816e356b68e598138b30a9e14b42107d9..249432717189e7b675b43829d13b9adc806a3274 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/debugger/test_pt_precision_debugger.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/debugger/test_pt_precision_debugger.py @@ -8,9 +8,12 @@ import torch from msprobe.core.common.const import Const, MsgConst from msprobe.core.common.utils import get_real_step_or_rank from msprobe.core.common.exceptions import MsprobeException, FileCheckException -from msprobe.pytorch.debugger.precision_debugger import PrecisionDebugger, iter_tracer +from msprobe.pytorch.debugger.precision_debugger import PrecisionDebugger from msprobe.pytorch.grad_probe.grad_monitor import GradientMonitor from msprobe.test.pytorch_ut.grad_probe.test_grad_monitor import common_config, task_config +from msprobe.core.common_config import CommonConfig +from msprobe.core.debugger.precision_debugger import BasePrecisionDebugger +from msprobe.pytorch.pt_config import StatisticsConfig, GradToolConfig class Args: @@ -23,6 +26,29 @@ class Args: class TestPrecisionDebugger(unittest.TestCase): + grad_json_config = { + "task": Const.GRAD_PROBE, + "dump_path": "/absolute_path", + "rank": [], + "step": [], + "level": "L1", + "async_dump": False + } + + grad_common_config = CommonConfig(grad_json_config) + grad_task_config = GradToolConfig(grad_json_config) + + json_config = { + "task": "statistics", + "dump_path": "/absolute_path", + "rank": [], + "step": [], + "level": "L1", + "async_dump": False + } + + statistics_common_config = CommonConfig(json_config) + statistics_task_config = StatisticsConfig(json_config) def test_init(self): gm = GradientMonitor(common_config, task_config) @@ -30,43 +56,43 @@ class TestPrecisionDebugger(unittest.TestCase): step = get_real_step_or_rank([0, 1, "3-5"], Const.STEP) self.assertListEqual(step, [0, 1, 3, 4, 5]) - def test_instance(self): - debugger1 = PrecisionDebugger(dump_path="./dump_path") - debugger2 = PrecisionDebugger(dump_path="./dump_path") - self.assertIs(debugger1.instance, debugger2.instance) - def test_check_input_params(self): - args = Args(config_path = 1) + args = Args(config_path=1) with self.assertRaises(MsprobeException) as context: - PrecisionDebugger.check_input_params(args) + PrecisionDebugger._check_input_params(args.config_path, args.task, args.dump_path, args.level) self.assertEqual(context.exception.code, MsprobeException.INVALID_PARAM_ERROR) - args = Args(config_path = "./") + args = Args(config_path="./") with self.assertRaises(FileCheckException) as context: - PrecisionDebugger.check_input_params(args) + PrecisionDebugger._check_input_params(args.config_path, args.task, args.dump_path, args.level) self.assertEqual(context.exception.code, FileCheckException.INVALID_FILE_ERROR) - args = Args(task = 1) + args = Args(task=1) with self.assertRaises(MsprobeException) as context: - PrecisionDebugger.check_input_params(args) + PrecisionDebugger._check_input_params(args.config_path, args.task, args.dump_path, args.level) self.assertEqual(context.exception.code, MsprobeException.INVALID_PARAM_ERROR) - args = Args(dump_path = 1) + args = Args(dump_path=1) with self.assertRaises(MsprobeException) as context: - PrecisionDebugger.check_input_params(args) + PrecisionDebugger._check_input_params(args.config_path, args.task, args.dump_path, args.level) self.assertEqual(context.exception.code, MsprobeException.INVALID_PARAM_ERROR) - args = Args(level = 1) + args = Args(level=1) with self.assertRaises(MsprobeException) as context: - PrecisionDebugger.check_input_params(args) + PrecisionDebugger._check_input_params(args.config_path, args.task, args.dump_path, args.level) self.assertEqual(context.exception.code, MsprobeException.INVALID_PARAM_ERROR) - args = Args(config_path = os.path.join(os.path.dirname(__file__), "../../../config.json"), - task = Const.TASK_LIST[0], - dump_path="./dump_path", - level = Const.LEVEL_LIST[0], - model = torch.nn.Module()) - checked_input_params = PrecisionDebugger.check_input_params(args) + args = Args(config_path=os.path.join(os.path.dirname(__file__), "../../../config.json"), + task=Const.TASK_LIST[0], + dump_path="./dump_path", + level=Const.LEVEL_LIST[0], + model=torch.nn.Module()) + checked_input_params = PrecisionDebugger._check_input_params( + args.config_path, + args.task, + args.dump_path, + args.level + ) self.assertIsNone(checked_input_params) def test_start_grad_probe(self): @@ -75,12 +101,16 @@ class TestPrecisionDebugger(unittest.TestCase): PrecisionDebugger.start() self.assertEqual(str(context.exception), MsgConst.NOT_CREATED_INSTANCE) - PrecisionDebugger._instance = PrecisionDebugger(task=Const.GRAD_PROBE, dump_path="./dump_path") + with patch.object(BasePrecisionDebugger, "_parse_config_path", + return_value=(self.grad_common_config, self.grad_task_config)): + PrecisionDebugger._instance = PrecisionDebugger(task=Const.GRAD_PROBE, dump_path="./dump_path") checked_start = PrecisionDebugger.start() self.assertIsNone(checked_start) def test_start_statistics(self): - debugger = PrecisionDebugger(dump_path="./dump_path") + with patch.object(BasePrecisionDebugger, "_parse_config_path", + return_value=(self.statistics_common_config, self.statistics_task_config)): + debugger = PrecisionDebugger(dump_path="./dump_path") debugger.service = MagicMock() debugger.config = MagicMock() debugger.task = 'statistics' @@ -88,7 +118,12 @@ class TestPrecisionDebugger(unittest.TestCase): debugger.service.start.assert_called_once() def test_forward_backward_dump_end(self): - debugger = PrecisionDebugger(dump_path="./dump_path") + with patch.object( + BasePrecisionDebugger, + "_parse_config_path", + return_value=(self.statistics_common_config,self.statistics_task_config) + ): + debugger = PrecisionDebugger(dump_path="./dump_path", task='statistics') debugger.service = MagicMock() debugger.config = MagicMock() debugger.task = 'statistics' @@ -101,7 +136,9 @@ class TestPrecisionDebugger(unittest.TestCase): PrecisionDebugger.stop() self.assertEqual(str(context.exception), MsgConst.NOT_CREATED_INSTANCE) - PrecisionDebugger._instance = PrecisionDebugger(task=Const.GRAD_PROBE, dump_path="./dump_path") + with patch.object(BasePrecisionDebugger, "_parse_config_path", + return_value=(self.grad_common_config, self.grad_task_config)): + PrecisionDebugger._instance = PrecisionDebugger(task=Const.GRAD_PROBE, dump_path="./dump_path") checked_stop = PrecisionDebugger.stop() self.assertIsNone(checked_stop) @@ -117,8 +154,9 @@ class TestPrecisionDebugger(unittest.TestCase): PrecisionDebugger._instance = None PrecisionDebugger.step() self.assertEqual(str(context.exception), MsgConst.NOT_CREATED_INSTANCE) - - PrecisionDebugger._instance = PrecisionDebugger(task=Const.GRAD_PROBE, dump_path="./dump_path") + with patch.object(BasePrecisionDebugger, "_parse_config_path", + return_value=(self.grad_common_config, self.grad_task_config)): + PrecisionDebugger._instance = PrecisionDebugger(task=Const.GRAD_PROBE, dump_path="./dump_path") checked_step = PrecisionDebugger.step() self.assertIsNone(checked_step) @@ -135,7 +173,12 @@ class TestPrecisionDebugger(unittest.TestCase): PrecisionDebugger.monitor(torch.nn.Module()) self.assertEqual(str(context.exception), MsgConst.NOT_CREATED_INSTANCE) - debugger = PrecisionDebugger(task=Const.STATISTICS, dump_path="./dump_path") + with patch.object( + BasePrecisionDebugger, + "_parse_config_path", + return_value=(self.statistics_common_config, self.statistics_task_config) + ): + debugger = PrecisionDebugger(task=Const.STATISTICS, dump_path="./dump_path") checked_monitor = debugger.monitor(torch.nn.Module()) self.assertIsNone(checked_monitor) @@ -146,40 +189,57 @@ class TestPrecisionDebugger(unittest.TestCase): debugger.gm.monitor(torch.nn.Module()) debugger.gm.monitor.assert_called_once() - @patch('msprobe.pytorch.debugger.precision_debugger.PrecisionDebugger') - def test_iter_tracer(self, mock_debugger): - mock_debugger_instance = mock_debugger.instance = MagicMock() - mock_debugger_instance.service.first_start = False - - @iter_tracer - def dataloader_func(): - return "test_iter_tracer" - result = dataloader_func() - self.assertEqual(result, "test_iter_tracer") - - mock_debugger_instance.stop.assert_called_once() - mock_debugger_instance.step.assert_called_once() - mock_debugger_instance.start.assert_called_once() - self.assertTrue(mock_debugger_instance.enable_dataloader) - - @patch('msprobe.pytorch.debugger.precision_debugger.PrecisionDebugger') - def test_iter_tracer_first_start(self, mock_debugger): - mock_debugger_instance = mock_debugger.instance = MagicMock() - mock_debugger_instance.service.first_start = True - - @iter_tracer - def dataloader_func(): - return "test_iter_tracer" - result = dataloader_func() - self.assertEqual(result, "test_iter_tracer") - - mock_debugger_instance.stop.assert_not_called() - mock_debugger_instance.step.assert_not_called() - mock_debugger_instance.start.assert_called_once() - self.assertTrue(mock_debugger_instance.enable_dataloader) - def tearDown(self): if os.path.exists("./dump_path/"): shutil.rmtree("./dump_path/") if os.path.exists("./grad_output/"): shutil.rmtree("./grad_output/") + + +class TestIterTracer(unittest.TestCase): + def setUp(self): + self.debugger = MagicMock() + self.debugger.service.first_start = False + self.debugger.enable_dataloader = True + self.ori_instance = PrecisionDebugger._instance + PrecisionDebugger._instance = self.debugger + + def tearDown(self): + PrecisionDebugger._instance = self.ori_instance + + def test_debugger_with_not_first_start(self): + @PrecisionDebugger._iter_tracer + def test_func(): + return "test case 1" + + result = test_func() + + self.assertEqual(result, "test case 1") + self.debugger.stop.assert_called_once() + self.debugger.step.assert_called_once() + self.debugger.start.assert_called_once() + + def test_debugger_with_first_start(self): + self.debugger.service.first_start = True + + @PrecisionDebugger._iter_tracer + def test_func(): + return "test case 2" + + result = test_func() + self.assertEqual(result, "test case 2") + self.debugger.stop.assert_not_called() + self.debugger.step.assert_not_called() + self.debugger.start.assert_called_once() + + def test_no_debugger_instance(self): + PrecisionDebugger._instance = None + + @PrecisionDebugger._iter_tracer + def test_func(): + return "test case 3" + + with self.assertRaises(MsprobeException) as context: + result = test_func() + self.assertEqual(result, "test case 3") + self.assertEqual(context.exception.code, MsprobeException.INTERFACE_USAGE_ERROR) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/debugger_save/test_debugger_save_pytorch.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/debugger_save/test_debugger_save_pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..3a3d1dd2362146f56d4f5bcc53e4792689df3e90 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/debugger_save/test_debugger_save_pytorch.py @@ -0,0 +1,449 @@ +import unittest +import os +import json +import torch +import numpy as np +import shutil + +from msprobe.pytorch import PrecisionDebugger + +current_file = __file__ +parent_dir = os.path.abspath(os.path.dirname(current_file)) +test_dir = os.path.join(parent_dir, "test_dir") + +def deep_compare(obj1, obj2, float_tolerance=1e-5): + """ + Recursively compare two objects to check if they are the same. + Supports nested dictionaries and lists. + """ + if type(obj1) != type(obj2): + return False + if isinstance(obj1, dict): + if obj1.keys() != obj2.keys(): + return False + return all(deep_compare(obj1[key], obj2[key]) for key in obj1) + if isinstance(obj1, (tuple, list)): + if len(obj1) != len(obj2): + return False + return all(deep_compare(item1, item2) for item1, item2 in zip(obj1, obj2)) + if isinstance(obj1, (int, float)): + return abs(obj1 - obj2) < float_tolerance + return obj1 == obj2 + +class TestDebuggerSave(unittest.TestCase): + @staticmethod + def write_config_json(step, async_dump, mode, dump_path, config_file_path): + task = "tensor" if mode == "tensor" else "statistics" + statistics_summary_mode = "statistics" if mode == "statistics" else "md5" + config = { + "task": task, + "dump_path": dump_path, + "rank": [], + "step": step, + "level": "debug", + "enable_dataloader": False, + "async_dump": async_dump, + "statistics": { + "summary_mode": statistics_summary_mode, + } + } + with open(config_file_path, "w", encoding="utf-8") as f: + json.dump(config, f, indent=4, ensure_ascii=False) + + @staticmethod + def read_debug_json_into_dict(debug_json_path): + with open(debug_json_path, "r", encoding="utf-8") as f: + debug_json = json.load(f) + return debug_json + + + @staticmethod + def check_real_pt(pt_path, target_pt_tensor, check_values=True, rtol=1e-5, atol=1e-8): + """ + Enhanced version with optional value comparison. + + Args: + pt_path (str): Path to the .pt file + target_pt_tensor: Target torch tensor to compare + check_values (bool): If True, also compare array values + rtol, atol: Relative and absolute tolerances for value comparison + + Returns: + bool: True if all checks pass + """ + # Load the pt file + try: + pt_data = torch.load(pt_path) + except FileNotFoundError: + print(f"Error: The file {pt_path} does not exist.") + return False + except Exception as e: + print(f"Error loading pt file: {e}") + return False + # Check shapes + if pt_data.shape != target_pt_tensor.shape: + print(f"Shape mismatch: pt data shape is {pt_data.shape}, target tensor shape is {target_pt_tensor.shape}") + return False + # Check dtypes + if pt_data.dtype != target_pt_tensor.dtype: + print(f"Shape mismatch: pt data dtype is {pt_data.dtype}, target tensor dtype is {target_pt_tensor.dtype}") + return False + # Optionally check values + if check_values: + if not torch.allclose(pt_data, target_pt_tensor, rtol=rtol, atol=atol): + print("Value mismatch: pt data and target tensor values do not match within the specified tolerances.") + return False + return True + + def setUp(self): + if not os.path.exists(test_dir): + os.makedirs(test_dir) + PrecisionDebugger._instance = None + + def tearDown(self): + if os.path.exists(test_dir): + shutil.rmtree(test_dir) + PrecisionDebugger._instance = None + + def test_save_real_tensor(self): + data = {"a": torch.Tensor([1., 2.])} + step = [] + async_dump = False + mode = "tensor" + dump_path = os.path.join(test_dir, "debug_save") + config_file_path = os.path.join(test_dir, "config.json") + self.write_config_json(step, async_dump, mode, dump_path, config_file_path) + debugger = PrecisionDebugger(config_file_path) + PrecisionDebugger.save(data, "data_dict", save_backward=False) + PrecisionDebugger.step() + # check pt file + pt_path = os.path.join(dump_path, "step0", "rank", "dump_tensor_data", "data_dict.0.debug.a.pt") + assert self.check_real_pt(pt_path, data["a"]) + # check debug json + target_debug_info = { + "a": { + "type": "torch.Tensor", + "dtype": "torch.float32", + "shape": [ + 2 + ], + "Max": 2.0, + "Min": 1.0, + "Mean": 1.5, + "Norm": 2.2360680103302, + "requires_grad": False, + "data_name": "data_dict.0.debug.a.pt" + } + } + debug_json_path = os.path.join(dump_path, "step0", "rank", "debug.json") + debug_json_dict = self.read_debug_json_into_dict(debug_json_path) + assert deep_compare(debug_json_dict["data"]["data_dict.0.debug"], target_debug_info) + + def test_save_md5(self): + data = {"a": torch.Tensor([1., 2.])} + step = [] + async_dump = False + mode = "md5" + dump_path = os.path.join(test_dir, "debug_save") + config_file_path = os.path.join(test_dir, "config.json") + self.write_config_json(step, async_dump, mode, dump_path, config_file_path) + debugger = PrecisionDebugger(config_file_path) + PrecisionDebugger.save(data, "data_dict", save_backward=False) + PrecisionDebugger.step() + # check debug json + target_debug_info = { + "a": { + "type": "torch.Tensor", + "dtype": "torch.float32", + "shape": [ + 2 + ], + "Max": 2.0, + "Min": 1.0, + "Mean": 1.5, + "Norm": 2.2360680103302, + "requires_grad": False, + "md5": "2e3fa576" + } + } + debug_json_path = os.path.join(dump_path, "step0", "rank", "debug.json") + debug_json_dict = self.read_debug_json_into_dict(debug_json_path) + assert deep_compare(debug_json_dict["data"]["data_dict.0.debug"], target_debug_info) + + def test_save_multiple_steps(self): + data = {"a": torch.Tensor([1., 2.])} + step = [0, 1, 2] + async_dump = False + mode = "tensor" + dump_path = os.path.join(test_dir, "debug_save") + config_file_path = os.path.join(test_dir, "config.json") + self.write_config_json(step, async_dump, mode, dump_path, config_file_path) + debugger = PrecisionDebugger(config_file_path) + for _ in step: + PrecisionDebugger.save(data, "data_dict", save_backward=False) + PrecisionDebugger.step() + # check pt file + for i in step: + pt_path = os.path.join(dump_path, f"step{i}", "rank", "dump_tensor_data", "data_dict.0.debug.a.pt") + assert self.check_real_pt(pt_path, data["a"]) + # check debug json + target_debug_info = { + "a": { + "type": "torch.Tensor", + "dtype": "torch.float32", + "shape": [ + 2 + ], + "Max": 2.0, + "Min": 1.0, + "Mean": 1.5, + "Norm": 2.2360680103302, + "requires_grad": False, + "data_name": "data_dict.0.debug.a.pt" + } + } + for i in step: + debug_json_path = os.path.join(dump_path, f"step{i}", "rank", "debug.json") + debug_json_dict = self.read_debug_json_into_dict(debug_json_path) + assert deep_compare(debug_json_dict["data"]["data_dict.0.debug"], target_debug_info) + + def test_async_save_tensor(self): + data = {"a": torch.Tensor([1., 2.])} + step = [] + async_dump = True + mode = "tensor" + dump_path = os.path.join(test_dir, "debug_save") + config_file_path = os.path.join(test_dir, "config.json") + + self.write_config_json(step, async_dump, mode, dump_path, config_file_path) + debugger = PrecisionDebugger(config_file_path) + PrecisionDebugger.save(data, "data_dict", save_backward=False) + PrecisionDebugger.step() + + # check pt file + pt_path = os.path.join(dump_path, "step0", "rank", "dump_tensor_data", "data_dict.0.debug.a.pt") + assert self.check_real_pt(pt_path, data["a"]) + + # check debug json + target_debug_info = { + "a": { + "type": "torch.Tensor", + "dtype": "torch.float32", + "shape": [ + 2 + ], + "data_name": "data_dict.0.debug.a.pt", + "Max": 2.0, + "Min": 1.0, + "Mean": 1.5, + "Norm": 2.2360680103302, + "requires_grad": False, + } + } + debug_json_path = os.path.join(dump_path, "step0", "rank", "debug.json") + debug_json_dict = self.read_debug_json_into_dict(debug_json_path) + assert deep_compare(debug_json_dict["data"]["data_dict.0.debug"], target_debug_info) + + def test_async_save_md5(self): + # async_dump case, md5 configuration not working,only save statistics + data = {"a": torch.Tensor([1., 2.])} + step = [] + async_dump = True + mode = "md5" + dump_path = os.path.join(test_dir, "debug_save") + config_file_path = os.path.join(test_dir, "config.json") + self.write_config_json(step, async_dump, mode, dump_path, config_file_path) + debugger = PrecisionDebugger(config_file_path) + PrecisionDebugger.save(data, "data_dict", save_backward=False) + PrecisionDebugger.step() + # check debug json + target_debug_info = { + "a": { + "type": "torch.Tensor", + "dtype": "torch.float32", + "shape": [ + 2 + ], + "Max": 2.0, + "Min": 1.0, + "Mean": 1.5, + "Norm": 2.2360680103302, + "requires_grad": False, + } + } + debug_json_path = os.path.join(dump_path, "step0", "rank", "debug.json") + debug_json_dict = self.read_debug_json_into_dict(debug_json_path) + assert deep_compare(debug_json_dict["data"]["data_dict.0.debug"], target_debug_info) + + def test_save_multiple_times(self): + data = {"a": torch.Tensor([1., 2.])} + step = [] + call_times = 3 + async_dump = False + mode = "tensor" + dump_path = os.path.join(test_dir, "debug_save") + config_file_path = os.path.join(test_dir, "config.json") + + self.write_config_json(step, async_dump, mode, dump_path, config_file_path) + debugger = PrecisionDebugger(config_file_path) + for _ in range(call_times): + PrecisionDebugger.save(data, "data_dict", save_backward=False) + PrecisionDebugger.step() + + # check pt file + for i in range(call_times): + pt_path = os.path.join(dump_path, "step0", "rank", "dump_tensor_data", f"data_dict.{i}.debug.a.pt") + assert self.check_real_pt(pt_path, data["a"]) + + # check debug json + for i in range(call_times): + target_debug_info = { + "a": { + "type": "torch.Tensor", + "dtype": "torch.float32", + "shape": [ + 2 + ], + "Max": 2.0, + "Min": 1.0, + "Mean": 1.5, + "Norm": 2.2360680103302, + "requires_grad": False, + "data_name": f"data_dict.{i}.debug.a.pt" + } + } + + debug_json_path = os.path.join(dump_path, "step0", "rank", "debug.json") + debug_json_dict = self.read_debug_json_into_dict(debug_json_path) + assert deep_compare(debug_json_dict["data"][f"data_dict.{i}.debug"], target_debug_info) + + def test_save_backward(self): + x = torch.Tensor([1., 2.]) + target_x_grad = torch.Tensor([1., 1.]) + def _forward_simple_func(x): + PrecisionDebugger.save(x, "x_tensor") + return x.sum() + step = [] + async_dump = False + mode = "tensor" + dump_path = os.path.join(test_dir, "debug_save") + config_file_path = os.path.join(test_dir, "config.json") + self.write_config_json(step, async_dump, mode, dump_path, config_file_path) + debugger = PrecisionDebugger(config_file_path) + x.requires_grad = True + loss = _forward_simple_func(x) + loss.backward() + PrecisionDebugger.step() + x_info_list = [ + x, + os.path.join(dump_path, "step0", "rank", "dump_tensor_data", "x_tensor.0.debug.pt"), + "x_tensor.0.debug", + { + "type": "torch.Tensor", + "dtype": "torch.float32", + "shape": [ + 2 + ], + "Max": 2.0, + "Min": 1.0, + "Mean": 1.5, + "Norm": 2.2360680103302, + "requires_grad": True, + "data_name": "x_tensor.0.debug.pt" + }, + ] + x_grad_info_list = [ + target_x_grad, + os.path.join(dump_path, "step0", "rank", "dump_tensor_data", "x_tensor_grad.0.debug.pt"), + "x_tensor_grad.0.debug", + { + "type": "torch.Tensor", + "dtype": "torch.float32", + "shape": [ + 2 + ], + "Max": 1.0, + "Min": 1.0, + "Mean": 1.0, + "Norm": 1.4142135381698608, + "requires_grad": False, + "data_name": "x_tensor_grad.0.debug.pt" + }, + ] + check_list = [x_info_list, x_grad_info_list] + debug_json_path = os.path.join(dump_path, "step0", "rank", "debug.json") + debug_json_dict = self.read_debug_json_into_dict(debug_json_path) + for check_info in check_list: + target_tensor, target_tensor_path, target_tensor_key, target_tensor_info = check_info + assert self.check_real_pt(target_tensor_path, target_tensor) + assert deep_compare(debug_json_dict["data"][target_tensor_key], target_tensor_info) + + def test_save_compilcated_data_structure_backward(self): + x = torch.Tensor([1., 2.]) + target_x_grad = torch.Tensor([1., 1.]) + def _forward_complicated_func(x): + complicated_structure = [{"a_key": x}] + PrecisionDebugger.save(complicated_structure, "complicated_structure") + return complicated_structure[0]["a_key"].sum() + step = [] + async_dump = False + mode = "tensor" + dump_path = os.path.join(test_dir, "debug_save") + config_file_path = os.path.join(test_dir, "config.json") + self.write_config_json(step, async_dump, mode, dump_path, config_file_path) + debugger = PrecisionDebugger(config_file_path) + x.requires_grad = True + loss = _forward_complicated_func(x) + loss.backward() + PrecisionDebugger.step() + complicated_structure_info_list = [ + x, + os.path.join(dump_path, "step0", "rank", "dump_tensor_data", "complicated_structure.0.debug.0.a_key.pt"), + "complicated_structure.0.debug", + [ + { + "a_key": { + "type": "torch.Tensor", + "dtype": "torch.float32", + "shape": [ + 2 + ], + "Max": 2.0, + "Min": 1.0, + "Mean": 1.5, + "Norm": 2.2360680103302, + "requires_grad": True, + "data_name": "complicated_structure.0.debug.0.a_key.pt" + } + } + ], + ] + complicated_structure_grad_info_list = [ + target_x_grad, + os.path.join(dump_path, "step0", "rank", "dump_tensor_data", "complicated_structure_grad.0.debug.0.a_key.pt"), + "complicated_structure_grad.0.debug", + [ + { + "a_key": { + "type": "torch.Tensor", + "dtype": "torch.float32", + "shape": [ + 2 + ], + "Max": 1.0, + "Min": 1.0, + "Mean": 1.0, + "Norm": 1.4142135381698608, + "requires_grad": False, + "data_name": "complicated_structure_grad.0.debug.0.a_key.pt" + } + } + ], + ] + check_list = [complicated_structure_info_list, complicated_structure_grad_info_list] + debug_json_path = os.path.join(dump_path, "step0", "rank", "debug.json") + debug_json_dict = self.read_debug_json_into_dict(debug_json_path) + for check_info in check_list: + target_tensor, target_tensor_path, target_tensor_key, target_tensor_info = check_info + assert self.check_real_pt(target_tensor_path, target_tensor) + assert deep_compare(debug_json_dict["data"][target_tensor_key], target_tensor_info) \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/dump/test_module_dump.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/dump/test_module_dump.py index 63d6abc3a2430bb6f092820c4b97a02cdf675612..4ba3556c277f3326520547a6124170f32a9cc8e8 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/dump/test_module_dump.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/dump/test_module_dump.py @@ -16,45 +16,68 @@ import unittest from unittest.mock import patch, MagicMock -import torch -import torch.nn as nn -from msprobe.pytorch import PrecisionDebugger -from msprobe.pytorch.hook_module.api_registry import api_register -from msprobe.pytorch.service import torch_version_above_or_equal_2 +from torch import nn + +from msprobe.pytorch.common.log import logger +from msprobe.pytorch.dump.module_dump.module_dump import ModuleDumper +from msprobe.pytorch.dump.module_dump.module_processer import ModuleProcesser class TestModuleDumper(unittest.TestCase): - @classmethod - def setUpClass(cls): - PrecisionDebugger._instance = None - api_register.api_originality() + def setUp(self): + self.service = MagicMock() + with patch('msprobe.pytorch.dump.module_dump.module_dump.get_api_register'): + self.module_dumper = ModuleDumper(self.service) - @classmethod - def tearDownClass(cls): - PrecisionDebugger._instance = None - api_register.api_originality() + def test__init__(self): + self.service = MagicMock() + with patch('msprobe.pytorch.dump.module_dump.module_dump.get_api_register') as mock_get_api_register: + self.module_dumper = ModuleDumper(self.service) + self.assertEqual(self.module_dumper.service, self.service) + mock_get_api_register.assert_called_once() - def setUp(self): - self.module = nn.Linear(8, 4) - debugger = PrecisionDebugger(dump_path="./") - self.module_dumper = debugger.module_dumper + def test_start_module_dump(self): + module = nn.Module() + with patch.object(logger, 'info_on_rank_0') as mock_info: + module.msprobe_hook = True + ModuleProcesser.enable_module_dump = False + self.module_dumper.api_register.restore_all_api.reset_mock() + self.module_dumper.start_module_dump(module, 'dump_name') + mock_info.assert_called_with('The init dump is enabled, and the module dump function will not be available.') + self.assertFalse(ModuleProcesser.enable_module_dump) + self.module_dumper.api_register.restore_all_api.assert_not_called() + self.assertFalse(hasattr(module, 'msprobe_module_dump')) + + del module.msprobe_hook + mock_info.reset_mock() + self.module_dumper.start_module_dump(module, 'dump_name') + mock_info.assert_not_called() + self.assertTrue(ModuleProcesser.enable_module_dump) + self.module_dumper.api_register.restore_all_api.assert_called_once() + self.module_dumper.service.module_processor.register_module_hook.assert_called_with( + module, + self.module_dumper.service.build_hook, + recursive=False, + module_names=['dump_name'] + ) + self.assertTrue(module.msprobe_module_dump) + ModuleProcesser.enable_module_dump = False + + self.module_dumper.api_register.restore_all_api.reset_mock() + self.module_dumper.service.module_processor.register_module_hook.reset_mock() + self.module_dumper.start_module_dump(module, 'dump_name') + mock_info.assert_not_called() + self.assertTrue(ModuleProcesser.enable_module_dump) + self.module_dumper.api_register.restore_all_api.assert_called_once() + self.module_dumper.service.module_processor.register_module_hook.assert_not_called() + + ModuleProcesser.enable_module_dump = False def test_stop_module_dump(self): - self.module_dumper.hook_handle_list.extend([1, 2, 3]) - with patch('msprobe.pytorch.dump.module_dump.module_dump.api_register') as mock_api_register: - mock_handle1 = MagicMock(spec=torch.utils.hooks.RemovableHandle) - mock_handle2 = MagicMock(spec=torch.utils.hooks.RemovableHandle) - self.module_dumper.hook_handle_list.extend([mock_handle1, mock_handle2]) - - self.module_dumper.stop_module_dump() - mock_handle1.remove.assert_called_once() - mock_handle2.remove.assert_called_once() - self.assertEqual(self.module_dumper.hook_handle_list, []) - mock_api_register.api_modularity.assert_called_once() - - def test_register_hook(self): - self.module_dumper.register_hook(self.module, "TestModule") - if torch_version_above_or_equal_2: - self.assertEqual(len(self.module_dumper.hook_handle_list), 6) - else: - self.assertEqual(len(self.module_dumper.hook_handle_list), 5) + ModuleProcesser.enable_module_dump = True + self.module_dumper.api_register.register_all_api.reset_mock() + self.module_dumper.stop_module_dump() + self.assertFalse(ModuleProcesser.enable_module_dump) + self.module_dumper.api_register.register_all_api.assert_called_once() + + self.module_dumper.api_register.register_all_api.reset_mock() diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/dump/test_module_processer.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/dump/test_module_processer.py index f8a561b61b6a758a525675bdc59957e5c923b261..832f63f8fd99b53d8d1909bee45e7a5634c6ca92 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/dump/test_module_processer.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/dump/test_module_processer.py @@ -1,10 +1,24 @@ +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import unittest from unittest.mock import MagicMock import torch from msprobe.core.data_dump.scope import ModuleRangeScope -from msprobe.pytorch.common.utils import Const from msprobe.pytorch.dump.module_dump.module_processer import ModuleProcesser @@ -25,80 +39,12 @@ class TestModuleProcesser(unittest.TestCase): processor = ModuleProcesser(scope) self.assertIsNone(processor.scope) - def test_clone_return_value_and_test_clone_if_tensor(self): - def func(x): - return x - - input = torch.tensor([1]) - input_tuple = (torch.tensor([1]), torch.tensor([2])) - input_list = [torch.tensor([1]), torch.tensor([2])] - input_dict = {"A": torch.tensor([1]), "B": torch.tensor([2])} - - result = ModuleProcesser.clone_return_value(func)(input) - result[0] = 2 - self.assertNotEqual(result, input) - result_tuple = ModuleProcesser.clone_return_value(func)(input_tuple) - result_tuple[0][0] = 2 - self.assertNotEqual(result_tuple, input_tuple) - result_list = ModuleProcesser.clone_return_value(func)(input_list) - result_list[0][0] = 2 - self.assertNotEqual(result_list, input_list) - result_dict = ModuleProcesser.clone_return_value(func)(input_dict) - result_dict["A"][0] = 2 - self.assertNotEqual(result_dict, input_dict) - - def test_module_count_func(self): + def test_set_and_get_calls_number(self): + ModuleProcesser.reset_module_stats() test = ModuleProcesser(None) self.assertEqual(test.module_count, {}) module_name = "nope" - test.module_count_func(module_name) + test.set_and_get_calls_number(module_name) self.assertEqual(test.module_count["nope"], 0) - def test_node_hook_forward_start(self): - name_prefix = "forward_layer" - hook = self.processor.node_hook(name_prefix, start_or_stop=Const.START) - module = MagicMock() - input = (self.mock_tensor,) - module.mindstudio_reserved_name = None - hook(module, input) - expected_name = f"forward_layer{Const.SEP}0" - self.assertEqual(module.mindstudio_reserved_name, [expected_name]) - self.assertIn(expected_name, ModuleProcesser.module_stack) - self.assertEqual(ModuleProcesser.api_parent_node, expected_name) - - def test_node_hook_forward_stop(self): - name_prefix = "forward_layer" - hook = self.processor.node_hook(name_prefix, start_or_stop=Const.STOP) - ModuleProcesser.module_stack.append(f"forward_layer{Const.SEP}0") - - module = MagicMock() - input = (self.mock_tensor,) - reserved_name = f"forward_layer{Const.SEP}0" - module.mindstudio_reserved_name = [reserved_name] - hook(module, input) - self.assertNotIn([f"forward_layer{Const.SEP}0"], ModuleProcesser.module_stack) - self.assertEqual(ModuleProcesser.api_parent_node, reserved_name) - - def test_node_hook_backward(self): - name_prefix = "backward_layer" - hook = self.processor.node_hook(name_prefix, start_or_stop=Const.START) - - module = MagicMock() - input = (self.mock_tensor,) - module.mindstudio_reserved_name = None - ModuleProcesser.module_node[f"forward_layer{Const.SEP}0"] = None - hook(module, input) - expected_name = f"backward_layer{Const.SEP}0" - self.assertEqual(module.mindstudio_reserved_name, [expected_name]) - self.assertIn(expected_name, ModuleProcesser.module_node) - - def test_has_register_backward_hook(self): - module = MagicMock() - module._backward_hooks = {0: lambda: None} - module._is_full_backward_hook = False - result = self.processor.has_register_backward_hook(module) - self.assertTrue(result) - - module._is_full_backward_hook = True - result = self.processor.has_register_backward_hook(module) - self.assertFalse(result) + ModuleProcesser.reset_module_stats() diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/dump/test_pt_hook_wrapper.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/dump/test_pt_hook_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..88039390f1900bde2e81390af778b8f83c7eb8ff --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/dump/test_pt_hook_wrapper.py @@ -0,0 +1,92 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest.mock import MagicMock, patch + +import torch + +from msprobe.pytorch.dump.module_dump.hook_wrapper import wrap_setup_backward_hook + + +class TestWrapSetupBackwardHook(unittest.TestCase): + def setUp(self): + self.mock_func = MagicMock() + self.mock_func.return_value = ["clone_tensor1", "clone_tensor2"] + + self.decorated_func = wrap_setup_backward_hook(self.mock_func) + + self.tensor = torch.randn(3, requires_grad=True) + torch.set_grad_enabled(True) + + def test_insufficient_args(self): + result = self.decorated_func("test_case1") + self.mock_func.assert_called_once_with("test_case1") + self.assertListEqual(result, ["clone_tensor1", "clone_tensor2"]) + + def test_normal_processing_flow(self): + test_tensor = torch.randn(2, requires_grad=False) + test_data = { + "tensors": [self.tensor, torch.randn(2, requires_grad=True)], + "nested": { + "tuple": (self.tensor, test_tensor) + } + } + + mock_self = MagicMock() + mock_self.module.inplace = False + test_tensor1 = torch.randn(4, requires_grad=True) + test_tensor2 = torch.randn(4, requires_grad=True) + test_tensor3 = torch.randn(4, requires_grad=True) + self.mock_func.return_value = [test_tensor1, test_tensor2, test_tensor3] + result = self.decorated_func(mock_self, test_data) + + self.assertIsInstance(result, dict) + self.assertFalse(torch.equal(result["tensors"][0], self.tensor)) + self.assertTrue(torch.equal(result["tensors"][1], test_tensor2)) + self.assertIsInstance(result["nested"]["tuple"][0], torch.Tensor) + self.assertTrue(torch.equal(result["nested"]["tuple"][1], test_tensor)) + + def test_complex_data_structures(self): + test_case = [ + self.tensor, + {"dict": torch.randn(4, requires_grad=True)}, + (torch.randn(5, requires_grad=True),), + [torch.randn(6, requires_grad=True)] + ] + + mock_self = MagicMock() + mock_self.module.inplace = False + test_tensor1 = torch.randn(4, requires_grad=True) + test_tensor2 = torch.randn(5, requires_grad=True) + test_tensor3 = torch.randn(6, requires_grad=True) + self.mock_func.return_value = [self.tensor, test_tensor1, test_tensor2, test_tensor3] + result = self.decorated_func(mock_self, test_case) + + self.assertIsInstance(result, list) + self.assertTrue(torch.equal(result[1]["dict"], test_tensor1)) + self.assertTrue(torch.equal(result[2][0], test_tensor2)) + self.assertTrue(torch.equal(result[3][0], test_tensor3)) + + @patch('msprobe.pytorch.common.utils.is_float8_tensor', return_value=True) + def test_float8_tensor_handling(self, _): + test_data = [torch.randn(3, requires_grad=True)] + + mock_self = MagicMock() + self.mock_func.return_value = [] + result = self.decorated_func(mock_self, test_data) + + self.assertIsInstance(result, list) + self.assertListEqual(result, test_data) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/dump/test_pt_kernel_config.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/dump/test_pt_kernel_config.py index fbeeb07ffc9ac43eedc22ed95d1fa142bb2dd6e4..89176f5f51ce9f93b13bc14906be62e0425d957c 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/dump/test_pt_kernel_config.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/dump/test_pt_kernel_config.py @@ -16,11 +16,11 @@ import unittest from unittest.mock import patch -from msprobe.pytorch.dump.kernel_dump.kernel_config import create_kernel_config_json +from msprobe.core.kernel_dump.kernel_config import create_kernel_config_json class TestPtKernelConfig(unittest.TestCase): - @patch("msprobe.pytorch.dump.kernel_dump.kernel_config.save_json") + @patch("msprobe.core.kernel_dump.kernel_config.save_json") def test_create_kernel_config_json_with_rank(self, mock_save_json): dump_path = "./step0" cur_rank = 0 @@ -36,7 +36,7 @@ class TestPtKernelConfig(unittest.TestCase): } mock_save_json.assert_called_once_with(kernel_config_path, config_info, indent=4) - @patch("msprobe.pytorch.dump.kernel_dump.kernel_config.save_json") + @patch("msprobe.core.kernel_dump.kernel_config.save_json") def test_create_kernel_config_json_without_rank(self, mock_save_json): dump_path = "./step0" cur_rank = '' diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py index be2215dcd9cb22577a84954b9283ed68825de86e..d4e568303a8b22058cba4ad879b160b3169a6cae 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py @@ -166,7 +166,7 @@ class TestPerturbedLayer(TestCase): layer.pre_check(y) mock_logger.assert_called_with( "[msprobe] Free Benchmark: For test_api_name, " - "Maximun value is less than the minimun threshold. Cancel add noise." + "maximum value is less than the minimum threshold. Cancel adding noise." ) # 对于输入张量,add_noise扰动因子对大于极小值的部分增加一个小值 @@ -212,7 +212,7 @@ class TestPerturbedLayer(TestCase): layer.pre_check(y) mock_logger.assert_called_with( "[msprobe] Free Benchmark: For test_api_name, " - "Maximun value is less than the minimun threshold. Cancel add noise." + "maximum value is less than the minimum threshold. Cancel adding noise." ) # 对于低精度输入、run cpu会升精度在cpu上计算,并会打印日志 diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/grad_probe/test_grad_csv.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/grad_probe/test_grad_csv.py index f39d3f091faf8d57f80cccbadc15259ee54269f0..80e32ac2890dfa0247eb5cd76aefe45cbc735345 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/grad_probe/test_grad_csv.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/grad_probe/test_grad_csv.py @@ -24,7 +24,7 @@ class TestGradCSV(unittest.TestCase): def test_level_L0_content(self): generated_csv_line = GradStatCsv.generate_csv_line("model.conv2d", level_adp["L0"], grad_tensor, [-1, 0, 1]) - self.assertEqual(['model.conv2d', '678a6c7d9d9716682b56fda097d0936c', 2.0, -2.0, 2.851315498352051, [2, 2]], + self.assertEqual(['model.conv2d', 'e2863940', 2.0, -2.0, 2.851315498352051, [2, 2]], generated_csv_line) def test_level_L1_content(self): diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_api_registry.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_api_registry.py deleted file mode 100644 index 837ad23df76be2a012a7408dab4879847937f229..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_api_registry.py +++ /dev/null @@ -1,130 +0,0 @@ -import unittest -from msprobe.pytorch.hook_module.api_registry import ApiRegistry, torch_version_above_2, is_gpu - - -class TestApiRegistry(unittest.TestCase): - - def test_store_ori_attr(self): - class A(): - a1 = 1 - class B(): - a = A() - b1 = 1 - b2 = 2 - - api_list = ["a.a1", "b1", "b2"] - expect_output = {"a.a1":1, "b1":1, "b2":2} - actual_output = dict() - ApiRegistry.store_ori_attr(B, api_list, actual_output) - self.assertEqual(actual_output, expect_output) - - - def test_set_api_attr(self): - class A(): - a1 = 1 - class B(): - a = A().__class__ - b1 = 1 - - attr_dict = {"a.a2":2, "b2":2, "b3":3} - ApiRegistry.set_api_attr(B, attr_dict) - - for k, v in attr_dict.items(): - if '.' in k: - sub_module_name, sub_op = k.rsplit('.', 1) - sub_module = getattr(B, sub_module_name, None) - - self.assertEqual(getattr(sub_module, sub_op), v) - else: - self.assertEqual(getattr(B, k), v) - - def test_api_modularity(self): - - import torch - import torch.distributed as dist - #import torch_npu #门禁没有安装torch_npu - from msprobe.pytorch.hook_module.api_registry import torch_without_guard_version, npu_distributed_api, is_gpu, torch_version_above_2 - - - - reg = ApiRegistry() - attr_dict = {"b2":2, "b3":3} - reg.tensor_hook_attr = attr_dict - reg.torch_hook_attr = attr_dict - reg.functional_hook_attr = attr_dict - reg.distributed_hook_attr = attr_dict - reg.npu_distributed_hook_attr = attr_dict - reg.aten_hook_attr = attr_dict - reg.vf_hook_attr = attr_dict - reg.torch_npu_hook_attr = attr_dict - - reg.api_modularity() - self.assertEqual(torch.Tensor.b2, 2) - - self.assertEqual(torch.b2, 2) - self.assertEqual(torch.nn.functional.b2, 2) - self.assertEqual(dist.b2, 2) - self.assertEqual(dist.distributed_c10d.b2, 2) - #if not is_gpu and not torch_without_guard_version: - #self.assertEqual(torch_npu.distributed.b2, 2) - #self.assertEqual(torch_npu.distributed.distributed_c10d.b2, 2) - if torch_version_above_2: - self.assertEqual(torch.ops.aten.b2, 2) - self.assertEqual(torch._VF.b2, 2) - #if not is_gpu: - #self.assertEqual(torch_npu.b2, 2) - - - def test_api_originality(self): - import torch - import torch.distributed as dist - #import torch_npu #门禁没有安装torch_npu - from msprobe.pytorch.hook_module.api_registry import torch_without_guard_version, npu_distributed_api, is_gpu, torch_version_above_2 - - - - reg = ApiRegistry() - attr_dict = {"b2":2, "b3":3} - reg.tensor_hook_attr = attr_dict - reg.torch_hook_attr = attr_dict - reg.functional_hook_attr = attr_dict - reg.distributed_hook_attr = attr_dict - reg.npu_distributed_hook_attr = attr_dict - reg.aten_hook_attr = attr_dict - reg.vf_hook_attr = attr_dict - reg.torch_npu_hook_attr = attr_dict - - reg.api_originality() - self.assertEqual(torch.Tensor.b2, 2) - - self.assertEqual(torch.b2, 2) - self.assertEqual(torch.nn.functional.b2, 2) - self.assertEqual(dist.b2, 2) - self.assertEqual(dist.distributed_c10d.b2, 2) - #if not is_gpu and not torch_without_guard_version: - #self.assertEqual(torch_npu.distributed.b2, 2) - #self.assertEqual(torch_npu.distributed.distributed_c10d.b2, 2) - if torch_version_above_2: - self.assertEqual(torch.ops.aten.b2, 2) - self.assertEqual(torch._VF.b2, 2) - #if not is_gpu: - #self.assertEqual(torch_npu.b2, 2) - - def test_initialize_hook(self): - def hook_test(): - pass - - reg = ApiRegistry() - reg.initialize_hook(hook_test) - empty_list = [] - self.assertFalse(empty_list==reg.tensor_hook_attr) - self.assertFalse(empty_list==reg.torch_hook_attr) - self.assertFalse(empty_list==reg.functional_hook_attr) - self.assertFalse(empty_list==reg.distributed_hook_attr) - self.assertFalse(empty_list==reg.npu_distributed_hook_attr) - if torch_version_above_2: - #print(True) - self.assertFalse(empty_list==reg.aten_hook_attr) - if not is_gpu: - #print(True) - self.assertFalse(empty_list==reg.torch_npu_hook_attr) \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_hook_module.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_hook_module.py deleted file mode 100644 index 1524a82ae1fc81eee245fa73bde4b4938cb89638..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_hook_module.py +++ /dev/null @@ -1,34 +0,0 @@ -import unittest -from unittest.mock import MagicMock, patch -import threading -from msprobe.pytorch.hook_module.hook_module import HOOKModule - -class TestHOOKModuleInit(unittest.TestCase): - - def setUp(self): - self.mock_build_hook = MagicMock(return_value=(MagicMock(), MagicMock(), MagicMock(), None)) - - def test_thread_handling(self): - module = HOOKModule(self.mock_build_hook) - current_thread_id = module.current_thread - self.assertEqual(current_thread_id, threading.current_thread().ident) - - -class TestHOOKModuleCall(unittest.TestCase): - def setUp(self): - self.mock_build_hook = MagicMock(return_value=(MagicMock(), MagicMock(), MagicMock(), None)) - self.module = HOOKModule(self.mock_build_hook) - - @patch.object(HOOKModule, '_call_func') - def test_call_function(self, mock_call_func): - mock_call_func.return_value = "test_result" - result = self.module("input_data") - mock_call_func.assert_called_once_with("input_data", **{}) - self.assertEqual(result, "test_result") - - @patch.object(HOOKModule, '_call_func') - def test_call_func_with_hooks(self, mock_call_func): - mock_call_func.return_value = "test_result_with_hooks" - result = self.module("input_data") - self.assertEqual(result, "test_result_with_hooks") - HOOKModule.inner_stop_hook[self.module.current_thread] = False diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_pt_api_register.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_pt_api_register.py new file mode 100644 index 0000000000000000000000000000000000000000..da6e21049bf76e0389795aa816dbee8218bc87be --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_pt_api_register.py @@ -0,0 +1,212 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest.mock import MagicMock, patch + +import msprobe.pytorch.hook_module.api_register as api_register +from msprobe.pytorch.hook_module.api_register import ( + tensor_module_forward, + dist_module_forward, + npu_module_forward, + get_api_register, + ApiTemplate +) + + +class TestAPIRegister(unittest.TestCase): + def setUp(self): + api_register.api_register = None + + def test_tensor_module_forward(self): + mock_module = MagicMock() + mock_module.api_name = "test_name" + mock_module.api_func.return_value = "test_result" + + args = (1, 2, 3) + kwargs = {"key": "value"} + result = tensor_module_forward(mock_module, *args, **kwargs) + + mock_module.api_func.assert_called_once_with(*args, **kwargs) + self.assertEqual(result, "test_result") + + @patch('msprobe.pytorch.hook_module.api_register.logger.warning') + def test_basic_dist_module_forward(self, mock_logger): + mock_module = MagicMock() + mock_module.api_func.return_value = "test_handle" + mock_module.api_name = "test_api" + + result = dist_module_forward(mock_module, 1, 2, key="value") + mock_module.api_func.assert_called_once_with(1, 2, key="value") + self.assertEqual(result, "test_handle") + mock_logger.assert_not_called() + + @patch('msprobe.pytorch.hook_module.api_register.ApiRegistry') + def test_get_api_register_with_new_obj(self, mock_api_registry): + get_api_register(return_new=True) + mock_api_registry.assert_called_once() + self.assertIsNone(api_register.api_register) + + @patch('msprobe.pytorch.hook_module.api_register.ApiRegistry') + def test_get_api_register_with_not_new_obj(self, mock_api_registry): + get_api_register() + mock_api_registry.assert_called_once() + self.assertIsNotNone(api_register.api_register) + + +class TestNpuModuleForward(unittest.TestCase): + def setUp(self): + self.npu_custom_functions = { + "custom_func": MagicMock(return_value="custom_result"), + "npu_fusion_attention": MagicMock(return_value="nfa_result"), + "gpu_fusion_attention": MagicMock(return_value="gfa_result") + } + + self.module = MagicMock() + self.module.api_func.return_value = "test_result" + + def test_with_hook_enabled(self): + self.module.need_hook = True + result = npu_module_forward(self.module, 1, 2, key="value") + self.module.api_func.assert_called_once_with(1, 2, key="value") + self.assertEqual(result, "test_result") + + def test_with_unknown_api(self): + self.module.need_hook = False + self.module.api_name = "unknown_func" + with patch('msprobe.pytorch.hook_module.api_register.npu_custom_functions', new=self.npu_custom_functions): + with self.assertRaises(Exception) as context: + npu_module_forward(self.module, 1, 2, key="value") + self.assertIn("There is not bench function unknown_func", str(context.exception)) + + def test_cuda_device_with_mapping(self): + self.module.need_hook = False + self.module.api_name = "npu_fusion_attention" + self.module.device = 'cuda' + + with patch('msprobe.pytorch.hook_module.api_register.npu_custom_functions', new=self.npu_custom_functions): + result = npu_module_forward(self.module, 1, 2, key="value") + self.npu_custom_functions["gpu_fusion_attention"].assert_called_once_with(1, 2, key="value") + self.assertEqual(result, "gfa_result") + + def test_cpu_device(self): + self.module.need_hook = False + self.module.api_name = "custom_func" + self.module.device = "cpu" + + with patch('msprobe.pytorch.hook_module.api_register.npu_custom_functions', new=self.npu_custom_functions): + result = npu_module_forward(self.module, 1, 2, key="value") + self.npu_custom_functions["custom_func"].assert_called_once_with(1, 2, key="value") + self.assertEqual(result, "custom_result") + + def test_unsupported_device(self): + self.module.need_hook = False + self.module.api_name = "custom_func" + self.module.device = "unsupported_device" + + with patch('msprobe.pytorch.hook_module.api_register.npu_custom_functions', new=self.npu_custom_functions): + result = npu_module_forward(self.module, 1, 2, key="value") + self.module.api_func.assert_called_once_with(1, 2, key="value") + self.assertEqual(result, "test_result") + + +class TestApiTemplate(unittest.TestCase): + def setUp(self): + self.api_name = "Tensor.test_api" + self.api_func = MagicMock(return_value="test_result") + self.prefix = "test_prefix" + self.hook_build_func = MagicMock() + self.mock_hook_module = MagicMock() + + def test_init(self): + with patch('msprobe.pytorch.hook_module.api_register.HOOKModule') as mock_hook_module: + template = ApiTemplate( + self.api_name, + self.api_func, + self.prefix, + self.hook_build_func, + need_hook=False + ) + + self.assertEqual(template.api_name, self.api_name) + self.assertEqual(template.api_func, self.api_func) + self.assertEqual(template.prefix, self.prefix) + self.assertEqual(template.prefix_api_name, "test_prefix.test_api.") + self.assertEqual(template.device, "cpu") + self.assertFalse(template.need_hook) + + self.assertFalse(hasattr(template, 'op_is_distributed')) + + def test_init_with_distributed_prefix(self): + with patch('msprobe.pytorch.hook_module.api_register.HOOKModule'): + self.prefix = "Distributed" + template = ApiTemplate( + self.api_name, + self.api_func, + self.prefix, + self.hook_build_func, + need_hook=False, + device="npu" + ) + + self.assertEqual(template.device, "npu") + self.assertEqual(template.prefix_api_name, "Distributed.test_api.") + self.assertTrue(template.op_is_distributed) + + def test_init_without_hook(self): + with patch('msprobe.pytorch.hook_module.api_register.HOOKModule') as mock_hook_module: + template = ApiTemplate( + self.api_name, + self.api_func, + self.prefix, + self.hook_build_func, + need_hook=False, + device="npu" + ) + + self.assertFalse(template.need_hook) + self.mock_hook_module.assert_not_called() + + def test_forward_with_prefix_match(self): + with patch('msprobe.pytorch.hook_module.api_register.HOOKModule'): + self.prefix = "Tensor" + template = ApiTemplate( + self.api_name, + self.api_func, + self.prefix, + self.hook_build_func, + need_hook=False, + device="npu" + ) + + result = template.forward("arg1", key="value") + + self.assertEqual(result, "test_result") + + def test_forward_without_prefix_match(self): + with patch('msprobe.pytorch.hook_module.api_register.HOOKModule'): + template = ApiTemplate( + self.api_name, + self.api_func, + self.prefix, + self.hook_build_func, + need_hook=False, + device="npu" + ) + + result = template.forward("arg1", key="value") + + self.api_func.assert_called_once_with("arg1", key="value") + self.assertEqual(result, "test_result") diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_pt_hook_manager.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_pt_hook_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..aa21afae721ceff54fa671d817aaa60467c60fd8 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_pt_hook_manager.py @@ -0,0 +1,100 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest.mock import MagicMock, patch +from contextlib import nullcontext +from msprobe.pytorch.hook_module.pt_hook_manager import PytorchHookManager +from msprobe.core.common.const import Const +from msprobe.core.hook_manager import HookSet, BaseHookManager + + +class TestPytorchHookManager(unittest.TestCase): + def setUp(self): + self.mock_data_collector = MagicMock() + self.mock_config = MagicMock() + self.mock_config.data_mode = ["all"] + self.mock_config.task = "statistics" + self.manager = PytorchHookManager( + self.mock_data_collector, + self.mock_config + ) + BaseHookManager.inner_switch = False + + def test_properties(self): + with patch('msprobe.pytorch.hook_module.pt_hook_manager.is_recomputation', return_value=True): + self.assertTrue(self.manager._is_recompute) + + with patch('msprobe.pytorch.hook_module.pt_hook_manager.is_recomputation', return_value=False): + self.assertFalse(self.manager._is_recompute) + + def test_no_grad_context(self): + self.assertIsInstance(self.manager._no_grad_context(), nullcontext) + + def test_add_count(self): + with patch('msprobe.pytorch.hook_module.pt_hook_manager.HOOKModule.add_module_count') as mock_add: + self.manager._add_count("test_layer") + mock_add.assert_called_once_with("test_layer") + + def test_process_kwargs_and_output(self): + with patch('msprobe.pytorch.hook_module.pt_hook_manager.torch_version_above_or_equal_2', new=True): + kwargs, output = self.manager._process_kwargs_and_output( + None, None, "kwargs_value", "output_value" + ) + self.assertEqual(kwargs, "kwargs_value") + self.assertEqual(output, "output_value") + + with patch('msprobe.pytorch.hook_module.pt_hook_manager.torch_version_above_or_equal_2', new=False): + kwargs, output = self.manager._process_kwargs_and_output( + None, None, "kwargs_value", "output_value" + ) + self.assertEqual(kwargs, {}) + self.assertEqual(output, "kwargs_value") + + def test_build_hook(self): + hookset = self.manager.build_hook(Const.API, "test_api") + self.assertIsInstance(hookset, HookSet) + self.assertTrue(callable(hookset.forward_hook)) + self.assertTrue(callable(hookset.forward_pre_hook)) + self.assertTrue(callable(hookset.backward_hook)) + self.assertIsNone(hookset.backward_pre_hook) + + hookset = self.manager.build_hook(Const.MODULE, "test_module") + self.assertEqual(hookset.forward_pre_hook.__name__, "forward_pre_hook") + + def test_need_exchange(self): + self.assertTrue(self.manager._need_exchange(None)) + self.assertTrue(self.manager._need_exchange(MagicMock())) + + def test_get_params_dict(self): + mock_module = MagicMock() + + self.mock_config.task = Const.STRUCTURE + params_dict = self.manager._get_params_dict(mock_module) + self.assertEqual(params_dict, {}) + + self.mock_config.task = "statistics" + + mock_named_params = [ + ("conv.weight", MagicMock()), + ("bn.bias", MagicMock()) + ] + mock_module.named_parameters.return_value = mock_named_params + params_dict = self.manager._get_params_dict(mock_module) + mock_module.named_parameters.assert_called_once_with(recurse=False) + + self.assertEqual(set(params_dict.keys()), {"weight", "bias"}) + self.assertEqual(params_dict["weight"], mock_named_params[0][1]) + self.assertEqual(params_dict["bias"], mock_named_params[1][1]) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_pt_hook_module.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_pt_hook_module.py new file mode 100644 index 0000000000000000000000000000000000000000..2abb582ee4db611f00c83a44a466bea309a2192f --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_pt_hook_module.py @@ -0,0 +1,91 @@ +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import threading +import unittest +from collections import defaultdict +from unittest.mock import MagicMock, patch +from msprobe.core.hook_manager import HookSet +from msprobe.pytorch.hook_module.hook_module import HOOKModule + + +class TestHOOKModule(unittest.TestCase): + def setUp(self): + self.mock_build_hook = MagicMock(return_value=HookSet(MagicMock(), MagicMock(), MagicMock())) + HOOKModule.module_count = defaultdict(int) + HOOKModule.inner_stop_hook = {} + + def test_init_with_stop_hook(self): + expected_thread = threading.current_thread().ident + HOOKModule.inner_stop_hook[expected_thread] = True + + module1 = HOOKModule(self.mock_build_hook) + current_thread = module1.current_thread + + self.assertEqual(current_thread, expected_thread) + self.assertTrue(module1.inner_stop_hook[current_thread]) + self.assertTrue(module1.stop_hook) + self.assertFalse(hasattr(module1, "forward_data_collected")) + + def test_init_with_start_hook(self): + module1 = HOOKModule(self.mock_build_hook) + current_thread = module1.current_thread + expected_thread = threading.current_thread().ident + + self.assertEqual(current_thread, expected_thread) + self.assertFalse(module1.inner_stop_hook[current_thread]) + self.assertFalse(module1.stop_hook) + self.assertTrue(hasattr(module1, "forward_data_collected")) + + @patch.object(HOOKModule, '_call_func') + def test_call_with_stop_hooks(self, mock_call_func): + mock_call_func.return_value = "test_result" + expected_thread = threading.current_thread().ident + HOOKModule.inner_stop_hook[expected_thread] = True + + module1 = HOOKModule(self.mock_build_hook) + self.assertTrue(module1.stop_hook) + + result = module1("arg1", "arg2", key="value") + mock_call_func.assert_called_once_with("arg1", "arg2", key="value") + self.assertEqual(result, "test_result") + self.assertTrue(HOOKModule.inner_stop_hook[expected_thread]) + + @patch.object(HOOKModule, '_call_func') + def test_call_with_start_hooks(self, mock_call_func): + mock_call_func.return_value = "test_result" + expected_thread = threading.current_thread().ident + + module1 = HOOKModule(self.mock_build_hook) + self.assertFalse(module1.stop_hook) + + result = module1("arg1", "arg2", key="value") + mock_call_func.assert_called_once_with("arg1", "arg2", key="value") + self.assertEqual(result, "test_result") + self.assertFalse(HOOKModule.inner_stop_hook[expected_thread]) + + def test_reset_module_stats(self): + HOOKModule.module_count = {"Tensor.add.0.forward": 0} + HOOKModule.reset_module_stats() + self.assertDictEqual(HOOKModule.module_count, defaultdict(int)) + + def test_add_module_count(self): + HOOKModule.add_module_count("Tensor.add.0.forward") + self.assertEqual(HOOKModule.module_count["Tensor.add.0.forward"], 1) + + def test_get_module_count(self): + HOOKModule.module_count = {"Tensor.add.0.forward": 0} + result = HOOKModule.get_module_count("Tensor.add.0.forward") + self.assertEqual(result, 0) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_pt_jit_script_wrapper.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_pt_jit_script_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..61909523fc523dead62887ba94f399424b72a098 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_pt_jit_script_wrapper.py @@ -0,0 +1,51 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest.mock import MagicMock, patch + +import torch +from msprobe.pytorch.hook_module.jit_script_wrapper import wrap_jit_script_func + + +class TestWrapJitScriptFunc(unittest.TestCase): + def setUp(self): + self.original_script = torch.jit.script + + self.mock_api_register = MagicMock() + self.mock_api_register.all_api_registered = True + self.mock_api_register.register_all_api = MagicMock() + self.mock_api_register.restore_all_api = MagicMock() + + def tearDown(self): + torch.jit.script = self.original_script + + @patch('torch.jit.script', new_callable=MagicMock) + @patch('msprobe.pytorch.hook_module.jit_script_wrapper.get_api_register', return_value=MagicMock()) + def test_patched_script(self, mock_get_api, mock_original_script): + mock_original_script.return_value = "mocked_result" + mock_get_api.return_value = self.mock_api_register + + wrap_jit_script_func() + + self.assertNotEqual(torch.jit.script, self.original_script) + + result = torch.jit.script("test_input") + + mock_original_script.assert_called_once_with("test_input") + self.assertEqual(result, "mocked_result") + + self.mock_api_register.restore_all_api.assert_called_once() + self.mock_api_register.register_all_api.assert_called_once() diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py index af669cb5c73de85e51f36f62f9e7dc61bb599ca1..0ee3df9bccbd4e4fa93abb931ad2e379b44344e8 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py @@ -1,15 +1,35 @@ +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import unittest from unittest.mock import MagicMock, patch import torch +from msprobe.core.hook_manager import HookSet from msprobe.pytorch.function_factory import npu_custom_grad_functions -from msprobe.pytorch.hook_module.wrap_aten import AtenOPTemplate, white_aten_ops, \ +from msprobe.pytorch.hook_module.wrap_aten import ( + AtenOPTemplate, + white_aten_ops, AtenOPPacketTemplate +) def mock_build_hook(prefix): - return (MagicMock(), MagicMock(), MagicMock(), MagicMock()) + return HookSet(MagicMock(), MagicMock(), MagicMock()) + class TestAtenOPTemplate(unittest.TestCase): @@ -79,8 +99,8 @@ class TestAtenOPPacketTemplate(unittest.TestCase): del self.mock_op_packet.nonexistent_attr with self.assertRaises(AttributeError) as context: _ = self.template.nonexistent_attr - self.assertIn("or OpOverloadPacket does not have attribute 'nonexistent_attr'.", \ - str(context.exception)) + self.assertIn("or OpOverloadPacket does not have attribute 'nonexistent_attr'.", + str(context.exception)) @patch('msprobe.pytorch.hook_module.wrap_aten.AtenOPTemplate', autospec=True) def test_getattr_op_overload(self, MockAtenOPTemplate): diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py deleted file mode 100644 index 246feb56becf9942de9214f5b24b8471e9b4024a..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +++ /dev/null @@ -1,41 +0,0 @@ -import unittest -import torch.distributed as dist -from msprobe.pytorch.hook_module.wrap_distributed import * - -class TestWrapDistributed(unittest.TestCase): - def hook(name, prefix): - def forward_pre_hook(nope, input, kwargs): - return input, kwargs - - def forward_hook(nope, input, kwargs, result): - return 2 - - def backward_hook(): - pass - - def forward_hook_torch_version_below_2(): - pass - - return forward_pre_hook, forward_hook, backward_hook, forward_hook_torch_version_below_2 - - def test_get_distributed_ops(self): - ops = get_distributed_ops() - self.assertIsInstance(ops, set) - - def test_DistributedOPTemplate(self): - self.setUp() - op_name = 'all_reduce' - if op_name in get_distributed_ops(): - op = DistributedOPTemplate(op_name, self.hook) - self.assertEqual(op.op_name_, op_name) - - def test_wrap_distributed_op(self): - op_name = 'all_reduce' - if op_name in get_distributed_ops(): - wrapped_op = wrap_distributed_op(op_name, self.hook) - self.assertTrue(callable(wrapped_op)) - - def test_wrap_distributed_ops_and_bind(self): - wrap_distributed_ops_and_bind(self.hook) - for op_name in get_distributed_ops(): - self.assertTrue(hasattr(HOOKDistributedOP, "wrap_" + str(op_name))) \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py deleted file mode 100644 index 282551e3cefdb2ae63efda284f5e7ae7482ae81c..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +++ /dev/null @@ -1,73 +0,0 @@ -import unittest -import torch -import torch.nn.functional as F -from msprobe.pytorch.hook_module.wrap_functional import get_functional_ops, \ - wrap_functional_ops_and_bind, HOOKFunctionalOP -from msprobe.pytorch.common.utils import remove_dropout - - -class TestDropoutFunctions(unittest.TestCase): - - def setUp(self): - self.input_tensor = torch.ones(10, 10) - remove_dropout() - - def test_function_dropout_no_dropout(self): - output = F.dropout(self.input_tensor, p = 0., training = True) - self.assertTrue(torch.equal(self.input_tensor, output)) - - def test_function_dropout_train_vs_eval(self): - output_train = F.dropout(self.input_tensor, p = 0., training = True) - output_eval = F.dropout(self.input_tensor, p = 0., training = False) - self.assertTrue(torch.equal(output_train, output_eval)) - - def test_function_dropout_invalid_probability(self): - with self.assertRaises(ValueError): - F.dropout(self.input_tensor, p = -0.1) - with self.assertRaises(ValueError): - F.dropout(self.input_tensor, p = 1.1) - - def test_function_dropout2d_no_dropout(self): - output = F.dropout2d(self.input_tensor, p = 0., training = True) - self.assertTrue(torch.equal(self.input_tensor, output)) - - def test_function_dropout2d_train_vs_eval(self): - output_train = F.dropout2d(self.input_tensor, p = 0., training = True) - output_eval = F.dropout2d(self.input_tensor, p = 0., training = False) - self.assertTrue(torch.equal(output_train, output_eval)) - - def test_function_dropout2d_invalid_probability(self): - with self.assertRaises(ValueError): - F.dropout2d(self.input_tensor, p = -0.1) - with self.assertRaises(ValueError): - F.dropout2d(self.input_tensor, p = 1.1) - - def test_function_dropout3d_no_dropout(self): - input_tensor_3d = self.input_tensor.unsqueeze(0) - output = F.dropout3d(input_tensor_3d, p = 0., training = True) - self.assertTrue(torch.equal(input_tensor_3d, output)) - - def test_function_dropout3d_train_vs_eval(self): - input_tensor_3d = self.input_tensor.unsqueeze(0) - output_train = F.dropout3d(input_tensor_3d, p = 0., training = True) - output_eval = F.dropout3d(input_tensor_3d, p = 0., training = False) - self.assertTrue(torch.equal(output_train, output_eval)) - - def test_function_dropout3d_invalid_probability(self): - input_tensor_3d = self.input_tensor.unsqueeze(0) - with self.assertRaises(ValueError): - F.dropout3d(input_tensor_3d, p = -0.1) - with self.assertRaises(ValueError): - F.dropout3d(input_tensor_3d, p = 1.1) - - -class TestWrapFunctional(unittest.TestCase): - - def test_get_functional_ops(self): - expected_ops = {'relu', 'sigmoid', 'softmax'} - actual_ops = get_functional_ops() - self.assertTrue(expected_ops.issubset(actual_ops)) - - def test_wrap_functional_ops_and_bind(self): - wrap_functional_ops_and_bind(None) - self.assertTrue(hasattr(HOOKFunctionalOP, 'wrap_relu')) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_npu_custom.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_npu_custom.py deleted file mode 100644 index 573d6d000f37f429619b89507cecd1258fbe4c8b..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_npu_custom.py +++ /dev/null @@ -1,43 +0,0 @@ -import unittest -from unittest.mock import MagicMock, patch - -from msprobe.core.common.const import Const -from msprobe.core.common.log import logger -from msprobe.pytorch.function_factory import npu_custom_functions -from msprobe.pytorch.hook_module.wrap_npu_custom import NpuOPTemplate - -try: - import torch_npu -except ImportError: - logger.info("Failing to import torch_npu.") - - -class TestNpuOPTemplate(unittest.TestCase): - - def setUp(self): - self.mock_hook = MagicMock(return_value=(MagicMock(), MagicMock(), MagicMock(), None)) - self.template = NpuOPTemplate("sum", self.mock_hook) - - def test_init(self): - self.assertEqual(self.template.op_name_, "sum") - self.assertEqual(self.template.prefix_op_name_, f"NPU{Const.SEP}sum{Const.SEP}") - self.assertTrue(self.template.need_hook) - self.assertEqual(self.template.device, Const.CPU_LOWERCASE) - - @patch('torch.ops.npu.sum') - def test_forward_without_hook(self, mock_npu_sum): - self.template.need_hook = False - npu_custom_functions["sum"] = MagicMock(return_value="output_from_custom") - - result = self.template.forward(1, 2, key='value') - self.assertEqual(result, "output_from_custom") - mock_npu_sum.assert_not_called() - - @patch('torch.ops.npu.sum') - def test_forward_with_hook(self, mock_npu_sum): - self.template.need_hook = True - mock_npu_sum.return_value = "output_from_npu" - - result = self.template.forward(1, 2, key='value') - self.assertEqual(result, "output_from_npu") - mock_npu_sum.assert_called_once_with(1, 2, key='value') diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py deleted file mode 100644 index 6868c5bda7a88c84702d15e995c7f60af2b4e4c5..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +++ /dev/null @@ -1,40 +0,0 @@ -import unittest -import torch -from msprobe.pytorch.hook_module.wrap_tensor import get_tensor_ops, HOOKTensor, TensorOPTemplate, wrap_tensor_op, wrap_tensor_ops_and_bind - -class TestWrapTensor(unittest.TestCase): - - def hook(name, prefix): - def forward_pre_hook(nope, input, kwargs): - return input, kwargs - - def forward_hook(nope, input, kwargs, result): - return 2 - - def backward_hook(): - pass - - def forward_hook_torch_version_below_2(): - pass - - return forward_pre_hook, forward_hook, backward_hook, forward_hook_torch_version_below_2 - - def test_get_tensor_ops(self): - result = get_tensor_ops() - self.assertIsInstance(result, set) - - def test_HOOKTensor(self): - hook_tensor = HOOKTensor() - self.assertIsInstance(hook_tensor, HOOKTensor) - - def test_TensorOPTemplate(self): - tensor_op_template = TensorOPTemplate('add', self.hook) - self.assertTrue(tensor_op_template.op_name_, 'add') - - def test_wrap_tensor_op(self): - wrapped_op = wrap_tensor_op('add', self.hook) - self.assertTrue(callable(wrapped_op)) - - def test_wrap_tensor_ops_and_bind(self): - wrap_tensor_ops_and_bind(self.hook) - self.assertTrue(hasattr(HOOKTensor, 'wrap_add')) \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py deleted file mode 100644 index e0e4d000c0bd83be4facbbb406357427faf875ec..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +++ /dev/null @@ -1,48 +0,0 @@ -import unittest -import torch -from msprobe.pytorch.hook_module.wrap_torch import * - -class TestWrapTorch(unittest.TestCase): - - def hook(name, prefix): - def forward_pre_hook(nope, input, kwargs): - return input, kwargs - - def forward_hook(nope, input, kwargs, result): - return 2 - - def backward_hook(): - pass - - def forward_hook_torch_version_below_2(): - pass - - return forward_pre_hook, forward_hook, backward_hook, forward_hook_torch_version_below_2 - - def setUp(self): - - self.op_name = 'add' - self.torch_op = wrap_torch_op(self.op_name, self.hook) - - def test_get_torch_ops(self): - self.setUp() - ops = get_torch_ops() - self.assertIsInstance(ops, set) - self.assertIn(self.op_name, ops) - - def test_TorchOPTemplate(self): - self.setUp() - template = TorchOPTemplate(self.op_name, self.hook) - self.assertEqual(template.op_name_, self.op_name) - self.assertEqual(template.prefix_op_name_, "Torch." + str(self.op_name) + ".") - - def test_forward(self): - self.setUp() - template = TorchOPTemplate(self.op_name, self.hook) - result = template.forward(torch.tensor([1, 2, 3]), torch.tensor([4, 5, 6])) - torch.testing.assert_close(result, torch.tensor([5, 7, 9])) - - def test_wrap_torch_ops_and_bind(self): - self.setUp() - wrap_torch_ops_and_bind(self.hook) - self.assertTrue(hasattr(HOOKTorchOP, "wrap_" + self.op_name)) \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py deleted file mode 100644 index 98efb4bc5b8a30284fe820124e48af7f487d1c54..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +++ /dev/null @@ -1,11 +0,0 @@ -import unittest -import torch -from msprobe.pytorch.hook_module import wrap_vf - -class TestWrapVF(unittest.TestCase): - def setUp(self): - self.hook = lambda x: x - - def test_get_vf_ops(self): - ops = wrap_vf.get_vf_ops() - self.assertIsInstance(ops, list) \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/config/stack_config.json b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/config/stack_config.json new file mode 100644 index 0000000000000000000000000000000000000000..461b447ce0cd33fdcbab3476f7c1e3bcdee9dfad --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/config/stack_config.json @@ -0,0 +1,5 @@ +{ + "targets": {}, + "format": "csv", + "stack_info": true +} \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/demo_model.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/demo_model.py index f5de419440224cca261b62df2495e8ce28b8e2d4..820b1f7476d3d92288069bc00ac798c44bf14da6 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/demo_model.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/demo_model.py @@ -1,7 +1,25 @@ +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import torch import torch.nn.functional as F from msprobe.pytorch import TrainerMon from msprobe.pytorch.common import seed_all +from msprobe.pytorch.hook_module.api_register import get_api_register + +get_api_register().restore_all_api() device = torch.device('cpu') dtype_float32 = torch.float32 diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_anomaly_detect.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_anomaly_detect.py deleted file mode 100644 index fa0960e2cc1842a138b47fad3f86c1ed0d089db8..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_anomaly_detect.py +++ /dev/null @@ -1,291 +0,0 @@ -import unittest -from unittest import TestCase -from unittest.mock import patch - -from msprobe.pytorch.monitor.anomaly_detect import AnomalyTurbulence, AnomalyScanner, \ - AnomalyDataFactory, GradAnomalyData, BaseWriterWithAD, ScanRule, WriterInput - - -class TestScanRule(TestCase): - def test_apply_not_implemented(self): - scan_rule = ScanRule() - with self.assertRaises(Exception) as context: - scan_rule.apply(None, None) - - self.assertEqual(str(context.exception), "abstract method apply is not implemented") - - -class TestAnomalyTurbulence(TestCase): - - def setUp(self) -> None: - self.threshold = 0.2 - self.rule = AnomalyTurbulence(self.threshold) - - def test_apply_with_positive_baseline(self): - history = [10, 12, 14] - cur = 16 - result = self.rule.apply(history, cur) - self.assertTrue(result) - - def test_apply_with_non_positive_baseline(self): - history = [0, 0, 0] - cur = -1 - result = self.rule.apply(history, cur) - self.assertTrue(result) - - -class TestAnomalyScanner(TestCase): - - def test_load_rules_with_valied_spec(self): - specs = [ - {"rule_name": "AnomalyTurbulence", "args": {"threshold": 0.2}} - ] - rules = AnomalyScanner.load_rules(specs) - - self.assertEqual(len(rules), 1) - self.assertIsInstance(rules[0], AnomalyTurbulence) - self.assertEqual(rules[0].threshold, 0.2) - - rules = AnomalyScanner.load_rules(None) - self.assertEqual(len(rules), 0) - - @patch("msprobe.pytorch.monitor.anomaly_detect.logger") - def test_load_rules_with_missing_keys(self, mock_logger): - specs = [ - {"rule_name": "AnomalyTurbulence"} - ] - rules = AnomalyScanner.load_rules(specs) - - self.assertEqual(len(rules), 0) - mock_logger.warning.assert_called_once_with(f"Spec is missing required keys: {specs[0]}") - - def test_load_rules_with_invalid_rule(self): - # test invalid rule_name - specs = [{"rule_name": "InvalidRule", "args": {"threshold": 0.2}}] - rules = AnomalyScanner.load_rules(specs) - self.assertEqual(len(rules), 0) - - # test invalid args - specs = [{"rule_name": "AnomalyTurbulence", "args": "invalid args"}] - rules = AnomalyScanner.load_rules(specs) - self.assertEqual(len(rules), 0) - - def test_scan(self): - ad_rules = [AnomalyTurbulence(0.2)] - # test scan with anomaly - expected = True, "AnomalyTurbulence" - self.assertEqual(AnomalyScanner.scan(ad_rules, 1.0, 2.0), expected) - # test scan with no anomaly - expected = False, None - self.assertEqual(AnomalyScanner.scan(ad_rules, 1.0, 1.0), expected) - - -class TestAnomalyDataFactory(TestCase): - - def setUp(self) -> None: - rank = 0 - pp_stage = 0 - group_mates = [0] - self.AnomalyDataFactory = AnomalyDataFactory(rank, pp_stage, group_mates) - - def test_set_call_id(self): - name2callid = {'param_name': 0} - self.AnomalyDataFactory.set_call_id(name2callid) - - self.assertEqual(self.AnomalyDataFactory.name2callid, {'param_name': 0}) - - def test_create_success(self): - tag = ('0:1.self_attention.core_attention_flash_0/rank0/output', 'min') - message = "Rule AnomalyTurbulence reports anomaly signal in ('0:1.self_attention.core_attention_flash_0/rank0/output', 'min') at step 2." - step = 2 - result = self.AnomalyDataFactory.create(tag, message, step) - - self.assertEqual(result.step, step) - self.assertEqual(result.tag_name, tag[0]) - self.assertEqual(result.message, message) - self.assertEqual(result.vpp_stage, 0) - - # test no vpp_stage - tag = ('1.self_attention.core_attention_flash_0/rank0/output', 'min') - result = self.AnomalyDataFactory.create(tag, message, step) - self.assertEqual(result.vpp_stage, 0) - - def test_create_failed(self): - error_tag = '0:1.self_attention.core_attention_flash_0/rank0/output' - message = "Rule AnomalyTurbulence reports anomaly signal in ('0:1.self_attention.core_attention_flash_0/rank0/output', 'min') at step 2." - step = 2 - with self.assertRaises(Exception) as context: - self.AnomalyDataFactory.create(error_tag, message, step) - self.assertEqual(str(context.exception), "tag must be a tuple with length 2") - - -class TestGradAnomalyData(TestCase): - - def setUp(self) -> None: - tag_name = "0:1.self_attention.core_attention_flash.output:0/rank0/actv" - message = "Rule AnomalyTurbulence reports anomaly signal in ('0:1.self_attention.core_attention_flash.output:0/rank0/actv', 'min') at step 2." - group_mates = [0] - self.GradAnomalyData = GradAnomalyData(tag_name=tag_name, message=message, group_mates=group_mates) - - def test_get_train_stage(self): - tag_name_list = ["0:fc2.input:0/rank0/actv", "0:fc1.weight/rank0/post_grad", "0:fc2.weight/rank0/exp_avg_sq", ""] - expected_train_stage_list = [0, 1, 2, -1] - for tag_name, expected_train_stage in zip(tag_name_list, expected_train_stage_list): - train_stage = GradAnomalyData.get_train_stage(tag_name) - self.assertEqual(train_stage, expected_train_stage) - - def test_to_dict(self): - expected = { - 'rank': 0, - 'step': 0, - 'micro_step': 0, - 'pp_stage': 0, - 'vpp_stage': 0, - 'call_id': 0, - 'tag_name': "0:1.self_attention.core_attention_flash.output:0/rank0/actv", - 'message': "Rule AnomalyTurbulence reports anomaly signal in ('0:1.self_attention.core_attention_flash.output:0/rank0/actv', 'min') at step 2.", - 'group_mates': [0] - } - - self.assertEqual(self.GradAnomalyData.to_dict(), expected) - - def test_get_key(self): - expected = "0:1.self_attention.core_attention_flash.output:0/rank0/actv_step_0_call_0" - - self.assertEqual(self.GradAnomalyData.get_key(), expected) - - def test_lt_different_step(self): - data1 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=0, call_id=0, tag_name="") - data2 = GradAnomalyData(step=2, micro_step=0, vpp_stage=0, pp_stage=0, call_id=0, tag_name="") - self.assertLess(data1, data2) - self.assertGreater(data2, data1) - - def test_lt_same_step_different_micro_step(self): - data1 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=0, call_id=0, tag_name="") - data2 = GradAnomalyData(step=1, micro_step=1, vpp_stage=0, pp_stage=0, call_id=0, tag_name="") - self.assertLess(data1, data2) - self.assertGreater(data2, data1) - - def test_lt_same_step_same_micro_step_different_vpp_stage(self): - # same forward - data1 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=0, call_id=0, tag_name="xxx/actv") - data2 = GradAnomalyData(step=1, micro_step=0, vpp_stage=1, pp_stage=0, call_id=0, tag_name="xxx/actv") - self.assertLess(data1, data2) - self.assertGreater(data2, data1) - - # same backward - data1 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=0, call_id=0, tag_name="xxx/post_grad") - data2 = GradAnomalyData(step=1, micro_step=0, vpp_stage=1, pp_stage=0, call_id=0, tag_name="xxx/post_grad") - self.assertLess(data2, data1) - self.assertGreater(data1, data2) - - # diff train stage - data1 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=0, call_id=0, tag_name="xxx/actv") - data2 = GradAnomalyData(step=1, micro_step=0, vpp_stage=1, pp_stage=0, call_id=0, tag_name="xxx/post_grad") - self.assertLess(data1, data2) - self.assertGreater(data2, data1) - - def test_lt_same_step_same_micro_step_same_vpp_stage_different_pp_stage(self): - # same forward - data1 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=0, call_id=0, tag_name="xxx/actv") - data2 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=1, call_id=0, tag_name="xxx/actv") - self.assertLess(data1, data2) - self.assertGreater(data2, data1) - - # same backward - data1 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=0, call_id=0, tag_name="xxx/post_grad") - data2 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=1, call_id=0, tag_name="xxx/post_grad") - self.assertLess(data2, data1) - self.assertGreater(data1, data2) - - # diff train stage - data1 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=0, call_id=0, tag_name="xxx/input") - data2 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=1, call_id=0, tag_name="xxx/post_grad") - self.assertLess(data1, data2) - self.assertGreater(data2, data1) - - def test_lt_same_step_same_micro_step_same_vpp_stage_same_pp_stage_different_call_id(self): - data1 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=0, call_id=0, tag_name="") - data2 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=0, call_id=1, tag_name="") - self.assertLess(data1, data2) - self.assertGreater(data2, data1) - - def test_lt_same_data(self): - data1 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=0, call_id=0, tag_name="") - data2 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=0, call_id=0, tag_name="") - self.assertGreaterEqual(data1, data2) - self.assertLessEqual(data1, data2) - - def test_lt_not_instance(self): - data = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=0, call_id=0) - not_instance = "not an instance of GradAnomalyData" - self.assertEqual(data.__lt__(not_instance), NotImplemented) - - def test_le_same_instance(self): - # 测试相同实例的情况 - data1 = GradAnomalyData() - self.assertTrue(data1 <= data1) - - def test_le_different_instance(self): - # 测试不同实例的情况 - data1 = GradAnomalyData() - data2 = GradAnomalyData() - self.assertTrue(data1 <= data2) - - def test_le_not_instance(self): - # 测试非GradAnomalyData实例的情况 - data = GradAnomalyData() - not_instance = "Not an instance of GradAnomalyData" - self.assertEqual(data.__le__(not_instance), NotImplemented) - - def test_le_different_instance_not_equal(self): - # 测试不同实例且不相等的情况 - data1 = GradAnomalyData() - data2 = GradAnomalyData() - data2.some_attribute = "some value" - self.assertTrue(data1 <= data2) - - -class TestBaseWriterWithAD(TestCase): - - def setUp(self) -> None: - self.BaseWriter = BaseWriterWithAD(WriterInput('', None, None)) - - def test_get_anomalies(self): - expected = [] - - self.assertEqual(self.BaseWriter.get_anomalies(), expected) - - def test_clear_anomalies(self): - self.BaseWriter.anomalies = ['anomaly1', 'anomaly2'] - self.BaseWriter.clear_anomalies() - - self.assertEqual(self.BaseWriter.anomalies, []) - - @patch("msprobe.pytorch.monitor.anomaly_detect.logger") - def test_add_scalar(self, mock_logger): - AnomalyTurbulence_obj = AnomalyTurbulence(0.2) - self.BaseWriter.ad_rules = [AnomalyTurbulence_obj] - self.BaseWriter.tag2scalars = {'tag': {'avg': 1.0, 'count': 1}} - self.BaseWriter.add_scalar('tag', 2.0) - - mock_logger.info.assert_called_once() - - def test_ad(self): - AnomalyTurbulence_obj = AnomalyTurbulence(0.2) - self.BaseWriter.ad_rules = [AnomalyTurbulence_obj] - expected = True, "AnomalyTurbulence" - - self.assertEqual(self.BaseWriter._ad(2.0, 1.0), expected) - - def test_update_tag2scalars(self): - self.BaseWriter._update_tag2scalars('tag1', 1.0) - self.assertEqual(self.BaseWriter.tag2scalars['tag1']['avg'], 1.0) - self.assertEqual(self.BaseWriter.tag2scalars['tag1']['count'], 1) - self.BaseWriter._update_tag2scalars('tag1', 2.0) - self.assertEqual(self.BaseWriter.tag2scalars['tag1']['avg'], 1.5) - self.assertEqual(self.BaseWriter.tag2scalars['tag1']['count'], 2) - - -if __name__ == '__main__': - unittest.main() diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_csv2db.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_csv2db.py new file mode 100644 index 0000000000000000000000000000000000000000..535d68179a25271dd68eeae6617febc5d6c89b45 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_csv2db.py @@ -0,0 +1,386 @@ +import os +import re +import shutil +import sqlite3 +import tempfile +import unittest +from unittest.mock import patch + +import pandas as pd +from msprobe.pytorch.monitor.csv2db import ( + CSV_FILE_PATTERN, + CSV2DBConfig, + MonitorDB, + all_data_type_list, + check_data_type_list, + check_process_num, + csv2db, + pre_scan_single_rank, + process_single_rank, +) + + +class TestCSV2DB(unittest.TestCase): + def setUp(self): + self.temp_dir = tempfile.mkdtemp() + self.db_path = os.path.join(self.temp_dir, "test.db") + self.config = CSV2DBConfig( + monitor_path=self.temp_dir, + process_num=2, + data_type_list=["actv", "actv_grad"], + step_partition=100 + ) + # 创建模拟的CSV目录结构 + self.rank_dir = os.path.join(self.temp_dir, "rank_0") + os.makedirs(self.rank_dir) + with open(os.path.join(self.rank_dir, "actv_100-200.csv"), "w") as f: + f.write( + "name,vpp_stage,micro_step,step,mean,max\n" + "op1,1,0,100,0.5,1.0\n" + "op2,1,0,100,0.3,0.8\n" + ) + with open(os.path.join(self.rank_dir, "grad_300-400.csv"), "w") as f: + f.write( + "name,vpp_stage,micro_step,step,min,max\n" + "op2,2,1,300,0.1,0.9\n" + ) + with open(os.path.join(self.rank_dir, "actv_500-600.csv"), "w") as f: + f.write( + "name,vpp_stage,micro_step,step,mean,max\n" + "op3,1,0,500,0.2,0.7\n" + ) + # 创建第二个rank的目录 + self.rank_dir1 = os.path.join(self.temp_dir, "rank_1") + os.makedirs(self.rank_dir1) + with open(os.path.join(self.rank_dir1, "actv_100-200.csv"), "w") as f: + f.write( + "name,vpp_stage,micro_step,step,mean,max\n" + "op1,1,0,100,0.6,1.1\n" + ) + + def tearDown(self): + shutil.rmtree(self.temp_dir) + + def test_check_process_num_edge_cases(self): + """测试边界进程数""" + with self.assertRaises(ValueError): + check_process_num(0) # 0进程 + with self.assertRaises(ValueError): + check_process_num(-1) # 负数进程 + with self.assertRaises(ValueError): + check_process_num(129) # 超过最大限制(128) + # 边界有效值 + check_process_num(1) # 最小值 + check_process_num(128) # 最大值 + + def test_check_data_type_list_edge_cases(self): + """测试数据类型边界情况""" + # None使用默认值 + check_data_type_list(None) + # 空列表 + check_data_type_list([]) + # 所有支持的类型 + check_data_type_list(all_data_type_list) + # 混合有效和无效类型 + with self.assertRaises(ValueError): + check_data_type_list(["actv", "invalid_type"]) + # 非列表输入 + with self.assertRaises(ValueError): + check_data_type_list("actv") + + @patch("msprobe.pytorch.monitor.csv2db.read_csv") + def test_pre_scan_with_empty_file(self, mock_read_csv): + """测试空文件处理""" + mock_read_csv.return_value = pd.DataFrame() # 空DataFrame + + files = [os.path.join(self.rank_dir, "actv_100-200.csv")] + result = pre_scan_single_rank(0, files) + + self.assertEqual(result["metrics"], {"actv"}) + self.assertEqual(result["targets"], []) + + @patch("msprobe.pytorch.monitor.csv2db.read_csv") + def test_pre_scan_multiple_targets(self, mock_read_csv): + """测试扫描多个目标""" + mock_read_csv.return_value = pd.DataFrame({ + "name": ["op1", "op2"], + "vpp_stage": [1, 2], + "micro_step": [0, 1], + "step": [100, 100], + "mean": [0.5, 0.3], + "max": [1.0, 0.8] + }) + + files = [os.path.join(self.rank_dir, "actv_100-200.csv")] + result = pre_scan_single_rank(0, files) + + self.assertEqual(len(result["targets"]), 2) + self.assertIn((1, "op1", 0), result["targets"]) + self.assertIn((2, "op2", 1), result["targets"]) + + # 数据库操作测试 + def test_table_creation_logic(self): + """测试表创建逻辑""" + # _init_schema + db = MonitorDB(self.db_path, 100) + + # 验证初始表结构 + tables = db.conn.execute( + "SELECT name FROM sqlite_master WHERE type='table'").fetchall() + table_names = {t[0] for t in tables} + expected_tables = { + "monitoring_targets", "monitoring_metrics", + "metric_stats", "global_stats" + } + self.assertTrue(expected_tables.issubset(table_names)) + + # 验证全局统计初始值 + global_stats = db.conn.execute("SELECT * FROM global_stats").fetchall() + self.assertEqual(len(global_stats), 4) + + db.conn.close() + + def test_metric_table_partitioning(self): + """测试指标表的分区逻辑""" + db = MonitorDB(self.db_path, 100) + + # 模拟需要创建多个分区的情况 + db._create_metric_table(1, 0, ["mean", "max"]) # step 0-99 + db._create_metric_table(1, 100, ["mean", "max"]) # step 100-199 + db._create_metric_table(1, 200, ["mean", "max"]) # step 200-299 + + # 验证表是否创建创建 + self.assertTrue(db._table_exists("metric_1_step_0_99")) + self.assertTrue(db._table_exists("metric_1_step_100_199")) + self.assertTrue(db._table_exists("metric_1_step_200_299")) + + # 验证表结构 + columns = db.conn.execute( + "PRAGMA table_info(metric_1_step_0_99)").fetchall() + column_names = [col[1] for col in columns] + expected_columns = ["rank", "step", "target_id", "mean", "max"] + self.assertEqual(column_names, expected_columns) + + # 验证分区约束存在 + table_sql = db.conn.execute( + "SELECT sql FROM sqlite_master WHERE type='table' AND name='metric_1_step_0_99'" + ).fetchone()[0] + table_sql = table_sql.replace("\n", "") + table_sql = ' '.join(table_sql.split()) + self.assertIn("CHECK(step BETWEEN 0 AND 99", table_sql) + + db.conn.close() + + def test_target_insertion_conflict(self): + """测试目标插入冲突处理""" + db = MonitorDB(self.db_path, 100) + + # 第一次插入 + db.conn.executemany( + "INSERT OR IGNORE INTO monitoring_targets (vpp_stage, target_name, micro_step) VALUES (?, ?, ?)", + [(1, "op1", 0), (2, "op2", 1)] + ) + db.conn.commit() + + # 第二次插入相同目标 + db.conn.executemany( + "INSERT OR IGNORE INTO monitoring_targets (vpp_stage, target_name, micro_step) VALUES (?, ?, ?)", + [(1, "op1", 0), (3, "op3", 0)] + ) + db.conn.commit() + + # 验证只有新目标被插入 + targets = db.conn.execute( + "SELECT * FROM monitoring_targets").fetchall() + self.assertEqual(len(targets), 3) + + db.conn.close() + + # 数据处理测试 + @patch("msprobe.pytorch.monitor.csv2db.read_csv") + def test_process_single_rank_with_missing_data(self, mock_read_csv): + """测试处理缺失数据的情况""" + # 创建临时文件数据库 + test_db_path = os.path.join(self.temp_dir, "test_missing_data.db") + conn = sqlite3.connect(test_db_path) + conn.execute("PRAGMA journal_mode=WAL") + conn.execute( + "CREATE TABLE monitoring_targets (target_id INTEGER PRIMARY KEY, target_name TEXT, vpp_stage INTEGER, micro_step INTEGER)") + conn.execute("INSERT INTO monitoring_targets VALUES (1, 'op1', 1, 0)") + conn.execute(""" + CREATE TABLE metric_1_step_100_199 ( + rank INTEGER, step INTEGER, target_id INTEGER, mean REAL, max REAL, + PRIMARY KEY (rank, step, target_id) + ) WITHOUT ROWID + """) + conn.commit() + conn.close() + + # 模拟CSV数据(包含缺失值) + mock_read_csv.return_value = pd.DataFrame({ + "name": ["op1", "op1", "op3"], # op3未在目标表中 + "vpp_stage": [1, 1, 3], + "micro_step": [0, 0, 0], + "step": [100, 101, 102], + "mean": [0.5, None, 0.7], # 缺失值 + "max": [1.0, 0.8, 0.9] + }) + + # 执行处理 + metric_id_dict = {"actv": [1, ["mean", "max"]]} + target_dict = {("op1", 1, 0): 1} # op3不存在 + process_single_rank( + (0, [os.path.join(self.rank_dir, "actv_100-200.csv")]), + metric_id_dict, + target_dict, + 100, + test_db_path # 使用文件数据库路径 + ) + + # 重新连接数据库验证结果 + conn = sqlite3.connect(test_db_path) + result = conn.execute("SELECT * FROM metric_1_step_100_199").fetchall() + + self.assertEqual(len(result), 2) # 只有两个有效行 + + # 第一行数据完整 + self.assertEqual(result[0], (0, 100, 1, 0.5, 1.0)) + + # 第二行有缺失值 + self.assertEqual(result[1][:3], (0, 101, 1)) + self.assertIsNone(result[1][3]) # mean为None + self.assertEqual(result[1][4], 0.8) # max存在 + + conn.close() + + @patch("msprobe.pytorch.monitor.csv2db.read_csv") + def test_process_single_rank_large_batch(self, mock_read_csv): + """测试大批量数据处理""" + # 创建内存数据库 + test_db_path = os.path.join(self.temp_dir, "test_missing_data.db") + conn = sqlite3.connect(test_db_path) + conn.execute("PRAGMA journal_mode=WAL") + conn.execute( + "CREATE TABLE monitoring_targets (target_id INTEGER PRIMARY KEY, target_name TEXT, vpp_stage INTEGER, micro_step INTEGER)") + conn.execute("INSERT INTO monitoring_targets VALUES (1, 'op1', 1, 0)") + conn.execute(""" + CREATE TABLE metric_1_step_0_59999 ( + rank INTEGER, step INTEGER, target_id INTEGER, mean REAL, max REAL, + PRIMARY KEY (rank, step, target_id) + ) WITHOUT ROWID + """) + conn.commit() + + # 生成大量数据 (大于BATCH_SIZE) + num_rows = 60000 + data = { + "name": ["op1"] * num_rows, + "vpp_stage": [1] * num_rows, + "micro_step": [0] * num_rows, + "step": list(range(num_rows)), + "mean": [0.5] * num_rows, + "max": [1.0] * num_rows + } + mock_read_csv.return_value = pd.DataFrame(data) + + # 执行处理 + metric_id_dict = {"actv": [1, ["mean", "max"]]} + target_dict = {("op1", 1, 0): 1} + process_single_rank( + (0, [os.path.join(self.rank_dir, "actv_0-59999.csv")]), + metric_id_dict, + target_dict, + num_rows, + test_db_path + ) + + # 验证所有数据都被插入 + result = conn.execute( + "SELECT COUNT(*) FROM metric_1_step_0_59999").fetchone() + self.assertEqual(result[0], num_rows) + + # 错误处理测试 + @patch("msprobe.pytorch.monitor.csv2db.logger") + def test_process_single_rank_exception_handling(self, mock_logger): + """测试处理过程中的异常处理""" + # 模拟文件读取时抛出异常 + with patch("msprobe.pytorch.monitor.csv2db.read_csv", side_effect=Exception("Test error")): + process_single_rank( + (0, [os.path.join("invalid_step", "invalid_rank", "actv_300-300.csv")]), + {"actv": (1, ["norm"])}, + {}, + 100, + ":memory:" + ) + # 验证错误被记录 + mock_logger.error.assert_called_with("Error processing 0: Test error") + + @patch("msprobe.pytorch.monitor.csv2db.get_target_output_dir") + def test_csv2db_with_invalid_step_partition(self, mock_get_dir): + """测试无效step分区值""" + mock_get_dir.return_value = {0: self.rank_dir} + + # 无效分区大小 + with self.assertRaises(ValueError): + config = CSV2DBConfig( + monitor_path=self.temp_dir, + step_partition=0 # 无效值 + ) + csv2db(config) + + # 负分区大小 + with self.assertRaises(ValueError): + config = CSV2DBConfig( + monitor_path=self.temp_dir, + step_partition=-100 + ) + csv2db(config) + + def test_global_stats_update(self): + """测试全局统计更新""" + # 模拟预扫描结果 + + db = MonitorDB(self.db_path, 100) + rank_files = db._pre_scan( + {0: self.rank_dir, 1: self.rank_dir1}, + ["actv"], + workers=1 + ) + + # 验证只处理了有效文件 + self.assertEqual(len(rank_files), 2) + + # 验证全局统计 + stats = db.conn.execute( + "SELECT stat_name, stat_value FROM global_stats" + ).fetchall() + stats_dict = dict(stats) + + self.assertEqual(stats_dict["max_rank"], 1) + self.assertEqual(stats_dict["min_step"], 100) + self.assertEqual(stats_dict["max_step"], 600) # 根据文件名称落入 + self.assertEqual(stats_dict["step_partition_size"], 100) + + # 文件模式匹配测试 + def test_csv_file_pattern_matching(self): + """测试CSV文件模式匹配""" + valid_names = [ + "actv_100-200.csv", + "grad_0-1000.csv", + "param_updated_5000-6000.csv" + ] + + invalid_names = [ + "actv_100_200.csv", # 错误的分隔符 + "actv100-200.csv", # 缺少下划线 + "invalid.csv", # 不匹配的模式 + "actv_abc-def.csv" # 非数字步骤 + ] + + pattern = re.compile(CSV_FILE_PATTERN) + + for name in valid_names: + self.assertTrue(pattern.match(name), f"{name} should match") + + for name in invalid_names: + self.assertIsNone(pattern.match(name), f"{name} should not match") diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_csv2tb.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_csv2tb.py index f2bc82ffafc2a1f10719d4a46669bc0050c12782..09e860e7ac5048bd059f888eabfd8ad1d7f45d37 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_csv2tb.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_csv2tb.py @@ -1,8 +1,22 @@ +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import os import shutil import random import unittest -import pytest import torch import numpy as np import torch.nn as nn @@ -11,14 +25,13 @@ from tensorboard.backend.event_processing.event_accumulator import EventAccumula from msprobe.pytorch import TrainerMon from msprobe.core.common.const import MonitorConst from msprobe.pytorch.monitor.csv2tb import parse_step_fn, csv2tensorboard_by_step +from msprobe.pytorch.hook_module.api_register import get_api_register +get_api_register().restore_all_api() base_dir = os.path.dirname(os.path.realpath(__file__)) config_json_path = os.path.join(base_dir, "config", "all_config.json") monitor_output = os.path.join(base_dir, "./monitor_output_csv2tb") -os.environ[MonitorConst.MONITOR_OUTPUT_DIR] = monitor_output -timestamp_dirpath = None -csv2tb_dirpath = None def seed_all(seed=1234, mode=False): @@ -28,8 +41,8 @@ def seed_all(seed=1234, mode=False): torch.manual_seed(seed) torch.use_deterministic_algorithms(mode) -seed_all() +seed_all() inputs = [torch.rand(10, 10) for _ in range(10)] labels = [torch.randint(0, 5, (10,)) for _ in range(10)] @@ -47,31 +60,6 @@ class MockModule(nn.Module): return x2 -def data_collect(): - loss_fun = nn.CrossEntropyLoss() - test_module = MockModule() - nn.init.constant_(test_module.linear.weight, 1.0) - nn.init.constant_(test_module.linear.bias, 1.0) - optimizer = torch.optim.Adam(test_module.parameters()) - - monitor = TrainerMon(config_json_path, params_have_main_grad=False) - monitor.set_monitor(test_module, grad_acc_steps=1, optimizer=optimizer) - - for input_data, label in zip(inputs, labels): - output = test_module(input_data) - loss = loss_fun(output, label) - optimizer.zero_grad() - loss.backward() - optimizer.step() - - global timestamp_dirpath, csv2tb_dirpath - timestamp_dirpath = os.path.join(monitor_output, os.listdir(monitor_output)[0]) - csv2tensorboard_by_step(monitor_output) - for dirname in os.listdir(monitor_output): - if "csv2tensorboard" in dirname: - csv2tb_dirpath = os.path.join(monitor_output, dirname, "rank0") - - def extract_scalars_from_tensorboard(log_dir): # 初始化 EventAccumulator event_acc = EventAccumulator(log_dir) @@ -126,97 +114,102 @@ def compare_scalar_dicts(dict1, dict2): return True -@pytest.fixture(scope="session") -def setup_all(): - data_collect() - yield - shutil.rmtree(monitor_output) - -@pytest.mark.usefixtures("setup_all") class TestGradMonitor(unittest.TestCase): + timestamp_dirpath = None + csv2tb_dirpath = None + + @classmethod + def setUpClass(cls): + + os.environ[MonitorConst.MONITOR_OUTPUT_DIR] = monitor_output + if os.path.exists(monitor_output): + shutil.rmtree(monitor_output) + + loss_fun = nn.CrossEntropyLoss() + test_module = MockModule() + nn.init.constant_(test_module.linear.weight, 1.0) + nn.init.constant_(test_module.linear.bias, 1.0) + optimizer = torch.optim.Adam(test_module.parameters()) + + monitor = TrainerMon(config_json_path, params_have_main_grad=False) + monitor.set_monitor(test_module, grad_acc_steps=1, optimizer=optimizer) + + for input_data, label in zip(inputs, labels): + output = test_module(input_data) + loss = loss_fun(output, label) + optimizer.zero_grad() + loss.backward() + optimizer.step() + + cls.timestamp_dirpath = os.path.join(monitor_output, os.listdir(monitor_output)[0]) + csv2tensorboard_by_step(monitor_output) + for dirname in os.listdir(monitor_output): + if "csv2tensorboard" in dirname: + cls.csv2tb_dirpath = os.path.join(monitor_output, dirname, "rank0") + os.environ.pop(MonitorConst.MONITOR_OUTPUT_DIR) def setUp(self): self.maxDiff = None - + def test_actv(self): - data = parse_step_fn(os.path.join(timestamp_dirpath,"actv_0-2.csv")) + data = parse_step_fn(os.path.join(self.timestamp_dirpath, "actv_0-2.csv")) result = { 'vp0:.input:micro0': { - 0: {'nans': 0.0,'norm': 5.550016}, - 1: {'nans': 0.0,'norm': 5.975112}, - 2: {'nans': 0.0,'norm': 5.789881} - }, + 0: {'nans': 0.0, 'norm': 5.550016}, + 1: {'nans': 0.0, 'norm': 5.975112}, + 2: {'nans': 0.0, 'norm': 5.789881} + }, 'vp0:.output:micro0': { - 0: {'nans': 0.0,'norm': 41.842655}, - 1: {'nans': 0.0,'norm': 44.40981}, - 2: {'nans': 0.0,'norm': 43.578354} - }, + 0: {'nans': 0.0, 'norm': 41.842655}, + 1: {'nans': 0.0, 'norm': 44.40981}, + 2: {'nans': 0.0, 'norm': 43.578354} + }, 'vp0:linear.input:micro0': { - 0: {'nans': 0.0,'norm': 5.550016}, - 1: {'nans': 0.0,'norm': 5.975112}, - 2: {'nans': 0.0,'norm': 5.789881} - }, + 0: {'nans': 0.0, 'norm': 5.550016}, + 1: {'nans': 0.0, 'norm': 5.975112}, + 2: {'nans': 0.0, 'norm': 5.789881} + }, 'vp0:linear.output:micro0': { - 0: {'nans': 0.0,'norm': 41.842655}, - 1: {'nans': 0.0,'norm': 44.40981}, - 2: {'nans': 0.0,'norm': 43.578354} - }, + 0: {'nans': 0.0, 'norm': 41.842655}, + 1: {'nans': 0.0, 'norm': 44.40981}, + 2: {'nans': 0.0, 'norm': 43.578354} + }, 'vp0:relu.input:micro0': { - 0: {'nans': 0.0,'norm': 41.842655}, - 1: {'nans': 0.0,'norm': 44.40981}, - 2: {'nans': 0.0,'norm': 43.578354} - }, + 0: {'nans': 0.0, 'norm': 41.842655}, + 1: {'nans': 0.0, 'norm': 44.40981}, + 2: {'nans': 0.0, 'norm': 43.578354} + }, 'vp0:relu.output:micro0': { - 0: {'nans': 0.0,'norm': 41.842655}, - 1: {'nans': 0.0,'norm': 44.40981}, - 2: {'nans': 0.0,'norm': 43.578354} - } + 0: {'nans': 0.0, 'norm': 41.842655}, + 1: {'nans': 0.0, 'norm': 44.40981}, + 2: {'nans': 0.0, 'norm': 43.578354} } - self.assertEqual(dict_equal(data, result), True) - tb_data = extract_scalars_from_tensorboard(os.path.join(csv2tb_dirpath, "actv")) + } + self.assertDictEqual(data, result) + tb_data = extract_scalars_from_tensorboard(os.path.join(self.csv2tb_dirpath, "actv")) print(tb_data) tb_result = { 'vp0:.input:micro0/nans': [(0, 0.0), - (1, 0.0), - (2, 0.0), - (3, 0.0), - (4, 0.0), - (5, 0.0), - (6, 0.0), - (7, 0.0), - (8, 0.0), - (9, 0.0)], + (1, 0.0), + (2, 0.0), + (3, 0.0), + (4, 0.0), + (5, 0.0), + (6, 0.0), + (7, 0.0), + (8, 0.0), + (9, 0.0)], 'vp0:.input:micro0/norm': [(0, 5.550015926361084), - (1, 5.975111961364746), - (2, 5.789881229400635), - (3, 6.052319049835205), - (4, 5.573315143585205), - (5, 5.864360809326172), - (6, 5.292460918426514), - (7, 5.477899074554443), - (8, 5.884613990783691), - (9, 5.456457138061523)], + (1, 5.975111961364746), + (2, 5.789881229400635), + (3, 6.052319049835205), + (4, 5.573315143585205), + (5, 5.864360809326172), + (6, 5.292460918426514), + (7, 5.477899074554443), + (8, 5.884613990783691), + (9, 5.456457138061523)], 'vp0:.output:micro0/nans': [(0, 0.0), - (1, 0.0), - (2, 0.0), - (3, 0.0), - (4, 0.0), - (5, 0.0), - (6, 0.0), - (7, 0.0), - (8, 0.0), - (9, 0.0)], - 'vp0:.output:micro0/norm': [(0, 41.842655181884766), - (1, 44.40980911254883), - (2, 43.57835388183594), - (3, 45.83631134033203), - (4, 42.0673828125), - (5, 43.46839141845703), - (6, 39.77947235107422), - (7, 40.200843811035156), - (8, 44.453147888183594), - (9, 40.841522216796875)], - 'vp0:linear.input:micro0/nans': [(0, 0.0), (1, 0.0), (2, 0.0), (3, 0.0), @@ -226,117 +219,136 @@ class TestGradMonitor(unittest.TestCase): (7, 0.0), (8, 0.0), (9, 0.0)], + 'vp0:.output:micro0/norm': [(0, 41.842655181884766), + (1, 44.40980911254883), + (2, 43.57835388183594), + (3, 45.83631134033203), + (4, 42.0673828125), + (5, 43.46839141845703), + (6, 39.77947235107422), + (7, 40.200843811035156), + (8, 44.453147888183594), + (9, 40.841522216796875)], + 'vp0:linear.input:micro0/nans': [(0, 0.0), + (1, 0.0), + (2, 0.0), + (3, 0.0), + (4, 0.0), + (5, 0.0), + (6, 0.0), + (7, 0.0), + (8, 0.0), + (9, 0.0)], 'vp0:linear.input:micro0/norm': [(0, 5.550015926361084), - (1, 5.975111961364746), - (2, 5.789881229400635), - (3, 6.052319049835205), - (4, 5.573315143585205), - (5, 5.864360809326172), - (6, 5.292460918426514), - (7, 5.477899074554443), - (8, 5.884613990783691), - (9, 5.456457138061523)], + (1, 5.975111961364746), + (2, 5.789881229400635), + (3, 6.052319049835205), + (4, 5.573315143585205), + (5, 5.864360809326172), + (6, 5.292460918426514), + (7, 5.477899074554443), + (8, 5.884613990783691), + (9, 5.456457138061523)], 'vp0:linear.output:micro0/nans': [(0, 0.0), - (1, 0.0), - (2, 0.0), - (3, 0.0), - (4, 0.0), - (5, 0.0), - (6, 0.0), - (7, 0.0), - (8, 0.0), - (9, 0.0)], + (1, 0.0), + (2, 0.0), + (3, 0.0), + (4, 0.0), + (5, 0.0), + (6, 0.0), + (7, 0.0), + (8, 0.0), + (9, 0.0)], 'vp0:linear.output:micro0/norm': [(0, 41.842655181884766), - (1, 44.40980911254883), - (2, 43.57835388183594), - (3, 45.83631134033203), - (4, 42.0673828125), - (5, 43.46839141845703), - (6, 39.77947235107422), - (7, 40.200843811035156), - (8, 44.453147888183594), - (9, 40.841522216796875)], + (1, 44.40980911254883), + (2, 43.57835388183594), + (3, 45.83631134033203), + (4, 42.0673828125), + (5, 43.46839141845703), + (6, 39.77947235107422), + (7, 40.200843811035156), + (8, 44.453147888183594), + (9, 40.841522216796875)], 'vp0:relu.input:micro0/nans': [(0, 0.0), - (1, 0.0), - (2, 0.0), - (3, 0.0), - (4, 0.0), - (5, 0.0), - (6, 0.0), - (7, 0.0), - (8, 0.0), - (9, 0.0)], + (1, 0.0), + (2, 0.0), + (3, 0.0), + (4, 0.0), + (5, 0.0), + (6, 0.0), + (7, 0.0), + (8, 0.0), + (9, 0.0)], 'vp0:relu.input:micro0/norm': [(0, 41.842655181884766), - (1, 44.40980911254883), - (2, 43.57835388183594), - (3, 45.83631134033203), - (4, 42.0673828125), - (5, 43.46839141845703), - (6, 39.77947235107422), - (7, 40.200843811035156), - (8, 44.453147888183594), - (9, 40.841522216796875)], + (1, 44.40980911254883), + (2, 43.57835388183594), + (3, 45.83631134033203), + (4, 42.0673828125), + (5, 43.46839141845703), + (6, 39.77947235107422), + (7, 40.200843811035156), + (8, 44.453147888183594), + (9, 40.841522216796875)], 'vp0:relu.output:micro0/nans': [(0, 0.0), - (1, 0.0), - (2, 0.0), - (3, 0.0), - (4, 0.0), - (5, 0.0), - (6, 0.0), - (7, 0.0), - (8, 0.0), - (9, 0.0)], + (1, 0.0), + (2, 0.0), + (3, 0.0), + (4, 0.0), + (5, 0.0), + (6, 0.0), + (7, 0.0), + (8, 0.0), + (9, 0.0)], 'vp0:relu.output:micro0/norm': [(0, 41.842655181884766), - (1, 44.40980911254883), - (2, 43.57835388183594), - (3, 45.83631134033203), - (4, 42.0673828125), - (5, 43.46839141845703), - (6, 39.77947235107422), - (7, 40.200843811035156), - (8, 44.453147888183594), - (9, 40.841522216796875)]} - self.assertEqual(compare_scalar_dicts(tb_data, tb_result), True) - + (1, 44.40980911254883), + (2, 43.57835388183594), + (3, 45.83631134033203), + (4, 42.0673828125), + (5, 43.46839141845703), + (6, 39.77947235107422), + (7, 40.200843811035156), + (8, 44.453147888183594), + (9, 40.841522216796875)]} + self.assertDictEqual(tb_data, tb_result) def test_actv_grad(self): - data = parse_step_fn(os.path.join(timestamp_dirpath,"actv_grad_0-2.csv")) + data = parse_step_fn(os.path.join(self.timestamp_dirpath, "actv_grad_0-2.csv")) nan = np.nan result = { 'vp0:.input:micro0': { - 0: {'norm': nan, 'nans': nan}, - 1: {'norm': nan, 'nans': nan}, + 0: {'norm': nan, 'nans': nan}, + 1: {'norm': nan, 'nans': nan}, 2: {'norm': nan, 'nans': nan} - }, + }, 'vp0:.output:micro0': { - 0: {'norm': 0.282843, 'nans': 0.0}, - 1: {'norm': 0.282617, 'nans': 0.0}, + 0: {'norm': 0.282843, 'nans': 0.0}, + 1: {'norm': 0.282617, 'nans': 0.0}, 2: {'norm': 0.282655, 'nans': 0.0} - }, + }, 'vp0:relu.input:micro0': { - 0: {'norm': 0.282843, 'nans': 0.0}, - 1: {'norm': 0.282617, 'nans': 0.0}, + 0: {'norm': 0.282843, 'nans': 0.0}, + 1: {'norm': 0.282617, 'nans': 0.0}, 2: {'norm': 0.282655, 'nans': 0.0} - }, + }, 'vp0:relu.output:micro0': { - 0: {'norm': 0.282843, 'nans': 0.0}, - 1: {'norm': 0.282617, 'nans': 0.0}, + 0: {'norm': 0.282843, 'nans': 0.0}, + 1: {'norm': 0.282617, 'nans': 0.0}, 2: {'norm': 0.282655, 'nans': 0.0} - }, + }, 'vp0:linear.input:micro0': { - 0: {'norm': nan, 'nans': nan}, - 1: {'norm': nan, 'nans': nan}, + 0: {'norm': nan, 'nans': nan}, + 1: {'norm': nan, 'nans': nan}, 2: {'norm': nan, 'nans': nan} - }, + }, 'vp0:linear.output:micro0': { - 0: {'norm': 0.282843, 'nans': 0.0}, - 1: {'norm': 0.282617, 'nans': 0.0}, + 0: {'norm': 0.282843, 'nans': 0.0}, + 1: {'norm': 0.282617, 'nans': 0.0}, 2: {'norm': 0.282655, 'nans': 0.0} - } } - self.assertEqual(dict_equal(data, result), True) - - tb_data = extract_scalars_from_tensorboard(os.path.join(csv2tb_dirpath, "actv_grad")) + } + print(data) + + tb_data = extract_scalars_from_tensorboard(os.path.join(self.csv2tb_dirpath, "actv_grad")) tb_result = { 'vp0:.input:micro0/nans': [(0, nan), (1, nan), @@ -457,88 +469,90 @@ class TestGradMonitor(unittest.TestCase): (6, 0.28316599130630493), (7, 0.28274500370025635), (8, 0.2833530008792877), - (9, 0.2825529873371124)]} - self.assertEqual(compare_scalar_dicts(tb_data, tb_result), True) + (9, 0.2825529873371124)] + } + print(tb_data) - def test_param(self): - data = parse_step_fn(os.path.join(timestamp_dirpath,"param_0-2.csv")) + data = parse_step_fn(os.path.join(self.timestamp_dirpath, "param_origin_0-2.csv")) result = { 'vp0:linear.bias': { 0: {'nans': 0.0, 'norm': 2.236068}, 1: {'nans': 0.0, 'norm': 2.236198}, 2: {'nans': 0.0, 'norm': 2.235769} - }, + }, 'vp0:linear.weight': { 0: {'nans': 0.0, 'norm': 7.071068}, 1: {'nans': 0.0, 'norm': 7.068808}, 2: {'nans': 0.0, 'norm': 7.06771} - } } - self.assertEqual(dict_equal(data, result), True) - tb_data = extract_scalars_from_tensorboard(os.path.join(csv2tb_dirpath, "param")) + } + self.assertDictEqual(data, result) + tb_data = extract_scalars_from_tensorboard(os.path.join(self.csv2tb_dirpath, "param_origin")) tb_result = { 'vp0:linear.weight/norm': [ - (0, 7.071067810058594), - (1, 7.068808078765869), - (2, 7.067709922790527), - (3, 7.0673418045043945), - (4, 7.066926956176758), - (5, 7.066311836242676), - (6, 7.065629959106445), - (7, 7.065262794494629), - (8, 7.065001964569092), - (9, 7.064840793609619)], + (0, 7.071067810058594), + (1, 7.068808078765869), + (2, 7.067709922790527), + (3, 7.0673418045043945), + (4, 7.066926956176758), + (5, 7.066311836242676), + (6, 7.065629959106445), + (7, 7.065262794494629), + (8, 7.065001964569092), + (9, 7.064840793609619)], 'vp0:linear.weight/nans': [ - (0, 0.0), - (1, 0.0), - (2, 0.0), - (3, 0.0), - (4, 0.0), - (5, 0.0), - (6, 0.0), - (7, 0.0), - (8, 0.0), - (9, 0.0)], + (0, 0.0), + (1, 0.0), + (2, 0.0), + (3, 0.0), + (4, 0.0), + (5, 0.0), + (6, 0.0), + (7, 0.0), + (8, 0.0), + (9, 0.0)], 'vp0:linear.bias/norm': [ - (0, 2.2360680103302), - (1, 2.2361979484558105), - (2, 2.235769033432007), - (3, 2.235903024673462), - (4, 2.2360129356384277), - (5, 2.2359039783477783), - (6, 2.2357990741729736), - (7, 2.2357349395751953), - (8, 2.2356700897216797), - (9, 2.235619068145752)], + (0, 2.2360680103302), + (1, 2.2361979484558105), + (2, 2.235769033432007), + (3, 2.235903024673462), + (4, 2.2360129356384277), + (5, 2.2359039783477783), + (6, 2.2357990741729736), + (7, 2.2357349395751953), + (8, 2.2356700897216797), + (9, 2.235619068145752) + ], 'vp0:linear.bias/nans': [ - (0, 0.0), - (1, 0.0), - (2, 0.0), - (3, 0.0), - (4, 0.0), - (5, 0.0), - (6, 0.0), - (7, 0.0), - (8, 0.0), - (9, 0.0)] - } - self.assertEqual(compare_scalar_dicts(tb_data, tb_result), True) + (0, 0.0), + (1, 0.0), + (2, 0.0), + (3, 0.0), + (4, 0.0), + (5, 0.0), + (6, 0.0), + (7, 0.0), + (8, 0.0), + (9, 0.0) + ] + } + self.assertDictEqual(tb_data, tb_result) def test_exp_avg(self): - data = parse_step_fn(os.path.join(timestamp_dirpath,"exp_avg_0-2.csv")) + data = parse_step_fn(os.path.join(self.timestamp_dirpath, "exp_avg_0-2.csv")) result = { 'vp0:linear.bias': { 1: {'nans': 0.0, 'norm': 0.024495}, 2: {'nans': 0.0, 'norm': 0.052203} - }, + }, 'vp0:linear.weight': { 1: {'nans': 0.0, 'norm': 0.052394}, 2: {'nans': 0.0, 'norm': 0.099221} - } } - self.assertEqual(dict_equal(data, result), True) - tb_data = extract_scalars_from_tensorboard(os.path.join(csv2tb_dirpath, "exp_avg")) + } + self.assertDictEqual(data, result) + tb_data = extract_scalars_from_tensorboard(os.path.join(self.csv2tb_dirpath, "exp_avg")) tb_result = { 'vp0:linear.bias/nans': [(1, 0.0), (2, 0.0), @@ -576,22 +590,22 @@ class TestGradMonitor(unittest.TestCase): (7, 0.11372199654579163), (8, 0.12264800071716309), (9, 0.09017200022935867)]} - self.assertEqual(compare_scalar_dicts(tb_data, tb_result), True) + self.assertDictEqual(tb_data, tb_result) def test_exp_avg_sq(self): - data = parse_step_fn(os.path.join(timestamp_dirpath,"exp_avg_sq_0-2.csv")) + data = parse_step_fn(os.path.join(self.timestamp_dirpath, "exp_avg_sq_0-2.csv")) result = { 'vp0:linear.bias': { 1: {'nans': 0.0, 'norm': 4.2e-05}, 2: {'nans': 0.0, 'norm': 9.6e-05} - }, + }, 'vp0:linear.weight': { 1: {'nans': 0.0, 'norm': 6.7e-05}, 2: {'nans': 0.0, 'norm': 0.000126} - } } - self.assertEqual(dict_equal(data, result), True) - tb_data = extract_scalars_from_tensorboard(os.path.join(csv2tb_dirpath, "exp_avg_sq")) + } + self.assertDictEqual(data, result) + tb_data = extract_scalars_from_tensorboard(os.path.join(self.csv2tb_dirpath, "exp_avg_sq")) tb_result = { 'vp0:linear.bias/nans': [(1, 0.0), (2, 0.0), @@ -629,24 +643,24 @@ class TestGradMonitor(unittest.TestCase): (7, 0.00026000000070780516), (8, 0.00028700000257231295), (9, 0.0003060000017285347)]} - self.assertEqual(compare_scalar_dicts(tb_data, tb_result), True) - + self.assertDictEqual(tb_data, tb_result) + def test_grad_reduced(self): - data = parse_step_fn(os.path.join(timestamp_dirpath,"grad_reduced_0-2.csv")) + data = parse_step_fn(os.path.join(self.timestamp_dirpath, "grad_reduced_0-2.csv")) result = { 'vp0:linear.bias': { 0: {'nans': 0.0, 'norm': 0.244949}, 1: {'nans': 0.0, 'norm': 0.314345}, 2: {'nans': 0.0, 'norm': 0.281475} - }, + }, 'vp0:linear.weight': { 0: {'nans': 0.0, 'norm': 0.523935}, 1: {'nans': 0.0, 'norm': 0.595672}, 2: {'nans': 0.0, 'norm': 0.497603} - } } - self.assertEqual(dict_equal(data, result), True) - tb_data = extract_scalars_from_tensorboard(os.path.join(csv2tb_dirpath, "grad_reduced")) + } + self.assertDictEqual(data, result) + tb_data = extract_scalars_from_tensorboard(os.path.join(self.csv2tb_dirpath, "grad_reduced")) tb_result = { 'vp0:linear.bias/nans': [(0, 0.0), (1, 0.0), @@ -688,25 +702,25 @@ class TestGradMonitor(unittest.TestCase): (7, 0.4831080138683319), (8, 0.3234719932079315), (9, 0.32385098934173584)]} - self.assertEqual(compare_scalar_dicts(tb_data, tb_result), True) - + self.assertDictEqual(tb_data, tb_result) + def test_grad_unreduced(self): - data = parse_step_fn(os.path.join(timestamp_dirpath,"grad_unreduced_0-2.csv")) + data = parse_step_fn(os.path.join(self.timestamp_dirpath, "grad_unreduced_0-2.csv")) result = { 'vp0:linear.bias': { 0: {'nans': 0.0, 'norm': 0.244949}, 1: {'nans': 0.0, 'norm': 0.314345}, 2: {'nans': 0.0, 'norm': 0.281475} - }, + }, 'vp0:linear.weight': { 0: {'nans': 0.0, 'norm': 0.523935}, 1: {'nans': 0.0, 'norm': 0.595672}, 2: {'nans': 0.0, 'norm': 0.497603} - } } - self.assertEqual(dict_equal(data, result), True) + } + self.assertDictEqual(data, result) - tb_data = extract_scalars_from_tensorboard(os.path.join(csv2tb_dirpath, "grad_unreduced")) + tb_data = extract_scalars_from_tensorboard(os.path.join(self.csv2tb_dirpath, "grad_unreduced")) tb_result = { 'vp0:linear.bias/nans': [(0, 0.0), (1, 0.0), @@ -748,4 +762,8 @@ class TestGradMonitor(unittest.TestCase): (7, 0.4831080138683319), (8, 0.3234719932079315), (9, 0.32385098934173584)]} - self.assertEqual(compare_scalar_dicts(tb_data, tb_result), True) + self.assertDictEqual(tb_data, tb_result) + + +if __name__ == '__main__': + unittest.main() diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_data_writers.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_data_writers.py new file mode 100644 index 0000000000000000000000000000000000000000..34204267935cd7691f5bcccce6c1af5451a2c34f --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_data_writers.py @@ -0,0 +1,52 @@ +import unittest +from unittest import TestCase +from unittest.mock import patch + +from msprobe.core.monitor.anomaly_processor import AnomalyTurbulence +from msprobe.pytorch.monitor.data_writers import BaseWriterWithAD, WriterInput + + +class TestBaseWriterWithAD(TestCase): + + def setUp(self) -> None: + self.BaseWriter = BaseWriterWithAD(WriterInput('', None, None)) + + def test_get_anomalies(self): + expected = [] + + self.assertEqual(self.BaseWriter.get_anomalies(), expected) + + def test_clear_anomalies(self): + self.BaseWriter.anomalies = ['anomaly1', 'anomaly2'] + self.BaseWriter.clear_anomalies() + + self.assertEqual(self.BaseWriter.anomalies, []) + + @patch("msprobe.pytorch.monitor.data_writers.logger") + def test_add_scalar(self, mock_logger): + AnomalyTurbulence_obj = AnomalyTurbulence(0.2) + self.BaseWriter.ad_rules = [AnomalyTurbulence_obj] + tag = ('0:1.post_attention_norm.weight/rank0/pre_grad', 'mean') + self.BaseWriter.tag2scalars = {tag: {'avg': 1.0, 'count': 1}} + self.BaseWriter.add_scalar(tag, 2.0) + + mock_logger.info.assert_called_once() + + def test_ad(self): + AnomalyTurbulence_obj = AnomalyTurbulence(0.2) + self.BaseWriter.ad_rules = [AnomalyTurbulence_obj] + expected = True, "AnomalyTurbulence" + + self.assertEqual(self.BaseWriter._ad(2.0, 1.0), expected) + + def test_update_tag2scalars(self): + self.BaseWriter._update_tag2scalars('tag1', 1.0) + self.assertEqual(self.BaseWriter.tag2scalars['tag1']['avg'], 1.0) + self.assertEqual(self.BaseWriter.tag2scalars['tag1']['count'], 1) + self.BaseWriter._update_tag2scalars('tag1', 2.0) + self.assertEqual(self.BaseWriter.tag2scalars['tag1']['avg'], 1.01) + self.assertEqual(self.BaseWriter.tag2scalars['tag1']['count'], 2) + + +if __name__ == '__main__': + unittest.main() diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_features.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_features.py index ff00cf7490d8110f2198df57ee5d91b6b75f5092..f2b447f1694edc1909b3ca4e2f858183d22472ff 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_features.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_features.py @@ -1,7 +1,31 @@ import unittest +from unittest.mock import patch + import torch -from msprobe.pytorch.monitor.features import square_sum, get_min, get_mean, get_norm, get_max, get_zeros, \ - get_sign_matches, eff_rank, mNTK, lambda_max_subsample, cal_histc, get_nans +from msprobe.pytorch.monitor.features import ( + cal_avg_token_similarity, + cal_avg_token_similarity_chunk, + cal_dist_diff, + cal_entropy, + cal_histc, + cal_kl_divergence, + cal_qkt, + cal_stable_rank, + cal_svd_entropy, + eff_rank, + get_max, + get_mean, + get_min, + get_nans, + get_norm, + get_sign_matches, + get_zeros, + lambda_max_subsample, + layer_norm_jacobian, + max_eigenvalue, + mNTK, + square_sum, +) class TestMathFunctions(unittest.TestCase): @@ -23,7 +47,8 @@ class TestMathFunctions(unittest.TestCase): def test_get_norm(self): tensor = torch.tensor([1.0, 2.0, 3.0]) result = get_norm(tensor) - self.assertTrue(torch.allclose(result, torch.tensor(3.7417, dtype=torch.float64), atol=1e-4)) + self.assertTrue(torch.allclose(result, torch.tensor( + 3.7417, dtype=torch.float64), atol=1e-4)) def test_get_max(self): tensor = torch.tensor([1.0, 2.0, 3.0]) @@ -44,7 +69,8 @@ class TestMathFunctions(unittest.TestCase): self.assertTrue(res) def test_eff_rank(self): - tensor = torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0], [1.0, 2.0, 3.0, 4.0, 5.0]]) + tensor = torch.tensor( + [[1.0, 2.0, 3.0, 4.0, 5.0], [1.0, 2.0, 3.0, 4.0, 5.0]]) result = eff_rank(tensor) res = torch.allclose(result, torch.tensor(2), atol=1e-1) self.assertTrue(res) @@ -87,6 +113,171 @@ class TestMathFunctions(unittest.TestCase): result = get_nans(tensor) self.assertEqual(result, 1) + def test_max_eigenvalue(self): + """测试最大特征值计算""" + # 创建已知特征值的矩阵 + A = torch.diag(torch.tensor([3.0, 2.0, 1.0])) + + # 测试不同迭代次数 + eigval = max_eigenvalue(A, num_iterations=5) + self.assertAlmostEqual(eigval.item(), 3.0, delta=0.1) + + # 测试全零矩阵 + zero_matrix = torch.zeros(3, 3) + eigval = max_eigenvalue(zero_matrix) + self.assertAlmostEqual(eigval.item(), 0.0) + + # ==================== 注意力机制测试 ==================== + + def test_cal_entropy(self): + """测试注意力熵计算""" + # 创建简单的注意力分数 + qk = torch.tensor([[1.0, 2.0, 3.0], + [4.0, 5.0, 6.0], + [7.0, 8.0, 9.0]]) + + # 无mask + entropy, softmax_max = cal_entropy(qk) + self.assertAlmostEqual(entropy, 0.4715, delta=0.1) + self.assertAlmostEqual(softmax_max, 0.7988, delta=0.1) + + # 带mask 和默认生成相同 + mask = torch.tensor([[1, 0, 0], + [1, 1, 0], + [1, 1, 1]], dtype=torch.float) + entropy, _ = cal_entropy(qk, mask) + self.assertAlmostEqual(entropy, 0.4715, delta=0.1) + self.assertAlmostEqual(softmax_max, 0.7988, delta=0.1) + + @patch("msprobe.pytorch.monitor.features.logger") + def test_cal_qkt(self, mock_logger): + """测试QK^T计算""" + # 测试s,b,h,d顺序 + q = torch.randn(10, 2, 4, 8) # [s, b, h, d] + k = torch.randn(10, 2, 4, 8) # [s, b, h, d] + q_batch = torch.randn(2, 10, 4, 8) # [b, s, h, d] + qkt = cal_qkt(q, k, order="s,b,h,d") + self.assertEqual(qkt.shape, (10, 10)) # [s, s] + + # 测试b,s,h,d顺序 + qkt = cal_qkt(q_batch, q_batch, order="b,s,h,d") + self.assertEqual(qkt.shape, (10, 10)) # [s, s] + + # 测试无效顺序 + cal_qkt(q, k, order="invalid_order") + mock_logger.warning.assert_called_with( + "Calculate qk tensor failed: Order unsupported.") + + def test_cal_stable_rank(self): + """测试谱半径计算""" + # 创建已知谱半径的矩阵 + A = torch.diag(torch.tensor([3.0, 2.0, 1.0])) + sr, eig = cal_stable_rank(A) + + # 验证Frobenius范数 + fro_norm = torch.norm(A, p='fro') + self.assertAlmostEqual(sr, fro_norm / 3.0, delta=.5) # 最大特征值为3 + + # 测试正交矩阵 + ortho = torch.eye(5) + sr, eig = cal_stable_rank(ortho) + self.assertAlmostEqual(sr, torch.tensor(2.23/1), delta=.5) # F范数应为2.23 + self.assertAlmostEqual(eig, torch.tensor(1.0), delta=.1) # 特征值应为1 + + def test_cal_svd_entropy(self): + """测试SVD熵计算""" + # 创建低秩矩阵 + low_rank = torch.ones(10, 10) # 秩为1 + entropy = cal_svd_entropy(low_rank, k=5) + self.assertAlmostEqual(entropy.item(), 0.0, + delta=1e-5) # 所有"质量"集中在一个奇异值上 + + # 创建满秩矩阵 + full_rank = torch.randn(10, 10) + entropy = cal_svd_entropy(full_rank, k=5) + self.assertGreater(entropy, 0) + self.assertLess(entropy, 10) # 熵值应在合理范围内 + + # ==================== 相似度与距离测试 ==================== + + def test_cal_avg_token_similarity(self): + """测试平均token相似度计算""" + # 创建相同token + same_tokens = torch.ones(5, 10) + sim = cal_avg_token_similarity(same_tokens) + self.assertAlmostEqual(sim.item(), 1.0, delta=0.001) + + # 创建正交token + ortho_tokens = torch.eye(5, 10) + sim = cal_avg_token_similarity(ortho_tokens) + # 正交向量相似度接近0,与自己相似度为1 + self.assertAlmostEqual(sim.item(), 1/5, delta=0.001) + + # 测试大矩阵分块计算 + large_tokens = torch.randn(1000, 128) + sim_chunk = cal_avg_token_similarity_chunk( + large_tokens, chunk_size=256) + sim_full = cal_avg_token_similarity(large_tokens) + self.assertAlmostEqual(sim_chunk.item(), sim_full.item(), delta=0.01) + + def test_layer_norm_jacobian(self): + """测试层归一化Jacobian计算""" + # 创建简单输入 + input_tensor = torch.randn(1, 10) + weight = torch.ones(10) + + # 计算Jacobian + std_x, max_eig = layer_norm_jacobian(input_tensor, weight) + + # 验证std计算 + expected_std = torch.std(input_tensor) + self.assertAlmostEqual(std_x.item(), expected_std.item(), delta=0.001) + + # 验证特征值非负 + self.assertGreater(max_eig.item(), 0) + + def test_cal_kl_divergence(self): + """测试KL散度计算""" + # 创建相同分布 + p = torch.randn(100) + q = p.clone() + kl = cal_kl_divergence(p, q) + self.assertAlmostEqual(kl.item(), 0.0, delta=1e-5) + + # 创建不同分布 + p = torch.tensor([0.5, 0.5]) + q = torch.tensor([0.9, 0.1]) + kl = cal_kl_divergence(p, q) + self.assertGreater(kl.item(), 0) + + # 测试空输入 + with self.assertRaises(RuntimeError): + cal_kl_divergence(torch.tensor([]), torch.tensor([1.0])) + + def test_cal_dist_diff(self): + """测试分布差异计算""" + # 相同分布 + p = torch.randn(100) + q = p.clone() + w1, mean_diff, std_diff = cal_dist_diff(p, q) + self.assertAlmostEqual(w1.item(), 0.0, delta=1e-5) + self.assertAlmostEqual(mean_diff.item(), 0.0, delta=1e-5) + self.assertAlmostEqual(std_diff.item(), 0.0, delta=1e-5) + + # 不同分布 + p = torch.tensor([1.0, 2.0, 3.0, 4.0]) + q = torch.tensor([0.0, 6.0, 7.0, 8.0]) + w1, mean_diff, std_diff = cal_dist_diff(p, q) + self.assertGreater(w1.item(), 0) + self.assertGreater(mean_diff.item(), 0) + self.assertGreater(std_diff.item(), 0) + + # 测试不同形状 + p = torch.randn(100) + q = torch.randn(200) + w1, mean_diff, std_diff = cal_dist_diff(p, q) + self.assertTrue(torch.isfinite(w1)) + if __name__ == '__main__': unittest.main() diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_module_hook.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_module_hook.py index eefacb73c8e76636086554775b0e6f2e916ddf6e..6c3d2b925a4fbcd7ec7c81b85614b6be0e731b0c 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_module_hook.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_module_hook.py @@ -1,3 +1,18 @@ +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import os.path import shutil import unittest @@ -8,10 +23,13 @@ import torch from msprobe.core.common.const import MonitorConst, Const from torch import distributed as dist +from msprobe.pytorch import TrainerMon +from msprobe.pytorch.hook_module.api_register import get_api_register from msprobe.pytorch.monitor.module_hook import CommunicationContext, GradContext, ModuleHookContext, \ param_is_not_tensor_parallel_duplicate, param_is_data_parallel_duplicate from msprobe.test.pytorch_ut.monitor.demo_model import monitor_demo -from msprobe.pytorch import TrainerMon + +get_api_register().restore_all_api() base_dir = os.path.dirname(os.path.realpath(__file__)) @@ -72,13 +90,13 @@ class TestModuleHook(unittest.TestCase): self.assertTrue(os.path.exists(actv_grad_0_csv)) # validate columns and lines actv_0 = pd.read_csv(actv_0_csv) - expect_columns = ['vpp_stage', 'name', 'step', 'micro_step', 'norm', 'nans'] + expect_columns = ['vpp_stage', 'name', 'step', 'micro_step', 'norm', 'nans', "shape", "dtype"] self.assertListEqual(list(actv_0.columns), expect_columns) - self.assertEqual(actv_0.shape, tuple([6, 6])) + self.assertEqual(actv_0.shape, tuple([6, 8])) actv_grad_0 = pd.read_csv(actv_grad_0_csv) - expect_columns = ['vpp_stage', 'name', 'step', 'micro_step', 'norm', 'nans'] + expect_columns = ['vpp_stage', 'name', 'step', 'micro_step', 'norm', 'nans', "shape", "dtype"] self.assertListEqual(list(actv_grad_0.columns), expect_columns) - self.assertEqual(actv_0.shape, tuple([6, 6])) + self.assertEqual(actv_0.shape, tuple([6, 8])) def test_wg_distribution(self): self.get_dist_mock(False) @@ -95,13 +113,13 @@ class TestModuleHook(unittest.TestCase): self.assertTrue(os.path.exists(grad_reduced_0_csv)) self.assertTrue(os.path.exists(grad_unreduced_0_csv)) # validate columns and lines - expect_columns = ["vpp_stage", "name", "step", "norm"] + expect_columns = ["vpp_stage", "name", "step", "norm", "shape", "dtype"] grad_reduced_0 = pd.read_csv(grad_reduced_0_csv) self.assertListEqual(list(grad_reduced_0.columns), expect_columns) - self.assertEqual(grad_reduced_0.shape, tuple([2, 4])) + self.assertEqual(grad_reduced_0.shape, tuple([2, 6])) grad_unreduced_0 = pd.read_csv(grad_unreduced_0_csv) self.assertListEqual(list(grad_unreduced_0.columns), expect_columns) - self.assertEqual(grad_unreduced_0.shape, tuple([2, 4])) + self.assertEqual(grad_unreduced_0.shape, tuple([2, 6])) def test_mv_distribution(self): self.get_dist_mock(False) @@ -118,13 +136,13 @@ class TestModuleHook(unittest.TestCase): self.assertTrue(os.path.exists(exp_avg_1_csv)) self.assertTrue(os.path.exists(exp_avg_sq_1_csv)) # validate columns and lines - expect_columns = ["vpp_stage", "name", "step", "norm"] + expect_columns = ["vpp_stage", "name", "step", "norm", "shape", "dtype"] exp_avg_1 = pd.read_csv(exp_avg_1_csv) self.assertListEqual(list(exp_avg_1.columns), expect_columns) - self.assertEqual(exp_avg_1.shape, tuple([2, 4])) + self.assertEqual(exp_avg_1.shape, tuple([2, 6])) exp_avg_sq_1 = pd.read_csv(exp_avg_sq_1_csv) self.assertListEqual(list(exp_avg_sq_1.columns), expect_columns) - self.assertEqual(exp_avg_sq_1.shape, tuple([2, 4])) + self.assertEqual(exp_avg_sq_1.shape, tuple([2, 6])) def test_ur_distribution(self): self.get_dist_mock(False) @@ -149,6 +167,18 @@ class TestModuleHook(unittest.TestCase): ) self.assertIsNotNone(hooker) + def test_stack_collect(self): + self.get_dist_mock(False) + stack_monitor_output = "./test_stack_info" + clean_output(stack_monitor_output) + os.environ[MonitorConst.MONITOR_OUTPUT_DIR] = stack_monitor_output + stack_config = os.path.join(base_dir, "config/stack_config.json") + monitor_demo(stack_config) + output_dir_list = os.listdir(stack_monitor_output) + self.assertEqual(len(output_dir_list), 1) + stack_csv_path = os.path.join(stack_monitor_output, output_dir_list[0], "stack_info.csv") + self.assertTrue(os.path.exists(stack_csv_path)) + def test_adhoc_check(self): # mock dist self.get_dist_mock(True) @@ -243,61 +273,6 @@ class TestParamIsDataParallelDuplicate(unittest.TestCase): self.assertFalse(result) -class TestModuleHookContext(unittest.TestCase): - def setUp(self): - self.module_name = "test_module" - self.context = ModuleHookContext(self.module_name) - self.context.struct = { - Const.INPUT: { - "config": "tuple[1]", - "0": "size=(2, 784), dtype=torch.float32", - }, - Const.OUTPUT: { - "config": "tensor", - "tensor": "size=(2, 10), dtype=torch.float32" - }, - MonitorConst.INPUT_GRAD: { - "config": "tuple[1]", - "0": "size=(2, 784), dtype=torch.float32" - }, - MonitorConst.OUTPUT_GRAD: { - "config": "tuple[1]", - "0": "size=(2, 10), dtype=torch.float32" - } - } - self.target_config = { - self.module_name: { - Const.INPUT: "tuple[1]:0", - Const.OUTPUT: "tensor", - MonitorConst.INPUT_GRAD: "tuple[1]:0" - } - } - - def test_set_format_by_arg_module_name_in_target_config(self): - self.context.set_format_by_arg(Const.INPUT, self.target_config) - self.assertEqual(self.context.format_by_arg[Const.INPUT], "tuple[1]:0") - self.context.set_format_by_arg(Const.OUTPUT, self.target_config) - self.assertEqual(self.context.format_by_arg[Const.OUTPUT], "tensor") - self.context.set_format_by_arg(MonitorConst.INPUT_GRAD, self.target_config) - self.assertEqual(self.context.format_by_arg[MonitorConst.INPUT_GRAD], "tuple[1]:0") - self.context.set_format_by_arg(MonitorConst.OUTPUT_GRAD, self.target_config) - self.assertEqual(self.context.format_by_arg[MonitorConst.OUTPUT_GRAD], "tuple[1]") - - def test_set_format_by_arg_module_name_not_in_target_config(self): - target_config = {} - self.context.set_format_by_arg(Const.INPUT, target_config) - self.assertEqual(self.context.format_by_arg[Const.INPUT], "tuple[1]") - self.context.set_format_by_arg(Const.OUTPUT, target_config) - self.assertEqual(self.context.format_by_arg[Const.OUTPUT], "tensor") - - @patch('msprobe.pytorch.monitor.module_hook.logger') - def test_set_format_by_arg_target_module_config_error(self, mock_logger): - target_config = {self.module_name: {Const.INPUT: 123}} - self.context.set_format_by_arg(Const.INPUT, target_config) - self.assertIsNone(self.context.format_by_arg.get(Const.INPUT)) - mock_logger.warning_on_rank_0.assert_called_once() - - class TestContext(unittest.TestCase): def test_communication_context(self): cc_ctx = CommunicationContext() diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_monitor_utils.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_monitor_utils.py index 0462ac3f39531119b40d3cc5051fad77f687b9b5..87822ab0503bd21e0546d8c846d69f56204eb048 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_monitor_utils.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_monitor_utils.py @@ -44,12 +44,12 @@ class TestValidationFunctions(unittest.TestCase): def test_validate_ops(self): ops = ['op1', 'op2', 'norm', 'max'] valid_ops = validate_ops(ops) - self.assertEqual(valid_ops, ['norm', 'max']) + self.assertEqual(valid_ops, ['norm', 'max', "shape", "dtype"]) def test_no_valid_ops(self): ops = ['op1', 'op2'] valid_ops = validate_ops(ops) - target_ops = [MonitorConst.OP_LIST[0]] + target_ops = [MonitorConst.OP_LIST[0], "shape", "dtype"] self.assertEqual(valid_ops, target_ops) def test_validate_ranks(self): @@ -104,7 +104,7 @@ class TestValidationFunctions(unittest.TestCase): 'alert': {'rules': [{'rule_name': 'AnomalyTurbulence', 'args': {'threshold': 10.0}}], 'dump': True} } validate_config(config) - target_ops = [MonitorConst.OP_LIST[0]] + target_ops = [MonitorConst.OP_LIST[0], "shape", "dtype"] self.assertEqual(config["ops"], target_ops) del config["targets"] validate_config(config) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_optimizer_collect.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_optimizer_collect.py index 793b086b02db03f8a04b159f35f1df55fc1a9d2c..e32e4f860ee40a2bb3198ee30fd522b98ae2e36e 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_optimizer_collect.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_optimizer_collect.py @@ -3,18 +3,51 @@ from collections import defaultdict from unittest.mock import Mock, patch, MagicMock import torch +from msprobe.core.common.const import MonitorConst from msprobe.pytorch.monitor.optimizer_collect import OptimizerMon, \ - OptimizerMonFactory, DummyOptimizerMon, \ - MixPrecisionOptimizerMon, MegatronDistributedOptimizerMon, MegatronFP32OptimizerMon, \ + OptimizerMonFactory, MixPrecisionOptimizerMon, MegatronDistributedOptimizerMon, \ MegatronChainedDistributedOptimizerMon, MegatronChainedMixPrecisionOptimizerMon, \ - DeepSpeedZeroOptimizerStage0Mon, DeepSpeedZeroOptimizerStage1or2Mon, DeepSpeedZeroOptimizerStage3Mon - -from msprobe.pytorch.monitor.utils import MVResult, MVGradResult - + DeepSpeedZeroOptimizerMon, DeepSpeedZeroOptimizerStage0Mon, \ + DeepSpeedZeroOptimizerStage1or2Mon, DeepSpeedZeroOptimizerStage3Mon +from msprobe.pytorch.monitor.utils import MVResult + + +def setup_param_groups(num_groups=2, params_per_group=5): + bit16_groups = [] + param_names = {} + grad_position = {} + param_slice_mappings = [] + count = 0 + for group_idx in range(num_groups): + group = [] + param_slice_mapping = {} + offset = 0 + for i in range(params_per_group): + name = f'param{group_idx}_{i}' + p = torch.nn.Parameter(torch.randn(2,3, dtype=torch.bfloat16)) + p.ds_tensor = torch.nn.Parameter(torch.randn(1,3, dtype=torch.bfloat16)) + p.ds_id = count + param_slice_mapping[name] = MagicMock(start=offset, numel=p.numel()) + group.append(p) + param_names[p] = name + grad_position[count] = [group_idx, offset, p.numel()] + offset += p.numel() + count += 1 + bit16_groups.append(group) + param_slice_mappings.append(param_slice_mapping) + + return bit16_groups, param_names, param_slice_mappings, grad_position + +def setup_mock_monitor(): + mock_monitor = MagicMock() + mock_monitor.mv_distribution = True + mock_monitor.mg_direction = False + mock_monitor.ur_distribution = False + + return mock_monitor class TestOptimizerMon(unittest.TestCase): def setUp(self) -> None: - # 初始化需要的monitor, torch_opt, params2name等对象 self.monitor = Mock() self.monitor.mv_distribution = True self.monitor.mg_direction = True @@ -23,11 +56,11 @@ class TestOptimizerMon(unittest.TestCase): self.monitor.ratio_heatmap_visualizer = {'param1': Mock(), 'param2': Mock()} def test_fetch_mv(self): - optimizer_mon = OptimizerMon() - res = optimizer_mon.fetch_mv(None, None, None) - self.assertEqual(res, None) + optimizer_mon = OptimizerMon(None) + res = optimizer_mon.fetch_mv(None, {}) + self.assertEqual(res.exp_avg, {}) - def test_fetch_mv_in_adam(self): + def test_fetch_mv(self): self.torch_opt = Mock() self.torch_opt.state = { 'param1': {'exp_avg': torch.tensor(0.1), 'exp_avg_sq': torch.tensor(0.2), 'step': torch.tensor(10)}, @@ -37,48 +70,10 @@ class TestOptimizerMon(unittest.TestCase): self.torch_opt.defaults = {'betas': (0.9, 0.999), 'eps': 1e-8} self.params2name = {'param1': 'param1', 'param2': 'param2'} - self.optimizer_mon = OptimizerMon() - result = self.optimizer_mon._fetch_mv_in_adam(self.monitor, self.torch_opt, self.params2name) + self.optimizer_mon = OptimizerMon(None) + result = self.optimizer_mon.fetch_mv(self.monitor, self.params2name) self.assertIsInstance(result, MVResult) - @patch('msprobe.pytorch.monitor.optimizer_collect.dist') - def test_fetch_mv_grad_in_adam(self, mock_dist): - self.optimizer_mon = OptimizerMon() - self.monitor = MagicMock() - self.torch_opt = MagicMock() - self.params2name = defaultdict(str) - self.name2indices = defaultdict(tuple) - self.fp32_partitioned_groups_flat = defaultdict(torch.Tensor) - - # Mocking the dist.get_rank() and dist.get_world_size() - mock_dist.get_rank.return_value = 0 - mock_dist.get_world_size.return_value = 1 - - # Mocking the wrapped_optimizer - self.torch_opt.state = defaultdict(dict) - self.torch_opt.averaged_gradients = defaultdict(torch.Tensor) - self.torch_opt.partition_size = defaultdict(int) - self.torch_opt.flatten_dense_tensors_aligned = MagicMock() - self.torch_opt.flatten = MagicMock() - - # Mocking the torch_opt.param_groups - self.torch_opt.param_groups = [{'step': 1, 'betas': (0.9, 0.999)}, - {'step': 2, 'betas': (0.9, 0.999)}, - {'step': 3, 'betas': (0.9, 0.999)}] - - # Mocking the monitor.mv_distribution, monitor.mg_direction, monitor.ur_distribution - self.monitor.mv_distribution = True - self.monitor.mg_direction = True - self.monitor.ur_distribution = True - - # Mocking the monitor.update_heatmap_visualizer and monitor.ratio_heatmap_visualizer - self.monitor.update_heatmap_visualizer = defaultdict(MagicMock) - self.monitor.ratio_heatmap_visualizer = defaultdict(MagicMock) - - result = self.optimizer_mon._fetch_mv_grad_in_adam(self.monitor, self.torch_opt, self.params2name, - self.name2indices, self.fp32_partitioned_groups_flat) - self.assertIsInstance(result, MVGradResult) - class TestMixPrecisionOptimizerMon(unittest.TestCase): def test_fetch_mv_with_fp16_to_fp32_param_and_mix_prec_opt(self): @@ -89,16 +84,16 @@ class TestMixPrecisionOptimizerMon(unittest.TestCase): self.mix_prec_opt = MagicMock() self.mix_prec_opt.float16_groups = [MagicMock()] self.mix_prec_opt.fp32_from_float16_groups = [MagicMock()] - self.optimizer = MixPrecisionOptimizerMon() + self.optimizer = MixPrecisionOptimizerMon(self.torch_opt) self.optimizer.fp16_to_fp32_param = {} - # Mock _fetch_mv_in_adam method and set a fixed return value + # Mock fetch_mv method and set a fixed return value mv_result = MVResult(exp_avg={}, exp_avg_sq={}, update={}, ratio={}) - self.mock_fetch_mv_in_adam = MagicMock(return_value=mv_result) - self.optimizer._fetch_mv_in_adam = self.mock_fetch_mv_in_adam + self.mock_fetch_mv = MagicMock(return_value=mv_result) + self.optimizer.fetch_mv = self.mock_fetch_mv - res = self.optimizer.fetch_mv(self.monitor, self.torch_opt, self.params2name) - self.mock_fetch_mv_in_adam.assert_called_once_with(self.monitor, self.torch_opt, self.params2name) + res = self.optimizer.fetch_mv(self.monitor, self.params2name) + self.mock_fetch_mv.assert_called_once_with(self.monitor, self.params2name) self.assertIsInstance(res, MVResult) @@ -110,17 +105,17 @@ class TestChainedMixPrecisionOptimizerMon(unittest.TestCase): self.params2name = MagicMock() self.torch_opt.float16_groups = [MagicMock()] self.torch_opt.fp32_from_float16_groups = [MagicMock()] - self.optimizer = MegatronChainedMixPrecisionOptimizerMon() + self.optimizer = MegatronChainedMixPrecisionOptimizerMon(self.torch_opt) self.optimizer.optimizer = [MagicMock(), MagicMock()] self.optimizer.fp16_to_fp32_param = {} - # Mock _fetch_mv_in_adam method and set a fixed return value + # Mock fetch_mv method and set a fixed return value mv_result = MVResult(exp_avg={}, exp_avg_sq={}, update={}, ratio={}) - self.mock_fetch_mv_in_adam = MagicMock(return_value=mv_result) - self.optimizer._fetch_mv_in_adam = self.mock_fetch_mv_in_adam + self.mock_fetch_mv = MagicMock(return_value=mv_result) + self.optimizer.fetch_mv = self.mock_fetch_mv - res = self.optimizer.fetch_mv(self.monitor, self.torch_opt, self.params2name) - self.mock_fetch_mv_in_adam.assert_called_once_with(self.monitor, self.torch_opt, self.params2name) + res = self.optimizer.fetch_mv(self.monitor, self.params2name) + self.mock_fetch_mv.assert_called_once_with(self.monitor, self.params2name) self.assertIsInstance(res, MVResult) @@ -129,26 +124,27 @@ class TestMegatronChainedDistributedOptimizerMon(unittest.TestCase): self.monitor = MagicMock() self.torch_opt = MagicMock() self.params2name = MagicMock() + self.torch_opt.chained_optimizers = [MagicMock(), MagicMock()] mv_result = MVResult(exp_avg={}, exp_avg_sq={}, update={}, ratio={}) - self.mock_fetch_mv_in_adam = MagicMock(return_value=mv_result) - self.optimizer = MegatronChainedDistributedOptimizerMon() + self.mock_fetch_mv = MagicMock(return_value=mv_result) + self.optimizer = MegatronChainedDistributedOptimizerMon(self.torch_opt) def test_fetch_mv_with_valid_optimizer(self): - self.torch_opt.model_float16_groups = [MagicMock()] - self.torch_opt.shard_fp32_from_float16_groups = [MagicMock()] - self.optimizer._fetch_mv_in_adam = self.mock_fetch_mv_in_adam + for opt in self.torch_opt.chained_optimizers: + opt.model_float16_groups = [MagicMock()] + opt.shard_fp32_from_float16_groups = [MagicMock()] + self.optimizer.fetch_mv = self.mock_fetch_mv - res = self.optimizer.fetch_mv(self.monitor, self.torch_opt, self.params2name) + res = self.optimizer.fetch_mv(self.monitor, self.params2name) self.assertIsInstance(res, MVResult) def test_fetch_mv_with_invalid_optimizer(self): - self.torch_opt = Mock() - self.torch_opt.model_float16_groups = None - self.torch_opt.shard_fp32_from_float16_groups = None - self.optimizer._fetch_mv_in_adam = self.mock_fetch_mv_in_adam + for opt in self.torch_opt.chained_optimizers: + del opt.model_float16_groups + del opt.shard_fp32_from_float16_groups with self.assertRaises(Exception): - self.optimizer.fetch_mv(self.monitor, self.torch_opt, self.params2name) + self.optimizer.fetch_mv(self.monitor, self.params2name) class TestMegatronDistributedOptimizerMon(unittest.TestCase): @@ -157,25 +153,23 @@ class TestMegatronDistributedOptimizerMon(unittest.TestCase): self.torch_opt = MagicMock() self.params2name = MagicMock() mv_result = MVResult(exp_avg={}, exp_avg_sq={}, update={}, ratio={}) - self.mock_fetch_mv_in_adam = MagicMock(return_value=mv_result) - self.optimizer = MegatronDistributedOptimizerMon() + self.mock_fetch_mv = MagicMock(return_value=mv_result) + self.optimizer = MegatronDistributedOptimizerMon(self.torch_opt) def test_fetch_mv_with_valid_optimizer(self): self.torch_opt.model_float16_groups = [MagicMock()] self.torch_opt.shard_fp32_from_float16_groups = [MagicMock()] - self.optimizer._fetch_mv_in_adam = self.mock_fetch_mv_in_adam + self.optimizer.fetch_mv = self.mock_fetch_mv - res = self.optimizer.fetch_mv(self.monitor, self.torch_opt, self.params2name) + res = self.optimizer.fetch_mv(self.monitor, self.params2name) self.assertIsInstance(res, MVResult) def test_fetch_mv_with_invalid_optimizer(self): - self.torch_opt = Mock() self.torch_opt.model_float16_groups = None self.torch_opt.shard_fp32_from_float16_groups = None - self.optimizer._fetch_mv_in_adam = self.mock_fetch_mv_in_adam with self.assertRaises(Exception): - self.optimizer.fetch_mv(self.monitor, self.torch_opt, self.params2name) + self.optimizer.fetch_mv(self.monitor, self.params2name) class TestCommonFetchMv(unittest.TestCase): @@ -184,103 +178,189 @@ class TestCommonFetchMv(unittest.TestCase): self.torch_opt = MagicMock() self.params2name = MagicMock() - def test_megatron_fp32_optimizer_mon(self): - self.optimizer = MegatronFP32OptimizerMon() - res = self.optimizer.fetch_mv(self.monitor, self.torch_opt, self.params2name) + def test_optimizer_mon(self): + self.optimizer = OptimizerMon(None) + res = self.optimizer.fetch_mv(self.monitor, self.params2name) self.assertIsInstance(res, MVResult) - def test_deepspeed_zero_optimizer_stage0_mon(self): - self.optimizer = DeepSpeedZeroOptimizerStage0Mon() - res = self.optimizer.fetch_mv(self.monitor, self.torch_opt, self.params2name) - self.assertIsInstance(res, MVResult) - def test_dummy_optimizer_mon(self): - self.optimizer = DummyOptimizerMon() - res = self.optimizer.fetch_mv(self.monitor, self.torch_opt, self.params2name) - self.assertIsInstance(res, MVResult) +class TestDeepSpeedZeroOptimizer(unittest.TestCase): + def setUp(self): + bit16_groups, param_names, param_slice_mappings, _ = setup_param_groups() + mock_opt = MagicMock() + mock_opt.state_dict.return_value = { + 'param_slice_mappings': param_slice_mappings + } + mock_opt.param_names = param_names + mock_opt.bit16_groups = bit16_groups + self.torch_opt = mock_opt + self.mock_monitor = setup_mock_monitor() + self.optimizer_mon = DeepSpeedZeroOptimizerMon(mock_opt) + self.optimizer_mon.bit16_groups = mock_opt.bit16_groups + self.optimizer_mon.param2group = self.optimizer_mon.get_group_index() + + def test_param_not_in_partition(self): + param_in_partition = list(self.torch_opt.param_names.keys())[0] + param_not_in_partition = torch.randn(2,3) + + self.assertFalse( + self.optimizer_mon.param_not_in_partition(param_in_partition, 0) + ) + self.assertTrue( + self.optimizer_mon.param_not_in_partition(param_not_in_partition, 0) + ) + + def test_get_position(self): + param_in_partition = list(self.torch_opt.param_names.keys())[0] + start, numel = self.optimizer_mon.get_position(param_in_partition, 0) + self.assertEqual(start, 0) + self.assertEqual(numel, 6) -class TestDeepSpeedZeroOptimizerStage3Mon(unittest.TestCase): - def test_get_param_index(self): - self.torch_opt = Mock() - self.torch_opt.fp16_partitioned_groups = [ - [Mock(flatten=lambda: [1, 2, 3]), - Mock(flatten=lambda: [4, 5])], - [Mock(flatten=lambda: [6, 7, 8, 9])] - ] - self.params2name = {'param1': 'weight1', 'param2': 'weight2'} - self.name2index = {'weight1': 0, 'weight2': 2} + def test_get_group_index(self): + param = list(self.torch_opt.param_names.keys())[6] + self.assertEqual(self.optimizer_mon.param2group[param], 1) - optimizer_stage3_mon = DeepSpeedZeroOptimizerStage3Mon() - name2indices = optimizer_stage3_mon.get_param_index(self.params2name, self.name2index, self.torch_opt) +class TestDeepSpeedZeroOptimizerStage0Mon(unittest.TestCase): + def setUp(self): + bit16_groups, param_names, param_slice_mappings, _ = setup_param_groups() - expected_name2indices = {'weight1': (0, 3, 0, None), 'weight2': (5, 9, 1, None)} - self.assertDictEqual(dict(name2indices), expected_name2indices) + mock_opt = MagicMock() + mock_opt.state_dict.return_value = { + 'param_slice_mappings': param_slice_mappings + } + mock_opt.param_names = param_names + mock_opt.bf16_groups = bit16_groups + mock_opt.fp32_groups_flat_partition = [torch.stack(group,dim=0).flatten().float() \ + for group in bit16_groups]# mock name 2 index in subgroup + mock_opt.state = { + flat_group: { + 'exp_avg': torch.ones_like(flat_group), + 'exp_avg_sq': torch.ones_like(flat_group) + } for flat_group in mock_opt.fp32_groups_flat_partition + } + mock_opt.cpu_offload = False + + self.torch_opt = mock_opt + self.mock_monitor = setup_mock_monitor() + self.optimizer_mon = DeepSpeedZeroOptimizerStage0Mon(mock_opt) + + def test_get_grad_for_param(self): + param = list(self.torch_opt.param_names.keys())[0] + group_idx = 0 + param_id = 2 + grad_expected = torch.randn_like(param) + self.torch_opt.fp32_groups_gradient_dict = [[0, 0, grad_expected, 0]] + grad = self.optimizer_mon.get_grad_for_param(param, group_idx, param_id) + + self.assertTrue(torch.equal(grad_expected, grad)) + + def test_fetch_grad(self): + self.torch_opt.fp32_groups_gradient_dict = [[torch.randn_like(param) for param in group] for group in self.optimizer_mon.bit16_groups] + self.mock_monitor.name2tag = {name:{MonitorConst.POST_GRAD: name} for name in self.torch_opt.param_names.values()} + result = self.optimizer_mon.fetch_grad(self.mock_monitor, self.torch_opt.param_names) + for _, name in self.torch_opt.param_names.items(): + group_index, param_id = [int(i) for i in name.replace('param','').split('_')] + self.assertTrue(torch.equal(result[name], self.torch_opt.fp32_groups_gradient_dict[group_index][param_id])) def test_fetch_mv(self): - self.monitor = MagicMock() - self.torch_opt = MagicMock() - self.params2name = MagicMock() - self.torch_opt.fp16_partitioned_groups = MagicMock() - self.optimizer = DeepSpeedZeroOptimizerStage3Mon() - - # mock _fetch_mv_grad_in_adam - mv_result = MVGradResult(exp_avg={}, exp_avg_sq={}, update={}, ratio={}, grad={}) - self.mock_fetch_mv_grad_in_adam = MagicMock(return_value=mv_result) - self.optimizer._fetch_mv_grad_in_adam = self.mock_fetch_mv_grad_in_adam - - res = self.optimizer.fetch_mv(self.monitor, self.torch_opt, self.params2name) - self.assertIsInstance(res, MVGradResult) + del self.torch_opt.chained_optimizers + del self.torch_opt.param_to_cpu_states_map + result = self.optimizer_mon.fetch_mv(self.mock_monitor, self.torch_opt.param_names) + for param, name in self.torch_opt.param_names.items(): + self.assertTrue(torch.equal(result.exp_avg[name], torch.ones_like(param).flatten())) + self.assertTrue(torch.equal(result.exp_avg_sq[name], torch.ones_like(param).flatten())) class TestDeepSpeedZeroOptimizerStage1or2Mon(unittest.TestCase): - def test_get_group_index(self): - self.fp32_length = [10, 20, 30, 40] - self.world_size = 4 - self.indexes = [5, 7, 12, 25, 35, 45] - self.expected_results = [(40, 0), (40, 0), (12, 1), (24, 2), (34, 2), (40, 0)] - - optimizer = DeepSpeedZeroOptimizerStage1or2Mon() - results = [optimizer.get_group_index(self.fp32_length, self.world_size, index) for index in self.indexes] - self.assertEqual(results, self.expected_results) + def setUp(self): + bit16_groups, param_names, param_slice_mappings, _ = setup_param_groups() - @patch('msprobe.pytorch.monitor.optimizer_collect.dist') - def test_get_param_index(self, mock_dist): - mock_dist.get_world_size.return_value = 4 + mock_opt = MagicMock() + mock_opt.state_dict.return_value = { + 'param_slice_mappings': param_slice_mappings + } + mock_opt.param_names = param_names + mock_opt.bit16_groups = bit16_groups + mock_opt.single_partition_of_fp32_groups = [torch.stack(group,dim=0).flatten().float() \ + for group in bit16_groups] + mock_opt.averaged_gradients = {group_idx: [torch.randn_like(param) for param in group] for group_idx, group in enumerate(bit16_groups)}# mock name 2 index in subgroup + mock_opt.state = { + flat_group: { + 'exp_avg': torch.ones_like(flat_group), + 'exp_avg_sq': torch.ones_like(flat_group) + } for flat_group in mock_opt.single_partition_of_fp32_groups + } + mock_opt.cpu_offload = False + + self.torch_opt = mock_opt + self.mock_monitor = setup_mock_monitor() + self.optimizer_mon = DeepSpeedZeroOptimizerStage1or2Mon(mock_opt) + + def test_get_grad_for_param(self): + param = list(self.torch_opt.param_names.keys())[0] + group_idx = 0 + param_id = 2 + grad_expected = torch.randn_like(param) + self.torch_opt.averaged_gradients = [[0, 0, grad_expected, 0]] + grad = self.optimizer_mon.get_grad_for_param(param, group_idx, param_id) + + self.assertTrue(torch.equal(grad_expected, grad)) + + def test_fetch_grad(self): + self.mock_monitor.name2tag = {name:{MonitorConst.POST_GRAD: name} for name in self.torch_opt.param_names.values()} + result = self.optimizer_mon.fetch_grad(self.mock_monitor, self.torch_opt.param_names) + for param, name in self.torch_opt.param_names.items(): + group_index, param_id = [int(i) for i in name.replace('param','').split('_')] + self.assertTrue(torch.equal(result[name], self.torch_opt.averaged_gradients[group_index][param_id])) - self.params2name = {'param1': 'weight', 'param2': 'bias'} - self.name2index = {'weight': 0, 'bias': 1} + def test_fetch_mv(self): + del self.torch_opt.chained_optimizers + del self.torch_opt.param_to_cpu_states_map + result = self.optimizer_mon.fetch_mv(self.mock_monitor, self.torch_opt.param_names) + for param, name in self.torch_opt.param_names.items(): + self.assertTrue(torch.equal(result.exp_avg[name], torch.ones_like(param).flatten())) + self.assertTrue(torch.equal(result.exp_avg_sq[name], torch.ones_like(param).flatten())) - self.optimizer_monitor = DeepSpeedZeroOptimizerStage1or2Mon() - self.torch_opt = MagicMock() - self.torch_opt.groups_padding = [1, 2, 3] - self.torch_opt.single_partition_of_fp32_groups = [torch.tensor([1, 2]), torch.tensor([3, 4, 5])] - self.torch_opt.bit16_groups = [ - [torch.tensor([6, 7]), torch.tensor([8])], - [torch.tensor([9, 10, 11])] - ] - - name2indices = self.optimizer_monitor.get_param_index(self.params2name, self.name2index, self.torch_opt) - for name, indices in name2indices.items(): - self.assertIn(name, self.params2name.values()) - self.assertIsInstance(indices, tuple) - self.assertEqual(len(indices), 4) +class TestDeepSpeedZeroOptimizerStage3Mon(unittest.TestCase): + def setUp(self): + bit16_groups, param_names, _, grad_position = setup_param_groups() + + mock_opt = MagicMock() + mock_opt.param_names = param_names + mock_opt.fp16_groups = bit16_groups + mock_opt.fp32_partitioned_groups_flat = [torch.stack(group,dim=0).flatten().float() + for group in bit16_groups] + mock_opt.averaged_gradients = {group_idx: [torch.randn_like(param) for param in group] + for group_idx, group in enumerate(bit16_groups)} + mock_opt.grad_position = grad_position + mock_opt.get_param_id = lambda x: int(param_names[x].split('_')[1]) + mock_opt.state = { + flat_group: { + 'exp_avg': torch.ones_like(flat_group), + 'exp_avg_sq': torch.ones_like(flat_group) + } for flat_group in mock_opt.fp32_partitioned_groups_flat + } + + self.torch_opt = mock_opt + self.optimizer_mon = DeepSpeedZeroOptimizerStage3Mon(mock_opt) + self.mock_monitor = setup_mock_monitor() + + def test_fetch_grad(self): + self.mock_monitor.name2tag = {name:{MonitorConst.POST_GRAD: name} for name in self.torch_opt.param_names.values()} + result = self.optimizer_mon.fetch_grad(self.mock_monitor, self.torch_opt.param_names) + for param, name in self.torch_opt.param_names.items(): + group_index, param_id = [int(i) for i in name.replace('param','').split('_')] + self.assertTrue(torch.equal(result[name], self.torch_opt.averaged_gradients[group_index][param_id])) def test_fetch_mv(self): - self.monitor = MagicMock() - self.torch_opt = MagicMock() - self.params2name = MagicMock() - self.torch_opt.fp16_partitioned_groups = MagicMock() - self.optimizer = DeepSpeedZeroOptimizerStage1or2Mon() - - # mock _fetch_mv_grad_in_adam - mv_result = MVGradResult(exp_avg={}, exp_avg_sq={}, update={}, ratio={}, grad={}) - self.mock_fetch_mv_grad_in_adam = MagicMock(return_value=mv_result) - self.optimizer._fetch_mv_grad_in_adam = self.mock_fetch_mv_grad_in_adam - - res = self.optimizer.fetch_mv(self.monitor, self.torch_opt, self.params2name) - self.assertIsInstance(res, MVGradResult) + del self.torch_opt.chained_optimizers + del self.torch_opt.param_to_cpu_states_map + result = self.optimizer_mon.fetch_mv(self.mock_monitor, self.torch_opt.param_names) + for param, name in self.torch_opt.param_names.items(): + self.assertTrue(torch.equal(result.exp_avg[name], torch.ones_like(param).flatten())) + self.assertTrue(torch.equal(result.exp_avg_sq[name], torch.ones_like(param).flatten())) class TestOptimizerMonFactory(unittest.TestCase): @@ -291,48 +371,48 @@ class TestOptimizerMonFactory(unittest.TestCase): mix_optimizer_class = MagicMock() mix_optimizer_class.__name__ = "Float16OptimizerWithFloat16Params" mix_optimizer.__class__ = mix_optimizer_class - self.assertIsInstance(OptimizerMonFactory.create_optimizer_mon(mix_optimizer)[0], + self.assertIsInstance(OptimizerMonFactory.create_optimizer_mon(mix_optimizer), MixPrecisionOptimizerMon) dis_optimizer = MagicMock() dis_optimizer_class = MagicMock() dis_optimizer_class.__name__ = "DistributedOptimizer" dis_optimizer.__class__ = dis_optimizer_class - self.assertIsInstance(OptimizerMonFactory.create_optimizer_mon(dis_optimizer)[0], + self.assertIsInstance(OptimizerMonFactory.create_optimizer_mon(dis_optimizer), MegatronDistributedOptimizerMon) fp32_optimizer = MagicMock() fp32_optimizer_class = MagicMock() fp32_optimizer_class.__name__ = "FP32Optimizer" fp32_optimizer.__class__ = fp32_optimizer_class - self.assertIsInstance(OptimizerMonFactory.create_optimizer_mon(fp32_optimizer)[0], - MegatronFP32OptimizerMon) + self.assertIsInstance(OptimizerMonFactory.create_optimizer_mon(fp32_optimizer), + OptimizerMon) chained_optimizer = MagicMock() chained_optimizer_class = MagicMock() chained_optimizer_class.__name__ = "ChainedOptimizer" chained_optimizer.__class__ = chained_optimizer_class chained_optimizer.chained_optimizers = [mix_optimizer, mix_optimizer] - self.assertIsInstance(OptimizerMonFactory.create_optimizer_mon(chained_optimizer)[0], + self.assertIsInstance(OptimizerMonFactory.create_optimizer_mon(chained_optimizer), MegatronChainedMixPrecisionOptimizerMon) chained_optimizer.chained_optimizers = [dis_optimizer, dis_optimizer] - self.assertIsInstance(OptimizerMonFactory.create_optimizer_mon(chained_optimizer)[0], + self.assertIsInstance(OptimizerMonFactory.create_optimizer_mon(chained_optimizer), MegatronChainedDistributedOptimizerMon) deepspeed_optimizer = MagicMock() deepspeed_optimizer_class = MagicMock() deepspeed_optimizer_class.__name__ = "BF16_Optimizer" deepspeed_optimizer.__class__ = deepspeed_optimizer_class - self.assertIsInstance(OptimizerMonFactory.create_optimizer_mon(deepspeed_optimizer)[0], + self.assertIsInstance(OptimizerMonFactory.create_optimizer_mon(deepspeed_optimizer), DeepSpeedZeroOptimizerStage0Mon) deepspeed_optimizer_class.__name__ = "DeepSpeedZeroOptimizer" - self.assertIsInstance(OptimizerMonFactory.create_optimizer_mon(deepspeed_optimizer)[0], + self.assertIsInstance(OptimizerMonFactory.create_optimizer_mon(deepspeed_optimizer), DeepSpeedZeroOptimizerStage1or2Mon) deepspeed_optimizer_class.__name__ = "DeepSpeedZeroOptimizer_Stage3" - self.assertIsInstance(OptimizerMonFactory.create_optimizer_mon(deepspeed_optimizer)[0], + self.assertIsInstance(OptimizerMonFactory.create_optimizer_mon(deepspeed_optimizer), DeepSpeedZeroOptimizerStage3Mon) - # 测试未知的优化器类型,应该返回DummyOptimizerMon + # 测试未知的优化器类型,应该返回OptimizerMon unknown_optimizer = MagicMock() unknown_optimizer_class = MagicMock() unknown_optimizer_class.__name__ = "unknown" unknown_optimizer.__class__ = unknown_optimizer_class - self.assertIsInstance(OptimizerMonFactory.create_optimizer_mon(unknown_optimizer)[0], DummyOptimizerMon) + self.assertIsInstance(OptimizerMonFactory.create_optimizer_mon(unknown_optimizer), OptimizerMon) if __name__ == '__main__': diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/parse_tool/test_interactive_cli.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/parse_tool/test_interactive_cli.py index b875bd7e8e17f6b869b2a1b1498982b2a17e1258..3a09d41588a94043f54161023ffbba573c60d76c 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/parse_tool/test_interactive_cli.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/parse_tool/test_interactive_cli.py @@ -26,63 +26,50 @@ class TestInteractiveCli(unittest.TestCase): @patch('msprobe.pytorch.parse_tool.lib.interactive_cli.ParseTool.prepare', return_value=None) def test_prepare(self, mock_prepare): self.interactive_cli.prepare() - mock_prepare.assert_called_once() - @patch('msprobe.pytorch.parse_tool.lib.interactive_cli.Util.execute_command', return_value=None) - def test_default(self, mock_execute_command): - res = self.interactive_cli.default() - - mock_execute_command.assert_called_once() - self.assertFalse(res) - - @patch('msprobe.pytorch.parse_tool.lib.interactive_cli.Util.execute_command', return_value=None) - def test_do_run(self, mock_execute_command): - self.interactive_cli.do_run() - - mock_execute_command.assert_called_once() + def test_default(self, command='rm'): + res = self.interactive_cli.default(command) + self.assertIsNone(res) @patch('msprobe.pytorch.parse_tool.lib.interactive_cli.ParseTool.do_compare_converted_dir') @patch('msprobe.pytorch.parse_tool.lib.interactive_cli.ParseTool.do_vector_compare') def test_do_vc(self, mock_do_vector_compare, mock_do_compare_converted_dir): - with patch('msprobe.pytorch.parse_tool.lib.interactive_cli.Util.check_path_valid'), \ - patch('msprobe.pytorch.parse_tool.lib.interactive_cli.Util.check_files_in_path'): - with patch('msprobe.pytorch.parse_tool.lib.interactive_cli.Util.dir_contains_only', return_value=False): - self.interactive_cli.do_vc('-m my_dump_path -g golden_dump_path -out output_path -cmp_path msaccucmp_path') - + with (patch('msprobe.pytorch.parse_tool.lib.interactive_cli.Util.check_path_valid'), + patch('msprobe.pytorch.parse_tool.lib.interactive_cli.Util.check_files_in_path')): + with patch('msprobe.pytorch.parse_tool.lib.interactive_cli.Util.dir_contains_only', + return_value=False): + self.interactive_cli.do_vc( + '-m my_dump_path -g golden_dump_path -out output_path -cmp_path msaccucmp_path') mock_do_vector_compare.assert_called_once() - with patch('msprobe.pytorch.parse_tool.lib.interactive_cli.Util.dir_contains_only', return_value=True): - self.interactive_cli.do_vc('-m my_dump_path -g golden_dump_path -out output_path -cmp_path msaccucmp_path') - + with patch('msprobe.pytorch.parse_tool.lib.interactive_cli.Util.dir_contains_only', + return_value=True): + self.interactive_cli.do_vc( + '-m my_dump_path -g golden_dump_path -out output_path -cmp_path msaccucmp_path') mock_do_compare_converted_dir.assert_called_once() @patch('msprobe.pytorch.parse_tool.lib.interactive_cli.ParseTool.do_convert_dump', return_value=None) def test_do_dc(self, mock_do_convert_dump): self.interactive_cli.do_dc('-n file_name/file_path -f format -out output_path') - mock_do_convert_dump.assert_called_once() @patch('msprobe.pytorch.parse_tool.lib.interactive_cli.ParseTool.do_print_data', return_value=None) def test_do_pt(self, mock_do_print_data): self.interactive_cli.do_pt('-n file_path') - mock_do_print_data.assert_called_once() @patch('msprobe.pytorch.parse_tool.lib.interactive_cli.ParseTool.do_parse_pkl', return_value=None) def test_do_pk(self, mock_do_parse_pkl): self.interactive_cli.do_pk('-f pkl_path -n api_name') - mock_do_parse_pkl.assert_called_once() @patch('msprobe.pytorch.parse_tool.lib.interactive_cli.ParseTool.do_compare_data', return_value=None) def test_do_cn(self, mock_do_comapre_data): self.interactive_cli.do_cn('-m my_data*.npy -g golden*.npu -p num -al atol -rl rtol') - mock_do_comapre_data.assert_called_once() @patch('msprobe.pytorch.parse_tool.lib.interactive_cli.ParseTool.do_convert_api_dir', return_value=None) def test_do_cad(self, mock_do_convert_api_dir): self.interactive_cli.do_cad('-m my_dump_path -out output_path -asc msaccucmp_path') - mock_do_convert_api_dir.assert_called_once() diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/parse_tool/test_parse_utils.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/parse_tool/test_parse_utils.py index dfec4d20366c6e834939130009dc6d33d1cbe9ed..c148f84d0d20213631e9be039521a14d970849e9 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/parse_tool/test_parse_utils.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/parse_tool/test_parse_utils.py @@ -88,7 +88,7 @@ class TestUtils(unittest.TestCase): obj = np.array([1, 2, 3, 4, 5]) res = self.util.get_md5_for_numpy(obj) - self.assertEqual(res, '3cd8e13ca72251bfd8c08e209abcf46f') + self.assertEqual(res, 'baa24928') def test_deal_with_dir_or_file_inconsistency(self): with self.assertRaises(ParseException): diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/test_pt_config.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/test_pt_config.py index c1b8bac47fda100636b55fbc5ad452c2843e8aaa..191b2b6baa7b69e74a33987f7c17545399a6e202 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/test_pt_config.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/test_pt_config.py @@ -105,9 +105,7 @@ class TestTensorConfig(unittest.TestCase): self.config._check_file_format() self.assertIn(str(context.exception), "file_format is invalid") - @patch('msprobe.pytorch.pt_config.check_crt_valid') - def test_check_online_run_ut(self, mock_check_crt_valid): - mock_check_crt_valid.return_value = True + def test_check_online_run_ut(self): self.config.online_run_ut = "True" with self.assertRaises(Exception) as context: @@ -137,6 +135,10 @@ class TestTensorConfig(unittest.TestCase): file.write("1") with open(os.path.join(self.config.tls_path, "client.crt"), 'w') as file: file.write("1") + with open(os.path.join(self.config.tls_path, "ca.crt"), 'w') as file: + file.write("1") + with open(os.path.join(self.config.tls_path, "crl.pem"), 'w') as file: + file.write("1") self.config._check_online_run_ut() shutil.rmtree(self.config.tls_path) self.config.tls_path = "" @@ -181,7 +183,7 @@ class TestStatisticsConfig(unittest.TestCase): self.config.summary_mode = "invalid_mode" with self.assertRaises(Exception) as context: self.config._check_summary_mode() - self.assertIn(str(context.exception), "summary_mode is invalid") + self.assertIn(str(context.exception), "[msprobe] 无效参数:",) def test_check_summary_mode_none(self): self.config.summary_mode = None @@ -261,14 +263,14 @@ class TestFreeBenchmarkCheckConfig(unittest.TestCase): config = FreeBenchmarkCheckConfig(invalid_config) mock_error.assert_called_once() self.assertIn("fuzz_device is invalid", str(mock_error.call_args)) - + @patch('msprobe.core.common.log.logger.error_log_with_exp') def test_check_fuzz_device_cpu_mode_invalid(self, mock_error): invalid_config = self.valid_config.copy() invalid_config["fuzz_device"] = "cpu" invalid_config["pert_mode"] = "INVALID_CPU_MODE" config = FreeBenchmarkCheckConfig(invalid_config) - self.assertIn("You neet to and can only set fuzz_device as ", str(mock_error.call_args)) + self.assertIn("You need to and can only set fuzz_device as ", str(mock_error.call_args)) @patch('msprobe.core.common.log.logger.error_log_with_exp') def test_check_handler_type_invalid(self, mock_error): @@ -277,7 +279,7 @@ class TestFreeBenchmarkCheckConfig(unittest.TestCase): config = FreeBenchmarkCheckConfig(invalid_config) mock_error.assert_called_once() self.assertIn("handler_type is invalid", str(mock_error.call_args)) - + @patch('msprobe.core.common.log.logger.error_log_with_exp') def test_check_fuzz_stage_invalid(self, mock_error): invalid_config = self.valid_config.copy() @@ -319,7 +321,7 @@ class TestFreeBenchmarkCheckConfig(unittest.TestCase): config = FreeBenchmarkCheckConfig(invalid_config) mock_error.assert_called_once() self.assertIn("preheat_step must be greater than 0", str(mock_error.call_args)) - + @patch('msprobe.core.common.log.logger.error_log_with_exp') def test_check_preheat_max_sample_not_int(self, mock_error): invalid_config = self.valid_config.copy() @@ -328,7 +330,7 @@ class TestFreeBenchmarkCheckConfig(unittest.TestCase): config = FreeBenchmarkCheckConfig(invalid_config) mock_error.assert_called_once() self.assertIn("max_sample is invalid, it should be an integer", str(mock_error.call_args)) - + @patch('msprobe.core.common.log.logger.error_log_with_exp') def test_check_max_sample_invalid_not_great_than_zero(self, mock_error): invalid_config = self.valid_config.copy() @@ -397,19 +399,19 @@ class TestRunUTConfig(unittest.TestCase): def test_check_nfs_path_config_not_exist(self, mock_exists): with self.assertRaises(Exception) as context: RunUTConfig.check_nfs_path_config("./invalid_nfs") - self.assertIn("does not exist", str(context.exception)) + self.assertIn("[msprobe] 非法文件路径:", str(context.exception)) @patch('os.path.exists', return_value=False) def test_check_tls_path_config_not_exist(self, mock_exists): with self.assertRaises(Exception) as context: RunUTConfig.check_tls_path_config("./invalid_tls") - self.assertIn("does not exist", str(context.exception)) + self.assertIn("[msprobe] 非法文件路径:", str(context.exception)) def test_check_run_ut_config(self): with patch.object(RunUTConfig, 'check_filter_list_config') as mock_filter, \ - patch.object(RunUTConfig, 'check_error_data_path_config') as mock_error, \ - patch.object(RunUTConfig, 'check_nfs_path_config') as mock_nfs, \ - patch.object(RunUTConfig, 'check_tls_path_config') as mock_tls: + patch.object(RunUTConfig, 'check_error_data_path_config') as mock_error, \ + patch.object(RunUTConfig, 'check_nfs_path_config') as mock_nfs, \ + patch.object(RunUTConfig, 'check_tls_path_config') as mock_tls: self.config.check_run_ut_config() mock_filter.assert_called() mock_error.assert_called() @@ -442,3 +444,7 @@ class TestGradToolConfig(unittest.TestCase): with self.assertRaises(Exception) as context: GradToolConfig(json_config) self.assertTrue("param_list must be a list" in str(context.exception)) + + +if __name__ == '__main__': + unittest.main() diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/test_pt_debug_save.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/test_pt_debug_save.py index 534437260e66d9e586d69d557d30e308a9f4f3ee..e517e1cefe4987b62aa2040f1a4e9db0b8dfbe98 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/test_pt_debug_save.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/test_pt_debug_save.py @@ -18,6 +18,7 @@ import torch from msprobe.pytorch import PrecisionDebugger from msprobe.core.common_config import CommonConfig, BaseConfig +from msprobe.core.debugger.precision_debugger import BasePrecisionDebugger class TestPytorchDebuggerSave(TestCase): @@ -36,13 +37,14 @@ class TestPytorchDebuggerSave(TestCase): } common_config = CommonConfig(statistics_task_json) task_config = BaseConfig(statistics_task_json) - with patch("msprobe.pytorch.debugger.precision_debugger.parse_json_config", return_value=(common_config, task_config)): + with patch.object(BasePrecisionDebugger, "_parse_config_path", return_value=(common_config, task_config)): self.debugger = PrecisionDebugger() def test_forward_and_backward(self): def forward_func(x, y): PrecisionDebugger.save(x, "x_tensor") return x * y + x = torch.tensor([1.]) y = torch.tensor([2.]) x.requires_grad = True @@ -53,28 +55,28 @@ class TestPytorchDebuggerSave(TestCase): "framework": "pytorch", "dump_data_dir": None, "data": { - "x_tensor.0": { + "x_tensor.0.debug": { "type": "torch.Tensor", "dtype": "torch.float32", "shape": torch.Size([1]), - "Max": 1.0, - "Min": 1.0, - "Mean": 1.0, - "Norm": 1.0, "requires_grad": True }, - "x_tensor_grad.0": { + "x_tensor_grad.0.debug": { "type": "torch.Tensor", "dtype": "torch.float32", "shape": torch.Size([1]), - "Max": 2.0, - "Min": 2.0, - "Mean": 2.0, - "Norm": 2.0, "requires_grad": False } } } + loss = forward_func(x, y) loss.backward() - self.assertEqual(self.debugger.service.data_collector.data_writer.cache_debug, result_json) \ No newline at end of file + + result = self.debugger.service.data_collector.data_writer.cache_debug + # Remove 'tensor_stat_index' from all entries in the data dictionary + for key in result["data"]: + if 'tensor_stat_index' in result["data"][key]: + del result["data"][key]['tensor_stat_index'] + + self.assertEqual(result, result_json) \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/test_pt_service.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/test_pt_service.py new file mode 100644 index 0000000000000000000000000000000000000000..c0c56315ec3b9ddb2bdd1f0724f9edd1f4597b99 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/test_pt_service.py @@ -0,0 +1,178 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest.mock import MagicMock, patch +from msprobe.pytorch.pytorch_service import PytorchService +from msprobe.core.common.utils import Const +from msprobe.pytorch.dump.module_dump.module_processer import ModuleProcesser +from msprobe.pytorch.attl_manager import ATTLManager +from msprobe.pytorch.hook_module.hook_module import HOOKModule + + +class TestPytorchService(unittest.TestCase): + def setUp(self): + self.config = MagicMock() + self.config.step = [] + self.config.rank = [] + self.config.level = Const.LEVEL_MIX + self.config.task = Const.STATISTICS + self.config.online_run_ut_recompute = False + + with patch('msprobe.core.service.build_data_collector'): + self.service = PytorchService(self.config) + + self.service.logger = MagicMock() + self.service.data_collector = MagicMock() + self.service.module_processor = MagicMock() + self.service.attl_manager = MagicMock(spec=ATTLManager) + self.service.attl_manager.attl = MagicMock() + self.service.api_register = MagicMock() + + def test_framework_type(self): + self.assertEqual(self.service._get_framework_type, Const.PT_FRAMEWORK) + + @patch('msprobe.pytorch.pytorch_service.get_rank_if_initialized') + def test_get_current_rank(self, mock_get_rank): + mock_get_rank.return_value = 5 + self.assertEqual(self.service._get_current_rank(), 5) + + def test_init_specific_components(self): + with patch('msprobe.core.service.build_data_collector'): + service = PytorchService(self.config) + + self.assertIsNotNone(service.logger) + self.assertIsNotNone(service.api_register) + self.assertIsNotNone(service.module_processor) + self.assertIsNotNone(service.attl_manager) + self.assertIsNotNone(service.hook_manager) + + def test_register_hook(self): + self.service._register_hook() + self.service.attl_manager.attl_init.assert_called_once() + + @patch('msprobe.pytorch.pytorch_service.register_optimizer_hook') + def test_register_hook_mix_level(self, mock_register_opt): + self.service.config.level = Const.LEVEL_MIX + self.service._register_hook() + mock_register_opt.assert_called_once_with(self.service.data_collector) + + @patch('msprobe.pytorch.pytorch_service.register_optimizer_hook') + def test_register_hook_not_mix_level(self, mock_register_opt): + self.service.config.level = Const.LEVEL_L1 + self.service._register_hook() + mock_register_opt.assert_not_called() + + @patch('msprobe.pytorch.pytorch_service.wrap_jit_script_func') + def test_register_api_hook(self, mock_wrap_jit): + self.service.config.level = Const.LEVEL_L1 + self.service._register_api_hook() + mock_wrap_jit.assert_called_once() + self.service.api_register.initialize_hook.assert_called_once() + + def test_register_module_hook(self): + model_mock = MagicMock() + self.service.model = model_mock + self.service._register_module_hook() + + self.service.module_processor.register_module_hook.assert_called_once_with( + model_mock, self.service.build_hook + ) + + self.assertTrue(self.service.module_processor.enable_module_dump) + + @patch('msprobe.pytorch.pytorch_service.torch_version_above_or_equal_2', new=True) + @patch('msprobe.pytorch.pytorch_service.run_ut_dispatch') + def test_run_ut_dispatch(self, mock_run_ut): + status = True + self.service._run_ut_dispatch(status) + mock_run_ut.assert_called_once_with( + self.service.attl_manager.attl, + status, + self.config.online_run_ut_recompute + ) + + @patch('msprobe.pytorch.pytorch_service.torch_version_above_or_equal_2', new=False) + @patch('msprobe.pytorch.pytorch_service.run_ut_dispatch') + def test_run_ut_dispatch_torch_version_below_2(self, mock_run_ut): + status = True + self.service._run_ut_dispatch(status) + mock_run_ut.assert_not_called() + + @patch.object(HOOKModule, 'reset_module_stats') + @patch.object(ModuleProcesser, 'reset_module_stats') + def test_reset_status(self, mock_reset_module_processor, mock_reset_hook_module): + self.service._reset_status() + mock_reset_hook_module.assert_called_once() + mock_reset_module_processor.assert_called_once() + self.service.data_collector.reset_status.assert_called_once() + + @patch('msprobe.pytorch.pytorch_service.torch_version_above_or_equal_2', new=True) + @patch('msprobe.pytorch.pytorch_service.run_ut_dispatch') + def test_start_with_online_run_ut(self, mock_run_ut): + self.service.config.online_run_ut = True + self.service.data_collector.data_processor.is_terminated = False + model_mock = MagicMock() + + self.service.start(model=model_mock) + + mock_run_ut.assert_called_once_with( + self.service.attl_manager.attl, + True, + self.config.online_run_ut_recompute + ) + + @patch('msprobe.pytorch.pytorch_service.torch_version_above_or_equal_2', return_value=True) + @patch('msprobe.pytorch.pytorch_service.run_ut_dispatch') + def test_stop_with_online_run_ut(self, mock_run_ut, mock_version): + self.service.config.online_run_ut = True + self.service.current_iter = 1 + self.service.current_rank = 0 + self.service.attl_manager.attl = MagicMock() + self.service.stop() + + mock_run_ut.assert_called_once_with( + self.service.attl_manager.attl, + False, + self.config.online_run_ut_recompute + ) + + def test_register_module_hook(self): + self.service.model = MagicMock() + self.service._register_module_hook() + self.service.module_processor.register_module_hook.assert_called_once() + + @patch('msprobe.pytorch.pytorch_service.torch_version_above_or_equal_2', new=True) + @patch('msprobe.pytorch.pytorch_service.run_ut_dispatch') + def test_run_ut_dispatch_with_recompute(self, mock_run_ut): + self.service.attl_manager.attl = None + self.service.config.online_run_ut_recompute = True + status = True + self.service._run_ut_dispatch(status) + mock_run_ut.assert_called_once_with( + self.service.attl_manager.attl, + status, + True + ) + + def test_attl_manager_interaction(self): + self.service.config.online_run_ut = True + self.service.data_collector.data_processor.is_terminated = False + self.service.start(model=MagicMock()) + self.service.attl_manager.attl_init.assert_called_once() + + self.service.data_collector.data_processor.is_terminated = True + self.service.start() + self.service.attl_manager.attl_stop.assert_called_once() diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/test_service.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/test_service.py deleted file mode 100644 index 6687f3111050ea53e14e62f3afd55ae1eff2b8c0..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/test_service.py +++ /dev/null @@ -1,150 +0,0 @@ -import unittest -from unittest.mock import patch, mock_open, MagicMock - -from msprobe.core.common.utils import Const -from msprobe.pytorch.debugger.debugger_config import DebuggerConfig -from msprobe.pytorch.pt_config import parse_json_config -from msprobe.pytorch.service import Service - - -class TestService(unittest.TestCase): - def setUp(self): - mock_json_data = { - "dump_path": "./dump/", - } - with patch("msprobe.pytorch.pt_config.FileOpen", mock_open(read_data='')), \ - patch("msprobe.pytorch.pt_config.load_json", return_value=mock_json_data): - common_config, task_config = parse_json_config("./config.json", Const.STATISTICS) - self.config = DebuggerConfig(common_config, task_config, Const.STATISTICS, "./ut_dump", "L1") - self.service = Service(self.config) - - def test_start_success(self): - with patch("msprobe.pytorch.service.get_rank_if_initialized", return_value=0), \ - patch("msprobe.pytorch.service.Service.create_dirs", return_value=None): - self.service.start(None) - self.assertEqual(self.service.current_rank, 0) - - def test_start_fail(self): - self.service.config.rank = [1, 2] - self.service.current_rank = 3 - self.assertIsNone(self.service.start(None)) - - self.service.config.step = [1, 2] - self.service.current_iter = 3 - self.assertIsNone(self.service.start(None)) - - @patch("msprobe.core.data_dump.data_collector.DataCollector.write_json") - def test_stop_success(self, mock_write_json): - mock_write_json.return_value = None - self.service.stop() - - self.assertFalse(self.service.switch) - - def test_stop_fail(self): - self.service.switch = True - - self.service.config.rank = [1, 2] - self.service.current_rank = 3 - res = self.service.stop() - self.assertIsNone(res) - self.assertTrue(self.service.switch) - - self.service.config.step = [1, 2] - self.service.current_iter = 3 - res = self.service.stop() - self.assertIsNone(res) - self.assertTrue(self.service.switch) - - self.service.config.level = "L2" - res = self.service.stop() - self.assertIsNone(res) - self.assertTrue(self.service.switch) - - self.service.should_stop_service = True - res = self.service.stop() - self.assertIsNone(res) - self.assertTrue(self.service.switch) - - def test_step_success(self): - self.service.step() - self.assertEqual(self.service.current_iter, 1) - - def test_step_fail(self): - self.service.should_stop_service = True - self.assertIsNone(self.service.step()) - - def test_register_module_hook_with_level0(self): - self.service.model = MagicMock() - self.service.build_hook = MagicMock() - self.config.level = "L0" - with patch("msprobe.pytorch.service.logger.info_on_rank_0") as mock_logger, \ - patch("msprobe.pytorch.service.ModuleProcesser.register_module_hook") as mock_register_module_hook: - self.service.register_module_hook() - self.assertEqual(mock_logger.call_count, 1) - mock_register_module_hook.assert_called_once() - - def test_register_api_hook_with_level1(self): - self.service.build_hook = MagicMock() - self.config.level = "L1" - with patch("msprobe.pytorch.service.logger.info_on_rank_0") as mock_logger, \ - patch("msprobe.pytorch.service.api_register.initialize_hook") as mock_init_hook, \ - patch("msprobe.pytorch.service.api_register.api_modularity") as mock_api_modularity: - self.service.register_api_hook() - self.assertEqual(mock_logger.call_count, 1) - mock_init_hook.assert_called_once() - mock_api_modularity.assert_called_once() - - def test_create_dirs(self): - with patch("msprobe.pytorch.service.create_directory"), \ - patch("msprobe.core.data_dump.data_collector.DataCollector.update_dump_paths"), \ - patch("msprobe.core.data_dump.data_collector.DataCollector.initialize_json_file"): - self.service.create_dirs() - self.assertEqual(self.service.dump_iter_dir, "./ut_dump/step0") - - def test_need_end_service(self): - self.service.should_stop_service = True - self.assertTrue(self.service.need_stop_service()) - - self.service.should_stop_service = False - self.service.config.step = [1, 3] - self.service.current_iter = 1 - self.assertFalse(self.service.need_stop_service()) - - self.service.current_iter = 2 - self.assertTrue(self.service.need_stop_service()) - - self.service.current_iter = 4 - self.service.config.level = "L0" - self.service.config.online_run_ut = False - self.assertTrue(self.service.need_stop_service()) - self.assertFalse(self.service.switch) - self.assertTrue(self.service.should_stop_service) - - def test_should_execute_hook_return_false(self): - module = MagicMock() - self.service.switch = False - self.assertFalse(self.service.should_execute_hook("Module", module, True)) - self.assertFalse(self.service.should_execute_hook("api", module, True)) - - self.service.switch = True - module.forward_data_collected = False - self.assertFalse(self.service.should_execute_hook("api", module, False)) - - self.service.inner_switch = True - self.assertFalse(self.service.should_execute_hook("Module", module, True)) - - self.service.inner_switch = False - self.service.data_collector = None - self.assertFalse(self.service.should_execute_hook("Module", module, True)) - - def test_should_execute_hook_return_true(self): - module = MagicMock() - self.service.switch = True - self.service.inner_switch = False - self.service.data_collector = MagicMock() - self.service.data_collector.data_processor = MagicMock() - self.service.data_collector.data_processor.is_terminated = False - self.assertTrue(self.service.should_execute_hook("Module", module, True)) - - module.forward_data_collected = True - self.assertTrue(self.service.should_execute_hook("api", module, False)) diff --git a/debug/accuracy_tools/msprobe/test/resources/layer_mapping/mindspore/dump.json b/debug/accuracy_tools/msprobe/test/resources/layer_mapping/mindspore/dump.json index b55f9e0699fe6329ceeb09a51fe20118c65545e7..153d84e7d117b5be89dfdb522edc39dc066929cb 100644 --- a/debug/accuracy_tools/msprobe/test/resources/layer_mapping/mindspore/dump.json +++ b/debug/accuracy_tools/msprobe/test/resources/layer_mapping/mindspore/dump.json @@ -1,6 +1,7 @@ { "task": "statistics", "level": "mix", + "framework": "mindspore", "dump_data_dir": null, "data": { "Cell.network_with_loss.module.language_model.embedding.word_embeddings.VocabParallelEmbedding.forward.0": { diff --git a/debug/accuracy_tools/msprobe/test/resources/layer_mapping/pytorch/dump.json b/debug/accuracy_tools/msprobe/test/resources/layer_mapping/pytorch/dump.json index d7dd1c0c38e2d24c8b0d19c346a50eb33437d232..02239176a9d690c4ce70c06cc6ab117a3c122811 100644 --- a/debug/accuracy_tools/msprobe/test/resources/layer_mapping/pytorch/dump.json +++ b/debug/accuracy_tools/msprobe/test/resources/layer_mapping/pytorch/dump.json @@ -1,6 +1,7 @@ { "task": "statistics", "level": "mix", + "framework": "pytorch", "dump_data_dir": null, "data": { "Module.module.module.language_model.embedding.word_embeddings.VocabParallelEmbedding.forward.0": { diff --git a/debug/accuracy_tools/msprobe/test/run_ut.py b/debug/accuracy_tools/msprobe/test/run_ut.py index c5ebc6e3f052b8ef7d16694c31c22d16f8ec930a..06671c3d0d0e4440736712cb0718873280482781 100644 --- a/debug/accuracy_tools/msprobe/test/run_ut.py +++ b/debug/accuracy_tools/msprobe/test/run_ut.py @@ -2,6 +2,7 @@ import os import shutil import subprocess import sys +import tempfile from msprobe.core.common.log import logger @@ -20,6 +21,23 @@ def run_ut(): shutil.rmtree(report_dir) os.makedirs(report_dir) + tmpdir = tempfile.mkdtemp() + sitecustomize_path = os.path.join(tmpdir, "sitecustomize.py") + + with open(sitecustomize_path, "w") as f: + f.write(""" +import mindspore + +class Distributed: + P2POp = None + +if not hasattr(mindspore.mint, 'distributed'): + setattr(mindspore.mint, 'distributed', Distributed()) + """) + + env = os.environ.copy() + env["PYTHONPATH"] = f"{tmpdir}:{env.get('PYTHONPATH', '')}" + pytest_cmd = [ "python3", "-m", "pytest", ut_path, diff --git a/debug/accuracy_tools/msprobe/test/visualization_ut/builder/test_graph_builder.py b/debug/accuracy_tools/msprobe/test/visualization_ut/builder/test_graph_builder.py index 706dc8bf82e59f413c3fd559a39af89c6a70be47..2e41f2a325cf77937884b624c14b9ee7bef6c243 100644 --- a/debug/accuracy_tools/msprobe/test/visualization_ut/builder/test_graph_builder.py +++ b/debug/accuracy_tools/msprobe/test/visualization_ut/builder/test_graph_builder.py @@ -32,7 +32,7 @@ class TestGraphBuilder(unittest.TestCase): self.assertIsInstance(graph, Graph) self.assertEqual(len(graph.node_map), 3) - @patch('msprobe.visualization.builder.graph_builder.save_json_file') + @patch('msprobe.visualization.builder.graph_builder.save_json') def test_to_json(self, mock_save_json_file): GraphBuilder.to_json("step/rank/output.vis", self.config) mock_save_json_file.assert_called_once() @@ -111,3 +111,23 @@ class TestGraphBuilder(unittest.TestCase): self.assertEqual(graph.root.subnodes[2].op, NodeOp.module) self.assertEqual(len(graph.root.subnodes[0].subnodes), 0) self.assertEqual(graph.root.subnodes[0].id, 'Module.a.0') + + def test_add_parameters_grad(self): + graph = Graph('TestNet') + graph.add_node(NodeOp.module, 'Module.a.backward.0', graph.root) + graph.add_node(NodeOp.module, 'Module.b.backward.0', graph.root) + graph.add_node(NodeOp.module, 'Module.a.backward.1', graph.root) + graph.add_node(NodeOp.module, 'Module.aa.backward.0', graph.get_node('Module.a.backward.0')) + graph.add_node(NodeOp.module, 'Module.aaa.backward.0', graph.get_node('Module.a.backward.0')) + graph.add_node(NodeOp.module, 'Module.aa.backward.1', graph.get_node('Module.a.backward.1')) + graph.add_node(NodeOp.module, 'Module.aaa.backward.1', graph.get_node('Module.a.backward.1')) + + data_dict = {'Module.a.parameters_grad': {}, 'Module.aaa.parameters_grad': {}} + GraphBuilder._add_parameters_grad(graph, data_dict) + root_nodes_id = [node.id for node in graph.get_node('TestNet').subnodes] + sub_nodes_id0 = [node.id for node in graph.get_node('Module.a.backward.0').subnodes] + sub_nodes_id1 = [node.id for node in graph.get_node('Module.a.backward.1').subnodes] + + self.assertEqual(root_nodes_id[-1], 'Module.a.backward.1') + self.assertEqual(sub_nodes_id0[-1], 'Module.aaa.backward.0') + self.assertEqual(sub_nodes_id1[-1], 'Module.a.parameters_grad') diff --git a/debug/accuracy_tools/msprobe/test/visualization_ut/builder/test_graph_merger.py b/debug/accuracy_tools/msprobe/test/visualization_ut/builder/test_graph_merger.py new file mode 100644 index 0000000000000000000000000000000000000000..d4471b92d5b4534b22374da4b79548c6c064d389 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/visualization_ut/builder/test_graph_merger.py @@ -0,0 +1,424 @@ +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest.mock import patch, MagicMock, call +from msprobe.visualization.builder.graph_merger import ( + GraphMerger, BaseGraphMerger, PPMerger, TPMerger, + NoParallelMerger, TPPPMerger, FullMerger +) +from msprobe.core.common.const import Const +from msprobe.visualization.utils import GraphConst +from msprobe.visualization.graph.node_op import NodeOp +from msprobe.visualization.graph.graph import Graph +from msprobe.core.common.exceptions import MsprobeException + + +class TestGraphMerger(unittest.TestCase): + def setUp(self): + self.build_graph_results = MagicMock() + self.parallel_param = MagicMock(tp=1, pp=1, rank_size=1) + self.is_bench = False + + def test_select_strategy_no_parallel(self): + self.parallel_param.tp = self.parallel_param.pp = self.parallel_param.rank_size = 1 + merger = GraphMerger(self.build_graph_results, self.parallel_param, self.is_bench) + self.assertIsInstance(merger.strategy, NoParallelMerger) + + def test_select_strategy_tp(self): + self.parallel_param.tp = self.parallel_param.rank_size = 2 + self.parallel_param.pp = 1 + merger = GraphMerger(self.build_graph_results, self.parallel_param, self.is_bench) + self.assertIsInstance(merger.strategy, TPMerger) + + def test_select_strategy_pp(self): + self.parallel_param.pp = self.parallel_param.rank_size = 2 + self.parallel_param.tp = 1 + merger = GraphMerger(self.build_graph_results, self.parallel_param, self.is_bench) + self.assertIsInstance(merger.strategy, PPMerger) + + def test_select_strategy_tp_pp(self): + self.parallel_param.tp = self.parallel_param.pp = 2 + self.parallel_param.rank_size = 4 + merger = GraphMerger(self.build_graph_results, self.parallel_param, self.is_bench) + self.assertIsInstance(merger.strategy, TPPPMerger) + + def test_select_strategy_full(self): + self.parallel_param.tp = 2 + self.parallel_param.pp = 2 + self.parallel_param.rank_size = 8 + merger = GraphMerger(self.build_graph_results, self.parallel_param, self.is_bench) + self.assertIsInstance(merger.strategy, FullMerger) + + def test_merge_graph(self): + merger = GraphMerger(self.build_graph_results, self.parallel_param, self.is_bench) + merger.strategy.merge_graphs = MagicMock() + merger.merge_graph() + merger.strategy.merge_graphs.assert_called_once() + + +class TestBaseGraphMerger(unittest.TestCase): + def setUp(self): + self.build_graph_results = [MagicMock(rank=i) for i in range(2)] + self.parallel_param = MagicMock(tp=1, pp=1, rank_size=2) + self.is_bench = False + self.merger = BaseGraphMerger(self.build_graph_results, self.parallel_param, self.is_bench) + + def test_sort_merged_api_collection(self): + graph = MagicMock() + root = MagicMock() + graph.root = root + subnode1 = MagicMock(id=f"{GraphConst.APIS_BETWEEN_MODULES_ALL_RANKS}.0", op=NodeOp.api_collection) + subnode1.subnodes = [MagicMock(id="op_Rank1.0"), MagicMock(id="op_Rank0.0")] + root.subnodes = [subnode1] + self.merger.sort_merged_api_collection(graph) + self.assertEqual([n.id for n in subnode1.subnodes], ["op_Rank0.0", "op_Rank1.0"]) + + def test_update_node_data_key(self): + data_dict = { + "old_id.input.0": {"full_op_name": "old_id.op"}, + "other_key": {"value": "test"} + } + new_dict = self.merger._update_node_data_key("old_id", "new_id", data_dict) + self.assertEqual(new_dict, { + "new_id.input.0": {"full_op_name": "new_id.op"}, + "other_key": {"value": "test"} + }) + + def test_compare_value_same(self): + self.assertTrue(self.merger._compare_value_same(1, 1)) + self.assertFalse(self.merger._compare_value_same(1, 2)) + self.assertTrue(self.merger._compare_value_same("a", "a")) + self.assertTrue(self.merger._compare_value_same(1, 1.00000001, has_uncertainty=True)) + self.assertFalse(self.merger._compare_value_same(1, 1.1, has_uncertainty=True)) + + def test_merge_graph_api_collection(self): + results = [MagicMock() for _ in range(2)] + graph0, graph1 = Graph("name1"), Graph("name2") + results[0].graph, results[1].graph = graph0, graph1 + root0, root1 = MagicMock(), MagicMock() + graph0.root, graph1.root = root0, root1 + node0 = MagicMock(id=f"{GraphConst.APIS_BETWEEN_MODULES}.0") + node0_sub1 = MagicMock(id="sub_op.0") + node0.subnodes = [node0_sub1] + node1 = MagicMock(id=f"{GraphConst.APIS_BETWEEN_MODULES}.0") + node1_sub1 = MagicMock(id="sub_op.0") + graph0.node_map = {f"{GraphConst.APIS_BETWEEN_MODULES}.0": node0} + node1.subnodes = [node1_sub1] + root0.subnodes = [node0] + root1.subnodes = [node1] + + self.merger.merge_graph_api_collection(results) + + self.assertEqual(len(root0.subnodes), 1) + self.assertTrue(root0.subnodes[0].id.startswith(GraphConst.APIS_BETWEEN_MODULES_ALL_RANKS)) + self.assertEqual(len(root0.subnodes[0].subnodes), 1) + + def test_split_graph_results_by_groups(self): + groups = [[0, 1], [2, 3]] + results = [MagicMock(rank=i) for i in range(4)] + self.merger.build_graph_results = results + split = self.merger.split_graph_results_by_groups(groups) + self.assertEqual(len(split), 2) + self.assertEqual([r.rank for r in split[0]], [0, 1]) + self.assertEqual([r.rank for r in split[1]], [2, 3]) + + def test_compare_node_param_data(self): + main_node = MagicMock() + other_nodes = [MagicMock()] + main_node.id = "id" + other_nodes[0].id = "id" + main_node.input_data = {"input.0": {Const.DTYPE: "torch.float16", Const.MAX: 1}} + other_nodes[0].input_data = {"input.0": {Const.DTYPE: "torch.float16", Const.MAX: 2}} + in_diff, out_diff = self.merger.compare_node_param_data(main_node, other_nodes) + self.assertEqual(list(in_diff.keys()), ["input.0"]) + + def test_compare_param_same(self): + param1 = {Const.MAX: 1, Const.MIN: 0, Const.MEAN: 0.5, Const.NORM: 1} + param2 = {Const.MAX: 1, Const.MIN: 0, Const.MEAN: 0.5, Const.NORM: 1} + self.assertTrue(self.merger.compare_param_same(param1, param2)) + + param2[Const.MAX] = 2 + self.assertFalse(self.merger.compare_param_same(param1, param2)) + + def test_get_default_groups(self): + self.parallel_param.tp = 4 + self.parallel_param.pp = 2 + self.parallel_param.rank_size = 8 + merger = BaseGraphMerger(self.build_graph_results, self.parallel_param, self.is_bench) + tp_groups, pp_groups = merger.get_default_groups() + self.assertEqual(tp_groups, [[0, 1, 2, 3], [4, 5, 6, 7]]) + self.assertEqual(pp_groups, [[0, 4], [1, 5], [2, 6], [3, 7]]) + + self.parallel_param.tp = 2 + self.parallel_param.pp = 2 + self.parallel_param.rank_size = 8 + merger = BaseGraphMerger(self.build_graph_results, self.parallel_param, self.is_bench) + tp_groups, pp_groups = merger.get_default_groups() + self.assertEqual(tp_groups, [[0, 1], [2, 3], [4, 5], [6, 7]]) + self.assertEqual(pp_groups, [[0, 2], [1, 3], [4, 6], [5, 7]]) + + self.parallel_param.tp = 2 + self.parallel_param.pp = 3 + self.parallel_param.rank_size = 8 + merger = BaseGraphMerger(self.build_graph_results, self.parallel_param, self.is_bench) + with self.assertRaises(MsprobeException): + merger.get_default_groups() + + def test_add_all_nodes_rank(self): + graph0, graph1 = MagicMock(), MagicMock() + node0, node1 = MagicMock(), MagicMock() + graph0.node_map.values.return_value = [node0] + graph1.node_map.values.return_value = [node1] + self.build_graph_results[0].graph = graph0 + self.build_graph_results[1].graph = graph1 + + self.merger._add_all_nodes_rank() + + self.assertEqual(node0.rank, 0) + self.assertEqual(node1.rank, 1) + + +class TestPPMerger(unittest.TestCase): + def setUp(self): + self.build_graph_results = [MagicMock(rank=i) for i in range(4)] + self.parallel_param = MagicMock(tp=1, pp=4, rank_size=4) + self.is_bench = False + self.merger = PPMerger(self.build_graph_results, self.parallel_param, self.is_bench) + + def test_trace_p2p_mapping(self): + p2p_mapping = {0: 2, 1: 3, 2: 4, 3: 5, 4: 6, 5: 7, 6: 4, 7: 5} + chains = self.merger._trace_p2p_mapping(p2p_mapping) + self.assertEqual(len(chains), 2) + self.assertIn([0, 2, 4, 6], chains) + self.assertIn([1, 3, 5, 7], chains) + + @patch('msprobe.visualization.builder.graph_merger.PPMerger._merge_nodes') + def test_merge_nodes(self, mock_merge): + main_graph = MagicMock() + main_node = MagicMock(id="module.layers.0.forward") + other_graphs = [MagicMock() for _ in range(3)] + for i, g in enumerate(other_graphs): + g.get_node.return_value = MagicMock(id=f"module.layers.{i}.forward") + + self.merger._merge_nodes(main_graph, main_node, other_graphs) + mock_merge.assert_called() + + def test_merge_graphs(self): + self.merger.get_groups = MagicMock(return_value=[[0, 1, 2, 3]]) + self.merger.merge_pp_graphs = MagicMock(return_value=self.build_graph_results[:1]) + results = self.merger.merge_graphs() + self.assertEqual(len(results), 1) + + def test_get_groups(self): + for i, result in enumerate(self.build_graph_results): + graph = MagicMock() + node = MagicMock(id=f"Distributed.send.{i}.forward") + node.input_data = {f"Distributed.send.{i}.forward.input.dst": {"value": (i + 1) % 4}} + graph.node_map.values.return_value = [node] + result.graph = graph + + groups = self.merger.get_groups() + self.assertEqual(len(groups), 1) + self.assertEqual(groups[0], [0, 1, 2, 3]) + + def test_merge_other_unique_nodes(self): + main_graph = MagicMock() + main_node = MagicMock() + other_nodes = [MagicMock()] + main_node.subnodes = [MagicMock(id="main_sub.0")] + other_nodes[0].subnodes = [MagicMock(id="other_sub.0")] + + self.merger._merge_other_unique_nodes(main_graph, main_node, other_nodes) + self.assertEqual(len(main_node.subnodes), 2) + + def test_sort_nodes(self): + graph = MagicMock() + start_node = MagicMock(id="module.layers.0.forward%0%0") + start_node.op = NodeOp.module + api_node = MagicMock(id="Torch.mul.forward.0%0%0") + graph.node_map = {"module.layers.0.forward%0%0": start_node, "Torch.mul.forward.0%0%0": api_node} + parent_node = MagicMock() + parent_node.subnodes = [start_node, api_node] + start_node.upnode = parent_node + + self.merger._sort_nodes(graph, start_node) + self.assertEqual(parent_node.subnodes[0].id, "module.layers.0.forward") + self.assertEqual(parent_node.subnodes[1].id, "Torch.mul_rank0.forward.0") + + def test_add_node_to_main_graph(self): + graph = MagicMock() + node = MagicMock() + subnode = MagicMock() + node.subnodes = [subnode] + + self.merger._add_node_to_main_graph(graph, node) + graph.node_map.__setitem__.assert_has_calls([call(node.id, node), call(subnode.id, subnode)]) + + def test_get_node_sort_rule(self): + node = MagicMock(id="module.layers.0.forward%1%2") + self.assertEqual(self.merger._get_node_sort_rule(node), (2, 1)) + self.assertEqual(self.merger._get_node_sort_rule(node, rank_ascending=False), (-2, 1)) + + def test_mark_node_id_position_rank(self): + node = MagicMock() + parent_node = MagicMock() + parent_node.subnodes = [MagicMock(), node, MagicMock()] + node.upnode = parent_node + node.id = "module.layers.0.forward" + + self.merger._mark_node_id_position_rank(node, 2) + self.assertEqual(node.id, "module.layers.0.forward%1%2") + + def test_update_node_id(self): + graph = MagicMock() + start_node = MagicMock(id="module.layers.0.forward%1%2") + start_node.op = NodeOp.module + start_node.pp_index = 1 + graph.node_map = {start_node.id: start_node} + + self.merger._update_node_id(graph, start_node) + self.assertEqual(start_node.id, "module.layers.1.forward") + + +class TestTPMerger(unittest.TestCase): + def setUp(self): + self.build_graph_results = [MagicMock(rank=i) for i in range(4)] + self.parallel_param = MagicMock(tp=4, pp=1, rank_size=4) + self.is_bench = False + self.merger = TPMerger(self.build_graph_results, self.parallel_param, self.is_bench) + + def test_merge_params(self): + params = { + "input.0": [ + {Const.MAX: 1, Const.MIN: 0, Const.MEAN: 0.5, Const.NORM: 1}, + {Const.MAX: 2, Const.MIN: 0, Const.MEAN: 0.7, Const.NORM: 1.2} + ] + } + merge_info = self.merger._merge_params(params) + self.assertIn("The Max value merging method for input.0 is: max(1, 2) = 2", merge_info) + self.assertIn("The Mean value merging method for input.0 is: (0.5 + 0.7) / 2 = 0.6", merge_info) + + def test_get_need_merge_node(self): + main_node = MagicMock(id="module.matmul_rank0.forward") + other_graphs = [MagicMock() for _ in range(3)] + tp_merge_mapping = {0: [1, 2, 3]} + + for i, g in enumerate(other_graphs): + g.node_map = {f"module.matmul_rank{i + 1}.forward": MagicMock()} + + nodes = self.merger._get_need_merge_node(main_node, other_graphs, tp_merge_mapping) + self.assertEqual(len(nodes), 0) + + def test_merge_graphs(self): + self.merger.get_groups = MagicMock(return_value=[[0, 1, 2, 3]]) + self.merger.merge_tp_graphs = MagicMock(return_value=self.build_graph_results[:1]) + results = self.merger.merge_graphs() + self.assertEqual(len(results), 1) + + def test_get_groups(self): + for i, result in enumerate(self.build_graph_results): + graph = MagicMock() + node = MagicMock(id=f"all_reduce.{i}") + node.input_data = {f"all_reduce.{i}.input.group": {"group_ranks": [0, 1, 2, 3]}} + graph.node_map.values.return_value = [node] + result.graph = graph + + groups = self.merger.get_groups() + self.assertEqual(len(groups), 1) + self.assertEqual(groups[0], [0, 1, 2, 3]) + + def test_handle_tp_matmul_reduce(self): + node = MagicMock(id=f"module.RowParallelLinear.forward.0") + node.op = NodeOp.module + matmul_node = MagicMock(id="matmul.0") + matmul_node.output_data = {"output.0": {Const.MAX: 1}} + reduce_node = MagicMock(id="all_reduce.0") + reduce_node.input_data = {"input.0": {Const.MAX: 1}} + reduce_node.output_data = {"output.0": {Const.MAX: 2}} + node.subnodes = [matmul_node, reduce_node] + other_graphs = [MagicMock()] + + self.merger._handle_tp_matmul_reduce(node, other_graphs, {}) + self.assertEqual(matmul_node.output_data["output.0"][Const.MAX], 2) + + +class TestNoParallelMerger(unittest.TestCase): + def setUp(self): + self.build_graph_results = [MagicMock()] + self.parallel_param = MagicMock(tp=1, pp=1, rank_size=1) + self.is_bench = False + self.merger = NoParallelMerger(self.build_graph_results, self.parallel_param, self.is_bench) + + def test_merge_graphs(self): + self.merger.merge_graph_api_collection = MagicMock() + results = self.merger.merge_graphs() + self.assertEqual(results, self.build_graph_results) + self.merger.merge_graph_api_collection.assert_called_once_with(self.build_graph_results) + + +class TestTPPPMerger(unittest.TestCase): + def setUp(self): + self.build_graph_results = [MagicMock(rank=i) for i in range(4)] + self.parallel_param = MagicMock(tp=2, pp=2, rank_size=4) + self.is_bench = False + self.merger = TPPPMerger(self.build_graph_results, self.parallel_param, self.is_bench) + + @patch('msprobe.visualization.builder.graph_merger.TPMerger') + @patch('msprobe.visualization.builder.graph_merger.PPMerger') + def test_merge_graphs(self, mock_pp, mock_tp): + tp_merger = MagicMock() + pp_merger = MagicMock() + mock_tp.return_value = tp_merger + mock_pp.return_value = pp_merger + + pp_merger.get_groups.return_value = [[0, 1], [2, 3]] + tp_merger.get_groups.return_value = [[0, 2], [1, 3]] + tp_merger.merge_tp_graphs.return_value = [MagicMock()] + + results = self.merger.merge_graphs() + self.assertEqual(len(results), 1) + + +class TestFullMerger(unittest.TestCase): + def setUp(self): + self.build_graph_results = [MagicMock(rank=i) for i in range(8)] + self.parallel_param = MagicMock(tp=2, pp=4, rank_size=8) + self.is_bench = False + self.merger = FullMerger(self.build_graph_results, self.parallel_param, self.is_bench) + + @patch('msprobe.visualization.builder.graph_merger.TPMerger') + @patch('msprobe.visualization.builder.graph_merger.PPMerger') + def test_merge_graphs(self, mock_pp, mock_tp): + tp_merger = MagicMock() + pp_merger = MagicMock() + mock_tp.return_value = tp_merger + mock_pp.return_value = pp_merger + + pp_merger.get_groups.return_value = [[0, 1, 2, 3], [4, 5, 6, 7]] + tp_merger.get_groups.return_value = [[0, 4], [1, 5], [2, 6], [3, 7]] + + pp_result0 = MagicMock(rank=0) + pp_result1 = MagicMock(rank=4) + pp_merger.merge_pp_graphs.side_effect = [[pp_result0], [pp_result1]] + + tp_merger.merge_tp_graphs.side_effect = [[MagicMock()], [MagicMock()]] + + results = self.merger.merge_graphs() + self.assertEqual(len(results), 1) + + +if __name__ == '__main__': + unittest.main() diff --git a/debug/accuracy_tools/msprobe/test/visualization_ut/builder/test_msprobe_adapter.py b/debug/accuracy_tools/msprobe/test/visualization_ut/builder/test_msprobe_adapter.py index bee32a34a0509d5559b47d7a1625f618dc132d4e..e2ca516542a9840e0230a58eca5d0ad20c6f7579 100644 --- a/debug/accuracy_tools/msprobe/test/visualization_ut/builder/test_msprobe_adapter.py +++ b/debug/accuracy_tools/msprobe/test/visualization_ut/builder/test_msprobe_adapter.py @@ -11,6 +11,7 @@ from msprobe.visualization.builder.msprobe_adapter import ( _format_data ) from msprobe.visualization.utils import GraphConst +from msprobe.visualization.graph.base_node import BaseNode import torch from msprobe.core.common.const import Const @@ -55,11 +56,9 @@ class TestMsprobeAdapter(unittest.TestCase): @patch('msprobe.visualization.builder.msprobe_adapter.get_accuracy') def test_compare_node(self, mock_get_accuracy): - node_ids = ["node1", "node2"] - data_dicts = [{'node1': {"input_args": [], "input_kwargs": {}, "output": {}}}, - {'node2': {"input_args": [], "input_kwargs": {}, "output": {}}}] - stack_json_data = {} - result = compare_node(node_ids, data_dicts, stack_json_data, GraphConst.REAL_DATA_COMPARE) + node_n = BaseNode('', 'node1') + node_b = BaseNode('', 'node2') + result = compare_node(node_n, node_b, GraphConst.REAL_DATA_COMPARE) mock_get_accuracy.assert_called_once() self.assertIsInstance(result, list) diff --git a/debug/accuracy_tools/msprobe/test/visualization_ut/compare/test_graph_comparator.py b/debug/accuracy_tools/msprobe/test/visualization_ut/compare/test_graph_comparator.py index f4d68ccb530919dbdfedaa12bea716b2c70e278d..4accdacd76a434b6329a9fce378e38927092e9ae 100644 --- a/debug/accuracy_tools/msprobe/test/visualization_ut/compare/test_graph_comparator.py +++ b/debug/accuracy_tools/msprobe/test/visualization_ut/compare/test_graph_comparator.py @@ -1,5 +1,6 @@ import os import unittest +from typing import Any from dataclasses import dataclass from unittest.mock import patch from unittest.mock import MagicMock @@ -12,7 +13,7 @@ from msprobe.visualization.utils import GraphConst class Args: input_path: str = None output_path: str = None - layer_mapping: str = None + layer_mapping: Any = None framework: str = None overflow_check: bool = False fuzzy_match: bool = False @@ -39,7 +40,7 @@ class TestGraphComparator(unittest.TestCase): mock_load_data_json_file.return_value = "data_dict" mock_load_json_file.return_value = "construct_dict" mock_get_compare_mode.return_value = GraphConst.SUMMARY_COMPARE - self.comparator = GraphComparator(self.graphs, self.dump_path_param, Args(output_path=self.output_path)) + self.comparator = GraphComparator(self.graphs, self.dump_path_param, Args(output_path=self.output_path), False) self.comparator._parse_param(self.dump_path_param, self.output_path) self.assertEqual(self.comparator.dump_path_param, { @@ -57,7 +58,7 @@ class TestGraphComparator(unittest.TestCase): mock_load_data_json_file.return_value = "data_dict" mock_load_json_file.return_value = "construct_dict" mock_get_compare_mode.return_value = GraphConst.SUMMARY_COMPARE - comparator = GraphComparator(self.graphs, self.dump_path_param, Args(output_path=self.output_path)) + comparator = GraphComparator(self.graphs, self.dump_path_param, Args(output_path=self.output_path), False) comparator._compare_nodes = MagicMock() comparator._postcompare = MagicMock() @@ -76,7 +77,7 @@ class TestGraphComparator(unittest.TestCase): node = MagicMock() compare_result_list = [("output1", "data1"), ("input1", "data2")] - comparator = GraphComparator(self.graphs, self.dump_path_param, Args(output_path=self.output_path)) + comparator = GraphComparator(self.graphs, self.dump_path_param, Args(output_path=self.output_path), False) comparator.ma = MagicMock() comparator.ma.prepare_real_data.return_value = True @@ -100,7 +101,7 @@ class TestGraphComparator(unittest.TestCase): mock_run_real_data.return_value = mock_df mock_get_csv_df.return_value = mock_df mock_get_node_error_status.return_value = True - comparator = GraphComparator(self.graphs, self.dump_path_param, Args(output_path=self.output_path)) + comparator = GraphComparator(self.graphs, self.dump_path_param, Args(output_path=self.output_path), False) comparator.ma = MagicMock() comparator.ma.compare_mode = GraphConst.REAL_DATA_COMPARE comparator._handle_api_collection_index = MagicMock() @@ -118,7 +119,7 @@ class TestGraphComparator(unittest.TestCase): mock_load_data_json_file.return_value = "data_dict" mock_load_json_file.return_value = "construct_dict" mock_get_compare_mode.return_value = GraphConst.SUMMARY_COMPARE - comparator = GraphComparator(self.graphs, self.dump_path_param, Args(output_path=self.output_path)) + comparator = GraphComparator(self.graphs, self.dump_path_param, Args(output_path=self.output_path), False) apis = BaseNode(NodeOp.api_collection, 'Apis_Between_Modules.0') api1 = BaseNode(NodeOp.function_api, 'Tensor.a.0') api1.data = {GraphConst.JSON_INDEX_KEY: 0.9} @@ -145,11 +146,12 @@ class TestGraphComparator(unittest.TestCase): mock_get_compare_mode.return_value = GraphConst.SUMMARY_COMPARE mock_mapping_match.return_value = (node_b, [], []) mock_compare_node.return_value = ['result'] - comparator = GraphComparator(self.graphs, self.dump_path_param, Args(output_path=self.output_path)) + comparator = GraphComparator(self.graphs, self.dump_path_param, Args(output_path=self.output_path, layer_mapping=True), True) comparator.mapping_dict = True comparator._compare_nodes(node_n) self.assertEqual(node_n.matched_node_link, ['Tensor.b.0']) self.assertEqual(node_b.matched_node_link, ['Tensor.a.0']) + comparator = GraphComparator(self.graphs, self.dump_path_param, Args(output_path=self.output_path), False) comparator.mapping_dict = False node_n = BaseNode(NodeOp.function_api, 'Tensor.a.0') node_b = BaseNode(NodeOp.function_api, 'Tensor.a.0') @@ -185,6 +187,6 @@ class TestGraphComparator(unittest.TestCase): 'stack_json_path': os.path.join(dir_name, 'input', 'step0', 'rank0', 'stack.json'), 'is_print_compare_log': True } - comparator = GraphComparator(self.graphs, dump_path_param, Args(output_path=self.output_path)) + comparator = GraphComparator(self.graphs, dump_path_param, Args(output_path=self.output_path), False) comparator.add_compare_result_to_node(node, compare_result_list) self.assertEqual(node.data, {'precision_index': 0}) diff --git a/debug/accuracy_tools/msprobe/test/visualization_ut/compare/test_mode_adapter.py b/debug/accuracy_tools/msprobe/test/visualization_ut/compare/test_mode_adapter.py index 87d1f9ee5f01c7c9b2f264f3e6ec16b5155c1f8e..5f9a64f04dd7d4bffe0881519c2aa1264c105898 100644 --- a/debug/accuracy_tools/msprobe/test/visualization_ut/compare/test_mode_adapter.py +++ b/debug/accuracy_tools/msprobe/test/visualization_ut/compare/test_mode_adapter.py @@ -2,7 +2,8 @@ import json import unittest from unittest.mock import patch, MagicMock from msprobe.visualization.compare.mode_adapter import ModeAdapter -from msprobe.visualization.graph.base_node import BaseNode, NodeOp +from msprobe.visualization.graph.base_node import BaseNode +from msprobe.visualization.graph.node_op import NodeOp from msprobe.visualization.utils import GraphConst, ToolTip from msprobe.core.common.const import CompareConst @@ -225,27 +226,6 @@ class TestModeAdapter(unittest.TestCase): self.adapter.add_csv_data(compare_result_list) self.assertEqual(self.adapter.csv_data, compare_result_list) - def test_add_error_key(self): - node_data = {'key': {}} - self.adapter.compare_mode = GraphConst.REAL_DATA_COMPARE - self.adapter.add_error_key(node_data) - self.assertEqual(node_data['key'][GraphConst.ERROR_KEY], - [CompareConst.ONE_THOUSANDTH_ERR_RATIO, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO]) - node_data = {'key': {}} - self.adapter.compare_mode = GraphConst.SUMMARY_COMPARE - self.adapter.add_error_key(node_data) - self.assertEqual(node_data['key'][GraphConst.ERROR_KEY], - [CompareConst.MAX_RELATIVE_ERR, CompareConst.MIN_RELATIVE_ERR, - CompareConst.MEAN_RELATIVE_ERR, CompareConst.NORM_RELATIVE_ERR]) - node_data = {'key': []} - self.adapter.add_error_key(node_data) - self.assertEqual(node_data['key'], []) - - node_data = {'key': {}} - self.adapter.compare_mode = '111' - self.adapter.add_error_key(node_data) - self.assertEqual(node_data['key'], {'error_key': []}) - def test_get_tool_tip(self): self.adapter.compare_mode = GraphConst.MD5_COMPARE tips = self.adapter.get_tool_tip() diff --git a/debug/accuracy_tools/msprobe/test/visualization_ut/compare/test_multi_mapping.py b/debug/accuracy_tools/msprobe/test/visualization_ut/compare/test_multi_mapping.py new file mode 100644 index 0000000000000000000000000000000000000000..7fe14317b2af7334693270d060c58af2dada4cbc --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/visualization_ut/compare/test_multi_mapping.py @@ -0,0 +1,114 @@ +import unittest +from msprobe.visualization.compare.multi_mapping import MultiMapping +from msprobe.visualization.graph.graph import Graph +from msprobe.visualization.graph.base_node import BaseNode +from msprobe.visualization.graph.node_op import NodeOp +from msprobe.visualization.utils import GraphConst + + +class TestMultiMapping(unittest.TestCase): + + def setUp(self): + pass + + def test_validate_yaml(self): + multi_mapping = MultiMapping.validate_yaml({}) + self.assertEqual(multi_mapping, {}) + + multi_mapping = MultiMapping.validate_yaml([]) + self.assertEqual(multi_mapping, {}) + + multi_mapping = MultiMapping.validate_yaml({'a': 'b'}) + self.assertEqual(multi_mapping, {('a',): ('b',)}) + + multi_mapping = MultiMapping.validate_yaml({'a': 'b c d'}) + self.assertEqual(multi_mapping, {('a',): ('b c d',)}) + + multi_mapping = MultiMapping.validate_yaml({'a': 'b, c, d'}) + self.assertEqual(multi_mapping, {('a',): ('b', 'd')}) + + def test_validate_ids_in_graph(self): + graph = Graph("model_name") + graph.node_map = {'node1': BaseNode(NodeOp.module, 'node1'), + 'node2': BaseNode(NodeOp.module, 'node2'), + 'node3': BaseNode(NodeOp.module, 'node3')} + result = MultiMapping.validate_ids_in_graph(['node1', 'node3'], graph) + self.assertTrue(result) + + result = MultiMapping.validate_ids_in_graph(['node1', 'node5'], graph) + self.assertFalse(result) + + def test_get_merged_nodes_data(self): + node_ids = ['Module.layer1.Linear.forward.0', 'Module.layer3.Linear.forward.0'] + dump_data = {'Module.layer1.Linear.forward.0': {'input_args': [ + {'type': 'torch.Tensor', 'dtype': 'torch.float32', 'shape': [100, 10], 'Max': 3.029174327850342, + 'Min': -3.405808448791504, 'Mean': -0.08760099112987518, 'Norm': 31.511741638183594, + 'requires_grad': False}], 'input_kwargs': {}, 'output': [ + {'type': 'torch.Tensor', 'dtype': 'torch.float32', 'shape': [100, 20], 'Max': 2.280996561050415, + 'Min': -2.6040544509887695, 'Mean': -0.05008987337350845, 'Norm': 26.9143123626709, + 'requires_grad': True}], 'parameters': { + 'weight': {'type': 'torch.Tensor', 'dtype': 'torch.float32', 'shape': [20, 10], 'Max': 0.31333038210868835, + 'Min': -0.3147874176502228, 'Mean': -0.007642852142453194, 'Norm': 2.594407558441162, + 'requires_grad': True}, + 'bias': {'type': 'torch.Tensor', 'dtype': 'torch.float32', 'shape': [20], 'Max': 0.3160688579082489, + 'Min': -0.31076428294181824, 'Mean': -0.05035770684480667, 'Norm': 0.8817608952522278, + 'requires_grad': True}}, 'is_recompute': False}, + 'Module.layer3.Linear.forward.0': {'input_args': [ + {'type': 'torch.Tensor', 'dtype': 'torch.float32', 'shape': [100, 30], 'Max': 1.8936877250671387, + 'Min': -1.60052490234375, 'Mean': -0.05550510436296463, 'Norm': 21.1639404296875, + 'requires_grad': True}], 'input_kwargs': {}, 'output': [ + {'type': 'torch.Tensor', 'dtype': 'torch.float32', 'shape': [100, 1], 'Max': 0.8175169229507446, + 'Min': -0.3781408369541168, 'Mean': 0.16728776693344116, 'Norm': 2.627354145050049, + 'requires_grad': True}], 'parameters': { + 'weight': {'type': 'torch.Tensor', 'dtype': 'torch.float32', 'shape': [1, 30], + 'Max': 0.17745383083820343, 'Min': -0.11874081194400787, 'Mean': 0.013812449760735035, + 'Norm': 0.48705562949180603, 'requires_grad': True}, + 'bias': {'type': 'torch.Tensor', 'dtype': 'torch.float32', 'shape': [1], 'Max': 0.1430283486843109, + 'Min': 0.1430283486843109, 'Mean': 0.1430283486843109, 'Norm': 0.1430283486843109, + 'requires_grad': True}}, 'is_recompute': False}} + multi_node_data = {'input_args': [ + {'type': 'torch.Tensor', 'dtype': 'torch.float32', 'shape': [100, 10], 'Max': 3.029174327850342, + 'Min': -3.405808448791504, 'Mean': -0.08760099112987518, 'Norm': 31.511741638183594, + 'requires_grad': False}], 'input_kwargs': {}, 'output': [ + {'type': 'torch.Tensor', 'dtype': 'torch.float32', 'shape': [100, 1], 'Max': 0.8175169229507446, + 'Min': -0.3781408369541168, 'Mean': 0.16728776693344116, 'Norm': 2.627354145050049, + 'requires_grad': True}]} + result = MultiMapping.get_merged_nodes_data(node_ids, dump_data, 'multi_node0') + self.assertEqual(result, {'multi_node0': multi_node_data}) + result = MultiMapping.get_merged_nodes_data([], dump_data, 'multi_node0') + self.assertEqual(result, {}) + + def test_merge_nodes(self): + graph = Graph('graph') + graph.add_node(NodeOp.module, 'Module.layer1.Linear.forward.0', graph.root) + graph.add_node(NodeOp.module, 'Module.layer2.Linear.forward.0', graph.root) + graph.add_node(NodeOp.module, 'Module.layer3.Linear.forward.0', graph.root) + result = MultiMapping.merge_nodes(['Module.layer1.Linear.forward.0', 'Module.layer3.Linear.forward.0'], + graph) + self.assertTrue(isinstance(result.multi_node, BaseNode)) + self.assertEqual(result.multi_node.subnodes, [graph.get_node('Module.layer1.Linear.forward.0'), + graph.get_node('Module.layer2.Linear.forward.0'), + graph.get_node('Module.layer3.Linear.forward.0')]) + self.assertEqual(result.multi_node.upnode, graph.get_node('graph')) + self.assertEqual(result.multi_node.id, GraphConst.MERGE_NODES + '.forward.0') + + result = MultiMapping.merge_nodes(['Module.layer1.Linear.forward.0'], graph) + self.assertEqual(result.multi_node, graph.get_node('Module.layer1.Linear.forward.0')) + + result = MultiMapping.merge_nodes(['Module.layer5.Linear.forward.0', 'Module.layer6.Linear.forward.0'], + graph) + self.assertIsNone(result.multi_node) + + result = MultiMapping.merge_nodes(['Module.layer3.Linear.forward.0', 'Module.layer1.Linear.forward.0'], + graph) + self.assertIsNone(result.multi_node) + + def test_split_mapping_str(self): + result = MultiMapping._split_mapping_str('a, b,c, d') + self.assertEqual(result, ('a', 'd')) + + result = MultiMapping._split_mapping_str('a') + self.assertEqual(result, ('a',)) + + result = MultiMapping._split_mapping_str('a b* c ') + self.assertEqual(result, ('a b* c',)) diff --git a/debug/accuracy_tools/msprobe/test/visualization_ut/graph/test_base_node.py b/debug/accuracy_tools/msprobe/test/visualization_ut/graph/test_base_node.py index 480b95620e6a81577d825b7af55b45fc0a04c34c..64b7101c6b036113e018faec649974753acdaec3 100644 --- a/debug/accuracy_tools/msprobe/test/visualization_ut/graph/test_base_node.py +++ b/debug/accuracy_tools/msprobe/test/visualization_ut/graph/test_base_node.py @@ -1,6 +1,6 @@ import unittest -from msprobe.visualization.graph.base_node import BaseNode, NodeOp -from msprobe.visualization.utils import GraphConst +from msprobe.visualization.graph.base_node import BaseNode +from msprobe.visualization.graph.node_op import NodeOp class TestBaseNode(unittest.TestCase): diff --git a/debug/accuracy_tools/msprobe/test/visualization_ut/graph/test_graph.py b/debug/accuracy_tools/msprobe/test/visualization_ut/graph/test_graph.py index 81f9fdca5277de6e1670da409bcf93e56ece3206..24f39cbb808234cfce6af02046755d3df3a1a5e4 100644 --- a/debug/accuracy_tools/msprobe/test/visualization_ut/graph/test_graph.py +++ b/debug/accuracy_tools/msprobe/test/visualization_ut/graph/test_graph.py @@ -55,17 +55,6 @@ class TestGraph(unittest.TestCase): self.assertIsNotNone(matched_node) self.assertEqual(ancestors, ['node_id_a']) - def test_dfs(self): - graph = Graph("model_name") - graph.add_node(NodeOp.module, "node_a") - graph.add_node(NodeOp.module, "node_b") - node_a = BaseNode(self.node_op, self.node_id) - result = {} - graph.dfs(node_a, result) - self.assertEqual(result, {'node_id': {'id': 'node_id', 'node_type': 0, 'data': {}, - 'output_data': {}, 'input_data': {}, 'upnode': 'None', 'subnodes': [], - 'matched_node_link': [], 'suggestions': {}, 'stack_info': []}}) - def test_split_nodes_by_micro_step(self): nodes = [BaseNode(NodeOp.module, 'a.forward.0'), BaseNode(NodeOp.module, 'a.backward.0'), BaseNode(NodeOp.api_collection, 'apis.0'), BaseNode(NodeOp.module, 'a.forward.1'), diff --git a/debug/accuracy_tools/msprobe/test/visualization_ut/input/step0/rank0/dump.json b/debug/accuracy_tools/msprobe/test/visualization_ut/input/step0/rank0/dump.json index 330122252bd65cb01bbf9f0cd6c912f407b32a28..18501445cf403fecabace4e817f79e3b29edace0 100644 --- a/debug/accuracy_tools/msprobe/test/visualization_ut/input/step0/rank0/dump.json +++ b/debug/accuracy_tools/msprobe/test/visualization_ut/input/step0/rank0/dump.json @@ -2,5 +2,5 @@ "task": "statistics", "level": "mix", "dump_data_dir": null, - "data": {} + "data": {"api": {"input": [{}]}} } diff --git a/debug/accuracy_tools/msprobe/test/visualization_ut/input_format_correct/step0/rank0/dump.json b/debug/accuracy_tools/msprobe/test/visualization_ut/input_format_correct/step0/rank0/dump.json index 330122252bd65cb01bbf9f0cd6c912f407b32a28..d40eabd5eeddb1c0eb723e49b5674a9bb0635fa7 100644 --- a/debug/accuracy_tools/msprobe/test/visualization_ut/input_format_correct/step0/rank0/dump.json +++ b/debug/accuracy_tools/msprobe/test/visualization_ut/input_format_correct/step0/rank0/dump.json @@ -2,5 +2,6 @@ "task": "statistics", "level": "mix", "dump_data_dir": null, - "data": {} + "data": {"api": {"input": [{}]}}, + "framework": "pytorch" } diff --git a/debug/accuracy_tools/msprobe/test/visualization_ut/input_format_correct/step0/rank1/dump.json b/debug/accuracy_tools/msprobe/test/visualization_ut/input_format_correct/step0/rank1/dump.json index 330122252bd65cb01bbf9f0cd6c912f407b32a28..d40eabd5eeddb1c0eb723e49b5674a9bb0635fa7 100644 --- a/debug/accuracy_tools/msprobe/test/visualization_ut/input_format_correct/step0/rank1/dump.json +++ b/debug/accuracy_tools/msprobe/test/visualization_ut/input_format_correct/step0/rank1/dump.json @@ -2,5 +2,6 @@ "task": "statistics", "level": "mix", "dump_data_dir": null, - "data": {} + "data": {"api": {"input": [{}]}}, + "framework": "pytorch" } diff --git a/debug/accuracy_tools/msprobe/test/visualization_ut/input_format_correct/step1/rank0/dump.json b/debug/accuracy_tools/msprobe/test/visualization_ut/input_format_correct/step1/rank0/dump.json index 330122252bd65cb01bbf9f0cd6c912f407b32a28..d40eabd5eeddb1c0eb723e49b5674a9bb0635fa7 100644 --- a/debug/accuracy_tools/msprobe/test/visualization_ut/input_format_correct/step1/rank0/dump.json +++ b/debug/accuracy_tools/msprobe/test/visualization_ut/input_format_correct/step1/rank0/dump.json @@ -2,5 +2,6 @@ "task": "statistics", "level": "mix", "dump_data_dir": null, - "data": {} + "data": {"api": {"input": [{}]}}, + "framework": "pytorch" } diff --git a/debug/accuracy_tools/msprobe/test/visualization_ut/input_format_correct/step2/rank0/dump.json b/debug/accuracy_tools/msprobe/test/visualization_ut/input_format_correct/step2/rank0/dump.json index 330122252bd65cb01bbf9f0cd6c912f407b32a28..d40eabd5eeddb1c0eb723e49b5674a9bb0635fa7 100644 --- a/debug/accuracy_tools/msprobe/test/visualization_ut/input_format_correct/step2/rank0/dump.json +++ b/debug/accuracy_tools/msprobe/test/visualization_ut/input_format_correct/step2/rank0/dump.json @@ -2,5 +2,6 @@ "task": "statistics", "level": "mix", "dump_data_dir": null, - "data": {} + "data": {"api": {"input": [{}]}}, + "framework": "pytorch" } diff --git a/debug/accuracy_tools/msprobe/test/visualization_ut/test_graph_service.py b/debug/accuracy_tools/msprobe/test/visualization_ut/test_graph_service.py index 7dfd9564ebc21327f3e7e29be90da7f78c3b0393..8be6eda8ac7ff86d7d6bde86182a4038d3e1ac9a 100644 --- a/debug/accuracy_tools/msprobe/test/visualization_ut/test_graph_service.py +++ b/debug/accuracy_tools/msprobe/test/visualization_ut/test_graph_service.py @@ -7,7 +7,7 @@ import argparse from dataclasses import dataclass from unittest.mock import patch -from msprobe.visualization.graph_service import _compare_graph, _build_graph, _compare_graph_ranks, \ +from msprobe.visualization.graph_service import _compare_graph_result, _build_graph_result, _compare_graph_ranks, \ _compare_graph_steps, _build_graph_ranks, _build_graph_steps, _graph_service_command, _graph_service_parser from msprobe.core.common.utils import CompareException @@ -21,6 +21,9 @@ class Args: overflow_check: bool = False fuzzy_match: bool = False complete_stack: bool = False + multi_mapping: str = None + parallel_merge: bool = False + parallel_params: tuple = None class TestGraphService(unittest.TestCase): @@ -45,30 +48,31 @@ class TestGraphService(unittest.TestCase): last_call_args = mock_log_info.call_args[0][0] self.assertIn(log_info, last_call_args) matches = re.findall(self.pattern, last_call_args) - self.assertTrue(os.path.exists(os.path.join(self.output, matches[0]))) + if matches: + self.assertTrue(os.path.exists(os.path.join(self.output, matches[0]))) @patch('msprobe.core.common.log.logger.info') - def test_compare_graph(self, mock_log_info): + def test_compare_graph_result(self, mock_log_info): args = Args(output_path=self.output, framework='pytorch') - result = _compare_graph(self.input_param, args) + result = _compare_graph_result(self.input_param, args) self.assertEqual(mock_log_info.call_count, 2) self.assertIsNotNone(result) args = Args(output_path=self.output, framework='mindspore') - result = _compare_graph(self.input_param, args) + result = _compare_graph_result(self.input_param, args) self.assertIsNotNone(result) args = Args(output_path=self.output, framework='pytorch', layer_mapping=self.layer_mapping) - result = _compare_graph(self.input_param, args) + result = _compare_graph_result(self.input_param, args) self.assertIsNotNone(result) args = Args(output_path=self.output, framework='pytorch', overflow_check=True) - result = _compare_graph(self.input_param, args) + result = _compare_graph_result(self.input_param, args) self.assertIsNotNone(result) @patch('msprobe.core.common.log.logger.info') - def test_build_graph(self, mock_log_info): - result = _build_graph(os.path.join(self.input, 'step0', 'rank0'), Args(overflow_check=True)) + def test_build_graph_result(self, mock_log_info): + result = _build_graph_result(os.path.join(self.input, 'step0', 'rank0'), Args(overflow_check=True)) self.assertEqual(mock_log_info.call_count, 1) self.assertIsNotNone(result) @@ -81,7 +85,7 @@ class TestGraphService(unittest.TestCase): } args = Args(output_path=self.output, framework='pytorch') _compare_graph_ranks(input_param, args) - self.assert_log_info(mock_log_info) + self.assert_log_info(mock_log_info, 'Successfully exported compare graph results.') input_param1 = { 'npu_path': os.path.join(self.input, 'step0'), @@ -101,7 +105,7 @@ class TestGraphService(unittest.TestCase): } args = Args(output_path=self.output, framework='pytorch') _compare_graph_steps(input_param, args) - self.assert_log_info(mock_log_info) + self.assert_log_info(mock_log_info, 'Successfully exported compare graph results.') input_param1 = { 'npu_path': self.input, @@ -115,12 +119,12 @@ class TestGraphService(unittest.TestCase): @patch('msprobe.core.common.log.logger.info') def test_build_graph_ranks(self, mock_log_info): _build_graph_ranks(os.path.join(self.input, 'step0'), Args(output_path=self.output)) - self.assert_log_info(mock_log_info, "Model graph built successfully, the result file is saved in") + self.assert_log_info(mock_log_info, "Successfully exported build graph results.") @patch('msprobe.core.common.log.logger.info') def test_build_graph_steps(self, mock_log_info): _build_graph_steps(self.input, Args(output_path=self.output)) - self.assert_log_info(mock_log_info, "Model graph built successfully, the result file is saved in") + self.assert_log_info(mock_log_info, "Successfully exported build graph results.") @patch('msprobe.core.common.log.logger.info') def test_graph_service_command(self, mock_log_info): @@ -129,7 +133,7 @@ class TestGraphService(unittest.TestCase): args = Args(input_path=self.output_json[0], output_path=self.output, framework='pytorch') _graph_service_command(args) - self.assert_log_info(mock_log_info) + self.assert_log_info(mock_log_info, 'Exporting compare graph result successfully, the result file is saved in') input_param1 = { 'npu_path': os.path.join(self.input, 'step0', 'rank0'), @@ -139,7 +143,7 @@ class TestGraphService(unittest.TestCase): json.dump(input_param1, f, indent=4) args = Args(input_path=self.output_json[1], output_path=self.output, framework='pytorch') _graph_service_command(args) - self.assert_log_info(mock_log_info, "Model graph built successfully, the result file is saved in") + self.assert_log_info(mock_log_info, "Model graph exported successfully, the result file is saved in") input_param2 = { 'npu_path': os.path.join(self.input, 'step0'), @@ -150,7 +154,7 @@ class TestGraphService(unittest.TestCase): json.dump(input_param2, f, indent=4) args = Args(input_path=self.output_json[2], output_path=self.output, framework='pytorch') _graph_service_command(args) - self.assert_log_info(mock_log_info) + self.assert_log_info(mock_log_info, 'Successfully exported compare graph results.') input_param3 = { 'npu_path': self.input, @@ -161,7 +165,7 @@ class TestGraphService(unittest.TestCase): json.dump(input_param3, f, indent=4) args = Args(input_path=self.output_json[3], output_path=self.output, framework='pytorch') _graph_service_command(args) - self.assert_log_info(mock_log_info) + self.assert_log_info(mock_log_info, 'Successfully exported compare graph results.') input_param4 = { 'npu_path': os.path.join(self.input, 'step0'), @@ -171,7 +175,7 @@ class TestGraphService(unittest.TestCase): json.dump(input_param4, f, indent=4) args = Args(input_path=self.output_json[4], output_path=self.output, framework='pytorch') _graph_service_command(args) - self.assert_log_info(mock_log_info, "Model graph built successfully, the result file is saved in") + self.assert_log_info(mock_log_info, "Successfully exported build graph results.") input_param5 = { 'npu_path': self.input, @@ -181,7 +185,7 @@ class TestGraphService(unittest.TestCase): json.dump(input_param5, f, indent=4) args = Args(input_path=self.output_json[5], output_path=self.output, framework='pytorch') _graph_service_command(args) - self.assert_log_info(mock_log_info, "Model graph built successfully, the result file is saved in") + self.assert_log_info(mock_log_info, "Successfully exported build graph results.") input_param6 = { 'npu_path': self.input, diff --git a/debug/accuracy_tools/msprobe/test/visualization_ut/test_visualization_utils.py b/debug/accuracy_tools/msprobe/test/visualization_ut/test_visualization_utils.py index e5b0afaadf9def910c248b945ad15084300a65c0..974fcdf8a19d81a6c22a0396f45fc1725b00a39a 100644 --- a/debug/accuracy_tools/msprobe/test/visualization_ut/test_visualization_utils.py +++ b/debug/accuracy_tools/msprobe/test/visualization_ut/test_visualization_utils.py @@ -1,7 +1,7 @@ import os import unittest from msprobe.visualization.utils import (load_json_file, load_data_json_file, str2float, check_directory_content, - GraphConst) + GraphConst, SerializableArgs) class TestMappingConfig(unittest.TestCase): @@ -37,6 +37,20 @@ class TestMappingConfig(unittest.TestCase): input_type = check_directory_content(os.path.join(self.input, "step0", "rank0")) self.assertEqual(input_type, GraphConst.FILES) + def test_serializable_args(self): + class TmpArgs: + def __init__(self, a, b, c): + self.a = a + self.b = b + self.c = c + + input_args1 = TmpArgs('a', 123, [1, 2, 3]) + serializable_args1 = SerializableArgs(input_args1) + self.assertEqual(serializable_args1.__dict__, input_args1.__dict__) + input_args2 = TmpArgs('a', 123, lambda x: print(x)) + serializable_args2 = SerializableArgs(input_args2) + self.assertNotEqual(serializable_args2.__dict__, input_args2.__dict__) + if __name__ == '__main__': unittest.main() diff --git a/debug/accuracy_tools/msprobe/visualization/builder/graph_builder.py b/debug/accuracy_tools/msprobe/visualization/builder/graph_builder.py index 814882e6b819e9e6b6b421aec5f8f0b89f03f7c6..f56c6f9a10e79f834e4f4e96aa5c242cc69d7ac4 100644 --- a/debug/accuracy_tools/msprobe/visualization/builder/graph_builder.py +++ b/debug/accuracy_tools/msprobe/visualization/builder/graph_builder.py @@ -14,21 +14,23 @@ # limitations under the License. import re +from dataclasses import dataclass from msprobe.core.common.const import Const -from msprobe.core.common.file_utils import load_json +from msprobe.core.common.file_utils import load_json, save_json +from msprobe.core.common.utils import load_stack_json from msprobe.visualization.builder.msprobe_adapter import get_input_output from msprobe.visualization.builder.msprobe_adapter import op_patterns from msprobe.visualization.graph.graph import Graph from msprobe.visualization.graph.node_op import NodeOp -from msprobe.visualization.utils import save_json_file, GraphConst +from msprobe.visualization.utils import GraphConst class GraphBuilder: backward_pattern = re.compile(r"(\.backward\.)(\d+)$") forward_pattern = re.compile(r"(\.forward\.)(\d+)$") - # 匹配以大写字母开头,后接任意字母,并以Template(结尾 - template_pattern = re.compile(r'\b[A-Z][a-zA-Z]*Template\(') + # 匹配以大写字母开头,后接任意字母,并以Template(结尾,或包含api_template(的字符串 + template_pattern = re.compile(r'\b([A-Z][a-zA-Z]*Template|api_template)\(') @staticmethod def build(construct_path, data_path, stack_path, model_name='DefaultModel', complete_stack=False): @@ -44,13 +46,14 @@ class GraphBuilder: """ construct_dict = load_json(construct_path) dump_dict = load_json(data_path) - stack_dict = load_json(stack_path) + stack_dict = load_stack_json(stack_path) if not complete_stack: GraphBuilder._simplify_stack(stack_dict) data_dict = dump_dict.get(GraphConst.DATA_KEY, {}) graph = Graph(model_name, data_path=dump_dict.get('dump_data_dir', ''), dump_data=data_dict) GraphBuilder._init_nodes(graph, construct_dict, data_dict, stack_dict) GraphBuilder._collect_apis_between_modules(graph) + GraphBuilder._add_parameters_grad(graph, data_dict) return graph @staticmethod @@ -60,10 +63,10 @@ class GraphBuilder: """ result = {} if config.graph_b: - result[GraphConst.JSON_NPU_KEY] = config.graph_n.to_dict() - result[GraphConst.JSON_BENCH_KEY] = config.graph_b.to_dict() + result[GraphConst.JSON_NPU_KEY] = config.graph_n.to_dict(config.compare_mode) + result[GraphConst.JSON_BENCH_KEY] = config.graph_b.to_dict(config.compare_mode) else: - result = config.graph_n.to_dict() + result = config.graph_n.to_dict(config.compare_mode) if config.tool_tip: result[GraphConst.JSON_TIP_KEY] = config.tool_tip if config.node_colors: @@ -73,7 +76,7 @@ class GraphBuilder: if config.task: result[GraphConst.JSON_TASK_KEY] = config.task result[GraphConst.OVERFLOW_CHECK] = config.overflow_check - save_json_file(filename, result) + save_json(filename, result, indent=4) @staticmethod def _simplify_stack(stack_dict): @@ -186,6 +189,8 @@ class GraphBuilder: # 数据格式:"output": [[{param1}, {param2}, ...]] if GraphBuilder._is_valid_batch_p2p_output(param_list): for param in param_list[0]: + if not isinstance(param, dict): + continue info = {GraphConst.OP: param.get(GraphConst.OP), GraphConst.PEER: param.get(GraphConst.PEER), GraphConst.GROUP_ID: param.get(GraphConst.GROUP_ID)} node.batch_p2p_info.append(info) @@ -235,10 +240,46 @@ class GraphBuilder: graph.root.subnodes = output + @staticmethod + def _add_parameters_grad(graph, data_dict): + """ + 将parameters_grad信息添加到graph中, + 对应模块的parameters_grad节点添加到对应模块的最后一次backward节点(backward计数最大)内作为子节点 + + 例如,graph有节点Module.a.backward.0, Module.a.backward.1, Module.a.backward.2 + 则Module.a.parameters_grad添加在Module.a.backward.2内作为子节点 + """ + prefixes = [] + suffix = Const.SEP + Const.PARAMS_GRAD + for node_id in data_dict.keys(): + if node_id not in graph.node_map and node_id.endswith(suffix): + prefixes.append(node_id.replace(suffix, '')) + + max_info = {prefix: 0 for prefix in prefixes} + + for key in graph.node_map.keys(): + parts = key.split(Const.SEP) + if len(parts) > 2 and parts[-2] == Const.BACKWARD: + num = int(parts[-1]) + prefix = Const.SEP.join(parts[:-2]) + if prefix in max_info and num > max_info[prefix]: + max_info[prefix] = num + + for prefix, num in max_info.items(): + node_id = prefix + Const.SEP + Const.BACKWARD + Const.SEP + str(num) + node = graph.get_node(node_id) + if node: + parameters_grad_node_id = graph.add_node(NodeOp.module, prefix + suffix, up_node=node) + # 添加输入输出数据 + node_data = data_dict.get(parameters_grad_node_id, {}) + input_data, output_data = get_input_output(node_data, parameters_grad_node_id) + # 更新数据 + graph.get_node(parameters_grad_node_id).set_input_output(input_data, output_data) + class GraphExportConfig: def __init__(self, graph_n, graph_b=None, tool_tip=None, node_colors=None, micro_steps=None, task='', - overflow_check=False): + overflow_check=False, compare_mode=None): self.graph_n = graph_n self.graph_b = graph_b self.tool_tip = tool_tip @@ -246,3 +287,21 @@ class GraphExportConfig: self.micro_steps = micro_steps self.task = task self.overflow_check = overflow_check + self.compare_mode = compare_mode + + +@dataclass +class GraphInfo: + graph: Graph + construct_path: str + data_path: str + stack_path: str + + +@dataclass +class BuildGraphTaskInfo: + graph_info_n: GraphInfo + graph_info_b: GraphInfo + npu_rank: str + bench_rank: str + time_str: str diff --git a/debug/accuracy_tools/msprobe/visualization/builder/graph_merger.py b/debug/accuracy_tools/msprobe/visualization/builder/graph_merger.py new file mode 100644 index 0000000000000000000000000000000000000000..96940b99c95050396d6f84be108466e0406fdbd5 --- /dev/null +++ b/debug/accuracy_tools/msprobe/visualization/builder/graph_merger.py @@ -0,0 +1,854 @@ +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +import math + +from msprobe.core.common.const import Const +from msprobe.visualization.graph.graph import Graph, BaseNode +from msprobe.visualization.graph.node_op import NodeOp +from msprobe.core.common.log import logger +from msprobe.visualization.utils import GraphConst +from msprobe.core.common.exceptions import MsprobeException +from msprobe.core.common.decorator import recursion_depth_decorator + +MAX_INFO = 'The Max value merging method for ' +MIN_INFO = 'The Min value merging method for ' +MEAN_INFO = 'The Mean value merging method for ' +NORM_INFO = 'The Norm value merging method for ' + + +class GraphMerger: + def __init__(self, build_graph_results, parallel_param, is_bench=False): + self.strategy = self._select_strategy(build_graph_results, parallel_param, is_bench) + + @staticmethod + def _select_strategy(results, param, is_bench): + if param.tp == param.pp == param.rank_size == 1: + return NoParallelMerger(results, param, is_bench) + elif param.tp == param.rank_size: + return TPMerger(results, param, is_bench) + elif param.pp == param.rank_size: + return PPMerger(results, param, is_bench) + elif param.pp == 1: + return TPMerger(results, param, is_bench) + elif param.tp == 1: + return PPMerger(results, param, is_bench) + elif param.tp * param.pp == param.rank_size: + return TPPPMerger(results, param, is_bench) + else: + return FullMerger(results, param, is_bench) + + def merge_graph(self): + return self.strategy.merge_graphs() + + +class BaseGraphMerger: + def __init__(self, build_graph_results, parallel_param, is_bench): + self.unmerged_module = [Const.CLIP_GRAD, Const.OPTIMIZER] + self.dtype_list = Const.TORCH_INT_DTYPE + Const.TORCH_FLOAT_DTYPE + [Const.FLOAT16, Const.FLOAT32, + Const.BFLOAT16] + self.build_graph_results = build_graph_results + self.parallel_param = parallel_param + self.is_bench = is_bench + self.log_prefix = '[Bench]' if self.is_bench else '[NPU]' + self._add_all_nodes_rank() + + @staticmethod + def sort_merged_api_collection(graph): + def extract_rank(node): + match = re.search(r'_Rank(\d+)', node.id) + return int(match.group(1)) if match else None + + for sub_node in graph.root.subnodes: + if sub_node.op == NodeOp.api_collection and sub_node.id.startswith( + GraphConst.APIS_BETWEEN_MODULES_ALL_RANKS): + sub_node.subnodes = sorted(sub_node.subnodes, key=extract_rank) + + @staticmethod + def _update_node_data_key(old_id, new_id, data_dict): + new_dict = {} + for key, value in data_dict.items(): + new_key = key.replace(old_id, new_id) + if 'full_op_name' in value: + value['full_op_name'] = value.get('full_op_name').replace(old_id, new_id) + new_dict[new_key] = value + return new_dict + + @staticmethod + def _compare_value_same(main_value, other_value, has_uncertainty=False): + if not isinstance(main_value, (int, float)) or not isinstance(other_value, (int, float)): + return True + # 没开启确定性计算,各rank的mean和norm有细微差异,如果相对误差在阈值内则认为是相同的 + if has_uncertainty: + diff = abs(main_value - other_value) + if math.isnan(diff): + return math.isnan(main_value) and math.isnan(other_value) + elif math.isinf(diff): + return math.isinf(main_value) and math.isinf(other_value) + else: + return diff < GraphConst.UNCERTAINTY_THRESHOLD if main_value == 0 else \ + abs(diff / main_value) < GraphConst.UNCERTAINTY_THRESHOLD + else: + return main_value == other_value + + def merge_graphs(self): + raise NotImplementedError("This method should be implemented by subclasses.") + + def merge_graph_api_collection(self, results: list): + """ + graph合并时,将各rank的游离api集合合并为一个总的游离api集合 + example: + rank0: Apis_Between_Modules.0 rank1: Apis_Between_Modules.0 + Module.module.Float16Module.forward.0 Module.module.Float16Module.forward.0 + Apis_Between_Modules.1 Apis_Between_Modules.1 + + merged: Apis_Between_Modules_All_Ranks.0 + |_ Apis_Between_Modules_Rank0.0 + |_ Apis_Between_Modules_Rank1.0 + Module.module.Float16Module.forward.0 + Apis_Between_Modules_All_Ranks.1 + |_ Apis_Between_Modules_Rank0.1 + |_ Apis_Between_Modules_Rank1.1 + """ + main_graph_result = results[0] + main_root_sub_nodes = main_graph_result.graph.root.subnodes + new_main_root_sub_nodes = [] + for main_node in main_root_sub_nodes: + # 如果游离api集合已合并为一个总的游离api集合,总的游离api集合之间还要再合并 + if main_node.id.startswith(GraphConst.APIS_BETWEEN_MODULES_ALL_RANKS): + new_main_root_sub_nodes.append(main_node) + for other_graph_result in results[1:]: + other_node = other_graph_result.graph.get_node(main_node.id) + if not other_node: + continue + for sub_node in other_node.subnodes: + sub_node.upnode = main_node + main_graph_result.graph.node_map[sub_node.id] = sub_node + for sub_sub_node in sub_node.subnodes: + main_graph_result.graph.node_map[sub_sub_node.id] = sub_sub_node + main_node.subnodes.extend(other_node.subnodes) + # 游离api集合合并为一个总的游离api集合 + elif main_node.id.startswith(GraphConst.APIS_BETWEEN_MODULES): + all_collection_node_id = main_graph_result.graph.add_node(NodeOp.api_collection, + GraphConst.APIS_BETWEEN_MODULES_ALL_RANKS, + id_accumulation=True) + all_collection_node = main_graph_result.graph.get_node(all_collection_node_id) + new_main_root_sub_nodes.append(all_collection_node) + # Apis_Between_Modules.0 --> Apis_Between_Modules_Rank0.0 + origin_main_node_id = main_node.id + main_node.id = GraphConst.APIS_BETWEEN_MODULES + f'_Rank{main_graph_result.rank}.' + \ + main_node.id.split(Const.SEP)[-1] + all_collection_node.subnodes = [main_node] + main_node.upnode = all_collection_node + main_graph_result.graph.node_map[main_node.id] = main_node + del main_graph_result.graph.node_map[origin_main_node_id] + for other_graph_result in results[1:]: + other_node = other_graph_result.graph.get_node(origin_main_node_id) + if not other_node: + continue + # Apis_Between_Modules.0 --> Apis_Between_Modules_Rank1.0 + other_node.id = GraphConst.APIS_BETWEEN_MODULES + f'_Rank{other_graph_result.rank}.' + \ + other_node.id.split(Const.SEP)[-1] + main_graph_result.graph.node_map[other_node.id] = other_node + for sub_node in other_node.subnodes: + # api节点,在api名称上添加rank信息 + old_id = sub_node.id + parts = sub_node.id.split(Const.SEP) + parts[1] += f'_rank{other_graph_result.rank}' + sub_node.id = Const.SEP.join(parts) + sub_node.input_data = self._update_node_data_key(old_id, sub_node.id, sub_node.input_data) + sub_node.output_data = self._update_node_data_key(old_id, sub_node.id, sub_node.output_data) + main_graph_result.graph.node_map[sub_node.id] = sub_node + all_collection_node.subnodes.append(other_node) + other_node.upnode = all_collection_node + else: + new_main_root_sub_nodes.append(main_node) + main_graph_result.graph.root.subnodes = new_main_root_sub_nodes + + def split_graph_results_by_groups(self, groups): + """ + 基于pp或tp域,划分待合并的graph + """ + rank_results_mapping = {result.rank: result for result in self.build_graph_results} + return [[rank_results_mapping.get(rank) for rank in ranks] for ranks in groups] + + def compare_node_param_data(self, main_node, other_nodes, compare_data=True): + """ + 当前节点与若干其他节点比较输入输出参数的数据是否一致,如果发现有不一致的参数,将参数暂存于列表中 + """ + if not other_nodes: + return {}, {} + data_types = {'input_data': {}, 'output_data': {}} + for data_type, data_dict in data_types.items(): + main_data_dict = getattr(main_node, data_type) + for key, main_param in main_data_dict.items(): + same_flag = compare_data + if main_param.get(Const.DTYPE) not in self.dtype_list: + continue + tp_need_merge_params = [main_param] + for other_node in other_nodes: + param_key = key.replace(main_node.id, other_node.id) if main_node.id != other_node.id else key + other_param = getattr(other_node, data_type).get(param_key, {}) + if other_param.get(Const.DTYPE) not in self.dtype_list: + break + tp_need_merge_params.append(other_param) + if compare_data and not self.compare_param_same(main_param, other_param, has_uncertainty=True): + same_flag = False + if not same_flag: + # {input.0: [{"Max": 0, "Min": 0, ...}, {"Max": 0.1, "Min": 0, ...}, ...]} + data_dict[key.replace(main_node.id + Const.SEP, '')] = tp_need_merge_params + return data_types.get('input_data'), data_types.get('output_data') + + def compare_param_same(self, main_param, other_param, has_uncertainty=False): + if not self._compare_value_same(main_param.get(Const.MAX), other_param.get(Const.MAX)): + return False + if not self._compare_value_same(main_param.get(Const.MIN), other_param.get(Const.MIN)): + return False + if not self._compare_value_same(main_param.get(Const.MEAN), other_param.get(Const.MEAN), has_uncertainty): + return False + if not self._compare_value_same(main_param.get(Const.NORM), other_param.get(Const.NORM), has_uncertainty): + return False + return True + + def get_default_groups(self): + """ + 根据GPU总数、TP数、PP数初始化并行组 + + return: + tp_groups: 张量并行组列表,每个元素是一个包含组内rank的列表 + pp_groups: 流水线并行组列表,每个元素是一个包含组内rank的列表 + """ + rank_size = self.parallel_param.rank_size + tp_size = self.parallel_param.tp + pp_size = self.parallel_param.pp + + if rank_size % (tp_size * pp_size) != 0: + logger.error(f'{self.log_prefix} The parallel param "rank_size" must be divisible by "tp * pp"!') + raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR) + dp_size = int(rank_size / tp_size / pp_size) + + # 存储并行组信息 + tp_groups = [] + pp_groups = [] + + # 创建张量并行组 + for dp_rank in range(dp_size): + for pp_rank in range(pp_size): + # 计算当前DP组和PP组组合下的TP组的第一个rank + base_rank = dp_rank * tp_size * pp_size + pp_rank * tp_size + group_ranks = [base_rank + tp_rank for tp_rank in range(tp_size)] + tp_groups.append(group_ranks) + + # 创建流水线并行组 + for dp_rank in range(dp_size): + for tp_rank in range(tp_size): + # 计算当前DP组和TP组组合下的PP组的第一个rank + base_rank = dp_rank * tp_size * pp_size + tp_rank + group_ranks = [base_rank + pp_rank * tp_size for pp_rank in range(pp_size)] + pp_groups.append(group_ranks) + + return tp_groups, pp_groups + + def _add_all_nodes_rank(self): + for result in self.build_graph_results: + for node in result.graph.node_map.values(): + node.rank = result.rank + + +class PPMerger(BaseGraphMerger): + LAYERS_PATTERN = re.compile(r"(layers\.|layer\.)\d+(\.)") + MARK_PATTERN = re.compile(r"%(\d+)%(\d+)$") + MARK = '%' + + @staticmethod + def _trace_p2p_mapping(p2p_mapping: dict): + """ + 将字典分组为独立的链,每个链都从未访问过的键开始,按照字典中的映射关系进行追踪 + p2p_mapping内容为p2p通信的send映射,追踪映射关系建立pp域 + example: p2p_mapping={0: 2, 1: 3, 2: 4, 3: 5, 4: 6, 5: 7, 6: 4, 7: 5}, return=[[0, 2, 4, 6], [1, 3, 5, 7]] + """ + visited = set() + result = [] + + def collect_keys(start_key): + """ + 追踪从某一个键开始的所有“连续”键,直到无法再找到下一个键为止 + """ + current_key = start_key + chain = [] + while current_key in p2p_mapping and current_key not in visited: + chain.append(current_key) + visited.add(current_key) + current_key = p2p_mapping[current_key] + return chain + + for key in p2p_mapping: + if key not in visited: + chain_result = collect_keys(key) + if chain_result: + result.append(chain_result) + return result + + @recursion_depth_decorator("msprobe.visualization.builder.graph_merger.PPMerger._merge_nodes", 1000) + def _merge_nodes(self, main_graph, main_node, other_graphs): + """ + 其他rank graph中被pp切分的节点,需要合并到main graph + """ + other_nodes = [] + for other_graph in other_graphs: + other_node = other_graph.get_node(main_node.id) + # 表明此节点只有main graph有 + if not other_node: + other_nodes.clear() + return + other_nodes.append(other_node) + if other_nodes: + param_in, param_out = self.compare_node_param_data(main_node, other_nodes) + # 各个rank都有的模块,且输入输出都不一致,且节点id符合正则,判定为被pp切分的模块,需要合并结构 + pp_merged_condition = param_in and param_out and self.LAYERS_PATTERN.search(main_node.id) + # backward可能没有output,是否要pp合并从对应的forward节点判断 + if Const.SEP + Const.BACKWARD + Const.SEP in main_node.id: + f_node = main_graph.node_map.get( + main_node.id.replace(Const.SEP + Const.BACKWARD + Const.SEP, Const.SEP + Const.FORWARD + Const.SEP)) + if f_node and hasattr(f_node, 'is_pp_merged'): + pp_merged_condition = True + if pp_merged_condition: + main_node.is_pp_merged = True + main_up_node = main_node.upnode + for other_node in other_nodes: + # pp切分中被切分的层在各rank的名称是一样的,这里给其他rank的同名层增加位置和rank标记 + self._mark_node_id_position_rank(other_node, other_node.rank) + self._add_node_to_main_graph(main_graph, other_node) + # 其他rank被pp切分的模块节点添加到当前rank的graph + other_node.upnode = main_up_node + main_up_node.subnodes.append(other_node) + # 已找到被pp切分的模块节点,不再递归其内部 + return + # 各个rank都有的forward模块,且输入一致,输出不一致,判定为模块内部包含被pp切分的模块,此模块的输出要使用最后一个rank的输出 + elif not param_in and param_out and Const.SEP + Const.FORWARD + Const.SEP in main_node.id: + main_node.output_data = other_nodes[-1].output_data + # 各个rank都有的backward模块,且输出一致,输入不一致,判定为模块内部包含被pp切分的模块,此模块的输入要使用最后一个rank的输入 + elif param_in and not param_out and Const.SEP + Const.BACKWARD + Const.SEP in main_node.id: + main_node.input_data = other_nodes[-1].input_data + self._merge_other_unique_nodes(main_graph, main_node, other_nodes) + for sub_node in main_node.subnodes: + if sub_node.op == NodeOp.module: + self._merge_nodes(main_graph, sub_node, other_graphs) + + def merge_graphs(self): + results_groups = self.split_graph_results_by_groups(self.get_groups()) + results = [] + for result_groups in results_groups: + self.merge_graph_api_collection(result_groups) + results.extend(self.merge_pp_graphs(result_groups)) + return results + + def merge_pp_graphs(self, results): + if not results or len(results) < 2: + return results + graphs = [x.graph for x in results] + main_graph_result = results[0] + for main_node in main_graph_result.graph.root.subnodes: + if main_node.op == NodeOp.module and main_node.id not in self.unmerged_module: + self._merge_nodes(main_graph_result.graph, main_node, graphs[1:]) + self._sort_nodes(main_graph_result.graph, main_node) + return [main_graph_result] + + def get_groups(self): + """ + 在各rank寻找p2p通信节点,建立各rank之间p2p的映射关系 + """ + p2p_mapping = {} + for result in self.build_graph_results: + rank = result.rank + pp_rank = None + for node in result.graph.node_map.values(): + if not node.id.startswith(Const.DISTRIBUTED + Const.SEP): + continue + if '.batch_isend_irecv.' in node.id: + for p2p_info in node.batch_p2p_info: + target_rank = p2p_info.get(GraphConst.PEER) + if target_rank is not None and target_rank != rank and p2p_info.get(GraphConst.OP) == 'isend': + pp_rank = target_rank + break + elif '.send.' in node.id or '.isend.' in node.id: + # example: Distributed.isend.0.forward --> Distributed.isend.0.forward.input.dst + dst_kwarg = f'{node.id}{Const.SEP}{Const.INPUT}{Const.SEP}{GraphConst.DST}' + dst = node.input_data.get(dst_kwarg, {}).get('value') + if dst is not None: + pp_rank = dst + break + if pp_rank is not None: + break + if pp_rank is None: + logger.warning(f'{self.log_prefix} Unable to get pp groups because ' + f'the batch_isend_irecv, send, or isend were not found.') + else: + p2p_mapping[rank] = pp_rank + pp_groups = self._trace_p2p_mapping(p2p_mapping) + if not pp_groups: + logger.info('Unable to get pp groups based on Distributed Api, ' + 'generate pp groups using parallel param "rank_size", "tp" and "pp".') + _, pp_groups = self.get_default_groups() + logger.info(f'{self.log_prefix} All pp groups is {pp_groups}.') + return pp_groups + + def _merge_other_unique_nodes(self, main_graph, main_node, other_nodes): + """ + 其他rank graph中other_node的子节点列表如果包含独有的节点,需要合并到main graph + """ + lists = [main_node.subnodes] + for other_node in other_nodes: + lists.append(other_node.subnodes) + dicts = [{node.id: node for node in lst} for lst in lists] + unique_node_ids = {} + # 计算每个集合的独有元素 + for i, current_dict in enumerate(dicts): + other_ids = set() + for j, other_dict in enumerate(dicts): + if i != j: + # 更新并集,添加当前遍历到的集合的元素 + other_ids.update(other_dict.keys()) + result = set(current_dict.keys()) - other_ids + if i != 0 and result: + # 计算当前集合与其他集合并集的差集,即独有元素,保持原始顺序 + unique_node_ids[i] = [node_id for node_id in current_dict if node_id in result] + unique_nodes = [] + if unique_node_ids: + for i, items in unique_node_ids.items(): + for item in items: + unique_nodes.append(dicts[i].get(item)) + if unique_nodes: + for unique_node in unique_nodes: + self._mark_node_id_position_rank(unique_node, unique_node.rank) + self._add_node_to_main_graph(main_graph, unique_node) + main_node.subnodes.append(unique_node) + unique_node.upnode = main_node + + def _sort_nodes(self, main_graph, start_node): + stack = [start_node] + while stack: + node = stack.pop() + if self.MARK_PATTERN.search(node.id): + is_forward = (Const.SEP + Const.FORWARD + Const.SEP in node.id or + Const.SEP + Const.FORWARD + self.MARK in node.id) + new_sub_nodes1, new_sub_nodes2 = [], [] + for item in node.upnode.subnodes: + new_sub_nodes2.append(item) if self.MARK_PATTERN.search(item.id) else new_sub_nodes1.append(item) + + order = True if is_forward else False + new_sub_nodes2.sort(key=lambda n: self._get_node_sort_rule(n, rank_ascending=order)) + new_sub_nodes = new_sub_nodes1 + new_sub_nodes2 if is_forward else new_sub_nodes2 + new_sub_nodes1 + + index = -1 + node_iter = new_sub_nodes if is_forward else reversed(new_sub_nodes) + for item in node_iter: + if self.LAYERS_PATTERN.search(item.id): + index += 1 + if self.MARK_PATTERN.search(item.id): + item.pp_index = index + for item in new_sub_nodes2: + self._update_node_id(main_graph, item) + node.upnode.subnodes = new_sub_nodes + stack.extend(node.subnodes) + + def _add_node_to_main_graph(self, main_graph: Graph, node: BaseNode): + if node.id in main_graph.node_map: + logger.warning(f'{node.id} is exist!') + else: + main_graph.node_map[node.id] = node + for sub_node in node.subnodes: + self._add_node_to_main_graph(main_graph, sub_node) + + def _get_node_sort_rule(self, node, rank_ascending=True): + match = self.MARK_PATTERN.search(node.id) + if match: + # position代表当前节点在父节点中的位置序号 + position, rank = int(match.group(1)), int(match.group(2)) + if rank_ascending: + return rank, position + else: + return -rank, position + return (float('inf'), float('inf')) if rank_ascending else (-float('inf'), -float('inf')) + + def _mark_node_id_position_rank(self, node: BaseNode, rank): + position = 0 + for index, item in enumerate(node.upnode.subnodes): + if item.id == node.id: + position = index + break + # 各rank重复节点添加所处层级位置排序信息position和rank号,用%分隔 + node.id = node.id + f'{self.MARK}{position}' + f'{self.MARK}{rank}' + for sub_node in node.subnodes: + self._mark_node_id_position_rank(sub_node, rank) + + def _update_node_id(self, graph, start_node: BaseNode, pp_index=""): + stack = [(start_node, pp_index)] + while stack: + node, pp_index = stack.pop() + # 修改节点id之前删除node_map的信息,修改完再添加回去 + if node.id not in graph.node_map: + logger.warning(f'Update node id {node.id} fail!') + else: + del graph.node_map[node.id] + old_id = self.MARK_PATTERN.sub("", node.id) + if node.op == NodeOp.module: + # 被pp切分的模块节点,基于位置和rank信息修改模块名称计数信息 + if self.LAYERS_PATTERN.search(node.id) and self.MARK_PATTERN.search(node.id): + if hasattr(node, 'pp_index'): + pp_index = str(node.pp_index) + node.id = self.LAYERS_PATTERN.sub(r"\g<1>" + pp_index + r"\g<2>", node.id) + else: + # api节点,在api名称上添加rank信息 + parts = node.id.split(Const.SEP) + parts[1] += f'_rank{node.id.split(PPMerger.MARK)[-1]}' + node.id = Const.SEP.join(parts) + # 把之前添加的位置和rank信息删掉 + node.id = self.MARK_PATTERN.sub("", node.id) + # node id更新了,那么data的key中包含node id也要更新 + node.input_data = self._update_node_data_key(old_id, node.id, node.input_data) + node.output_data = self._update_node_data_key(old_id, node.id, node.output_data) + graph.node_map[node.id] = node + # 将子节点加入栈中 + for sub_node in node.subnodes: + stack.append((sub_node, pp_index)) + + +class TPMerger(BaseGraphMerger): + RANK_PATTERN = re.compile(r"_rank(\d+)\.") + OPERATION_TABLE = { + Const.MAX: { + 'initial': lambda p: p.get(Const.MAX), + 'merge': lambda current, other: max(current, other.get(Const.MAX)), + 'finalize': lambda current, count: current, + 'formula': lambda key, values: f'{MAX_INFO}{key} is: max({", ".join(map(str, values))})' + }, + Const.MIN: { + 'initial': lambda p: p.get(Const.MIN), + 'merge': lambda current, other: min(current, other.get(Const.MIN)), + 'finalize': lambda current, count: current, + 'formula': lambda key, values: f'{MIN_INFO}{key} is: min({", ".join(map(str, values))})' + }, + Const.MEAN: { + 'initial': lambda p: p.get(Const.MEAN), + 'merge': lambda current, other: current + other.get(Const.MEAN), + 'finalize': lambda current, count: current / count, + 'formula': lambda key, values: f'{MEAN_INFO}{key} is: ({" + ".join(map(str, values))}) / {len(values)}' + }, + Const.NORM: { + 'initial': lambda p: pow(p.get(Const.NORM), 2.0), + 'merge': lambda current, other: current + pow(other.get(Const.NORM), 2.0), + 'finalize': lambda current, count: pow(current, 1 / 2.0), + 'formula': lambda key, values: f'{NORM_INFO}{key} is: ({" + ".join([f"{v} ** 2" for v in values])}) ** 0.5' + } + } + TP_MERGED_INFO = f'This data is the merged data after tensor parallelism(TP), and the data is merged from rank ' + + @staticmethod + def _merge_params(tp_need_merge_param: dict): + """ + 合并tp切分的各rank参数统计值 + tp_need_merge_param: {input.0: [{"Max": 0, "Min": 0, ...}, {"Max": 0.1, "Min": 0, ...}, ...]} + """ + merge_info = [] + for key, param_list in tp_need_merge_param.items(): + if len(param_list) < 2: + continue + main_param = param_list[0] + + for stat, ops in TPMerger.OPERATION_TABLE.items(): + current_value = ops['initial'](main_param) + value_list = [current_value if stat != Const.NORM else main_param.get(Const.NORM)] + + for other_param in param_list[1:]: + current_value = ops['merge'](current_value, other_param) + value_list.append(other_param.get(stat) if stat != Const.NORM else other_param.get(Const.NORM)) + + final_value = ops['finalize'](current_value, len(param_list)) + main_param[stat] = final_value + formula_base = f'{ops["formula"](key, value_list)}' + f' = {final_value}' + + merge_info.append(formula_base) + + return merge_info + + @staticmethod + def _get_need_merge_node(main_node, other_graphs, tp_merge_mapping): + """ + 获取需要TP合并的节点列表 + 如果是TP+PP的混合并行,此时数据已经被PP合并过,一些node_id被标记上rank信息,此时需要基于rank映射才能获取到需要TP合并的节点列表,例如: + main_node = Torch.matmul_rank4.32.forward other_node = Torch.matmul_rank5.32.forward + 需要建立4->5的映射,才能基于Torch.matmul_rank4.32.forward找到Torch.matmul_rank5.32.forward + """ + other_nodes = [] + match = TPMerger.RANK_PATTERN.search(main_node.id) + # 节点名称被标记rank信息,且提供了映射 + if match and tp_merge_mapping: + rank = int(match.group(1)) + tp_mapping_ranks = tp_merge_mapping.get(rank) + if not tp_mapping_ranks: + return other_nodes + if len(tp_mapping_ranks) != len(other_graphs): + return other_nodes + for i, graph in enumerate(other_graphs): + # 基于映射得到目标rank,替换node_id当前rank信息后去目标graph取node + tp_mapping_id = TPMerger.RANK_PATTERN.sub(f"_rank{tp_mapping_ranks[i]}.", main_node.id) + other_node = graph.node_map.get(tp_mapping_id) + if not other_node or main_node.get_ancestors() != other_node.get_ancestors(): + other_nodes.clear() + break + other_nodes.append(other_node) + else: + for graph in other_graphs: + other_node = graph.node_map.get(main_node.id) + if not other_node or main_node.get_ancestors() != other_node.get_ancestors(): + other_nodes.clear() + break + other_nodes.append(other_node) + + return other_nodes + + @staticmethod + def _slice_list_at_id(node_list, target_id1, target_id2): + start_index, end_index = -1, -1 + for index, node in enumerate(node_list): + if target_id1 in node.id: + start_index = index + elif target_id2 in node.id: + end_index = index + return [] if start_index == -1 or end_index == -1 else node_list[start_index:end_index + 1] + + def merge_graphs(self): + results_groups = self.split_graph_results_by_groups(self.get_groups()) + results = [] + for result_groups in results_groups: + self.merge_graph_api_collection(result_groups) + results.extend(self.merge_tp_graphs(result_groups)) + return results + + def merge_tp_graphs(self, results, tp_merge_mapping=None): + if not results or len(results) < 2: + return results + graphs = [x.graph for x in results] + main_graph_result = results[0] + for main_node in main_graph_result.graph.node_map.values(): + should_continue = ( + not main_node.upnode or main_node.upnode.op != NodeOp.module or + main_node.upnode.id in self.unmerged_module or main_node.id.startswith(Const.DISTRIBUTED) or + main_node.parallel_merge_info != []) + if should_continue: + continue + self._handle_tp_matmul_reduce(main_node, graphs[1:], tp_merge_mapping) + other_nodes = self._get_need_merge_node(main_node, graphs[1:], tp_merge_mapping) + tp_need_merge_param_in, tp_need_merge_param_out = self.compare_node_param_data(main_node, other_nodes) + if tp_need_merge_param_in or tp_need_merge_param_out: + ranks = [main_node.rank] + for other_node in other_nodes: + ranks.append(other_node.rank) + main_node.parallel_merge_info.append(f'{self.TP_MERGED_INFO}{ranks}.') + merge_info_in = self._merge_params(tp_need_merge_param_in) + merge_info_out = self._merge_params(tp_need_merge_param_out) + main_node.parallel_merge_info.extend(merge_info_in + merge_info_out) + for main_node in main_graph_result.graph.node_map.values(): + self._merge_tp_megatron_column_row_parallel(main_node, graphs[1:], tp_merge_mapping) + return [main_graph_result] + + def get_groups(self): + tp_groups = [] + for result in self.build_graph_results: + for node in result.graph.node_map.values(): + if any(op in node.id for op in GraphConst.REDUCE_OPERATIONS): + group_ranks = node.input_data.get(f'{node.id}.input.group', {}).get('group_ranks') + if group_ranks and group_ranks not in tp_groups: + tp_groups.append(group_ranks) + break + if not tp_groups: + logger.info('Unable to get tp groups based on Distributed Api, ' + 'generate tp groups using parallel param "rank_size", "tp" and "pp".') + tp_groups, _ = self.get_default_groups() + logger.info(f'{self.log_prefix} All tp groups is {tp_groups}.') + return tp_groups + + def _handle_tp_matmul_reduce(self, node, other_graphs, tp_merge_mapping): + """ + 前向RowParallel和反向ColumnParallel层的matmul输出需要替换成matmul计算完成后all_reduce/reduce_scatter的输出 + """ + if node.op != NodeOp.module: + return + splits = node.id.split(Const.SEP) + if len(splits) < 4: + return + is_forward_with_row_parallel = splits[-2] == Const.FORWARD and 'RowParallelLinear' in splits[-3] + is_backward_with_column_parallel = splits[-2] == Const.BACKWARD and 'ColumnParallelLinear' in splits[-3] + if not is_forward_with_row_parallel and not is_backward_with_column_parallel: + return + matmul_list = [] + reduce_list = [] + for sub_node in node.subnodes: + if 'matmul' in sub_node.id: + matmul_list.append(sub_node) + if ('_reduce_scatter_base' in sub_node.id or 'reduce_scatter_tensor' in sub_node.id or + 'all_reduce' in sub_node.id): + reduce_list.append(sub_node) + if not matmul_list or not reduce_list: + return + for matmul_node in matmul_list: + if not matmul_node.output_data: + continue + # matmul的output0,将传递给all_reduce/reduce_scatter,作为all_reduce的input0,或作为reduce_scatter的input1 + matmul_node_output_param = list(matmul_node.output_data.values())[0] + for reduce_node in reduce_list: + if not reduce_node.output_data: + continue + if 'all_reduce' in reduce_node.id: + if not reduce_node.input_data: + continue + reduce_node_input_param = list(reduce_node.input_data.values())[0] + else: + if len(reduce_node.input_data) < 2: + continue + reduce_node_input_param = list(reduce_node.input_data.values())[1] + if not self.compare_param_same(matmul_node_output_param, reduce_node_input_param): + continue + # matmul的input统计值与其他rank的数据进行合并 + other_nodes = self._get_need_merge_node(matmul_node, other_graphs, tp_merge_mapping) + tp_need_merge_param_in, _ = self.compare_node_param_data(matmul_node, other_nodes) + if tp_need_merge_param_in: + ranks = [matmul_node.rank] + for other_node in other_nodes: + ranks.append(other_node.rank) + matmul_node.parallel_merge_info.append(f'{self.TP_MERGED_INFO}{ranks}.') + merge_info_in = self._merge_params(tp_need_merge_param_in) + matmul_node.parallel_merge_info.extend(merge_info_in) + # matmul的output0替换为all_reduce/reduce_scatter的output0 + reduce_node_output_param = list(reduce_node.output_data.values())[0] + keys = [Const.MAX, Const.MIN, Const.MEAN, Const.NORM] + matmul_node_output_param.update({k: reduce_node_output_param.get(k) for k in keys}) + full_op_name = reduce_node_output_param.get('full_op_name') + param_name = full_op_name if full_op_name else reduce_node.id + matmul_node.parallel_merge_info.append(f'The output of this data is merged from {param_name}') + reduce_list.remove(reduce_node) + break + + def _merge_tp_megatron_column_row_parallel(self, node, other_graphs, tp_merge_mapping): + if node.op != NodeOp.module or node.parallel_merge_info: + return + splits = node.id.split(Const.SEP) + if len(splits) < 4: + return + is_forward_with_column_parallel = splits[-2] == Const.FORWARD and 'ColumnParallelLinear' in splits[-3] + if not is_forward_with_column_parallel: + return + if not node.upnode: + return + # 获取[ColumnParallelLinear, RowParallelLinear]结构 + nodes = self._slice_list_at_id(node.upnode.subnodes, node.id, 'RowParallelLinear') + if len(nodes) < 2: + return + stack = nodes[:] + while stack: + current_node = stack.pop() + stack.extend(reversed(current_node.subnodes)) + + if current_node.parallel_merge_info or current_node.id.startswith(Const.DISTRIBUTED): + continue + + other_nodes = self._get_need_merge_node(current_node, other_graphs, tp_merge_mapping) + param_in, param_out = self.compare_node_param_data(current_node, other_nodes, False) + + if param_in or param_out: + ranks = [current_node.rank] + for other_node in other_nodes: + ranks.append(other_node.rank) + current_node.parallel_merge_info.append(f'{self.TP_MERGED_INFO}{ranks}.') + # ColumnParallelLinear层的输入、其中的matmul输入不需要合并 + if current_node == nodes[0] or ('matmul' in current_node.id and current_node.upnode == nodes[0]): + param_in.pop('input.0', None) + # RowParallelLinear层的输出、其中的matmul输出不需要合并, bias不需要合并 + elif current_node == nodes[-1] or ('matmul' in current_node.id and current_node.upnode == nodes[-1]): + param_out = {} + param_in.pop('parameters.bias', None) + + merge_info_in = self._merge_params(param_in) + merge_info_out = self._merge_params(param_out) + current_node.parallel_merge_info.extend(merge_info_in + merge_info_out) + + +class NoParallelMerger(BaseGraphMerger): + def merge_graphs(self): + self.merge_graph_api_collection(self.build_graph_results) + return self.build_graph_results + + +class TPPPMerger(BaseGraphMerger): + def merge_graphs(self): + tp_merger = TPMerger(self.build_graph_results, self.parallel_param, self.is_bench) + pp_merger = PPMerger(self.build_graph_results, self.parallel_param, self.is_bench) + pp_groups = pp_merger.get_groups() + tp_groups = tp_merger.get_groups() + # 进入TP+PP混合处理器,PP和TP必然大于1 + tp_merge_mapping = {} + for tp_group in tp_groups[1:]: + tp_merge_mapping[tp_group[0]] = tp_group[1:] + self.merge_graph_api_collection(self.build_graph_results) + # 先合并pp,需要知道pp域,在各自pp域中合并 + results_groups_pp = self.split_graph_results_by_groups(pp_groups) + pp_results = [] + for results in results_groups_pp: + pp_results.extend(pp_merger.merge_pp_graphs(results)) + # pp合并完成后,直接进行tp合并,最终得到一个graph + tp_result = tp_merger.merge_tp_graphs(pp_results, tp_merge_mapping) + self.sort_merged_api_collection(tp_result[0].graph) + return tp_result + + +class FullMerger(BaseGraphMerger): + def merge_graphs(self): + tp_merger = TPMerger(self.build_graph_results, self.parallel_param, self.is_bench) + pp_merger = PPMerger(self.build_graph_results, self.parallel_param, self.is_bench) + pp_groups = pp_merger.get_groups() + tp_groups = tp_merger.get_groups() + tp_merge_mapping = {} + if len(tp_groups) < 1: + raise RuntimeError(f'Graph merged error, and tp_groups is {tp_groups}.') + for tp_group in tp_groups[1:]: + if len(tp_group) < 1: + raise RuntimeError(f'Graph merged error, and tp_group is {tp_group}.') + tp_merge_mapping[tp_group[0]] = tp_group[1:] + # 先合并pp,需要知道pp域,在各自pp域中合并 + results_groups_pp = self.split_graph_results_by_groups(pp_groups) + pp_results = {} + for pp_result in results_groups_pp: + self.merge_graph_api_collection(pp_result) + pp_result = pp_merger.merge_pp_graphs(pp_result)[0] + pp_results[pp_result.rank] = pp_result + # pp合并完成后,基于tp域划分pp合并结果 + lists_to_be_tp_merged = [] + for tp_group in tp_groups: + list_to_be_tp_merged = [] + for rank in tp_group: + pp_result = pp_results.get(rank) + if pp_result: + list_to_be_tp_merged.append(pp_result) + if list_to_be_tp_merged: + lists_to_be_tp_merged.append(list_to_be_tp_merged) + tp_results = [] + for list_to_be_tp_merged in lists_to_be_tp_merged: + self.merge_graph_api_collection(list_to_be_tp_merged) + tp_merged_result = tp_merger.merge_tp_graphs(list_to_be_tp_merged, tp_merge_mapping) + self.sort_merged_api_collection(tp_merged_result[0].graph) + tp_results.extend(tp_merged_result) + return tp_results diff --git a/debug/accuracy_tools/msprobe/visualization/builder/msprobe_adapter.py b/debug/accuracy_tools/msprobe/visualization/builder/msprobe_adapter.py index ee5e3f519ed126b2aaa493e0d3a3b7fce33313e4..5e59ac9b036aa4cf875007a1db33b8e035abfb10 100644 --- a/debug/accuracy_tools/msprobe/visualization/builder/msprobe_adapter.py +++ b/debug/accuracy_tools/msprobe/visualization/builder/msprobe_adapter.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -12,13 +12,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import re -import math -from msprobe.core.compare.acc_compare import read_op, merge_tensor, get_accuracy + +from msprobe.core.compare.acc_compare import ModeConfig +from msprobe.core.compare.multiprocessing_compute import CompareRealData +from msprobe.core.compare.utils import read_op, merge_tensor, get_accuracy, make_result_table from msprobe.core.common.utils import set_dump_path, get_dump_mode from msprobe.visualization.utils import GraphConst from msprobe.core.common.const import Const -from msprobe.core.compare.acc_compare import ModeConfig +from msprobe.core.common.file_utils import load_json # 用于将节点名字解析成对应的NodeOp的规则 op_patterns = [ @@ -53,14 +56,37 @@ def run_real_data(dump_path_param, csv_path, framework, is_cross_frame=False): """ mode_config = ModeConfig(stack_mode=False, auto_analyze=True, fuzzy_match=False, dump_mode=Const.ALL) + if framework == Const.PT_FRAMEWORK: + from msprobe.pytorch.compare.pt_compare import read_real_data + return CompareRealData(read_real_data, mode_config, is_cross_frame).do_multi_process(dump_path_param, csv_path) + else: + from msprobe.mindspore.compare.ms_compare import read_real_data + return CompareRealData(read_real_data, mode_config, is_cross_frame).do_multi_process(dump_path_param, csv_path) + + +def run_real_data_single(op_names, op_name_mapping_dict, input_param, framework, is_cross_frame=False): + """ + 单进程运行生成真实数据 + Args: + op_names: [npu_op_name, bench_op_name], excel中的NPU_Name和Bench_Name,例如:Functional.conv2d.0.forward.input.3.0 + op_name_mapping_dict: op_name和npy或pt文件的映射关系 + input_param: npu_json_path/bench_json_path/stack_json_path等参数 + framework: 框架类型, pytorch或mindspore + is_cross_frame: 是否进行跨框架比对,仅支持mindspore比pytorch, 其中pytorch为标杆 + """ + if not isinstance(op_names, list) or len(op_names) != 2: + return [] + mode_config = ModeConfig(stack_mode=False, auto_analyze=True, fuzzy_match=False, dump_mode=Const.ALL) + set_dump_path(input_param) + if framework == Const.PT_FRAMEWORK: from msprobe.pytorch.compare.pt_compare import PTComparator - return PTComparator(mode_config).do_multi_process(dump_path_param, csv_path) + return PTComparator(mode_config).compare_by_op(op_names[0], op_names[1], op_name_mapping_dict, input_param) else: from msprobe.mindspore.compare.ms_compare import MSComparator, MappingConfig ms_comparator = MSComparator(mode_config, MappingConfig()) ms_comparator.cross_frame = is_cross_frame - return ms_comparator.do_multi_process(dump_path_param, csv_path) + return ms_comparator.compare_by_op(op_names[0], op_names[1], op_name_mapping_dict, input_param) def get_input_output(node_data, node_id): @@ -120,11 +146,13 @@ def compare_data_fuzzy(data_dict_list1, data_dict_list2): return True -def format_node_data(data_dict, node_id=None): +def format_node_data(data_dict, node_id=None, compare_mode=None): """ 删除节点数据中不需要展示的字段 """ del_list = ['requires_grad', 'full_op_name'] + if GraphConst.MD5_COMPARE != compare_mode: + del_list.append(Const.MD5) if node_id and GraphConst.BATCH_P2P in node_id: del_list.extend(['op', 'peer', 'tag', 'group_id']) for _, value in data_dict.items(): @@ -137,21 +165,21 @@ def format_node_data(data_dict, node_id=None): return data_dict -def compare_node(node_ids, data_dicts, stack_json_data, compare_mode): +def compare_node_by_dump_data(node_ids, data_dicts, stack_json_data, compare_mode): """ 调用acc_compare.py中的get_accuracy获得精度对比指标 真实数据对比模式无法获得精度对比指标,需要调用多进程比对接口 Returns: 包含参数信息和对比指标(真实数据对比模式除外)的list """ - merge_n = _parse_node(node_ids[0], data_dicts[0], stack_json_data, compare_mode) - merge_b = _parse_node(node_ids[1], data_dicts[1], stack_json_data, compare_mode) + merge_n = _parse_node_by_dump_data(node_ids[0], data_dicts[0], stack_json_data, compare_mode) + merge_b = _parse_node_by_dump_data(node_ids[1], data_dicts[1], stack_json_data, compare_mode) result = [] dump_mode = GraphConst.GRAPHCOMPARE_MODE_TO_DUMP_MODE_TO_MAPPING.get(compare_mode) get_accuracy(result, merge_n, merge_b, dump_mode) return result -def _parse_node(node_id, data_dict, stack_json_data, compare_mode): +def _parse_node_by_dump_data(node_id, data_dict, stack_json_data, compare_mode): """ 转换节点,使其能够作为acc_compare.py中的get_accuracy的入参 """ @@ -168,11 +196,30 @@ def _parse_node(node_id, data_dict, stack_json_data, compare_mode): return result +def compare_node(node_n, node_b, compare_mode): + dump_mode = GraphConst.GRAPHCOMPARE_MODE_TO_DUMP_MODE_TO_MAPPING.get(compare_mode) + merge_n = _parse_node(node_n, dump_mode) + merge_b = _parse_node(node_b, dump_mode) + result = [] + get_accuracy(result, merge_n, merge_b, dump_mode) + return result + + +def _parse_node(node, dump_mode): + op_parsed_list = [] + op_parsed_list.extend(node.input_data.values()) + op_parsed_list.extend(node.output_data.values()) + result = merge_tensor(op_parsed_list, dump_mode) + if not result: + result['op_name'] = [] + return result + + def _format_decimal_string(s): """ 使用正则表达式匹配包含数字、小数点和可选的百分号的字符串 """ - pattern = re.compile(r'\d{1,20}\.\d{1,20}%?') + pattern = re.compile(r'^\d{1,20}\.\d{1,20}%?$') matches = pattern.findall(s) for match in matches: is_percent = match.endswith('%') @@ -227,3 +274,12 @@ def _format_data(data_dict): if all_null: data_dict.clear() data_dict[GraphConst.VALUE] = GraphConst.NULL + + +def get_csv_df(stack_mode, csv_data, compare_mode): + """ + 调用acc接口写入csv + """ + + dump_mode = GraphConst.GRAPHCOMPARE_MODE_TO_DUMP_MODE_TO_MAPPING.get(compare_mode) + return make_result_table(csv_data, dump_mode, stack_mode) diff --git a/debug/accuracy_tools/msprobe/visualization/compare/graph_comparator.py b/debug/accuracy_tools/msprobe/visualization/compare/graph_comparator.py index 902d721a8d1047b687b878eb45a802a1df4154bd..20c97d717b1cc87e69523144aca0ec93d564e682 100644 --- a/debug/accuracy_tools/msprobe/visualization/compare/graph_comparator.py +++ b/debug/accuracy_tools/msprobe/visualization/compare/graph_comparator.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,34 +14,111 @@ # limitations under the License. import re -from msprobe.visualization.builder.msprobe_adapter import compare_node, get_compare_mode, run_real_data -from msprobe.visualization.utils import GraphConst, load_json_file, load_data_json_file, get_csv_df +from msprobe.visualization.builder.msprobe_adapter import compare_node, get_compare_mode, run_real_data, \ + run_real_data_single, get_csv_df, compare_node_by_dump_data +from msprobe.visualization.utils import GraphConst, load_json_file, load_data_json_file from msprobe.visualization.graph.graph import Graph, NodeOp -from msprobe.visualization.graph.node_colors import NodeColors from msprobe.visualization.compare.mode_adapter import ModeAdapter -from msprobe.core.common.const import Const +from msprobe.core.common.const import Const, CompareConst +from msprobe.core.common.log import logger +from msprobe.core.common.file_utils import load_yaml +from msprobe.visualization.compare.multi_mapping import MultiMapping class GraphComparator: - def __init__(self, graphs, dump_path_param, args, mapping_dict=None): + MAX_DEPTH = 1000 + + def __init__(self, graphs, dump_path_param, args, is_cross_framework, mapping_dict=None): self.graph_n = graphs[0] self.graph_b = graphs[1] self._parse_param(dump_path_param, args.output_path) self.framework = args.framework + self.layer_mapping = args.layer_mapping self.mapping_dict = mapping_dict self.fuzzy_match = args.fuzzy_match self.pattern = re.compile(r'\.\d+\.') + self.is_cross_framework = is_cross_framework + self.parallel_merge = args.parallel_merge if hasattr(args, 'parallel_merge') else False + self.rank_pattern = re.compile(r"_rank\d+") def compare(self): """ 比较函数,初始化结束后单独调用。比较结果写入graph_n """ if self.fuzzy_match: - self._compare_nodes_fuzzy(self.graph_n.root) + self._compare_nodes_fuzzy(self.graph_n.root, False if self.parallel_merge else True) else: self._compare_nodes(self.graph_n.root) self._postcompare() - + + def multi_compare(self, multi_yaml_path): + """ + 多对多节点比对,需建立数量n与数量m节点之间的映射关系 + Args: + multi_yaml_path: 映射文件路径 + """ + multi_mapping = MultiMapping.validate_yaml(load_yaml(multi_yaml_path)) + if not multi_mapping: + logger.warning( + f'The multi mapping file {multi_yaml_path} content is incorrect, and the mapping is not effective.') + return + if self.ma.compare_mode == GraphConst.REAL_DATA_COMPARE: + # 获取真实数据指标在真实数据表头的索引 + id_list = [CompareConst.COMPARE_RESULT_HEADER.index(x) for x in CompareConst.ALL_COMPARE_INDEX] + for node_n_ids, node_b_ids in multi_mapping.items(): + if not MultiMapping.validate_ids_in_graph(node_n_ids, self.graph_n): + continue + if not MultiMapping.validate_ids_in_graph(node_b_ids, self.graph_b, GraphConst.JSON_BENCH_KEY): + continue + merged_items_n = MultiMapping.merge_nodes(node_n_ids, self.graph_n) + merged_items_b = MultiMapping.merge_nodes(node_b_ids, self.graph_b) + node_n = merged_items_n.multi_node + node_n_data = self.data_n_dict + node_b = merged_items_b.multi_node + node_b_data = self.data_b_dict + + if node_n.op == NodeOp.multi_collection: + node_n_data = MultiMapping.get_merged_nodes_data(node_n_ids, self.data_n_dict, node_n.id) + if node_b.op == NodeOp.multi_collection: + node_b_data = MultiMapping.get_merged_nodes_data(node_b_ids, self.data_b_dict, node_b.id) + + node = self._compare_node_with_mapping(node_n, {node_n.id: node_b.id}) + if not node: + continue + compare_result_list = compare_node_by_dump_data([node_n.id, node_b.id], + [node_n_data, node_b_data], + self.stack_json_data, self.ma.compare_mode) + if not compare_result_list: + continue + # 真实数据模式,compare_result_list里没有精度指标,需要调用真实数据的比对接口得到指标 + if self.ma.compare_mode == GraphConst.REAL_DATA_COMPARE: + for compare_result in compare_result_list: + # 准备真实数据比对接口需要的参数 + full_param_name_n = compare_result[0] + full_param_name_b = compare_result[1] + + data_name_n = MultiMapping.get_dump_data_name(merged_items_n, full_param_name_n) + data_name_b = MultiMapping.get_dump_data_name(merged_items_b, full_param_name_b) + op_name_mapping_dict = {full_param_name_n: [data_name_n, data_name_b]} + + real_compare_result = run_real_data_single([full_param_name_n, full_param_name_b], + op_name_mapping_dict, self.dump_path_param, + self.framework, self.is_cross_framework) + if len(real_compare_result) < len(id_list): + continue + for i, index in enumerate(id_list): + # 根据索引,将真实数据指标插入表头相应位置 + compare_result[index] = real_compare_result[i] + compare_dict = {} + for item in compare_result_list: + if not isinstance(item, (list, tuple)) or not item: + continue + compare_dict[MultiMapping.replace_param_name(item[0], node_n.id)] = item + precision_index, _ = self.ma.parse_result(node_n, [compare_dict]) + node_n.data[GraphConst.JSON_INDEX_KEY] = precision_index + else: + self.add_compare_result_to_node(node_n, compare_result_list) + def add_compare_result_to_node(self, node, compare_result_list): """ 将比对结果添加到节点的输入输出数据中 @@ -66,7 +143,63 @@ class GraphComparator: self.ma.parse_result(node, [compare_in_dict, compare_out_dict])) node.data[GraphConst.JSON_INDEX_KEY] = precision_index node.data.update(other_dict) - + + def _compare_nodes(self, node_root): + """ + 遍历NPU树中的节点,如果在Bench中找到具有相同名称的节点,检查他们的祖先和参数信息,检查一致则及逆行精度数据对比 + 这里采用先序遍历,好处在于当这个节点被比较时,他的先序已经被匹配,这可以为后续的模糊匹配提供重要信息 + """ + def compare_single_node(node_n): + if self.layer_mapping: + node_b = self._compare_node_with_mapping(node_n, self.mapping_dict) + else: + node_b, ancestors = Graph.match(self.graph_n, node_n, self.graph_b) + if node_b: + ancestors.append(node_b.id) + node_n.add_link(node_b, ancestors) + if node_b: + # 真实数据比对只会得到基本信息,并没有精度指标,需要调用多进程对比接口 + self._get_and_add_result(node_n, node_b) + node_list.extend(node_n.subnodes) + + node_list = [node_root] + while node_list: + compare_single_node(node_list.pop(0)) + + def _compare_nodes_fuzzy(self, node_root, check_shape=True): + def compare_single_node_fuzzy(node_n): + if node_n.op != NodeOp.function_api: + # 模块经过模糊匹配 + node_b, ancestors_n, ancestors_b = Graph.fuzzy_match(node_n, self.graph_b.node_map.get(node_n.id), + check_shape) + if node_b: + self._process_matched_nodes(node_n, node_b, ancestors_n, ancestors_b) + # 匹配上的两个模块中的所有api, 忽略dump调用次数,按照名称一致+模块中的调用顺序进行匹配 + recount_result_n = self._recount_api_node(node_n) + recount_result_b = self._recount_api_node(node_b) + for recount_node_id, node_id_n in recount_result_n.items(): + api_node_n = self.graph_n.node_map.get(node_id_n) + if not api_node_n: + continue + api_node_b, ancestors_n, ancestors_b = Graph.fuzzy_match( + api_node_n, self.graph_b.node_map.get(recount_result_b.get(recount_node_id)), check_shape) + if api_node_b: + self._process_matched_nodes(api_node_n, api_node_b, ancestors_n, ancestors_b) + node_list.extend(node_n.subnodes) + + node_list = [node_root] + while node_list: + compare_single_node_fuzzy(node_list.pop(0)) + + def _compare_node_with_mapping(self, node_n, mapping_dict): + node_b, ancestors_n, ancestors_b = Graph.mapping_match(node_n, self.graph_b, mapping_dict) + if node_b: + ancestors_n.append(node_n.id) + ancestors_b.append(node_b.id) + node_n.matched_node_link = ancestors_b + node_b.matched_node_link = ancestors_n + return node_b + def _parse_param(self, dump_path_param, output_path): self.dump_path_param = dump_path_param self.output_path = output_path @@ -81,7 +214,7 @@ class GraphComparator: if not self.ma.compare_mode == GraphConst.REAL_DATA_COMPARE: return df = get_csv_df(True, self.ma.csv_data, self.ma.compare_mode) - df = run_real_data(self.dump_path_param, df, self.framework, True if self.mapping_dict else False) + df = run_real_data(self.dump_path_param, df, self.framework, self.is_cross_framework) compare_data_dict = {row[0]: row.tolist() for _, row in df.iterrows()} for node in self.ma.compare_nodes: precision_index, _ = self.ma.parse_result(node, [compare_data_dict]) @@ -92,64 +225,26 @@ class GraphComparator: api集合的指标, md5模式使用集合中所有api最小的指标,statistics和tensor模式使用集合中所有api最大的指标 md5模式下指标为0代表最差,statistics和tensor模式下指标为1代表最差 """ + def handle_api_collection_index(api_collection_node): + precision_index = GraphConst.MAX_INDEX_KEY if self.ma.compare_mode == GraphConst.MD5_COMPARE \ + else GraphConst.MIN_INDEX_KEY + for api in api_collection_node.subnodes: + precision_index = min(precision_index, + api.data.get(GraphConst.JSON_INDEX_KEY, GraphConst.MAX_INDEX_KEY)) \ + if self.ma.compare_mode == GraphConst.MD5_COMPARE \ + else max(precision_index, api.data.get(GraphConst.JSON_INDEX_KEY, GraphConst.MIN_INDEX_KEY)) + api_collection_node.data[GraphConst.JSON_INDEX_KEY] = precision_index + for node in self.graph_n.root.subnodes: - if node.op == NodeOp.api_collection: - precision_index = GraphConst.MAX_INDEX_KEY if self.ma.compare_mode == GraphConst.MD5_COMPARE \ - else GraphConst.MIN_INDEX_KEY - for api in node.subnodes: - precision_index = min(precision_index, - api.data.get(GraphConst.JSON_INDEX_KEY, GraphConst.MAX_INDEX_KEY)) \ - if self.ma.compare_mode == GraphConst.MD5_COMPARE \ - else max(precision_index, api.data.get(GraphConst.JSON_INDEX_KEY, GraphConst.MIN_INDEX_KEY)) - node.data[GraphConst.JSON_INDEX_KEY] = precision_index - - def _compare_nodes(self, node_n): - """ - 递归遍历NPU树中的节点,如果在Bench中找到具有相同名称的节点,检查他们的祖先和参数信息,检查一致则及逆行精度数据对比 - 这里采用先序遍历,好处在于当这个节点被比较时,他的先序已经被匹配,这可以为后续的模糊匹配提供重要信息 - """ - if self.mapping_dict: - node_b, ancestors_n, ancestors_b = Graph.mapping_match(node_n, self.graph_b, self.mapping_dict) - if node_b: - ancestors_n.append(node_n.id) - ancestors_b.append(node_b.id) - node_n.matched_node_link = ancestors_b - node_b.matched_node_link = ancestors_n - else: - node_b, ancestors = Graph.match(self.graph_n, node_n, self.graph_b) - if node_b: - ancestors.append(node_b.id) - node_n.add_link(node_b, ancestors) - if node_b: - # 真实数据比对只会得到基本信息,并没有精度指标,需要调用多进程对比接口 - self._get_and_add_result(node_n, node_b) - for subnode in node_n.subnodes: - self._compare_nodes(subnode) - - def _compare_nodes_fuzzy(self, node_n): - if node_n.op != NodeOp.function_api: - # 模块经过模糊匹配 - node_b, ancestors_n, ancestors_b = Graph.fuzzy_match(node_n, self.graph_b.node_map.get(node_n.id)) - if node_b: - self._process_matched_nodes(node_n, node_b, ancestors_n, ancestors_b) - # 匹配上的两个模块中的所有api, 忽略dump调用次数,按照名称一致+模块中的调用顺序进行匹配 - recount_result_n = self._recount_api_node(node_n) - recount_result_b = self._recount_api_node(node_b) - for recount_node_id, node_id_n in recount_result_n.items(): - api_node_n = self.graph_n.node_map.get(node_id_n) - if not api_node_n: - continue - api_node_b, ancestors_n, ancestors_b = Graph.fuzzy_match( - api_node_n, self.graph_b.node_map.get(recount_result_b.get(recount_node_id))) - if api_node_b: - self._process_matched_nodes(api_node_n, api_node_b, ancestors_n, ancestors_b) - for sub_node in node_n.subnodes: - self._compare_nodes_fuzzy(sub_node) + if node.op == NodeOp.api_collection and node.id.startswith(GraphConst.APIS_BETWEEN_MODULES_ALL_RANKS): + for sub_node in node.subnodes: + handle_api_collection_index(sub_node) + handle_api_collection_index(node) + elif node.op == NodeOp.api_collection: + handle_api_collection_index(node) def _get_and_add_result(self, node_n, node_b): - compare_result_list = compare_node([node_n.id, node_b.id], - [self.data_n_dict, self.data_b_dict], - self.stack_json_data, self.ma.compare_mode) + compare_result_list = compare_node(node_n, node_b, self.ma.compare_mode) if compare_result_list: self.ma.add_csv_data(compare_result_list) self.add_compare_result_to_node(node_n, compare_result_list) @@ -166,6 +261,8 @@ class GraphComparator: if sub_node.op == NodeOp.function_api: # 忽略dump调用次数 count_removed_id = self.pattern.sub(Const.SEP, sub_node.id) + if self.rank_pattern.search(count_removed_id): + count_removed_id = self.rank_pattern.sub('', count_removed_id) node_count[count_removed_id] = node_count.get(count_removed_id, 0) + 1 # 赋予模块中的调用顺序 recount_node_id = count_removed_id + str(node_count.get(count_removed_id)) diff --git a/debug/accuracy_tools/msprobe/visualization/compare/mode_adapter.py b/debug/accuracy_tools/msprobe/visualization/compare/mode_adapter.py index 535192d80c566c48cedde4ea5b4474b6dc82dec0..2f1c7d5721accb1165e57181573dc9cd9715746f 100644 --- a/debug/accuracy_tools/msprobe/visualization/compare/mode_adapter.py +++ b/debug/accuracy_tools/msprobe/visualization/compare/mode_adapter.py @@ -13,8 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json import math +import json from msprobe.core.common.const import CompareConst, Const from msprobe.visualization.utils import ToolTip, GraphConst, str2float @@ -25,6 +25,12 @@ class ModeAdapter: self.csv_data = [] self.compare_nodes = [] + @staticmethod + def _is_invalid(value): + if not isinstance(value, float): + return False + return math.isnan(value) or math.isinf(value) + @staticmethod def _add_md5_compare_data(node_data, compare_data_dict): precision_index = GraphConst.MAX_INDEX_KEY @@ -49,6 +55,8 @@ class ModeAdapter: for key, value in node_data.items(): if not isinstance(value, dict): continue + if value.get(Const.MAX) is None: + continue compare_data = compare_data_dict.get(key) if compare_data: headers = CompareConst.COMPARE_RESULT_HEADER @@ -67,9 +75,13 @@ class ModeAdapter: if thousandth is not None: numbers.append(thousandth) node_data[key] = value + if ModeAdapter._is_invalid(value.get(Const.MAX)) or ModeAdapter._is_invalid(value.get(Const.MIN)): + numbers.append(CompareConst.N_A) # 双千指标都是None的异常情况 if not numbers: min_thousandth = None + elif CompareConst.N_A in numbers: + min_thousandth = CompareConst.N_A else: min_thousandth = min(numbers + [min_thousandth]) return min_thousandth @@ -81,6 +93,8 @@ class ModeAdapter: for key, data_info in node_data.items(): if not isinstance(data_info, dict): continue + if data_info.get(Const.MAX) is None: + continue compare_data = compare_data_dict.get(key) if compare_data: # 对应比对结果csv的列 @@ -92,6 +106,8 @@ class ModeAdapter: relative_err = str2float(data_info.get(item)) max_relative_err = max(max_relative_err, relative_err) node_data[key] = data_info + if ModeAdapter._is_invalid(data_info.get(Const.MAX)) or ModeAdapter._is_invalid(data_info.get(Const.MIN)): + max_relative_err = GraphConst.MAX_INDEX_KEY max_relative_err = 1 if max_relative_err > 1 else max_relative_err return max_relative_err @@ -133,7 +149,11 @@ class ModeAdapter: ModeAdapter._check_list_len(compare_data_dict_list, 1) min_thousandth_in = ModeAdapter._add_real_compare_data(node.input_data, compare_data_dict_list[0]) min_thousandth_out = ModeAdapter._add_real_compare_data(node.output_data, compare_data_dict_list[0]) - if min_thousandth_in is not None and min_thousandth_out is not None: + if CompareConst.N_A == min_thousandth_out: + change_percentage = GraphConst.MAX_INDEX_KEY + elif CompareConst.N_A == min_thousandth_in: + change_percentage = GraphConst.MIN_INDEX_KEY + elif min_thousandth_in is not None and min_thousandth_out is not None: change_percentage = min_thousandth_in - min_thousandth_out else: change_percentage = GraphConst.MIN_INDEX_KEY @@ -141,6 +161,7 @@ class ModeAdapter: else change_percentage precision_index = GraphConst.MAX_INDEX_KEY \ if change_percentage > GraphConst.MAX_INDEX_KEY else change_percentage + precision_index = self._ignore_precision_index(node.id, precision_index) return precision_index, other_dict def prepare_real_data(self, node): @@ -157,24 +178,6 @@ class ModeAdapter: return self.csv_data.extend(compare_result_list) - def add_error_key(self, node_data): - """ - 根据不同的模式进行提供不同错误信息 - """ - for key, value in node_data.items(): - if not isinstance(value, dict): - continue - if self.compare_mode == GraphConst.SUMMARY_COMPARE: - message = [CompareConst.MAX_RELATIVE_ERR, CompareConst.MIN_RELATIVE_ERR, - CompareConst.MEAN_RELATIVE_ERR, CompareConst.NORM_RELATIVE_ERR] - elif self.compare_mode == GraphConst.REAL_DATA_COMPARE: - message = [CompareConst.ONE_THOUSANDTH_ERR_RATIO, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO] - else: - # 输出件优化 - message = [] - value[GraphConst.ERROR_KEY] = message - node_data[key] = value - def get_tool_tip(self): """ 用于前端展示字段的具体含义 @@ -195,3 +198,11 @@ class ModeAdapter: CompareConst.MAX_ABS_ERR: ToolTip.MAX_ABS_ERR, CompareConst.MAX_RELATIVE_ERR: ToolTip.MAX_RELATIVE_ERR} return json.dumps(tips) + + def _ignore_precision_index(self, node_id, precision_index): + node_id_split = node_id.split(Const.SEP) + if len(node_id_split) < 2: + return precision_index + if node_id.split(Const.SEP)[1] in GraphConst.IGNORE_PRECISION_INDEX: + return GraphConst.MAX_INDEX_KEY if self.compare_mode == GraphConst.MD5_COMPARE else GraphConst.MIN_INDEX_KEY + return precision_index diff --git a/debug/accuracy_tools/msprobe/visualization/compare/multi_mapping.py b/debug/accuracy_tools/msprobe/visualization/compare/multi_mapping.py new file mode 100644 index 0000000000000000000000000000000000000000..bcc7c0f31351a52e40acfd6824c6b2f8f49ffd52 --- /dev/null +++ b/debug/accuracy_tools/msprobe/visualization/compare/multi_mapping.py @@ -0,0 +1,173 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + +from msprobe.core.common.const import Const +from msprobe.core.common.log import logger +from msprobe.visualization.utils import GraphConst +from msprobe.visualization.graph.graph import NodeOp, BaseNode +from msprobe.core.compare.utils import get_name_and_state + + +@dataclass +class MergedItems: + multi_node: BaseNode = None + start_node: BaseNode = None + end_node: BaseNode = None + + +class MultiMapping: + + @staticmethod + def validate_yaml(yaml_file): + multi_mapping = {} + if not yaml_file: + logger.warning(f'The multi mapping file cannot be empty.') + return multi_mapping + if not isinstance(yaml_file, dict): + logger.warning(f'The multi mapping file format must be a dict.') + return multi_mapping + for key, value in yaml_file.items(): + multi_mapping[MultiMapping._split_mapping_str(key)] = MultiMapping._split_mapping_str(value) + return multi_mapping + + @staticmethod + def validate_ids_in_graph(node_ids, graph, graph_type=GraphConst.JSON_NPU_KEY): + in_graph = True + for node_id in node_ids: + if node_id not in graph.node_map: + logger.warning(f'{node_id} does not exist in the {graph_type} graph, and the mapping is not effective.') + in_graph = False + return in_graph + + @staticmethod + def get_merged_nodes_data(node_ids: (list, tuple), dump_data: dict, multi_node_id: str): + if len(node_ids) < 2: + return {} + multi_node_data = {} + for k, v in dump_data.get(node_ids[0], {}).items(): + if k in [Const.INPUT, Const.INPUT_ARGS, Const.INPUT_KWARGS]: + multi_node_data[k] = v + for k, v in dump_data.get(node_ids[-1], {}).items(): + if k == Const.OUTPUT: + multi_node_data[k] = v + return {multi_node_id: multi_node_data} + + @staticmethod + def replace_param_name(param_name: str, multi_node_id): + try: + api, _ = get_name_and_state(param_name) + except Exception: + return param_name + return param_name.replace(api, multi_node_id + Const.SEP) + + @staticmethod + def merge_nodes(node_ids, graph): + """ + 根据传入的节点名称列表,将列表中的节点合并为一个节点,并取列表中的首节点输入数据作为融合节点的输入,尾节点的输出数据作为融合节点的输出 + Args: + node_ids: 节点名称列表 + graph: 图 + + Returns: 融合节点,首节点,尾节点 + + """ + if not node_ids or not isinstance(node_ids, (list, tuple)): + return MergedItems() + if len(node_ids) == 1: + return MergedItems(graph.get_node(node_ids[0])) + # 根据映射文件中配置的首尾节点id,得到首尾节点id之间的所有节点id列表 + node0 = graph.get_node(node_ids[0]) + node1 = graph.get_node(node_ids[-1]) + if not node0 or not node1: + return MergedItems() + current_node_list = node0.upnode.subnodes + + start_index = end_index = 0 + for i, node in enumerate(current_node_list): + if node.id == node_ids[0]: + start_index = i + elif node.id == node_ids[-1]: + end_index = i + + if start_index > end_index: + logger.warning(f'{node_ids[0]} and {node_ids[-1]} are in the wrong order, {node_ids[0]} should come first, ' + f'and the mapping is not effective.') + return MergedItems() + + current_node_list = current_node_list[start_index:end_index + 1] + + # 创建一个新的节点,作为被映射多个节点的集合,输入使用第一个节点的输入,输出使用最后一个节点的输出 + multi_node_name = GraphConst.MERGE_NODES + Const.SEP + Const.FORWARD \ + if Const.SEP + Const.FORWARD + Const.SEP in node0.id \ + else GraphConst.MERGE_NODES + Const.SEP + Const.BACKWARD + multi_node_id = graph.add_node(NodeOp.multi_collection, multi_node_name, id_accumulation=True) + multi_node = graph.get_node(multi_node_id) + multi_node.subnodes = current_node_list + multi_node.upnode = node0.upnode + # 重新确立父子关系 + for node in current_node_list: + node.upnode = multi_node + + multi_node.upnode.subnodes[start_index:end_index + 1] = [multi_node] + + # 给节点添加输入输出数据, parameters信息不添加, 因为多对多节点之间的parameters的shape会不一致导致无法比对 + input_data = {} + output_data = {} + for key, value in node0.input_data.items(): + if any(s in key for s in [Const.INPUT, Const.INPUT_ARGS, Const.INPUT_KWARGS]): + input_data[MultiMapping.replace_param_name(key, multi_node_id)] = value + for key, value in node1.output_data.items(): + output_data[MultiMapping.replace_param_name(key, multi_node_id)] = value + multi_node.input_data = input_data + multi_node.output_data = output_data + + return MergedItems(multi_node, node0, node1) + + @staticmethod + def get_dump_data_name(merged_items, full_param_name): + """ + 根据节点参数名称,从融合节点信息中获取此参数的真实数据名称 + Args: + merged_items: 融合节点信息 + full_param_name: 参数名称,例如Module.layer.Linear.forward.0.input.0 + + Returns: 真实数据名称,例如Module.layer.Linear.forward.0.input.0.pt + + """ + try: + _, state = get_name_and_state(full_param_name) + except Exception: + return "-1" + node = merged_items.multi_node + # 如果是融合节点,那么其真实数据的存盘data_name需要从融合节点的首节点和尾节点中获取 + if node.op == NodeOp.multi_collection: + data = merged_items.end_node.output_data \ + if Const.OUTPUT == state \ + else merged_items.start_node.input_data + else: + data = node.output_data \ + if Const.OUTPUT == state \ + else node.input_data + + return data.get(full_param_name, {}).get("data_name", "-1") + + @staticmethod + def _split_mapping_str(x: str): + if Const.COMMA in x: + split_list = x.split(Const.COMMA) + return split_list[0].strip(), split_list[-1].strip() + return (x.strip(),) diff --git a/debug/accuracy_tools/msprobe/visualization/graph/base_node.py b/debug/accuracy_tools/msprobe/visualization/graph/base_node.py index 2642ff1e97ebcc055212d4d776eb7c8a08866dc8..96a16eb8f00c2a7c05d366a735502b86816a2ea1 100644 --- a/debug/accuracy_tools/msprobe/visualization/graph/base_node.py +++ b/debug/accuracy_tools/msprobe/visualization/graph/base_node.py @@ -12,10 +12,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from msprobe.core.overflow_check.level import OverflowLevel -from msprobe.visualization.graph.node_op import NodeOp from msprobe.visualization.utils import GraphConst from msprobe.visualization.builder.msprobe_adapter import format_node_data, compare_data, compare_data_fuzzy +from msprobe.core.common.log import logger class BaseNode: @@ -35,6 +36,8 @@ class BaseNode: self.overflow_level = None self.matched_distributed = {} self.batch_p2p_info = [] + self.rank = 0 + self.parallel_merge_info = [] def __str__(self): info = f'id:\t{self.id}' @@ -86,15 +89,15 @@ class BaseNode: self.matched_node_link = ancestors node.matched_node_link = ancestors - def to_dict(self): + def to_dict(self, compare_mode=None): """ 输出数据 """ result = { 'id': self.id, 'node_type': self.op.value, - 'output_data': format_node_data(self.output_data, self.id), - 'input_data': format_node_data(self.input_data, self.id), + 'output_data': format_node_data(self.output_data, self.id, compare_mode), + 'input_data': format_node_data(self.input_data, self.id, compare_mode), 'upnode': self.upnode.id if self.upnode else 'None', 'subnodes': [node.id for node in self.subnodes], 'matched_node_link': self.matched_node_link, @@ -106,6 +109,8 @@ class BaseNode: result['data'] = self.data if self.matched_distributed: result[GraphConst.MATCHED_DISTRIBUTED] = self.matched_distributed + if self.parallel_merge_info: + result['parallel_merge_info'] = self.parallel_merge_info return result def get_ancestors(self): @@ -114,7 +119,13 @@ class BaseNode: """ ancestors = [] current_node = self.upnode + seen_nodes = set() while current_node: + if current_node.id in seen_nodes: + logger.warning(f'Detected a cycle in the node structure and cannot get node ancestors, ' + f'current node is {current_node.id}.') + return [] + seen_nodes.add(current_node.id) ancestors.append(current_node.id) current_node = current_node.upnode return list(reversed(ancestors)) diff --git a/debug/accuracy_tools/msprobe/visualization/graph/distributed_analyzer.py b/debug/accuracy_tools/msprobe/visualization/graph/distributed_analyzer.py index 5e68d6b2528aea4d6645da2885fa76a7b9bb97b2..a4b709a1ed1e57fd34330e403992d5fdb781c4f5 100644 --- a/debug/accuracy_tools/msprobe/visualization/graph/distributed_analyzer.py +++ b/debug/accuracy_tools/msprobe/visualization/graph/distributed_analyzer.py @@ -107,15 +107,6 @@ class DistributedAnalyzer: return None, None return group_ranks, group_id - @staticmethod - def _get_batch_group_info(node, rank): - for data in node.input_data.values(): - group_id = data.get('group_id') - if group_id is not None: - return group_id - logger.warning(f'The group_id of node {node.id} does not exist, {CANNOT_MATCH}{rank}') - return None - def distributed_match(self): for rank, graph in self.graphs.items(): nodes = graph.node_map @@ -377,7 +368,7 @@ class DistributedAnalyzer: target_api_name = self.config.get(api_name)[0] target_rank = int(id_info[1].replace(Const.RANK, '')) except Exception as e: - logger.warning(f'Failed to parsing batch p2p parameter with error info: {e}.') + logger.warning(f'Failed to parse batch p2p parameter with error info: {e}.') continue target_node = self._get_target_node(rank, unique_group_id, api_name, target_rank, target_api_name) if not target_node: diff --git a/debug/accuracy_tools/msprobe/visualization/graph/graph.py b/debug/accuracy_tools/msprobe/visualization/graph/graph.py index 5ce12d1cadb9aec2cc7c65954bb861b85032212d..f4caec221f4168e73b7414b3493f3d3f6f79265c 100644 --- a/debug/accuracy_tools/msprobe/visualization/graph/graph.py +++ b/debug/accuracy_tools/msprobe/visualization/graph/graph.py @@ -20,9 +20,6 @@ from msprobe.core.common.log import logger from msprobe.core.common.const import Const -MAX_RECUR_LEVEL = 100 - - class Graph: def __init__(self, model_name, data_path='', dump_data=None): self.node_map = {} @@ -67,22 +64,16 @@ class Graph: ancestors_b = node_b.get_ancestors() return node_b, ancestors_n, ancestors_b - @staticmethod - def fuzzy_match(node_n, node_b): - if not node_n or not node_b or not node_n.fuzzy_eq(node_b): + def fuzzy_match(node_n, node_b, check_shape=True): + if not node_n or not node_b: + return None, [], [] + if check_shape and not node_n.fuzzy_eq(node_b): return None, [], [] ancestors_n = node_n.get_ancestors() ancestors_b = node_b.get_ancestors() return node_b, ancestors_n, ancestors_b - @staticmethod - def dfs(node, result): - info = node.to_dict() - result[node.id] = info - for subnode in node.subnodes: - Graph.dfs(subnode, result) - @staticmethod def split_nodes_by_micro_step(nodes): """ @@ -157,7 +148,7 @@ class Graph: """ return self.node_map.get(node_id, None) - def to_dict(self): + def to_dict(self, compare_mode=None): """ 用于数据输出 """ @@ -166,7 +157,7 @@ class Graph: result[GraphConst.JSON_DATA_KEY] = self.data_path result[GraphConst.JSON_NODE_KEY] = {} for node_id in self.node_map: - info = self.node_map.get(node_id).to_dict() + info = self.node_map.get(node_id).to_dict(compare_mode) result[GraphConst.JSON_NODE_KEY][node_id] = info return result diff --git a/debug/accuracy_tools/msprobe/visualization/graph/node_op.py b/debug/accuracy_tools/msprobe/visualization/graph/node_op.py index 33bfa9cc2e34a0960c3ff236a1bd183a5753a0ab..12072fff032ee1e26c5e8274cd1676679d531331 100644 --- a/debug/accuracy_tools/msprobe/visualization/graph/node_op.py +++ b/debug/accuracy_tools/msprobe/visualization/graph/node_op.py @@ -22,9 +22,9 @@ from msprobe.core.common.log import logger class NodeOp(Enum): module = 0 function_api = 1 + multi_collection = 8 api_collection = 9 - @staticmethod def get_node_op(node_name: str): """ @@ -37,5 +37,5 @@ class NodeOp(Enum): pattern = op_patterns[index] if re.match(pattern, node_name): return op - logger.warning(f"Cannot parsing node_name {node_name} into NodeOp, default parsing as module.") + logger.warning(f"Cannot parse node_name {node_name} into NodeOp, default parsing as module.") return NodeOp.module diff --git a/debug/accuracy_tools/msprobe/visualization/graph_service.py b/debug/accuracy_tools/msprobe/visualization/graph_service.py index 75b0014c1c09abb8dfecf285fed5eed3063827a0..69718d9c26fa4e3d8b787a936d01f624e7411915 100644 --- a/debug/accuracy_tools/msprobe/visualization/graph_service.py +++ b/debug/accuracy_tools/msprobe/visualization/graph_service.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -15,90 +15,114 @@ import os import time -import json +from copy import deepcopy +from multiprocessing import cpu_count, Pool from msprobe.core.common.file_utils import (check_file_type, create_directory, FileChecker, check_file_or_directory_path, load_json) from msprobe.core.common.const import FileCheckConst, Const -from msprobe.core.common.utils import CompareException -from msprobe.core.overflow_check.checker import AnomalyDetector +from msprobe.core.common.utils import CompareException, get_dump_mode from msprobe.visualization.compare.graph_comparator import GraphComparator -from msprobe.visualization.utils import GraphConst, check_directory_content -from msprobe.visualization.builder.graph_builder import GraphBuilder, GraphExportConfig +from msprobe.visualization.utils import GraphConst, check_directory_content, SerializableArgs, load_parallel_param, \ + sort_rank_number_strings, check_whether_parallel_merge, validate_parallel_param, extract_rank_number +from msprobe.visualization.builder.graph_builder import GraphBuilder, GraphExportConfig, GraphInfo, BuildGraphTaskInfo from msprobe.core.common.log import logger from msprobe.visualization.graph.node_colors import NodeColors from msprobe.core.compare.layer_mapping import generate_api_mapping_by_layer_mapping from msprobe.core.compare.utils import check_and_return_dir_contents +from msprobe.core.common.utils import detect_framework_by_dump_json from msprobe.visualization.graph.distributed_analyzer import DistributedAnalyzer +from msprobe.visualization.builder.graph_merger import GraphMerger current_time = time.strftime("%Y%m%d%H%M%S") -def _compare_graph(input_param, args): - logger.info('Start building model graphs...') - # 对两个数据进行构图 - dump_path_n = input_param.get('npu_path') - dump_path_b = input_param.get('bench_path') - construct_path_n = FileChecker(os.path.join(dump_path_n, GraphConst.CONSTRUCT_FILE), - FileCheckConst.FILE, FileCheckConst.READ_ABLE).common_check() - construct_path_b = FileChecker(os.path.join(dump_path_b, GraphConst.CONSTRUCT_FILE), - FileCheckConst.FILE, FileCheckConst.READ_ABLE).common_check() - data_path_n = FileChecker(os.path.join(dump_path_n, GraphConst.DUMP_FILE), FileCheckConst.FILE, - FileCheckConst.READ_ABLE).common_check() - data_path_b = FileChecker(os.path.join(dump_path_b, GraphConst.DUMP_FILE), FileCheckConst.FILE, - FileCheckConst.READ_ABLE).common_check() - stack_path_n = FileChecker(os.path.join(dump_path_n, GraphConst.STACK_FILE), FileCheckConst.FILE, - FileCheckConst.READ_ABLE).common_check() - stack_path_b = FileChecker(os.path.join(dump_path_b, GraphConst.STACK_FILE), FileCheckConst.FILE, - FileCheckConst.READ_ABLE).common_check() - graph_n = GraphBuilder.build(construct_path_n, data_path_n, stack_path_n, complete_stack=args.complete_stack) - graph_b = GraphBuilder.build(construct_path_b, data_path_b, stack_path_b, complete_stack=args.complete_stack) - logger.info('Model graphs built successfully, start Comparing graphs...') - # 基于graph、stack和data进行比较 +def _compare_graph(graph_n: GraphInfo, graph_b: GraphInfo, input_param, args): dump_path_param = { - 'npu_json_path': data_path_n, - 'bench_json_path': data_path_b, - 'stack_json_path': stack_path_n, + 'npu_json_path': graph_n.data_path, + 'bench_json_path': graph_b.data_path, + 'stack_json_path': graph_n.stack_path, 'is_print_compare_log': input_param.get("is_print_compare_log", True) } - mapping_dict = None + mapping_dict = {} if args.layer_mapping: - yaml_path = FileChecker(args.layer_mapping, FileCheckConst.FILE, FileCheckConst.READ_ABLE).common_check() try: - mapping_dict = generate_api_mapping_by_layer_mapping(data_path_n, data_path_b, yaml_path) + mapping_dict = generate_api_mapping_by_layer_mapping(graph_n.data_path, graph_b.data_path, + args.layer_mapping) except Exception: logger.warning('The layer mapping file parsing failed, please check file format, mapping is not effective.') - graph_comparator = GraphComparator([graph_n, graph_b], dump_path_param, args, mapping_dict=mapping_dict) + + is_cross_framework = detect_framework_by_dump_json(graph_n.data_path) != \ + detect_framework_by_dump_json(graph_b.data_path) + if is_cross_framework and not args.layer_mapping: + logger.error('The cross_frame graph comparison failed. ' + 'Please specify -lm or --layer_mapping when performing cross_frame graph comparison.') + raise CompareException(CompareException.CROSS_FRAME_ERROR) + + graph_comparator = GraphComparator([graph_n.graph, graph_b.graph], dump_path_param, args, is_cross_framework, + mapping_dict=mapping_dict) graph_comparator.compare() - micro_steps = graph_n.paging_by_micro_step(graph_b) + return graph_comparator + + +def _compare_graph_result(input_param, args): + logger.info('Start building model graphs...') + # 对两个数据进行构图 + graph_n = _build_graph_info(input_param.get('npu_path'), args) + graph_b = _build_graph_info(input_param.get('bench_path'), args) + logger.info('Model graphs built successfully, start Comparing graphs...') + # 基于graph、stack和data进行比较 + graph_comparator = _compare_graph(graph_n, graph_b, input_param, args) + # 增加micro step标记 + micro_steps = graph_n.graph.paging_by_micro_step(graph_b.graph) # 开启溢出检测 if args.overflow_check: - graph_n.overflow_check() - graph_b.overflow_check() + graph_n.graph.overflow_check() + graph_b.graph.overflow_check() + + if args.multi_mapping: + graph_comparator.multi_compare(args.multi_mapping) - return CompareGraphResult(graph_n, graph_b, graph_comparator, micro_steps) + return CompareGraphResult(graph_n.graph, graph_b.graph, graph_comparator, micro_steps) -def _export_compare_graph_result(args, graphs, graph_comparator, micro_steps, - output_file_name=f'compare_{current_time}.vis'): - create_directory(args.output_path) +def _export_compare_graph_result(args, result): + graphs = [result.graph_n, result.graph_b] + graph_comparator = result.graph_comparator + micro_steps = result.micro_steps + output_file_name = result.output_file_name + if not output_file_name: + output_file_name = f'compare_{current_time}.vis' + logger.info(f'Start exporting compare graph result, file name: {output_file_name}...') output_path = os.path.join(args.output_path, output_file_name) task = GraphConst.GRAPHCOMPARE_MODE_TO_DUMP_MODE_TO_MAPPING.get(graph_comparator.ma.compare_mode) export_config = GraphExportConfig(graphs[0], graphs[1], graph_comparator.ma.get_tool_tip(), NodeColors.get_node_colors(graph_comparator.ma.compare_mode), micro_steps, task, - args.overflow_check) - GraphBuilder.to_json(output_path, export_config) - logger.info(f'Model graphs compared successfully, the result file is saved in {output_path}') + args.overflow_check, graph_comparator.ma.compare_mode) + try: + GraphBuilder.to_json(output_path, export_config) + logger.info(f'Exporting compare graph result successfully, the result file is saved in {output_path}') + return '' + except RuntimeError as e: + logger.error(f'Failed to export compare graph result, file: {output_file_name}, error: {e}') + return output_file_name -def _build_graph(dump_path, args): - logger.info('Start building model graph...') +def _build_graph_info(dump_path, args, graph=None): construct_path = FileChecker(os.path.join(dump_path, GraphConst.CONSTRUCT_FILE), FileCheckConst.FILE, FileCheckConst.READ_ABLE).common_check() data_path = FileChecker(os.path.join(dump_path, GraphConst.DUMP_FILE), FileCheckConst.FILE, FileCheckConst.READ_ABLE).common_check() stack_path = FileChecker(os.path.join(dump_path, GraphConst.STACK_FILE), FileCheckConst.FILE, FileCheckConst.READ_ABLE).common_check() - graph = GraphBuilder.build(construct_path, data_path, stack_path, complete_stack=args.complete_stack) + if not graph: + graph = GraphBuilder.build(construct_path, data_path, stack_path, complete_stack=args.complete_stack) + return GraphInfo(graph, construct_path, data_path, stack_path) + + +def _build_graph_result(dump_path, args): + logger.info('Start building model graphs...') + graph = _build_graph_info(dump_path, args).graph + # 增加micro step标记 micro_steps = graph.paging_by_micro_step() # 开启溢出检测 if args.overflow_check: @@ -106,15 +130,132 @@ def _build_graph(dump_path, args): return BuildGraphResult(graph, micro_steps) -def _export_build_graph_result(out_path, graph, micro_steps, overflow_check, - output_file_name=f'build_{current_time}.vis'): - create_directory(out_path) +def _run_build_graph_compare(input_param, args, nr, br): + logger.info(f'Start building graph for {nr}...') + graph_n = _build_graph_info(input_param.get('npu_path'), args) + graph_b = _build_graph_info(input_param.get('bench_path'), args) + logger.info(f'Building graph for {nr} finished.') + return BuildGraphTaskInfo(graph_n, graph_b, nr, br, current_time) + + +def _run_build_graph_single(dump_ranks_path, rank, step, args): + logger.info(f'Start building graph for {rank}...') + dump_path = os.path.join(dump_ranks_path, rank) + output_file_name = f'build_{step}_{rank}_{current_time}.vis' if step else f'build_{rank}_{current_time}.vis' + result = _build_graph_result(dump_path, args) + result.output_file_name = output_file_name + if rank != Const.RANK: + try: + result.rank = int(rank.replace(Const.RANK, "")) + except Exception as e: + logger.error('The folder name format is incorrect, expected rank+number.') + raise CompareException(CompareException.INVALID_PATH_ERROR) from e + logger.info(f'Building graph for step: {step}, rank: {rank} finished.') + return result + + +def _run_graph_compare(graph_task_info, input_param, args, output_file_name): + logger.info(f'Start comparing data for {graph_task_info.npu_rank}...') + graph_n = graph_task_info.graph_info_n + graph_b = graph_task_info.graph_info_b + nr = graph_task_info.npu_rank + graph_comparator = _compare_graph(graph_n, graph_b, input_param, args) + micro_steps = graph_n.graph.paging_by_micro_step(graph_b.graph) + # 开启溢出检测 + if args.overflow_check: + graph_n.graph.overflow_check() + graph_b.graph.overflow_check() + + if args.multi_mapping: + graph_comparator.multi_compare(args.multi_mapping) + + graph_result = CompareGraphResult(graph_n.graph, graph_b.graph, graph_comparator, micro_steps) + graph_result.output_file_name = output_file_name + if nr != Const.RANK: + try: + graph_result.rank = int(nr.replace(Const.RANK, "")) + except Exception as e: + logger.error('The folder name format is incorrect, expected rank+number.') + raise CompareException(CompareException.INVALID_PATH_ERROR) from e + logger.info(f'Comparing data for {graph_task_info.npu_rank} finished.') + return graph_result + + +def _export_build_graph_result(args, result): + out_path = args.output_path + graph = result.graph + micro_steps = result.micro_steps + overflow_check = args.overflow_check + output_file_name = result.output_file_name + if not output_file_name: + output_file_name = f'build_{current_time}.vis' + logger.info(f'Start exporting graph for {output_file_name}...') output_path = os.path.join(out_path, output_file_name) - GraphBuilder.to_json(output_path, GraphExportConfig(graph, micro_steps=micro_steps, overflow_check=overflow_check)) - logger.info(f'Model graph built successfully, the result file is saved in {output_path}') + try: + GraphBuilder.to_json(output_path, GraphExportConfig(graph, micro_steps=micro_steps, + overflow_check=overflow_check)) + logger.info(f'Model graph exported successfully, the result file is saved in {output_path}') + return None + except RuntimeError as e: + logger.error(f'Failed to export model graph, file: {output_file_name}, error: {e}') + return output_file_name + + +def is_real_data_compare(input_param, npu_ranks, bench_ranks): + dump_rank_n = input_param.get('npu_path') + dump_rank_b = input_param.get('bench_path') + has_real_data = False + for nr, br in zip(npu_ranks, bench_ranks): + dump_path_param = { + 'npu_json_path': FileChecker(os.path.join(dump_rank_n, nr, GraphConst.DUMP_FILE), FileCheckConst.FILE, + FileCheckConst.READ_ABLE).common_check(), + 'bench_json_path': FileChecker(os.path.join(dump_rank_b, br, GraphConst.DUMP_FILE), FileCheckConst.FILE, + FileCheckConst.READ_ABLE).common_check() + } + has_real_data |= get_dump_mode(dump_path_param) == Const.ALL + return has_real_data + + +def _mp_compare(input_param, serializable_args, output_file_name, nr, br): + graph_task_info = _run_build_graph_compare(input_param, serializable_args, nr, br) + return _run_graph_compare(graph_task_info, input_param, serializable_args, output_file_name) def _compare_graph_ranks(input_param, args, step=None): + with Pool(processes=max(int((cpu_count() + 1) // 4), 1)) as pool: + def err_call(err): + logger.error(f'Error occurred while comparing graph ranks: {err}') + try: + pool.close() + except OSError as e: + logger.error(f'Error occurred while terminating the pool: {e}') + + serializable_args = SerializableArgs(args) + # 暂存所有rank的graph,用于匹配rank间的分布式节点 + compare_graph_results = _get_compare_graph_results(input_param, serializable_args, step, pool, err_call) + + # 匹配rank间的分布式节点 + if len(compare_graph_results) > 1: + DistributedAnalyzer({obj.rank: obj.graph_n for obj in compare_graph_results}, + args.overflow_check).distributed_match() + DistributedAnalyzer({obj.rank: obj.graph_b for obj in compare_graph_results}, + args.overflow_check).distributed_match() + + export_res_task_list = [] + create_directory(args.output_path) + for result in compare_graph_results: + export_res_task_list.append(pool.apply_async(_export_compare_graph_result, + args=(serializable_args, result), + error_callback=err_call)) + export_res_list = [res.get() for res in export_res_task_list] + if any(export_res_list): + failed_names = list(filter(lambda x: x, export_res_list)) + logger.error(f'Unable to export compare graph results: {failed_names}.') + else: + logger.info('Successfully exported compare graph results.') + + +def _get_compare_graph_results(input_param, serializable_args, step, pool, err_call): dump_rank_n = input_param.get('npu_path') dump_rank_b = input_param.get('bench_path') npu_ranks = sorted(check_and_return_dir_contents(dump_rank_n, Const.RANK)) @@ -123,32 +264,33 @@ def _compare_graph_ranks(input_param, args, step=None): logger.error('The number of ranks in the two runs are different. Unable to match the ranks.') raise CompareException(CompareException.INVALID_PATH_ERROR) compare_graph_results = [] - for nr, br in zip(npu_ranks, bench_ranks): - logger.info(f'Start processing data for {nr}...') - input_param['npu_path'] = os.path.join(dump_rank_n, nr) - input_param['bench_path'] = os.path.join(dump_rank_b, br) - output_file_name = f'compare_{step}_{nr}_{current_time}.vis' if step else f'compare_{nr}_{current_time}.vis' - result = _compare_graph(input_param, args) - result.output_file_name = output_file_name - if nr != Const.RANK: - try: - result.rank = int(nr.replace(Const.RANK, "")) - except Exception as e: - logger.error('The folder name format is incorrect, expected rank+number.') - raise CompareException(CompareException.INVALID_PATH_ERROR) from e - # 暂存所有rank的graph,用于匹配rank间的分布式节点 - compare_graph_results.append(result) - - # 匹配rank间的分布式节点 - if len(compare_graph_results) > 1: - DistributedAnalyzer({obj.rank: obj.graph_n for obj in compare_graph_results}, - args.overflow_check).distributed_match() - DistributedAnalyzer({obj.rank: obj.graph_b for obj in compare_graph_results}, - args.overflow_check).distributed_match() - - for result in compare_graph_results: - _export_compare_graph_result(args, [result.graph_n, result.graph_b], result.graph_comparator, - result.micro_steps, output_file_name=result.output_file_name) + if is_real_data_compare(input_param, npu_ranks, bench_ranks): + mp_task_dict = {} + for nr, br in zip(npu_ranks, bench_ranks): + input_param['npu_path'] = os.path.join(dump_rank_n, nr) + input_param['bench_path'] = os.path.join(dump_rank_b, br) + output_file_name = f'compare_{step}_{nr}_{current_time}.vis' if step else f'compare_{nr}_{current_time}.vis' + input_param_copy = deepcopy(input_param) + mp_task_dict[output_file_name] = pool.apply_async(_run_build_graph_compare, + args=(input_param_copy, serializable_args, nr, br), + error_callback=err_call) + + mp_res_dict = {k: v.get() for k, v in mp_task_dict.items()} + for output_file_name, mp_res in mp_res_dict.items(): + compare_graph_results.append(_run_graph_compare(mp_res, input_param, serializable_args, output_file_name)) + else: + compare_graph_tasks = [] + for nr, br in zip(npu_ranks, bench_ranks): + input_param['npu_path'] = os.path.join(dump_rank_n, nr) + input_param['bench_path'] = os.path.join(dump_rank_b, br) + output_file_name = f'compare_{step}_{nr}_{current_time}.vis' if step else f'compare_{nr}_{current_time}.vis' + input_param_copy = deepcopy(input_param) + compare_graph_tasks.append(pool.apply_async(_mp_compare, + args=(input_param_copy, serializable_args, output_file_name, nr, + br), + error_callback=err_call)) + compare_graph_results = [task.get() for task in compare_graph_tasks] + return compare_graph_results def _compare_graph_steps(input_param, args): @@ -159,7 +301,7 @@ def _compare_graph_steps(input_param, args): bench_steps = sorted(check_and_return_dir_contents(dump_step_b, Const.STEP)) if npu_steps != bench_steps: - logger.error('The number of steps in the two runs are different. Unable to match the steps.') + logger.error('The number of steps in the two runs is different. Unable to match the steps.') raise CompareException(CompareException.INVALID_PATH_ERROR) for folder_step in npu_steps: @@ -167,33 +309,51 @@ def _compare_graph_steps(input_param, args): input_param['npu_path'] = os.path.join(dump_step_n, folder_step) input_param['bench_path'] = os.path.join(dump_step_b, folder_step) - _compare_graph_ranks(input_param, args, step=folder_step) + _compare_graph_ranks(input_param, args, step=folder_step) if not args.parallel_merge \ + else _compare_graph_ranks_parallel(input_param, args, step=folder_step) def _build_graph_ranks(dump_ranks_path, args, step=None): - ranks = sorted(check_and_return_dir_contents(dump_ranks_path, Const.RANK)) - build_graph_results = [] - for rank in ranks: - logger.info(f'Start processing data for {rank}...') - dump_path = os.path.join(dump_ranks_path, rank) - output_file_name = f'build_{step}_{rank}_{current_time}.vis' if step else f'build_{rank}_{current_time}.vis' - result = _build_graph(dump_path, args) - result.output_file_name = output_file_name - if rank != Const.RANK: + ranks = sort_rank_number_strings(check_and_return_dir_contents(dump_ranks_path, Const.RANK)) + serializable_args = SerializableArgs(args) + with Pool(processes=max(int((cpu_count() + 1) // 4), 1)) as pool: + def err_call(err): + logger.error(f'Error occurred while comparing graph ranks: {err}') try: - result.rank = int(rank.replace(Const.RANK, "")) - except Exception as e: - logger.error('The folder name format is incorrect, expected rank+number.') - raise CompareException(CompareException.INVALID_PATH_ERROR) from e - build_graph_results.append(result) - - if len(build_graph_results) > 1: - DistributedAnalyzer({obj.rank: obj.graph for obj in build_graph_results}, - args.overflow_check).distributed_match() - - for result in build_graph_results: - _export_build_graph_result(args.output_path, result.graph, result.micro_steps, args.overflow_check, - result.output_file_name) + pool.close() + except OSError as e: + logger.error(f'Error occurred while terminating the pool: {e}') + + build_graph_tasks = [] + for rank in ranks: + build_graph_tasks.append(pool.apply_async(_run_build_graph_single, + args=(dump_ranks_path, rank, step, serializable_args), + error_callback=err_call)) + build_graph_results = [task.get() for task in build_graph_tasks] + + if args.parallel_params: + validate_parallel_param(args.parallel_params[0], dump_ranks_path) + build_graph_results = GraphMerger(build_graph_results, args.parallel_params[0]).merge_graph() + + if len(build_graph_results) > 1 and not args.parallel_merge: + DistributedAnalyzer({obj.rank: obj.graph for obj in build_graph_results}, + args.overflow_check).distributed_match() + + create_directory(args.output_path) + export_build_graph_tasks = [] + for i, result in enumerate(build_graph_results): + if args.parallel_params: + result.output_file_name = f'build_{step}_merged{i}_{current_time}.vis' \ + if step else f'build_merged{i}_{current_time}.vis' + export_build_graph_tasks.append(pool.apply_async(_export_build_graph_result, + args=(serializable_args, result), + error_callback=err_call)) + export_build_graph_result = [task.get() for task in export_build_graph_tasks] + if any(export_build_graph_result): + failed_names = list(filter(lambda x: x, export_build_graph_result)) + logger.error(f'Unable to export build graph results: {failed_names}.') + else: + logger.info(f'Successfully exported build graph results.') def _build_graph_steps(dump_steps_path, args): @@ -204,12 +364,82 @@ def _build_graph_steps(dump_steps_path, args): _build_graph_ranks(dump_ranks_path, args, step) +def _compare_and_export_graph(graph_task_info, input_param, args, output_file_name): + result = _run_graph_compare(graph_task_info, input_param, args, output_file_name) + return _export_compare_graph_result(args, result) + + +def _compare_graph_ranks_parallel(input_param, args, step=None): + args.fuzzy_match = True + npu_path = input_param.get('npu_path') + bench_path = input_param.get('bench_path') + ranks_n = sort_rank_number_strings(check_and_return_dir_contents(npu_path, Const.RANK)) + ranks_b = sort_rank_number_strings(check_and_return_dir_contents(bench_path, Const.RANK)) + parallel_params = load_parallel_param(input_param) + if len(parallel_params) != 2: + raise RuntimeError('Parallel params error in compare graph!') + validate_parallel_param(parallel_params[0], npu_path) + validate_parallel_param(parallel_params[1], bench_path, '[Bench]') + serializable_args = SerializableArgs(args) + + with Pool(processes=max(int((cpu_count() + 1) // 4), 1)) as pool: + def err_call(err): + logger.error(f'Error occurred while comparing graph ranks: {err}') + try: + pool.close() + except OSError as e: + logger.error(f'Error occurred while terminating the pool: {e}') + + # 1.并行构图 + build_graph_tasks_n = [] + build_graph_tasks_b = [] + for rank in ranks_n: + build_graph_tasks_n.append(pool.apply_async(_run_build_graph_single, + args=(npu_path, rank, step, serializable_args), + error_callback=err_call)) + for rank in ranks_b: + build_graph_tasks_b.append(pool.apply_async(_run_build_graph_single, + args=(bench_path, rank, step, serializable_args), + error_callback=err_call)) + graph_results_n = [task.get() for task in build_graph_tasks_n] + graph_results_b = [task.get() for task in build_graph_tasks_b] + + # 2.图合并 + build_graph_results_n = GraphMerger(graph_results_n, parallel_params[0]).merge_graph() + build_graph_results_b = GraphMerger(graph_results_b, parallel_params[1], True).merge_graph() + if len(build_graph_results_n) != len(build_graph_results_b): + raise RuntimeError(f'Parallel merge failed because the dp of npu: {len(build_graph_results_n)} ' + f'is inconsistent with that of bench: {len(build_graph_results_b)}!') + # 3.并行图比对和输出 + export_res_task_list = [] + create_directory(args.output_path) + for i, result_n in enumerate(build_graph_results_n): + graph_n = result_n.graph + graph_b = build_graph_results_b[i].graph + graph_task_info = BuildGraphTaskInfo( + _build_graph_info(os.path.join(npu_path, f'rank{graph_n.root.rank}'), args, graph_n), + _build_graph_info(os.path.join(bench_path, f'rank{graph_b.root.rank}'), args, graph_b), + f'rank{graph_n.root.rank}', f'rank{graph_b.root.rank}', current_time) + output_file_name = f'compare_{step}_merged{i}_{current_time}.vis' \ + if step else f'compare_merged{i}_{current_time}.vis' + export_res_task_list.append(pool.apply_async(_compare_and_export_graph, + args=(graph_task_info, input_param, serializable_args, + output_file_name), + error_callback=err_call)) + export_res_list = [res.get() for res in export_res_task_list] + if any(export_res_list): + failed_names = list(filter(lambda x: x, export_res_list)) + logger.error(f'Unable to export compare graph results: {", ".join(failed_names)}.') + else: + logger.info('Successfully exported compare graph results.') + + def _graph_service_parser(parser): parser.add_argument("-i", "--input_path", dest="input_path", type=str, help=" The compare input path, a dict json.", required=True) parser.add_argument("-o", "--output_path", dest="output_path", type=str, help=" The compare task result out path.", required=True) - parser.add_argument("-lm", "--layer_mapping", dest="layer_mapping", type=str, + parser.add_argument("-lm", "--layer_mapping", dest="layer_mapping", type=str, nargs='?', const=True, help=" The layer mapping file path.", required=False) parser.add_argument("-oc", "--overflow_check", dest="overflow_check", action="store_true", help=" whether open overflow_check for graph.", required=False) @@ -217,12 +447,16 @@ def _graph_service_parser(parser): help=" Whether to perform a fuzzy match on the api name.", required=False) parser.add_argument("-cs", "--complete_stack", dest="complete_stack", action="store_true", help=" Whether to use complete stack information.", required=False) + parser.add_argument("-mm", "--multi_mapping", dest="multi_mapping", type=str, + help=" The multi mapping file path.", required=False) def _graph_service_command(args): input_param = load_json(args.input_path) npu_path = input_param.get("npu_path") bench_path = input_param.get("bench_path") + args.parallel_merge = check_whether_parallel_merge(input_param) + args.parallel_params = load_parallel_param(input_param) if args.parallel_merge else None check_file_or_directory_path(npu_path, isdir=True) if bench_path: check_file_or_directory_path(bench_path, isdir=True) @@ -233,21 +467,29 @@ def _graph_service_command(args): elif content == GraphConst.STEPS: _build_graph_steps(npu_path, args) else: - result = _build_graph(npu_path, args) - _export_build_graph_result(args.output_path, result.graph, result.micro_steps, args.overflow_check) + result = _build_graph_result(npu_path, args) + create_directory(args.output_path) + file_name = _export_build_graph_result(args, result) + if file_name: + logger.error('Failed to export model build graph.') elif check_file_type(npu_path) == FileCheckConst.DIR and check_file_type(bench_path) == FileCheckConst.DIR: content_n = check_directory_content(npu_path) content_b = check_directory_content(bench_path) if content_n != content_b: raise ValueError('The directory structures of npu_path and bench_path are inconsistent.') if content_n == GraphConst.RANKS: - _compare_graph_ranks(input_param, args) + if args.parallel_merge: + _compare_graph_ranks_parallel(input_param, args) + else: + _compare_graph_ranks(input_param, args) elif content_n == GraphConst.STEPS: _compare_graph_steps(input_param, args) else: - result = _compare_graph(input_param, args) - _export_compare_graph_result(args, [result.graph_n, result.graph_b], - result.graph_comparator, result.micro_steps) + result = _compare_graph_result(input_param, args) + create_directory(args.output_path) + file_name = _export_compare_graph_result(args, result) + if file_name: + logger.error('Failed to export model compare graph.') else: logger.error("The npu_path or bench_path should be a folder.") raise CompareException(CompareException.INVALID_COMPARE_MODE) @@ -280,7 +522,7 @@ class CompareGraphResult: class BuildGraphResult: - def __init__(self, graph, micro_steps, rank=0, output_file_name=''): + def __init__(self, graph, micro_steps=0, rank=0, output_file_name=''): self.graph = graph self.micro_steps = micro_steps self.rank = rank diff --git a/debug/accuracy_tools/msprobe/visualization/utils.py b/debug/accuracy_tools/msprobe/visualization/utils.py index 623bcd11c45f1ff8e9c283d30a982af239706ce4..5a08921392dac136b0437f878a8b30710e113b00 100644 --- a/debug/accuracy_tools/msprobe/visualization/utils.py +++ b/debug/accuracy_tools/msprobe/visualization/utils.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -16,9 +16,12 @@ import os import re import json +import pickle from msprobe.core.common.file_utils import FileOpen from msprobe.core.common.const import CompareConst, Const -from msprobe.core.compare.acc_compare import Comparator, ModeConfig +from msprobe.core.common.log import logger +from msprobe.core.common.exceptions import MsprobeException +from msprobe.core.compare.utils import check_and_return_dir_contents def load_json_file(file_path): @@ -42,23 +45,6 @@ def load_data_json_file(file_path): return load_json_file(file_path).get(GraphConst.DATA_KEY, {}) -def save_json_file(file_path, data): - """ - 保存json文件 - """ - with FileOpen(file_path, 'w') as f: - f.write(json.dumps(data, indent=4)) - - -def get_csv_df(stack_mode, csv_data, compare_mode): - """ - 调用acc接口写入csv - """ - dump_mode = GraphConst.GRAPHCOMPARE_MODE_TO_DUMP_MODE_TO_MAPPING.get(compare_mode) - mode_config = ModeConfig(stack_mode=stack_mode, dump_mode=dump_mode) - return Comparator(mode_config).make_result_table(csv_data) - - def str2float(percentage_str): """ 百分比字符串转换转换为浮点型 @@ -73,14 +59,6 @@ def str2float(percentage_str): return 0 -def is_integer(s): - try: - int(s) - return True - except Exception: - return False - - def check_directory_content(input_path): """ 检查input_path内容, 是否全是step{数字}命名的文件夹(例如step0), 或者全是rank{数字}命名的文件夹(例如rank0), 或者全是文件 @@ -126,6 +104,73 @@ def check_directory_content(input_path): "all rank{number} named folders (such as rank0), or all files.") +def extract_rank_number(rank_str): + try: + return int(rank_str[4:]) + except ValueError: + return 0 + + +def sort_rank_number_strings(rank_number_strings): + sorted_list = sorted(rank_number_strings, key=extract_rank_number) + return sorted_list + + +def check_whether_parallel_merge(input_param): + parallel_merge = input_param.get("parallel_merge") + if not isinstance(parallel_merge, dict) or not parallel_merge: + return False + if not parallel_merge.get('npu'): + return False + return True + + +def load_parallel_param(input_param): + parallel_merge = input_param.get("parallel_merge", {}) + config_n = parallel_merge.get('npu', {}) + config_b = parallel_merge.get('bench', {}) + return (ParallelParam(config_n.get('rank_size'), config_n.get('tp'), config_n.get('pp')),) if not config_b else \ + (ParallelParam(config_n.get('rank_size'), config_n.get('tp'), config_n.get('pp')), + ParallelParam(config_b.get('rank_size'), config_b.get('tp'), config_b.get('pp'))) + + +def validate_parallel_param(parallel_param, dump_path, log_prefix='[NPU]'): + params = [parallel_param.tp, parallel_param.pp, parallel_param.rank_size] + ranks = check_and_return_dir_contents(dump_path, Const.RANK) + if len(ranks) != parallel_param.rank_size: + logger.error(f'{log_prefix} The parallel param "rank_size" error, ' + f'you set {parallel_param.rank_size} but expected {len(ranks)}.') + raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR) + if any(x is None for x in params): + logger.error(f'{log_prefix} The parallel params "tp/pp/rank_size" must not be null!') + raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR) + if any(x <= 0 for x in params): + logger.error(f'{log_prefix} The parallel params "tp/pp/rank_size" must be greater than 0!') + raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR) + if parallel_param.tp > parallel_param.rank_size: + logger.error(f'{log_prefix} The parallel param "tp" must be less than or equal to "rank_size"!') + raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR) + if parallel_param.pp > parallel_param.rank_size: + logger.error(f'{log_prefix} The parallel param "pp" must be less than or equal to "rank_size"!') + raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR) + if parallel_param.rank_size % parallel_param.tp != 0: + logger.error(f'{log_prefix} The parallel param "rank_size" must be divisible by "tp"!') + raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR) + if parallel_param.rank_size % parallel_param.pp != 0: + logger.error(f'{log_prefix} The parallel param "rank_size" must be divisible by "pp"!') + raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR) + if parallel_param.tp * parallel_param.pp > parallel_param.rank_size: + logger.error(f'{log_prefix} The parallel params "tp * pp" must be less than or equal to "rank_size"!') + raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR) + + +class ParallelParam: + def __init__(self, rank_size, tp, pp): + self.rank_size = rank_size + self.tp = tp + self.pp = pp + + class ToolTip: MAX_DIFF = 'NPU与标杆API统计信息比对,最大值的差值' MIN_DIFF = 'NPU与标杆API统计信息比对,最小值的差值' @@ -143,14 +188,12 @@ class ToolTip: '当最大相对误差越接近0表示其计算的误差越小。' '当dump数据中存在0或Nan时,比对结果中最大相对误差则出现inf或Nan的情况,属于正常现象' ) - SMALL_VALUE_TIP = '{}, 由于{}小于{}, 建议不参考此相对误差,请参考绝对误差' class GraphConst: CONSTRUCT_FILE = 'construct.json' DUMP_FILE = 'dump.json' STACK_FILE = 'stack.json' - GRAPH_FILE = 'graph.vis' ERROR_KEY = 'error_key' SUMMARY_COMPARE = 0 MD5_COMPARE = 1 @@ -164,35 +207,24 @@ class GraphConst: JSON_DATA_KEY = 'dump_data_dir' JSON_TASK_KEY = 'task' DATA_KEY = 'data' - REAL_DATA_TH = 0.1 - MAX_RELATIVE_ERR_TH = 0.5 ROUND_TH = 6 JSON_INDEX_KEY = 'precision_index' MATCHED_DISTRIBUTED = 'matched_distributed' OVERFLOW_LEVEL = 'overflow_level' MAX_INDEX_KEY = 1 MIN_INDEX_KEY = 0 - SUGGEST_KEY = 'text' - TAG_NA = 'na' - OUTPUT_INDEX_TWO = -2 - OUTPUT_INDEX_THREE = -3 - OUTPUT_MIN_LEN = 3 INPUT = '.input.' OUTPUT = '.output.' STR_MAX_LEN = 50 - SMALL_VALUE = 1e-3 MD5_INDEX_LIST = [CompareConst.RESULT] - REAL_DATA_INDEX_LIST = [CompareConst.COSINE, CompareConst.MAX_ABS_ERR, CompareConst.MAX_RELATIVE_ERR, - CompareConst.ONE_THOUSANDTH_ERR_RATIO, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO] - SUMMARY_INDEX_LIST = [CompareConst.MAX_DIFF, CompareConst.MIN_DIFF, CompareConst.MEAN_DIFF, - CompareConst.NORM_DIFF, CompareConst.MAX_RELATIVE_ERR, CompareConst.MIN_RELATIVE_ERR, - CompareConst.MEAN_RELATIVE_ERR, CompareConst.NORM_RELATIVE_ERR] - VALUE_INDEX_LIST = [Const.MAX, Const.MIN, Const.MEAN, Const.NORM] + REAL_DATA_INDEX_LIST = CompareConst.ALL_COMPARE_INDEX + SUMMARY_INDEX_LIST = CompareConst.SUMMARY_COMPARE_INDEX APIS_BETWEEN_MODULES = 'Apis_Between_Modules' + APIS_BETWEEN_MODULES_ALL_RANKS = 'Apis_Between_Modules_All_Ranks' + MERGE_NODES = 'Merged_Nodes' NULL = 'null' NONE = 'None' VALUE = 'value' - BRACE = '{}' DESCRIPTION = 'description' COLORS = 'Colors' MICRO_STEPS = 'MicroSteps' @@ -223,3 +255,30 @@ class GraphConst: OP = 'op' PEER = 'peer' GROUP_ID = 'group_id' + + UNCERTAINTY_THRESHOLD = 1e-6 + REDUCE_OPERATIONS = ['reduce_scatter', 'all_reduce'] + + IGNORE_PRECISION_INDEX = {'empty', 'empty_like', 'empty_with_format', 'new_empty_strided', 'new_empty', + 'empty_strided'} + + +def is_serializable(obj): + """ + Check if an object is serializable + """ + try: + pickle.dumps(obj) + return True + except (pickle.PicklingError, AttributeError, TypeError): + return False + except Exception as e: + logger.error('Unexpected error occurred while pickling obj.') + raise RuntimeError('Unexpected error occurred while pickling obj.') from e + + +class SerializableArgs: + def __init__(self, args): + for k, v in vars(args).items(): + if is_serializable(v): + setattr(self, k, v) diff --git a/debug/accuracy_tools/setup.py b/debug/accuracy_tools/setup.py index 2da7fcf667765a841b9db1bbf5628fad5b1cf8a9..21b389886c42b0f8e3a837295050d4fee9344d6d 100644 --- a/debug/accuracy_tools/setup.py +++ b/debug/accuracy_tools/setup.py @@ -14,7 +14,7 @@ # limitations under the License. -__version__ = '1.2.2' +__version__ = '1.3.0' import subprocess import platform @@ -24,17 +24,19 @@ import setuptools INSTALL_REQUIRED = [ "wheel", "einops", - "numpy < 2.0", + "numpy >=1.23.0, < 2.0", "pandas >= 1.3.5, < 2.1", "pyyaml", "rich", "tqdm", - "openpyxl", - "pyopenssl", + "openpyxl >= 3.0.6", + "pyopenssl==24.2.1", "twisted", "matplotlib", "tensorboard", - "tabulate" + "tabulate", + "pwinput", + "psutil" ] EXCLUDE_PKGS = [ diff --git a/debug/resources/training_process.png b/debug/resources/training_process.png new file mode 100644 index 0000000000000000000000000000000000000000..e1cf2f20471624cd86edbf45444bb431086d6065 Binary files /dev/null and b/debug/resources/training_process.png differ diff --git a/dynolog_npu/README.md b/dynolog_npu/README.md deleted file mode 100644 index 9cc015e66c656c65fa48ad73a8246487a2016bef..0000000000000000000000000000000000000000 --- a/dynolog_npu/README.md +++ /dev/null @@ -1,148 +0,0 @@ -# Ascend Extension for dynolog - -## 安装方式 - -### 1. clone 代码 - -```bash -git clone https://gitee.com/ascend/mstt.git -``` - -### 2. 安装依赖 -dynolog的编译依赖,确保安装了以下依赖: - - - - - - - - - - - - - -
Language - Toolchain -
C++ - gcc 8.5.0+ -
Rust - Rust 1.58.1 (1.56+ required for clap dependency) -
- -- 安装rust - -```bash -curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh - -source $HOME/.cargo/env -``` - -- 安装ninja - -```bash -# debian -sudo apt-get install -y cmake ninja-build - -# centos -sudo yum install -y cmake ninja -``` - -### 3. 编译 - -默认编译生成dyno和dynolog二进制文件, -t参数可以支持将二进制文件打包成deb包或rpm包. - -```bash -# 编译dyno和dynolog二进制文件 -bash scripts/build.sh - -# 编译deb包, 当前支持amd64和aarch64平台, 默认为amd64, 编译aarch64平台需要修改third_party/dynolog/scripts/debian/control文件中的Architecture改为aarch64 -bash scripts/build.sh -t deb - -# 编译rpm包, 当前只支持amd64平台 -bash scripts/build.sh -t rpm -``` - -## 使用方式 - -### Profiler trace dump功能 -Profiler trace dump功能基于dynolog开发,实现类似于动态profiling的动态触发Ascend Torch Profiler采集profiling的功能。用户基于dyno CLI命令行可以动态触发指定节点的训练进程trace dump。 - -- 查看nputrace支持的命令和帮助 - -```bash -dyno nputrace --help -``` - -- nputrace使用方式 - -```bash -dyno nputrace [SUBCOMMANDS] --log-file -``` - -nputrace子命令支持的参数选项 - -| 子命令 | 参数类型 | 说明 | -|-------|-------|-------| -| record_shapes | action | 是否采集算子的InputShapes和InputTypes,设置参数采集,默认不采集 | -| profile_memory | action | 是否采集算子内存信息,设置参数采集,默认不采集 | -| with_stack | action | 是否采集Python调用栈,设置参数采集,默认不采集 | -| with_flops | action | 是否采集算子flops,设置参数采集,默认不采集 | -| with_modules | action | 是否采集modules层级的Python调用栈,设置参数采集,默认不采集 | -| analyse | action | 采集后是否自动解析,设置参数解析,默认不解析 | -| l2_cache | action | 是否采集L2 Cache数据,设置参数采集,默认不采集 | -| op_attr | action | 是否采集算子属性信息,设置参数采集,默认不采集 | -| data_simplification | String | 解析完成后是否数据精简,可选值范围[`true`, `false`],默认值`true` | -| activities | String | 控制CPU、NPU事件采集范围,可选值范围[`CPU,NPU`, `NPU,CPU`, `CPU`, `NPU`],默认值`CPU,NPU` | -| profiler_level | String | 控制profiler的采集等级,可选值范围[`Level_none`, `Level0`, `Level1`, `Level2`],默认值`Level0`| -| aic_metrics | String | AI Core的性能指标采集项,可选值范围[`AiCoreNone`, `PipeUtilization`, `ArithmeticUtilization`, `Memory`, `MemoryL0`, `ResourceConflictRatio`, `MemoryUB`, `L2Cache`, `MemoryAccess`],默认值`AiCoreNone`| -| export_type | String | profiler解析导出数据的类型,可选值范围[`Text`, `Db`],默认值`Text`| -| gc_detect_threshold | Option | GC检测阈值,单位ms,只采集超过阈值的GC事件。该参数为可选参数,默认不设置时不开启GC检测 | - -- nputrace示例命令 - -```bash -# 示例1:采集框架、CANN和device数据,同时采集完后自动解析以及解析完成不做数据精简,落盘路径为/tmp/profile_data -dyno nputrace --activities CPU,NPU --analyse --data_simplification false --log-file /tmp/profile_data - -# 示例2:只采集CANN和device数据,同时采集完后自动解析以及解析完成后开启数据精简,落盘路径为/tmp/profile_data -dyno nputrace --activities NPU --analyse --data_simplification true --log-file /tmp/profile_data - -# 示例3:只采集CANN和device数据,只采集不解析,落盘路径为/tmp/profile_data -dyno nputrace --activities NPU --log-file /tmp/profile_data -``` - -### NPU Monitor功能 -NPU Monitor基于MSPTI/MSTX能力开发,实现了轻量级在线监控能力,能够用于性能问题的初步定位。 - -```bash -dyno npu-monitor --help -``` - -- npu-monitor使用方式 - -```bash -dyno npu-monitor [SUBCOMMANDS] -``` - -npu-monitor子命令支持的参数选项 -| 子命令 | 参数类型 | 说明 | -|-------|-------|-------| -| npu_monitor_start | action | 开启性能监控,设置参数开启,默认不采集 | -| npu_monitor_stop | action | 停止性能监控,设置参数开启,默认不采集 | -| report_interval_s | int | 性能监控数据上报周期,单位s,需要在启动时设置。默认值60 | -| mspti_activity_kind | String | 性能监控数据上报数据类型,可以设置单个或多个,多个类型以逗号分隔,需要在启动时设置。可选值范围[`Marker`, `Kernel`, `API`, `Hccl`, `Memory`, `MemSet`, `MemCpy`] , 默认值`Marker`| - -- npu-monitor示例命令 - -```bash -# 示例1:开启性能监控,使用默认配置 -dyno npu-monitor --npu_monitor_start - -# 示例2:暂停性能监控 -dyno npu-monitor --npu_monitor_stop - -# 示例3:开启性能监控,上报周期30s, 上报数据类型Marker和Kernel -dyno npu-monitor --npu_monitor_start 30 --mspti_activity_kind Marker,Kernel -``` \ No newline at end of file diff --git a/dynolog_npu/dynolog_npu/cli/src/commands/mod.rs b/dynolog_npu/dynolog_npu/cli/src/commands/mod.rs deleted file mode 100644 index 18950d3c1a01d972db58a614a46f08176b02c725..0000000000000000000000000000000000000000 --- a/dynolog_npu/dynolog_npu/cli/src/commands/mod.rs +++ /dev/null @@ -1,18 +0,0 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. -// -// This source code is licensed under the MIT license found in the -// LICENSE file in the root directory of this source tree. - -// Export all command submodules to be used in main.rs -// Note: This "intermediate" commands module is purely for organizational purposes. -// This allows for a clear distinction between the command dispatching code and the command -// handling code. Additionally, explicitly "exporting" all the command modules here allows -// us to avoid having to explicitly list all the command modules in main.rs. - -pub mod dcgm; -pub mod gputrace; -pub mod nputrace; -pub mod npumonitor; -pub mod status; -pub mod version; -// ... add new command modules here \ No newline at end of file diff --git a/dynolog_npu/dynolog_npu/cli/src/commands/npumonitor.rs b/dynolog_npu/dynolog_npu/cli/src/commands/npumonitor.rs deleted file mode 100644 index 1edfaea5939f5cee5df8618720d1bfa16d0071b5..0000000000000000000000000000000000000000 --- a/dynolog_npu/dynolog_npu/cli/src/commands/npumonitor.rs +++ /dev/null @@ -1,59 +0,0 @@ -use std::net::TcpStream; - -use anyhow::Result; - -#[path = "utils.rs"] -mod utils; - -#[derive(Debug)] -pub struct NpuMonitorConfig { - pub npu_monitor_start: bool, - pub npu_monitor_stop: bool, - pub report_interval_s: u32, - pub mspti_activity_kind: String, -} - -impl NpuMonitorConfig { - fn config(&self) -> String { - format!( - r#" -NPU_MONITOR_START={} -NPU_MONITOR_STOP={} -REPORT_INTERVAL_S={} -MSPTI_ACTIVITY_KIND={}"#, - self.npu_monitor_start, - self.npu_monitor_stop, - self.report_interval_s, - self.mspti_activity_kind - ) - } -} - -pub fn run_npumonitor( - client: TcpStream, - config: NpuMonitorConfig, -) -> Result<()> { - let config_str = config.config(); - println!("Npu monitor config = \n{}", config_str); - let config_str = config_str.replace('\n', "\\n"); - - let request_json = format!( - r#" -{{ - "fn": "setKinetOnDemandRequest", - "config": "{}", - "job_id": 0, - "pids": [0], - "process_limit": 3 -}}"#, - config_str - ); - - utils::send_msg(&client, &request_json).expect("Error sending message to service"); - - let resp_str = utils::get_resp(&client).expect("Unable to decode output bytes"); - - println!("response = {}", resp_str); - - Ok(()) -} diff --git a/dynolog_npu/dynolog_npu/cli/src/commands/nputrace.rs b/dynolog_npu/dynolog_npu/cli/src/commands/nputrace.rs deleted file mode 100644 index 4bf7132de338d8eee0de556449269712617772e2..0000000000000000000000000000000000000000 --- a/dynolog_npu/dynolog_npu/cli/src/commands/nputrace.rs +++ /dev/null @@ -1,242 +0,0 @@ -use std::net::TcpStream; - -use anyhow::Result; -use serde_json::Value; - -#[path = "utils.rs"] -mod utils; - -#[derive(Debug)] -pub enum NpuTraceTriggerConfig { - DurationBased { - profile_start_time: u64, - duration_ms: u64, - }, - IterationBased { - start_step: u64, - iterations: i64, - }, -} - -impl NpuTraceTriggerConfig { - fn config(&self) -> String { - match *self { - NpuTraceTriggerConfig::DurationBased { - profile_start_time, - duration_ms, - } => format!( - "PROFILE_START_TIME={}\nACTIVITIES_DURATION_MSECS={}", - profile_start_time, duration_ms - ), - NpuTraceTriggerConfig::IterationBased { - start_step, - iterations, - } => format!( - r#"PROFILE_START_ITERATION=0 -PROFILE_START_STEP={} -ACTIVITIES_ITERATIONS={}"#, - start_step, iterations - ), - } - } -} - -// torch npu profiler config -#[derive(Debug)] -pub struct NpuTraceOptions { - pub record_shapes: bool, - pub profile_memory: bool, - pub with_stack: bool, - pub with_flops: bool, - pub with_modules: bool, - pub activities: String, - pub analyse: bool, - pub profiler_level: String, - pub aic_metrics: String, - pub l2_cache: bool, - pub op_attr: bool, - pub gc_detect_threshold: Option, - pub data_simplification: String, - pub export_type: String, -} - -impl NpuTraceOptions { - fn config(&self) -> String { - format!( - r#" -PROFILE_RECORD_SHAPES={} -PROFILE_PROFILE_MEMORY={} -PROFILE_WITH_STACK={} -PROFILE_WITH_FLOPS={} -PROFILE_WITH_MODULES={} -PROFILE_ACTIVITIES={} -PROFILE_ANALYSE={} -PROFILE_PROFILER_LEVEL={} -PROFILE_AIC_METRICS={} -PROFILE_L2_CACHE={} -PROFILE_OP_ATTR={} -PROFILE_GC_DETECT_THRESHOLD={} -PROFILE_DATA_SIMPLIFICATION={} -PROFILE_EXPORT_TYPE={}"#, - self.record_shapes, - self.profile_memory, - self.with_stack, - self.with_flops, - self.with_modules, - self.activities, - self.analyse, - self.profiler_level, - self.aic_metrics, - self.l2_cache, - self.op_attr, - self.gc_detect_threshold.map_or("None".to_string(), |v| v.to_string()), - self.data_simplification, - self.export_type - ) - } -} - -#[derive(Debug)] -pub struct NpuTraceConfig { - pub log_file: String, - pub trigger_config: NpuTraceTriggerConfig, - pub trace_options: NpuTraceOptions, -} - -impl NpuTraceConfig { - fn config(&self) -> String { - format!( - "ACTIVITIES_LOG_FILE={}\n{}{}", - self.log_file, - self.trigger_config.config(), - self.trace_options.config() - ) - } -} - -pub fn run_nputrace( - client: TcpStream, - job_id: u64, - pids: &str, - process_limit: u32, - config: NpuTraceConfig, -) -> Result<()> { - let config_str = config.config(); - println!("NpuTrace config = \n{}", config_str); - let config_str = config_str.replace('\n', "\\n"); - - let request_json = format!( - r#" -{{ - "fn": "setKinetOnDemandRequest", - "config": "{}", - "job_id": {}, - "pids": [{}], - "process_limit": {} -}}"#, - config_str, job_id, pids, process_limit - ); - - utils::send_msg(&client, &request_json).expect("Error sending message to service"); - - let resp_str = utils::get_resp(&client).expect("Unable to decode output bytes"); - - println!("response = {}", resp_str); - - let resp_v: Value = serde_json::from_str(&resp_str)?; - let processes = resp_v["processesMatched"].as_array().unwrap(); - - if processes.is_empty() { - println!("No processes were matched, please check --job-id or --pids flags"); - } else { - println!("Matched {} processes", processes.len()); - println!("Trace output files will be written to:"); - - for pid in processes { - let pid = pid.as_i64().unwrap(); - println!( - " {}", - config.log_file.replace(".json", &format!("_{}.json", pid)) - ); - } - } - - Ok(()) -} - - -#[cfg(test)] -mod test { - use crate::*; - - #[test] - fn test_nputrace_trigger_config() { - let trigger_config = NpuTraceTriggerConfig::DurationBased { - profile_start_time: 1000, - duration_ms: 1000, - }; - assert_eq!( - trigger_config.config(), - r#"PROFILE_START_TIME=1000 -ACTIVITIES_DURATION_MSECS=1000"# - ); - - let trigger_config = NpuTraceTriggerConfig::IterationBased { - profile_start_step: 1000, - iterations: 1000, - }; - assert_eq!( - trigger_config.config(), - r#"PROFILE_START_ITERATION=0 -PROFILE_START_STEP=1000 -ACTIVITIES_ITERATIONS=1000"# - ); - } - - #[test] - fn test_nputrace_config() { - let config = NpuTraceConfig { - log_file: "test.json".to_string(), - trigger_config: NpuTraceTriggerConfig::DurationBased { - profile_start_time: 1000, - duration_ms: 1000, - }, - trace_options: NpuTraceOptions { - record_shapes: true, - profile_memory: false, - with_stack: true, - with_flops: true, - with_modules: true, - activities: "CPU,NPU".to_string(), - analyse: false, - profiler_level: "Level0".to_string(), - aic_metrics: "AiCoreNone".to_string(), - l2_cache: true, - op_attr: true, - gc_detect_threshold: 0.1, - data_simplification: "true", - export_type: "Text".to_string(), - }, - }; - assert_eq!( - config.config(), - r#"ACTIVITIES_LOG_FILE=test.json -PROFILE_START_TIME=1000 -ACTIVITIES_DURATION_MSECS=1000 -PROFILE_RECORD_SHAPES=true -PROFILE_PROFILE_MEMORY=false -PROFILE_WITH_STACK=true -PROFILE_WITH_FLOPS=true -PROFILE_WITH_MODULES=true -PROFILE_ACTIVITIES=CPU,NPU -PROFILE_ANALYSE=false -PROFILE_PROFILER_LEVEL=Level0 -PROFILE_AIC_METRICS=AiCoreNone -PROFILE_L2_CACHE=true -PROFILE_OP_ATTR=true -PROFILE_GC_DETECT_THRESHOLD=0.1 -PROFILE_DATA_SIMPLIFICATION=true -PROFILE_EXPORT_TYPE=Text"# - ); - } -} diff --git a/dynolog_npu/dynolog_npu/cli/src/main.rs b/dynolog_npu/dynolog_npu/cli/src/main.rs deleted file mode 100644 index 8bc4a2af0e2c19d6e783663924578e3c2ad7408a..0000000000000000000000000000000000000000 --- a/dynolog_npu/dynolog_npu/cli/src/main.rs +++ /dev/null @@ -1,350 +0,0 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. -// -// This source code is licensed under the MIT license found in the -// LICENSE file in the root directory of this source tree. - -use std::net::TcpStream; -use std::net::ToSocketAddrs; - -use anyhow::Result; -use clap::Parser; -use std::collections::HashSet; - -// Make all the command modules accessible to this file. -mod commands; -use commands::gputrace::GpuTraceConfig; -use commands::gputrace::GpuTraceOptions; -use commands::gputrace::GpuTraceTriggerConfig; -use commands::nputrace::NpuTraceConfig; -use commands::nputrace::NpuTraceOptions; -use commands::nputrace::NpuTraceTriggerConfig; -use commands::npumonitor::NpuMonitorConfig; -use commands::*; - -/// Instructions on adding a new Dyno CLI command: -/// -/// 1. Add a new variant to the `Command` enum. -/// Please include a description of the command and, if applicable, its flags/subcommands. -/// -/// 2. Create a new file for the command's implementation in the commands/ directory (ie -/// commands/status.rs). This new file is where the command should be implemented. -/// Make the new command's module accessible from this file by adding -/// a new line with `pub mod ;` to commands/mod.rs. -/// -/// -/// 3. Add a branch to the match statement in main() to handle the new enum variant (from step 1). -/// From here, invoke the handling logic defined in the new file (from step 2). In an effort to keep -/// the command dispatching logic clear and concise, please keep the code in the match branch to a minimum. - -const DYNO_PORT: u16 = 1778; - -#[derive(Debug, Parser)] -struct Opts { - #[clap(long, default_value = "localhost")] - hostname: String, - #[clap(long, default_value_t = DYNO_PORT)] - port: u16, - #[clap(subcommand)] - cmd: Command, -} - -const ALLOWED_VALUES: &[&str] = &["Marker", "Kernel", "API", "Hccl", "Memory", "MemSet", "MemCpy"]; - -fn parse_mspti_activity_kinds(src: &str) -> Result{ - let allowed_values: HashSet<&str> = ALLOWED_VALUES.iter().cloned().collect(); - - let kinds: Vec<&str> = src.split(',').map(|s| s.trim()).collect(); - - for kind in &kinds { - if !allowed_values.contains(kind) { - return Err(format!("Invalid MSPTI activity kind: {}, Possible values: {:?}.]", kind, allowed_values)); - } - } - - Ok(src.to_string()) -} - -#[derive(Debug, Parser)] -enum Command { - /// Check the status of a dynolog process - Status, - /// Check the version of a dynolog process - Version, - /// Capture gputrace - Gputrace { - /// Job id of the application to trace. - #[clap(long, default_value_t = 0)] - job_id: u64, - /// List of pids to capture trace for (comma separated). - #[clap(long, default_value = "0")] - pids: String, - /// Duration of trace to collect in ms. - #[clap(long, default_value_t = 500)] - duration_ms: u64, - /// Training iterations to collect, this takes precedence over duration. - #[clap(long, default_value_t = -1)] - iterations: i64, - /// Log file for trace. - #[clap(long)] - log_file: String, - /// Unix timestamp used for synchronized collection (milliseconds since epoch). - #[clap(long, default_value_t = 0)] - profile_start_time: u64, - /// Start iteration roundup, starts an iteration based trace at a multiple - /// of this value. - #[clap(long, default_value_t = 1)] - profile_start_iteration_roundup: u64, - /// Max number of processes to profile. - #[clap(long, default_value_t = 3)] - process_limit: u32, - /// Record PyTorch operator input shapes and types. - #[clap(long, action)] - record_shapes: bool, - /// Profile PyTorch memory. - #[clap(long, action)] - profile_memory: bool, - /// Capture Python stacks in traces. - #[clap(long, action)] - with_stacks: bool, - /// Annotate operators with analytical flops. - #[clap(long, action)] - with_flops: bool, - /// Capture PyTorch operator modules in traces. - #[clap(long, action)] - with_modules: bool, - }, - /// Capture nputrace. Subcommand functions aligned with Ascend Torch Profiler. - Nputrace { - /// Job id of the application to trace. - #[clap(long, default_value_t = 0)] - job_id: u64, - /// List of pids to capture trace for (comma separated). - #[clap(long, default_value = "0")] - pids: String, - /// Duration of trace to collect in ms. - #[clap(long, default_value_t = 500)] - duration_ms: u64, - /// Training iterations to collect, this takes precedence over duration. - #[clap(long, default_value_t = -1)] - iterations: i64, - /// Log file for trace. - #[clap(long)] - log_file: String, - /// Unix timestamp used for synchronized collection (milliseconds since epoch). - #[clap(long, default_value_t = 0)] - profile_start_time: u64, - /// Number of steps to start profile. - #[clap(long, default_value_t = 0)] - start_step: u64, - /// Max number of processes to profile. - #[clap(long, default_value_t = 3)] - process_limit: u32, - /// Whether to record PyTorch operator input shapes and types. - #[clap(long, action)] - record_shapes: bool, - /// Whether to profile PyTorch memory. - #[clap(long, action)] - profile_memory: bool, - /// Whether to profile the Python call stack in trace. - #[clap(long, action)] - with_stack: bool, - /// Annotate operators with analytical flops. - #[clap(long, action)] - with_flops: bool, - /// Whether to profile PyTorch operator modules in traces. - #[clap(long, action)] - with_modules: bool, - /// The scope of the profile's events. - #[clap(long, value_parser = ["CPU,NPU", "NPU,CPU", "CPU", "NPU"], default_value = "CPU,NPU")] - activities: String, - /// Profiler level. - #[clap(long, value_parser = ["Level0", "Level1", "Level2", "Level_none"], default_value = "Level0")] - profiler_level: String, - /// AIC metrics. - #[clap(long, value_parser = ["AiCoreNone", "PipeUtilization", "ArithmeticUtilization", "Memory", "MemoryL0", "ResourceConflictRatio", "MemoryUB", "L2Cache", "MemoryAccess"], default_value = "AiCoreNone")] - aic_metrics: String, - /// Whether to analyse the data after collection. - #[clap(long, action)] - analyse: bool, - /// Whether to collect L2 cache. - #[clap(long, action)] - l2_cache: bool, - /// Whether to collect op attributes. - #[clap(long, action)] - op_attr: bool, - /// GC detect threshold. - #[clap(long)] - gc_detect_threshold: Option, - /// Whether to streamline data after analyse is complete. - #[clap(long, value_parser = ["true", "false"], default_value = "true")] - data_simplification: String, - /// Types of data exported by the profiler. - #[clap(long, value_parser = ["Text", "Db"], default_value = "Text")] - export_type: String, - }, - /// Ascend MSPTI Monitor - NpuMonitor { - /// Start NPU monitor. - #[clap(long, action)] - npu_monitor_start: bool, - /// Stop NPU monitor. - #[clap(long, action)] - npu_monitor_stop: bool, - /// NPU monitor report interval in seconds. - #[clap(long, default_value_t = 60)] - report_interval_s: u32, - /// MSPTI collect activity kind - #[clap(long, value_parser = parse_mspti_activity_kinds, default_value = "Marker")] - mspti_activity_kind: String, - }, - /// Pause dcgm profiling. This enables running tools like Nsight compute and avoids conflicts. - DcgmPause { - /// Duration to pause dcgm profiling in seconds - #[clap(long, default_value_t = 300)] - duration_s: i32, - }, - /// Resume dcgm profiling - DcgmResume, -} - -/// Create a socket connection to dynolog -fn create_dyno_client(host: &str, port: u16) -> Result { - let addr = (host, port) - .to_socket_addrs()? - .next() - .expect("Failed to connect to the server"); - - TcpStream::connect(addr).map_err(|err| err.into()) -} - -fn main() -> Result<()> { - let Opts { - hostname, - port, - cmd, - } = Opts::parse(); - - let dyno_client = - create_dyno_client(&hostname, port).expect("Couldn't connect to the server..."); - - match cmd { - Command::Status => status::run_status(dyno_client), - Command::Version => version::run_version(dyno_client), - Command::Gputrace { - job_id, - pids, - log_file, - duration_ms, - iterations, - profile_start_time, - profile_start_iteration_roundup, - process_limit, - record_shapes, - profile_memory, - with_stacks, - with_flops, - with_modules, - } => { - let trigger_config = if iterations > 0 { - GpuTraceTriggerConfig::IterationBased { - profile_start_iteration_roundup, - iterations, - } - } else { - GpuTraceTriggerConfig::DurationBased { - profile_start_time, - duration_ms, - } - }; - let trace_options = GpuTraceOptions { - record_shapes, - profile_memory, - with_stacks, - with_flops, - with_modules, - }; - let trace_config = GpuTraceConfig { - log_file, - trigger_config, - trace_options, - }; - gputrace::run_gputrace(dyno_client, job_id, &pids, process_limit, trace_config) - } - Command::Nputrace { - job_id, - pids, - log_file, - duration_ms, - iterations, - profile_start_time, - start_step, - process_limit, - record_shapes, - profile_memory, - with_stack, - with_flops, - with_modules, - activities, - analyse, - profiler_level, - aic_metrics, - l2_cache, - op_attr, - gc_detect_threshold, - data_simplification, - export_type, - } => { - let trigger_config = if iterations > 0 { - NpuTraceTriggerConfig::IterationBased { - start_step, - iterations, - } - } else { - NpuTraceTriggerConfig::DurationBased { - profile_start_time, - duration_ms, - } - }; - - let trace_options = NpuTraceOptions { - record_shapes, - profile_memory, - with_stack, - with_flops, - with_modules, - activities, - analyse, - profiler_level, - aic_metrics, - l2_cache, - op_attr, - gc_detect_threshold, - data_simplification, - export_type, - }; - let trace_config = NpuTraceConfig { - log_file, - trigger_config, - trace_options, - }; - nputrace::run_nputrace(dyno_client, job_id, &pids, process_limit, trace_config) - } - Command::NpuMonitor { - npu_monitor_start, - npu_monitor_stop, - report_interval_s, - mspti_activity_kind, - } => { - let npu_mon_config = NpuMonitorConfig { - npu_monitor_start, - npu_monitor_stop, - report_interval_s, - mspti_activity_kind - }; - npumonitor::run_npumonitor(dyno_client, npu_mon_config) - } - Command::DcgmPause { duration_s } => dcgm::run_dcgm_pause(dyno_client, duration_s), - Command::DcgmResume => dcgm::run_dcgm_resume(dyno_client), - // ... add new commands here - } -} \ No newline at end of file diff --git a/dynolog_npu/dynolog_npu/dynolog/src/Main.cpp b/dynolog_npu/dynolog_npu/dynolog/src/Main.cpp deleted file mode 100644 index 8e5177768327e37173d4e7661e334a9400bd6172..0000000000000000000000000000000000000000 --- a/dynolog_npu/dynolog_npu/dynolog/src/Main.cpp +++ /dev/null @@ -1,206 +0,0 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. -// -// This source code is licensed under the MIT license found in the -// LICENSE file in the root directory of this source tree. - -// Dynolog : A portable telemetry monitoring daemon. - -#include -#include -#include -#include -#include -#include "dynolog/src/CompositeLogger.h" -#include "dynolog/src/FBRelayLogger.h" -#include "dynolog/src/KernelCollector.h" -#include "dynolog/src/Logger.h" -#include "dynolog/src/ODSJsonLogger.h" -#include "dynolog/src/PerfMonitor.h" -#include "dynolog/src/ScubaLogger.h" -#include "dynolog/src/ServiceHandler.h" -#include "dynolog/src/gpumon/DcgmGroupInfo.h" -#include "dynolog/src/rpc/SimpleJsonServer.h" -#include "dynolog/src/rpc/SimpleJsonServerInl.h" -#include "dynolog/src/tracing/IPCMonitor.h" -#include "hbt/src/perf_event/BuiltinMetrics.h" - -#ifdef USE_PROMETHEUS -#include "dynolog/src/PrometheusLogger.h" -#endif - -using namespace dynolog; -using json = nlohmann::json; -namespace hbt = facebook::hbt; - -DEFINE_int32(port, 1778, "Port for listening RPC requests."); -DEFINE_bool(use_JSON, false, "Emit metrics to JSON file through JSON logger"); -#ifdef USE_PROMETHEUS -DEFINE_bool(use_prometheus, false, "Emit metrics to Prometheus"); -#endif -DEFINE_bool(use_fbrelay, false, "Emit metrics to FB Relay on Lab machines"); -DEFINE_bool(use_ODS, false, "Emit metrics to ODS through ODS logger"); -DEFINE_bool(use_scuba, false, "Emit metrics to Scuba through Scuba logger"); -DEFINE_int32( - kernel_monitor_reporting_interval_s, - 60, - "Duration in seconds to read and report metrics for kernel monitor"); -DEFINE_int32( - perf_monitor_reporting_interval_s, - 60, - "Duration in seconds to read and report metrics for performance monitor"); -DEFINE_int32( - dcgm_reporting_interval_s, - 10, - "Duration in seconds to read and report metrics for DCGM"); -DEFINE_bool( - enable_ipc_monitor, - false, - "Enabled IPC monitor for on system tracing requests."); -DEFINE_bool( - enable_gpu_monitor, - false, - "Enabled GPU monitorng, currently supports NVIDIA GPUs."); -DEFINE_bool(enable_perf_monitor, false, "Enable heartbeat perf monitoring."); - -std::unique_ptr getLogger(const std::string& scribe_category = "") { - std::vector> loggers; -#ifdef USE_PROMETHEUS - if (FLAGS_use_prometheus) { - loggers.push_back(std::make_unique()); - } -#endif - if (FLAGS_use_fbrelay) { - loggers.push_back(std::make_unique()); - } - if (FLAGS_use_ODS) { - loggers.push_back(std::make_unique()); - } - if (FLAGS_use_JSON) { - loggers.push_back(std::make_unique()); - } - if (FLAGS_use_scuba && !scribe_category.empty()) { - loggers.push_back(std::make_unique(scribe_category)); - } - return std::make_unique(std::move(loggers)); -} - -auto next_wakeup(int sec) { - return std::chrono::steady_clock::now() + std::chrono::seconds(sec); -} - -void kernel_monitor_loop() { - KernelCollector kc; - - LOG(INFO) << "Running kernel monitor loop : interval = " - << FLAGS_kernel_monitor_reporting_interval_s << " s."; - - while (1) { - auto logger = getLogger(); - auto wakeup_timepoint = - next_wakeup(FLAGS_kernel_monitor_reporting_interval_s); - - kc.step(); - kc.log(*logger); - logger->finalize(); - - /* sleep override */ - std::this_thread::sleep_until(wakeup_timepoint); - } -} - -void perf_monitor_loop() { - PerfMonitor pm( - hbt::CpuSet::makeAllOnline(), - std::vector{"instructions", "cycles"}, - getDefaultPmuDeviceManager(), - getDefaultMetrics()); - - LOG(INFO) << "Running perf monitor loop : interval = " - << FLAGS_perf_monitor_reporting_interval_s << " s."; - - while (1) { - auto logger = getLogger(); - auto wakeup_timepoint = - next_wakeup(FLAGS_perf_monitor_reporting_interval_s); - - pm.step(); - pm.log(*logger); - - logger->finalize(); - /* sleep override */ - std::this_thread::sleep_until(wakeup_timepoint); - } -} - -auto setup_server(std::shared_ptr handler) { - return std::make_unique>( - handler, FLAGS_port); -} - -void gpu_monitor_loop(std::shared_ptr dcgm) { - auto logger = getLogger(FLAGS_scribe_category); - - LOG(INFO) << "Running DCGM loop : interval = " - << FLAGS_dcgm_reporting_interval_s << " s."; - LOG(INFO) << "DCGM fields: " << gpumon::FLAGS_dcgm_fields; - - while (1) { - auto wakeup_timepoint = next_wakeup(FLAGS_dcgm_reporting_interval_s); - - dcgm->update(); - dcgm->log(*logger); - - /* sleep override */ - std::this_thread::sleep_until(wakeup_timepoint); - } -} - -int main(int argc, char** argv) { - gflags::ParseCommandLineFlags(&argc, &argv, true); - FLAGS_logtostderr = 1; - google::InitGoogleLogging(argv[0]); - - LOG(INFO) << "Starting Ascend Extension for dynolog, version = " DYNOLOG_VERSION - << ", build git-hash = " DYNOLOG_GIT_REV; - - std::shared_ptr dcgm; - - std::unique_ptr ipcmon; - std::unique_ptr ipcmon_thread, gpumon_thread, pm_thread; - - if (FLAGS_enable_ipc_monitor) { - LOG(INFO) << "Starting IPC Monitor"; - ipcmon = std::make_unique(); - ipcmon_thread = - std::make_unique([&ipcmon]() { ipcmon->loop(); }); - } - - if (FLAGS_enable_gpu_monitor) { - dcgm = gpumon::DcgmGroupInfo::factory( - gpumon::FLAGS_dcgm_fields, FLAGS_dcgm_reporting_interval_s * 1000); - gpumon_thread = std::make_unique(gpu_monitor_loop, dcgm); - } - std::thread km_thread{kernel_monitor_loop}; - if (FLAGS_enable_perf_monitor) { - pm_thread = std::make_unique(perf_monitor_loop); - } - - // setup service - auto handler = std::make_shared(dcgm); - - // use simple json RPC server for now - auto server = setup_server(handler); - server->run(); - - km_thread.join(); - if (pm_thread) { - pm_thread->join(); - } - if (gpumon_thread) { - gpumon_thread->join(); - } - - server->stop(); - - return 0; -} \ No newline at end of file diff --git a/dynolog_npu/plugin/Readme.md b/dynolog_npu/plugin/Readme.md deleted file mode 100644 index c59bfffad5aaac5383b407e3ff3d23ed126131f5..0000000000000000000000000000000000000000 --- a/dynolog_npu/plugin/Readme.md +++ /dev/null @@ -1,17 +0,0 @@ - - -# Build and Install npu-dynolog-plugin -``` -# install pybind11 -pip install pybind11 - -# build dynolog_npu_plugin wheel -python3 setup.py bdist_wheel -# install -pip install dist/{dynolog-npu-plugin-xxx.wheel} - -# example -import IPCMonitor -dyno_worker = IPCMonitor.PyDynamicMonitorProxy() -dyno_worker.init_dyno(0) -``` diff --git a/dynolog_npu/plugin/bindings.cpp b/dynolog_npu/plugin/bindings.cpp deleted file mode 100644 index c0cdaa4d577b3a76ec2d6f3eae4b426556a56532..0000000000000000000000000000000000000000 --- a/dynolog_npu/plugin/bindings.cpp +++ /dev/null @@ -1,11 +0,0 @@ -#include -#include "ipc_monitor/PyDynamicMonitorProxy.h" - -namespace py = pybind11; - -PYBIND11_MODULE(IPCMonitor, m) { - py::class_(m, "PyDynamicMonitorProxy") - .def(py::init<>()) - .def("init_dyno", &dynolog_npu::ipc_monitor::PyDynamicMonitorProxy::InitDyno, py::arg("npuId")) - .def("poll_dyno", &dynolog_npu::ipc_monitor::PyDynamicMonitorProxy::PollDyno); -} \ No newline at end of file diff --git a/dynolog_npu/plugin/ipc_monitor/DynoLogNpuMonitor.cpp b/dynolog_npu/plugin/ipc_monitor/DynoLogNpuMonitor.cpp deleted file mode 100644 index 940f5aae167f088361057fe2a7a389a76f5bb2b4..0000000000000000000000000000000000000000 --- a/dynolog_npu/plugin/ipc_monitor/DynoLogNpuMonitor.cpp +++ /dev/null @@ -1,36 +0,0 @@ -#include "DynoLogNpuMonitor.h" - -#include - -#include "utils.h" - -namespace dynolog_npu { -namespace ipc_monitor { - -bool DynoLogNpuMonitor::Init() -{ - if (isInitialized_) { - std::cout << "[WRARNING] DynoLog npu monitor already initialized" << std::endl; - return true; - } - bool res = ipcClient_.RegisterInstance(npuId_); - if (res) { - isInitialized_ = true; - std::cout << "[INFO] DynoLog npu monitor initialized success !" << std::endl; - } - return res; -} - -std::string DynoLogNpuMonitor::Poll() -{ - std::string res = ipcClient_.IpcClientNpuConfig(); - if (res.empty()) { - std::cout << "[INFO] Request for dynolog server is empty !" << std::endl; - return ""; - } - std::cout << "[INFO] Received NPU configuration successfully" << std::endl; - return res; -} - -} // namespace ipc_monitor -} // namespace dynolog_npu \ No newline at end of file diff --git a/dynolog_npu/plugin/ipc_monitor/PyDynamicMonitorProxy.h b/dynolog_npu/plugin/ipc_monitor/PyDynamicMonitorProxy.h deleted file mode 100644 index 8b5f88abf9d2cf589bec685cd3a520729afe8dd5..0000000000000000000000000000000000000000 --- a/dynolog_npu/plugin/ipc_monitor/PyDynamicMonitorProxy.h +++ /dev/null @@ -1,40 +0,0 @@ -#ifndef PYDYNAMIC_MONITOR_PROXY_H -#define PYDYNAMIC_MONITOR_PROXY_H - -#include -#include -#include "MonitorBase.h" -#include "DynoLogNpuMonitor.h" - -namespace dynolog_npu { -namespace ipc_monitor { - -class PyDynamicMonitorProxy { -public: - PyDynamicMonitorProxy() = default; - bool InitDyno(int npuId) - { - try { - monitor_ = DynoLogNpuMonitor::GetInstance(); - monitor_->SetNpuId(npuId); - bool res = monitor_->Init(); - return res; - } catch (const std::exception &e) { - std::cout << "[ERROR] Error when init dyno " << e.what() << std::endl; - return false; - } - } - - std::string PollDyno() - { - return monitor_->Poll(); - }; - -private: - MonitorBase *monitor_ = nullptr; -}; - -} // namespace ipc_monitor -} // namespace dynolog_npu - -#endif diff --git a/dynolog_npu/plugin/ipc_monitor/utils.cpp b/dynolog_npu/plugin/ipc_monitor/utils.cpp deleted file mode 100644 index 936821fd34bc34bc9db9e09515132e8af39ba57a..0000000000000000000000000000000000000000 --- a/dynolog_npu/plugin/ipc_monitor/utils.cpp +++ /dev/null @@ -1,135 +0,0 @@ -#include "utils.h" - -namespace dynolog_npu { -namespace ipc_monitor { -std::unordered_map submoduleMap = { - {SubModule::IPC, "IPC"}, -}; - -std::unordered_map errCodeMap = { - {ErrCode::SUC, "success"}, - {ErrCode::PARAM, "invalid parameter"}, - {ErrCode::TYPE, "invalid type"}, - {ErrCode::VALUE, "invalid value"}, - {ErrCode::PTR, "invalid pointer"}, - {ErrCode::INTERNAL, "internal error"}, - {ErrCode::MEMORY, "memory error"}, - {ErrCode::NOT_SUPPORT, "feature not supported"}, - {ErrCode::NOT_FOUND, "resource not found"}, - {ErrCode::UNAVAIL, "resource unavailable"}, - {ErrCode::SYSCALL, "system call failed"}, - {ErrCode::TIMEOUT, "timeout error"}, - {ErrCode::PERMISSION, "permission error"}, -}; - -std::string getCurrentTimestamp() -{ - auto now = std::chrono::system_clock::now(); - auto micros = std::chrono::duration_cast(now.time_since_epoch()); - - std::time_t currentTime = std::chrono::system_clock::to_time_t(now); - std::tm* timeInfo = std::localtime(¤tTime); - - auto milli_time = std::chrono::duration_cast(micros).count() % 1000; - auto micro_time = micros.count() % 1000; - - std::ostringstream oss; - oss << std::put_time(timeInfo, "%Y-%m-%d-%H:%M:%S"); - return oss.str(); -} - -std::string formatErrorCode(SubModule submodule, ErrCode errorCode) -{ - std::ostringstream oss; - oss << "\n[ERROR] " << getCurrentTimestamp() << " (PID:" << getpid() << ")"; - oss << "ERR" << std::setw(2) << std::setfill('0') << static_cast(submodule); // 2: 字段宽度 - oss << std::setw(3) << std::setfill('0') << static_cast(errorCode); // 3: 字段宽度 - oss << " " << submoduleMap[submodule] << " " << errCodeMap[errorCode]; - - return oss.str(); -}; - - -int32_t GetProcessId() -{ - return static_cast(getpid()); -} - -std::pair GetParentPidAndCommand(int32_t pid) -{ - std::string fileName = "/proc/" + std::to_string(pid) + "/stat"; - std::ifstream statFile(fileName); - if (!statFile) { - return std::make_pair(0, ""); - } - int32_t parentPid = 0; - std::string command; - std::string line; - if (std::getline(statFile, line)) { - int ret = sscanf(line.c_str(), "%*d (%[^)]) %*c %d", command.data(), &parentPid); - if (ret == 2) { // 2: 接收到2个字符 - std::cout << "[INFO] Success to get parent pid: " << parentPid << std::endl; - return std::make_pair(parentPid, command); - } - } - std::cout << "[WARNING] Failed to parse /proc/" << pid << "/stat" << std::endl; - return std::make_pair(0, ""); -} - -std::vector> GetPidCommandPairsofAncestors() -{ - std::vector> process_pids_and_cmds; - process_pids_and_cmds.reserve(MaxParentPids + 1); - int32_t current_pid = GetProcessId(); - for (int i = 0; i <= MaxParentPids && (i == 0 || current_pid > 1); i++) { - std::pair parent_pid_and_cmd = GetParentPidAndCommand(current_pid); - process_pids_and_cmds.push_back(std::make_pair(current_pid, parent_pid_and_cmd.second)); - current_pid = parent_pid_and_cmd.first; - } - return process_pids_and_cmds; -} - -std::vector GetPids() -{ - const auto &pids = GetPidCommandPairsofAncestors(); - std::vector res; - res.reserve(pids.size()); - for (const auto &pidPair : pids) { - res.push_back(pidPair.first); - } - return res; -} -std::string GenerateUuidV4() -{ - static std::random_device randomDevice; - static std::mt19937 gen(randomDevice()); - static std::uniform_int_distribution<> dis(0, 15); // range (0, 15) - static std::uniform_int_distribution<> dis2(8, 11); // range (8, 11) - - std::stringstream stringStream; - stringStream << std::hex; - for (int i = 0; i < 8; i++) { // 8 times - stringStream << dis(gen); - } - stringStream << "-"; - for (int j = 0; j < 4; j++) { // 4 times - stringStream << dis(gen); - } - stringStream << "-4"; // add -4 - for (int k = 0; k < 3; k++) { // 3 times - stringStream << dis(gen); - } - stringStream << "-"; - stringStream << dis2(gen); - for (int m = 0; m < 3; m++) { // 3 times - stringStream << dis(gen); - } - stringStream << "-"; - for (int n = 0; n < 12; n++) { // 12 times - stringStream << dis(gen); - } - return stringStream.str(); -} - -} // namespace ipc_monitor -} // namespace dynolog_npu diff --git a/dynolog_npu/plugin/ipc_monitor/utils.h b/dynolog_npu/plugin/ipc_monitor/utils.h deleted file mode 100644 index 0d8ceb8cfd0bf81b6d8b807c6ac1b505276ddf83..0000000000000000000000000000000000000000 --- a/dynolog_npu/plugin/ipc_monitor/utils.h +++ /dev/null @@ -1,63 +0,0 @@ -#ifndef IPC_MONITOR_UTILS_H -#define IPC_MONITOR_UTILS_H -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - - -namespace dynolog_npu { -namespace ipc_monitor { - -constexpr int MaxParentPids = 5; -int32_t GetProcessId(); -std::string GenerateUuidV4(); -std::vector GetPids(); -std::pair GetParentPidAndCommand(int32_t pid); -std::vector> GetPidCommandPairsofAncestors(); -std::string getCurrentTimestamp(); - -enum class SubModule { - IPC = 0 -}; - -enum class ErrCode { - SUC = 0, - PARAM = 1, - TYPE = 2, - VALUE = 3, - PTR = 4, - INTERNAL = 5, - MEMORY = 6, - NOT_SUPPORT = 7, - NOT_FOUND = 8, - UNAVAIL = 9, - SYSCALL = 10, - TIMEOUT = 11, - PERMISSION = 12, -}; - - -std::string formatErrorCode(SubModule submodule, ErrCode errorCode); - -#define IPC_ERROR(error) formatErrorCode(SubModule::IPC, error) - -template -inline T ReinterpretConvert(V ptr) { - return reinterpret_cast(ptr); -} - - -} // namespace ipc_monitor -} // namespace dynolog_npu - -#endif - diff --git a/dynolog_npu/scripts/apply_dyno_patches.sh b/dynolog_npu/scripts/apply_dyno_patches.sh deleted file mode 100644 index c492db74a2a56948433a47e9cffcccd4ac71e098..0000000000000000000000000000000000000000 --- a/dynolog_npu/scripts/apply_dyno_patches.sh +++ /dev/null @@ -1,36 +0,0 @@ -#! /bin/bash -set -e - -apply_ascend_patches() { - cd ./third_party/dynolog || return 1 - - if [ ! -d "../../patches" ]; then - echo "ERROR: patches directory not found" - cd ../.. - return 1 - fi - - for patch_file in ../../patches/*.patch; do - if [ -f "$patch_file" ]; then - echo "Applying patch: $patch_file" - git apply --check -p1 "$patch_file" - if [ $? -ne 0 ]; then - echo "ERROR: Failed to apply patch: $(basename $patch_file)" - cd ../.. - return 1 - fi - git apply -p1 "$patch_file" - if [ $? -ne 0 ]; then - echo "ERROR: Failed to apply patch: $(basename $patch_file)" - cd ../.. - return 1 - fi - fi - done - - cd ../.. - echo "Successfully applied all Ascend patches" - return 0 -} - -apply_ascend_patches \ No newline at end of file diff --git a/dynolog_npu/scripts/build.sh b/dynolog_npu/scripts/build.sh deleted file mode 100644 index aa3508e14faa6bfea06afe0cd3083ad1a5317037..0000000000000000000000000000000000000000 --- a/dynolog_npu/scripts/build.sh +++ /dev/null @@ -1,108 +0,0 @@ -#!/bin/bash -set -e - -check_gcc_version() { - if ! command -v gcc >/dev/null 2>&1; then - echo "ERROR: gcc command not found" - return 1 - fi - - local GCC_VERSION=$(gcc -dumpversion) - local GCC_MAJOR=$(echo $GCC_VERSION | cut -d. -f1) - local GCC_MINOR=$(echo $GCC_VERSION | cut -d. -f2) - - if [ "$GCC_MAJOR" -lt 8 ] || ([ "$GCC_MAJOR" -eq 8 ] && [ "$GCC_MINOR" -lt 5 ]); then - echo "ERROR: gcc version must be greater than or equal to 8.5.0" - echo "Current gcc version: $GCC_VERSION" - return 1 - fi - echo "Check pass: current gcc version is $GCC_VERSION" - return 0 -} - -check_rust_version() { - if ! command -v rustc >/dev/null 2>&1; then - echo "ERROR: rustc command not found" - return 1 - fi - - local RUST_VERSION=$(rustc --version | cut -d' ' -f2) - local RUST_MAJOR=$(echo $RUST_VERSION | cut -d. -f1) - local RUST_MINOR=$(echo $RUST_VERSION | cut -d. -f2) - - if [ "$RUST_MAJOR" -lt 1 ] || ([ "$RUST_MAJOR" -eq 1 ] && [ "$RUST_MINOR" -lt 56 ]); then - echo "ERROR: Rust version must be greater than or equal to 1.56.0" - echo "Current Rust version: $RUST_VERSION" - return 1 - fi - echo "Check pass: current Rust version is $RUST_VERSION" - return 0 -} - -update_and_checkout_submodule() { - DYNLOG_COMMIT_ID="a9b6aeddcd6363252f5388cb0dd942981a09a24b" - - git submodule update --init --recursive - if [ $? -ne 0 ]; then - echo "ERROR: update git submodule failed" - return 1 - fi - - cd ./third_party/dynolog - git checkout ${DYNLOG_COMMIT_ID} - if [ $? -ne 0 ]; then - echo "ERROR: switch to dynolog specified commit failed" - cd .. - return 1 - fi - echo "Check pass: switch to dynolog specified commit ${DYNLOG_COMMIT_ID}" - cd ../../ - return 0 -} - -PACKAGE_TYPE="" -while getopts "t:" opt; do - case $opt in - t) - PACKAGE_TYPE="$OPTARG" - if [[ "$PACKAGE_TYPE" != "deb" && "$PACKAGE_TYPE" != "rpm" ]]; then - echo "ERROR: Invalid package type. Supported types: deb, rpm" - exit 1 - fi - ;; - \?) - echo "Usage: $0 [-t package_type]" - echo "package_type: deb or rpm (optional, if not specified will only build)" - exit 1 - ;; - esac -done - -echo "------------------ Check GCC and Rust version ----------------------" -check_gcc_version -check_rust_version - -echo "------------------ Update and checkout submodule -------------------" -update_and_checkout_submodule - -echo "------------------ Generate patch for Ascend -----------------------" -bash scripts/gen_dyno_patches.sh - -echo "------------------ Apply patch for Ascend --------------------------" -bash scripts/apply_dyno_patches.sh - -echo "------------------ Build dynolog patch for Ascend-------------------" -cd third_party/dynolog -rm -rf build -if [ -z "$PACKAGE_TYPE" ]; then - bash scripts/build.sh - echo "Build dynolog success without packaging" -elif [ "$PACKAGE_TYPE" = "deb" ]; then - bash scripts/debian/make_deb.sh - mv dynolog_*.deb ../../ - echo "Build dynolog deb package success" -elif [ "$PACKAGE_TYPE" = "rpm" ]; then - bash scripts/rpm/make_rpm.sh - mv dynolog_*.rpm ../../ - echo "Build dynolog rpm package success" -fi diff --git a/dynolog_npu/scripts/gen_dyno_patches.sh b/dynolog_npu/scripts/gen_dyno_patches.sh deleted file mode 100644 index 5ade74dbcfcf88dfbc072c9de790ec4f3ec451d9..0000000000000000000000000000000000000000 --- a/dynolog_npu/scripts/gen_dyno_patches.sh +++ /dev/null @@ -1,63 +0,0 @@ -#!/bin/bash -set -e - -WORK_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" -PATCHES_DIR="${WORK_DIR}/patches" -DYNOLOG_DIR="${WORK_DIR}/third_party/dynolog" -MODIFIED_FILES_DIR="${WORK_DIR}/dynolog_npu" - -mkdir -p "${PATCHES_DIR}" - -generate_patches() { - echo "Generating patches from modified files..." - - # 检查修改后的文件目录是否存在 - if [ ! -d "${MODIFIED_FILES_DIR}" ]; then - echo "ERROR: dynolog_npu directory not found" - return 1 - fi - - # 清理旧的patch文件 - rm -f "${PATCHES_DIR}"/*.patch - - # 遍历修改后的文件目录 - find "${MODIFIED_FILES_DIR}" -type f | while read modified_file; do - # 获取相对路径 - rel_path=$(realpath --relative-to="${MODIFIED_FILES_DIR}" "${modified_file}") - original_file="${DYNOLOG_DIR}/${rel_path}" - - echo "original_file: ${original_file}" - # 检查原始文件是否存在 - if [ ! -f "${original_file}" ]; then - echo "WARN: Original file not found: ${original_file}" - - cp "${modified_file}" "${original_file}" - echo "Copied ${modified_file} to ${original_file}" - continue - fi - - # 生成patch文件名(将路径中的斜杠替换为下划线) - patch_name=$(echo "${rel_path}" | sed 's/\//_/g') - patch_file="${PATCHES_DIR}/${patch_name}.patch" - - echo "Generating patch for: ${rel_path}" - - ( - cd "${WORK_DIR}" - diff -u "third_party/dynolog/${rel_path}" "dynolog_npu/${rel_path}" > "${patch_file}" || true - ) - - # 检查patch文件大小 - if [ ! -s "${patch_file}" ]; then - rm "${patch_file}" - echo "No differences found for: ${rel_path}" - else - echo "Successfully generated patch: ${patch_file}" - fi - done - - echo "Patch generation completed" - return 0 -} - -generate_patches \ No newline at end of file diff --git a/dynolog_npu/third_party/dynolog b/dynolog_npu/third_party/dynolog deleted file mode 160000 index d5d37bc182bc2aa8fa60ba7d5ee897bacb5cbd4b..0000000000000000000000000000000000000000 --- a/dynolog_npu/third_party/dynolog +++ /dev/null @@ -1 +0,0 @@ -Subproject commit d5d37bc182bc2aa8fa60ba7d5ee897bacb5cbd4b diff --git a/flight_recoder/analysis_flight.py b/flight_recoder/analysis_flight.py deleted file mode 100644 index f81f771ab1c81ad79cb93401e200b600a4b17af3..0000000000000000000000000000000000000000 --- a/flight_recoder/analysis_flight.py +++ /dev/null @@ -1,164 +0,0 @@ -# Copyright (c) 2025, Huawei Technologies Co., Ltd. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Copyright Huawei Technologies Co., Ltd. 2024-2025. All rights reserved. - -import os -import pickle -import sys -import logging -from collections import defaultdict - -from check_path import get_valid_read_path - - -logging.basicConfig( - level=logging.INFO, # 设置日志级别为 INFO - format="%(asctime)s - %(levelname)s - %(message)s", # 设置日志格式 - handlers=[logging.StreamHandler()], # 输出到控制台 -) - - -SAFE_CLASSES = { - # 内置安全类型 - "builtins": {"str", "int", "float", "list", "dict", "tuple"}, -} - - -class SafeUnpickler(pickle.Unpickler): - def find_class(self, module, name): - # 检查模块和类是否在白名单中 - if module in SAFE_CLASSES and name in SAFE_CLASSES[module]: - return super().find_class(module, name) - raise pickle.UnpicklingError(f"Forbidden class: {module}.{name}") - - -def load_recorder_data(path, world_size): - """加载所有 rank 的 recorder 数据""" - recorder_dict = {} - for rank in range(world_size): - file_path = os.path.join(path, str(rank)) if not path.endswith("/") else path + str(rank) - file_path = get_valid_read_path(file_path) - try: - with open(file_path, "rb") as f: - res = SafeUnpickler(f).load() - recorder_dict[str(rank)] = res - except Exception as e: - logging.error(f"Failed to load data from {file_path}: {e}") - return recorder_dict - - -def extract_hccl_info(recorder_dict): - """从 recorder 数据中提取 HCCL 相关信息""" - hccl_dict = {} - for rank, recorder in recorder_dict.items(): - entries = recorder.get("entries", []) - if not entries: - continue - last_entry = entries[-1] - hccl_dict[rank] = { - "state": last_entry.get("state", None), - "record_id": last_entry.get("record_id", None), - "pg_id": last_entry.get("pg_id", None), - "time_discovered_completed_ns": last_entry.get("time_discovered_completed_ns", None), - "name": last_entry.get("frames", [{}])[0].get("name", None), - } - return hccl_dict - - -def analyze_pg_groups(hccl_dict): - """分析 HCCL 数据,按 pg_id 分组并检查问题""" - pg_groups = defaultdict(list) - for _, op in hccl_dict.items(): - pg_groups[op["pg_id"]].append(op) - - for pg_id, group in pg_groups.items(): - scheduled_ops = [op for op in group if op["state"] == "scheduled"] - completed_ops = [op for op in group if op["state"] == "completed"] - - # 情况 1: 所有卡都是 scheduled,且 record_id 和 name 相同 - if len(scheduled_ops) == len(group): - record_id = scheduled_ops[0]["record_id"] - name = scheduled_ops[0]["name"] - all_same = all(op["record_id"] == record_id and op["name"] == name for op in scheduled_ops) - if all_same: - logging.info( - f"The pg_id {pg_id}'s Communication Operator {name}" - " executed too slowly, causing the HCCL to time out." - ) - - # 情况 2: 存在 completed 算子且 该算子的record_id 比其他 scheduled 算子少 1 - elif completed_ops and scheduled_ops: - completed_op = completed_ops[0] - scheduled_record_id = scheduled_ops[0]["record_id"] - if completed_op["record_id"] == scheduled_record_id - 1: - logging.info( - f"The pg_id {pg_id}'s rank {completed_op['pg_id']}'s " - "Computational task took too long, causing the other ranks' " - "HCCL task to time out." - ) - - # 情况 3: 所有算子均为 completed - elif not scheduled_ops and completed_ops: - latest_op = max(completed_ops, key=lambda x: x["time_discovered_completed_ns"] or 0) - logging.info( - f"The computational task of the pg_id {pg_id} " - f"after the communication operator {latest_op['name']} " - "took too long." - ) - - else: - logging.info(f"The situation cannot be recognized!") - - -def get_int_arg(args, idx, default): - if len(args) > idx: - try: - return int(args[idx]) - except ValueError: - logging.warning(f"Invalid input {args[idx]}, using default: {default}") - return default - - -def main(): - # 设置默认值 - default_path = os.getenv("TORCH_HCCL_DEBUG_INFO_TEMP_FILE") - default_world_size = 8 - - # 获取命令行参数,如果未提供则使用默认值 - path = sys.argv[1] if len(sys.argv) > 1 else default_path - world_size = get_int_arg(sys.argv, 2, default_world_size) - - if not path: - raise ValueError("Path is required and cannot be empty.") - - logging.info(f"Path: {path}") - logging.info(f"World Size: {world_size}") - - # 加载数据 - recorder_dict = load_recorder_data(path, world_size) - if not recorder_dict: - logging.error("No valid recorder data found.") - return - - # 提取 HCCL 信息 - hccl_dict = extract_hccl_info(recorder_dict) - - # 分析 HCCL 数据 - analyze_pg_groups(hccl_dict) - - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/flight_recoder/check_path.py b/flight_recoder/check_path.py deleted file mode 100644 index b34e4dcdb68b28b44f387cb14919ad127658ca8f..0000000000000000000000000000000000000000 --- a/flight_recoder/check_path.py +++ /dev/null @@ -1,81 +0,0 @@ -# Copyright (c) 2025, Huawei Technologies Co., Ltd. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import re -import os -import sys -import stat - - -PATH_WHITE_LIST_REGEX = re.compile(r"[^_A-Za-z0-9/.-]") -MAX_READ_FILE_SIZE_4G = 4294967296 # 4G, 4 * 1024 * 1024 * 1024 -MAX_READ_FILE_SIZE_32G = 34359738368 # 32G, 32 * 1024 * 1024 * 1024 -MAX_READ_FILE_SIZE_512G = 549755813888 # 512G, 512 * 1024 * 1024 * 1024 - -# group not writable, others no permission, max stat is 750 -WRITE_FILE_NOT_PERMITTED_STAT = stat.S_IWGRP | stat.S_IWOTH | stat.S_IROTH | stat.S_IXOTH -# group not writable, others not writable, max stat is 755 -READ_FILE_NOT_PERMITTED_STAT = stat.S_IWGRP | stat.S_IWOTH - - -def type_to_str(value_type): - return ' or '.join([ii.__name__ for ii in value_type]) if isinstance(value_type, tuple) else value_type.__name__ - - -def check_type(value, value_type, param_name="value"): - if not isinstance(value, value_type): - raise TypeError('{} must be {}, not {}.'.format(param_name, type_to_str(value_type), type(value).__name__)) - - -def get_valid_path(path): - check_type(path, str, "path") - if not path or len(path) == 0: - raise ValueError("The value of the path cannot be empty.") - if PATH_WHITE_LIST_REGEX.search(path): # Check special char - raise ValueError("Input path contains invalid characters.") # Not printing out the path value for invalid char - path = os.path.expanduser(path) # Consider paths starting with "~" - if os.path.islink(os.path.abspath(path)): # when checking link, get rid of the "/" at the path tail if any - raise ValueError("The value of the path cannot be soft link: {}.".format(path)) - - real_path = os.path.realpath(path) - - if len(real_path) > 4096: - raise ValueError("The length of file path should be less than 4096.") - - if real_path != path and PATH_WHITE_LIST_REGEX.search(real_path): # Check special char again - raise ValueError("Input path contains invalid characters.") # Not printing out the path value for invalid char - - return real_path - - -def is_belong_to_user_or_group(file_stat): - return file_stat.st_uid == os.getuid() or file_stat.st_gid in os.getgroups() - - -def get_valid_read_path(path, size_max=MAX_READ_FILE_SIZE_4G, check_user_stat=True, is_dir=False): - real_path = get_valid_path(path) - if not os.path.isfile(real_path): - raise ValueError("The path {} doesn't exists or not a file.".format(path)) - - file_stat = os.stat(real_path) - if check_user_stat and not sys.platform.startswith("win") and not is_belong_to_user_or_group(file_stat): - raise ValueError("The file {} doesn't belong to the current user or group.".format(path)) - if check_user_stat and os.stat(path).st_mode & READ_FILE_NOT_PERMITTED_STAT > 0: - raise ValueError("The file {} is group writable, or is others writable.".format(path)) - if not os.access(real_path, os.R_OK) or file_stat.st_mode & stat.S_IRUSR == 0: # At least been 400 - raise ValueError("Current user doesn't have read permission to the file {}.".format(path)) - if not is_dir and size_max > 0 and file_stat.st_size > size_max: - raise ValueError("The file {} exceeds size limitation of {}.".format(path, size_max)) - return real_path \ No newline at end of file diff --git a/flight_recoder/flight_recoder.md b/flight_recoder/flight_recoder.md deleted file mode 100644 index 8b398a6730bae0823b04c20a22258a81392922c9..0000000000000000000000000000000000000000 --- a/flight_recoder/flight_recoder.md +++ /dev/null @@ -1,49 +0,0 @@ -# 飞行记录器超时类问题分析 - -训练任务卡住是阻塞AI大规模分布式集群训练任务的主要和关键问题,当前需要等待集合通信超时才能感知,影响集群可用性。框架需要支持检测训练任务卡住问题,做到提前识别并保存必要的诊断信息,提高问题定位效率和集群设备可用性。当HeartbeatMonitor长时间未检测到心跳时,即可认为训练任务已经卡住,需要触发诊断信息保存。 - -本工具提供torch npu上飞行记录器flight recorder记录日志的读取解析能力,并根据解析后的日志提供超时类问题的初步分析能力,主要支持以下三种情况的超时类问题的识别和分析 - -|问题| 具体内容 | -| --- | --- | -|类型一 | 同通信域内的某张卡计算超时,导致其他卡等待触发飞行记录器和hccl time out | -|类型二 | 同通信域内的通信算子之后的非通信任务耗时过长| -|类型三 | 同通信域内的某个通信算子进行通信时执行超时 | - -## 使用方法 - -### 1 飞行记录器开启方法 - -按照如下方法设置环境变量开启飞行记录器 - -``` -export TORCH_HCCL_ENABLE_MONITORING=1 #用于检测是否开启卡住问题检测 -export TORCH_HCCL_DUMP_ON_TIMEOUT=1 # 用于控制是否保存诊断信息 -export TORCH_HCCL_TRACE_BUFFER_SIZE=1 # 用于控制保存的集合通信状态数量 -export TORCH_HCCL_HEARTBEAT_TIMEOUT_SEC=20 # 用于控制心跳超时时间,即训练业务多久未下发集合通信算子时需要判定为卡住,默认10分钟,单位s。(需要小于HCCL_EXEC_TIMEOUT,避免集合通信先报超时错误) -export TORCH_HCCL_DEBUG_INFO_TEMP_FILE=/tmp/ #保存诊断信息的文件路径 -``` - -### 2 工具使用方法 - -``` -python analysis_flight.py path world_size -``` - -脚本从命令行参数获取 `path` 和 `world_size` 的值,并记录日志。如果未提供命令行参数,则使用默认值。 - -* `path`:从命令行第一个参数获取,如果未提供则使用 `default_path`, default_path从TORCH_HCCL_DEBUG_INFO_TEMP_FILE获取。 -* `world_size`:从命令行第二个参数获取,如果未提供则使用 `default_world_size`,默认为8。 - -| 参数名| 含义 | 使用限制 | -| --- | --- | --- | -| path | 飞行记录器的日志 | 可选。数据类型:string 默认为环境变量中的TORCH_HCCL_DEBUG_INFO_TEMP_FILE,若设置日志格式指定有前缀,则需要在路径中加入前缀 | -| world_size | 同一个通信域中的卡数 | 可选。数据类型:int 默认为8 | - -### 3 输出示例 - -``` -2025-02-19 08:10:07,160 - INFO - Path: /tmp/ -2025-02-19 08:10:07,160 - INFO - World Size: 8 -2025-02-19 08:10:07,162 - INFO - The pg_id 0's rank 0's Computational task took too long, causing the other ranks' HCCL task to time out. -``` diff --git a/msmonitor/plugin/CMakeLists.txt b/msmonitor/plugin/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..9abfa9a951d732a9dca97438152abdf580a78418 --- /dev/null +++ b/msmonitor/plugin/CMakeLists.txt @@ -0,0 +1,68 @@ +cmake_minimum_required(VERSION 3.16) +project(IPCMonitor) + +set(CMAKE_SKIP_RPATH TRUE) + +set(CMAKE_CXX_STANDARD 14) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_CXX_EXTENSIONS OFF) + +find_package(pybind11 REQUIRED) +find_package(Python REQUIRED COMPONENTS Interpreter Development) + +include_directories( + ${CMAKE_CURRENT_SOURCE_DIR}/ipc_monitor + ${CMAKE_CURRENT_SOURCE_DIR}/ipc_monitor/metric + ${CMAKE_CURRENT_SOURCE_DIR}/ipc_monitor/mspti_monitor + ${CMAKE_CURRENT_SOURCE_DIR}/third_party/securec/include + ${DYNOLOG_PATH}/third_party/glog/src + ${DYNOLOG_PATH}/build/third_party/glog + ${DYNOLOG_PATH}/third_party/json/single_include +) + +file(GLOB_RECURSE IPC_SOURCES + ${CMAKE_CURRENT_SOURCE_DIR}/ipc_monitor/*.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/ipc_monitor/metric/*.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/ipc_monitor/mspti_monitor/*.cpp +) + +file(GLOB_RECURSE SECUREC_SOURCES ${CMAKE_CURRENT_SOURCE_DIR}/third_party/securec/src/*.c) + +set(SOURCES + bindings.cpp + ${IPC_SOURCES} + ${SECUREC_SOURCES} +) + +add_library(IPCMonitor MODULE ${SOURCES}) + +set_target_properties(IPCMonitor + PROPERTIES + OUTPUT_NAME IPCMonitor + PREFIX "" +) + +target_link_libraries(IPCMonitor PRIVATE + pybind11::module + pthread + ${CMAKE_CURRENT_SOURCE_DIR}/stub/libmspti.so +) + +target_link_libraries(IPCMonitor PRIVATE ${DYNOLOG_PATH}/build/third_party/glog/libglog.a) + +target_compile_options(IPCMonitor PRIVATE + -fPIC + -fstack-protector-all + -ftrapv + $<$>:-O2> +) +add_compile_options(-D_FORITFY_SOURCE=2 -O2) + +target_link_options(IPCMonitor PRIVATE + -Wl,-z,relro,-z,now,-z,noexecstack + -s +) + +install(TARGETS IPCMonitor + DESTINATION ${CMAKE_INSTALL_PREFIX}/python-package +) diff --git a/msmonitor/plugin/README.md b/msmonitor/plugin/README.md new file mode 100644 index 0000000000000000000000000000000000000000..43784413c9582f15f8014e863415d9ec2f422ed6 --- /dev/null +++ b/msmonitor/plugin/README.md @@ -0,0 +1,49 @@ + + +# Plugins for msMonitor +## 模块说明 +### IPCMonitor +提供IPC(Inter-Process Communication)通信接口,用于实现 +1. IPC控制通道: profiler backend向dynolog daemon获取profiler配置 +2. IPC数据通道: mspti monitor向dynolog daemon发送性能数据 + +__PyDynamicMonitorProxy__: +* `init_dyno` 向dynolog daemon发送注册请求 + * input: npuId(int) + * return:None +* `poll_dyno` 向dynolog daemon获取Profiler控制参数 + * input: None + * return: str, 返回控制参数 +* `enable_dyno_npu_monitor` 开启mspti监控 + * input: cfg_map(Dict[str,str]) 配置 + * return: None + +## 安装方式 +### 1. 通过shell脚本一键安装 +``` +chmod +x build.sh +./build.sh +``` +### 2. 手动安装 +* 安装依赖 +``` +pip install wheel +pip install pybind11 +``` +* 编译whl包 +``` +bash ./stub/build_stub.sh +python3 setup.py bdist_wheel +``` +以上命令执行完成后在dist目录下生成msMonitor插件whl安装包msmonitor-plugin-{version}.whl +* 安装 +``` +pip install dist/{msmonitor-plugin-{version}.whl} +``` +* 卸载 +``` +pip uninstall msmonitor-plugin +``` + +## 日志 +* 用户可以通过配置MSMONITOR_LOG_PATH环境变量,指定日志文件路径,默认路径为当前目录下的msmonitor_log diff --git a/msmonitor/plugin/bindings.cpp b/msmonitor/plugin/bindings.cpp new file mode 100644 index 0000000000000000000000000000000000000000..626e72157e25d290d28b1cd2706625b97caa9048 --- /dev/null +++ b/msmonitor/plugin/bindings.cpp @@ -0,0 +1,29 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include "ipc_monitor/PyDynamicMonitorProxy.h" + +namespace py = pybind11; + +PYBIND11_MODULE(IPCMonitor, m) { + py::class_(m, "PyDynamicMonitorProxy") + .def(py::init<>()) + .def("init_dyno", &dynolog_npu::ipc_monitor::PyDynamicMonitorProxy::InitDyno, py::arg("npuId")) + .def("poll_dyno", &dynolog_npu::ipc_monitor::PyDynamicMonitorProxy::PollDyno) + .def("enable_dyno_npu_monitor", &dynolog_npu::ipc_monitor::PyDynamicMonitorProxy::EnableMsptiMonitor, py::arg("cfg_map")) + .def("finalize_dyno", &dynolog_npu::ipc_monitor::PyDynamicMonitorProxy::FinalizeDyno); +} \ No newline at end of file diff --git a/dynolog_npu/plugin/build.sh b/msmonitor/plugin/build.sh old mode 100755 new mode 100644 similarity index 84% rename from dynolog_npu/plugin/build.sh rename to msmonitor/plugin/build.sh index ce20d9d2be546afbc63e3aace524f74858eff6ff..ec20536715a9b2bd1fd8ab7a694ca9eac26f3101 --- a/dynolog_npu/plugin/build.sh +++ b/msmonitor/plugin/build.sh @@ -3,7 +3,10 @@ # install pybind11 pip install pybind11 -# build dynolog_npu_plugin wheel +# build stub +sh ./stub/build_stub.sh + +# build msmonitor_plugin wheel python3 setup.py bdist_wheel # find .whl files in dist diff --git a/msmonitor/plugin/ipc_monitor/DynoLogNpuMonitor.cpp b/msmonitor/plugin/ipc_monitor/DynoLogNpuMonitor.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7774358177b087019ed3c5c64b2adc981bc11b73 --- /dev/null +++ b/msmonitor/plugin/ipc_monitor/DynoLogNpuMonitor.cpp @@ -0,0 +1,130 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "DynoLogNpuMonitor.h" +#include +#include +#include +#include "utils.h" + +namespace dynolog_npu { +namespace ipc_monitor { +DynoLogNpuMonitor::DynoLogNpuMonitor() +{ + // init glog + if (!google::IsGoogleLoggingInitialized()) { + std::string logPath; + if (CreateMsmonitorLogPath(logPath)) { + fprintf(stderr, "[INFO] [%d] Msmonitor log will record to %s\n", GetProcessId(), logPath.c_str()); + logPath = logPath + "/msmonitor_"; + google::InitGoogleLogging("MsMonitor"); + google::SetLogDestination(google::GLOG_INFO, logPath.c_str()); + google::SetLogFilenameExtension(".log"); + } else { + fprintf(stderr, "Failed to create log path, log will not record\n"); + } + } +} + +bool DynoLogNpuMonitor::Init() +{ + if (isInitialized_) { + LOG(WARNING) << "DynoLog npu monitor already initialized"; + return true; + } + if (!ipcClient_.Init()) { + LOG(ERROR) << "DynoLog npu monitor ipcClient init failed"; + return false; + } + bool res = ipcClient_.RegisterInstance(npuId_); + if (res) { + isInitialized_ = true; + LOG(INFO) << "DynoLog npu monitor initialized successfully"; + } + return res; +} + +ErrCode DynoLogNpuMonitor::DealMonitorReq(MsptiMonitorCfg& cmd) +{ + if (cmd.monitorStop) { + if (msptiMonitor_.IsStarted()) { + LOG(INFO) << "Stop mspti monitor thread successfully"; + msptiMonitor_.Stop(); + } + return ErrCode::SUC; + } + + if (cmd.reportIntervals <= 0) { + cmd.reportIntervals = DEFAULT_FLUSH_INTERVAL; + LOG(WARNING) << "Invalid report interval, set to 60"; + } + if (cmd.reportIntervals != 0) { + msptiMonitor_.SetFlushInterval(cmd.reportIntervals); + } + + if (cmd.monitorStart && !msptiMonitor_.IsStarted()) { + LOG(INFO) << "Start mspti monitor thread successfully"; + msptiMonitor_.Start(); + } + + if (msptiMonitor_.IsStarted() && !cmd.enableActivities.empty()) { + auto curActivities = msptiMonitor_.GetEnabledActivities(); + std::vector enableKinds; + std::vector disableKinds; + std::set_difference(cmd.enableActivities.begin(), cmd.enableActivities.end(), curActivities.begin(), curActivities.end(), + std::back_inserter(enableKinds)); + std::set_difference(curActivities.begin(), curActivities.end(), cmd.enableActivities.begin(), cmd.enableActivities.end(), + std::back_inserter(disableKinds)); + for (auto activity : enableKinds) { + msptiMonitor_.EnableActivity(activity); + } + for (auto activity : disableKinds) { + msptiMonitor_.DisableActivity(activity); + } + } + return ErrCode::SUC; +} + +std::string DynoLogNpuMonitor::Poll() +{ + std::string res = ipcClient_.IpcClientNpuConfig(); + if (res.size() == 4) { // res为4,表示dynolog注册进程成功 + LOG(INFO) << "Regist to dynolog daemon successfully"; + return ""; + } + if (res.empty()) { + return ""; + } + LOG(INFO) << "Received NPU configuration successfully"; + return res; +} + +void DynoLogNpuMonitor::EnableMsptiMonitor(std::unordered_map& cfg_map) +{ + auto cmd = InputParser::GetInstance()->DynoLogGetOpts(cfg_map); + if (cmd.isMonitor) { + auto ans = DealMonitorReq(cmd); + if (ans != ErrCode::SUC) { + LOG(ERROR) << "Deal monitor request failed, because" << IPC_ERROR(ans); + } + } +} + +void DynoLogNpuMonitor::Finalize() +{ + msptiMonitor_.Uninit(); +} +} // namespace ipc_monitor +} // namespace dynolog_npu diff --git a/dynolog_npu/plugin/ipc_monitor/DynoLogNpuMonitor.h b/msmonitor/plugin/ipc_monitor/DynoLogNpuMonitor.h similarity index 37% rename from dynolog_npu/plugin/ipc_monitor/DynoLogNpuMonitor.h rename to msmonitor/plugin/ipc_monitor/DynoLogNpuMonitor.h index 40ee21072710312a86cd75befdcefa67e24efb8f..5ffec3bd9667cefb400addaa49cdc2ea6d2ccb8a 100644 --- a/dynolog_npu/plugin/ipc_monitor/DynoLogNpuMonitor.h +++ b/msmonitor/plugin/ipc_monitor/DynoLogNpuMonitor.h @@ -1,9 +1,26 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #ifndef DYNOLOG_NPU_MONITOR_H #define DYNOLOG_NPU_MONITOR_H #include "MonitorBase.h" #include "NpuIpcClient.h" +#include "MsptiMonitor.h" #include "singleton.h" +#include "InputParser.h" namespace dynolog_npu { namespace ipc_monitor { @@ -12,22 +29,30 @@ class DynoLogNpuMonitor : public MonitorBase, public Singleton; public: - DynoLogNpuMonitor() = default; + DynoLogNpuMonitor(); bool Init() override; + ErrCode DealMonitorReq(MsptiMonitorCfg& cmd); std::string Poll() override; + void EnableMsptiMonitor(std::unordered_map& cfg_map); + void Finalize(); void SetNpuId(int id) override { npuId_ = id; } + IpcClient *GetIpcClient() + { + return &ipcClient_; + } + private: bool isInitialized_ = false; int32_t npuId_ = 0; IpcClient ipcClient_; + MsptiMonitor msptiMonitor_; }; } // namespace ipc_monitor } // namespace dynolog_npu -#endif - +#endif // DYNOLOG_NPU_MONITOR_H diff --git a/msmonitor/plugin/ipc_monitor/InputParser.cpp b/msmonitor/plugin/ipc_monitor/InputParser.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9dd9f4bb46412040825b050bbec5489ad387961c --- /dev/null +++ b/msmonitor/plugin/ipc_monitor/InputParser.cpp @@ -0,0 +1,76 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "InputParser.h" +#include +#include +#include "utils.h" + +namespace dynolog_npu { +namespace ipc_monitor { + +const std::string MSPTI_ACTIVITY_KIND_KEY = "MSPTI_ACTIVITY_KIND"; +const std::string REPORT_INTERVAL_S_KEY = "REPORT_INTERVAL_S"; +const std::string NPU_MONITOR_START_KEY = "NPU_MONITOR_START"; +const std::string NPU_MONITOR_STOP_KEY = "NPU_MONITOR_STOP"; + +const std::unordered_set cfgMap { + "MSPTI_ACTIVITY_KIND", + "REPORT_INTERVAL_S", + "NPU_MONITOR_START", + "NPU_MONITOR_STOP", + "REQUEST_TRACE_ID" +}; + +const std::unordered_map kindStrMap { + {"Marker", MSPTI_ACTIVITY_KIND_MARKER}, + {"Kernel", MSPTI_ACTIVITY_KIND_KERNEL}, + {"API", MSPTI_ACTIVITY_KIND_API}, + {"Hccl", MSPTI_ACTIVITY_KIND_HCCL}, + {"Memory", MSPTI_ACTIVITY_KIND_MEMORY}, + {"MemSet", MSPTI_ACTIVITY_KIND_MEMSET}, + {"MemCpy", MSPTI_ACTIVITY_KIND_MEMCPY} +}; + +std::set str2Kinds(const std::string& kindStrs) +{ + std::set res; + auto kindStrList = split(kindStrs, ','); + for (auto& kindStr : kindStrList) { + auto kind = kindStrMap.find(kindStr); + if (kind == kindStrMap.end()) { + return {MSPTI_ACTIVITY_KIND_INVALID}; + } + res.insert(kind->second); + } + return res; +} + +MsptiMonitorCfg InputParser::DynoLogGetOpts(std::unordered_map& cmd) +{ + if (cmd.count("NPU_MONITOR_SRART")) { + return {{MSPTI_ACTIVITY_KIND_INVALID}, 0, false, false, false}; + } + auto activityKinds = str2Kinds(cmd[MSPTI_ACTIVITY_KIND_KEY]); + uint32_t reportTimes = 0; + Str2Uint32(reportTimes, cmd[REPORT_INTERVAL_S_KEY]); + bool startSwitch = false; + Str2Bool(startSwitch, cmd[NPU_MONITOR_START_KEY]); + bool endSwitch = false; + Str2Bool(endSwitch, cmd[NPU_MONITOR_STOP_KEY]); + return {activityKinds, reportTimes, startSwitch, endSwitch, true}; +} +} +} \ No newline at end of file diff --git a/msmonitor/plugin/ipc_monitor/InputParser.h b/msmonitor/plugin/ipc_monitor/InputParser.h new file mode 100644 index 0000000000000000000000000000000000000000..5288a76e2a8c8c597e973e33cd04a14fc345e6a6 --- /dev/null +++ b/msmonitor/plugin/ipc_monitor/InputParser.h @@ -0,0 +1,44 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef INPUT_PARSER_H +#define INPUT_PARSER_H + +#include +#include +#include +#include + +namespace dynolog_npu { +namespace ipc_monitor { + +struct MsptiMonitorCfg { + std::set enableActivities; + uint32_t reportIntervals; + bool monitorStart; + bool monitorStop; + bool isMonitor; +}; + + +class InputParser : public dynolog_npu::ipc_monitor::Singleton { +public: + MsptiMonitorCfg DynoLogGetOpts(std::unordered_map& cmd); +}; + +} // namespace ipc_monitor +} // namespace dynolog_npu + +#endif \ No newline at end of file diff --git a/dynolog_npu/plugin/ipc_monitor/MonitorBase.h b/msmonitor/plugin/ipc_monitor/MonitorBase.h similarity index 31% rename from dynolog_npu/plugin/ipc_monitor/MonitorBase.h rename to msmonitor/plugin/ipc_monitor/MonitorBase.h index 108023c7624b747e5987be9184d6c594decd360a..a46fc7fe31e9464c00cfecd01a29b85f3977705b 100644 --- a/dynolog_npu/plugin/ipc_monitor/MonitorBase.h +++ b/msmonitor/plugin/ipc_monitor/MonitorBase.h @@ -1,5 +1,21 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #ifndef MONITOR_BASE_H #define MONITOR_BASE_H + #include namespace dynolog_npu { @@ -14,5 +30,4 @@ public: } // namespace ipc_monitor } // namespace dynolog_npu - -#endif \ No newline at end of file +#endif // MONITOR_BASE_H diff --git a/dynolog_npu/plugin/ipc_monitor/NpuIpcClient.cpp b/msmonitor/plugin/ipc_monitor/NpuIpcClient.cpp similarity index 64% rename from dynolog_npu/plugin/ipc_monitor/NpuIpcClient.cpp rename to msmonitor/plugin/ipc_monitor/NpuIpcClient.cpp index 97966e8eeacc7276426feb237aa122eb8dee046f..93fc9370a6216839120713ea250c12984f72e322 100644 --- a/dynolog_npu/plugin/ipc_monitor/NpuIpcClient.cpp +++ b/msmonitor/plugin/ipc_monitor/NpuIpcClient.cpp @@ -1,57 +1,82 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #include "NpuIpcClient.h" - -#include +#include namespace dynolog_npu { namespace ipc_monitor { +bool IpcClient::Init() +{ + pids_ = GetPids(); + return true; +} -bool IpcClient::RegisterInstance(int32_t id) +bool IpcClient::RegisterInstance(int32_t npu) { NpuContext context{ - .npu = id, + .npu = npu, .pid = getpid(), .jobId = JOB_ID, }; - std::unique_ptr message = Message::ConstructMessage(context, "ctxt"); + std::unique_ptr message = Message::ConstructMessage(context, MSG_TYPE_CONTEXT); try { - if (!SyncSendMessage(*message, std::string(DYNO_IPC_NAME))) { - std::cout << "[WARNING]Failed to send register ctxt for pid " << context.pid << " with dyno" << std::endl; + if (!SyncSendMessage(*message, DYNO_IPC_NAME)) { + LOG(WARNING) << "Failed to send register ctxt for pid " << context.pid << " with dyno"; return false; } } catch (const std::exception &e) { - std::cout << "[WARNING] Error when SyncSendMessage: " << e.what() << std::endl; + LOG(WARNING) << "Error when SyncSendMessage: " << e.what(); return false; } - std::cout << "[INFO] Resigter pid " << context.pid << " for dynolog success !" << std::endl; + LOG(INFO) << "Resigter pid " << context.pid << " for dynolog success!"; return true; } + std::string IpcClient::IpcClientNpuConfig() { auto size = pids_.size(); - auto *req = (NpuRequest *)malloc(sizeof(NpuRequest) + sizeof(int32_t) * size); + auto *req = ReinterpretConvert(malloc(sizeof(NpuRequest) + sizeof(int32_t) * size)); + if (req == nullptr) { + LOG(ERROR) << " Malloc for NpuRequest failed !"; + return ""; + } req->type = DYNO_IPC_TYPE; req->pidSize = size; req->jobId = JOB_ID; - for (int i = 0; i < size; i++) { + for (size_t i = 0; i < size; i++) { req->pids[i] = pids_[i]; } - std::unique_ptr message = Message::ConstructMessage(*req, "req", size); - if (!SyncSendMessage(*message, std::string(DYNO_IPC_NAME))) { - std::cout << "[WARNING] Failed to send config to dyno server fail !" << std::endl; + std::unique_ptr message = Message::ConstructMessage(*req, MSG_TYPE_REQUEST, size); + if (!SyncSendMessage(*message, DYNO_IPC_NAME)) { + LOG(WARNING) << "Failed to send config to dyno server"; free(req); req = nullptr; return ""; } free(req); + req = nullptr; message = PollRecvMessage(MAX_IPC_RETRIES, MAX_SLEEP_US); if (!message) { - std::cout << "[WARNING] Failed to receive on-demand config !" << std::endl; + LOG(WARNING) << "Failed to receive on-demand config"; return ""; } std::string res = std::string(ReinterpretConvert(message->buf.get()), message->metadata.size); - return res; } + std::unique_ptr IpcClient::ReceiveMessage() { std::lock_guard wguard(dequeLock_); @@ -62,10 +87,11 @@ std::unique_ptr IpcClient::ReceiveMessage() msgDynoDeque_.pop_front(); return message; } + bool IpcClient::SyncSendMessage(const Message &message, const std::string &destName, int numRetry, int seepTimeUs) { if (destName.empty()) { - std::cout << "[WARNING] Can not send to empty socket name !" << std::endl; + LOG(WARNING) << "Can not send to empty socket name!"; return false; } int i = 0; @@ -79,11 +105,12 @@ bool IpcClient::SyncSendMessage(const Message &message, const std::string &destN seepTimeUs *= 2; // 2: double sleep time } } catch (const std::exception &e) { - std::cout << "[ERROR] Error when SyncSendMessage: " << e.what() << std::endl; + LOG(ERROR) << "Error when SyncSendMessage: " << e.what(); return false; } return i < numRetry; } + bool IpcClient::Recv() { try { @@ -94,7 +121,7 @@ bool IpcClient::Recv() try { successFlag = ep_.TryPeekMessage(*peekCtxt); } catch (std::exception &e) { - std::cout << "[ERROR] Error when TryPeekMessage: " << e.what() << std::endl; + LOG(ERROR) << "Error when TryPeekMessage: " << e.what(); return false; } if (successFlag) { @@ -108,7 +135,7 @@ bool IpcClient::Recv() try { successFlag = ep_.TryRcvMessage(*recvCtxt); } catch (std::exception &e) { - std::cout << "[ERROR] Error when TryRecvMsg: " << e.what() << std::endl; + LOG(ERROR) << "Error when TryRecvMsg: " << e.what(); return false; } if (successFlag) { @@ -118,11 +145,12 @@ bool IpcClient::Recv() } } } catch (std::exception &e) { - std::cout << "[ERROR] Error in Recv(): " << e.what() << std::endl; + LOG(ERROR) << "Error in Recv(): " << e.what(); return false; } return false; } + std::unique_ptr IpcClient::PollRecvMessage(int maxRetry, int sleeTimeUs) { for (int i = 0; i < maxRetry; i++) { @@ -133,6 +161,5 @@ std::unique_ptr IpcClient::PollRecvMessage(int maxRetry, int sleeTimeUs } return nullptr; } - } // namespace ipc_monitor -} // namespace dynolog_npu \ No newline at end of file +} // namespace dynolog_npu diff --git a/dynolog_npu/plugin/ipc_monitor/NpuIpcClient.h b/msmonitor/plugin/ipc_monitor/NpuIpcClient.h similarity index 50% rename from dynolog_npu/plugin/ipc_monitor/NpuIpcClient.h rename to msmonitor/plugin/ipc_monitor/NpuIpcClient.h index ae7b00eb51b935db4e799fab470c3343e78bcb6f..4b4937bd6886c169faa2cfe76aeaf1ed10c85592 100644 --- a/dynolog_npu/plugin/ipc_monitor/NpuIpcClient.h +++ b/msmonitor/plugin/ipc_monitor/NpuIpcClient.h @@ -1,40 +1,59 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #ifndef NPU_IPC_CLIENT_H #define NPU_IPC_CLIENT_H -#include -#include + +#include #include #include -#include -#include -#include -#include #include "NpuIpcEndPoint.h" #include "utils.h" +#include "securec.h" namespace dynolog_npu { namespace ipc_monitor { constexpr int TYPE_SIZE = 32; constexpr int JOB_ID = 0; -constexpr const char *DYNO_IPC_NAME = "dynolog"; constexpr const int DYNO_IPC_TYPE = 3; constexpr const int MAX_IPC_RETRIES = 5; constexpr const int MAX_SLEEP_US = 10000; +const std::string DYNO_IPC_NAME = "dynolog"; +const std::string MSG_TYPE_REQUEST = "req"; +const std::string MSG_TYPE_CONTEXT = "ctxt"; +const std::string MSG_TYPE_DATA = "data"; + struct NpuRequest { int type; int pidSize; int64_t jobId; int32_t pids[0]; }; + struct NpuContext { int32_t npu; pid_t pid; int64_t jobId; }; + struct Metadata { size_t size = 0; char type[TYPE_SIZE] = ""; }; + struct Message { Metadata metadata; std::unique_ptr buf; @@ -45,19 +64,26 @@ struct Message { if (type.size() + 1 > sizeof(ipcNpuMessage->metadata.type)) { throw std::runtime_error("Type string is too long to fit in metadata.type" + IPC_ERROR(ErrCode::PARAM)); } - memcpy(ipcNpuMessage->metadata.type, type.c_str(), type.size() + 1); + if (memcpy_s(ipcNpuMessage->metadata.type, sizeof(ipcNpuMessage->metadata.type), + type.c_str(), type.size() + 1) != EOK) { + throw std::runtime_error("memcpy_s failed" + IPC_ERROR(ErrCode::MEMORY)); + } #if __cplusplus >= 201703L if constexpr (std::is_same::value == true) { ipcNpuMessage->metadata.size = data.size(); ipcNpuMessage->buf = std::make_unique(ipcNpuMessage->metadata.size); - memcpy(ipcNpuMessage->buf.get(), data.c_str(), sizeof(data)); + if (memcpy_s(ipcNpuMessage->buf.get(), ipcNpuMessage->metadata.size, data.c_str(), data.size()) != EOK) { + throw std::runtime_error("memcpy_s failed" + IPC_ERROR(ErrCode::MEMORY)); + } return ipcNpuMessage; } #endif static_assert(std::is_trivially_copyable::value); ipcNpuMessage->metadata.size = sizeof(data); ipcNpuMessage->buf = std::make_unique(ipcNpuMessage->metadata.size); - memcpy(ipcNpuMessage->buf.get(), &data, sizeof(data)); + if (memcpy_s(ipcNpuMessage->buf.get(), ipcNpuMessage->metadata.size, &data, sizeof(data)) != EOK) { + throw std::runtime_error("memcpy_s failed" + IPC_ERROR(ErrCode::MEMORY)); + } return ipcNpuMessage; } @@ -68,36 +94,61 @@ struct Message { if (type.size() + 1 > sizeof(ipcNpuMessage->metadata.type)) { throw std::runtime_error("Type string is too long to fit in metadata.type" + IPC_ERROR(ErrCode::PARAM)); } - memcpy(ipcNpuMessage->metadata.type, type.c_str(), type.size() + 1); + if (memcpy_s(ipcNpuMessage->metadata.type, sizeof(ipcNpuMessage->metadata.type), + type.c_str(), type.size() + 1) != EOK) { + throw std::runtime_error("memcpy_s failed" + IPC_ERROR(ErrCode::MEMORY)); + } static_assert(std::is_trivially_copyable::value); static_assert(std::is_trivially_copyable::value); ipcNpuMessage->metadata.size = sizeof(data) + sizeof(U) * n; ipcNpuMessage->buf = std::make_unique(ipcNpuMessage->metadata.size); - memcpy(ipcNpuMessage->buf.get(), &data, ipcNpuMessage->metadata.size); + if (memcpy_s(ipcNpuMessage->buf.get(), ipcNpuMessage->metadata.size, + &data, ipcNpuMessage->metadata.size) != EOK) { + throw std::runtime_error("memcpy_s failed" + IPC_ERROR(ErrCode::MEMORY)); + } + return ipcNpuMessage; + } + + static std::unique_ptr ConstructStrMessage(const std::string &data, const std::string &type) + { + std::unique_ptr ipcNpuMessage = std::make_unique(Message()); + if (type.size() + 1 > sizeof(ipcNpuMessage->metadata.type)) { + throw std::runtime_error("Type string is too long to fit in metadata.type" + IPC_ERROR(ErrCode::PARAM)); + } + if (memcpy_s(ipcNpuMessage->metadata.type, sizeof(ipcNpuMessage->metadata.type), + type.c_str(), type.size() + 1) != EOK) { + throw std::runtime_error("memcpy_s failed" + IPC_ERROR(ErrCode::MEMORY)); + } + ipcNpuMessage->metadata.size = data.size(); + ipcNpuMessage->buf = std::make_unique(ipcNpuMessage->metadata.size); + if (memcpy_s(ipcNpuMessage->buf.get(), ipcNpuMessage->metadata.size, data.c_str(), data.size()) != EOK) { + throw std::runtime_error("memcpy_s failed" + IPC_ERROR(ErrCode::MEMORY)); + } return ipcNpuMessage; } }; + class IpcClient { public: IpcClient(const IpcClient &) = delete; IpcClient &operator = (const IpcClient &) = delete; IpcClient() = default; + bool Init(); bool RegisterInstance(int32_t npu); std::string IpcClientNpuConfig(); + bool SyncSendMessage(const Message &message, const std::string &destName, int numRetry = 10, + int seepTimeUs = 10000); private: - std::vector pids_ = GetPids(); + std::vector pids_; NpuIpcEndPoint<0> ep_{ "dynoconfigclient" + GenerateUuidV4() }; std::mutex dequeLock_; std::deque> msgDynoDeque_; std::unique_ptr ReceiveMessage(); - bool SyncSendMessage(const Message &message, const std::string &destName, int numRetry = 10, - int seepTimeUs = 10000); bool Recv(); std::unique_ptr PollRecvMessage(int maxRetry, int sleeTimeUs); }; - } // namespace ipc_monitor } // namespace dynolog_npu -#endif \ No newline at end of file +#endif // NPU_IPC_CLIENT_H diff --git a/dynolog_npu/plugin/ipc_monitor/NpuIpcEndPoint.h b/msmonitor/plugin/ipc_monitor/NpuIpcEndPoint.h similarity index 77% rename from dynolog_npu/plugin/ipc_monitor/NpuIpcEndPoint.h rename to msmonitor/plugin/ipc_monitor/NpuIpcEndPoint.h index 6560fa515646226ddbffbca49c4f818eb0d0ebcf..22c43905fe242c4e5012836f54e1ae041189fe7b 100644 --- a/dynolog_npu/plugin/ipc_monitor/NpuIpcEndPoint.h +++ b/msmonitor/plugin/ipc_monitor/NpuIpcEndPoint.h @@ -1,16 +1,30 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #ifndef NPU_IPC_ENDPOINT_H #define NPU_IPC_ENDPOINT_H -#include + #include #include #include #include #include +#include #include -#include -#include -#include #include "utils.h" +#include "securec.h" namespace dynolog_npu { namespace ipc_monitor { @@ -46,23 +60,34 @@ public: if (socketFd == -1) { throw std::runtime_error(std::strerror(errno) + IPC_ERROR(ErrCode::PARAM)); } + int ret = 0; struct sockaddr_un address; size_t addressLen = SetSocketAdress(addressName, address); if (address.sun_path[0] != STR_END_CHAR) { - unlink(address.sun_path); + ret = unlink(address.sun_path); + } + if (ret == -1) { + throw std::runtime_error("Unlink failed, error is " + std::string(strerror(errno)) + IPC_ERROR(ErrCode::PARAM)); } - int res = bind(socketFd, ReinterpretConvert(&address), addressLen); - if (res == -1) { + + ret = bind(socketFd, ReinterpretConvert(&address), addressLen); + if (ret == -1) { throw std::runtime_error("Bind socket failed." + IPC_ERROR(ErrCode::PARAM)); } + if (address.sun_path[0] != STR_END_CHAR) { - chmod(address.sun_path, SOCKET_FD_CHMOD); + ret = chmod(address.sun_path, SOCKET_FD_CHMOD); + } + if (ret == -1) { + throw std::runtime_error("Chmod failed, error is " + std::string(strerror(errno)) + IPC_ERROR(ErrCode::PARAM)); } } + ~NpuIpcEndPoint() { close(socketFd); } + [[nodiscard]] auto BuildSendNpuCtxt(const std::string &desAddrName, const std::vector &npuPayLoad, const std::vector &fileDes) { @@ -80,7 +105,11 @@ public: throw std::runtime_error("Memcpy failed when fileDes size large than ctxt fileDesPtr " + IPC_ERROR(ErrCode::PARAM)); } - memcpy(ctxt->fileDesPtr, fileDes.data(), fileDes.size() * sizeof(fileDesT)); + if (memcpy_s(ctxt->fileDesPtr, sizeof(ctxt->fileDesPtr), + fileDes.data(), fileDes.size() * sizeof(fileDesT)) != EOK) { + throw std::runtime_error("Memcpy failed when fileDes size large than ctxt fileDesPtr " + + IPC_ERROR(ErrCode::MEMORY)); + } } return ctxt; } @@ -137,7 +166,7 @@ public: throw std::runtime_error("TryPeekMessage occur " + std::string(std::strerror(errno))); } - const char *GetName(Ctxt const & ctxt) const noexcept + const char *GetName(Ctxt const & ctxt) const { if (ctxt.messageName.sun_path[0] != STR_END_CHAR) { throw std::runtime_error("GetName() want to got abstract socket, but got " + @@ -173,8 +202,10 @@ protected: auto BuildNpuCtxt_(const std::vector &npuPayLoad, unsigned numFileDes) { auto ctxt = std::make_unique(npuPayLoad.size()); - std::memset(&ctxt->msghdr, 0, sizeof(ctxt->msghdr)); - for (auto i = 0; i < npuPayLoad.size(); i++) { + if (memset_s(&ctxt->msghdr, sizeof(ctxt->msghdr), 0, sizeof(ctxt->msghdr)) != EOK) { + throw std::runtime_error("Memset failed when build ctxt " + IPC_ERROR(ErrCode::MEMORY)); + } + for (size_t i = 0; i < npuPayLoad.size(); i++) { ctxt->iov[i] = {npuPayLoad[i].data, npuPayLoad[i].size}; } ctxt->msghdr.msg_name = &ctxt->messageName; @@ -197,8 +228,7 @@ protected: return ctxt; } }; - } // namespace ipc_monitor } // namespace dynolog_npu -#endif +#endif // NPU_IPC_ENDPOINT_H diff --git a/msmonitor/plugin/ipc_monitor/PyDynamicMonitorProxy.h b/msmonitor/plugin/ipc_monitor/PyDynamicMonitorProxy.h new file mode 100644 index 0000000000000000000000000000000000000000..03aa1d08105e419838fdfdce07b79c41ecf38ec3 --- /dev/null +++ b/msmonitor/plugin/ipc_monitor/PyDynamicMonitorProxy.h @@ -0,0 +1,62 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef PYDYNAMIC_MONITOR_PROXY_H +#define PYDYNAMIC_MONITOR_PROXY_H + +#include +#include "MonitorBase.h" +#include "DynoLogNpuMonitor.h" + +namespace dynolog_npu { +namespace ipc_monitor { + +class PyDynamicMonitorProxy { +public: + PyDynamicMonitorProxy() = default; + bool InitDyno(int npuId) + { + try { + monitor_ = DynoLogNpuMonitor::GetInstance(); + monitor_->SetNpuId(npuId); + bool res = monitor_->Init(); + return res; + } catch (const std::exception &e) { + LOG(ERROR) << "Error when init dyno " << e.what(); + return false; + } + } + + std::string PollDyno() + { + return monitor_->Poll(); + } + + void EnableMsptiMonitor(std::unordered_map& config_map) + { + DynoLogNpuMonitor::GetInstance()->EnableMsptiMonitor(config_map); + } + + void FinalizeDyno() + { + DynoLogNpuMonitor::GetInstance()->Finalize(); + } +private: + MonitorBase *monitor_ = nullptr; +}; + +} // namespace ipc_monitor +} // namespace dynolog_npu +#endif // PYDYNAMIC_MONITOR_PROXY_H diff --git a/msmonitor/plugin/ipc_monitor/TimerTask.h b/msmonitor/plugin/ipc_monitor/TimerTask.h new file mode 100644 index 0000000000000000000000000000000000000000..7ddc5d28ada0dd9ae255ef4748d58f8cd410190a --- /dev/null +++ b/msmonitor/plugin/ipc_monitor/TimerTask.h @@ -0,0 +1,118 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef TIMER_TASK_H +#define TIMER_TASK_H + +#include +#include +#include +#include +#include +#include + +namespace dynolog_npu { +namespace ipc_monitor { +class TimerTask { +public: + TimerTask(const std::string& name, int interval) + : interval(interval), name(name), manual_trigger(false), running(false) {} + + ~TimerTask() + { + Stop(); + } + + void Run() + { + if (running) { + LOG(ERROR) << name << " Timer task is already running."; + return; + } + running = true; + taskThread = std::thread(&TimerTask::TaskRun, this); + } + + void Trigger() + { + std::unique_lock lock(cv_mutex); + manual_trigger = true; + if (running.load()) { + cv.notify_one(); + } + } + + // 停止定时任务 + void Stop() + { + if (!running) { + LOG(ERROR) << name << "Timer task is not running."; + return; + } + + running = false; + cv.notify_one(); + if (taskThread.joinable()) { + taskThread.join(); + } + } + + void SetInterval(int intervalTimes) + { + interval.store(intervalTimes); + } + + virtual void InitResource() {}; + virtual void ReleaseResource() {}; + virtual void ExecuteTask() = 0; +private: + // 定时任务线程函数 + void TaskRun() + { + LOG(INFO) << name << " Timer task started."; + InitResource(); + while (running) { + std::unique_lock lock(cv_mutex); + if (interval.load()) { + cv.wait_for(lock, std::chrono::seconds(interval.load()), [&] {return manual_trigger || !running;}); + } else { + cv.wait(lock, [&] {return manual_trigger || !running;}); + } + if (!running) { + break; + } + if (manual_trigger) { + manual_trigger = false; + } + if (running) { + ExecuteTask(); + } + } + ReleaseResource(); + LOG(INFO) << name << " Timer task stopped."; + } + + std::atomic interval; + std::string name; + std::condition_variable cv; + std::mutex cv_mutex; + std::atomic manual_trigger; + std::atomic running; + std::thread taskThread; +}; + +} +} +#endif \ No newline at end of file diff --git a/msmonitor/plugin/ipc_monitor/metric/MetricApiProcess.cpp b/msmonitor/plugin/ipc_monitor/metric/MetricApiProcess.cpp new file mode 100644 index 0000000000000000000000000000000000000000..41850426c5212e7b28d6f04d90c8897014cb203d --- /dev/null +++ b/msmonitor/plugin/ipc_monitor/metric/MetricApiProcess.cpp @@ -0,0 +1,85 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "MetricApiProcess.h" + +#include +#include + +#include "utils.h" + +namespace dynolog_npu { +namespace ipc_monitor { +namespace metric { + +std::string ApiMetric::seriesToJson() +{ + nlohmann::json jsonMsg; + jsonMsg["kind"] = "API"; + jsonMsg["deviceId"] = -1; + jsonMsg["duration"] = duration; + jsonMsg["timestamp"] = timestamp; + return jsonMsg.dump(); +} + +void MetricApiProcess::ConsumeMsptiData(msptiActivity *record) +{ + msptiActivityApi* apiData = ReinterpretConvert(record); + msptiActivityApi* tmp = ReinterpretConvert(MsptiMalloc(sizeof(msptiActivityApi), ALIGN_SIZE)); + if (memcpy_s(tmp, sizeof(msptiActivityApi), apiData, sizeof(msptiActivityApi)) != EOK) { + MsptiFree(ReinterpretConvert(tmp)); + LOG(ERROR) << "memcpy_s failed" << IPC_ERROR(ErrCode::MEMORY); + return; + } + { + std::unique_lock lock(dataMutex); + records.emplace_back(tmp); + } +} + +std::vector MetricApiProcess::AggregatedData() +{ + std::vector> copyRecords; + { + std::unique_lock lock(dataMutex); + copyRecords = std::move(records); + records.clear(); + } + ApiMetric apiMetric{}; + auto ans = std::accumulate(copyRecords.begin(), copyRecords.end(), 0ULL, + [](uint64_t acc, std::shared_ptr api) { + return acc + api->end - api->start; + }); + apiMetric.duration = ans; + apiMetric.deviceId = -1; + apiMetric.timestamp = getCurrentTimestamp64(); + return {apiMetric}; +} + +void MetricApiProcess::SendProcessMessage() +{ + auto afterAggregated = AggregatedData(); + for (auto& metric: afterAggregated) { + SendMessage(metric.seriesToJson()); + } +} + +void MetricApiProcess::Clear() +{ + records.clear(); +} +} +} +} diff --git a/msmonitor/plugin/ipc_monitor/metric/MetricApiProcess.h b/msmonitor/plugin/ipc_monitor/metric/MetricApiProcess.h new file mode 100644 index 0000000000000000000000000000000000000000..c9357e58eec78ebf4b67941c14c16c3747daa46f --- /dev/null +++ b/msmonitor/plugin/ipc_monitor/metric/MetricApiProcess.h @@ -0,0 +1,51 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef METRIC_API_PROCESS_H +#define METRIC_API_PROCESS_H + +#include +#include +#include "MetricProcessBase.h" + + +namespace dynolog_npu { +namespace ipc_monitor { +namespace metric { + +struct ApiMetric { + uint64_t duration; + uint64_t timestamp; + uint32_t deviceId; +public: + std::string seriesToJson(); +}; + +class MetricApiProcess : public MetricProcessBase { +public: + MetricApiProcess() = default; + void ConsumeMsptiData(msptiActivity *record) override; + std::vector AggregatedData(); + void SendProcessMessage() override; + void Clear() override; +private: + std::mutex dataMutex; + std::vector> records; +}; +} +} +} + +#endif \ No newline at end of file diff --git a/msmonitor/plugin/ipc_monitor/metric/MetricHcclProcess.cpp b/msmonitor/plugin/ipc_monitor/metric/MetricHcclProcess.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b28e4e7d5bd587f93a89f31050d2c1a5d5246fd6 --- /dev/null +++ b/msmonitor/plugin/ipc_monitor/metric/MetricHcclProcess.cpp @@ -0,0 +1,95 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "MetricHcclProcess.h" +#include +#include +#include "utils.h" + +namespace dynolog_npu { +namespace ipc_monitor { +namespace metric { + +std::string HcclMetric::seriesToJson() +{ + nlohmann::json jsonMsg; + jsonMsg["kind"] = "Hccl"; + jsonMsg["deviceId"] = deviceId; + jsonMsg["duration"] = duration; + jsonMsg["timestamp"] = timestamp; + return jsonMsg.dump(); +} + +void MetricHcclProcess::ConsumeMsptiData(msptiActivity *record) +{ + msptiActivityHccl* hcclData = ReinterpretConvert(record); + msptiActivityHccl* tmp = ReinterpretConvert(MsptiMalloc(sizeof(msptiActivityHccl), ALIGN_SIZE)); + if (memcpy_s(tmp, sizeof(msptiActivityHccl), hcclData, sizeof(msptiActivityHccl)) != EOK) { + MsptiFree(ReinterpretConvert(tmp)); + LOG(ERROR) << "memcpy_s failed" << IPC_ERROR(ErrCode::MEMORY); + return; + } + { + std::unique_lock lock(dataMutex); + records.emplace_back(tmp); + } +} + +std::vector MetricHcclProcess::AggregatedData() +{ + std::vector> copyRecords; + { + std::unique_lock lock(dataMutex); + copyRecords = std::move(records); + records.clear(); + } + if (copyRecords.empty()) { + return {}; + } + std::unordered_map>> deviceId2HcclData = + groupby(copyRecords, [](const std::shared_ptr& data) -> std::uint32_t { + return data->ds.deviceId; + }); + std::vector ans; + auto curTimestamp = getCurrentTimestamp64(); + for (auto& pair: deviceId2HcclData) { + HcclMetric hcclMetric{}; + auto& hcclDatas = pair.second; + hcclMetric.duration = std::accumulate(hcclDatas.begin(), hcclDatas.end(), 0ULL, + [](uint64_t acc, std::shared_ptr hccl) { + return acc + hccl->end - hccl->start; + }); + hcclMetric.deviceId = pair.first; + hcclMetric.timestamp = curTimestamp; + ans.emplace_back(hcclMetric); + } + return ans; +} + +void MetricHcclProcess::SendProcessMessage() +{ + auto afterAggregated = AggregatedData(); + for (auto& metric: afterAggregated) { + SendMessage(metric.seriesToJson()); + } +} + +void MetricHcclProcess::Clear() +{ + records.clear(); +} +} +} +} diff --git a/msmonitor/plugin/ipc_monitor/metric/MetricHcclProcess.h b/msmonitor/plugin/ipc_monitor/metric/MetricHcclProcess.h new file mode 100644 index 0000000000000000000000000000000000000000..2c846949d35f1dc3b0c5d359e15dc8d2818db6b5 --- /dev/null +++ b/msmonitor/plugin/ipc_monitor/metric/MetricHcclProcess.h @@ -0,0 +1,52 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef METRIC_HCCL_PROCESS_H +#define METRIC_HCCL_PROCESS_H + +#include +#include +#include "MetricProcessBase.h" + + +namespace dynolog_npu { +namespace ipc_monitor { +namespace metric { + +struct HcclMetric { + std::string kindName; + uint64_t duration; + uint64_t timestamp; + uint32_t deviceId; +public: + std::string seriesToJson(); +}; + +class MetricHcclProcess : public MetricProcessBase { +public: + MetricHcclProcess() = default; + void ConsumeMsptiData(msptiActivity *record) override; + std::vector AggregatedData(); + void SendProcessMessage() override; + void Clear() override; +private: + std::mutex dataMutex; + std::vector> records; +}; +} +} +} + +#endif \ No newline at end of file diff --git a/msmonitor/plugin/ipc_monitor/metric/MetricKernelProcess.cpp b/msmonitor/plugin/ipc_monitor/metric/MetricKernelProcess.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f0d31e99fb56f04de4afc71533a34063eddf86ad --- /dev/null +++ b/msmonitor/plugin/ipc_monitor/metric/MetricKernelProcess.cpp @@ -0,0 +1,96 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "MetricKernelProcess.h" + +#include + +namespace dynolog_npu { +namespace ipc_monitor { +namespace metric { + +std::string KernelMetric::seriesToJson() +{ + nlohmann::json jsonMsg; + jsonMsg["kind"] = "Kernel"; + jsonMsg["deviceId"] = deviceId; + jsonMsg["duration"] = duration; + jsonMsg["timestamp"] = timestamp; + return jsonMsg.dump(); +} + +void MetricKernelProcess::ConsumeMsptiData(msptiActivity *record) +{ + msptiActivityKernel* kernel = ReinterpretConvert(record); + msptiActivityKernel* ptr = ReinterpretConvert(MsptiMalloc(sizeof(msptiActivityKernel), ALIGN_SIZE)); + if (memcpy_s(ptr, sizeof(msptiActivityKernel), kernel, sizeof(msptiActivityKernel)) != EOK) { + MsptiFree(ReinterpretConvert(ptr)); + LOG(ERROR) << "memcpy_s failed" << IPC_ERROR(ErrCode::MEMORY); + return; + } + { + std::unique_lock lock(dataMutex); + records.emplace_back(ptr); + } +} + +std::vector MetricKernelProcess::AggregatedData() +{ + std::vector> copyRecords; + { + std::unique_lock lock(dataMutex); + copyRecords = std::move(records); + records.clear(); + } + if (copyRecords.empty()) { + return {}; + } + std::unordered_map>> deviceId2KernelData = + groupby(copyRecords, [](const std::shared_ptr& data) -> std::uint32_t { + return data->ds.deviceId; + }); + std::vector ans; + auto curTimestamp = getCurrentTimestamp64(); + for (auto& pair: deviceId2KernelData) { + auto deviceId = pair.first; + auto& kernelDatas = pair.second; + KernelMetric kernelMetric{}; + kernelMetric.duration = std::accumulate(kernelDatas.begin(), kernelDatas.end(), 0ULL, + [](uint64_t acc, std::shared_ptr kernel) { + return acc + kernel->end - kernel->start; + }); + kernelMetric.deviceId = deviceId; + kernelMetric.timestamp = curTimestamp; + ans.emplace_back(kernelMetric); + } + + return ans; +} + +void MetricKernelProcess::SendProcessMessage() +{ + auto afterAggregated = AggregatedData(); + for (auto& metric: afterAggregated) { + SendMessage(metric.seriesToJson()); + } +} + +void MetricKernelProcess::Clear() +{ + records.clear(); +} +} +} +} diff --git a/msmonitor/plugin/ipc_monitor/metric/MetricKernelProcess.h b/msmonitor/plugin/ipc_monitor/metric/MetricKernelProcess.h new file mode 100644 index 0000000000000000000000000000000000000000..9bd034283ece0ba3cd5cfc5f5215b104ef37334c --- /dev/null +++ b/msmonitor/plugin/ipc_monitor/metric/MetricKernelProcess.h @@ -0,0 +1,50 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef METRIC_KERNEL_PROCESS_H +#define METRIC_KERNEL_PROCESS_H + +#include +#include "MetricProcessBase.h" + + +namespace dynolog_npu { +namespace ipc_monitor { +namespace metric { + +struct KernelMetric { + uint64_t duration; + uint64_t timestamp; + uint32_t deviceId; +public: + std::string seriesToJson(); +}; + +class MetricKernelProcess : public MetricProcessBase { +public: + MetricKernelProcess() = default; + void ConsumeMsptiData(msptiActivity *record) override; + std::vector AggregatedData(); + void SendProcessMessage() override; + void Clear() override; +private: + std::mutex dataMutex; + std::vector> records; +}; +} +} +} + +#endif \ No newline at end of file diff --git a/msmonitor/plugin/ipc_monitor/metric/MetricManager.cpp b/msmonitor/plugin/ipc_monitor/metric/MetricManager.cpp new file mode 100644 index 0000000000000000000000000000000000000000..36029313ca87e33540f5f06f3356445dc1157926 --- /dev/null +++ b/msmonitor/plugin/ipc_monitor/metric/MetricManager.cpp @@ -0,0 +1,93 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "MetricManager.h" +#include "MetricKernelProcess.h" +#include "MetricApiProcess.h" +#include "MetricMemCpyProcess.h" +#include "MetricHcclProcess.h" +#include "MetricMarkProcess.h" +#include "MetricMemSetProcess.h" +#include "MetricMemProcess.h" +#include "utils.h" + +namespace dynolog_npu { +namespace ipc_monitor { +namespace metric { + +MetricManager::MetricManager(): TimerTask("MetricManager", DEFAULT_FLUSH_INTERVAL), + kindSwitchs_(MSPTI_ACTIVITY_KIND_COUNT), consumeStatus_(MSPTI_ACTIVITY_KIND_COUNT) { + metrics.resize(MSPTI_ACTIVITY_KIND_COUNT); + metrics[MSPTI_ACTIVITY_KIND_KERNEL] = std::make_shared(); + metrics[MSPTI_ACTIVITY_KIND_API] = std::make_shared(); + metrics[MSPTI_ACTIVITY_KIND_MEMCPY] = std::make_shared(); + metrics[MSPTI_ACTIVITY_KIND_MARKER] = std::make_shared(); + metrics[MSPTI_ACTIVITY_KIND_MEMSET] = std::make_shared(); + metrics[MSPTI_ACTIVITY_KIND_HCCL] = std::make_shared(); + metrics[MSPTI_ACTIVITY_KIND_MEMORY] = std::make_shared(); +} + +void MetricManager::ReleaseResource() +{ + for (int i = 0; i < MSPTI_ACTIVITY_KIND_COUNT; i++) { + if (kindSwitchs_[i].load()) { + kindSwitchs_[i] = false; + metrics[i]->Clear(); + } + } +} + +ErrCode MetricManager::ConsumeMsptiData(msptiActivity *record) +{ + if (!kindSwitchs_[record->kind]) { + return ErrCode::PERMISSION; + } + auto metricProcess = metrics[record->kind]; + consumeStatus_[record->kind] = true; + metricProcess->ConsumeMsptiData(record); + consumeStatus_[record->kind] = false; + return ErrCode::SUC; +} + +void MetricManager::SetReportInterval(uint32_t intervalTimes) +{ + if (reportInterval_.load() != intervalTimes) { + SendMetricMsg(); + SetInterval(intervalTimes); + reportInterval_.store(intervalTimes); + } +} + +void MetricManager::ExecuteTask() +{ + SendMetricMsg(); +} + +void MetricManager::SendMetricMsg() +{ + for (int i = 0; i < MSPTI_ACTIVITY_KIND_COUNT; i++) { + if (kindSwitchs_[i].load()) { + metrics[i]->SendProcessMessage(); + } + } +} + +void MetricManager::EnableKindSwitch_(msptiActivityKind kind, bool flag) +{ + kindSwitchs_[kind] = flag; +} +} +} +} \ No newline at end of file diff --git a/msmonitor/plugin/ipc_monitor/metric/MetricManager.h b/msmonitor/plugin/ipc_monitor/metric/MetricManager.h new file mode 100644 index 0000000000000000000000000000000000000000..262dc19b8660c3f5d48ad81b59eb1c47d4edbd1f --- /dev/null +++ b/msmonitor/plugin/ipc_monitor/metric/MetricManager.h @@ -0,0 +1,50 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef METRIC_MANAGER_H +#define METRIC_MANAGER_H + +#include +#include + +#include "utils.h" +#include "singleton.h" +#include "mspti.h" +#include "TimerTask.h" +#include "MetricProcessBase.h" + +namespace dynolog_npu { +namespace ipc_monitor { +namespace metric { +class MetricManager : public ipc_monitor::Singleton, public TimerTask { +public: + MetricManager(); + ~MetricManager() = default; + ErrCode ConsumeMsptiData(msptiActivity *record); + void SetReportInterval(uint32_t intervalTimes); + void SendMetricMsg(); + void ExecuteTask() override; + void EnableKindSwitch_(msptiActivityKind kind, bool flag); + void ReleaseResource() override; +private: + std::vector> kindSwitchs_; + std::vector> consumeStatus_; + std::atomic reportInterval_; + std::vector> metrics; +}; +} +} +} +#endif \ No newline at end of file diff --git a/msmonitor/plugin/ipc_monitor/metric/MetricMarkProcess.cpp b/msmonitor/plugin/ipc_monitor/metric/MetricMarkProcess.cpp new file mode 100644 index 0000000000000000000000000000000000000000..14a4145197be5c2787663cd2dec7e198c3607cd1 --- /dev/null +++ b/msmonitor/plugin/ipc_monitor/metric/MetricMarkProcess.cpp @@ -0,0 +1,157 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "MetricMarkProcess.h" + +#include +#include +#include + +#include "utils.h" + + +namespace dynolog_npu { +namespace ipc_monitor { +namespace metric { + +constexpr size_t COMPLETE_RANGE_DATA_SIZE = 4; + +std::string MarkMetric::seriesToJson() +{ + nlohmann::json jsonMsg; + jsonMsg["kind"] = "Marker"; + jsonMsg["deviceId"] = deviceId; + jsonMsg["domain"] = domain; + jsonMsg["duration"] = duration; + jsonMsg["timestamp"] = timestamp; + return jsonMsg.dump(); +} + +bool MetricMarkProcess::TransMarkData2Range(const std::vector>& markDatas, + RangeMarkData& rangemarkData) +{ + if (markDatas.size() != COMPLETE_RANGE_DATA_SIZE) { + return false; + } + + for (auto& activityMarker: markDatas) { + if (activityMarker->flag == MSPTI_ACTIVITY_FLAG_MARKER_START_WITH_DEVICE) { + if (activityMarker->sourceKind == MSPTI_ACTIVITY_SOURCE_KIND_DEVICE) { + rangemarkData.deviceId = activityMarker->objectId.ds.deviceId; + rangemarkData.deviceStart = activityMarker->timestamp; + } else { + rangemarkData.start = activityMarker->timestamp; + } + } + if (activityMarker->flag == MSPTI_ACTIVITY_FLAG_MARKER_END_WITH_DEVICE) { + if (activityMarker->sourceKind == MSPTI_ACTIVITY_SOURCE_KIND_DEVICE) { + rangemarkData.deviceEnd = activityMarker->timestamp; + } else { + rangemarkData.end = activityMarker->timestamp; + } + } + } + auto markId = markDatas[0]->id; + std::string domainName = "default"; + auto it = domainMsg.find(markId); + if (it != domainMsg.end()) { + domainName = *it->second; + } + rangemarkData.domain = domainName; + id2Marker.erase(markId); + domainMsg.erase(markId); + return true; +} + +void MetricMarkProcess::ConsumeMsptiData(msptiActivity *record) +{ + msptiActivityMarker* markerData = ReinterpretConvert(record); + msptiActivityMarker* tmp = ReinterpretConvert(MsptiMalloc(sizeof(msptiActivityMarker), ALIGN_SIZE)); + if (memcpy_s(tmp, sizeof(msptiActivityMarker), markerData, sizeof(msptiActivityMarker)) != EOK) { + MsptiFree(ReinterpretConvert(tmp)); + LOG(ERROR) << "memcpy_s failed" << IPC_ERROR(ErrCode::MEMORY); + return; + } + { + std::unique_lock lock(dataMutex); + records.emplace_back(tmp); + if (markerData->flag == MSPTI_ACTIVITY_FLAG_MARKER_START_WITH_DEVICE && + markerData->sourceKind == MSPTI_ACTIVITY_SOURCE_KIND_HOST) { + std::string domainStr = markerData->domain; + auto markId = markerData->id; + domainMsg.emplace(markId, std::make_shared(domainStr)); + } + } +} + +std::vector MetricMarkProcess::AggregatedData() +{ + std::vector> copyRecords; + { + std::unique_lock lock(dataMutex); + copyRecords = std::move(records); + records.clear(); + } + for (auto& record: copyRecords) { + id2Marker[record->id].emplace_back(std::move(record)); + } + std::vector rangeDatas; + for (auto pair = id2Marker.rbegin(); pair != id2Marker.rend(); ++pair) { + auto markId = pair->first; + auto markDatas = pair->second; + RangeMarkData rangeMark{}; + if (TransMarkData2Range(markDatas, rangeMark)) { + rangeDatas.emplace_back(rangeMark); + } + } + + std::unordered_map> domain2RangeData = + groupby(rangeDatas, [](const RangeMarkData& data) -> std::string { + return data.domain + std::to_string(data.deviceId); + }); + std::vector ans; + for (auto& pair: domain2RangeData) { + MarkMetric markMetric{}; + auto domainName = pair.first; + auto rangeDatas = pair.second; + markMetric.deviceId = rangeDatas[0].deviceId; + markMetric.domain = domainName; + markMetric.timestamp = getCurrentTimestamp64(); + markMetric.duration = std::accumulate(rangeDatas.begin(), rangeDatas.end(), 0ULL, + [](uint64_t acc, const RangeMarkData& rangeData) { + return acc + rangeData.deviceEnd - rangeData.deviceStart; + }); + ans.emplace_back(markMetric); + } + return ans; +} + +void MetricMarkProcess::SendProcessMessage() +{ + auto afterAggregated = AggregatedData(); + for (auto& metric: afterAggregated) { + SendMessage(metric.seriesToJson()); + } +} + +void MetricMarkProcess::Clear() +{ + records.clear(); + domainMsg.clear(); + id2Marker.clear(); +} +} +} +} diff --git a/msmonitor/plugin/ipc_monitor/metric/MetricMarkProcess.h b/msmonitor/plugin/ipc_monitor/metric/MetricMarkProcess.h new file mode 100644 index 0000000000000000000000000000000000000000..3835bda859b9e4f7a530f1165684146a122f4b4e --- /dev/null +++ b/msmonitor/plugin/ipc_monitor/metric/MetricMarkProcess.h @@ -0,0 +1,69 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef METRIC_MARK_PROCESS_H +#define METRIC_MARK_PROCESS_H + +#include +#include +#include "MetricProcessBase.h" + + +namespace dynolog_npu { +namespace ipc_monitor { +namespace metric { + +struct MarkMetric { + std::string name; + std::string domain; + uint64_t duration; + uint64_t timestamp; + uint32_t deviceId; +public: + std::string seriesToJson(); +}; + +struct RangeMarkData { + std::string domain; + uint64_t duration; + uint64_t start{0}; + uint64_t end{0}; + uint64_t deviceStart{0}; + uint64_t deviceEnd{0}; + uint32_t deviceId; +}; + + +class MetricMarkProcess : public MetricProcessBase { +public: + MetricMarkProcess() = default; + void ConsumeMsptiData(msptiActivity *record) override; + std::vector AggregatedData(); + void SendProcessMessage() override; + void Clear() override; +private: + bool TransMarkData2Range(const std::vector>& markDatas, + RangeMarkData& rangemarkData); +private: + std::mutex dataMutex; + std::unordered_map> domainMsg; + std::vector> records; + std::map>> id2Marker; +}; +} +} +} + +#endif \ No newline at end of file diff --git a/msmonitor/plugin/ipc_monitor/metric/MetricMemCpyProcess.cpp b/msmonitor/plugin/ipc_monitor/metric/MetricMemCpyProcess.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c53f1741248732deddbc203adae75d4a643ddf38 --- /dev/null +++ b/msmonitor/plugin/ipc_monitor/metric/MetricMemCpyProcess.cpp @@ -0,0 +1,95 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "MetricMemCpyProcess.h" + +#include + +namespace dynolog_npu { +namespace ipc_monitor { +namespace metric { + +std::string MemCpyMetric::seriesToJson() +{ + nlohmann::json jsonMsg; + jsonMsg["kind"] = "MemCpy"; + jsonMsg["deviceId"] = deviceId; + jsonMsg["duration"] = duration; + jsonMsg["timestamp"] = timestamp; + return jsonMsg.dump(); +} + +void MetricMemCpyProcess::ConsumeMsptiData(msptiActivity *record) +{ + msptiActivityMemcpy* kernel = ReinterpretConvert(record); + msptiActivityMemcpy* ptr = ReinterpretConvert(MsptiMalloc(sizeof(msptiActivityMemcpy), ALIGN_SIZE)); + if (memcpy_s(ptr, sizeof(msptiActivityMemcpy), kernel, sizeof(msptiActivityMemcpy)) != EOK) { + MsptiFree(ReinterpretConvert(ptr)); + LOG(ERROR) << "memcpy_s failed" << IPC_ERROR(ErrCode::MEMORY); + return; + } + { + std::unique_lock lock(dataMutex); + records.emplace_back(ptr); + } +} + +std::vector MetricMemCpyProcess::AggregatedData() +{ + std::vector> copyRecords; + { + std::unique_lock lock(dataMutex); + copyRecords = std::move(records); + records.clear(); + } + if (copyRecords.empty()) { + return {}; + } + std::unordered_map>> deviceId2Memcpy = + groupby(copyRecords, [](const std::shared_ptr& data) -> std::uint32_t { + return data->deviceId; + }); + std::vector ans; + auto curTimestamp = getCurrentTimestamp64(); + for (auto& pair: deviceId2Memcpy) { + auto deviceId = pair.first; + MemCpyMetric memCpyMetric{}; + auto& memCpyDatas = pair.second; + memCpyMetric.duration = std::accumulate(memCpyDatas.begin(), memCpyDatas.end(), 0ULL, + [](uint64_t acc, std::shared_ptr memcpy) { + return acc + memcpy->end - memcpy->start; + }); + memCpyMetric.deviceId = deviceId; + memCpyMetric.timestamp = curTimestamp; + ans.emplace_back(memCpyMetric); + } + return ans; +} + +void MetricMemCpyProcess::SendProcessMessage() +{ + auto afterAggregated = AggregatedData(); + for (auto& metric: afterAggregated) { + SendMessage(metric.seriesToJson()); + } +} + +void MetricMemCpyProcess::Clear() +{ + records.clear(); +} +} +} +} diff --git a/msmonitor/plugin/ipc_monitor/metric/MetricMemCpyProcess.h b/msmonitor/plugin/ipc_monitor/metric/MetricMemCpyProcess.h new file mode 100644 index 0000000000000000000000000000000000000000..9b3b845f31a669190ac6b480d40090c7dc2785ad --- /dev/null +++ b/msmonitor/plugin/ipc_monitor/metric/MetricMemCpyProcess.h @@ -0,0 +1,50 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef METRIC_MEMCPY_PROCESS_H +#define METRIC_MEMCPY_PROCESS_H + +#include +#include "MetricProcessBase.h" + + +namespace dynolog_npu { +namespace ipc_monitor { +namespace metric { + +struct MemCpyMetric { + uint64_t duration; + uint64_t timestamp; + uint32_t deviceId; +public: + std::string seriesToJson(); +}; + +class MetricMemCpyProcess : public MetricProcessBase { +public: + MetricMemCpyProcess() = default; + void ConsumeMsptiData(msptiActivity *record) override; + std::vector AggregatedData(); + void SendProcessMessage() override; + void Clear() override; +private: + std::mutex dataMutex; + std::vector> records; +}; +} +} +} + +#endif \ No newline at end of file diff --git a/msmonitor/plugin/ipc_monitor/metric/MetricMemProcess.cpp b/msmonitor/plugin/ipc_monitor/metric/MetricMemProcess.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e7422de52ada7c4acf2a902a2a6a88ec8099dc1f --- /dev/null +++ b/msmonitor/plugin/ipc_monitor/metric/MetricMemProcess.cpp @@ -0,0 +1,95 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "MetricMemProcess.h" + +#include + +namespace dynolog_npu { +namespace ipc_monitor { +namespace metric { + +std::string MemMetric::seriesToJson() +{ + nlohmann::json jsonMsg; + jsonMsg["kind"] = "Memory"; + jsonMsg["deviceId"] = deviceId; + jsonMsg["duration"] = duration; + jsonMsg["timestamp"] = timestamp; + return jsonMsg.dump(); +} + +void MetricMemProcess::ConsumeMsptiData(msptiActivity *record) +{ + msptiActivityMemory* mem = ReinterpretConvert(record); + msptiActivityMemory* ptr = ReinterpretConvert(MsptiMalloc(sizeof(msptiActivityMemory), ALIGN_SIZE)); + if (memcpy_s(ptr, sizeof(msptiActivityMemory), mem, sizeof(msptiActivityMemory)) != EOK) { + MsptiFree(ReinterpretConvert(ptr)); + LOG(ERROR) << "memcpy_s failed" << IPC_ERROR(ErrCode::MEMORY); + return; + } + { + std::unique_lock lock(dataMutex); + records.emplace_back(ptr); + } +} + +std::vector MetricMemProcess::AggregatedData() +{ + std::vector> copyRecords; + { + std::unique_lock lock(dataMutex); + copyRecords = std::move(records); + records.clear(); + } + if (copyRecords.empty()) { + return {}; + } + std::unordered_map>> deviceId2MemData = + groupby(copyRecords, [](const std::shared_ptr& data) -> std::uint32_t { + return data->deviceId; + }); + std::vector ans; + auto curTimestamp = getCurrentTimestamp64(); + for (auto& pair: deviceId2MemData) { + auto deviceId = pair.first; + auto& memDatas = pair.second; + MemMetric memMetric{}; + memMetric.duration = std::accumulate(memDatas.begin(), memDatas.end(), 0ULL, + [](uint64_t acc, std::shared_ptr mem) { + return acc + mem->end - mem->start; + }); + memMetric.deviceId = deviceId; + memMetric.timestamp = curTimestamp; + ans.emplace_back(memMetric); + } + return ans; +} + +void MetricMemProcess::SendProcessMessage() +{ + auto afterAggregated = AggregatedData(); + for (auto& metric: afterAggregated) { + SendMessage(metric.seriesToJson()); + } +} + +void MetricMemProcess::Clear() +{ + records.clear(); +} +} +} +} \ No newline at end of file diff --git a/msmonitor/plugin/ipc_monitor/metric/MetricMemProcess.h b/msmonitor/plugin/ipc_monitor/metric/MetricMemProcess.h new file mode 100644 index 0000000000000000000000000000000000000000..c6548c18de1f0cf68ca0490b2640b15d8025ea29 --- /dev/null +++ b/msmonitor/plugin/ipc_monitor/metric/MetricMemProcess.h @@ -0,0 +1,51 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef METRIC_MEM_PROCESS_H +#define METRIC_MEM_PROCESS_H + +#include +#include "MetricProcessBase.h" + + +namespace dynolog_npu { +namespace ipc_monitor { +namespace metric { + +struct MemMetric { + std::string name; + uint64_t duration; + uint64_t timestamp; + uint32_t deviceId; +public: + std::string seriesToJson(); +}; + +class MetricMemProcess : public MetricProcessBase { +public: + MetricMemProcess() = default; + void ConsumeMsptiData(msptiActivity *record) override; + std::vector AggregatedData(); + void SendProcessMessage() override; + void Clear() override; +private: + std::mutex dataMutex; + std::vector> records; +}; +} +} +} + +#endif \ No newline at end of file diff --git a/msmonitor/plugin/ipc_monitor/metric/MetricMemSetProcess.cpp b/msmonitor/plugin/ipc_monitor/metric/MetricMemSetProcess.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1b5870cd391937cbfdeaa603f4aed4533e3b00ba --- /dev/null +++ b/msmonitor/plugin/ipc_monitor/metric/MetricMemSetProcess.cpp @@ -0,0 +1,95 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "MetricMemSetProcess.h" + +#include + +namespace dynolog_npu { +namespace ipc_monitor { +namespace metric { + +std::string MemSetMetric::seriesToJson() +{ + nlohmann::json jsonMsg; + jsonMsg["kind"] = "MemSet"; + jsonMsg["deviceId"] = deviceId; + jsonMsg["duration"] = duration; + jsonMsg["timestamp"] = timestamp; + return jsonMsg.dump(); +} + +void MetricMemSetProcess::ConsumeMsptiData(msptiActivity *record) +{ + msptiActivityMemset* memSet = ReinterpretConvert(record); + msptiActivityMemset* ptr = ReinterpretConvert(MsptiMalloc(sizeof(msptiActivityMemset), ALIGN_SIZE)); + if (memcpy_s(ptr, sizeof(msptiActivityMemset), memSet, sizeof(msptiActivityMemset)) != EOK) { + MsptiFree(ReinterpretConvert(ptr)); + LOG(ERROR) << "memcpy_s failed" << IPC_ERROR(ErrCode::MEMORY); + return; + } + { + std::unique_lock lock(dataMutex); + records.emplace_back(ptr); + } +} + +std::vector MetricMemSetProcess::AggregatedData() +{ + std::vector> copyRecords; + { + std::unique_lock lock(dataMutex); + copyRecords = std::move(records); + records.clear(); + } + if (copyRecords.empty()) { + return {}; + } + std::unordered_map>> deviceId2MemsetData = + groupby(copyRecords, [](const std::shared_ptr& data) -> std::uint32_t { + return data->deviceId; + }); + std::vector ans; + auto curTimestamp = getCurrentTimestamp64(); + for (auto& pair: deviceId2MemsetData) { + MemSetMetric memSetMetric{}; + auto deviceId = pair.first; + auto& memSetDatas = pair.second; + memSetMetric.duration = std::accumulate(memSetDatas.begin(), memSetDatas.end(), 0ULL, + [](uint64_t acc, std::shared_ptr memSet) { + return acc + memSet->end - memSet->start; + }); + memSetMetric.deviceId = deviceId; + memSetMetric.timestamp = curTimestamp; + ans.emplace_back(memSetMetric); + } + return ans; +} + +void MetricMemSetProcess::SendProcessMessage() +{ + auto afterAggregated = AggregatedData(); + for (auto& metric: afterAggregated) { + SendMessage(metric.seriesToJson()); + } +} + +void MetricMemSetProcess::Clear() +{ + records.clear(); +} +} +} +} \ No newline at end of file diff --git a/msmonitor/plugin/ipc_monitor/metric/MetricMemSetProcess.h b/msmonitor/plugin/ipc_monitor/metric/MetricMemSetProcess.h new file mode 100644 index 0000000000000000000000000000000000000000..5d725e6edf5c4bd074d9cc1751a7dde263b52f67 --- /dev/null +++ b/msmonitor/plugin/ipc_monitor/metric/MetricMemSetProcess.h @@ -0,0 +1,51 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef METRIC_MEM_SET_PROCESS_H +#define METRIC_MEM_SET_PROCESS_H + +#include +#include "metric/MetricProcessBase.h" + + +namespace dynolog_npu { +namespace ipc_monitor { +namespace metric { + +struct MemSetMetric { + std::string name; + uint64_t duration; + uint64_t timestamp; + uint32_t deviceId; +public: + std::string seriesToJson(); +}; + +class MetricMemSetProcess : public MetricProcessBase { +public: + MetricMemSetProcess() = default; + void ConsumeMsptiData(msptiActivity *record) override; + std::vector AggregatedData(); + void SendProcessMessage() override; + void Clear() override; +private: + std::mutex dataMutex; + std::vector> records; +}; +} +} +} + +#endif \ No newline at end of file diff --git a/msmonitor/plugin/ipc_monitor/metric/MetricProcessBase.h b/msmonitor/plugin/ipc_monitor/metric/MetricProcessBase.h new file mode 100644 index 0000000000000000000000000000000000000000..2d066a9b27080116e6d87c908a0d08e837c5add9 --- /dev/null +++ b/msmonitor/plugin/ipc_monitor/metric/MetricProcessBase.h @@ -0,0 +1,61 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef METRIC_PROCESS_BASE_H +#define METRIC_PROCESS_BASE_H + +#include +#include + +#include "DynoLogNpuMonitor.h" +#include "NpuIpcClient.h" +#include "mspti.h" + +namespace dynolog_npu { +namespace ipc_monitor { +namespace metric { +class MetricProcessBase { +public: + void SendMessage(std::string message) + { + if (message.empty()) { + LOG(ERROR) << "SendMessage message is empty"; + return; + } + static const std::string destName = DYNO_IPC_NAME + "_data"; + static const int maxRetry = 5; + static const int retryWaitTimeUs = 1000; + auto msg = Message::ConstructStrMessage(message, MSG_TYPE_DATA); + if (!msg) { + LOG(ERROR) << "ConstructStrMessage failed, message: " << message; + return; + } + auto ipcClient = DynoLogNpuMonitor::GetInstance()->GetIpcClient(); + if (!ipcClient) { + LOG(ERROR) << "DynoLogNpuMonitor ipcClient is nullptr"; + return; + } + if (!ipcClient->SyncSendMessage(*msg, destName, maxRetry, retryWaitTimeUs)) { + LOG(ERROR) << "send mspti message failed: " << message; + } + } + virtual void ConsumeMsptiData(msptiActivity *record) = 0; + virtual void Clear() = 0; + virtual void SendProcessMessage() = 0; +}; +} +} +} +#endif \ No newline at end of file diff --git a/msmonitor/plugin/ipc_monitor/mspti_monitor/MsptiMonitor.cpp b/msmonitor/plugin/ipc_monitor/mspti_monitor/MsptiMonitor.cpp new file mode 100644 index 0000000000000000000000000000000000000000..33abb8fe382ecfb1d14f70317e6c2138e2ebdf04 --- /dev/null +++ b/msmonitor/plugin/ipc_monitor/mspti_monitor/MsptiMonitor.cpp @@ -0,0 +1,233 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "MsptiMonitor.h" + +#include +#include +#include +#include + +#include "DynoLogNpuMonitor.h" +#include "MetricManager.h" +#include "utils.h" + +namespace { +constexpr size_t DEFAULT_BUFFER_SIZE = 8 * 1024 * 1024; +constexpr size_t MAX_BUFFER_SIZE = 256 * 1024 * 1024; +constexpr uint32_t MAX_ALLOC_CNT = MAX_BUFFER_SIZE / DEFAULT_BUFFER_SIZE; +} + +namespace dynolog_npu { +namespace ipc_monitor { + +MsptiMonitor::MsptiMonitor() + : start_(false), + subscriber_(nullptr), + checkFlush_(false), + flushInterval_(0) {} + +MsptiMonitor::~MsptiMonitor() +{ + Uninit(); +} + +void MsptiMonitor::Start() +{ + if (start_.load()) { + return; + } + SetThreadName("MsptiMonitor"); + if (Thread::Start() != 0) { + LOG(ERROR) << "MsptiMonitor start failed"; + return; + } + start_.store(true); + metric::MetricManager::GetInstance()->Run(); + LOG(INFO) << "MsptiMonitor start successfully"; +} + +void MsptiMonitor::Stop() +{ + if (!start_.load()) { + LOG(WARNING) << "MsptiMonitor is not running"; + return; + } + Uninit(); + if (msptiActivityFlushAll(1) != MSPTI_SUCCESS) { + LOG(WARNING) << "MsptiMonitor stop msptiActivityFlushAll failed"; + } + LOG(INFO) << "MsptiMonitor stop successfully"; +} + +void MsptiMonitor::Uninit() +{ + if (!start_.load()) { + return; + } + metric::MetricManager::GetInstance()->Stop(); + start_.store(false); + cv_.notify_one(); + Thread::Stop(); +} + +void MsptiMonitor::EnableActivity(msptiActivityKind kind) +{ + if (MSPTI_ACTIVITY_KIND_INVALID < kind && kind < MSPTI_ACTIVITY_KIND_COUNT) { + std::lock_guard lock(activityMtx_); + if (msptiActivityEnable(kind) == MSPTI_SUCCESS) { + enabledActivities_.insert(kind); + } else { + LOG(ERROR) << "MsptiMonitor enableActivity failed, kind: " << static_cast(kind); + } + metric::MetricManager::GetInstance()->EnableKindSwitch_(kind, true); + } +} + +void MsptiMonitor::DisableActivity(msptiActivityKind kind) +{ + if (MSPTI_ACTIVITY_KIND_INVALID < kind && kind < MSPTI_ACTIVITY_KIND_COUNT) { + std::lock_guard lock(activityMtx_); + if (msptiActivityDisable(kind) == MSPTI_SUCCESS) { + enabledActivities_.erase(kind); + } else { + LOG(ERROR) << "MsptiMonitor disableActivity failed, kind: " << static_cast(kind); + } + metric::MetricManager::GetInstance()->EnableKindSwitch_(kind, false); + } +} + +void MsptiMonitor::SetFlushInterval(uint32_t interval) +{ + flushInterval_.store(interval); + checkFlush_.store(true); + if (start_.load()) { + cv_.notify_one(); + } + metric::MetricManager::GetInstance()->SetReportInterval(interval); +} + +bool MsptiMonitor::IsStarted() +{ + return start_.load(); +} + +std::set MsptiMonitor::GetEnabledActivities() +{ + std::lock_guard lock(activityMtx_); + return enabledActivities_; +} + +void MsptiMonitor::Run() +{ + if (msptiSubscribe(&subscriber_, nullptr, nullptr) != MSPTI_SUCCESS) { + LOG(ERROR) << "MsptiMonitor run failed, msptiSubscribe failed"; + return; + } + if (msptiActivityRegisterCallbacks(BufferRequest, BufferComplete) != MSPTI_SUCCESS) { + LOG(ERROR) << "MsptiMonitor run failed, msptiActivityRegisterCallbacks failed"; + return; + } + while (true) { + std::unique_lock lock(cvMtx_); + if (flushInterval_.load() > 0) { + cv_.wait_for(lock, std::chrono::seconds(flushInterval_.load()), + [&]() { return checkFlush_.load() || !start_.load();}); + } else { + cv_.wait(lock, [&]() { return checkFlush_.load () || !start_.load();}); + } + if (!start_.load()) { + break; + } + if (checkFlush_.load()) { + checkFlush_.store(false); + } + if (flushInterval_.load() > 0) { + if (msptiActivityFlushAll(1) != MSPTI_SUCCESS) { + LOG(ERROR) << "MsptiMonitor run msptiActivityFlushAll failed"; + } + } + } + if (msptiUnsubscribe(subscriber_) != MSPTI_SUCCESS) { + LOG(ERROR) << "MsptiMonitor run failed, msptiUnsubscribe failed"; + } + { + std::lock_guard lock(activityMtx_); + for (auto kind : enabledActivities_) { + msptiActivityDisable(kind); + } + enabledActivities_.clear(); + } + checkFlush_.store(false); + flushInterval_.store(0); +} + +std::atomic MsptiMonitor::allocCnt{0}; + +void MsptiMonitor::BufferRequest(uint8_t **buffer, size_t *size, size_t *maxNumRecords) +{ + if (buffer == nullptr || size == nullptr || maxNumRecords == nullptr) { + return; + } + *maxNumRecords = 0; + if (allocCnt.load() >= MAX_ALLOC_CNT) { + *buffer = nullptr; + *size = 0; + LOG(ERROR) << "MsptiMonitor BufferRequest failed, allocCnt: " << allocCnt.load(); + return; + } + uint8_t *pBuffer = ReinterpretConvert(MsptiMalloc(DEFAULT_BUFFER_SIZE, ALIGN_SIZE)); + if (pBuffer == nullptr) { + *buffer = nullptr; + *size = 0; + } else { + *buffer = pBuffer; + *size = DEFAULT_BUFFER_SIZE; + allocCnt++; + LOG(INFO) << "MsptiMonitor BufferRequest, size: " << *size; + } +} + +void MsptiMonitor::BufferComplete(uint8_t *buffer, size_t size, size_t validSize) +{ + if (validSize > 0 && buffer != nullptr) { + LOG(INFO) << "MsptiMonitor BufferComplete, size: " << size << ", validSize: " << validSize; + msptiActivity *record = nullptr; + msptiResult status = MSPTI_SUCCESS; + do { + status = msptiActivityGetNextRecord(buffer, validSize, &record); + if (status == MSPTI_SUCCESS) { + BufferConsume(record); + } else if (status == MSPTI_ERROR_MAX_LIMIT_REACHED) { + break; + } else { + LOG(ERROR) << "MsptiMonitor BufferComplete failed, status: " << static_cast(status); + break; + } + } while (true); + allocCnt--; + } + MsptiFree(buffer); +} + +void MsptiMonitor::BufferConsume(msptiActivity *record) +{ + if (record == nullptr) { + return; + } + metric::MetricManager::GetInstance()->ConsumeMsptiData(record); +} +} // namespace ipc_monitor +} // namespace dynolog_npu diff --git a/msmonitor/plugin/ipc_monitor/mspti_monitor/MsptiMonitor.h b/msmonitor/plugin/ipc_monitor/mspti_monitor/MsptiMonitor.h new file mode 100644 index 0000000000000000000000000000000000000000..d1b73e581e1c26ed267ab8b414c695b8da4df8cf --- /dev/null +++ b/msmonitor/plugin/ipc_monitor/mspti_monitor/MsptiMonitor.h @@ -0,0 +1,63 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MSPTI_MONITOR_H +#define MSPTI_MONITOR_H + +#include +#include +#include +#include +#include "mspti.h" +#include "thread.h" + + +namespace dynolog_npu { +namespace ipc_monitor { +class MsptiMonitor : public Thread { +public: + explicit MsptiMonitor(); + virtual ~MsptiMonitor(); + void Start(); + void Stop(); + void EnableActivity(msptiActivityKind kind); + void DisableActivity(msptiActivityKind kind); + void SetFlushInterval(uint32_t interval); + bool IsStarted(); + std::set GetEnabledActivities(); + void Uninit(); + +private: + static void BufferRequest(uint8_t **buffer, size_t *size, size_t *maxNumRecords); + static void BufferComplete(uint8_t *buffer, size_t size, size_t validSize); + static void BufferConsume(msptiActivity *record); + static std::atomic allocCnt; + +private: + void Run() override; + +private: + std::atomic start_; + std::mutex cvMtx_; + std::condition_variable cv_; + msptiSubscriberHandle subscriber_; + std::mutex activityMtx_; + std::set enabledActivities_; + std::atomic checkFlush_; + std::atomic flushInterval_; +}; +} // namespace ipc_monitor +} // namespace dynolog_npu +#endif // MSPTI_MONITOR_H diff --git a/msmonitor/plugin/ipc_monitor/mspti_monitor/mspti.h b/msmonitor/plugin/ipc_monitor/mspti_monitor/mspti.h new file mode 100644 index 0000000000000000000000000000000000000000..3ff0d2b8d259fc427c69c0d3d82c4514ce385d29 --- /dev/null +++ b/msmonitor/plugin/ipc_monitor/mspti_monitor/mspti.h @@ -0,0 +1,259 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MSPTI_STUB_H +#define MSPTI_STUB_H + +constexpr int ACTIVITY_STRUCT_ALIGNMENT = 8; +#if defined(_WIN32) +#define START_PACKED_ALIGNMENT __pragma(pack(push, 1)) +#define PACKED_ALIGNMENT __declspec(align(ACTIVITY_STRUCT_ALIGNMENT)) +#define END_PACKED_ALIGNMENT __pragma(pack(pop)) +#elif defined(__GNUC__) +#define START_PACKED_ALIGNMENT +#define PACKED_ALIGNMENT __attribute__((__packed__)) __attribute__((aligned(ACTIVITY_STRUCT_ALIGNMENT))) +#define END_PACKED_ALIGNMENT +#else +#define START_PACKED_ALIGNMENT +#define PACKED_ALIGNMENT +#define END_PACKED_ALIGNMENT +#endif + +#include +#include + +#define MSPTI_INVALID_DEVICE_ID ((uint32_t) 0xFFFFFFFFU) +#define MSPTI_INVALID_STREAM_ID ((uint32_t) 0xFFFFFFFFU) +#define MSPTI_INVALID_CORRELATION_ID ((uint64_t) 0) +using msptiCallbackId = uint32_t; + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +typedef enum { + MSPTI_SUCCESS = 0, + MSPTI_ERROR_INVALID_PARAMETER = 1, + MSPTI_ERROR_MULTIPLE_SUBSCRIBERS_NOT_SUPPORTED = 2, + MSPTI_ERROR_MAX_LIMIT_REACHED = 3, + MSPTI_ERROR_DEVICE_OFFLINE = 4, + MSPTI_ERROR_QUERY_EMPTY = 5, + MSPTI_ERROR_INNER = 999, + MSPTI_ERROR_FOECE_INT = 0x7fffffff +} msptiResult; + +typedef enum { + MSPTI_CB_DOMAIN_INVALID = 0, + MSPTI_CB_DOMAIN_RUNTIME = 1, + MSPTI_CB_DOMAIN_HCCL = 2, + MSPTI_CB_DOMAIN_SIZE, + MSPTI_CB_DOMAIN_FORCE_INT = 0x7fffffff +} msptiCallbackDomain; + +typedef enum { + MSPTI_API_ENTER = 0, + MSPTI_API_EXIT = 1, + MSPTI_API_CBSITE_FORCE_INT = 0x7fffffff +} msptiApiCallbackSite; + +typedef struct { + msptiApiCallbackSite callbackSite; + const char *functionName; + const void *functionParams; + const void *functionReturnValue; + const char *symbolName; + uint64_t correlationId; + uint64_t reserved1; + uint64_t reserved2; + uint64_t *correlationData; +} msptiCallbackData; + +typedef enum { + MSPTI_ACTIVITY_KIND_INVALID = 0, + MSPTI_ACTIVITY_KIND_MARKER = 1, + MSPTI_ACTIVITY_KIND_KERNEL = 2, + MSPTI_ACTIVITY_KIND_API = 3, + MSPTI_ACTIVITY_KIND_HCCL = 4, + MSPTI_ACTIVITY_KIND_MEMORY = 5, + MSPTI_ACTIVITY_KIND_MEMSET = 6, + MSPTI_ACTIVITY_KIND_MEMCPY = 7, + MSPTI_ACTIVITY_KIND_EXTERNAL_CORRELATION = 8, + MSPTI_ACTIVITY_KIND_COUNT, + MSPTI_ACTIVITY_KIND_FORCE_INT = 0x7fffffff +} msptiActivityKind; + +typedef enum { + MSPTI_ACTIVITY_FLAG_NONE = 0, + MSPTI_ACTIVITY_FLAG_MARKER_INSTANTANEOUS = 1 << 0, + MSPTI_ACTIVITY_FLAG_MARKER_START = 1 << 1, + MSPTI_ACTIVITY_FLAG_MARKER_END = 1 << 2, + MSPTI_ACTIVITY_FLAG_MARKER_INSTANTANEOUS_WITH_DEVICE = 1 << 3, + MSPTI_ACTIVITY_FLAG_MARKER_START_WITH_DEVICE = 1 << 4, + MSPTI_ACTIVITY_FLAG_MARKER_END_WITH_DEVICE = 1 << 5 +} msptiActivityFlag; + +typedef enum { + MSPTI_ACTIVITY_SOURCE_KIND_HOST = 0, + MSPTI_ACTIVITY_SOURCE_KIND_DEVICE = 1 +} msptiActivitySourceKind; + +typedef enum { + MSPTI_ACTIVITY_MEMORY_OPERATION_TYPE_ALLOCATATION = 0, + MSPTI_ACTIVITY_MEMORY_OPERATION_TYPE_RELEASE = 1 +} msptiActivityMemoryOperationType; + +typedef enum { + MSPTI_ACTIVITY_MEMORY_KIND_UNKNOWN = 0, + MSPTI_ACTIVITY_MEMORY_KIND_DEVICE = 1 +} msptiActivityMemoryKind; + +typedef enum { + MSPTI_ACTIVITY_MEMCPY_KIND_UNKNOWN = 0, + MSPTI_ACTIVITY_MEMCPY_KIND_HTOH = 1, + MSPTI_ACTIVITY_MEMCPY_KIND_HTOD = 2, + MSPTI_ACTIVITY_MEMCPY_KIND_DTOH = 3, + MSPTI_ACTIVITY_MEMCPY_KIND_DTOD = 4, + MSPTI_ACTIVITY_MEMCPY_KIND_DEFAULT = 5 +} msptiActivityMemcpyKind; + +START_PACKED_ALIGNMENT + +typedef union PACKED_ALIGNMENT { + struct { + uint32_t processId; + uint32_t threadId; + } pt; + struct { + uint32_t deviceId; + uint32_t streamId; + } ds; +} msptiObjectId; + +typedef struct PACKED_ALIGNMENT { + msptiActivityKind kind; +} msptiActivity; + +typedef struct PACKED_ALIGNMENT { + msptiActivityKind kind; + uint64_t start; + uint64_t end; + struct { + uint32_t processId; + uint32_t threadId; + } pt; + uint64_t correlationId; + const char* name; +} msptiActivityApi; + +typedef struct PACKED_ALIGNMENT { + msptiActivityKind kind; + uint64_t start; + uint64_t end; + struct { + uint32_t deviceId; + uint32_t streamId; + } ds; + uint64_t correlationId; + const char *type; + const char *name; +} msptiActivityKernel; + +typedef struct PACKED_ALIGNMENT { + msptiActivityKind kind; + msptiActivityFlag flag; + msptiActivitySourceKind sourceKind; + uint64_t timestamp; + uint64_t id; + msptiObjectId objectId; + const char *name; + const char *domain; +} msptiActivityMarker; + +typedef struct PACKED_ALIGNMENT { + msptiActivityKind kind; + uint64_t start; + uint64_t end; + struct { + uint32_t deviceId; + uint32_t streamId; + } ds; + double bandWidth; + const char *name; + const char *commName; +} msptiActivityHccl; + +typedef struct PACKED_ALIGNMENT { + msptiActivityKind kind; + msptiActivityMemoryOperationType memoryOperationType; + msptiActivityMemoryKind memoryKind; + uint64_t correlationId; + uint64_t start; + uint64_t end; + uint64_t address; + uint64_t bytes; + uint32_t processId; + uint32_t deviceId; + uint32_t streamId; +} msptiActivityMemory; + +typedef struct PACKED_ALIGNMENT { + msptiActivityKind kind; + uint32_t value; + uint64_t bytes; + uint64_t start; + uint64_t end; + uint32_t deviceId; + uint32_t streamId; + uint64_t correlationId; + uint8_t isAsync; +} msptiActivityMemset; + +typedef struct PACKED_ALIGNMENT { + msptiActivityKind kind; + msptiActivityMemcpyKind copyKind; + uint64_t bytes; + uint64_t start; + uint64_t end; + uint32_t deviceId; + uint32_t streamId; + uint64_t correlationId; + uint8_t isAsync; +} msptiActivityMemcpy; + +END_PACKED_ALIGNMENT + +typedef void(*msptiCallbackFunc)(void* userdata, msptiCallbackDomain domain, msptiCallbackId cbid, const msptiCallbackData *cbdata); +typedef void(*msptiBuffersCallbackRequestFunc)(uint8_t **buffer, size_t *size, size_t *maxNumRecords); +typedef void(*msptiBuffersCallbackCompleteFunc)(uint8_t *buffer, size_t size, size_t validSize); + +struct msptiSubscriber_st { + msptiCallbackFunc callback; + void *userdata; +}; + +typedef struct msptiSubscriber_st *msptiSubscriberHandle; + +msptiResult msptiSubscribe(msptiSubscriberHandle *subscriber, msptiCallbackFunc callback, void *userdata); +msptiResult msptiUnsubscribe(msptiSubscriberHandle subscriber); +msptiResult msptiActivityRegisterCallbacks(msptiBuffersCallbackRequestFunc funcBufferRequested, msptiBuffersCallbackCompleteFunc funcBufferCompleted); +msptiResult msptiActivityEnable(msptiActivityKind kind); +msptiResult msptiActivityDisable(msptiActivityKind kind); +msptiResult msptiActivityGetNextRecord(uint8_t *buffer, size_t validBufferSizeBytes, msptiActivity **record); +msptiResult msptiActivityFlushAll(uint32_t flag); + +#ifdef __cplusplus +} +#endif // __cplusplus +#endif // MSPTI_STUB_H diff --git a/dynolog_npu/plugin/ipc_monitor/singleton.h b/msmonitor/plugin/ipc_monitor/singleton.h similarity index 48% rename from dynolog_npu/plugin/ipc_monitor/singleton.h rename to msmonitor/plugin/ipc_monitor/singleton.h index 8bb106f3adc8b365ef81feb603c6aaac917a00e2..5143f404f19a2c6d2da96e171bb39e1cd2b549b6 100644 --- a/dynolog_npu/plugin/ipc_monitor/singleton.h +++ b/msmonitor/plugin/ipc_monitor/singleton.h @@ -1,31 +1,47 @@ -#ifndef SINGLETON_H -#define SINGLETON_H -#include - -namespace dynolog_npu { -namespace ipc_monitor { - -template -class Singleton { -public: - static T *GetInstance() noexcept(std::is_nothrow_constructible::value) { - static T instance; - return &instance; - } - - virtual ~Singleton() = default; - -protected: - explicit Singleton() = default; - -private: - explicit Singleton(const Singleton &obj) = delete; - Singleton& operator=(const Singleton &obj) = delete; - explicit Singleton(Singleton &&obj) = delete; - Singleton& operator=(Singleton &&obj) = delete; -}; - -} // ipc_monitor -} // dynolog_npu - +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef SINGLETON_H +#define SINGLETON_H +#include + +namespace dynolog_npu { +namespace ipc_monitor { + +template +class Singleton { +public: + static T *GetInstance() noexcept(std::is_nothrow_constructible::value) + { + static T instance; + return &instance; + } + + virtual ~Singleton() = default; + +protected: + explicit Singleton() = default; + +private: + explicit Singleton(const Singleton &obj) = delete; + Singleton& operator=(const Singleton &obj) = delete; + explicit Singleton(Singleton &&obj) = delete; + Singleton& operator=(Singleton &&obj) = delete; +}; + +} // ipc_monitor +} // dynolog_npu + #endif \ No newline at end of file diff --git a/msmonitor/plugin/ipc_monitor/thread.h b/msmonitor/plugin/ipc_monitor/thread.h new file mode 100644 index 0000000000000000000000000000000000000000..b674cbb6cb86c3bf6177da81684549211ef92ee8 --- /dev/null +++ b/msmonitor/plugin/ipc_monitor/thread.h @@ -0,0 +1,90 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef IPC_MONITOR_THREAD_H +#define IPC_MONITOR_THREAD_H + +#include +#include +#include +#include +#include "utils.h" + +namespace dynolog_npu { +namespace ipc_monitor { +class Thread { +public: + Thread() + : is_alive_(false), + pid_(0), + thread_name_("IPCMonitor") {} + + ~Thread() + { + if (is_alive_) { + (void)pthread_cancel(pid_); + (void)pthread_join(pid_, nullptr); + } + } + + void SetThreadName(const std::string &name) + { + if (!name.empty()) { + thread_name_ = name; + } + } + + std::string GetThreadName() + { + return thread_name_; + } + + int Start() + { + int ret = pthread_create(&pid_, nullptr, Execute, ReinterpretConvert(this)); + is_alive_ = (ret == 0) ? true : false; + return ret; + } + + int Stop() + { + return Join(); + } + + int Join() + { + int ret = pthread_join(pid_, nullptr); + is_alive_ = (ret == 0) ? false : true; + return ret; + } + +private: + static void* Execute(void *args) + { + Thread *thr = ReinterpretConvert(args); + prctl(PR_SET_NAME, ReinterpretConvert(thr->GetThreadName().data())); + thr->Run(); + return nullptr; + } + virtual void Run() = 0; + +private: + bool is_alive_; + pthread_t pid_; + std::string thread_name_; +}; +} // ipc_monitor +} // dynolog_npu +#endif // IPC_MONITOR_THREAD_H diff --git a/msmonitor/plugin/setup.py b/msmonitor/plugin/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..2e257a48ada719a56d3cd0299f56f61351f249f4 --- /dev/null +++ b/msmonitor/plugin/setup.py @@ -0,0 +1,69 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import sys + +import subprocess +import pybind11 + +from setuptools import setup, Extension +from setuptools.command.build_ext import build_ext + + +class CMakeExtension(Extension): + def __init__(self, name, sourcedir=""): + super().__init__(name, sources=[]) + self.sourcedir = os.path.abspath(sourcedir) + + +class CMakeBuild(build_ext): + def run(self): + for ext in self.extensions: + self.build_extension(ext) + + def build_extension(self, ext): + cfg = 'Debug' if self.debug else 'Release' + build_args = ['--config', cfg] + + ext_dir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name))) + cmake_args = [ + '-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=' + ext_dir, + '-DPYTHON_EXECUTABLE=' + sys.executable, + '-DCMAKE_PREFIX_PATH=' + pybind11.get_cmake_dir(), + '-DCMAKE_INSTALL_PREFIX=' + ext_dir, + '-DDYNOLOG_PATH=' + os.path.join(os.path.dirname(BASE_DIR), "third_party", "dynolog"), + '-DCMAKE_BUILD_TYPE=' + cfg + ] + + env = os.environ.copy() + env['CXXFLAGS'] = '{} -DVERSION_INFO=\\"{}\\"'.format(env.get('CXXFLAGS', ''), + self.distribution.get_version()) + + if not os.path.exists(self.build_temp): + os.makedirs(self.build_temp) + subprocess.check_call(['cmake', ext.sourcedir] + cmake_args, cwd=self.build_temp, env=env) + subprocess.check_call(['cmake', '--build', '.', '--target', 'install', '-j', '8'] + build_args, + cwd=self.build_temp) + +BASE_DIR = os.path.dirname(os.path.realpath(__file__)) + +setup( + name="msmonitor_plugin", + version="0.1", + description="msMonitor plugins", + ext_modules=[CMakeExtension('IPCMonitor')], + cmdclass=dict(build_ext=CMakeBuild), + install_requires=["pybind11"], +) diff --git a/msmonitor/plugin/stub/build_stub.sh b/msmonitor/plugin/stub/build_stub.sh new file mode 100644 index 0000000000000000000000000000000000000000..97ec0699aec5923497ee32a7252b0337db059f7f --- /dev/null +++ b/msmonitor/plugin/stub/build_stub.sh @@ -0,0 +1,7 @@ +#!/bin/bash + +CDIR="$(cd "$(dirname "$0")" ; pwd -P)" + +cd ${CDIR} + +gcc -fPIC -shared -o libmspti.so -I../ipc_monitor/mspti_monitor mspti.cpp diff --git a/msmonitor/plugin/stub/mspti.cpp b/msmonitor/plugin/stub/mspti.cpp new file mode 100644 index 0000000000000000000000000000000000000000..db05f209275f0702f50f6d33fe5d5fb6aa1b2732 --- /dev/null +++ b/msmonitor/plugin/stub/mspti.cpp @@ -0,0 +1,52 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "mspti.h" + +msptiResult msptiSubscribe(msptiSubscriberHandle *subscriber, msptiCallbackFunc callback, void *userdata) +{ + return MSPTI_SUCCESS; +} + +msptiResult msptiUnsubscribe(msptiSubscriberHandle subscriber) +{ + return MSPTI_SUCCESS; +} + +msptiResult msptiActivityRegisterCallbacks(msptiBuffersCallbackRequestFunc funcBufferRequested, msptiBuffersCallbackCompleteFunc funcBufferCompleted) +{ + return MSPTI_SUCCESS; +} + +msptiResult msptiActivityEnable(msptiActivityKind kind) +{ + return MSPTI_SUCCESS; +} + +msptiResult msptiActivityDisable(msptiActivityKind kind) +{ + return MSPTI_SUCCESS; +} + +msptiResult msptiActivityGetNextRecord(uint8_t *buffer, size_t validBufferSizeBytes, msptiActivity **record) +{ + return MSPTI_SUCCESS; +} + +msptiResult msptiActivityFlushAll(uint32_t flag) +{ + return MSPTI_SUCCESS; +} diff --git a/plugins/mindstudio-vscode-plugins/OWNERS b/plugins/mindstudio-vscode-plugins/OWNERS deleted file mode 100644 index 2c4ada94aa198321313f24bc0b0f289eba360c33..0000000000000000000000000000000000000000 --- a/plugins/mindstudio-vscode-plugins/OWNERS +++ /dev/null @@ -1,9 +0,0 @@ -options: - no_parent_owners: true -approvers: -- lee314 -- linxi9527 -reviewers: -- jzc_23 -- duanhaomiao -- yangqingliang4 \ No newline at end of file diff --git a/plugins/tensorboard-plugins/OWNERS b/plugins/tensorboard-plugins/ OWNERS similarity index 67% rename from plugins/tensorboard-plugins/OWNERS rename to plugins/tensorboard-plugins/ OWNERS index 8dd996262b04faf778976324fa4221e51c4bfa30..34c383beaf138da92df0991b472135496450a827 100644 --- a/plugins/tensorboard-plugins/OWNERS +++ b/plugins/tensorboard-plugins/ OWNERS @@ -3,8 +3,7 @@ options: approvers: - wo-wenjie - ly-qianxiao -- leo920320 -- ninghuang reviewers: +- wo-wenjie +- ly-qianxiao - leo920320 -- ninghuang diff --git a/plugins/tensorboard-plugins/.github/workflows/libkineto_ci.yml b/plugins/tensorboard-plugins/.github/workflows/libkineto_ci.yml new file mode 100644 index 0000000000000000000000000000000000000000..3133d6400fb0b3ca0ee9b38c311c2db6d1167c7e --- /dev/null +++ b/plugins/tensorboard-plugins/.github/workflows/libkineto_ci.yml @@ -0,0 +1,56 @@ +name: LIBKINETOCI + +on: + push: + branches: + - main + pull_request: + branches: + - main + +jobs: + build: + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest] + + steps: + - uses: actions/checkout@v2 + - name: Checkout submodules + shell: bash + run: | + auth_header="$(git config --local --get http.https://github.com/.extraheader)" + git submodule sync --recursive + git -c "http.extraheader=$auth_header" -c protocol.version=2 submodule update --init --force --recursive --depth=1 + + - name: Get env vars + run: | + echo GITHUB_WORKFLOW = $GITHUB_WORKFLOW + echo HOME = $HOME + echo GITHUB_ACTION = $GITHUB_ACTION + echo GITHUB_ACTIONS = $GITHUB_ACTIONS + echo GITHUB_REPOSITORY = $GITHUB_REPOSITORY + echo GITHUB_EVENT_NAME = $GITHUB_EVENT_NAME + echo GITHUB_EVENT_PATH = $GITHUB_EVENT_PATH + echo GITHUB_WORKSPACE = $GITHUB_WORKSPACE + echo GITHUB_SHA = $GITHUB_SHA + echo GITHUB_REF = $GITHUB_REF + c++ --verbose + + # TODO: Figure out how to install cupti headers T84637671 + - name: Build static lib + run: | + set -e + mkdir build_static + cd build_static + cmake -DKINETO_LIBRARY_TYPE=static ../libkineto/ + make -j + + - name: Build shared lib + run: | + set -e + mkdir build_shared + cd build_shared + cmake -DKINETO_LIBRARY_TYPE=shared ../libkineto/ + make -j diff --git a/plugins/tensorboard-plugins/.github/workflows/tb_plugin_build_pip_package.yml b/plugins/tensorboard-plugins/.github/workflows/tb_plugin_build_pip_package.yml new file mode 100644 index 0000000000000000000000000000000000000000..9bdafcc442635eaff19fc7a7505f5231cf6e5cf7 --- /dev/null +++ b/plugins/tensorboard-plugins/.github/workflows/tb_plugin_build_pip_package.yml @@ -0,0 +1,19 @@ +name: Build torch-tb-profiler Pip Package + +on: + # TODO: Add an on_release trigger to build on tags + workflow_dispatch: + +jobs: + build-package: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - name: build pip package + run: | + set -e + cd tb_plugin + python setup.py sdist bdist_wheel + cd dist/ + pip install *.whl + python -c "import torch_tb_profiler;print(torch_tb_profiler.__version__)" diff --git a/plugins/tensorboard-plugins/.github/workflows/tb_plugin_ci.yml b/plugins/tensorboard-plugins/.github/workflows/tb_plugin_ci.yml new file mode 100644 index 0000000000000000000000000000000000000000..1b59a7bf90a6009caa41d4ac0e3d5545dc8b6c7c --- /dev/null +++ b/plugins/tensorboard-plugins/.github/workflows/tb_plugin_ci.yml @@ -0,0 +1,57 @@ +name: TB_Plugin_CI + +on: + push: + branches: + - main + - release/** + - plugin/** + + pull_request: + branches: + - main + - release/** + - plugin/** + +jobs: + generate-matrix: + runs-on: ubuntu-latest + outputs: + matrix: ${{ steps.set-matrix.outputs.matrix }} + steps: + - id: set-matrix + run: | + echo $GITHUB_BASE_REF + if [ $GITHUB_BASE_REF == "plugin/vnext" ] + then + echo "::set-output name=matrix::{\"python-version\":[3.7, 3.8, 3.9], \"cuda-version\":[\"cpu\"], \"pytorch-version\":[\"nightly\"]}" + else + echo "::set-output name=matrix::{\"python-version\":[3.7, 3.8, 3.9], \"cuda-version\":[\"cpu\"], \"pytorch-version\":[\"nightly\", \"1.11rc\", \"stable\"]}" + fi + + build: + needs: generate-matrix + runs-on: ubuntu-latest + strategy: + matrix: ${{fromJSON(needs.generate-matrix.outputs.matrix)}} + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + architecture: 'x64' + - name: Test + env: + CUDA_VERSION: ${{ matrix.cuda-version }} + PYTORCH_VERSION: ${{ matrix.pytorch-version }} + TORCH_PROFILER_LOG_LEVEL: DEBUG + GRPC_VERBOSITY: DEBUG + GRPC_ENABLE_FORK_SUPPORT: 'False' + run: | + set -e + cd tb_plugin + sh ./ci_scripts/install_env.sh + pip install .[gs] + cd test + pytest diff --git a/plugins/tensorboard-plugins/.gitignore b/plugins/tensorboard-plugins/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..ce186381c0b566e0ca225be70cbf8ac233d7aa6b --- /dev/null +++ b/plugins/tensorboard-plugins/.gitignore @@ -0,0 +1,3 @@ +# ignore common items +.idea +.vscode diff --git a/plugins/tensorboard-plugins/.gitmodules b/plugins/tensorboard-plugins/.gitmodules new file mode 100644 index 0000000000000000000000000000000000000000..4660ee8bc9e6a4be4f4fbb007b8e66058122d716 --- /dev/null +++ b/plugins/tensorboard-plugins/.gitmodules @@ -0,0 +1,6 @@ +[submodule "libkineto/third_party/googletest"] + path = libkineto/third_party/googletest + url = https://github.com/google/googletest.git +[submodule "libkineto/third_party/fmt"] + path = libkineto/third_party/fmt + url = https://github.com/fmtlib/fmt.git diff --git a/plugins/tensorboard-plugins/CODE_OF_CONDUCT.md b/plugins/tensorboard-plugins/CODE_OF_CONDUCT.md new file mode 100644 index 0000000000000000000000000000000000000000..a0cbeaab7650bf08267fbdbc9bb54e845c88f392 --- /dev/null +++ b/plugins/tensorboard-plugins/CODE_OF_CONDUCT.md @@ -0,0 +1,77 @@ +# Code of Conduct + +## Our Pledge + +In the interest of fostering an open and welcoming environment, we as +contributors and maintainers pledge to make participation in our project and +our community a harassment-free experience for everyone, regardless of age, body +size, disability, ethnicity, sex characteristics, gender identity and expression, +level of experience, education, socio-economic status, nationality, personal +appearance, race, religion, or sexual identity and orientation. + +## Our Standards + +Examples of behavior that contributes to creating a positive environment +include: + +* Using welcoming and inclusive language +* Being respectful of differing viewpoints and experiences +* Gracefully accepting constructive criticism +* Focusing on what is best for the community +* Showing empathy towards other community members + +Examples of unacceptable behavior by participants include: + +* The use of sexualized language or imagery and unwelcome sexual attention or + advances +* Trolling, insulting/derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or electronic + address, without explicit permission +* Other conduct which could reasonably be considered inappropriate in a + professional setting + +## Our Responsibilities + +Project maintainers are responsible for clarifying the standards of acceptable +behavior and are expected to take appropriate and fair corrective action in +response to any instances of unacceptable behavior. + +Project maintainers have the right and responsibility to remove, edit, or +reject comments, commits, code, wiki edits, issues, and other contributions +that are not aligned to this Code of Conduct, or to ban temporarily or +permanently any contributor for other behaviors that they deem inappropriate, +threatening, offensive, or harmful. + +## Scope + +This Code of Conduct applies within all project spaces, and it also applies when +an individual is representing the project or its community in public spaces. +Examples of representing a project or community include using an official +project e-mail address, posting via an official social media account, or acting +as an appointed representative at an online or offline event. Representation of +a project may be further defined and clarified by project maintainers. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported by contacting the project team at . All +complaints will be reviewed and investigated and will result in a response that +is deemed necessary and appropriate to the circumstances. The project team is +obligated to maintain confidentiality with regard to the reporter of an incident. +Further details of specific enforcement policies may be posted separately. + +Project maintainers who do not follow or enforce the Code of Conduct in good +faith may face temporary or permanent repercussions as determined by other +members of the project's leadership. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, +available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html + +[homepage]: https://www.contributor-covenant.org + +For answers to common questions about this code of conduct, see +https://www.contributor-covenant.org/faq + diff --git a/plugins/tensorboard-plugins/CONTRIBUTING.md b/plugins/tensorboard-plugins/CONTRIBUTING.md new file mode 100644 index 0000000000000000000000000000000000000000..a2e931bb6f0cc82ff030cee10ee1c99fbbbda07b --- /dev/null +++ b/plugins/tensorboard-plugins/CONTRIBUTING.md @@ -0,0 +1,34 @@ +# Contributing to Kineto +We want to make contributing to this project as easy and transparent as +possible. + +## Code of Conduct +The code of conduct is described in [`CODE_OF_CONDUCT.md`](CODE_OF_CONDUCT.md). + +## Pull Requests +We actively welcome your pull requests. + +1. Fork the repo and create your branch from `main`. +2. If you've added code that should be tested, add tests. +3. If you've changed APIs, update the documentation. +4. Ensure the test suite passes. +5. Make sure your code lints. +6. If you haven't already, complete the Contributor License Agreement ("CLA"). + +## Contributor License Agreement ("CLA") +In order to accept your pull request, we need you to submit a CLA. You only need +to do this once to work on any of Facebook's open source projects. + +Complete your CLA here: + +## Issues +We use GitHub issues to track public bugs. Please ensure your description is +clear and has sufficient instructions to be able to reproduce the issue. + +Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe +disclosure of security bugs. In those cases, please go through the process +outlined on that page and do not file a public issue. + +## License +By contributing to Kineto, you agree that your contributions will be licensed +under the LICENSE file in the root directory of this source tree. diff --git a/plugins/tensorboard-plugins/LICENSE b/plugins/tensorboard-plugins/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..edb179715b5213644cfe903d43294f54892e707e --- /dev/null +++ b/plugins/tensorboard-plugins/LICENSE @@ -0,0 +1,33 @@ +BSD License + +For Kineto software + +Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +All contributions by Microsoft: +Copyright (c) Microsoft Corporation. (The Azure AI Platform team) + +Redistribution and use in source and binary forms, with or without modification, +are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + * Neither the name Facebook nor the names of its contributors may be used to + endorse or promote products derived from this software without specific + prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/plugins/tensorboard-plugins/README.md b/plugins/tensorboard-plugins/README.md new file mode 100644 index 0000000000000000000000000000000000000000..3a18f4c6239f353c10362c9e0ba5aae052cb2c07 --- /dev/null +++ b/plugins/tensorboard-plugins/README.md @@ -0,0 +1,38 @@ +# Kineto + +Kineto is part of the PyTorch Profiler. + +The Kineto project was started to help enable +- **performance observability and diagnostics** across common ML bottleneck components +- **actionable recommendations** for common issues +- integration of external system-level profiling tools +- integration with popular visualization platforms and analysis pipelines + +A central component is libkineto, a profiling library with special focus on low-overhead GPU timeline tracing. + +The PyTorch Profiler TensorBoard plugin provides powerful and intuitive visualizations of profiling results, as well as actionable recommendations, and is the best way to experience the new PyTorch Profiler. + +## Libkineto +Libkineto is an in-process profiling library integrated with the PyTorch Profiler. Please refer to the [README](libkineto/README.md) file in the `libkineto` folder as well as documentation on the [new PyTorch Profiler API](https://pytorch.org/docs/master/profiler.html). + +## PyTorch TensorBoard Profiler NPU Plugin +The goal of the PyTorch TensorBoard Profiler is to provide a seamless and intuitive end-to-end profiling experience, including straightforward collection from PyTorch and insightful visualizations and recommendations in the TensorBoard UI. +Please refer to the [README](tb_plugin/README.md) file in the `tb_plugin` folder. + +## Future Development Direction: +Some areas we're currently working on: +- Support for tracing distributed workloads +- Trace processing, analysis and recommendation engine +- System-level activities, multiple tracing sources +- Profiling and monitoring daemon for larger scale deployments + +## Releases and Contributing +We will follow the PyTorch release schedule which roughly happens on a 3 month basis. + +We appreciate all contributions. If you are planning to contribute back bug-fixes, please do so without any further discussion. + +If you plan to contribute new features, please first open an issue and discuss the feature with us. Sending a PR without discussion might end up resulting in a rejected PR because we might be taking the infrastructure in a different direction than you might be aware of. We expect the architecture to keep evolving. + +## License +Kineto has a BSD-style license, as found in the [LICENSE](LICENSE) file. + diff --git a/plugins/tensorboard-plugins/libkineto/CMakeLists.txt b/plugins/tensorboard-plugins/libkineto/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..63966de803a786913b104419776aa94bb00b74b0 --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/CMakeLists.txt @@ -0,0 +1,198 @@ +cmake_minimum_required(VERSION 3.5 FATAL_ERROR) + +list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/modules") + +#install libraries into correct locations on all platforms +include(GNUInstallDirs) + +# function to extract filelists from libkineto_defs.bzl file +find_package(PythonInterp) +function(get_filelist name outputvar) + execute_process( + COMMAND "${PYTHON_EXECUTABLE}" -c + "exec(open('libkineto_defs.bzl').read());print(';'.join(${name}))" + WORKING_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}" + OUTPUT_VARIABLE _tempvar) + string(REPLACE "\n" "" _tempvar "${_tempvar}") + set(${outputvar} ${_tempvar} PARENT_SCOPE) +endfunction() + +project(kineto VERSION 0.1 LANGUAGES CXX C) + +set(KINETO_LIBRARY_TYPE "default" CACHE STRING + "Type of library (default, static or shared) to build") +set_property(CACHE KINETO_LIBRARY_TYPE PROPERTY STRINGS default shared) +option(KINETO_BUILD_TESTS "Build kineto unit tests" ON) + +set(LIBKINETO_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/src") +set(LIBKINETO_INCLUDE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/include") +set(LIBKINETO_BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}) +set(LIBKINETO_THIRDPARTY_DIR "${CMAKE_CURRENT_SOURCE_DIR}/third_party") +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) + +#We should default to a Release build +if (NOT CMAKE_BUILD_TYPE OR CMAKE_BUILD_TYPE STREQUAL "") + set(CMAKE_BUILD_TYPE "Release" CACHE STRING "" FORCE) +endif() + +if (NOT CUDA_SOURCE_DIR) + set(CUDA_SOURCE_DIR "$ENV{CUDA_SOURCE_DIR}") + message(INFO " CUDA_SOURCE_DIR = ${CUDA_SOURCE_DIR}") +endif() + +if (NOT ROCM_SOURCE_DIR) + set(ROCM_SOURCE_DIR "$ENV{ROCM_SOURCE_DIR}") + message(INFO " ROCM_SOURCE_DIR = ${ROCM_SOURCE_DIR}") +endif() + +# Set LIBKINETO_NOCUPTI to explicitly disable CUPTI +# Otherwise, CUPTI is disabled if not found +IF (NOT CUDA_SOURCE_DIR OR NOT CUPTI_INCLUDE_DIR OR NOT CUDA_cupti_LIBRARY) + set(LIBKINETO_NOCUPTI ON CACHE BOOL "" FORCE) +endif() + +IF (NOT ROCM_SOURCE_DIR AND NOT ROCTRACER_INCLUDE_DIR) + set(LIBKINETO_NOROCTRACER ON CACHE BOOL "" FORCE) +endif() + +# Define file lists +if (LIBKINETO_NOCUPTI AND LIBKINETO_NOROCTRACER) + get_filelist("get_libkineto_cpu_only_srcs(with_api=False)" LIBKINETO_SRCS) + message(INFO " CUPTI unavailable or disabled - not building GPU profilers") +elseif(NOT LIBKINETO_NOROCTRACER) + get_filelist("get_libkineto_roctracer_srcs()" LIBKINETO_SRCS) + message(INFO " Building with roctracer") +else() + get_filelist("get_libkineto_cupti_srcs(with_api=False)" LIBKINETO_SRCS) +endif() +get_filelist("get_libkineto_public_headers()" LIBKINETO_PUBLIC_HEADERS) +get_filelist("get_libkineto_api_srcs()" LIBKINETO_API_SRCS) + +add_library(kineto_base OBJECT ${LIBKINETO_SRCS}) +add_library(kineto_api OBJECT ${LIBKINETO_API_SRCS}) + +# Make libraries depend on libkineto_defs.bzl +add_custom_target(libkineto_defs.bzl DEPENDS libkineto_defs.bzl) +add_dependencies(kineto_base libkineto_defs.bzl) + +set_target_properties(kineto_base kineto_api PROPERTIES + CXX_STANDARD 14 + CXX_STANDARD_REQUIRED YES + CXX_EXTENSIONS NO + CXX_VISIBILITY_PRESET hidden) + +set(KINETO_COMPILE_OPTIONS "-DKINETO_NAMESPACE=libkineto") +list(APPEND KINETO_COMPILE_OPTIONS "-DFMT_HEADER_ONLY") +if(NOT MSVC) + list(APPEND KINETO_COMPILE_OPTIONS "-std=c++14") +else() + list(APPEND KINETO_COMPILE_OPTIONS "/std:c++14") + list(APPEND KINETO_COMPILE_OPTIONS "-DWIN32_LEAN_AND_MEAN") + list(APPEND KINETO_COMPILE_OPTIONS "-DNOGDI") +endif() +if (NOT LIBKINETO_NOCUPTI) + list(APPEND KINETO_COMPILE_OPTIONS "-DHAS_CUPTI") +endif() +if (NOT LIBKINETO_NOROCTRACER) + target_compile_options(kineto_base PRIVATE "-DHAS_ROCTRACER") + target_compile_options(kineto_base PRIVATE "-D__HIP_PLATFORM_HCC__") + target_compile_options(kineto_base PRIVATE "-D__HIP_PLATFORM_AMD__") +endif() + +target_compile_options(kineto_base PRIVATE "${KINETO_COMPILE_OPTIONS}") +target_compile_options(kineto_api PRIVATE "${KINETO_COMPILE_OPTIONS}") + +if(NOT TARGET fmt) + if(NOT FMT_SOURCE_DIR) + set(FMT_SOURCE_DIR "${LIBKINETO_THIRDPARTY_DIR}/fmt" + CACHE STRING "fmt source directory from submodules") + endif() + + # Build FMT. + # FMT and some other libraries use BUILD_SHARED_LIBS to control + # the library type. + # Save and restore the value after configuring FMT + set(TEMP_BUILD_SHARED_LIBS ${BUILD_SHARED_LIBS}) + set(BUILD_SHARED_LIBS OFF CACHE BOOL "Build shared libs" FORCE) + set(FMT_LIBRARY_TYPE static CACHE STRING "Set lib type to static") + add_subdirectory("${FMT_SOURCE_DIR}" "${LIBKINETO_BINARY_DIR}/fmt") + set_property(TARGET fmt PROPERTY POSITION_INDEPENDENT_CODE ON) + set(BUILD_SHARED_LIBS ${TEMP_BUILD_SHARED_LIBS} CACHE BOOL "Build shared libs" FORCE) +endif() + +set(FMT_INCLUDE_DIR "${FMT_SOURCE_DIR}/include") +message(STATUS "Kineto: FMT_SOURCE_DIR = ${FMT_SOURCE_DIR}") +message(STATUS "Kineto: FMT_INCLUDE_DIR = ${FMT_INCLUDE_DIR}") +if (NOT CUPTI_INCLUDE_DIR) + set(CUPTI_INCLUDE_DIR "${CUDA_SOURCE_DIR}/extras/CUPTI/include") +endif() +if (NOT CUDA_INCLUDE_DIRS) + set(CUDA_INCLUDE_DIRS "${CUDA_SOURCE_DIR}/include") +endif() +if (NOT ROCTRACER_INCLUDE_DIR) + set(ROCTRACER_INCLUDE_DIR "${ROCM_SOURCE_DIR}/roctracer/include") +endif() +if (NOT ROCM_INCLUDE_DIRS) + set(ROCM_INCLUDE_DIRS "${ROCM_SOURCE_DIR}/include") +endif() + +message(INFO " CUPTI_INCLUDE_DIR = ${CUPTI_INCLUDE_DIR}") +message(INFO " ROCTRACER_INCLUDE_DIR = ${ROCTRACER_INCLUDE_DIR}") + +target_include_directories(kineto_base PUBLIC + $ + $ + $ + $ + $ + $ + $) + +target_include_directories(kineto_api PUBLIC + $ + $) + +if(KINETO_LIBRARY_TYPE STREQUAL "default") + add_library(kineto + $ + $) +elseif(KINETO_LIBRARY_TYPE STREQUAL "static") + add_library(kineto STATIC + $ + $) +elseif(KINETO_LIBRARY_TYPE STREQUAL "shared") + add_library(kineto SHARED + $) + set_property(TARGET kineto_base PROPERTY POSITION_INDEPENDENT_CODE ON) + set_target_properties(kineto PROPERTIES + CXX_VISIBILITY_PRESET hidden) +else() + message(FATAL_ERROR "Unsupported library type ${KINETO_LIBRARY_TYPE}") +endif() + +if(NOT LIBKINETO_NOROCTRACER) + find_library(ROCTRACER_LIBRARY NAMES libroctracer64.so HINTS /opt/rocm/roctracer/lib) + target_link_libraries(kineto "${ROCTRACER_LIBRARY}") + find_library(KINETO_HIP_LIBRARY NAMES libamdhip64.so HINTS /opt/rocm/lib) + target_link_libraries(kineto "${KINETO_HIP_LIBRARY}") +endif() + +if(NOT LIBKINETO_NOCUPTI) + target_link_libraries(kineto "${CUDA_cupti_LIBRARY}") +endif() +target_link_libraries(kineto $) +add_dependencies(kineto fmt::fmt-header-only) + +install(TARGETS kineto EXPORT kinetoLibraryConfig + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}) + +install(FILES ${LIBKINETO_PUBLIC_HEADERS} + DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/kineto") + +install(EXPORT kinetoLibraryConfig DESTINATION share/cmake/kineto + FILE kinetoLibraryConfig.cmake) + +if(KINETO_BUILD_TESTS) + add_subdirectory(test) +endif() diff --git a/plugins/tensorboard-plugins/libkineto/README.md b/plugins/tensorboard-plugins/libkineto/README.md new file mode 100644 index 0000000000000000000000000000000000000000..37127ca5aa821217da48aad38cb82eb36f8735c2 --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/README.md @@ -0,0 +1,65 @@ +# Libkineto + +Libkineto is an in-process profiling library, part of the Kineto performance +tools project. + +The library provides a way to collect GPU traces and metrics from the host +process, either via the library public API or by sending a signal, if enabled. + +Currently only NVIDIA GPUs are supported. + +## Build Notes +Libkineto uses the standard CMAKE-based build flow. + +### Dependencies +Libkineto requires gcc 5+ and: + +- NVIDIA CUPTI: used to collect traces and metrics from NVIDIA GPUs. +- fmt: used for its convenient and lightweight string formatting functionality. +- googletest: required to build and run Kineto's tests. + - **googletest is not required** if you don't want to run Kineto tests. +By default, building of tests is **on**. Turn it off by setting `KINETO_BUILD_TESTS` to **off**. + +You can download [NVIDIA CUPTI][1], [fmt][2], [googletest][3] and set +`CUDA_SOURCE_DIR`, `FMT_SOURCE_DIR`, `GOOGLETEST_SOURCE_DIR` respectively for +cmake to find these libraries. If the fmt and googletest variables are not set, cmake will +build the git submodules found in the `third_party` directory. +If `CUDA_SOURCE_DIR` is not set, libkineto will fail to build. + +### Building Libkineto + +``` +# Check out repo and sub modules +git clone --recursive https://github.com/pytorch/kineto.git +# Build libkineto with cmake +cd kineto/libkineto +mkdir build && cd build +cmake .. +make +``` + +To run the tests after building libkineto (if tests are built), use the following +command: +``` +make test +``` + +### Installing Libkineto +``` +make install +``` + +## How Libkineto works +We will provide a high-level overview, design philosophy and brief descriptions of various +parts of Libkineto in upcoming blogs. + +## Full documentation +We strive to keep our source files readable. The best and up-to-date +documentation is available in the source files. + +## License +Libkineto is BSD licensed, as detailed in the [LICENSE](../LICENSE) file. + +[1]:https://developer.nvidia.com/CUPTI-CTK10_2 +[2]:https://github.com/fmt +[3]:https://github.com/google/googletest diff --git a/plugins/tensorboard-plugins/libkineto/include/AbstractConfig.h b/plugins/tensorboard-plugins/libkineto/include/AbstractConfig.h new file mode 100644 index 0000000000000000000000000000000000000000..1cadf4906c11c3b5f59e290295048cee7fd63acf --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/include/AbstractConfig.h @@ -0,0 +1,113 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#pragma once + +#include +#include +#include +#include +#include + +namespace KINETO_NAMESPACE { + +class AbstractConfig { + public: + AbstractConfig& operator=(const AbstractConfig&) = delete; + AbstractConfig(AbstractConfig&&) = delete; + AbstractConfig& operator=(AbstractConfig&&) = delete; + + virtual ~AbstractConfig() { + for (const auto& p : featureConfigs_) { + delete p.second; + } + } + + // Return a copy of the full derived class + virtual AbstractConfig* cloneDerived(AbstractConfig& parent) const = 0; + + // Returns true if successfully parsed the config string + bool parse(const std::string& conf); + + // Default setup for signal-triggered profiling + virtual void setSignalDefaults() { + for (auto& p : featureConfigs_) { + p.second->setSignalDefaults(); + } + } + + // Default setup for client-triggered profiling + virtual void setClientDefaults() { + for (auto& p : featureConfigs_) { + p.second->setClientDefaults(); + } + } + + // Time config was created / updated + std::chrono::time_point timestamp() const { + return timestamp_; + } + + // Source config string that this was parsed from + const std::string& source() const { + return source_; + } + + AbstractConfig& feature(std::string name) const { + const auto& pos = featureConfigs_.find(name); + return *pos->second; + } + + // Transfers ownership of cfg arg + void addFeature(const std::string& name, AbstractConfig* cfg) { + featureConfigs_[name] = cfg; + } + + protected: + AbstractConfig() {} + AbstractConfig(const AbstractConfig& other) = default; + + // Return true if the option was recognized and successfully parsed. + // Throw std::invalid_argument if val is invalid. + virtual bool handleOption(const std::string& name, std::string& val); + + // Perform post-validation checks, typically conditons involving + // multiple options. + // Throw std::invalid_argument if automatic correction can not be made. + // + // @param fallbackProfileStartTime Specify a fallback profile start timestamp in case it was never specified by the client + virtual void validate(const std::chrono::time_point& fallbackProfileStartTime) = 0; + + // TODO: Separate out each profiler type into features? + virtual void printActivityProfilerConfig(std::ostream& s) const; + + // Helpers for use in handleOption + // Split a string by delimiter and remove external white space + std::vector splitAndTrim(const std::string& s, char delim) const; + // Lowercase for case-insensitive comparisons + std::string toLower(std::string& s) const; + // Does string end with suffix + bool endsWith(const std::string& s, const std::string& suffix) const; + // Conversions + int64_t toIntRange(const std::string& val, int64_t min, int64_t max) const; + int32_t toInt32(const std::string& val) const; + int64_t toInt64(const std::string& val) const; + bool toBool(std::string& val) const; + + void cloneFeaturesInto(AbstractConfig& cfg) const { + for (const auto& feature : featureConfigs_) { + cfg.featureConfigs_[feature.first] = feature.second->cloneDerived(cfg); + } + } + + private: + // Time config was created / updated + std::chrono::time_point timestamp_{}; + + // Original configuration string, used for comparison + std::string source_{""}; + + // Configuration objects for optional features + std::map featureConfigs_{}; +}; + +} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/include/ActivityProfilerInterface.h b/plugins/tensorboard-plugins/libkineto/include/ActivityProfilerInterface.h new file mode 100644 index 0000000000000000000000000000000000000000..29871e47ab8af87888ccb8e20403bc26c433b5cc --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/include/ActivityProfilerInterface.h @@ -0,0 +1,91 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#pragma once + +#include +#include +#include +#include + +#include "ActivityType.h" +#include "ActivityTraceInterface.h" +#include "IActivityProfiler.h" + +namespace libkineto { + +class ActivityProfilerController; +struct CpuTraceBuffer; +class Config; + +class ActivityProfilerInterface { + + public: + virtual ~ActivityProfilerInterface() {}; + + virtual void init() {} + virtual bool isInitialized() { + return false; + } + virtual bool isActive(){ + return false; + } + + // *** Asynchronous API *** + // Instead of starting and stopping the trace manually, provide a start time + // and duration and / or iteration stop criterion. + // Tracing terminates when either condition is met. + virtual void scheduleTrace(const std::string& configStr) {} + + // *** Synchronous API *** + // These must be called in order: + // prepareTrace -> startTrace -> stopTrace. + + // Many tracing structures are lazily initialized during trace collection, + // with potentially high overhead. + // Call prepareTrace to enable tracing, then run the region to trace + // at least once (and ideally run the same code that is to be traced) to + // allow tracing structures to be initialized. + virtual void prepareTrace( + const std::set& activityTypes, + const std::string& configStr = "") {} + + // Start recording, potentially reusing any buffers allocated since + // prepareTrace was called. + virtual void startTrace() {} + + // Stop and process trace, producing an in-memory list of trace records. + // The processing will be done synchronously (using the calling thread.) + virtual std::unique_ptr stopTrace() { + return nullptr; + } + + // Re-evaluate internal state to allow for triggering operations based + // on number of iteration. each implicitly increments the iteration count + virtual void step() {} + + // *** TraceActivity API *** + // FIXME: Pass activityProfiler interface into clientInterface? + virtual void pushCorrelationId(uint64_t id){} + virtual void popCorrelationId(){} + virtual void transferCpuTrace( + std::unique_ptr traceBuffer){} + + // Correlation ids for user defined spans + virtual void pushUserCorrelationId(uint64_t){} + virtual void popUserCorrelationId(){} + + // Saves information for the current thread to be used in profiler output + // Client must record any new kernel thread where the activity has occured. + virtual void recordThreadInfo() {} + + // Record trace metadata, currently supporting only string key and values, + // values with the same key are overwritten + virtual void addMetadata(const std::string& key, const std::string& value) = 0; + + // Add a child activity profiler, this enables frameworks in the application + // to enable custom framework events. + virtual void addChildActivityProfiler( + std::unique_ptr profiler) {} +}; + +} // namespace libkineto diff --git a/plugins/tensorboard-plugins/libkineto/include/ActivityTraceInterface.h b/plugins/tensorboard-plugins/libkineto/include/ActivityTraceInterface.h new file mode 100644 index 0000000000000000000000000000000000000000..23d4edab00ce2fa90427e13818ac09c8541835ac --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/include/ActivityTraceInterface.h @@ -0,0 +1,21 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#pragma once + +#include +#include + +namespace libkineto { + +struct ITraceActivity; + +class ActivityTraceInterface { + public: + virtual ~ActivityTraceInterface() {} + virtual const std::vector* activities() { + return nullptr; + } + virtual void save(const std::string& path) {} +}; + +} // namespace libkineto diff --git a/plugins/tensorboard-plugins/libkineto/include/ActivityType.h b/plugins/tensorboard-plugins/libkineto/include/ActivityType.h new file mode 100644 index 0000000000000000000000000000000000000000..74c6a2531d6a9cee3196f9f889517926afea823f --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/include/ActivityType.h @@ -0,0 +1,34 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#pragma once + +#include +#include + +namespace libkineto { + +enum class ActivityType { + CPU_OP = 0, // cpu side ops + USER_ANNOTATION, + GPU_USER_ANNOTATION, + GPU_MEMCPY, + GPU_MEMSET, + CONCURRENT_KERNEL, // on-device kernels + EXTERNAL_CORRELATION, + CUDA_RUNTIME, // host side cuda runtime events + CUDA_PROFILER_RANGE, // CUPTI Profiler range for performance metrics + GLOW_RUNTIME, // host side glow runtime events + CPU_INSTANT_EVENT, // host side point-like events + PYTHON_FUNCTION, + OVERHEAD, // CUPTI induced overhead events sampled from its overhead API. + ENUM_COUNT // This is to add buffer and not used for any profiling logic. Add your new type before it. +}; + +const char* toString(ActivityType t); +ActivityType toActivityType(const std::string& str); + +// Return an array of all activity types except COUNT +constexpr int activityTypeCount = (int)ActivityType::ENUM_COUNT; +const std::array activityTypes(); + +} // namespace libkineto diff --git a/plugins/tensorboard-plugins/libkineto/include/ClientInterface.h b/plugins/tensorboard-plugins/libkineto/include/ClientInterface.h new file mode 100644 index 0000000000000000000000000000000000000000..06dc075838164f80e9481b34a5d5d3c136b92efd --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/include/ClientInterface.h @@ -0,0 +1,16 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#pragma once + +namespace libkineto { + +class ClientInterface { + public: + virtual ~ClientInterface() {} + virtual void init() = 0; + virtual void warmup(bool setupOpInputsCollection) = 0; + virtual void start() = 0; + virtual void stop() = 0; +}; + +} // namespace libkineto diff --git a/plugins/tensorboard-plugins/libkineto/include/Config.h b/plugins/tensorboard-plugins/libkineto/include/Config.h new file mode 100644 index 0000000000000000000000000000000000000000..040e96c9f75ab3ab768aaebac28f959f12a3ea06 --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/include/Config.h @@ -0,0 +1,433 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#pragma once + +#include "AbstractConfig.h" +#include "ActivityType.h" + +#include +#include +#include +#include +#include +#include + +namespace KINETO_NAMESPACE { + +using namespace libkineto; + +class Config : public AbstractConfig { + public: + Config(); + Config& operator=(const Config&) = delete; + Config(Config&&) = delete; + Config& operator=(Config&&) = delete; + + // Return a full copy including feature config object + std::unique_ptr clone() const { + auto cfg = std::unique_ptr(new Config(*this)); + cloneFeaturesInto(*cfg); + return cfg; + } + + bool handleOption(const std::string& name, std::string& val) override; + + void setClientDefaults() override; + + // Log events to this file + const std::string& eventLogFile() const { + return eventLogFile_; + } + + bool activityProfilerEnabled() const { + return activityProfilerEnabled_ || + activitiesOnDemandTimestamp_.time_since_epoch().count() > 0; + } + + // Log activitiy trace to this file + const std::string& activitiesLogFile() const { + return activitiesLogFile_; + } + + // Log activitiy trace to this url + const std::string& activitiesLogUrl() const { + return activitiesLogUrl_; + } + + void setActivitiesLogUrl(const std::string& url) { + activitiesLogUrl_ = url; + } + + bool activitiesLogToMemory() const { + return activitiesLogToMemory_; + } + + // Is profiling enabled for the given device? + bool eventProfilerEnabledForDevice(uint32_t dev) const { + return 0 != (eventProfilerDeviceMask_ & (1 << dev)); + } + + // Take a sample (read hardware counters) at this frequency. + // This controls how often counters are read - if all counters cannot + // be collected simultaneously then multiple samples are needed to + // collect all requested counters - see multiplex period. + std::chrono::milliseconds samplePeriod() const { + return samplePeriod_; + } + + void setSamplePeriod(std::chrono::milliseconds period) { + samplePeriod_ = period; + } + + // When all requested counters cannot be collected simultaneously, + // counters will be multiplexed at this frequency. + // Multiplexing can have a large performance impact if done frequently. + // To avoid a perf impact, keep this at 1s or above. + std::chrono::milliseconds multiplexPeriod() const { + return multiplexPeriod_; + } + + void setMultiplexPeriod(std::chrono::milliseconds period) { + multiplexPeriod_ = period; + } + + // Report counters at this frequency. Note that several samples can + // be reported each time, see samplesPerReport. + std::chrono::milliseconds reportPeriod() const { + return reportPeriod_; + } + + void setReportPeriod(std::chrono::milliseconds msecs); + + // Number of samples dispatched each report period. + // Must be in the range [1, report period / sample period]. + // In other words, aggregation is supported but not interpolation. + int samplesPerReport() const { + return samplesPerReport_; + } + + void setSamplesPerReport(int count) { + samplesPerReport_ = count; + } + + // The names of events to collect + const std::set& eventNames() const { + return eventNames_; + } + + // Add additional events to be profiled + void addEvents(const std::set& names) { + eventNames_.insert(names.begin(), names.end()); + } + + // The names of metrics to collect + const std::set& metricNames() const { + return metricNames_; + } + + // Add additional metrics to be profiled + void addMetrics(const std::set& names) { + metricNames_.insert(names.begin(), names.end()); + } + + const std::vector& percentiles() const { + return eventReportPercentiles_; + } + + // Profile for this long, then revert to base config + std::chrono::seconds eventProfilerOnDemandDuration() const { + return eventProfilerOnDemandDuration_; + } + + void setEventProfilerOnDemandDuration(std::chrono::seconds duration) { + eventProfilerOnDemandDuration_ = duration; + } + + // Too many event profilers on a single system can overload the driver. + // At some point, latencies shoot through the roof and collection of samples + // becomes impossible. To avoid this situation we have a limit of profilers + // per GPU. + // NOTE: Communication with a daemon is needed for this feature. + // Library must be built with an active DaemonConfigLoader. + int maxEventProfilersPerGpu() const { + return eventProfilerMaxInstancesPerGpu_; + } + + // On Cuda11 we've seen occasional hangs when reprogramming counters + // Monitor profiling threads and report when a thread is not responding + // for a given number of seconds. + // A period of 0 means disable. + std::chrono::seconds eventProfilerHeartbeatMonitorPeriod() const { + return eventProfilerHeartbeatMonitorPeriod_; + } + + // The types of activities selected in the configuration file + const std::set& selectedActivityTypes() const { + return selectedActivityTypes_; + } + + void setSelectedActivityTypes(const std::set& types) { + selectedActivityTypes_ = types; + } + + bool isOpInputsCollectionEnabled() const { + return enableOpInputsCollection_; + } + + // Trace for this long + std::chrono::milliseconds activitiesDuration() const { + return activitiesDuration_; + } + + // Trace for this many iterations, determined by external API + int activitiesRunIterations() const { + return activitiesRunIterations_; + } + + std::chrono::milliseconds activitiesDurationDefault() const; + + void setActivitiesDuration(std::chrono::milliseconds duration) { + activitiesDuration_ = duration; + } + + int activitiesMaxGpuBufferSize() const { + return activitiesMaxGpuBufferSize_; + } + + std::chrono::seconds activitiesWarmupDuration() const { + return activitiesWarmupDuration_; + } + + int activitiesWarmupIterations() const { + return activitiesWarmupIterations_; + } + + // Timestamp at which the profiling to start, requested by the user. + const std::chrono::time_point requestTimestamp() + const { + if (profileStartTime_.time_since_epoch().count()) { + return profileStartTime_; + } + + // TODO(T94634890): Deperecate requestTimestamp + return requestTimestamp_ + maxRequestAge() + activitiesWarmupDuration(); + } + + bool hasProfileStartTime() const { + return requestTimestamp_.time_since_epoch().count() > 0 || + profileStartTime_.time_since_epoch().count() > 0; + } + + int profileStartIteration() const { + return profileStartIteration_; + } + + bool hasProfileStartIteration() const { + return profileStartIteration_ >= 0 && activitiesRunIterations_ > 0; + } + + void setProfileStartIteration(int iter) { + profileStartIteration_ = iter; + } + + int profileStartIterationRoundUp() const { + return profileStartIterationRoundUp_; + } + + // calculate the start iteration accounting for warmup + int startIterationIncludingWarmup() const { + if (!hasProfileStartIteration()) { + return -1; + } + return profileStartIteration_ - activitiesWarmupIterations_; + } + + const std::chrono::seconds maxRequestAge() const; + + // All VLOG* macros will log if the verbose log level is >= + // the verbosity specified for the verbose log message. + // Default value is -1, so messages with log level 0 will log by default. + int verboseLogLevel() const { + return verboseLogLevel_; + } + + // Modules for which verbose logging is enabled. + // If empty, logging is enabled for all modules. + const std::vector& verboseLogModules() const { + return verboseLogModules_; + } + + bool sigUsr2Enabled() const { + return enableSigUsr2_; + } + + bool ipcFabricEnabled() const { + return enableIpcFabric_; + } + + static std::chrono::milliseconds alignUp( + std::chrono::milliseconds duration, + std::chrono::milliseconds alignment) { + duration += alignment; + return duration - (duration % alignment); + } + + std::chrono::time_point + eventProfilerOnDemandStartTime() const { + return eventProfilerOnDemandTimestamp_; + } + + std::chrono::time_point + eventProfilerOnDemandEndTime() const { + return eventProfilerOnDemandTimestamp_ + eventProfilerOnDemandDuration_; + } + + std::chrono::time_point + activityProfilerRequestReceivedTime() const { + return activitiesOnDemandTimestamp_; + } + + // Users may request and set trace id and group trace id. + const std::string& requestTraceID() const { + return requestTraceID_; + } + + void setRequestTraceID(const std::string& tid) { + requestTraceID_ = tid; + } + + const std::string& requestGroupTraceID() const { + return requestGroupTraceID_; + } + + void setRequestGroupTraceID(const std::string& gtid) { + requestGroupTraceID_ = gtid; + } + + void updateActivityProfilerRequestReceivedTime(); + + void printActivityProfilerConfig(std::ostream& s) const override; + + void validate( + const std::chrono::time_point& fallbackProfileStartTime) override; + + static void addConfigFactory( + std::string name, + std::function factory); + + void print(std::ostream& s) const; + + private: + explicit Config(const Config& other) = default; + + AbstractConfig* cloneDerived(AbstractConfig& parent) const override { + // Clone from AbstractConfig not supported + assert(false); + return nullptr; + } + + uint8_t createDeviceMask(const std::string& val); + + // Adds valid activity types from the user defined string list in the + // configuration file + void setActivityTypes(const std::vector& selected_activities); + + // Sets the default activity types to be traced + void selectDefaultActivityTypes() { + // If the user has not specified an activity list, add all types + for (ActivityType t : activityTypes()) { + // Do no enable this by default + // TODO: introduce optional types + if (t != ActivityType::OVERHEAD) { + selectedActivityTypes_.insert(t); + } + } + } + + int verboseLogLevel_; + std::vector verboseLogModules_; + + // Event profiler + // These settings are also supported in on-demand mode + std::chrono::milliseconds samplePeriod_; + std::chrono::milliseconds reportPeriod_; + int samplesPerReport_; + std::set eventNames_; + std::set metricNames_; + + // On-demand duration + std::chrono::seconds eventProfilerOnDemandDuration_; + // Last on-demand request + std::chrono::time_point + eventProfilerOnDemandTimestamp_; + + int eventProfilerMaxInstancesPerGpu_; + + // Monitor whether event profiler threads are stuck + // at this frequency + std::chrono::seconds eventProfilerHeartbeatMonitorPeriod_; + + // These settings can not be changed on-demand + std::string eventLogFile_; + std::vector eventReportPercentiles_ = {5, 25, 50, 75, 95}; + uint8_t eventProfilerDeviceMask_ = ~0; + std::chrono::milliseconds multiplexPeriod_; + + // Activity profiler + bool activityProfilerEnabled_; + std::set selectedActivityTypes_; + + // The activity profiler settings are all on-demand + std::string activitiesLogFile_; + + std::string activitiesLogUrl_; + + // Log activities to memory buffer + bool activitiesLogToMemory_{false}; + + int activitiesMaxGpuBufferSize_; + std::chrono::seconds activitiesWarmupDuration_; + int activitiesWarmupIterations_; + + // Client Interface + // Enable inputs collection when tracing ops + bool enableOpInputsCollection_{true}; + + // Profile for specified iterations and duration + std::chrono::milliseconds activitiesDuration_; + int activitiesRunIterations_; + + // Below are not used + // Use this net name for iteration count + std::string activitiesExternalAPIIterationsTarget_; + // Only profile nets that includes this in the name + std::vector activitiesExternalAPIFilter_; + // Only profile nets with at least this many operators + int activitiesExternalAPINetSizeThreshold_; + // Only profile nets with at least this many GPU operators + int activitiesExternalAPIGpuOpCountThreshold_; + // Last activity profiler request + std::chrono::time_point + activitiesOnDemandTimestamp_; + + // Synchronized start timestamp + std::chrono::time_point profileStartTime_; + // or start iteration + int profileStartIteration_; + int profileStartIterationRoundUp_; + + // DEPRECATED + std::chrono::time_point requestTimestamp_; + + // Enable profiling via SIGUSR2 + bool enableSigUsr2_; + + // Enable IPC Fabric instead of thrift communication + bool enableIpcFabric_; + + // Logger Metadata + std::string requestTraceID_; + std::string requestGroupTraceID_; +}; + +} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/include/GenericTraceActivity.h b/plugins/tensorboard-plugins/libkineto/include/GenericTraceActivity.h new file mode 100644 index 0000000000000000000000000000000000000000..4272cf1efa4e7613a46c3684270b4e803853345b --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/include/GenericTraceActivity.h @@ -0,0 +1,125 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#pragma once + +#include +#include +#include +#include + +#include "ThreadUtil.h" +#include "ITraceActivity.h" +#include "TraceSpan.h" + +namespace libkineto { + +// Link type, used in GenericTraceActivity.flow.type +constexpr unsigned int kLinkFwdBwd = 1; +constexpr unsigned int kLinkAsyncCpuGpu = 2; + +// @lint-ignore-every CLANGTIDY cppcoreguidelines-non-private-member-variables-in-classes +// @lint-ignore-every CLANGTIDY cppcoreguidelines-pro-type-member-init +class GenericTraceActivity : public ITraceActivity { + + public: + GenericTraceActivity() : activityType(ActivityType::ENUM_COUNT), traceSpan_(NULL) {} + + GenericTraceActivity( + const TraceSpan& trace, ActivityType type, const std::string& name) + : activityType(type), activityName(name), traceSpan_(&trace) { + } + + int64_t deviceId() const override { + return device; + } + + int64_t resourceId() const override { + return resource; + } + + int32_t getThreadId() const override { + return threadId; + } + + int64_t timestamp() const override { + return startTime; + } + + int64_t duration() const override { + return endTime - startTime; + } + + int64_t correlationId() const override { + return id; + } + + ActivityType type() const override { + return activityType; + } + + const ITraceActivity* linkedActivity() const override { + return nullptr; + } + + int flowType() const override { + return flow.type; + } + + int flowId() const override { + return flow.id; + } + + bool flowStart() const override { + return flow.start; + } + + const std::string name() const override { + return activityName; + } + + const TraceSpan* traceSpan() const override { + return traceSpan_; + } + + void log(ActivityLogger& logger) const override; + + //Encode client side metadata as a key/value + template + void addMetadata(const std::string& key, const ValType& value) { + metadata_.push_back(fmt::format("\"{}\": {}", key, value)); + } + + void addMetadataQuoted(const std::string& key, const std::string& value) { + metadata_.push_back(fmt::format("\"{}\": \"{}\"", key, value)); + } + + const std::string metadataJson() const override { + return fmt::format("{}", fmt::join(metadata_, ", ")); + } + + virtual ~GenericTraceActivity() {}; + + int64_t startTime{0}; + int64_t endTime{0}; + int32_t id{0}; + int32_t device{0}; + int32_t resource{0}; + int32_t threadId{0}; + ActivityType activityType; + std::string activityName; + struct Flow { + Flow(): id(0), type(0), start(0) {} + // Ids must be unique within each type + uint32_t id : 27; + // Type will be used to connect flows between profilers, as + // well as look up flow information (name etc) + uint32_t type : 4; + uint32_t start : 1; + } flow; + + private: + const TraceSpan* traceSpan_; + std::vector metadata_; +}; + +} // namespace libkineto diff --git a/plugins/tensorboard-plugins/libkineto/include/IActivityProfiler.h b/plugins/tensorboard-plugins/libkineto/include/IActivityProfiler.h new file mode 100644 index 0000000000000000000000000000000000000000..f5d4b3fb828a3348d948c6487acc6a9e5a18f836 --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/include/IActivityProfiler.h @@ -0,0 +1,104 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#pragma once + +#include +#include +#include + +#include "Config.h" +#include "GenericTraceActivity.h" + +/* This file includes an abstract base class for an activity profiler + * that can be implemented by multiple tracing agents in the application. + * The high level Kineto profiler can co-ordinate start and end of tracing + * and combine together events from multiple such activity profilers. + */ + +namespace libkineto { + +using namespace KINETO_NAMESPACE; + +#ifdef _MSC_VER +// workaround for the predefined ERROR macro on Windows +#undef ERROR +#endif // _MSC_VER + +enum class TraceStatus { + READY, // Accepting trace requests + WARMUP, // Performing trace warmup + RECORDING, // Actively collecting activities + PROCESSING, // Recording is complete, preparing results + ERROR, // One or more errors (and possibly also warnings) occurred. + WARNING, // One or more warnings occurred. +}; + +/* IActivityProfilerSession: + * an opaque object that can be used by a high level profiler to + * start/stop and return trace events. + */ +class IActivityProfilerSession { + + public: + virtual ~IActivityProfilerSession() {} + + // start the trace collection synchronously + virtual void start() = 0; + + // stop the trace collection synchronously + virtual void stop() = 0; + + TraceStatus status() { + return status_; + } + + // returns list of Trace Activities + virtual std::vector& activities() = 0; + + // returns errors with this trace + virtual std::vector errors() = 0; + + // processes trace activities using logger + virtual void processTrace(ActivityLogger& logger) = 0; + + // XXX define trace formats + // virtual save(string name, TraceFormat format) + + protected: + TraceStatus status_ = TraceStatus::READY; +}; + + +/* Activity Profiler Plugins: + * These allow other frameworks to integrate into Kineto's primariy + * activity profiler. While the primary activity profiler handles + * timing the trace collections and correlating events the plugins + * can become source of new trace activity types. + */ +class IActivityProfiler { + + public: + + virtual ~IActivityProfiler() {} + + // name of profiler + virtual const std::string& name() const = 0; + + // returns activity types this profiler supports + virtual const std::set& availableActivities() const = 0; + + // Calls prepare() on registered tracer providers passing in the relevant + // activity types. Returns a profiler session handle + virtual std::unique_ptr configure( + const std::set& activity_types, + const Config& config) = 0; + + // asynchronous version of the above with future timestamp and duration. + virtual std::unique_ptr configure( + int64_t ts_ms, + int64_t duration_ms, + const std::set& activity_types, + const Config& config) = 0; +}; + +} // namespace libkineto diff --git a/plugins/tensorboard-plugins/libkineto/include/ILoggerObserver.h b/plugins/tensorboard-plugins/libkineto/include/ILoggerObserver.h new file mode 100644 index 0000000000000000000000000000000000000000..4fce7851b9669ff93a3f3a772140b0466674853c --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/include/ILoggerObserver.h @@ -0,0 +1,50 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#pragma once + +#include + +// Stages in libkineto used when pushing logs to UST Logger. +constexpr char kWarmUpStage[] = "Warm Up"; +constexpr char kCollectionStage[] = "Collection"; +constexpr char kPostProcessingStage[] = "Post Processing"; + +#if !USE_GOOGLE_LOG + +#include +#include + +namespace libkineto { + +enum LoggerOutputType { + VERBOSE = 0, + INFO = 1, + WARNING = 2, + ERROR = 3, + STAGE = 4, + ENUM_COUNT = 5 +}; + +const char* toString(LoggerOutputType t); +LoggerOutputType toLoggerOutputType(const std::string& str); + +constexpr int LoggerTypeCount = (int) LoggerOutputType::ENUM_COUNT; + +class ILoggerObserver { + public: + virtual ~ILoggerObserver() = default; + virtual void write(const std::string& message, LoggerOutputType ot) = 0; + virtual const std::map> extractCollectorMetadata() = 0; + virtual void reset() = 0; + virtual void addDevice(const int64_t device) = 0; + virtual void setTraceDurationMS(const int64_t duration) = 0; + virtual void addEventCount(const int64_t count) = 0; + virtual void setTraceID(const std::string&) {} + virtual void setGroupTraceID(const std::string&) {} + virtual void addDestination(const std::string& dest) = 0; + +}; + +} // namespace libkineto + +#endif // !USE_GOOGLE_LOG diff --git a/plugins/tensorboard-plugins/libkineto/include/ITraceActivity.h b/plugins/tensorboard-plugins/libkineto/include/ITraceActivity.h new file mode 100644 index 0000000000000000000000000000000000000000..a477ed814662cb4c57738b7e40ec6052e9f65288 --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/include/ITraceActivity.h @@ -0,0 +1,53 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#pragma once + +#include + +#include "ActivityType.h" + +namespace libkineto { + +class ActivityLogger; +struct TraceSpan; + +// Generic activity interface is borrowed from tensorboard protobuf format. +struct ITraceActivity { + virtual ~ITraceActivity() {} + // Device is a physical or logical entity, e.g. CPU, GPU or process + virtual int64_t deviceId() const = 0; + // A resource is something on the device, h/w thread, + // functional units etc. + virtual int64_t resourceId() const = 0; + // s/w thread + virtual int32_t getThreadId() const = 0; + // Start timestamp in mucrosecond + virtual int64_t timestamp() const = 0; + // Duration in microseconds + virtual int64_t duration() const = 0; + // Used to link up async activities + virtual int64_t correlationId() const = 0; + // Part of a flow, identified by flow id and type + virtual int flowType() const = 0; + virtual int flowId() const = 0; + virtual bool flowStart() const = 0; + virtual ActivityType type() const = 0; + virtual const std::string name() const = 0; + // Optional linked activity + virtual const ITraceActivity* linkedActivity() const = 0; + // Optional containing trace object + virtual const TraceSpan* traceSpan() const = 0; + // Log activity + virtual void log(ActivityLogger& logger) const = 0; + // Return json formatted metadata + // FIXME: Return iterator to dynamic type map here instead + virtual const std::string metadataJson() const = 0; + + static int64_t nsToUs(int64_t ns) { + // It's important that this conversion is the same everywhere. + // No rounding! + return ns / 1000; + } +}; + +} // namespace libkineto diff --git a/plugins/tensorboard-plugins/libkineto/include/ThreadUtil.h b/plugins/tensorboard-plugins/libkineto/include/ThreadUtil.h new file mode 100644 index 0000000000000000000000000000000000000000..d1dc80ad2ab0dfd3bea313363fb0e6565349889c --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/include/ThreadUtil.h @@ -0,0 +1,22 @@ +#pragma once + +#include +#include +#include +#include + +namespace libkineto { + +int32_t systemThreadId(); +int32_t threadId(); +bool setThreadName(const std::string& name); +std::string getThreadName(); + +int32_t processId(); +std::string processName(int32_t pid); + +// Return a list of pids and process names for the current process +// and its parents. +std::vector> pidCommandPairsOfAncestors(); + +} // namespace libkineto diff --git a/plugins/tensorboard-plugins/libkineto/include/TraceSpan.h b/plugins/tensorboard-plugins/libkineto/include/TraceSpan.h new file mode 100644 index 0000000000000000000000000000000000000000..af9a9d5ee556830ac34568e6c81ec4f8f00da2e3 --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/include/TraceSpan.h @@ -0,0 +1,36 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#pragma once + +#include +#include +#include + +namespace libkineto { + +struct TraceSpan { + TraceSpan() = delete; + TraceSpan( + int64_t startTime, int64_t endTime, std::string name) + : startTime(startTime), endTime(endTime), name(std::move(name)) { + } + TraceSpan( + int opCount, int it, std::string name, std::string prefix) + : opCount(opCount), + iteration(it), + name(std::move(name)), + prefix(std::move(prefix)) { + } + + // FIXME: change to duration? + int64_t startTime{0}; + int64_t endTime{0}; + int opCount{0}; + int iteration{-1}; + // Name is used to identify timeline + std::string name; + // Prefix used to distinguish trace spans on the same timeline + std::string prefix; +}; + +} // namespace libkineto diff --git a/plugins/tensorboard-plugins/libkineto/include/libkineto.h b/plugins/tensorboard-plugins/libkineto/include/libkineto.h new file mode 100644 index 0000000000000000000000000000000000000000..87c3d64f638dad9d1c2d24c013135db60d477642 --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/include/libkineto.h @@ -0,0 +1,138 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +// Mediator for initialization and profiler control + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ActivityProfilerInterface.h" +#include "ActivityType.h" +#include "ClientInterface.h" +#include "GenericTraceActivity.h" +#include "TraceSpan.h" +#include "IActivityProfiler.h" +#include "ActivityTraceInterface.h" + +#include "ThreadUtil.h" + +extern "C" { + void suppressLibkinetoLogMessages(); + int InitializeInjection(void); + bool libkineto_init(bool cpuOnly, bool logOnError); +} + +namespace libkineto { + +class Config; +class ConfigLoader; + +struct CpuTraceBuffer { + TraceSpan span{0, 0, "none"}; + int gpuOpCount; + std::deque activities; +}; + +using ChildActivityProfilerFactory = + std::function()>; + +class LibkinetoApi { + public: + + explicit LibkinetoApi(ConfigLoader& configLoader) + : configLoader_(configLoader) { + } + + // Called by client that supports tracing API. + // libkineto can still function without this. + void registerClient(ClientInterface* client); + + // Called by libkineto on init + void registerProfiler(std::unique_ptr profiler) { + activityProfiler_ = std::move(profiler); + initClientIfRegistered(); + } + + ActivityProfilerInterface& activityProfiler() { + return *activityProfiler_; + } + + ClientInterface* client() { + return client_; + } + + void initProfilerIfRegistered() { + static std::once_flag once; + if (activityProfiler_) { + std::call_once(once, [this] { + if (!activityProfiler_->isInitialized()) { + activityProfiler_->init(); + initChildActivityProfilers(); + } + }); + } + } + + bool isProfilerInitialized() const { + return activityProfiler_ && activityProfiler_->isInitialized(); + } + + bool isProfilerRegistered() const { + return activityProfiler_ != nullptr; + } + + void suppressLogMessages() { + suppressLibkinetoLogMessages(); + } + + // Provides access to profier configuration manaegement + ConfigLoader& configLoader() { + return configLoader_; + } + + void registerProfilerFactory( + ChildActivityProfilerFactory factory) { + if (isProfilerInitialized()) { + activityProfiler_->addChildActivityProfiler(factory()); + } else { + childProfilerFactories_.push_back(factory); + } + } + + private: + + void initChildActivityProfilers() { + if (!isProfilerInitialized()) { + return; + } + for (const auto& factory : childProfilerFactories_) { + activityProfiler_->addChildActivityProfiler(factory()); + } + childProfilerFactories_.clear(); + } + + // Client is initialized once both it and libkineto has registered + void initClientIfRegistered(); + + ConfigLoader& configLoader_; + std::unique_ptr activityProfiler_{}; + ClientInterface* client_{}; + int32_t clientRegisterThread_{0}; + + bool isLoaded_{false}; + std::vector childProfilerFactories_; +}; + +// Singleton +LibkinetoApi& api(); + +} // namespace libkineto diff --git a/plugins/tensorboard-plugins/libkineto/include/time_since_epoch.h b/plugins/tensorboard-plugins/libkineto/include/time_since_epoch.h new file mode 100644 index 0000000000000000000000000000000000000000..caa6b4d92760d384eca2b1383a679fe7435c53b3 --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/include/time_since_epoch.h @@ -0,0 +1,16 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#pragma once + +#include + +namespace libkineto { + +inline int64_t timeSinceEpoch( + const std::chrono::time_point& t) { + return std::chrono::duration_cast( + t.time_since_epoch()) + .count(); +} + +} // namespace libkineto diff --git a/plugins/tensorboard-plugins/libkineto/libkineto_defs.bzl b/plugins/tensorboard-plugins/libkineto/libkineto_defs.bzl new file mode 100644 index 0000000000000000000000000000000000000000..330c54a22dfcedf895f0eba4077713a7c4cd8072 --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/libkineto_defs.bzl @@ -0,0 +1,77 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +def get_libkineto_api_srcs(): + return [ + "src/ThreadUtil.cpp", + "src/libkineto_api.cpp", + ] + +def get_libkineto_cupti_srcs(with_api = True): + return [ + "src/CudaDeviceProperties.cpp", + "src/CuptiActivityApi.cpp", + "src/CuptiActivityPlatform.cpp", + "src/CuptiCallbackApi.cpp", + "src/CuptiEventApi.cpp", + "src/CuptiMetricApi.cpp", + "src/CuptiRangeProfilerApi.cpp", + "src/Demangle.cpp", + "src/EventProfiler.cpp", + "src/EventProfilerController.cpp", + "src/WeakSymbols.cpp", + "src/cupti_strings.cpp", + ] + (get_libkineto_cpu_only_srcs(with_api)) + +def get_libkineto_roctracer_srcs(with_api = True): + return [ + "src/RoctracerActivityApi.cpp", + ] + (get_libkineto_cpu_only_srcs(with_api)) + +def get_libkineto_cpu_only_srcs(with_api = True): + return [ + "src/AbstractConfig.cpp", + "src/CuptiActivityProfiler.cpp", + "src/ActivityProfilerController.cpp", + "src/ActivityProfilerProxy.cpp", + "src/ActivityType.cpp", + "src/Config.cpp", + "src/ConfigLoader.cpp", + "src/CuptiActivityApi.cpp", + "src/Demangle.cpp", + "src/GenericTraceActivity.cpp", + "src/ILoggerObserver.cpp", + "src/Logger.cpp", + "src/init.cpp", + "src/output_csv.cpp", + "src/output_json.cpp", + ] + (get_libkineto_api_srcs() if with_api else []) + +def get_libkineto_public_headers(): + return [ + "include/AbstractConfig.h", + "include/ActivityProfilerInterface.h", + "include/ActivityType.h", + "include/Config.h", + "include/ClientInterface.h", + "include/GenericTraceActivity.h", + "include/GenericTraceActivity.h", + "include/IActivityProfiler.h", + "include/ILoggerObserver.h", + "include/ITraceActivity.h", + "include/TraceSpan.h", + "include/ThreadUtil.h", + "include/libkineto.h", + "include/time_since_epoch.h", + ] + +# kineto code should be updated to not have to +# suppress these warnings. +KINETO_COMPILER_FLAGS = [ + "-fexceptions", + "-Wno-deprecated-declarations", + "-Wno-unused-function", + "-Wno-unused-private-field", +] diff --git a/plugins/tensorboard-plugins/libkineto/sample_programs/kineto_playground.cpp b/plugins/tensorboard-plugins/libkineto/sample_programs/kineto_playground.cpp new file mode 100644 index 0000000000000000000000000000000000000000..780047912ed09996d3952901267d46aab99cf78c --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/sample_programs/kineto_playground.cpp @@ -0,0 +1,38 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#include +#include +#include + +#include +#include + +#include "kineto/libkineto/sample_programs/kineto_playground.cuh" + +using namespace kineto; + +static const std::string kFileName = "/tmp/kineto_playground_trace.json"; + +int main() { + warmup(); + + // Kineto config + + // Empty types set defaults to all types + std::set types; + + auto& profiler = libkineto::api().activityProfiler(); + libkineto::api().initProfilerIfRegistered(); + profiler.prepareTrace(types); + + // Good to warm up after prepareTrace to get cupti initialization to settle + warmup(); + profiler.startTrace(); + playground(); + + auto trace = profiler.stopTrace(); + LOG(INFO) << "Stopped and processed trace. Got " << trace->activities()->size() << " activities."; + trace->save(kFileName); + return 0; +} + diff --git a/plugins/tensorboard-plugins/libkineto/sample_programs/kineto_playground.cu b/plugins/tensorboard-plugins/libkineto/sample_programs/kineto_playground.cu new file mode 100644 index 0000000000000000000000000000000000000000..54c6f82ff4be2e468c0e868b49b3a9130de97490 --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/sample_programs/kineto_playground.cu @@ -0,0 +1,60 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#include + +#include "kineto_playground.cuh" + + +namespace kineto { + +void warmup(void) { + // Inititalizing CUDA can take a while which we normally do not want to see in Kineto traces. + // This is done in various ways that take Kineto as dependency. This is our way of doing warmup + // for kineto_playground + size_t bytes = 1000; + float* mem = NULL; + auto error = cudaMalloc(&mem, bytes); + if (error != cudaSuccess) { + printf("cudaMalloc failed during kineto_playground warmup. error code: %d", error); + return; + } + + cudaFree(mem); +} + +void basicMemcpyMemset(void) { + size_t size = (1 << 8) * sizeof(float); + float *hostMemSrc, *deviceMem, *hostMemDst; + cudaError_t err; + + hostMemSrc = (float*)malloc(size); + hostMemDst = (float*)malloc(size); + err = cudaMalloc(&deviceMem, size); + if (err != cudaSuccess) { + printf("cudaMalloc failed during %s", __func__); + return; + } + + memset(hostMemSrc, 1, size); + cudaMemcpy(deviceMem, hostMemSrc, size, cudaMemcpyHostToDevice); + if (err != cudaSuccess) { + printf("cudaMemcpy failed during %s", __func__); + return; + } + + cudaMemcpy(hostMemDst, deviceMem, size, cudaMemcpyDeviceToHost); + if (err != cudaSuccess) { + printf("cudaMemcpy failed during %s", __func__); + return; + } + + free(hostMemSrc); + free(hostMemDst); + cudaFree(deviceMem); +} + +void playground(void) { + // Add your experimental CUDA implementation here. +} + +} diff --git a/plugins/tensorboard-plugins/libkineto/sample_programs/kineto_playground.cuh b/plugins/tensorboard-plugins/libkineto/sample_programs/kineto_playground.cuh new file mode 100644 index 0000000000000000000000000000000000000000..54e1ee59ada9ae88370b38146567ed87be2b914b --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/sample_programs/kineto_playground.cuh @@ -0,0 +1,18 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#pragma once + +#include + +namespace kineto { + +// Warms up CUDA before the tracing starts +void warmup(void); + +// Basic usage of cudaMemcpy and cudaMemset +void basicMemcpyMemset(void); + +// Your experimental code goes in here! +void playground(void); + +} diff --git a/plugins/tensorboard-plugins/libkineto/src/AbstractConfig.cpp b/plugins/tensorboard-plugins/libkineto/src/AbstractConfig.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d60ab43c9a3e198167beb7987d619b0bb8e9ed13 --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/src/AbstractConfig.cpp @@ -0,0 +1,188 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#include "AbstractConfig.h" + +#include +#include +#include + +#include "Logger.h" + +using namespace std::chrono; + +using std::string; +using std::vector; + +namespace KINETO_NAMESPACE { + +constexpr char kWhitespace[] = "\t\n "; + +static bool isWhitespace(string& s) { + return s.find_first_not_of(kWhitespace) == string::npos; +} + +// Remove whitespace from both end of string +static inline string trim(string& s) { + if (s.empty()) { + return s; + } else if (isWhitespace(s)) { + return ""; + } + auto start = s.find_first_not_of(kWhitespace); + auto end = s.find_last_not_of(kWhitespace); + return s.substr(start, end - start + 1); +} + +// Helper function for split. +// Return the index of char d in string s. +// If not found, returns the length of the string. +static int find(const char* s, char delim) { + int i; + for (i = 0; s[i]; i++) { + if (s[i] == delim) { + break; + } + } + return i; +} + +// Split a string by delimiter +static vector split(const string& s, char delim) { + vector res; + const char* cs = s.c_str(); + for (int i = find(cs, delim); cs[i]; cs += i + 1, i = find(cs, delim)) { + res.emplace_back(cs, i); + } + res.emplace_back(cs); + return res; +} + +// Remove a trailing comment. +static inline string stripComment(const string& s) { + std::size_t pos = s.find("#"); + return s.substr(0, pos); +} + +string AbstractConfig::toLower(string& s) const { + string res = s; + for (int i = 0; i < res.size(); i++) { + if (res[i] >= 'A' && res[i] <= 'Z') { + res[i] += ('a' - 'A'); + } + } + return res; +} + +bool AbstractConfig::endsWith(const string& s, const string& suffix) const { + if (suffix.size() > s.size()) { + return false; + } + return s.compare(s.size() - suffix.size(), suffix.size(), suffix) == 0; +} + +vector AbstractConfig::splitAndTrim(const string& s, char delim) const { + auto res = split(s, delim); + for (string& x : res) { + x = trim(x); + } + return res; +} + +int64_t AbstractConfig::toIntRange(const string& val, int64_t min, int64_t max) + const { + char* invalid; + int64_t res = strtoll(val.c_str(), &invalid, 10); + if (val.empty() || *invalid) { + throw std::invalid_argument(fmt::format("Invalid integer: {}", val)); + } else if (res < min || res > max) { + throw std::invalid_argument(fmt::format( + "Invalid argument: {} - expected range [{}, {}]", res, min, max)); + } + return res; +} + +int32_t AbstractConfig::toInt32(const string& val) const { + return toIntRange(val, 0, ~0u / 2); +} + +int64_t AbstractConfig::toInt64(const string& val) const { + return toIntRange(val, 0, ~0ul / 2); +} + +bool AbstractConfig::toBool(string& val) const { + const std::array bool_vals{ + "n", "y", "no", "yes", "f", "t", "false", "true"}; + const string lower_val = toLower(val); + for (int i = 0; i < bool_vals.size(); i++) { + if (lower_val == bool_vals[i]) { + return i % 2; + } + } + throw std::invalid_argument(fmt::format("Invalid bool argument: {}", val)); + return false; +} + +bool AbstractConfig::parse(const string& conf) { + std::istringstream iss(conf); + string line; + + timestamp_ = system_clock::now(); + + // Read the string stream 1 line at a time to parse. + while (std::getline(iss, line)) { + line = stripComment(line); + if (isWhitespace(line)) { + continue; + } + vector key_val = splitAndTrim(line, '='); + if (key_val.size() != 2) { + LOG(ERROR) << "Invalid config line: " << line; + return false; + } else { + bool handled = false; + try { + handled = handleOption(key_val[0], key_val[1]); + if (!handled) { + for (auto& feature_cfg : featureConfigs_) { + if (feature_cfg.second->handleOption(key_val[0], key_val[1])) { + handled = true; + break; + } + } + } + } catch (const std::exception& e) { + LOG(ERROR) << "Failed to parse config line: " << line; + LOG(ERROR) << e.what(); + return false; + } + if (!handled) { + // This might be due to using a newer config option on an + // older binary where it is not supported. In this case, + // print a warning message - but it is expected to work! + LOG(WARNING) << "Unrecognized config line: " << line; + } + } + } + + validate(timestamp_); + + // Store original text, used to detect updates + source_ = conf; + timestamp_ = system_clock::now(); + return true; +} + +bool AbstractConfig::handleOption( + const std::string& /* unused */, + std::string& /* unused */) { + LOG(ERROR) << "handleOption unimplemented"; + return false; +} + +void AbstractConfig::printActivityProfilerConfig(std::ostream& s) const { + for (const auto& feature_cfg : featureConfigs_) { + feature_cfg.second->printActivityProfilerConfig(s); + } +} + +} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/ActivityBuffers.h b/plugins/tensorboard-plugins/libkineto/src/ActivityBuffers.h new file mode 100644 index 0000000000000000000000000000000000000000..157af879379a5f5fc5e274f22604987a97f17af4 --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/src/ActivityBuffers.h @@ -0,0 +1,29 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#pragma once + + +#include +#include + +#include "libkineto.h" +#include "CuptiActivityBuffer.h" + +namespace KINETO_NAMESPACE { + +struct ActivityBuffers { + std::list> cpu; + std::unique_ptr gpu; + + // Add a wrapper object to the underlying struct stored in the buffer + template + const ITraceActivity& addActivityWrapper(const T& act) { + wrappers_.push_back(std::make_unique(act)); + return *wrappers_.back().get(); + } + + private: + std::vector> wrappers_; +}; + +} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/ActivityLoggerFactory.h b/plugins/tensorboard-plugins/libkineto/src/ActivityLoggerFactory.h new file mode 100644 index 0000000000000000000000000000000000000000..0d1bf642cd68051e487004d33e19c5eb181e1c41 --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/src/ActivityLoggerFactory.h @@ -0,0 +1,60 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace KINETO_NAMESPACE { + +class ActivityLogger; + +class ActivityLoggerFactory { + + public: + using FactoryFunc = + std::function(const std::string& url)>; + + // Add logger factory for a protocol prefix + void addProtocol(const std::string& protocol, FactoryFunc f) { + factories_[tolower(protocol)] = f; + } + + // Create a logger, invoking the factory for the protocol specified in url + std::unique_ptr makeLogger(const std::string& url) const { + std::string protocol = extractProtocol(url); + auto it = factories_.find(tolower(protocol)); + if (it != factories_.end()) { + return it->second(stripProtocol(url)); + } + throw std::invalid_argument(fmt::format( + "No logger registered for the {} protocol prefix", + protocol)); + return nullptr; + } + + private: + static std::string tolower(std::string s) { + std::transform(s.begin(), s.end(), s.begin(), + [](unsigned char c) { return std::tolower(c); } + ); + return s; + } + + static std::string extractProtocol(std::string url) { + return url.substr(0, url.find("://")); + } + + static std::string stripProtocol(std::string url) { + size_t pos = url.find("://"); + return pos == url.npos ? url : url.substr(pos + 3); + } + + std::map factories_; +}; + +} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/ActivityProfilerController.cpp b/plugins/tensorboard-plugins/libkineto/src/ActivityProfilerController.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c85d41ed73ff059bcd7ee69c36a0bcc6c3d5c4ca --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/src/ActivityProfilerController.cpp @@ -0,0 +1,246 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#include "ActivityProfilerController.h" + +#include +#include + +#include "ActivityLoggerFactory.h" +#include "ActivityTrace.h" +#include "CuptiActivityApi.h" +#ifdef HAS_ROCTRACER +#include "RoctracerActivityApi.h" +#endif +#include "ThreadUtil.h" +#include "output_json.h" +#include "output_membuf.h" + +#include "Logger.h" + +using namespace std::chrono; + +namespace KINETO_NAMESPACE { + +constexpr milliseconds kProfilerIntervalMsecs(1000); + +ActivityProfilerController::ActivityProfilerController( + ConfigLoader& configLoader, bool cpuOnly) + : configLoader_(configLoader) { +#ifdef HAS_ROCTRACER + profiler_ = std::make_unique( + RoctracerActivityApi::singleton(), cpuOnly); +#else + profiler_ = std::make_unique( + CuptiActivityApi::singleton(), cpuOnly); +#endif + configLoader_.addHandler(ConfigLoader::ConfigKind::ActivityProfiler, this); +} + +ActivityProfilerController::~ActivityProfilerController() { + configLoader_.removeHandler( + ConfigLoader::ConfigKind::ActivityProfiler, this); + if (profilerThread_) { + // signaling termination of the profiler loop + stopRunloop_ = true; + profilerThread_->join(); + delete profilerThread_; + profilerThread_ = nullptr; + } +} + +static ActivityLoggerFactory initLoggerFactory() { + ActivityLoggerFactory factory; + factory.addProtocol("file", [](const std::string& url) { + return std::unique_ptr(new ChromeTraceLogger(url)); + }); + return factory; +} + +static ActivityLoggerFactory& loggerFactory() { + static ActivityLoggerFactory factory = initLoggerFactory(); + return factory; +} + +void ActivityProfilerController::addLoggerFactory( + const std::string& protocol, ActivityLoggerFactory::FactoryFunc factory) { + loggerFactory().addProtocol(protocol, factory); +} + +static std::unique_ptr makeLogger(const Config& config) { + if (config.activitiesLogToMemory()) { + return std::make_unique(config); + } + return loggerFactory().makeLogger(config.activitiesLogUrl()); +} + +bool ActivityProfilerController::canAcceptConfig() { + return !profiler_->isActive(); +} + +void ActivityProfilerController::acceptConfig(const Config& config) { + VLOG(1) << "acceptConfig"; + if (config.activityProfilerEnabled()) { + scheduleTrace(config); + } +} + +void ActivityProfilerController::profilerLoop() { + setThreadName("Kineto Activity Profiler"); + VLOG(0) << "Entering activity profiler loop"; + + auto now = system_clock::now(); + auto next_wakeup_time = now + kProfilerIntervalMsecs; + + while (!stopRunloop_) { + now = system_clock::now(); + + while (now < next_wakeup_time) { + /* sleep override */ + std::this_thread::sleep_for(next_wakeup_time - now); + now = system_clock::now(); + } + + if (!profiler_->isActive()) { + std::lock_guard lock(asyncConfigLock_); + if (asyncRequestConfig_ + && !asyncRequestConfig_->hasProfileStartIteration()) { + // Note on now + kProfilerIntervalMsecs + // Profiler interval does not align perfectly upto startTime - warmup. Waiting until the next tick + // won't allow sufficient time for the profiler to warm up. So check if we are very close to the warmup time and trigger warmup + if (now + kProfilerIntervalMsecs + >= (asyncRequestConfig_->requestTimestamp() - asyncRequestConfig_->activitiesWarmupDuration())) { + LOG(INFO) << "Received on-demand activity trace request by " + << " profile timestamp = " + << asyncRequestConfig_-> + requestTimestamp().time_since_epoch().count(); + activateConfig(now); + } + } + } + + while (next_wakeup_time < now) { + next_wakeup_time += kProfilerIntervalMsecs; + } + + if (profiler_->isActive()) { + next_wakeup_time = profiler_->performRunLoopStep(now, next_wakeup_time); + VLOG(1) << "Profiler loop: " + << duration_cast(system_clock::now() - now).count() + << "ms"; + } + } + + VLOG(0) << "Exited activity profiling loop"; +} + +void ActivityProfilerController::step() { + int64_t currentIter = ++iterationCount_; + VLOG(0) << "Step called , iteration = " << currentIter; + + // optimization to not take the lock unless necessary + if (asyncRequestConfig_ && !profiler_->isActive()) { + std::lock_guard lock(asyncConfigLock_); + auto startIter = asyncRequestConfig_->startIterationIncludingWarmup(); + + if (asyncRequestConfig_->hasProfileStartIteration() + && currentIter >= startIter) { + LOG(INFO) << "Received on-demand activity trace request by profile" + << " start iteration = " + << asyncRequestConfig_->profileStartIteration() + << " current iteration = " << currentIter; + + if (currentIter > startIter) { + // adjust the start iteration if it is in the past + auto newProfileStart = currentIter + + asyncRequestConfig_->activitiesWarmupIterations(); + LOG(INFO) << "Start iteration updated to " << newProfileStart; + asyncRequestConfig_->setProfileStartIteration(newProfileStart); + } + activateConfig(system_clock::now()); + } + } + + if (profiler_->isActive()) { + auto now = system_clock::now(); + auto next_wakeup_time = now + kProfilerIntervalMsecs; + profiler_->performRunLoopStep(now, next_wakeup_time, currentIter); + } +} + +void ActivityProfilerController::activateConfig( + std::chrono::time_point now) { + logger_ = makeLogger(*asyncRequestConfig_); + profiler_->setLogger(logger_.get()); + profiler_->configure(*asyncRequestConfig_, now); + asyncRequestConfig_ = nullptr; +} + +void ActivityProfilerController::scheduleTrace(const Config& config) { + VLOG(1) << "scheduleTrace"; + if (profiler_->isActive()) { + LOG(ERROR) << "Ignored request - profiler busy"; + return; + } + int64_t currentIter = iterationCount_; + if (config.hasProfileStartIteration() && currentIter < 0) { + LOG(ERROR) << "Ignored profile iteration count based request as " + << "application is not updating iteration count"; + return; + } + std::lock_guard lock(asyncConfigLock_); + asyncRequestConfig_ = config.clone(); + + auto startIter = asyncRequestConfig_->startIterationIncludingWarmup(); + + if (asyncRequestConfig_->hasProfileStartIteration() + && (currentIter > startIter) + && asyncRequestConfig_->profileStartIterationRoundUp() > 0) { + auto newProfileStart + = currentIter + asyncRequestConfig_->activitiesWarmupIterations(); + // round up to nearest multiple + auto divisor = asyncRequestConfig_->profileStartIterationRoundUp(); + auto rem = newProfileStart % divisor; + newProfileStart += ((rem == 0) ? 0 : divisor - rem); + LOG(INFO) << "Rounding up profiler start iteration to : " << newProfileStart; + asyncRequestConfig_->setProfileStartIteration(newProfileStart); + } + + // start a profilerLoop() thread to handle request + if (!profilerThread_) { + profilerThread_ = + new std::thread(&ActivityProfilerController::profilerLoop, this); + } +} + +void ActivityProfilerController::prepareTrace(const Config& config) { + // Requests from ActivityProfilerApi have higher priority than + // requests from other sources (signal, daemon). + // Cancel any ongoing request and refuse new ones. + auto now = system_clock::now(); + if (profiler_->isActive()) { + LOG(WARNING) << "Cancelling current trace request in order to start " + << "higher priority synchronous request"; + if (libkineto::api().client()) { + libkineto::api().client()->stop(); + } + profiler_->stopTrace(now); + profiler_->reset(); + } + + profiler_->configure(config, now); +} + +std::unique_ptr ActivityProfilerController::stopTrace() { + profiler_->stopTrace(std::chrono::system_clock::now()); + auto logger = std::make_unique(profiler_->config()); + profiler_->processTrace(*logger); + profiler_->reset(); + return std::make_unique(std::move(logger), loggerFactory()); +} + +void ActivityProfilerController::addMetadata( + const std::string& key, const std::string& value) { + profiler_->addMetadata(key, value); +} + +} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/ActivityProfilerController.h b/plugins/tensorboard-plugins/libkineto/src/ActivityProfilerController.h new file mode 100644 index 0000000000000000000000000000000000000000..415f107cbed6aab4777c65e9e51d65686002e762 --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/src/ActivityProfilerController.h @@ -0,0 +1,84 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#pragma once + +#include +#include +#include +#include +#include + +#include "ActivityLoggerFactory.h" +#include "CuptiActivityProfiler.h" +#include "ActivityProfilerInterface.h" +#include "ActivityTraceInterface.h" +#include "ConfigLoader.h" +#include "CuptiActivityApi.h" + +namespace KINETO_NAMESPACE { + +class Config; + +class ActivityProfilerController : public ConfigLoader::ConfigHandler { + public: + explicit ActivityProfilerController(ConfigLoader& configLoader, bool cpuOnly); + ActivityProfilerController(const ActivityProfilerController&) = delete; + ActivityProfilerController& operator=(const ActivityProfilerController&) = + delete; + + ~ActivityProfilerController(); + + static void addLoggerFactory( + const std::string& protocol, + ActivityLoggerFactory::FactoryFunc factory); + + bool canAcceptConfig() override; + void acceptConfig(const Config& config) override; + + void scheduleTrace(const Config& config); + + void prepareTrace(const Config& config); + + void startTrace() { + profiler_->startTrace(std::chrono::system_clock::now()); + } + + void step(); + + std::unique_ptr stopTrace(); + + bool isActive() { + return profiler_->isActive(); + } + + void transferCpuTrace( + std::unique_ptr cpuTrace) { + return profiler_->transferCpuTrace(std::move(cpuTrace)); + } + + void recordThreadInfo() { + profiler_->recordThreadInfo(); + } + + void addChildActivityProfiler( + std::unique_ptr profiler) { + profiler_->addChildActivityProfiler(std::move(profiler)); + } + + void addMetadata(const std::string& key, const std::string& value); + + private: + void profilerLoop(); + void activateConfig(std::chrono::time_point now); + + std::unique_ptr asyncRequestConfig_; + std::mutex asyncConfigLock_; + std::unique_ptr profiler_; + std::unique_ptr logger_; + std::thread* profilerThread_{nullptr}; + std::atomic_bool stopRunloop_{false}; + std::atomic iterationCount_{-1}; + ConfigLoader& configLoader_; +}; + +} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/ActivityProfilerProxy.cpp b/plugins/tensorboard-plugins/libkineto/src/ActivityProfilerProxy.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b2d36b7b3abf9c3e0aed838a10e4054a5d292139 --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/src/ActivityProfilerProxy.cpp @@ -0,0 +1,119 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#include "ActivityProfilerProxy.h" + +#include "ActivityProfilerController.h" +#include "Config.h" +#include "CuptiActivityApi.h" +#include "Logger.h" +#include + +namespace KINETO_NAMESPACE { + +ActivityProfilerProxy::ActivityProfilerProxy( + bool cpuOnly, ConfigLoader& configLoader) + : cpuOnly_(cpuOnly), configLoader_(configLoader) { +} + +ActivityProfilerProxy::~ActivityProfilerProxy() { + delete controller_; +}; + +void ActivityProfilerProxy::init() { + if (!controller_) { + controller_ = new ActivityProfilerController(configLoader_, cpuOnly_); + } +} + +void ActivityProfilerProxy::scheduleTrace(const std::string& configStr) { + Config config; + config.parse(configStr); + controller_->scheduleTrace(config); +} + +void ActivityProfilerProxy::scheduleTrace(const Config& config) { + controller_->scheduleTrace(config); +} + +void ActivityProfilerProxy::prepareTrace( + const std::set& activityTypes, + const std::string& configStr) { + Config config; + bool validate_required = true; + + // allow user provided config to override default options + if (!configStr.empty()) { + if (!config.parse(configStr)) { + LOG(WARNING) << "Failed to parse config : " << configStr; + } + // parse also runs validate + validate_required = false; + } + + config.setClientDefaults(); + config.setSelectedActivityTypes(activityTypes); + + if (validate_required) { + config.validate(std::chrono::system_clock::now()); + } + + controller_->prepareTrace(config); +} + +void ActivityProfilerProxy::startTrace() { + controller_->startTrace(); +} + +std::unique_ptr +ActivityProfilerProxy::stopTrace() { + return controller_->stopTrace(); +} + +void ActivityProfilerProxy::step() { + controller_->step(); +} + +bool ActivityProfilerProxy::isActive() { + return controller_->isActive(); +} + +void ActivityProfilerProxy::pushCorrelationId(uint64_t id) { + CuptiActivityApi::pushCorrelationID(id, + CuptiActivityApi::CorrelationFlowType::Default); +} + +void ActivityProfilerProxy::popCorrelationId() { + CuptiActivityApi::popCorrelationID( + CuptiActivityApi::CorrelationFlowType::Default); +} + +void ActivityProfilerProxy::pushUserCorrelationId(uint64_t id) { + CuptiActivityApi::pushCorrelationID(id, + CuptiActivityApi::CorrelationFlowType::User); +} + +void ActivityProfilerProxy::popUserCorrelationId() { + CuptiActivityApi::popCorrelationID( + CuptiActivityApi::CorrelationFlowType::User); +} + +void ActivityProfilerProxy::transferCpuTrace( + std::unique_ptr traceBuffer) { + controller_->transferCpuTrace(std::move(traceBuffer)); +} + +void ActivityProfilerProxy::addMetadata( + const std::string& key, const std::string& value) { + controller_->addMetadata(key, value); +} + +void ActivityProfilerProxy::recordThreadInfo() { + controller_->recordThreadInfo(); +} + +void ActivityProfilerProxy::addChildActivityProfiler( + std::unique_ptr profiler) { + controller_->addChildActivityProfiler(std::move(profiler)); +} + +} // namespace libkineto diff --git a/plugins/tensorboard-plugins/libkineto/src/ActivityProfilerProxy.h b/plugins/tensorboard-plugins/libkineto/src/ActivityProfilerProxy.h new file mode 100644 index 0000000000000000000000000000000000000000..b5cf84b2f1ddb005060fea0927c99fc63d144d99 --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/src/ActivityProfilerProxy.h @@ -0,0 +1,73 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#pragma once + +#include "ActivityProfilerInterface.h" + +#include +#include +#include + +#include "ActivityType.h" +#include "ITraceActivity.h" + +namespace libkineto { + // previous declaration is struct so this one must be too. + struct CpuTraceBuffer; +} + +namespace KINETO_NAMESPACE { + +using namespace libkineto; + +class ActivityProfilerController; +class Config; +class ConfigLoader; + +class ActivityProfilerProxy : public ActivityProfilerInterface { + + public: + ActivityProfilerProxy(bool cpuOnly, ConfigLoader& configLoader); + ~ActivityProfilerProxy() override; + + void init() override; + bool isInitialized() override { + return controller_ != nullptr; + } + + bool isActive() override; + + void recordThreadInfo() override; + + void scheduleTrace(const std::string& configStr) override; + void scheduleTrace(const Config& config); + + void prepareTrace( + const std::set& activityTypes, + const std::string& configStr = "") override; + + void startTrace() override; + void step() override; + std::unique_ptr stopTrace() override; + + void pushCorrelationId(uint64_t id) override; + void popCorrelationId() override; + + void pushUserCorrelationId(uint64_t id) override; + void popUserCorrelationId() override; + + void transferCpuTrace( + std::unique_ptr traceBuffer) override; + + void addMetadata(const std::string& key, const std::string& value) override; + + virtual void addChildActivityProfiler( + std::unique_ptr profiler) override; + + private: + bool cpuOnly_{true}; + ConfigLoader& configLoader_; + ActivityProfilerController* controller_{nullptr}; +}; + +} // namespace libkineto diff --git a/plugins/tensorboard-plugins/libkineto/src/ActivityTrace.h b/plugins/tensorboard-plugins/libkineto/src/ActivityTrace.h new file mode 100644 index 0000000000000000000000000000000000000000..0be76af08e47c16ebee2ac1d1ad01c4425ff17a5 --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/src/ActivityTrace.h @@ -0,0 +1,45 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#pragma once + +#include +#include + +#include "ActivityLoggerFactory.h" +#include "ActivityTraceInterface.h" +#include "output_json.h" +#include "output_membuf.h" + +namespace libkineto { + +class ActivityTrace : public ActivityTraceInterface { + public: + ActivityTrace( + std::unique_ptr tmpLogger, + const ActivityLoggerFactory& factory) + : memLogger_(std::move(tmpLogger)), + loggerFactory_(factory) { + } + + const std::vector* activities() override { + return memLogger_->traceActivities(); + }; + + void save(const std::string& url) override { + std::string prefix; + // if no protocol is specified, default to file + if (url.find("://") == url.npos) { + prefix = "file://"; + } + memLogger_->log(*loggerFactory_.makeLogger(prefix + url)); + }; + + private: + // Activities are logged into a buffer + std::unique_ptr memLogger_; + + // Alternative logger used by save() if protocol prefix is specified + const ActivityLoggerFactory& loggerFactory_; +}; + +} // namespace libkineto diff --git a/plugins/tensorboard-plugins/libkineto/src/ActivityType.cpp b/plugins/tensorboard-plugins/libkineto/src/ActivityType.cpp new file mode 100644 index 0000000000000000000000000000000000000000..18856b72370abdb6d9cf4309b32be4cae10805de --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/src/ActivityType.cpp @@ -0,0 +1,58 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#include "ActivityType.h" + +#include + +namespace libkineto { + +struct ActivityTypeName { + const char* name; + ActivityType type; +}; + +static constexpr std::array map{{ + {"cpu_op", ActivityType::CPU_OP}, + {"user_annotation", ActivityType::USER_ANNOTATION}, + {"gpu_user_Annotation", ActivityType::GPU_USER_ANNOTATION}, + {"gpu_memcpy", ActivityType::GPU_MEMCPY}, + {"gpu_memset", ActivityType::GPU_MEMSET}, + {"kernel", ActivityType::CONCURRENT_KERNEL}, + {"external_correlation", ActivityType::EXTERNAL_CORRELATION}, + {"cuda_runtime", ActivityType::CUDA_RUNTIME}, + {"cuda_profiler_range", ActivityType::CUDA_PROFILER_RANGE}, + {"glow_runtime", ActivityType::GLOW_RUNTIME}, + {"cpu_instant_event", ActivityType::CPU_INSTANT_EVENT}, + {"python_function", ActivityType::PYTHON_FUNCTION}, + {"overhead", ActivityType::OVERHEAD}, + {"ENUM_COUNT", ActivityType::ENUM_COUNT} +}}; + +static constexpr bool matchingOrder(int idx = 0) { + return map[idx].type == ActivityType::ENUM_COUNT || + ((idx == (int) map[idx].type) && matchingOrder(idx + 1)); +} +static_assert(matchingOrder(), "ActivityTypeName map is out of order"); + +const char* toString(ActivityType t) { + return map[(int)t].name; +} + +ActivityType toActivityType(const std::string& str) { + for (int i = 0; i < activityTypeCount; i++) { + if (str == map[i].name) { + return map[i].type; + } + } + throw std::invalid_argument(fmt::format("Invalid activity type: {}", str)); +} + +const std::array activityTypes() { + std::array res; + for (int i = 0; i < activityTypeCount; i++) { + res[i] = map[i].type; + } + return res; +} + +} // namespace libkineto diff --git a/plugins/tensorboard-plugins/libkineto/src/Config.cpp b/plugins/tensorboard-plugins/libkineto/src/Config.cpp new file mode 100644 index 0000000000000000000000000000000000000000..95538840f378e83b2b44161823042c620b34fe93 --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/src/Config.cpp @@ -0,0 +1,473 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#include "Config.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "Logger.h" +#include "ThreadUtil.h" + +using namespace std::chrono; + +using std::string; +using std::vector; + +namespace KINETO_NAMESPACE { + +constexpr milliseconds kDefaultSamplePeriodMsecs(1000); +constexpr milliseconds kDefaultMultiplexPeriodMsecs(1000); +constexpr milliseconds kDefaultActivitiesProfileDurationMSecs(500); +constexpr int kDefaultActivitiesMaxGpuBufferSize(128 * 1024 * 1024); +constexpr seconds kDefaultActivitiesWarmupDurationSecs(5); +constexpr seconds kDefaultBufferUntilWarmup(10); +constexpr seconds kDefaultReportPeriodSecs(1); +constexpr int kDefaultSamplesPerReport(1); +constexpr int kDefaultMaxEventProfilersPerGpu(1); +constexpr int kDefaultEventProfilerHearbeatMonitorPeriod(0); +constexpr seconds kMaxRequestAge(10); + +// Event Profiler +constexpr char kEventsKey[] = "EVENTS"; +constexpr char kMetricsKey[] = "METRICS"; +constexpr char kSamplePeriodKey[] = "SAMPLE_PERIOD_MSECS"; +constexpr char kMultiplexPeriodKey[] = "MULTIPLEX_PERIOD_MSECS"; +constexpr char kReportPeriodKey[] = "REPORT_PERIOD_SECS"; +constexpr char kSamplesPerReportKey[] = "SAMPLES_PER_REPORT"; +constexpr char kEventsLogFileKey[] = "EVENTS_LOG_FILE"; +constexpr char kEventsEnabledDevicesKey[] = "EVENTS_ENABLED_DEVICES"; +constexpr char kOnDemandDurationKey[] = "EVENTS_DURATION_SECS"; +constexpr char kMaxEventProfilersPerGpuKey[] = "MAX_EVENT_PROFILERS_PER_GPU"; +constexpr char kHeartbeatMonitorPeriodKey[] = + "EVENTS_HEARTBEAT_MONITOR_PERIOD_SECS"; + +// Activity Profiler +constexpr char kActivitiesEnabledKey[] = "ACTIVITIES_ENABLED"; +constexpr char kActivityTypesKey[] = "ACTIVITY_TYPES"; +constexpr char kActivitiesLogFileKey[] = "ACTIVITIES_LOG_FILE"; +constexpr char kActivitiesDurationKey[] = "ACTIVITIES_DURATION_SECS"; +constexpr char kActivitiesDurationMsecsKey[] = "ACTIVITIES_DURATION_MSECS"; +constexpr char kActivitiesWarmupDurationSecsKey[] = "ACTIVITIES_WARMUP_PERIOD_SECS"; +constexpr char kActivitiesMaxGpuBufferSizeKey[] = + "ACTIVITIES_MAX_GPU_BUFFER_SIZE_MB"; + +// Client Interface +constexpr char kClientInterfaceEnableOpInputsCollection[] = "CLIENT_INTERFACE_ENABLE_OP_INPUTS_COLLECTION"; + +constexpr char kActivitiesWarmupIterationsKey[] = "ACTIVITIES_WARMUP_ITERATIONS"; +constexpr char kActivitiesIterationsKey[] = "ACTIVITIES_ITERATIONS"; +// Common + +// Client-side timestamp used for synchronized start across hosts for +// distributed workloads. +// Specified in milliseconds Unix time (milliseconds since epoch). +// To use, compute a future timestamp as follows: +// * C++: + duration_cast( +// system_clock::now().time_since_epoch()).count() +// * Python: + int(time.time() * 1000) +// * Bash: $(( + $(date +%s%3N))) +// If used for a tracing request, timestamp must be far enough in the future +// to accommodate ACTIVITIES_WARMUP_PERIOD_SECS as well as any delays in +// propagating the request to the profiler. +// If the request can not be honored, it is up to the profilers to report +// an error somehow - no checks are done at config parse time. +// Note PROFILE_START_ITERATION has higher precedence +constexpr char kProfileStartTimeKey[] = "PROFILE_START_TIME"; +// DEPRECATED - USE PROFILE_START_TIME instead +constexpr char kRequestTimestampKey[] = "REQUEST_TIMESTAMP"; + +// Alternatively if the application supports reporting iterations +// start the profile at specific iteration. If the iteration count +// is >= this value the profile is started immediately. +// A value >= 0 is valid for this config option to take effect. +// Note PROFILE_START_ITERATION will take precedence over PROFILE_START_TIME. +constexpr char kProfileStartIterationKey[] = "PROFILE_START_ITERATION"; + +// Users can also start the profile on an integer multiple of the config +// value PROFILE_START_ITERATION_ROUNDUP. This knob behaves similar to +// PROFILE_START_ITERATION but instead of saying : "start collection trace on +// iteration 500", one can configure it to "start collecting trace on the next +// 100th iteration". +// +// For example, +// PROFILE_START_ITERATION_ROUNDUP = 1000, and the current iteration is 2010 +// The profile will then be collected on the next multiple of 1000 ie. 3000 +// Note PROFILE_START_ITERATION_ROUNDUP will also take precedence over +// PROFILE_START_TIME. +constexpr char kProfileStartIterationRoundUpKey[] + = "PROFILE_START_ITERATION_ROUNDUP"; + +// Enable on-demand trigger via kill -USR2 +// When triggered in this way, /tmp/libkineto.conf will be used as config. +constexpr char kEnableSigUsr2Key[] = "ENABLE_SIGUSR2"; + +// Enable communication through IPC Fabric +// and disable thrift communication with dynolog daemon +constexpr char kEnableIpcFabricKey[] = "ENABLE_IPC_FABRIC"; + +// Verbose log level +// The actual glog is not used and --v and --vmodule has no effect. +// Instead set the verbose level and modules in the config file. +constexpr char kLogVerboseLevelKey[] = "VERBOSE_LOG_LEVEL"; +// By default, all modules will log verbose messages >= verboseLogLevel. +// But to reduce noise we can specify one or more modules of interest. +// A module is a C/C++ object file (source file name), +// Example argument: ActivityProfiler.cpp,output_json.cpp +constexpr char kLogVerboseModulesKey[] = "VERBOSE_LOG_MODULES"; + +// Max devices supported on any system +constexpr uint8_t kMaxDevices = 8; + +namespace { + +struct FactoryMap { + + void addFactory( + std::string name, + std::function factory) { + std::lock_guard lock(lock_); + factories_[name] = factory; + } + + void addFeatureConfigs(Config& cfg) { + std::lock_guard lock(lock_); + for (const auto& p : factories_) { + cfg.addFeature(p.first, p.second(cfg)); + } + } + +// Config factories are shared between objects and since +// config objects can be created by multiple threads, we need a lock. + std::mutex lock_; + std::map> factories_; +}; + +std::shared_ptr configFactories() { + // Ensure this is safe to call during shutdown, even as static + // destructors are invoked. Once factories destructor has been + // invoked, weak_ptr.lock() will return nullptr. + // But calls before that point will have a valid shared_ptr, + // delaying destruction of the underlying FactoryMap. + static auto factories = std::make_shared(); + static std::weak_ptr weak_ptr = factories; + return weak_ptr.lock(); +} + +} // namespace + +void Config::addConfigFactory( + std::string name, + std::function factory) { + auto factories = configFactories(); + if (factories) { + factories->addFactory(name, factory); + } +} + +static string defaultTraceFileName() { + return fmt::format("/tmp/libkineto_activities_{}.json", processId()); +} + +Config::Config() + : verboseLogLevel_(-1), + samplePeriod_(kDefaultSamplePeriodMsecs), + reportPeriod_(duration_cast(kDefaultReportPeriodSecs)), + samplesPerReport_(kDefaultSamplesPerReport), + eventProfilerOnDemandDuration_(seconds(0)), + eventProfilerMaxInstancesPerGpu_(kDefaultMaxEventProfilersPerGpu), + eventProfilerHeartbeatMonitorPeriod_( + kDefaultEventProfilerHearbeatMonitorPeriod), + multiplexPeriod_(kDefaultMultiplexPeriodMsecs), + activityProfilerEnabled_(true), + activitiesLogFile_(defaultTraceFileName()), + activitiesLogUrl_(fmt::format("file://{}", activitiesLogFile_)), + activitiesMaxGpuBufferSize_(kDefaultActivitiesMaxGpuBufferSize), + activitiesWarmupDuration_(kDefaultActivitiesWarmupDurationSecs), + activitiesWarmupIterations_(0), + activitiesDuration_(kDefaultActivitiesProfileDurationMSecs), + activitiesRunIterations_(0), + activitiesOnDemandTimestamp_(milliseconds(0)), + profileStartTime_(milliseconds(0)), + profileStartIteration_(-1), + profileStartIterationRoundUp_(-1), + requestTimestamp_(milliseconds(0)), + enableSigUsr2_(false), + enableIpcFabric_(false) { + auto factories = configFactories(); + if (factories) { + factories->addFeatureConfigs(*this); + } +} + +uint8_t Config::createDeviceMask(const string& val) { + uint8_t res = 0; + for (const auto& d : splitAndTrim(val, ',')) { + res |= 1 << toIntRange(d, 0, kMaxDevices - 1); + } + return res; +} + +const seconds Config::maxRequestAge() const { + return kMaxRequestAge; +} + +static std::string getTimeStr(time_point t) { + std::time_t t_c = system_clock::to_time_t(t); + return fmt::format("{:%H:%M:%S}", fmt::localtime(t_c)); +} + +static time_point handleRequestTimestamp(int64_t ms) { + auto t = time_point(milliseconds(ms)); + auto now = system_clock::now(); + if (t > now) { + throw std::invalid_argument(fmt::format( + "Invalid {}: {} - time is in future", + kRequestTimestampKey, + getTimeStr(t))); + } else if ((now - t) > kMaxRequestAge) { + throw std::invalid_argument(fmt::format( + "Invalid {}: {} - time is more than {}s in the past", + kRequestTimestampKey, + getTimeStr(t), + kMaxRequestAge.count())); + } + return t; +} + +void Config::setActivityTypes( + const std::vector& selected_activities) { + selectedActivityTypes_.clear(); + if (selected_activities.size() > 0) { + for (const auto& activity : selected_activities) { + if (activity == "") { + continue; + } + selectedActivityTypes_.insert(toActivityType(activity)); + } + } +} + +bool Config::handleOption(const std::string& name, std::string& val) { + // Event Profiler + if (!name.compare(kEventsKey)) { + vector event_names = splitAndTrim(val, ','); + eventNames_.insert(event_names.begin(), event_names.end()); + } else if (!name.compare(kMetricsKey)) { + vector metric_names = splitAndTrim(val, ','); + metricNames_.insert(metric_names.begin(), metric_names.end()); + } else if (!name.compare(kSamplePeriodKey)) { + samplePeriod_ = milliseconds(toInt32(val)); + } else if (!name.compare(kMultiplexPeriodKey)) { + multiplexPeriod_ = milliseconds(toInt32(val)); + } else if (!name.compare(kReportPeriodKey)) { + setReportPeriod(seconds(toInt32(val))); + } else if (!name.compare(kSamplesPerReportKey)) { + samplesPerReport_ = toInt32(val); + } else if (!name.compare(kEventsLogFileKey)) { + eventLogFile_ = val; + } else if (!name.compare(kEventsEnabledDevicesKey)) { + eventProfilerDeviceMask_ = createDeviceMask(val); + } else if (!name.compare(kOnDemandDurationKey)) { + eventProfilerOnDemandDuration_ = seconds(toInt32(val)); + eventProfilerOnDemandTimestamp_ = timestamp(); + } else if (!name.compare(kMaxEventProfilersPerGpuKey)) { + eventProfilerMaxInstancesPerGpu_ = toInt32(val); + } else if (!name.compare(kHeartbeatMonitorPeriodKey)) { + eventProfilerHeartbeatMonitorPeriod_ = seconds(toInt32(val)); + } + + // Activity Profiler + else if (!name.compare(kActivitiesDurationKey)) { + activitiesDuration_ = + duration_cast(seconds(toInt32(val))); + activitiesOnDemandTimestamp_ = timestamp(); + } else if (!name.compare(kActivityTypesKey)) { + vector activity_types = splitAndTrim(toLower(val), ','); + setActivityTypes(activity_types); + } else if (!name.compare(kActivitiesDurationMsecsKey)) { + activitiesDuration_ = milliseconds(toInt32(val)); + activitiesOnDemandTimestamp_ = timestamp(); + } else if (!name.compare(kActivitiesIterationsKey)) { + activitiesRunIterations_ = toInt32(val); + activitiesOnDemandTimestamp_ = timestamp(); + } else if (!name.compare(kLogVerboseLevelKey)) { + verboseLogLevel_ = toInt32(val); + } else if (!name.compare(kLogVerboseModulesKey)) { + verboseLogModules_ = splitAndTrim(val, ','); + } else if (!name.compare(kActivitiesEnabledKey)) { + activityProfilerEnabled_ = toBool(val); + } else if (!name.compare(kActivitiesLogFileKey)) { + activitiesLogFile_ = val; + activitiesLogUrl_ = fmt::format("file://{}", val); + activitiesOnDemandTimestamp_ = timestamp(); + } else if (!name.compare(kActivitiesMaxGpuBufferSizeKey)) { + activitiesMaxGpuBufferSize_ = toInt32(val) * 1024 * 1024; + } else if (!name.compare(kActivitiesWarmupDurationSecsKey)) { + activitiesWarmupDuration_ = seconds(toInt32(val)); + } else if (!name.compare(kActivitiesWarmupIterationsKey)) { + activitiesWarmupIterations_ = toInt32(val); + } + + // Client Interface + else if (!name.compare(kClientInterfaceEnableOpInputsCollection)) { + enableOpInputsCollection_ = toBool(val); + } + + // Common + else if (!name.compare(kRequestTimestampKey)) { + VLOG(0) << kRequestTimestampKey + << " has been deprecated - please use " + << kProfileStartTimeKey; + requestTimestamp_ = handleRequestTimestamp(toInt64(val)); + } else if (!name.compare(kProfileStartTimeKey)) { + profileStartTime_ = + time_point(milliseconds(toInt64(val))); + } else if (!name.compare(kProfileStartIterationKey)) { + profileStartIteration_ = toInt32(val); + } else if (!name.compare(kProfileStartIterationRoundUpKey)) { + profileStartIterationRoundUp_ = toInt32(val); + } else if (!name.compare(kEnableSigUsr2Key)) { + enableSigUsr2_ = toBool(val); + } else if (!name.compare(kEnableIpcFabricKey)) { + enableIpcFabric_ = toBool(val); + } else { + return false; + } + return true; +} + +std::chrono::milliseconds Config::activitiesDurationDefault() const { + return kDefaultActivitiesProfileDurationMSecs; +}; + +void Config::updateActivityProfilerRequestReceivedTime() { + activitiesOnDemandTimestamp_ = system_clock::now(); +} + +void Config::setClientDefaults() { + AbstractConfig::setClientDefaults(); + activitiesLogToMemory_ = true; +} + +void Config::validate( + const time_point& fallbackProfileStartTime) { + if (samplePeriod_.count() == 0) { + LOG(WARNING) << "Sample period must be greater than 0, setting to 1ms"; + samplePeriod_ = milliseconds(1); + } + + if (multiplexPeriod_ < samplePeriod_) { + LOG(WARNING) << "Multiplex period can not be smaller " + << "than sample period"; + LOG(WARNING) << "Setting multiplex period to " << samplePeriod_.count() + << "ms"; + multiplexPeriod_ = samplePeriod_; + } + + if ((multiplexPeriod_ % samplePeriod_).count() != 0) { + LOG(WARNING) << "Multiplex period must be a " + << "multiple of sample period"; + multiplexPeriod_ = alignUp(multiplexPeriod_, samplePeriod_); + LOG(WARNING) << "Setting multiplex period to " << multiplexPeriod_.count() + << "ms"; + } + + if ((reportPeriod_ % multiplexPeriod_).count() != 0 || + reportPeriod_.count() == 0) { + LOG(WARNING) << "Report period must be a " + << "multiple of multiplex period"; + reportPeriod_ = alignUp(reportPeriod_, multiplexPeriod_); + LOG(WARNING) << "Setting report period to " << reportPeriod_.count() + << "ms"; + } + + if (samplesPerReport_ < 1) { + LOG(WARNING) << "Samples per report must be in the range " + << "[1, report period / sample period]"; + LOG(WARNING) << "Setting samples per report to 1"; + samplesPerReport_ = 1; + } + + int max_samples_per_report = reportPeriod_ / samplePeriod_; + if (samplesPerReport_ > max_samples_per_report) { + LOG(WARNING) << "Samples per report must be in the range " + << "[1, report period / sample period] ([1, " + << reportPeriod_.count() << "ms / " << samplePeriod_.count() + << "ms = " << max_samples_per_report << "])"; + LOG(WARNING) << "Setting samples per report to " << max_samples_per_report; + samplesPerReport_ = max_samples_per_report; + } + + if (!hasProfileStartTime()) { + VLOG(0) + << "No explicit timestamp has been set. " + << "Defaulting it to now + activitiesWarmupDuration with buffer."; + profileStartTime_ = fallbackProfileStartTime + + activitiesWarmupDuration() + kDefaultBufferUntilWarmup; + } + + if (profileStartIterationRoundUp_ == 0) { + // setting to 0 will mess up modulo arithmetic, set it to -1 so it has no effect + LOG(WARNING) << "Profiler start iteration round up should be >= 1."; + profileStartIterationRoundUp_ = -1; + } + + if (profileStartIterationRoundUp_ > 0 && !hasProfileStartIteration()) { + VLOG(0) << "Setting profiler start iteration to 0 so this config is " + << "triggered via iteration count."; + profileStartIteration_ = 0; + } + + if (selectedActivityTypes_.size() == 0) { + selectDefaultActivityTypes(); + } +} + +void Config::setReportPeriod(milliseconds msecs) { + reportPeriod_ = msecs; +} + +void Config::printActivityProfilerConfig(std::ostream& s) const { + s << "Log file: " << activitiesLogFile() << std::endl; + if (hasProfileStartIteration()) { + s << "Trace start Iteration: " << profileStartIteration() << std::endl; + s << "Trace warmup Iterations: " << activitiesWarmupIterations() << std::endl; + s << "Trace profile Iterations: " << activitiesRunIterations() << std::endl; + if (profileStartIterationRoundUp() > 0) { + s << "Trace start iteration roundup : " << profileStartIterationRoundUp() + << std::endl; + } + } else if (hasProfileStartTime()) { + std::time_t t_c = system_clock::to_time_t(requestTimestamp()); + LOG(INFO) << "Trace start time: " + << fmt::format("{:%Y-%m-%d %H:%M:%S}", fmt::localtime(t_c)); + s << "Trace duration: " << activitiesDuration().count() << "ms" + << std::endl; + s << "Warmup duration: " << activitiesWarmupDuration().count() << "s" + << std::endl; + } + + s << "Max GPU buffer size: " << activitiesMaxGpuBufferSize() / 1024 / 1024 + << "MB" << std::endl; + + std::vector activities; + for (const auto& activity : selectedActivityTypes_) { + activities.push_back(toString(activity)); + } + s << "Enabled activities: " + << fmt::format("{}", fmt::join(activities, ",")) << std::endl; + + AbstractConfig::printActivityProfilerConfig(s); +} + +} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/ConfigLoader.cpp b/plugins/tensorboard-plugins/libkineto/src/ConfigLoader.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4080b678d371e98757897d4d7726c159887377e1 --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/src/ConfigLoader.cpp @@ -0,0 +1,300 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#include "ConfigLoader.h" + +#ifdef __linux__ +#include +#endif + +#include +#include +#include +#include +#include + +#include "DaemonConfigLoader.h" + +#include "Logger.h" + +using namespace std::chrono; +using std::string; + +namespace KINETO_NAMESPACE { + +using namespace libkineto; + +constexpr char kConfigFileEnvVar[] = "KINETO_CONFIG"; +#ifdef __linux__ +constexpr char kConfigFile[] = "/etc/libkineto.conf"; +constexpr char kOnDemandConfigFile[] = "/tmp/libkineto.conf"; +#else +constexpr char kConfigFile[] = "libkineto.conf"; +constexpr char kOnDemandConfigFile[] = "libkineto.conf"; +#endif + +constexpr std::chrono::seconds kConfigUpdateIntervalSecs(300); +constexpr std::chrono::seconds kOnDemandConfigUpdateIntervalSecs(5); + +#ifdef __linux__ +static struct sigaction originalUsr2Handler = {}; +#endif + +// Use SIGUSR2 to initiate profiling. +// Look for an on-demand config file. +// If none is found, default to base config. +// Try to not affect existing handlers +static bool hasOriginalSignalHandler() { +#ifdef __linux__ + return originalUsr2Handler.sa_handler != nullptr || + originalUsr2Handler.sa_sigaction != nullptr; +#else + return false; +#endif +} + +static void handle_signal(int signal) { +#ifdef __linux__ + if (signal == SIGUSR2) { + ConfigLoader::instance().handleOnDemandSignal(); + if (hasOriginalSignalHandler()) { + // Invoke original handler and reinstate ours + struct sigaction act; + sigaction(SIGUSR2, &originalUsr2Handler, &act); + raise(SIGUSR2); + sigaction(SIGUSR2, &act, &originalUsr2Handler); + } + } +#endif +} + +static void setupSignalHandler(bool enableSigUsr2) { +#ifdef __linux__ + if (enableSigUsr2) { + struct sigaction act = {}; + act.sa_handler = &handle_signal; + act.sa_flags = SA_NODEFER; + if (sigaction(SIGUSR2, &act, &originalUsr2Handler) < 0) { + PLOG(ERROR) << "Failed to register SIGUSR2 handler"; + } + if (originalUsr2Handler.sa_handler == &handle_signal) { + originalUsr2Handler = {}; + } + } else if (hasOriginalSignalHandler()) { + sigaction(SIGUSR2, &originalUsr2Handler, nullptr); + originalUsr2Handler = {}; + } +#endif +} + +// return an empty string if reading gets any errors. Otherwise a config string. +static std::string readConfigFromConfigFile(const char* filename) { + // Read whole file into a string. + std::ifstream file(filename); + std::string conf; + try { + conf.assign( + std::istreambuf_iterator(file), std::istreambuf_iterator()); + } catch (std::exception& e) { + VLOG(0) << "Error reading " << filename << ": " + << e.what(); + conf = ""; + } + return conf; +} + +static std::function()>& +daemonConfigLoaderFactory() { + static std::function()> factory = nullptr; + return factory; +} + +void ConfigLoader::setDaemonConfigLoaderFactory( + std::function()> factory) { + daemonConfigLoaderFactory() = factory; +} + +ConfigLoader& ConfigLoader::instance() { + static ConfigLoader config_loader; + return config_loader; +} + +// return an empty string if polling gets any errors. Otherwise a config string. +std::string ConfigLoader::readOnDemandConfigFromDaemon( + time_point now) { + if (!daemonConfigLoader_) { + return ""; + } + bool events = canHandlerAcceptConfig(ConfigKind::EventProfiler); + bool activities = canHandlerAcceptConfig(ConfigKind::ActivityProfiler); + return daemonConfigLoader_->readOnDemandConfig(events, activities); +} + +int ConfigLoader::contextCountForGpu(uint32_t device) { + if (!daemonConfigLoader_) { + // FIXME: Throw error? + return 0; + } + return daemonConfigLoader_->gpuContextCount(device); +} + +ConfigLoader::ConfigLoader() + : configUpdateIntervalSecs_(kConfigUpdateIntervalSecs), + onDemandConfigUpdateIntervalSecs_(kOnDemandConfigUpdateIntervalSecs), + stopFlag_(false), + onDemandSignal_(false) { +} + +void ConfigLoader::startThread() { + if (!updateThread_) { + // Create default base config here - at this point static initializers + // of extensions should have run and registered all config feature factories + std::lock_guard lock(configLock_); + if (!config_) { + config_ = std::make_unique(); + } + updateThread_ = + std::make_unique(&ConfigLoader::updateConfigThread, this); + } +} + +ConfigLoader::~ConfigLoader() { + if (updateThread_) { + stopFlag_ = true; + { + std::lock_guard lock(updateThreadMutex_); + updateThreadCondVar_.notify_one(); + } + updateThread_->join(); + } +#if !USE_GOOGLE_LOG + Logger::clearLoggerObservers(); +#endif // !USE_GOOGLE_LOG +} + +void ConfigLoader::handleOnDemandSignal() { + onDemandSignal_ = true; + { + std::lock_guard lock(updateThreadMutex_); + updateThreadCondVar_.notify_one(); + } +} + +const char* ConfigLoader::configFileName() { + if (!configFileName_) { + configFileName_ = getenv(kConfigFileEnvVar); + if (configFileName_ == nullptr) { + configFileName_ = kConfigFile; + } + } + return configFileName_; +} + +DaemonConfigLoader* ConfigLoader::daemonConfigLoader() { + if (!daemonConfigLoader_ && daemonConfigLoaderFactory()) { + daemonConfigLoader_ = daemonConfigLoaderFactory()(); + daemonConfigLoader_->setCommunicationFabric(config_->ipcFabricEnabled()); + } + return daemonConfigLoader_.get(); +} + +void ConfigLoader::updateBaseConfig() { + // First try reading local config file + // If that fails, read from daemon + // TODO: Invert these once daemon path fully rolled out + std::string config_str = readConfigFromConfigFile(configFileName()); + if (config_str.empty() && daemonConfigLoader()) { + // If local config file was not successfully loaded (e.g. not found) + // then try the daemon + config_str = daemonConfigLoader()->readBaseConfig(); + } + if (config_str != config_->source()) { + std::lock_guard lock(configLock_); + config_ = std::make_unique(); + config_->parse(config_str); + if (daemonConfigLoader()) { + daemonConfigLoader()->setCommunicationFabric(config_->ipcFabricEnabled()); + } + setupSignalHandler(config_->sigUsr2Enabled()); + SET_LOG_VERBOSITY_LEVEL( + config_->verboseLogLevel(), + config_->verboseLogModules()); + VLOG(0) << "Detected base config change"; + } +} + +void ConfigLoader::configureFromSignal( + time_point now, + Config& config) { + LOG(INFO) << "Received on-demand profiling signal, " + << "reading config from " << kOnDemandConfigFile; + // Reset start time to 0 in order to compute new default start time + const std::string config_str = "PROFILE_START_TIME=0\n" + + readConfigFromConfigFile(kOnDemandConfigFile); + config.parse(config_str); + config.setSignalDefaults(); + notifyHandlers(config); +} + +void ConfigLoader::configureFromDaemon( + time_point now, + Config& config) { + const std::string config_str = readOnDemandConfigFromDaemon(now); + if (config_str.empty()) { + return; + } + + LOG(INFO) << "Received config from dyno:\n" << config_str; + config.parse(config_str); + notifyHandlers(config); +} + +void ConfigLoader::updateConfigThread() { + auto now = system_clock::now(); + auto next_config_load_time = now; + auto next_on_demand_load_time = now + onDemandConfigUpdateIntervalSecs_; + seconds interval = configUpdateIntervalSecs_; + if (interval > onDemandConfigUpdateIntervalSecs_) { + interval = onDemandConfigUpdateIntervalSecs_; + } + auto onDemandConfig = std::make_unique(); + + // This can potentially sleep for long periods of time, so allow + // the desctructor to wake it to avoid a 5-minute long destruct period. + for (;;) { + { + std::unique_lock lock(updateThreadMutex_); + updateThreadCondVar_.wait_for(lock, interval); + } + if (stopFlag_) { + break; + } + now = system_clock::now(); + if (now > next_config_load_time) { + updateBaseConfig(); + next_config_load_time = now + configUpdateIntervalSecs_; + } + if (onDemandSignal_.exchange(false)) { + onDemandConfig = config_->clone(); + configureFromSignal(now, *onDemandConfig); + } else if (now > next_on_demand_load_time) { + onDemandConfig = std::make_unique(); + configureFromDaemon(now, *onDemandConfig); + next_on_demand_load_time = now + onDemandConfigUpdateIntervalSecs_; + } + if (onDemandConfig->verboseLogLevel() >= 0) { + LOG(INFO) << "Setting verbose level to " + << onDemandConfig->verboseLogLevel() + << " from on-demand config"; + SET_LOG_VERBOSITY_LEVEL( + onDemandConfig->verboseLogLevel(), + onDemandConfig->verboseLogModules()); + } + } +} + +bool ConfigLoader::hasNewConfig(const Config& oldConfig) { + std::lock_guard lock(configLock_); + return config_->timestamp() > oldConfig.timestamp(); +} + +} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/ConfigLoader.h b/plugins/tensorboard-plugins/libkineto/src/ConfigLoader.h new file mode 100644 index 0000000000000000000000000000000000000000..4ce3468e48db116b2a40d992f000a3af1338e70a --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/src/ConfigLoader.h @@ -0,0 +1,147 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include "Config.h" + +// TODO(T90238193) +// @lint-ignore-every CLANGTIDY facebook-hte-RelativeInclude +#include "ILoggerObserver.h" + +namespace libkineto { + class LibkinetoApi; +} + +namespace KINETO_NAMESPACE { + +using namespace libkineto; +class DaemonConfigLoader; + +class ConfigLoader { + public: + + static ConfigLoader& instance(); + + enum ConfigKind { + ActivityProfiler = 0, + EventProfiler, + NumConfigKinds + }; + + struct ConfigHandler { + virtual ~ConfigHandler() {} + virtual bool canAcceptConfig() = 0; + virtual void acceptConfig(const Config& cfg) = 0; + }; + + void addHandler(ConfigKind kind, ConfigHandler* handler) { + std::lock_guard lock(updateThreadMutex_); + handlers_[kind].push_back(handler); + startThread(); + } + + void removeHandler(ConfigKind kind, ConfigHandler* handler) { + std::lock_guard lock(updateThreadMutex_); + auto it = std::find( + handlers_[kind].begin(), handlers_[kind].end(), handler); + if (it != handlers_[kind].end()) { + handlers_[kind].erase(it); + } + } + + void notifyHandlers(const Config& cfg) { + std::lock_guard lock(updateThreadMutex_); + for (auto& key_val : handlers_) { + for (ConfigHandler* handler : key_val.second) { + handler->acceptConfig(cfg); + } + } + } + + bool canHandlerAcceptConfig(ConfigKind kind) { + std::lock_guard lock(updateThreadMutex_); + for (ConfigHandler* handler : handlers_[kind]) { + if (!handler->canAcceptConfig()) { + return false; + } + } + return true; + } + + void initBaseConfig() { + bool init = false; + { + std::lock_guard lock(configLock_); + init = !config_ || config_->source().empty(); + } + if (init) { + updateBaseConfig(); + } + } + + inline std::unique_ptr getConfigCopy() { + std::lock_guard lock(configLock_); + return config_->clone(); + } + + bool hasNewConfig(const Config& oldConfig); + int contextCountForGpu(uint32_t gpu); + + void handleOnDemandSignal(); + + static void setDaemonConfigLoaderFactory( + std::function()> factory); + + private: + ConfigLoader(); + ~ConfigLoader(); + + const char* configFileName(); + DaemonConfigLoader* daemonConfigLoader(); + + void startThread(); + void updateConfigThread(); + void updateBaseConfig(); + + // Create configuration when receiving SIGUSR2 + void configureFromSignal( + std::chrono::time_point now, + Config& config); + + // Create configuration when receiving request from a daemon + void configureFromDaemon( + std::chrono::time_point now, + Config& config); + + std::string readOnDemandConfigFromDaemon( + std::chrono::time_point now); + + std::mutex configLock_; + std::atomic configFileName_{nullptr}; + std::unique_ptr config_; + std::unique_ptr daemonConfigLoader_; + std::map> handlers_; + + std::chrono::seconds configUpdateIntervalSecs_; + std::chrono::seconds onDemandConfigUpdateIntervalSecs_; + std::unique_ptr updateThread_; + std::condition_variable updateThreadCondVar_; + std::mutex updateThreadMutex_; + std::atomic_bool stopFlag_{false}; + std::atomic_bool onDemandSignal_{false}; + +#if !USE_GOOGLE_LOG + std::unique_ptr> loggerObservers_; + std::mutex loggerObserversMutex_; +#endif // !USE_GOOGLE_LOG +}; + +} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/CudaDeviceProperties.cpp b/plugins/tensorboard-plugins/libkineto/src/CudaDeviceProperties.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1e909d5f9cfda13b95cc4abab547d964fe47b48a --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/src/CudaDeviceProperties.cpp @@ -0,0 +1,130 @@ +/* + * Copyright (c) Kineto Contributors + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "CudaDeviceProperties.h" + +#include +#include + +#include +#include + +#include "Logger.h" + +namespace KINETO_NAMESPACE { + +static const std::vector createDeviceProps() { + std::vector props; + int device_count; + cudaError_t error_id = cudaGetDeviceCount(&device_count); + // Return empty vector if error. + if (error_id != cudaSuccess) { + LOG(ERROR) << "cudaGetDeviceCount failed with code " << error_id; + return {}; + } + VLOG(0) << "Device count is " << device_count; + for (size_t i = 0; i < device_count; ++i) { + cudaDeviceProp prop; + error_id = cudaGetDeviceProperties(&prop, i); + // Return empty vector if any device property fail to get. + if (error_id != cudaSuccess) { + LOG(ERROR) << "cudaGetDeviceProperties failed with " << error_id; + return {}; + } + props.push_back(prop); + LOGGER_OBSERVER_ADD_DEVICE(i); + } + return props; +} + +static const std::vector& deviceProps() { + static const std::vector props = createDeviceProps(); + return props; +} + +static const std::string createDevicePropertiesJson( + size_t id, const cudaDeviceProp& props) { + return fmt::format(R"JSON( + {{ + "id": {}, "name": "{}", "totalGlobalMem": {}, + "computeMajor": {}, "computeMinor": {}, + "maxThreadsPerBlock": {}, "maxThreadsPerMultiprocessor": {}, + "regsPerBlock": {}, "regsPerMultiprocessor": {}, "warpSize": {}, + "sharedMemPerBlock": {}, "sharedMemPerMultiprocessor": {}, + "numSms": {}, "sharedMemPerBlockOptin": {} + }})JSON", + id, props.name, props.totalGlobalMem, + props.major, props.minor, + props.maxThreadsPerBlock, props.maxThreadsPerMultiProcessor, + props.regsPerBlock, props.regsPerMultiprocessor, props.warpSize, + props.sharedMemPerBlock, props.sharedMemPerMultiprocessor, + props.multiProcessorCount, props.sharedMemPerBlockOptin); +} + +static const std::string createDevicePropertiesJson() { + std::vector jsonProps; + const auto& props = deviceProps(); + for (size_t i = 0; i < props.size(); i++) { + jsonProps.push_back(createDevicePropertiesJson(i, props[i])); + } + return fmt::format("{}", fmt::join(jsonProps, ",")); +} + +const std::string& devicePropertiesJson() { + static std::string devicePropsJson = createDevicePropertiesJson(); + return devicePropsJson; +} + +int smCount(uint32_t deviceId) { + const std::vector &props = deviceProps(); + return deviceId >= props.size() ? 0 : + props[deviceId].multiProcessorCount; +} + +float kernelOccupancy( + uint32_t deviceId, + uint16_t registersPerThread, + int32_t staticSharedMemory, + int32_t dynamicSharedMemory, + int32_t blockX, + int32_t blockY, + int32_t blockZ, + float blocksPerSm) { + // Calculate occupancy + float occupancy = -1.0; + const std::vector &props = deviceProps(); + if (deviceId < props.size()) { + cudaOccFuncAttributes occFuncAttr; + occFuncAttr.maxThreadsPerBlock = INT_MAX; + occFuncAttr.numRegs = registersPerThread; + occFuncAttr.sharedSizeBytes = staticSharedMemory; + occFuncAttr.partitionedGCConfig = PARTITIONED_GC_OFF; + occFuncAttr.shmemLimitConfig = FUNC_SHMEM_LIMIT_DEFAULT; + occFuncAttr.maxDynamicSharedSizeBytes = 0; + const cudaOccDeviceState occDeviceState = {}; + int blockSize = blockX * blockY * blockZ; + size_t dynamicSmemSize = dynamicSharedMemory; + cudaOccResult occ_result; + cudaOccDeviceProp prop(props[deviceId]); + cudaOccError status = cudaOccMaxActiveBlocksPerMultiprocessor( + &occ_result, &prop, &occFuncAttr, &occDeviceState, + blockSize, dynamicSmemSize); + if (status == CUDA_OCC_SUCCESS) { + if (occ_result.activeBlocksPerMultiprocessor < blocksPerSm) { + blocksPerSm = occ_result.activeBlocksPerMultiprocessor; + } + occupancy = blocksPerSm * blockSize / + (float) props[deviceId].maxThreadsPerMultiProcessor; + } else { + LOG_EVERY_N(ERROR, 1000) << "Failed to calculate occupancy, status = " + << status; + } + } + return occupancy; +} + +} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/CudaDeviceProperties.h b/plugins/tensorboard-plugins/libkineto/src/CudaDeviceProperties.h new file mode 100644 index 0000000000000000000000000000000000000000..b731fde0c2aab4c9bd3e97f475d204dad02986e7 --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/src/CudaDeviceProperties.h @@ -0,0 +1,31 @@ +/* + * Copyright (c) Kineto Contributors + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace KINETO_NAMESPACE { + +int smCount(uint32_t deviceId); + +// Return estimated achieved occupancy for a kernel +float kernelOccupancy( + uint32_t deviceId, + uint16_t registersPerThread, + int32_t staticSharedMemory, + int32_t dynamicSharedMemory, + int32_t blockX, + int32_t blockY, + int32_t blockZ, + float blocks_per_sm); + +// Return compute properties for each device as a json string +const std::string& devicePropertiesJson(); + +} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/CuptiActivity.h b/plugins/tensorboard-plugins/libkineto/src/CuptiActivity.h new file mode 100644 index 0000000000000000000000000000000000000000..09c29504060ecbbac609aa2d021ff643f45c143e --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/src/CuptiActivity.h @@ -0,0 +1,114 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#pragma once + +#include + +#include "ITraceActivity.h" +#include "CuptiActivityPlatform.h" +#include "ThreadUtil.h" +#include "cupti_strings.h" + +namespace libkineto { + class ActivityLogger; +} + +namespace KINETO_NAMESPACE { + +using namespace libkineto; +struct TraceSpan; + +// These classes wrap the various CUPTI activity types +// into subclasses of ITraceActivity so that they can all be accessed +// using the ITraceActivity interface and logged via ActivityLogger. + +// Abstract base class, templated on Cupti activity type +template +struct CuptiActivity : public ITraceActivity { + explicit CuptiActivity(const T* activity, const ITraceActivity* linked) + : activity_(*activity), linked_(linked) {} + int64_t timestamp() const override { + return nsToUs(unixEpochTimestamp(activity_.start)); + } + int64_t duration() const override { + return nsToUs(activity_.end - activity_.start); + } + // TODO(T107507796): Deprecate ITraceActivity + int64_t correlationId() const override {return 0;} + int32_t getThreadId() const override {return 0;} + const ITraceActivity* linkedActivity() const override {return linked_;} + int flowType() const override {return kLinkAsyncCpuGpu;} + int flowId() const override {return correlationId();} + const T& raw() const {return activity_;} + const TraceSpan* traceSpan() const override {return nullptr;} + + protected: + const T& activity_; + const ITraceActivity* linked_{nullptr}; +}; + +// CUpti_ActivityAPI - CUDA runtime activities +struct RuntimeActivity : public CuptiActivity { + explicit RuntimeActivity( + const CUpti_ActivityAPI* activity, + const ITraceActivity* linked, + int32_t threadId) + : CuptiActivity(activity, linked), threadId_(threadId) {} + int64_t correlationId() const override {return activity_.correlationId;} + int64_t deviceId() const override {return processId();} + int64_t resourceId() const override {return threadId_;} + ActivityType type() const override {return ActivityType::CUDA_RUNTIME;} + bool flowStart() const override; + const std::string name() const override {return runtimeCbidName(activity_.cbid);} + void log(ActivityLogger& logger) const override; + const std::string metadataJson() const override; + + private: + const int32_t threadId_; +}; + +// CUpti_ActivityAPI - CUDA runtime activities +struct OverheadActivity : public CuptiActivity { + explicit OverheadActivity( + const CUpti_ActivityOverhead* activity, + const ITraceActivity* linked, + int32_t threadId=0) + : CuptiActivity(activity, linked), threadId_(threadId) {} + + int64_t timestamp() const override { + return nsToUs(unixEpochTimestamp(activity_.start)); + } + int64_t duration() const override { + return nsToUs(activity_.end - activity_.start); + } + // TODO: Update this with PID ordering + int64_t deviceId() const override {return -1;} + int64_t resourceId() const override {return threadId_;} + ActivityType type() const override {return ActivityType::OVERHEAD;} + bool flowStart() const override; + const std::string name() const override {return overheadKindString(activity_.overheadKind);} + void log(ActivityLogger& logger) const override; + const std::string metadataJson() const override; + + private: + const int32_t threadId_; +}; + +// Base class for GPU activities. +// Can also be instantiated directly. +template +struct GpuActivity : public CuptiActivity { + explicit GpuActivity(const T* activity, const ITraceActivity* linked) + : CuptiActivity(activity, linked) {} + int64_t correlationId() const override {return raw().correlationId;} + int64_t deviceId() const override {return raw().deviceId;} + int64_t resourceId() const override {return raw().streamId;} + ActivityType type() const override; + bool flowStart() const override {return false;} + const std::string name() const override; + void log(ActivityLogger& logger) const override; + const std::string metadataJson() const override; + const T& raw() const {return CuptiActivity::raw();} +}; + +} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/CuptiActivity.tpp b/plugins/tensorboard-plugins/libkineto/src/CuptiActivity.tpp new file mode 100644 index 0000000000000000000000000000000000000000..1ff2dafe06b0016ce7b904ef4b55e047c69bcc1c --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/src/CuptiActivity.tpp @@ -0,0 +1,111 @@ + /* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "CuptiActivity.h" + +#include + +#include "Demangle.h" +#include "output_base.h" + +namespace KINETO_NAMESPACE { + +using namespace libkineto; + +template<> +inline const std::string GpuActivity::name() const { + return demangle(raw().name); +} + +template<> +inline ActivityType GpuActivity::type() const { + return ActivityType::CONCURRENT_KERNEL; +} + +static inline std::string memcpyName(uint8_t kind, uint8_t src, uint8_t dst) { + return fmt::format( + "Memcpy {} ({} -> {})", + memcpyKindString((CUpti_ActivityMemcpyKind)kind), + memoryKindString((CUpti_ActivityMemoryKind)src), + memoryKindString((CUpti_ActivityMemoryKind)dst)); +} + +template<> +inline ActivityType GpuActivity::type() const { + return ActivityType::GPU_MEMCPY; +} + +template<> +inline const std::string GpuActivity::name() const { + return memcpyName(raw().copyKind, raw().srcKind, raw().dstKind); +} + +template<> +inline ActivityType GpuActivity::type() const { + return ActivityType::GPU_MEMCPY; +} + +template<> +inline const std::string GpuActivity::name() const { + return memcpyName(raw().copyKind, raw().srcKind, raw().dstKind); +} + +template<> +inline const std::string GpuActivity::name() const { + const char* memory_kind = + memoryKindString((CUpti_ActivityMemoryKind)raw().memoryKind); + return fmt::format("Memset ({})", memory_kind); +} + +template<> +inline ActivityType GpuActivity::type() const { + return ActivityType::GPU_MEMSET; +} + +inline void RuntimeActivity::log(ActivityLogger& logger) const { + logger.handleActivity(*this); +} + +inline void OverheadActivity::log(ActivityLogger& logger) const { + logger.handleActivity(*this); +} + +inline bool OverheadActivity::flowStart() const { + return false; +} + +inline const std::string OverheadActivity::metadataJson() const { + return ""; +} + +template +inline void GpuActivity::log(ActivityLogger& logger) const { + logger.handleGpuActivity(*this); +} + +inline bool RuntimeActivity::flowStart() const { + return activity_.cbid == CUPTI_RUNTIME_TRACE_CBID_cudaLaunchKernel_v7000 || + (activity_.cbid >= CUPTI_RUNTIME_TRACE_CBID_cudaMemcpy_v3020 && + activity_.cbid <= CUPTI_RUNTIME_TRACE_CBID_cudaMemset2DAsync_v3020) || + activity_.cbid == + CUPTI_RUNTIME_TRACE_CBID_cudaLaunchCooperativeKernel_v9000 || + activity_.cbid == + CUPTI_RUNTIME_TRACE_CBID_cudaLaunchCooperativeKernelMultiDevice_v9000; +} + +inline const std::string RuntimeActivity::metadataJson() const { + return fmt::format(R"JSON( + "cbid": {}, "correlation": {})JSON", + activity_.cbid, activity_.correlationId); +} + +template +inline const std::string GpuActivity::metadataJson() const { + return ""; +} + +} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/CuptiActivityApi.cpp b/plugins/tensorboard-plugins/libkineto/src/CuptiActivityApi.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5718bed2f89b06cc702d1b82976cd42e5fceebd0 --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/src/CuptiActivityApi.cpp @@ -0,0 +1,343 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#include "CuptiActivityApi.h" + +#include +#include + +#include "cupti_call.h" +#include "Logger.h" + +using namespace std::chrono; + +namespace KINETO_NAMESPACE { + +// TODO: do we want this to be configurable? +// Set to 2MB to avoid constantly creating buffers (espeically for networks +// that has many small memcpy such as sparseNN) +// Consider putting this on huge pages? +constexpr size_t kBufSize(2 * 1024 * 1024); + +CuptiActivityApi& CuptiActivityApi::singleton() { + static CuptiActivityApi instance; + return instance; +} + +void CuptiActivityApi::pushCorrelationID(int id, CorrelationFlowType type) { +#ifdef HAS_CUPTI + if (!singleton().externalCorrelationEnabled_) { + return; + } + VLOG(2) << "pushCorrelationID(" << id << ")"; + switch(type) { + case Default: + CUPTI_CALL(cuptiActivityPushExternalCorrelationId( + CUPTI_EXTERNAL_CORRELATION_KIND_CUSTOM0, id)); + break; + case User: + CUPTI_CALL(cuptiActivityPushExternalCorrelationId( + CUPTI_EXTERNAL_CORRELATION_KIND_CUSTOM1, id)); + } +#endif +} + +void CuptiActivityApi::popCorrelationID(CorrelationFlowType type) { +#ifdef HAS_CUPTI + if (!singleton().externalCorrelationEnabled_) { + return; + } + switch(type) { + case Default: + CUPTI_CALL(cuptiActivityPopExternalCorrelationId( + CUPTI_EXTERNAL_CORRELATION_KIND_CUSTOM0, nullptr)); + break; + case User: + CUPTI_CALL(cuptiActivityPopExternalCorrelationId( + CUPTI_EXTERNAL_CORRELATION_KIND_CUSTOM1, nullptr)); + } +#endif +} + +static int getSMCount() { +#ifdef HAS_CUPTI + // There may be a simpler way to get the number of SMs.... + // Look for domain_d - this has 80 instances on Volta and + // 56 instances on Pascal, corresponding to the number of SMs + // FIXME: This does not work on Turing and later + uint32_t domainCount{0}; + CUPTI_CALL(cuptiDeviceGetNumEventDomains(0, &domainCount)); + std::vector ids(domainCount); + size_t sz = sizeof(CUpti_EventDomainID) * domainCount; + CUPTI_CALL(cuptiDeviceEnumEventDomains(0, &sz, ids.data())); + for (CUpti_EventDomainID id : ids) { + char name[16]; + name[0] = '\0'; + sz = sizeof(name); + CUPTI_CALL(cuptiEventDomainGetAttribute( + id, CUPTI_EVENT_DOMAIN_ATTR_NAME, &sz, name)); + if (strncmp(name, "domain_d", sz) == 0) { + uint32_t count{0}; + sz = sizeof(count); + CUPTI_CALL(cuptiDeviceGetEventDomainAttribute( + 0, id, CUPTI_EVENT_DOMAIN_ATTR_TOTAL_INSTANCE_COUNT, &sz, &count)); + return count; + } + } +#endif + + return -1; +} + +int CuptiActivityApi::smCount() { + static int sm_count = getSMCount(); + return sm_count; +} + +static bool nextActivityRecord( + uint8_t* buffer, + size_t valid_size, + CUpti_Activity*& record) { +#ifdef HAS_CUPTI + CUptiResult status = CUPTI_CALL_NOWARN( + cuptiActivityGetNextRecord(buffer, valid_size, &record)); + if (status != CUPTI_SUCCESS) { + if (status != CUPTI_ERROR_MAX_LIMIT_REACHED) { + CUPTI_CALL(status); + } + record = nullptr; + } +#endif + return record != nullptr; +} + +void CuptiActivityApi::setMaxBufferSize(int size) { + maxGpuBufferCount_ = 1 + size / kBufSize; +} + +void CuptiActivityApi::forceLoadCupti() { +#ifdef HAS_CUPTI + CUPTI_CALL(cuptiActivityEnable(CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL)); +#endif +} + +#ifdef HAS_CUPTI +void CUPTIAPI CuptiActivityApi::bufferRequestedTrampoline( + uint8_t** buffer, + size_t* size, + size_t* maxNumRecords) { + singleton().bufferRequested(buffer, size, maxNumRecords); +} + +void CuptiActivityApi::bufferRequested( + uint8_t** buffer, size_t* size, size_t* maxNumRecords) { + std::lock_guard guard(mutex_); + if (allocatedGpuTraceBuffers_.size() >= maxGpuBufferCount_) { + stopCollection = true; + LOG(WARNING) << "Exceeded max GPU buffer count (" + << allocatedGpuTraceBuffers_.size() + << " > " << maxGpuBufferCount_ + << ") - terminating tracing"; + } + + auto buf = std::make_unique(kBufSize); + *buffer = buf->data(); + *size = kBufSize; + + allocatedGpuTraceBuffers_[*buffer] = std::move(buf); + + *maxNumRecords = 0; +} +#endif + +std::unique_ptr +CuptiActivityApi::activityBuffers() { + { + std::lock_guard guard(mutex_); + if (allocatedGpuTraceBuffers_.empty()) { + return nullptr; + } + } + +#ifdef HAS_CUPTI + VLOG(1) << "Flushing GPU activity buffers"; + time_point t1; + if (VLOG_IS_ON(1)) { + t1 = system_clock::now(); + } + // Can't hold mutex_ during this call, since bufferCompleted + // will be called by libcupti and mutex_ is acquired there. + CUPTI_CALL(cuptiActivityFlushAll(CUPTI_ACTIVITY_FLAG_FLUSH_FORCED)); + if (VLOG_IS_ON(1)) { + flushOverhead = + duration_cast(system_clock::now() - t1).count(); + } +#endif + std::lock_guard guard(mutex_); + // Transfer ownership of buffers to caller. A new map is created on-demand. + return std::move(readyGpuTraceBuffers_); +} + +#ifdef HAS_CUPTI +int CuptiActivityApi::processActivitiesForBuffer( + uint8_t* buf, + size_t validSize, + std::function handler) { + int count = 0; + if (buf && validSize) { + CUpti_Activity* record{nullptr}; + while ((nextActivityRecord(buf, validSize, record))) { + handler(record); + ++count; + } + } + return count; +} +#endif + +const std::pair CuptiActivityApi::processActivities( + CuptiActivityBufferMap& buffers, + std::function handler) { + std::pair res{0, 0}; +#ifdef HAS_CUPTI + for (auto& pair : buffers) { + // No lock needed - only accessed from this thread + auto& buf = pair.second; + res.first += processActivitiesForBuffer(buf->data(), buf->size(), handler); + res.second += buf->size(); + } +#endif + return res; +} + +void CuptiActivityApi::clearActivities() { + { + std::lock_guard guard(mutex_); + if (allocatedGpuTraceBuffers_.empty()) { + return; + } + } + // Can't hold mutex_ during this call, since bufferCompleted + // will be called by libcupti and mutex_ is acquired there. +#ifdef HAS_CUPTI + CUPTI_CALL(cuptiActivityFlushAll(0)); +#endif + // FIXME: We might want to make sure we reuse + // the same memory during warmup and tracing. + // Also, try to use the amount of memory required + // for active tracing during warmup. + std::lock_guard guard(mutex_); + // Throw away ready buffers as a result of above flush + readyGpuTraceBuffers_ = nullptr; +} + +#ifdef HAS_CUPTI +void CUPTIAPI CuptiActivityApi::bufferCompletedTrampoline( + CUcontext ctx, + uint32_t streamId, + uint8_t* buffer, + size_t /* unused */, + size_t validSize) { + singleton().bufferCompleted(ctx, streamId, buffer, 0, validSize); +} + +void CuptiActivityApi::bufferCompleted( + CUcontext ctx, + uint32_t streamId, + uint8_t* buffer, + size_t /* unused */, + size_t validSize) { + + std::lock_guard guard(mutex_); + auto it = allocatedGpuTraceBuffers_.find(buffer); + if (it == allocatedGpuTraceBuffers_.end()) { + LOG(ERROR) << "bufferCompleted called with unknown buffer: " + << (void*) buffer; + return; + } + + if (!readyGpuTraceBuffers_) { + readyGpuTraceBuffers_ = std::make_unique(); + } + // Set valid size of buffer before moving to ready map + it->second->setSize(validSize); + (*readyGpuTraceBuffers_)[it->first] = std::move(it->second); + allocatedGpuTraceBuffers_.erase(it); + + // report any records dropped from the queue; to avoid unnecessary cupti + // API calls, we make it report only in verbose mode (it doesn't happen + // often in our testing anyways) + if (VLOG_IS_ON(1)) { + size_t dropped = 0; + CUPTI_CALL(cuptiActivityGetNumDroppedRecords(ctx, streamId, &dropped)); + if (dropped != 0) { + LOG(WARNING) << "Dropped " << dropped << " activity records"; + } + } +} +#endif + +void CuptiActivityApi::enableCuptiActivities( + const std::set& selected_activities) { +#ifdef HAS_CUPTI + static bool registered = false; + if (!registered) { + CUPTI_CALL( + cuptiActivityRegisterCallbacks(bufferRequestedTrampoline, bufferCompletedTrampoline)); + } + + externalCorrelationEnabled_ = false; + for (const auto& activity : selected_activities) { + if (activity == ActivityType::GPU_MEMCPY) { + CUPTI_CALL(cuptiActivityEnable(CUPTI_ACTIVITY_KIND_MEMCPY)); + } + if (activity == ActivityType::GPU_MEMSET) { + CUPTI_CALL(cuptiActivityEnable(CUPTI_ACTIVITY_KIND_MEMSET)); + } + if (activity == ActivityType::CONCURRENT_KERNEL) { + CUPTI_CALL(cuptiActivityEnable(CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL)); + } + if (activity == ActivityType::EXTERNAL_CORRELATION) { + CUPTI_CALL(cuptiActivityEnable(CUPTI_ACTIVITY_KIND_EXTERNAL_CORRELATION)); + externalCorrelationEnabled_ = true; + } + if (activity == ActivityType::CUDA_RUNTIME) { + CUPTI_CALL(cuptiActivityEnable(CUPTI_ACTIVITY_KIND_RUNTIME)); + } + if (activity == ActivityType::OVERHEAD) { + CUPTI_CALL(cuptiActivityEnable(CUPTI_ACTIVITY_KIND_OVERHEAD)); + } + } +#endif + + // Explicitly enabled, so reset this flag if set + stopCollection = false; +} + +void CuptiActivityApi::disableCuptiActivities( + const std::set& selected_activities) { +#ifdef HAS_CUPTI + for (const auto& activity : selected_activities) { + if (activity == ActivityType::GPU_MEMCPY) { + CUPTI_CALL(cuptiActivityDisable(CUPTI_ACTIVITY_KIND_MEMCPY)); + } + if (activity == ActivityType::GPU_MEMSET) { + CUPTI_CALL(cuptiActivityDisable(CUPTI_ACTIVITY_KIND_MEMSET)); + } + if (activity == ActivityType::CONCURRENT_KERNEL) { + CUPTI_CALL(cuptiActivityDisable(CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL)); + } + if (activity == ActivityType::EXTERNAL_CORRELATION) { + CUPTI_CALL(cuptiActivityDisable(CUPTI_ACTIVITY_KIND_EXTERNAL_CORRELATION)); + } + if (activity == ActivityType::CUDA_RUNTIME) { + CUPTI_CALL(cuptiActivityDisable(CUPTI_ACTIVITY_KIND_RUNTIME)); + } + if (activity == ActivityType::OVERHEAD) { + CUPTI_CALL(cuptiActivityDisable(CUPTI_ACTIVITY_KIND_OVERHEAD)); + } + } + externalCorrelationEnabled_ = false; +#endif +} + +} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/CuptiActivityApi.h b/plugins/tensorboard-plugins/libkineto/src/CuptiActivityApi.h new file mode 100644 index 0000000000000000000000000000000000000000..92af51ecac9ec99181c4726c3849894de9e32b33 --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/src/CuptiActivityApi.h @@ -0,0 +1,100 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#pragma once + +#include +#include +#include +#include +#include +#include + +#ifdef HAS_CUPTI +#include +#endif + +#include "ActivityType.h" +#include "CuptiActivityBuffer.h" + + +namespace KINETO_NAMESPACE { + +using namespace libkineto; + +#ifndef HAS_CUPTI +using CUpti_Activity = void; +#endif + +class CuptiActivityApi { + public: + enum CorrelationFlowType { + Default, + User + }; + + CuptiActivityApi() = default; + CuptiActivityApi(const CuptiActivityApi&) = delete; + CuptiActivityApi& operator=(const CuptiActivityApi&) = delete; + + virtual ~CuptiActivityApi() {} + + static CuptiActivityApi& singleton(); + + virtual int smCount(); + static void pushCorrelationID(int id, CorrelationFlowType type); + static void popCorrelationID(CorrelationFlowType type); + + void enableCuptiActivities( + const std::set& selected_activities); + void disableCuptiActivities( + const std::set& selected_activities); + void clearActivities(); + + virtual std::unique_ptr activityBuffers(); + + virtual const std::pair processActivities( + CuptiActivityBufferMap&, + std::function handler); + + void setMaxBufferSize(int size); + + std::atomic_bool stopCollection{false}; + int64_t flushOverhead{0}; + + static void forceLoadCupti(); + + private: +#ifdef HAS_CUPTI + int processActivitiesForBuffer( + uint8_t* buf, + size_t validSize, + std::function handler); + static void CUPTIAPI + bufferRequestedTrampoline(uint8_t** buffer, size_t* size, size_t* maxNumRecords); + static void CUPTIAPI bufferCompletedTrampoline( + CUcontext ctx, + uint32_t streamId, + uint8_t* buffer, + size_t /* unused */, + size_t validSize); +#endif // HAS_CUPTI + + int maxGpuBufferCount_{0}; + CuptiActivityBufferMap allocatedGpuTraceBuffers_; + std::unique_ptr readyGpuTraceBuffers_; + std::mutex mutex_; + bool externalCorrelationEnabled_{false}; + + protected: +#ifdef HAS_CUPTI + void bufferRequested(uint8_t** buffer, size_t* size, size_t* maxNumRecords); + void bufferCompleted( + CUcontext ctx, + uint32_t streamId, + uint8_t* buffer, + size_t /* unused */, + size_t validSize); +#endif +}; + +} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/CuptiActivityBuffer.h b/plugins/tensorboard-plugins/libkineto/src/CuptiActivityBuffer.h new file mode 100644 index 0000000000000000000000000000000000000000..1c3fbef62c8d8f42ff5da1718e20315cc1ba95d5 --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/src/CuptiActivityBuffer.h @@ -0,0 +1,51 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include "ITraceActivity.h" + +namespace KINETO_NAMESPACE { + +class CuptiActivityBuffer { + public: + explicit CuptiActivityBuffer(size_t size) : size_(size) { + buf_.reserve(size); + } + CuptiActivityBuffer() = delete; + CuptiActivityBuffer& operator=(const CuptiActivityBuffer&) = delete; + CuptiActivityBuffer(CuptiActivityBuffer&&) = default; + CuptiActivityBuffer& operator=(CuptiActivityBuffer&&) = default; + + size_t size() const { + return size_; + } + + void setSize(size_t size) { + assert(size <= buf_.capacity()); + size_ = size; + } + + uint8_t* data() { + return buf_.data(); + } + + private: + + std::vector buf_; + size_t size_; + + std::vector> wrappers_; +}; + +using CuptiActivityBufferMap = + std::map>; + +} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/CuptiActivityPlatform.cpp b/plugins/tensorboard-plugins/libkineto/src/CuptiActivityPlatform.cpp new file mode 100644 index 0000000000000000000000000000000000000000..fa2ef2f3a8c9cbb7f10567c158d6ee3e8e26eed0 --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/src/CuptiActivityPlatform.cpp @@ -0,0 +1,31 @@ +#include + +namespace chrono = std::chrono; + +namespace KINETO_NAMESPACE { + +#ifdef _WIN32 +uint64_t epochs_diff() { + // On Windows, steady_clock wraps the QueryPerformanceCounter function. + // https://docs.microsoft.com/en-us/cpp/standard-library/steady-clock-struct?view=msvc-160 + auto steady = + chrono::time_point_cast(chrono::steady_clock::now()); + auto system = + chrono::time_point_cast(chrono::system_clock::now()); + + auto time_since_unix = system.time_since_epoch().count(); + auto time_since_boot = steady.time_since_epoch().count(); + return time_since_unix - time_since_boot; +} + +uint64_t unixEpochTimestamp(uint64_t ts) { + static uint64_t diff = epochs_diff(); + return ts + diff; +} +#else +uint64_t unixEpochTimestamp(uint64_t ts) { + return ts; +} +#endif // _WIN32 + +} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/CuptiActivityPlatform.h b/plugins/tensorboard-plugins/libkineto/src/CuptiActivityPlatform.h new file mode 100644 index 0000000000000000000000000000000000000000..78de8373d5fe391d48edffc897aff6893aa6f54f --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/src/CuptiActivityPlatform.h @@ -0,0 +1,12 @@ +#pragma once + +#include + +namespace KINETO_NAMESPACE { + +// cupti's timestamps are platform specific. This function convert the raw +// cupti timestamp to time since unix epoch. So that on different platform, +// correction can work correctly. +uint64_t unixEpochTimestamp(uint64_t ts); + +} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/CuptiActivityProfiler.cpp b/plugins/tensorboard-plugins/libkineto/src/CuptiActivityProfiler.cpp new file mode 100644 index 0000000000000000000000000000000000000000..97c23ef047d75aff75b56773a20801ce83fb1653 --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/src/CuptiActivityProfiler.cpp @@ -0,0 +1,841 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#include "CuptiActivityProfiler.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef HAS_CUPTI +#include +#endif + +#include "Config.h" +#include "time_since_epoch.h" +#ifdef HAS_CUPTI +#include "CuptiActivity.h" +#include "CuptiActivity.tpp" +#include "CuptiActivityApi.h" +#endif // HAS_CUPTI +#ifdef HAS_ROCTRACER +#include "RoctracerActivityApi.h" +#endif +#include "output_base.h" + +#include "Logger.h" +#include "ThreadUtil.h" + +using namespace std::chrono; +using namespace libkineto; +using std::string; + +namespace KINETO_NAMESPACE { + +void CuptiActivityProfiler::transferCpuTrace( + std::unique_ptr cpuTrace) { + std::lock_guard guard(mutex_); + const string& trace_name = cpuTrace->span.name; + if (currentRunloopState_ != RunloopState::CollectTrace && + currentRunloopState_ != RunloopState::ProcessTrace) { + VLOG(0) << "Trace collection not in progress - discarding span " + << trace_name; + return; + } + + cpuTrace->span.iteration = iterationCountMap_[trace_name]++; + + VLOG(0) << "Received iteration " << cpuTrace->span.iteration << " of span " + << trace_name << " (" << cpuTrace->activities.size() << " activities / " + << cpuTrace->gpuOpCount << " gpu activities)"; + traceBuffers_->cpu.push_back(std::move(cpuTrace)); +} + +#ifdef HAS_ROCTRACER +CuptiActivityProfiler::CuptiActivityProfiler(RoctracerActivityApi& cupti, bool cpuOnly) +#else +CuptiActivityProfiler::CuptiActivityProfiler(CuptiActivityApi& cupti, bool cpuOnly) +#endif + : cupti_(cupti), + flushOverhead_{0, 0}, + setupOverhead_{0, 0}, + cpuOnly_{cpuOnly}, + currentRunloopState_{RunloopState::WaitForRequest}, + stopCollection_{false} {} + +void CuptiActivityProfiler::processTraceInternal(ActivityLogger& logger) { + LOG(INFO) << "Processing " << traceBuffers_->cpu.size() + << " CPU buffers"; + VLOG(0) << "Profile time range: " << captureWindowStartTime_ << " - " + << captureWindowEndTime_; + logger.handleTraceStart(metadata_); + for (auto& cpu_trace : traceBuffers_->cpu) { + string trace_name = cpu_trace->span.name; + VLOG(0) << "Processing CPU buffer for " << trace_name << " (" + << cpu_trace->span.iteration << ") - " + << cpu_trace->activities.size() << " records"; + VLOG(0) << "Span time range: " << cpu_trace->span.startTime << " - " + << cpu_trace->span.endTime; + processCpuTrace(*cpu_trace, logger); + LOGGER_OBSERVER_ADD_EVENT_COUNT(cpu_trace->activities.size()); + } + +#ifdef HAS_CUPTI + if (!cpuOnly_) { + VLOG(0) << "Retrieving GPU activity buffers"; + traceBuffers_->gpu = cupti_.activityBuffers(); + if (VLOG_IS_ON(1)) { + addOverheadSample(flushOverhead_, cupti_.flushOverhead); + } + if (traceBuffers_->gpu) { + const auto count_and_size = cupti_.processActivities( + *traceBuffers_->gpu, + std::bind(&CuptiActivityProfiler::handleCuptiActivity, this, std::placeholders::_1, &logger)); + LOG(INFO) << "Processed " << count_and_size.first + << " GPU records (" << count_and_size.second << " bytes)"; + LOGGER_OBSERVER_ADD_EVENT_COUNT(count_and_size.first); + } + } +#endif // HAS_CUPTI +#ifdef HAS_ROCTRACER + if (!cpuOnly_) { + VLOG(0) << "Retrieving GPU activity buffers"; + const int count = cupti_.processActivities(logger); + LOG(INFO) << "Processed " << count + << " GPU records"; + LOGGER_OBSERVER_ADD_EVENT_COUNT(count); + } +#endif // HAS_ROCTRACER + + for (const auto& session : sessions_){ + LOG(INFO) << "Processing child profiler trace"; + session->processTrace(logger); + } + + finalizeTrace(*config_, logger); +} + +CuptiActivityProfiler::CpuGpuSpanPair& CuptiActivityProfiler::recordTraceSpan( + TraceSpan& span, int gpuOpCount) { + TraceSpan gpu_span(gpuOpCount, span.iteration, span.name, "GPU: "); + auto& iterations = traceSpans_[span.name]; + iterations.push_back({span, gpu_span}); + return iterations.back(); +} + +void CuptiActivityProfiler::processCpuTrace( + libkineto::CpuTraceBuffer& cpuTrace, + ActivityLogger& logger) { + if (cpuTrace.activities.size() == 0) { + LOG(WARNING) << "CPU trace is empty!"; + return; + } + + CpuGpuSpanPair& span_pair = recordTraceSpan(cpuTrace.span, cpuTrace.gpuOpCount); + TraceSpan& cpu_span = span_pair.first; + for (auto const& act : cpuTrace.activities) { + VLOG(2) << act.correlationId() << ": OP " << act.activityName; + if (config_->selectedActivityTypes().count(act.type())) { + act.log(logger); + } + clientActivityTraceMap_[act.correlationId()] = &span_pair; + activityMap_[act.correlationId()] = &act; + + recordThreadInfo(act.resourceId(), act.getThreadId(), act.deviceId()); + } + logger.handleTraceSpan(cpu_span); +} + +#ifdef HAS_CUPTI +inline void CuptiActivityProfiler::handleCorrelationActivity( + const CUpti_ActivityExternalCorrelation* correlation) { + if (correlation->externalKind == CUPTI_EXTERNAL_CORRELATION_KIND_CUSTOM0) { + cpuCorrelationMap_[correlation->correlationId] = correlation->externalId; + } else if (correlation->externalKind == CUPTI_EXTERNAL_CORRELATION_KIND_CUSTOM1){ + userCorrelationMap_[correlation->correlationId] = correlation->externalId; + } else { + LOG(ERROR) << "Invalid CUpti_ActivityExternalCorrelation sent to handleCuptiActivity"; + } +} +#endif // HAS_CUPTI + +static GenericTraceActivity createUserGpuSpan( + const libkineto::ITraceActivity& cpuTraceActivity, + const libkineto::ITraceActivity& gpuTraceActivity) { + GenericTraceActivity res( + *cpuTraceActivity.traceSpan(), + ActivityType::GPU_USER_ANNOTATION, + cpuTraceActivity.name()); + res.startTime = gpuTraceActivity.timestamp(); + res.device = gpuTraceActivity.deviceId(); + res.resource = gpuTraceActivity.resourceId(); + res.endTime = + gpuTraceActivity.timestamp() + gpuTraceActivity.duration(); + res.id = cpuTraceActivity.correlationId(); + return res; +} + +void CuptiActivityProfiler::GpuUserEventMap::insertOrExtendEvent( + const ITraceActivity& userActivity, + const ITraceActivity& gpuActivity) { + StreamKey key(gpuActivity.deviceId(), gpuActivity.resourceId()); + CorrelationSpanMap& correlationSpanMap = streamSpanMap_[key]; + auto it = correlationSpanMap.find(userActivity.correlationId()); + if (it == correlationSpanMap.end()) { + auto it_success = correlationSpanMap.insert({ + userActivity.correlationId(), createUserGpuSpan(userActivity, gpuActivity) + }); + it = it_success.first; + } + GenericTraceActivity& span = it->second; + if (gpuActivity.timestamp() < span.startTime || span.startTime == 0) { + span.startTime = gpuActivity.timestamp(); + } + int64_t gpu_activity_end = gpuActivity.timestamp() + gpuActivity.duration(); + if (gpu_activity_end > span.endTime) { + span.endTime = gpu_activity_end; + } +} + +const CuptiActivityProfiler::CpuGpuSpanPair& CuptiActivityProfiler::defaultTraceSpan() { + static TraceSpan span(0, 0, "Unknown", ""); + static CpuGpuSpanPair span_pair(span, span); + return span_pair; +} + +void CuptiActivityProfiler::GpuUserEventMap::logEvents(ActivityLogger *logger) { + for (auto const& streamMapPair : streamSpanMap_) { + for (auto const& correlationSpanPair : streamMapPair.second) { + correlationSpanPair.second.log(*logger); + } + } +} + +#ifdef HAS_CUPTI +inline bool CuptiActivityProfiler::outOfRange(const ITraceActivity& act) { + bool out_of_range = act.timestamp() < captureWindowStartTime_ || + (act.timestamp() + act.duration()) > captureWindowEndTime_; + if (out_of_range) { + VLOG(2) << "TraceActivity outside of profiling window: " << act.name() + << " (" << act.timestamp() << " < " << captureWindowStartTime_ << " or " + << (act.timestamp() + act.duration()) << " > " << captureWindowEndTime_; + } + return out_of_range; +} + +inline static bool isBlockListedRuntimeCbid(CUpti_CallbackId cbid) { + // Some CUDA calls that are very frequent and also not very interesting. + // Filter these out to reduce trace size. + if (cbid == CUPTI_RUNTIME_TRACE_CBID_cudaGetDevice_v3020 || + cbid == CUPTI_RUNTIME_TRACE_CBID_cudaSetDevice_v3020 || + cbid == CUPTI_RUNTIME_TRACE_CBID_cudaGetLastError_v3020 || + // Don't care about cudaEvents + cbid == CUPTI_RUNTIME_TRACE_CBID_cudaEventCreate_v3020 || + cbid == CUPTI_RUNTIME_TRACE_CBID_cudaEventCreateWithFlags_v3020 || + cbid == CUPTI_RUNTIME_TRACE_CBID_cudaEventRecord_v3020 || + cbid == CUPTI_RUNTIME_TRACE_CBID_cudaEventDestroy_v3020 || + cbid == CUPTI_RUNTIME_TRACE_CBID_cudaEventSynchronize_v3020) { + return true; + } + + return false; +} + +void CuptiActivityProfiler::handleRuntimeActivity( + const CUpti_ActivityAPI* activity, + ActivityLogger* logger) { + if (isBlockListedRuntimeCbid(activity->cbid)) { + return; + } + VLOG(2) << activity->correlationId + << ": CUPTI_ACTIVITY_KIND_RUNTIME, cbid=" << activity->cbid + << " tid=" << activity->threadId; + int32_t tid = activity->threadId; + const auto& it = resourceInfo_.find({processId(), tid}); + if (it != resourceInfo_.end()) { + tid = it->second.id; + } + const ITraceActivity* linked = linkedActivity( + activity->correlationId, cpuCorrelationMap_); + const auto& runtime_activity = + traceBuffers_->addActivityWrapper(RuntimeActivity(activity, linked, tid)); + checkTimestampOrder(&runtime_activity); + if (outOfRange(runtime_activity)) { + return; + } + runtime_activity.log(*logger); +} + +void CuptiActivityProfiler::handleOverheadActivity( + const CUpti_ActivityOverhead* activity, + ActivityLogger* logger) { + VLOG(2) << ": CUPTI_ACTIVITY_KIND_OVERHEAD" << " overheadKind=" << activity->overheadKind; + + const auto& overhead_activity = + traceBuffers_->addActivityWrapper(OverheadActivity(activity, nullptr)); + overhead_activity.log(*logger); +} + + +inline void CuptiActivityProfiler::updateGpuNetSpan( + const ITraceActivity& gpuOp) { + if (!gpuOp.linkedActivity()) { + VLOG(0) << "Missing linked activity"; + return; + } + const auto& it = clientActivityTraceMap_.find( + gpuOp.linkedActivity()->correlationId()); + if (it == clientActivityTraceMap_.end()) { + // No correlation id mapping? + return; + } + TraceSpan& gpu_span = it->second->second; + if (gpuOp.timestamp() < gpu_span.startTime || gpu_span.startTime == 0) { + gpu_span.startTime = gpuOp.timestamp(); + } + if ((gpuOp.timestamp() + gpuOp.duration()) > gpu_span.endTime) { + gpu_span.endTime = gpuOp.timestamp() + gpuOp.duration(); + } +} + +// I've observed occasional broken timestamps attached to GPU events... +void CuptiActivityProfiler::checkTimestampOrder(const ITraceActivity* act1) { + // Correlated GPU runtime activity cannot + // have timestamp greater than the GPU activity's + const auto& it = correlatedCudaActivities_.find(act1->correlationId()); + if (it == correlatedCudaActivities_.end()) { + correlatedCudaActivities_.insert({act1->correlationId(), act1}); + return; + } + + // Activities may be appear in the buffers out of order. + // If we have a runtime activity in the map, it should mean that we + // have a GPU activity passed in, and vice versa. + const ITraceActivity* act2 = it->second; + if (act2->type() == ActivityType::CUDA_RUNTIME) { + // Buffer is out-of-order. + // Swap so that runtime activity is first for the comparison below. + std::swap(act1, act2); + } + if (act1->timestamp() > act2->timestamp()) { + LOG(WARNING) << "GPU op timestamp (" << act2->timestamp() + << ") < runtime timestamp (" << act1->timestamp() << ") by " + << act1->timestamp() - act2->timestamp() << "us"; + LOG(WARNING) << "Name: " << act2->name() + << " Device: " << act2->deviceId() + << " Stream: " << act2->resourceId(); + } +} + +inline void CuptiActivityProfiler::handleGpuActivity( + const ITraceActivity& act, + ActivityLogger* logger) { + if (outOfRange(act)) { + return; + } + checkTimestampOrder(&act); + VLOG(2) << act.correlationId() << ": " + << act.name(); + recordStream(act.deviceId(), act.resourceId(), ""); + act.log(*logger); + updateGpuNetSpan(act); + if (config_->selectedActivityTypes().count(ActivityType::GPU_USER_ANNOTATION)) { + const auto& it = userCorrelationMap_.find(act.correlationId()); + if (it != userCorrelationMap_.end()) { + const auto& it2 = activityMap_.find(it->second); + if (it2 != activityMap_.end()) { + recordStream(act.deviceId(), act.resourceId(), "context"); + gpuUserEventMap_.insertOrExtendEvent(*it2->second, act); + } + } + } +} + +const ITraceActivity* CuptiActivityProfiler::linkedActivity( + int32_t correlationId, + const std::unordered_map& correlationMap) { + const auto& it = correlationMap.find(correlationId); + if (it != correlationMap.end()) { + const auto& it2 = activityMap_.find(it->second); + if (it2 != activityMap_.end()) { + return it2->second; + } + } + return nullptr; +} + +template +inline void CuptiActivityProfiler::handleGpuActivity( + const T* act, ActivityLogger* logger) { + const ITraceActivity* linked = linkedActivity( + act->correlationId, cpuCorrelationMap_); + const auto& gpu_activity = + traceBuffers_->addActivityWrapper(GpuActivity(act, linked)); + handleGpuActivity(gpu_activity, logger); +} + +void CuptiActivityProfiler::handleCuptiActivity(const CUpti_Activity* record, ActivityLogger* logger) { + switch (record->kind) { + case CUPTI_ACTIVITY_KIND_EXTERNAL_CORRELATION: + handleCorrelationActivity( + reinterpret_cast( + record)); + break; + case CUPTI_ACTIVITY_KIND_RUNTIME: + handleRuntimeActivity( + reinterpret_cast(record), logger); + break; + case CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL: + handleGpuActivity( + reinterpret_cast(record), logger); + break; + case CUPTI_ACTIVITY_KIND_MEMCPY: + handleGpuActivity( + reinterpret_cast(record), logger); + break; + case CUPTI_ACTIVITY_KIND_MEMCPY2: + handleGpuActivity( + reinterpret_cast(record), logger); + break; + case CUPTI_ACTIVITY_KIND_MEMSET: + handleGpuActivity( + reinterpret_cast(record), logger); + break; + case CUPTI_ACTIVITY_KIND_OVERHEAD: + handleOverheadActivity (reinterpret_cast(record), logger); + break; + default: + LOG(WARNING) << "Unexpected activity type: " << record->kind; + break; + } +} +#endif // HAS_CUPTI + +void CuptiActivityProfiler::configureChildProfilers() { + // If child profilers are enabled create profiler sessions + for (auto& profiler: profilers_) { + int64_t start_time_ms = duration_cast( + profileStartTime_.time_since_epoch()).count(); + LOG(INFO) << "Running child profiler " << profiler->name() << " for " + << config_->activitiesDuration().count() << " ms"; + auto session = profiler->configure( + start_time_ms, + config_->activitiesDuration().count(), + config_->selectedActivityTypes(), + *config_ + ); + if (session) { + sessions_.push_back(std::move(session)); + } + } +} + +void CuptiActivityProfiler::configure( + const Config& config, + const time_point& now) { + std::lock_guard guard(mutex_); + if (isActive()) { + LOG(ERROR) << "CuptiActivityProfiler already busy, terminating"; + return; + } + + config_ = config.clone(); + + if (config_->activitiesDuration().count() == 0) { + // Use default if not specified + config_->setActivitiesDuration( + config_->activitiesDurationDefault()); + } + + // Ensure we're starting in a clean state + resetTraceData(); + +#if !USE_GOOGLE_LOG + // Add a LoggerObserverCollector to collect all logs during the trace. + loggerCollectorMetadata_ = std::make_unique(); + Logger::addLoggerObserver(loggerCollectorMetadata_.get()); +#endif // !USE_GOOGLE_LOG + + profileStartTime_ = config_->requestTimestamp(); + + if (config_->hasProfileStartIteration()) { + profileStartIter_ = config_->profileStartIteration(); + profileEndIter_ = profileStartIter_ + config_->activitiesRunIterations(); + } else { + + profileStartIter_ = -1; + profileEndIter_ = (std::numeric_limits::max)(); + + if (profileStartTime_ < now) { + LOG(ERROR) << "Not starting tracing - start timestamp is in the past. Time difference (ms): " << duration_cast(now - profileStartTime_).count(); + return; + } else if ((profileStartTime_ - now) < config_->activitiesWarmupDuration()) { + LOG(ERROR) << "Not starting tracing - insufficient time for warmup. Time to warmup (ms): " << duration_cast(profileStartTime_ - now).count() ; + return; + } + } + + if (LOG_IS_ON(INFO)) { + config_->printActivityProfilerConfig(LIBKINETO_DBG_STREAM); + } + if (!cpuOnly_ && !libkineto::api().client()) { + if (profileStartIter_ < 0) { + LOG(INFO) << "GPU-only tracing for " + << config_->activitiesDuration().count() << "ms"; + } else { + LOG(INFO) << "GPU-only tracing for " + << config_->activitiesRunIterations() << " iterations"; + } + } + + // Set useful metadata into the logger. + LOGGER_OBSERVER_SET_TRACE_DURATION_MS(config_->activitiesDuration().count()); + if (!config_->requestTraceID().empty()) { + LOGGER_OBSERVER_SET_TRACE_ID(config_->requestTraceID()); + } + if (!config_->requestGroupTraceID().empty()) { + LOGGER_OBSERVER_SET_GROUP_TRACE_ID(config_->requestGroupTraceID()); + } + LOGGER_OBSERVER_ADD_DESTINATION(config_->activitiesLogUrl()); + +#if defined(HAS_CUPTI) || defined(HAS_ROCTRACER) + if (!cpuOnly_) { + // Enabling CUPTI activity tracing incurs a larger perf hit at first, + // presumably because structures are allocated and initialized, callbacks + // are activated etc. After a while the overhead decreases and stabilizes. + // It's therefore useful to perform some warmup before starting recording. + LOG(INFO) << "Enabling GPU tracing"; + cupti_.setMaxBufferSize(config_->activitiesMaxGpuBufferSize()); + + time_point timestamp; + if (VLOG_IS_ON(1)) { + timestamp = system_clock::now(); + } +#ifdef HAS_CUPTI + cupti_.enableCuptiActivities(config_->selectedActivityTypes()); +#else + cupti_.enableActivities(config_->selectedActivityTypes()); +#endif + if (VLOG_IS_ON(1)) { + auto t2 = system_clock::now(); + addOverheadSample( + setupOverhead_, duration_cast(t2 - timestamp).count()); + } + } +#endif // HAS_CUPTI || HAS_ROCTRACER + + if (profilers_.size() > 0) { + configureChildProfilers(); + } + + if (libkineto::api().client()) { + libkineto::api().client()->warmup(config_->isOpInputsCollectionEnabled()); + } + if (profileStartIter_ >= 0) { + LOG(INFO) << "Tracing starting on iteration = " << profileStartIter_; + } else { + LOG(INFO) << "Tracing starting in " + << duration_cast(profileStartTime_ - now).count() << "s"; + } + + traceBuffers_ = std::make_unique(); + captureWindowStartTime_ = captureWindowEndTime_ = 0; + currentRunloopState_ = RunloopState::Warmup; +} + +void CuptiActivityProfiler::startTraceInternal(const time_point& now) { + captureWindowStartTime_ = libkineto::timeSinceEpoch(now); + VLOG(0) << "Warmup -> CollectTrace"; + for (auto& session: sessions_){ + LOG(INFO) << "Starting child profiler session"; + session->start(); + } + currentRunloopState_ = RunloopState::CollectTrace; +} + +void CuptiActivityProfiler::stopTraceInternal(const time_point& now) { + if (captureWindowEndTime_ == 0) { + captureWindowEndTime_ = libkineto::timeSinceEpoch(now); + } +#if defined(HAS_CUPTI) || defined(HAS_ROCTRACER) + if (!cpuOnly_) { + time_point timestamp; + if (VLOG_IS_ON(1)) { + timestamp = system_clock::now(); + } +#ifdef HAS_CUPTI + cupti_.disableCuptiActivities(config_->selectedActivityTypes()); +#else + cupti_.disableActivities(config_->selectedActivityTypes()); +#endif + if (VLOG_IS_ON(1)) { + auto t2 = system_clock::now(); + addOverheadSample( + setupOverhead_, duration_cast(t2 - timestamp).count()); + } + } +#endif // HAS_CUPTI || HAS_ROCTRACER + + if (currentRunloopState_ == RunloopState::CollectTrace) { + VLOG(0) << "CollectTrace -> ProcessTrace"; + } else { + LOG(WARNING) << "Called stopTrace with state == " << + static_cast::type>( + currentRunloopState_.load()); + } + for (auto& session: sessions_){ + LOG(INFO) << "Stopping child profiler session"; + session->stop(); + } + currentRunloopState_ = RunloopState::ProcessTrace; +} + +void CuptiActivityProfiler::resetInternal() { + resetTraceData(); + currentRunloopState_ = RunloopState::WaitForRequest; +} + +bool CuptiActivityProfiler::isWarmupDone( + const time_point& now, + int64_t currentIter) const { + // is it a time based config + if (profileStartIter_ < 0) { + // qualify that this check is not being called from application step() API + // this avoids races between the step() API and periodically invoked + // profiler run loop step() method + return (currentIter < 0) && (now >= profileStartTime_); + } + // this is an iteration based config + if (currentIter < 0) { + return false; + } + return currentIter >= profileStartIter_; +} + +bool CuptiActivityProfiler::isCollectionDone( + const time_point& now, + int64_t currentIter) const { + // is it a time based config + if (profileStartIter_ < 0) { + // qualify that this check is not being called from application step() API + return (currentIter < 0) && (now >= profileEndTime_); + } + // this is an iteration based config + if (currentIter < 0) { + return false; + } + return currentIter >= profileEndIter_; +} + +const time_point CuptiActivityProfiler::performRunLoopStep( + const time_point& now, + const time_point& nextWakeupTime, + int64_t currentIter) { + auto new_wakeup_time = nextWakeupTime; + bool warmup_done = false, collection_done = false; + + VLOG_IF(1, currentIter >= 0) << "Run loop on application step(), iteration = " + << currentIter; + + switch (currentRunloopState_) { + case RunloopState::WaitForRequest: + VLOG(1) << "State: WaitForRequest"; + // Nothing to do + break; + + case RunloopState::Warmup: + VLOG(1) << "State: Warmup"; + warmup_done = isWarmupDone(now, currentIter); +#if defined(HAS_CUPTI) || defined(HAS_ROCTRACER) + // Flushing can take a while so avoid doing it close to the start time + if (!cpuOnly_ && currentIter < 0 && + (profileStartIter_ >= 0 || nextWakeupTime < profileStartTime_)) { + cupti_.clearActivities(); + } + + if (cupti_.stopCollection) { + // Go to process trace to clear any outstanding buffers etc + LOG(WARNING) << "Trace terminated during warmup"; + std::lock_guard guard(mutex_); + stopTraceInternal(now); + resetInternal(); + VLOG(0) << "Warmup -> WaitForRequest"; + break; + } +#endif // HAS_CUPTI || HAS_ROCTRACER + + if (warmup_done) { + UST_LOGGER_MARK_COMPLETED(kWarmUpStage); + if (profileStartIter_ < 0 && + (now > profileStartTime_ + milliseconds(10))) { + LOG(WARNING) + << "Tracing started " + << duration_cast(now - profileStartTime_).count() + << "ms late!"; + } else { + LOG(INFO) << "Tracing started"; + } + startTrace(now); + if (libkineto::api().client()) { + libkineto::api().client()->start(); + } + if (nextWakeupTime > profileEndTime_) { + new_wakeup_time = profileEndTime_; + } + } else if (nextWakeupTime > profileStartTime_) { + new_wakeup_time = profileStartTime_; + } + + break; + + case RunloopState::CollectTrace: + VLOG(1) << "State: CollectTrace"; + // captureWindowStartTime_ can be set by external threads, + // so recompute end time. + // FIXME: Is this a good idea for synced start? + if (profileStartIter_ < 0) { + std::lock_guard guard(mutex_); + profileEndTime_ = time_point( + microseconds(captureWindowStartTime_)) + + config_->activitiesDuration(); + } + + collection_done = isCollectionDone(now, currentIter); + + // TODO revisit stopCollection_ is not used right now + if (collection_done || stopCollection_.exchange(false) +#if defined(HAS_CUPTI) || defined(HAS_ROCTRACER) + || cupti_.stopCollection +#endif // HAS_CUPTI || HAS_ROCTRACER + ){ + // Update runloop state first to prevent further updates to shared state + LOG(INFO) << "Tracing complete."; + if (currentIter > 0) { + LOG(INFO) << "This state change was invoked by application's step() call"; + } + // FIXME: Need to communicate reason for stopping on errors + if (libkineto::api().client()) { + libkineto::api().client()->stop(); + } + std::lock_guard guard(mutex_); + stopTraceInternal(now); + VLOG_IF(0, collection_done) << "Reached profile end time"; + + UST_LOGGER_MARK_COMPLETED(kCollectionStage); + } else if (profileStartIter_ >= 0) { + // nothing to do here + } else if (now < profileEndTime_ && profileEndTime_ < nextWakeupTime) { + new_wakeup_time = profileEndTime_; + } + + break; + + case RunloopState::ProcessTrace: + VLOG(1) << "State: ProcessTrace"; + // skip this state transition if it called from the step() api + // of the profiler. + // else it could lead to a race between the profiler thread and an + // application thread calling step() + if (currentIter >= 0) { + return new_wakeup_time; + } + // FIXME: Probably want to allow interruption here + // for quickly handling trace request via synchronous API + std::lock_guard guard(mutex_); + processTraceInternal(*logger_); + UST_LOGGER_MARK_COMPLETED(kPostProcessingStage); + resetInternal(); + VLOG(0) << "ProcessTrace -> WaitForRequest"; + break; + } + + return new_wakeup_time; +} + +void CuptiActivityProfiler::finalizeTrace(const Config& config, ActivityLogger& logger) { + LOG(INFO) << "Recorded nets:"; + { + for (const auto& it : iterationCountMap_) { + LOG(INFO) << it.first << ": " << it.second << " iterations"; + } + iterationCountMap_.clear(); + } + + // Process names + int32_t pid = processId(); + string process_name = processName(pid); + if (!process_name.empty()) { + logger.handleDeviceInfo( + {pid, process_name, "CPU"}, captureWindowStartTime_); + if (!cpuOnly_) { + // GPU events use device id as pid (0-7). + constexpr int kMaxGpuCount = 8; + for (int gpu = 0; gpu < kMaxGpuCount; gpu++) { + logger.handleDeviceInfo( + {gpu, process_name, fmt::format("GPU {}", gpu)}, + captureWindowStartTime_); + } + } + } + + // Thread & stream info + for (auto pair : resourceInfo_) { + const auto& resource = pair.second; + logger.handleResourceInfo(resource, captureWindowStartTime_); + } + + for (const auto& iterations : traceSpans_) { + for (const auto& span_pair : iterations.second) { + const TraceSpan& gpu_span = span_pair.second; + if (gpu_span.opCount > 0) { + logger.handleTraceSpan(gpu_span); + } + } + } + + // Overhead info + overheadInfo_.push_back(ActivityLogger::OverheadInfo("CUPTI Overhead")); + for(const auto& info : overheadInfo_) { + logger.handleOverheadInfo(info, captureWindowStartTime_); + } + + gpuUserEventMap_.logEvents(&logger); + +#if !USE_GOOGLE_LOG + // Save logs from LoggerCollector objects into Trace metadata. + auto LoggerMD = loggerCollectorMetadata_->extractCollectorMetadata(); + std::unordered_map> LoggerMDString; + for (auto& md : LoggerMD) { + LoggerMDString[toString(md.first)] = md.second; + } +#endif // !USE_GOOGLE_LOG + + logger.finalizeTrace(config, std::move(traceBuffers_), captureWindowEndTime_, LoggerMDString); +} + +void CuptiActivityProfiler::resetTraceData() { +#if defined(HAS_CUPTI) || defined(HAS_ROCTRACER) + if (!cpuOnly_) { + cupti_.clearActivities(); + } +#endif // HAS_CUPTI || HAS_ROCTRACER + activityMap_.clear(); + cpuCorrelationMap_.clear(); + correlatedCudaActivities_.clear(); + gpuUserEventMap_.clear(); + traceSpans_.clear(); + clientActivityTraceMap_.clear(); + traceBuffers_ = nullptr; + metadata_.clear(); + sessions_.clear(); +#if !USE_GOOGLE_LOG + Logger::removeLoggerObserver(loggerCollectorMetadata_.get()); +#endif // !USE_GOOGLE_LOG +} + + +} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/CuptiActivityProfiler.h b/plugins/tensorboard-plugins/libkineto/src/CuptiActivityProfiler.h new file mode 100644 index 0000000000000000000000000000000000000000..208833a4db720429982a63ed72ffa4762ef00bd0 --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/src/CuptiActivityProfiler.h @@ -0,0 +1,364 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// TODO(T90238193) +// @lint-ignore-every CLANGTIDY facebook-hte-RelativeInclude +#include "ThreadUtil.h" +#include "TraceSpan.h" +#include "libkineto.h" +#include "output_base.h" +#include "GenericTraceActivity.h" +#include "IActivityProfiler.h" +#include "LoggerCollector.h" + +namespace KINETO_NAMESPACE { + +class Config; +class CuptiActivityApi; +class RoctracerActivityApi; + +class CuptiActivityProfiler { + public: + CuptiActivityProfiler(CuptiActivityApi& cupti, bool cpuOnly); + CuptiActivityProfiler(RoctracerActivityApi& rai, bool cpuOnly); + CuptiActivityProfiler(const CuptiActivityProfiler&) = delete; + CuptiActivityProfiler& operator=(const CuptiActivityProfiler&) = delete; + + bool isActive() const { + return currentRunloopState_ != RunloopState::WaitForRequest; + } + + // Invoke at a regular interval to perform profiling activities. + // When not active, an interval of 1-5 seconds is probably fine, + // depending on required warm-up time and delayed start time. + // When active, it's a good idea to invoke more frequently to stay below + // memory usage limit (ACTIVITIES_MAX_GPU_BUFFER_SIZE_MB) during warmup. + const std::chrono::time_point performRunLoopStep( + const std::chrono::time_point& now, + const std::chrono::time_point& nextWakeupTime, + int64_t currentIter = -1); + + // Used for async requests + void setLogger(ActivityLogger* logger) { + logger_ = logger; + } + + // Synchronous control API + void startTrace( + const std::chrono::time_point& now) { + std::lock_guard guard(mutex_); + startTraceInternal(now); + } + + void stopTrace(const std::chrono::time_point& now) { + std::lock_guard guard(mutex_); + stopTraceInternal(now); + } + + // Process CPU and GPU traces + void processTrace(ActivityLogger& logger) { + std::lock_guard guard(mutex_); + processTraceInternal(logger); + } + + void reset() { + std::lock_guard guard(mutex_); + resetInternal(); + } + + // Set up profiler as specified in config. + void configure( + const Config& config, + const std::chrono::time_point& now); + + // Registered with client API to pass CPU trace events over + void transferCpuTrace( + std::unique_ptr cpuTrace); + + Config& config() { + return *config_; + } + + inline void recordThreadInfo() { + int32_t sysTid = systemThreadId(); + // Note we're using the lower 32 bits of the (opaque) pthread id + // as key, because that's what CUPTI records. + int32_t tid = threadId(); + int32_t pid = processId(); + std::lock_guard guard(mutex_); + recordThreadInfo(sysTid, tid, pid); + } + + // T107508020: We can deprecate the recordThreadInfo(void) once we optimized profiler_kineto + void recordThreadInfo(int32_t sysTid, int32_t tid, int32_t pid) { + if (resourceInfo_.find({pid, tid}) == resourceInfo_.end()) { + resourceInfo_.emplace( + std::make_pair(pid, tid), + ActivityLogger::ResourceInfo( + pid, + sysTid, + sysTid, // sortindex + fmt::format("thread {} ({})", sysTid, getThreadName()))); + } + } + + void addMetadata(const std::string& key, const std::string& value) { + std::lock_guard guard(mutex_); + metadata_[key] = value; + } + + void addChildActivityProfiler( + std::unique_ptr profiler) { + std::lock_guard guard(mutex_); + profilers_.push_back(std::move(profiler)); + } + + protected: + + using CpuGpuSpanPair = std::pair; + static const CpuGpuSpanPair& defaultTraceSpan(); + + private: + + // Map of gpu activities to user defined events + class GpuUserEventMap { + public: + // Insert a user defined event which maps to the gpu trace activity. + // If the user defined event mapping already exists this will update the + // gpu side span to include the span of gpuTraceActivity. + void insertOrExtendEvent(const ITraceActivity& cpuTraceActivity, + const ITraceActivity& gpuTraceActivity); + // Log out the events to the logger + void logEvents(ActivityLogger *logger); + + void clear() { + streamSpanMap_.clear(); + } + + private: + // device id and stream name + using StreamKey = std::pair; + + // map of correlation id to TraceSpan + using CorrelationSpanMap = + std::unordered_map; + std::map streamSpanMap_; + }; + + GpuUserEventMap gpuUserEventMap_; + // id -> activity* + std::unordered_map activityMap_; + // cuda runtime id -> pytorch op id + // CUPTI provides a mechanism for correlating Cuda events to arbitrary + // external events, e.g.operator activities from PyTorch. + std::unordered_map cpuCorrelationMap_; + // CUDA runtime <-> GPU Activity + std::unordered_map + correlatedCudaActivities_; + std::unordered_map userCorrelationMap_; + + // data structure to collect cuptiActivityFlushAll() latency overhead + struct profilerOverhead { + int64_t overhead; + int cntr; + }; + + bool isWarmupDone( + const std::chrono::time_point& now, + int64_t currentIter) const; + + bool isCollectionDone( + const std::chrono::time_point& now, + int64_t currentIter) const; + + void startTraceInternal( + const std::chrono::time_point& now); + + void stopTraceInternal( + const std::chrono::time_point& now); + + void processTraceInternal(ActivityLogger& logger); + + void resetInternal(); + + void finalizeTrace(const Config& config, ActivityLogger& logger); + + void configureChildProfilers(); + + // Process a single CPU trace + void processCpuTrace( + libkineto::CpuTraceBuffer& cpuTrace, + ActivityLogger& logger); + + // Create resource names for streams + inline void recordStream(int device, int id, const char* postfix) { + if (resourceInfo_.find({device, id}) == resourceInfo_.end()) { + resourceInfo_.emplace( + std::make_pair(device, id), + ActivityLogger::ResourceInfo( + device, id, id, fmt::format( + "stream {} {}", id, postfix))); + } + } + + // Record client trace span for subsequent lookups from activities + // Also creates a corresponding GPU-side span. + CpuGpuSpanPair& recordTraceSpan(TraceSpan& span, int gpuOpCount); + + // Returns true if net name is to be tracked for a specified number of + // iterations. + bool iterationTargetMatch(libkineto::CpuTraceBuffer& trace); + + // net name to id + int netId(const std::string& netName); + + const ITraceActivity* linkedActivity( + int32_t correlationId, + const std::unordered_map& correlationMap); + +#ifdef HAS_CUPTI + // Process generic CUPTI activity + void handleCuptiActivity(const CUpti_Activity* record, ActivityLogger* logger); + + // Process specific GPU activity types + void updateGpuNetSpan(const ITraceActivity& gpuOp); + bool outOfRange(const ITraceActivity& act); + void handleCorrelationActivity( + const CUpti_ActivityExternalCorrelation* correlation); + void handleRuntimeActivity( + const CUpti_ActivityAPI* activity, ActivityLogger* logger); + void handleOverheadActivity( + const CUpti_ActivityOverhead* activity, ActivityLogger* logger); + void handleGpuActivity(const ITraceActivity& act, + ActivityLogger* logger); + template + void handleGpuActivity(const T* act, ActivityLogger* logger); +#endif // HAS_CUPTI + + void resetTraceData(); + + void addOverheadSample(profilerOverhead& counter, int64_t overhead) { + counter.overhead += overhead; + counter.cntr++; + } + int64_t getOverhead(const profilerOverhead& counter) { + if (counter.cntr == 0) { + return 0; + } + return counter.overhead / counter.cntr; + } + + void checkTimestampOrder(const ITraceActivity* act1); + + // On-demand request configuration + std::unique_ptr config_; + + // Logger used during trace processing + ActivityLogger* logger_; + + // Calls to CUPTI is encapsulated behind this interface +#ifdef HAS_ROCTRACER + RoctracerActivityApi& cupti_; // Design failure here +#else + CuptiActivityApi& cupti_; +#endif + + enum class RunloopState { + WaitForRequest, + Warmup, + CollectTrace, + ProcessTrace + }; + + // Start and end time used for triggering and stopping profiling + std::chrono::time_point profileStartTime_; + std::chrono::time_point profileEndTime_; + int64_t profileStartIter_ = -1, profileEndIter_ = -1; + + + // All recorded trace spans, both CPU and GPU + // Trace Id -> list of iterations. + // Using map of lists for the iterator semantics, since we are recording + // pointers to the elements in this structure. + std::map> traceSpans_; + + // Maintain a map of client trace activity to trace span. + // Maps correlation id -> TraceSpan* held by traceSpans_. + using ActivityTraceMap = std::unordered_map; + ActivityTraceMap clientActivityTraceMap_; + + // Cache thread names and system thread ids for pthread ids, + // and stream ids for GPU streams + std::map< + std::pair, + ActivityLogger::ResourceInfo> resourceInfo_; + + std::vector overheadInfo_; + + // the overhead to flush the activity buffer + profilerOverhead flushOverhead_; + // the overhead to enable/disable activity tracking + profilerOverhead setupOverhead_; + + bool cpuOnly_{false}; + + // *************************************************************************** + // Below state is shared with external threads. + // These need to either be atomic, accessed under lock or only used + // by external threads in separate runloop phases from the profiler thread. + // *************************************************************************** + + // Mutex to protect non-atomic access to below state + std::mutex mutex_; + + // Runloop phase + std::atomic currentRunloopState_{RunloopState::WaitForRequest}; + + // Keep track of the start time of the first net in the current trace. + // This is only relevant to Caffe2 as PyTorch does not have nets. + // All CUDA events before this time will be removed + // Can be written by external threads during collection. + int64_t captureWindowStartTime_{0}; + // Similarly, all CUDA API events after the last net event will be removed + int64_t captureWindowEndTime_{0}; + + // span name -> iteration count + std::map iterationCountMap_; + // Flag used to stop tracing from external api callback. + // Needs to be atomic since it's set from a different thread. + std::atomic_bool stopCollection_{false}; + + // Buffers where trace data is stored + std::unique_ptr traceBuffers_; + + // Trace metadata + std::unordered_map metadata_; + + // child activity profilers + std::vector> profilers_; + + // a vector of active profiler plugin sessions + std::vector> sessions_; + + // LoggerCollector to collect all LOGs during the trace +#if !USE_GOOGLE_LOG + std::unique_ptr loggerCollectorMetadata_; +#endif // !USE_GOOGLE_LOG +}; + +} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/CuptiCallbackApi.cpp b/plugins/tensorboard-plugins/libkineto/src/CuptiCallbackApi.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1876003998dc0c66f882d939ca8100750cfd046a --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/src/CuptiCallbackApi.cpp @@ -0,0 +1,260 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#include "CuptiCallbackApi.h" + +#include +#include +#include +#include +#include + +#ifdef HAS_CUPTI +#include "cupti_call.h" +#endif +#include "Logger.h" + + +namespace KINETO_NAMESPACE { + +// limit on number of handles per callback type +constexpr size_t MAX_CB_FNS_PER_CB = 8; + +// Reader Writer lock types +using ReaderWriterLock = std::shared_timed_mutex; +using ReaderLockGuard = std::shared_lock; +using WriteLockGuard = std::unique_lock; + +static ReaderWriterLock callbackLock_; + +/* Callback Table : + * Overall goal of the design is to optimize the lookup of function + * pointers. The table is structured at two levels and the leaf + * elements in the table are std::list to enable fast access/inserts/deletes + * + * | + * -> cb id 0 -> std::list of callbacks + * ... + * -> cb id n -> std::list of callbacks + * | + * ... + * CallbackTable is the finaly table type above + * See type declrartions in header file. + */ + + +/* callback_switchboard : is the global callback handler we register + * with CUPTI. The goal is to make it as efficient as possible + * to re-direct to the registered callback(s). + * + * Few things to care about : + * a) use if/then switches rather than map/hash structures + * b) avoid dynamic memory allocations + * c) be aware of locking overheads + */ +#ifdef HAS_CUPTI +static void CUPTIAPI callback_switchboard( +#else +static void callback_switchboard( +#endif + void* /* unused */, + CUpti_CallbackDomain domain, + CUpti_CallbackId cbid, + const CUpti_CallbackData* cbInfo) { + + // below statement is likey going to call a mutex + // on the singleton access + CuptiCallbackApi::singleton().__callback_switchboard( + domain, cbid, cbInfo); +} + + +void CuptiCallbackApi::__callback_switchboard( + CUpti_CallbackDomain domain, + CUpti_CallbackId cbid, + const CUpti_CallbackData* cbInfo) { + VLOG(0) << "Callback: domain = " << domain << ", cbid = " << cbid; + CallbackList *cblist = nullptr; + + switch (domain) { + + // add the fastest path for kernel launch callbacks + // as these are the most frequent ones + case CUPTI_CB_DOMAIN_RUNTIME_API: + switch (cbid) { + case CUPTI_RUNTIME_TRACE_CBID_cudaLaunchKernel_v7000: + cblist = &callbacks_.runtime[ + CUDA_LAUNCH_KERNEL - __RUNTIME_CB_DOMAIN_START]; + break; + default: + break; + } + break; + + case CUPTI_CB_DOMAIN_RESOURCE: + switch (cbid) { + case CUPTI_CBID_RESOURCE_CONTEXT_CREATED: + cblist = &callbacks_.resource[ + RESOURCE_CONTEXT_CREATED - __RESOURCE_CB_DOMAIN_START]; + break; + case CUPTI_CBID_RESOURCE_CONTEXT_DESTROY_STARTING: + cblist = &callbacks_.resource[ + RESOURCE_CONTEXT_DESTROYED - __RESOURCE_CB_DOMAIN_START]; + break; + default: + break; + } + break; + + default: + return; + } + + // ignore callbacks that are not handled + if (cblist == nullptr) { + return; + } + + // make a copy of the callback list so we avoid holding lock + // in common case this should be just one func pointer copy + std::array callbacks; + int num_cbs = 0; + { + ReaderLockGuard rl(callbackLock_); + int i = 0; + for (auto it = cblist->begin(); + it != cblist->end() && i < MAX_CB_FNS_PER_CB; + it++, i++) { + callbacks[i] = *it; + } + num_cbs = i; + } + + for (int i = 0; i < num_cbs; i++) { + auto fn = callbacks[i]; + fn(domain, cbid, cbInfo); + } +} + +CuptiCallbackApi& CuptiCallbackApi::singleton() { + static CuptiCallbackApi instance; + return instance; +} + +CuptiCallbackApi::CuptiCallbackApi() { +#ifdef HAS_CUPTI + lastCuptiStatus_ = CUPTI_ERROR_UNKNOWN; + lastCuptiStatus_ = CUPTI_CALL_NOWARN( + cuptiSubscribe(&subscriber_, + (CUpti_CallbackFunc)callback_switchboard, + nullptr)); + + initSuccess_ = (lastCuptiStatus_ == CUPTI_SUCCESS); +#endif +} + +CuptiCallbackApi::CallbackList* CuptiCallbackApi::CallbackTable::lookup( + CUpti_CallbackDomain domain, CuptiCallBackID cbid) { + size_t idx; + + switch (domain) { + + case CUPTI_CB_DOMAIN_RESOURCE: + assert(cbid >= __RESOURCE_CB_DOMAIN_START); + assert(cbid < __RESOURCE_CB_DOMAIN_END); + idx = cbid - __RESOURCE_CB_DOMAIN_START; + return &resource.at(idx); + + case CUPTI_CB_DOMAIN_RUNTIME_API: + assert(cbid >= __RUNTIME_CB_DOMAIN_START); + assert(cbid < __RUNTIME_CB_DOMAIN_END); + idx = cbid - __RUNTIME_CB_DOMAIN_START; + return &runtime.at(idx); + + default: + LOG(WARNING) << " Unsupported callback domain : " << domain; + return nullptr; + } +} + +bool CuptiCallbackApi::registerCallback( + CUpti_CallbackDomain domain, + CuptiCallBackID cbid, + CuptiCallbackFn cbfn) { + CallbackList* cblist = callbacks_.lookup(domain, cbid); + + if (!cblist) { + LOG(WARNING) << "Could not register callback -- domain = " << domain + << " callback id = " << cbid; + return false; + } + + // avoid duplicates + auto it = std::find(cblist->begin(), cblist->end(), cbfn); + if (it != cblist->end()) { + LOG(WARNING) << "Adding duplicate callback -- domain = " << domain + << " callback id = " << cbid; + return true; + } + + if (cblist->size() == MAX_CB_FNS_PER_CB) { + LOG(WARNING) << "Already registered max callback -- domain = " << domain + << " callback id = " << cbid; + } + + WriteLockGuard wl(callbackLock_); + cblist->push_back(cbfn); + return true; +} + +bool CuptiCallbackApi::deleteCallback( + CUpti_CallbackDomain domain, + CuptiCallBackID cbid, + CuptiCallbackFn cbfn) { + CallbackList* cblist = callbacks_.lookup(domain, cbid); + if (!cblist) { + LOG(WARNING) << "Attempting to remove unsupported callback -- domain = " << domain + << " callback id = " << cbid; + return false; + } + + // Locks are not required here as + // https://en.cppreference.com/w/cpp/container/list/erase + // "References and iterators to the erased elements are invalidated. + // Other references and iterators are not affected." + auto it = std::find(cblist->begin(), cblist->end(), cbfn); + if (it == cblist->end()) { + LOG(WARNING) << "Could not find callback to remove -- domain = " << domain + << " callback id = " << cbid; + return false; + } + + WriteLockGuard wl(callbackLock_); + cblist->erase(it); + return true; +} + +bool CuptiCallbackApi::enableCallback( + CUpti_CallbackDomain domain, CUpti_CallbackId cbid) { +#ifdef HAS_CUPTI + if (initSuccess_) { + lastCuptiStatus_ = CUPTI_CALL_NOWARN( + cuptiEnableCallback(1, subscriber_, domain, cbid)); + return (lastCuptiStatus_ == CUPTI_SUCCESS); + } +#endif + return false; +} + +bool CuptiCallbackApi::disableCallback( + CUpti_CallbackDomain domain, CUpti_CallbackId cbid) { +#ifdef HAS_CUPTI + if (initSuccess_) { + lastCuptiStatus_ = CUPTI_CALL_NOWARN( + cuptiEnableCallback(0, subscriber_, domain, cbid)); + return (lastCuptiStatus_ == CUPTI_SUCCESS); + } +#endif + return false; +} + +} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/CuptiCallbackApi.h b/plugins/tensorboard-plugins/libkineto/src/CuptiCallbackApi.h new file mode 100644 index 0000000000000000000000000000000000000000..4526f3750b4a134bc888843b8ff347a1f2bf8d5f --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/src/CuptiCallbackApi.h @@ -0,0 +1,130 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#pragma once + +#include +#ifdef HAS_CUPTI +#include +#endif +#include +#include +#include +#include +#include + +// TODO(T90238193) +// @lint-ignore-every CLANGTIDY facebook-hte-RelativeInclude +#include "CuptiCallbackApiMock.h" + +namespace KINETO_NAMESPACE { + +using namespace libkineto; + + +/* CuptiCallbackApi : Provides an abstraction over CUPTI callback + * interface. This enables various callback functions to be registered + * with this class. The class registers a global callback handler that + * redirects to the respective callbacks. + * + * Note: one design choice we made is to only support simple function pointers + * in order to speed up the implementation for fast path. + */ + +using CuptiCallbackFn = void(*)( + CUpti_CallbackDomain domain, + CUpti_CallbackId cbid, + const CUpti_CallbackData* cbInfo); + + +class CuptiCallbackApi { + + public: + + /* Global list of supported callback ids + * use the class namespace to avoid confusing with CUPTI enums*/ + enum CuptiCallBackID { + CUDA_LAUNCH_KERNEL = 0, + // can possibly support more callback ids per domain + // + __RUNTIME_CB_DOMAIN_START = CUDA_LAUNCH_KERNEL, + + // Callbacks under Resource CB domain + RESOURCE_CONTEXT_CREATED, + RESOURCE_CONTEXT_DESTROYED, + + __RUNTIME_CB_DOMAIN_END = RESOURCE_CONTEXT_CREATED, + __RESOURCE_CB_DOMAIN_START = RESOURCE_CONTEXT_CREATED, + + __RESOURCE_CB_DOMAIN_END = RESOURCE_CONTEXT_DESTROYED + 1, + }; + + + CuptiCallbackApi(const CuptiCallbackApi&) = delete; + CuptiCallbackApi& operator=(const CuptiCallbackApi&) = delete; + + static CuptiCallbackApi& singleton(); + + bool initSuccess() const { + return initSuccess_; + } + +#ifdef HAS_CUPTI + CUptiResult getCuptiStatus() const { + return lastCuptiStatus_; + } +#endif + + bool registerCallback( + CUpti_CallbackDomain domain, + CuptiCallBackID cbid, + CuptiCallbackFn cbfn); + + // returns false if callback was not found + bool deleteCallback( + CUpti_CallbackDomain domain, + CuptiCallBackID cbid, + CuptiCallbackFn cbfn); + + bool enableCallback(CUpti_CallbackDomain domain, CUpti_CallbackId cbid); + bool disableCallback(CUpti_CallbackDomain domain, CUpti_CallbackId cbid); + + + // Please do not use this method. This has to be exposed as public + // so it is accessible from the callback handler + void __callback_switchboard( + CUpti_CallbackDomain domain, + CUpti_CallbackId cbid, + const CUpti_CallbackData* cbInfo); + + private: + + explicit CuptiCallbackApi(); + + // For callback table design overview see the .cpp file + using CallbackList = std::list; + + // level 2 tables sizes are known at compile time + constexpr static size_t RUNTIME_CB_DOMAIN_SIZE + = (__RUNTIME_CB_DOMAIN_END - __RUNTIME_CB_DOMAIN_START); + + constexpr static size_t RESOURCE_CB_DOMAIN_SIZE + = (__RESOURCE_CB_DOMAIN_END - __RESOURCE_CB_DOMAIN_START); + + // level 1 table is a struct + struct CallbackTable { + std::array runtime; + std::array resource; + + CallbackList* lookup(CUpti_CallbackDomain domain, CuptiCallBackID cbid); + }; + + CallbackTable callbacks_; + bool initSuccess_ = false; + +#ifdef HAS_CUPTI + CUptiResult lastCuptiStatus_; + CUpti_SubscriberHandle subscriber_; +#endif +}; + +} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/CuptiCallbackApiMock.h b/plugins/tensorboard-plugins/libkineto/src/CuptiCallbackApiMock.h new file mode 100644 index 0000000000000000000000000000000000000000..fd51267274f99a0c9949eaac6fdae2dff917c7a0 --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/src/CuptiCallbackApiMock.h @@ -0,0 +1,32 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#pragma once + +// Provides data structures to mock CUPTI Callback API +#ifndef HAS_CUPTI + +enum CUpti_CallbackDomain { + CUPTI_CB_DOMAIN_RESOURCE, + CUPTI_CB_DOMAIN_RUNTIME_API, +}; +enum CUpti_CallbackId { + CUPTI_RUNTIME_TRACE_CBID_cudaLaunchKernel_v7000, + CUPTI_CBID_RESOURCE_CONTEXT_CREATED, + CUPTI_CBID_RESOURCE_CONTEXT_DESTROY_STARTING, +}; + +using CUcontext = void*; + +struct CUpti_ResourceData { + CUcontext context; +}; + +constexpr int CUPTI_API_ENTER = 0; +constexpr int CUPTI_API_EXIT = 0; + +struct CUpti_CallbackData { + CUcontext context; + const char* symbolName; + int callbackSite; +}; +#endif // HAS_CUPTI diff --git a/plugins/tensorboard-plugins/libkineto/src/CuptiEventApi.cpp b/plugins/tensorboard-plugins/libkineto/src/CuptiEventApi.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7f1d48c1d00bb7defb6b622c13da55da99312a3b --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/src/CuptiEventApi.cpp @@ -0,0 +1,112 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#include "CuptiEventApi.h" + +#include + +#include "Logger.h" +#include "cupti_call.h" + +using namespace std::chrono; +using std::vector; + +namespace KINETO_NAMESPACE { + +CuptiEventApi::CuptiEventApi(CUcontext context) + : context_(context) { + CUPTI_CALL(cuptiGetDeviceId(context_, (uint32_t*)&device_)); +} + +CUpti_EventGroupSets* CuptiEventApi::createGroupSets( + vector& ids) { + CUpti_EventGroupSets* group_sets = nullptr; + CUptiResult res = CUPTI_CALL(cuptiEventGroupSetsCreate( + context_, sizeof(CUpti_EventID) * ids.size(), ids.data(), &group_sets)); + + if (res != CUPTI_SUCCESS || group_sets == nullptr) { + const char* errstr = nullptr; + CUPTI_CALL(cuptiGetResultString(res, &errstr)); + throw std::system_error(EINVAL, std::generic_category(), errstr); + } + + return group_sets; +} + +void CuptiEventApi::destroyGroupSets(CUpti_EventGroupSets* sets) { + CUPTI_CALL(cuptiEventGroupSetsDestroy(sets)); +} + +bool CuptiEventApi::setContinuousMode() { + // Avoid logging noise for CUPTI_ERROR_LEGACY_PROFILER_NOT_SUPPORTED + CUptiResult res = CUPTI_CALL_NOWARN(cuptiSetEventCollectionMode( + context_, CUPTI_EVENT_COLLECTION_MODE_CONTINUOUS)); + if (res == CUPTI_ERROR_LEGACY_PROFILER_NOT_SUPPORTED) { + return false; + } + // Log warning on other errors + CUPTI_CALL(res); + return (res == CUPTI_SUCCESS); +} + +void CuptiEventApi::enablePerInstance(CUpti_EventGroup eventGroup) { + uint32_t profile_all = 1; + CUPTI_CALL(cuptiEventGroupSetAttribute( + eventGroup, + CUPTI_EVENT_GROUP_ATTR_PROFILE_ALL_DOMAIN_INSTANCES, + sizeof(profile_all), + &profile_all)); +} + +uint32_t CuptiEventApi::instanceCount(CUpti_EventGroup eventGroup) { + uint32_t instance_count = 0; + size_t s = sizeof(instance_count); + CUPTI_CALL(cuptiEventGroupGetAttribute( + eventGroup, CUPTI_EVENT_GROUP_ATTR_INSTANCE_COUNT, &s, &instance_count)); + return instance_count; +} + +void CuptiEventApi::enableGroupSet(CUpti_EventGroupSet& set) { + CUptiResult res = CUPTI_CALL_NOWARN(cuptiEventGroupSetEnable(&set)); + if (res != CUPTI_SUCCESS) { + const char* errstr = nullptr; + CUPTI_CALL(cuptiGetResultString(res, &errstr)); + throw std::system_error(EIO, std::generic_category(), errstr); + } +} + +void CuptiEventApi::disableGroupSet(CUpti_EventGroupSet& set) { + CUPTI_CALL(cuptiEventGroupSetDisable(&set)); +} + +void CuptiEventApi::readEvent( + CUpti_EventGroup grp, + CUpti_EventID id, + vector& vals) { + size_t s = sizeof(int64_t) * vals.size(); + CUPTI_CALL(cuptiEventGroupReadEvent( + grp, + CUPTI_EVENT_READ_FLAG_NONE, + id, + &s, + reinterpret_cast(vals.data()))); +} + +vector CuptiEventApi::eventsInGroup(CUpti_EventGroup grp) { + uint32_t group_size = 0; + size_t s = sizeof(group_size); + CUPTI_CALL(cuptiEventGroupGetAttribute( + grp, CUPTI_EVENT_GROUP_ATTR_NUM_EVENTS, &s, &group_size)); + size_t events_size = group_size * sizeof(CUpti_EventID); + vector res(group_size); + CUPTI_CALL(cuptiEventGroupGetAttribute( + grp, CUPTI_EVENT_GROUP_ATTR_EVENTS, &events_size, res.data())); + return res; +} + +CUpti_EventID CuptiEventApi::eventId(const std::string& name) { + CUpti_EventID id{0}; + CUPTI_CALL(cuptiEventGetIdFromName(device_, name.c_str(), &id)); + return id; +} + +} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/CuptiEventApi.h b/plugins/tensorboard-plugins/libkineto/src/CuptiEventApi.h new file mode 100644 index 0000000000000000000000000000000000000000..79610f93f0ecfa62a9508d4caddfa876518169d3 --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/src/CuptiEventApi.h @@ -0,0 +1,49 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#pragma once + +#include +#include +#include + +namespace KINETO_NAMESPACE { + +// C++ interface to CUPTI Events C API. +// Virtual methods are here mainly to allow easier testing. +class CuptiEventApi { + public: + explicit CuptiEventApi(CUcontext context_); + virtual ~CuptiEventApi() {} + + CUdevice device() { + return device_; + } + + virtual CUpti_EventGroupSets* createGroupSets( + std::vector& ids); + virtual void destroyGroupSets(CUpti_EventGroupSets* sets); + + virtual bool setContinuousMode(); + + virtual void enablePerInstance(CUpti_EventGroup eventGroup); + virtual uint32_t instanceCount(CUpti_EventGroup eventGroup); + + virtual void enableGroupSet(CUpti_EventGroupSet& set); + virtual void disableGroupSet(CUpti_EventGroupSet& set); + + virtual void + readEvent(CUpti_EventGroup g, CUpti_EventID id, std::vector& vals); + virtual std::vector eventsInGroup(CUpti_EventGroup g); + + virtual CUpti_EventID eventId(const std::string& name); + + protected: + // Unit testing + CuptiEventApi() : context_(nullptr), device_(0) {} + + private: + CUcontext context_; + CUdevice device_; +}; + +} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/CuptiMetricApi.cpp b/plugins/tensorboard-plugins/libkineto/src/CuptiMetricApi.cpp new file mode 100644 index 0000000000000000000000000000000000000000..36401e7434108d1da079aa4ba0264192c5d62838 --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/src/CuptiMetricApi.cpp @@ -0,0 +1,107 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#include "CuptiMetricApi.h" + +#include + +#include "Logger.h" +#include "cupti_call.h" + +using namespace std::chrono; +using std::vector; + +namespace KINETO_NAMESPACE { + +CUpti_MetricID CuptiMetricApi::idFromName(const std::string& name) { + CUpti_MetricID metric_id{~0u}; + CUptiResult res = + CUPTI_CALL(cuptiMetricGetIdFromName(device_, name.c_str(), &metric_id)); + if (res == CUPTI_ERROR_INVALID_METRIC_NAME) { + LOG(WARNING) << "Invalid metric name: " << name; + } + return metric_id; +} + +// Return a map of event IDs and names for a given metric id. +// Note that many events don't have a name. In that case the name will +// be set to the empty string. +std::map CuptiMetricApi::events( + CUpti_MetricID metric_id) { + uint32_t num_events = 0; + CUPTI_CALL(cuptiMetricGetNumEvents(metric_id, &num_events)); + vector ids(num_events); + size_t array_size = num_events * sizeof(CUpti_EventID); + CUPTI_CALL(cuptiMetricEnumEvents(metric_id, &array_size, ids.data())); + std::map res; + for (CUpti_EventID id : ids) { + // Attempt to lookup name from CUPTI + constexpr size_t kMaxEventNameLength = 64; + char cupti_name[kMaxEventNameLength]; + size_t size = kMaxEventNameLength; + CUPTI_CALL( + cuptiEventGetAttribute(id, CUPTI_EVENT_ATTR_NAME, &size, cupti_name)); + cupti_name[kMaxEventNameLength - 1] = 0; + + // CUPTI "helpfully" returns "event_name" when the event is unnamed. + if (size > 0 && strcmp(cupti_name, "event_name") != 0) { + res.emplace(id, cupti_name); + } else { + res.emplace(id, ""); + } + } + return res; +} + +CUpti_MetricValueKind CuptiMetricApi::valueKind(CUpti_MetricID metric) { + CUpti_MetricValueKind res{CUPTI_METRIC_VALUE_KIND_FORCE_INT}; + size_t value_kind_size = sizeof(res); + CUPTI_CALL(cuptiMetricGetAttribute( + metric, CUPTI_METRIC_ATTR_VALUE_KIND, &value_kind_size, &res)); + return res; +} + +CUpti_MetricEvaluationMode CuptiMetricApi::evaluationMode( + CUpti_MetricID metric) { + CUpti_MetricEvaluationMode eval_mode{ + CUPTI_METRIC_EVALUATION_MODE_PER_INSTANCE}; + size_t eval_mode_size = sizeof(eval_mode); + CUPTI_CALL(cuptiMetricGetAttribute( + metric, CUPTI_METRIC_ATTR_EVALUATION_MODE, &eval_mode_size, &eval_mode)); + return eval_mode; +} + +// FIXME: Consider caching value kind here +SampleValue CuptiMetricApi::calculate( + CUpti_MetricID metric, + CUpti_MetricValueKind kind, + vector& events, + vector& values, + int64_t duration) { + CUpti_MetricValue metric_value; + CUPTI_CALL(cuptiMetricGetValue( + device_, + metric, + events.size() * sizeof(CUpti_EventID), + events.data(), + values.size() * sizeof(int64_t), + reinterpret_cast(values.data()), + duration, + &metric_value)); + + switch (kind) { + case CUPTI_METRIC_VALUE_KIND_DOUBLE: + case CUPTI_METRIC_VALUE_KIND_PERCENT: + return SampleValue(metric_value.metricValueDouble); + case CUPTI_METRIC_VALUE_KIND_UINT64: + case CUPTI_METRIC_VALUE_KIND_INT64: + case CUPTI_METRIC_VALUE_KIND_THROUGHPUT: + return SampleValue(metric_value.metricValueUint64); + case CUPTI_METRIC_VALUE_KIND_UTILIZATION_LEVEL: + return SampleValue((int)metric_value.metricValueUtilizationLevel); + default: + assert(false); + } + return SampleValue(-1); +} + +} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/CuptiMetricApi.h b/plugins/tensorboard-plugins/libkineto/src/CuptiMetricApi.h new file mode 100644 index 0000000000000000000000000000000000000000..f45d38cd6169dc7fd30208dbb7dac09fd8a9dee5 --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/src/CuptiMetricApi.h @@ -0,0 +1,38 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#pragma once + +#include + +#include +#include + +#include "SampleListener.h" + +namespace KINETO_NAMESPACE { + +// C++ interface to CUPTI Metrics C API. +// Virtual methods are here mainly to allow easier testing. +class CuptiMetricApi { + public: + explicit CuptiMetricApi(CUdevice device) : device_(device) {} + virtual ~CuptiMetricApi() {} + + virtual CUpti_MetricID idFromName(const std::string& name); + virtual std::map events(CUpti_MetricID metric_id); + + virtual CUpti_MetricValueKind valueKind(CUpti_MetricID metric); + virtual CUpti_MetricEvaluationMode evaluationMode(CUpti_MetricID metric); + + virtual SampleValue calculate( + CUpti_MetricID metric, + CUpti_MetricValueKind kind, + std::vector& events, + std::vector& values, + int64_t duration); + + private: + CUdevice device_; +}; + +} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/CuptiNvPerfMetric.cpp b/plugins/tensorboard-plugins/libkineto/src/CuptiNvPerfMetric.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d1b08ab2c13d0615221e71f43f07c3d3fe102a2f --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/src/CuptiNvPerfMetric.cpp @@ -0,0 +1,504 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#ifdef HAS_CUPTI +#include +#if defined(CUDART_VERSION) && CUDART_VERSION > 10000 && CUDART_VERSION < 11040 +#include +#include +#include +#endif // cuda version > 10.00 and < 11.04 +#endif // HAS_CUPTI + +// TODO(T90238193) +// @lint-ignore-every CLANGTIDY facebook-hte-RelativeInclude +#include "ScopeExit.h" +#include "CuptiNvPerfMetric.h" +#include "Logger.h" + +namespace KINETO_NAMESPACE { + +// Add a namespace to isolate these utility functions that are only +// going to be used by the CuptiRangeProfiler. These included calls +// to NVIDIA PerfWorks APIs. +namespace nvperf { + + +// Largely based on NVIDIA sample code provided with CUDA release +// files Metric.cpp and Eval.cpp + +// ------------------------------------------------- +// Metric and Counter Data Configuration +// ------------------------------------------------- + + +// Note: Be carful before modifying the code below. There is a specific +// sequence one needs to follow to program the metrics else things may +// stop working. We tried to keep the flow consistent with the example +// code from NVIDIA. Since most of the programmability comes from +// the CUPTI profiler metric names this should be okay. + +// Only supported on CUDA RT Version between 10.0 and 11.04. +// After CUDA RT 11.04, the structure has changed. +// TODO update the structure NVPA_RawMetricsConfig to support 11.04 +#if defined(CUDART_VERSION) && CUDART_VERSION > 10000 && CUDART_VERSION < 11040 + +bool getRawMetricRequests( + NVPA_MetricsContext* metricsContext, + std::vector metricNames, + std::vector& rawMetricsDeps, + std::vector& rawMetricRequests) { + bool isolated = true; + /* Bug in collection with collection of metrics without instances, keep it + * to true*/ + bool keepInstances = true; + + for (const auto& metricName : metricNames) { + + NVPW_MetricsContext_GetMetricProperties_Begin_Params + getMetricPropertiesBeginParams = { + NVPW_MetricsContext_GetMetricProperties_Begin_Params_STRUCT_SIZE, nullptr}; + getMetricPropertiesBeginParams.pMetricsContext = metricsContext; + getMetricPropertiesBeginParams.pMetricName = metricName.c_str(); + + if (!NVPW_CALL( + NVPW_MetricsContext_GetMetricProperties_Begin( + &getMetricPropertiesBeginParams))) { + return false; + } + + for (const char** metricDepsIt = + getMetricPropertiesBeginParams.ppRawMetricDependencies; + *metricDepsIt; + ++metricDepsIt) { + rawMetricsDeps.push_back(*metricDepsIt); + } + + NVPW_MetricsContext_GetMetricProperties_End_Params + getMetricPropertiesEndParams = { + NVPW_MetricsContext_GetMetricProperties_End_Params_STRUCT_SIZE, nullptr}; + getMetricPropertiesEndParams.pMetricsContext = metricsContext; + + if (!NVPW_CALL(NVPW_MetricsContext_GetMetricProperties_End( + &getMetricPropertiesEndParams))) { + return false; + } + } + + for (const auto& rawMetricName : rawMetricsDeps) { + NVPA_RawMetricRequest metricRequest = {NVPA_RAW_METRIC_REQUEST_STRUCT_SIZE, nullptr}; + metricRequest.pMetricName = rawMetricName.c_str(); + metricRequest.isolated = isolated; + metricRequest.keepInstances = keepInstances; + rawMetricRequests.push_back(metricRequest); + VLOG(1) << "Adding raw metric struct : raw metric = " << rawMetricName + << " isolated = " << isolated << " keepinst = " << keepInstances; + } + + if (rawMetricRequests.size() == 0) { + LOG(WARNING) << "CUPTI Profiler was unable to configure any metrics"; + return false; + } + return true; +} + +// Setup CUPTI Profiler Config Image +bool getProfilerConfigImage( + const std::string& chipName, + const std::vector& metricNames, + std::vector& configImage, + const uint8_t* counterAvailabilityImage) { + + NVPW_CUDA_MetricsContext_Create_Params metricsContextCreateParams = { + NVPW_CUDA_MetricsContext_Create_Params_STRUCT_SIZE, nullptr}; + metricsContextCreateParams.pChipName = chipName.c_str(); + + if (!NVPW_CALL( + NVPW_CUDA_MetricsContext_Create(&metricsContextCreateParams))) { + return false; + } + + NVPW_MetricsContext_Destroy_Params metricsContextDestroyParams = { + NVPW_MetricsContext_Destroy_Params_STRUCT_SIZE, nullptr}; + metricsContextDestroyParams.pMetricsContext = + metricsContextCreateParams.pMetricsContext; + + SCOPE_EXIT([&]() { + NVPW_MetricsContext_Destroy( + (NVPW_MetricsContext_Destroy_Params*)&metricsContextDestroyParams); + }); + + // Get all raw metrics required for given metricNames list + std::vector rawMetricRequests; + + // note: we need a variable at this functions scope to hold the string + // pointers for underlying C char arrays. + std::vector rawMetricDeps; + + if (!getRawMetricRequests( + metricsContextCreateParams.pMetricsContext, + metricNames, + rawMetricDeps, + rawMetricRequests)) { + return false; + } + + NVPA_RawMetricsConfigOptions metricsConfigOptions = { + NVPA_RAW_METRICS_CONFIG_OPTIONS_STRUCT_SIZE, nullptr}; + metricsConfigOptions.activityKind = NVPA_ACTIVITY_KIND_PROFILER; + metricsConfigOptions.pChipName = chipName.c_str(); + NVPA_RawMetricsConfig* rawMetricsConfig; + if (!NVPW_CALL( + NVPA_RawMetricsConfig_Create( + &metricsConfigOptions, &rawMetricsConfig))) { + return false; + } + + // TODO check if this is required + if (counterAvailabilityImage) { + NVPW_RawMetricsConfig_SetCounterAvailability_Params + setCounterAvailabilityParams = { + NVPW_RawMetricsConfig_SetCounterAvailability_Params_STRUCT_SIZE, nullptr}; + setCounterAvailabilityParams.pRawMetricsConfig = rawMetricsConfig; + setCounterAvailabilityParams.pCounterAvailabilityImage = + counterAvailabilityImage; + if (!NVPW_CALL( + NVPW_RawMetricsConfig_SetCounterAvailability( + &setCounterAvailabilityParams))) { + return false; + } + } + + NVPW_RawMetricsConfig_Destroy_Params rawMetricsConfigDestroyParams = { + NVPW_RawMetricsConfig_Destroy_Params_STRUCT_SIZE, nullptr}; + rawMetricsConfigDestroyParams.pRawMetricsConfig = rawMetricsConfig; + SCOPE_EXIT([&]() { + NVPW_RawMetricsConfig_Destroy( + (NVPW_RawMetricsConfig_Destroy_Params*)&rawMetricsConfigDestroyParams); + }); + + // Start a Raw Metric Pass group + NVPW_RawMetricsConfig_BeginPassGroup_Params beginPassGroupParams = { + NVPW_RawMetricsConfig_BeginPassGroup_Params_STRUCT_SIZE, nullptr}; + beginPassGroupParams.pRawMetricsConfig = rawMetricsConfig; + if (!NVPW_CALL( + NVPW_RawMetricsConfig_BeginPassGroup(&beginPassGroupParams))) { + return false; + } + + // Add all raw metrics + NVPW_RawMetricsConfig_AddMetrics_Params addMetricsParams = { + NVPW_RawMetricsConfig_AddMetrics_Params_STRUCT_SIZE, nullptr}; + addMetricsParams.pRawMetricsConfig = rawMetricsConfig; + addMetricsParams.pRawMetricRequests = rawMetricRequests.data(); + addMetricsParams.numMetricRequests = rawMetricRequests.size(); + if (!NVPW_CALL( + NVPW_RawMetricsConfig_AddMetrics(&addMetricsParams))) { + return false; + } + + // End pass group + NVPW_RawMetricsConfig_EndPassGroup_Params endPassGroupParams = { + NVPW_RawMetricsConfig_EndPassGroup_Params_STRUCT_SIZE, nullptr}; + endPassGroupParams.pRawMetricsConfig = rawMetricsConfig; + if (!NVPW_CALL( + NVPW_RawMetricsConfig_EndPassGroup(&endPassGroupParams))) { + return false; + } + + // Setup Config Image generation + NVPW_RawMetricsConfig_GenerateConfigImage_Params generateConfigImageParams = { + NVPW_RawMetricsConfig_GenerateConfigImage_Params_STRUCT_SIZE, nullptr}; + generateConfigImageParams.pRawMetricsConfig = rawMetricsConfig; + if (!NVPW_CALL( + NVPW_RawMetricsConfig_GenerateConfigImage(&generateConfigImageParams))) { + return false; + } + + // Get the Config Image size... nearly there + NVPW_RawMetricsConfig_GetConfigImage_Params getConfigImageParams = { + NVPW_RawMetricsConfig_GetConfigImage_Params_STRUCT_SIZE, nullptr}; + getConfigImageParams.pRawMetricsConfig = rawMetricsConfig; + getConfigImageParams.bytesAllocated = 0; + getConfigImageParams.pBuffer = nullptr; + if (!NVPW_CALL( + NVPW_RawMetricsConfig_GetConfigImage(&getConfigImageParams))) { + return false; + } + + configImage.resize(getConfigImageParams.bytesCopied); + + // Write the Config image binary + getConfigImageParams.bytesAllocated = configImage.size(); + getConfigImageParams.pBuffer = configImage.data(); + if (!NVPW_CALL( + NVPW_RawMetricsConfig_GetConfigImage(&getConfigImageParams))) { + return false; + } + + return true; +} + +bool getCounterDataPrefixImage( + const std::string& chipName, + const std::vector& metricNames, + std::vector& counterDataImagePrefix) { + + NVPW_CUDA_MetricsContext_Create_Params metricsContextCreateParams = { + NVPW_CUDA_MetricsContext_Create_Params_STRUCT_SIZE, nullptr}; + metricsContextCreateParams.pChipName = chipName.c_str(); + + if (!NVPW_CALL( + NVPW_CUDA_MetricsContext_Create(&metricsContextCreateParams))) { + return false; + } + + NVPW_MetricsContext_Destroy_Params metricsContextDestroyParams = { + NVPW_MetricsContext_Destroy_Params_STRUCT_SIZE, nullptr}; + metricsContextDestroyParams.pMetricsContext = + metricsContextCreateParams.pMetricsContext; + + + SCOPE_EXIT([&]() { + NVPW_MetricsContext_Destroy( + (NVPW_MetricsContext_Destroy_Params*)&metricsContextDestroyParams); + }); + + // Get all raw metrics required for given metricNames list + std::vector rawMetricRequests; + + // note: we need a variable at this functions scope to hold the string + // pointers for underlying C char arrays. + std::vector rawMetricDeps; + + if (!getRawMetricRequests( + metricsContextCreateParams.pMetricsContext, + metricNames, + rawMetricDeps, + rawMetricRequests)) { + return false; + } + + // Setup Counter Data builder + NVPW_CounterDataBuilder_Create_Params counterDataBuilderCreateParams = { + NVPW_CounterDataBuilder_Create_Params_STRUCT_SIZE, nullptr}; + counterDataBuilderCreateParams.pChipName = chipName.c_str(); + if (!NVPW_CALL( + NVPW_CounterDataBuilder_Create(&counterDataBuilderCreateParams))) { + return false; + } + + NVPW_CounterDataBuilder_Destroy_Params counterDataBuilderDestroyParams = { + NVPW_CounterDataBuilder_Destroy_Params_STRUCT_SIZE, nullptr}; + counterDataBuilderDestroyParams.pCounterDataBuilder = + counterDataBuilderCreateParams.pCounterDataBuilder; + SCOPE_EXIT([&]() { + NVPW_CounterDataBuilder_Destroy(( + NVPW_CounterDataBuilder_Destroy_Params*)&counterDataBuilderDestroyParams); + }); + + // Add metrics to counter data image prefix + NVPW_CounterDataBuilder_AddMetrics_Params addMetricsParams = { + NVPW_CounterDataBuilder_AddMetrics_Params_STRUCT_SIZE, nullptr}; + addMetricsParams.pCounterDataBuilder = + counterDataBuilderCreateParams.pCounterDataBuilder; + addMetricsParams.pRawMetricRequests = rawMetricRequests.data(); + addMetricsParams.numMetricRequests = rawMetricRequests.size(); + if (!NVPW_CALL( + NVPW_CounterDataBuilder_AddMetrics(&addMetricsParams))) { + return false; + } + + // Get image prefix size + NVPW_CounterDataBuilder_GetCounterDataPrefix_Params + getCounterDataPrefixParams = { + NVPW_CounterDataBuilder_GetCounterDataPrefix_Params_STRUCT_SIZE, nullptr}; + getCounterDataPrefixParams.pCounterDataBuilder = + counterDataBuilderCreateParams.pCounterDataBuilder; + getCounterDataPrefixParams.bytesAllocated = 0; + getCounterDataPrefixParams.pBuffer = nullptr; + if (!NVPW_CALL( + NVPW_CounterDataBuilder_GetCounterDataPrefix( + &getCounterDataPrefixParams))) { + return false; + } + + counterDataImagePrefix.resize(getCounterDataPrefixParams.bytesCopied); + + // Now write counter data image prefix + getCounterDataPrefixParams.bytesAllocated = counterDataImagePrefix.size(); + getCounterDataPrefixParams.pBuffer = counterDataImagePrefix.data(); + if (!NVPW_CALL( + NVPW_CounterDataBuilder_GetCounterDataPrefix( + &getCounterDataPrefixParams))) { + return false; + } + + return true; +} + +// ------------------------------------------------- +// Metric and Counter Evaluation Utilities +// ------------------------------------------------- + +std::string getRangeDescription( + const std::vector& counterDataImage, + int rangeIndex) { + std::vector descriptionPtrs; + + NVPW_Profiler_CounterData_GetRangeDescriptions_Params getRangeDescParams = { + NVPW_Profiler_CounterData_GetRangeDescriptions_Params_STRUCT_SIZE, nullptr}; + getRangeDescParams.pCounterDataImage = counterDataImage.data(); + getRangeDescParams.rangeIndex = rangeIndex; + + if (!NVPW_CALL( + NVPW_Profiler_CounterData_GetRangeDescriptions(&getRangeDescParams))) { + return ""; + } + + descriptionPtrs.resize(getRangeDescParams.numDescriptions); + getRangeDescParams.ppDescriptions = descriptionPtrs.data(); + + if (!NVPW_CALL( + NVPW_Profiler_CounterData_GetRangeDescriptions(&getRangeDescParams))) { + return ""; + } + + std::string rangeName; + + for (size_t i = 0; i < getRangeDescParams.numDescriptions; i++) { + if (i > 0) { + rangeName.append("/"); + } + rangeName.append(descriptionPtrs[i]); + } + return rangeName; +} + +CuptiProfilerResult evalMetricValues( + const std::string& chipName, + const std::vector& counterDataImage, + const std::vector& metricNames, + bool verbose) { + + if (!counterDataImage.size()) { + LOG(ERROR) << "Counter Data Image is empty!"; + return {}; + } + + NVPW_CUDA_MetricsContext_Create_Params metricsContextCreateParams = { + NVPW_CUDA_MetricsContext_Create_Params_STRUCT_SIZE, nullptr}; + metricsContextCreateParams.pChipName = chipName.c_str(); + if (!NVPW_CALL( + NVPW_CUDA_MetricsContext_Create(&metricsContextCreateParams))) { + return {}; + } + + NVPW_MetricsContext_Destroy_Params metricsContextDestroyParams = { + NVPW_MetricsContext_Destroy_Params_STRUCT_SIZE, nullptr}; + metricsContextDestroyParams.pMetricsContext = + metricsContextCreateParams.pMetricsContext; + SCOPE_EXIT([&]() { + NVPW_MetricsContext_Destroy( + (NVPW_MetricsContext_Destroy_Params*)&metricsContextDestroyParams); + }); + + NVPW_CounterData_GetNumRanges_Params getNumRangesParams = { + NVPW_CounterData_GetNumRanges_Params_STRUCT_SIZE, nullptr}; + getNumRangesParams.pCounterDataImage = counterDataImage.data(); + if (!NVPW_CALL( + NVPW_CounterData_GetNumRanges(&getNumRangesParams))) { + return {}; + } + + // TBD in the future support special chars in metric name + // for now these are default + const bool isolated = true; + + // API takes a 2D array of chars + std::vector metricNamePtrs; + + for (const auto& metric : metricNames) { + metricNamePtrs.push_back(metric.c_str()); + } + + CuptiProfilerResult result{ + .metricNames = metricNames}; + + for (size_t rangeIndex = 0; rangeIndex < getNumRangesParams.numRanges; + ++rangeIndex) { + + CuptiRangeMeasurement rangeData { + .rangeName = getRangeDescription(counterDataImage, rangeIndex)}; + rangeData.values.resize(metricNames.size()); + + // First set Counter data image with current range + NVPW_MetricsContext_SetCounterData_Params setCounterDataParams = { + NVPW_MetricsContext_SetCounterData_Params_STRUCT_SIZE, nullptr}; + + setCounterDataParams.pMetricsContext = + metricsContextCreateParams.pMetricsContext; + setCounterDataParams.pCounterDataImage = counterDataImage.data(); + setCounterDataParams.isolated = isolated; + setCounterDataParams.rangeIndex = rangeIndex; + + NVPW_CALL(NVPW_MetricsContext_SetCounterData(&setCounterDataParams)); + + + // Now we can evaluate GPU metrics + NVPW_MetricsContext_EvaluateToGpuValues_Params evalToGpuParams = { + NVPW_MetricsContext_EvaluateToGpuValues_Params_STRUCT_SIZE, nullptr}; + evalToGpuParams.pMetricsContext = + metricsContextCreateParams.pMetricsContext; + evalToGpuParams.numMetrics = metricNamePtrs.size(); + evalToGpuParams.ppMetricNames = metricNamePtrs.data(); + evalToGpuParams.pMetricValues = rangeData.values.data(); + + if (!NVPW_CALL(NVPW_MetricsContext_EvaluateToGpuValues(&evalToGpuParams))) { + LOG(WARNING) << "Failed to evaluate metris for range : " + << rangeData.rangeName; + continue; + } + + if (verbose) { + for (size_t i = 0; i < metricNames.size(); i++) { + LOG(INFO) << "rangeName: " << rangeData.rangeName + << "\tmetricName: " << metricNames[i] + << "\tgpuValue: " << rangeData.values[i]; + } + } + + result.rangeVals.emplace_back(std::move(rangeData)); + } + + return result; +} + +#else + +bool getProfilerConfigImage( + const std::string& /*chipName*/, + const std::vector& /*metricNames*/, + std::vector& /*configImage*/, + const uint8_t* /*counterAvailabilityImage*/) { + return false; +} + +bool getCounterDataPrefixImage( + const std::string& /*chipName*/, + const std::vector& /*metricNames*/, + std::vector& /*counterDataImagePrefix*/) { + return false; +} + +CuptiProfilerResult evalMetricValues( + const std::string& /*chipName*/, + const std::vector& /*counterDataImage*/, + const std::vector& /*metricNames*/, + bool /*verbose*/) { + return {}; +} + +#endif // cuda version > 10.00 and < 11.04 + +} // namespace nvperf +} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/CuptiNvPerfMetric.h b/plugins/tensorboard-plugins/libkineto/src/CuptiNvPerfMetric.h new file mode 100644 index 0000000000000000000000000000000000000000..d5dd1b1c1d20b066891f8be679e6d6371d4f4a9b --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/src/CuptiNvPerfMetric.h @@ -0,0 +1,71 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#pragma once + +#include +#include +#include + +// TODO(T90238193) +// @lint-ignore-every CLANGTIDY facebook-hte-RelativeInclude +#include "Logger.h" + +namespace KINETO_NAMESPACE { + +struct CuptiRangeMeasurement { + std::string rangeName; + std::vector values; +}; + +struct CuptiProfilerResult { + std::vector metricNames; + // rangeName, list values + std::vector rangeVals; +}; + +/* Utilities for CUPTI and NVIDIA PerfWorks Metric API + */ + +#define NVPW_CALL(call) \ + [&]() -> bool { \ + NVPA_Status _status_ = call; \ + if (_status_ != NVPA_STATUS_SUCCESS) { \ + LOG(WARNING) << fmt::format( \ + "function {} failed with error ({})", \ + #call, \ + (int)_status_); \ + return false; \ + } \ + return true; \ + }() + +// fixme - add a results string +// nvpperfGetResultString(_status_, &_errstr_); + +namespace nvperf { + +// Setup CUPTI profiler configuration blob and counter data image prefix +bool getProfilerConfigImage( + const std::string& chipName, + const std::vector& metricNames, + std::vector& configImage, + const uint8_t* counterAvailabilityImage = nullptr); + +// Setup CUPTI profiler configuration blob and counter data image prefix +bool getCounterDataPrefixImage( + const std::string& chipName, + const std::vector& metricNames, + std::vector& counterDataImagePrefix); + +/* NV Perf Metric Evaluation helpers + * - utilities to read binary data and obtain metrics for ranges + */ +CuptiProfilerResult evalMetricValues( + const std::string& chipName, + const std::vector& counterDataImage, + const std::vector& metricNames, + bool verbose = false); + + +} // namespace nvperf +} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/CuptiRangeProfilerApi.cpp b/plugins/tensorboard-plugins/libkineto/src/CuptiRangeProfilerApi.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e5f18ed7b0b70963eb2deab126ff4f7119ed582b --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/src/CuptiRangeProfilerApi.cpp @@ -0,0 +1,751 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#include +#include +#ifdef HAS_CUPTI +#include +#include +#endif // HAS_CUPTI +#include +#include + +#ifdef HAS_CUPTI +#include "cupti_call.h" +#endif + +#include "time_since_epoch.h" +#include "Logger.h" +#include "Demangle.h" + +// TODO(T90238193) +// @lint-ignore-every CLANGTIDY facebook-hte-RelativeInclude +#include "CuptiCallbackApiMock.h" +#include "CuptiRangeProfilerApi.h" + +#if HAS_CUPTI_RANGE_PROFILER +#include +#include +#include "cupti_call.h" +#endif // HAS_CUPTI_RANGE_PROFILER + +namespace KINETO_NAMESPACE { + +#if HAS_CUPTI_RANGE_PROFILER +constexpr char kRootUserRangeName[] = "__profile__"; +constexpr int kCallbacksCountToFlush = 500; + +// Should we set Counter availability image ourselves? +// Disabled this right now as this call conflicts with DCGM +// It is not clear why it should conflict except it being a profiler API call +// TODO Revisit +constexpr bool kSetCounterAvail = false; + +// Shared state to track one Cupti Profiler API per Device +namespace { +// per device profiler maps +std::unordered_map profiler_map; +std::unordered_map enable_flag; +std::unordered_map disable_flag; + +std::mutex contextMutex_; +std::unordered_map ctx_to_dev; +std::set active_devices; +} + +// forward declarations +void __trackCudaCtx(CUcontext ctx, uint32_t device_id, CUpti_CallbackId cbid); +void __trackCudaKernelLaunch(CUcontext ctx, const char* kernelName); + +/// Helper functions + +// Available raw counters +std::vector getCounterAvailiability(CUcontext cuContext) { + std::vector counterAvailabilityImage; + CUpti_Profiler_GetCounterAvailability_Params getCounterAvailabilityParams = { + CUpti_Profiler_GetCounterAvailability_Params_STRUCT_SIZE, nullptr}; + getCounterAvailabilityParams.ctx = cuContext; + CUPTI_CALL( + cuptiProfilerGetCounterAvailability(&getCounterAvailabilityParams)); + + counterAvailabilityImage.clear(); + counterAvailabilityImage.resize( + getCounterAvailabilityParams.counterAvailabilityImageSize); + + getCounterAvailabilityParams.pCounterAvailabilityImage = + counterAvailabilityImage.data(); + CUPTI_CALL( + cuptiProfilerGetCounterAvailability(&getCounterAvailabilityParams)); + + return counterAvailabilityImage; +} + +std::string getChipName(int deviceId) { + // Get chip name for the cuda device + CUpti_Device_GetChipName_Params getChipNameParams = { + CUpti_Device_GetChipName_Params_STRUCT_SIZE, nullptr}; + + getChipNameParams.deviceIndex = deviceId; + CUPTI_CALL(cuptiDeviceGetChipName(&getChipNameParams)); + + return getChipNameParams.pChipName; +} + +inline uint32_t getDevID(CUcontext ctx) { + uint32_t device_id = UINT32_MAX; + CUPTI_CALL(cuptiGetDeviceId(ctx, &device_id)); + if (device_id == UINT32_MAX) { + LOG(ERROR) << "Could not determine dev id for = " << ctx; + } + return device_id; +} + +// We use CUPTI Callback functions in three ways : +// 1. Track cuda contexts and maintain a list of active GPUs to profile +// 2. Callbacks on kernel launches to track the name of automatic +// ranges that correspond to names of kernels +// 3. Lastly CUPTI profiler has to be enabled on the same thread executing +// the CUDA kernels. We use Callbacks to enable the profiler +// asynchronously from another thread. + +void disableKernelCallbacks(); + +void trackCudaCtx( + CUpti_CallbackDomain /*domain*/, + CUpti_CallbackId cbid, + const CUpti_CallbackData* cbInfo) { + auto *d = reinterpret_cast(cbInfo); + auto ctx = d->context; + uint32_t device_id = getDevID(ctx); + + if (device_id == UINT32_MAX) { + return; + } + + __trackCudaCtx(ctx, device_id, cbid); +} + +void __trackCudaCtx(CUcontext ctx, uint32_t device_id, CUpti_CallbackId cbid) { + std::lock_guard g(contextMutex_); + if (cbid == CUPTI_CBID_RESOURCE_CONTEXT_CREATED) { + VLOG(0) << "CUPTI Profiler observed CUDA Context created = " + << ctx << " device id = " << device_id; + active_devices.insert(device_id); + if constexpr (kSetCounterAvail) { + if (active_devices.size() == 1) { + CuptiRBProfilerSession::setCounterAvailabilityImage( + getCounterAvailiability(ctx)); + } + } + ctx_to_dev[ctx] = device_id; + + } else if (cbid == CUPTI_CBID_RESOURCE_CONTEXT_DESTROY_STARTING) { + VLOG(0) << "CUPTI Profiler observed CUDA Context destroyed = " + << ctx << " device id = " << device_id; + auto it = active_devices.find(device_id); + if (it != active_devices.end()) { + active_devices.erase(it); + ctx_to_dev.erase(ctx); + } + } +} + +void trackCudaKernelLaunch( + CUpti_CallbackDomain /*domain*/, + CUpti_CallbackId /*cbid*/, + const CUpti_CallbackData* cbInfo) { + VLOG(1) << " Trace : Callback name = " + << (cbInfo->symbolName ? cbInfo->symbolName: "") + << " context ptr = " << cbInfo->context; + auto ctx = cbInfo->context; + // should be in CUPTI_API_ENTER call site + if (cbInfo->callbackSite != CUPTI_API_ENTER) { + return; + } + __trackCudaKernelLaunch(ctx, cbInfo->symbolName); +} + +void __trackCudaKernelLaunch( + CUcontext ctx, + const char* kernelName) { + VLOG(0) << " Tracking kernel name = " << (kernelName ? kernelName : "") + << " context ptr = " << ctx; + + uint32_t device_id = 0; + auto it = ctx_to_dev.find(ctx); + if (it == ctx_to_dev.end()) { + // Warning here could be too noisy + VLOG(0) << " Could not find corresponding device to ctx = " << ctx; + return; + } else { + device_id = it->second; + } + + auto pit = profiler_map.find(device_id); + if (pit == profiler_map.end() || pit->second == nullptr) { + return; + } + auto profiler = pit->second; + + if (enable_flag[device_id]) { + LOG(INFO) << "Callback handler is enabling cupti profiler"; + profiler->startAndEnable(); + enable_flag[device_id] = false; + + } else if (disable_flag[device_id]) { + LOG(INFO) << "Callback handler is disabling cupti profiler"; + profiler->disableAndStop(); + return; + } + + if (profiler->curRange_ == CUPTI_AutoRange) { + profiler->logKernelName(kernelName ? kernelName : "__missing__"); + } + + /* TODO add per kernel time logging + if (measure_per_kernel) { + profiler->kernelStartTs_.push_back( + std::chrono::high_resolution_clock::now()); + } + */ + + // periodically flush profiler data from GPU + if (profiler->numCallbacks_ % kCallbacksCountToFlush == 0) { + profiler->flushCounterData(); + } + profiler->numCallbacks_++; +} + +void enableKernelCallbacks() { + auto& cbapi = CuptiCallbackApi::singleton(); + bool status = cbapi.enableCallback( + CUPTI_CB_DOMAIN_RUNTIME_API, + CUPTI_RUNTIME_TRACE_CBID_cudaLaunchKernel_v7000); + if (!status) { + LOG(WARNING) << "CUPTI Range Profiler unable to " + << "enable cuda kernel launch callback"; + return; + } + LOG(INFO) << "CUPTI Profiler kernel callbacks enabled"; +} + +void disableKernelCallbacks() { + auto& cbapi = CuptiCallbackApi::singleton(); + bool status = cbapi.disableCallback( + CUPTI_CB_DOMAIN_RUNTIME_API, + CUPTI_RUNTIME_TRACE_CBID_cudaLaunchKernel_v7000); + if (!status) { + LOG(WARNING) << "CUPTI Range Profiler unable to " + << "disable cuda kernel launch callback"; + return; + } + LOG(INFO) << "CUPTI Profiler kernel callbacks disabled"; +} + +// static +std::set CuptiRBProfilerSession::getActiveDevices() { + std::lock_guard g(contextMutex_); + return active_devices; +} + +// static +void CuptiRBProfilerSession::initCupti() { + CUpti_Profiler_Initialize_Params profilerInitializeParams = { + CUpti_Profiler_Initialize_Params_STRUCT_SIZE, nullptr}; + CUPTI_CALL(cuptiProfilerInitialize(&profilerInitializeParams)); +} + +// static +void CuptiRBProfilerSession::deInitCupti() { + CUpti_Profiler_DeInitialize_Params profilerDeInitializeParams = { + CUpti_Profiler_DeInitialize_Params_STRUCT_SIZE, nullptr}; + CUPTI_CALL(cuptiProfilerDeInitialize(&profilerDeInitializeParams)); +} + +// static +void CuptiRBProfilerSession::staticInit() { + CuptiRBProfilerSession::initCupti(); + + // Register CUPTI callbacks + auto& cbapi = CuptiCallbackApi::singleton(); + CUpti_CallbackDomain domain = CUPTI_CB_DOMAIN_RESOURCE; + bool status = cbapi.registerCallback( + domain, CuptiCallbackApi::RESOURCE_CONTEXT_CREATED, trackCudaCtx); + status = status && cbapi.registerCallback( + domain, CuptiCallbackApi::RESOURCE_CONTEXT_DESTROYED, trackCudaCtx); + status = status && cbapi.enableCallback( + domain, CUPTI_CBID_RESOURCE_CONTEXT_CREATED); + status = status && cbapi.enableCallback( + domain, CUPTI_CBID_RESOURCE_CONTEXT_DESTROY_STARTING); + + if (!status) { + LOG(WARNING) << "CUPTI Range Profiler unable to attach cuda context " + << "create and destroy callbacks"; + CUPTI_CALL(cbapi.getCuptiStatus()); + return; + } + + domain = CUPTI_CB_DOMAIN_RUNTIME_API; + status = cbapi.registerCallback( + domain, CuptiCallbackApi::CUDA_LAUNCH_KERNEL, trackCudaKernelLaunch); + + if (!status) { + LOG(WARNING) << "CUPTI Range Profiler unable to attach cuda kernel " + << "launch callback"; + return; + } +} + +// static +std::vector& CuptiRBProfilerSession::counterAvailabilityImage() { + static std::vector counterAvailabilityImage_; + return counterAvailabilityImage_; +} + + +// Setup the profiler sessions +CuptiRBProfilerSession::CuptiRBProfilerSession( + const std::vector& metricNames, + int deviceId, + int maxRanges, + int numNestingLevels, + CUcontext cuContext) + : metricNames_(metricNames), + chipName_(getChipName(deviceId)), + deviceId_(deviceId), + maxRanges_(maxRanges), + numNestingLevels_(numNestingLevels), + cuContext_(cuContext) { + CuptiRBProfilerSession::initCupti(); + + LOG(INFO) << "Initializing CUPTI profiler session : device = " << deviceId + << " chip = " << chipName_; + /* Generate configuration for metrics, this can also be done offline*/ + NVPW_InitializeHost_Params initializeHostParams = { + NVPW_InitializeHost_Params_STRUCT_SIZE, nullptr}; + NVPW_CALL(NVPW_InitializeHost(&initializeHostParams)); + + if (metricNames.size()) { + if (!nvperf::getProfilerConfigImage( + chipName_, + metricNames, + configImage, + CuptiRBProfilerSession::counterAvailabilityImage().data())) { + LOG(ERROR) << "Failed to create configImage or counterDataImagePrefix"; + return; + } + if (!nvperf::getCounterDataPrefixImage( + chipName_, + metricNames, + counterDataImagePrefix)) { + LOG(ERROR) << "Failed to create counterDataImagePrefix"; + return; + } + } else { + LOG(ERROR) << "No metrics provided to profile"; + return; + } + + if (!createCounterDataImage()) { + LOG(ERROR) << "Failed to create counterDataImage"; + return; + } + + LOG(INFO) << "Size of structs\n" + << " config image size = " << configImage.size() << " B" + << " counter data image prefix = " + << counterDataImagePrefix.size() << " B" + << " counter data image size = " << counterDataImage.size() / 1024 + << " KB" + << " counter sb image size = " + << counterDataScratchBuffer.size() << " B"; + + beginPassParams_ = {CUpti_Profiler_BeginPass_Params_STRUCT_SIZE, nullptr}; + endPassParams_ = {CUpti_Profiler_EndPass_Params_STRUCT_SIZE, nullptr}; + + initSuccess_ = true; + profiler_map[deviceId] = this; +} + +// used in unittests only +CuptiRBProfilerSession::CuptiRBProfilerSession(int deviceId, CUcontext ctx) + : deviceId_(deviceId), cuContext_(ctx) { + initSuccess_ = true; + profiler_map[deviceId] = this; +} + +void CuptiRBProfilerSession::startInternal( + CUpti_ProfilerRange profilerRange, + CUpti_ProfilerReplayMode profilerReplayMode) { + LOG(INFO) << "Starting profiler session: profiler range = " + << ((profilerRange == CUPTI_AutoRange) ? "autorange" : "userrange") + << " replay mode = " + << ((profilerReplayMode == CUPTI_KernelReplay) ? "kernel" : "user"); + if (!initSuccess_) { + LOG(WARNING) << __func__ << "() bailing out since initialization failed"; + return; + } + + if (cuContext_ == nullptr) { + for (const auto& it : ctx_to_dev) { + if (it.second == deviceId_) { + cuContext_ = it.first; + break; + } + } + LOG(INFO) << " Cupti Profiler using CUDA context = " << cuContext_; + } + + profilerStartTs_ = std::chrono::high_resolution_clock::now(); + curRange_ = profilerRange; + curReplay_ = profilerReplayMode; + + CUpti_Profiler_BeginSession_Params beginSessionParams = { + CUpti_Profiler_BeginSession_Params_STRUCT_SIZE, nullptr}; + + beginSessionParams.ctx = cuContext_; + beginSessionParams.counterDataImageSize = counterDataImage.size(); + beginSessionParams.pCounterDataImage = counterDataImage.data(); + beginSessionParams.counterDataScratchBufferSize = + counterDataScratchBuffer.size(); + beginSessionParams.pCounterDataScratchBuffer = counterDataScratchBuffer.data(); + beginSessionParams.range = profilerRange; + beginSessionParams.replayMode = profilerReplayMode; + beginSessionParams.maxRangesPerPass = maxRanges_; + beginSessionParams.maxLaunchesPerPass = maxRanges_; + + auto status = CUPTI_CALL(cuptiProfilerBeginSession(&beginSessionParams)); + if (status != CUPTI_SUCCESS) { + LOG(WARNING) << "Failed to start CUPTI profiler"; + initSuccess_ = false; + return; + } + + // Set counter configuration + CUpti_Profiler_SetConfig_Params setConfigParams = { + CUpti_Profiler_SetConfig_Params_STRUCT_SIZE, nullptr}; + + setConfigParams.ctx = cuContext_; + setConfigParams.pConfig = configImage.data(); + setConfigParams.configSize = configImage.size(); + setConfigParams.passIndex = 0; + setConfigParams.minNestingLevel = 1; + setConfigParams.numNestingLevels = numNestingLevels_; + status = CUPTI_CALL(cuptiProfilerSetConfig(&setConfigParams)); + + if (status != CUPTI_SUCCESS) { + LOG(WARNING) << "Failed to configure CUPTI profiler"; + initSuccess_ = false; + return; + } + profilerInitDoneTs_ = std::chrono::high_resolution_clock::now(); + + if (curRange_ == CUPTI_AutoRange) { + enableKernelCallbacks(); + } + profilingActive_ = true; +} + +void CuptiRBProfilerSession::stop() { + if (!initSuccess_) { + LOG(WARNING) << __func__ << "() bailing out since initialization failed"; + return; + } + LOG(INFO) << "Stop profiler session on device = " << deviceId_; + + CUpti_Profiler_UnsetConfig_Params unsetConfigParams = { + CUpti_Profiler_UnsetConfig_Params_STRUCT_SIZE, nullptr}; + CUPTI_CALL(cuptiProfilerUnsetConfig(&unsetConfigParams)); + + CUpti_Profiler_EndSession_Params endSessionParams = { + CUpti_Profiler_EndSession_Params_STRUCT_SIZE, nullptr}; + CUPTI_CALL(cuptiProfilerEndSession(&endSessionParams)); + + disableKernelCallbacks(); + + profilerStopTs_ = std::chrono::high_resolution_clock::now(); + profilingActive_ = false; +} + +void CuptiRBProfilerSession::beginPass() { + if (!initSuccess_) { + LOG(WARNING) << __func__ << "() bailing out since initialization failed"; + return; + } + CUPTI_CALL(cuptiProfilerBeginPass(&beginPassParams_)); +} + +bool CuptiRBProfilerSession::endPass() { + if (!initSuccess_) { + LOG(WARNING) << __func__ << "() bailing out since initialization failed"; + return true; + } + CUPTI_CALL(cuptiProfilerEndPass(&endPassParams_)); + return endPassParams_.allPassesSubmitted; +} + +void CuptiRBProfilerSession::flushCounterData() { + LOG(INFO) << "Flushing counter data on device = " << deviceId_; + CUpti_Profiler_FlushCounterData_Params flushCounterDataParams = { + CUpti_Profiler_FlushCounterData_Params_STRUCT_SIZE, nullptr}; + CUPTI_CALL(cuptiProfilerFlushCounterData(&flushCounterDataParams)); +} + +/// Enable and disable the profiler +void CuptiRBProfilerSession::enable() { + if (!initSuccess_) { + LOG(WARNING) << __func__ << "() bailing out since initialization failed"; + return; + } + CUpti_Profiler_EnableProfiling_Params enableProfilingParams = { + CUpti_Profiler_EnableProfiling_Params_STRUCT_SIZE, nullptr}; + CUPTI_CALL(cuptiProfilerEnableProfiling(&enableProfilingParams)); +} + +void CuptiRBProfilerSession::disable() { + if (!initSuccess_) { + LOG(WARNING) << __func__ << "() bailing out since initialization failed"; + return; + } + CUpti_Profiler_DisableProfiling_Params disableProfilingParams = { + CUpti_Profiler_DisableProfiling_Params_STRUCT_SIZE, nullptr}; + CUPTI_CALL(cuptiProfilerDisableProfiling(&disableProfilingParams)); +} + +/// User range based profiling +void CuptiRBProfilerSession::pushRange(const std::string& rangeName) { + LOG(INFO) << " CUPTI pushrange ( " << rangeName << " )"; + CUpti_Profiler_PushRange_Params pushRangeParams = { + CUpti_Profiler_PushRange_Params_STRUCT_SIZE, nullptr}; + pushRangeParams.pRangeName = rangeName.c_str(); + CUPTI_CALL(cuptiProfilerPushRange(&pushRangeParams)); +} + +void CuptiRBProfilerSession::popRange() { + LOG(INFO) << " CUPTI pop range"; + CUpti_Profiler_PopRange_Params popRangeParams = { + CUpti_Profiler_PopRange_Params_STRUCT_SIZE, nullptr}; + CUPTI_CALL(cuptiProfilerPopRange(&popRangeParams)); +} + +void CuptiRBProfilerSession::startAndEnable() { + startInternal(curRange_, curReplay_); + if (curReplay_ == CUPTI_UserReplay) { + beginPass(); + } + enable(); + if (curRange_ == CUPTI_UserRange) { + pushRange(kRootUserRangeName); + } + enable_flag[deviceId_] = false; +} + +void CuptiRBProfilerSession::disableAndStop() { + if (curRange_ == CUPTI_UserRange) { + popRange(); + } + disable(); + if (curReplay_ == CUPTI_UserReplay) { + endPass(); + flushCounterData(); + } + stop(); + disable_flag[deviceId_] = false; +} + +void CuptiRBProfilerSession::asyncStartAndEnable( + CUpti_ProfilerRange profilerRange, + CUpti_ProfilerReplayMode profilerReplayMode) { + LOG(INFO) << "Starting CUPTI profiler asynchronously on device = " + << deviceId_ << " profiler range = " + << ((profilerRange == CUPTI_AutoRange) ? "autorange" : "userrange") + << " replay mode = " + << ((profilerReplayMode == CUPTI_KernelReplay) ? "kernel" : "user"); + curReplay_ = profilerReplayMode; + curRange_ = profilerRange; + enable_flag[deviceId_] = true; + enableKernelCallbacks(); +} + +void CuptiRBProfilerSession::asyncDisableAndStop() { + LOG(INFO) << "Stopping CUPTI profiler asynchronously on device = " + << deviceId_ << " cu context = " << cuContext_; + disable_flag[deviceId_] = true; +} + + +CuptiProfilerResult CuptiRBProfilerSession::evaluateMetrics( + bool verbose) { + if (!initSuccess_) { + LOG(WARNING) << "Profiling failed, no results to return"; + return {}; + } + if (profilingActive_) { + disableAndStop(); + } + + LOG(INFO) << "Total kernels logged = " << kernelNames_.size(); + if (verbose) { + for (const auto& kernel : kernelNames_) { + std::cout << demangle(kernel) << std::endl; + } + LOG(INFO) << "Profiler Range data : "; + } + + auto results = nvperf::evalMetricValues( + chipName_, counterDataImage, metricNames_, verbose /*verbose*/); + + // profiler end-end duration + auto duration_ms = std::chrono::duration_cast( + profilerStopTs_ - profilerStartTs_); + + auto init_dur_ms = std::chrono::duration_cast( + profilerInitDoneTs_ - profilerStartTs_); + LOG(INFO) << "Total profiler time = " << duration_ms.count() << " ms"; + LOG(INFO) << "Total profiler init time = " << init_dur_ms.count() << " ms"; + + return results; +} + +std::unique_ptr CuptiRBProfilerSession::getProfilerTraceSpan() { + return std::make_unique( + timeSinceEpoch(profilerStartTs_), + timeSinceEpoch(profilerStopTs_), + "__cupti_profiler__" + ); +} + +void CuptiRBProfilerSession::saveCounterData( + const std::string& /*CounterDataFileName*/, + const std::string& /*CounterDataSBFileName*/) { + /* TBD write binary files for counter data and counter scratch buffer */ +} + +/// Setup counter data +bool CuptiRBProfilerSession::createCounterDataImage() { + CUpti_Profiler_CounterDataImageOptions counterDataImageOptions; + counterDataImageOptions.pCounterDataPrefix = counterDataImagePrefix.data(); + counterDataImageOptions.counterDataPrefixSize = counterDataImagePrefix.size(); + counterDataImageOptions.maxNumRanges = maxRanges_; + counterDataImageOptions.maxNumRangeTreeNodes = maxRanges_; + counterDataImageOptions.maxRangeNameLength = 64; + + // Calculate size of counter data image + CUpti_Profiler_CounterDataImage_CalculateSize_Params calculateSizeParams = { + CUpti_Profiler_CounterDataImage_CalculateSize_Params_STRUCT_SIZE, nullptr}; + calculateSizeParams.pOptions = &counterDataImageOptions; + calculateSizeParams.sizeofCounterDataImageOptions = + CUpti_Profiler_CounterDataImageOptions_STRUCT_SIZE; + + CUPTI_CALL( + cuptiProfilerCounterDataImageCalculateSize(&calculateSizeParams)); + counterDataImage.resize(calculateSizeParams.counterDataImageSize); + + // Initialize counter data image + CUpti_Profiler_CounterDataImage_Initialize_Params initializeParams = { + CUpti_Profiler_CounterDataImage_Initialize_Params_STRUCT_SIZE, nullptr}; + initializeParams.sizeofCounterDataImageOptions = + CUpti_Profiler_CounterDataImageOptions_STRUCT_SIZE; + initializeParams.pOptions = &counterDataImageOptions; + initializeParams.counterDataImageSize = + calculateSizeParams.counterDataImageSize; + initializeParams.pCounterDataImage = counterDataImage.data(); + CUPTI_CALL(cuptiProfilerCounterDataImageInitialize(&initializeParams)); + + // Calculate counter Scratch Buffer size + CUpti_Profiler_CounterDataImage_CalculateScratchBufferSize_Params + scratchBufferSizeParams = { + CUpti_Profiler_CounterDataImage_CalculateScratchBufferSize_Params_STRUCT_SIZE, nullptr}; + + scratchBufferSizeParams.counterDataImageSize = + calculateSizeParams.counterDataImageSize; + scratchBufferSizeParams.pCounterDataImage = + initializeParams.pCounterDataImage; + CUPTI_CALL(cuptiProfilerCounterDataImageCalculateScratchBufferSize( + &scratchBufferSizeParams)); + + counterDataScratchBuffer.resize( + scratchBufferSizeParams.counterDataScratchBufferSize); + + // Initialize scratch buffer + CUpti_Profiler_CounterDataImage_InitializeScratchBuffer_Params + initScratchBufferParams = { + CUpti_Profiler_CounterDataImage_InitializeScratchBuffer_Params_STRUCT_SIZE, nullptr}; + + initScratchBufferParams.counterDataImageSize = + calculateSizeParams.counterDataImageSize; + + initScratchBufferParams.pCounterDataImage = + initializeParams.pCounterDataImage; + initScratchBufferParams.counterDataScratchBufferSize = + scratchBufferSizeParams.counterDataScratchBufferSize; + initScratchBufferParams.pCounterDataScratchBuffer = + counterDataScratchBuffer.data(); + + CUPTI_CALL(cuptiProfilerCounterDataImageInitializeScratchBuffer( + &initScratchBufferParams)); + + return true; +} + +#elif defined(HAS_CUPTI) + +// Create empty stubs for the API when CUPTI is not present. +CuptiRBProfilerSession::CuptiRBProfilerSession( + const std::vector& metricNames, + int deviceId, + int maxRanges, + int numNestingLevels, + CUcontext cuContext) + : metricNames_(metricNames), + deviceId_(deviceId), + maxRanges_(maxRanges), + numNestingLevels_(numNestingLevels), + cuContext_(cuContext) {} +void CuptiRBProfilerSession::stop() {} +void CuptiRBProfilerSession::enable() {} +void CuptiRBProfilerSession::disable() {} +void CuptiRBProfilerSession::beginPass() {} +bool CuptiRBProfilerSession::endPass() { return true; } +void CuptiRBProfilerSession::flushCounterData() {} +void CuptiRBProfilerSession::pushRange(const std::string& /*rangeName*/) {} +void CuptiRBProfilerSession::popRange() {} +void CuptiRBProfilerSession::asyncStartAndEnable( + CUpti_ProfilerRange /*profilerRange*/, + CUpti_ProfilerReplayMode /*profilerReplayMode*/) {} +void CuptiRBProfilerSession::asyncDisableAndStop() {} +CuptiProfilerResult CuptiRBProfilerSession::evaluateMetrics(bool verbose) { + static CuptiProfilerResult res; + return res; +}; +void CuptiRBProfilerSession::saveCounterData( + const std::string& /*CounterDataFileName*/, + const std::string& /*CounterDataSBFileName*/) {} +void CuptiRBProfilerSession::initCupti() {} +void CuptiRBProfilerSession::deInitCupti() {} +void CuptiRBProfilerSession::staticInit() {} +bool CuptiRBProfilerSession::createCounterDataImage() { return true; } +void CuptiRBProfilerSession::startInternal( + CUpti_ProfilerRange /*profilerRange*/, + CUpti_ProfilerReplayMode /*profilerReplayMode*/) {} +std::vector& CuptiRBProfilerSession::counterAvailabilityImage() { + static std::vector _vec; + return _vec; +} +#endif // HAS_CUPTI_RANGE_PROFILER + +namespace testing { + +void trackCudaCtx(CUcontext ctx, uint32_t device_id, CUpti_CallbackId cbid) { +#if HAS_CUPTI_RANGE_PROFILER + __trackCudaCtx(ctx, device_id, cbid); +#endif // HAS_CUPTI_RANGE_PROFILER +} + +void trackCudaKernelLaunch(CUcontext ctx, const char* kernelName) { +#if HAS_CUPTI_RANGE_PROFILER + __trackCudaKernelLaunch(ctx, kernelName); +#endif // HAS_CUPTI_RANGE_PROFILER +} + +} // namespace testing +} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/CuptiRangeProfilerApi.h b/plugins/tensorboard-plugins/libkineto/src/CuptiRangeProfilerApi.h new file mode 100644 index 0000000000000000000000000000000000000000..98a0b3ea5f4850dfa060e4e86d5ebf210692db1a --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/src/CuptiRangeProfilerApi.h @@ -0,0 +1,220 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#pragma once + +#ifdef HAS_CUPTI +#include +#include +// Using CUDA 11 and above due to usage of API: cuptiProfilerGetCounterAvailability. +#if defined(CUDART_VERSION) && CUDART_VERSION >= 10000 && CUDART_VERSION < 11040 && CUDA_VERSION >= 11000 +#define HAS_CUPTI_RANGE_PROFILER 1 +#endif // CUDART_VERSION > 10.00 and < 11.04 && CUDA_VERSION >= 11.00 +#endif // HAS_CUPTI + +#if HAS_CUPTI_RANGE_PROFILER +#include +#include +#include +#else +using CUpti_ProfilerRange = enum +{ + CUPTI_AutoRange, + CUPTI_UserRange, +}; + +using CUpti_ProfilerReplayMode = enum +{ + CUPTI_KernelReplay, + CUPTI_UserReplay, +}; +#endif // HAS_CUPTI_RANGE_PROFILER + +#include +#include +#include +#include +#include + +// TODO(T90238193) +// @lint-ignore-every CLANGTIDY facebook-hte-RelativeInclude +#include "TraceSpan.h" +#include "CuptiCallbackApi.h" +#include "CuptiNvPerfMetric.h" + +/* Cupti Range based profiler session + * See : https://docs.nvidia.com/cupti/Cupti/r_main.html#r_profiler + */ + +namespace KINETO_NAMESPACE { + +class CuptiRBProfilerSession { + public: + // Initialize and configure CUPTI Profiler counters. + // - Metric names must be provided as string vector. + // - Supported values by CUPTI can be found at - + // https://docs.nvidia.com/cupti/Cupti/r_main.html#r_host_metrics_api + explicit CuptiRBProfilerSession( + const std::vector& metricNames, + int deviceId, + int maxRanges, + int numNestingLevels = 1, + CUcontext cuContext = 0); + + virtual ~CuptiRBProfilerSession() = default; + + // Start profiling session + // This function has to be called from the CPU thread running + // the CUDA context. If this is not the case asyncStartAndEnable() + // can be used + void start( + CUpti_ProfilerRange profilerRange = CUPTI_AutoRange, + CUpti_ProfilerReplayMode profilerReplayMode = CUPTI_KernelReplay) { + startInternal(profilerRange, profilerReplayMode); + } + + // Stop profiling session + virtual void stop(); + + virtual void enable(); + virtual void disable(); + + // Profiler passes + // GPU hardware has limited performance monitoring resources + // the CUPTI profiler may need to run multiple passes to collect + // data for a given range + // If we use kernel replay model the kernels are automatically replayed + // else, you can use the beginPass() and endPass() functions below + // for user to manage the replays + + // starts a profiler pass with given kernels in between + virtual void beginPass(); + + // end a profiler pass with given kernels in between + // returns true if no more passes are required + virtual bool endPass(); + + // flushes the counter data - required if you use user replay + virtual void flushCounterData(); + + // Each pass can contain multiple of ranges + // metrics configured in a pass are collected per each range-stack. + virtual void pushRange(const std::string& rangeName); + virtual void popRange(); + + // utilities for common operations + void startAndEnable(); + void disableAndStop(); + + // Async APIs : these will can be called from another thread + // outside the CUDA context being profiled + void asyncStartAndEnable( + CUpti_ProfilerRange profilerRange = CUPTI_AutoRange, + CUpti_ProfilerReplayMode profilerReplayMode = CUPTI_KernelReplay); + void asyncDisableAndStop(); + + void printMetrics() { + evaluateMetrics(true); + } + + std::unique_ptr getProfilerTraceSpan(); + + virtual CuptiProfilerResult evaluateMetrics(bool verbose = false); + + void saveCounterData( + const std::string& CounterDataFileName, + const std::string& CounterDataSBFileName); + + // This is not thread safe so please only call after + // profiling has stopped + const std::vector& getKernelNames() const { + return kernelNames_; + } + + int deviceId() const { + return deviceId_; + } + + bool profilingActive() const { + return profilingActive_; + } + + static std::set getActiveDevices(); + + static void initCupti(); + + static void deInitCupti(); + + static void staticInit(); + + static void setCounterAvailabilityImage(std::vector img) { + counterAvailabilityImage() = img; + } + protected: + CuptiRBProfilerSession(int deviceId, CUcontext ctx); + + virtual void startInternal( + CUpti_ProfilerRange profilerRange, + CUpti_ProfilerReplayMode profilerReplayMode); + + CUpti_ProfilerRange curRange_ = CUPTI_AutoRange; + CUpti_ProfilerReplayMode curReplay_ = CUPTI_KernelReplay; + + private: + + bool createCounterDataImage(); + + + // log kernel name that used with callbacks + void logKernelName(const char* kernel) { + std::lock_guard lg(kernelNamesMutex_); + kernelNames_.emplace_back(kernel); + } + + std::vector metricNames_; + std::string chipName_; + + uint32_t deviceId_ = 0; + int maxRanges_; + int numNestingLevels_; + CUcontext cuContext_; + + + // data buffers for configuration and counter data collection + std::vector counterDataImagePrefix; + std::vector configImage; + std::vector counterDataImage; + std::vector counterDataScratchBuffer; + + std::chrono::time_point profilerStartTs_; + std::chrono::time_point + profilerInitDoneTs_; + std::chrono::time_point profilerStopTs_; + + std::mutex kernelNamesMutex_; + // raw kernel names (not demangled) + std::vector kernelNames_; + + uint32_t numCallbacks_ = 0; + + static std::vector& counterAvailabilityImage(); + +#if HAS_CUPTI_RANGE_PROFILER + CUpti_Profiler_BeginPass_Params beginPassParams_; + CUpti_Profiler_EndPass_Params endPassParams_; +#endif + + bool initSuccess_ = false; + bool profilingActive_ = false; + + friend void __trackCudaKernelLaunch(CUcontext ctx, const char* kernelName); +}; + +// called directly only in unit tests +namespace testing { + +void trackCudaCtx(CUcontext ctx, uint32_t device_id, CUpti_CallbackId cbid); +void trackCudaKernelLaunch(CUcontext ctx, const char* kernelName); + +} // namespace testing + +} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/CuptiRangeProfilerConfig.cpp b/plugins/tensorboard-plugins/libkineto/src/CuptiRangeProfilerConfig.cpp new file mode 100644 index 0000000000000000000000000000000000000000..04b1ad0cb3f807cf87d32bc03de0ca9b552b0063 --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/src/CuptiRangeProfilerConfig.cpp @@ -0,0 +1,68 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#include +#include + +#include +#include + +#include +#include + +using namespace std::chrono; + +namespace KINETO_NAMESPACE { + +// number of ranges affect the size of counter data binary used by +// the CUPTI Profiler. these defaults can be tuned +constexpr int KMaxAutoRanges = 1500; // supports 1500 kernels +constexpr int KMaxUserRanges = 10; // enable upto 10 sub regions marked by user + +constexpr char kCuptiProfilerMetricsKey[] = "CUPTI_PROFILER_METRICS"; +constexpr char kCuptiProfilerPerKernelKey[] = "CUPTI_PROFILER_ENABLE_PER_KERNEL"; +constexpr char kCuptiProfilerMaxRangesKey[] = "CUPTI_PROFILER_MAX_RANGES"; + +CuptiRangeProfilerConfig::CuptiRangeProfilerConfig(Config& cfg) + : parent_(&cfg), + cuptiProfilerPerKernel_(false), + cuptiProfilerMaxRanges_(0) {} + +bool CuptiRangeProfilerConfig::handleOption(const std::string& name, std::string& val) { + VLOG(0) << " handling : " << name << " = " << val; + // Cupti Range based Profiler configuration + if (!name.compare(kCuptiProfilerMetricsKey)) { + activitiesCuptiMetrics_ = splitAndTrim(val, ','); + } else if (!name.compare(kCuptiProfilerPerKernelKey)) { + cuptiProfilerPerKernel_ = toBool(val); + } else if (!name.compare(kCuptiProfilerMaxRangesKey)) { + cuptiProfilerMaxRanges_ = toInt64(val); + } else { + return false; + } + return true; +} + +void CuptiRangeProfilerConfig::setDefaults() { + if (activitiesCuptiMetrics_.size() > 0 && cuptiProfilerMaxRanges_ == 0) { + cuptiProfilerMaxRanges_ = + cuptiProfilerPerKernel_ ? KMaxAutoRanges : KMaxUserRanges; + } +} + +void CuptiRangeProfilerConfig::printActivityProfilerConfig(std::ostream& s) const { + if (activitiesCuptiMetrics_.size() > 0) { + s << "Cupti Profiler metrics : " + << fmt::format("{}", fmt::join(activitiesCuptiMetrics_, ", ")) << std::endl; + s << "Cupti Profiler measure per kernel : " + << cuptiProfilerPerKernel_ << std::endl; + s << "Cupti Profiler max ranges : " << cuptiProfilerMaxRanges_ << std::endl; + } +} + +void CuptiRangeProfilerConfig::registerFactory() { + Config::addConfigFactory( + kCuptiProfilerConfigName, + [](Config& cfg) { return new CuptiRangeProfilerConfig(cfg); }); +} + +} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/CuptiRangeProfilerConfig.h b/plugins/tensorboard-plugins/libkineto/src/CuptiRangeProfilerConfig.h new file mode 100644 index 0000000000000000000000000000000000000000..549b8a4e8b40c66b59bae974eb87c7f64967344e --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/src/CuptiRangeProfilerConfig.h @@ -0,0 +1,86 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#pragma once + +#include "Config.h" + +#include +#include +#include +#include + +namespace KINETO_NAMESPACE { + +constexpr char kCuptiProfilerConfigName[] = "cupti_rb_profiler"; + +class CuptiRangeProfilerConfig : public AbstractConfig { + public: + bool handleOption(const std::string& name, std::string& val) override; + + void validate( + const std::chrono::time_point& + fallbackProfileStartTime) override {} + + static CuptiRangeProfilerConfig& get(const Config& cfg) { + return dynamic_cast(cfg.feature( + kCuptiProfilerConfigName)); + } + + Config& parent() const { + return *parent_; + } + + std::vector activitiesCuptiMetrics() const { + return activitiesCuptiMetrics_; + } + + bool cuptiProfilerPerKernel() const { + return cuptiProfilerPerKernel_; + } + + int64_t cuptiProfilerMaxRanges() const { + return cuptiProfilerMaxRanges_; + } + + void setSignalDefaults() override { + setDefaults(); + } + + void setClientDefaults() override { + setDefaults(); + } + + void printActivityProfilerConfig(std::ostream& s) const override; + + static void registerFactory(); + protected: + AbstractConfig* cloneDerived(AbstractConfig& parent) const override { + CuptiRangeProfilerConfig* clone = new CuptiRangeProfilerConfig(*this); + clone->parent_ = dynamic_cast(&parent); + return clone; + } + + private: + CuptiRangeProfilerConfig() = delete; + explicit CuptiRangeProfilerConfig(Config& parent); + explicit CuptiRangeProfilerConfig( + const CuptiRangeProfilerConfig& other) = default; + + // some defaults will depend on other configuration + void setDefaults(); + + // Associated Config object + Config* parent_; + + // Counter metrics exposed via CUPTI Profiler API + std::vector activitiesCuptiMetrics_; + + // Collect profiler metrics per kernel - autorange made + bool cuptiProfilerPerKernel_{false}; + + // max number of ranges to configure the profiler for. + // this has to be set before hand to reserve space for the output + int64_t cuptiProfilerMaxRanges_ = 0; +}; + +} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/DaemonConfigLoader.h b/plugins/tensorboard-plugins/libkineto/src/DaemonConfigLoader.h new file mode 100644 index 0000000000000000000000000000000000000000..9b0ed92863648824a57ce8193ddc16d7cf23622e --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/src/DaemonConfigLoader.h @@ -0,0 +1,27 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#pragma once + +#include +#include + +namespace KINETO_NAMESPACE { + +class DaemonConfigLoader { + public: + virtual ~DaemonConfigLoader() {} + + // Return the base config from the daemon + virtual std::string readBaseConfig() = 0; + + // Return a configuration string from the daemon, if one has been posted. + virtual std::string readOnDemandConfig(bool events, bool activities) = 0; + + // Returns the number of tracked contexts for this device. The daemon has a + // global view. If an unexpedted error occurs, return -1. + virtual int gpuContextCount(uint32_t device) = 0; + + virtual void setCommunicationFabric(bool enabled) = 0; +}; + +} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/Demangle.cpp b/plugins/tensorboard-plugins/libkineto/src/Demangle.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f84f0b8ec36f621061cb1e8bb8dd948cb8aed7b3 --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/src/Demangle.cpp @@ -0,0 +1,49 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#include "Demangle.h" + +#ifndef _MSC_VER +#include +#endif +#include +#include + +namespace KINETO_NAMESPACE { + +static constexpr int kMaxSymbolSize = 1024; + +std::string demangle(const char* name) { +#ifndef _MSC_VER + if (!name) { + return ""; + } + + if (strlen(name) > kMaxSymbolSize) { + return name; + } + + int status; + size_t len = 0; + char* demangled = abi::__cxa_demangle(name, nullptr, &len, &status); + if (status != 0) { + return name; + } + std::string res(demangled); + // The returned buffer must be freed! + free(demangled); + return res; +#else + // TODO: demangling on Windows + if (!name) { + return ""; + } else { + return name; + } +#endif +} + +std::string demangle(const std::string& name) { + return demangle(name.c_str()); +} + +} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/Demangle.h b/plugins/tensorboard-plugins/libkineto/src/Demangle.h new file mode 100644 index 0000000000000000000000000000000000000000..6dcf0776f1abf30e7e3614272fa02f6bae1bdf35 --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/src/Demangle.h @@ -0,0 +1,12 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#pragma once + +#include + +namespace KINETO_NAMESPACE { + +std::string demangle(const char* name); +std::string demangle(const std::string& name); + +} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/EventProfiler.cpp b/plugins/tensorboard-plugins/libkineto/src/EventProfiler.cpp new file mode 100644 index 0000000000000000000000000000000000000000..dbf2755238974392ff6205f05a5c80a1733bf2ee --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/src/EventProfiler.cpp @@ -0,0 +1,635 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#include "EventProfiler.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "CuptiEventApi.h" +#include "Logger.h" + +using namespace std::chrono; +using std::accumulate; +using std::endl; +using std::map; +using std::ostream; +using std::string; +using std::unique_ptr; +using std::vector; + +namespace KINETO_NAMESPACE { + +static std::mutex& logMutex() { + static std::mutex instance; + return instance; +} + +// --------------------------------------------------------------------- +// class Event +// --------------------------------------------------------------------- + +// Compute domain instance percentiles +PercentileList& Event::percentiles( + PercentileList& pcs, + const SampleSlice& slice) const { + vector instance_values; + instance_values.reserve(instanceCount); + for (int i = 0; i < instanceCount; i++) { + instance_values.push_back(sumInstance(i, slice)); + } + return KINETO_NAMESPACE::percentiles(instance_values, pcs); +} + +// Add up all samples for a given domain instance +int64_t Event::sumInstance(int i, const SampleSlice& slice) const { + auto r = toIdxRange(slice); + auto start = samples_.cbegin(); + std::advance(start, r.first); + auto end = start; + std::advance(end, r.second); + return accumulate(start, end, 0ul, [i](int64_t a, const Sample& b) { + return a + b.second[i]; + }); +} + +// Add up all samples across all domain instances +int64_t Event::sumAll(const SampleSlice& slice) const { + int64_t res = 0; + for (int i = 0; i < instanceCount; i++) { + res += sumInstance(i, slice); + } + return res; +} + +// Print raw sample values for all domains +void Event::printSamples(ostream& s, CUdevice device) const { + // Don't mess up output with interleaved lines + // Probably OK to reuse logMutex() here since this is + // used for debugging, but need to keep an eye on it. + std::lock_guard lock(logMutex()); + s << "Device " << device << " " << name << ":" << endl; + for (const auto& sample : samples_) { + const auto& vals = sample.second; + for (int64_t val : vals) { + s << val << " "; + } + s << endl; + } +} + +// --------------------------------------------------------------------- +// class Metric +// --------------------------------------------------------------------- +Metric::Metric( + string name, + CUpti_MetricID id, + vector events, + CUpti_MetricEvaluationMode eval_mode, + CuptiMetricApi& cupti_metrics) + : name(std::move(name)), + id_(id), + events_(std::move(events)), + evalMode_(eval_mode), + cuptiMetrics_(cupti_metrics), + valueKind_(cuptiMetrics_.valueKind(id)) {} + +// Return per-SM vector as well as total +struct Metric::CalculatedValues Metric::calculate( + map& event_map, + nanoseconds sample_duration, + const SampleSlice& slice) { + vector metric_values; + vector ev_values; + ev_values.reserve(events_.size()); + if (evalMode_ & CUPTI_METRIC_EVALUATION_MODE_PER_INSTANCE) { + int instance_count = instanceCount(event_map); + metric_values.reserve(instance_count); + for (int i = 0; i < instance_count; i++) { + ev_values.clear(); + for (CUpti_EventID event_id : events_) { + ev_values.push_back(event_map[event_id].sumInstance(i, slice)); + } + metric_values.push_back(cuptiMetrics_.calculate( + id_, valueKind_, events_, ev_values, sample_duration.count())); + } + } + + // FIXME: Check assumption that all instances are profiled + ev_values.clear(); + for (CUpti_EventID event_id : events_) { + ev_values.push_back(event_map[event_id].sumAll(slice)); + } + SampleValue total = cuptiMetrics_.calculate( + id_, valueKind_, events_, ev_values, sample_duration.count()); + if (evalMode_ & CUPTI_METRIC_EVALUATION_MODE_AGGREGATE) { + metric_values.push_back(total); + } + return {metric_values, std::move(total)}; +} + +void Metric::printDescription(ostream& s) const { + s << fmt::format("{} ({})", name, fmt::join(events_, ",")) << endl; +} + +// --------------------------------------------------------------------- +// class EventGroupSet +// --------------------------------------------------------------------- + +// Each domain has a set of counters. +// Some counters in a domain can be collected simultaneously in a "group" +// Counters from different domains can also be collected at the same time +// Therefore we have a "set of groups", or group set, with counters that +// can all be collected at once. +EventGroupSet::EventGroupSet( + CUpti_EventGroupSet& set, + map& events, + CuptiEventApi& cupti) + : set_(set), events_(events), cuptiEvents_(cupti), enabled_(false) { + for (int g = 0; g < set.numEventGroups; g++) { + CUpti_EventGroup grp = set.eventGroups[g]; + // Profile all domain instances + cuptiEvents_.enablePerInstance(grp); + uint32_t instance_count = cuptiEvents_.instanceCount(grp); + for (const auto& id : cuptiEvents_.eventsInGroup(grp)) { + VLOG(0) << "Instance count for " << id << ":" << instance_count; + events_[id].instanceCount = instance_count; + } + } +} + +EventGroupSet::~EventGroupSet() { + // Disable EventGroupSet in Cupti. + if (enabled_) { + setEnabled(false); + } +} + +// Enable or disable this group set +void EventGroupSet::setEnabled(bool enabled) { + if (enabled && !enabled_) { + cuptiEvents_.enableGroupSet(set_); + } else if (!enabled && enabled_) { + cuptiEvents_.disableGroupSet(set_); + } + enabled_ = enabled; +} + +// Collect counter values for each counter in group set +void EventGroupSet::collectSample() { + auto timestamp = system_clock::now(); + for (int g = 0; g < set_.numEventGroups; g++) { + CUpti_EventGroup grp = set_.eventGroups[g]; + for (const auto& id : cuptiEvents_.eventsInGroup(grp)) { + Event& ev = events_[id]; + vector vals(ev.instanceCount); + // FIXME: Use cuptiEventGroupReadAllEvents + cuptiEvents_.readEvent(grp, id, vals); + + if (VLOG_IS_ON(0)) { + for (int64_t v : vals) { + if (v == CUPTI_EVENT_OVERFLOW) { + LOG(WARNING) << "Counter overflow detected " + << "- decrease sample period!" << endl; + } + } + } + + ev.addSample(timestamp, vals); + } + } + + if (VLOG_IS_ON(1)) { + auto t2 = system_clock::now(); + VLOG(1) << "Device " << cuptiEvents_.device() << " Sample (us): " + << duration_cast(t2 - timestamp).count(); + } +} + +// Print names of events in this group set, ordered by group +void EventGroupSet::printDescription(ostream& s) const { + for (int g = 0; g < set_.numEventGroups; g++) { + s << " Events in group " << g << ": "; + for (const auto& id : cuptiEvents_.eventsInGroup(set_.eventGroups[g])) { + s << id << " (" << events_[id].name << ") "; + } + s << endl; + } +} + +// --------------------------------------------------------------------- +// class EventProfiler +// --------------------------------------------------------------------- + +// Find nearest factor of a number by linear search, +// starting at hi and lo - hi searches up and lo searches down +static int nearestFactor(int hi, int lo, int number) { + return number % hi == 0 + ? hi + : number % lo == 0 ? lo : nearestFactor(hi + 1, lo - 1, number); +} + +static int nearestFactor(int count, int max) { + return nearestFactor(count, count, max); +} + +void EventProfiler::initEvents(const std::set& eventNames) { + events_.clear(); + // Build event map + for (const auto& name : eventNames) { + events_.emplace(cuptiEvents_->eventId(name), name); + } +} + +void EventProfiler::initMetrics(const std::set& metricNames) { + metrics_.clear(); + // Add events from metrics + metrics_.reserve(metricNames.size()); + for (const auto& metric_name : metricNames) { + CUpti_MetricID metric_id = cuptiMetrics_->idFromName(metric_name); + if (metric_id == ~0) { + continue; + } + + const auto& events = cuptiMetrics_->events(metric_id); + vector event_ids; + event_ids.reserve(events.size()); + for (const auto& pair : events) { + CUpti_EventID id = pair.first; + const string& event_name = pair.second; + if (event_name.empty()) { + // For unnamed events, use metric name and event id + // FIXME: For subsequent metrics using the same event, + // this will be confusing + events_.emplace(id, metric_name + "_" + event_name); + } else { + events_.emplace(id, event_name); + } + event_ids.push_back(id); + } + metrics_.emplace_back( + metric_name, + metric_id, + event_ids, + cuptiMetrics_->evaluationMode(metric_id), + *cuptiMetrics_); + } +} + +bool EventProfiler::initEventGroups() { + sets_.clear(); + if (eventGroupSets_) { + cuptiEvents_->destroyGroupSets(eventGroupSets_); + eventGroupSets_ = nullptr; + } + if (events_.empty()) { + return true; + } + + // Determine sets of groups to be collected + vector ids; + ids.reserve(events_.size()); + for (const auto& ev : events_) { + ids.push_back(ev.first); + } + eventGroupSets_ = cuptiEvents_->createGroupSets(ids); + VLOG(0) << "Number of group sets: " << eventGroupSets_->numSets; + for (int i = 0; i < eventGroupSets_->numSets; i++) { + sets_.push_back( + EventGroupSet(eventGroupSets_->sets[i], events_, *cuptiEvents_)); + } + return !sets_.empty(); +} + +static unique_ptr alignAndValidateConfigs( + Config& base, + Config* onDemand) { + auto now = system_clock::now(); + if (!onDemand || + now > + (onDemand->eventProfilerOnDemandStartTime() + + onDemand->eventProfilerOnDemandDuration())) { + base.validate(now); + return base.clone(); + } + + auto res = base.clone(); + res->addEvents(onDemand->eventNames()); + res->addMetrics(onDemand->metricNames()); + + int sample_period = + std::min(base.samplePeriod().count(), onDemand->samplePeriod().count()); + if (sample_period < base.samplePeriod().count() && + (base.samplePeriod().count() % sample_period) != 0) { + sample_period = nearestFactor(sample_period, base.samplePeriod().count()); + LOG(WARNING) + << "On-demand sample period must be a factor of base sample period. " + << "Adjusting from " << onDemand->samplePeriod().count() << "ms to " + << sample_period << "ms."; + } + base.setSamplePeriod(milliseconds(sample_period)); + base.validate(now); + res->setSamplePeriod(base.samplePeriod()); + res->setMultiplexPeriod(base.multiplexPeriod()); + res->validate(now); + onDemand->setSamplePeriod(base.samplePeriod()); + onDemand->setMultiplexPeriod(base.multiplexPeriod()); + onDemand->validate(now); + + return res; +} + +static milliseconds minReportPeriod(const Config& config, int num_sets) { + return config.multiplexPeriod() * num_sets; +} + +static bool canSupportReportPeriod(const Config& config, int num_sets) { + // Can we get through the groups an even number per report period? + milliseconds min_report_period = minReportPeriod(config, num_sets); + return (config.reportPeriod().count() % min_report_period.count()) == 0; +} + +static int completeSamplesPerReport(const Config& config, int num_sets) { + if (num_sets <= 1) { + return config.reportPeriod() / config.samplePeriod(); + } + // Numnber of complete sample collections in the report period + // E.g. if report period is 10000ms, sample period 500ms, + // multiplex period 2000ms and num_sets is 5 then # of complete samples is + // (2000ms / 500ms) * (10000ms / 2000ms / 5) = 4 * 1 = 4 + int samples_per_multiplex_period = + config.multiplexPeriod() / config.samplePeriod(); + int multiplex_periods_per_report = + config.reportPeriod() / config.multiplexPeriod(); + return (multiplex_periods_per_report / num_sets) * + samples_per_multiplex_period; +} + +static bool canSupportSamplesPerReport(const Config& config, int num_sets) { + // Can samples per report can be honored with an exact *full* set of samples? + // We don't support partial samples at this point. + int full_samples_per_report = completeSamplesPerReport(config, num_sets); + return (full_samples_per_report % config.samplesPerReport()) == 0; +} + +static void adjustConfig(Config& config, int num_sets) { + // Don't change sample period and multiplex period here, since that can + // cause overflows and perf degradation. Report period and samples per + // report is OK to change (with warning). + if (!canSupportReportPeriod(config, num_sets)) { + milliseconds min_report_period = minReportPeriod(config, num_sets); + LOG(WARNING) << "Report period must be a multiple of " + << min_report_period.count() << "ms (" << num_sets + << " event sets * " << config.multiplexPeriod().count() + << "ms multiplex period), in order to get complete samples."; + auto new_report_period = + Config::alignUp(config.reportPeriod(), min_report_period); + double sf = + ((double)new_report_period.count()) / config.reportPeriod().count(); + int new_samples_per_report = std::round(config.samplesPerReport() * sf); + LOG(WARNING) << "Adjusting report period from " + << config.reportPeriod().count() << "ms to " + << new_report_period.count() << "ms"; + if (new_samples_per_report != config.samplesPerReport()) { + LOG(WARNING) << "Adjusting samples per report from " + << config.samplesPerReport() << " to " + << new_samples_per_report; + } + config.setReportPeriod(new_report_period); + config.setSamplesPerReport(new_samples_per_report); + } + // Ensure that samples per report can be honored with + // an exact *full* set of samples. Don't support partial + // samples at this point. + if (!canSupportSamplesPerReport(config, num_sets)) { + int full_samples_per_report = completeSamplesPerReport(config, num_sets); + int adjusted_count = + nearestFactor(config.samplesPerReport(), full_samples_per_report); + LOG(WARNING) + << "Samples per report must be such that an even number of " + << "complete samples can be aggregated in each report period. Adjusting" + << " from " << config.samplesPerReport() << " to " << adjusted_count + << " (complete sample count is " << full_samples_per_report << ")"; + config.setSamplesPerReport(adjusted_count); + } +} + +// Prepare profiler +EventProfiler::EventProfiler( + std::unique_ptr cupti_events, + std::unique_ptr cupti_metrics, + vector>& loggers, + vector>& onDemandLoggers) + : cuptiEvents_(std::move(cupti_events)), + cuptiMetrics_(std::move(cupti_metrics)), + loggers_(loggers), + onDemandLoggers_(onDemandLoggers) {} + +void EventProfiler::reportSamples() { + dispatchSamples(*config_, loggers_, baseSamples_); + baseSamples_ += completeSamplesPerReport(*config_, sets_.size()); +} + +void EventProfiler::reportOnDemandSamples() { + dispatchSamples(*onDemandConfig_, onDemandLoggers_, onDemandSamples_); + onDemandSamples_ += completeSamplesPerReport(*onDemandConfig_, sets_.size()); +} + +EventProfiler::~EventProfiler() { + if (eventGroupSets_) { + for (auto& set : sets_) { + set.setEnabled(false); + } + cuptiEvents_->destroyGroupSets(eventGroupSets_); + } + VLOG(0) << "Stopped event profiler for device " << device(); +} + +void EventProfiler::updateLoggers(Config& config, Config* on_demand_config) { + // Update loggers. + for (auto& logger : loggers_) { + std::lock_guard lock(logMutex()); + logger->update(config); + } + + if (on_demand_config) { + // Update onDemand loggers. + for (auto& logger : onDemandLoggers_) { + std::lock_guard lock(logMutex()); + logger->update(*on_demand_config); + } + } +} + +bool EventProfiler::applyConfig(const Config& config) { + // Initialize events, metrics, and event group sets. + // TODO: Send warnings / errors back to dyno for onDemand config + try { + if (!initEventsAndMetrics(config)) { + return false; + } + } catch (const std::exception& ex) { + LOG(WARNING) << "Failed to apply config (" << ex.what() << ")"; + return false; + } + + return true; +} + +bool EventProfiler::initEventsAndMetrics(const Config& config) { + initEvents(config.eventNames()); + initMetrics(config.metricNames()); + // We now have the total list of events to collect + // They need to be organized into groups for multiplexing + if (!initEventGroups()) { + LOG(WARNING) << "No events/metrics initialized successfully"; + return false; + } + + if (VLOG_IS_ON(1)) { + printMetrics(LIBKINETO_DBG_STREAM); + printSets(LIBKINETO_DBG_STREAM); + } + return true; +} + +void EventProfiler::printSets(ostream& s) const { + for (int i = 0; i < sets_.size(); i++) { + s << "Set " << i << endl; + sets_[i].printDescription(s); + } +} + +void EventProfiler::printMetrics(ostream& s) const { + s << "Metrics:" << endl; + for (const Metric& m : metrics_) { + m.printDescription(s); + } +} + +void EventProfiler::printAllSamples(ostream& s, CUdevice device) const { + for (const auto& pair : events_) { + const Event& ev = pair.second; + ev.printSamples(s, device); + } +} + +void EventProfiler::enableNextCounterSet() { + if (sets_.size() > 1) { + auto t1 = system_clock::now(); + + VLOG(1) << "Disabling set " << curEnabledSet_; + sets_[curEnabledSet_].setEnabled(false); + curEnabledSet_ = (curEnabledSet_ + 1) % sets_.size(); + VLOG(1) << "Enabling set " << curEnabledSet_; + sets_[curEnabledSet_].setEnabled(true); + + if (VLOG_IS_ON(1)) { + auto t2 = system_clock::now(); + VLOG(1) << "Switch (us): " + << duration_cast(t2 - t1).count(); + } + } +} + +// Notify listeners of collected samples +void EventProfiler::dispatchSamples( + const Config& config, + const vector>& loggers, + int sample_offset) { + Sample sample(events_.size() + metrics_.size()); + // Normalize values to per second + auto delta = config.reportPeriod() / config.samplesPerReport(); + double sf = 1000.0 * sets_.size() / delta.count(); + for (int i = 0; i < config.samplesPerReport(); i++) { + sample.stats.clear(); + sample.deltaMsec = (delta * i).count(); + SampleSlice slice = {sample_offset, i, config.samplesPerReport()}; + VLOG(1) << "Slice: " << sample_offset << ", " << i << ", " + << config.samplesPerReport(); + for (const auto& pair : events_) { + const Event& ev = pair.second; + int64_t total = std::round(sf * ev.sumAll(slice)); + PercentileList pcs = initPercentiles(config.percentiles()); + normalize(ev.percentiles(pcs, slice), sf); + sample.stats.push_back({ev.name, std::move(pcs), SampleValue(total)}); + } + + for (auto& m : metrics_) { + // calculate returns a pair of per-SM vector and a total + auto vals = m.calculate(events_, delta, slice); + PercentileList pcs = initPercentiles(config.percentiles()); + sample.stats.push_back( + {m.name, std::move(percentiles(vals.perInstance, pcs)), vals.total}); + } + + for (auto& logger : loggers) { + std::lock_guard lock(logMutex()); + logger->handleSample(device(), sample, config.ipcFabricEnabled()); + } + } + + if (VLOG_IS_ON(2)) { + printAllSamples(LIBKINETO_DBG_STREAM, device()); + } +} + +void EventProfiler::configure(Config& config, Config* onDemandConfig) { + if (!sets_.empty()) { + sets_[curEnabledSet_].setEnabled(false); + clearSamples(); + } + + config_ = config.clone(); + onDemandConfig_ = onDemandConfig ? onDemandConfig->clone() : nullptr; + mergedConfig_ = alignAndValidateConfigs(*config_, onDemandConfig_.get()); + if (!applyConfig(*mergedConfig_)) { + LOG(WARNING) << "Failed to apply config!"; + mergedConfig_ = config_->clone(); + applyConfig(*config_); + } + if (!sets_.empty()) { + // Make timing adjustments based on multiplexing requirements. + adjustConfig(*config_, sets_.size()); + if (onDemandConfig_) { + int duration = onDemandConfig_->eventProfilerOnDemandDuration().count(); + LOG(INFO) << "On demand profiler activated for " << duration << " secs"; + adjustConfig(*onDemandConfig_, sets_.size()); + } + // If events or metrics were added or removed, need to tell loggers + updateLoggers(*config_, onDemandConfig_.get()); + } + + curEnabledSet_ = 0; + if (!sets_.empty()) { + sets_[0].setEnabled(true); + } else { + VLOG(0) << "No counters profiled!"; + } + + baseSamples_ = 0; + onDemandSamples_ = 0; +} + +void EventProfiler::collectSample() { + if (sets_.empty()) { + return; + } + sets_[curEnabledSet_].collectSample(); + if (VLOG_IS_ON(1)) { + printAllSamples(LIBKINETO_DBG_STREAM, device()); + } +} + +} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/EventProfiler.h b/plugins/tensorboard-plugins/libkineto/src/EventProfiler.h new file mode 100644 index 0000000000000000000000000000000000000000..fafd5b9bb8336b28b210ba58d588d3a798a73969 --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/src/EventProfiler.h @@ -0,0 +1,341 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "Config.h" +#include "CuptiEventApi.h" +#include "CuptiMetricApi.h" +#include "SampleListener.h" + +namespace KINETO_NAMESPACE { + +// Helper function for computing percentiles (nearest-rank). +// Modifies the input. +template +inline PercentileList& percentiles(std::vector values, PercentileList& pcs) { + auto size = values.size(); + for (auto& x : pcs) { + int idx = std::min(size - 1, (x.first * size) / 100); + std::nth_element(values.begin(), values.begin() + idx, values.end()); + x.second = SampleValue(values[idx]); + } + return pcs; +} + +// Helper function for normalizing a percentile list +// Modifies the input +inline PercentileList& normalize(PercentileList& pcs, double sf) { + for (auto& pc : pcs) { + pc.second *= sf; + } + return pcs; +} + +// A slice of the sample buffer +struct SampleSlice { + // Start offset (samples) + int offset; + // Slice number + int index; + // Out of this many + int count; +}; + +// A sampled event +class Event { + public: + /* implicit */ Event(std::string name) : name(std::move(name)) {} + /* implicit */ Event(const char* name) : name(name) {} + Event() : name("INVALID") {} + + Event(const Event&) = delete; + Event& operator=(const Event&) = delete; + Event(Event&&) = default; + Event& operator=(Event&&) = default; + + void addSample( + std::chrono::time_point timestamp, + const std::vector& values) { + assert(values.size() == instanceCount); + samples_.emplace_back(timestamp, values); + } + + // Sum samples for a single domain instance + int64_t sumInstance(int i, const SampleSlice& slice) const; + + // Sum all samples across all domain instances + int64_t sumAll(const SampleSlice& slice) const; + + // Create list of percentiles + PercentileList& percentiles(PercentileList& pcs, const SampleSlice& slice) + const; + + void eraseSamples(int count) { + auto end = samples_.begin(); + std::advance(end, count); + samples_.erase(samples_.begin(), end); + } + + void clearSamples() { + samples_.clear(); + } + + int sampleCount() { + return samples_.size(); + } + + void printSamples(std::ostream& s, CUdevice device) const; + + // Event name (see nvprof --query-events) + std::string name; + + // Number of domain instances for this event, e.g. number of SMs + int instanceCount = 0; + + private: + std::pair toIdxRange(const SampleSlice& slice) const { + int size = (samples_.size() - slice.offset) / slice.count; + return std::make_pair(slice.offset + (slice.index * size), size); + } + + // List of collected samples, where each sample has values for + // one or more domain instances + using Sample = std::pair< + std::chrono::time_point, + std::vector>; + std::list samples_; +}; + +class Metric { + public: + Metric( + std::string name, + CUpti_MetricID id, + std::vector events, + CUpti_MetricEvaluationMode eval_mode, + CuptiMetricApi& cupti_metrics); + + struct CalculatedValues { + std::vector perInstance; + SampleValue total; + }; + + struct CalculatedValues calculate( + std::map& events, + std::chrono::nanoseconds sample_duration, + const SampleSlice& slice); + + int instanceCount(std::map& events) { + return events[events_[0]].instanceCount; + } + + void printDescription(std::ostream& s) const; + + std::string name; + + private: + CUpti_MetricID id_; + std::vector events_; + CUpti_MetricEvaluationMode evalMode_; + // Calls to CUPTI is encapsulated behind this interface + CuptiMetricApi& cuptiMetrics_; + CUpti_MetricValueKind valueKind_; +}; + +/** + * A set of event groups. + * Holds all the events that may be collected in a single pass. + * A group contains one or more counters for a single domain. + * A group set contains zero or one groups per domain. + */ +class EventGroupSet { + public: + EventGroupSet( + CUpti_EventGroupSet& set, + std::map& events, + CuptiEventApi& cupti); + ~EventGroupSet(); + + EventGroupSet(const EventGroupSet&) = delete; + EventGroupSet& operator=(const EventGroupSet&) = delete; + EventGroupSet(EventGroupSet&&) = default; + EventGroupSet& operator=(EventGroupSet&&) = delete; + + // Number of groups = number of domains profiled + int groupCount() const { + return set_.numEventGroups; + } + + void setEnabled(bool enabled); + // Take a sample of counters in this group set + void collectSample(); + void printDescription(std::ostream& s) const; + + private: + CUpti_EventGroupSet& set_; + std::map& events_; + // Calls to CUPTI is encapsulated behind this interface + CuptiEventApi& cuptiEvents_; + bool enabled_; +}; + +// The sampler +class EventProfiler { + public: + explicit EventProfiler( + std::unique_ptr cupti_events, + std::unique_ptr cupti_metrics, + std::vector>& loggers, + std::vector>& onDemandLoggers); + EventProfiler(const EventProfiler&) = delete; + EventProfiler& operator=(const EventProfiler&) = delete; + ~EventProfiler(); + + void configure(Config& config, Config* onDemandConfig); + + bool isOnDemandActive() { + return !!onDemandConfig_; + } + + // Print the counter sets. Multiple sets will be multiplexed. + void printSets(std::ostream& s) const; + + // Print metrics descriptions + void printMetrics(std::ostream& s) const; + + bool enableForDevice(Config& cfg); + + CUdevice device() { + return cuptiEvents_->device(); + } + + bool setContinuousMode() { + return cuptiEvents_->setContinuousMode(); + } + + std::chrono::milliseconds samplePeriod() { + return mergedConfig_->samplePeriod(); + } + + std::chrono::milliseconds multiplexPeriod() { + return mergedConfig_->multiplexPeriod(); + } + + std::chrono::milliseconds reportPeriod() { + return config_->reportPeriod(); + } + + std::chrono::milliseconds onDemandReportPeriod() { + return onDemandConfig_->reportPeriod(); + } + + // Read values of currently running counters. + void collectSample(); + + void reportSamples(); + void reportOnDemandSamples(); + + bool enabled() { + return sets_.size() > 0; + } + + bool multiplexEnabled() { + return sets_.size() > 1; + } + + // Multiplex counters. + void enableNextCounterSet(); + + void eraseReportedSamples() { + int erase_count = baseSamples_; + if (onDemandConfig_ && + onDemandConfig_->eventProfilerOnDemandDuration().count() > 0) { + erase_count = std::min(baseSamples_, onDemandSamples_); + } + eraseSamples(erase_count); + baseSamples_ -= erase_count; + onDemandSamples_ -= erase_count; + } + + void clearSamples() { + for (auto& pair : events_) { + pair.second.clearSamples(); + } + baseSamples_ = 0; + onDemandSamples_ = 0; + } + + private: + // Functions to initialize profiler based on Config settings. + bool applyConfig(const Config& config); + bool initEventsAndMetrics(const Config& config); + void initEvents(const std::set& eventNames); + void initMetrics(const std::set& metricNames); + bool initEventGroups(); + + PercentileList initPercentiles(const std::vector& percentiles) { + PercentileList res; + res.reserve(percentiles.size()); + for (int p : percentiles) { + res.emplace_back(p, SampleValue(0)); + } + return res; + } + + // Notify listeners of collected samples + void dispatchSamples( + const Config& config, + const std::vector>& loggers, + int report_nr); + + void eraseSamples(int count) { + for (auto& pair : events_) { + pair.second.eraseSamples(count); + } + } + + void updateLoggers(Config& config, Config* on_demand_config); + + // Print all collected samples since last clear. + void printAllSamples(std::ostream& s, CUdevice device) const; + + // Calls to CUPTI is encapsulated behind these interfaces + std::unique_ptr cuptiEvents_; + std::unique_ptr cuptiMetrics_; + // The CUpti API reports event IDs, we must map them to our event objects + std::map events_; + // List of metrics + std::vector metrics_; + // The countert sets needed to collect all counters + std::vector sets_; + // The event group set object returned by Cupti. + // Saved s.t. we can call cuptiEventGroupSetsDestroy to free memory when + // the object is no longer needed. + CUpti_EventGroupSets* eventGroupSets_ = nullptr; + // Current multiplexed counter set + int curEnabledSet_{0}; + + std::unique_ptr config_; + std::unique_ptr onDemandConfig_; + std::unique_ptr mergedConfig_; + int baseSamples_{0}; + int onDemandSamples_{0}; + + // Shared between profiler threads + // Vectors are read-only but calling loggers require lock + const std::vector>& loggers_; + const std::vector>& onDemandLoggers_; +}; + +} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/EventProfilerController.cpp b/plugins/tensorboard-plugins/libkineto/src/EventProfilerController.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0427cc7a90cbc49d31262bcce63f1f81c5b6293f --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/src/EventProfilerController.cpp @@ -0,0 +1,423 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#include "EventProfilerController.h" + +#include +#include +#include + +#include "ConfigLoader.h" +#include "CuptiEventApi.h" +#include "CuptiMetricApi.h" +#include "EventProfiler.h" +#include "output_csv.h" + +#include "Logger.h" +#include "ThreadUtil.h" + +using namespace std::chrono; +using std::unique_ptr; +using std::vector; + +namespace KINETO_NAMESPACE { + +namespace { + +vector(const Config&)>>& +loggerFactories() { + static vector(const Config&)>> + factories; + return factories; +} + +vector(const Config&)>>& +onDemandLoggerFactories() { + static vector(const Config&)>> + factories; + return factories; +} + +vector> makeLoggers(const Config& config) { + vector> loggers; + for (const auto& factory : loggerFactories()) { + loggers.push_back(factory(config)); + } + loggers.push_back(std::make_unique()); + loggers.push_back(std::make_unique()); + return loggers; +} + +vector> makeOnDemandLoggers( + const Config& config) { + vector> loggers; + for (const auto& factory : onDemandLoggerFactories()) { + loggers.push_back(factory(config)); + } + loggers.push_back(std::make_unique()); + return loggers; +} + +vector>& loggers(const Config& config) { + static auto res = makeLoggers(config); + return res; +} + +vector>& onDemandLoggers( + const Config& config) { + static auto res = makeOnDemandLoggers(config); + return res; +} + +} // anon namespace + +// Keep an eye on profiling threads. +// We've observed deadlocks in Cuda11 in libcuda / libcupti.. +namespace detail { + +class HeartbeatMonitor { + + public: + ~HeartbeatMonitor() { + stopMonitoring(); + } + + static HeartbeatMonitor& instance() { + static HeartbeatMonitor monitor; + return monitor; + } + + void profilerHeartbeat() { + int32_t tid = systemThreadId(); + std::lock_guard lock(mutex_); + profilerAliveMap_[tid]++; + } + + void setPeriod(seconds period) { + { + std::lock_guard lock(mutex_); + if (period_ == period) { + return; + } + period_ = period; + } + if (period == seconds(0)) { + stopMonitoring(); + } else { + startMonitoring(); + } + } + + private: + HeartbeatMonitor() = default; + + void monitorLoop() { + std::unique_lock lock(mutex_); + while(!stopMonitor_) { + auto cv_status = condVar_.wait_for(lock, seconds(period_)); + // Don't perform check on spurious wakeup or on notify + if (cv_status == std::cv_status::timeout) { + for (auto& pair : profilerAliveMap_) { + int32_t tid = pair.first; + int& i = pair.second; + if (i == 0) { + LOG(ERROR) << "Thread " << tid << " appears stuck!"; + } + i = 0; + } + } + } + } + + void startMonitoring() { + if (!monitorThread_) { + VLOG(0) << "Starting monitoring thread"; + stopMonitor_ = false; + monitorThread_ = std::make_unique( + &HeartbeatMonitor::monitorLoop, this); + } + } + + void stopMonitoring() { + if (monitorThread_) { + VLOG(0) << "Stopping monitoring thread"; + stopMonitor_ = true; + condVar_.notify_one(); + monitorThread_->join(); + monitorThread_ = nullptr; + VLOG(0) << "Monitoring thread terminated"; + } + } + + std::map profilerAliveMap_; + std::unique_ptr monitorThread_; + std::mutex mutex_; + std::condition_variable condVar_; + std::atomic_bool stopMonitor_{false}; + seconds period_{0}; +}; + +} // namespace detail + +namespace { +// Profiler map singleton +std::map>& profilerMap() { + static std::map> instance; + return instance; +} + +void reportLateSample( + int sleepMs, + int sampleMs, + int reportMs, + int reprogramMs) { + LOG_EVERY_N(WARNING, 10) << "Lost sample due to delays (ms): " << sleepMs + << ", " << sampleMs << ", " << reportMs << ", " + << reprogramMs; +} + +void configureHeartbeatMonitor( + detail::HeartbeatMonitor& monitor, const Config& base, const Config* onDemand) { + seconds base_period = + base.eventProfilerHeartbeatMonitorPeriod(); + seconds on_demand_period = !onDemand ? seconds(0) : + onDemand->eventProfilerHeartbeatMonitorPeriod(); + monitor.setPeriod( + on_demand_period > seconds(0) ? on_demand_period : base_period); +} + +} // anon namespace + +void EventProfilerController::addLoggerFactory( + std::function(const Config&)> factory) { + loggerFactories().push_back(factory); +} + +void EventProfilerController::addOnDemandLoggerFactory( + std::function(const Config&)> factory) { + onDemandLoggerFactories().push_back(factory); +} + +EventProfilerController::EventProfilerController( + CUcontext context, + ConfigLoader& configLoader, + detail::HeartbeatMonitor& heartbeatMonitor) + : configLoader_(configLoader), heartbeatMonitor_(heartbeatMonitor) { + auto cupti_events = std::make_unique(context); + auto cupti_metrics = + std::make_unique(cupti_events->device()); + configLoader_.addHandler( + ConfigLoader::ConfigKind::EventProfiler, this); + auto config = configLoader.getConfigCopy(); + profiler_ = std::make_unique( + std::move(cupti_events), + std::move(cupti_metrics), + loggers(*config), + onDemandLoggers(*config)); + profilerThread_ = std::make_unique( + &EventProfilerController::profilerLoop, this); +} + +EventProfilerController::~EventProfilerController() { + if (profilerThread_) { + // signaling termination of the profiler loop + stopRunloop_ = true; + profilerThread_->join(); + } + configLoader_.removeHandler( + ConfigLoader::ConfigKind::EventProfiler, this); + VLOG(0) << "Stopped event profiler"; +} + +// Must be called under lock +void EventProfilerController::start(CUcontext ctx, ConfigLoader& configLoader) { + profilerMap()[ctx] = unique_ptr( + new EventProfilerController( + ctx, configLoader, detail::HeartbeatMonitor::instance())); +} + +// Must be called under lock +void EventProfilerController::stop(CUcontext ctx) { + profilerMap()[ctx] = nullptr; +} + +bool EventProfilerController::canAcceptConfig() { + std::lock_guard guard(mutex_); + return !newOnDemandConfig_; +} + +void EventProfilerController::acceptConfig(const Config& config) { + if (config.eventProfilerOnDemandDuration().count() == 0) { + // Ignore - not for this profiler + return; + } + std::lock_guard guard(mutex_); + if (newOnDemandConfig_) { + LOG(ERROR) << "On demand request already queued - ignoring new request"; + return; + } + newOnDemandConfig_ = config.clone(); + LOG(INFO) << "Received new on-demand config"; +} + +bool EventProfilerController::enableForDevice(Config& cfg) { + // FIXME: Use device unique id! + if (!cfg.eventProfilerEnabledForDevice(profiler_->device())) { + return false; + } + // context count includes the new context + int instances = configLoader_.contextCountForGpu(profiler_->device()); + VLOG(0) << "Device context count: " << instances; + return instances >= 0 && instances <= cfg.maxEventProfilersPerGpu(); +} + +void EventProfilerController::profilerLoop() { + // We limit the number of profilers that can exist per GPU + auto config = configLoader_.getConfigCopy(); + if (!enableForDevice(*config)) { + VLOG(0) << "Not starting EventProfiler - profilers for GPU " + << profiler_->device() << " exceeds profilers per GPU limit (" + << config->maxEventProfilersPerGpu() << ")"; + return; + } + + if (!profiler_->setContinuousMode()) { + VLOG(0) << "Continuous mode not supported for GPU " + << profiler_->device() << ". Not starting Event Profiler."; + return; + } + + VLOG(0) << "Starting Event Profiler for GPU " << profiler_->device(); + setThreadName("CUPTI Event Profiler"); + + time_point next_sample_time; + time_point next_report_time; + time_point next_on_demand_report_time; + time_point next_multiplex_time; + std::unique_ptr on_demand_config = nullptr; + bool reconfigure = true; + bool restart = true; + int report_count = 0; + int on_demand_report_count = 0; + while (!stopRunloop_) { + heartbeatMonitor_.profilerHeartbeat(); + if (configLoader_.hasNewConfig(*config)) { + config = configLoader_.getConfigCopy(); + VLOG(0) << "Base config changed"; + report_count = 0; + reconfigure = true; + } + + auto now = system_clock::now(); + if (on_demand_config && + now > (on_demand_config->eventProfilerOnDemandStartTime() + + on_demand_config->eventProfilerOnDemandDuration())) { + on_demand_config = nullptr; + LOG(INFO) << "On-demand profiling complete"; + reconfigure = true; + } + + if (!profiler_->isOnDemandActive()) { + std::lock_guard lock(mutex_); + if (newOnDemandConfig_) { + VLOG(0) << "Received on-demand config, reconfiguring"; + on_demand_config = std::move(newOnDemandConfig_); + reconfigure = true; + on_demand_report_count = 0; + } + } + + if (reconfigure) { + try { + profiler_->configure(*config, on_demand_config.get()); + } catch (const std::exception& ex) { + LOG(ERROR) << "Encountered error while configuring event profiler: " + << ex.what(); + // Exit profiling entirely when encountering an error here + // as it indicates a serious problem or bug. + break; + } + configureHeartbeatMonitor( + heartbeatMonitor_, *config, on_demand_config.get()); + reconfigure = false; + restart = true; + } + + if (restart) { + now = system_clock::now(); + next_sample_time = now + profiler_->samplePeriod(); + next_report_time = now + profiler_->reportPeriod(); + if (profiler_->isOnDemandActive()) { + next_on_demand_report_time = now + profiler_->onDemandReportPeriod(); + } + next_multiplex_time = now + profiler_->multiplexPeriod(); + // Collect an initial sample and throw it away + // The next sample is the first valid one + profiler_->collectSample(); + profiler_->clearSamples(); + restart = false; + } + + auto start_sleep = now; + while (now < next_sample_time) { + /* sleep override */ + std::this_thread::sleep_for(next_sample_time - now); + now = system_clock::now(); + } + int sleep_time = duration_cast(now - start_sleep).count(); + + auto start_sample = now; + profiler_->collectSample(); + now = system_clock::now(); + int sample_time = duration_cast(now - start_sample).count(); + + next_sample_time += profiler_->samplePeriod(); + if (now > next_sample_time) { + reportLateSample(sleep_time, sample_time, 0, 0); + restart = true; + continue; + } + + auto start_report = now; + if (now > next_report_time) { + VLOG(1) << "Report #" << report_count++; + profiler_->reportSamples(); + next_report_time += profiler_->reportPeriod(); + } + if (profiler_->isOnDemandActive() && now > next_on_demand_report_time) { + VLOG(1) << "OnDemand Report #" << on_demand_report_count++; + profiler_->reportOnDemandSamples(); + next_on_demand_report_time += profiler_->onDemandReportPeriod(); + } + profiler_->eraseReportedSamples(); + now = system_clock::now(); + int report_time = duration_cast(now - start_report).count(); + + if (now > next_sample_time) { + reportLateSample(sleep_time, sample_time, report_time, 0); + restart = true; + continue; + } + + auto start_multiplex = now; + if (profiler_->multiplexEnabled() && now > next_multiplex_time) { + profiler_->enableNextCounterSet(); + next_multiplex_time += profiler_->multiplexPeriod(); + } + now = system_clock::now(); + int multiplex_time = + duration_cast(now - start_multiplex).count(); + + if (now > next_sample_time) { + reportLateSample(sleep_time, sample_time, report_time, multiplex_time); + restart = true; + } + + VLOG(0) << "Runloop execution time: " + << duration_cast(now - start_sample).count() << "ms"; + } + + VLOG(0) << "Device " << profiler_->device() + << ": Exited event profiling loop"; +} + +} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/EventProfilerController.h b/plugins/tensorboard-plugins/libkineto/src/EventProfilerController.h new file mode 100644 index 0000000000000000000000000000000000000000..007a82faa9289ada9256d09907167471eb6520b9 --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/src/EventProfilerController.h @@ -0,0 +1,63 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#pragma once + +#include +#include +#include +#include +#include + +#include + +#include "ConfigLoader.h" + +namespace KINETO_NAMESPACE { + +class Config; +class ConfigLoader; +class EventProfiler; +class SampleListener; + +namespace detail { +class HeartbeatMonitor; +} + +class EventProfilerController : public ConfigLoader::ConfigHandler { + public: + EventProfilerController(const EventProfilerController&) = delete; + EventProfilerController& operator=(const EventProfilerController&) = delete; + + ~EventProfilerController(); + + static void start(CUcontext ctx, ConfigLoader& configLoader); + static void stop(CUcontext ctx); + + static void addLoggerFactory( + std::function(const Config&)> factory); + + static void addOnDemandLoggerFactory( + std::function(const Config&)> factory); + + bool canAcceptConfig() override; + + void acceptConfig(const Config& config) override; + + private: + explicit EventProfilerController( + CUcontext context, + ConfigLoader& configLoader, + detail::HeartbeatMonitor& heartbeatMonitor); + bool enableForDevice(Config& cfg); + void profilerLoop(); + + ConfigLoader& configLoader_; + std::unique_ptr newOnDemandConfig_; + detail::HeartbeatMonitor& heartbeatMonitor_; + std::unique_ptr profiler_; + std::unique_ptr profilerThread_; + std::atomic_bool stopRunloop_{false}; + std::mutex mutex_; +}; + +} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/GenericTraceActivity.cpp b/plugins/tensorboard-plugins/libkineto/src/GenericTraceActivity.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4e00b1256c4fa301e288e619ee9ef8c56c8b8569 --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/src/GenericTraceActivity.cpp @@ -0,0 +1,10 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#include "GenericTraceActivity.h" +#include "output_base.h" + +namespace libkineto { + void GenericTraceActivity::log(ActivityLogger& logger) const { + logger.handleGenericActivity(*this); + } +} // namespace libkineto diff --git a/plugins/tensorboard-plugins/libkineto/src/ILoggerObserver.cpp b/plugins/tensorboard-plugins/libkineto/src/ILoggerObserver.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f0106578811837c9cc677def30d5697d43a94221 --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/src/ILoggerObserver.cpp @@ -0,0 +1,54 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +// TODO(T90238193) +// @lint-ignore-every CLANGTIDY facebook-hte-RelativeInclude +#include "ILoggerObserver.h" + +#if !USE_GOOGLE_LOG + +#include +#include + +namespace libkineto { + +struct LoggerTypeName { + constexpr LoggerTypeName(const char* n, LoggerOutputType t) : name(n), type(t) {}; + const char* name; + LoggerOutputType type; +}; + +static constexpr std::array LoggerMap{{ + {"VERBOSE", LoggerOutputType::VERBOSE}, + {"INFO", LoggerOutputType::INFO}, + {"WARNING", LoggerOutputType::WARNING}, + {"ERROR", LoggerOutputType::ERROR}, + {"STAGE", LoggerOutputType::STAGE}, + {"???", LoggerOutputType::ENUM_COUNT} +}}; + +static constexpr bool matchingOrder(int idx = 0) { + return LoggerMap[idx].type == LoggerOutputType::ENUM_COUNT || + ((idx == (int) LoggerMap[idx].type) && matchingOrder(idx + 1)); +} +static_assert(matchingOrder(), "LoggerTypeName map is out of order"); + +const char* toString(LoggerOutputType t) { + if(t < VERBOSE || t >= ENUM_COUNT) { + return LoggerMap[ENUM_COUNT].name; + } + return LoggerMap[(int)t].name; +} + +LoggerOutputType toLoggerOutputType(const std::string& str) { + for (int i = 0; i < LoggerTypeCount; i++) { + if (str == LoggerMap[i].name) { + return LoggerMap[i].type; + } + } + throw std::invalid_argument(fmt::format("Invalid activity type: {}", str)); +} + +} // namespace libkineto + + +#endif // !USE_GOOGLE_LOG diff --git a/plugins/tensorboard-plugins/libkineto/src/Logger.cpp b/plugins/tensorboard-plugins/libkineto/src/Logger.cpp new file mode 100644 index 0000000000000000000000000000000000000000..dbde765f51f7a5f03c31a9c79e6d00ce9a2070b6 --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/src/Logger.cpp @@ -0,0 +1,136 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +// TODO(T90238193) +// @lint-ignore-every CLANGTIDY facebook-hte-RelativeInclude +#include "Logger.h" +#include "ILoggerObserver.h" + +#ifndef USE_GOOGLE_LOG + +#include +#include +#include +#include +#include + +#include +#include + +#include "ThreadUtil.h" + +namespace KINETO_NAMESPACE { + +std::atomic_int Logger::severityLevel_{VERBOSE}; +std::atomic_int Logger::verboseLogLevel_{-1}; +std::atomic Logger::verboseLogModules_{~0ull}; + +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wglobal-constructors" +std::mutex Logger::loggerObserversMutex_; +#pragma GCC diagnostic pop + + +Logger::Logger(int severity, int line, const char* filePath, int errnum) + : buf_(), out_(LIBKINETO_DBG_STREAM), errnum_(errnum), messageSeverity_(severity) { + buf_ << toString((LoggerOutputType) severity) << ":"; + + const auto tt = + std::chrono::system_clock::to_time_t(std::chrono::system_clock::now()); + const char* file = strrchr(filePath, '/'); + buf_ << fmt::format("{:%Y-%m-%d %H:%M:%S}", fmt::localtime(tt)) << " " + << processId() << ":" << systemThreadId() << " " + << (file ? file + 1 : filePath) << ":" << line << "] "; +} + +Logger::~Logger() { +#ifdef __linux__ + if (errnum_ != 0) { + thread_local char buf[1024]; + buf_ << " : " << strerror_r(errnum_, buf, sizeof(buf)); + } +#endif + + { + std::lock_guard guard(loggerObserversMutex_); + for (auto* observer : loggerObservers()) { + // Output to observers. Current Severity helps keep track of which bucket the output goes. + if (observer) { + observer->write(buf_.str(), (LoggerOutputType) messageSeverity_); + } + } + } + + // Finally, print to terminal or console. + out_ << buf_.str() << std::endl; +} + +void Logger::setVerboseLogModules(const std::vector& modules) { + uint64_t mask = 0; + if (modules.empty()) { + mask = ~0ull; + } else { + for (const std::string& name : modules) { + mask |= hash(name.c_str()); + } + } + verboseLogModules_ = mask; +} + +void Logger::addLoggerObserver(ILoggerObserver* observer) { + if (observer == nullptr) { + return; + } + std::lock_guard guard(loggerObserversMutex_); + loggerObservers().insert(observer); +} + +void Logger::removeLoggerObserver(ILoggerObserver* observer) { + std::lock_guard guard(loggerObserversMutex_); + loggerObservers().erase(observer); +} + +void Logger::addLoggerObserverDevice(int64_t device) { + std::lock_guard guard(loggerObserversMutex_); + for (auto observer : loggerObservers()) { + observer->addDevice(device); + } +} + +void Logger::addLoggerObserverEventCount(int64_t count) { + std::lock_guard guard(loggerObserversMutex_); + for (auto observer : loggerObservers()) { + observer->addEventCount(count); + } +} + +void Logger::setLoggerObserverTraceDurationMS(int64_t duration) { + std::lock_guard guard(loggerObserversMutex_); + for (auto observer : loggerObservers()) { + observer->setTraceDurationMS(duration); + } +} + +void Logger::setLoggerObserverTraceID(const std::string& tid) { + std::lock_guard guard(loggerObserversMutex_); + for (auto observer : loggerObservers()) { + observer->setTraceID(tid); + } +} + +void Logger::setLoggerObserverGroupTraceID(const std::string& gtid) { + std::lock_guard guard(loggerObserversMutex_); + for (auto observer : loggerObservers()) { + observer->setGroupTraceID(gtid); + } +} + +void Logger::addLoggerObserverDestination(const std::string& dest) { + std::lock_guard guard(loggerObserversMutex_); + for (auto observer : loggerObservers()) { + observer->addDestination(dest); + } +} + +} // namespace KINETO_NAMESPACE + +#endif // USE_GOOGLE_LOG diff --git a/plugins/tensorboard-plugins/libkineto/src/Logger.h b/plugins/tensorboard-plugins/libkineto/src/Logger.h new file mode 100644 index 0000000000000000000000000000000000000000..868fc84b9f4ee86d88805bed81468a5df6988257 --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/src/Logger.h @@ -0,0 +1,244 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#pragma once + +#include + +#define LIBKINETO_DBG_STREAM std::cerr + +#if USE_GOOGLE_LOG + +#include + +#define SET_LOG_SEVERITY_LEVEL(level) +#define SET_LOG_VERBOSITY_LEVEL(level, modules) +#define LOGGER_OBSERVER_ADD_DEVICE(device) +#define LOGGER_OBSERVER_ADD_EVENT_COUNT(count) +#define LOGGER_OBSERVER_SET_TRACE_DURATION_MS(duration) +#define LOGGER_OBSERVER_SET_TRACE_ID(tid) +#define LOGGER_OBSERVER_SET_GROUP_TRACE_ID(gtid) +#define LOGGER_OBSERVER_ADD_DESTINATION(dest) +#define UST_LOGGER_MARK_COMPLETED(stage) + +#else // !USE_GOOGLE_LOG +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// TODO(T90238193) +// @lint-ignore-every CLANGTIDY facebook-hte-RelativeInclude +#include "ILoggerObserver.h" + +#ifdef _MSC_VER +// unset a predefined ERROR (windows) +#undef ERROR +#endif // _MSC_VER + +namespace KINETO_NAMESPACE { + +class Logger { + public: + Logger(int severity, int line, const char* filePath, int errnum = 0); + ~Logger(); + + inline std::ostream& stream() { + return buf_; + } + + static inline void setSeverityLevel(int level) { + severityLevel_ = level; + } + + static inline int severityLevel() { + return severityLevel_; + } + + static inline void setVerboseLogLevel(int level) { + verboseLogLevel_ = level; + } + + static inline int verboseLogLevel() { + return verboseLogLevel_; + } + + // This is constexpr so that the hash for a file name is computed at compile + // time when used in the VLOG macros. + // This way, there is no string comparison for matching VLOG modules, + // only a comparison of pre-computed hashes. + // No fancy hashing needed here. It's pretty inefficient (one character + // at a time) but the strings are not large and it's not in the critical path. + static constexpr uint64_t rol(uint64_t val, int amount) { + return val << amount | val >> (63 - amount); + } + static constexpr uint64_t hash(const char* s) { + uint64_t hash = hash_rec(s, 0); + return hash & rol(0x41a0240682483014ull, hash & 63); + } + static constexpr uint64_t hash_rec(const char* s, int off) { + // Random constants! + return (!s[off] ? 57ull : (hash_rec(s, off + 1) * 293) ^ s[off]); + } + static constexpr const char* basename(const char* s, int off = 0) { + return !s[off] + ? s + : s[off] == '/' ? basename(&s[off + 1]) : basename(s, off + 1); + } + + static void setVerboseLogModules(const std::vector& modules); + + static inline uint64_t verboseLogModules() { + return verboseLogModules_; + } + + static void clearLoggerObservers() { + std::lock_guard g(loggerObserversMutex_); + loggerObservers().clear(); + } + + static void addLoggerObserver(ILoggerObserver* observer); + + static void removeLoggerObserver(ILoggerObserver* observer); + + static void addLoggerObserverDevice(int64_t device); + + static void addLoggerObserverEventCount(int64_t count); + + static void setLoggerObserverTraceDurationMS(int64_t duration); + + static void setLoggerObserverTraceID(const std::string& tid); + + static void setLoggerObserverGroupTraceID(const std::string& gtid); + + static void addLoggerObserverDestination(const std::string& dest); + + private: + std::stringstream buf_; + std::ostream& out_; + int errnum_; + int messageSeverity_; + static std::atomic_int severityLevel_; + static std::atomic_int verboseLogLevel_; + static std::atomic verboseLogModules_; + static std::set& loggerObservers() { + static auto* inst = new std::set(); + return *inst; + } + static std::mutex loggerObserversMutex_; +}; + +class VoidLogger { + public: + VoidLogger() {} + void operator&(std::ostream&) {} +}; + +} // namespace KINETO_NAMESPACE + +#ifdef LOG // Undefine in case these are already defined (quite likely) +#undef LOG +#undef LOG_IS_ON +#undef LOG_IF +#undef LOG_EVERY_N +#undef LOG_IF_EVERY_N +#undef DLOG +#undef DLOG_IF +#undef VLOG +#undef VLOG_IF +#undef VLOG_EVERY_N +#undef VLOG_IS_ON +#undef DVLOG +#undef LOG_FIRST_N +#undef CHECK +#undef DCHECK +#undef DCHECK_EQ +#undef PLOG +#undef PCHECK +#undef LOG_OCCURRENCES +#endif + +#define LOG_IS_ON(severity) \ + (severity >= libkineto::Logger::severityLevel()) + +#define LOG_IF(severity, condition) \ + !(LOG_IS_ON(severity) && (condition)) ? (void)0 : libkineto::VoidLogger() & \ + libkineto::Logger(severity, __LINE__, __FILE__).stream() + +#define LOG(severity) LOG_IF(severity, true) + +#define LOCAL_VARNAME_CONCAT(name, suffix) _##name##suffix##_ + +#define LOCAL_VARNAME(name) LOCAL_VARNAME_CONCAT(name, __LINE__) + +#define LOG_OCCURRENCES LOCAL_VARNAME(log_count) + +#define LOG_EVERY_N(severity, rate) \ + static int LOG_OCCURRENCES = 0; \ + LOG_IF(severity, LOG_OCCURRENCES++ % rate == 0) \ + << "(x" << LOG_OCCURRENCES << ") " + +template +struct __to_constant__ { + static const uint64_t val = n; +}; +#define FILENAME_HASH \ + __to_constant__::val +#define VLOG_IS_ON(verbosity) \ + (libkineto::Logger::verboseLogLevel() >= verbosity && \ + (libkineto::Logger::verboseLogModules() & FILENAME_HASH) == FILENAME_HASH) + +#define VLOG_IF(verbosity, condition) \ + LOG_IF(VERBOSE, VLOG_IS_ON(verbosity) && (condition)) + +#define VLOG(verbosity) VLOG_IF(verbosity, true) + +#define VLOG_EVERY_N(verbosity, rate) \ + static int LOG_OCCURRENCES = 0; \ + VLOG_IF(verbosity, LOG_OCCURRENCES++ % rate == 0) \ + << "(x" << LOG_OCCURRENCES << ") " + +#define PLOG(severity) \ + libkineto::Logger(severity, __LINE__, __FILE__, errno).stream() + +#define SET_LOG_SEVERITY_LEVEL(level) \ + libkineto::Logger::setSeverityLevel(level) + +#define SET_LOG_VERBOSITY_LEVEL(level, modules) \ + libkineto::Logger::setVerboseLogLevel(level); \ + libkineto::Logger::setVerboseLogModules(modules) + +// Logging the set of devices the trace is collect on. +#define LOGGER_OBSERVER_ADD_DEVICE(device_count) \ + libkineto::Logger::addLoggerObserverDevice(device_count) + +// Incrementing the number of events collected by this trace. +#define LOGGER_OBSERVER_ADD_EVENT_COUNT(count) \ + libkineto::Logger::addLoggerObserverEventCount(count) + +// Record duration of trace in milliseconds. +#define LOGGER_OBSERVER_SET_TRACE_DURATION_MS(duration) \ + libkineto::Logger::setLoggerObserverTraceDurationMS(duration) + +// Record the trace id when given. +#define LOGGER_OBSERVER_SET_TRACE_ID(tid) \ + libkineto::Logger::setLoggerObserverTraceID(tid) + +// Record the group trace id when given. +#define LOGGER_OBSERVER_SET_GROUP_TRACE_ID(gtid) \ + libkineto::Logger::setLoggerObserverGroupTraceID(gtid) + +// Log the set of destinations the trace is sent to. +#define LOGGER_OBSERVER_ADD_DESTINATION(dest) \ + libkineto::Logger::addLoggerObserverDestination(dest) + +// UST Logger Semantics to describe when a stage is complete. +#define UST_LOGGER_MARK_COMPLETED(stage) \ + LOG(libkineto::LoggerOutputType::STAGE) << "Completed Stage: " << stage + +#endif // USE_GOOGLE_LOG diff --git a/plugins/tensorboard-plugins/libkineto/src/LoggerCollector.h b/plugins/tensorboard-plugins/libkineto/src/LoggerCollector.h new file mode 100644 index 0000000000000000000000000000000000000000..bb05aab218dc137cfe2f0107694a049ee2ea6508 --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/src/LoggerCollector.h @@ -0,0 +1,70 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#pragma once + +#if !USE_GOOGLE_LOG + +#include +#include +#include +#include + +// TODO(T90238193) +// @lint-ignore-every CLANGTIDY facebook-hte-RelativeInclude +#include "ILoggerObserver.h" + +namespace KINETO_NAMESPACE { + +using namespace libkineto; + +class LoggerCollector : public ILoggerObserver { + public: + LoggerCollector() : buckets_() {} + + void write(const std::string& message, LoggerOutputType ot = ERROR) override { + // Skip STAGE output type which is only used by USTLoggerCollector. + if (ot != STAGE) { + buckets_[ot].push_back(message); + } + } + + const std::map> extractCollectorMetadata() override { + return buckets_; + } + + void reset() override { + trace_duration_ms = 0; + event_count = 0; + destinations.clear(); + } + + void addDevice(const int64_t device) override { + devices.insert(device); + } + + void setTraceDurationMS(const int64_t duration) override { + trace_duration_ms = duration; + } + + void addEventCount(const int64_t count) override { + event_count += count; + } + + void addDestination(const std::string& dest) override { + destinations.insert(dest); + } + + protected: + std::map> buckets_; + + // These are useful metadata to collect from CUPTIActivityProfiler for internal tracking. + std::set devices; + int64_t trace_duration_ms{0}; + std::atomic event_count{0}; + std::set destinations; + +}; + +} // namespace KINETO_NAMESPACE + +#endif // !USE_GOOGLE_LOG diff --git a/plugins/tensorboard-plugins/libkineto/src/RoctracerActivityApi.cpp b/plugins/tensorboard-plugins/libkineto/src/RoctracerActivityApi.cpp new file mode 100644 index 0000000000000000000000000000000000000000..73eff13e2a08bcfecefb03f5b229bde89b7e96cb --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/src/RoctracerActivityApi.cpp @@ -0,0 +1,569 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#include "RoctracerActivityApi.h" + +#include +#include +#include + +#include "Demangle.h" +#include "output_base.h" +#include "ThreadUtil.h" + +typedef uint64_t timestamp_t; + +static timestamp_t timespec_to_ns(const timespec& time) { + return ((timestamp_t)time.tv_sec * 1000000000) + time.tv_nsec; + } + +using namespace std::chrono; + +namespace KINETO_NAMESPACE { + +constexpr size_t kBufSize(2 * 1024 * 1024); + +RoctracerActivityApi& RoctracerActivityApi::singleton() { + static RoctracerActivityApi instance; + return instance; +} + +RoctracerActivityApi::RoctracerActivityApi() { + gpuTraceBuffers_ = std::make_unique>(); +} + +RoctracerActivityApi::~RoctracerActivityApi() { + disableActivities(std::set()); + endTracing(); +} + +void RoctracerActivityApi::pushCorrelationID(int id, CorrelationFlowType type) { +#ifdef HAS_ROCTRACER + if (!singleton().externalCorrelationEnabled_) { + return; + } + // placeholder +#endif +} + +void RoctracerActivityApi::popCorrelationID(CorrelationFlowType type) { +#ifdef HAS_ROCTRACER + if (!singleton().externalCorrelationEnabled_) { + return; + } + // placeholder +#endif +} + +void RoctracerActivityApi::setMaxBufferSize(int size) { + maxGpuBufferCount_ = 1 + size / kBufSize; +} + +int RoctracerActivityApi::processActivities( + ActivityLogger& logger) { + // Find offset to map from monotonic clock to system clock. + // This will break time-ordering of events but is status quo. + + timespec t0, t1, t00; + clock_gettime(CLOCK_REALTIME, &t0); + clock_gettime(CLOCK_MONOTONIC, &t1); + clock_gettime(CLOCK_REALTIME, &t00); + + const timestamp_t toffset = (timespec_to_ns(t0) >> 1) + (timespec_to_ns(t00) >> 1) - timespec_to_ns(t1); + + int count = 0; + + // Basic Api calls + + for (auto &item : rows_) { + GenericTraceActivity a; + a.startTime = (item.begin + toffset) / 1000; + a.endTime = (item.end + toffset) / 1000; + a.id = item.id; + a.device = item.pid; + a.resource = item.tid; + a.activityType = ActivityType::CUDA_RUNTIME; + a.activityName = std::string(roctracer_op_string(ACTIVITY_DOMAIN_HIP_API, item.cid, 0)); + a.flow.id = item.id; + a.flow.type = kLinkAsyncCpuGpu; + a.flow.start = true; + + logger.handleGenericActivity(a); + ++count; + } + + // Malloc/Free calls + for (auto &item : mallocRows_) { + GenericTraceActivity a; + a.startTime = (item.begin + toffset) / 1000; + a.endTime = (item.end + toffset) / 1000; + a.id = item.id; + a.device = item.pid; + a.resource = item.tid; + a.activityType = ActivityType::CUDA_RUNTIME; + a.activityName = std::string(roctracer_op_string(ACTIVITY_DOMAIN_HIP_API, item.cid, 0)); + a.flow.id = item.id; + a.flow.type = kLinkAsyncCpuGpu; + a.flow.start = true; + + a.addMetadata("ptr", item.ptr); + if (item.cid == HIP_API_ID_hipMalloc) { + a.addMetadata("size", item.size); + } + + logger.handleGenericActivity(a); + ++count; + } + + // HipMemcpy calls + for (auto &item : copyRows_) { + GenericTraceActivity a; + a.startTime = (item.begin + toffset) / 1000; + a.endTime = (item.end + toffset) / 1000; + a.id = item.id; + a.device = item.pid; + a.resource = item.tid; + a.activityType = ActivityType::CUDA_RUNTIME; + a.activityName = std::string(roctracer_op_string(ACTIVITY_DOMAIN_HIP_API, item.cid, 0)); + a.flow.id = item.id; + a.flow.type = kLinkAsyncCpuGpu; + a.flow.start = true; + + a.addMetadata("src", item.src); + a.addMetadata("dst", item.dst); + a.addMetadata("size", item.size); + a.addMetadata("kind", item.kind); + if ((item.cid == HIP_API_ID_hipMemcpyAsync) || (item.cid == HIP_API_ID_hipMemcpyWithStream)) { + a.addMetadata("stream", fmt::format("{}", reinterpret_cast(item.stream))); + } + + logger.handleGenericActivity(a); + ++count; + } + + // Kernel Launch Api calls + + for (auto &item : kernelRows_) { + GenericTraceActivity a; + a.startTime = (item.begin + toffset) / 1000; + a.endTime = (item.end + toffset) / 1000; + a.id = item.id; + a.device = item.pid; + a.resource = item.tid; + a.activityType = ActivityType::CUDA_RUNTIME; + a.activityName = std::string(roctracer_op_string(ACTIVITY_DOMAIN_HIP_API, item.cid, 0)); + a.flow.id = item.id; + a.flow.type = kLinkAsyncCpuGpu; + a.flow.start = true; + + if (item.functionAddr != nullptr) { + a.addMetadataQuoted( + "kernel", demangle(hipKernelNameRefByPtr(item.functionAddr, item.stream))); + } + else if (item.function != nullptr) { + a.addMetadataQuoted( + "kernel", demangle(hipKernelNameRef(item.function))); + } + a.addMetadata("grid dim", fmt::format("[{}, {}, {}]", item.gridX, item.gridY, item.gridZ)); + a.addMetadata("block dim", fmt::format("[{}, {}, {}]", item.workgroupX, item.workgroupY, item.workgroupZ)); + a.addMetadata("shared size", item.groupSegmentSize); + a.addMetadata("stream", fmt::format("{}", reinterpret_cast(item.stream))); + + // Stash launches to tie to the async ops + kernelLaunches_[a.id] = a; + + // Stash kernel names to tie to the async ops + std::string name; + if (item.functionAddr != nullptr) { + name = demangle(hipKernelNameRefByPtr(item.functionAddr, item.stream)); + } + else if (item.function != nullptr) { + name = demangle(hipKernelNameRef(item.function)); + } + if (!name.empty()) { + uint32_t string_id = reverseStrings_[name]; + if (string_id == 0) { + string_id = nextStringId_++; + reverseStrings_[name] = string_id; + strings_[string_id] = name; + } + kernelNames_[item.id] = string_id; + } + + logger.handleGenericActivity(a); + ++count; + } + + // Async Ops + + for (auto& buffer : *gpuTraceBuffers_) { + const roctracer_record_t* record = (const roctracer_record_t*)(buffer.data); + const roctracer_record_t* end_record = (const roctracer_record_t*)(buffer.data + buffer.validSize); + GenericTraceActivity a; + + while (record < end_record) { + if ((record->domain == ACTIVITY_DOMAIN_HIP_API) && (loggedIds_.contains(record->op))) { + const char *name = roctracer_op_string(record->domain, record->op, record->kind); + a.device = record->process_id; + a.resource = record->thread_id; + + a.startTime = (record->begin_ns + toffset) / 1000; + a.endTime = (record->end_ns + toffset) / 1000; + a.id = record->correlation_id; + + a.activityType = ActivityType::CUDA_RUNTIME; + a.activityName = std::string(name); + a.flow.id = record->correlation_id; + a.flow.type = kLinkAsyncCpuGpu; + a.flow.start = true; + + logger.handleGenericActivity(a); + ++count; + } + else if (record->domain == ACTIVITY_DOMAIN_HCC_OPS) { + // Overlay launch metadata for kernels + auto kit = kernelLaunches_.find(record->correlation_id); + if (kit != kernelLaunches_.end()) { + a = (*kit).second; + } + + const char *name = roctracer_op_string(record->domain, record->op, record->kind); + a.device = record->device_id; + a.resource = record->queue_id; + + a.startTime = (record->begin_ns + toffset) / 1000; + a.endTime = (record->end_ns + toffset) / 1000; + a.id = record->correlation_id; + + a.activityType = ActivityType::CONCURRENT_KERNEL; + a.activityName = std::string(name); + a.flow.id = record->correlation_id; + a.flow.type = kLinkAsyncCpuGpu; + + auto it = kernelNames_.find(record->correlation_id); + if (it != kernelNames_.end()) { + a.activityName = strings_[it->second]; + } + + logger.handleGenericActivity(a); + ++count; + } + + roctracer_next_record(record, &record); + } + } + return count; +} + +void RoctracerActivityApi::clearActivities() { + gpuTraceBuffers_->clear(); + rows_.clear(); + kernelRows_.clear(); + copyRows_.clear(); + mallocRows_.clear(); + kernelLaunches_.clear(); +} + +void RoctracerActivityApi::api_callback(uint32_t domain, uint32_t cid, const void* callback_data, void* arg) +{ + RoctracerActivityApi *dis = &singleton(); + + if (domain == ACTIVITY_DOMAIN_HIP_API && dis->loggedIds_.contains(cid)) { + const hip_api_data_t* data = (const hip_api_data_t*)(callback_data); + + // Pack callbacks into row structures + + static timespec timestamp; // FIXME verify thread safety + + if (data->phase == ACTIVITY_API_PHASE_ENTER) { + clock_gettime(CLOCK_MONOTONIC, ×tamp); // record proper clock + } + else { // (data->phase == ACTIVITY_API_PHASE_EXIT) + timespec endTime; + timespec startTime { timestamp }; + clock_gettime(CLOCK_MONOTONIC, &endTime); // record proper clock + + switch (cid) { + case HIP_API_ID_hipLaunchKernel: + case HIP_API_ID_hipExtLaunchKernel: + case HIP_API_ID_hipLaunchCooperativeKernel: // Should work here + { + auto &args = data->args.hipLaunchKernel; + dis->kernelRows_.emplace_back(data->correlation_id, + domain, + cid, + processId(), + systemThreadId(), + timespec_to_ns(startTime), + timespec_to_ns(endTime), + args.function_address, + nullptr, + args.numBlocks.x, + args.numBlocks.y, + args.numBlocks.z, + args.dimBlocks.x, + args.dimBlocks.y, + args.dimBlocks.z, + args.sharedMemBytes, + args.stream + ); + } + break; + case HIP_API_ID_hipHccModuleLaunchKernel: + case HIP_API_ID_hipModuleLaunchKernel: + case HIP_API_ID_hipExtModuleLaunchKernel: + { + auto &args = data->args.hipModuleLaunchKernel; + dis->kernelRows_.emplace_back(data->correlation_id, + domain, + cid, + processId(), + systemThreadId(), + timespec_to_ns(startTime), + timespec_to_ns(endTime), + nullptr, + args.f, + args.gridDimX, + args.gridDimY, + args.gridDimZ, + args.blockDimX, + args.blockDimY, + args.blockDimZ, + args.sharedMemBytes, + args.stream + ); + } + break; + case HIP_API_ID_hipLaunchCooperativeKernelMultiDevice: + case HIP_API_ID_hipExtLaunchMultiKernelMultiDevice: +#if 0 + { + auto &args = data->args.hipLaunchCooperativeKernelMultiDevice.launchParamsList__val; + dis->kernelRows_.emplace_back(data->correlation_id, + domain, + cid, + processId(), + systemThreadId(), + timespec_to_ns(startTime), + timespec_to_ns(endTime), + args.function_address, + nullptr, + args.numBlocks.x, + args.numBlocks.y, + args.numBlocks.z, + args.dimBlocks.x, + args.dimBlocks.y, + args.dimBlocks.z, + args.sharedMemBytes, + args.stream + ); + } +#endif + break; + case HIP_API_ID_hipMalloc: + dis->mallocRows_.emplace_back(data->correlation_id, + domain, + cid, + processId(), + systemThreadId(), + timespec_to_ns(startTime), + timespec_to_ns(endTime), + data->args.hipMalloc.ptr__val, + data->args.hipMalloc.size + ); + break; + case HIP_API_ID_hipFree: + dis->mallocRows_.emplace_back(data->correlation_id, + domain, + cid, + processId(), + systemThreadId(), + timespec_to_ns(startTime), + timespec_to_ns(endTime), + data->args.hipFree.ptr, + 0 + ); + break; + case HIP_API_ID_hipMemcpy: + { + auto &args = data->args.hipMemcpy; + dis->copyRows_.emplace_back(data->correlation_id, + domain, + cid, + processId(), + systemThreadId(), + timespec_to_ns(startTime), + timespec_to_ns(endTime), + args.src, + args.dst, + args.sizeBytes, + args.kind, + static_cast(0) // use placeholder? + ); + } + break; + case HIP_API_ID_hipMemcpyAsync: + case HIP_API_ID_hipMemcpyWithStream: + { + auto &args = data->args.hipMemcpyAsync; + dis->copyRows_.emplace_back(data->correlation_id, + domain, + cid, + processId(), + systemThreadId(), + timespec_to_ns(startTime), + timespec_to_ns(endTime), + args.src, + args.dst, + args.sizeBytes, + args.kind, + args.stream + ); + } + break; + default: + dis->rows_.emplace_back(data->correlation_id, + domain, + cid, + processId(), + systemThreadId(), + timespec_to_ns(startTime), + timespec_to_ns(endTime) + ); + break; + } + } + } +} + +void RoctracerActivityApi::activity_callback(const char* begin, const char* end, void* arg) +{ + size_t size = end - begin; + uint8_t *buffer = (uint8_t*) malloc(size); + auto &gpuTraceBuffers = singleton().gpuTraceBuffers_; + memcpy(buffer, begin, size); + gpuTraceBuffers->emplace_back(buffer, size); +} + +void RoctracerActivityApi::enableActivities( + const std::set& selected_activities) { +#ifdef HAS_ROCTRACER + if (!registered_) { + roctracer_set_properties(ACTIVITY_DOMAIN_HIP_API, nullptr); // Magic encantation + + // Set some api calls to ignore + loggedIds_.setInvertMode(true); // Omit the specified api + loggedIds_.add("hipGetDevice"); + loggedIds_.add("hipSetDevice"); + loggedIds_.add("hipGetLastError"); + loggedIds_.add("__hipPushCallConfiguration"); + loggedIds_.add("__hipPopCallConfiguration"); + loggedIds_.add("hipCtxSetCurrent"); + loggedIds_.add("hipEventRecord"); + loggedIds_.add("hipEventQuery"); + loggedIds_.add("hipGetDeviceProperties"); + loggedIds_.add("hipPeekAtLastError"); + loggedIds_.add("hipModuleGetFunction"); + loggedIds_.add("hipEventCreateWithFlags"); + + // Enable API callbacks + if (loggedIds_.invertMode() == true) { + // exclusion list - enable entire domain and turn off things in list + roctracer_enable_domain_callback(ACTIVITY_DOMAIN_HIP_API, api_callback, nullptr); + const std::unordered_map &filter = loggedIds_.filterList(); + for (auto it = filter.begin(); it != filter.end(); ++it) { + roctracer_disable_op_callback(ACTIVITY_DOMAIN_HIP_API, it->first); + } + } + else { + // inclusion list - only enable things in the list + const std::unordered_map &filter = loggedIds_.filterList(); + roctracer_disable_domain_callback(ACTIVITY_DOMAIN_HIP_API); + for (auto it = filter.begin(); it != filter.end(); ++it) { + roctracer_enable_op_callback(ACTIVITY_DOMAIN_HIP_API, it->first, api_callback, nullptr); + } + } + //roctracer_enable_domain_callback(ACTIVITY_DOMAIN_ROCTX, api_callback, nullptr); + + // Allocate default tracing pool + roctracer_properties_t properties; + memset(&properties, 0, sizeof(roctracer_properties_t)); + properties.buffer_size = 0x1000; + roctracer_open_pool(&properties); + + // Enable async op collection + roctracer_properties_t hcc_cb_properties; + memset(&hcc_cb_properties, 0, sizeof(roctracer_properties_t)); + hcc_cb_properties.buffer_size = 0x4000; + hcc_cb_properties.buffer_callback_fun = activity_callback; + roctracer_open_pool_expl(&hcc_cb_properties, &hccPool_); + roctracer_enable_domain_activity_expl(ACTIVITY_DOMAIN_HCC_OPS, hccPool_); + + registered_ = true; + } + + for (const auto& activity : selected_activities) { + if (activity == ActivityType::EXTERNAL_CORRELATION) { + externalCorrelationEnabled_ = true; + } + } + + roctracer_start(); +#endif +} + +void RoctracerActivityApi::disableActivities( + const std::set& selected_activities) { +#ifdef HAS_ROCTRACER + roctracer_stop(); + roctracer_flush_activity_expl(hccPool_); + + for (const auto& activity : selected_activities) { + if (activity == ActivityType::EXTERNAL_CORRELATION) { + externalCorrelationEnabled_ = false; + } + } +#endif +} + +void RoctracerActivityApi::endTracing() { + if (registered_ == true) { + roctracer_disable_domain_callback(ACTIVITY_DOMAIN_HIP_API); + //roctracer_disable_domain_callback(ACTIVITY_DOMAIN_ROCTX); + + roctracer_disable_domain_activity(ACTIVITY_DOMAIN_HCC_OPS); + roctracer_close_pool_expl(hccPool_); + } +} + + +ApiIdList::ApiIdList() +: invert_(true) +{ +} + +void ApiIdList::add(std::string apiName) +{ + uint32_t cid = 0; + if (roctracer_op_code(ACTIVITY_DOMAIN_HIP_API, apiName.c_str(), &cid, nullptr) == ROCTRACER_STATUS_SUCCESS) { + filter_[cid] = 1; + } +} +void ApiIdList::remove(std::string apiName) +{ + uint32_t cid = 0; + if (roctracer_op_code(ACTIVITY_DOMAIN_HIP_API, apiName.c_str(), &cid, nullptr) == ROCTRACER_STATUS_SUCCESS) { + filter_.erase(cid); + } +} + +bool ApiIdList::loadUserPrefs() +{ + // placeholder + return false; +} +bool ApiIdList::contains(uint32_t apiId) +{ + return (filter_.find(apiId) != filter_.end()) ? !invert_ : invert_; // XOR +} + +} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/RoctracerActivityApi.h b/plugins/tensorboard-plugins/libkineto/src/RoctracerActivityApi.h new file mode 100644 index 0000000000000000000000000000000000000000..28280253e7c8426e85c11d679785bcd74fa2a0c7 --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/src/RoctracerActivityApi.h @@ -0,0 +1,171 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef HAS_ROCTRACER +#include +#include +#include +#include +#include +#endif + +#include "ActivityType.h" +#include "GenericTraceActivity.h" +#include "RoctracerActivityBuffer.h" + + +namespace KINETO_NAMESPACE { + +using namespace libkineto; + +class ApiIdList +{ +public: + ApiIdList(); + bool invertMode() { return invert_; } + void setInvertMode(bool invert) { invert_ = invert; } + void add(std::string apiName); + void remove(std::string apiName); + bool loadUserPrefs(); + bool contains(uint32_t apiId); + const std::unordered_map &filterList() { return filter_; } + +private: + std::unordered_map filter_; + bool invert_; +}; + +struct roctracerRow { + roctracerRow(uint64_t id, uint32_t domain, uint32_t cid, uint32_t pid + , uint32_t tid, uint64_t begin, uint64_t end) + : id(id), domain(domain), cid(cid), pid(pid), tid(tid), begin(begin), end(end) {} + uint64_t id; // correlation_id + uint32_t domain; + uint32_t cid; + uint32_t pid; + uint32_t tid; + uint64_t begin; + uint64_t end; +}; + +struct kernelRow : public roctracerRow { + kernelRow(uint64_t id, uint32_t domain, uint32_t cid, uint32_t pid + , uint32_t tid, uint64_t begin, uint64_t end + , const void *faddr, hipFunction_t function + , unsigned int gx, unsigned int gy, unsigned int gz + , unsigned int wx, unsigned int wy, unsigned int wz + , size_t gss, hipStream_t stream) + : roctracerRow(id, domain, cid, pid, tid, begin, end), functionAddr(faddr) + , function(function), gridX(gx), gridY(gy), gridZ(gz) + , workgroupX(wx), workgroupY(wy), workgroupZ(wz), groupSegmentSize(gss) + , stream(stream) {} + const void* functionAddr; + hipFunction_t function; + unsigned int gridX; + unsigned int gridY; + unsigned int gridZ; + unsigned int workgroupX; + unsigned int workgroupY; + unsigned int workgroupZ; + size_t groupSegmentSize; + hipStream_t stream; +}; + +struct copyRow : public roctracerRow { + copyRow(uint64_t id, uint32_t domain, uint32_t cid, uint32_t pid + , uint32_t tid, uint64_t begin, uint64_t end + , const void* src, const void *dst, size_t size, hipMemcpyKind kind + , hipStream_t stream) + : roctracerRow(id, domain, cid, pid, tid, begin, end) + , src(src), dst(dst), size(size), kind(kind), stream(stream) {} + const void *src; + const void *dst; + size_t size; + hipMemcpyKind kind; + hipStream_t stream; +}; + +struct mallocRow : public roctracerRow { + mallocRow(uint64_t id, uint32_t domain, uint32_t cid, uint32_t pid + , uint32_t tid, uint64_t begin, uint64_t end + , const void* ptr, size_t size) + : roctracerRow(id, domain, cid, pid, tid, begin, end) + , ptr(ptr), size(size) {} + const void *ptr; + size_t size; +}; + + +class RoctracerActivityApi { + public: + enum CorrelationFlowType { + Default, + User + }; + + RoctracerActivityApi(); + RoctracerActivityApi(const RoctracerActivityApi&) = delete; + RoctracerActivityApi& operator=(const RoctracerActivityApi&) = delete; + + virtual ~RoctracerActivityApi(); + + static RoctracerActivityApi& singleton(); + + static void pushCorrelationID(int id, CorrelationFlowType type); + static void popCorrelationID(CorrelationFlowType type); + + void enableActivities( + const std::set& selected_activities); + void disableActivities( + const std::set& selected_activities); + void clearActivities(); + + int processActivities(ActivityLogger& logger); + + void setMaxBufferSize(int size); + + std::atomic_bool stopCollection{false}; + + private: + bool registered_{false}; + void endTracing(); + +#ifdef HAS_ROCTRACER + roctracer_pool_t *hccPool_{NULL}; + static void api_callback(uint32_t domain, uint32_t cid, const void* callback_data, void* arg); + static void activity_callback(const char* begin, const char* end, void* arg); + + //Name cache + uint32_t nextStringId_{2}; + std::map strings_; + std::map reverseStrings_; + std::map kernelNames_; + + ApiIdList loggedIds_; + + // Api callback data + std::deque rows_; + std::deque kernelRows_; + std::deque copyRows_; + std::deque mallocRows_; + std::map kernelLaunches_; +#endif + + int maxGpuBufferCount_{0}; + std::unique_ptr> gpuTraceBuffers_; + bool externalCorrelationEnabled_{true}; +}; + +} // namespace KINETO_NAMESPACE + diff --git a/plugins/tensorboard-plugins/libkineto/src/RoctracerActivityBuffer.h b/plugins/tensorboard-plugins/libkineto/src/RoctracerActivityBuffer.h new file mode 100644 index 0000000000000000000000000000000000000000..cd8a5709a841b7c988ab3f2d1f3108d693343584 --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/src/RoctracerActivityBuffer.h @@ -0,0 +1,30 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#pragma once + +#include +#include +#include +#include + +namespace KINETO_NAMESPACE { + +class RoctracerActivityBuffer { + public: + // data must be allocated using malloc. + // Ownership is transferred to this object. + RoctracerActivityBuffer(uint8_t* data, size_t validSize) + : data(data), validSize(validSize) {} + + ~RoctracerActivityBuffer() { + free(data); + } + + // Allocated by malloc + uint8_t* data{nullptr}; + + // Number of bytes used + size_t validSize; +}; + +} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/SampleListener.h b/plugins/tensorboard-plugins/libkineto/src/SampleListener.h new file mode 100644 index 0000000000000000000000000000000000000000..bff86ad122a051d4f3dfdbdd329a3b63d93a7c77 --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/src/SampleListener.h @@ -0,0 +1,146 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#pragma once + +#include +#include +#include +#include +#include + +namespace KINETO_NAMESPACE { + +class Config; + +class SampleValue { + public: + template + explicit SampleValue(T v) { + init(v); + } + + SampleValue(const SampleValue&) = default; + SampleValue& operator=(const SampleValue&) = delete; + SampleValue(SampleValue&&) = default; + SampleValue& operator=(SampleValue&&) = default; + + bool isInt() const { + return type_ == INT64; + } + + int64_t getInt() const { + assert(isInt()); + return int_; + } + + bool isDouble() const { + return type_ == DOUBLE; + } + + double getDouble() const { + assert(isDouble()); + return dbl_; + } + + inline void operator*=(double x) { + assert(isDouble() || isInt()); + if (isDouble()) { + dbl_ *= x; + } else { + int_ = std::round(int_ * x); + } + } + + inline bool operator<(const SampleValue& o) const { + if (type_ != o.type_) { + return type_ < o.type_; + } else if (type_ == INT64) { + return int_ < o.int_; + } else if (type_ == DOUBLE) { + return dbl_ < o.dbl_; + } + assert(false); + return true; + } + + void print(std::ostream& s) const { + if (type_ == INT64) { + s << int_; + } else if (type_ == DOUBLE) { + s << dbl_; + } else { + assert(false); + } + } + + private: + enum Type { INT64, DOUBLE }; + + template + void init(T v); + + Type type_{INT64}; + union { + int64_t int_{0}; + double dbl_; + }; +}; + +template <> +inline void SampleValue::init(uint64_t v) { + int_ = v, type_ = INT64; +} +template <> +inline void SampleValue::init(int64_t v) { + int_ = v, type_ = INT64; +} +template <> +inline void SampleValue::init(int v) { + int_ = v, type_ = INT64; +} +template <> +inline void SampleValue::init(double v) { + dbl_ = v, type_ = DOUBLE; +} + +inline std::ostream& operator<<(std::ostream& out, const SampleValue& s) { + s.print(out); + return out; +} + +using PercentileList = std::vector>; + +struct Stat { + const std::string& name; + const PercentileList percentileValues; + SampleValue total; +}; + +struct Sample { + Sample(int stats_count) { + stats.reserve(stats_count); + } + + // Offset in milliseconds from first sample in report + int deltaMsec; + std::vector stats; +}; + +// Inherit from this to be notified of samples +class SampleListener { + public: + SampleListener(const SampleListener&) = delete; + SampleListener& operator=(const SampleListener&) = delete; + + virtual ~SampleListener(){}; + + // Report bucketed & aggregated values for event + virtual void handleSample(int device, const Sample& sample, bool from_new_version) = 0; + + virtual void update(const Config& config) = 0; + + protected: + SampleListener() = default; +}; + +} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/ScopeExit.h b/plugins/tensorboard-plugins/libkineto/src/ScopeExit.h new file mode 100644 index 0000000000000000000000000000000000000000..b9a6bc83ef942c7fb0e4b198b0396e5d75aa5a3a --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/src/ScopeExit.h @@ -0,0 +1,29 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#pragma once + +// Implement a simple scope handler allowing a function to release +// resources when an error or exception occurs + +template +class ScopeExit { + public: + explicit ScopeExit(T t) : t(t) {} + ~ScopeExit() { + t(); + } + T t; +}; + +template +ScopeExit makeScopeExit(T t) { + return ScopeExit(t); +}; + +// Add a level of indirection so __LINE__ is expanded +#define __kINETO_CONCAT(name, line) name##line +#define ANON_VAR(name, line) __kINETO_CONCAT(name, line) + +#define SCOPE_EXIT(func) \ + const auto ANON_VAR(SCOPE_BLOCK, __LINE__) = \ + makeScopeExit([=]() { func; }) diff --git a/plugins/tensorboard-plugins/libkineto/src/ThreadUtil.cpp b/plugins/tensorboard-plugins/libkineto/src/ThreadUtil.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0f67d54d58512aa47b05aed69748a6894aa06b1c --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/src/ThreadUtil.cpp @@ -0,0 +1,203 @@ +#include "ThreadUtil.h" + +#ifndef _MSC_VER +#include +#include +#include +#include +#else // _MSC_VER +#include +#include +#define WIN32_LEAN_AND_MEAN +#define NOGDI +#include +#include +#undef ERROR +#endif // _MSC_VER + +#ifdef __ANDROID__ +#include +#endif + +#include +#include +#include + +namespace libkineto { + +namespace { +thread_local int32_t _pid = 0; +thread_local int32_t _tid = 0; +thread_local int32_t _sysTid = 0; +} + +int32_t processId() { + if (!_pid) { +#ifndef _MSC_VER + _pid = (int32_t)getpid(); +#else + _pid = (int32_t)GetCurrentProcessId(); +#endif + } + return _pid; +} + +int32_t systemThreadId() { + if (!_sysTid) { +#ifdef __APPLE__ + _sysTid = (int32_t)syscall(SYS_thread_selfid); +#elif defined _MSC_VER + _sysTid = (int32_t)GetCurrentThreadId(); +#else + _sysTid = (int32_t)syscall(SYS_gettid); +#endif + } + return _sysTid; +} + +int32_t threadId() { + if (!_tid) { +#ifdef __APPLE__ + uint64_t tid; + pthread_threadid_np(nullptr, &tid); + _tid = tid; +#elif defined _MSC_VER + _tid = (int32_t)GetCurrentThreadId(); +#else + pthread_t pth = pthread_self(); + int32_t* ptr = reinterpret_cast(&pth); + _tid = *ptr; +#endif + } + return _tid; +} + +namespace { +static constexpr size_t kMaxThreadNameLength = 16; + +static constexpr const char* basename(const char* s, int off = 0) { + return !s[off] + ? s + : s[off] == '/' ? basename(&s[off + 1]) : basename(s, off + 1); +} +#if defined(_MSC_VER) +void *getKernel32Func(const char* procName) { + return GetProcAddress(GetModuleHandleA("KERNEL32.DLL"), procName); +} +#endif +} + +bool setThreadName(const std::string& name) { +#ifdef __APPLE__ + return 0 == pthread_setname_np(name.c_str()); +#elif defined _MSC_VER + // Per https://docs.microsoft.com/en-us/windows/win32/api/processthreadsapi/nf-processthreadsapi-setthreaddescription + // Use runtime linking to set thread description + static auto _SetThreadDescription = reinterpret_cast(getKernel32Func("SetThreadDescription")); + if (!_SetThreadDescription) { + return false; + } + std::wstring_convert> conv; + std::wstring wname = conv.from_bytes(name); + HRESULT hr = _SetThreadDescription(GetCurrentThread(), wname.c_str()); + return SUCCEEDED(hr); +#else + return 0 == pthread_setname_np(pthread_self(), name.c_str()); +#endif +} + +std::string getThreadName() { +#ifndef _MSC_VER + char buf[kMaxThreadNameLength] = ""; + if ( +#ifndef __ANDROID__ + pthread_getname_np(pthread_self(), buf, kMaxThreadNameLength) != 0 +#else + prctl(PR_GET_NAME, buf, kMaxThreadNameLength) != 0 +#endif + ) { + return "Unknown"; + } + return buf; +#else // _MSC_VER + static auto _GetThreadDescription = reinterpret_cast(getKernel32Func("GetThreadDescription")); + if (!_GetThreadDescription) { + return "Unknown"; + } + PWSTR data; + HRESULT hr = _GetThreadDescription(GetCurrentThread(), &data); + if (!SUCCEEDED(hr)) { + return ""; + } + std::wstring_convert> conv; + std::string name = conv.to_bytes(data); + LocalFree(data); + return name; +#endif +} + +// Linux: +// Extract process name from /proc/pid/cmdline. This does not have +// the 16 character limit that /proc/pid/status and /prod/pid/comm has. +std::string processName(int32_t pid) { +#ifdef __linux__ + FILE* cmdfile = fopen(fmt::format("/proc/{}/cmdline", pid).c_str(), "r"); + if (cmdfile != nullptr) { + char* command = nullptr; + int scanned = fscanf(cmdfile, "%ms", &command); + fclose(cmdfile); + if (scanned > 0 && command) { + std::string ret(basename(command)); + free(command); + return ret; + } + } + std::cerr << "Failed to read process name for pid " << pid << std::endl; +#endif + return ""; +} + +// Max number of parent pids to collect, just for extra safeguarding. +constexpr int kMaxParentPids = 10; + +// Return a pair of +static std::pair parentPidAndCommand(int32_t pid) { +#ifdef __linux__ + FILE* statfile = fopen(fmt::format("/proc/{}/stat", pid).c_str(), "r"); + if (statfile == nullptr) { + return std::make_pair(0, ""); + } + int32_t parent_pid; + char* command = nullptr; + int scanned = fscanf(statfile, "%*d (%m[^)]) %*c %d", &command, &parent_pid); + fclose(statfile); + std::pair ret; + if (scanned == 2) { + ret = std::make_pair(parent_pid, std::string(command)); + } else { + std::cerr << "Failed to parse /proc/" << pid << "/stat" << std::endl; + ret = std::make_pair(0, ""); + } + + // The 'm' character in the format tells fscanf to allocate memory + // for the parsed string, which we need to free here. + free(command); + return ret; +#else + return std::make_pair(0, ""); +#endif +} + +std::vector> pidCommandPairsOfAncestors() { + std::vector> pairs; + pairs.reserve(kMaxParentPids + 1); + int32_t curr_pid = processId(); + for (int i = 0; i <= kMaxParentPids && curr_pid > 1; i++) { + std::pair ppid_and_comm = parentPidAndCommand(curr_pid); + pairs.push_back(std::make_pair(curr_pid, ppid_and_comm.second)); + curr_pid = ppid_and_comm.first; + } + return pairs; +} + +} // namespace libkineto diff --git a/plugins/tensorboard-plugins/libkineto/src/WeakSymbols.cpp b/plugins/tensorboard-plugins/libkineto/src/WeakSymbols.cpp new file mode 100644 index 0000000000000000000000000000000000000000..540a5ac8f97c8f38c7ee3d31ea285a3ab7c9f375 --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/src/WeakSymbols.cpp @@ -0,0 +1,12 @@ +#include + +#ifndef _MSC_VER +extern "C" { +// This function is needed to avoid superfluous dependency on GNU OpenMP library when cuPTI is linked statically +// For more details see https://github.com/pytorch/pytorch/issues/51026 +__attribute__((weak)) int acc_get_device_type() { + throw std::runtime_error("Dummy implementation of acc_get_device_type is not supposed to be called!"); +} + +} // extern "C" +#endif diff --git a/plugins/tensorboard-plugins/libkineto/src/cupti_call.h b/plugins/tensorboard-plugins/libkineto/src/cupti_call.h new file mode 100644 index 0000000000000000000000000000000000000000..fd6ebae7691ed607867db5717248ba22f4efa5c0 --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/src/cupti_call.h @@ -0,0 +1,33 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#pragma once + +#include + +#ifdef HAS_CUPTI + +#include + +#define CUPTI_CALL(call) \ + [&]() -> CUptiResult { \ + CUptiResult _status_ = call; \ + if (_status_ != CUPTI_SUCCESS) { \ + const char* _errstr_ = nullptr; \ + cuptiGetResultString(_status_, &_errstr_); \ + LOG(WARNING) << fmt::format( \ + "function {} failed with error {} ({})", \ + #call, \ + _errstr_, \ + (int)_status_); \ + } \ + return _status_; \ + }() + +#define CUPTI_CALL_NOWARN(call) call + +#else + +#define CUPTI_CALL(call) call +#define CUPTI_CALL_NOWARN(call) call + +#endif // HAS_CUPTI diff --git a/plugins/tensorboard-plugins/libkineto/src/cupti_strings.cpp b/plugins/tensorboard-plugins/libkineto/src/cupti_strings.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4535273a277e04b0b6f98b539df82955ef62468f --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/src/cupti_strings.cpp @@ -0,0 +1,502 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#include "cupti_strings.h" + +namespace libkineto { + +const char* memcpyKindString( + CUpti_ActivityMemcpyKind kind) { + switch (kind) { + case CUPTI_ACTIVITY_MEMCPY_KIND_HTOD: + return "HtoD"; + case CUPTI_ACTIVITY_MEMCPY_KIND_DTOH: + return "DtoH"; + case CUPTI_ACTIVITY_MEMCPY_KIND_HTOA: + return "HtoA"; + case CUPTI_ACTIVITY_MEMCPY_KIND_ATOH: + return "AtoH"; + case CUPTI_ACTIVITY_MEMCPY_KIND_ATOA: + return "AtoA"; + case CUPTI_ACTIVITY_MEMCPY_KIND_ATOD: + return "AtoD"; + case CUPTI_ACTIVITY_MEMCPY_KIND_DTOA: + return "DtoA"; + case CUPTI_ACTIVITY_MEMCPY_KIND_DTOD: + return "DtoD"; + case CUPTI_ACTIVITY_MEMCPY_KIND_HTOH: + return "HtoH"; + case CUPTI_ACTIVITY_MEMCPY_KIND_PTOP: + return "PtoP"; + default: + break; + } + return ""; +} + +const char* memoryKindString( + CUpti_ActivityMemoryKind kind) { + switch (kind) { + case CUPTI_ACTIVITY_MEMORY_KIND_UNKNOWN: + return "Unknown"; + case CUPTI_ACTIVITY_MEMORY_KIND_PAGEABLE: + return "Pageable"; + case CUPTI_ACTIVITY_MEMORY_KIND_PINNED: + return "Pinned"; + case CUPTI_ACTIVITY_MEMORY_KIND_DEVICE: + return "Device"; + case CUPTI_ACTIVITY_MEMORY_KIND_ARRAY: + return "Array"; + case CUPTI_ACTIVITY_MEMORY_KIND_MANAGED: + return "Managed"; + case CUPTI_ACTIVITY_MEMORY_KIND_DEVICE_STATIC: + return "Device Static"; + case CUPTI_ACTIVITY_MEMORY_KIND_MANAGED_STATIC: + return "Managed Static"; + case CUPTI_ACTIVITY_MEMORY_KIND_FORCE_INT: + return "Force Int"; + default: + return "Unrecognized"; + } +} + +const char* overheadKindString( + CUpti_ActivityOverheadKind kind) { + switch (kind) { + case CUPTI_ACTIVITY_OVERHEAD_UNKNOWN: + return "Unknown"; + case CUPTI_ACTIVITY_OVERHEAD_DRIVER_COMPILER: + return "Driver Compiler"; + case CUPTI_ACTIVITY_OVERHEAD_CUPTI_BUFFER_FLUSH: + return "Buffer Flush"; + case CUPTI_ACTIVITY_OVERHEAD_CUPTI_INSTRUMENTATION: + return "Instrumentation"; + case CUPTI_ACTIVITY_OVERHEAD_CUPTI_RESOURCE: + return "Resource"; + case CUPTI_ACTIVITY_OVERHEAD_FORCE_INT: + return "Force Int"; + default: + return "Unrecognized"; + } +} + + + +static const char* runtimeCbidNames[] = { + "INVALID", + "cudaDriverGetVersion", + "cudaRuntimeGetVersion", + "cudaGetDeviceCount", + "cudaGetDeviceProperties", + "cudaChooseDevice", + "cudaGetChannelDesc", + "cudaCreateChannelDesc", + "cudaConfigureCall", + "cudaSetupArgument", + "cudaGetLastError", + "cudaPeekAtLastError", + "cudaGetErrorString", + "cudaLaunch", + "cudaFuncSetCacheConfig", + "cudaFuncGetAttributes", + "cudaSetDevice", + "cudaGetDevice", + "cudaSetValidDevices", + "cudaSetDeviceFlags", + "cudaMalloc", + "cudaMallocPitch", + "cudaFree", + "cudaMallocArray", + "cudaFreeArray", + "cudaMallocHost", + "cudaFreeHost", + "cudaHostAlloc", + "cudaHostGetDevicePointer", + "cudaHostGetFlags", + "cudaMemGetInfo", + "cudaMemcpy", + "cudaMemcpy2D", + "cudaMemcpyToArray", + "cudaMemcpy2DToArray", + "cudaMemcpyFromArray", + "cudaMemcpy2DFromArray", + "cudaMemcpyArrayToArray", + "cudaMemcpy2DArrayToArray", + "cudaMemcpyToSymbol", + "cudaMemcpyFromSymbol", + "cudaMemcpyAsync", + "cudaMemcpyToArrayAsync", + "cudaMemcpyFromArrayAsync", + "cudaMemcpy2DAsync", + "cudaMemcpy2DToArrayAsync", + "cudaMemcpy2DFromArrayAsync", + "cudaMemcpyToSymbolAsync", + "cudaMemcpyFromSymbolAsync", + "cudaMemset", + "cudaMemset2D", + "cudaMemsetAsync", + "cudaMemset2DAsync", + "cudaGetSymbolAddress", + "cudaGetSymbolSize", + "cudaBindTexture", + "cudaBindTexture2D", + "cudaBindTextureToArray", + "cudaUnbindTexture", + "cudaGetTextureAlignmentOffset", + "cudaGetTextureReference", + "cudaBindSurfaceToArray", + "cudaGetSurfaceReference", + "cudaGLSetGLDevice", + "cudaGLRegisterBufferObject", + "cudaGLMapBufferObject", + "cudaGLUnmapBufferObject", + "cudaGLUnregisterBufferObject", + "cudaGLSetBufferObjectMapFlags", + "cudaGLMapBufferObjectAsync", + "cudaGLUnmapBufferObjectAsync", + "cudaWGLGetDevice", + "cudaGraphicsGLRegisterImage", + "cudaGraphicsGLRegisterBuffer", + "cudaGraphicsUnregisterResource", + "cudaGraphicsResourceSetMapFlags", + "cudaGraphicsMapResources", + "cudaGraphicsUnmapResources", + "cudaGraphicsResourceGetMappedPointer", + "cudaGraphicsSubResourceGetMappedArray", + "cudaVDPAUGetDevice", + "cudaVDPAUSetVDPAUDevice", + "cudaGraphicsVDPAURegisterVideoSurface", + "cudaGraphicsVDPAURegisterOutputSurface", + "cudaD3D11GetDevice", + "cudaD3D11GetDevices", + "cudaD3D11SetDirect3DDevice", + "cudaGraphicsD3D11RegisterResource", + "cudaD3D10GetDevice", + "cudaD3D10GetDevices", + "cudaD3D10SetDirect3DDevice", + "cudaGraphicsD3D10RegisterResource", + "cudaD3D10RegisterResource", + "cudaD3D10UnregisterResource", + "cudaD3D10MapResources", + "cudaD3D10UnmapResources", + "cudaD3D10ResourceSetMapFlags", + "cudaD3D10ResourceGetSurfaceDimensions", + "cudaD3D10ResourceGetMappedArray", + "cudaD3D10ResourceGetMappedPointer", + "cudaD3D10ResourceGetMappedSize", + "cudaD3D10ResourceGetMappedPitch", + "cudaD3D9GetDevice", + "cudaD3D9GetDevices", + "cudaD3D9SetDirect3DDevice", + "cudaD3D9GetDirect3DDevice", + "cudaGraphicsD3D9RegisterResource", + "cudaD3D9RegisterResource", + "cudaD3D9UnregisterResource", + "cudaD3D9MapResources", + "cudaD3D9UnmapResources", + "cudaD3D9ResourceSetMapFlags", + "cudaD3D9ResourceGetSurfaceDimensions", + "cudaD3D9ResourceGetMappedArray", + "cudaD3D9ResourceGetMappedPointer", + "cudaD3D9ResourceGetMappedSize", + "cudaD3D9ResourceGetMappedPitch", + "cudaD3D9Begin", + "cudaD3D9End", + "cudaD3D9RegisterVertexBuffer", + "cudaD3D9UnregisterVertexBuffer", + "cudaD3D9MapVertexBuffer", + "cudaD3D9UnmapVertexBuffer", + "cudaThreadExit", + "cudaSetDoubleForDevice", + "cudaSetDoubleForHost", + "cudaThreadSynchronize", + "cudaThreadGetLimit", + "cudaThreadSetLimit", + "cudaStreamCreate", + "cudaStreamDestroy", + "cudaStreamSynchronize", + "cudaStreamQuery", + "cudaEventCreate", + "cudaEventCreateWithFlags", + "cudaEventRecord", + "cudaEventDestroy", + "cudaEventSynchronize", + "cudaEventQuery", + "cudaEventElapsedTime", + "cudaMalloc3D", + "cudaMalloc3DArray", + "cudaMemset3D", + "cudaMemset3DAsync", + "cudaMemcpy3D", + "cudaMemcpy3DAsync", + "cudaThreadSetCacheConfig", + "cudaStreamWaitEvent", + "cudaD3D11GetDirect3DDevice", + "cudaD3D10GetDirect3DDevice", + "cudaThreadGetCacheConfig", + "cudaPointerGetAttributes", + "cudaHostRegister", + "cudaHostUnregister", + "cudaDeviceCanAccessPeer", + "cudaDeviceEnablePeerAccess", + "cudaDeviceDisablePeerAccess", + "cudaPeerRegister", + "cudaPeerUnregister", + "cudaPeerGetDevicePointer", + "cudaMemcpyPeer", + "cudaMemcpyPeerAsync", + "cudaMemcpy3DPeer", + "cudaMemcpy3DPeerAsync", + "cudaDeviceReset", + "cudaDeviceSynchronize", + "cudaDeviceGetLimit", + "cudaDeviceSetLimit", + "cudaDeviceGetCacheConfig", + "cudaDeviceSetCacheConfig", + "cudaProfilerInitialize", + "cudaProfilerStart", + "cudaProfilerStop", + "cudaDeviceGetByPCIBusId", + "cudaDeviceGetPCIBusId", + "cudaGLGetDevices", + "cudaIpcGetEventHandle", + "cudaIpcOpenEventHandle", + "cudaIpcGetMemHandle", + "cudaIpcOpenMemHandle", + "cudaIpcCloseMemHandle", + "cudaArrayGetInfo", + "cudaFuncSetSharedMemConfig", + "cudaDeviceGetSharedMemConfig", + "cudaDeviceSetSharedMemConfig", + "cudaCreateTextureObject", + "cudaDestroyTextureObject", + "cudaGetTextureObjectResourceDesc", + "cudaGetTextureObjectTextureDesc", + "cudaCreateSurfaceObject", + "cudaDestroySurfaceObject", + "cudaGetSurfaceObjectResourceDesc", + "cudaMallocMipmappedArray", + "cudaGetMipmappedArrayLevel", + "cudaFreeMipmappedArray", + "cudaBindTextureToMipmappedArray", + "cudaGraphicsResourceGetMappedMipmappedArray", + "cudaStreamAddCallback", + "cudaStreamCreateWithFlags", + "cudaGetTextureObjectResourceViewDesc", + "cudaDeviceGetAttribute", + "cudaStreamDestroy", + "cudaStreamCreateWithPriority", + "cudaStreamGetPriority", + "cudaStreamGetFlags", + "cudaDeviceGetStreamPriorityRange", + "cudaMallocManaged", + "cudaOccupancyMaxActiveBlocksPerMultiprocessor", + "cudaStreamAttachMemAsync", + "cudaGetErrorName", + "cudaOccupancyMaxActiveBlocksPerMultiprocessor", + "cudaLaunchKernel", + "cudaGetDeviceFlags", + "cudaLaunch_ptsz", + "cudaLaunchKernel_ptsz", + "cudaMemcpy_ptds", + "cudaMemcpy2D_ptds", + "cudaMemcpyToArray_ptds", + "cudaMemcpy2DToArray_ptds", + "cudaMemcpyFromArray_ptds", + "cudaMemcpy2DFromArray_ptds", + "cudaMemcpyArrayToArray_ptds", + "cudaMemcpy2DArrayToArray_ptds", + "cudaMemcpyToSymbol_ptds", + "cudaMemcpyFromSymbol_ptds", + "cudaMemcpyAsync_ptsz", + "cudaMemcpyToArrayAsync_ptsz", + "cudaMemcpyFromArrayAsync_ptsz", + "cudaMemcpy2DAsync_ptsz", + "cudaMemcpy2DToArrayAsync_ptsz", + "cudaMemcpy2DFromArrayAsync_ptsz", + "cudaMemcpyToSymbolAsync_ptsz", + "cudaMemcpyFromSymbolAsync_ptsz", + "cudaMemset_ptds", + "cudaMemset2D_ptds", + "cudaMemsetAsync_ptsz", + "cudaMemset2DAsync_ptsz", + "cudaStreamGetPriority_ptsz", + "cudaStreamGetFlags_ptsz", + "cudaStreamSynchronize_ptsz", + "cudaStreamQuery_ptsz", + "cudaStreamAttachMemAsync_ptsz", + "cudaEventRecord_ptsz", + "cudaMemset3D_ptds", + "cudaMemset3DAsync_ptsz", + "cudaMemcpy3D_ptds", + "cudaMemcpy3DAsync_ptsz", + "cudaStreamWaitEvent_ptsz", + "cudaStreamAddCallback_ptsz", + "cudaMemcpy3DPeer_ptds", + "cudaMemcpy3DPeerAsync_ptsz", + "cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags", + "cudaMemPrefetchAsync", + "cudaMemPrefetchAsync_ptsz", + "cudaMemAdvise", + "cudaDeviceGetP2PAttribute", + "cudaGraphicsEGLRegisterImage", + "cudaEGLStreamConsumerConnect", + "cudaEGLStreamConsumerDisconnect", + "cudaEGLStreamConsumerAcquireFrame", + "cudaEGLStreamConsumerReleaseFrame", + "cudaEGLStreamProducerConnect", + "cudaEGLStreamProducerDisconnect", + "cudaEGLStreamProducerPresentFrame", + "cudaEGLStreamProducerReturnFrame", + "cudaGraphicsResourceGetMappedEglFrame", + "cudaMemRangeGetAttribute", + "cudaMemRangeGetAttributes", + "cudaEGLStreamConsumerConnectWithFlags", + "cudaLaunchCooperativeKernel", + "cudaLaunchCooperativeKernel_ptsz", + "cudaEventCreateFromEGLSync", + "cudaLaunchCooperativeKernelMultiDevice", + "cudaFuncSetAttribute", + "cudaImportExternalMemory", + "cudaExternalMemoryGetMappedBuffer", + "cudaExternalMemoryGetMappedMipmappedArray", + "cudaDestroyExternalMemory", + "cudaImportExternalSemaphore", + "cudaSignalExternalSemaphoresAsync", + "cudaSignalExternalSemaphoresAsync_ptsz", + "cudaWaitExternalSemaphoresAsync", + "cudaWaitExternalSemaphoresAsync_ptsz", + "cudaDestroyExternalSemaphore", + "cudaLaunchHostFunc", + "cudaLaunchHostFunc_ptsz", + "cudaGraphCreate", + "cudaGraphKernelNodeGetParams", + "cudaGraphKernelNodeSetParams", + "cudaGraphAddKernelNode", + "cudaGraphAddMemcpyNode", + "cudaGraphMemcpyNodeGetParams", + "cudaGraphMemcpyNodeSetParams", + "cudaGraphAddMemsetNode", + "cudaGraphMemsetNodeGetParams", + "cudaGraphMemsetNodeSetParams", + "cudaGraphAddHostNode", + "cudaGraphHostNodeGetParams", + "cudaGraphAddChildGraphNode", + "cudaGraphChildGraphNodeGetGraph", + "cudaGraphAddEmptyNode", + "cudaGraphClone", + "cudaGraphNodeFindInClone", + "cudaGraphNodeGetType", + "cudaGraphGetRootNodes", + "cudaGraphNodeGetDependencies", + "cudaGraphNodeGetDependentNodes", + "cudaGraphAddDependencies", + "cudaGraphRemoveDependencies", + "cudaGraphDestroyNode", + "cudaGraphInstantiate", + "cudaGraphLaunch", + "cudaGraphLaunch_ptsz", + "cudaGraphExecDestroy", + "cudaGraphDestroy", + "cudaStreamBeginCapture", + "cudaStreamBeginCapture_ptsz", + "cudaStreamIsCapturing", + "cudaStreamIsCapturing_ptsz", + "cudaStreamEndCapture", + "cudaStreamEndCapture_ptsz", + "cudaGraphHostNodeSetParams", + "cudaGraphGetNodes", + "cudaGraphGetEdges", + "cudaStreamGetCaptureInfo", + "cudaStreamGetCaptureInfo_ptsz", + "cudaGraphExecKernelNodeSetParams", + "cudaThreadExchangeStreamCaptureMode", + "cudaDeviceGetNvSciSyncAttributes", + "cudaOccupancyAvailableDynamicSMemPerBlock", + "cudaStreamSetFlags", + "cudaStreamSetFlags_ptsz", + "cudaGraphExecMemcpyNodeSetParams", + "cudaGraphExecMemsetNodeSetParams", + "cudaGraphExecHostNodeSetParams", + "cudaGraphExecUpdate", + "cudaGetFuncBySymbol", + "cudaCtxResetPersistingL2Cache", + "cudaGraphKernelNodeCopyAttributes", + "cudaGraphKernelNodeGetAttribute", + "cudaGraphKernelNodeSetAttribute", + "cudaStreamCopyAttributes", + "cudaStreamCopyAttributes_ptsz", + "cudaStreamGetAttribute", + "cudaStreamGetAttribute_ptsz", + "cudaStreamSetAttribute", + "cudaStreamSetAttribute_ptsz", + "cudaDeviceGetTexture1DLinearMaxWidth", + "cudaGraphUpload", + "cudaGraphUpload_ptsz", + "cudaGraphAddMemcpyNodeToSymbol", + "cudaGraphAddMemcpyNodeFromSymbol", + "cudaGraphAddMemcpyNode1D", + "cudaGraphMemcpyNodeSetParamsToSymbol", + "cudaGraphMemcpyNodeSetParamsFromSymbol", + "cudaGraphMemcpyNodeSetParams1D", + "cudaGraphExecMemcpyNodeSetParamsToSymbol", + "cudaGraphExecMemcpyNodeSetParamsFromSymbol", + "cudaGraphExecMemcpyNodeSetParams1D", + "cudaArrayGetSparseProperties", + "cudaMipmappedArrayGetSparseProperties", + "cudaGraphExecChildGraphNodeSetParams", + "cudaGraphAddEventRecordNode", + "cudaGraphEventRecordNodeGetEvent", + "cudaGraphEventRecordNodeSetEvent", + "cudaGraphAddEventWaitNode", + "cudaGraphEventWaitNodeGetEvent", + "cudaGraphEventWaitNodeSetEvent", + "cudaGraphExecEventRecordNodeSetEvent", + "cudaGraphExecEventWaitNodeSetEvent", + "cudaEventRecordWithFlags", + "cudaEventRecordWithFlags_ptsz", + "cudaDeviceGetDefaultMemPool", + "cudaMallocAsync", + "cudaMallocAsync_ptsz", + "cudaFreeAsync", + "cudaFreeAsync_ptsz", + "cudaMemPoolTrimTo", + "cudaMemPoolSetAttribute", + "cudaMemPoolGetAttribute", + "cudaMemPoolSetAccess", + "cudaArrayGetPlane", + "cudaMemPoolGetAccess", + "cudaMemPoolCreate", + "cudaMemPoolDestroy", + "cudaDeviceSetMemPool", + "cudaDeviceGetMemPool", + "cudaMemPoolExportToShareableHandle", + "cudaMemPoolImportFromShareableHandle", + "cudaMemPoolExportPointer", + "cudaMemPoolImportPointer", + "cudaMallocFromPoolAsync", + "cudaMallocFromPoolAsync_ptsz", + "cudaSignalExternalSemaphoresAsync", + "cudaSignalExternalSemaphoresAsync", + "cudaWaitExternalSemaphoresAsync", + "cudaWaitExternalSemaphoresAsync", + "cudaGraphAddExternalSemaphoresSignalNode", + "cudaGraphExternalSemaphoresSignalNodeGetParams", + "cudaGraphExternalSemaphoresSignalNodeSetParams", + "cudaGraphAddExternalSemaphoresWaitNode", + "cudaGraphExternalSemaphoresWaitNodeGetParams", + "cudaGraphExternalSemaphoresWaitNodeSetParams", + "cudaGraphExecExternalSemaphoresSignalNodeSetParams", + "cudaGraphExecExternalSemaphoresWaitNodeSetParams", + "SIZE" +}; + +const char* runtimeCbidName(CUpti_CallbackId cbid) { + constexpr int names_size = + sizeof(runtimeCbidNames) / sizeof(runtimeCbidNames[0]); + if (cbid < 0 || cbid >= names_size) { + return runtimeCbidNames[CUPTI_RUNTIME_TRACE_CBID_INVALID]; + } + return runtimeCbidNames[cbid]; +} + +} // namespace libkineto diff --git a/plugins/tensorboard-plugins/libkineto/src/cupti_strings.h b/plugins/tensorboard-plugins/libkineto/src/cupti_strings.h new file mode 100644 index 0000000000000000000000000000000000000000..bbfebb983648005d8268d9a29d613d369d6a5384 --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/src/cupti_strings.h @@ -0,0 +1,14 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#pragma once + +#include + +namespace libkineto { + +const char* memoryKindString(CUpti_ActivityMemoryKind kind); +const char* memcpyKindString(CUpti_ActivityMemcpyKind kind); +const char* runtimeCbidName(CUpti_CallbackId cbid); +const char* overheadKindString(CUpti_ActivityOverheadKind kind); + +} // namespace libkineto diff --git a/plugins/tensorboard-plugins/libkineto/src/init.cpp b/plugins/tensorboard-plugins/libkineto/src/init.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4e1022485ac5d17b5af1e0676b6a4595a138e1b5 --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/src/init.cpp @@ -0,0 +1,139 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#include +#include + +#include "ActivityProfilerProxy.h" +#include "Config.h" +#ifdef HAS_CUPTI +#include "CuptiCallbackApi.h" +#include "CuptiActivityApi.h" +#include "EventProfilerController.h" +#endif +#include "cupti_call.h" +#include "libkineto.h" + +#include "Logger.h" + +namespace KINETO_NAMESPACE { + +#ifdef HAS_CUPTI +static bool initialized = false; +static std::mutex initMutex; + +static void initProfilers( + CUpti_CallbackDomain /*domain*/, + CUpti_CallbackId /*cbid*/, + const CUpti_CallbackData* cbInfo) { + CUpti_ResourceData* d = (CUpti_ResourceData*)cbInfo; + CUcontext ctx = d->context; + + VLOG(0) << "CUDA Context created"; + std::lock_guard lock(initMutex); + + if (!initialized) { + libkineto::api().initProfilerIfRegistered(); + initialized = true; + VLOG(0) << "libkineto profilers activated"; + } + if (getenv("KINETO_DISABLE_EVENT_PROFILER") != nullptr) { + VLOG(0) << "Event profiler disabled via env var"; + } else { + ConfigLoader& config_loader = libkineto::api().configLoader(); + config_loader.initBaseConfig(); + EventProfilerController::start(ctx, config_loader); + } +} + +// Some models suffer from excessive instrumentation code gen +// on dynamic attach which can hang for more than 5+ seconds. +// If the workload was meant to be traced, preload the CUPTI +// to take the performance hit early on. +// https://docs.nvidia.com/cupti/r_main.html#r_overhead +static bool shouldPreloadCuptiInstrumentation() { + return getenv("PRELOAD_CUPTI_INSTRUMENTATION"); +} + +static void stopProfiler( + CUpti_CallbackDomain /*domain*/, + CUpti_CallbackId /*cbid*/, + const CUpti_CallbackData* cbInfo) { + CUpti_ResourceData* d = (CUpti_ResourceData*)cbInfo; + CUcontext ctx = d->context; + + LOG(INFO) << "CUDA Context destroyed"; + std::lock_guard lock(initMutex); + EventProfilerController::stop(ctx); +} +#endif // HAS_CUPTI + +} // namespace KINETO_NAMESPACE + +// Callback interface with CUPTI and library constructors +using namespace KINETO_NAMESPACE; +extern "C" { + +// Return true if no CUPTI errors occurred during init +bool libkineto_init(bool cpuOnly, bool logOnError) { + bool success = true; +#ifdef HAS_CUPTI + if (!cpuOnly) { + // libcupti will be lazily loaded on this call. + // If it is not available (e.g. CUDA is not installed), + // then this call will return an error and we just abort init. + auto& cbapi = CuptiCallbackApi::singleton(); + bool status = false; + + if (cbapi.initSuccess()){ + const CUpti_CallbackDomain domain = CUPTI_CB_DOMAIN_RESOURCE; + status = cbapi.registerCallback( + domain, CuptiCallbackApi::RESOURCE_CONTEXT_CREATED, initProfilers); + status = status && cbapi.registerCallback( + domain, CuptiCallbackApi::RESOURCE_CONTEXT_DESTROYED, stopProfiler); + + if (status) { + status = cbapi.enableCallback( + domain, CuptiCallbackApi::RESOURCE_CONTEXT_CREATED); + status = status && cbapi.enableCallback( + domain, CuptiCallbackApi::RESOURCE_CONTEXT_DESTROYED); + } + } + + if (!cbapi.initSuccess() || !status) { + success = false; + cpuOnly = true; + if (logOnError) { + CUPTI_CALL(cbapi.getCuptiStatus()); + LOG(WARNING) << "CUPTI initialization failed - " + << "CUDA profiler activities will be missing"; + LOG(INFO) << "If you see CUPTI_ERROR_INSUFFICIENT_PRIVILEGES, refer to " + << "https://developer.nvidia.com/nvidia-development-tools-solutions-err-nvgpuctrperm-cupti"; + } + } + } + + if (shouldPreloadCuptiInstrumentation()) { + CuptiActivityApi::forceLoadCupti(); + } +#endif // HAS_CUPTI + + ConfigLoader& config_loader = libkineto::api().configLoader(); + libkineto::api().registerProfiler( + std::make_unique(cpuOnly, config_loader)); + + return success; +} + +// The cuda driver calls this function if the CUDA_INJECTION64_PATH environment +// variable is set +int InitializeInjection(void) { + LOG(INFO) << "Injection mode: Initializing libkineto"; + libkineto_init(false /*cpuOnly*/, true /*logOnError*/); + return 1; +} + +void suppressLibkinetoLogMessages() { + SET_LOG_SEVERITY_LEVEL(ERROR); +} + +} // extern C diff --git a/plugins/tensorboard-plugins/libkineto/src/libkineto_api.cpp b/plugins/tensorboard-plugins/libkineto/src/libkineto_api.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9a622e4f5e5cfd54848cb8c6dc05b98da2fb6011 --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/src/libkineto_api.cpp @@ -0,0 +1,41 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#include "libkineto.h" + +#include "ConfigLoader.h" +#include "ThreadUtil.h" + +namespace libkineto { + +LibkinetoApi& api() { + static LibkinetoApi instance(ConfigLoader::instance()); + return instance; +} + +void LibkinetoApi::initClientIfRegistered() { + if (client_) { + if (clientRegisterThread_ != threadId()) { + fprintf( + stderr, + "ERROR: External init callback must run in same thread as registerClient " + "(%d != %d)\n", + threadId(), + (int)clientRegisterThread_); + } else { + client_->init(); + } + } +} + +void LibkinetoApi::registerClient(ClientInterface* client) { + client_ = client; + if (client && activityProfiler_) { + // Can initialize straight away + client->init(); + } + // Assume here that the external init callback is *not* threadsafe + // and only call it if it's the same thread that called registerClient + clientRegisterThread_ = threadId(); +} + +} // namespace libkineto diff --git a/plugins/tensorboard-plugins/libkineto/src/output_base.h b/plugins/tensorboard-plugins/libkineto/src/output_base.h new file mode 100644 index 0000000000000000000000000000000000000000..29d0d57768c91b8593f202cea51071a1affcd88d --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/src/output_base.h @@ -0,0 +1,104 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#pragma once + +#include +#include +#include +#include +#include + +#ifdef HAS_CUPTI +#include +#include "CuptiActivity.h" +#endif // HAS_CUPTI +#include "ActivityBuffers.h" +#include "GenericTraceActivity.h" +#include "ThreadUtil.h" +#include "TraceSpan.h" + +namespace KINETO_NAMESPACE { + class Config; + class GpuKernelActivity; + struct RuntimeActivity; +} + +namespace libkineto { + +using namespace KINETO_NAMESPACE; + +class ActivityLogger { + public: + + virtual ~ActivityLogger() = default; + + struct DeviceInfo { + DeviceInfo(int64_t id, const std::string& name, const std::string& label) : + id(id), name(name), label(label) {} + int64_t id; + const std::string name; + const std::string label; + }; + + struct ResourceInfo { + ResourceInfo( + int64_t deviceId, + int64_t id, + int64_t sortIndex, + const std::string& name) : + id(id), sortIndex(sortIndex), deviceId(deviceId), name(name) {} + int64_t id; + int64_t sortIndex; + int64_t deviceId; + const std::string name; + }; + + struct OverheadInfo { + explicit OverheadInfo(const std::string& name) : name(name) {} + const std::string name; + }; + + virtual void handleDeviceInfo( + const DeviceInfo& info, + uint64_t time) = 0; + + virtual void handleResourceInfo(const ResourceInfo& info, int64_t time) = 0; + + virtual void handleOverheadInfo(const OverheadInfo& info, int64_t time) = 0; + + virtual void handleTraceSpan(const TraceSpan& span) = 0; + + virtual void handleActivity( + const libkineto::ITraceActivity& activity) = 0; + virtual void handleGenericActivity( + const libkineto::GenericTraceActivity& activity) = 0; + +#ifdef HAS_CUPTI + virtual void handleGpuActivity( + const GpuActivity& activity) = 0; + virtual void handleGpuActivity( + const GpuActivity& activity) = 0; + virtual void handleGpuActivity( + const GpuActivity& activity) = 0; + virtual void handleGpuActivity( + const GpuActivity& activity) = 0; +#endif // HAS_CUPTI + + virtual void handleTraceStart( + const std::unordered_map& metadata) = 0; + + void handleTraceStart() { + handleTraceStart(std::unordered_map()); + } + + virtual void finalizeTrace( + const KINETO_NAMESPACE::Config& config, + std::unique_ptr buffers, + int64_t endTime, + std::unordered_map>& metadata) = 0; + + protected: + ActivityLogger() = default; +}; + +} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/output_csv.cpp b/plugins/tensorboard-plugins/libkineto/src/output_csv.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e56c02293982745ed0c013b83bd04d9f42ea7305 --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/src/output_csv.cpp @@ -0,0 +1,88 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#include "output_csv.h" + +#include +#include +#include + +#include +#include + +#include "Config.h" +#include "Logger.h" + +namespace KINETO_NAMESPACE { + +static void write_header( + std::ostream& out, + const std::vector& percentiles) { + out << "timestamp,delta_ms,device,event_name"; + for (int p : percentiles) { + out << ",p" << p; + } + out << ",total" << std::endl; +} + +void EventCSVLogger::update(const Config& config) { + eventNames_.clear(); + eventNames_.insert(config.eventNames().begin(), config.eventNames().end()); + eventNames_.insert(config.metricNames().begin(), config.metricNames().end()); + if (config.percentiles() != percentiles_) { + percentiles_ = config.percentiles(); + if (out_) { + write_header(*out_, percentiles_); + } + } +} + +void EventCSVLogger::handleSample(int device, const Sample& sample, bool from_new_version) { + using namespace std::chrono; + if (out_) { + auto now = system_clock::now(); + auto time = system_clock::to_time_t(now); + for (const Stat& s : sample.stats) { + if (eventNames_.find(s.name) == eventNames_.end()) { + continue; + } + *out_ << fmt::format("{:%Y-%m-%d %H:%M:%S}", fmt::localtime(time)) << ","; + *out_ << sample.deltaMsec << ","; + *out_ << device << ","; + *out_ << s.name; + for (const auto& p : s.percentileValues) { + *out_ << "," << p.second; + } + *out_ << "," << s.total << std::endl; + } + } +} + +void EventCSVFileLogger::update(const Config& config) { + if (config.eventLogFile() != filename_) { + if (of_.is_open()) { + of_.close(); + out_ = nullptr; + percentiles_.clear(); + } + filename_ = config.eventLogFile(); + if (!filename_.empty()) { + of_.open(filename_, std::ios::out | std::ios::trunc); + out_ = &of_; + } + } + EventCSVLogger::update(config); +} + +void EventCSVDbgLogger::update(const Config& config) { + if (out_ && config.verboseLogLevel() < 0) { + out_ = nullptr; + } else if (!out_ && config.verboseLogLevel() >= 0) { + out_ = &LIBKINETO_DBG_STREAM; + } + if (config.verboseLogLevel() >= 0) { + percentiles_.clear(); + EventCSVLogger::update(config); + } +} + +} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/output_csv.h b/plugins/tensorboard-plugins/libkineto/src/output_csv.h new file mode 100644 index 0000000000000000000000000000000000000000..bca29f4db99af8aedf031aed869ff2efd3df6155 --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/src/output_csv.h @@ -0,0 +1,39 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#pragma once +#include "SampleListener.h" + +#include +#include +#include + +namespace KINETO_NAMESPACE { + +class EventCSVLogger : public SampleListener { + public: + void update(const Config& config) override; + void handleSample(int device, const Sample& sample, bool from_new_version) override; + + protected: + EventCSVLogger() : out_(nullptr) {} + + std::ostream* out_; + std::set eventNames_; + std::vector percentiles_; +}; + +class EventCSVFileLogger : public EventCSVLogger { + public: + void update(const Config& config) override; + + private: + std::ofstream of_; + std::string filename_; +}; + +class EventCSVDbgLogger : public EventCSVLogger { + public: + void update(const Config& config) override; +}; + +} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/output_json.cpp b/plugins/tensorboard-plugins/libkineto/src/output_json.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0ef22339fad15d6a78e43d7fcb7761fbbc97333b --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/src/output_json.cpp @@ -0,0 +1,583 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#include "output_json.h" + +#include +#include +#include +#include + +#include "Config.h" +#ifdef HAS_CUPTI +#include "CuptiActivity.h" +#include "CuptiActivity.tpp" +#include "CuptiActivityApi.h" +#include "CudaDeviceProperties.h" +#endif // HAS_CUPTI +#include "Demangle.h" +#include "TraceSpan.h" + +#include "Logger.h" + +using std::endl; +using namespace libkineto; + +namespace KINETO_NAMESPACE { + +static constexpr int kSchemaVersion = 1; +static constexpr char kFlowStart = 's'; +static constexpr char kFlowEnd = 'f'; + +#ifdef __linux__ +static constexpr char kDefaultLogFileFmt[] = + "/tmp/libkineto_activities_{}.json"; +#else +static constexpr char kDefaultLogFileFmt[] = "libkineto_activities_{}.json"; +#endif + +std::string& ChromeTraceLogger::sanitizeStrForJSON(std::string& value) { +// Replace all backslashes with forward slash because Windows paths causing JSONDecodeError. +#ifdef _WIN32 + std::replace(value.begin(), value.end(), '\\', '/'); +#endif + return value; +} + +void ChromeTraceLogger::metadataToJSON( + const std::unordered_map& metadata) { + for (const auto& kv : metadata) { + traceOf_ << fmt::format(R"JSON( + "{}": {},)JSON", kv.first, kv.second); + } +} + +void ChromeTraceLogger::handleTraceStart( + const std::unordered_map& metadata) { + traceOf_ << fmt::format(R"JSON( +{{ + "schemaVersion": {},)JSON", kSchemaVersion); + +#ifdef HAS_CUPTI + traceOf_ << fmt::format(R"JSON( + "deviceProperties": [{} + ],)JSON", devicePropertiesJson()); +#endif + + metadataToJSON(metadata); + traceOf_ << R"JSON( + "traceEvents": [)JSON"; +} + +static std::string defaultFileName() { + return fmt::format(kDefaultLogFileFmt, processId()); +} + +void ChromeTraceLogger::openTraceFile() { + traceOf_.open(fileName_, std::ofstream::out | std::ofstream::trunc); + if (!traceOf_) { + PLOG(ERROR) << "Failed to open '" << fileName_ << "'"; + } else { + LOG(INFO) << "Tracing to " << fileName_; + } +} + +ChromeTraceLogger::ChromeTraceLogger(const std::string& traceFileName) { + fileName_ = traceFileName.empty() ? defaultFileName() : traceFileName; + traceOf_.clear(std::ios_base::badbit); + openTraceFile(); +} + +static int64_t us(int64_t timestamp) { + // It's important that this conversion is the same here and in the CPU trace. + // No rounding! + return timestamp / 1000; +} + +void ChromeTraceLogger::handleDeviceInfo( + const DeviceInfo& info, + uint64_t time) { + if (!traceOf_) { + return; + } + + // M is for metadata + // process_name needs a pid and a name arg + // clang-format off + traceOf_ << fmt::format(R"JSON( + {{ + "name": "process_name", "ph": "M", "ts": {}, "pid": {}, "tid": 0, + "args": {{ + "name": "{}" + }} + }}, + {{ + "name": "process_labels", "ph": "M", "ts": {}, "pid": {}, "tid": 0, + "args": {{ + "labels": "{}" + }} + }}, + {{ + "name": "process_sort_index", "ph": "M", "ts": {}, "pid": {}, "tid": 0, + "args": {{ + "sort_index": {} + }} + }},)JSON", + time, info.id, + info.name, + time, info.id, + info.label, + time, info.id, + info.id < 8 ? info.id + 0x1000000ll : info.id); + // clang-format on +} + +void ChromeTraceLogger::handleResourceInfo( + const ResourceInfo& info, + int64_t time) { + if (!traceOf_) { + return; + } + + // M is for metadata + // thread_name needs a pid and a name arg + // clang-format off + traceOf_ << fmt::format(R"JSON( + {{ + "name": "thread_name", "ph": "M", "ts": {}, "pid": {}, "tid": {}, + "args": {{ + "name": "{}" + }} + }}, + {{ + "name": "thread_sort_index", "ph": "M", "ts": {}, "pid": {}, "tid": {}, + "args": {{ + "sort_index": {} + }} + }},)JSON", + time, info.deviceId, info.id, + info.name, + time, info.deviceId, info.id, + info.sortIndex); + // clang-format on +} + +void ChromeTraceLogger::handleOverheadInfo( + const OverheadInfo& info, + int64_t time) { + if (!traceOf_) { + return; + } + + // TOOD: reserve pid = -1 for overhead but we need to rethink how to scale this for + // other metadata + // clang-format off + traceOf_ << fmt::format(R"JSON( + {{ + "name": "process_name", "ph": "M", "ts": {}, "pid": -1, "tid": 0, + "args": {{ + "name": "{}" + }} + }}, + {{ + "name": "process_sort_index", "ph": "M", "ts": {}, "pid": -1, "tid": 0, + "args": {{ + "sort_index": {} + }} + }},)JSON", + time, + info.name, + time, + 0x100000All); + // clang-format on +} + +void ChromeTraceLogger::handleTraceSpan(const TraceSpan& span) { + if (!traceOf_) { + return; + } + + // clang-format off + traceOf_ << fmt::format(R"JSON( + {{ + "ph": "X", "cat": "Trace", "ts": {}, "dur": {}, + "pid": "Spans", "tid": "{}", + "name": "{}{} ({})", + "args": {{ + "Op count": {} + }} + }}, + {{ + "name": "process_sort_index", "ph": "M", "ts": {}, + "pid": "Spans", "tid": 0, + "args": {{ + "sort_index": {} + }} + }},)JSON", + span.startTime, span.endTime - span.startTime, + span.name, + span.prefix, span.name, span.iteration, + span.opCount, + span.startTime, + // Large sort index to appear at the bottom + 0x20000000ll); + // clang-format on + + addIterationMarker(span); +} + +void ChromeTraceLogger::addIterationMarker(const TraceSpan& span) { + if (!traceOf_) { + return; + } + + // clang-format off + traceOf_ << fmt::format(R"JSON( + {{ + "name": "Iteration Start: {}", "ph": "i", "s": "g", + "pid": "Traces", "tid": "Trace {}", "ts": {} + }},)JSON", + span.name, + span.name, span.startTime); + // clang-format on +} + +static std::string traceActivityJson(const ITraceActivity& activity) { + // clang-format off + int64_t ts = activity.timestamp(); + int64_t duration = activity.duration(); + if (activity.type() == ActivityType::GPU_USER_ANNOTATION) { + // The GPU user annotations start at the same time as the + // first associated GPU activity. Since they appear later + // in the trace file, this causes a visualization issue in Chrome. + // Make it start one us earlier. + ts--; + duration++; // Still need it to end at the orginal point + } + return fmt::format(R"JSON( + "name": "{}", "pid": {}, "tid": {}, + "ts": {}, "dur": {})JSON", + activity.name(), activity.deviceId(), activity.resourceId(), + ts, duration); + // clang-format on +} + +void ChromeTraceLogger::handleGenericInstantEvent( + const libkineto::ITraceActivity& op) { + if (!traceOf_) { + return; + } + + traceOf_ << fmt::format(R"JSON( + {{ + "ph": "i", "s": "t", "name": "{}", + "pid": {}, "tid": {}, + "ts": {}, + "args": {{ + {} + }} + }},)JSON", + op.name(), op.deviceId(), op.resourceId(), + op.timestamp(), op.metadataJson()); +} + +void ChromeTraceLogger::handleActivity( + const libkineto::ITraceActivity& op) { + if (!traceOf_) { + return; + } + + if (op.type() == ActivityType::CPU_INSTANT_EVENT) { + handleGenericInstantEvent(op); + return; + } + + const std::string op_metadata = op.metadataJson(); + std::string separator = ""; + if (op_metadata.find_first_not_of(" \t\n") != std::string::npos) { + separator = ",\n "; + } + std::string span = ""; + if (op.traceSpan()) { + span = fmt::format(R"JSON( + "Trace name": "{}", "Trace iteration": {},)JSON", + op.traceSpan()->name, + op.traceSpan()->iteration); + } + + // clang-format off + traceOf_ << fmt::format(R"JSON( + {{ + "ph": "X", "cat": "{}", {}, + "args": {{{} + "External id": {}{}{} + }} + }},)JSON", + toString(op.type()), traceActivityJson(op), + // args + span, + op.correlationId(), separator, op_metadata); + // clang-format on + if (op.flowId() > 0) { + handleGenericLink(op); + } +} + +void ChromeTraceLogger::handleGenericActivity( + const libkineto::GenericTraceActivity& op) { + handleActivity(op); +} + +void ChromeTraceLogger::handleGenericLink(const ITraceActivity& act) { + static struct { + int type; + char longName[24]; + char shortName[16]; + } flow_names[] = { + {kLinkFwdBwd, "forward_backward", "fwd_bwd"}, + {kLinkAsyncCpuGpu, "async_cpu_to_gpu", "async_gpu"} + }; + for (auto& flow : flow_names) { + if (act.flowType() == flow.type) { + // Link the activities via flow ID in source and destination. + // The source node must return true from flowStart() + // and the destination node false. + if (act.flowStart()) { + handleLink(kFlowStart, act, act.flowId(), flow.longName, flow.shortName); + } else { + handleLink(kFlowEnd, act, act.flowId(), flow.longName, flow.shortName); + } + return; + } + } + LOG(ERROR) << "Unknown flow type: " << act.flowType(); +} + +void ChromeTraceLogger::handleLink( + char type, + const ITraceActivity& e, + int64_t id, + const std::string& cat, + const std::string& name) { + if (!traceOf_) { + return; + } + + // clang-format off + traceOf_ << fmt::format(R"JSON( + {{ + "ph": "{}", "id": {}, "pid": {}, "tid": {}, "ts": {}, + "cat": "{}", "name": "{}", "bp": "e" + }},)JSON", + type, id, e.deviceId(), e.resourceId(), e.timestamp(), cat, name); + // clang-format on +} + +#ifdef HAS_CUPTI +// GPU side kernel activity +void ChromeTraceLogger::handleGpuActivity( + const GpuActivity& activity) { + if (!traceOf_) { + return; + } + const CUpti_ActivityKernel4* kernel = &activity.raw(); + constexpr int threads_per_warp = 32; + float blocks_per_sm = -1.0; + float warps_per_sm = -1.0; + int sm_count = smCount(kernel->deviceId); + if (sm_count) { + blocks_per_sm = + (kernel->gridX * kernel->gridY * kernel->gridZ) / (float) sm_count; + warps_per_sm = + blocks_per_sm * (kernel->blockX * kernel->blockY * kernel->blockZ) + / threads_per_warp; + } + + // Calculate occupancy + float occupancy = KINETO_NAMESPACE::kernelOccupancy( + kernel->deviceId, + kernel->registersPerThread, + kernel->staticSharedMemory, + kernel->dynamicSharedMemory, + kernel->blockX, + kernel->blockY, + kernel->blockZ, + blocks_per_sm); + + // clang-format off + traceOf_ << fmt::format(R"JSON( + {{ + "ph": "X", "cat": "Kernel", {}, + "args": {{ + "queued": {}, "device": {}, "context": {}, + "stream": {}, "correlation": {}, + "registers per thread": {}, + "shared memory": {}, + "blocks per SM": {}, + "warps per SM": {}, + "grid": [{}, {}, {}], + "block": [{}, {}, {}], + "est. achieved occupancy %": {} + }} + }},)JSON", + traceActivityJson(activity), + // args + us(kernel->queued), kernel->deviceId, kernel->contextId, + kernel->streamId, kernel->correlationId, + kernel->registersPerThread, + kernel->staticSharedMemory + kernel->dynamicSharedMemory, + blocks_per_sm, + warps_per_sm, + kernel->gridX, kernel->gridY, kernel->gridZ, + kernel->blockX, kernel->blockY, kernel->blockZ, + (int) (0.5 + occupancy * 100.0)); + // clang-format on + + auto to_id = activity.correlationId(); + handleLink(kFlowEnd, activity, to_id, "async_cpu_to_gpu", "async_gpu"); +} + +static std::string bandwidth(uint64_t bytes, uint64_t duration) { + return duration == 0 ? "\"N/A\"" : fmt::format("{}", bytes * 1.0 / duration); +} + +// GPU side memcpy activity +void ChromeTraceLogger::handleGpuActivity( + const GpuActivity& activity) { + if (!traceOf_) { + return; + } + const CUpti_ActivityMemcpy& memcpy = activity.raw(); + VLOG(2) << memcpy.correlationId << ": MEMCPY"; + // clang-format off + traceOf_ << fmt::format(R"JSON( + {{ + "ph": "X", "cat": "Memcpy", {}, + "args": {{ + "device": {}, "context": {}, + "stream": {}, "correlation": {}, + "bytes": {}, "memory bandwidth (GB/s)": {} + }} + }},)JSON", + traceActivityJson(activity), + // args + memcpy.deviceId, memcpy.contextId, + memcpy.streamId, memcpy.correlationId, + memcpy.bytes, bandwidth(memcpy.bytes, memcpy.end - memcpy.start)); + // clang-format on + + int64_t to_id = activity.correlationId(); + handleLink(kFlowEnd, activity, to_id, "async_cpu_to_gpu", "async_gpu"); +} + +// GPU side memcpy activity +void ChromeTraceLogger::handleGpuActivity( + const GpuActivity& activity) { + if (!traceOf_) { + return; + } + const CUpti_ActivityMemcpy2& memcpy = activity.raw(); + // clang-format off + traceOf_ << fmt::format(R"JSON( + {{ + "ph": "X", "cat": "Memcpy", {}, + "args": {{ + "fromDevice": {}, "inDevice": {}, "toDevice": {}, + "fromContext": {}, "inContext": {}, "toContext": {}, + "stream": {}, "correlation": {}, + "bytes": {}, "memory bandwidth (GB/s)": {} + }} + }},)JSON", + traceActivityJson(activity), + // args + memcpy.srcDeviceId, memcpy.deviceId, memcpy.dstDeviceId, + memcpy.srcContextId, memcpy.contextId, memcpy.dstContextId, + memcpy.streamId, memcpy.correlationId, + memcpy.bytes, bandwidth(memcpy.bytes, memcpy.end - memcpy.start)); + // clang-format on + + int64_t to_id = activity.correlationId(); + handleLink(kFlowEnd, activity, to_id, "async_cpu_to_gpu", "async_gpu"); +} + +void ChromeTraceLogger::handleGpuActivity( + const GpuActivity& activity) { + if (!traceOf_) { + return; + } + const CUpti_ActivityMemset& memset = activity.raw(); + // clang-format off + traceOf_ << fmt::format(R"JSON( + {{ + "ph": "X", "cat": "Memset", {}, + "args": {{ + "device": {}, "context": {}, + "stream": {}, "correlation": {}, + "bytes": {}, "memory bandwidth (GB/s)": {} + }} + }},)JSON", + traceActivityJson(activity), + // args + memset.deviceId, memset.contextId, + memset.streamId, memset.correlationId, + memset.bytes, bandwidth(memset.bytes, memset.end - memset.start)); + // clang-format on + + int64_t to_id = activity.correlationId(); + handleLink(kFlowEnd, activity, to_id, "async_cpu_to_gpu", "async_gpu"); +} +#endif // HAS_CUPTI + +void ChromeTraceLogger::finalizeTrace( + const Config& /*unused*/, + std::unique_ptr /*unused*/, + int64_t endTime, + std::unordered_map>& metadata) { + if (!traceOf_) { + LOG(ERROR) << "Failed to write to log file!"; + return; + } + LOG(INFO) << "Chrome Trace written to " << fileName_; + // clang-format off + traceOf_ << fmt::format(R"JSON( + {{ + "name": "Record Window End", "ph": "i", "s": "g", + "pid": "", "tid": "", "ts": {} + }} + ],)JSON", + endTime); + +#if !USE_GOOGLE_LOG + std::unordered_map PreparedMetadata; + for (const auto& kv : metadata) { + // Skip empty log buckets, ex. skip ERROR if its empty. + if (!kv.second.empty()) { + std::string value = "["; + // Ex. Each metadata from logger is a list of strings, expressed in JSON as + // "ERROR": ["Error 1", "Error 2"], + // "WARNING": ["Warning 1", "Warning 2", "Warning 3"], + // ... + int mdv_count = kv.second.size(); + for (const auto& v : kv.second) { + value.append("\"" + v + "\""); + if(mdv_count > 1) { + value.append(","); + mdv_count--; + } + } + value.append("]"); + PreparedMetadata[kv.first] = sanitizeStrForJSON(value); + } + } + metadataToJSON(PreparedMetadata); +#endif // !USE_GOOGLE_LOG + + // Putting this here because the last entry MUST not end with a comma. + traceOf_ << fmt::format(R"JSON( + "traceName": "{}" +}})JSON", sanitizeStrForJSON(fileName_)); + // clang-format on + + traceOf_.close(); +} + +} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/output_json.h b/plugins/tensorboard-plugins/libkineto/src/output_json.h new file mode 100644 index 0000000000000000000000000000000000000000..5a8a81e4a9fdeef09b0e9ace59b964d5ab99b7ad --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/src/output_json.h @@ -0,0 +1,91 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#pragma once + +#include +#include +#include +#include +#include + +#ifdef HAS_CUPTI +#include +#endif +#include "GenericTraceActivity.h" +#include "output_base.h" + +namespace KINETO_NAMESPACE { + // Previous declaration of TraceSpan is struct. Must match the same here. + struct TraceSpan; +} + +namespace KINETO_NAMESPACE { + +class Config; + +class ChromeTraceLogger : public libkineto::ActivityLogger { + public: + explicit ChromeTraceLogger(const std::string& traceFileName); + + // Note: the caller of these functions should handle concurrency + // i.e., we these functions are not thread-safe + void handleDeviceInfo( + const DeviceInfo& info, + uint64_t time) override; + + void handleOverheadInfo(const OverheadInfo& info, int64_t time) override; + + void handleResourceInfo(const ResourceInfo& info, int64_t time) override; + + void handleTraceSpan(const TraceSpan& span) override; + + void handleActivity(const ITraceActivity& activity) override; + void handleGenericActivity(const GenericTraceActivity& activity) override; + +#ifdef HAS_CUPTI + void handleGpuActivity(const GpuActivity& activity) override; + void handleGpuActivity(const GpuActivity& activity) override; + void handleGpuActivity(const GpuActivity& activity) override; + void handleGpuActivity(const GpuActivity& activity) override; +#endif // HAS_CUPTI + + void handleTraceStart( + const std::unordered_map& metadata) override; + + void finalizeTrace( + const Config& config, + std::unique_ptr buffers, + int64_t endTime, + std::unordered_map>& metadata) override; + + std::string traceFileName() const { + return fileName_; + } + + private: + + // Create a flow event (arrow) + void handleLink( + char type, + const ITraceActivity& e, + int64_t id, + const std::string& cat, + const std::string& name); + + void addIterationMarker(const TraceSpan& span); + + void openTraceFile(); + + void handleGenericInstantEvent(const ITraceActivity& op); + + void handleGenericLink(const ITraceActivity& activity); + + void metadataToJSON(const std::unordered_map& metadata); + + std::string& sanitizeStrForJSON(std::string& value); + + std::string fileName_; + std::ofstream traceOf_; +}; + +} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/output_membuf.h b/plugins/tensorboard-plugins/libkineto/src/output_membuf.h new file mode 100644 index 0000000000000000000000000000000000000000..ef6aadeb65728e0e05e454f98b32ccecca229cf4 --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/src/output_membuf.h @@ -0,0 +1,130 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#pragma once + +#include +#include +#include +#include + +#ifdef HAS_CUPTI +#include +#endif + +#include "Config.h" +#include "GenericTraceActivity.h" +#ifdef HAS_CUPTI +#include "CuptiActivity.h" +#include "CuptiActivity.tpp" +#endif // HAS_CUPTI +#include "output_base.h" + +namespace KINETO_NAMESPACE { + +class Config; + +class MemoryTraceLogger : public ActivityLogger { + public: + MemoryTraceLogger(const Config& config) : config_(config.clone()) { + activities_.reserve(100000); + } + + // Note: the caller of these functions should handle concurrency + // i.e., these functions are not thread-safe + void handleDeviceInfo( + const DeviceInfo& info, + uint64_t time) override { + deviceInfoList_.emplace_back(info, time); + } + + void handleResourceInfo(const ResourceInfo& info, int64_t time) override { + resourceInfoList_.emplace_back(info, time); + } + + void handleOverheadInfo(const OverheadInfo& info, int64_t time) override {} + + void handleTraceSpan(const TraceSpan& span) override { + // Handled separately + } + + template + void addActivityWrapper(const T& act) { + wrappers_.push_back(std::make_unique(act)); + activities_.push_back(wrappers_.back().get()); + } + + // Just add the pointer to the list - ownership of the underlying + // objects must be transferred in ActivityBuffers via finalizeTrace + void handleActivity(const ITraceActivity& activity) override { + activities_.push_back(&activity); + } + void handleGenericActivity(const GenericTraceActivity& activity) override { + addActivityWrapper(activity); + } + +#ifdef HAS_CUPTI + void handleGpuActivity(const GpuActivity& activity) override { + addActivityWrapper(activity); + } + void handleGpuActivity(const GpuActivity& activity) override { + addActivityWrapper(activity); + } + void handleGpuActivity(const GpuActivity& activity) override { + addActivityWrapper(activity); + } + void handleGpuActivity(const GpuActivity& activity) override { + addActivityWrapper(activity); + } +#endif // HAS_CUPTI + + void handleTraceStart( + const std::unordered_map& metadata) override { + metadata_ = metadata; + } + + void finalizeTrace( + const Config& config, + std::unique_ptr buffers, + int64_t endTime, + std::unordered_map>& metadata) override { + buffers_ = std::move(buffers); + endTime_ = endTime; + } + + const std::vector* traceActivities() { + return &activities_; + } + + void log(ActivityLogger& logger) { + logger.handleTraceStart(metadata_); + for (auto& activity : activities_) { + activity->log(logger); + } + for (auto& p : deviceInfoList_) { + logger.handleDeviceInfo(p.first, p.second); + } + for (auto& p : resourceInfoList_) { + logger.handleResourceInfo(p.first, p.second); + } + for (auto& cpu_trace_buffer : buffers_->cpu) { + logger.handleTraceSpan(cpu_trace_buffer->span); + } + // Hold on to the buffers + logger.finalizeTrace(*config_, nullptr, endTime_, loggerMetadata_); + } + + private: + + std::unique_ptr config_; + // Optimization: Remove unique_ptr by keeping separate vector per type + std::vector activities_; + std::vector> wrappers_; + std::vector> deviceInfoList_; + std::vector> resourceInfoList_; + std::unique_ptr buffers_; + std::unordered_map metadata_; + std::unordered_map> loggerMetadata_; + int64_t endTime_{0}; +}; + +} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/test/CMakeLists.txt b/plugins/tensorboard-plugins/libkineto/test/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..ca54460b36cd4ade93918c8512f1309b48552e65 --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/test/CMakeLists.txt @@ -0,0 +1,3 @@ +cmake_minimum_required(VERSION 3.5 FATAL_ERROR) + +# TODO diff --git a/plugins/tensorboard-plugins/libkineto/test/ConfigTest.cpp b/plugins/tensorboard-plugins/libkineto/test/ConfigTest.cpp new file mode 100644 index 0000000000000000000000000000000000000000..16bc86e751cefdbee1d48aeb79fc849b7d151a18 --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/test/ConfigTest.cpp @@ -0,0 +1,315 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#include "include/Config.h" + +#include +#include +#include +#include + +using namespace std::chrono; +using namespace KINETO_NAMESPACE; + +TEST(ParseTest, Whitespace) { + Config cfg; + // Check that various types of whitespace is ignored + EXPECT_TRUE(cfg.parse("")); + EXPECT_TRUE(cfg.parse(" ")); + EXPECT_TRUE(cfg.parse("\t")); + EXPECT_TRUE(cfg.parse("\n")); + EXPECT_TRUE(cfg.parse(" ")); + EXPECT_TRUE(cfg.parse("\t \n \t\t\n\n")); + // Only the above characters are supported + EXPECT_FALSE(cfg.parse("\r\n")); +} + +TEST(ParseTest, Comment) { + Config cfg; + // Anything following a '#' should be ignored, up to a newline + EXPECT_TRUE(cfg.parse("# comment")); + EXPECT_TRUE(cfg.parse(" # ~!@#$")); + EXPECT_TRUE(cfg.parse("\t#abc")); + EXPECT_TRUE(cfg.parse("###\n##")); + EXPECT_TRUE(cfg.parse("EVENTS=util ##ok")); + EXPECT_TRUE(cfg.parse("EVENTS=util ## EVENTS=instruction")); + // Whatever appears before the comment must be valid format + EXPECT_FALSE(cfg.parse("util ## not ok")); + EXPECT_FALSE(cfg.parse("## ok \n blah # not OK")); + // Check that a comment does not affect config parsing + EXPECT_TRUE(cfg.parse("SAMPLE_PERIOD_MSECS = 1 # Sample every millisecond")); + EXPECT_EQ(cfg.samplePeriod(), milliseconds(1)); +} + +TEST(ParseTest, Format) { + Config cfg; + // The basic format is just "name = value". + // Where both value and name can be almost anything. + // Leading and trailing whitespace should be removed + // for both 'name' and 'value', but internal whitespace is not. + EXPECT_FALSE(cfg.parse("events")); + EXPECT_TRUE(cfg.parse("events=")); + EXPECT_FALSE(cfg.parse("=events=")); + EXPECT_TRUE(cfg.parse("events=1,2,3")); + // Only one setting per line + EXPECT_FALSE(cfg.parse("events = 1,2,3 ; metrics = 4,5,6")); + // Names are case sensitive + EXPECT_TRUE(cfg.parse("EVENTS = 1,2,3 \n metrics = 4,5,6")); + EXPECT_EQ(cfg.eventNames(), std::set({"1", "2", "3"})); + EXPECT_EQ(cfg.metricNames().size(), 0); + // Leading and trailing whitespace removed for event and metric names, + // but not internal. + EXPECT_TRUE( + cfg.parse("EVENTS = 1, 2, 3 \n \tMETRICS\t = \t4,\t5\t,\ts i x ")); + EXPECT_EQ(cfg.eventNames(), std::set({"1", "2", "3"})); + EXPECT_EQ(cfg.metricNames(), std::set({"4", "5", "s i x"})); +} + +TEST(ParseTest, DefaultActivityTypes) { + Config cfg; + cfg.validate(std::chrono::system_clock::now()); + auto all_activities = activityTypes(); + // TODO: introduce optional activities + EXPECT_EQ(cfg.selectedActivityTypes(), + std::set(all_activities.begin(), all_activities.end() - 1)); +} + +TEST(ParseTest, ActivityTypes) { + Config cfg; + EXPECT_FALSE(cfg.parse("ACTIVITY_TYPES")); + EXPECT_TRUE(cfg.parse("ACTIVITY_TYPES=")); + EXPECT_FALSE(cfg.parse("=ACTIVITY_TYPES=")); + + EXPECT_EQ(cfg.selectedActivityTypes(), + std::set({ActivityType::CPU_OP, + ActivityType::CPU_INSTANT_EVENT, + ActivityType::PYTHON_FUNCTION, + ActivityType::USER_ANNOTATION, + ActivityType::GPU_USER_ANNOTATION, + ActivityType::GPU_MEMCPY, + ActivityType::GPU_MEMSET, + ActivityType::CONCURRENT_KERNEL, + ActivityType::EXTERNAL_CORRELATION, + ActivityType::GLOW_RUNTIME, + ActivityType::CUDA_RUNTIME, + ActivityType::CUDA_PROFILER_RANGE})); + + Config cfg2; + EXPECT_TRUE(cfg2.parse("ACTIVITY_TYPES=gpu_memcpy,gpu_MeMsEt,kernel")); + EXPECT_EQ(cfg2.selectedActivityTypes(), + std::set({ActivityType::GPU_MEMCPY, + ActivityType::GPU_MEMSET, + ActivityType::CONCURRENT_KERNEL})); + + EXPECT_TRUE(cfg2.parse("ACTIVITY_TYPES = cuda_Runtime,")); + EXPECT_EQ(cfg2.selectedActivityTypes(), + std::set({ActivityType::CUDA_RUNTIME})); + + // Should throw an exception because incorrect activity name + EXPECT_FALSE(cfg2.parse("ACTIVITY_TYPES = memcopy,cuda_runtime")); + + EXPECT_TRUE(cfg2.parse("ACTIVITY_TYPES = cpu_op")); + EXPECT_EQ(cfg2.selectedActivityTypes(), + std::set({ActivityType::CPU_OP})); +} + +TEST(ParseTest, SamplePeriod) { + Config cfg; + EXPECT_TRUE(cfg.parse("SAMPLE_PERIOD_MSECS=10")); + EXPECT_EQ(cfg.samplePeriod(), milliseconds(10)); + EXPECT_TRUE(cfg.parse("SAMPLE_PERIOD_MSECS=0")); + cfg.validate(std::chrono::system_clock::now()); + // 0 should be adjustd up to 1 + EXPECT_EQ(cfg.samplePeriod(), milliseconds(1)); + // Negative and non-int values should fail + EXPECT_FALSE(cfg.parse("SAMPLE_PERIOD_MSECS=-10")); + EXPECT_FALSE(cfg.parse("SAMPLE_PERIOD_MSECS=1.5")); + EXPECT_FALSE(cfg.parse("SAMPLE_PERIOD_MSECS=")); + EXPECT_FALSE(cfg.parse("SAMPLE_PERIOD_MSECS=string")); + EXPECT_EQ(cfg.samplePeriod(), milliseconds(1)); +} + +TEST(ParseTest, MultiplexPeriod) { + Config cfg; + auto now = std::chrono::system_clock::now(); + + EXPECT_TRUE(cfg.parse("SAMPLE_PERIOD_MSECS=100\nMULTIPLEX_PERIOD_MSECS=100")); + EXPECT_EQ(cfg.multiplexPeriod(), milliseconds(100)); + EXPECT_TRUE(cfg.parse("MULTIPLEX_PERIOD_MSECS = 0")); + cfg.validate(now); + // Adjusted to match sample period + EXPECT_EQ(cfg.multiplexPeriod(), milliseconds(100)); + EXPECT_TRUE(cfg.parse("MULTIPLEX_PERIOD_MSECS \t= \t 750 \n")); + cfg.validate(now); + // Adjusted to match multiple of sample period + EXPECT_EQ(cfg.multiplexPeriod(), milliseconds(800)); + EXPECT_FALSE(cfg.parse("MULTIPLEX_PERIOD_MSECS=-10")); + EXPECT_FALSE(cfg.parse("MULTIPLEX_PERIOD_MSECS=1.5")); + EXPECT_FALSE(cfg.parse("MULTIPLEX_PERIOD_MSECS=")); + EXPECT_FALSE(cfg.parse("MULTIPLEX_PERIOD_MSECS=string")); + // Previous value not affected + EXPECT_EQ(cfg.multiplexPeriod(), milliseconds(800)); +} + +TEST(ParseTest, ReportPeriod) { + Config cfg; + EXPECT_TRUE(cfg.parse("REPORT_PERIOD_SECS=1")); + EXPECT_EQ(cfg.reportPeriod(), seconds(1)); + // Whitespace + EXPECT_TRUE(cfg.parse("REPORT_PERIOD_SECS = \t100")); + EXPECT_EQ(cfg.reportPeriod(), seconds(100)); + // Invalid types + EXPECT_FALSE(cfg.parse("REPORT_PERIOD_SECS=-1")); + EXPECT_EQ(cfg.reportPeriod(), seconds(100)); +} + +TEST(ParseTest, SamplesPerReport) { + Config cfg; + auto now = std::chrono::system_clock::now(); + + EXPECT_TRUE(cfg.parse(R"( + SAMPLE_PERIOD_MSECS = 1000 + REPORT_PERIOD_SECS = 1 + SAMPLES_PER_REPORT = 10)")); + cfg.validate(now); + // Adjusted down to one sample per report + EXPECT_EQ(cfg.samplesPerReport(), 1); + EXPECT_TRUE(cfg.parse(R"( + SAMPLE_PERIOD_MSECS = 1000 + REPORT_PERIOD_SECS = 10 + SAMPLES_PER_REPORT = 10)")); + cfg.validate(now); + // No adjustment needed + EXPECT_EQ(cfg.samplesPerReport(), 10); + EXPECT_TRUE(cfg.parse(R"( + SAMPLE_PERIOD_MSECS = 1000 + REPORT_PERIOD_SECS = 2 + SAMPLES_PER_REPORT = 10)")); + cfg.validate(now); + // Adjusted to 2 samples per report + EXPECT_EQ(cfg.samplesPerReport(), 2); + EXPECT_TRUE(cfg.parse(R"( + SAMPLE_PERIOD_MSECS = 200 + REPORT_PERIOD_SECS = 2 + SAMPLES_PER_REPORT = 10)")); + cfg.validate(now); + // No adjustment needed + EXPECT_EQ(cfg.samplesPerReport(), 10); + EXPECT_TRUE(cfg.parse("SAMPLES_PER_REPORT=0")); + cfg.validate(now); + // Adjusted up to 1 + EXPECT_EQ(cfg.samplesPerReport(), 1); + // Invalid value types + EXPECT_FALSE(cfg.parse("SAMPLES_PER_REPORT=-10")); + EXPECT_FALSE(cfg.parse("SAMPLES_PER_REPORT=1.5")); + EXPECT_EQ(cfg.samplesPerReport(), 1); + + EXPECT_TRUE(cfg.parse(R"( + SAMPLE_PERIOD_MSECS=1000 + MULTIPLEX_PERIOD_MSECS=500 # Must be a multiple of sample period + REPORT_PERIOD_SECS=0 # Must be non-zero multiple of multiplex period + SAMPLES_PER_REPORT=5 # Max report period / multiplex period)")); + cfg.validate(now); + // Multiple adjustments + EXPECT_EQ(cfg.samplePeriod(), milliseconds(1000)); + EXPECT_EQ(cfg.multiplexPeriod(), milliseconds(1000)); + EXPECT_EQ(cfg.reportPeriod(), seconds(1)); + EXPECT_EQ(cfg.samplesPerReport(), 1); +} + +TEST(ParseTest, EnableSigUsr2) { + Config cfg; + EXPECT_TRUE(cfg.parse("ENABLE_SIGUSR2=yes")); + EXPECT_TRUE(cfg.sigUsr2Enabled()); + EXPECT_TRUE(cfg.parse("ENABLE_SIGUSR2=no")); + EXPECT_FALSE(cfg.sigUsr2Enabled()); + EXPECT_TRUE(cfg.parse("ENABLE_SIGUSR2=YES")); + EXPECT_TRUE(cfg.sigUsr2Enabled()); + EXPECT_TRUE(cfg.parse("ENABLE_SIGUSR2=NO")); + EXPECT_FALSE(cfg.sigUsr2Enabled()); + EXPECT_TRUE(cfg.parse("ENABLE_SIGUSR2=Y")); + EXPECT_TRUE(cfg.sigUsr2Enabled()); + EXPECT_TRUE(cfg.parse("ENABLE_SIGUSR2=N")); + EXPECT_FALSE(cfg.sigUsr2Enabled()); + EXPECT_TRUE(cfg.parse("ENABLE_SIGUSR2=T")); + EXPECT_TRUE(cfg.sigUsr2Enabled()); + EXPECT_TRUE(cfg.parse("ENABLE_SIGUSR2=F")); + EXPECT_FALSE(cfg.sigUsr2Enabled()); + EXPECT_TRUE(cfg.parse("ENABLE_SIGUSR2=true")); + EXPECT_TRUE(cfg.sigUsr2Enabled()); + EXPECT_TRUE(cfg.parse("ENABLE_SIGUSR2=false")); + EXPECT_FALSE(cfg.sigUsr2Enabled()); + EXPECT_FALSE(cfg.parse("ENABLE_SIGUSR2= ")); + EXPECT_FALSE(cfg.parse("ENABLE_SIGUSR2=2")); + EXPECT_FALSE(cfg.parse("ENABLE_SIGUSR2=-1")); + EXPECT_FALSE(cfg.parse("ENABLE_SIGUSR2=yep")); +} + +TEST(ParseTest, DeviceMask) { + Config cfg; + // Single device + EXPECT_TRUE(cfg.parse("EVENTS_ENABLED_DEVICES = 0")); + EXPECT_TRUE(cfg.eventProfilerEnabledForDevice(0)); + EXPECT_FALSE(cfg.eventProfilerEnabledForDevice(1)); + + // Two devices, internal whitespace + EXPECT_TRUE(cfg.parse("EVENTS_ENABLED_DEVICES = 1, 2")); + EXPECT_FALSE(cfg.eventProfilerEnabledForDevice(0)); + EXPECT_TRUE(cfg.eventProfilerEnabledForDevice(1)); + EXPECT_TRUE(cfg.eventProfilerEnabledForDevice(2)); + EXPECT_FALSE(cfg.eventProfilerEnabledForDevice(3)); + + // Three devices, check that previous devices are ignored + EXPECT_TRUE(cfg.parse("EVENTS_ENABLED_DEVICES = 0, 2,4")); + EXPECT_TRUE(cfg.eventProfilerEnabledForDevice(0)); + EXPECT_FALSE(cfg.eventProfilerEnabledForDevice(1)); + EXPECT_TRUE(cfg.eventProfilerEnabledForDevice(2)); + EXPECT_FALSE(cfg.eventProfilerEnabledForDevice(3)); + EXPECT_TRUE(cfg.eventProfilerEnabledForDevice(4)); + EXPECT_FALSE(cfg.eventProfilerEnabledForDevice(5)); + + // Repeated numbers have no effect + EXPECT_TRUE(cfg.parse("EVENTS_ENABLED_DEVICES = 0,1,1,1,2,3,2,1,3,7,7,3")); + EXPECT_TRUE(cfg.eventProfilerEnabledForDevice(0)); + EXPECT_TRUE(cfg.eventProfilerEnabledForDevice(1)); + EXPECT_TRUE(cfg.eventProfilerEnabledForDevice(2)); + EXPECT_TRUE(cfg.eventProfilerEnabledForDevice(3)); + EXPECT_FALSE(cfg.eventProfilerEnabledForDevice(4)); + EXPECT_FALSE(cfg.eventProfilerEnabledForDevice(6)); + EXPECT_TRUE(cfg.eventProfilerEnabledForDevice(7)); + + // 8 is larger than the max allowed + EXPECT_FALSE(cfg.parse("EVENTS_ENABLED_DEVICES = 3,8")); + + // 300 cannot be held in an uint8_t + EXPECT_FALSE(cfg.parse("EVENTS_ENABLED_DEVICES = 300")); + + // Various illegal cases + EXPECT_FALSE(cfg.parse("EVENTS_ENABLED_DEVICES = 0,1,two,three")); + EXPECT_FALSE(cfg.parse("EVENTS_ENABLED_DEVICES = 0,1,,2")); + EXPECT_FALSE(cfg.parse("EVENTS_ENABLED_DEVICES = -1")); + EXPECT_FALSE(cfg.parse("EVENTS_ENABLED_DEVICES = 1.0")); +} + +TEST(ParseTest, RequestTime) { + Config cfg; + system_clock::time_point now = system_clock::now(); + int64_t tgood_ms = + duration_cast(now.time_since_epoch()).count(); + EXPECT_TRUE(cfg.parse(fmt::format("REQUEST_TIMESTAMP = {}", tgood_ms))); + + tgood_ms = duration_cast((now - seconds(5)).time_since_epoch()) + .count(); + EXPECT_TRUE(cfg.parse(fmt::format("REQUEST_TIMESTAMP = {}", tgood_ms))); + + int64_t tbad_ms = + duration_cast((now - seconds(20)).time_since_epoch()) + .count(); + EXPECT_FALSE(cfg.parse(fmt::format("REQUEST_TIMESTAMP = {}", tbad_ms))); + + EXPECT_FALSE(cfg.parse("REQUEST_TIMESTAMP = 0")); + EXPECT_FALSE(cfg.parse("REQUEST_TIMESTAMP = -1")); + + tbad_ms = duration_cast((now + seconds(10)).time_since_epoch()) + .count(); + EXPECT_FALSE(cfg.parse(fmt::format("REQUEST_TIMESTAMP = {}", tbad_ms))); +} diff --git a/plugins/tensorboard-plugins/libkineto/test/CuptiActivityProfilerTest.cpp b/plugins/tensorboard-plugins/libkineto/test/CuptiActivityProfilerTest.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6e67980ee31a3386580974033201b7acae75d22b --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/test/CuptiActivityProfilerTest.cpp @@ -0,0 +1,629 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#include +#include +#include +#include +#include +#include + +#ifdef __linux__ +#include +#include +#include +#endif + +#include "include/libkineto.h" +#include "include/Config.h" +#include "src/CuptiActivityProfiler.h" +#include "src/ActivityTrace.h" +#include "src/CuptiActivityApi.h" +#include "src/output_base.h" +#include "src/output_json.h" +#include "src/output_membuf.h" + +#include "src/Logger.h" +#include "test/MockActivitySubProfiler.h" + +using namespace std::chrono; +using namespace KINETO_NAMESPACE; + +#define CUDA_LAUNCH_KERNEL CUPTI_RUNTIME_TRACE_CBID_cudaLaunchKernel_v7000 +#define CUDA_MEMCPY CUPTI_RUNTIME_TRACE_CBID_cudaMemcpy_v3020 + +namespace { +const TraceSpan& defaultTraceSpan() { + static TraceSpan span(0, 0, "Unknown", ""); + return span; +} +} + +// Provides ability to easily create a few test CPU-side ops +struct MockCpuActivityBuffer : public CpuTraceBuffer { + MockCpuActivityBuffer(int64_t startTime, int64_t endTime) { + span = TraceSpan(startTime, endTime,"Test trace"); + gpuOpCount = 0; + } + + void addOp(std::string name, int64_t startTime, int64_t endTime, int64_t correlation) { + GenericTraceActivity op(span, ActivityType::CPU_OP, name); + op.startTime = startTime; + op.endTime = endTime; + op.resource = systemThreadId(); + op.id = correlation; + activities.push_back(std::move(op)); + span.opCount++; + } +}; + +// Provides ability to easily create a few test CUPTI ops +struct MockCuptiActivityBuffer { + void addCorrelationActivity(int64_t correlation, CUpti_ExternalCorrelationKind externalKind, int64_t externalId) { + auto& act = *(CUpti_ActivityExternalCorrelation*) malloc(sizeof(CUpti_ActivityExternalCorrelation)); + act.kind = CUPTI_ACTIVITY_KIND_EXTERNAL_CORRELATION; + act.externalId = externalId; + act.externalKind = externalKind; + act.correlationId = correlation; + activities.push_back(reinterpret_cast(&act)); + } + + void addRuntimeActivity( + CUpti_runtime_api_trace_cbid_enum cbid, + int64_t start_us, int64_t end_us, int64_t correlation) { + auto& act = createActivity( + start_us, end_us, correlation); + act.kind = CUPTI_ACTIVITY_KIND_RUNTIME; + act.cbid = cbid; + act.threadId = threadId(); + activities.push_back(reinterpret_cast(&act)); + } + + void addKernelActivity( + int64_t start_us, int64_t end_us, int64_t correlation) { + auto& act = createActivity( + start_us, end_us, correlation); + act.kind = CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL; + act.deviceId = 0; + act.streamId = 1; + act.name = "kernel"; + act.gridX = act.gridY = act.gridZ = 1; + act.blockX = act.blockY = act.blockZ = 1; + activities.push_back(reinterpret_cast(&act)); + } + + void addMemcpyActivity( + int64_t start_us, int64_t end_us, int64_t correlation) { + auto& act = createActivity( + start_us, end_us, correlation); + act.kind = CUPTI_ACTIVITY_KIND_MEMCPY; + act.deviceId = 0; + act.streamId = 2; + act.copyKind = CUPTI_ACTIVITY_MEMCPY_KIND_HTOD; + act.srcKind = CUPTI_ACTIVITY_MEMORY_KIND_PINNED; + act.dstKind = CUPTI_ACTIVITY_MEMORY_KIND_DEVICE; + activities.push_back(reinterpret_cast(&act)); + } + + template + T& createActivity( + int64_t start_us, int64_t end_us, int64_t correlation) { + T& act = *static_cast(malloc(sizeof(T))); + bzero(&act, sizeof(act)); + act.start = start_us * 1000; + act.end = end_us * 1000; + act.correlationId = correlation; + return act; + } + + ~MockCuptiActivityBuffer() { + for (CUpti_Activity* act : activities) { + free(act); + } + } + + std::vector activities; +}; + +// Mock parts of the CuptiActivityApi +class MockCuptiActivities : public CuptiActivityApi { + public: + virtual int smCount() override { + return 10; + } + + virtual const std::pair processActivities( + CuptiActivityBufferMap&, /*unused*/ + std::function handler) override { + for (CUpti_Activity* act : activityBuffer->activities) { + handler(act); + } + return {activityBuffer->activities.size(), 100}; + } + + virtual std::unique_ptr + activityBuffers() override { + auto map = std::make_unique(); + auto buf = std::make_unique(100); + uint8_t* addr = buf->data(); + (*map)[addr] = std::move(buf); + return map; + } + + void bufferRequestedOverride(uint8_t** buffer, size_t* size, size_t* maxNumRecords) { + this->bufferRequested(buffer, size, maxNumRecords); + } + + std::unique_ptr activityBuffer; +}; + + +// Common setup / teardown and helper functions +class CuptiActivityProfilerTest : public ::testing::Test { + protected: + void SetUp() override { + profiler_ = std::make_unique( + cuptiActivities_, /*cpu only*/ false); + cfg_ = std::make_unique(); + cfg_->validate(std::chrono::system_clock::now()); + loggerFactory.addProtocol("file", [](const std::string& url) { + return std::unique_ptr(new ChromeTraceLogger(url)); + }); + } + + std::unique_ptr cfg_; + MockCuptiActivities cuptiActivities_; + std::unique_ptr profiler_; + ActivityLoggerFactory loggerFactory; +}; + +void checkTracefile(const char* filename) { +#ifdef __linux__ + // Check that the expected file was written and that it has some content + int fd = open(filename, O_RDONLY); + if (!fd) { + perror(filename); + } + EXPECT_TRUE(fd); + // Should expect at least 100 bytes + struct stat buf{}; + fstat(fd, &buf); + EXPECT_GT(buf.st_size, 100); + close(fd); +#endif +} + +TEST(CuptiActivityProfiler, AsyncTrace) { + std::vector log_modules( + {"CuptiActivityProfiler.cpp", "output_json.cpp"}); + SET_LOG_VERBOSITY_LEVEL(1, log_modules); + + MockCuptiActivities activities; + CuptiActivityProfiler profiler(activities, /*cpu only*/ true); + + char filename[] = "/tmp/libkineto_testXXXXXX.json"; + mkstemps(filename, 5); + + Config cfg; + + int iter = 0; + int warmup = 5; + auto now = system_clock::now(); + auto startTime = now + seconds(10); + + bool success = cfg.parse(fmt::format(R"CFG( + ACTIVITIES_WARMUP_PERIOD_SECS = {} + ACTIVITIES_DURATION_SECS = 1 + ACTIVITIES_LOG_FILE = {} + PROFILE_START_TIME = {} + )CFG", warmup, filename, duration_cast(startTime.time_since_epoch()).count())); + + EXPECT_TRUE(success); + EXPECT_FALSE(profiler.isActive()); + + auto logger = std::make_unique(cfg.activitiesLogFile()); + + // Usually configuration is done when now is startTime - warmup to kick off warmup + // but start right away in the test + profiler.configure(cfg, now); + profiler.setLogger(logger.get()); + + EXPECT_TRUE(profiler.isActive()); + + // fast forward in time and we have reached the startTime + now = startTime; + + // Run the profiler + // Warmup + // performRunLoopStep is usually called by the controller loop and takes + // the current time and the controller's next wakeup time. + profiler.performRunLoopStep( + /* Current time */ now, /* Next wakeup time */ now); + + auto next = now + milliseconds(1000); + + // performRunLoopStep can also be called by an application thread to update iteration count + // since this config does not use iteration this should have no effect on the state + while (++iter < 20) { + profiler.performRunLoopStep(now, now, iter); + } + + // Runloop should now be in collect state, so start workload + // Perform another runloop step, passing in the end profile time as current. + // This should terminate collection + profiler.performRunLoopStep( + /* Current time */ next, /* Next wakeup time */ next); + // One step needed for each of the Process and Finalize phases + // Doesn't really matter what times we pass in here. + + EXPECT_TRUE(profiler.isActive()); + + auto nextnext = next + milliseconds(1000); + + while (++iter < 40) { + profiler.performRunLoopStep(next, next, iter); + } + + EXPECT_TRUE(profiler.isActive()); + + profiler.performRunLoopStep(nextnext,nextnext); + profiler.performRunLoopStep(nextnext,nextnext); + + // Assert that tracing has completed + EXPECT_FALSE(profiler.isActive()); + + checkTracefile(filename); +} + +TEST(CuptiActivityProfiler, AsyncTraceUsingIter) { + std::vector log_modules( + {"CuptiActivityProfiler.cpp", "output_json.cpp"}); + SET_LOG_VERBOSITY_LEVEL(1, log_modules); + + auto runIterTest = [&]( + int start_iter, int warmup_iters, int trace_iters) { + + LOG(INFO ) << "Async Trace Test: start_iteration = " << start_iter + << " warmup iterations = " << warmup_iters + << " trace iterations = " << trace_iters; + + MockCuptiActivities activities; + CuptiActivityProfiler profiler(activities, /*cpu only*/ true); + + char filename[] = "/tmp/libkineto_testXXXXXX.json"; + mkstemps(filename, 5); + + Config cfg; + + int iter = 0; + auto now = system_clock::now(); + + bool success = cfg.parse(fmt::format(R"CFG( + PROFILE_START_ITERATION = {} + ACTIVITIES_WARMUP_ITERATIONS={} + ACTIVITIES_ITERATIONS={} + ACTIVITIES_DURATION_SECS = 1 + ACTIVITIES_LOG_FILE = {} + )CFG", start_iter, warmup_iters, trace_iters, filename)); + + EXPECT_TRUE(success); + EXPECT_FALSE(profiler.isActive()); + + auto logger = std::make_unique(cfg.activitiesLogFile()); + + // Usually configuration is done when now is startIter - warmup iter to kick off warmup + // but start right away in the test + while (iter < (start_iter - warmup_iters)) { + profiler.performRunLoopStep(now, now, iter++); + } + + profiler.configure(cfg, now); + profiler.setLogger(logger.get()); + + EXPECT_TRUE(profiler.isActive()); + + // fast forward in time, mimicking what will happen in reality + now += seconds(10); + auto next = now + milliseconds(1000); + + // this call to runloop step should not be effecting the state + profiler.performRunLoopStep(now, next); + EXPECT_TRUE(profiler.isActive()); + + // start trace collection + while (iter < start_iter) { + profiler.performRunLoopStep(now, next, iter++); + } + + // Runloop should now be in collect state, so start workload + + while (iter < (start_iter + trace_iters)) { + profiler.performRunLoopStep(now, next, iter++); + } + + // One step is required for each of the Process and Finalize phases + // Doesn't really matter what times we pass in here. + if (iter >= (start_iter + trace_iters)) { + profiler.performRunLoopStep(now, next, iter++); + } + EXPECT_TRUE(profiler.isActive()); + + auto nextnext = next + milliseconds(1000); + + profiler.performRunLoopStep(nextnext, nextnext); + profiler.performRunLoopStep(nextnext, nextnext); + + // Assert that tracing has completed + EXPECT_FALSE(profiler.isActive()); + + checkTracefile(filename); + }; + + // start iter = 50, warmup iters = 5, trace iters = 10 + runIterTest(50, 5, 10); + // should be able to start at 0 iteration + runIterTest(0, 0, 2); + runIterTest(0, 5, 5); +} + +TEST_F(CuptiActivityProfilerTest, SyncTrace) { + using ::testing::Return; + using ::testing::ByMove; + + // Verbose logging is useful for debugging + std::vector log_modules( + {"CuptiActivityProfiler.cpp"}); + SET_LOG_VERBOSITY_LEVEL(2, log_modules); + + // Start and stop profiling + CuptiActivityProfiler profiler(cuptiActivities_, /*cpu only*/ false); + int64_t start_time_us = 100; + int64_t duration_us = 300; + auto start_time = time_point(microseconds(start_time_us)); + profiler.configure(*cfg_, start_time); + profiler.startTrace(start_time); + profiler.stopTrace(start_time + microseconds(duration_us)); + + profiler.recordThreadInfo(); + + // Log some cpu ops + auto cpuOps = std::make_unique( + start_time_us, start_time_us + duration_us); + cpuOps->addOp("op1", 120, 150, 1); + cpuOps->addOp("op2", 130, 140, 2); + cpuOps->addOp("op3", 200, 250, 3); + profiler.transferCpuTrace(std::move(cpuOps)); + + // And some GPU ops + auto gpuOps = std::make_unique(); + gpuOps->addRuntimeActivity(CUDA_LAUNCH_KERNEL, 133, 138, 1); + gpuOps->addRuntimeActivity(CUDA_MEMCPY, 210, 220, 2); + gpuOps->addRuntimeActivity(CUDA_LAUNCH_KERNEL, 230, 245, 3); + gpuOps->addKernelActivity(150, 170, 1); + gpuOps->addMemcpyActivity(240, 250, 2); + gpuOps->addKernelActivity(260, 320, 3); + cuptiActivities_.activityBuffer = std::move(gpuOps); + + // Have the profiler process them + auto logger = std::make_unique(*cfg_); + profiler.processTrace(*logger); + + // Profiler can be reset at this point - logger owns the activities + profiler_->reset(); + + // Wrapper that allows iterating over the activities + ActivityTrace trace(std::move(logger), loggerFactory); + EXPECT_EQ(trace.activities()->size(), 9); + std::map activityCounts; + std::map resourceIds; + for (auto& activity : *trace.activities()) { + activityCounts[activity->name()]++; + resourceIds[activity->resourceId()]++; + } + for (const auto& p : activityCounts) { + LOG(INFO) << p.first << ": " << p.second; + } + EXPECT_EQ(activityCounts["op1"], 1); + EXPECT_EQ(activityCounts["op2"], 1); + EXPECT_EQ(activityCounts["op3"], 1); + EXPECT_EQ(activityCounts["cudaLaunchKernel"], 2); + EXPECT_EQ(activityCounts["cudaMemcpy"], 1); + EXPECT_EQ(activityCounts["kernel"], 2); + EXPECT_EQ(activityCounts["Memcpy HtoD (Pinned -> Device)"], 1); + + auto sysTid = systemThreadId(); + // Ops and runtime events are on thread sysTid + EXPECT_EQ(resourceIds[sysTid], 6); + // Kernels are on stream 1, memcpy on stream 2 + EXPECT_EQ(resourceIds[1], 2); + EXPECT_EQ(resourceIds[2], 1); + +#ifdef __linux__ + char filename[] = "/tmp/libkineto_testXXXXXX.json"; + mkstemps(filename, 5); + trace.save(filename); + // Check that the expected file was written and that it has some content + int fd = open(filename, O_RDONLY); + if (!fd) { + perror(filename); + } + EXPECT_TRUE(fd); + // Should expect at least 100 bytes + struct stat buf{}; + fstat(fd, &buf); + EXPECT_GT(buf.st_size, 100); +#endif +} + +TEST_F(CuptiActivityProfilerTest, GpuUserAnnotationTest) { + // Verbose logging is useful for debugging + std::vector log_modules( + {"CuptiActivityProfiler.cpp"}); + SET_LOG_VERBOSITY_LEVEL(2, log_modules); + + // Start and stop profiling + CuptiActivityProfiler profiler(cuptiActivities_, /*cpu only*/ false); + int64_t start_time_us = 100; + int64_t duration_us = 300; + auto start_time = time_point(microseconds(start_time_us)); + profiler.configure(*cfg_, start_time); + profiler.startTrace(start_time); + profiler.stopTrace(start_time + microseconds(duration_us)); + + int64_t kernelLaunchTime = 120; + profiler.recordThreadInfo(); + + // set up CPU event + auto cpuOps = std::make_unique( + start_time_us, start_time_us + duration_us); + cpuOps->addOp("annotation", kernelLaunchTime, kernelLaunchTime + 10, 1); + profiler.transferCpuTrace(std::move(cpuOps)); + + // set up a couple of GPU events and correlate with above CPU event. + // CUPTI_EXTERNAL_CORRELATION_KIND_CUSTOM1 is used for user annotations. + auto gpuOps = std::make_unique(); + gpuOps->addCorrelationActivity(1, CUPTI_EXTERNAL_CORRELATION_KIND_CUSTOM1, 1); + gpuOps->addKernelActivity(kernelLaunchTime + 5, kernelLaunchTime + 10, 1); + gpuOps->addCorrelationActivity(1, CUPTI_EXTERNAL_CORRELATION_KIND_CUSTOM1, 1); + gpuOps->addKernelActivity(kernelLaunchTime + 15, kernelLaunchTime + 25, 1); + cuptiActivities_.activityBuffer = std::move(gpuOps); + + // process trace + auto logger = std::make_unique(*cfg_); + profiler.processTrace(*logger); + + ActivityTrace trace(std::move(logger), loggerFactory); + std::map counts; + for (auto& activity : *trace.activities()) { + counts[activity->name()]++; + } + + // We should now have an additional annotation activity created + // on the GPU timeline. + EXPECT_EQ(counts["annotation"], 2); + EXPECT_EQ(counts["kernel"], 2); + + auto& annotation = trace.activities()->at(0); + auto& kernel1 = trace.activities()->at(1); + auto& kernel2 = trace.activities()->at(2); + auto& gpu_annotation = trace.activities()->at(3); + EXPECT_EQ(gpu_annotation->type(), ActivityType::GPU_USER_ANNOTATION); + EXPECT_EQ(gpu_annotation->timestamp(), kernel1->timestamp()); + EXPECT_EQ( + gpu_annotation->duration(), + kernel2->timestamp() + kernel2->duration() - kernel1->timestamp()); + EXPECT_EQ(gpu_annotation->deviceId(), kernel1->deviceId()); + EXPECT_EQ(gpu_annotation->resourceId(), kernel1->resourceId()); + EXPECT_EQ(gpu_annotation->correlationId(), annotation->correlationId()); + EXPECT_EQ(gpu_annotation->name(), annotation->name()); +} + +TEST_F(CuptiActivityProfilerTest, SubActivityProfilers) { + using ::testing::Return; + using ::testing::ByMove; + + // Verbose logging is useful for debugging + std::vector log_modules( + {"CuptiActivityProfiler.cpp"}); + SET_LOG_VERBOSITY_LEVEL(2, log_modules); + + // Setup example events to test + GenericTraceActivity ev{defaultTraceSpan(), ActivityType::GLOW_RUNTIME, ""}; + ev.device = 1; + ev.resource = 0; + + int64_t start_time_us = 100; + int64_t duration_us = 1000; + auto start_time = time_point(microseconds(start_time_us)); + + std::vector test_activities{3, ev}; + test_activities[0].startTime = start_time_us; + test_activities[0].endTime = start_time_us + 5000; + test_activities[0].activityName = "SubGraph A execution"; + test_activities[1].startTime = start_time_us; + test_activities[1].endTime = start_time_us + 2000; + test_activities[1].activityName = "Operator foo"; + test_activities[2].startTime = start_time_us + 2500; + test_activities[2].endTime = start_time_us + 2900; + test_activities[2].activityName = "Operator bar"; + + auto mock_activity_profiler = + std::make_unique(test_activities); + + MockCuptiActivities activities; + CuptiActivityProfiler profiler(activities, /*cpu only*/ true); + profiler.addChildActivityProfiler( + std::move(mock_activity_profiler)); + + profiler.configure(*cfg_, start_time); + profiler.startTrace(start_time); + EXPECT_TRUE(profiler.isActive()); + + profiler.stopTrace(start_time + microseconds(duration_us)); + EXPECT_TRUE(profiler.isActive()); + + char filename[] = "/tmp/libkineto_testXXXXXX.json"; + mkstemps(filename, 5); + LOG(INFO) << "Logging to tmp file " << filename; + + // process trace + auto logger = std::make_unique(*cfg_); + profiler.processTrace(*logger); + profiler.setLogger(logger.get()); + + ActivityTrace trace(std::move(logger), loggerFactory); + trace.save(filename); + const auto& traced_activites = trace.activities(); + + // Test we have all the events + EXPECT_EQ(traced_activites->size(), test_activities.size()); + + // Check that the expected file was written and that it has some content + int fd = open(filename, O_RDONLY); + if (!fd) { + perror(filename); + } + EXPECT_TRUE(fd); + + // Should expect at least 100 bytes + struct stat buf{}; + fstat(fd, &buf); + EXPECT_GT(buf.st_size, 100); +} + +TEST_F(CuptiActivityProfilerTest, BufferSizeLimitTestWarmup) { + CuptiActivityProfiler profiler(cuptiActivities_, /*cpu only*/ false); + + auto now = system_clock::now(); + auto startTime = now + seconds(10); + + int maxBufferSizeMB = 3; + + auto startTimeEpoch = std::to_string(duration_cast(startTime.time_since_epoch()).count()); + std::string maxBufferSizeMBStr = std::to_string(maxBufferSizeMB); + cfg_->handleOption("ACTIVITIES_MAX_GPU_BUFFER_SIZE_MB", maxBufferSizeMBStr); + cfg_->handleOption("PROFILE_START_TIME", startTimeEpoch); + + + EXPECT_FALSE(profiler.isActive()); + profiler.configure(*cfg_, now); + EXPECT_TRUE(profiler.isActive()); + + for (size_t i = 0; i < maxBufferSizeMB; i++) { + uint8_t* buf; + size_t gpuBufferSize; + size_t maxNumRecords; + cuptiActivities_.bufferRequestedOverride(&buf, &gpuBufferSize, &maxNumRecords); + } + + // fast forward to startTime and profiler is now running + now = startTime; + + profiler.performRunLoopStep(now, now); + + auto next = now + milliseconds(1000); + profiler.performRunLoopStep(next, next); + profiler.performRunLoopStep(next, next); + profiler.performRunLoopStep(next, next); + + EXPECT_FALSE(profiler.isActive()); +} diff --git a/plugins/tensorboard-plugins/libkineto/test/CuptiCallbackApiTest.cpp b/plugins/tensorboard-plugins/libkineto/test/CuptiCallbackApiTest.cpp new file mode 100644 index 0000000000000000000000000000000000000000..253b696da54d1919e9c0076c5691a11e35345686 --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/test/CuptiCallbackApiTest.cpp @@ -0,0 +1,239 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#include "src/Logger.h" +#include "src/CuptiCallbackApi.h" + +#include +#include +#include +#include + +using namespace std::chrono; +using namespace KINETO_NAMESPACE; +using namespace libkineto; + +const size_t some_data = 42; + +std::atomic simple_cb_calls = 0; + +void simple_cb( + CUpti_CallbackDomain domain, + CUpti_CallbackId cbid, + const CUpti_CallbackData* cbInfo) { + + // simple arg check + EXPECT_EQ(domain, CUPTI_CB_DOMAIN_RUNTIME_API); + EXPECT_EQ(cbid, CUPTI_RUNTIME_TRACE_CBID_cudaLaunchKernel_v7000); + EXPECT_EQ(*reinterpret_cast(cbInfo), some_data); + + simple_cb_calls++; +} + +void atomic_cb( + CUpti_CallbackDomain /*domain*/, + CUpti_CallbackId /*cbid*/, + const CUpti_CallbackData* /*cbInfo)*/) { + // do some atomics in a loop + for (int i = 0; i < 1000; i++) { + // would have used release consistency but this is fine + simple_cb_calls++; + } +} + +void empty_cb( + CUpti_CallbackDomain /*domain*/, + CUpti_CallbackId /*cbid*/, + const CUpti_CallbackData* /*cbInfo*/) { +} + +TEST(CuptiCallbackApiTest, SimpleTest) { + auto& api = CuptiCallbackApi::singleton(); + + auto addSimpleCallback = [&]() -> bool { + bool ret = api.registerCallback( + CUPTI_CB_DOMAIN_RUNTIME_API, + CuptiCallbackApi::CUDA_LAUNCH_KERNEL, + &simple_cb + ); + return ret; + }; + EXPECT_TRUE(addSimpleCallback()) << "Failed to add callback"; + + // duplicate add should be okay + EXPECT_TRUE(addSimpleCallback()) << "Failed to re-add callback"; + + simple_cb_calls = 0; + + // simulate callback + api.__callback_switchboard( + CUPTI_CB_DOMAIN_RUNTIME_API, + CUPTI_RUNTIME_TRACE_CBID_cudaLaunchKernel_v7000, + reinterpret_cast(&some_data)); + + EXPECT_EQ(simple_cb_calls, 1); + + bool ret = api.deleteCallback( + CUPTI_CB_DOMAIN_RUNTIME_API, + CuptiCallbackApi::CUDA_LAUNCH_KERNEL, + &simple_cb + ); + + EXPECT_TRUE(ret) << "Failed to remove callback"; + + ret = api.deleteCallback( + CUPTI_CB_DOMAIN_RUNTIME_API, + CuptiCallbackApi::CUDA_LAUNCH_KERNEL, + &atomic_cb + ); + + EXPECT_FALSE(ret) << "oops! deleted a callback that was never added"; +} + +TEST(CuptiCallbackApiTest, AllCallbacks) { + auto& api = CuptiCallbackApi::singleton(); + + auto testCallback = [&]( + CUpti_CallbackDomain domain, + CUpti_CallbackId cbid, + CuptiCallbackApi::CuptiCallBackID kineto_cbid) -> bool { + + bool ret = api.registerCallback(domain, kineto_cbid, atomic_cb); + EXPECT_TRUE(ret) << "Failed to add callback"; + + if (!ret) { + return false; + } + + simple_cb_calls = 0; + api.__callback_switchboard(domain, cbid, nullptr); + EXPECT_EQ(simple_cb_calls, 1000); + ret = simple_cb_calls == 1000; + + EXPECT_TRUE(api.deleteCallback(domain, kineto_cbid, atomic_cb)); + + return ret; + }; + + EXPECT_TRUE( + testCallback( + CUPTI_CB_DOMAIN_RESOURCE, + CUPTI_CBID_RESOURCE_CONTEXT_CREATED, + CuptiCallbackApi::RESOURCE_CONTEXT_CREATED)) + << "Failed to run callback for RESOURCE_CONTEXT_CREATED"; + + EXPECT_TRUE( + testCallback( + CUPTI_CB_DOMAIN_RESOURCE, + CUPTI_CBID_RESOURCE_CONTEXT_DESTROY_STARTING, + CuptiCallbackApi::RESOURCE_CONTEXT_DESTROYED)) + << "Failed to run callback for RESOURCE_CONTEXT_DESTROYED"; + + EXPECT_TRUE( + testCallback( + CUPTI_CB_DOMAIN_RUNTIME_API, + CUPTI_RUNTIME_TRACE_CBID_cudaLaunchKernel_v7000, + CuptiCallbackApi::CUDA_LAUNCH_KERNEL)) + << "Failed to run callback for CUDA_LAUNCH_KERNEL"; + +} + +TEST(CuptiCallbackApiTest, ContentionTest) { + auto& api = CuptiCallbackApi::singleton(); + const CUpti_CallbackDomain domain = CUPTI_CB_DOMAIN_RUNTIME_API; + const CUpti_CallbackId cbid = CUPTI_RUNTIME_TRACE_CBID_cudaLaunchKernel_v7000; + const CuptiCallbackApi::CuptiCallBackID kineto_cbid = + CuptiCallbackApi::CUDA_LAUNCH_KERNEL; + + bool ret = api.registerCallback(domain, kineto_cbid, empty_cb); + EXPECT_TRUE(ret) << "Failed to add callback"; + + const int iters = 10000; + const int num_readers = 8; + + simple_cb_calls = 0; + + // simulate callbacks being executed on multiple threads in parallel + // during this interval add a new atomic_callback. + // this test ensured mutual exclusion is working fine + auto read_fn = [&](int tid){ + auto start_ts = high_resolution_clock::now(); + for (int i = 0; i < iters; i++) { + api.__callback_switchboard(domain, cbid, nullptr); + } + auto runtime_ms = duration_cast( + high_resolution_clock::now() - start_ts); + LOG(INFO) << "th " << tid << " done in " << runtime_ms.count() << " ms"; + }; + + + std::vector read_ths; + for (int i = 0; i< num_readers; i++) { + read_ths.emplace_back(read_fn, i); + } + + ret = api.registerCallback(domain, kineto_cbid, atomic_cb); + EXPECT_TRUE(ret) << "Failed to add callback"; + + for (auto& t : read_ths) { + t.join(); + } + + //EXPECT_GT(simple_cb_calls, 0) + // << "Atomic callback should have been called at least once."; + + api.deleteCallback(domain, kineto_cbid, empty_cb); + api.deleteCallback(domain, kineto_cbid, atomic_cb); +} + +TEST(CuptiCallbackApiTest, Bechmark) { + + constexpr int iters = 1000; + // atomic bench a number of times to get a baseline + + const CUpti_CallbackDomain domain = CUPTI_CB_DOMAIN_RUNTIME_API; + const CUpti_CallbackId cbid = CUPTI_RUNTIME_TRACE_CBID_cudaLaunchKernel_v7000; + const CuptiCallbackApi::CuptiCallBackID kineto_cbid = + CuptiCallbackApi::CUDA_LAUNCH_KERNEL; + + LOG(INFO) << "Iteration count = " << iters; + + const bool use_empty = true; + auto cbfn = use_empty ? &empty_cb : &atomic_cb; + + // warmup + for (int i = 0; i < 50; i++) { + (*cbfn)(domain, cbid, nullptr); + } + + auto start_ts = high_resolution_clock::now(); + for (int i = 0; i < iters; i++) { + (*cbfn)(domain, cbid, nullptr); + } + auto delta_baseline_ns = duration_cast( + high_resolution_clock::now() - start_ts); + LOG(INFO) << "Baseline runtime = " << delta_baseline_ns.count() << " ns"; + + + auto& api = CuptiCallbackApi::singleton(); + bool ret = api.registerCallback(domain, kineto_cbid, cbfn); + EXPECT_TRUE(ret) << "Failed to add callback"; + + // warmup + for (int i = 0; i < 50; i++) { + api.__callback_switchboard(domain, cbid, nullptr); + } + + start_ts = high_resolution_clock::now(); + for (int i = 0; i < iters; i++) { + api.__callback_switchboard(domain, cbid, nullptr); + } + + auto delta_callback_ns = duration_cast( + high_resolution_clock::now() - start_ts); + LOG(INFO) << "Callback runtime = " << delta_callback_ns.count() << " ns"; + + LOG(INFO) << "Callback runtime per iteration = " << + (delta_callback_ns.count() - delta_baseline_ns.count()) / (double) iters + << " ns"; + +} diff --git a/plugins/tensorboard-plugins/libkineto/test/CuptiProfilerApiTest.cu b/plugins/tensorboard-plugins/libkineto/test/CuptiProfilerApiTest.cu new file mode 100644 index 0000000000000000000000000000000000000000..54ad51b0a1fc9a6a54585d1cad4674943c874b98 --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/test/CuptiProfilerApiTest.cu @@ -0,0 +1,353 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#include +#include +#include + +#include + +// TODO(T90238193) +// @lint-ignore-every CLANGTIDY facebook-hte-RelativeInclude +#include "src/Logger.h" +#include "src/CuptiRangeProfilerApi.h" + +#define DRIVER_API_CALL(apiFuncCall) \ + do { \ + CUresult _status = apiFuncCall; \ + if (_status != CUDA_SUCCESS) { \ + LOG(ERROR) << "Failed invoking CUDA driver function " \ + << #apiFuncCall << " status = " \ + << _status; \ + exit(-1); \ + } \ + } while (0) + +#define EXPECT(expr)\ + if (!(expr)) {\ + }; + +using namespace KINETO_NAMESPACE; + +static int numRanges = 1; + +using Type = double; + +// Device code +__global__ void VecAdd(const Type* A, const Type* B, Type* C, int N) { + int i = blockDim.x * blockIdx.x + threadIdx.x; + if (i < N) { + C[i] = A[i] + B[i]; + } +} + +// Device code +__global__ void VecSub(const Type* A, const Type* B, Type* C, int N) { + int i = blockDim.x * blockIdx.x + threadIdx.x; + if (i < N) { + C[i] = A[i] - B[i]; + } +} + +static void initVec(Type* vec, int n) { + for (int i = 0; i < n; i++) { + vec[i] = i; + } +} + +static void cleanUp( + Type* h_A, + Type* h_B, + Type* h_C, + Type* h_D, + Type* d_A, + Type* d_B, + Type* d_C, + Type* d_D) { + if (d_A) + cudaFree(d_A); + if (d_B) + cudaFree(d_B); + if (d_C) + cudaFree(d_C); + if (d_D) + cudaFree(d_D); + + // Free host memory + if (h_A) + free(h_A); + if (h_B) + free(h_B); + if (h_C) + free(h_C); + if (h_D) + free(h_D); +} + +/* Benchmark application used to test profiler measurements + * This simply runs two kernels vector Add and Vector Subtract + */ + +void VectorAddSubtract() { + int N = 50000; + size_t size = N * sizeof(Type); + int threadsPerBlock = 0; + int blocksPerGrid = 0; + Type *h_A, *h_B, *h_C, *h_D; + Type *d_A, *d_B, *d_C, *d_D; + int i; + Type sum, diff; + + // Allocate input vectors h_A and h_B in host memory + h_A = (Type*)malloc(size); + h_B = (Type*)malloc(size); + h_C = (Type*)malloc(size); + h_D = (Type*)malloc(size); + + // Initialize input vectors + initVec(h_A, N); + initVec(h_B, N); + memset(h_C, 0, size); + memset(h_D, 0, size); + + // Allocate vectors in device memory + cudaMalloc((void**)&d_A, size); + cudaMalloc((void**)&d_B, size); + cudaMalloc((void**)&d_C, size); + cudaMalloc((void**)&d_D, size); + + // Copy vectors from host memory to device memory + cudaMemcpy(d_A, h_A, size, cudaMemcpyHostToDevice); + cudaMemcpy(d_B, h_B, size, cudaMemcpyHostToDevice); + + // Invoke kernel + threadsPerBlock = 256; + blocksPerGrid = (N + threadsPerBlock - 1) / threadsPerBlock; + LOG(INFO) << fmt::format( + "Launching kernel: blocks {}, thread/block {}", + blocksPerGrid, + threadsPerBlock); + + VecAdd<<>>(d_A, d_B, d_C, N); + + VecSub<<>>(d_A, d_B, d_D, N); + + // Copy result from device memory to host memory + // h_C contains the result in host memory + cudaMemcpy(h_C, d_C, size, cudaMemcpyDeviceToHost); + cudaMemcpy(h_D, d_D, size, cudaMemcpyDeviceToHost); + + // Verify result + for (i = 0; i < N; ++i) { + sum = h_A[i] + h_B[i]; + diff = h_A[i] - h_B[i]; + if (h_C[i] != sum || h_D[i] != diff) { + LOG(ERROR) << "Result verification failed"; + break; + } + } + + cleanUp(h_A, h_B, h_C, h_D, d_A, d_B, d_C, d_D); +} + +#if HAS_CUPTI_RANGE_PROFILER +bool runTestWithAutoRange( + int deviceNum, + const std::vector& metricNames, + CUcontext cuContext, + bool async) { + + // create a CUPTI range based profiling profiler + // this configures the counter data as well + CuptiRBProfilerSession profiler( + metricNames, deviceNum, 2, 1, async ? nullptr : cuContext); + + CUpti_ProfilerRange profilerRange = CUPTI_AutoRange; + CUpti_ProfilerReplayMode profilerReplayMode = CUPTI_KernelReplay; + + if (async) { + profiler.asyncStartAndEnable(profilerRange, profilerReplayMode); + } else { + profiler.start(profilerRange, profilerReplayMode); + profiler.enable(); + } + + VectorAddSubtract(); + + if (!async) { + profiler.disable(); + // stop profiler + profiler.stop(); + } else { + profiler.asyncDisableAndStop(); + } + + auto result = profiler.evaluateMetrics(true); + + // check results + EXPECT_EQ(result.metricNames.size(), 3); + EXPECT_EQ(result.rangeVals.size(), 2); + + for (const auto& measurement : result.rangeVals) { + EXPECT_EQ(measurement.values.size(), 3); + + if (measurement.values.size() == 3) { + // smsp__warps_launched.avg + EXPECT_NE(measurement.values[0], 0); + // smsp__sass_thread_inst_executed_op_dadd_pred_on.sum + // each kernel has 50000 dadd ops + EXPECT_EQ(measurement.values[1], 50000); + // sm__inst_executed_pipe_tensor.sum + //EXPECT_EQ(measurement.values[2], 0); + } + } + return true; +} + +bool runTestWithUserRange( + int deviceNum, + const std::vector& metricNames, + CUcontext cuContext, + bool async = false) { + + // create a CUPTI range based profiling profiler + // this configures the counter data as well + CuptiRBProfilerSession profiler( + metricNames, deviceNum, numRanges, 1, async ? nullptr : cuContext); + + CUpti_ProfilerRange profilerRange = CUPTI_UserRange; + CUpti_ProfilerReplayMode profilerReplayMode = CUPTI_UserReplay; + + if (async) { + profiler.asyncStartAndEnable(profilerRange, profilerReplayMode); + { VectorAddSubtract(); } + profiler.disableAndStop(); + } else { + profiler.start(profilerRange, profilerReplayMode); + + /* User takes the resposiblity of replaying the kernel launches */ + bool replay = true; + do { + profiler.beginPass(); + { + profiler.enable(); + + std::string rangeName = "vecAddSub"; + profiler.pushRange(rangeName); + + { VectorAddSubtract(); } + + profiler.popRange(); + profiler.disable(); + } + LOG(INFO) << "Replay starting."; + replay = profiler.endPass(); + + } while (!replay); + + // stop profiler + profiler.stop(); + } + VectorAddSubtract(); + auto result = profiler.evaluateMetrics(true); + + // check results + EXPECT_EQ(result.metricNames.size(), 3); + EXPECT_EQ(result.rangeVals.size(), 1); + + if (result.rangeVals.size() > 0) { + const auto& measurement = result.rangeVals[0]; + EXPECT_EQ(measurement.values.size(), 3); + + if (measurement.values.size() == 3) { + // smsp__warps_launched.avg + EXPECT_NE(measurement.values[0], 0); + // smsp__sass_thread_inst_executed_op_dadd_pred_on.sum + // in async mode multiple passes are not supported yet + if (!async) { + EXPECT_EQ(measurement.values[1], 100000); + } + // sm__inst_executed_pipe_tensor.sum + //EXPECT_EQ(measurement.values[2], 0); + } + } + return true; +} +#endif // HAS_CUPTI_RANGE_PROFILER + +int main(int argc, char* argv[]) { + + CUdevice cuDevice; + + int deviceCount, deviceNum; + int computeCapabilityMajor = 0, computeCapabilityMinor = 0; + + printf("Usage: %s [device_num]\n", argv[0]); + + DRIVER_API_CALL(cuInit(0)); + DRIVER_API_CALL(cuDeviceGetCount(&deviceCount)); + + if (deviceCount == 0) { + LOG(ERROR) << "There is no device supporting CUDA."; + return -2; + } + + if (argc > 1) + deviceNum = atoi(argv[1]); + else + deviceNum = 0; + LOG(INFO) << "CUDA Device Number: " << deviceNum; + + DRIVER_API_CALL(cuDeviceGet(&cuDevice, deviceNum)); + DRIVER_API_CALL(cuDeviceGetAttribute( + &computeCapabilityMajor, + CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, + cuDevice)); + DRIVER_API_CALL(cuDeviceGetAttribute( + &computeCapabilityMinor, + CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, + cuDevice)); + + LOG(INFO) << "Compute Cabapbility = " + << fmt::format("{},{}",computeCapabilityMajor, computeCapabilityMinor); + + if (computeCapabilityMajor < 7) { + LOG(ERROR) << "CUPTI Profiler is not supported with compute capability < 7.0"; + return -2; + } + + CuptiRBProfilerSession::staticInit(); + + // metrics to profile + std::vector metricNames = { + "smsp__warps_launched.avg", + "smsp__sass_thread_inst_executed_op_dadd_pred_on.sum", + "sm__inst_executed_pipe_tensor.sum", + }; + + CUcontext cuContext; + DRIVER_API_CALL(cuCtxCreate(&cuContext, 0, cuDevice)); + + VectorAddSubtract(); + +#if HAS_CUPTI_RANGE_PROFILER + CuptiRBProfilerSession::staticInit(); + + if (!runTestWithUserRange(deviceNum, metricNames, cuContext, false)) { + LOG(ERROR) << "Failed to profiler test benchmark in user range"; + } else if (!runTestWithAutoRange(deviceNum, metricNames, cuContext, false)) { + LOG(ERROR) << "Failed to profiler test benchmark in auto range"; + } else if (!runTestWithUserRange(deviceNum, metricNames, cuContext, true)) { + LOG(ERROR) << "Failed to profiler test benchmark in user range async"; + } else if (!runTestWithAutoRange(deviceNum, metricNames, cuContext, true)) { + LOG(ERROR) << "Failed to profiler test benchmark in auto range async"; + } + + CuptiRBProfilerSession::deInitCupti(); +#else + LOG(WARNING) << "CuptiRBProfilerSession is not supported."; +#endif // HAS_CUPTI_RANGE_PROFILER + DRIVER_API_CALL(cuCtxDestroy(cuContext)); + + + return 0; +} diff --git a/plugins/tensorboard-plugins/libkineto/test/CuptiRangeProfilerApiTest.cpp b/plugins/tensorboard-plugins/libkineto/test/CuptiRangeProfilerApiTest.cpp new file mode 100644 index 0000000000000000000000000000000000000000..28cad722c53ee5defaa7c24cbe0d6b2cbc840a30 --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/test/CuptiRangeProfilerApiTest.cpp @@ -0,0 +1,113 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#include +#include +#include + +#include "include/libkineto.h" +#include "include/Config.h" +#include "src/CuptiRangeProfilerApi.h" + +#include "src/Logger.h" +#include "test/CuptiRangeProfilerTestUtil.h" + +using namespace KINETO_NAMESPACE; + +#if HAS_CUPTI_PROFILER + +TEST(CuptiRangeProfilerApiTest, contextTracking) { + std::vector log_modules( + {"CuptiRangeProfilerApi.cpp"}); + SET_LOG_VERBOSITY_LEVEL(1, log_modules); + + std::array data; + std::array contexts; + for (int i = 0; i < data.size(); i++) { + contexts[i] = reinterpret_cast(&data[i]); + } + + // simulate creating contexts, this calls the trackCudaContexts + // function that would otherwise be called via a callback + uint32_t dev = 0; + for (auto ctx : contexts) { + simulateCudaContextCreate(ctx, dev++); + } + + EXPECT_EQ( + CuptiRBProfilerSession::getActiveDevices(), + std::set({0, 1, 2})); + + simulateCudaContextDestroy(contexts[1], 1); + + EXPECT_EQ( + CuptiRBProfilerSession::getActiveDevices(), + std::set({0, 2})); + + simulateCudaContextDestroy(contexts[0], 0); + simulateCudaContextDestroy(contexts[2], 2); + + EXPECT_TRUE( + CuptiRBProfilerSession::getActiveDevices().empty()); +} + +TEST(CuptiRangeProfilerApiTest, asyncLaunchUserRange) { + std::vector log_modules( + {"CuptiRangeProfilerApi.cpp"}); + SET_LOG_VERBOSITY_LEVEL(1, log_modules); + + // this is bad but the pointer is never accessed + CUcontext ctx0 = reinterpret_cast(10); + simulateCudaContextCreate(ctx0, 0 /*device_id*/); + + auto session = std::make_unique(0, ctx0); + session->asyncStartAndEnable(CUPTI_UserRange, CUPTI_UserReplay); + + simulateKernelLaunch(ctx0, "hello"); + simulateKernelLaunch(ctx0, "foo"); + simulateKernelLaunch(ctx0, "bar"); + + session->asyncDisableAndStop(); + // stop happens after next kernel is run + simulateKernelLaunch(ctx0, "bar"); + simulateCudaContextDestroy(ctx0, 0 /*device_id*/); + + EXPECT_EQ(session->passes_ended, 1); + EXPECT_EQ(session->ranges_ended, 1); + EXPECT_TRUE(session->enabled); +} + +TEST(CuptiRangeProfilerApiTest, asyncLaunchAutoRange) { + std::vector log_modules( + {"CuptiRangeProfilerApi.cpp"}); + SET_LOG_VERBOSITY_LEVEL(1, log_modules); + + // this is bad but the pointer is never accessed + CUcontext ctx0 = reinterpret_cast(10); + CUcontext ctx1 = reinterpret_cast(11); + + simulateCudaContextCreate(ctx0, 0 /*device_id*/); + + auto session = std::make_unique(0, ctx0); + session->asyncStartAndEnable(CUPTI_AutoRange, CUPTI_KernelReplay); + + simulateKernelLaunch(ctx0, "hello"); + simulateKernelLaunch(ctx0, "foo"); + simulateKernelLaunch(ctx1, "kernel_on_different_device"); + simulateKernelLaunch(ctx0, "bar"); + + session->asyncDisableAndStop(); + // stop happens after next kernel is run + simulateKernelLaunch(ctx0, "bar"); + simulateCudaContextDestroy(ctx0, 0 /*device_id*/); + + EXPECT_EQ(session->passes_ended, 0); + EXPECT_EQ(session->ranges_ended, 0); + EXPECT_TRUE(session->enabled); + + EXPECT_EQ( + session->getKernelNames(), + std::vector({"hello", "foo", "bar"})) + << "Kernel names were not tracked"; +} + +#endif // HAS_CUPTI_PROFILER diff --git a/plugins/tensorboard-plugins/libkineto/test/CuptiRangeProfilerConfigTest.cpp b/plugins/tensorboard-plugins/libkineto/test/CuptiRangeProfilerConfigTest.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3f568968238a0e376ab3bae621af00a162af0d25 --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/test/CuptiRangeProfilerConfigTest.cpp @@ -0,0 +1,67 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#include "include/Config.h" +#include "src/CuptiRangeProfilerConfig.h" + +#include +#include +#include +#include + +using namespace std::chrono; +using namespace KINETO_NAMESPACE; + +class CuptiRangeProfilerConfigTest : public ::testing::Test { + protected: + void SetUp() override { + CuptiRangeProfilerConfig::registerFactory(); + } +}; + +TEST_F(CuptiRangeProfilerConfigTest, ConfigureProfiler) { + Config cfg; + std::vector metrics = { + "kineto__cuda_core_flops", + "sm__inst_executed.sum", + "l1tex__data_bank_conflicts_pipe_lsu.sum", + }; + auto metricsConfigStr = + fmt::format("CUPTI_PROFILER_METRICS = {}", fmt::join(metrics, ",")); + + EXPECT_TRUE(cfg.parse(metricsConfigStr)); + EXPECT_TRUE(cfg.parse("CUPTI_PROFILER_ENABLE_PER_KERNEL = true")); + EXPECT_TRUE(cfg.parse("CUPTI_PROFILER_MAX_RANGES = 42")); + + const CuptiRangeProfilerConfig& cupti_cfg = + CuptiRangeProfilerConfig::get(cfg); + + EXPECT_EQ(cupti_cfg.activitiesCuptiMetrics(), metrics); + EXPECT_EQ(cupti_cfg.cuptiProfilerPerKernel(), true); + EXPECT_EQ(cupti_cfg.cuptiProfilerMaxRanges(), 42); + +} + +TEST_F(CuptiRangeProfilerConfigTest, RangesDefaults) { + Config cfg, cfg_auto; + + // do not set max ranges in config, check defaults are sane + EXPECT_TRUE(cfg.parse("CUPTI_PROFILER_METRICS = kineto__cuda_core_flops")); + EXPECT_TRUE(cfg.parse("CUPTI_PROFILER_ENABLE_PER_KERNEL = false")); + + cfg.setSignalDefaults(); + + EXPECT_TRUE(cfg_auto.parse("CUPTI_PROFILER_METRICS = kineto__cuda_core_flops")); + EXPECT_TRUE(cfg_auto.parse("CUPTI_PROFILER_ENABLE_PER_KERNEL = true")); + + cfg_auto.setClientDefaults(); + + int user_ranges, auto_ranges; + + user_ranges = CuptiRangeProfilerConfig::get(cfg).cuptiProfilerMaxRanges(); + auto_ranges = CuptiRangeProfilerConfig::get(cfg_auto).cuptiProfilerMaxRanges(); + + EXPECT_GE(user_ranges, 1) << " in user range mode default to at least 1 ranges"; + EXPECT_GE(auto_ranges, 1000) << " in auto range mode default to at least 1000 ranges"; + + EXPECT_GT(auto_ranges, user_ranges); +} diff --git a/plugins/tensorboard-plugins/libkineto/test/CuptiRangeProfilerTestUtil.h b/plugins/tensorboard-plugins/libkineto/test/CuptiRangeProfilerTestUtil.h new file mode 100644 index 0000000000000000000000000000000000000000..861b65fd701bf69373df657ab2a22d9dba0b27df --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/test/CuptiRangeProfilerTestUtil.h @@ -0,0 +1,96 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#include +#include + +// TODO(T90238193) +// @lint-ignore-every CLANGTIDY facebook-hte-RelativeInclude +#include "CuptiRangeProfilerApi.h" + +namespace KINETO_NAMESPACE { + +#if HAS_CUPTI_PROFILER + +class MockCuptiRBProfilerSession : public CuptiRBProfilerSession { + public: + MockCuptiRBProfilerSession(int deviceId, CUcontext ctx) + : CuptiRBProfilerSession(deviceId, ctx) {} + + void beginPass() override { + LOG(INFO) << " Mock CUPTI begin pass"; + passes_started++; + } + + bool endPass() override { + passes_ended++; + return true; + } + + void flushCounterData() override {} + + void pushRange(const std::string& rangeName) override { + LOG(INFO) << " Mock CUPTI pushrange ( " << rangeName << " )"; + ranges_started++; + } + + void popRange() override { + LOG(INFO) << " Mock CUPTI poprange"; + ranges_ended++; + } + + void stop() override { + runChecks(); + } + + void enable() override { + enabled = true; + } + void disable() override {} + + CuptiProfilerResult evaluateMetrics(bool /*verbose*/) override { + return result; + } + +protected: + void startInternal( + CUpti_ProfilerRange profilerRange, + CUpti_ProfilerReplayMode profilerReplayMode) override { + curRange_ = profilerRange; + curReplay_ = profilerReplayMode; + } + +private: + void runChecks() { + EXPECT_EQ(passes_started, passes_ended); + EXPECT_EQ(ranges_started, ranges_ended); + } + + public: + int passes_started = 0; + int passes_ended = 0; + int ranges_started = 0; + int ranges_ended = 0; + bool enabled = false; + + CuptiProfilerResult result; + +}; + +inline void simulateCudaContextCreate(CUcontext context, uint32_t dev) { + testing::trackCudaCtx( + context, dev, CUPTI_CBID_RESOURCE_CONTEXT_CREATED); +} + +inline void simulateCudaContextDestroy(CUcontext context, uint32_t dev) { + testing::trackCudaCtx( + context, dev, CUPTI_CBID_RESOURCE_CONTEXT_DESTROY_STARTING); +} + +inline void simulateKernelLaunch( + CUcontext context, const std::string& kernelName) { + testing::trackCudaKernelLaunch(context, kernelName.c_str()); +} + +#endif // HAS_CUPTI_PROFILER + +} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/test/CuptiStringsTest.cpp b/plugins/tensorboard-plugins/libkineto/test/CuptiStringsTest.cpp new file mode 100644 index 0000000000000000000000000000000000000000..405f9404a49a5bf8b7433930b0ad2fe898ea2d89 --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/test/CuptiStringsTest.cpp @@ -0,0 +1,29 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#include + +#include "src/cupti_strings.h" + +using namespace KINETO_NAMESPACE; + +TEST(CuptiStringsTest, Valid) { + ASSERT_STREQ( + runtimeCbidName(CUPTI_RUNTIME_TRACE_CBID_INVALID), "INVALID"); + ASSERT_STREQ( + runtimeCbidName(CUPTI_RUNTIME_TRACE_CBID_cudaDriverGetVersion_v3020), + "cudaDriverGetVersion"); + ASSERT_STREQ(runtimeCbidName + (CUPTI_RUNTIME_TRACE_CBID_cudaDeviceSynchronize_v3020), + "cudaDeviceSynchronize"); + ASSERT_STREQ( + runtimeCbidName(CUPTI_RUNTIME_TRACE_CBID_cudaStreamSetAttribute_ptsz_v11000), + "cudaStreamSetAttribute_ptsz"); +} + +TEST(CuptiStringsTest, Invalid) { + ASSERT_STREQ(runtimeCbidName(-1), "INVALID"); + // We can't actually use CUPTI_RUNTIME_TRACE_CBID_SIZE here until we + // auto-generate the string table, since it may have more entries than + // the enum in the version used to compile. + ASSERT_STREQ(runtimeCbidName(1000), "INVALID"); +} diff --git a/plugins/tensorboard-plugins/libkineto/test/EventProfilerTest.cpp b/plugins/tensorboard-plugins/libkineto/test/EventProfilerTest.cpp new file mode 100644 index 0000000000000000000000000000000000000000..cb36c826a7f32b2fe6732e73eae3b6a006b0cd3d --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/test/EventProfilerTest.cpp @@ -0,0 +1,578 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#include "src/EventProfiler.h" + +#include +#include +#include + +using namespace std::chrono; +using namespace KINETO_NAMESPACE; + +TEST(PercentileTest, Create) { + PercentileList pct = {{10, SampleValue(0)}, + {49, SampleValue(0)}, + {50, SampleValue(0)}, + {90, SampleValue(0)}}; + + percentiles({0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100}, pct); + EXPECT_EQ(pct[0].second.getInt(), 10); + EXPECT_EQ(pct[1].second.getInt(), 50); + EXPECT_EQ(pct[2].second.getInt(), 50); + EXPECT_EQ(pct[3].second.getInt(), 90); + + percentiles({80, 10, 20, 70, 60, 40, 90, 30, 50, 0, 100}, pct); + EXPECT_EQ(pct[0].second.getInt(), 10); + EXPECT_EQ(pct[1].second.getInt(), 50); + EXPECT_EQ(pct[2].second.getInt(), 50); + EXPECT_EQ(pct[3].second.getInt(), 90); + + percentiles({80}, pct); + EXPECT_EQ(pct[0].second.getInt(), 80); + EXPECT_EQ(pct[1].second.getInt(), 80); + EXPECT_EQ(pct[2].second.getInt(), 80); + EXPECT_EQ(pct[3].second.getInt(), 80); + + percentiles({80, 50}, pct); + EXPECT_EQ(pct[0].second.getInt(), 50); + EXPECT_EQ(pct[1].second.getInt(), 50); + EXPECT_EQ(pct[2].second.getInt(), 80); + EXPECT_EQ(pct[3].second.getInt(), 80); +} + +TEST(PercentileTest, Normalize) { + PercentileList pct = { + {10, SampleValue(10)}, {50, SampleValue(100.0)}, {90, SampleValue(2000)}}; + + normalize(pct, 2.5); + + EXPECT_EQ(pct[0].second.getInt(), 25); + EXPECT_EQ((int)pct[1].second.getDouble(), 250); + EXPECT_EQ(pct[2].second.getInt(), 5000); +} + +TEST(EventTest, SumSamples) { + Event ev; + ev.instanceCount = 4; + auto t = system_clock::now(); + ev.addSample(t, {1, 2, 3, 4}); + ev.addSample(t, {10, 20, 30, 40}); + ev.addSample(t, {100, 200, 300, 400}); + + EXPECT_EQ(ev.sumInstance(0, {0, 0, 3}), 1); + EXPECT_EQ(ev.sumInstance(0, {0, 1, 3}), 10); + EXPECT_EQ(ev.sumInstance(0, {0, 2, 3}), 100); + + EXPECT_EQ(ev.sumInstance(0, {0, 0, 1}), 111); + + EXPECT_EQ(ev.sumInstance(3, {0, 0, 1}), 444); + + // Non-zero offset + EXPECT_EQ(ev.sumInstance(0, {1, 0, 2}), 10); + EXPECT_EQ(ev.sumInstance(0, {1, 1, 2}), 100); + EXPECT_EQ(ev.sumInstance(0, {1, 0, 1}), 110); + + ev.addSample(t, {1000, 2000, 3000, 4000}); + + EXPECT_EQ(ev.sumInstance(0, {1, 0, 3}), 10); + EXPECT_EQ(ev.sumInstance(0, {1, 1, 3}), 100); + EXPECT_EQ(ev.sumInstance(0, {2, 1, 2}), 1000); + EXPECT_EQ(ev.sumInstance(0, {2, 0, 1}), 1100); + + EXPECT_EQ(ev.sumAll({0, 0, 4}), 10); + EXPECT_EQ(ev.sumAll({1, 0, 3}), 100); + EXPECT_EQ(ev.sumAll({2, 1, 2}), 10000); + EXPECT_EQ(ev.sumAll({0, 1, 2}), 11000); + EXPECT_EQ(ev.sumAll({0, 0, 1}), 11110); +} + +TEST(EventTest, Percentiles) { + Event ev; + ev.instanceCount = 4; + auto t = system_clock::now(); + ev.addSample(t, {3, 2, 1, 4}); + ev.addSample(t, {30, 20, 10, 40}); + ev.addSample(t, {300, 200, 100, 400}); + + PercentileList pct = { + {10, SampleValue(0)}, {50, SampleValue(0)}, {90, SampleValue(0)}}; + + ev.percentiles(pct, {0, 0, 3}); + EXPECT_EQ(pct[0].second.getInt(), 1); + EXPECT_EQ(pct[1].second.getInt(), 3); + EXPECT_EQ(pct[2].second.getInt(), 4); + + ev.percentiles(pct, {0, 0, 1}); + EXPECT_EQ(pct[0].second.getInt(), 111); + EXPECT_EQ(pct[1].second.getInt(), 333); + EXPECT_EQ(pct[2].second.getInt(), 444); +} + +class MockCuptiMetrics : public CuptiMetricApi { + public: + MockCuptiMetrics() : CuptiMetricApi(0) {} + MOCK_METHOD1(idFromName, CUpti_MetricID(const std::string& name)); + MOCK_METHOD1( + events, + std::map(CUpti_MetricID metric_id)); + MOCK_METHOD1(valueKind, CUpti_MetricValueKind(CUpti_MetricID metric)); + MOCK_METHOD1( + evaluationMode, + CUpti_MetricEvaluationMode(CUpti_MetricID metric)); + MOCK_METHOD5( + calculate, + SampleValue( + CUpti_MetricID metric, + CUpti_MetricValueKind kind, + std::vector& events, + std::vector& values, + int64_t duration)); +}; + +TEST(MetricTest, Calculate) { + using ::testing::Return; + MockCuptiMetrics metrics; + + // The events used for the ipc metrics: instructions and cycles + // Pretend we have 2 SMs and 2 samples of each event + Event instr("instructions"); + instr.instanceCount = 2; + auto t = system_clock::now(); + instr.addSample(t, {100, 200}); + instr.addSample(t, {300, 400}); + + Event cycles("cycles"); + cycles.instanceCount = 2; + cycles.addSample(t, {1000, 1200}); + cycles.addSample(t, {1300, 1300}); + + // 2 & 3 are the event ids we specified in the metric + std::map events; + events[2] = std::move(instr); + events[3] = std::move(cycles); + + // Define an ipc metric + EXPECT_CALL(metrics, valueKind(1)) + .Times(1) + .WillOnce(Return(CUPTI_METRIC_VALUE_KIND_DOUBLE)); + Metric m( + "ipc", 1, {2, 3}, CUPTI_METRIC_EVALUATION_MODE_PER_INSTANCE, metrics); + + // Calculate metric for first sample + // Since evaluation mode is CUPTI_METRIC_EVALUATION_MODE_PER_INSTANCE, + // Cupti API will be called three times: once for each SM (2) and once + // to get the total across SMs. + std::vector ids = {2, 3}; + std::vector vals = {100, 1000}; + EXPECT_CALL( + metrics, calculate(1, CUPTI_METRIC_VALUE_KIND_DOUBLE, ids, vals, 1000)) + .Times(1) + .WillOnce(Return(SampleValue(0.1))); + vals = {200, 1200}; + EXPECT_CALL( + metrics, calculate(1, CUPTI_METRIC_VALUE_KIND_DOUBLE, ids, vals, 1000)) + .Times(1) + .WillOnce(Return(SampleValue(0.17))); + vals = {300, 2200}; + EXPECT_CALL( + metrics, calculate(1, CUPTI_METRIC_VALUE_KIND_DOUBLE, ids, vals, 1000)) + .Times(1) + .WillOnce(Return(SampleValue(0.14))); + auto v = m.calculate(events, nanoseconds(1000), {0, 0, 2}); + + EXPECT_EQ(v.perInstance.size(), 2); + EXPECT_EQ(v.perInstance[0].getDouble(), 0.1); + EXPECT_EQ(v.perInstance[1].getDouble(), 0.17); + EXPECT_EQ(v.total.getDouble(), 0.14); + + // Calculate second sample. + // Change evaluation mode to CUPTI_METRIC_EVALUATION_MODE_AGGREGATE. + // Now we should get only one call to the Cupti API for the total. + EXPECT_CALL(metrics, valueKind(1)) + .Times(1) + .WillOnce(Return(CUPTI_METRIC_VALUE_KIND_DOUBLE)); + Metric m2("ipc", 1, {2, 3}, CUPTI_METRIC_EVALUATION_MODE_AGGREGATE, metrics); + vals = {700, 2600}; + EXPECT_CALL( + metrics, calculate(1, CUPTI_METRIC_VALUE_KIND_DOUBLE, ids, vals, 1000)) + .Times(1) + .WillOnce(Return(SampleValue(0.27))); + v = m2.calculate(events, nanoseconds(1000), {0, 1, 2}); + + EXPECT_EQ(v.perInstance.size(), 1); + EXPECT_EQ(v.perInstance[0].getDouble(), 0.27); + EXPECT_EQ(v.total.getDouble(), 0.27); +} + +class MockCuptiEvents : public CuptiEventApi { + public: + MOCK_METHOD1( + createGroupSets, + CUpti_EventGroupSets*(std::vector& ids)); + MOCK_METHOD1(destroyGroupSets, void(CUpti_EventGroupSets* sets)); + MOCK_METHOD0(setContinuousMode, bool()); + MOCK_METHOD1(enablePerInstance, void(CUpti_EventGroup eventGroup)); + MOCK_METHOD1(instanceCount, uint32_t(CUpti_EventGroup eventGroup)); + MOCK_METHOD1(enableGroupSet, void(CUpti_EventGroupSet& set)); + MOCK_METHOD1(disableGroupSet, void(CUpti_EventGroupSet& set)); + MOCK_METHOD3( + readEvent, + void(CUpti_EventGroup g, CUpti_EventID id, std::vector& vals)); + MOCK_METHOD1(eventsInGroup, std::vector(CUpti_EventGroup g)); + MOCK_METHOD1(eventId, CUpti_EventID(const std::string& name)); +}; + +TEST(EventGroupSetTest, CollectSample) { + using ::testing::_; + using ::testing::Return; + using ::testing::SetArgPointee; + const CUpti_EventGroup g1{nullptr}; + const CUpti_EventGroup g2{reinterpret_cast(0x1000)}; + CUpti_EventGroup groups[] = {g1, g2}; + CUpti_EventGroupSet set; + set.eventGroups = groups; + set.numEventGroups = 2; + + std::map events; + Event instr("instructions"); + events[4] = std::move(instr); + Event cycles("cycles"); + events[5] = std::move(cycles); + Event branches("branches"); + events[10] = std::move(branches); + + MockCuptiEvents cupti_events; + EXPECT_CALL(cupti_events, enablePerInstance(g1)).Times(1); + EXPECT_CALL(cupti_events, enablePerInstance(g2)).Times(1); + EXPECT_CALL(cupti_events, instanceCount(g1)).Times(1).WillOnce(Return(80)); + EXPECT_CALL(cupti_events, instanceCount(g2)).Times(1).WillOnce(Return(40)); + std::vector events_in_group1 = {4, 5}; + EXPECT_CALL(cupti_events, eventsInGroup(g1)) + .Times(1) + .WillOnce(Return(events_in_group1)); + std::vector events_in_group2 = {10}; + EXPECT_CALL(cupti_events, eventsInGroup(g2)) + .Times(1) + .WillOnce(Return(events_in_group2)); + EventGroupSet group_set(set, events, cupti_events); + + EXPECT_EQ(group_set.groupCount(), 2); + EXPECT_EQ(events[4].instanceCount, 80); + EXPECT_EQ(events[5].instanceCount, 80); + EXPECT_EQ(events[10].instanceCount, 40); + + // This should not cause any Cupti API action as the group + // set is already disabled + group_set.setEnabled(false); + + // Activate group set - if activated twice, only the first + // should cause cupti API to be called + EXPECT_CALL(cupti_events, enableGroupSet(_)).Times(1); + group_set.setEnabled(false); + group_set.setEnabled(true); + + EXPECT_CALL(cupti_events, eventsInGroup(g1)) + .Times(1) + .WillOnce(Return(events_in_group1)); + EXPECT_CALL(cupti_events, eventsInGroup(g2)) + .Times(1) + .WillOnce(Return(events_in_group2)); + EXPECT_CALL(cupti_events, readEvent(g1, 4, _)).Times(1); + EXPECT_CALL(cupti_events, readEvent(g1, 5, _)).Times(1); + EXPECT_CALL(cupti_events, readEvent(g2, 10, _)).Times(1); + group_set.collectSample(); + + EXPECT_EQ(events[4].sampleCount(), 1); + EXPECT_EQ(events[5].sampleCount(), 1); + EXPECT_EQ(events[10].sampleCount(), 1); +} + +class MockLogger : public SampleListener { + public: + MOCK_METHOD3(handleSample, void(int device, const Sample& sample, bool from_new_version)); + MOCK_METHOD1(update, void(const Config& config)); +}; + +class EventProfilerTest : public ::testing::Test { + protected: + void SetUp() override { + auto cupti_events_ptr = std::make_unique(); + auto cupti_metrics_ptr = std::make_unique(); + cuptiEvents_ = cupti_events_ptr.get(); + cuptiMetrics_ = cupti_metrics_ptr.get(); + loggers_.push_back(std::make_unique()); + onDemandLoggers_.push_back(std::make_unique()); + profiler_ = std::make_unique( + std::move(cupti_events_ptr), + std::move(cupti_metrics_ptr), + loggers_, + onDemandLoggers_); + + for (int i = 0; i < kEventGroupCount; i++) { + eventGroups_[i] = &eventGroups_[i]; + } + for (int i = 0; i < kGroupSetCount; i++) { + // Default size to 1 but can be changed by test + groupSet_[i].numEventGroups = 1; + // Two groups per set + groupSet_[i].eventGroups = &eventGroups_[i * 2]; + } + groupSets_.numSets = 1; + groupSets_.sets = groupSet_; + } + + MockCuptiEvents* cuptiEvents_; + MockCuptiMetrics* cuptiMetrics_; + std::vector> loggers_; + std::vector> onDemandLoggers_; + constexpr static int kEventGroupCount = 4; + constexpr static int kGroupSetCount = 2; + CUpti_EventGroup eventGroups_[kEventGroupCount]; + CUpti_EventGroupSet groupSet_[kGroupSetCount]; + CUpti_EventGroupSets groupSets_; + std::unique_ptr profiler_; +}; + +TEST_F(EventProfilerTest, ConfigureFailure) { + using namespace testing; + + // Default config has no counters enabled. + // Check that profiler remains disabled. + Config cfg; + profiler_->configure(cfg, nullptr); + + EXPECT_FALSE(profiler_->enabled()); + + // There is no event named "cycles" + // In this case the profiler should print a warning and remain disabled + bool parsed = cfg.parse("EVENTS = cycles"); + EXPECT_TRUE(parsed); + + // EventProfiler should handle exception thrown from createGroupSets + // Configuration will be applied twice - once for combined base + on-demand + // and then again falling back to base + EXPECT_CALL(*cuptiEvents_, eventId("cycles")) + .Times(2) + .WillRepeatedly(Return(0)); + std::vector ids = {0}; + EXPECT_CALL(*cuptiEvents_, createGroupSets(ids)) + .Times(2) + .WillRepeatedly(Throw( + std::system_error(EINVAL, std::generic_category(), "Event ID"))); + profiler_->configure(cfg, nullptr); + + EXPECT_FALSE(profiler_->enabled()); +} + +TEST_F(EventProfilerTest, ConfigureBase) { + using namespace testing; + + // Test normal path, simple base config + Config cfg; + bool parsed = cfg.parse("EVENTS = elapsed_cycles_sm"); + EXPECT_TRUE(parsed); + + // One valid event - expect one call to eventId and createGroupSets + EXPECT_CALL(*cuptiEvents_, eventId("elapsed_cycles_sm")) + .Times(1) + .WillOnce(Return(5)); + std::vector ids = {5}; + EXPECT_CALL(*cuptiEvents_, createGroupSets(ids)) + .Times(1) + .WillOnce(Return(&groupSets_)); + EXPECT_CALL(*cuptiEvents_, enablePerInstance(eventGroups_[0])).Times(1); + EXPECT_CALL(*cuptiEvents_, instanceCount(eventGroups_[0])) + .Times(1) + .WillOnce(Return(80)); + EXPECT_CALL(*cuptiEvents_, eventsInGroup(eventGroups_[0])) + .Times(1) + .WillOnce(Return(ids)); + EXPECT_CALL(*cuptiEvents_, enableGroupSet(_)).Times(1); + + profiler_->configure(cfg, nullptr); + + EXPECT_TRUE(profiler_->enabled()); +} + +TEST_F(EventProfilerTest, ConfigureOnDemand) { + using namespace testing; + + // Test base + on-demand config, one event and one metric + Config cfg, on_demand_cfg; + bool parsed = cfg.parse(R"( + EVENTS = active_cycles + SAMPLE_PERIOD_MSECS=500 + REPORT_PERIOD_SECS=10 + SAMPLES_PER_REPORT=5 + )"); + EXPECT_TRUE(parsed); + + parsed = on_demand_cfg.parse(R"( + METRICS = ipc + EVENTS_DURATION_SECS=60 + SAMPLE_PERIOD_MSECS=200 + MULTIPLEX_PERIOD_MSECS=2000 + REPORT_PERIOD_SECS=3 + SAMPLES_PER_REPORT=10 + )"); + EXPECT_TRUE(parsed); + + // One event + EXPECT_CALL(*cuptiEvents_, eventId("active_cycles")) + .Times(1) + .WillOnce(Return(3)); + // One metric + EXPECT_CALL(*cuptiMetrics_, idFromName("ipc")).Times(1).WillOnce(Return(10)); + std::map ipc_events; + ipc_events[4] = "instructions"; + ipc_events[5] = "elapsed_cycles_sm"; + EXPECT_CALL(*cuptiMetrics_, events(10)).Times(1).WillOnce(Return(ipc_events)); + EXPECT_CALL(*cuptiMetrics_, evaluationMode(10)) + .Times(1) + .WillOnce(Return(CUPTI_METRIC_EVALUATION_MODE_PER_INSTANCE)); + EXPECT_CALL(*cuptiMetrics_, valueKind(10)) + .Times(1) + .WillOnce(Return(CUPTI_METRIC_VALUE_KIND_DOUBLE)); + std::vector ids = {3, 4, 5}; + groupSet_[0].numEventGroups = 2; + groupSets_.numSets = 2; + EXPECT_CALL(*cuptiEvents_, createGroupSets(ids)) + .Times(1) + .WillOnce(Return(&groupSets_)); + // Specified CUPTI_METRIC_EVALUATION_MODE_PER_INSTANCE per instance above + // So check that it's enabled + EXPECT_CALL(*cuptiEvents_, enablePerInstance(eventGroups_[0])).Times(1); + EXPECT_CALL(*cuptiEvents_, enablePerInstance(eventGroups_[1])).Times(1); + EXPECT_CALL(*cuptiEvents_, enablePerInstance(eventGroups_[2])).Times(1); + std::vector ids_g1{3}, ids_g2{4}, ids_g3{5}; + EXPECT_CALL(*cuptiEvents_, eventsInGroup(eventGroups_[0])) + .Times(1) + .WillOnce(Return(ids_g1)); + EXPECT_CALL(*cuptiEvents_, eventsInGroup(eventGroups_[1])) + .Times(1) + .WillOnce(Return(ids_g2)); + EXPECT_CALL(*cuptiEvents_, eventsInGroup(eventGroups_[2])) + .Times(1) + .WillOnce(Return(ids_g3)); + EXPECT_CALL(*cuptiEvents_, enableGroupSet(_)).Times(1); + + profiler_->configure(cfg, &on_demand_cfg); + + EXPECT_TRUE(profiler_->enabled()); + EXPECT_EQ(profiler_->samplePeriod().count(), 250); + EXPECT_EQ(profiler_->multiplexPeriod().count(), 1000); + EXPECT_EQ(profiler_->reportPeriod().count(), 10000); + EXPECT_EQ(profiler_->onDemandReportPeriod().count(), 4000); +} + +TEST_F(EventProfilerTest, ReportSample) { + using namespace testing; + + // Test base + on-demand config, one event and one metric + Config cfg, on_demand_cfg; + bool parsed = cfg.parse("EVENTS = active_cycles"); + EXPECT_TRUE(parsed); + + parsed = on_demand_cfg.parse(R"( + METRICS = ipc + EVENTS_DURATION_SECS=60 + )"); + EXPECT_TRUE(parsed); + + // One event + EXPECT_CALL(*cuptiEvents_, eventId("active_cycles")) + .Times(1) + .WillOnce(Return(3)); + // One metric + EXPECT_CALL(*cuptiMetrics_, idFromName("ipc")).Times(1).WillOnce(Return(10)); + std::map ipc_events; + ipc_events[4] = "instructions"; + ipc_events[5] = "elapsed_cycles_sm"; + EXPECT_CALL(*cuptiMetrics_, events(10)).Times(1).WillOnce(Return(ipc_events)); + EXPECT_CALL(*cuptiMetrics_, evaluationMode(10)) + .Times(1) + .WillOnce(Return(CUPTI_METRIC_EVALUATION_MODE_PER_INSTANCE)); + EXPECT_CALL(*cuptiMetrics_, valueKind(10)) + .Times(1) + .WillOnce(Return(CUPTI_METRIC_VALUE_KIND_DOUBLE)); + std::vector ids = {3, 4, 5}; + groupSet_[0].numEventGroups = 2; + groupSets_.numSets = 2; + EXPECT_CALL(*cuptiEvents_, createGroupSets(ids)) + .Times(1) + .WillOnce(Return(&groupSets_)); + EXPECT_CALL(*cuptiEvents_, instanceCount(_)) + .Times(3) + .WillRepeatedly(Return(4)); + std::vector ids_g1{3}, ids_g2{4}, ids_g3{5}; + // These will be called by collectSample() as well, which is called twice + // per group set + EXPECT_CALL(*cuptiEvents_, eventsInGroup(eventGroups_[0])) + .Times(3) + .WillRepeatedly(Return(ids_g1)); + EXPECT_CALL(*cuptiEvents_, eventsInGroup(eventGroups_[1])) + .Times(3) + .WillRepeatedly(Return(ids_g2)); + EXPECT_CALL(*cuptiEvents_, eventsInGroup(eventGroups_[2])) + .Times(3) + .WillRepeatedly(Return(ids_g3)); + EXPECT_CALL(*cuptiEvents_, enableGroupSet(_)).Times(1); + + profiler_->configure(cfg, &on_demand_cfg); + + EXPECT_TRUE(profiler_->enabled()); + + EXPECT_CALL(*cuptiEvents_, readEvent(_, _, _)) + .Times(6) + .WillRepeatedly(Invoke( + [](CUpti_EventGroup g, CUpti_EventID id, std::vector& vals) { + vals = {1, 2, 3, 4}; + })); + + // Need to collect four times - twice for each group set + profiler_->collectSample(); + profiler_->collectSample(); + EXPECT_CALL(*cuptiEvents_, disableGroupSet(_)).Times(1); + EXPECT_CALL(*cuptiEvents_, enableGroupSet(_)).Times(1); + profiler_->enableNextCounterSet(); + profiler_->collectSample(); + profiler_->collectSample(); + + std::vector ipc_ids = {4, 5}; + // Called once for each instance (4) and once for the total. + // x2 since we recompute per logger. + EXPECT_CALL( + *cuptiMetrics_, + calculate(10, CUPTI_METRIC_VALUE_KIND_DOUBLE, ipc_ids, _, 2000000000)) + .Times(10) + .WillRepeatedly(Return(SampleValue(0.3))); + auto& logger = dynamic_cast(*loggers_[0]); + EXPECT_CALL(logger, handleSample(0, _, _)) + .Times(1) + .WillOnce(Invoke([](int device, const Sample& sample, bool from_new_version) { + // Sample will include all stats - logger must pick the + // ones it wants. + EXPECT_EQ(sample.stats.size(), 4); + EXPECT_EQ(sample.stats[0].name, "active_cycles"); + EXPECT_EQ(sample.stats[1].name, "instructions"); + EXPECT_EQ(sample.stats[2].name, "elapsed_cycles_sm"); + EXPECT_EQ(sample.stats[3].name, "ipc"); + // 2 samples, each with values {1, 2, 3, 4} + // i.e. {2, 4, 6, 8} total + EXPECT_EQ(sample.stats[0].total.getInt(), 20); + EXPECT_EQ(sample.stats[0].percentileValues[0].second.getInt(), 2); + EXPECT_EQ(sample.stats[0].percentileValues.back().second.getInt(), 8); + // ipc is always 0.3 from mocked calculate function above + EXPECT_EQ(sample.stats[3].total.getDouble(), 0.3); + EXPECT_EQ(sample.stats[3].percentileValues[0].second.getDouble(), 0.3); + EXPECT_EQ( + sample.stats[3].percentileValues.back().second.getDouble(), 0.3); + })); + profiler_->reportSamples(); + + auto& on_demand_logger = dynamic_cast(*onDemandLoggers_[0]); + EXPECT_CALL(on_demand_logger, handleSample(0, _, _)).Times(1); + profiler_->reportOnDemandSamples(); + + EXPECT_CALL(*cuptiEvents_, disableGroupSet(_)).Times(1); +} diff --git a/plugins/tensorboard-plugins/libkineto/test/LoggerObserverTest.cpp b/plugins/tensorboard-plugins/libkineto/test/LoggerObserverTest.cpp new file mode 100644 index 0000000000000000000000000000000000000000..30ba4a824af10401a45100b0b39cec54fcf98680 --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/test/LoggerObserverTest.cpp @@ -0,0 +1,96 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#include +#include + +// TODO(T90238193) +// @lint-ignore-every CLANGTIDY facebook-hte-RelativeInclude +#include "include/libkineto.h" +#include "src/Logger.h" +#include "LoggerCollector.h" + +using namespace KINETO_NAMESPACE; + +#if !USE_GOOGLE_LOG + +constexpr char InfoTestStr[] = "Checking LOG(INFO)"; +constexpr char WarningTestStr[] = "Checking LOG(WARNING)"; +constexpr char ErrorTestStr[] = "Checking LOG(ERROR)"; + +TEST(LoggerObserverTest, SingleCollectorObserver) { + // Add a LoggerObserverCollector to collect all logs during the trace. + std::unique_ptr lCollector = std::make_unique(); + Logger::addLoggerObserver(lCollector.get()); + + LOG(INFO) << InfoTestStr; + LOG(WARNING) << WarningTestStr; + LOG(ERROR) << ErrorTestStr; + + auto LoggerMD = lCollector->extractCollectorMetadata(); + EXPECT_TRUE(LoggerMD[LoggerOutputType::INFO][0].find(InfoTestStr) != std::string::npos); + EXPECT_TRUE(LoggerMD[LoggerOutputType::WARNING][0].find(WarningTestStr) != std::string::npos); + EXPECT_TRUE(LoggerMD[LoggerOutputType::ERROR][0].find(ErrorTestStr) != std::string::npos); + + Logger::removeLoggerObserver(lCollector.get()); +} + +#define NUM_OF_MESSAGES_FOR_EACH_TYPE 10 +#define NUM_OF_WRITE_THREADS 200 + +// Writes NUM_OF_MESSAGES_FOR_EACH_TYPE messages for each INFO, WARNING, and ERROR. +// NOLINTNEXTLINE(clang-diagnostic-unused-parameter) +void* writeSeveralMessages(void* ptr) { + for(int i=0; i lc1 = std::make_unique(); + std::unique_ptr lc2 = std::make_unique(); + std::unique_ptr lc3 = std::make_unique(); + std::unique_ptr lc4 = std::make_unique(); + Logger::addLoggerObserver(lc1.get()); + Logger::addLoggerObserver(lc2.get()); + Logger::addLoggerObserver(lc3.get()); + Logger::addLoggerObserver(lc4.get()); + + // Launch NUM_OF_WRITE_THREADS threads writing several messages. + pthread_t ListOfThreads[NUM_OF_WRITE_THREADS]; + for (int i=0; iextractCollectorMetadata(); + int InfoCount = 0, WarnCount = 0, ErrorCount = 0; + for (auto& md : lc1MD) { + InfoCount += md.first == LoggerOutputType::INFO ? md.second.size() : 0; + WarnCount += md.first == LoggerOutputType::WARNING ? md.second.size() : 0; + ErrorCount += md.first == LoggerOutputType::ERROR ? md.second.size() : 0; + } + + EXPECT_EQ(InfoCount, NUM_OF_WRITE_THREADS * NUM_OF_MESSAGES_FOR_EACH_TYPE); + EXPECT_EQ(WarnCount, NUM_OF_WRITE_THREADS * NUM_OF_MESSAGES_FOR_EACH_TYPE); + EXPECT_EQ(ErrorCount, NUM_OF_WRITE_THREADS * NUM_OF_MESSAGES_FOR_EACH_TYPE); + + Logger::removeLoggerObserver(lc1.get()); + Logger::removeLoggerObserver(lc2.get()); + Logger::removeLoggerObserver(lc3.get()); + Logger::removeLoggerObserver(lc4.get()); +} + +#endif // !USE_GOOGLE_LOG + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/plugins/tensorboard-plugins/libkineto/test/MockActivitySubProfiler.cpp b/plugins/tensorboard-plugins/libkineto/test/MockActivitySubProfiler.cpp new file mode 100644 index 0000000000000000000000000000000000000000..89f1d536ca8d6d794b7ffc7402001d0e3d4d9c06 --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/test/MockActivitySubProfiler.cpp @@ -0,0 +1,49 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#include +#include +#include + +#include "test/MockActivitySubProfiler.h" + +namespace libkineto { + +const std::set supported_activities {ActivityType::CPU_OP}; +const std::string profile_name{"MockProfiler"}; + +void MockProfilerSession::processTrace(ActivityLogger& logger) { + for (const auto& activity: activities()) { + activity.log(logger); + } +} + +const std::string& MockActivityProfiler::name() const { + return profile_name; +} + +const std::set& MockActivityProfiler::availableActivities() const { + return supported_activities; +} + +MockActivityProfiler::MockActivityProfiler( + std::vector& activities) : + test_activities_(activities) {}; + +std::unique_ptr MockActivityProfiler::configure( + const std::set& /*activity_types*/, + const Config& /*config*/) { + auto session = std::make_unique(); + session->set_test_activities(std::move(test_activities_)); + return session; +}; + +std::unique_ptr MockActivityProfiler::configure( + int64_t /*ts_ms*/, + int64_t /*duration_ms*/, + const std::set& activity_types, + const Config& config) { + return configure(activity_types, config); +}; + +} // namespace libkineto + diff --git a/plugins/tensorboard-plugins/libkineto/test/MockActivitySubProfiler.h b/plugins/tensorboard-plugins/libkineto/test/MockActivitySubProfiler.h new file mode 100644 index 0000000000000000000000000000000000000000..36eaa13d1a544c624a2f4bb053891d055686ebf4 --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/test/MockActivitySubProfiler.h @@ -0,0 +1,72 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#pragma once + +#include +#include +#include + +#include "include/IActivityProfiler.h" + +namespace libkineto { + +class MockProfilerSession: public IActivityProfilerSession { + + public: + explicit MockProfilerSession() {} + + void start() override { + start_count++; + status_ = TraceStatus::RECORDING; + } + + void stop() override { + stop_count++; + status_ = TraceStatus::PROCESSING; + } + + std::vector& activities() override { + return test_activities_; + } + + std::vector errors() override { + return {}; + } + + void processTrace(ActivityLogger& logger) override; + + void set_test_activities(std::vector&& acs) { + test_activities_ = std::move(acs); + } + + int start_count = 0; + int stop_count = 0; + private: + std::vector test_activities_; +}; + + +class MockActivityProfiler: public IActivityProfiler { + + public: + explicit MockActivityProfiler(std::vector& activities); + + const std::string& name() const override; + + const std::set& availableActivities() const override; + + std::unique_ptr configure( + const std::set& activity_types, + const Config& config) override; + + std::unique_ptr configure( + int64_t ts_ms, + int64_t duration_ms, + const std::set& activity_types, + const Config& config) override; + + private: + std::vector test_activities_; +}; + +} // namespace libkineto diff --git a/plugins/tensorboard-plugins/libkineto/test/PidInfoTest.cpp b/plugins/tensorboard-plugins/libkineto/test/PidInfoTest.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b86cfb36d0581ba9a8a03a09724b181c2fd2e88a --- /dev/null +++ b/plugins/tensorboard-plugins/libkineto/test/PidInfoTest.cpp @@ -0,0 +1,27 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#include "include/ThreadUtil.h" + +#include +#include + +#include +#include + +using namespace KINETO_NAMESPACE; + +TEST(ThreadNameTest, setAndGet) { + setThreadName("ThreadNameTest"); + EXPECT_EQ(getThreadName(), "ThreadNameTest"); + + setThreadName(""); + EXPECT_EQ(getThreadName(), ""); + + // Spaces etc are ok + setThreadName("Name w/ spaces"); + EXPECT_EQ(getThreadName(), "Name w/ spaces"); + + // More than 16 chars is not OK + setThreadName("More than 16 characters"); + EXPECT_EQ(getThreadName(), "Name w/ spaces"); +} diff --git a/plugins/tensorboard-plugins/tb_plugin/README.md b/plugins/tensorboard-plugins/tb_plugin/README.md index b4b417c4d21899c195674be08d196a5d9f11021a..fc0c2bc5514bba9b4507b9c07919f8ef28048d5e 100644 --- a/plugins/tensorboard-plugins/tb_plugin/README.md +++ b/plugins/tensorboard-plugins/tb_plugin/README.md @@ -17,12 +17,12 @@ 2. 从源代码安装 * 从仓库下载源码: - `git clone https://gitee.com/ascend/mstt.git` + `git clone https://gitee.com/ascend/att.git` - * 进入目录 `/plugins/tensorboard_plugins/tb_plugin` 下. + * 进入目录 `/plugins/tensorboard-plugins/tb_plugin` 下. * 编译前端代码 - `python setup.py build_fe` \ - **注意**: build_fe步骤需要安装yarn和Node.js环境 + **注意**: build_fe步骤需要安装[Node.js](https://nodejs.org/zh-cn/download)和[yarn](https://yarn.bootcss.com/docs/install/index.html)环境 * 执行安装命令可直接安装: - `pip install .` * 或: 构建whl包安装 @@ -128,37 +128,25 @@ ##### Kernel View - Kernel View 展示算子在加速核上运行的详细信息。此视图包含两张饼图和两张表,可通过 Group By 切换表格数据:算子的详情表以及统计表。 - - * 上方为饼图,展示耗时最多的数个算子耗时比例信息(左侧饼图)和算子执行在各类加速核上耗时百分比(右侧饼图) + Kernel View展示算子在加速核上运行的详细信息。 ![Alt text](./docs/images/kernel_view.PNG) - * 选择 Group By 为 All 时,展示算子详情表,部分字段说明如下: + * Calls: 算子调度的次数。 + + * Accelerator Core: 计算核。 - | 字段名 | 说明 | - | ---------------- | -------------------------------------- | - | Step Id | 标识在哪个 Step 采集的数据 | - | Name | 运行在 npu 上的算子名称 | - | Type | 算子类型 | - | Accelerator Core | AI 加速核类型,包括 AI Core、AI CPU 等 | - | Start Time(us) | 算子执行开始时间 | - | Duration(us) | 当前算子执行耗时 | - | Wait Time(us) | 算子执行等待时间 | - | Block Dim | 运行切分数量,对应任务执行时的核数 | + * Block Dim: Task运行切分数量,对应Task运行时核数。 ![Alt text](./docs/images/kernel_view_group_by_statistic.PNG) - * 选择 Group By 为 Statistic 时,展示算子信息统计表,此表格展示各算子的执行统计信息,字段说明如下: + * Accelerator Core Utilization: 算子执行在各类core上耗时百分比。 - | 字段名 | 说明 | - | ---------------- | -------| - | Name | 运行在 npu 上的算子名称 | - | Calls | 算子执行次数 | - | Total Duration(us) | 算子执行总时间 | - | Min Duration(us) | 算子执行的最小时间 | - | Max Duration(us) | 算子执行的最大时间 | - | Avg Duration(us) | 算子执行平均时间 | + * Name: 运行在npu上的算子名称。 + + * Total Duration、 Max Duration、Avg Duration、Min Duration: 算子调用总耗时、最大耗时、平均耗时以及最小耗时。 + + 此视图包含两张饼图和两张表,可通过Group By切换表格数据:算子的详细表以及统计表。 ##### Trace View @@ -174,7 +162,7 @@ ![Alt text](./docs/images/trace_view_launch.PNG) - 选择只展示async_npu,可以查看框架侧算子与昇腾硬件上执行的算子的下发执行关系。 + 选择只展示async_nup,可以查看框架侧算子与昇腾硬件上执行的算子的关联关系。 ![Alt text](./docs/images/trace_view_npu_utilization.PNG) @@ -280,7 +268,7 @@ ###### 文件导入 界面分为左侧边栏和右侧展示界面。点击左侧的Import Files或在左侧未勾选文件时点击右侧界面中心的Import Files字体,将会弹出系统文件资源管理窗,可以上传需要比对的模型网络训练日志文件。 - **注:当前最多支持上传6个文件,单个文件大小不能超过50MB。** + 注:当前最多支持上传6个文件,单个文件大小不能超过50MB。 ![Alt text](./docs/images/accuracy.PNG) ###### 已上传文件操作 @@ -331,8 +319,8 @@ * 比对方式有三种,通过Comparison Setting进行设定。 * Comparison Normal:相同iteration,后选择文件的loss值减去先选择文件的loss值。 - * Comparison Absolute:相同iteration,两个文件的loss的差值的绝对值。 - * Comparison Relative:相同iteration,两个文件的loss的差值的绝对值 / 先选择文件的loss值。 + * Comparison Normal:相同iteration,两个文件的loss的差值的绝对值。 + * Comparison Normal:相同iteration,两个文件的loss的差值的绝对值 / 先选择文件的loss值。 ### 公网URL说明 diff --git "a/plugins/tensorboard-plugins/tb_plugin/docs/\345\205\254\347\275\221URL\350\257\264\346\230\216.xlsx" "b/plugins/tensorboard-plugins/tb_plugin/docs/\345\205\254\347\275\221URL\350\257\264\346\230\216.xlsx" index de0bb25fe155aa188e5670a377311e96168586e8..b7a8bf1fd0e7eec640e46af76e16c6a228f335ba 100644 Binary files "a/plugins/tensorboard-plugins/tb_plugin/docs/\345\205\254\347\275\221URL\350\257\264\346\230\216.xlsx" and "b/plugins/tensorboard-plugins/tb_plugin/docs/\345\205\254\347\275\221URL\350\257\264\346\230\216.xlsx" differ diff --git a/plugins/tensorboard-plugins/tb_plugin/fe/prettier.json b/plugins/tensorboard-plugins/tb_plugin/fe/prettier.json index ef5789da9458a66e7dacc1dfdeeb764642331734..6049640793f6907bbd38c7065360df0ac24d64d4 100644 --- a/plugins/tensorboard-plugins/tb_plugin/fe/prettier.json +++ b/plugins/tensorboard-plugins/tb_plugin/fe/prettier.json @@ -1,12 +1,12 @@ { - "parser": "typescript", - "semi": true, - "singleQuote": true, - "jsxSingleQuote": false, - "bracketSpacing": true, - "tabWidth": 2, - "useTabs": false, - "trailingComma": "all", - "proseWrap": "always", - "endOfLine": "lf" + "parser": "typescript", + "semi": false, + "singleQuote": true, + "jsxSingleQuote": false, + "bracketSpacing": true, + "tabWidth": 2, + "useTabs": false, + "trailingComma": "none", + "proseWrap": "always", + "endOfLine": "lf" } diff --git a/plugins/tensorboard-plugins/tb_plugin/fe/scripts/add_header.py b/plugins/tensorboard-plugins/tb_plugin/fe/scripts/add_header.py index 69bc6c05541cbaff0fc88eb7456f501fb5bd4f71..03fb7c15aea6bf361b241910fa4529bc0996286c 100644 --- a/plugins/tensorboard-plugins/tb_plugin/fe/scripts/add_header.py +++ b/plugins/tensorboard-plugins/tb_plugin/fe/scripts/add_header.py @@ -1,23 +1,4 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. -# Copyright(c) 2023 Huawei Technologies. -# All rights reserved -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# Modifications: Add visualization of PyTorch Ascend profiling. -# -------------------------------------------------------------------------- -# !/usr/bin/env python +#!/usr/bin/env python import glob import os import sys diff --git a/plugins/tensorboard-plugins/tb_plugin/fe/src/api/generated/api.ts b/plugins/tensorboard-plugins/tb_plugin/fe/src/api/generated/api.ts index 29cde96ebbde928cde967b3b1b365d12e74ee734..b00601fba8852eeed9be052c6ed8adc106d49215 100644 --- a/plugins/tensorboard-plugins/tb_plugin/fe/src/api/generated/api.ts +++ b/plugins/tensorboard-plugins/tb_plugin/fe/src/api/generated/api.ts @@ -15,7 +15,7 @@ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. - * + * * Modifications: Add visualization of PyTorch Ascend profiling. *--------------------------------------------------------------------------------------------*/ @@ -33,11 +33,11 @@ * Do not edit the file manually. */ -import * as url from 'url'; -import * as portableFetch from 'portable-fetch'; -import { Configuration } from './configuration'; +import * as url from 'url' +import * as portableFetch from 'portable-fetch' +import { Configuration } from './configuration' -const BASE_PATH = '.'.replace(/\/+$/, ''); +const BASE_PATH = '.'.replace(/\/+$/, '') /** * @@ -47,8 +47,8 @@ export const COLLECTION_FORMATS = { csv: ',', ssv: ' ', tsv: '\t', - pipes: '|', -}; + pipes: '|' +} /** * @@ -56,7 +56,7 @@ export const COLLECTION_FORMATS = { * @interface FetchAPI */ export interface FetchAPI { - (url: string, init?: any): Promise; + (url: string, init?: any): Promise } /** @@ -65,8 +65,8 @@ export interface FetchAPI { * @interface FetchArgs */ export interface FetchArgs { - url: string; - options: any; + url: string + options: any } /** @@ -75,7 +75,7 @@ export interface FetchArgs { * @class BaseAPI */ export class BaseAPI { - protected configuration: Configuration; + protected configuration: Configuration constructor( configuration?: Configuration, @@ -83,8 +83,8 @@ export class BaseAPI { protected fetch: FetchAPI = portableFetch ) { if (configuration) { - this.configuration = configuration; - this.basePath = configuration.basePath || this.basePath; + this.configuration = configuration + this.basePath = configuration.basePath || this.basePath } } } @@ -96,9 +96,9 @@ export class BaseAPI { * @extends {Error} */ export class RequiredError extends Error { - name: 'RequiredError'; + name: 'RequiredError' constructor(public field: string, msg?: string) { - super(msg); + super(msg) } } @@ -107,7 +107,7 @@ export class RequiredError extends Error { * @export * @interface CallStackTableData */ -export interface CallStackTableData extends Array {} +export interface CallStackTableData extends Array { } /** * * @export @@ -119,67 +119,67 @@ export interface CallStackTableDataInner { * @type {string} * @memberof CallStackTableDataInner */ - name: string; + name: string /** * * @type {string} * @memberof CallStackTableDataInner */ - input_shape?: string; + input_shape?: string /** * * @type {number} * @memberof CallStackTableDataInner */ - calls: number; + calls: number /** * * @type {number} * @memberof CallStackTableDataInner */ - device_self_duration?: number; + device_self_duration?: number /** * * @type {number} * @memberof CallStackTableDataInner */ - device_total_duration?: number; + device_total_duration?: number /** * * @type {number} * @memberof CallStackTableDataInner */ - host_self_duration: number; + host_self_duration: number /** * * @type {number} * @memberof CallStackTableDataInner */ - host_total_duration: number; + host_total_duration: number /** * * @type {string} * @memberof CallStackTableDataInner */ - call_stack?: string; + call_stack?: string /** * * @type {string} * @memberof CallStackTableDataInner */ - tc_eligible?: string; + tc_eligible?: string /** * * @type {number} * @memberof CallStackTableDataInner */ - tc_self_ratio?: number; + tc_self_ratio?: number /** * * @type {number} * @memberof CallStackTableDataInner */ - tc_total_ratio?: number; + tc_total_ratio?: number } /** * @@ -192,25 +192,25 @@ export interface DiffNode { * @type {OpStats} * @memberof DiffNode */ - left: OpStats; + left: OpStats /** * * @type {OpStats} * @memberof DiffNode */ - right: OpStats; + right: OpStats /** * * @type {string} * @memberof DiffNode */ - path: string; + path: string /** * * @type {Array} * @memberof DiffNode */ - children: Array; + children: Array } /** * @@ -223,13 +223,13 @@ export interface DistributedGraph { * @type {DistributedGraphMetadata} * @memberof DistributedGraph */ - metadata: DistributedGraphMetadata; + metadata: DistributedGraphMetadata /** * * @type {any} * @memberof DistributedGraph */ - data: any; + data: any } /** * @@ -242,19 +242,19 @@ export interface DistributedGraphMetadata { * @type {string} * @memberof DistributedGraphMetadata */ - title: string; + title: string /** * * @type {Array} * @memberof DistributedGraphMetadata */ - legends: Array; + legends: Array /** * * @type {string} * @memberof DistributedGraphMetadata */ - units: string; + units: string } /** * @@ -267,13 +267,13 @@ export interface Environment { * @type {string} * @memberof Environment */ - title: string; + title: string /** * * @type {string} * @memberof Environment */ - value: string; + value: string } /** * @@ -286,13 +286,13 @@ export interface GpuInfo { * @type {GpuInfoMetadata} * @memberof GpuInfo */ - metadata: GpuInfoMetadata; + metadata: GpuInfoMetadata /** * * @type {any} * @memberof GpuInfo */ - data: any; + data: any } /** * @@ -305,7 +305,7 @@ export interface GpuInfoMetadata { * @type {string} * @memberof GpuInfoMetadata */ - title: string; + title: string } /** * @@ -318,13 +318,13 @@ export interface GpuMetric { * @type {string} * @memberof GpuMetric */ - title: string; + title: string /** * * @type {string} * @memberof GpuMetric */ - value: string; + value: string } /** * @@ -337,13 +337,13 @@ export interface GpuMetrics { * @type {Array} * @memberof GpuMetrics */ - data: Array; + data: Array /** * * @type {string} * @memberof GpuMetrics */ - tooltip: string; + tooltip: string } /** * @@ -356,19 +356,19 @@ export interface Graph { * @type {string} * @memberof Graph */ - title?: string; + title?: string /** * * @type {Array} * @memberof Graph */ - columns: Array; + columns: Array /** * * @type {Array>} * @memberof Graph */ - rows: Array>; + rows: Array> } /** * @@ -381,13 +381,13 @@ export interface ValueAndTooltip { * @type {string | number} * @memberof ValueAndTooltip */ - value: string | number; + value: string | number /** * * @type {string} * @memberof ValueAndTooltip */ - tooltip?: string; + tooltip?: string } /** * @@ -400,19 +400,19 @@ export interface StepedGraph { * @type {string} * @memberof StepedGraph */ - title?: string; + title?: string /** * * @type {Array} * @memberof StepedGraph */ - columns: Array; + columns: Array /** * * @type {Array>} * @memberof StepedGraph */ - rows: Array>; + rows: Array> } /** * @@ -425,19 +425,19 @@ export interface GraphAscend { * @type {string} * @memberof GraphAscend */ - title?: string; + title?: string /** * * @type {Array} * @memberof GraphAscend */ - columns: Array; + columns: Array /** * * @type {any} * @memberof GraphAscend */ - rows: any; + rows: any } /** * @@ -450,25 +450,25 @@ export interface GraphColumn { * @type {string} * @memberof GraphColumn */ - type: string; + type: string /** * * @type {string} * @memberof GraphColumn */ - name: string; + name: string /** * * @type {string} * @memberof GraphColumn */ - role?: string; + role?: string /** * * @type {GraphColumnP} * @memberof GraphColumn */ - p?: GraphColumnP; + p?: GraphColumnP } /** * @@ -481,7 +481,7 @@ export interface GraphColumnP { * @type {boolean} * @memberof GraphColumnP */ - html?: boolean; + html?: boolean } /** * @@ -494,13 +494,13 @@ export interface InlineResponse200 { * @type {TableMetadata} * @memberof InlineResponse200 */ - metadata: TableMetadata; + metadata: TableMetadata /** * * @type {OperationTableData} * @memberof InlineResponse200 */ - data: OperationTableData; + data: OperationTableData } /** * @@ -513,13 +513,13 @@ export interface InlineResponse2001 { * @type {TableMetadata} * @memberof InlineResponse2001 */ - metadata: TableMetadata; + metadata: TableMetadata /** * * @type {CallStackTableData} * @memberof InlineResponse2001 */ - data: CallStackTableData; + data: CallStackTableData } /** * @@ -532,13 +532,13 @@ export interface InlineResponse2002 { * @type {GpuInfoMetadata} * @memberof InlineResponse2002 */ - metadata: GpuInfoMetadata; + metadata: GpuInfoMetadata /** * * @type {any} * @memberof InlineResponse2002 */ - data: any; + data: any } /** * @@ -551,8 +551,8 @@ export interface KernelGraph { * @type {Graph} * @memberof KernelGraph */ - total: Graph; - device_target: string; + total: Graph, + device_target: string } /** * @@ -565,50 +565,50 @@ export interface KeyedColumn { * @type {string} * @memberof KeyedColumn */ - type: string; + type: string /** * * @type {string} * @memberof KeyedColumn */ - name: string; + name: string /** * * @type {string} * @memberof KeyedColumn */ - key: string; + key: string } /** - * + * * @export * @interface MemoryCurveDataAll */ export interface MemoryCurveDataAll { /** - * + * * @type {string} * @memberof MemoryCurveDataAll */ - default_device: string; + default_device: string /** - * + * * @type {Array} * @memberof MemoryCurveDataAll */ - devices: Array; + devices: Array /** * * @type {MemoryCurveDataAscend} * @memberof MemoryCurveDataAll */ - total: MemoryCurveDataAscend; + total: MemoryCurveDataAscend /** * * @type {MemoryCurveDataAscend} * @memberof MemoryCurveDataAll */ - ptaGe: MemoryCurveDataAscend; + ptaGe: MemoryCurveDataAscend } /** * @@ -621,19 +621,19 @@ export interface MemoryCurveData { * @type {MemoryCurveDataMetadata} * @memberof MemoryCurveData */ - metadata: MemoryCurveDataMetadata; + metadata: MemoryCurveDataMetadata /** * * @type {Array} * @memberof MemoryCurveData */ - columns: Array; + columns: Array /** * * @type {any} * @memberof MemoryCurveData */ - rows: any; + rows: any } /** * @@ -646,19 +646,19 @@ export interface MemoryCurveDataAscend { * @type {MemoryCurveDataMetadata} * @memberof MemoryCurveDataAscend */ - metadata: MemoryCurveDataMetadata; + metadata: MemoryCurveDataMetadata /** * * @type {any} * @memberof MemoryCurveDataAscend */ - columns: any; + columns: any /** * * @type {any} * @memberof MemoryCurveDataAscend */ - rows: any; + rows: any } /** * @@ -671,55 +671,55 @@ export interface MemoryCurveDataMetadata { * @type {string} * @memberof MemoryCurveDataMetadata */ - default_device: string; + default_device: string /** * * @type {Array} * @memberof MemoryCurveDataMetadata */ - devices: Array; + devices: Array /** * * @type {any} * @memberof MemoryCurveDataMetadata */ - peaks: any; + peaks: any /** * * @type {any} * @memberof MemoryCurveDataMetadata */ - totals: any; + totals: any /** * * @type {number} * @memberof MemoryCurveDataMetadata */ - first_ts: number; + first_ts: number /** * * @type {string} * @memberof MemoryCurveDataMetadata */ - time_metric: string; + time_metric: string /** * * @type {string} * @memberof MemoryCurveDataMetadata */ - memory_metric: string; + memory_metric: string /** * * @type {number} * @memberof MemoryCurveDataMetadata */ - time_factor: number; + time_factor: number /** * * @type {number} * @memberof MemoryCurveDataMetadata */ - memory_factor: number; + memory_factor: number } /** * @@ -732,38 +732,38 @@ export interface MemoryEventsData { * @type {MemoryEventsTableMetadata} * @memberof MemoryEventsData */ - metadata: MemoryEventsTableMetadata; + metadata: MemoryEventsTableMetadata /** * * @type {Array} * @memberof MemoryEventsData */ - columns: Array; + columns: Array /** * * @type {any} * @memberof MemoryEventsData */ - rows: any; + rows: any } /** - * + * * @exports * @interface MemoryEventsDataAll */ export interface MemoryEventsDataAll { /** - * + * * @type {MemoryEventsData} * @memberof MemoryEventsDataAll */ - operator: MemoryEventsData; + operator: MemoryEventsData /** - * + * * @type {MemoryEventsData} * @memberof MemoryEventsDataAll */ - component: MemoryEventsData; + component: MemoryEventsData } /** * @@ -776,25 +776,25 @@ export interface MemoryEventsTableMetadata { * @type {string} * @memberof MemoryEventsTableMetadata */ - title: string; + title: string /** * * @type {string} * @memberof MemoryEventsTableMetadata */ - default_device: string; + default_device: string /** * * @type {string} * @memberof MemoryEventsTableMetadata */ - search?: string; + search?: string /** * * @type {string} * @memberof MemoryEventsTableMetadata */ - sort?: string; + sort?: string } /** * @@ -807,19 +807,19 @@ export interface MemoryStatsData { * @type {MemoryStatsTableMetadata} * @memberof MemoryStatsData */ - metadata: MemoryStatsTableMetadata; + metadata: MemoryStatsTableMetadata /** * * @type {Array} * @memberof MemoryStatsData */ - columns: Array; + columns: Array /** * * @type {any} * @memberof MemoryStatsData */ - rows: any; + rows: any } /** * @@ -832,25 +832,25 @@ export interface MemoryStatsTableMetadata { * @type {string} * @memberof MemoryStatsTableMetadata */ - title: string; + title: string /** * * @type {string} * @memberof MemoryStatsTableMetadata */ - default_device: string; + default_device: string /** * * @type {string} * @memberof MemoryStatsTableMetadata */ - search: string; + search: string /** * * @type {string} * @memberof MemoryStatsTableMetadata */ - sort: string; + sort: string } /** * @@ -863,61 +863,61 @@ export interface ModuleStats { * @type {string} * @memberof ModuleStats */ - name: string; + name: string /** * * @type {string} * @memberof ModuleStats */ - id: string; + id: string /** * * @type {number} * @memberof ModuleStats */ - occurences: number; + occurences: number /** * * @type {number} * @memberof ModuleStats */ - operators: number; + operators: number /** * * @type {number} * @memberof ModuleStats */ - host_duration: number; + host_duration: number /** * * @type {number} * @memberof ModuleStats */ - self_host_duration: number; + self_host_duration: number /** * * @type {number} * @memberof ModuleStats */ - device_duration: number; + device_duration: number /** * * @type {number} * @memberof ModuleStats */ - self_device_duration: number; + self_device_duration: number /** * * @type {number} * @memberof ModuleStats */ - avg_duration: number; + avg_duration: number /** * * @type {Array} * @memberof ModuleStats */ - children: Array; + children: Array } /** * @@ -930,13 +930,13 @@ export interface ModuleViewData { * @type {Array} * @memberof ModuleViewData */ - columns: Array; + columns: Array /** * * @type {Array} * @memberof ModuleViewData */ - data: Array; + data: Array } /** * @@ -949,37 +949,37 @@ export interface OpAgg { * @type {string} * @memberof OpAgg */ - name: string; + name: string /** * * @type {number} * @memberof OpAgg */ - calls: number; + calls: number /** * * @type {number} * @memberof OpAgg */ - host_duration: number; + host_duration: number /** * * @type {number} * @memberof OpAgg */ - device_duration: number; + device_duration: number /** * * @type {number} * @memberof OpAgg */ - self_host_duration: number; + self_host_duration: number /** * * @type {number} * @memberof OpAgg */ - self_device_duration: number; + self_device_duration: number } /** * @@ -992,38 +992,38 @@ export interface OpStats { * @type {string} * @memberof OpStats */ - name: string; + name: string /** * * @type {number} * @memberof OpStats */ - duration: number; + duration: number /** * * @type {number} * @memberof OpStats */ - device_duration: number; + device_duration: number /** * * @type {number} * @memberof OpStats */ - total_duration: number; + total_duration: number /** * * @type {Array} * @memberof OpStats */ - aggs: Array; + aggs: Array } /** * * @export * @interface OperationTableData */ -export interface OperationTableData extends Array {} +export interface OperationTableData extends Array { } /** * * @export @@ -1035,67 +1035,67 @@ export interface OperationTableDataInner { * @type {string} * @memberof OperationTableDataInner */ - name: string; + name: string /** * * @type {string} * @memberof OperationTableDataInner */ - input_shape?: string; + input_shape?: string /** * * @type {number} * @memberof OperationTableDataInner */ - calls: number; + calls: number /** * * @type {number} * @memberof OperationTableDataInner */ - device_self_duration?: number; + device_self_duration?: number /** * * @type {number} * @memberof OperationTableDataInner */ - device_total_duration?: number; + device_total_duration?: number /** * * @type {number} * @memberof OperationTableDataInner */ - host_self_duration: number; + host_self_duration: number /** * * @type {number} * @memberof OperationTableDataInner */ - host_total_duration: number; + host_total_duration: number /** * * @type {boolean} * @memberof OperationTableDataInner */ - has_call_stack: boolean; + has_call_stack: boolean /** * * @type {string} * @memberof OperationTableDataInner */ - tc_eligible?: string; + tc_eligible?: string /** * * @type {number} * @memberof OperationTableDataInner */ - tc_self_ratio?: number; + tc_self_ratio?: number /** * * @type {number} * @memberof OperationTableDataInner */ - tc_total_ratio?: number; + tc_total_ratio?: number } /** * @@ -1108,25 +1108,25 @@ export interface OperatorGraph { * @type {Graph} * @memberof OperatorGraph */ - device_total_time: Graph; + device_total_time: Graph /** * * @type {Graph} * @memberof OperatorGraph */ - device_self_time: Graph; + device_self_time: Graph /** * * @type {Graph} * @memberof OperatorGraph */ - host_total_time: Graph; + host_total_time: Graph /** * * @type {Graph} * @memberof OperatorGraph */ - host_self_time: Graph; + host_self_time: Graph } /** * @@ -1139,37 +1139,37 @@ export interface OperatorNode { * @type {string} * @memberof OperatorNode */ - name: string; + name: string /** * * @type {number} * @memberof OperatorNode */ - start_time: number; + start_time: number /** * * @type {number} * @memberof OperatorNode */ - end_time: number; + end_time: number /** * * @type {string} * @memberof OperatorNode */ - type: string; + type: string /** * * @type {number} * @memberof OperatorNode */ - tid: number; + tid: number /** * * @type {Array} * @memberof OperatorNode */ - children: Array; + children: Array } /** * @@ -1182,31 +1182,31 @@ export interface Overview { * @type {Array} * @memberof Overview */ - performance: Array; + performance: Array /** * * @type {Array} * @memberof Overview */ - environments: Array; + environments: Array /** * * @type {StepedGraph} * @memberof Overview */ - steps: StepedGraph; + steps: StepedGraph /** * * @type {string} * @memberof Overview */ - recommendations: string; + recommendations: string /** * * @type {GpuMetrics} * @memberof Overview */ - gpu_metrics?: GpuMetrics; + gpu_metrics?: GpuMetrics } /** * @@ -1219,31 +1219,31 @@ export interface Performance { * @type {string} * @memberof Performance */ - name: string; + name: string /** * * @type {string} * @memberof Performance */ - description?: string; + description?: string /** * * @type {string} * @memberof Performance */ - value?: string; + value?: string /** * * @type {string} * @memberof Performance */ - extra?: string; + extra?: string /** * * @type {Array} * @memberof Performance */ - children?: Array; + children?: Array } /** * @@ -1256,13 +1256,13 @@ export interface Runs { * @type {Array} * @memberof Runs */ - runs: Array; + runs: Array /** * * @type {boolean} * @memberof Runs */ - loading: boolean; + loading: boolean } /** * @@ -1275,13 +1275,13 @@ export interface TableData { * @type {Graph} * @memberof TableData */ - data: Graph; + data: Graph /** * * @type {TableMetadata} * @memberof TableData */ - metadata: TableMetadata; + metadata: TableMetadata } /** * @@ -1294,13 +1294,13 @@ export interface TableMetadata { * @type {string} * @memberof TableMetadata */ - sort: string; + sort: string /** * * @type {any} * @memberof TableMetadata */ - tooltips?: any; + tooltips?: any } /** * @@ -1313,7 +1313,7 @@ export interface TensorCoresGraph { * @type {Graph} * @memberof TensorCoresGraph */ - total: Graph; + total: Graph } /** * @@ -1326,32 +1326,32 @@ export interface ValueAndFormat { * @type {string | number | boolean} * @memberof ValueAndFormat */ - v: string | number | boolean; + v: string | number | boolean /** * * @type {string} * @memberof ValueAndFormat */ - f: string; + f: string } /** - * + * * @exports * @interface Views */ export interface Views { /** - * + * * @type {string} * @memberof Views */ - device_target: string; + device_target: string /** - * + * * @type {Array} * @memberof Views */ - views: Array; + views: Array } /** * DefaultApi - fetch parameter creator @@ -1388,75 +1388,75 @@ export const DefaultApiFetchParamCreator = function ( throw new RequiredError( 'run', 'Required parameter run was null or undefined when calling diffnodeGet.' - ); + ) } // verify required parameter 'worker' is not null or undefined if (worker === null || worker === undefined) { throw new RequiredError( 'worker', 'Required parameter worker was null or undefined when calling diffnodeGet.' - ); + ) } // verify required parameter 'span' is not null or undefined if (span === null || span === undefined) { throw new RequiredError( 'span', 'Required parameter span was null or undefined when calling diffnodeGet.' - ); + ) } // verify required parameter 'exp_run' is not null or undefined if (exp_run === null || exp_run === undefined) { throw new RequiredError( 'exp_run', 'Required parameter exp_run was null or undefined when calling diffnodeGet.' - ); + ) } // verify required parameter 'exp_worker' is not null or undefined if (exp_worker === null || exp_worker === undefined) { throw new RequiredError( 'exp_worker', 'Required parameter exp_worker was null or undefined when calling diffnodeGet.' - ); + ) } // verify required parameter 'exp_span' is not null or undefined if (exp_span === null || exp_span === undefined) { throw new RequiredError( 'exp_span', 'Required parameter exp_span was null or undefined when calling diffnodeGet.' - ); + ) } - const localVarPath = `/diffnode`; - const localVarUrlObj = url.parse(localVarPath, true); - const localVarRequestOptions = Object.assign({ method: 'GET' }, options); - const localVarHeaderParameter = {} as any; - const localVarQueryParameter = {} as any; + const localVarPath = `/diffnode` + const localVarUrlObj = url.parse(localVarPath, true) + const localVarRequestOptions = Object.assign({ method: 'GET' }, options) + const localVarHeaderParameter = {} as any + const localVarQueryParameter = {} as any if (run !== undefined) { - localVarQueryParameter.run = run; + localVarQueryParameter['run'] = run } if (worker !== undefined) { - localVarQueryParameter.worker = worker; + localVarQueryParameter['worker'] = worker } if (span !== undefined) { - localVarQueryParameter.span = span; + localVarQueryParameter['span'] = span } if (exp_run !== undefined) { - localVarQueryParameter.exp_run = exp_run; + localVarQueryParameter['exp_run'] = exp_run } if (exp_worker !== undefined) { - localVarQueryParameter.exp_worker = exp_worker; + localVarQueryParameter['exp_worker'] = exp_worker } if (exp_span !== undefined) { - localVarQueryParameter.exp_span = exp_span; + localVarQueryParameter['exp_span'] = exp_span } if (path !== undefined) { - localVarQueryParameter.path = path; + localVarQueryParameter['path'] = path } localVarUrlObj.query = Object.assign( @@ -1464,19 +1464,19 @@ export const DefaultApiFetchParamCreator = function ( localVarUrlObj.query, localVarQueryParameter, options.query - ); + ) // fix override query string Detail: https://stackoverflow.com/a/7517673/1077943 - delete localVarUrlObj.search; + delete localVarUrlObj.search localVarRequestOptions.headers = Object.assign( {}, localVarHeaderParameter, options.headers - ); + ) return { url: url.format(localVarUrlObj), - options: localVarRequestOptions, - }; + options: localVarRequestOptions + } }, /** * @@ -1497,38 +1497,38 @@ export const DefaultApiFetchParamCreator = function ( throw new RequiredError( 'run', 'Required parameter run was null or undefined when calling distributedCommopsGet.' - ); + ) } // verify required parameter 'worker' is not null or undefined if (worker === null || worker === undefined) { throw new RequiredError( 'worker', 'Required parameter worker was null or undefined when calling distributedCommopsGet.' - ); + ) } // verify required parameter 'span' is not null or undefined if (span === null || span === undefined) { throw new RequiredError( 'span', 'Required parameter span was null or undefined when calling distributedCommopsGet.' - ); + ) } - const localVarPath = `/distributed/commops`; - const localVarUrlObj = url.parse(localVarPath, true); - const localVarRequestOptions = Object.assign({ method: 'GET' }, options); - const localVarHeaderParameter = {} as any; - const localVarQueryParameter = {} as any; + const localVarPath = `/distributed/commops` + const localVarUrlObj = url.parse(localVarPath, true) + const localVarRequestOptions = Object.assign({ method: 'GET' }, options) + const localVarHeaderParameter = {} as any + const localVarQueryParameter = {} as any if (run !== undefined) { - localVarQueryParameter.run = run; + localVarQueryParameter['run'] = run } if (worker !== undefined) { - localVarQueryParameter.worker = worker; + localVarQueryParameter['worker'] = worker } if (span !== undefined) { - localVarQueryParameter.span = span; + localVarQueryParameter['span'] = span } localVarUrlObj.query = Object.assign( @@ -1536,19 +1536,19 @@ export const DefaultApiFetchParamCreator = function ( localVarUrlObj.query, localVarQueryParameter, options.query - ); + ) // fix override query string Detail: https://stackoverflow.com/a/7517673/1077943 - delete localVarUrlObj.search; + delete localVarUrlObj.search localVarRequestOptions.headers = Object.assign( {}, localVarHeaderParameter, options.headers - ); + ) return { url: url.format(localVarUrlObj), - options: localVarRequestOptions, - }; + options: localVarRequestOptions + } }, /** * @@ -1569,38 +1569,38 @@ export const DefaultApiFetchParamCreator = function ( throw new RequiredError( 'run', 'Required parameter run was null or undefined when calling distributedGpuinfoGet.' - ); + ) } // verify required parameter 'worker' is not null or undefined if (worker === null || worker === undefined) { throw new RequiredError( 'worker', 'Required parameter worker was null or undefined when calling distributedGpuinfoGet.' - ); + ) } // verify required parameter 'span' is not null or undefined if (span === null || span === undefined) { throw new RequiredError( 'span', 'Required parameter span was null or undefined when calling distributedGpuinfoGet.' - ); + ) } - const localVarPath = `/distributed/gpuinfo`; - const localVarUrlObj = url.parse(localVarPath, true); - const localVarRequestOptions = Object.assign({ method: 'GET' }, options); - const localVarHeaderParameter = {} as any; - const localVarQueryParameter = {} as any; + const localVarPath = `/distributed/gpuinfo` + const localVarUrlObj = url.parse(localVarPath, true) + const localVarRequestOptions = Object.assign({ method: 'GET' }, options) + const localVarHeaderParameter = {} as any + const localVarQueryParameter = {} as any if (run !== undefined) { - localVarQueryParameter.run = run; + localVarQueryParameter['run'] = run } if (worker !== undefined) { - localVarQueryParameter.worker = worker; + localVarQueryParameter['worker'] = worker } if (span !== undefined) { - localVarQueryParameter.span = span; + localVarQueryParameter['span'] = span } localVarUrlObj.query = Object.assign( @@ -1608,19 +1608,19 @@ export const DefaultApiFetchParamCreator = function ( localVarUrlObj.query, localVarQueryParameter, options.query - ); + ) // fix override query string Detail: https://stackoverflow.com/a/7517673/1077943 - delete localVarUrlObj.search; + delete localVarUrlObj.search localVarRequestOptions.headers = Object.assign( {}, localVarHeaderParameter, options.headers - ); + ) return { url: url.format(localVarUrlObj), - options: localVarRequestOptions, - }; + options: localVarRequestOptions + } }, /** * @@ -1641,38 +1641,38 @@ export const DefaultApiFetchParamCreator = function ( throw new RequiredError( 'run', 'Required parameter run was null or undefined when calling distributedOverlapGet.' - ); + ) } // verify required parameter 'worker' is not null or undefined if (worker === null || worker === undefined) { throw new RequiredError( 'worker', 'Required parameter worker was null or undefined when calling distributedOverlapGet.' - ); + ) } // verify required parameter 'span' is not null or undefined if (span === null || span === undefined) { throw new RequiredError( 'span', 'Required parameter span was null or undefined when calling distributedOverlapGet.' - ); + ) } - const localVarPath = `/distributed/overlap`; - const localVarUrlObj = url.parse(localVarPath, true); - const localVarRequestOptions = Object.assign({ method: 'GET' }, options); - const localVarHeaderParameter = {} as any; - const localVarQueryParameter = {} as any; + const localVarPath = `/distributed/overlap` + const localVarUrlObj = url.parse(localVarPath, true) + const localVarRequestOptions = Object.assign({ method: 'GET' }, options) + const localVarHeaderParameter = {} as any + const localVarQueryParameter = {} as any if (run !== undefined) { - localVarQueryParameter.run = run; + localVarQueryParameter['run'] = run } if (worker !== undefined) { - localVarQueryParameter.worker = worker; + localVarQueryParameter['worker'] = worker } if (span !== undefined) { - localVarQueryParameter.span = span; + localVarQueryParameter['span'] = span } localVarUrlObj.query = Object.assign( @@ -1680,19 +1680,19 @@ export const DefaultApiFetchParamCreator = function ( localVarUrlObj.query, localVarQueryParameter, options.query - ); + ) // fix override query string Detail: https://stackoverflow.com/a/7517673/1077943 - delete localVarUrlObj.search; + delete localVarUrlObj.search localVarRequestOptions.headers = Object.assign( {}, localVarHeaderParameter, options.headers - ); + ) return { url: url.format(localVarUrlObj), - options: localVarRequestOptions, - }; + options: localVarRequestOptions + } }, /** * @@ -1713,38 +1713,38 @@ export const DefaultApiFetchParamCreator = function ( throw new RequiredError( 'run', 'Required parameter run was null or undefined when calling distributedWaittimeGet.' - ); + ) } // verify required parameter 'worker' is not null or undefined if (worker === null || worker === undefined) { throw new RequiredError( 'worker', 'Required parameter worker was null or undefined when calling distributedWaittimeGet.' - ); + ) } // verify required parameter 'span' is not null or undefined if (span === null || span === undefined) { throw new RequiredError( 'span', 'Required parameter span was null or undefined when calling distributedWaittimeGet.' - ); + ) } - const localVarPath = `/distributed/waittime`; - const localVarUrlObj = url.parse(localVarPath, true); - const localVarRequestOptions = Object.assign({ method: 'GET' }, options); - const localVarHeaderParameter = {} as any; - const localVarQueryParameter = {} as any; + const localVarPath = `/distributed/waittime` + const localVarUrlObj = url.parse(localVarPath, true) + const localVarRequestOptions = Object.assign({ method: 'GET' }, options) + const localVarHeaderParameter = {} as any + const localVarQueryParameter = {} as any if (run !== undefined) { - localVarQueryParameter.run = run; + localVarQueryParameter['run'] = run } if (worker !== undefined) { - localVarQueryParameter.worker = worker; + localVarQueryParameter['worker'] = worker } if (span !== undefined) { - localVarQueryParameter.span = span; + localVarQueryParameter['span'] = span } localVarUrlObj.query = Object.assign( @@ -1752,19 +1752,19 @@ export const DefaultApiFetchParamCreator = function ( localVarUrlObj.query, localVarQueryParameter, options.query - ); + ) // fix override query string Detail: https://stackoverflow.com/a/7517673/1077943 - delete localVarUrlObj.search; + delete localVarUrlObj.search localVarRequestOptions.headers = Object.assign( {}, localVarHeaderParameter, options.headers - ); + ) return { url: url.format(localVarUrlObj), - options: localVarRequestOptions, - }; + options: localVarRequestOptions + } }, /** * @@ -1787,49 +1787,49 @@ export const DefaultApiFetchParamCreator = function ( throw new RequiredError( 'run', 'Required parameter run was null or undefined when calling kernelGet.' - ); + ) } // verify required parameter 'worker' is not null or undefined if (worker === null || worker === undefined) { throw new RequiredError( 'worker', 'Required parameter worker was null or undefined when calling kernelGet.' - ); + ) } // verify required parameter 'span' is not null or undefined if (span === null || span === undefined) { throw new RequiredError( 'span', 'Required parameter span was null or undefined when calling kernelGet.' - ); + ) } // verify required parameter 'group_by' is not null or undefined if (group_by === null || group_by === undefined) { throw new RequiredError( 'group_by', 'Required parameter group_by was null or undefined when calling kernelGet.' - ); + ) } - const localVarPath = `/kernel`; - const localVarUrlObj = url.parse(localVarPath, true); - const localVarRequestOptions = Object.assign({ method: 'GET' }, options); - const localVarHeaderParameter = {} as any; - const localVarQueryParameter = {} as any; + const localVarPath = `/kernel` + const localVarUrlObj = url.parse(localVarPath, true) + const localVarRequestOptions = Object.assign({ method: 'GET' }, options) + const localVarHeaderParameter = {} as any + const localVarQueryParameter = {} as any if (run !== undefined) { - localVarQueryParameter.run = run; + localVarQueryParameter['run'] = run } if (worker !== undefined) { - localVarQueryParameter.worker = worker; + localVarQueryParameter['worker'] = worker } if (span !== undefined) { - localVarQueryParameter.span = span; + localVarQueryParameter['span'] = span } if (group_by !== undefined) { - localVarQueryParameter.group_by = group_by; + localVarQueryParameter['group_by'] = group_by } localVarUrlObj.query = Object.assign( @@ -1837,19 +1837,19 @@ export const DefaultApiFetchParamCreator = function ( localVarUrlObj.query, localVarQueryParameter, options.query - ); + ) // fix override query string Detail: https://stackoverflow.com/a/7517673/1077943 - delete localVarUrlObj.search; + delete localVarUrlObj.search localVarRequestOptions.headers = Object.assign( {}, localVarHeaderParameter, options.headers - ); + ) return { url: url.format(localVarUrlObj), - options: localVarRequestOptions, - }; + options: localVarRequestOptions + } }, /** * @@ -1872,42 +1872,42 @@ export const DefaultApiFetchParamCreator = function ( throw new RequiredError( 'run', 'Required parameter run was null or undefined when calling kernelTableGet.' - ); + ) } // verify required parameter 'worker' is not null or undefined if (worker === null || worker === undefined) { throw new RequiredError( 'worker', 'Required parameter worker was null or undefined when calling kernelTableGet.' - ); + ) } // verify required parameter 'span' is not null or undefined if (span === null || span === undefined) { throw new RequiredError( 'span', 'Required parameter span was null or undefined when calling kernelTableGet.' - ); + ) } - const localVarPath = `/kernel/table`; - const localVarUrlObj = url.parse(localVarPath, true); - const localVarRequestOptions = Object.assign({ method: 'GET' }, options); - const localVarHeaderParameter = {} as any; - const localVarQueryParameter = {} as any; + const localVarPath = `/kernel/table` + const localVarUrlObj = url.parse(localVarPath, true) + const localVarRequestOptions = Object.assign({ method: 'GET' }, options) + const localVarHeaderParameter = {} as any + const localVarQueryParameter = {} as any if (run !== undefined) { - localVarQueryParameter.run = run; + localVarQueryParameter['run'] = run } if (worker !== undefined) { - localVarQueryParameter.worker = worker; + localVarQueryParameter['worker'] = worker } if (span !== undefined) { - localVarQueryParameter.span = span; + localVarQueryParameter['span'] = span } if (group_by !== undefined) { - localVarQueryParameter.group_by = group_by; + localVarQueryParameter['group_by'] = group_by } localVarUrlObj.query = Object.assign( @@ -1915,19 +1915,19 @@ export const DefaultApiFetchParamCreator = function ( localVarUrlObj.query, localVarQueryParameter, options.query - ); + ) // fix override query string Detail: https://stackoverflow.com/a/7517673/1077943 - delete localVarUrlObj.search; + delete localVarUrlObj.search localVarRequestOptions.headers = Object.assign( {}, localVarHeaderParameter, options.headers - ); + ) return { url: url.format(localVarUrlObj), - options: localVarRequestOptions, - }; + options: localVarRequestOptions + } }, /** * @@ -1948,38 +1948,38 @@ export const DefaultApiFetchParamCreator = function ( throw new RequiredError( 'run', 'Required parameter run was null or undefined when calling kernelTcPieGet.' - ); + ) } // verify required parameter 'worker' is not null or undefined if (worker === null || worker === undefined) { throw new RequiredError( 'worker', 'Required parameter worker was null or undefined when calling kernelTcPieGet.' - ); + ) } // verify required parameter 'span' is not null or undefined if (span === null || span === undefined) { throw new RequiredError( 'span', 'Required parameter span was null or undefined when calling kernelTcPieGet.' - ); + ) } - const localVarPath = `/kernel/tc_pie`; - const localVarUrlObj = url.parse(localVarPath, true); - const localVarRequestOptions = Object.assign({ method: 'GET' }, options); - const localVarHeaderParameter = {} as any; - const localVarQueryParameter = {} as any; + const localVarPath = `/kernel/tc_pie` + const localVarUrlObj = url.parse(localVarPath, true) + const localVarRequestOptions = Object.assign({ method: 'GET' }, options) + const localVarHeaderParameter = {} as any + const localVarQueryParameter = {} as any if (run !== undefined) { - localVarQueryParameter.run = run; + localVarQueryParameter['run'] = run } if (worker !== undefined) { - localVarQueryParameter.worker = worker; + localVarQueryParameter['worker'] = worker } if (span !== undefined) { - localVarQueryParameter.span = span; + localVarQueryParameter['span'] = span } localVarUrlObj.query = Object.assign( @@ -1987,19 +1987,19 @@ export const DefaultApiFetchParamCreator = function ( localVarUrlObj.query, localVarQueryParameter, options.query - ); + ) // fix override query string Detail: https://stackoverflow.com/a/7517673/1077943 - delete localVarUrlObj.search; + delete localVarUrlObj.search localVarRequestOptions.headers = Object.assign( {}, localVarHeaderParameter, options.headers - ); + ) return { url: url.format(localVarUrlObj), - options: localVarRequestOptions, - }; + options: localVarRequestOptions + } }, /** * @@ -2020,38 +2020,38 @@ export const DefaultApiFetchParamCreator = function ( throw new RequiredError( 'run', 'Required parameter run was null or undefined when calling memoryCurveGet.' - ); + ) } // verify required parameter 'worker' is not null or undefined if (worker === null || worker === undefined) { throw new RequiredError( 'worker', 'Required parameter worker was null or undefined when calling memoryCurveGet.' - ); + ) } // verify required parameter 'span' is not null or undefined if (span === null || span === undefined) { throw new RequiredError( 'span', 'Required parameter span was null or undefined when calling memoryCurveGet.' - ); + ) } - const localVarPath = `/memory_curve`; - const localVarUrlObj = url.parse(localVarPath, true); - const localVarRequestOptions = Object.assign({ method: 'GET' }, options); - const localVarHeaderParameter = {} as any; - const localVarQueryParameter = {} as any; + const localVarPath = `/memory_curve` + const localVarUrlObj = url.parse(localVarPath, true) + const localVarRequestOptions = Object.assign({ method: 'GET' }, options) + const localVarHeaderParameter = {} as any + const localVarQueryParameter = {} as any if (run !== undefined) { - localVarQueryParameter.run = run; + localVarQueryParameter['run'] = run } if (worker !== undefined) { - localVarQueryParameter.worker = worker; + localVarQueryParameter['worker'] = worker } if (span !== undefined) { - localVarQueryParameter.span = span; + localVarQueryParameter['span'] = span } localVarUrlObj.query = Object.assign( @@ -2059,19 +2059,19 @@ export const DefaultApiFetchParamCreator = function ( localVarUrlObj.query, localVarQueryParameter, options.query - ); + ) // fix override query string Detail: https://stackoverflow.com/a/7517673/1077943 - delete localVarUrlObj.search; + delete localVarUrlObj.search localVarRequestOptions.headers = Object.assign( {}, localVarHeaderParameter, options.headers - ); + ) return { url: url.format(localVarUrlObj), - options: localVarRequestOptions, - }; + options: localVarRequestOptions + } }, /** * @@ -2096,46 +2096,46 @@ export const DefaultApiFetchParamCreator = function ( throw new RequiredError( 'run', 'Required parameter run was null or undefined when calling memoryEventsGet.' - ); + ) } // verify required parameter 'worker' is not null or undefined if (worker === null || worker === undefined) { throw new RequiredError( 'worker', 'Required parameter worker was null or undefined when calling memoryEventsGet.' - ); + ) } // verify required parameter 'span' is not null or undefined if (span === null || span === undefined) { throw new RequiredError( 'span', 'Required parameter span was null or undefined when calling memoryEventsGet.' - ); + ) } - const localVarPath = `/memory_events`; - const localVarUrlObj = url.parse(localVarPath, true); - const localVarRequestOptions = Object.assign({ method: 'GET' }, options); - const localVarHeaderParameter = {} as any; - const localVarQueryParameter = {} as any; + const localVarPath = `/memory_events` + const localVarUrlObj = url.parse(localVarPath, true) + const localVarRequestOptions = Object.assign({ method: 'GET' }, options) + const localVarHeaderParameter = {} as any + const localVarQueryParameter = {} as any if (run !== undefined) { - localVarQueryParameter.run = run; + localVarQueryParameter['run'] = run } if (worker !== undefined) { - localVarQueryParameter.worker = worker; + localVarQueryParameter['worker'] = worker } if (span !== undefined) { - localVarQueryParameter.span = span; + localVarQueryParameter['span'] = span } if (start_ts !== undefined) { - localVarQueryParameter.start_ts = start_ts; + localVarQueryParameter['start_ts'] = start_ts } if (end_ts !== undefined) { - localVarQueryParameter.end_ts = end_ts; + localVarQueryParameter['end_ts'] = end_ts } localVarUrlObj.query = Object.assign( @@ -2143,19 +2143,19 @@ export const DefaultApiFetchParamCreator = function ( localVarUrlObj.query, localVarQueryParameter, options.query - ); + ) // fix override query string Detail: https://stackoverflow.com/a/7517673/1077943 - delete localVarUrlObj.search; + delete localVarUrlObj.search localVarRequestOptions.headers = Object.assign( {}, localVarHeaderParameter, options.headers - ); + ) return { url: url.format(localVarUrlObj), - options: localVarRequestOptions, - }; + options: localVarRequestOptions + } }, /** * @@ -2180,46 +2180,46 @@ export const DefaultApiFetchParamCreator = function ( throw new RequiredError( 'run', 'Required parameter run was null or undefined when calling memoryGet.' - ); + ) } // verify required parameter 'worker' is not null or undefined if (worker === null || worker === undefined) { throw new RequiredError( 'worker', 'Required parameter worker was null or undefined when calling memoryGet.' - ); + ) } // verify required parameter 'span' is not null or undefined if (span === null || span === undefined) { throw new RequiredError( 'span', 'Required parameter span was null or undefined when calling memoryGet.' - ); + ) } - const localVarPath = `/memory`; - const localVarUrlObj = url.parse(localVarPath, true); - const localVarRequestOptions = Object.assign({ method: 'GET' }, options); - const localVarHeaderParameter = {} as any; - const localVarQueryParameter = {} as any; + const localVarPath = `/memory` + const localVarUrlObj = url.parse(localVarPath, true) + const localVarRequestOptions = Object.assign({ method: 'GET' }, options) + const localVarHeaderParameter = {} as any + const localVarQueryParameter = {} as any if (run !== undefined) { - localVarQueryParameter.run = run; + localVarQueryParameter['run'] = run } if (worker !== undefined) { - localVarQueryParameter.worker = worker; + localVarQueryParameter['worker'] = worker } if (span !== undefined) { - localVarQueryParameter.span = span; + localVarQueryParameter['span'] = span } if (start_ts !== undefined) { - localVarQueryParameter.start_ts = start_ts; + localVarQueryParameter['start_ts'] = start_ts } if (end_ts !== undefined) { - localVarQueryParameter.end_ts = end_ts; + localVarQueryParameter['end_ts'] = end_ts } localVarUrlObj.query = Object.assign( @@ -2227,19 +2227,19 @@ export const DefaultApiFetchParamCreator = function ( localVarUrlObj.query, localVarQueryParameter, options.query - ); + ) // fix override query string Detail: https://stackoverflow.com/a/7517673/1077943 - delete localVarUrlObj.search; + delete localVarUrlObj.search localVarRequestOptions.headers = Object.assign( {}, localVarHeaderParameter, options.headers - ); + ) return { url: url.format(localVarUrlObj), - options: localVarRequestOptions, - }; + options: localVarRequestOptions + } }, /** * @@ -2260,38 +2260,38 @@ export const DefaultApiFetchParamCreator = function ( throw new RequiredError( 'run', 'Required parameter run was null or undefined when calling moduleGet.' - ); + ) } // verify required parameter 'worker' is not null or undefined if (worker === null || worker === undefined) { throw new RequiredError( 'worker', 'Required parameter worker was null or undefined when calling moduleGet.' - ); + ) } // verify required parameter 'span' is not null or undefined if (span === null || span === undefined) { throw new RequiredError( 'span', 'Required parameter span was null or undefined when calling moduleGet.' - ); + ) } - const localVarPath = `/module`; - const localVarUrlObj = url.parse(localVarPath, true); - const localVarRequestOptions = Object.assign({ method: 'GET' }, options); - const localVarHeaderParameter = {} as any; - const localVarQueryParameter = {} as any; + const localVarPath = `/module` + const localVarUrlObj = url.parse(localVarPath, true) + const localVarRequestOptions = Object.assign({ method: 'GET' }, options) + const localVarHeaderParameter = {} as any + const localVarQueryParameter = {} as any if (run !== undefined) { - localVarQueryParameter.run = run; + localVarQueryParameter['run'] = run } if (worker !== undefined) { - localVarQueryParameter.worker = worker; + localVarQueryParameter['worker'] = worker } if (span !== undefined) { - localVarQueryParameter.span = span; + localVarQueryParameter['span'] = span } localVarUrlObj.query = Object.assign( @@ -2299,19 +2299,19 @@ export const DefaultApiFetchParamCreator = function ( localVarUrlObj.query, localVarQueryParameter, options.query - ); + ) // fix override query string Detail: https://stackoverflow.com/a/7517673/1077943 - delete localVarUrlObj.search; + delete localVarUrlObj.search localVarRequestOptions.headers = Object.assign( {}, localVarHeaderParameter, options.headers - ); + ) return { url: url.format(localVarUrlObj), - options: localVarRequestOptions, - }; + options: localVarRequestOptions + } }, /** * @@ -2334,49 +2334,49 @@ export const DefaultApiFetchParamCreator = function ( throw new RequiredError( 'run', 'Required parameter run was null or undefined when calling operationGet.' - ); + ) } // verify required parameter 'worker' is not null or undefined if (worker === null || worker === undefined) { throw new RequiredError( 'worker', 'Required parameter worker was null or undefined when calling operationGet.' - ); + ) } // verify required parameter 'span' is not null or undefined if (span === null || span === undefined) { throw new RequiredError( 'span', 'Required parameter span was null or undefined when calling operationGet.' - ); + ) } // verify required parameter 'group_by' is not null or undefined if (group_by === null || group_by === undefined) { throw new RequiredError( 'group_by', 'Required parameter group_by was null or undefined when calling operationGet.' - ); + ) } - const localVarPath = `/operation`; - const localVarUrlObj = url.parse(localVarPath, true); - const localVarRequestOptions = Object.assign({ method: 'GET' }, options); - const localVarHeaderParameter = {} as any; - const localVarQueryParameter = {} as any; + const localVarPath = `/operation` + const localVarUrlObj = url.parse(localVarPath, true) + const localVarRequestOptions = Object.assign({ method: 'GET' }, options) + const localVarHeaderParameter = {} as any + const localVarQueryParameter = {} as any if (run !== undefined) { - localVarQueryParameter.run = run; + localVarQueryParameter['run'] = run } if (worker !== undefined) { - localVarQueryParameter.worker = worker; + localVarQueryParameter['worker'] = worker } if (span !== undefined) { - localVarQueryParameter.span = span; + localVarQueryParameter['span'] = span } if (group_by !== undefined) { - localVarQueryParameter.group_by = group_by; + localVarQueryParameter['group_by'] = group_by } localVarUrlObj.query = Object.assign( @@ -2384,19 +2384,19 @@ export const DefaultApiFetchParamCreator = function ( localVarUrlObj.query, localVarQueryParameter, options.query - ); + ) // fix override query string Detail: https://stackoverflow.com/a/7517673/1077943 - delete localVarUrlObj.search; + delete localVarUrlObj.search localVarRequestOptions.headers = Object.assign( {}, localVarHeaderParameter, options.headers - ); + ) return { url: url.format(localVarUrlObj), - options: localVarRequestOptions, - }; + options: localVarRequestOptions + } }, /** * @@ -2423,64 +2423,64 @@ export const DefaultApiFetchParamCreator = function ( throw new RequiredError( 'run', 'Required parameter run was null or undefined when calling operationStackGet.' - ); + ) } // verify required parameter 'worker' is not null or undefined if (worker === null || worker === undefined) { throw new RequiredError( 'worker', 'Required parameter worker was null or undefined when calling operationStackGet.' - ); + ) } // verify required parameter 'span' is not null or undefined if (span === null || span === undefined) { throw new RequiredError( 'span', 'Required parameter span was null or undefined when calling operationStackGet.' - ); + ) } // verify required parameter 'group_by' is not null or undefined if (group_by === null || group_by === undefined) { throw new RequiredError( 'group_by', 'Required parameter group_by was null or undefined when calling operationStackGet.' - ); + ) } // verify required parameter 'op_name' is not null or undefined if (op_name === null || op_name === undefined) { throw new RequiredError( 'op_name', 'Required parameter op_name was null or undefined when calling operationStackGet.' - ); + ) } - const localVarPath = `/operation/stack`; - const localVarUrlObj = url.parse(localVarPath, true); - const localVarRequestOptions = Object.assign({ method: 'GET' }, options); - const localVarHeaderParameter = {} as any; - const localVarQueryParameter = {} as any; + const localVarPath = `/operation/stack` + const localVarUrlObj = url.parse(localVarPath, true) + const localVarRequestOptions = Object.assign({ method: 'GET' }, options) + const localVarHeaderParameter = {} as any + const localVarQueryParameter = {} as any if (run !== undefined) { - localVarQueryParameter.run = run; + localVarQueryParameter['run'] = run } if (worker !== undefined) { - localVarQueryParameter.worker = worker; + localVarQueryParameter['worker'] = worker } if (span !== undefined) { - localVarQueryParameter.span = span; + localVarQueryParameter['span'] = span } if (group_by !== undefined) { - localVarQueryParameter.group_by = group_by; + localVarQueryParameter['group_by'] = group_by } if (op_name !== undefined) { - localVarQueryParameter.op_name = op_name; + localVarQueryParameter['op_name'] = op_name } if (input_shape !== undefined) { - localVarQueryParameter.input_shape = input_shape; + localVarQueryParameter['input_shape'] = input_shape } localVarUrlObj.query = Object.assign( @@ -2488,19 +2488,19 @@ export const DefaultApiFetchParamCreator = function ( localVarUrlObj.query, localVarQueryParameter, options.query - ); + ) // fix override query string Detail: https://stackoverflow.com/a/7517673/1077943 - delete localVarUrlObj.search; + delete localVarUrlObj.search localVarRequestOptions.headers = Object.assign( {}, localVarHeaderParameter, options.headers - ); + ) return { url: url.format(localVarUrlObj), - options: localVarRequestOptions, - }; + options: localVarRequestOptions + } }, /** * @@ -2523,49 +2523,49 @@ export const DefaultApiFetchParamCreator = function ( throw new RequiredError( 'run', 'Required parameter run was null or undefined when calling operationTableGet.' - ); + ) } // verify required parameter 'worker' is not null or undefined if (worker === null || worker === undefined) { throw new RequiredError( 'worker', 'Required parameter worker was null or undefined when calling operationTableGet.' - ); + ) } // verify required parameter 'span' is not null or undefined if (span === null || span === undefined) { throw new RequiredError( 'span', 'Required parameter span was null or undefined when calling operationTableGet.' - ); + ) } // verify required parameter 'group_by' is not null or undefined if (group_by === null || group_by === undefined) { throw new RequiredError( 'group_by', 'Required parameter group_by was null or undefined when calling operationTableGet.' - ); + ) } - const localVarPath = `/operation/table`; - const localVarUrlObj = url.parse(localVarPath, true); - const localVarRequestOptions = Object.assign({ method: 'GET' }, options); - const localVarHeaderParameter = {} as any; - const localVarQueryParameter = {} as any; + const localVarPath = `/operation/table` + const localVarUrlObj = url.parse(localVarPath, true) + const localVarRequestOptions = Object.assign({ method: 'GET' }, options) + const localVarHeaderParameter = {} as any + const localVarQueryParameter = {} as any if (run !== undefined) { - localVarQueryParameter.run = run; + localVarQueryParameter['run'] = run } if (worker !== undefined) { - localVarQueryParameter.worker = worker; + localVarQueryParameter['worker'] = worker } if (span !== undefined) { - localVarQueryParameter.span = span; + localVarQueryParameter['span'] = span } if (group_by !== undefined) { - localVarQueryParameter.group_by = group_by; + localVarQueryParameter['group_by'] = group_by } localVarUrlObj.query = Object.assign( @@ -2573,19 +2573,19 @@ export const DefaultApiFetchParamCreator = function ( localVarUrlObj.query, localVarQueryParameter, options.query - ); + ) // fix override query string Detail: https://stackoverflow.com/a/7517673/1077943 - delete localVarUrlObj.search; + delete localVarUrlObj.search localVarRequestOptions.headers = Object.assign( {}, localVarHeaderParameter, options.headers - ); + ) return { url: url.format(localVarUrlObj), - options: localVarRequestOptions, - }; + options: localVarRequestOptions + } }, /** * @@ -2606,38 +2606,38 @@ export const DefaultApiFetchParamCreator = function ( throw new RequiredError( 'run', 'Required parameter run was null or undefined when calling overviewGet.' - ); + ) } // verify required parameter 'worker' is not null or undefined if (worker === null || worker === undefined) { throw new RequiredError( 'worker', 'Required parameter worker was null or undefined when calling overviewGet.' - ); + ) } // verify required parameter 'span' is not null or undefined if (span === null || span === undefined) { throw new RequiredError( 'span', 'Required parameter span was null or undefined when calling overviewGet.' - ); + ) } - const localVarPath = `/overview`; - const localVarUrlObj = url.parse(localVarPath, true); - const localVarRequestOptions = Object.assign({ method: 'GET' }, options); - const localVarHeaderParameter = {} as any; - const localVarQueryParameter = {} as any; + const localVarPath = `/overview` + const localVarUrlObj = url.parse(localVarPath, true) + const localVarRequestOptions = Object.assign({ method: 'GET' }, options) + const localVarHeaderParameter = {} as any + const localVarQueryParameter = {} as any if (run !== undefined) { - localVarQueryParameter.run = run; + localVarQueryParameter['run'] = run } if (worker !== undefined) { - localVarQueryParameter.worker = worker; + localVarQueryParameter['worker'] = worker } if (span !== undefined) { - localVarQueryParameter.span = span; + localVarQueryParameter['span'] = span } localVarUrlObj.query = Object.assign( @@ -2645,19 +2645,19 @@ export const DefaultApiFetchParamCreator = function ( localVarUrlObj.query, localVarQueryParameter, options.query - ); + ) // fix override query string Detail: https://stackoverflow.com/a/7517673/1077943 - delete localVarUrlObj.search; + delete localVarUrlObj.search localVarRequestOptions.headers = Object.assign( {}, localVarHeaderParameter, options.headers - ); + ) return { url: url.format(localVarUrlObj), - options: localVarRequestOptions, - }; + options: localVarRequestOptions + } }, /** * @@ -2665,30 +2665,30 @@ export const DefaultApiFetchParamCreator = function ( * @throws {RequiredError} */ runsGet(options: any = {}): FetchArgs { - const localVarPath = `/runs`; - const localVarUrlObj = url.parse(localVarPath, true); - const localVarRequestOptions = Object.assign({ method: 'GET' }, options); - const localVarHeaderParameter = {} as any; - const localVarQueryParameter = {} as any; + const localVarPath = `/runs` + const localVarUrlObj = url.parse(localVarPath, true) + const localVarRequestOptions = Object.assign({ method: 'GET' }, options) + const localVarHeaderParameter = {} as any + const localVarQueryParameter = {} as any localVarUrlObj.query = Object.assign( {}, localVarUrlObj.query, localVarQueryParameter, options.query - ); + ) // fix override query string Detail: https://stackoverflow.com/a/7517673/1077943 - delete localVarUrlObj.search; + delete localVarUrlObj.search localVarRequestOptions.headers = Object.assign( {}, localVarHeaderParameter, options.headers - ); + ) return { url: url.format(localVarUrlObj), - options: localVarRequestOptions, - }; + options: localVarRequestOptions + } }, /** * @@ -2703,27 +2703,27 @@ export const DefaultApiFetchParamCreator = function ( throw new RequiredError( 'run', 'Required parameter run was null or undefined when calling spansGet.' - ); + ) } // verify required parameter 'worker' is not null or undefined if (worker === null || worker === undefined) { throw new RequiredError( 'worker', 'Required parameter worker was null or undefined when calling spansGet.' - ); + ) } - const localVarPath = `/spans`; - const localVarUrlObj = url.parse(localVarPath, true); - const localVarRequestOptions = Object.assign({ method: 'GET' }, options); - const localVarHeaderParameter = {} as any; - const localVarQueryParameter = {} as any; + const localVarPath = `/spans` + const localVarUrlObj = url.parse(localVarPath, true) + const localVarRequestOptions = Object.assign({ method: 'GET' }, options) + const localVarHeaderParameter = {} as any + const localVarQueryParameter = {} as any if (run !== undefined) { - localVarQueryParameter.run = run; + localVarQueryParameter['run'] = run } if (worker !== undefined) { - localVarQueryParameter.worker = worker; + localVarQueryParameter['worker'] = worker } localVarUrlObj.query = Object.assign( @@ -2731,19 +2731,19 @@ export const DefaultApiFetchParamCreator = function ( localVarUrlObj.query, localVarQueryParameter, options.query - ); + ) // fix override query string Detail: https://stackoverflow.com/a/7517673/1077943 - delete localVarUrlObj.search; + delete localVarUrlObj.search localVarRequestOptions.headers = Object.assign( {}, localVarHeaderParameter, options.headers - ); + ) return { url: url.format(localVarUrlObj), - options: localVarRequestOptions, - }; + options: localVarRequestOptions + } }, /** * @@ -2764,38 +2764,38 @@ export const DefaultApiFetchParamCreator = function ( throw new RequiredError( 'run', 'Required parameter run was null or undefined when calling traceGet.' - ); + ) } // verify required parameter 'worker' is not null or undefined if (worker === null || worker === undefined) { throw new RequiredError( 'worker', 'Required parameter worker was null or undefined when calling traceGet.' - ); + ) } // verify required parameter 'span' is not null or undefined if (span === null || span === undefined) { throw new RequiredError( 'span', 'Required parameter span was null or undefined when calling traceGet.' - ); + ) } - const localVarPath = `/trace`; - const localVarUrlObj = url.parse(localVarPath, true); - const localVarRequestOptions = Object.assign({ method: 'GET' }, options); - const localVarHeaderParameter = {} as any; - const localVarQueryParameter = {} as any; + const localVarPath = `/trace` + const localVarUrlObj = url.parse(localVarPath, true) + const localVarRequestOptions = Object.assign({ method: 'GET' }, options) + const localVarHeaderParameter = {} as any + const localVarQueryParameter = {} as any if (run !== undefined) { - localVarQueryParameter.run = run; + localVarQueryParameter['run'] = run } if (worker !== undefined) { - localVarQueryParameter.worker = worker; + localVarQueryParameter['worker'] = worker } if (span !== undefined) { - localVarQueryParameter.span = span; + localVarQueryParameter['span'] = span } localVarUrlObj.query = Object.assign( @@ -2803,19 +2803,19 @@ export const DefaultApiFetchParamCreator = function ( localVarUrlObj.query, localVarQueryParameter, options.query - ); + ) // fix override query string Detail: https://stackoverflow.com/a/7517673/1077943 - delete localVarUrlObj.search; + delete localVarUrlObj.search localVarRequestOptions.headers = Object.assign( {}, localVarHeaderParameter, options.headers - ); + ) return { url: url.format(localVarUrlObj), - options: localVarRequestOptions, - }; + options: localVarRequestOptions + } }, /** * @@ -2836,38 +2836,38 @@ export const DefaultApiFetchParamCreator = function ( throw new RequiredError( 'run', 'Required parameter run was null or undefined when calling treeGet.' - ); + ) } // verify required parameter 'worker' is not null or undefined if (worker === null || worker === undefined) { throw new RequiredError( 'worker', 'Required parameter worker was null or undefined when calling treeGet.' - ); + ) } // verify required parameter 'span' is not null or undefined if (span === null || span === undefined) { throw new RequiredError( 'span', 'Required parameter span was null or undefined when calling treeGet.' - ); + ) } - const localVarPath = `/tree`; - const localVarUrlObj = url.parse(localVarPath, true); - const localVarRequestOptions = Object.assign({ method: 'GET' }, options); - const localVarHeaderParameter = {} as any; - const localVarQueryParameter = {} as any; + const localVarPath = `/tree` + const localVarUrlObj = url.parse(localVarPath, true) + const localVarRequestOptions = Object.assign({ method: 'GET' }, options) + const localVarHeaderParameter = {} as any + const localVarQueryParameter = {} as any if (run !== undefined) { - localVarQueryParameter.run = run; + localVarQueryParameter['run'] = run } if (worker !== undefined) { - localVarQueryParameter.worker = worker; + localVarQueryParameter['worker'] = worker } if (span !== undefined) { - localVarQueryParameter.span = span; + localVarQueryParameter['span'] = span } localVarUrlObj.query = Object.assign( @@ -2875,19 +2875,19 @@ export const DefaultApiFetchParamCreator = function ( localVarUrlObj.query, localVarQueryParameter, options.query - ); + ) // fix override query string Detail: https://stackoverflow.com/a/7517673/1077943 - delete localVarUrlObj.search; + delete localVarUrlObj.search localVarRequestOptions.headers = Object.assign( {}, localVarHeaderParameter, options.headers - ); + ) return { url: url.format(localVarUrlObj), - options: localVarRequestOptions, - }; + options: localVarRequestOptions + } }, /** * @@ -2901,16 +2901,16 @@ export const DefaultApiFetchParamCreator = function ( throw new RequiredError( 'run', 'Required parameter run was null or undefined when calling viewsGet.' - ); + ) } - const localVarPath = `/views`; - const localVarUrlObj = url.parse(localVarPath, true); - const localVarRequestOptions = Object.assign({ method: 'GET' }, options); - const localVarHeaderParameter = {} as any; - const localVarQueryParameter = {} as any; + const localVarPath = `/views` + const localVarUrlObj = url.parse(localVarPath, true) + const localVarRequestOptions = Object.assign({ method: 'GET' }, options) + const localVarHeaderParameter = {} as any + const localVarQueryParameter = {} as any if (run !== undefined) { - localVarQueryParameter.run = run; + localVarQueryParameter['run'] = run } localVarUrlObj.query = Object.assign( @@ -2918,19 +2918,19 @@ export const DefaultApiFetchParamCreator = function ( localVarUrlObj.query, localVarQueryParameter, options.query - ); + ) // fix override query string Detail: https://stackoverflow.com/a/7517673/1077943 - delete localVarUrlObj.search; + delete localVarUrlObj.search localVarRequestOptions.headers = Object.assign( {}, localVarHeaderParameter, options.headers - ); + ) return { url: url.format(localVarUrlObj), - options: localVarRequestOptions, - }; + options: localVarRequestOptions + } }, /** * @@ -2945,27 +2945,27 @@ export const DefaultApiFetchParamCreator = function ( throw new RequiredError( 'run', 'Required parameter run was null or undefined when calling workersGet.' - ); + ) } // verify required parameter 'view' is not null or undefined if (view === null || view === undefined) { throw new RequiredError( 'view', 'Required parameter view was null or undefined when calling workersGet.' - ); + ) } - const localVarPath = `/workers`; - const localVarUrlObj = url.parse(localVarPath, true); - const localVarRequestOptions = Object.assign({ method: 'GET' }, options); - const localVarHeaderParameter = {} as any; - const localVarQueryParameter = {} as any; + const localVarPath = `/workers` + const localVarUrlObj = url.parse(localVarPath, true) + const localVarRequestOptions = Object.assign({ method: 'GET' }, options) + const localVarHeaderParameter = {} as any + const localVarQueryParameter = {} as any if (run !== undefined) { - localVarQueryParameter.run = run; + localVarQueryParameter['run'] = run } if (view !== undefined) { - localVarQueryParameter.view = view; + localVarQueryParameter['view'] = view } localVarUrlObj.query = Object.assign( @@ -2973,22 +2973,22 @@ export const DefaultApiFetchParamCreator = function ( localVarUrlObj.query, localVarQueryParameter, options.query - ); + ) // fix override query string Detail: https://stackoverflow.com/a/7517673/1077943 - delete localVarUrlObj.search; + delete localVarUrlObj.search localVarRequestOptions.headers = Object.assign( {}, localVarHeaderParameter, options.headers - ); + ) return { url: url.format(localVarUrlObj), - options: localVarRequestOptions, - }; - }, - }; -}; + options: localVarRequestOptions + } + } + } +} /** * DefaultApi - functional programming interface @@ -3029,7 +3029,7 @@ export const DefaultApiFp = function (configuration?: Configuration) { exp_span, path, options - ); + ) return ( fetch: FetchAPI = portableFetch, basePath: string = BASE_PATH @@ -3039,12 +3039,12 @@ export const DefaultApiFp = function (configuration?: Configuration) { localVarFetchArgs.options ).then((response) => { if (response.status >= 200 && response.status < 300) { - return response.json(); + return response.json() } else { - throw response; + throw response } - }); - }; + }) + } }, /** * @@ -3062,7 +3062,7 @@ export const DefaultApiFp = function (configuration?: Configuration) { ): (fetch?: FetchAPI, basePath?: string) => Promise { const localVarFetchArgs = DefaultApiFetchParamCreator( configuration - ).distributedCommopsGet(run, worker, span, options); + ).distributedCommopsGet(run, worker, span, options) return ( fetch: FetchAPI = portableFetch, basePath: string = BASE_PATH @@ -3072,12 +3072,12 @@ export const DefaultApiFp = function (configuration?: Configuration) { localVarFetchArgs.options ).then((response) => { if (response.status >= 200 && response.status < 300) { - return response.json(); + return response.json() } else { - throw response; + throw response } - }); - }; + }) + } }, /** * @@ -3095,7 +3095,7 @@ export const DefaultApiFp = function (configuration?: Configuration) { ): (fetch?: FetchAPI, basePath?: string) => Promise { const localVarFetchArgs = DefaultApiFetchParamCreator( configuration - ).distributedGpuinfoGet(run, worker, span, options); + ).distributedGpuinfoGet(run, worker, span, options) return ( fetch: FetchAPI = portableFetch, basePath: string = BASE_PATH @@ -3105,12 +3105,12 @@ export const DefaultApiFp = function (configuration?: Configuration) { localVarFetchArgs.options ).then((response) => { if (response.status >= 200 && response.status < 300) { - return response.json(); + return response.json() } else { - throw response; + throw response } - }); - }; + }) + } }, /** * @@ -3128,7 +3128,7 @@ export const DefaultApiFp = function (configuration?: Configuration) { ): (fetch?: FetchAPI, basePath?: string) => Promise { const localVarFetchArgs = DefaultApiFetchParamCreator( configuration - ).distributedOverlapGet(run, worker, span, options); + ).distributedOverlapGet(run, worker, span, options) return ( fetch: FetchAPI = portableFetch, basePath: string = BASE_PATH @@ -3138,12 +3138,12 @@ export const DefaultApiFp = function (configuration?: Configuration) { localVarFetchArgs.options ).then((response) => { if (response.status >= 200 && response.status < 300) { - return response.json(); + return response.json() } else { - throw response; + throw response } - }); - }; + }) + } }, /** * @@ -3161,7 +3161,7 @@ export const DefaultApiFp = function (configuration?: Configuration) { ): (fetch?: FetchAPI, basePath?: string) => Promise { const localVarFetchArgs = DefaultApiFetchParamCreator( configuration - ).distributedWaittimeGet(run, worker, span, options); + ).distributedWaittimeGet(run, worker, span, options) return ( fetch: FetchAPI = portableFetch, basePath: string = BASE_PATH @@ -3171,12 +3171,12 @@ export const DefaultApiFp = function (configuration?: Configuration) { localVarFetchArgs.options ).then((response) => { if (response.status >= 200 && response.status < 300) { - return response.json(); + return response.json() } else { - throw response; + throw response } - }); - }; + }) + } }, /** * @@ -3196,7 +3196,7 @@ export const DefaultApiFp = function (configuration?: Configuration) { ): (fetch?: FetchAPI, basePath?: string) => Promise { const localVarFetchArgs = DefaultApiFetchParamCreator( configuration - ).kernelGet(run, worker, span, group_by, options); + ).kernelGet(run, worker, span, group_by, options) return ( fetch: FetchAPI = portableFetch, basePath: string = BASE_PATH @@ -3206,12 +3206,12 @@ export const DefaultApiFp = function (configuration?: Configuration) { localVarFetchArgs.options ).then((response) => { if (response.status >= 200 && response.status < 300) { - return response.json(); + return response.json() } else { - throw response; + throw response } - }); - }; + }) + } }, /** * @@ -3231,7 +3231,7 @@ export const DefaultApiFp = function (configuration?: Configuration) { ): (fetch?: FetchAPI, basePath?: string) => Promise { const localVarFetchArgs = DefaultApiFetchParamCreator( configuration - ).kernelTableGet(run, worker, span, group_by, options); + ).kernelTableGet(run, worker, span, group_by, options) return ( fetch: FetchAPI = portableFetch, basePath: string = BASE_PATH @@ -3241,12 +3241,12 @@ export const DefaultApiFp = function (configuration?: Configuration) { localVarFetchArgs.options ).then((response) => { if (response.status >= 200 && response.status < 300) { - return response.json(); + return response.json() } else { - throw response; + throw response } - }); - }; + }) + } }, /** * @@ -3264,7 +3264,7 @@ export const DefaultApiFp = function (configuration?: Configuration) { ): (fetch?: FetchAPI, basePath?: string) => Promise { const localVarFetchArgs = DefaultApiFetchParamCreator( configuration - ).kernelTcPieGet(run, worker, span, options); + ).kernelTcPieGet(run, worker, span, options) return ( fetch: FetchAPI = portableFetch, basePath: string = BASE_PATH @@ -3274,12 +3274,12 @@ export const DefaultApiFp = function (configuration?: Configuration) { localVarFetchArgs.options ).then((response) => { if (response.status >= 200 && response.status < 300) { - return response.json(); + return response.json() } else { - throw response; + throw response } - }); - }; + }) + } }, /** * @@ -3294,13 +3294,10 @@ export const DefaultApiFp = function (configuration?: Configuration) { worker: string, span: string, options?: any - ): ( - fetch?: FetchAPI, - basePath?: string - ) => Promise { + ): (fetch?: FetchAPI, basePath?: string) => Promise { const localVarFetchArgs = DefaultApiFetchParamCreator( configuration - ).memoryCurveGet(run, worker, span, options); + ).memoryCurveGet(run, worker, span, options) return ( fetch: FetchAPI = portableFetch, basePath: string = BASE_PATH @@ -3310,12 +3307,12 @@ export const DefaultApiFp = function (configuration?: Configuration) { localVarFetchArgs.options ).then((response) => { if (response.status >= 200 && response.status < 300) { - return response.json(); + return response.json() } else { - throw response; + throw response } - }); - }; + }) + } }, /** * @@ -3334,13 +3331,10 @@ export const DefaultApiFp = function (configuration?: Configuration) { start_ts?: number, end_ts?: number, options?: any - ): ( - fetch?: FetchAPI, - basePath?: string - ) => Promise { + ): (fetch?: FetchAPI, basePath?: string) => Promise { const localVarFetchArgs = DefaultApiFetchParamCreator( configuration - ).memoryEventsGet(run, worker, span, start_ts, end_ts, options); + ).memoryEventsGet(run, worker, span, start_ts, end_ts, options) return ( fetch: FetchAPI = portableFetch, basePath: string = BASE_PATH @@ -3350,12 +3344,12 @@ export const DefaultApiFp = function (configuration?: Configuration) { localVarFetchArgs.options ).then((response) => { if (response.status >= 200 && response.status < 300) { - return response.json(); + return response.json() } else { - throw response; + throw response } - }); - }; + }) + } }, /** * @@ -3377,7 +3371,7 @@ export const DefaultApiFp = function (configuration?: Configuration) { ): (fetch?: FetchAPI, basePath?: string) => Promise { const localVarFetchArgs = DefaultApiFetchParamCreator( configuration - ).memoryGet(run, worker, span, start_ts, end_ts, options); + ).memoryGet(run, worker, span, start_ts, end_ts, options) return ( fetch: FetchAPI = portableFetch, basePath: string = BASE_PATH @@ -3387,12 +3381,12 @@ export const DefaultApiFp = function (configuration?: Configuration) { localVarFetchArgs.options ).then((response) => { if (response.status >= 200 && response.status < 300) { - return response.json(); + return response.json() } else { - throw response; + throw response } - }); - }; + }) + } }, /** * @@ -3410,7 +3404,7 @@ export const DefaultApiFp = function (configuration?: Configuration) { ): (fetch?: FetchAPI, basePath?: string) => Promise { const localVarFetchArgs = DefaultApiFetchParamCreator( configuration - ).moduleGet(run, worker, span, options); + ).moduleGet(run, worker, span, options) return ( fetch: FetchAPI = portableFetch, basePath: string = BASE_PATH @@ -3420,12 +3414,12 @@ export const DefaultApiFp = function (configuration?: Configuration) { localVarFetchArgs.options ).then((response) => { if (response.status >= 200 && response.status < 300) { - return response.json(); + return response.json() } else { - throw response; + throw response } - }); - }; + }) + } }, /** * @@ -3445,7 +3439,7 @@ export const DefaultApiFp = function (configuration?: Configuration) { ): (fetch?: FetchAPI, basePath?: string) => Promise { const localVarFetchArgs = DefaultApiFetchParamCreator( configuration - ).operationGet(run, worker, span, group_by, options); + ).operationGet(run, worker, span, group_by, options) return ( fetch: FetchAPI = portableFetch, basePath: string = BASE_PATH @@ -3455,12 +3449,12 @@ export const DefaultApiFp = function (configuration?: Configuration) { localVarFetchArgs.options ).then((response) => { if (response.status >= 200 && response.status < 300) { - return response.json(); + return response.json() } else { - throw response; + throw response } - }); - }; + }) + } }, /** * @@ -3492,7 +3486,7 @@ export const DefaultApiFp = function (configuration?: Configuration) { op_name, input_shape, options - ); + ) return ( fetch: FetchAPI = portableFetch, basePath: string = BASE_PATH @@ -3502,12 +3496,12 @@ export const DefaultApiFp = function (configuration?: Configuration) { localVarFetchArgs.options ).then((response) => { if (response.status >= 200 && response.status < 300) { - return response.json(); + return response.json() } else { - throw response; + throw response } - }); - }; + }) + } }, /** * @@ -3527,7 +3521,7 @@ export const DefaultApiFp = function (configuration?: Configuration) { ): (fetch?: FetchAPI, basePath?: string) => Promise { const localVarFetchArgs = DefaultApiFetchParamCreator( configuration - ).operationTableGet(run, worker, span, group_by, options); + ).operationTableGet(run, worker, span, group_by, options) return ( fetch: FetchAPI = portableFetch, basePath: string = BASE_PATH @@ -3537,12 +3531,12 @@ export const DefaultApiFp = function (configuration?: Configuration) { localVarFetchArgs.options ).then((response) => { if (response.status >= 200 && response.status < 300) { - return response.json(); + return response.json() } else { - throw response; + throw response } - }); - }; + }) + } }, /** * @@ -3560,7 +3554,7 @@ export const DefaultApiFp = function (configuration?: Configuration) { ): (fetch?: FetchAPI, basePath?: string) => Promise { const localVarFetchArgs = DefaultApiFetchParamCreator( configuration - ).overviewGet(run, worker, span, options); + ).overviewGet(run, worker, span, options) return ( fetch: FetchAPI = portableFetch, basePath: string = BASE_PATH @@ -3570,12 +3564,12 @@ export const DefaultApiFp = function (configuration?: Configuration) { localVarFetchArgs.options ).then((response) => { if (response.status >= 200 && response.status < 300) { - return response.json(); + return response.json() } else { - throw response; + throw response } - }); - }; + }) + } }, /** * @@ -3585,8 +3579,9 @@ export const DefaultApiFp = function (configuration?: Configuration) { runsGet( options?: any ): (fetch?: FetchAPI, basePath?: string) => Promise { - const localVarFetchArgs = - DefaultApiFetchParamCreator(configuration).runsGet(options); + const localVarFetchArgs = DefaultApiFetchParamCreator( + configuration + ).runsGet(options) return ( fetch: FetchAPI = portableFetch, basePath: string = BASE_PATH @@ -3596,12 +3591,12 @@ export const DefaultApiFp = function (configuration?: Configuration) { localVarFetchArgs.options ).then((response) => { if (response.status >= 200 && response.status < 300) { - return response.json(); + return response.json() } else { - throw response; + throw response } - }); - }; + }) + } }, /** * @@ -3617,7 +3612,7 @@ export const DefaultApiFp = function (configuration?: Configuration) { ): (fetch?: FetchAPI, basePath?: string) => Promise> { const localVarFetchArgs = DefaultApiFetchParamCreator( configuration - ).spansGet(run, worker, options); + ).spansGet(run, worker, options) return ( fetch: FetchAPI = portableFetch, basePath: string = BASE_PATH @@ -3627,12 +3622,12 @@ export const DefaultApiFp = function (configuration?: Configuration) { localVarFetchArgs.options ).then((response) => { if (response.status >= 200 && response.status < 300) { - return response.json(); + return response.json() } else { - throw response; + throw response } - }); - }; + }) + } }, /** * @@ -3650,7 +3645,7 @@ export const DefaultApiFp = function (configuration?: Configuration) { ): (fetch?: FetchAPI, basePath?: string) => Promise { const localVarFetchArgs = DefaultApiFetchParamCreator( configuration - ).traceGet(run, worker, span, options); + ).traceGet(run, worker, span, options) return ( fetch: FetchAPI = portableFetch, basePath: string = BASE_PATH @@ -3660,12 +3655,12 @@ export const DefaultApiFp = function (configuration?: Configuration) { localVarFetchArgs.options ).then((response) => { if (response.status >= 200 && response.status < 300) { - return response.json(); + return response.json() } else { - throw response; + throw response } - }); - }; + }) + } }, /** * @@ -3683,7 +3678,7 @@ export const DefaultApiFp = function (configuration?: Configuration) { ): (fetch?: FetchAPI, basePath?: string) => Promise { const localVarFetchArgs = DefaultApiFetchParamCreator( configuration - ).treeGet(run, worker, span, options); + ).treeGet(run, worker, span, options) return ( fetch: FetchAPI = portableFetch, basePath: string = BASE_PATH @@ -3693,12 +3688,12 @@ export const DefaultApiFp = function (configuration?: Configuration) { localVarFetchArgs.options ).then((response) => { if (response.status >= 200 && response.status < 300) { - return response.json(); + return response.json() } else { - throw response; + throw response } - }); - }; + }) + } }, /** * @@ -3712,7 +3707,7 @@ export const DefaultApiFp = function (configuration?: Configuration) { ): (fetch?: FetchAPI, basePath?: string) => Promise { const localVarFetchArgs = DefaultApiFetchParamCreator( configuration - ).viewsGet(run, options); + ).viewsGet(run, options) return ( fetch: FetchAPI = portableFetch, basePath: string = BASE_PATH @@ -3722,12 +3717,12 @@ export const DefaultApiFp = function (configuration?: Configuration) { localVarFetchArgs.options ).then((response) => { if (response.status >= 200 && response.status < 300) { - return response.json(); + return response.json() } else { - throw response; + throw response } - }); - }; + }) + } }, /** * @@ -3743,7 +3738,7 @@ export const DefaultApiFp = function (configuration?: Configuration) { ): (fetch?: FetchAPI, basePath?: string) => Promise> { const localVarFetchArgs = DefaultApiFetchParamCreator( configuration - ).workersGet(run, view, options); + ).workersGet(run, view, options) return ( fetch: FetchAPI = portableFetch, basePath: string = BASE_PATH @@ -3753,15 +3748,15 @@ export const DefaultApiFp = function (configuration?: Configuration) { localVarFetchArgs.options ).then((response) => { if (response.status >= 200 && response.status < 300) { - return response.json(); + return response.json() } else { - throw response; + throw response } - }); - }; - }, - }; -}; + }) + } + } + } +} /** * DefaultApi - factory interface @@ -3804,7 +3799,7 @@ export const DefaultApiFactory = function ( exp_span, path, options - )(fetch, basePath); + )(fetch, basePath) }, /** * @@ -3825,7 +3820,7 @@ export const DefaultApiFactory = function ( worker, span, options - )(fetch, basePath); + )(fetch, basePath) }, /** * @@ -3846,7 +3841,7 @@ export const DefaultApiFactory = function ( worker, span, options - )(fetch, basePath); + )(fetch, basePath) }, /** * @@ -3867,7 +3862,7 @@ export const DefaultApiFactory = function ( worker, span, options - )(fetch, basePath); + )(fetch, basePath) }, /** * @@ -3888,7 +3883,7 @@ export const DefaultApiFactory = function ( worker, span, options - )(fetch, basePath); + )(fetch, basePath) }, /** * @@ -3912,7 +3907,7 @@ export const DefaultApiFactory = function ( span, group_by, options - )(fetch, basePath); + )(fetch, basePath) }, /** * @@ -3936,7 +3931,7 @@ export const DefaultApiFactory = function ( span, group_by, options - )(fetch, basePath); + )(fetch, basePath) }, /** * @@ -3952,7 +3947,7 @@ export const DefaultApiFactory = function ( worker, span, options - )(fetch, basePath); + )(fetch, basePath) }, /** * @@ -3968,7 +3963,7 @@ export const DefaultApiFactory = function ( worker, span, options - )(fetch, basePath); + )(fetch, basePath) }, /** * @@ -3995,7 +3990,7 @@ export const DefaultApiFactory = function ( start_ts, end_ts, options - )(fetch, basePath); + )(fetch, basePath) }, /** * @@ -4022,7 +4017,7 @@ export const DefaultApiFactory = function ( start_ts, end_ts, options - )(fetch, basePath); + )(fetch, basePath) }, /** * @@ -4038,7 +4033,7 @@ export const DefaultApiFactory = function ( worker, span, options - )(fetch, basePath); + )(fetch, basePath) }, /** * @@ -4062,7 +4057,7 @@ export const DefaultApiFactory = function ( span, group_by, options - )(fetch, basePath); + )(fetch, basePath) }, /** * @@ -4092,7 +4087,7 @@ export const DefaultApiFactory = function ( op_name, input_shape, options - )(fetch, basePath); + )(fetch, basePath) }, /** * @@ -4116,7 +4111,7 @@ export const DefaultApiFactory = function ( span, group_by, options - )(fetch, basePath); + )(fetch, basePath) }, /** * @@ -4132,7 +4127,7 @@ export const DefaultApiFactory = function ( worker, span, options - )(fetch, basePath); + )(fetch, basePath) }, /** * @@ -4140,7 +4135,7 @@ export const DefaultApiFactory = function ( * @throws {RequiredError} */ runsGet(options?: any) { - return DefaultApiFp(configuration).runsGet(options)(fetch, basePath); + return DefaultApiFp(configuration).runsGet(options)(fetch, basePath) }, /** * @@ -4154,7 +4149,7 @@ export const DefaultApiFactory = function ( run, worker, options - )(fetch, basePath); + )(fetch, basePath) }, /** * @@ -4170,7 +4165,7 @@ export const DefaultApiFactory = function ( worker, span, options - )(fetch, basePath); + )(fetch, basePath) }, /** * @@ -4186,7 +4181,7 @@ export const DefaultApiFactory = function ( worker, span, options - )(fetch, basePath); + )(fetch, basePath) }, /** * @@ -4195,10 +4190,7 @@ export const DefaultApiFactory = function ( * @throws {RequiredError} */ viewsGet(run: string, options?: any) { - return DefaultApiFp(configuration).viewsGet(run, options)( - fetch, - basePath - ); + return DefaultApiFp(configuration).viewsGet(run, options)(fetch, basePath) }, /** * @@ -4212,10 +4204,10 @@ export const DefaultApiFactory = function ( run, view, options - )(fetch, basePath); - }, - }; -}; + )(fetch, basePath) + } + } +} /** * DefaultApi - object-oriented interface @@ -4256,7 +4248,7 @@ export class DefaultApi extends BaseAPI { exp_span, path, options - )(this.fetch, this.basePath); + )(this.fetch, this.basePath) } /** @@ -4279,7 +4271,7 @@ export class DefaultApi extends BaseAPI { worker, span, options - )(this.fetch, this.basePath); + )(this.fetch, this.basePath) } /** @@ -4302,7 +4294,7 @@ export class DefaultApi extends BaseAPI { worker, span, options - )(this.fetch, this.basePath); + )(this.fetch, this.basePath) } /** @@ -4325,7 +4317,7 @@ export class DefaultApi extends BaseAPI { worker, span, options - )(this.fetch, this.basePath); + )(this.fetch, this.basePath) } /** @@ -4348,7 +4340,7 @@ export class DefaultApi extends BaseAPI { worker, span, options - )(this.fetch, this.basePath); + )(this.fetch, this.basePath) } /** @@ -4374,7 +4366,7 @@ export class DefaultApi extends BaseAPI { span, group_by, options - )(this.fetch, this.basePath); + )(this.fetch, this.basePath) } /** @@ -4400,7 +4392,7 @@ export class DefaultApi extends BaseAPI { span, group_by, options - )(this.fetch, this.basePath); + )(this.fetch, this.basePath) } /** @@ -4423,7 +4415,7 @@ export class DefaultApi extends BaseAPI { worker, span, options - )(this.fetch, this.basePath); + )(this.fetch, this.basePath) } /** @@ -4446,7 +4438,7 @@ export class DefaultApi extends BaseAPI { worker, span, options - )(this.fetch, this.basePath); + )(this.fetch, this.basePath) } /** @@ -4475,7 +4467,7 @@ export class DefaultApi extends BaseAPI { start_ts, end_ts, options - )(this.fetch, this.basePath); + )(this.fetch, this.basePath) } /** @@ -4504,7 +4496,7 @@ export class DefaultApi extends BaseAPI { start_ts, end_ts, options - )(this.fetch, this.basePath); + )(this.fetch, this.basePath) } /** @@ -4522,7 +4514,7 @@ export class DefaultApi extends BaseAPI { worker, span, options - )(this.fetch, this.basePath); + )(this.fetch, this.basePath) } /** @@ -4548,7 +4540,7 @@ export class DefaultApi extends BaseAPI { span, group_by, options - )(this.fetch, this.basePath); + )(this.fetch, this.basePath) } /** @@ -4580,7 +4572,7 @@ export class DefaultApi extends BaseAPI { op_name, input_shape, options - )(this.fetch, this.basePath); + )(this.fetch, this.basePath) } /** @@ -4606,7 +4598,7 @@ export class DefaultApi extends BaseAPI { span, group_by, options - )(this.fetch, this.basePath); + )(this.fetch, this.basePath) } /** @@ -4624,7 +4616,7 @@ export class DefaultApi extends BaseAPI { worker, span, options - )(this.fetch, this.basePath); + )(this.fetch, this.basePath) } /** @@ -4637,7 +4629,7 @@ export class DefaultApi extends BaseAPI { return DefaultApiFp(this.configuration).runsGet(options)( this.fetch, this.basePath - ); + ) } /** @@ -4653,7 +4645,7 @@ export class DefaultApi extends BaseAPI { run, worker, options - )(this.fetch, this.basePath); + )(this.fetch, this.basePath) } /** @@ -4671,7 +4663,7 @@ export class DefaultApi extends BaseAPI { worker, span, options - )(this.fetch, this.basePath); + )(this.fetch, this.basePath) } /** @@ -4689,7 +4681,7 @@ export class DefaultApi extends BaseAPI { worker, span, options - )(this.fetch, this.basePath); + )(this.fetch, this.basePath) } /** @@ -4703,7 +4695,7 @@ export class DefaultApi extends BaseAPI { return DefaultApiFp(this.configuration).viewsGet(run, options)( this.fetch, this.basePath - ); + ) } /** @@ -4719,6 +4711,6 @@ export class DefaultApi extends BaseAPI { run, view, options - )(this.fetch, this.basePath); + )(this.fetch, this.basePath) } } diff --git a/plugins/tensorboard-plugins/tb_plugin/fe/src/api/generated/configuration.ts b/plugins/tensorboard-plugins/tb_plugin/fe/src/api/generated/configuration.ts index 85b77bf651c049ec5a2ec85379414f619904c6dd..edec57eed84498fa3dcaa804ada9787b0202066c 100644 --- a/plugins/tensorboard-plugins/tb_plugin/fe/src/api/generated/configuration.ts +++ b/plugins/tensorboard-plugins/tb_plugin/fe/src/api/generated/configuration.ts @@ -14,12 +14,13 @@ * https://github.com/swagger-api/swagger-codegen.git * Do not edit the file manually. */ + export interface ConfigurationParameters { - apiKey?: string | ((name: string) => string); - username?: string; - password?: string; - accessToken?: string | ((name: string, scopes?: string[]) => string); - basePath?: string; + apiKey?: string | ((name: string) => string) + username?: string + password?: string + accessToken?: string | ((name: string, scopes?: string[]) => string) + basePath?: string } export class Configuration { @@ -28,41 +29,41 @@ export class Configuration { * @param name security name * @memberof Configuration */ - apiKey?: string | ((name: string) => string); + apiKey?: string | ((name: string) => string) /** * parameter for basic security * * @type {string} * @memberof Configuration */ - username?: string; + username?: string /** * parameter for basic security * * @type {string} * @memberof Configuration */ - password?: string; + password?: string /** * parameter for oauth2 security * @param name security name * @param scopes oauth2 scope * @memberof Configuration */ - accessToken?: string | ((name: string, scopes?: string[]) => string); + accessToken?: string | ((name: string, scopes?: string[]) => string) /** * override base path * * @type {string} * @memberof Configuration */ - basePath?: string; + basePath?: string constructor(param: ConfigurationParameters = {}) { - this.apiKey = param.apiKey; - this.username = param.username; - this.password = param.password; - this.accessToken = param.accessToken; - this.basePath = param.basePath; + this.apiKey = param.apiKey + this.username = param.username + this.password = param.password + this.accessToken = param.accessToken + this.basePath = param.basePath } } diff --git a/plugins/tensorboard-plugins/tb_plugin/fe/src/api/generated/custom.d.ts b/plugins/tensorboard-plugins/tb_plugin/fe/src/api/generated/custom.d.ts index 992af468898f15bee4f609a8cb752e21f0a9ad48..bfe6a59d9df208845d2fb5a43edb7a2f3d8721ae 100644 --- a/plugins/tensorboard-plugins/tb_plugin/fe/src/api/generated/custom.d.ts +++ b/plugins/tensorboard-plugins/tb_plugin/fe/src/api/generated/custom.d.ts @@ -2,5 +2,5 @@ * Copyright (c) Microsoft Corporation. All rights reserved. *--------------------------------------------------------------------------------------------*/ -declare module 'portable-fetch'; -declare module 'url'; +declare module 'portable-fetch' +declare module 'url' diff --git a/plugins/tensorboard-plugins/tb_plugin/fe/src/api/generated/index.ts b/plugins/tensorboard-plugins/tb_plugin/fe/src/api/generated/index.ts index 7ad784e60de2777174cea9d902ad9cf2550fad68..1ab79fb65f34d7c33099bac7e54378c3f54fdb35 100644 --- a/plugins/tensorboard-plugins/tb_plugin/fe/src/api/generated/index.ts +++ b/plugins/tensorboard-plugins/tb_plugin/fe/src/api/generated/index.ts @@ -14,5 +14,6 @@ * https://github.com/swagger-api/swagger-codegen.git * Do not edit the file manually. */ -export * from './api'; -export * from './configuration'; + +export * from './api' +export * from './configuration' diff --git a/plugins/tensorboard-plugins/tb_plugin/fe/src/api/index.ts b/plugins/tensorboard-plugins/tb_plugin/fe/src/api/index.ts index 98b35abfbc09785ffa09b1bbaa48c73685ec84f5..f43336a583b81998422facba8787270d6cee7673 100644 --- a/plugins/tensorboard-plugins/tb_plugin/fe/src/api/index.ts +++ b/plugins/tensorboard-plugins/tb_plugin/fe/src/api/index.ts @@ -2,7 +2,7 @@ * Copyright (c) Microsoft Corporation. All rights reserved. *--------------------------------------------------------------------------------------------*/ -import * as api from './generated'; +import * as api from './generated' -export const defaultApi = new api.DefaultApi(undefined, undefined, fetch); -export * from './generated/api'; +export const defaultApi = new api.DefaultApi(undefined, undefined, fetch) +export * from './generated/api' diff --git a/plugins/tensorboard-plugins/tb_plugin/fe/src/api/mock.ts b/plugins/tensorboard-plugins/tb_plugin/fe/src/api/mock.ts index 4b4b447d97192b7c7c00784dd9176faeed25d64b..744c222a0266eed6359bb60fc0f6ba9601ba8edc 100644 --- a/plugins/tensorboard-plugins/tb_plugin/fe/src/api/mock.ts +++ b/plugins/tensorboard-plugins/tb_plugin/fe/src/api/mock.ts @@ -6,8 +6,8 @@ export class MockAPI { runsGet() { return { runs: ['resnet50_num_workers_0', 'resnet50_num_workers_4'], - loading: false, - }; + loading: false + } } viewsGet(run: string) { @@ -16,16 +16,16 @@ export class MockAPI { 'Operator', 'Kernel', 'Trace', - 'Memory', - ]); + 'Memory' + ]) } - spansGet(run: string, view: string): Promise { - return Promise.resolve(['1', '2']); + spansGet(run: string, view: String) { + return Promise.resolve(['1', '2']) } - workersGet(run: string, view: string): Promise { - return Promise.resolve(['worker0']); + workersGet(run: string, view: String) { + return Promise.resolve(['worker0']) } overviewGet(run: string, worker: string, span: string) { @@ -46,7 +46,7 @@ export class MockAPI { { type: 'number', name: 'CPU Exec' }, { type: 'string', role: 'tooltip', p: { html: 'true' } }, { type: 'number', name: 'Other' }, - { type: 'string', role: 'tooltip', p: { html: 'true' } }, + { type: 'string', role: 'tooltip', p: { html: 'true' } } ], rows: [ [ @@ -64,7 +64,7 @@ export class MockAPI { 14091, '
Step 5
Total: 187948us
CPU Exec: 14091us
Percentage: 7.5%
', 1115, - '
Step 5
Total: 187948us
Other: 1115us
Percentage: 0.59%
', + '
Step 5
Total: 187948us
Other: 1115us
Percentage: 0.59%
' ], [ '6', @@ -81,7 +81,7 @@ export class MockAPI { 12968, '
Step 6
Total: 175153us
CPU Exec: 12968us
Percentage: 7.4%
', 1148, - '
Step 6
Total: 175153us
Other: 1148us
Percentage: 0.66%
', + '
Step 6
Total: 175153us
Other: 1148us
Percentage: 0.66%
' ], [ '7', @@ -98,7 +98,7 @@ export class MockAPI { 13768, '
Step 7
Total: 179733us
CPU Exec: 13768us
Percentage: 7.66%
', 1213, - '
Step 7
Total: 179733us
Other: 1213us
Percentage: 0.67%
', + '
Step 7
Total: 179733us
Other: 1213us
Percentage: 0.67%
' ], [ '8', @@ -115,7 +115,7 @@ export class MockAPI { 13420, '
Step 8
Total: 174564us
CPU Exec: 13420us
Percentage: 7.69%
', 1200, - '
Step 8
Total: 174564us
Other: 1200us
Percentage: 0.69%
', + '
Step 8
Total: 174564us
Other: 1200us
Percentage: 0.69%
' ], [ '9', @@ -132,7 +132,7 @@ export class MockAPI { 15025, '
Step 9
Total: 182172us
CPU Exec: 15025us
Percentage: 8.25%
', 1141, - '
Step 9
Total: 182172us
Other: 1141us
Percentage: 0.63%
', + '
Step 9
Total: 182172us
Other: 1141us
Percentage: 0.63%
' ], [ '10', @@ -149,9 +149,9 @@ export class MockAPI { 12773, '
Step 10
Total: 165983us
CPU Exec: 12773us
Percentage: 7.7%
', 1117, - '
Step 10
Total: 165983us
Other: 1117us
Percentage: 0.67%
', - ], - ], + '
Step 10
Total: 165983us
Other: 1117us
Percentage: 0.67%
' + ] + ] }, performance: [ { @@ -166,15 +166,15 @@ export class MockAPI { { name: 'Runtime', description: '', value: 2908, extra: 1.64 }, { name: 'DataLoader', description: '', value: 59262, extra: 33.37 }, { name: 'CPU Exec', description: '', value: 13674, extra: 7.7 }, - { name: 'Other', description: '', value: 1156, extra: 0.65 }, - ], - }, + { name: 'Other', description: '', value: 1156, extra: 0.65 } + ] + } ], recommendations: '
  • This run has high time cost on input data loading. 33.4% of the step time is in DataLoader. You could try to set num_workers on DataLoader\'s construction and enable multi-processes on data loading.
  • Kernels with 68% time are launched by Tensor Cores eligible operators. You could enable Automatic Mixed Precision to speedup by using FP16.
', environments: [ { title: 'Number of Worker(s)', value: '1' }, - { title: 'Device Type', value: 'GPU' }, + { title: 'Device Type', value: 'GPU' } ], gpu_metrics: { title: 'GPU Summary', @@ -186,12 +186,12 @@ export class MockAPI { { title: 'GPU Utilization', value: '55.51 %' }, { title: 'Est. SM Efficiency', value: '54.68 %' }, { title: 'Est. Achieved Occupancy', value: '49.13 %' }, - { title: 'Kernel Time using Tensor Cores', value: '0.0 %' }, + { title: 'Kernel Time using Tensor Cores', value: '0.0 %' } ], tooltip: - "The GPU usage metrics:\n\nGPU Utilization:\nGPU busy time / All steps time. The higher, the better. GPU busy time is the time during which there is at least one GPU kernel running on it. All steps time is the total time of all profiler steps(or called as iterations).\n\nEst. SM Efficiency:\nEstimated Stream Multiprocessor Efficiency. The higher, the better. This metric of a kernel, SM_Eff_K = min(blocks of this kernel / SM number of this GPU, 100%). This overall number is the sum of all kernels' SM_Eff_K weighted by kernel's execution duration, divided by all steps time.\n\nEst. Achieved Occupancy:\nFor most cases such as memory bandwidth bounded kernels, the higher the better. Occupancy is the ratio of active warps on an SM to the maximum number of active warps supported by the SM. The theoretical occupancy of a kernel is upper limit occupancy of this kernel, limited by multiple factors such as kernel shape, kernel used resource, and the GPU compute capability.\nEst. Achieved Occupancy of a kernel, OCC_K = min(threads of the kernel / SM number / max threads per SM, theoretical occupancy of the kernel). This overall number is the weighted average of all kernels' OCC_K using kernel's execution duration as weight. It shows fine-grained low-level GPU utilization.\n\nKernel using Tensor Cores:\nTotal GPU Time for Tensor Core kernels / Total GPU Time for all kernels.\n", - }, - }); + "The GPU usage metrics:\n\nGPU Utilization:\nGPU busy time / All steps time. The higher, the better. GPU busy time is the time during which there is at least one GPU kernel running on it. All steps time is the total time of all profiler steps(or called as iterations).\n\nEst. SM Efficiency:\nEstimated Stream Multiprocessor Efficiency. The higher, the better. This metric of a kernel, SM_Eff_K = min(blocks of this kernel / SM number of this GPU, 100%). This overall number is the sum of all kernels' SM_Eff_K weighted by kernel's execution duration, divided by all steps time.\n\nEst. Achieved Occupancy:\nFor most cases such as memory bandwidth bounded kernels, the higher the better. Occupancy is the ratio of active warps on an SM to the maximum number of active warps supported by the SM. The theoretical occupancy of a kernel is upper limit occupancy of this kernel, limited by multiple factors such as kernel shape, kernel used resource, and the GPU compute capability.\nEst. Achieved Occupancy of a kernel, OCC_K = min(threads of the kernel / SM number / max threads per SM, theoretical occupancy of the kernel). This overall number is the weighted average of all kernels' OCC_K using kernel's execution duration as weight. It shows fine-grained low-level GPU utilization.\n\nKernel using Tensor Cores:\nTotal GPU Time for Tensor Core kernels / Total GPU Time for all kernels.\n" + } + }) } diffnodeGet( @@ -216,7 +216,7 @@ export class MockAPI { host_duration: 186312, device_duration: 0, self_host_duration: 186312, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::zero_', @@ -224,7 +224,7 @@ export class MockAPI { host_duration: 31902, device_duration: 736, self_host_duration: 17460, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::zeros', @@ -232,7 +232,7 @@ export class MockAPI { host_duration: 62713, device_duration: 0, self_host_duration: 32640, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::to', @@ -240,7 +240,7 @@ export class MockAPI { host_duration: 1711486, device_duration: 8796, self_host_duration: 37162, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'detach', @@ -248,7 +248,7 @@ export class MockAPI { host_duration: 4379, device_duration: 0, self_host_duration: 4379, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::detach', @@ -256,7 +256,7 @@ export class MockAPI { host_duration: 10596, device_duration: 0, self_host_duration: 6217, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::as_strided', @@ -264,7 +264,7 @@ export class MockAPI { host_duration: 8470, device_duration: 0, self_host_duration: 8470, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::unsqueeze', @@ -272,7 +272,7 @@ export class MockAPI { host_duration: 19150, device_duration: 0, self_host_duration: 16142, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::empty_strided', @@ -280,7 +280,7 @@ export class MockAPI { host_duration: 50043, device_duration: 0, self_host_duration: 50043, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::copy_', @@ -288,7 +288,7 @@ export class MockAPI { host_duration: 1518205, device_duration: 8796, self_host_duration: 1509009, - self_device_duration: 8796, + self_device_duration: 8796 }, { name: 'aten::_to_copy', @@ -296,7 +296,7 @@ export class MockAPI { host_duration: 1674324, device_duration: 8796, self_host_duration: 104788, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::upsample_bilinear2d', @@ -304,7 +304,7 @@ export class MockAPI { host_duration: 460479, device_duration: 0, self_host_duration: 421547, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::squeeze', @@ -312,7 +312,7 @@ export class MockAPI { host_duration: 9401, device_duration: 0, self_host_duration: 8211, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::round', @@ -320,7 +320,7 @@ export class MockAPI { host_duration: 31311, device_duration: 0, self_host_duration: 31311, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::slice', @@ -328,7 +328,7 @@ export class MockAPI { host_duration: 17762, device_duration: 0, self_host_duration: 15082, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'detach_', @@ -336,7 +336,7 @@ export class MockAPI { host_duration: 4194, device_duration: 0, self_host_duration: 4194, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::detach_', @@ -344,7 +344,7 @@ export class MockAPI { host_duration: 14514, device_duration: 0, self_host_duration: 10320, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::result_type', @@ -352,7 +352,7 @@ export class MockAPI { host_duration: 1734, device_duration: 0, self_host_duration: 1734, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::pow', @@ -360,7 +360,7 @@ export class MockAPI { host_duration: 86249, device_duration: 0, self_host_duration: 78373, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::sub', @@ -368,7 +368,7 @@ export class MockAPI { host_duration: 183533, device_duration: 0, self_host_duration: 75637, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::gt', @@ -376,7 +376,7 @@ export class MockAPI { host_duration: 71284, device_duration: 0, self_host_duration: 49575, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::_local_scalar_dense', @@ -384,7 +384,7 @@ export class MockAPI { host_duration: 4948, device_duration: 0, self_host_duration: 4948, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::item', @@ -392,7 +392,7 @@ export class MockAPI { host_duration: 20922, device_duration: 0, self_host_duration: 15974, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::is_nonzero', @@ -400,7 +400,7 @@ export class MockAPI { host_duration: 27934, device_duration: 0, self_host_duration: 10747, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::div', @@ -408,7 +408,7 @@ export class MockAPI { host_duration: 168214, device_duration: 75, self_host_duration: 146203, - self_device_duration: 75, + self_device_duration: 75 }, { name: 'aten::resize_', @@ -416,7 +416,7 @@ export class MockAPI { host_duration: 248, device_duration: 0, self_host_duration: 248, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::narrow', @@ -424,7 +424,7 @@ export class MockAPI { host_duration: 280, device_duration: 0, self_host_duration: 99, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::_cat', @@ -432,7 +432,7 @@ export class MockAPI { host_duration: 92993, device_duration: 0, self_host_duration: 92405, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::cat', @@ -440,7 +440,7 @@ export class MockAPI { host_duration: 93282, device_duration: 0, self_host_duration: 289, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::stack', @@ -448,7 +448,7 @@ export class MockAPI { host_duration: 124757, device_duration: 0, self_host_duration: 22050, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::cudnn_convolution', @@ -456,7 +456,7 @@ export class MockAPI { host_duration: 44043, device_duration: 71832, self_host_duration: 35027, - self_device_duration: 71832, + self_device_duration: 71832 }, { name: 'aten::_convolution', @@ -464,7 +464,7 @@ export class MockAPI { host_duration: 51312, device_duration: 71832, self_host_duration: 7269, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::convolution', @@ -472,7 +472,7 @@ export class MockAPI { host_duration: 55287, device_duration: 71832, self_host_duration: 3975, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::conv2d', @@ -480,7 +480,7 @@ export class MockAPI { host_duration: 59323, device_duration: 71832, self_host_duration: 4036, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::add', @@ -488,7 +488,7 @@ export class MockAPI { host_duration: 17461, device_duration: 10540, self_host_duration: 15188, - self_device_duration: 10540, + self_device_duration: 10540 }, { name: 'aten::empty_like', @@ -496,7 +496,7 @@ export class MockAPI { host_duration: 11504, device_duration: 0, self_host_duration: 4865, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::view', @@ -504,7 +504,7 @@ export class MockAPI { host_duration: 3589, device_duration: 0, self_host_duration: 3589, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::cudnn_batch_norm', @@ -512,7 +512,7 @@ export class MockAPI { host_duration: 71328, device_duration: 25802, self_host_duration: 40944, - self_device_duration: 25802, + self_device_duration: 25802 }, { name: 'aten::_batch_norm_impl_index', @@ -520,7 +520,7 @@ export class MockAPI { host_duration: 76354, device_duration: 25802, self_host_duration: 5026, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::batch_norm', @@ -528,7 +528,7 @@ export class MockAPI { host_duration: 79832, device_duration: 25802, self_host_duration: 3478, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::clamp_min', @@ -536,7 +536,7 @@ export class MockAPI { host_duration: 5417, device_duration: 12000, self_host_duration: 3885, - self_device_duration: 12000, + self_device_duration: 12000 }, { name: 'aten::clamp_min_', @@ -544,7 +544,7 @@ export class MockAPI { host_duration: 8537, device_duration: 12000, self_host_duration: 3120, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::relu_', @@ -552,7 +552,7 @@ export class MockAPI { host_duration: 16708, device_duration: 12000, self_host_duration: 8171, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::max_pool2d_with_indices', @@ -560,7 +560,7 @@ export class MockAPI { host_duration: 442, device_duration: 940, self_host_duration: 405, - self_device_duration: 940, + self_device_duration: 940 }, { name: 'aten::max_pool2d', @@ -568,7 +568,7 @@ export class MockAPI { host_duration: 542, device_duration: 940, self_host_duration: 100, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::add_', @@ -576,7 +576,7 @@ export class MockAPI { host_duration: 72931, device_duration: 13090, self_host_duration: 57558, - self_device_duration: 13090, + self_device_duration: 13090 }, { name: 'aten::mean', @@ -584,7 +584,7 @@ export class MockAPI { host_duration: 376, device_duration: 133, self_host_duration: 339, - self_device_duration: 133, + self_device_duration: 133 }, { name: 'aten::adaptive_avg_pool2d', @@ -592,7 +592,7 @@ export class MockAPI { host_duration: 465, device_duration: 133, self_host_duration: 89, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::_reshape_alias', @@ -600,7 +600,7 @@ export class MockAPI { host_duration: 170, device_duration: 0, self_host_duration: 170, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::flatten', @@ -608,7 +608,7 @@ export class MockAPI { host_duration: 207, device_duration: 0, self_host_duration: 103, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::transpose', @@ -616,7 +616,7 @@ export class MockAPI { host_duration: 587, device_duration: 0, self_host_duration: 465, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::t', @@ -624,7 +624,7 @@ export class MockAPI { host_duration: 1068, device_duration: 0, self_host_duration: 481, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::expand', @@ -632,7 +632,7 @@ export class MockAPI { host_duration: 277, device_duration: 0, self_host_duration: 227, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::addmm', @@ -640,7 +640,7 @@ export class MockAPI { host_duration: 809, device_duration: 84, self_host_duration: 604, - self_device_duration: 84, + self_device_duration: 84 }, { name: 'aten::linear', @@ -648,7 +648,7 @@ export class MockAPI { host_duration: 1185, device_duration: 84, self_host_duration: 137, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::_log_softmax', @@ -656,7 +656,7 @@ export class MockAPI { host_duration: 308, device_duration: 14, self_host_duration: 271, - self_device_duration: 14, + self_device_duration: 14 }, { name: 'aten::log_softmax', @@ -664,7 +664,7 @@ export class MockAPI { host_duration: 472, device_duration: 14, self_host_duration: 153, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::nll_loss_forward', @@ -672,7 +672,7 @@ export class MockAPI { host_duration: 522, device_duration: 8, self_host_duration: 476, - self_device_duration: 8, + self_device_duration: 8 }, { name: 'aten::nll_loss', @@ -680,7 +680,7 @@ export class MockAPI { host_duration: 590, device_duration: 8, self_host_duration: 68, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::nll_loss_nd', @@ -688,7 +688,7 @@ export class MockAPI { host_duration: 641, device_duration: 8, self_host_duration: 51, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::cross_entropy_loss', @@ -696,7 +696,7 @@ export class MockAPI { host_duration: 1234, device_duration: 22, self_host_duration: 121, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::fill_', @@ -704,7 +704,7 @@ export class MockAPI { host_duration: 14541, device_duration: 738, self_host_duration: 10083, - self_device_duration: 738, + self_device_duration: 738 }, { name: 'aten::ones_like', @@ -712,7 +712,7 @@ export class MockAPI { host_duration: 516, device_duration: 2, self_host_duration: 142, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::nll_loss_backward', @@ -720,7 +720,7 @@ export class MockAPI { host_duration: 573, device_duration: 8, self_host_duration: 310, - self_device_duration: 6, + self_device_duration: 6 }, { name: 'NllLossBackward0', @@ -728,7 +728,7 @@ export class MockAPI { host_duration: 774, device_duration: 8, self_host_duration: 201, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'autograd::engine::evaluate_function: NllLossBackward0', @@ -736,7 +736,7 @@ export class MockAPI { host_duration: 1025, device_duration: 8, self_host_duration: 251, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::_log_softmax_backward_data', @@ -744,7 +744,7 @@ export class MockAPI { host_duration: 236, device_duration: 18, self_host_duration: 196, - self_device_duration: 18, + self_device_duration: 18 }, { name: 'LogSoftmaxBackward0', @@ -752,7 +752,7 @@ export class MockAPI { host_duration: 385, device_duration: 18, self_host_duration: 149, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'autograd::engine::evaluate_function: LogSoftmaxBackward0', @@ -760,7 +760,7 @@ export class MockAPI { host_duration: 632, device_duration: 18, self_host_duration: 247, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::mm', @@ -768,7 +768,7 @@ export class MockAPI { host_duration: 668, device_duration: 140, self_host_duration: 547, - self_device_duration: 140, + self_device_duration: 140 }, { name: 'AddmmBackward0', @@ -776,7 +776,7 @@ export class MockAPI { host_duration: 1698, device_duration: 140, self_host_duration: 417, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::sum', @@ -784,7 +784,7 @@ export class MockAPI { host_duration: 370, device_duration: 15, self_host_duration: 328, - self_device_duration: 15, + self_device_duration: 15 }, { name: 'autograd::engine::evaluate_function: AddmmBackward0', @@ -792,7 +792,7 @@ export class MockAPI { host_duration: 2710, device_duration: 155, self_host_duration: 567, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'torch::autograd::AccumulateGrad', @@ -800,15 +800,16 @@ export class MockAPI { host_duration: 41184, device_duration: 997, self_host_duration: 16159, - self_device_duration: 0, + self_device_duration: 0 }, { - name: 'autograd::engine::evaluate_function: torch::autograd::AccumulateGrad', + name: + 'autograd::engine::evaluate_function: torch::autograd::AccumulateGrad', calls: 322, host_duration: 70946, device_duration: 997, self_host_duration: 29762, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'TBackward0', @@ -816,7 +817,7 @@ export class MockAPI { host_duration: 280, device_duration: 0, self_host_duration: 64, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'autograd::engine::evaluate_function: TBackward0', @@ -824,7 +825,7 @@ export class MockAPI { host_duration: 428, device_duration: 0, self_host_duration: 148, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::reshape', @@ -832,7 +833,7 @@ export class MockAPI { host_duration: 170, device_duration: 0, self_host_duration: 104, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'ReshapeAliasBackward0', @@ -840,7 +841,7 @@ export class MockAPI { host_duration: 264, device_duration: 0, self_host_duration: 94, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'autograd::engine::evaluate_function: ReshapeAliasBackward0', @@ -848,7 +849,7 @@ export class MockAPI { host_duration: 402, device_duration: 0, self_host_duration: 138, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'MeanBackward1', @@ -856,7 +857,7 @@ export class MockAPI { host_duration: 1036, device_duration: 75, self_host_duration: 231, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'autograd::engine::evaluate_function: MeanBackward1', @@ -864,7 +865,7 @@ export class MockAPI { host_duration: 1254, device_duration: 75, self_host_duration: 218, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::threshold_backward', @@ -872,7 +873,7 @@ export class MockAPI { host_duration: 13838, device_duration: 17984, self_host_duration: 12131, - self_device_duration: 17984, + self_device_duration: 17984 }, { name: 'ReluBackward0', @@ -880,7 +881,7 @@ export class MockAPI { host_duration: 21183, device_duration: 17984, self_host_duration: 7345, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'autograd::engine::evaluate_function: ReluBackward0', @@ -888,7 +889,7 @@ export class MockAPI { host_duration: 33492, device_duration: 17984, self_host_duration: 12309, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'AddBackward0', @@ -896,7 +897,7 @@ export class MockAPI { host_duration: 251, device_duration: 0, self_host_duration: 251, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'autograd::engine::evaluate_function: AddBackward0', @@ -904,7 +905,7 @@ export class MockAPI { host_duration: 2579, device_duration: 0, self_host_duration: 2328, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::cudnn_batch_norm_backward', @@ -912,7 +913,7 @@ export class MockAPI { host_duration: 62175, device_duration: 44433, self_host_duration: 36053, - self_device_duration: 44433, + self_device_duration: 44433 }, { name: 'CudnnBatchNormBackward0', @@ -920,15 +921,16 @@ export class MockAPI { host_duration: 69160, device_duration: 44433, self_host_duration: 6985, - self_device_duration: 0, + self_device_duration: 0 }, { - name: 'autograd::engine::evaluate_function: CudnnBatchNormBackward0', + name: + 'autograd::engine::evaluate_function: CudnnBatchNormBackward0', calls: 106, host_duration: 88613, device_duration: 44433, self_host_duration: 19453, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::cudnn_convolution_backward_input', @@ -936,7 +938,7 @@ export class MockAPI { host_duration: 40820, device_duration: 76620, self_host_duration: 30768, - self_device_duration: 76620, + self_device_duration: 76620 }, { name: 'aten::cudnn_convolution_backward_weight', @@ -944,7 +946,7 @@ export class MockAPI { host_duration: 44875, device_duration: 90108, self_host_duration: 27458, - self_device_duration: 90108, + self_device_duration: 90108 }, { name: 'aten::cudnn_convolution_backward', @@ -952,7 +954,7 @@ export class MockAPI { host_duration: 101020, device_duration: 166728, self_host_duration: 15325, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'CudnnConvolutionBackward0', @@ -960,15 +962,16 @@ export class MockAPI { host_duration: 107964, device_duration: 166728, self_host_duration: 6944, - self_device_duration: 0, + self_device_duration: 0 }, { - name: 'autograd::engine::evaluate_function: CudnnConvolutionBackward0', + name: + 'autograd::engine::evaluate_function: CudnnConvolutionBackward0', calls: 106, host_duration: 129129, device_duration: 177161, self_host_duration: 16746, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::max_pool2d_with_indices_backward', @@ -976,7 +979,7 @@ export class MockAPI { host_duration: 483, device_duration: 3048, self_host_duration: 257, - self_device_duration: 2588, + self_device_duration: 2588 }, { name: 'MaxPool2DWithIndicesBackward0', @@ -984,15 +987,16 @@ export class MockAPI { host_duration: 599, device_duration: 3048, self_host_duration: 116, - self_device_duration: 0, + self_device_duration: 0 }, { - name: 'autograd::engine::evaluate_function: MaxPool2DWithIndicesBackward0', + name: + 'autograd::engine::evaluate_function: MaxPool2DWithIndicesBackward0', calls: 2, host_duration: 836, device_duration: 3048, self_host_duration: 237, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::mul_', @@ -1000,9 +1004,9 @@ export class MockAPI { host_duration: 23818, device_duration: 797, self_host_duration: 19073, - self_device_duration: 797, - }, - ], + self_device_duration: 797 + } + ] }, right: { name: 'multiple nodes', @@ -1016,7 +1020,7 @@ export class MockAPI { host_duration: 31594, device_duration: 0, self_host_duration: 31594, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::zero_', @@ -1024,7 +1028,7 @@ export class MockAPI { host_duration: 6010, device_duration: 864, self_host_duration: 1910, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::zeros', @@ -1032,7 +1036,7 @@ export class MockAPI { host_duration: 10338, device_duration: 0, self_host_duration: 2951, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::to', @@ -1040,7 +1044,7 @@ export class MockAPI { host_duration: 47031, device_duration: 8684, self_host_duration: 4258, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'detach', @@ -1048,7 +1052,7 @@ export class MockAPI { host_duration: 701, device_duration: 0, self_host_duration: 698, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::detach', @@ -1056,7 +1060,7 @@ export class MockAPI { host_duration: 1374, device_duration: 0, self_host_duration: 676, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::as_strided', @@ -1064,7 +1068,7 @@ export class MockAPI { host_duration: 1013, device_duration: 0, self_host_duration: 1013, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::unsqueeze', @@ -1072,7 +1076,7 @@ export class MockAPI { host_duration: 2074, device_duration: 0, self_host_duration: 1723, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::empty_strided', @@ -1080,7 +1084,7 @@ export class MockAPI { host_duration: 6859, device_duration: 0, self_host_duration: 6859, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::copy_', @@ -1088,7 +1092,7 @@ export class MockAPI { host_duration: 25248, device_duration: 8684, self_host_duration: 16166, - self_device_duration: 8684, + self_device_duration: 8684 }, { name: 'aten::_to_copy', @@ -1096,7 +1100,7 @@ export class MockAPI { host_duration: 42773, device_duration: 8684, self_host_duration: 10227, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::upsample_bilinear2d', @@ -1104,7 +1108,7 @@ export class MockAPI { host_duration: 51788, device_duration: 0, self_host_duration: 46788, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::squeeze', @@ -1112,7 +1116,7 @@ export class MockAPI { host_duration: 1035, device_duration: 0, self_host_duration: 895, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::round', @@ -1120,7 +1124,7 @@ export class MockAPI { host_duration: 11074, device_duration: 0, self_host_duration: 11074, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::slice', @@ -1128,7 +1132,7 @@ export class MockAPI { host_duration: 1892, device_duration: 0, self_host_duration: 1600, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'detach_', @@ -1136,7 +1140,7 @@ export class MockAPI { host_duration: 278, device_duration: 0, self_host_duration: 244, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::detach_', @@ -1144,7 +1148,7 @@ export class MockAPI { host_duration: 1341, device_duration: 0, self_host_duration: 1097, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::result_type', @@ -1152,7 +1156,7 @@ export class MockAPI { host_duration: 317, device_duration: 0, self_host_duration: 317, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::pow', @@ -1160,7 +1164,7 @@ export class MockAPI { host_duration: 8857, device_duration: 0, self_host_duration: 7959, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::sub', @@ -1168,7 +1172,7 @@ export class MockAPI { host_duration: 17840, device_duration: 0, self_host_duration: 7688, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::gt', @@ -1176,7 +1180,7 @@ export class MockAPI { host_duration: 6903, device_duration: 0, self_host_duration: 4901, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::_local_scalar_dense', @@ -1184,7 +1188,7 @@ export class MockAPI { host_duration: 395, device_duration: 0, self_host_duration: 395, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::item', @@ -1192,7 +1196,7 @@ export class MockAPI { host_duration: 2532, device_duration: 0, self_host_duration: 2130, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::is_nonzero', @@ -1200,7 +1204,7 @@ export class MockAPI { host_duration: 3601, device_duration: 0, self_host_duration: 1427, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::div', @@ -1208,7 +1212,7 @@ export class MockAPI { host_duration: 11707, device_duration: 75, self_host_duration: 9531, - self_device_duration: 75, + self_device_duration: 75 }, { name: 'aten::resize_', @@ -1216,7 +1220,7 @@ export class MockAPI { host_duration: 79, device_duration: 0, self_host_duration: 79, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::narrow', @@ -1224,7 +1228,7 @@ export class MockAPI { host_duration: 37, device_duration: 0, self_host_duration: 16, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::_cat', @@ -1232,7 +1236,7 @@ export class MockAPI { host_duration: 9241, device_duration: 0, self_host_duration: 9113, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::cat', @@ -1240,7 +1244,7 @@ export class MockAPI { host_duration: 9286, device_duration: 0, self_host_duration: 45, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::stack', @@ -1248,7 +1252,7 @@ export class MockAPI { host_duration: 16195, device_duration: 0, self_host_duration: 6105, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::cudnn_convolution', @@ -1256,7 +1260,7 @@ export class MockAPI { host_duration: 17357, device_duration: 71414, self_host_duration: 13601, - self_device_duration: 71414, + self_device_duration: 71414 }, { name: 'aten::_convolution', @@ -1264,7 +1268,7 @@ export class MockAPI { host_duration: 18514, device_duration: 71414, self_host_duration: 1157, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::convolution', @@ -1272,7 +1276,7 @@ export class MockAPI { host_duration: 19185, device_duration: 71414, self_host_duration: 671, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::conv2d', @@ -1280,7 +1284,7 @@ export class MockAPI { host_duration: 19750, device_duration: 71414, self_host_duration: 565, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::add', @@ -1288,7 +1292,7 @@ export class MockAPI { host_duration: 4973, device_duration: 10567, self_host_duration: 3157, - self_device_duration: 10567, + self_device_duration: 10567 }, { name: 'aten::empty_like', @@ -1296,7 +1300,7 @@ export class MockAPI { host_duration: 1924, device_duration: 0, self_host_duration: 598, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::view', @@ -1304,7 +1308,7 @@ export class MockAPI { host_duration: 596, device_duration: 0, self_host_duration: 596, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::cudnn_batch_norm', @@ -1312,7 +1316,7 @@ export class MockAPI { host_duration: 11083, device_duration: 25737, self_host_duration: 5031, - self_device_duration: 25737, + self_device_duration: 25737 }, { name: 'aten::_batch_norm_impl_index', @@ -1320,7 +1324,7 @@ export class MockAPI { host_duration: 11856, device_duration: 25737, self_host_duration: 773, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::batch_norm', @@ -1328,7 +1332,7 @@ export class MockAPI { host_duration: 12386, device_duration: 25737, self_host_duration: 530, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::clamp_min', @@ -1336,7 +1340,7 @@ export class MockAPI { host_duration: 2189, device_duration: 12010, self_host_duration: 1030, - self_device_duration: 12010, + self_device_duration: 12010 }, { name: 'aten::clamp_min_', @@ -1344,7 +1348,7 @@ export class MockAPI { host_duration: 2614, device_duration: 12010, self_host_duration: 425, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::relu_', @@ -1352,7 +1356,7 @@ export class MockAPI { host_duration: 3880, device_duration: 12010, self_host_duration: 1266, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::max_pool2d_with_indices', @@ -1360,7 +1364,7 @@ export class MockAPI { host_duration: 112, device_duration: 938, self_host_duration: 82, - self_device_duration: 938, + self_device_duration: 938 }, { name: 'aten::max_pool2d', @@ -1368,7 +1372,7 @@ export class MockAPI { host_duration: 127, device_duration: 938, self_host_duration: 15, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::add_', @@ -1376,7 +1380,7 @@ export class MockAPI { host_duration: 21459, device_duration: 13178, self_host_duration: 11041, - self_device_duration: 13178, + self_device_duration: 13178 }, { name: 'aten::mean', @@ -1384,7 +1388,7 @@ export class MockAPI { host_duration: 104, device_duration: 126, self_host_duration: 76, - self_device_duration: 126, + self_device_duration: 126 }, { name: 'aten::adaptive_avg_pool2d', @@ -1392,7 +1396,7 @@ export class MockAPI { host_duration: 117, device_duration: 126, self_host_duration: 13, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::_reshape_alias', @@ -1400,7 +1404,7 @@ export class MockAPI { host_duration: 26, device_duration: 0, self_host_duration: 26, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::flatten', @@ -1408,7 +1412,7 @@ export class MockAPI { host_duration: 31, device_duration: 0, self_host_duration: 15, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::transpose', @@ -1416,7 +1420,7 @@ export class MockAPI { host_duration: 85, device_duration: 0, self_host_duration: 68, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::t', @@ -1424,7 +1428,7 @@ export class MockAPI { host_duration: 145, device_duration: 0, self_host_duration: 60, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::expand', @@ -1432,7 +1436,7 @@ export class MockAPI { host_duration: 30, device_duration: 0, self_host_duration: 25, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::addmm', @@ -1440,7 +1444,7 @@ export class MockAPI { host_duration: 334, device_duration: 84, self_host_duration: 234, - self_device_duration: 84, + self_device_duration: 84 }, { name: 'aten::linear', @@ -1448,7 +1452,7 @@ export class MockAPI { host_duration: 386, device_duration: 84, self_host_duration: 19, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::_log_softmax', @@ -1456,7 +1460,7 @@ export class MockAPI { host_duration: 83, device_duration: 14, self_host_duration: 55, - self_device_duration: 14, + self_device_duration: 14 }, { name: 'aten::log_softmax', @@ -1464,7 +1468,7 @@ export class MockAPI { host_duration: 106, device_duration: 14, self_host_duration: 20, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::nll_loss_forward', @@ -1472,7 +1476,7 @@ export class MockAPI { host_duration: 96, device_duration: 8, self_host_duration: 68, - self_device_duration: 8, + self_device_duration: 8 }, { name: 'aten::nll_loss', @@ -1480,7 +1484,7 @@ export class MockAPI { host_duration: 105, device_duration: 8, self_host_duration: 9, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::nll_loss_nd', @@ -1488,7 +1492,7 @@ export class MockAPI { host_duration: 113, device_duration: 8, self_host_duration: 8, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::cross_entropy_loss', @@ -1496,7 +1500,7 @@ export class MockAPI { host_duration: 243, device_duration: 22, self_host_duration: 24, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::fill_', @@ -1504,7 +1508,7 @@ export class MockAPI { host_duration: 4140, device_duration: 866, self_host_duration: 1851, - self_device_duration: 866, + self_device_duration: 866 }, { name: 'aten::ones_like', @@ -1512,7 +1516,7 @@ export class MockAPI { host_duration: 104, device_duration: 2, self_host_duration: 14, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::nll_loss_backward', @@ -1520,7 +1524,7 @@ export class MockAPI { host_duration: 192, device_duration: 9, self_host_duration: 84, - self_device_duration: 6, + self_device_duration: 6 }, { name: 'NllLossBackward0', @@ -1528,7 +1532,7 @@ export class MockAPI { host_duration: 297, device_duration: 9, self_host_duration: 105, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'autograd::engine::evaluate_function: NllLossBackward0', @@ -1536,7 +1540,7 @@ export class MockAPI { host_duration: 352, device_duration: 9, self_host_duration: 55, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::_log_softmax_backward_data', @@ -1544,7 +1548,7 @@ export class MockAPI { host_duration: 71, device_duration: 18, self_host_duration: 43, - self_device_duration: 18, + self_device_duration: 18 }, { name: 'LogSoftmaxBackward0', @@ -1552,7 +1556,7 @@ export class MockAPI { host_duration: 91, device_duration: 18, self_host_duration: 20, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'autograd::engine::evaluate_function: LogSoftmaxBackward0', @@ -1560,7 +1564,7 @@ export class MockAPI { host_duration: 126, device_duration: 18, self_host_duration: 35, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::mm', @@ -1568,7 +1572,7 @@ export class MockAPI { host_duration: 283, device_duration: 134, self_host_duration: 186, - self_device_duration: 134, + self_device_duration: 134 }, { name: 'AddmmBackward0', @@ -1576,7 +1580,7 @@ export class MockAPI { host_duration: 418, device_duration: 134, self_host_duration: 47, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::sum', @@ -1584,7 +1588,7 @@ export class MockAPI { host_duration: 92, device_duration: 14, self_host_duration: 62, - self_device_duration: 14, + self_device_duration: 14 }, { name: 'autograd::engine::evaluate_function: AddmmBackward0', @@ -1592,7 +1596,7 @@ export class MockAPI { host_duration: 594, device_duration: 148, self_host_duration: 75, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'torch::autograd::AccumulateGrad', @@ -1600,15 +1604,16 @@ export class MockAPI { host_duration: 10317, device_duration: 1069, self_host_duration: 2127, - self_device_duration: 0, + self_device_duration: 0 }, { - name: 'autograd::engine::evaluate_function: torch::autograd::AccumulateGrad', + name: + 'autograd::engine::evaluate_function: torch::autograd::AccumulateGrad', calls: 322, host_duration: 15128, device_duration: 1069, self_host_duration: 4811, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'TBackward0', @@ -1616,7 +1621,7 @@ export class MockAPI { host_duration: 30, device_duration: 0, self_host_duration: 6, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'autograd::engine::evaluate_function: TBackward0', @@ -1624,7 +1629,7 @@ export class MockAPI { host_duration: 45, device_duration: 0, self_host_duration: 15, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::reshape', @@ -1632,7 +1637,7 @@ export class MockAPI { host_duration: 20, device_duration: 0, self_host_duration: 10, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'ReshapeAliasBackward0', @@ -1640,7 +1645,7 @@ export class MockAPI { host_duration: 31, device_duration: 0, self_host_duration: 11, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'autograd::engine::evaluate_function: ReshapeAliasBackward0', @@ -1648,7 +1653,7 @@ export class MockAPI { host_duration: 48, device_duration: 0, self_host_duration: 17, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'MeanBackward1', @@ -1656,7 +1661,7 @@ export class MockAPI { host_duration: 172, device_duration: 75, self_host_duration: 18, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'autograd::engine::evaluate_function: MeanBackward1', @@ -1664,7 +1669,7 @@ export class MockAPI { host_duration: 201, device_duration: 75, self_host_duration: 29, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::threshold_backward', @@ -1672,7 +1677,7 @@ export class MockAPI { host_duration: 3652, device_duration: 18018, self_host_duration: 2361, - self_device_duration: 18018, + self_device_duration: 18018 }, { name: 'ReluBackward0', @@ -1680,7 +1685,7 @@ export class MockAPI { host_duration: 4567, device_duration: 18018, self_host_duration: 915, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'autograd::engine::evaluate_function: ReluBackward0', @@ -1688,7 +1693,7 @@ export class MockAPI { host_duration: 6457, device_duration: 18018, self_host_duration: 1890, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'AddBackward0', @@ -1696,7 +1701,7 @@ export class MockAPI { host_duration: 26, device_duration: 0, self_host_duration: 26, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'autograd::engine::evaluate_function: AddBackward0', @@ -1704,7 +1709,7 @@ export class MockAPI { host_duration: 261, device_duration: 0, self_host_duration: 235, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::cudnn_batch_norm_backward', @@ -1712,7 +1717,7 @@ export class MockAPI { host_duration: 9943, device_duration: 44401, self_host_duration: 4355, - self_device_duration: 44401, + self_device_duration: 44401 }, { name: 'CudnnBatchNormBackward0', @@ -1720,15 +1725,16 @@ export class MockAPI { host_duration: 11132, device_duration: 44401, self_host_duration: 1189, - self_device_duration: 0, + self_device_duration: 0 }, { - name: 'autograd::engine::evaluate_function: CudnnBatchNormBackward0', + name: + 'autograd::engine::evaluate_function: CudnnBatchNormBackward0', calls: 106, host_duration: 14696, device_duration: 44401, self_host_duration: 3564, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::cudnn_convolution_backward_input', @@ -1736,7 +1742,7 @@ export class MockAPI { host_duration: 18813, device_duration: 75568, self_host_duration: 13997, - self_device_duration: 75568, + self_device_duration: 75568 }, { name: 'aten::cudnn_convolution_backward_weight', @@ -1744,7 +1750,7 @@ export class MockAPI { host_duration: 18792, device_duration: 88992, self_host_duration: 11101, - self_device_duration: 88992, + self_device_duration: 88992 }, { name: 'aten::cudnn_convolution_backward', @@ -1752,7 +1758,7 @@ export class MockAPI { host_duration: 40064, device_duration: 164560, self_host_duration: 2459, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'CudnnConvolutionBackward0', @@ -1760,15 +1766,16 @@ export class MockAPI { host_duration: 41205, device_duration: 164560, self_host_duration: 1141, - self_device_duration: 0, + self_device_duration: 0 }, { - name: 'autograd::engine::evaluate_function: CudnnConvolutionBackward0', + name: + 'autograd::engine::evaluate_function: CudnnConvolutionBackward0', calls: 106, host_duration: 45209, device_duration: 175014, self_host_duration: 2826, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::max_pool2d_with_indices_backward', @@ -1776,7 +1783,7 @@ export class MockAPI { host_duration: 145, device_duration: 3016, self_host_duration: 61, - self_device_duration: 2556, + self_device_duration: 2556 }, { name: 'MaxPool2DWithIndicesBackward0', @@ -1784,15 +1791,16 @@ export class MockAPI { host_duration: 165, device_duration: 3016, self_host_duration: 20, - self_device_duration: 0, + self_device_duration: 0 }, { - name: 'autograd::engine::evaluate_function: MaxPool2DWithIndicesBackward0', + name: + 'autograd::engine::evaluate_function: MaxPool2DWithIndicesBackward0', calls: 2, host_duration: 209, device_duration: 3016, self_host_duration: 44, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::mul_', @@ -1800,9 +1808,9 @@ export class MockAPI { host_duration: 6835, device_duration: 803, self_host_duration: 3630, - self_device_duration: 803, - }, - ], + self_device_duration: 803 + } + ] }, path: '0', children: [ @@ -1819,7 +1827,7 @@ export class MockAPI { host_duration: 100, device_duration: 0, self_host_duration: 100, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::zero_', @@ -1827,7 +1835,7 @@ export class MockAPI { host_duration: 4, device_duration: 0, self_host_duration: 4, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::zeros', @@ -1835,9 +1843,9 @@ export class MockAPI { host_duration: 119, device_duration: 0, self_host_duration: 64, - self_device_duration: 0, - }, - ], + self_device_duration: 0 + } + ] }, right: { name: 'multiple nodes', @@ -1851,7 +1859,7 @@ export class MockAPI { host_duration: 17, device_duration: 0, self_host_duration: 17, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::zero_', @@ -1859,7 +1867,7 @@ export class MockAPI { host_duration: 1, device_duration: 0, self_host_duration: 1, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::zeros', @@ -1867,11 +1875,11 @@ export class MockAPI { host_duration: 15, device_duration: 0, self_host_duration: 6, - self_device_duration: 0, - }, - ], + self_device_duration: 0 + } + ] }, - path: '0-0', + path: '0-0' }, { left: { @@ -1886,7 +1894,7 @@ export class MockAPI { host_duration: 62288, device_duration: 0, self_host_duration: 62288, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::zero_', @@ -1894,7 +1902,7 @@ export class MockAPI { host_duration: 959, device_duration: 0, self_host_duration: 959, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::zeros', @@ -1902,7 +1910,7 @@ export class MockAPI { host_duration: 35273, device_duration: 0, self_host_duration: 16154, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::to', @@ -1910,7 +1918,7 @@ export class MockAPI { host_duration: 877101, device_duration: 0, self_host_duration: 18482, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'detach', @@ -1918,7 +1926,7 @@ export class MockAPI { host_duration: 2191, device_duration: 0, self_host_duration: 2191, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::detach', @@ -1926,7 +1934,7 @@ export class MockAPI { host_duration: 5301, device_duration: 0, self_host_duration: 3110, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::as_strided', @@ -1934,7 +1942,7 @@ export class MockAPI { host_duration: 4175, device_duration: 0, self_host_duration: 4175, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::unsqueeze', @@ -1942,7 +1950,7 @@ export class MockAPI { host_duration: 9560, device_duration: 0, self_host_duration: 8045, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::empty_strided', @@ -1950,7 +1958,7 @@ export class MockAPI { host_duration: 24689, device_duration: 0, self_host_duration: 24689, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::copy_', @@ -1958,7 +1966,7 @@ export class MockAPI { host_duration: 780214, device_duration: 0, self_host_duration: 780214, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::_to_copy', @@ -1966,7 +1974,7 @@ export class MockAPI { host_duration: 858619, device_duration: 0, self_host_duration: 53009, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::upsample_bilinear2d', @@ -1974,7 +1982,7 @@ export class MockAPI { host_duration: 224031, device_duration: 0, self_host_duration: 204660, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::squeeze', @@ -1982,7 +1990,7 @@ export class MockAPI { host_duration: 4719, device_duration: 0, self_host_duration: 4119, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::round', @@ -1990,7 +1998,7 @@ export class MockAPI { host_duration: 16028, device_duration: 0, self_host_duration: 16028, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::slice', @@ -1998,7 +2006,7 @@ export class MockAPI { host_duration: 8918, device_duration: 0, self_host_duration: 7569, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'detach_', @@ -2006,7 +2014,7 @@ export class MockAPI { host_duration: 2092, device_duration: 0, self_host_duration: 2092, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::detach_', @@ -2014,7 +2022,7 @@ export class MockAPI { host_duration: 7228, device_duration: 0, self_host_duration: 5136, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::result_type', @@ -2022,7 +2030,7 @@ export class MockAPI { host_duration: 884, device_duration: 0, self_host_duration: 884, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::pow', @@ -2030,7 +2038,7 @@ export class MockAPI { host_duration: 43030, device_duration: 0, self_host_duration: 39068, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::sub', @@ -2038,7 +2046,7 @@ export class MockAPI { host_duration: 91440, device_duration: 0, self_host_duration: 37676, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::gt', @@ -2046,7 +2054,7 @@ export class MockAPI { host_duration: 35514, device_duration: 0, self_host_duration: 24706, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::_local_scalar_dense', @@ -2054,7 +2062,7 @@ export class MockAPI { host_duration: 2467, device_duration: 0, self_host_duration: 2467, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::item', @@ -2062,7 +2070,7 @@ export class MockAPI { host_duration: 10375, device_duration: 0, self_host_duration: 7908, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::is_nonzero', @@ -2070,7 +2078,7 @@ export class MockAPI { host_duration: 13905, device_duration: 0, self_host_duration: 5383, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::div', @@ -2078,7 +2086,7 @@ export class MockAPI { host_duration: 87841, device_duration: 0, self_host_duration: 76794, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::resize_', @@ -2086,7 +2094,7 @@ export class MockAPI { host_duration: 117, device_duration: 0, self_host_duration: 117, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::narrow', @@ -2094,7 +2102,7 @@ export class MockAPI { host_duration: 142, device_duration: 0, self_host_duration: 51, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::_cat', @@ -2102,7 +2110,7 @@ export class MockAPI { host_duration: 51526, device_duration: 0, self_host_duration: 51229, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::cat', @@ -2110,7 +2118,7 @@ export class MockAPI { host_duration: 51674, device_duration: 0, self_host_duration: 148, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::stack', @@ -2118,9 +2126,9 @@ export class MockAPI { host_duration: 75677, device_duration: 0, self_host_duration: 19330, - self_device_duration: 0, - }, - ], + self_device_duration: 0 + } + ] }, right: { name: 'enumerate(DataLoader)#_SingleProcessDataLoaderIter.__next__', @@ -2134,7 +2142,7 @@ export class MockAPI { host_duration: 12399, device_duration: 0, self_host_duration: 12399, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::zero_', @@ -2142,7 +2150,7 @@ export class MockAPI { host_duration: 98, device_duration: 0, self_host_duration: 98, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::zeros', @@ -2150,7 +2158,7 @@ export class MockAPI { host_duration: 7665, device_duration: 0, self_host_duration: 1689, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::to', @@ -2158,7 +2166,7 @@ export class MockAPI { host_duration: 21137, device_duration: 0, self_host_duration: 2377, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'detach', @@ -2166,7 +2174,7 @@ export class MockAPI { host_duration: 364, device_duration: 0, self_host_duration: 361, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::detach', @@ -2174,7 +2182,7 @@ export class MockAPI { host_duration: 745, device_duration: 0, self_host_duration: 384, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::as_strided', @@ -2182,7 +2190,7 @@ export class MockAPI { host_duration: 527, device_duration: 0, self_host_duration: 527, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::unsqueeze', @@ -2190,7 +2198,7 @@ export class MockAPI { host_duration: 1050, device_duration: 0, self_host_duration: 869, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::empty_strided', @@ -2198,7 +2206,7 @@ export class MockAPI { host_duration: 3689, device_duration: 0, self_host_duration: 3689, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::copy_', @@ -2206,7 +2214,7 @@ export class MockAPI { host_duration: 8695, device_duration: 0, self_host_duration: 8695, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::_to_copy', @@ -2214,7 +2222,7 @@ export class MockAPI { host_duration: 18760, device_duration: 0, self_host_duration: 6122, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::upsample_bilinear2d', @@ -2222,7 +2230,7 @@ export class MockAPI { host_duration: 20349, device_duration: 0, self_host_duration: 17634, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::squeeze', @@ -2230,7 +2238,7 @@ export class MockAPI { host_duration: 562, device_duration: 0, self_host_duration: 487, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::round', @@ -2238,7 +2246,7 @@ export class MockAPI { host_duration: 6658, device_duration: 0, self_host_duration: 6658, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::slice', @@ -2246,7 +2254,7 @@ export class MockAPI { host_duration: 1028, device_duration: 0, self_host_duration: 870, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'detach_', @@ -2254,7 +2262,7 @@ export class MockAPI { host_duration: 142, device_duration: 0, self_host_duration: 129, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::detach_', @@ -2262,7 +2270,7 @@ export class MockAPI { host_duration: 755, device_duration: 0, self_host_duration: 626, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::result_type', @@ -2270,7 +2278,7 @@ export class MockAPI { host_duration: 168, device_duration: 0, self_host_duration: 168, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::pow', @@ -2278,7 +2286,7 @@ export class MockAPI { host_duration: 4922, device_duration: 0, self_host_duration: 4440, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::sub', @@ -2286,7 +2294,7 @@ export class MockAPI { host_duration: 9959, device_duration: 0, self_host_duration: 4339, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::gt', @@ -2294,7 +2302,7 @@ export class MockAPI { host_duration: 3848, device_duration: 0, self_host_duration: 2737, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::_local_scalar_dense', @@ -2302,7 +2310,7 @@ export class MockAPI { host_duration: 209, device_duration: 0, self_host_duration: 209, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::item', @@ -2310,7 +2318,7 @@ export class MockAPI { host_duration: 1398, device_duration: 0, self_host_duration: 1187, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::is_nonzero', @@ -2318,7 +2326,7 @@ export class MockAPI { host_duration: 2013, device_duration: 0, self_host_duration: 812, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::div', @@ -2326,7 +2334,7 @@ export class MockAPI { host_duration: 7421, device_duration: 0, self_host_duration: 6234, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::resize_', @@ -2334,7 +2342,7 @@ export class MockAPI { host_duration: 36, device_duration: 0, self_host_duration: 36, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::narrow', @@ -2342,7 +2350,7 @@ export class MockAPI { host_duration: 19, device_duration: 0, self_host_duration: 9, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::_cat', @@ -2350,7 +2358,7 @@ export class MockAPI { host_duration: 4628, device_duration: 0, self_host_duration: 4566, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::cat', @@ -2358,7 +2366,7 @@ export class MockAPI { host_duration: 4649, device_duration: 0, self_host_duration: 21, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::stack', @@ -2366,11 +2374,11 @@ export class MockAPI { host_duration: 10884, device_duration: 0, self_host_duration: 5859, - self_device_duration: 0, - }, - ], + self_device_duration: 0 + } + ] }, - path: '0-1', + path: '0-1' }, { left: { @@ -2385,7 +2393,7 @@ export class MockAPI { host_duration: 209, device_duration: 0, self_host_duration: 209, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::copy_', @@ -2393,7 +2401,7 @@ export class MockAPI { host_duration: 4696, device_duration: 4402, self_host_duration: 93, - self_device_duration: 4402, + self_device_duration: 4402 }, { name: 'aten::_to_copy', @@ -2401,7 +2409,7 @@ export class MockAPI { host_duration: 5111, device_duration: 4402, self_host_duration: 206, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::to', @@ -2409,9 +2417,9 @@ export class MockAPI { host_duration: 5170, device_duration: 4402, self_host_duration: 59, - self_device_duration: 0, - }, - ], + self_device_duration: 0 + } + ] }, right: { name: 'multiple nodes', @@ -2425,7 +2433,7 @@ export class MockAPI { host_duration: 65, device_duration: 0, self_host_duration: 65, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::copy_', @@ -2433,7 +2441,7 @@ export class MockAPI { host_duration: 4575, device_duration: 4350, self_host_duration: 26, - self_device_duration: 4350, + self_device_duration: 4350 }, { name: 'aten::_to_copy', @@ -2441,7 +2449,7 @@ export class MockAPI { host_duration: 4670, device_duration: 4350, self_host_duration: 30, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::to', @@ -2449,11 +2457,11 @@ export class MockAPI { host_duration: 4681, device_duration: 4350, self_host_duration: 11, - self_device_duration: 0, - }, - ], + self_device_duration: 0 + } + ] }, - path: '0-2', + path: '0-2' }, { left: { @@ -2468,7 +2476,7 @@ export class MockAPI { host_duration: 14161, device_duration: 0, self_host_duration: 14161, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::cudnn_convolution', @@ -2476,7 +2484,7 @@ export class MockAPI { host_duration: 22091, device_duration: 36599, self_host_duration: 17567, - self_device_duration: 36599, + self_device_duration: 36599 }, { name: 'aten::_convolution', @@ -2484,7 +2492,7 @@ export class MockAPI { host_duration: 25744, device_duration: 36599, self_host_duration: 3653, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::convolution', @@ -2492,7 +2500,7 @@ export class MockAPI { host_duration: 27753, device_duration: 36599, self_host_duration: 2009, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::conv2d', @@ -2500,7 +2508,7 @@ export class MockAPI { host_duration: 29777, device_duration: 36599, self_host_duration: 2024, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::add', @@ -2508,7 +2516,7 @@ export class MockAPI { host_duration: 6519, device_duration: 54, self_host_duration: 5666, - self_device_duration: 54, + self_device_duration: 54 }, { name: 'aten::empty_like', @@ -2516,7 +2524,7 @@ export class MockAPI { host_duration: 5624, device_duration: 0, self_host_duration: 2390, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::view', @@ -2524,7 +2532,7 @@ export class MockAPI { host_duration: 826, device_duration: 0, self_host_duration: 826, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::cudnn_batch_norm', @@ -2532,7 +2540,7 @@ export class MockAPI { host_duration: 35818, device_duration: 12974, self_host_duration: 20557, - self_device_duration: 12974, + self_device_duration: 12974 }, { name: 'aten::_batch_norm_impl_index', @@ -2540,7 +2548,7 @@ export class MockAPI { host_duration: 38324, device_duration: 12974, self_host_duration: 2506, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::batch_norm', @@ -2548,7 +2556,7 @@ export class MockAPI { host_duration: 40105, device_duration: 12974, self_host_duration: 1781, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::clamp_min', @@ -2556,7 +2564,7 @@ export class MockAPI { host_duration: 2702, device_duration: 6002, self_host_duration: 1935, - self_device_duration: 6002, + self_device_duration: 6002 }, { name: 'aten::clamp_min_', @@ -2564,7 +2572,7 @@ export class MockAPI { host_duration: 4273, device_duration: 6002, self_host_duration: 1571, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::relu_', @@ -2572,7 +2580,7 @@ export class MockAPI { host_duration: 8371, device_duration: 6002, self_host_duration: 4098, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::max_pool2d_with_indices', @@ -2580,7 +2588,7 @@ export class MockAPI { host_duration: 230, device_duration: 474, self_host_duration: 212, - self_device_duration: 474, + self_device_duration: 474 }, { name: 'aten::max_pool2d', @@ -2588,7 +2596,7 @@ export class MockAPI { host_duration: 280, device_duration: 474, self_host_duration: 50, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::add_', @@ -2596,7 +2604,7 @@ export class MockAPI { host_duration: 1546, device_duration: 5141, self_host_duration: 1290, - self_device_duration: 5141, + self_device_duration: 5141 }, { name: 'aten::mean', @@ -2604,7 +2612,7 @@ export class MockAPI { host_duration: 189, device_duration: 69, self_host_duration: 170, - self_device_duration: 69, + self_device_duration: 69 }, { name: 'aten::adaptive_avg_pool2d', @@ -2612,7 +2620,7 @@ export class MockAPI { host_duration: 234, device_duration: 69, self_host_duration: 45, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::_reshape_alias', @@ -2620,7 +2628,7 @@ export class MockAPI { host_duration: 52, device_duration: 0, self_host_duration: 52, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::flatten', @@ -2628,7 +2636,7 @@ export class MockAPI { host_duration: 106, device_duration: 0, self_host_duration: 54, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::as_strided', @@ -2636,7 +2644,7 @@ export class MockAPI { host_duration: 23, device_duration: 0, self_host_duration: 23, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::transpose', @@ -2644,7 +2652,7 @@ export class MockAPI { host_duration: 55, device_duration: 0, self_host_duration: 41, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::t', @@ -2652,7 +2660,7 @@ export class MockAPI { host_duration: 119, device_duration: 0, self_host_duration: 64, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::expand', @@ -2660,7 +2668,7 @@ export class MockAPI { host_duration: 49, device_duration: 0, self_host_duration: 40, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::addmm', @@ -2668,7 +2676,7 @@ export class MockAPI { host_duration: 404, device_duration: 43, self_host_duration: 302, - self_device_duration: 43, + self_device_duration: 43 }, { name: 'aten::linear', @@ -2676,9 +2684,9 @@ export class MockAPI { host_duration: 591, device_duration: 43, self_host_duration: 68, - self_device_duration: 0, - }, - ], + self_device_duration: 0 + } + ] }, right: { name: 'nn.Module: ResNet', @@ -2692,7 +2700,7 @@ export class MockAPI { host_duration: 2292, device_duration: 0, self_host_duration: 2292, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::cudnn_convolution', @@ -2700,7 +2708,7 @@ export class MockAPI { host_duration: 8713, device_duration: 36205, self_host_duration: 6819, - self_device_duration: 36205, + self_device_duration: 36205 }, { name: 'aten::_convolution', @@ -2708,7 +2716,7 @@ export class MockAPI { host_duration: 9298, device_duration: 36205, self_host_duration: 585, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::convolution', @@ -2716,7 +2724,7 @@ export class MockAPI { host_duration: 9653, device_duration: 36205, self_host_duration: 355, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::conv2d', @@ -2724,7 +2732,7 @@ export class MockAPI { host_duration: 9932, device_duration: 36205, self_host_duration: 279, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::add', @@ -2732,7 +2740,7 @@ export class MockAPI { host_duration: 1897, device_duration: 58, self_host_duration: 1201, - self_device_duration: 58, + self_device_duration: 58 }, { name: 'aten::empty_like', @@ -2740,7 +2748,7 @@ export class MockAPI { host_duration: 933, device_duration: 0, self_host_duration: 284, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::view', @@ -2748,7 +2756,7 @@ export class MockAPI { host_duration: 130, device_duration: 0, self_host_duration: 130, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::cudnn_batch_norm', @@ -2756,7 +2764,7 @@ export class MockAPI { host_duration: 5540, device_duration: 12913, self_host_duration: 2504, - self_device_duration: 12913, + self_device_duration: 12913 }, { name: 'aten::_batch_norm_impl_index', @@ -2764,7 +2772,7 @@ export class MockAPI { host_duration: 5942, device_duration: 12913, self_host_duration: 402, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::batch_norm', @@ -2772,7 +2780,7 @@ export class MockAPI { host_duration: 6219, device_duration: 12913, self_host_duration: 277, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::clamp_min', @@ -2780,7 +2788,7 @@ export class MockAPI { host_duration: 1108, device_duration: 6006, self_host_duration: 523, - self_device_duration: 6006, + self_device_duration: 6006 }, { name: 'aten::clamp_min_', @@ -2788,7 +2796,7 @@ export class MockAPI { host_duration: 1315, device_duration: 6006, self_host_duration: 207, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::relu_', @@ -2796,7 +2804,7 @@ export class MockAPI { host_duration: 1939, device_duration: 6006, self_host_duration: 624, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::max_pool2d_with_indices', @@ -2804,7 +2812,7 @@ export class MockAPI { host_duration: 53, device_duration: 472, self_host_duration: 38, - self_device_duration: 472, + self_device_duration: 472 }, { name: 'aten::max_pool2d', @@ -2812,7 +2820,7 @@ export class MockAPI { host_duration: 61, device_duration: 472, self_host_duration: 8, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::add_', @@ -2820,7 +2828,7 @@ export class MockAPI { host_duration: 448, device_duration: 5140, self_host_duration: 268, - self_device_duration: 5140, + self_device_duration: 5140 }, { name: 'aten::mean', @@ -2828,7 +2836,7 @@ export class MockAPI { host_duration: 53, device_duration: 63, self_host_duration: 39, - self_device_duration: 63, + self_device_duration: 63 }, { name: 'aten::adaptive_avg_pool2d', @@ -2836,7 +2844,7 @@ export class MockAPI { host_duration: 59, device_duration: 63, self_host_duration: 6, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::_reshape_alias', @@ -2844,7 +2852,7 @@ export class MockAPI { host_duration: 8, device_duration: 0, self_host_duration: 8, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::flatten', @@ -2852,7 +2860,7 @@ export class MockAPI { host_duration: 15, device_duration: 0, self_host_duration: 7, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::as_strided', @@ -2860,7 +2868,7 @@ export class MockAPI { host_duration: 3, device_duration: 0, self_host_duration: 3, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::transpose', @@ -2868,7 +2876,7 @@ export class MockAPI { host_duration: 8, device_duration: 0, self_host_duration: 6, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::t', @@ -2876,7 +2884,7 @@ export class MockAPI { host_duration: 15, device_duration: 0, self_host_duration: 7, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::expand', @@ -2884,7 +2892,7 @@ export class MockAPI { host_duration: 6, device_duration: 0, self_host_duration: 5, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::addmm', @@ -2892,7 +2900,7 @@ export class MockAPI { host_duration: 173, device_duration: 42, self_host_duration: 123, - self_device_duration: 42, + self_device_duration: 42 }, { name: 'aten::linear', @@ -2900,11 +2908,11 @@ export class MockAPI { host_duration: 198, device_duration: 42, self_host_duration: 10, - self_device_duration: 0, - }, - ], + self_device_duration: 0 + } + ] }, - path: '0-3', + path: '0-3' }, { left: { @@ -2919,7 +2927,7 @@ export class MockAPI { host_duration: 5, device_duration: 0, self_host_duration: 5, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::_log_softmax', @@ -2927,7 +2935,7 @@ export class MockAPI { host_duration: 158, device_duration: 7, self_host_duration: 139, - self_device_duration: 7, + self_device_duration: 7 }, { name: 'aten::log_softmax', @@ -2935,7 +2943,7 @@ export class MockAPI { host_duration: 241, device_duration: 7, self_host_duration: 78, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::resize_', @@ -2943,7 +2951,7 @@ export class MockAPI { host_duration: 5, device_duration: 0, self_host_duration: 5, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::nll_loss_forward', @@ -2951,7 +2959,7 @@ export class MockAPI { host_duration: 256, device_duration: 4, self_host_duration: 233, - self_device_duration: 4, + self_device_duration: 4 }, { name: 'aten::nll_loss', @@ -2959,7 +2967,7 @@ export class MockAPI { host_duration: 290, device_duration: 4, self_host_duration: 34, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::nll_loss_nd', @@ -2967,7 +2975,7 @@ export class MockAPI { host_duration: 313, device_duration: 4, self_host_duration: 23, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::cross_entropy_loss', @@ -2975,9 +2983,9 @@ export class MockAPI { host_duration: 614, device_duration: 11, self_host_duration: 60, - self_device_duration: 0, - }, - ], + self_device_duration: 0 + } + ] }, right: { name: 'nn.Module: CrossEntropyLoss', @@ -2991,7 +2999,7 @@ export class MockAPI { host_duration: 2, device_duration: 0, self_host_duration: 2, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::_log_softmax', @@ -2999,7 +3007,7 @@ export class MockAPI { host_duration: 42, device_duration: 7, self_host_duration: 28, - self_device_duration: 7, + self_device_duration: 7 }, { name: 'aten::log_softmax', @@ -3007,7 +3015,7 @@ export class MockAPI { host_duration: 54, device_duration: 7, self_host_duration: 10, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::resize_', @@ -3015,7 +3023,7 @@ export class MockAPI { host_duration: 0, device_duration: 0, self_host_duration: 0, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::nll_loss_forward', @@ -3023,7 +3031,7 @@ export class MockAPI { host_duration: 47, device_duration: 4, self_host_duration: 34, - self_device_duration: 4, + self_device_duration: 4 }, { name: 'aten::nll_loss', @@ -3031,7 +3039,7 @@ export class MockAPI { host_duration: 52, device_duration: 4, self_host_duration: 5, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::nll_loss_nd', @@ -3039,7 +3047,7 @@ export class MockAPI { host_duration: 56, device_duration: 4, self_host_duration: 4, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::cross_entropy_loss', @@ -3047,11 +3055,11 @@ export class MockAPI { host_duration: 119, device_duration: 11, self_host_duration: 9, - self_device_duration: 0, - }, - ], + self_device_duration: 0 + } + ] }, - path: '0-4', + path: '0-4' }, { left: { @@ -3066,7 +3074,7 @@ export class MockAPI { host_duration: 47, device_duration: 0, self_host_duration: 47, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::zero_', @@ -3074,7 +3082,7 @@ export class MockAPI { host_duration: 4, device_duration: 0, self_host_duration: 4, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::zeros', @@ -3082,9 +3090,9 @@ export class MockAPI { host_duration: 119, device_duration: 0, self_host_duration: 68, - self_device_duration: 0, - }, - ], + self_device_duration: 0 + } + ] }, right: { name: 'aten::zeros', @@ -3098,7 +3106,7 @@ export class MockAPI { host_duration: 8, device_duration: 0, self_host_duration: 8, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::zero_', @@ -3106,7 +3114,7 @@ export class MockAPI { host_duration: 2, device_duration: 0, self_host_duration: 2, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::zeros', @@ -3114,11 +3122,11 @@ export class MockAPI { host_duration: 17, device_duration: 0, self_host_duration: 7, - self_device_duration: 0, - }, - ], + self_device_duration: 0 + } + ] }, - path: '0-5', + path: '0-5' }, { left: { @@ -3133,7 +3141,7 @@ export class MockAPI { host_duration: 38, device_duration: 0, self_host_duration: 38, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::fill_', @@ -3141,7 +3149,7 @@ export class MockAPI { host_duration: 7097, device_duration: 142, self_host_duration: 4914, - self_device_duration: 142, + self_device_duration: 142 }, { name: 'aten::zero_', @@ -3149,9 +3157,9 @@ export class MockAPI { host_duration: 14725, device_duration: 142, self_host_duration: 7628, - self_device_duration: 0, - }, - ], + self_device_duration: 0 + } + ] }, right: { name: 'Optimizer.zero_grad#SGD.zero_grad', @@ -3165,7 +3173,7 @@ export class MockAPI { host_duration: 6, device_duration: 0, self_host_duration: 6, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::fill_', @@ -3173,7 +3181,7 @@ export class MockAPI { host_duration: 2036, device_duration: 264, self_host_duration: 909, - self_device_duration: 264, + self_device_duration: 264 }, { name: 'aten::zero_', @@ -3181,11 +3189,11 @@ export class MockAPI { host_duration: 2855, device_duration: 264, self_host_duration: 819, - self_device_duration: 0, - }, - ], + self_device_duration: 0 + } + ] }, - path: '0-6', + path: '0-6' }, { left: { @@ -3200,7 +3208,7 @@ export class MockAPI { host_duration: 79, device_duration: 0, self_host_duration: 79, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::empty_like', @@ -3208,7 +3216,7 @@ export class MockAPI { host_duration: 126, device_duration: 0, self_host_duration: 47, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::fill_', @@ -3216,7 +3224,7 @@ export class MockAPI { host_duration: 50, device_duration: 1, self_host_duration: 35, - self_device_duration: 1, + self_device_duration: 1 }, { name: 'aten::ones_like', @@ -3224,9 +3232,9 @@ export class MockAPI { host_duration: 253, device_duration: 1, self_host_duration: 77, - self_device_duration: 0, - }, - ], + self_device_duration: 0 + } + ] }, right: { name: 'aten::ones_like', @@ -3240,7 +3248,7 @@ export class MockAPI { host_duration: 18, device_duration: 0, self_host_duration: 18, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::empty_like', @@ -3248,7 +3256,7 @@ export class MockAPI { host_duration: 26, device_duration: 0, self_host_duration: 8, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::fill_', @@ -3256,7 +3264,7 @@ export class MockAPI { host_duration: 20, device_duration: 1, self_host_duration: 8, - self_device_duration: 1, + self_device_duration: 1 }, { name: 'aten::ones_like', @@ -3264,11 +3272,11 @@ export class MockAPI { host_duration: 53, device_duration: 1, self_host_duration: 7, - self_device_duration: 0, - }, - ], + self_device_duration: 0 + } + ] }, - path: '0-7', + path: '0-7' }, { left: { @@ -3283,7 +3291,7 @@ export class MockAPI { host_duration: 69, device_duration: 1, self_host_duration: 43, - self_device_duration: 1, + self_device_duration: 1 }, { name: 'aten::zero_', @@ -3291,7 +3299,7 @@ export class MockAPI { host_duration: 120, device_duration: 1, self_host_duration: 51, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::nll_loss_backward', @@ -3299,7 +3307,7 @@ export class MockAPI { host_duration: 304, device_duration: 4, self_host_duration: 168, - self_device_duration: 3, + self_device_duration: 3 }, { name: 'NllLossBackward0', @@ -3307,7 +3315,7 @@ export class MockAPI { host_duration: 368, device_duration: 4, self_host_duration: 64, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'autograd::engine::evaluate_function: NllLossBackward0', @@ -3315,7 +3323,7 @@ export class MockAPI { host_duration: 503, device_duration: 4, self_host_duration: 135, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::_log_softmax_backward_data', @@ -3323,7 +3331,7 @@ export class MockAPI { host_duration: 127, device_duration: 9, self_host_duration: 105, - self_device_duration: 9, + self_device_duration: 9 }, { name: 'LogSoftmaxBackward0', @@ -3331,17 +3339,18 @@ export class MockAPI { host_duration: 207, device_duration: 9, self_host_duration: 80, - self_device_duration: 0, + self_device_duration: 0 }, { - name: 'autograd::engine::evaluate_function: LogSoftmaxBackward0', + name: + 'autograd::engine::evaluate_function: LogSoftmaxBackward0', calls: 1, host_duration: 349, device_duration: 9, self_host_duration: 142, - self_device_duration: 0, - }, - ], + self_device_duration: 0 + } + ] }, right: { name: 'nn.Module: CrossEntropyLoss.backward', @@ -3355,7 +3364,7 @@ export class MockAPI { host_duration: 36, device_duration: 2, self_host_duration: 13, - self_device_duration: 2, + self_device_duration: 2 }, { name: 'aten::zero_', @@ -3363,7 +3372,7 @@ export class MockAPI { host_duration: 45, device_duration: 2, self_host_duration: 9, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::nll_loss_backward', @@ -3371,7 +3380,7 @@ export class MockAPI { host_duration: 99, device_duration: 5, self_host_duration: 43, - self_device_duration: 3, + self_device_duration: 3 }, { name: 'NllLossBackward0', @@ -3379,7 +3388,7 @@ export class MockAPI { host_duration: 112, device_duration: 5, self_host_duration: 13, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'autograd::engine::evaluate_function: NllLossBackward0', @@ -3387,7 +3396,7 @@ export class MockAPI { host_duration: 141, device_duration: 5, self_host_duration: 29, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::_log_softmax_backward_data', @@ -3395,7 +3404,7 @@ export class MockAPI { host_duration: 35, device_duration: 9, self_host_duration: 21, - self_device_duration: 9, + self_device_duration: 9 }, { name: 'LogSoftmaxBackward0', @@ -3403,19 +3412,20 @@ export class MockAPI { host_duration: 46, device_duration: 9, self_host_duration: 11, - self_device_duration: 0, + self_device_duration: 0 }, { - name: 'autograd::engine::evaluate_function: LogSoftmaxBackward0', + name: + 'autograd::engine::evaluate_function: LogSoftmaxBackward0', calls: 1, host_duration: 64, device_duration: 9, self_host_duration: 18, - self_device_duration: 0, - }, - ], + self_device_duration: 0 + } + ] }, - path: '0-8', + path: '0-8' }, { left: { @@ -3430,7 +3440,7 @@ export class MockAPI { host_duration: 61, device_duration: 0, self_host_duration: 61, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::transpose', @@ -3438,7 +3448,7 @@ export class MockAPI { host_duration: 226, device_duration: 0, self_host_duration: 180, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::t', @@ -3446,7 +3456,7 @@ export class MockAPI { host_duration: 399, device_duration: 0, self_host_duration: 173, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::mm', @@ -3454,7 +3464,7 @@ export class MockAPI { host_duration: 345, device_duration: 72, self_host_duration: 282, - self_device_duration: 72, + self_device_duration: 72 }, { name: 'AddmmBackward0', @@ -3462,7 +3472,7 @@ export class MockAPI { host_duration: 854, device_duration: 72, self_host_duration: 208, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::sum', @@ -3470,7 +3480,7 @@ export class MockAPI { host_duration: 173, device_duration: 8, self_host_duration: 153, - self_device_duration: 8, + self_device_duration: 8 }, { name: 'aten::view', @@ -3478,7 +3488,7 @@ export class MockAPI { host_duration: 971, device_duration: 0, self_host_duration: 971, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'autograd::engine::evaluate_function: AddmmBackward0', @@ -3486,7 +3496,7 @@ export class MockAPI { host_duration: 1333, device_duration: 80, self_host_duration: 271, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::add_', @@ -3494,7 +3504,7 @@ export class MockAPI { host_duration: 12621, device_duration: 501, self_host_duration: 9839, - self_device_duration: 501, + self_device_duration: 501 }, { name: 'torch::autograd::AccumulateGrad', @@ -3502,15 +3512,16 @@ export class MockAPI { host_duration: 20767, device_duration: 501, self_host_duration: 8146, - self_device_duration: 0, + self_device_duration: 0 }, { - name: 'autograd::engine::evaluate_function: torch::autograd::AccumulateGrad', + name: + 'autograd::engine::evaluate_function: torch::autograd::AccumulateGrad', calls: 161, host_duration: 35735, device_duration: 501, self_host_duration: 14968, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'TBackward0', @@ -3518,7 +3529,7 @@ export class MockAPI { host_duration: 128, device_duration: 0, self_host_duration: 30, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'autograd::engine::evaluate_function: TBackward0', @@ -3526,7 +3537,7 @@ export class MockAPI { host_duration: 197, device_duration: 0, self_host_duration: 69, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::_reshape_alias', @@ -3534,7 +3545,7 @@ export class MockAPI { host_duration: 31, device_duration: 0, self_host_duration: 31, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::reshape', @@ -3542,7 +3553,7 @@ export class MockAPI { host_duration: 79, device_duration: 0, self_host_duration: 48, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'ReshapeAliasBackward0', @@ -3550,15 +3561,16 @@ export class MockAPI { host_duration: 131, device_duration: 0, self_host_duration: 52, - self_device_duration: 0, + self_device_duration: 0 }, { - name: 'autograd::engine::evaluate_function: ReshapeAliasBackward0', + name: + 'autograd::engine::evaluate_function: ReshapeAliasBackward0', calls: 1, host_duration: 197, device_duration: 0, self_host_duration: 66, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::expand', @@ -3566,7 +3578,7 @@ export class MockAPI { host_duration: 84, device_duration: 0, self_host_duration: 69, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::to', @@ -3574,7 +3586,7 @@ export class MockAPI { host_duration: 6, device_duration: 0, self_host_duration: 6, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::div', @@ -3582,7 +3594,7 @@ export class MockAPI { host_duration: 289, device_duration: 38, self_host_duration: 267, - self_device_duration: 38, + self_device_duration: 38 }, { name: 'MeanBackward1', @@ -3590,7 +3602,7 @@ export class MockAPI { host_duration: 489, device_duration: 38, self_host_duration: 110, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'autograd::engine::evaluate_function: MeanBackward1', @@ -3598,7 +3610,7 @@ export class MockAPI { host_duration: 592, device_duration: 38, self_host_duration: 103, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::threshold_backward', @@ -3606,7 +3618,7 @@ export class MockAPI { host_duration: 6958, device_duration: 8972, self_host_duration: 6094, - self_device_duration: 8972, + self_device_duration: 8972 }, { name: 'ReluBackward0', @@ -3614,7 +3626,7 @@ export class MockAPI { host_duration: 10647, device_duration: 8972, self_host_duration: 3689, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'autograd::engine::evaluate_function: ReluBackward0', @@ -3622,7 +3634,7 @@ export class MockAPI { host_duration: 16826, device_duration: 8972, self_host_duration: 6179, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'AddBackward0', @@ -3630,7 +3642,7 @@ export class MockAPI { host_duration: 129, device_duration: 0, self_host_duration: 129, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'autograd::engine::evaluate_function: AddBackward0', @@ -3638,7 +3650,7 @@ export class MockAPI { host_duration: 1301, device_duration: 0, self_host_duration: 1172, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::empty', @@ -3646,7 +3658,7 @@ export class MockAPI { host_duration: 20319, device_duration: 0, self_host_duration: 20319, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::cudnn_batch_norm_backward', @@ -3654,7 +3666,7 @@ export class MockAPI { host_duration: 31300, device_duration: 22267, self_host_duration: 18144, - self_device_duration: 22267, + self_device_duration: 22267 }, { name: 'CudnnBatchNormBackward0', @@ -3662,15 +3674,16 @@ export class MockAPI { host_duration: 34805, device_duration: 22267, self_host_duration: 3505, - self_device_duration: 0, + self_device_duration: 0 }, { - name: 'autograd::engine::evaluate_function: CudnnBatchNormBackward0', + name: + 'autograd::engine::evaluate_function: CudnnBatchNormBackward0', calls: 53, host_duration: 44607, device_duration: 22267, self_host_duration: 9802, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::cudnn_convolution_backward_input', @@ -3678,7 +3691,7 @@ export class MockAPI { host_duration: 20324, device_duration: 38733, self_host_duration: 15252, - self_device_duration: 38733, + self_device_duration: 38733 }, { name: 'aten::cudnn_convolution_backward_weight', @@ -3686,7 +3699,7 @@ export class MockAPI { host_duration: 21997, device_duration: 45837, self_host_duration: 13786, - self_device_duration: 45837, + self_device_duration: 45837 }, { name: 'aten::cudnn_convolution_backward', @@ -3694,7 +3707,7 @@ export class MockAPI { host_duration: 50059, device_duration: 84570, self_host_duration: 7738, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'CudnnConvolutionBackward0', @@ -3702,15 +3715,16 @@ export class MockAPI { host_duration: 53558, device_duration: 84570, self_host_duration: 3499, - self_device_duration: 0, + self_device_duration: 0 }, { - name: 'autograd::engine::evaluate_function: CudnnConvolutionBackward0', + name: + 'autograd::engine::evaluate_function: CudnnConvolutionBackward0', calls: 53, host_duration: 64252, device_duration: 89775, self_host_duration: 8462, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::add', @@ -3718,7 +3732,7 @@ export class MockAPI { host_duration: 2232, device_duration: 5205, self_host_duration: 1944, - self_device_duration: 5205, + self_device_duration: 5205 }, { name: 'aten::fill_', @@ -3726,7 +3740,7 @@ export class MockAPI { host_duration: 61, device_duration: 230, self_host_duration: 44, - self_device_duration: 230, + self_device_duration: 230 }, { name: 'aten::zero_', @@ -3734,7 +3748,7 @@ export class MockAPI { host_duration: 104, device_duration: 230, self_host_duration: 43, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::max_pool2d_with_indices_backward', @@ -3742,7 +3756,7 @@ export class MockAPI { host_duration: 246, device_duration: 1544, self_host_duration: 128, - self_device_duration: 1314, + self_device_duration: 1314 }, { name: 'MaxPool2DWithIndicesBackward0', @@ -3750,17 +3764,18 @@ export class MockAPI { host_duration: 304, device_duration: 1544, self_host_duration: 58, - self_device_duration: 0, + self_device_duration: 0 }, { - name: 'autograd::engine::evaluate_function: MaxPool2DWithIndicesBackward0', + name: + 'autograd::engine::evaluate_function: MaxPool2DWithIndicesBackward0', calls: 1, host_duration: 425, device_duration: 1544, self_host_duration: 121, - self_device_duration: 0, - }, - ], + self_device_duration: 0 + } + ] }, right: { name: 'nn.Module: ResNet.backward', @@ -3774,7 +3789,7 @@ export class MockAPI { host_duration: 9, device_duration: 0, self_host_duration: 9, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::transpose', @@ -3782,7 +3797,7 @@ export class MockAPI { host_duration: 38, device_duration: 0, self_host_duration: 31, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::t', @@ -3790,7 +3805,7 @@ export class MockAPI { host_duration: 59, device_duration: 0, self_host_duration: 21, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::mm', @@ -3798,7 +3813,7 @@ export class MockAPI { host_duration: 139, device_duration: 67, self_host_duration: 90, - self_device_duration: 67, + self_device_duration: 67 }, { name: 'AddmmBackward0', @@ -3806,7 +3821,7 @@ export class MockAPI { host_duration: 210, device_duration: 67, self_host_duration: 23, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::sum', @@ -3814,7 +3829,7 @@ export class MockAPI { host_duration: 47, device_duration: 7, self_host_duration: 32, - self_device_duration: 7, + self_device_duration: 7 }, { name: 'aten::view', @@ -3822,7 +3837,7 @@ export class MockAPI { host_duration: 166, device_duration: 0, self_host_duration: 166, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'autograd::engine::evaluate_function: AddmmBackward0', @@ -3830,7 +3845,7 @@ export class MockAPI { host_duration: 299, device_duration: 74, self_host_duration: 37, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::add_', @@ -3838,7 +3853,7 @@ export class MockAPI { host_duration: 4087, device_duration: 534, self_host_duration: 2037, - self_device_duration: 534, + self_device_duration: 534 }, { name: 'torch::autograd::AccumulateGrad', @@ -3846,15 +3861,16 @@ export class MockAPI { host_duration: 5134, device_duration: 534, self_host_duration: 1047, - self_device_duration: 0, + self_device_duration: 0 }, { - name: 'autograd::engine::evaluate_function: torch::autograd::AccumulateGrad', + name: + 'autograd::engine::evaluate_function: torch::autograd::AccumulateGrad', calls: 161, host_duration: 7473, device_duration: 534, self_host_duration: 2339, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'TBackward0', @@ -3862,7 +3878,7 @@ export class MockAPI { host_duration: 14, device_duration: 0, self_host_duration: 3, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'autograd::engine::evaluate_function: TBackward0', @@ -3870,7 +3886,7 @@ export class MockAPI { host_duration: 21, device_duration: 0, self_host_duration: 7, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::_reshape_alias', @@ -3878,7 +3894,7 @@ export class MockAPI { host_duration: 5, device_duration: 0, self_host_duration: 5, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::reshape', @@ -3886,7 +3902,7 @@ export class MockAPI { host_duration: 10, device_duration: 0, self_host_duration: 5, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'ReshapeAliasBackward0', @@ -3894,15 +3910,16 @@ export class MockAPI { host_duration: 14, device_duration: 0, self_host_duration: 4, - self_device_duration: 0, + self_device_duration: 0 }, { - name: 'autograd::engine::evaluate_function: ReshapeAliasBackward0', + name: + 'autograd::engine::evaluate_function: ReshapeAliasBackward0', calls: 1, host_duration: 21, device_duration: 0, self_host_duration: 7, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::expand', @@ -3910,7 +3927,7 @@ export class MockAPI { host_duration: 9, device_duration: 0, self_host_duration: 7, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::to', @@ -3918,7 +3935,7 @@ export class MockAPI { host_duration: 1, device_duration: 0, self_host_duration: 1, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::div', @@ -3926,7 +3943,7 @@ export class MockAPI { host_duration: 70, device_duration: 38, self_host_duration: 49, - self_device_duration: 38, + self_device_duration: 38 }, { name: 'MeanBackward1', @@ -3934,7 +3951,7 @@ export class MockAPI { host_duration: 89, device_duration: 38, self_host_duration: 9, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'autograd::engine::evaluate_function: MeanBackward1', @@ -3942,7 +3959,7 @@ export class MockAPI { host_duration: 102, device_duration: 38, self_host_duration: 13, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::threshold_backward', @@ -3950,7 +3967,7 @@ export class MockAPI { host_duration: 1789, device_duration: 9015, self_host_duration: 1158, - self_device_duration: 9015, + self_device_duration: 9015 }, { name: 'ReluBackward0', @@ -3958,7 +3975,7 @@ export class MockAPI { host_duration: 2237, device_duration: 9015, self_host_duration: 448, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'autograd::engine::evaluate_function: ReluBackward0', @@ -3966,7 +3983,7 @@ export class MockAPI { host_duration: 3144, device_duration: 9015, self_host_duration: 907, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'AddBackward0', @@ -3974,7 +3991,7 @@ export class MockAPI { host_duration: 12, device_duration: 0, self_host_duration: 12, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'autograd::engine::evaluate_function: AddBackward0', @@ -3982,7 +3999,7 @@ export class MockAPI { host_duration: 126, device_duration: 0, self_host_duration: 114, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::empty', @@ -3990,7 +4007,7 @@ export class MockAPI { host_duration: 3292, device_duration: 0, self_host_duration: 3292, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::cudnn_batch_norm_backward', @@ -3998,7 +4015,7 @@ export class MockAPI { host_duration: 4896, device_duration: 22157, self_host_duration: 2136, - self_device_duration: 22157, + self_device_duration: 22157 }, { name: 'CudnnBatchNormBackward0', @@ -4006,15 +4023,16 @@ export class MockAPI { host_duration: 5495, device_duration: 22157, self_host_duration: 599, - self_device_duration: 0, + self_device_duration: 0 }, { - name: 'autograd::engine::evaluate_function: CudnnBatchNormBackward0', + name: + 'autograd::engine::evaluate_function: CudnnBatchNormBackward0', calls: 53, host_duration: 7289, device_duration: 22157, self_host_duration: 1794, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::cudnn_convolution_backward_input', @@ -4022,7 +4040,7 @@ export class MockAPI { host_duration: 9468, device_duration: 37714, self_host_duration: 7052, - self_device_duration: 37714, + self_device_duration: 37714 }, { name: 'aten::cudnn_convolution_backward_weight', @@ -4030,7 +4048,7 @@ export class MockAPI { host_duration: 8906, device_duration: 44342, self_host_duration: 5723, - self_device_duration: 44342, + self_device_duration: 44342 }, { name: 'aten::cudnn_convolution_backward', @@ -4038,7 +4056,7 @@ export class MockAPI { host_duration: 19611, device_duration: 82056, self_host_duration: 1237, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'CudnnConvolutionBackward0', @@ -4046,15 +4064,16 @@ export class MockAPI { host_duration: 20205, device_duration: 82056, self_host_duration: 594, - self_device_duration: 0, + self_device_duration: 0 }, { - name: 'autograd::engine::evaluate_function: CudnnConvolutionBackward0', + name: + 'autograd::engine::evaluate_function: CudnnConvolutionBackward0', calls: 53, host_duration: 22185, device_duration: 87283, self_host_duration: 1386, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::add', @@ -4062,7 +4081,7 @@ export class MockAPI { host_duration: 594, device_duration: 5227, self_host_duration: 380, - self_device_duration: 5227, + self_device_duration: 5227 }, { name: 'aten::fill_', @@ -4070,7 +4089,7 @@ export class MockAPI { host_duration: 24, device_duration: 230, self_host_duration: 11, - self_device_duration: 230, + self_device_duration: 230 }, { name: 'aten::zero_', @@ -4078,7 +4097,7 @@ export class MockAPI { host_duration: 32, device_duration: 230, self_host_duration: 8, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::max_pool2d_with_indices_backward', @@ -4086,7 +4105,7 @@ export class MockAPI { host_duration: 72, device_duration: 1503, self_host_duration: 31, - self_device_duration: 1273, + self_device_duration: 1273 }, { name: 'MaxPool2DWithIndicesBackward0', @@ -4094,19 +4113,20 @@ export class MockAPI { host_duration: 82, device_duration: 1503, self_host_duration: 10, - self_device_duration: 0, + self_device_duration: 0 }, { - name: 'autograd::engine::evaluate_function: MaxPool2DWithIndicesBackward0', + name: + 'autograd::engine::evaluate_function: MaxPool2DWithIndicesBackward0', calls: 1, host_duration: 103, device_duration: 1503, self_host_duration: 21, - self_device_duration: 0, - }, - ], + self_device_duration: 0 + } + ] }, - path: '0-9', + path: '0-9' }, { left: { @@ -4121,7 +4141,7 @@ export class MockAPI { host_duration: 75, device_duration: 0, self_host_duration: 75, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::zero_', @@ -4129,7 +4149,7 @@ export class MockAPI { host_duration: 4, device_duration: 0, self_host_duration: 4, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::zeros', @@ -4137,9 +4157,9 @@ export class MockAPI { host_duration: 154, device_duration: 0, self_host_duration: 75, - self_device_duration: 0, - }, - ], + self_device_duration: 0 + } + ] }, right: { name: 'aten::zeros', @@ -4153,7 +4173,7 @@ export class MockAPI { host_duration: 32, device_duration: 0, self_host_duration: 32, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::zero_', @@ -4161,7 +4181,7 @@ export class MockAPI { host_duration: 1, device_duration: 0, self_host_duration: 1, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::zeros', @@ -4169,11 +4189,11 @@ export class MockAPI { host_duration: 42, device_duration: 0, self_host_duration: 9, - self_device_duration: 0, - }, - ], + self_device_duration: 0 + } + ] }, - path: '0-10', + path: '0-10' }, { left: { @@ -4188,7 +4208,7 @@ export class MockAPI { host_duration: 40, device_duration: 0, self_host_duration: 40, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::mul_', @@ -4196,7 +4216,7 @@ export class MockAPI { host_duration: 11873, device_duration: 396, self_host_duration: 9505, - self_device_duration: 396, + self_device_duration: 396 }, { name: 'aten::add_', @@ -4204,9 +4224,9 @@ export class MockAPI { host_duration: 22327, device_duration: 893, self_host_duration: 17668, - self_device_duration: 893, - }, - ], + self_device_duration: 893 + } + ] }, right: { name: 'Optimizer.step#SGD.step', @@ -4220,7 +4240,7 @@ export class MockAPI { host_duration: 6, device_duration: 0, self_host_duration: 6, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::mul_', @@ -4228,7 +4248,7 @@ export class MockAPI { host_duration: 3395, device_duration: 399, self_host_duration: 1806, - self_device_duration: 399, + self_device_duration: 399 }, { name: 'aten::add_', @@ -4236,11 +4256,11 @@ export class MockAPI { host_duration: 6217, device_duration: 906, self_host_duration: 3246, - self_device_duration: 906, - }, - ], + self_device_duration: 906 + } + ] }, - path: '0-11', + path: '0-11' }, { left: { @@ -4255,7 +4275,7 @@ export class MockAPI { host_duration: 79, device_duration: 0, self_host_duration: 79, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::zero_', @@ -4263,7 +4283,7 @@ export class MockAPI { host_duration: 4, device_duration: 0, self_host_duration: 4, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::zeros', @@ -4271,9 +4291,9 @@ export class MockAPI { host_duration: 106, device_duration: 0, self_host_duration: 62, - self_device_duration: 0, - }, - ], + self_device_duration: 0 + } + ] }, right: { name: 'multiple nodes', @@ -4287,7 +4307,7 @@ export class MockAPI { host_duration: 10, device_duration: 0, self_host_duration: 10, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::zero_', @@ -4295,7 +4315,7 @@ export class MockAPI { host_duration: 0, device_duration: 0, self_host_duration: 0, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::zeros', @@ -4303,11 +4323,11 @@ export class MockAPI { host_duration: 9, device_duration: 0, self_host_duration: 5, - self_device_duration: 0, - }, - ], + self_device_duration: 0 + } + ] }, - path: '0-12', + path: '0-12' }, { left: { @@ -4322,7 +4342,7 @@ export class MockAPI { host_duration: 53837, device_duration: 0, self_host_duration: 53837, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::zero_', @@ -4330,7 +4350,7 @@ export class MockAPI { host_duration: 955, device_duration: 0, self_host_duration: 955, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::zeros', @@ -4338,7 +4358,7 @@ export class MockAPI { host_duration: 26673, device_duration: 0, self_host_duration: 16083, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::to', @@ -4346,7 +4366,7 @@ export class MockAPI { host_duration: 824006, device_duration: 0, self_host_duration: 18525, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'detach', @@ -4354,7 +4374,7 @@ export class MockAPI { host_duration: 2188, device_duration: 0, self_host_duration: 2188, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::detach', @@ -4362,7 +4382,7 @@ export class MockAPI { host_duration: 5295, device_duration: 0, self_host_duration: 3107, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::as_strided', @@ -4370,7 +4390,7 @@ export class MockAPI { host_duration: 4123, device_duration: 0, self_host_duration: 4123, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::unsqueeze', @@ -4378,7 +4398,7 @@ export class MockAPI { host_duration: 9590, device_duration: 0, self_host_duration: 8097, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::empty_strided', @@ -4386,7 +4406,7 @@ export class MockAPI { host_duration: 24764, device_duration: 0, self_host_duration: 24764, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::copy_', @@ -4394,7 +4414,7 @@ export class MockAPI { host_duration: 728608, device_duration: 0, self_host_duration: 728608, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::_to_copy', @@ -4402,7 +4422,7 @@ export class MockAPI { host_duration: 805481, device_duration: 0, self_host_duration: 51350, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::upsample_bilinear2d', @@ -4410,7 +4430,7 @@ export class MockAPI { host_duration: 236448, device_duration: 0, self_host_duration: 216887, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::squeeze', @@ -4418,7 +4438,7 @@ export class MockAPI { host_duration: 4682, device_duration: 0, self_host_duration: 4092, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::round', @@ -4426,7 +4446,7 @@ export class MockAPI { host_duration: 15283, device_duration: 0, self_host_duration: 15283, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::slice', @@ -4434,7 +4454,7 @@ export class MockAPI { host_duration: 8844, device_duration: 0, self_host_duration: 7513, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'detach_', @@ -4442,7 +4462,7 @@ export class MockAPI { host_duration: 2102, device_duration: 0, self_host_duration: 2102, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::detach_', @@ -4450,7 +4470,7 @@ export class MockAPI { host_duration: 7286, device_duration: 0, self_host_duration: 5184, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::result_type', @@ -4458,7 +4478,7 @@ export class MockAPI { host_duration: 850, device_duration: 0, self_host_duration: 850, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::pow', @@ -4466,7 +4486,7 @@ export class MockAPI { host_duration: 43219, device_duration: 0, self_host_duration: 39305, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::sub', @@ -4474,7 +4494,7 @@ export class MockAPI { host_duration: 92093, device_duration: 0, self_host_duration: 37961, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::gt', @@ -4482,7 +4502,7 @@ export class MockAPI { host_duration: 35770, device_duration: 0, self_host_duration: 24869, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::_local_scalar_dense', @@ -4490,7 +4510,7 @@ export class MockAPI { host_duration: 2481, device_duration: 0, self_host_duration: 2481, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::item', @@ -4498,7 +4518,7 @@ export class MockAPI { host_duration: 10547, device_duration: 0, self_host_duration: 8066, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::is_nonzero', @@ -4506,7 +4526,7 @@ export class MockAPI { host_duration: 14029, device_duration: 0, self_host_duration: 5364, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::div', @@ -4514,7 +4534,7 @@ export class MockAPI { host_duration: 79760, device_duration: 0, self_host_duration: 68841, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::resize_', @@ -4522,7 +4542,7 @@ export class MockAPI { host_duration: 121, device_duration: 0, self_host_duration: 121, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::narrow', @@ -4530,7 +4550,7 @@ export class MockAPI { host_duration: 138, device_duration: 0, self_host_duration: 48, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::_cat', @@ -4538,7 +4558,7 @@ export class MockAPI { host_duration: 41467, device_duration: 0, self_host_duration: 41176, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::cat', @@ -4546,7 +4566,7 @@ export class MockAPI { host_duration: 41608, device_duration: 0, self_host_duration: 141, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::stack', @@ -4554,9 +4574,9 @@ export class MockAPI { host_duration: 49080, device_duration: 0, self_host_duration: 2720, - self_device_duration: 0, - }, - ], + self_device_duration: 0 + } + ] }, right: { name: 'enumerate(DataLoader)#_SingleProcessDataLoaderIter.__next__', @@ -4570,7 +4590,7 @@ export class MockAPI { host_duration: 6528, device_duration: 0, self_host_duration: 6528, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::zero_', @@ -4578,7 +4598,7 @@ export class MockAPI { host_duration: 94, device_duration: 0, self_host_duration: 94, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::zeros', @@ -4586,7 +4606,7 @@ export class MockAPI { host_duration: 2448, device_duration: 0, self_host_duration: 1214, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::to', @@ -4594,7 +4614,7 @@ export class MockAPI { host_duration: 16544, device_duration: 0, self_host_duration: 1856, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'detach', @@ -4602,7 +4622,7 @@ export class MockAPI { host_duration: 337, device_duration: 0, self_host_duration: 337, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::detach', @@ -4610,7 +4630,7 @@ export class MockAPI { host_duration: 629, device_duration: 0, self_host_duration: 292, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::as_strided', @@ -4618,7 +4638,7 @@ export class MockAPI { host_duration: 464, device_duration: 0, self_host_duration: 464, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::unsqueeze', @@ -4626,7 +4646,7 @@ export class MockAPI { host_duration: 1024, device_duration: 0, self_host_duration: 854, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::empty_strided', @@ -4634,7 +4654,7 @@ export class MockAPI { host_duration: 3009, device_duration: 0, self_host_duration: 3009, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::copy_', @@ -4642,7 +4662,7 @@ export class MockAPI { host_duration: 7419, device_duration: 0, self_host_duration: 7419, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::_to_copy', @@ -4650,7 +4670,7 @@ export class MockAPI { host_duration: 14688, device_duration: 0, self_host_duration: 4039, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::upsample_bilinear2d', @@ -4658,7 +4678,7 @@ export class MockAPI { host_duration: 31439, device_duration: 0, self_host_duration: 29154, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::squeeze', @@ -4666,7 +4686,7 @@ export class MockAPI { host_duration: 473, device_duration: 0, self_host_duration: 408, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::round', @@ -4674,7 +4694,7 @@ export class MockAPI { host_duration: 4416, device_duration: 0, self_host_duration: 4416, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::slice', @@ -4682,7 +4702,7 @@ export class MockAPI { host_duration: 864, device_duration: 0, self_host_duration: 730, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'detach_', @@ -4690,7 +4710,7 @@ export class MockAPI { host_duration: 136, device_duration: 0, self_host_duration: 115, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::detach_', @@ -4698,7 +4718,7 @@ export class MockAPI { host_duration: 586, device_duration: 0, self_host_duration: 471, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::result_type', @@ -4706,7 +4726,7 @@ export class MockAPI { host_duration: 149, device_duration: 0, self_host_duration: 149, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::pow', @@ -4714,7 +4734,7 @@ export class MockAPI { host_duration: 3935, device_duration: 0, self_host_duration: 3519, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::sub', @@ -4722,7 +4742,7 @@ export class MockAPI { host_duration: 7881, device_duration: 0, self_host_duration: 3349, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::gt', @@ -4730,7 +4750,7 @@ export class MockAPI { host_duration: 3055, device_duration: 0, self_host_duration: 2164, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::_local_scalar_dense', @@ -4738,7 +4758,7 @@ export class MockAPI { host_duration: 186, device_duration: 0, self_host_duration: 186, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::item', @@ -4746,7 +4766,7 @@ export class MockAPI { host_duration: 1134, device_duration: 0, self_host_duration: 943, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::is_nonzero', @@ -4754,7 +4774,7 @@ export class MockAPI { host_duration: 1588, device_duration: 0, self_host_duration: 615, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::div', @@ -4762,7 +4782,7 @@ export class MockAPI { host_duration: 4153, device_duration: 0, self_host_duration: 3203, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::resize_', @@ -4770,7 +4790,7 @@ export class MockAPI { host_duration: 42, device_duration: 0, self_host_duration: 42, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::narrow', @@ -4778,7 +4798,7 @@ export class MockAPI { host_duration: 18, device_duration: 0, self_host_duration: 7, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::_cat', @@ -4786,7 +4806,7 @@ export class MockAPI { host_duration: 4613, device_duration: 0, self_host_duration: 4547, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::cat', @@ -4794,7 +4814,7 @@ export class MockAPI { host_duration: 4637, device_duration: 0, self_host_duration: 24, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::stack', @@ -4802,11 +4822,11 @@ export class MockAPI { host_duration: 5311, device_duration: 0, self_host_duration: 246, - self_device_duration: 0, - }, - ], + self_device_duration: 0 + } + ] }, - path: '0-13', + path: '0-13' }, { left: { @@ -4821,7 +4841,7 @@ export class MockAPI { host_duration: 203, device_duration: 0, self_host_duration: 203, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::copy_', @@ -4829,7 +4849,7 @@ export class MockAPI { host_duration: 4687, device_duration: 4394, self_host_duration: 94, - self_device_duration: 4394, + self_device_duration: 4394 }, { name: 'aten::_to_copy', @@ -4837,7 +4857,7 @@ export class MockAPI { host_duration: 5113, device_duration: 4394, self_host_duration: 223, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::to', @@ -4845,9 +4865,9 @@ export class MockAPI { host_duration: 5185, device_duration: 4394, self_host_duration: 72, - self_device_duration: 0, - }, - ], + self_device_duration: 0 + } + ] }, right: { name: 'multiple nodes', @@ -4861,7 +4881,7 @@ export class MockAPI { host_duration: 60, device_duration: 0, self_host_duration: 60, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::copy_', @@ -4869,7 +4889,7 @@ export class MockAPI { host_duration: 4559, device_duration: 4334, self_host_duration: 26, - self_device_duration: 4334, + self_device_duration: 4334 }, { name: 'aten::_to_copy', @@ -4877,7 +4897,7 @@ export class MockAPI { host_duration: 4655, device_duration: 4334, self_host_duration: 36, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::to', @@ -4885,11 +4905,11 @@ export class MockAPI { host_duration: 4664, device_duration: 4334, self_host_duration: 9, - self_device_duration: 0, - }, - ], + self_device_duration: 0 + } + ] }, - path: '0-14', + path: '0-14' }, { left: { @@ -4904,7 +4924,7 @@ export class MockAPI { host_duration: 13992, device_duration: 0, self_host_duration: 13992, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::cudnn_convolution', @@ -4912,7 +4932,7 @@ export class MockAPI { host_duration: 21952, device_duration: 35233, self_host_duration: 17460, - self_device_duration: 35233, + self_device_duration: 35233 }, { name: 'aten::_convolution', @@ -4920,7 +4940,7 @@ export class MockAPI { host_duration: 25568, device_duration: 35233, self_host_duration: 3616, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::convolution', @@ -4928,7 +4948,7 @@ export class MockAPI { host_duration: 27534, device_duration: 35233, self_host_duration: 1966, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::conv2d', @@ -4936,7 +4956,7 @@ export class MockAPI { host_duration: 29546, device_duration: 35233, self_host_duration: 2012, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::add', @@ -4944,7 +4964,7 @@ export class MockAPI { host_duration: 6523, device_duration: 53, self_host_duration: 5669, - self_device_duration: 53, + self_device_duration: 53 }, { name: 'aten::empty_like', @@ -4952,7 +4972,7 @@ export class MockAPI { host_duration: 5605, device_duration: 0, self_host_duration: 2378, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::view', @@ -4960,7 +4980,7 @@ export class MockAPI { host_duration: 829, device_duration: 0, self_host_duration: 829, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::cudnn_batch_norm', @@ -4968,7 +4988,7 @@ export class MockAPI { host_duration: 35510, device_duration: 12828, self_host_duration: 20387, - self_device_duration: 12828, + self_device_duration: 12828 }, { name: 'aten::_batch_norm_impl_index', @@ -4976,7 +4996,7 @@ export class MockAPI { host_duration: 38030, device_duration: 12828, self_host_duration: 2520, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::batch_norm', @@ -4984,7 +5004,7 @@ export class MockAPI { host_duration: 39727, device_duration: 12828, self_host_duration: 1697, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::clamp_min', @@ -4992,7 +5012,7 @@ export class MockAPI { host_duration: 2715, device_duration: 5998, self_host_duration: 1950, - self_device_duration: 5998, + self_device_duration: 5998 }, { name: 'aten::clamp_min_', @@ -5000,7 +5020,7 @@ export class MockAPI { host_duration: 4264, device_duration: 5998, self_host_duration: 1549, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::relu_', @@ -5008,7 +5028,7 @@ export class MockAPI { host_duration: 8337, device_duration: 5998, self_host_duration: 4073, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::max_pool2d_with_indices', @@ -5016,7 +5036,7 @@ export class MockAPI { host_duration: 212, device_duration: 466, self_host_duration: 193, - self_device_duration: 466, + self_device_duration: 466 }, { name: 'aten::max_pool2d', @@ -5024,7 +5044,7 @@ export class MockAPI { host_duration: 262, device_duration: 466, self_host_duration: 50, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::add_', @@ -5032,7 +5052,7 @@ export class MockAPI { host_duration: 1553, device_duration: 5165, self_host_duration: 1297, - self_device_duration: 5165, + self_device_duration: 5165 }, { name: 'aten::mean', @@ -5040,7 +5060,7 @@ export class MockAPI { host_duration: 187, device_duration: 64, self_host_duration: 169, - self_device_duration: 64, + self_device_duration: 64 }, { name: 'aten::adaptive_avg_pool2d', @@ -5048,7 +5068,7 @@ export class MockAPI { host_duration: 231, device_duration: 64, self_host_duration: 44, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::_reshape_alias', @@ -5056,7 +5076,7 @@ export class MockAPI { host_duration: 52, device_duration: 0, self_host_duration: 52, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::flatten', @@ -5064,7 +5084,7 @@ export class MockAPI { host_duration: 101, device_duration: 0, self_host_duration: 49, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::as_strided', @@ -5072,7 +5092,7 @@ export class MockAPI { host_duration: 21, device_duration: 0, self_host_duration: 21, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::transpose', @@ -5080,7 +5100,7 @@ export class MockAPI { host_duration: 51, device_duration: 0, self_host_duration: 40, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::t', @@ -5088,7 +5108,7 @@ export class MockAPI { host_duration: 120, device_duration: 0, self_host_duration: 69, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::expand', @@ -5096,7 +5116,7 @@ export class MockAPI { host_duration: 49, device_duration: 0, self_host_duration: 39, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::addmm', @@ -5104,7 +5124,7 @@ export class MockAPI { host_duration: 405, device_duration: 41, self_host_duration: 302, - self_device_duration: 41, + self_device_duration: 41 }, { name: 'aten::linear', @@ -5112,9 +5132,9 @@ export class MockAPI { host_duration: 594, device_duration: 41, self_host_duration: 69, - self_device_duration: 0, - }, - ], + self_device_duration: 0 + } + ] }, right: { name: 'nn.Module: ResNet', @@ -5128,7 +5148,7 @@ export class MockAPI { host_duration: 2234, device_duration: 0, self_host_duration: 2234, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::cudnn_convolution', @@ -5136,7 +5156,7 @@ export class MockAPI { host_duration: 8644, device_duration: 35209, self_host_duration: 6782, - self_device_duration: 35209, + self_device_duration: 35209 }, { name: 'aten::_convolution', @@ -5144,7 +5164,7 @@ export class MockAPI { host_duration: 9216, device_duration: 35209, self_host_duration: 572, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::convolution', @@ -5152,7 +5172,7 @@ export class MockAPI { host_duration: 9532, device_duration: 35209, self_host_duration: 316, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::conv2d', @@ -5160,7 +5180,7 @@ export class MockAPI { host_duration: 9818, device_duration: 35209, self_host_duration: 286, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::add', @@ -5168,7 +5188,7 @@ export class MockAPI { host_duration: 1898, device_duration: 55, self_host_duration: 1202, - self_device_duration: 55, + self_device_duration: 55 }, { name: 'aten::empty_like', @@ -5176,7 +5196,7 @@ export class MockAPI { host_duration: 941, device_duration: 0, self_host_duration: 300, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::view', @@ -5184,7 +5204,7 @@ export class MockAPI { host_duration: 137, device_duration: 0, self_host_duration: 137, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::cudnn_batch_norm', @@ -5192,7 +5212,7 @@ export class MockAPI { host_duration: 5543, device_duration: 12824, self_host_duration: 2527, - self_device_duration: 12824, + self_device_duration: 12824 }, { name: 'aten::_batch_norm_impl_index', @@ -5200,7 +5220,7 @@ export class MockAPI { host_duration: 5914, device_duration: 12824, self_host_duration: 371, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::batch_norm', @@ -5208,7 +5228,7 @@ export class MockAPI { host_duration: 6167, device_duration: 12824, self_host_duration: 253, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::clamp_min', @@ -5216,7 +5236,7 @@ export class MockAPI { host_duration: 1081, device_duration: 6004, self_host_duration: 507, - self_device_duration: 6004, + self_device_duration: 6004 }, { name: 'aten::clamp_min_', @@ -5224,7 +5244,7 @@ export class MockAPI { host_duration: 1299, device_duration: 6004, self_host_duration: 218, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::relu_', @@ -5232,7 +5252,7 @@ export class MockAPI { host_duration: 1941, device_duration: 6004, self_host_duration: 642, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::max_pool2d_with_indices', @@ -5240,7 +5260,7 @@ export class MockAPI { host_duration: 59, device_duration: 466, self_host_duration: 44, - self_device_duration: 466, + self_device_duration: 466 }, { name: 'aten::max_pool2d', @@ -5248,7 +5268,7 @@ export class MockAPI { host_duration: 66, device_duration: 466, self_host_duration: 7, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::add_', @@ -5256,7 +5276,7 @@ export class MockAPI { host_duration: 443, device_duration: 5169, self_host_duration: 267, - self_device_duration: 5169, + self_device_duration: 5169 }, { name: 'aten::mean', @@ -5264,7 +5284,7 @@ export class MockAPI { host_duration: 51, device_duration: 63, self_host_duration: 37, - self_device_duration: 63, + self_device_duration: 63 }, { name: 'aten::adaptive_avg_pool2d', @@ -5272,7 +5292,7 @@ export class MockAPI { host_duration: 58, device_duration: 63, self_host_duration: 7, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::_reshape_alias', @@ -5280,7 +5300,7 @@ export class MockAPI { host_duration: 8, device_duration: 0, self_host_duration: 8, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::flatten', @@ -5288,7 +5308,7 @@ export class MockAPI { host_duration: 16, device_duration: 0, self_host_duration: 8, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::as_strided', @@ -5296,7 +5316,7 @@ export class MockAPI { host_duration: 3, device_duration: 0, self_host_duration: 3, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::transpose', @@ -5304,7 +5324,7 @@ export class MockAPI { host_duration: 10, device_duration: 0, self_host_duration: 8, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::t', @@ -5312,7 +5332,7 @@ export class MockAPI { host_duration: 18, device_duration: 0, self_host_duration: 8, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::expand', @@ -5320,7 +5340,7 @@ export class MockAPI { host_duration: 5, device_duration: 0, self_host_duration: 4, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::addmm', @@ -5328,7 +5348,7 @@ export class MockAPI { host_duration: 161, device_duration: 42, self_host_duration: 111, - self_device_duration: 42, + self_device_duration: 42 }, { name: 'aten::linear', @@ -5336,11 +5356,11 @@ export class MockAPI { host_duration: 188, device_duration: 42, self_host_duration: 9, - self_device_duration: 0, - }, - ], + self_device_duration: 0 + } + ] }, - path: '0-15', + path: '0-15' }, { left: { @@ -5355,7 +5375,7 @@ export class MockAPI { host_duration: 6, device_duration: 0, self_host_duration: 6, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::_log_softmax', @@ -5363,7 +5383,7 @@ export class MockAPI { host_duration: 150, device_duration: 7, self_host_duration: 132, - self_device_duration: 7, + self_device_duration: 7 }, { name: 'aten::log_softmax', @@ -5371,7 +5391,7 @@ export class MockAPI { host_duration: 231, device_duration: 7, self_host_duration: 75, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::resize_', @@ -5379,7 +5399,7 @@ export class MockAPI { host_duration: 5, device_duration: 0, self_host_duration: 5, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::nll_loss_forward', @@ -5387,7 +5407,7 @@ export class MockAPI { host_duration: 266, device_duration: 4, self_host_duration: 243, - self_device_duration: 4, + self_device_duration: 4 }, { name: 'aten::nll_loss', @@ -5395,7 +5415,7 @@ export class MockAPI { host_duration: 300, device_duration: 4, self_host_duration: 34, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::nll_loss_nd', @@ -5403,7 +5423,7 @@ export class MockAPI { host_duration: 328, device_duration: 4, self_host_duration: 28, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::cross_entropy_loss', @@ -5411,9 +5431,9 @@ export class MockAPI { host_duration: 620, device_duration: 11, self_host_duration: 61, - self_device_duration: 0, - }, - ], + self_device_duration: 0 + } + ] }, right: { name: 'nn.Module: CrossEntropyLoss', @@ -5427,7 +5447,7 @@ export class MockAPI { host_duration: 1, device_duration: 0, self_host_duration: 1, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::_log_softmax', @@ -5435,7 +5455,7 @@ export class MockAPI { host_duration: 41, device_duration: 7, self_host_duration: 27, - self_device_duration: 7, + self_device_duration: 7 }, { name: 'aten::log_softmax', @@ -5443,7 +5463,7 @@ export class MockAPI { host_duration: 52, device_duration: 7, self_host_duration: 10, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::resize_', @@ -5451,7 +5471,7 @@ export class MockAPI { host_duration: 1, device_duration: 0, self_host_duration: 1, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::nll_loss_forward', @@ -5459,7 +5479,7 @@ export class MockAPI { host_duration: 49, device_duration: 4, self_host_duration: 34, - self_device_duration: 4, + self_device_duration: 4 }, { name: 'aten::nll_loss', @@ -5467,7 +5487,7 @@ export class MockAPI { host_duration: 53, device_duration: 4, self_host_duration: 4, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::nll_loss_nd', @@ -5475,7 +5495,7 @@ export class MockAPI { host_duration: 57, device_duration: 4, self_host_duration: 4, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::cross_entropy_loss', @@ -5483,11 +5503,11 @@ export class MockAPI { host_duration: 124, device_duration: 11, self_host_duration: 15, - self_device_duration: 0, - }, - ], + self_device_duration: 0 + } + ] }, - path: '0-16', + path: '0-16' }, { left: { @@ -5502,7 +5522,7 @@ export class MockAPI { host_duration: 39, device_duration: 0, self_host_duration: 39, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::zero_', @@ -5510,7 +5530,7 @@ export class MockAPI { host_duration: 5, device_duration: 0, self_host_duration: 5, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::zeros', @@ -5518,9 +5538,9 @@ export class MockAPI { host_duration: 109, device_duration: 0, self_host_duration: 65, - self_device_duration: 0, - }, - ], + self_device_duration: 0 + } + ] }, right: { name: 'aten::zeros', @@ -5534,7 +5554,7 @@ export class MockAPI { host_duration: 13, device_duration: 0, self_host_duration: 13, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::zero_', @@ -5542,7 +5562,7 @@ export class MockAPI { host_duration: 1, device_duration: 0, self_host_duration: 1, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::zeros', @@ -5550,11 +5570,11 @@ export class MockAPI { host_duration: 23, device_duration: 0, self_host_duration: 9, - self_device_duration: 0, - }, - ], + self_device_duration: 0 + } + ] }, - path: '0-17', + path: '0-17' }, { left: { @@ -5569,7 +5589,7 @@ export class MockAPI { host_duration: 44, device_duration: 0, self_host_duration: 44, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::fill_', @@ -5577,7 +5597,7 @@ export class MockAPI { host_duration: 7104, device_duration: 132, self_host_duration: 4941, - self_device_duration: 132, + self_device_duration: 132 }, { name: 'aten::zero_', @@ -5585,9 +5605,9 @@ export class MockAPI { host_duration: 14806, device_duration: 132, self_host_duration: 7702, - self_device_duration: 0, - }, - ], + self_device_duration: 0 + } + ] }, right: { name: 'Optimizer.zero_grad#SGD.zero_grad', @@ -5601,7 +5621,7 @@ export class MockAPI { host_duration: 6, device_duration: 0, self_host_duration: 6, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::fill_', @@ -5609,7 +5629,7 @@ export class MockAPI { host_duration: 1945, device_duration: 137, self_host_duration: 878, - self_device_duration: 137, + self_device_duration: 137 }, { name: 'aten::zero_', @@ -5617,11 +5637,11 @@ export class MockAPI { host_duration: 2805, device_duration: 137, self_host_duration: 860, - self_device_duration: 0, - }, - ], + self_device_duration: 0 + } + ] }, - path: '0-18', + path: '0-18' }, { left: { @@ -5636,7 +5656,7 @@ export class MockAPI { host_duration: 99, device_duration: 0, self_host_duration: 99, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::empty_like', @@ -5644,7 +5664,7 @@ export class MockAPI { host_duration: 149, device_duration: 0, self_host_duration: 50, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::fill_', @@ -5652,7 +5672,7 @@ export class MockAPI { host_duration: 49, device_duration: 1, self_host_duration: 34, - self_device_duration: 1, + self_device_duration: 1 }, { name: 'aten::ones_like', @@ -5660,9 +5680,9 @@ export class MockAPI { host_duration: 263, device_duration: 1, self_host_duration: 65, - self_device_duration: 0, - }, - ], + self_device_duration: 0 + } + ] }, right: { name: 'aten::ones_like', @@ -5676,7 +5696,7 @@ export class MockAPI { host_duration: 18, device_duration: 0, self_host_duration: 18, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::empty_like', @@ -5684,7 +5704,7 @@ export class MockAPI { host_duration: 24, device_duration: 0, self_host_duration: 6, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::fill_', @@ -5692,7 +5712,7 @@ export class MockAPI { host_duration: 20, device_duration: 1, self_host_duration: 8, - self_device_duration: 1, + self_device_duration: 1 }, { name: 'aten::ones_like', @@ -5700,11 +5720,11 @@ export class MockAPI { host_duration: 51, device_duration: 1, self_host_duration: 7, - self_device_duration: 0, - }, - ], + self_device_duration: 0 + } + ] }, - path: '0-19', + path: '0-19' }, { left: { @@ -5719,7 +5739,7 @@ export class MockAPI { host_duration: 58, device_duration: 1, self_host_duration: 36, - self_device_duration: 1, + self_device_duration: 1 }, { name: 'aten::zero_', @@ -5727,7 +5747,7 @@ export class MockAPI { host_duration: 112, device_duration: 1, self_host_duration: 54, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::nll_loss_backward', @@ -5735,7 +5755,7 @@ export class MockAPI { host_duration: 269, device_duration: 4, self_host_duration: 142, - self_device_duration: 3, + self_device_duration: 3 }, { name: 'NllLossBackward0', @@ -5743,7 +5763,7 @@ export class MockAPI { host_duration: 406, device_duration: 4, self_host_duration: 137, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'autograd::engine::evaluate_function: NllLossBackward0', @@ -5751,7 +5771,7 @@ export class MockAPI { host_duration: 522, device_duration: 4, self_host_duration: 116, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::_log_softmax_backward_data', @@ -5759,7 +5779,7 @@ export class MockAPI { host_duration: 109, device_duration: 9, self_host_duration: 91, - self_device_duration: 9, + self_device_duration: 9 }, { name: 'LogSoftmaxBackward0', @@ -5767,17 +5787,18 @@ export class MockAPI { host_duration: 178, device_duration: 9, self_host_duration: 69, - self_device_duration: 0, + self_device_duration: 0 }, { - name: 'autograd::engine::evaluate_function: LogSoftmaxBackward0', + name: + 'autograd::engine::evaluate_function: LogSoftmaxBackward0', calls: 1, host_duration: 283, device_duration: 9, self_host_duration: 105, - self_device_duration: 0, - }, - ], + self_device_duration: 0 + } + ] }, right: { name: 'nn.Module: CrossEntropyLoss.backward', @@ -5791,7 +5812,7 @@ export class MockAPI { host_duration: 33, device_duration: 1, self_host_duration: 12, - self_device_duration: 1, + self_device_duration: 1 }, { name: 'aten::zero_', @@ -5799,7 +5820,7 @@ export class MockAPI { host_duration: 41, device_duration: 1, self_host_duration: 8, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::nll_loss_backward', @@ -5807,7 +5828,7 @@ export class MockAPI { host_duration: 93, device_duration: 4, self_host_duration: 41, - self_device_duration: 3, + self_device_duration: 3 }, { name: 'NllLossBackward0', @@ -5815,7 +5836,7 @@ export class MockAPI { host_duration: 185, device_duration: 4, self_host_duration: 92, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'autograd::engine::evaluate_function: NllLossBackward0', @@ -5823,7 +5844,7 @@ export class MockAPI { host_duration: 211, device_duration: 4, self_host_duration: 26, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::_log_softmax_backward_data', @@ -5831,7 +5852,7 @@ export class MockAPI { host_duration: 36, device_duration: 9, self_host_duration: 22, - self_device_duration: 9, + self_device_duration: 9 }, { name: 'LogSoftmaxBackward0', @@ -5839,19 +5860,20 @@ export class MockAPI { host_duration: 45, device_duration: 9, self_host_duration: 9, - self_device_duration: 0, + self_device_duration: 0 }, { - name: 'autograd::engine::evaluate_function: LogSoftmaxBackward0', + name: + 'autograd::engine::evaluate_function: LogSoftmaxBackward0', calls: 1, host_duration: 62, device_duration: 9, self_host_duration: 17, - self_device_duration: 0, - }, - ], + self_device_duration: 0 + } + ] }, - path: '0-20', + path: '0-20' }, { left: { @@ -5866,7 +5888,7 @@ export class MockAPI { host_duration: 67, device_duration: 0, self_host_duration: 67, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::transpose', @@ -5874,7 +5896,7 @@ export class MockAPI { host_duration: 255, device_duration: 0, self_host_duration: 204, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::t', @@ -5882,7 +5904,7 @@ export class MockAPI { host_duration: 430, device_duration: 0, self_host_duration: 175, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::mm', @@ -5890,7 +5912,7 @@ export class MockAPI { host_duration: 323, device_duration: 68, self_host_duration: 265, - self_device_duration: 68, + self_device_duration: 68 }, { name: 'AddmmBackward0', @@ -5898,7 +5920,7 @@ export class MockAPI { host_duration: 844, device_duration: 68, self_host_duration: 209, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::sum', @@ -5906,7 +5928,7 @@ export class MockAPI { host_duration: 197, device_duration: 7, self_host_duration: 175, - self_device_duration: 7, + self_device_duration: 7 }, { name: 'aten::view', @@ -5914,7 +5936,7 @@ export class MockAPI { host_duration: 963, device_duration: 0, self_host_duration: 963, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'autograd::engine::evaluate_function: AddmmBackward0', @@ -5922,7 +5944,7 @@ export class MockAPI { host_duration: 1377, device_duration: 75, self_host_duration: 296, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::add_', @@ -5930,7 +5952,7 @@ export class MockAPI { host_duration: 12404, device_duration: 496, self_host_duration: 9659, - self_device_duration: 496, + self_device_duration: 496 }, { name: 'torch::autograd::AccumulateGrad', @@ -5938,15 +5960,16 @@ export class MockAPI { host_duration: 20417, device_duration: 496, self_host_duration: 8013, - self_device_duration: 0, + self_device_duration: 0 }, { - name: 'autograd::engine::evaluate_function: torch::autograd::AccumulateGrad', + name: + 'autograd::engine::evaluate_function: torch::autograd::AccumulateGrad', calls: 161, host_duration: 35211, device_duration: 496, self_host_duration: 14794, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'TBackward0', @@ -5954,7 +5977,7 @@ export class MockAPI { host_duration: 152, device_duration: 0, self_host_duration: 34, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'autograd::engine::evaluate_function: TBackward0', @@ -5962,7 +5985,7 @@ export class MockAPI { host_duration: 231, device_duration: 0, self_host_duration: 79, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::_reshape_alias', @@ -5970,7 +5993,7 @@ export class MockAPI { host_duration: 35, device_duration: 0, self_host_duration: 35, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::reshape', @@ -5978,7 +6001,7 @@ export class MockAPI { host_duration: 91, device_duration: 0, self_host_duration: 56, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'ReshapeAliasBackward0', @@ -5986,15 +6009,16 @@ export class MockAPI { host_duration: 133, device_duration: 0, self_host_duration: 42, - self_device_duration: 0, + self_device_duration: 0 }, { - name: 'autograd::engine::evaluate_function: ReshapeAliasBackward0', + name: + 'autograd::engine::evaluate_function: ReshapeAliasBackward0', calls: 1, host_duration: 205, device_duration: 0, self_host_duration: 72, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::expand', @@ -6002,7 +6026,7 @@ export class MockAPI { host_duration: 95, device_duration: 0, self_host_duration: 79, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::to', @@ -6010,7 +6034,7 @@ export class MockAPI { host_duration: 7, device_duration: 0, self_host_duration: 7, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::div', @@ -6018,7 +6042,7 @@ export class MockAPI { host_duration: 324, device_duration: 37, self_host_duration: 301, - self_device_duration: 37, + self_device_duration: 37 }, { name: 'MeanBackward1', @@ -6026,7 +6050,7 @@ export class MockAPI { host_duration: 547, device_duration: 37, self_host_duration: 121, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'autograd::engine::evaluate_function: MeanBackward1', @@ -6034,7 +6058,7 @@ export class MockAPI { host_duration: 662, device_duration: 37, self_host_duration: 115, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::threshold_backward', @@ -6042,7 +6066,7 @@ export class MockAPI { host_duration: 6880, device_duration: 9012, self_host_duration: 6037, - self_device_duration: 9012, + self_device_duration: 9012 }, { name: 'ReluBackward0', @@ -6050,7 +6074,7 @@ export class MockAPI { host_duration: 10536, device_duration: 9012, self_host_duration: 3656, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'autograd::engine::evaluate_function: ReluBackward0', @@ -6058,7 +6082,7 @@ export class MockAPI { host_duration: 16666, device_duration: 9012, self_host_duration: 6130, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'AddBackward0', @@ -6066,7 +6090,7 @@ export class MockAPI { host_duration: 122, device_duration: 0, self_host_duration: 122, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'autograd::engine::evaluate_function: AddBackward0', @@ -6074,7 +6098,7 @@ export class MockAPI { host_duration: 1278, device_duration: 0, self_host_duration: 1156, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::empty', @@ -6082,7 +6106,7 @@ export class MockAPI { host_duration: 21126, device_duration: 0, self_host_duration: 21126, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::cudnn_batch_norm_backward', @@ -6090,7 +6114,7 @@ export class MockAPI { host_duration: 30875, device_duration: 22166, self_host_duration: 17909, - self_device_duration: 22166, + self_device_duration: 22166 }, { name: 'CudnnBatchNormBackward0', @@ -6098,15 +6122,16 @@ export class MockAPI { host_duration: 34355, device_duration: 22166, self_host_duration: 3480, - self_device_duration: 0, + self_device_duration: 0 }, { - name: 'autograd::engine::evaluate_function: CudnnBatchNormBackward0', + name: + 'autograd::engine::evaluate_function: CudnnBatchNormBackward0', calls: 53, host_duration: 44006, device_duration: 22166, self_host_duration: 9651, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::cudnn_convolution_backward_input', @@ -6114,7 +6139,7 @@ export class MockAPI { host_duration: 20496, device_duration: 37887, self_host_duration: 15516, - self_device_duration: 37887, + self_device_duration: 37887 }, { name: 'aten::cudnn_convolution_backward_weight', @@ -6122,7 +6147,7 @@ export class MockAPI { host_duration: 22878, device_duration: 44271, self_host_duration: 13672, - self_device_duration: 44271, + self_device_duration: 44271 }, { name: 'aten::cudnn_convolution_backward', @@ -6130,7 +6155,7 @@ export class MockAPI { host_duration: 50961, device_duration: 82158, self_host_duration: 7587, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'CudnnConvolutionBackward0', @@ -6138,15 +6163,16 @@ export class MockAPI { host_duration: 54406, device_duration: 82158, self_host_duration: 3445, - self_device_duration: 0, + self_device_duration: 0 }, { - name: 'autograd::engine::evaluate_function: CudnnConvolutionBackward0', + name: + 'autograd::engine::evaluate_function: CudnnConvolutionBackward0', calls: 53, host_duration: 64877, device_duration: 87386, self_host_duration: 8284, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::add', @@ -6154,7 +6180,7 @@ export class MockAPI { host_duration: 2187, device_duration: 5228, self_host_duration: 1909, - self_device_duration: 5228, + self_device_duration: 5228 }, { name: 'aten::fill_', @@ -6162,7 +6188,7 @@ export class MockAPI { host_duration: 53, device_duration: 230, self_host_duration: 36, - self_device_duration: 230, + self_device_duration: 230 }, { name: 'aten::zero_', @@ -6170,7 +6196,7 @@ export class MockAPI { host_duration: 96, device_duration: 230, self_host_duration: 43, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::max_pool2d_with_indices_backward', @@ -6178,7 +6204,7 @@ export class MockAPI { host_duration: 237, device_duration: 1504, self_host_duration: 129, - self_device_duration: 1274, + self_device_duration: 1274 }, { name: 'MaxPool2DWithIndicesBackward0', @@ -6186,17 +6212,18 @@ export class MockAPI { host_duration: 295, device_duration: 1504, self_host_duration: 58, - self_device_duration: 0, + self_device_duration: 0 }, { - name: 'autograd::engine::evaluate_function: MaxPool2DWithIndicesBackward0', + name: + 'autograd::engine::evaluate_function: MaxPool2DWithIndicesBackward0', calls: 1, host_duration: 411, device_duration: 1504, self_host_duration: 116, - self_device_duration: 0, - }, - ], + self_device_duration: 0 + } + ] }, right: { name: 'nn.Module: ResNet.backward', @@ -6210,7 +6237,7 @@ export class MockAPI { host_duration: 7, device_duration: 0, self_host_duration: 7, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::transpose', @@ -6218,7 +6245,7 @@ export class MockAPI { host_duration: 29, device_duration: 0, self_host_duration: 23, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::t', @@ -6226,7 +6253,7 @@ export class MockAPI { host_duration: 53, device_duration: 0, self_host_duration: 24, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::mm', @@ -6234,7 +6261,7 @@ export class MockAPI { host_duration: 144, device_duration: 67, self_host_duration: 96, - self_device_duration: 67, + self_device_duration: 67 }, { name: 'AddmmBackward0', @@ -6242,7 +6269,7 @@ export class MockAPI { host_duration: 208, device_duration: 67, self_host_duration: 24, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::sum', @@ -6250,7 +6277,7 @@ export class MockAPI { host_duration: 45, device_duration: 7, self_host_duration: 30, - self_device_duration: 7, + self_device_duration: 7 }, { name: 'aten::view', @@ -6258,7 +6285,7 @@ export class MockAPI { host_duration: 163, device_duration: 0, self_host_duration: 163, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'autograd::engine::evaluate_function: AddmmBackward0', @@ -6266,7 +6293,7 @@ export class MockAPI { host_duration: 295, device_duration: 74, self_host_duration: 38, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::add_', @@ -6274,7 +6301,7 @@ export class MockAPI { host_duration: 4103, device_duration: 535, self_host_duration: 2037, - self_device_duration: 535, + self_device_duration: 535 }, { name: 'torch::autograd::AccumulateGrad', @@ -6282,15 +6309,16 @@ export class MockAPI { host_duration: 5183, device_duration: 535, self_host_duration: 1080, - self_device_duration: 0, + self_device_duration: 0 }, { - name: 'autograd::engine::evaluate_function: torch::autograd::AccumulateGrad', + name: + 'autograd::engine::evaluate_function: torch::autograd::AccumulateGrad', calls: 161, host_duration: 7655, device_duration: 535, self_host_duration: 2472, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'TBackward0', @@ -6298,7 +6326,7 @@ export class MockAPI { host_duration: 16, device_duration: 0, self_host_duration: 3, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'autograd::engine::evaluate_function: TBackward0', @@ -6306,7 +6334,7 @@ export class MockAPI { host_duration: 24, device_duration: 0, self_host_duration: 8, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::_reshape_alias', @@ -6314,7 +6342,7 @@ export class MockAPI { host_duration: 5, device_duration: 0, self_host_duration: 5, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::reshape', @@ -6322,7 +6350,7 @@ export class MockAPI { host_duration: 10, device_duration: 0, self_host_duration: 5, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'ReshapeAliasBackward0', @@ -6330,15 +6358,16 @@ export class MockAPI { host_duration: 17, device_duration: 0, self_host_duration: 7, - self_device_duration: 0, + self_device_duration: 0 }, { - name: 'autograd::engine::evaluate_function: ReshapeAliasBackward0', + name: + 'autograd::engine::evaluate_function: ReshapeAliasBackward0', calls: 1, host_duration: 27, device_duration: 0, self_host_duration: 10, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::expand', @@ -6346,7 +6375,7 @@ export class MockAPI { host_duration: 10, device_duration: 0, self_host_duration: 9, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::to', @@ -6354,7 +6383,7 @@ export class MockAPI { host_duration: 1, device_duration: 0, self_host_duration: 1, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::div', @@ -6362,7 +6391,7 @@ export class MockAPI { host_duration: 63, device_duration: 37, self_host_duration: 45, - self_device_duration: 37, + self_device_duration: 37 }, { name: 'MeanBackward1', @@ -6370,7 +6399,7 @@ export class MockAPI { host_duration: 83, device_duration: 37, self_host_duration: 9, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'autograd::engine::evaluate_function: MeanBackward1', @@ -6378,7 +6407,7 @@ export class MockAPI { host_duration: 99, device_duration: 37, self_host_duration: 16, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::threshold_backward', @@ -6386,7 +6415,7 @@ export class MockAPI { host_duration: 1863, device_duration: 9003, self_host_duration: 1203, - self_device_duration: 9003, + self_device_duration: 9003 }, { name: 'ReluBackward0', @@ -6394,7 +6423,7 @@ export class MockAPI { host_duration: 2330, device_duration: 9003, self_host_duration: 467, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'autograd::engine::evaluate_function: ReluBackward0', @@ -6402,7 +6431,7 @@ export class MockAPI { host_duration: 3313, device_duration: 9003, self_host_duration: 983, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'AddBackward0', @@ -6410,7 +6439,7 @@ export class MockAPI { host_duration: 14, device_duration: 0, self_host_duration: 14, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'autograd::engine::evaluate_function: AddBackward0', @@ -6418,7 +6447,7 @@ export class MockAPI { host_duration: 135, device_duration: 0, self_host_duration: 121, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::empty', @@ -6426,7 +6455,7 @@ export class MockAPI { host_duration: 4638, device_duration: 0, self_host_duration: 4638, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::cudnn_batch_norm_backward', @@ -6434,7 +6463,7 @@ export class MockAPI { host_duration: 5047, device_duration: 22244, self_host_duration: 2219, - self_device_duration: 22244, + self_device_duration: 22244 }, { name: 'CudnnBatchNormBackward0', @@ -6442,15 +6471,16 @@ export class MockAPI { host_duration: 5637, device_duration: 22244, self_host_duration: 590, - self_device_duration: 0, + self_device_duration: 0 }, { - name: 'autograd::engine::evaluate_function: CudnnBatchNormBackward0', + name: + 'autograd::engine::evaluate_function: CudnnBatchNormBackward0', calls: 53, host_duration: 7407, device_duration: 22244, self_host_duration: 1770, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::cudnn_convolution_backward_input', @@ -6458,7 +6488,7 @@ export class MockAPI { host_duration: 9345, device_duration: 37854, self_host_duration: 6945, - self_device_duration: 37854, + self_device_duration: 37854 }, { name: 'aten::cudnn_convolution_backward_weight', @@ -6466,7 +6496,7 @@ export class MockAPI { host_duration: 9886, device_duration: 44650, self_host_duration: 5378, - self_device_duration: 44650, + self_device_duration: 44650 }, { name: 'aten::cudnn_convolution_backward', @@ -6474,7 +6504,7 @@ export class MockAPI { host_duration: 20453, device_duration: 82504, self_host_duration: 1222, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'CudnnConvolutionBackward0', @@ -6482,15 +6512,16 @@ export class MockAPI { host_duration: 21000, device_duration: 82504, self_host_duration: 547, - self_device_duration: 0, + self_device_duration: 0 }, { - name: 'autograd::engine::evaluate_function: CudnnConvolutionBackward0', + name: + 'autograd::engine::evaluate_function: CudnnConvolutionBackward0', calls: 53, host_duration: 23024, device_duration: 87731, self_host_duration: 1440, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::add', @@ -6498,7 +6529,7 @@ export class MockAPI { host_duration: 584, device_duration: 5227, self_host_duration: 374, - self_device_duration: 5227, + self_device_duration: 5227 }, { name: 'aten::fill_', @@ -6506,7 +6537,7 @@ export class MockAPI { host_duration: 26, device_duration: 230, self_host_duration: 12, - self_device_duration: 230, + self_device_duration: 230 }, { name: 'aten::zero_', @@ -6514,7 +6545,7 @@ export class MockAPI { host_duration: 33, device_duration: 230, self_host_duration: 7, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::max_pool2d_with_indices_backward', @@ -6522,7 +6553,7 @@ export class MockAPI { host_duration: 73, device_duration: 1513, self_host_duration: 30, - self_device_duration: 1283, + self_device_duration: 1283 }, { name: 'MaxPool2DWithIndicesBackward0', @@ -6530,19 +6561,20 @@ export class MockAPI { host_duration: 83, device_duration: 1513, self_host_duration: 10, - self_device_duration: 0, + self_device_duration: 0 }, { - name: 'autograd::engine::evaluate_function: MaxPool2DWithIndicesBackward0', + name: + 'autograd::engine::evaluate_function: MaxPool2DWithIndicesBackward0', calls: 1, host_duration: 106, device_duration: 1513, self_host_duration: 23, - self_device_duration: 0, - }, - ], + self_device_duration: 0 + } + ] }, - path: '0-21', + path: '0-21' }, { left: { @@ -6557,7 +6589,7 @@ export class MockAPI { host_duration: 87, device_duration: 0, self_host_duration: 87, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::zero_', @@ -6565,7 +6597,7 @@ export class MockAPI { host_duration: 4, device_duration: 0, self_host_duration: 4, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::zeros', @@ -6573,9 +6605,9 @@ export class MockAPI { host_duration: 160, device_duration: 0, self_host_duration: 69, - self_device_duration: 0, - }, - ], + self_device_duration: 0 + } + ] }, right: { name: 'aten::zeros', @@ -6589,7 +6621,7 @@ export class MockAPI { host_duration: 105, device_duration: 0, self_host_duration: 105, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::zero_', @@ -6597,7 +6629,7 @@ export class MockAPI { host_duration: 2, device_duration: 0, self_host_duration: 2, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::zeros', @@ -6605,11 +6637,11 @@ export class MockAPI { host_duration: 119, device_duration: 0, self_host_duration: 12, - self_device_duration: 0, - }, - ], + self_device_duration: 0 + } + ] }, - path: '0-22', + path: '0-22' }, { left: { @@ -6624,7 +6656,7 @@ export class MockAPI { host_duration: 40, device_duration: 0, self_host_duration: 40, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::mul_', @@ -6632,7 +6664,7 @@ export class MockAPI { host_duration: 11945, device_duration: 401, self_host_duration: 9568, - self_device_duration: 401, + self_device_duration: 401 }, { name: 'aten::add_', @@ -6640,9 +6672,9 @@ export class MockAPI { host_duration: 22480, device_duration: 894, self_host_duration: 17805, - self_device_duration: 894, - }, - ], + self_device_duration: 894 + } + ] }, right: { name: 'Optimizer.step#SGD.step', @@ -6656,7 +6688,7 @@ export class MockAPI { host_duration: 8, device_duration: 0, self_host_duration: 8, - self_device_duration: 0, + self_device_duration: 0 }, { name: 'aten::mul_', @@ -6664,7 +6696,7 @@ export class MockAPI { host_duration: 3440, device_duration: 404, self_host_duration: 1824, - self_device_duration: 404, + self_device_duration: 404 }, { name: 'aten::add_', @@ -6672,13 +6704,13 @@ export class MockAPI { host_duration: 6161, device_duration: 894, self_host_duration: 3186, - self_device_duration: 894, - }, - ], - }, - path: '0-23', - }, - ], - }); + self_device_duration: 894 + } + ] + }, + path: '0-23' + } + ] + }) } } diff --git a/plugins/tensorboard-plugins/tb_plugin/fe/src/app.tsx b/plugins/tensorboard-plugins/tb_plugin/fe/src/app.tsx index 19eb4b112529073c6b8db9a86b8d68a7633598db..c8cd2ddec26fee10f0a6d448a2051e749ae20696 100644 --- a/plugins/tensorboard-plugins/tb_plugin/fe/src/app.tsx +++ b/plugins/tensorboard-plugins/tb_plugin/fe/src/app.tsx @@ -15,52 +15,51 @@ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. - * + * * Modifications: Add visualization of PyTorch Ascend profiling. *--------------------------------------------------------------------------------------------*/ -import Box from '@material-ui/core/Box'; -import Card from '@material-ui/core/Card'; -import CardContent from '@material-ui/core/CardContent'; -import CardHeader from '@material-ui/core/CardHeader'; -import ClickAwayListener from '@material-ui/core/ClickAwayListener'; -import CssBaseline from '@material-ui/core/CssBaseline'; -import Divider from '@material-ui/core/Divider'; -import Drawer from '@material-ui/core/Drawer'; -import Fab from '@material-ui/core/Fab'; -import FormControl from '@material-ui/core/FormControl'; -import IconButton from '@material-ui/core/IconButton'; -import ListSubheader from '@material-ui/core/ListSubheader'; -import MenuItem from '@material-ui/core/MenuItem'; -import Select, { SelectProps } from '@material-ui/core/Select'; -import { makeStyles } from '@material-ui/core/styles'; -import Tab from '@material-ui/core/Tab'; -import Tabs from '@material-ui/core/Tabs'; -import Typography from '@material-ui/core/Typography'; -import ChevronLeftIcon from '@material-ui/icons/ChevronLeft'; -import ChevronRightIcon from '@material-ui/icons/ChevronRight'; -import { message } from 'antd'; -import 'antd/es/button/style/css'; -import 'antd/es/list/style/css'; -import 'antd/es/table/style/css'; -import clsx from 'clsx'; -import * as React from 'react'; -import * as api from './api'; -import { AccuracyLeftPanel } from './components/Accuracy/AccuracyLeftPanel'; -import { FileInfo } from './components/Accuracy/entity'; -import { LossComparison } from './components/Accuracy/LossComparison'; -import { DiffOverview } from './components/DiffOverview'; -import { DistributedView } from './components/DistributedView'; -import { FullCircularProgress } from './components/FullCircularProgress'; -import { Kernel as KernelView } from './components/Kernel'; -import { MemoryView } from './components/MemoryView'; -import { ModuleView } from './components/ModuleView'; -import { Operator as OperatorView } from './components/Operator'; -import { Overview as OverviewPage } from './components/Overview'; -import { TraceView } from './components/TraceView'; -import { setup } from './setup'; -import './styles.css'; -import { firstOrUndefined, sleep } from './utils'; +import Box from '@material-ui/core/Box' +import Card from '@material-ui/core/Card' +import CardContent from '@material-ui/core/CardContent' +import CardHeader from '@material-ui/core/CardHeader' +import ClickAwayListener from '@material-ui/core/ClickAwayListener' +import CssBaseline from '@material-ui/core/CssBaseline' +import Divider from '@material-ui/core/Divider' +import Drawer from '@material-ui/core/Drawer' +import Fab from '@material-ui/core/Fab' +import FormControl from '@material-ui/core/FormControl' +import IconButton from '@material-ui/core/IconButton' +import ListSubheader from '@material-ui/core/ListSubheader' +import MenuItem from '@material-ui/core/MenuItem' +import Select, { SelectProps } from '@material-ui/core/Select' +import { makeStyles } from '@material-ui/core/styles' +import Tab from '@material-ui/core/Tab' +import Tabs from '@material-ui/core/Tabs' +import Typography from '@material-ui/core/Typography' +import ChevronLeftIcon from '@material-ui/icons/ChevronLeft' +import ChevronRightIcon from '@material-ui/icons/ChevronRight' +import 'antd/es/button/style/css' +import 'antd/es/list/style/css' +import 'antd/es/table/style/css' +import clsx from 'clsx' +import * as React from 'react' +import * as api from './api' +import { AccuracyLeftPanel } from './components/Accuracy/AccuracyLeftPanel' +import { FileInfo } from './components/Accuracy/entity' +import { LossComparison } from './components/Accuracy/LossComparison' +import { DiffOverview } from './components/DiffOverview' +import { DistributedView } from './components/DistributedView' +import { FullCircularProgress } from './components/FullCircularProgress' +import { Kernel } from './components/Kernel' +import { MemoryView } from './components/MemoryView' +import { ModuleView } from './components/ModuleView' +import { Operator } from './components/Operator' +import { Overview } from './components/Overview' +import { TraceView } from './components/TraceView' +import { setup } from './setup' +import './styles.css' +import { firstOrUndefined, sleep } from './utils' export enum Views { Overview = 'Overview', @@ -70,10 +69,10 @@ export enum Views { Distributed = 'Distributed', Memory = 'Memory', Module = 'Module', - Lightning = 'Lightning', + Lightning = 'Lightning' } -const viewNames = { +const ViewNames = { [Views.Overview]: Views.Overview, [Views.Operator]: Views.Operator, [Views.Kernel]: 'Kernel', @@ -81,59 +80,61 @@ const viewNames = { [Views.Distributed]: Views.Distributed, [Views.Memory]: Views.Memory, [Views.Module]: Views.Module, - [Views.Lightning]: Views.Lightning, -}; + [Views.Lightning]: Views.Lightning +} + +const accViews = ['Loss Comparison'] -const drawerWidth = 340; +const drawerWidth = 340 const useStyles = makeStyles((theme) => ({ root: { display: 'flex', - height: '100%', + height: '100%' }, appBar: { zIndex: theme.zIndex.drawer + 1, transition: theme.transitions.create(['width', 'margin'], { easing: theme.transitions.easing.sharp, - duration: theme.transitions.duration.leavingScreen, - }), + duration: theme.transitions.duration.leavingScreen + }) }, appBarShift: { marginLeft: drawerWidth, width: `calc(100% - ${drawerWidth}px)`, transition: theme.transitions.create(['width', 'margin'], { easing: theme.transitions.easing.sharp, - duration: theme.transitions.duration.enteringScreen, - }), + duration: theme.transitions.duration.enteringScreen + }) }, menuButton: { - marginRight: 36, + marginRight: 36 }, hide: { - display: 'none', + display: 'none' }, drawer: { width: drawerWidth, flexShrink: 0, - whiteSpace: 'nowrap', + whiteSpace: 'nowrap' }, drawerOpen: { width: drawerWidth, zIndex: 999, transition: theme.transitions.create('width', { easing: theme.transitions.easing.sharp, - duration: theme.transitions.duration.enteringScreen, - }), + duration: theme.transitions.duration.enteringScreen + }) }, drawerClose: { transition: theme.transitions.create('width', { easing: theme.transitions.easing.sharp, - duration: theme.transitions.duration.leavingScreen, + duration: theme.transitions.duration.leavingScreen }), overflowX: 'hidden', width: 0, [theme.breakpoints.up('sm')]: { - width: 0, - }, + width: 0 + } }, toolbar: { display: 'flex', @@ -141,304 +142,322 @@ const useStyles = makeStyles((theme) => ({ justifyContent: 'flex-end', padding: theme.spacing(0, 1), // necessary for content to be below app bar - ...theme.mixins.toolbar, + ...theme.mixins.toolbar }, content: { flexGrow: 1, padding: theme.spacing(3), - overflowX: 'hidden', + overflowX: 'hidden' }, formControl: { margin: theme.spacing(1), - minWidth: 120, + minWidth: 120 }, fab: { marginLeft: theme.spacing(1), marginTop: theme.spacing(1), - position: 'absolute', + position: 'absolute' }, iconButton: { - padding: '8px', - }, -})); + padding: '8px' + } +})) -export const App = (): JSX.Element => { - const classes = useStyles(); +export const App = () => { + const classes = useStyles() // #region - State - const [selectedTab, setSelectedTab] = React.useState(0); - - const [run, setRun] = React.useState(''); - const [runs, setRuns] = React.useState([]); - const [runsLoading, setRunsLoading] = React.useState(true); - - const [workers, setWorkers] = React.useState([]); - const [worker, setWorker] = React.useState(''); - - const [spans, setSpans] = React.useState([]); - const [span, setSpan] = React.useState(''); - const [views, setViews] = React.useState([]); - const [view, setView] = React.useState(''); - const [loaded, setLoaded] = React.useState(false); - const iframeRef = React.useRef(null); - const [deviceTarget, setDeviceTarget] = React.useState('GPU'); + const [selectedTab, setSelectedTab] = React.useState(0) + + const [run, setRun] = React.useState('') + const [runs, setRuns] = React.useState([]) + const [runsLoading, setRunsLoading] = React.useState(true) + + const [workers, setWorkers] = React.useState([]) + const [worker, setWorker] = React.useState('') + + const [spans, setSpans] = React.useState([]) + const [span, setSpan] = React.useState('') + + const [views, setViews] = React.useState([]) + const [view, setView] = React.useState('') + const [loaded, setLoaded] = React.useState(false) + const iframeRef = React.useRef(null) + const [deviceTarget, setDeviceTarget] = React.useState('GPU') + + const [diffLeftWorkerOptions, setDiffLeftWorkerOptions] = React.useState< + string[] + >([]) + const [diffLeftSpansOptions, setDiffLeftSpansOptions] = React.useState< + string[] + >([]) + const [diffLeftRun, setDiffLeftRun] = React.useState('') + const [diffLeftWorker, setDiffLeftWorker] = React.useState('') + const [diffLeftSpan, setDiffLeftSpan] = React.useState('') + + const [diffRightWorkerOptions, setDiffRightWorkerOptions] = React.useState< + string[] + >([]) + const [diffRightSpansOptions, setDiffRightSpansOptions] = React.useState< + string[] + >([]) + const [diffRightRun, setDiffRightRun] = React.useState('') + const [diffRightWorker, setDiffRightWorker] = React.useState('') + const [diffRightSpan, setDiffRightSpan] = React.useState('') + + const [open, setOpen] = React.useState(true) + + const [topTab, setTopTab] = React.useState(0) + const [fileList, setFileList] = React.useState([]) + const [uploadedCount, setUploadedCount] = React.useState(0) - const [diffLeftWorkerOptions, setDiffLeftWorkerOptions] = React.useState([]); - const [diffLeftSpansOptions, setDiffLeftSpansOptions] = React.useState([]); - const [diffLeftRun, setDiffLeftRun] = React.useState(''); - const [diffLeftWorker, setDiffLeftWorker] = React.useState(''); - const [diffLeftSpan, setDiffLeftSpan] = React.useState(''); - - const [diffRightWorkerOptions, setDiffRightWorkerOptions] = React.useState([]); - const [diffRightSpansOptions, setDiffRightSpansOptions] = React.useState([]); - const [diffRightRun, setDiffRightRun] = React.useState(''); - const [diffRightWorker, setDiffRightWorker] = React.useState(''); - const [diffRightSpan, setDiffRightSpan] = React.useState(''); - - const [open, setOpen] = React.useState(true); - - const [topTab, setTopTab] = React.useState(0); - const [fileList, setFileList] = React.useState([]); - const [uploadedCount, setUploadedCount] = React.useState(0); // #endregion + // #endregion React.useEffect(() => { - setup() - .catch(() => { - message.warning('google chart is not supported offline'); - }) - .finally(() => { - setLoaded(true); - }); - }, []); - - const continuouslyFetchRuns = async (): Promise => { + setup().catch(() => { + console.log('google chart is not supported offline') + }).finally(() => { + setLoaded(true) + }) + }, []) + + const continuouslyFetchRuns = async () => { while (true) { try { - const result = await api.defaultApi.runsGet(); - setRuns(result.runs); - setRunsLoading(result.loading); + const runs = await api.defaultApi.runsGet() + setRuns(runs.runs) + setRunsLoading(runs.loading) } catch (e) { - message.warning(`Cannot fetch runs: ${e}`); + console.info('Cannot fetch runs: ', e) } - await sleep(5000); + await sleep(5000) } - }; + } React.useEffect(() => { - continuouslyFetchRuns(); - }, []); + continuouslyFetchRuns() + }, []) React.useEffect(() => { if (!run || !runs.includes(run)) { - setRun(firstOrUndefined(runs) ?? ''); + setRun(firstOrUndefined(runs) ?? '') } - }, [runs]); // #region - Diff Left + }, [runs]) + + // #region - Diff Left React.useEffect(() => { if (diffLeftRun) { - api.defaultApi.workersGet(diffLeftRun, Views.Overview).then((data) => { - setDiffLeftWorkerOptions(data); - }); + api.defaultApi.workersGet(diffLeftRun, Views.Overview).then((workers) => { + setDiffLeftWorkerOptions(workers) + }) } - }, [diffLeftRun]); + }, [diffLeftRun]) React.useEffect(() => { if (diffLeftRun && diffLeftWorker) { - api.defaultApi.spansGet(diffLeftRun, diffLeftWorker).then((data) => { - setDiffLeftSpansOptions(data); - }); + api.defaultApi.spansGet(diffLeftRun, diffLeftWorker).then((spans) => { + setDiffLeftSpansOptions(spans) + }) } - }, [diffLeftRun, diffLeftWorker]); + }, [diffLeftRun, diffLeftWorker]) // #endregion + // #region - Diff Right + React.useEffect(() => { if (diffRightRun) { - api.defaultApi.workersGet(diffRightRun, Views.Overview).then((data) => { - setDiffRightWorkerOptions(data); - }); + api.defaultApi + .workersGet(diffRightRun, Views.Overview) + .then((workers) => { + setDiffRightWorkerOptions(workers) + }) } - }, [diffRightRun]); + }, [diffRightRun]) React.useEffect(() => { if (diffRightRun && diffRightWorker) { - api.defaultApi.spansGet(diffRightRun, diffRightWorker).then((data) => { - setDiffRightSpansOptions(data); - }); + api.defaultApi.spansGet(diffRightRun, diffRightWorker).then((spans) => { + setDiffRightSpansOptions(spans) + }) } - }, [diffRightRun, diffRightWorker]); + }, [diffRightRun, diffRightWorker]) // #endregion + // #region - normal + React.useEffect(() => { if (run) { api.defaultApi.viewsGet(run).then((rawViews) => { - const result = rawViews.views.map((v) => Views[Views[v as Views]]).filter(Boolean); - setDeviceTarget(rawViews.device_target); - setViews(result); - }); + const views = rawViews.views + .map((v) => Views[Views[v as Views]]) + .filter(Boolean) + setDeviceTarget(rawViews.device_target) + setViews(views) + }) } - }, [run]); + }, [run]) React.useEffect(() => { - setView(firstOrUndefined(views) ?? ''); - }, [views]); + setView(firstOrUndefined(views) ?? '') + }, [views]) React.useEffect(() => { if (run && view) { - api.defaultApi.workersGet(run, view).then((data) => { - setWorkers(data); - }); + api.defaultApi.workersGet(run, view).then((workers) => { + setWorkers(workers) + }) } - }, [run, view]); + }, [run, view]) React.useEffect(() => { - setWorker(firstOrUndefined(workers) ?? ''); - }, [workers]); + setWorker(firstOrUndefined(workers) ?? '') + }, [workers]) React.useEffect(() => { if (run && worker) { - api.defaultApi.spansGet(run, worker).then((data) => { - setSpans(data); - }); + api.defaultApi.spansGet(run, worker).then((spans) => { + setSpans(spans) + }) } - }, [run, worker]); + }, [run, worker]) React.useEffect(() => { - setSpan(firstOrUndefined(spans) ?? ''); - }, [spans]); + setSpan(firstOrUndefined(spans) ?? '') + }, [spans]) // #endregion // #region - Event Handler - const handleTabChange = (event: React.ChangeEvent>, value: any): void => { - setSelectedTab(value as number); - }; + const handleTabChange = (event: React.ChangeEvent<{}>, value: any) => { + setSelectedTab(value as number) + } - const handleTopTabChange = (event: React.ChangeEvent>, value: any): void => { - setTopTab(value as number); - }; + const handleTopTabChange = (event: React.ChangeEvent<{}>, value: any) => { + setTopTab(value as number) + } const handleRunChange: SelectProps['onChange'] = (event) => { - setRun(event.target.value as string); - setView(''); - setWorker(''); - setSpan(''); - }; + setRun(event.target.value as string) + setView('') + setWorker('') + setSpan('') + } const handleViewChange: SelectProps['onChange'] = (event) => { - setView(event.target.value as Views); - setWorker(''); - setSpan(''); - }; + setView(event.target.value as Views) + setWorker('') + setSpan('') + } const handleWorkerChange: SelectProps['onChange'] = (event) => { - setWorker(event.target.value as string); - setSpan(''); - }; + setWorker(event.target.value as string) + setSpan('') + } const handleSpanChange: SelectProps['onChange'] = (event) => { - setSpan(event.target.value as string); - }; + setSpan(event.target.value as string) + } const handleDiffLeftRunChange: SelectProps['onChange'] = (event) => { - setDiffLeftRun(event.target.value as string); - setDiffLeftWorker(''); - setDiffLeftSpan(''); - }; + setDiffLeftRun(event.target.value as string) + setDiffLeftWorker('') + setDiffLeftSpan('') + } const handleDiffLeftWorkerChange: SelectProps['onChange'] = (event) => { - setDiffLeftWorker(event.target.value as string); - setDiffLeftSpan(''); - }; + setDiffLeftWorker(event.target.value as string) + setDiffLeftSpan('') + } const handleDiffLeftSpanChange: SelectProps['onChange'] = (event) => { - setDiffLeftSpan(event.target.value as string); - }; + setDiffLeftSpan(event.target.value as string) + } const handleDiffRightRunChange: SelectProps['onChange'] = (event) => { - setDiffRightRun(event.target.value as string); - setDiffRightWorker(''); - setDiffRightSpan(''); - }; + setDiffRightRun(event.target.value as string) + setDiffRightWorker('') + setDiffRightSpan('') + } const handleDiffRightWorkerChange: SelectProps['onChange'] = (event) => { - setDiffRightWorker(event.target.value as string); - setDiffRightSpan(''); - }; + setDiffRightWorker(event.target.value as string) + setDiffRightSpan('') + } const handleDiffRightSpanChange: SelectProps['onChange'] = (event) => { - setDiffRightSpan(event.target.value as string); - }; + setDiffRightSpan(event.target.value as string) + } - const handleDrawerOpen = (): void => { - setOpen(true); - setIframeActive(); - }; + const handleDrawerOpen = () => { + setOpen(true) + SetIframeActive() + } - const handleDrawerClose = (): void => { - setOpen(false); - setIframeActive(); - }; + const handleDrawerClose = () => { + setOpen(false) + SetIframeActive() + } - const setIframeActive = (): void => { - iframeRef.current?.focus(); - }; + const SetIframeActive = () => { + iframeRef.current?.focus() + } - const _changeFileList = (files: FileInfo[]): void => { + const _changeFileList = (files: FileInfo[]) => { if (JSON.stringify(files) !== JSON.stringify(fileList)) { - setFileList(files); + setFileList(files) } - }; + } - const _getViews = (viewName: Views): string => { - if (viewName === Views.Kernel) { - return deviceTarget === 'Ascend' ? `NPU ${viewNames[viewName]}` : `GPU ${viewNames[viewName]}`; - } else { - return viewNames[viewName]; - } - }; + const _changeUploadCount = (count: number) => { + setUploadedCount(count) + } - const _changeUploadCount = (count: number): void => { - setUploadedCount(count); - }; // #endregion + // #endregion - const renderContent = (): JSX.Element => { - if (!runsLoading && runs.length === 0) { + const renderContent = () => { + if (!runsLoading && runs.length == 0) { return ( - - + + There are not any runs in the log folder. - ); + ) } - const notReady = !loaded || !run || !worker || !view || !span; - if (notReady) { - return ; + + if (!loaded || !run || !worker || !view || !span) { + return } if (selectedTab === 0) { switch (view) { case Views.Overview: - return ; + return case Views.Operator: - return ; + return case Views.Kernel: - return ; + return case Views.Trace: - return ; + return ( + + ) case Views.Distributed: - return ; + return case Views.Memory: - return ; + return case Views.Module: case Views.Lightning: - return ; - default: - return <>; + return } } else { return ( @@ -450,99 +469,112 @@ export const App = (): JSX.Element => { expWorker={diffRightWorker} expSpan={diffRightSpan} /> - ); + ) } - }; + } - const spanComponent = (): JSX.Element => { + const spanComponent = () => { const spanFragment = ( Spans - - + + - ); + ) if (!spans || spans.length <= 1) { - return
{spanFragment}
; + return
{spanFragment}
} else { - return spanFragment; + return spanFragment } - }; + } return (
- +
- - - + + + {topTab === 0 ? ( <> - - - + + + - {selectedTab === 0 ? ( + {selectedTab == 0 ? ( <> Runs - - + + Views - - + + Workers - - + + @@ -551,75 +583,93 @@ export const App = (): JSX.Element => { ) : ( <> -   Baseline +   Baseline Runs - + Workers - - - - Spans - - - + + + + Spans + + + - + -   Experimental +   Experimental Runs - + Workers - - + {diffRightWorkerOptions.map((worker) => ( + {worker} ))} Spans - - + {diffRightSpansOptions.map((span) => ( + {span} ))} )} - ) : ( - - )} + ) : + + }
{!open && ( - + )}
{topTab === 0 ? renderContent() : }
-
- ); -}; + + ) +} diff --git a/plugins/tensorboard-plugins/tb_plugin/fe/src/components/Accuracy/AccuracyLeftPanel.tsx b/plugins/tensorboard-plugins/tb_plugin/fe/src/components/Accuracy/AccuracyLeftPanel.tsx index c7b7d7cf0841e7dc3686138b584e101e5052f4a6..ef9b170ec7a3de46039e5345ddf574f6fd620077 100644 --- a/plugins/tensorboard-plugins/tb_plugin/fe/src/components/Accuracy/AccuracyLeftPanel.tsx +++ b/plugins/tensorboard-plugins/tb_plugin/fe/src/components/Accuracy/AccuracyLeftPanel.tsx @@ -17,32 +17,38 @@ * limitations under the License. *--------------------------------------------------------------------------------------------*/ -import * as React from 'react'; -import { useState, useEffect, useCallback, useRef } from 'react'; -import { makeStyles } from '@material-ui/core/styles'; -import { Button, Checkbox, Spin, Modal, message } from 'antd'; -import { CheckboxChangeEvent } from 'antd/es/checkbox'; -import { DeleteOutlined, DownloadOutlined, ImportOutlined, SettingOutlined, WarningTwoTone } from '@ant-design/icons'; -import { RegexConfigModal } from './RegexConfigModal'; -import { FileInfo } from './entity'; +import * as React from 'react' +import { useState, useEffect, useCallback, useRef } from 'react' +import { makeStyles } from '@material-ui/core/styles' +import { Button, Checkbox, Spin, Modal, message } from 'antd' +import { CheckboxChangeEvent } from 'antd/es/checkbox' +import { + DeleteOutlined, + DownloadOutlined, + ImportOutlined, + SettingOutlined, + WarningTwoTone, +} from '@ant-design/icons' +import { RegexConfigModal } from './RegexConfigModal' +import { FileInfo } from './entity' interface IProps { - onChangeCheckedFileList: (files: FileInfo[]) => void; - onChangeUploadedCount: (count: number) => void; + onChangeCheckedFileList: (files: FileInfo[]) => void + onChangeUploadedCount: (count: number) => void } // 匹配数字包括科学计数法 -const LOSS_REG_EXP = /[+-]?\d+(?:\.\d+)?(?:[eE][+-]?\d+)?/; +const LOSS_REG_EXP = /[+-]?\d+(?:\.\d+)?(?:[eE][+-]?\d+)?/ // 匹配自然数 -const ITER_REG_EXP = /\d+/; +const ITER_REG_EXP = /\d+/ // 单个文件最大大小 -const FILE_MAX_SIZE = 50 * 1024 * 1024; +const FILE_MAX_SIZE = 50 * 1024 * 1024 // 最大文件上传数量 -export const MAX_FILE_COUNT = 6; +export const MAX_FILE_COUNT = 6 const useStyles = makeStyles(() => ({ root: { - height: '100%', + height: '100%' }, btnPanel: { height: 50, @@ -50,8 +56,8 @@ const useStyles = makeStyles(() => ({ borderBottom: '1px solid #DFE5EF', display: 'flex', '& .ant-btn': { - margin: 'auto', - }, + margin: 'auto' + } }, fileContainer: { height: 54, @@ -65,7 +71,7 @@ const useStyles = makeStyles(() => ({ fontSize: 14, overflow: 'hidden', textOverflow: 'ellipsis', - whiteSpace: 'nowrap', + whiteSpace: 'nowrap' }, '& .btns': { display: 'inline-block', @@ -73,17 +79,17 @@ const useStyles = makeStyles(() => ({ '& .icon': { cursor: 'pointer', '&:hover': { - color: '#1890ff', - }, + color: '#1890ff' + } }, '& .iconLeft': { - marginRight: 8, - }, + marginRight: 8 + } }, }, deleteModal: { '& .ant-modal-title': { - fontWeight: 'bold', + fontWeight: 'bold' }, '& .deleteModalBody': { display: 'flex', @@ -91,210 +97,203 @@ const useStyles = makeStyles(() => ({ height: 80, '& .warningIcon': { display: 'inline-block', - fontSize: 50, + fontSize: 50 }, '& .warningText': { display: 'inline-block', marginLeft: 16, overflow: 'hidden', wordBreak: 'break-all', - flex: 1, - }, - }, - }, -})); + flex: 1 + } + } + } +})) export const AccuracyLeftPanel: React.FC = (props) => { - const { onChangeCheckedFileList, onChangeUploadedCount } = props; - const classes = useStyles(); - const [configModalVis, setConfigModalVis] = useState(false); - const [deleteModalVis, setDeleteModalVis] = useState(false); - const [fileList, setFileList] = useState([]); - const [importSpin, setImportSpin] = useState(false); - const [selectedFile, setSelectedFile] = useState(undefined); - const downLoadRef = useRef(null); + const { onChangeCheckedFileList, onChangeUploadedCount } = props + const classes = useStyles() + const [configModalVis, setConfigModalVis] = useState(false) + const [deleteModalVis, setDeleteModalVis] = useState(false) + const [fileList, setFileList] = useState([]) + const [importSpin, setImportSpin] = useState(false) + const [selectedFile, setSelectedFile] = useState(undefined) + const downLoadRef = useRef(null) const parseFile = (file: FileInfo): FileInfo => { - file.losses = []; - file.iterLosses = {}; - file.iters = []; - const lines = file.fileContent.split(/\r\n|\n|\r/); + file.losses = [] + file.iterLosses = {} + file.iters = [] + const lines = file.fileContent.split(/\r\n|\n|\r/) for (let i = 0; i < lines.length; i++) { - const iter = parseByTag(lines[i], file.iterTag, false); - const loss = parseByTag(lines[i], file.lossTag, true); + const iter = parseByTag(lines[i], file.iterTag, false) + const loss = parseByTag(lines[i], file.lossTag, true) if (iter !== null && loss !== null) { - file.iters.push(iter); - file.losses.push([iter, loss]); - file.iterLosses[iter] = loss; + file.iters.push(iter) + file.losses.push([iter, loss]) + file.iterLosses[iter] = loss } } - return file; - }; + return file + } const parseByTag = (line: string, tag: string, isLoss: boolean): number | null => { - let pos = line.indexOf(tag); - let result: number | null = null; + let pos = line.indexOf(tag) + let result: number | null = null if (pos !== -1) { - const res = (isLoss ? LOSS_REG_EXP : ITER_REG_EXP).exec( - line - .substring(pos + tag.length) - .trim() - .split(/\s+/)[0] - ); + const res = (isLoss ? LOSS_REG_EXP : ITER_REG_EXP) + .exec(line.substring(pos + tag.length).trim().split(/\s+/)[0]) if (res !== null) { if (isLoss) { - result = parseFloat(res[0]); + result = parseFloat(res[0]) } else { - result = parseInt(res[0]); + result = parseInt(res[0]) } } else { - console.warn(`Found ${isLoss ? 'loss' : 'iteration'} text, but parse value with error: [${line}]`); + console.log(`Found ${isLoss ? 'loss' : 'iteration'} text, but parse value with error: [${line}]`) } } - return result; - }; + return result + } - const importFile = (): void => { - document.getElementById('accComparisonSelectFile')?.click(); - }; + const importFile = () => { + document.getElementById('accComparisonSelectFile')?.click() + } - const uploadFile = (e: React.ChangeEvent): void => { - setImportSpin(true); - const file = e.target.files?.[0]; + const uploadFile = (e: React.ChangeEvent) => { + setImportSpin(true) + const file = e.target.files?.[0] if (file) { if (file.size > FILE_MAX_SIZE) { - message.warn('Sorry, the file size cannot be greater than 50MB.'); - setImportSpin(false); + message.warn('Sorry, the file size cannot be greater than 50MB.') + setImportSpin(false) // 防止同名文件不触发事件 - e.target.value = ''; - return; + e.target.value = '' + return } - const reader = new FileReader(); - reader.onload = ((loadedFile) => { - return (event) => { - addFile(loadedFile.name.trim(), event.target?.result as string); - setImportSpin(false); - }; + const reader = new FileReader() + reader.onload = ((selectedFile) => { + return (e) => { + addFile(selectedFile.name.trim(), e.target?.result as string) + setImportSpin(false) + } })(file); - reader.readAsText(file); + reader.readAsText(file) } // 防止同名文件不触发事件 - e.target.value = ''; - }; + e.target.value = '' + } - const addFile = (fileName: string, fileContent: string): void => { - const fileLength = fileName.length; - const tempList: FileInfo[] = JSON.parse(JSON.stringify(fileList)); - let updatedFileName = fileName; // 新变量用于存储更新后的文件名 + const addFile = (fileName: string, fileContent: string) => { + const fileLength = fileName.length + const tempList: FileInfo[] = JSON.parse(JSON.stringify(fileList)) // 上传同名文件加上(1~最大文件数减1)标识 - if (!!tempList.find((item) => item.fileName === fileName)) { + if (!!tempList.find(item => item.fileName === fileName)) { for (let i = 1; i < MAX_FILE_COUNT; i++) { - let temp = `${fileName.slice(0, fileLength - 4)}(${i})${fileName.slice(fileLength - 4)}`; - if (tempList.find((item) => item.fileName === temp) === undefined) { - updatedFileName = temp; - break; + let temp = `${fileName.slice(0, fileLength - 4)}(${i})${fileName.slice(fileLength - 4)}` + if (tempList.find(item => item.fileName === temp) === undefined) { + fileName = temp + break } } } const file: FileInfo = { id: fileList.length, - fileName: updatedFileName, + fileName: fileName, fileContent, checked: true, lossTag: 'loss:', iterTag: 'iteration', iters: [], losses: [], - iterLosses: {}, - }; - tempList.push(parseFile(file)); - setFileList(tempList); - }; + iterLosses: {} + } + tempList.push(parseFile(file)) + setFileList(tempList) + } - const exportCsv = (data: FileInfo): void => { - let csvContent = `data:text/csv;charset=utf-8,${data.iterTag},${data.lossTag}\n`; - data.losses.forEach((item) => { - csvContent += `${item[0]},${item[1]}\n`; - }); - downLoadRef.current?.setAttribute('href', encodeURI(csvContent)); - downLoadRef.current?.setAttribute('download', `${data.fileName}.csv`); - downLoadRef.current?.click(); - }; + const exportCsv = (data: FileInfo) => { + let csvContent = `data:text/csv;charset=utf-8,${data.iterTag},${data.lossTag}\n` + data.losses.forEach(item => { + csvContent += `${item[0]},${item[1]}\n` + }) + downLoadRef.current?.setAttribute('href', encodeURI(csvContent)) + downLoadRef.current?.setAttribute('download', `${data.fileName}.csv`) + downLoadRef.current?.click() + } - const onCheckChange = (e: CheckboxChangeEvent, index: number): void => { - const tempList: FileInfo[] = JSON.parse(JSON.stringify(fileList)); - tempList[index].checked = e.target.checked; - setFileList(tempList); - }; + const onCheckChange = (e: CheckboxChangeEvent, index: number) => { + const tempList: FileInfo[] = JSON.parse(JSON.stringify(fileList)) + tempList[index].checked = e.target.checked + setFileList(tempList) + } - const onConfigIconClick = (data: FileInfo): void => { - setSelectedFile(data); - setConfigModalVis(true); - }; + const onConfigIconClick = (data: FileInfo) => { + setSelectedFile(data) + setConfigModalVis(true) + } - const onDeleteIconClick = (data: FileInfo): void => { - setSelectedFile(data); - setDeleteModalVis(true); - }; + const onDeleteIconClick = (data: FileInfo) => { + setSelectedFile(data) + setDeleteModalVis(true) + } - const configModalOk = (data: FileInfo): void => { - const tempList = fileList.map((item) => { - return item.id === data.id ? parseFile(data) : item; - }); - setFileList(tempList); - setConfigModalVis(false); - }; + const configModalOk = (data: FileInfo) => { + const tempList = fileList.map(item => { + return item.id === data.id ? parseFile(data) : item + }) + setFileList(tempList) + setConfigModalVis(false) + } - const configModalCancel = (): void => { - setConfigModalVis(false); - }; + const configModalCancel = () => { + setConfigModalVis(false) + } - const deleteModalOk = (): void => { - const tempList = JSON.parse(JSON.stringify(fileList)); - let founded = false; - let index = 0; + const deleteModalOk = () => { + const tempList = JSON.parse(JSON.stringify(fileList)) + let founded = false + let index = 0 for (let i = 0; i < tempList.length; i++) { if (founded) { - tempList[i].id -= 1; - continue; + tempList[i].id -= 1 + continue } if (tempList[i].id === selectedFile?.id) { - founded = true; - index = i; + founded = true + index = i } } - tempList.splice(index, 1); - setFileList(tempList); - setSelectedFile(undefined); - setDeleteModalVis(false); - }; + tempList.splice(index, 1) + setFileList(tempList) + setSelectedFile(undefined) + setDeleteModalVis(false) + } const renderFileItems = useCallback(() => { return fileList.map((item) => { return (
- onCheckChange(e, item.id)} /> - - {item.fileName} - -
- onConfigIconClick(item)} /> - exportCsv(item)} /> - onDeleteIconClick(item)} /> + onCheckChange(e, item.id)} /> + {item.fileName} +
+ onConfigIconClick(item)} /> + exportCsv(item)} /> + onDeleteIconClick(item)} />
- ); - }); - }, [JSON.stringify(fileList)]); + ) + }) + }, [JSON.stringify(fileList)]) useEffect(() => { - onChangeCheckedFileList(fileList.filter((item) => item.checked)); - onChangeUploadedCount(fileList.length); - }, [JSON.stringify(fileList)]); + onChangeCheckedFileList(fileList.filter(item => item.checked)) + onChangeUploadedCount(fileList.length) + }, [JSON.stringify(fileList)]) return (
- +
- +
{renderFileItems()}
- {configModalVis && ( - - )} + {configModalVis && + + } setDeleteModalVis(false)} + onCancel={() => setDeleteModalVis(false)} onOk={deleteModalOk} width={500} className={classes.deleteModal} > -
- - +
+ + Are you sure to delete "{selectedFile?.fileName}"?
- ); -}; + ) +} diff --git a/plugins/tensorboard-plugins/tb_plugin/fe/src/components/Accuracy/ComparisonPanel.tsx b/plugins/tensorboard-plugins/tb_plugin/fe/src/components/Accuracy/ComparisonPanel.tsx index 500d29764c5209958ba19630ac1d4e08c10f24a5..a9c9d34feb585cac7c6aa26f9e962c0ed9d11d88 100644 --- a/plugins/tensorboard-plugins/tb_plugin/fe/src/components/Accuracy/ComparisonPanel.tsx +++ b/plugins/tensorboard-plugins/tb_plugin/fe/src/components/Accuracy/ComparisonPanel.tsx @@ -17,23 +17,23 @@ * limitations under the License. *--------------------------------------------------------------------------------------------*/ -import * as React from 'react'; -import { useState, useLayoutEffect, useRef, useEffect } from 'react'; -import { makeStyles } from '@material-ui/core/styles'; -import { FileInfo } from './entity'; -import { Empty, Popover, Radio, RadioChangeEvent, Select, Table } from 'antd'; -import { ColumnsType } from 'antd/es/table'; -import * as echarts from 'echarts'; -import { InfoCircleOutlined } from '@ant-design/icons'; +import * as React from 'react' +import { useState, useLayoutEffect, useRef, useEffect } from 'react' +import { makeStyles } from '@material-ui/core/styles' +import { FileInfo } from './entity' +import { Empty, Popover, Radio, RadioChangeEvent, Select, Table } from 'antd' +import { ColumnsType } from 'antd/es/table' +import * as echarts from 'echarts' +import { InfoCircleOutlined } from '@ant-design/icons' interface IProps { - fileList: FileInfo[]; + fileList: FileInfo[] } interface ILineDataList { - normal: number[][]; - absolute: number[][]; - relative: number[][]; + normal: number[][] + absolute: number[][] + relative: number[][] } const useStyles = makeStyles(() => ({ @@ -49,26 +49,26 @@ const useStyles = makeStyles(() => ({ lineHeight: '24px', fontFamily: 'sans-serif', fontSize: 16, - fontWeight: 700, + fontWeight: 700 }, filter: { height: 40, lineHeight: '40px', '& .comparisonSelect': { - margin: '0 8px', + margin: '0 8px' }, '& .comparisonLabel': { - marginRight: 8, + marginRight: 8 }, '& .comparisonBtn': { - marginLeft: 20, + marginLeft: 20 }, '& .infoLabel': { - fontSize: 20, - }, + fontSize: 20 + } }, empty: { - marginTop: 60, + marginTop: 60 }, content: { flex: 1, @@ -76,11 +76,11 @@ const useStyles = makeStyles(() => ({ }, lossChart: { height: '100%', - flex: 1, + flex: 1 }, lossTable: { height: '100%', - width: '32%', + width: '32%' }, tableHeader: { display: 'inline-block', @@ -90,163 +90,149 @@ const useStyles = makeStyles(() => ({ transform: 'translateY(-50%)', overflow: 'hidden', textOverflow: 'ellipsis', - whiteSpace: 'nowrap', - }, -})); + whiteSpace: 'nowrap' + } +})) export const ComparisonPanel: React.FC = (props) => { - const { fileList } = props; - const classes = useStyles(); - const [selectedFiles, setSelectedFiles] = useState([]); - const [compareWay, setCompareWay] = useState(0); - const [pageSize, setPageSize] = useState(20); - const [lineData, setLineData] = useState(undefined); - const [tableData, setTableData] = useState([]); - const chartRef = useRef(null); + const { fileList } = props + const classes = useStyles() + const [selectedFiles, setSelectedFiles] = useState([]) + const [compareWay, setCompareWay] = useState(0) + const [pageSize, setPageSize] = useState(20) + const [lineData, setLineData] = useState(undefined) + const [tableData, setTableData] = useState([]) + const chartRef = useRef(null) const getColumns = (): ColumnsType => { - const columns: ColumnsType = [ - { - title: 'Iteration', - key: 'iter', - dataIndex: 'iter', - }, - ]; + const columns: ColumnsType = [{ + title: 'Iteration', + key: 'iter', + dataIndex: 'iter', + }] selectedFiles.forEach((item, index) => { columns.push({ title: () => ( -
- {item} -
+
{item}
), key: index, dataIndex: item, - width: '40%', - }); - }); - return columns; - }; + width: '40%' + }) + }) + return columns + } - const compareFile = (fileNames: string[]): void => { + const compareFile = (fileNames: string[]) => { if (fileNames.length < 2) { - return; + return } - const baseFile = fileList.find((item) => item.fileName === fileNames[0]); - const expFile = fileList.find((item) => item.fileName === fileNames[1]); + const baseFile = fileList.find(item => item.fileName === fileNames[0]) + const expFile = fileList.find(item => item.fileName === fileNames[1]) if (!!baseFile && !!expFile) { - const commonIters: number[] = []; - const lessIters = baseFile.iters.length <= expFile.iters.length ? baseFile.iters : expFile.iters; - const moreIters = baseFile.iters.length > expFile.iters.length ? baseFile.iters : expFile.iters; - lessIters.forEach((iter) => { + const commonIters: number[] = [] + const lessIters = baseFile.iters.length <= expFile.iters.length ? baseFile.iters : expFile.iters + const moreIters = baseFile.iters.length > expFile.iters.length ? baseFile.iters : expFile.iters + lessIters.forEach(iter => { if (moreIters.includes(iter)) { - commonIters.push(iter); + commonIters.push(iter) } - }); - commonIters.sort((a, b) => a - b); - const tempTableData: any[] = []; + }) + commonIters.sort((a, b) => a - b) + const tempTableData: any[] = [] const tempChartData: ILineDataList = { normal: [], absolute: [], - relative: [], - }; + relative: [] + } commonIters.forEach((iter, index) => { - const baseLoss = baseFile.iterLosses[iter]; - const expLoss = expFile.iterLosses[iter]; + const baseLoss = baseFile.iterLosses[iter] + const expLoss = expFile.iterLosses[iter] tempTableData.push({ key: `${iter}_${index}`, iter, [baseFile.fileName]: baseLoss, - [expFile.fileName]: expLoss, - }); - tempChartData.normal.push([iter, expLoss - baseLoss]); - tempChartData.absolute.push([iter, Math.abs(expLoss - baseLoss)]); - tempChartData.relative.push([iter, baseLoss === 0 ? 0 : Math.abs(expLoss - baseLoss) / baseLoss]); - }); - setTableData(tempTableData); - setLineData(tempChartData); + [expFile.fileName]: expLoss + }) + tempChartData.normal.push([iter, expLoss - baseLoss]) + tempChartData.absolute.push([iter, Math.abs(expLoss - baseLoss)]) + tempChartData.relative.push([iter, baseLoss === 0 ? 0 : Math.abs(expLoss - baseLoss) / baseLoss]) + }) + setTableData(tempTableData) + setLineData(tempChartData) } - }; + } - const onSelectChange = (value: string[]): void => { - setSelectedFiles(value); - compareFile(value); - }; + const onSelectChange = (value: string[]) => { + setSelectedFiles(value) + compareFile(value) + } - const onRadioChange = (e: RadioChangeEvent): void => { - setCompareWay(e.target.value); - }; + const onRadioChange = (e: RadioChangeEvent) => { + setCompareWay(e.target.value) + } - const onShowSizeChange = (current: number, size: number): void => { - setPageSize(size); - }; + const onShowSizeChange = (current: number, size: number) => { + setPageSize(size) + } useLayoutEffect(() => { - const element = chartRef.current; + const element = chartRef.current if (!element || !lineData) { - return undefined; - } - const echart = echarts.init(element); - let dataSource: number[][] = []; - if (compareWay === 0) { - dataSource = lineData.normal; - } else if (compareWay === 1) { - dataSource = lineData.absolute; - } else { - dataSource = lineData.relative; + return } + const echart = echarts.init(element) const option: echarts.EChartsOption = { title: { text: 'Comparison Chart', textStyle: { fontSize: 12, - color: '#000', - }, + color: '#000' + } }, legend: { bottom: 0 }, xAxis: { type: 'category', boundaryGap: false, - name: 'Iteration', + name: 'Iteration' }, yAxis: { type: 'value', name: 'Difference', - scale: true, + scale: true }, tooltip: { trigger: 'axis', - valueFormatter: (value) => (value as number).toFixed(6), + valueFormatter: (value) => (value as number).toFixed(6) }, dataZoom: { - type: 'inside', + type: 'inside' }, dataset: { - source: dataSource, + source: compareWay === 0 ? lineData.normal : (compareWay === 1 ? lineData.absolute : lineData.relative) }, series: { type: 'line', name: 'Difference', - symbol: 'none', - }, - }; - - if (option) { - echart.setOption(option, true); + symbol: 'none' + } } + + option && echart.setOption(option, true) return () => { - echart.dispose(); - }; - }, [compareWay, lineData]); + echart.dispose() + } + }, [compareWay, lineData]) useEffect(() => { - const tempValue = selectedFiles.filter((item) => { - return !!fileList.find((file) => file.fileName === item); - }); + const tempValue = selectedFiles.filter(item => { + return !!fileList.find(file => file.fileName === item) + }) if (JSON.stringify(tempValue) === JSON.stringify(selectedFiles)) { - compareFile(tempValue); + compareFile(tempValue) } - setSelectedFiles(tempValue); - }, [fileList]); + setSelectedFiles(tempValue) + }, [fileList]) return (
@@ -254,23 +240,25 @@ export const ComparisonPanel: React.FC = (props) => {
Comparison objects:
- Iteration Tag + Iteration Tag
- ); -}; + ) +} \ No newline at end of file diff --git a/plugins/tensorboard-plugins/tb_plugin/fe/src/components/Accuracy/entity.ts b/plugins/tensorboard-plugins/tb_plugin/fe/src/components/Accuracy/entity.ts index 270c4cb6535633f9a03e5b9fe02dca6121cd3ba7..0a0a1ee4b28661799aea5a9233c4f3a90f4a251e 100644 --- a/plugins/tensorboard-plugins/tb_plugin/fe/src/components/Accuracy/entity.ts +++ b/plugins/tensorboard-plugins/tb_plugin/fe/src/components/Accuracy/entity.ts @@ -18,13 +18,13 @@ *--------------------------------------------------------------------------------------------*/ export interface FileInfo { - id: number; - fileName: string; - fileContent: string; - checked: boolean; - lossTag: string; - iterTag: string; - iters: number[]; - losses: number[][]; - iterLosses: { [iter: number]: number }; + id: number + fileName: string + fileContent: string + checked: boolean + lossTag: string + iterTag: string + iters: number[] + losses: number[][] + iterLosses: { [iter: number]: number } } diff --git a/plugins/tensorboard-plugins/tb_plugin/fe/src/components/DataLoading.tsx b/plugins/tensorboard-plugins/tb_plugin/fe/src/components/DataLoading.tsx index 3c5d353ce641c409b51a7aaef8c00ff2f57df6e8..e2967bdf74196ad74a13f2d2f8b1799911d3b553 100644 --- a/plugins/tensorboard-plugins/tb_plugin/fe/src/components/DataLoading.tsx +++ b/plugins/tensorboard-plugins/tb_plugin/fe/src/components/DataLoading.tsx @@ -2,18 +2,18 @@ * Copyright (c) Microsoft Corporation. All rights reserved. *--------------------------------------------------------------------------------------------*/ -import * as React from 'react'; -import { FullCircularProgress } from './FullCircularProgress'; +import * as React from 'react' +import { FullCircularProgress } from './FullCircularProgress' interface IProps { - value?: T | null; - children: (t: T) => JSX.Element; + value: T | undefined | null + children: (t: T) => JSX.Element } -export function DataLoading(props: IProps): JSX.Element { +export function DataLoading(props: IProps) { if (props.value === undefined || props.value === null) { - return ; + return } - return props.children(props.value); + return props.children(props.value) } diff --git a/plugins/tensorboard-plugins/tb_plugin/fe/src/components/DiffOverview.tsx b/plugins/tensorboard-plugins/tb_plugin/fe/src/components/DiffOverview.tsx index ed029d5020ed1eaf8caea159b25d33c7a5ad03e3..e8071b2c5966d944804b4d8abd780d8389042d38 100644 --- a/plugins/tensorboard-plugins/tb_plugin/fe/src/components/DiffOverview.tsx +++ b/plugins/tensorboard-plugins/tb_plugin/fe/src/components/DiffOverview.tsx @@ -2,101 +2,130 @@ * Copyright (c) Microsoft Corporation. All rights reserved. *--------------------------------------------------------------------------------------------*/ -import Button from '@material-ui/core/Button'; -import Card from '@material-ui/core/Card'; -import CardContent from '@material-ui/core/CardContent'; -import CardHeader from '@material-ui/core/CardHeader'; -import Grid from '@material-ui/core/Grid'; -import { makeStyles } from '@material-ui/core/styles'; -import Typography from '@material-ui/core/Typography'; -import ChevronLeftIcon from '@material-ui/icons/ChevronLeft'; -import { Select, Table } from 'antd'; -import * as React from 'react'; -import * as api from '../api'; -import { useResizeEventDependency } from '../utils/resize'; -import { FullCircularProgress } from './FullCircularProgress'; -import * as echarts from 'echarts'; - -const { Option } = Select; - -const topGraphHeight = 230; +import Button from '@material-ui/core/Button' +import Card from '@material-ui/core/Card' +import CardContent from '@material-ui/core/CardContent' +import CardHeader from '@material-ui/core/CardHeader' +import Grid from '@material-ui/core/Grid' +import { makeStyles } from '@material-ui/core/styles' +import Typography from '@material-ui/core/Typography' +import ChevronLeftIcon from '@material-ui/icons/ChevronLeft' +import { Select, Table } from 'antd' +import * as React from 'react' +import * as api from '../api' +import { useResizeEventDependency } from '../utils/resize' +import { FullCircularProgress } from './FullCircularProgress' +import * as echarts from 'echarts' + +const { Option } = Select + +const topGraphHeight = 230 const useStyles = makeStyles((theme) => ({ root: { - flexGrow: 1, + flexGrow: 1 }, pre: { '& ul': { margin: 0, paddingLeft: theme.spacing(3), - ...theme.typography.body1, + ...theme.typography.body1 }, '& li': {}, '& a': { - color: '#ffa726', + color: '#ffa726' }, '& a:active': { - color: '#ffa726', + color: '#ffa726' }, '& p': { margin: 0, ...theme.typography.subtitle1, - fontWeight: theme.typography.fontWeightBold, - }, + fontWeight: theme.typography.fontWeightBold + } }, topGraph: { - height: topGraphHeight + 40, + height: topGraphHeight + 40 }, iconButton: { - padding: '8px', - }, -})); + padding: '8px' + } +})) -const getAngleByDataLength = (data: number): number => { +const getAngleByDataLength = (data: number) => { if (data < 10) { - return 0; + return 0 } else { // 数量越大越趋近于旋转90度 - return 90 * (1 - (10 / data)); + return 90 * (1 - 10 / data) } -}; +} export interface DiffColumnChartIProps { - rawData: any[]; - selectCallback: (row: number, column: number) => void; + rawData: any[] + selectCallback: (row: number, column: number) => void } export interface DiffStepChartIProps { - rawData: any[]; + rawData: any[] } -const DiffColumnChart: React.FC = (props: DiffColumnChartIProps) => { - const { rawData, selectCallback } = props; - const graphRef = React.useRef(null); - const [resizeEventDependency] = useResizeEventDependency(); +const DiffColumnChart: React.FC = ( + props: DiffColumnChartIProps +) => { + const { rawData, selectCallback } = props + const graphRef = React.useRef(null) + const [resizeEventDependency] = useResizeEventDependency() React.useLayoutEffect(() => { - const element = graphRef.current; - if (!element) { - return undefined; + const element = graphRef.current + if (!element) return + + let left_duration_data: number[] = [] + let left_accumulated_duration_data: number[] = [] + + let right_duration_data: number[] = [] + let right_accumulated_duration_data: number[] = [] + + for (let i = 0; i < rawData.length; i++) { + let curr = rawData[i] + left_duration_data.push(curr[1]) + right_duration_data.push(curr[2]) + left_accumulated_duration_data.push(curr[3]) + right_accumulated_duration_data.push(curr[4]) } - const chart = echarts.init(element); + let left_duration_max = Math.max(...left_duration_data) + let right_duration_max = Math.max(...right_duration_data) + let duration_max = Math.max(left_duration_max, right_duration_max) + + let left_accumulated_duration_max = Math.max( + ...left_accumulated_duration_data + ) + let right_accumulated_duration_max = Math.max( + ...right_accumulated_duration_data + ) + let accumulated_max = Math.max( + left_accumulated_duration_max, + right_accumulated_duration_max + ) + + const chart = echarts.init(element) const options: echarts.EChartsOption = { title: { - text: 'Execution Comparsion', + text: 'Execution Comparsion' }, legend: { top: 10, - right: 10, + right: 10 }, tooltip: { trigger: 'axis', formatter: function (params: any) { - const index = params[0].name.indexOf('@'); - const safeName = params[0].name.replace(//g, '>'); - let res = `${index > -1 ? safeName.slice(index + 1) : safeName}
`; + const index = params[0].name.indexOf('@') + const safeName = params[0].name.replace(//g, '>') + var res = `${index > -1 ? safeName.slice(index + 1) : safeName}
` for (const item of params) { if (typeof item.value[item.encode.y[0]] === 'number') { res += ` - ${item.seriesName}: ${item.value[item.encode.y[0]]}
`; + ${item.seriesName}: ${item.value[item.encode.y[0]]}
` } } - return res; - }, + return res + } }, series: [ { type: 'bar', itemStyle: { - color: '#3366cc', + color: '#3366cc' }, yAxisIndex: 0, + }, { type: 'bar', itemStyle: { - color: '#dc3912', + color: '#dc3912' }, - yAxisIndex: 0, + yAxisIndex: 0 }, { type: 'line', itemStyle: { - color: '#ff9900', + color: '#ff9900' }, - yAxisIndex: 1, + yAxisIndex: 1 }, { type: 'line', itemStyle: { - color: '#109618', + color: '#109618' }, - yAxisIndex: 1, - }, + yAxisIndex: 1 + } ], xAxis: { type: 'category', @@ -148,81 +178,78 @@ const DiffColumnChart: React.FC = (props: DiffColumnChart interval: 0, rotate: getAngleByDataLength(rawData.length), formatter: (name: string) => { - const index = name.indexOf('@'); - const displayName = index > -1 ? name.slice(index + 1) : name; // 创建新变量 - return displayName.length > 16 ? `${displayName.slice(0, 14)}...` : displayName; - }, - }, + const index = name.indexOf('@') + if (index > -1) { + name = name.slice(index + 1) + } + return name.length > 16 ? name.slice(0, 14) + "..." : name; + } + } }, - yAxis: [ - { - type: 'value', - name: 'Time Difference(us)', - scale: true, - }, - { - type: 'value', - name: 'Accumulated Difference(us)', - scale: true, - }, - ], + yAxis: [{ + type: 'value', + name: 'Time Difference(us)', + scale: true + }, { + type: 'value', + name: 'Accumulated Difference(us)', + scale: true + }], dataset: { source: rawData.map((item, idx) => { // 添加索引保证x轴刻度不重复 - let param: any[] = [...item]; - param[0] = `${idx}@${param[0]}`; - return param; - }), - }, - }; - - if (options) { - chart.setOption(options, true); + let param: any[] = [...item] + param[0] = `${idx}@${param[0]}` + return param + }) + } } + + options && chart.setOption(options, true) chart.on('click', (param) => { if (param.seriesIndex !== undefined) { - selectCallback(param.dataIndex, param.seriesIndex + 1); + selectCallback(param.dataIndex, param.seriesIndex + 1) } - }); + }) return () => { - chart.dispose(); - }; - }, [rawData, resizeEventDependency]); + chart.dispose() + } + }, [rawData, resizeEventDependency]) return (
- ); -}; + ) +} -const DiffStepChart: React.FC = (props: DiffStepChartIProps) => { - const { rawData } = props; - const graphRef = React.useRef(null); - const [resizeEventDependency] = useResizeEventDependency(); +const DiffStepChart: React.FC = ( + props: DiffStepChartIProps +) => { + const { rawData } = props + const graphRef = React.useRef(null) + const [resizeEventDependency] = useResizeEventDependency() React.useLayoutEffect(() => { - const element = graphRef.current; - if (!element) { - return undefined; - } - const chart = echarts.init(element); + const element = graphRef.current + if (!element) return + const chart = echarts.init(element) const options: echarts.EChartsOption = { title: { - text: 'Execution Diff', + text: 'Execution Diff' }, legend: { top: 10, - right: 10, + right: 10 }, dataset: { source: rawData.map((item, idx) => { // 添加索引保证x轴刻度不重复 - let param: any[] = [...item]; - param[0] = `${idx}@${param[0]}`; - return param; - }), + let param: any[] = [...item] + param[0] = `${idx}@${param[0]}` + return param + }) }, xAxis: { type: 'category', @@ -230,22 +257,24 @@ const DiffStepChart: React.FC = (props: DiffStepChartIProps interval: 0, rotate: getAngleByDataLength(rawData.length), formatter: (name: string) => { - const index = name.indexOf('@'); - const displayName = index > -1 ? name.slice(index + 1) : name; // 创建新变量 - return displayName.length > 16 ? `${displayName.slice(0, 14)}...` : displayName; - }, - }, + const index = name.indexOf('@') + if (index > -1) { + name = name.slice(index + 1) + } + return name.length > 16 ? name.slice(0, 14) + "..." : name; + } + } }, yAxis: { type: 'value', - scale: true, + scale: true }, tooltip: { trigger: 'axis', formatter: function (params: any) { - const index = params[0].name.indexOf('@'); - const safeName = params[0].name.replace(//g, '>'); - let res = `${index > -1 ? safeName.slice(index + 1) : safeName}
`; + const index = params[0].name.indexOf('@') + const safeName = params[0].name.replace(//g, '>') + var res = `${index > -1 ? safeName.slice(index + 1) : safeName}
` for (const item of params) { if (typeof item.value[item.encode.y[0]] === 'number') { res += ` - ${item.seriesName}: ${item.value[item.encode.y[0]]}
`; + ${item.seriesName}: ${item.value[item.encode.y[0]]}
` } } - return res; - }, + return res + } }, series: [ { @@ -269,411 +298,413 @@ const DiffStepChart: React.FC = (props: DiffStepChartIProps step: 'middle', areaStyle: { color: '#c1d1ef', - opacity: 1, - }, - }, - { + opacity: 1 + } + }, { type: 'line', color: '#dc3912', symbolSize: 0, step: 'middle', areaStyle: { color: '#f4c3b7', - opacity: 1, - }, - }, - ], - }; - - if (options) { - chart.setOption(options, true); + opacity: 1 + } + } + ] } + + options && chart.setOption(options, true) return () => { - chart.dispose(); - }; - }, [rawData, resizeEventDependency]); + chart.dispose() + } + }, [rawData, resizeEventDependency]) return (
- ); -}; + ) +} export interface IProps { - run: string; - worker: string; - span: string; - expRun: string; - expWorker: string; - expSpan: string; + run: string + worker: string + span: string + expRun: string + expWorker: string + expSpan: string } export interface ColumnUnderlyingData { - name: string; - path: string; - leftAggs: any[]; - rightAggs: any[]; + name: string + path: string + leftAggs: any[] + rightAggs: any[] } export interface TableRow { - key: number; - - operator: string; - baselineCalls?: number; - expCalls?: number; - deltaCalls?: number; - deltaCallsPercentNumber?: number; - deltaCallsPercent?: string; - - baselineHostDuration: number; - expHostDuration: number; - deltaHostDuration: number; - deltaHostDurationPercentNumber: number; - deltaHostDurationPercent: string; - - baselineSelfHostDuration: number; - expSelfHostDuration: number; - deltaSelfHostDuration: number; - deltaSelfHostDurationPercentNumber: number; - deltaSelfHostDurationPercent: string; - - baselineDeviceDuration: number; - expDeviceDuration: number; - deltaDeviceDuration: number; - deltaDeviceDurationPercentNumber: number; - deltaDeviceDurationPercent: string; - - baselineSelfDeviceDuration: number; - expSelfDeviceDuration: number; - deltaSelfDeviceDuration: number; - deltaSelfDeviceDurationPercentNumber: number; - deltaSelfDeviceDurationPercent: string; + key: number + + operator: string + baselineCalls?: number + expCalls?: number + deltaCalls?: number + deltaCallsPercentNumber?: number + deltaCallsPercent?: string + + baselineHostDuration: number + expHostDuration: number + deltaHostDuration: number + deltaHostDurationPercentNumber: number + deltaHostDurationPercent: string + + baselineSelfHostDuration: number + expSelfHostDuration: number + deltaSelfHostDuration: number + deltaSelfHostDurationPercentNumber: number + deltaSelfHostDurationPercent: string + + baselineDeviceDuration: number + expDeviceDuration: number + deltaDeviceDuration: number + deltaDeviceDurationPercentNumber: number + deltaDeviceDurationPercent: string + + baselineSelfDeviceDuration: number + expSelfDeviceDuration: number + deltaSelfDeviceDuration: number + deltaSelfDeviceDurationPercentNumber: number + deltaSelfDeviceDurationPercent: string } -let columnChartDataStack: any[][] = []; -let stepChartDataStack: any[][] = []; -let columnUnderlyingDataStack: ColumnUnderlyingData[][] = []; -let columnTableDataSourceStack: TableRow[][] = []; +let columnChartDataStack: any[][] = [] +let stepChartDataStack: any[][] = [] +let columnUnderlyingDataStack: ColumnUnderlyingData[][] = [] +let columnTableDataSourceStack: TableRow[][] = [] export const DiffOverview: React.FC = (props: IProps) => { // #region - Constant - const COMPOSITE_NODES_NAME = 'CompositeNodes'; + + const COMPOSITE_NODES_NAME = 'CompositeNodes' const hostDurationColumns = [ { title: 'Baseline Host Duration (us)', dataIndex: 'baselineHostDuration', key: 'baselineHostDuration', - sorter: (a: TableRow, b: TableRow): number => { - const aBaselineHost = a.baselineHostDuration ?? 0; - const bBaselineHost = b.baselineHostDuration ?? 0; - return aBaselineHost - bBaselineHost; - }, + sorter: (a: TableRow, b: TableRow) => + a.baselineHostDuration - b.baselineHostDuration }, { title: 'Exp Host Duration (us)', dataIndex: 'expHostDuration', key: 'expHostDuration', - sorter: (a: TableRow, b: TableRow): number => { - const aExpHost = a.expHostDuration ?? 0; - const bExpHost = b.expHostDuration ?? 0; - return aExpHost - bExpHost; - }, + sorter: (a: TableRow, b: TableRow) => + a.expHostDuration - b.expHostDuration }, { title: 'Delta Host Duration (us)', dataIndex: 'deltaHostDuration', key: 'deltaHostDuration', - sorter: (a: TableRow, b: TableRow): number => { - const aDeltaHost = a.deltaHostDuration ?? 0; - const bDeltaHost = b.deltaHostDuration ?? 0; - return aDeltaHost - bDeltaHost; - }, + sorter: (a: TableRow, b: TableRow) => + a.deltaHostDuration! - b.deltaHostDuration! }, { title: 'Delta Host Duration%', dataIndex: 'deltaHostDurationPercent', key: 'deltaHostDurationPercent', - sorter: (a: TableRow, b: TableRow): number => { - const aPercent = a.deltaHostDurationPercentNumber ?? 0; - const bPercent = b.deltaHostDurationPercentNumber ?? 0; - return aPercent - bPercent; - }, - }, - ]; + sorter: (a: TableRow, b: TableRow) => + a.deltaHostDurationPercentNumber! - b.deltaHostDurationPercentNumber! + } + ] const selfHostDurationColumns = [ { title: 'Baseline Self Host Duration (us)', dataIndex: 'baselineSelfHostDuration', key: 'baselineSelfHostDuration', - sorter: (a: TableRow, b: TableRow): number => a.baselineSelfHostDuration - b.baselineSelfHostDuration, + sorter: (a: TableRow, b: TableRow) => + a.baselineSelfHostDuration - b.baselineSelfHostDuration }, { title: 'Exp Self Host Duration (us)', dataIndex: 'expSelfHostDuration', key: 'expSelfHostDuration', - sorter: (a: TableRow, b: TableRow): number => a.expSelfHostDuration - b.expSelfHostDuration, + sorter: (a: TableRow, b: TableRow) => + a.expSelfHostDuration - b.expSelfHostDuration }, { title: 'Delta Self Host Duration (us)', dataIndex: 'deltaSelfHostDuration', key: 'deltaSelfHostDuration', - sorter: (a: TableRow, b: TableRow): number => { - const aDeltaSelfHost = a.deltaSelfHostDuration ?? 0; - const bDeltaSelfHost = b.deltaSelfHostDuration ?? 0; - return aDeltaSelfHost - bDeltaSelfHost; - }, + sorter: (a: TableRow, b: TableRow) => + a.deltaSelfHostDuration! - b.deltaSelfHostDuration! }, { title: 'Delta Self Host Duration%', dataIndex: 'deltaSelfHostDurationPercent', key: 'deltaSelfHostDurationPercent', - sorter: (a: TableRow, b: TableRow): number => { - const aSelfPercent = a.deltaSelfHostDurationPercentNumber ?? 0; - const bSelfPercent = b.deltaSelfHostDurationPercentNumber ?? 0; - return aSelfPercent - bSelfPercent; - }, - }, - ]; + sorter: (a: TableRow, b: TableRow) => + a.deltaSelfHostDurationPercentNumber! - + b.deltaSelfHostDurationPercentNumber! + } + ] const deviceDurationColumns = [ { title: 'Baseline Device Duration (us)', dataIndex: 'baselineDeviceDuration', key: 'baselineDeviceDuration', - sorter: (a: TableRow, b: TableRow): number => a.baselineDeviceDuration - b.baselineDeviceDuration, + sorter: (a: TableRow, b: TableRow) => + a.baselineDeviceDuration - b.baselineDeviceDuration }, { title: 'Exp Device Duration (us)', dataIndex: 'expDeviceDuration', key: 'expDeviceDuration', - sorter: (a: TableRow, b: TableRow): number => a.expDeviceDuration - b.expDeviceDuration, + sorter: (a: TableRow, b: TableRow) => + a.expDeviceDuration - b.expDeviceDuration }, { title: 'Delta Device Duration (us)', dataIndex: 'deltaDeviceDuration', key: 'deltaDeviceDuration', - sorter: (a: TableRow, b: TableRow): number => { - const aDeltaDeviceDuration = a.deltaDeviceDuration ?? 0; - const bdeltaDeviceDuration = b.deltaDeviceDuration ?? 0; - return aDeltaDeviceDuration - bdeltaDeviceDuration; - }, + sorter: (a: TableRow, b: TableRow) => + a.deltaDeviceDuration! - b.deltaDeviceDuration! }, { title: 'Delta Device Duration%', dataIndex: 'deltaDeviceDurationPercent', key: 'deltaDeviceDurationPercent', - sorter: (a: TableRow, b: TableRow): number => { - const aDeltaDeviceDurationPercentNumber = a.deltaDeviceDurationPercentNumber ?? 0; - const bDeltaDeviceDurationPercentNumber = b.deltaDeviceDurationPercentNumber ?? 0; - return aDeltaDeviceDurationPercentNumber - bDeltaDeviceDurationPercentNumber; - }, - }, - ]; + sorter: (a: TableRow, b: TableRow) => + a.deltaDeviceDurationPercentNumber! - + b.deltaDeviceDurationPercentNumber! + } + ] const selfDeviceDurationColumns = [ { title: 'Baseline Self Device Duration (us)', dataIndex: 'baselineSelfDeviceDuration', key: 'baselineSelfDeviceDuration', - sorter: (a: TableRow, b: TableRow): number => a.baselineSelfDeviceDuration - b.baselineSelfDeviceDuration, + sorter: (a: TableRow, b: TableRow) => + a.baselineSelfDeviceDuration - b.baselineSelfDeviceDuration }, { title: 'Exp Self Device Duration (us)', dataIndex: 'expSelfDeviceDuration', key: 'expSelfDeviceDuration', - sorter: (a: TableRow, b: TableRow): number => a.expSelfDeviceDuration - b.expSelfDeviceDuration, + sorter: (a: TableRow, b: TableRow) => + a.expSelfDeviceDuration - b.expSelfDeviceDuration }, { title: 'Delta Self Device Duration (us)', dataIndex: 'deltaSelfDeviceDuration', key: 'deltaSelfDeviceDuration', - sorter: (a: TableRow, b: TableRow): number => { - const aDeltaSelfDeviceDuration = a.deltaSelfDeviceDuration ?? 0; - const bDeltaSelfDeviceDuration = b.deltaSelfDeviceDuration ?? 0; - return aDeltaSelfDeviceDuration - bDeltaSelfDeviceDuration; - }, + sorter: (a: TableRow, b: TableRow) => + a.deltaSelfDeviceDuration! - b.deltaSelfDeviceDuration! }, { title: 'Delta Self Device Duration%', dataIndex: 'deltaSelfDeviceDurationPercent', key: 'deltaSelfDeviceDurationPercent', - sorter: (a: TableRow, b: TableRow): number => { - const aDeltaSelfDeviceDurationPercentNumber = a.deltaSelfDeviceDurationPercentNumber ?? 0; - const bDeltaSelfDeviceDurationPercentNumber = b.deltaSelfDeviceDurationPercentNumber ?? 0; - return aDeltaSelfDeviceDurationPercentNumber - bDeltaSelfDeviceDurationPercentNumber; - }, - }, - ]; + sorter: (a: TableRow, b: TableRow) => + a.deltaSelfDeviceDurationPercentNumber! - + b.deltaSelfDeviceDurationPercentNumber! + } + ] - interface IColumnMap { - [key: string]: any; - } - type IColumnMapType = IColumnMap; + type IColumnMapType = { [key: string]: any } const tableSourceColumnMap: IColumnMapType = { selfHostDuration: selfHostDurationColumns, hostDuration: hostDurationColumns, deviceDuration: deviceDurationColumns, - selfDeviceDuration: selfDeviceDurationColumns, - }; + selfDeviceDuration: selfDeviceDurationColumns + } const baseTableColumns = [ { title: 'Operator', dataIndex: 'operator', key: 'operator', - sorter: (a: TableRow, b: TableRow) => a.operator.localeCompare(b.operator), + sorter: (a: TableRow, b: TableRow) => a.operator.localeCompare(b.operator) }, { title: 'Baseline Calls', dataIndex: 'baselineCalls', key: 'baselineCalls', - sorter: (a: TableRow, b: TableRow) => a.baselineCalls ?? 0 - (b.baselineCalls ?? 0), + sorter: (a: TableRow, b: TableRow) => a.baselineCalls! - b.baselineCalls! }, { title: 'Exp Calls', dataIndex: 'expCalls', key: 'expCalls', - sorter: (a: TableRow, b: TableRow) => a.expCalls ?? 0 - (b.expCalls ?? 0), + sorter: (a: TableRow, b: TableRow) => a.expCalls! - b.expCalls! }, { title: 'Delta Calls', dataIndex: 'deltaCalls', key: 'deltaCalls', - sorter: (a: TableRow, b: TableRow) => a.deltaCalls ?? 0 - (b.deltaCalls ?? 0), + sorter: (a: TableRow, b: TableRow) => a.deltaCalls! - b.deltaCalls! }, { title: 'Delta Calls%', dataIndex: 'deltaCallsPercent', key: 'deltaCallsPercent', - sorter: (a: TableRow, b: TableRow) => a.deltaCallsPercentNumber ?? 0 - (b.deltaCallsPercentNumber ?? 0), - }, - ]; + sorter: (a: TableRow, b: TableRow) => + a.deltaCallsPercentNumber! - b.deltaCallsPercentNumber! + } + ] // #endregion // #region - State - const [tableDataSource, setTableDataSource] = React.useState([]); - const { run, worker, span, expRun, expWorker, expSpan } = props; + const [tableDataSource, setTableDataSource] = React.useState([]) + const { run, worker, span, expRun, expWorker, expSpan } = props - const [columnUnderlyingData, setColumnUnderlyingData] = React.useState([]); + const [columnUnderlyingData, setColumnUnderlyingData] = React.useState< + ColumnUnderlyingData[] + >([]) - const [rootUnderlyingData, setRootUnderlyingData] = React.useState(); + const [ + rootUnderlyingData, + setRootUnderlyingData + ] = React.useState() - const [columnChartData, setColumnChartData] = React.useState([]); - const [stepChartData, setStepChartData] = React.useState([]); + const [columnChartData, setColumnChartData] = React.useState([]) + const [stepChartData, setStepChartData] = React.useState([]) - const [selectedTableColumnsOptions, setSelectedTableColumnsOptions] = React.useState<[key: string]>(['hostDuration']); - const [selectedTableColumns, setSelectedTableColumns] = React.useState([ - ...baseTableColumns, - ...hostDurationColumns, - ]); + const [ + selectedTableColumnsOptions, + setSelectedTableColumnsOptions + ] = React.useState<[key: string]>(['hostDuration']) + const [selectedTableColumns, setSelectedTableColumns] = React.useState( + [...baseTableColumns, ...hostDurationColumns] + ) - const [dataStackLevel, setDataStackLevel] = React.useState(0); - const [loading, setLoading] = React.useState(false); + const [dataStackLevel, setDataStackLevel] = React.useState(0) + const [loading, setLoading] = React.useState(false) // #endregion - const classes = useStyles(); + const classes = useStyles() // #region - Event Handler - const handleChartColumnSelect = (row: number, column: number): void => { + const handleChartColumnSelect = (row: number, column: number) => { if (columnUnderlyingData.length === 0) { - return; + return } - let selectedUnderlyingData = columnUnderlyingData[row]; + let selectedUnderlyingData = columnUnderlyingData[row] if (!selectedUnderlyingData) { - return; + return } - let tableDataSource1 = generateDataSourceFromUnderlyingData(selectedUnderlyingData); - setTableDataSource(tableDataSource1); - columnTableDataSourceStack.push(tableDataSource1); + let tableDataSource = generateDataSourceFromUnderlyingData( + selectedUnderlyingData + ) + setTableDataSource(tableDataSource) + columnTableDataSourceStack.push(tableDataSource) - setLoading(true); + setLoading(true) api.defaultApi - .diffnodeGet(run, worker, span, expRun, expWorker, expSpan, selectedUnderlyingData.path) + .diffnodeGet( + run, + worker, + span, + expRun, + expWorker, + expSpan, + selectedUnderlyingData.path + ) .then((resp) => handleDiffNodeResp(resp)) - .finally(() => setLoading(false)); - }; + .finally(() => setLoading(false)) + } - const handleGoBack = (): void => { + const handleGoBack = () => { if (columnChartDataStack.length > 1) { - columnChartDataStack.pop(); - let top = columnChartDataStack[columnChartDataStack.length - 1]; - setColumnChartData(top); + columnChartDataStack.pop() + let top = columnChartDataStack[columnChartDataStack.length - 1] + setColumnChartData(top) } if (stepChartDataStack.length > 1) { - stepChartDataStack.pop(); - let top = stepChartDataStack[stepChartDataStack.length - 1]; - setStepChartData(top); + stepChartDataStack.pop() + let top = stepChartDataStack[stepChartDataStack.length - 1] + setStepChartData(top) } if (columnUnderlyingDataStack.length > 0) { - columnUnderlyingDataStack.pop(); - let top = columnUnderlyingDataStack[columnUnderlyingDataStack.length - 1]; - setColumnUnderlyingData(top); + columnUnderlyingDataStack.pop() + let top = columnUnderlyingDataStack[columnUnderlyingDataStack.length - 1] + setColumnUnderlyingData(top) } if (columnTableDataSourceStack.length > 0) { - columnTableDataSourceStack.pop(); - let top = columnTableDataSourceStack[columnTableDataSourceStack.length - 1]; + columnTableDataSourceStack.pop() + let top = + columnTableDataSourceStack[columnTableDataSourceStack.length - 1] if (top) { - setTableDataSource(top); + setTableDataSource(top) } else { - let tableDataSource2 = generateDataSourceFromUnderlyingData(rootUnderlyingData); - setTableDataSource(tableDataSource2); + let tableDataSource = generateDataSourceFromUnderlyingData( + rootUnderlyingData! + ) + setTableDataSource(tableDataSource) } } - setDataStackLevel(dataStackLevel - 1); - }; + setDataStackLevel(dataStackLevel - 1) + } - const toPercentString = (percentNumber: number): string => { + const toPercentString = (percentNumber: number) => { if (isNaN(percentNumber)) { - return 'N/A'; + return 'N/A' } - return `${percentNumber.toFixed(2)}%`; - }; + return `${percentNumber.toFixed(2)}%` + } - const handleColumnSelectionChange = (value: [key: string]): void => { - let columns = value.map((x) => tableSourceColumnMap[x]).flat(); - let r = [...baseTableColumns, ...columns]; - setSelectedTableColumnsOptions(value); - setSelectedTableColumns(r); - }; + const handleColumnSelectionChange = (value: [key: string]) => { + let columns = value.map((x) => tableSourceColumnMap[x]).flat() + let r = [...baseTableColumns, ...columns] + setSelectedTableColumnsOptions(value) + setSelectedTableColumns(r) + } - const generateDataSourceFromUnderlyingData = (selectedUnderlyingData?: ColumnUnderlyingData): TableRow[] => { - if (!selectedUnderlyingData) { - return []; - } - let newTableDataSource: TableRow[] = []; + const generateDataSourceFromUnderlyingData = ( + selectedUnderlyingData: ColumnUnderlyingData + ) => { + let tableDataSource: TableRow[] = [] for (let i = 0; i < selectedUnderlyingData.leftAggs.length; i++) { - let left = selectedUnderlyingData.leftAggs[i]; - let right = selectedUnderlyingData.rightAggs[i]; + let left = selectedUnderlyingData.leftAggs[i] + let right = selectedUnderlyingData.rightAggs[i] - let deltaCallsPercentNumber = ((right.calls - left.calls) / left.calls) * 100; + let deltaCallsPercentNumber = + ((right.calls - left.calls) / left.calls) * 100 - let deltaHostDurationPercentNumber = ((right.host_duration - left.host_duration) / left.host_duration) * 100; + let deltaHostDurationPercentNumber = + ((right.host_duration - left.host_duration) / left.host_duration) * 100 let deltaSelfHostDurationPercentNumber = - ((right.self_host_duration - left.self_host_duration) / left.self_host_duration) * 100; + ((right.self_host_duration - left.self_host_duration) / + left.self_host_duration) * + 100 let deltaDeviceDurationPercentNumber = - ((right.device_duration - left.device_duration) / left.device_duration) * 100; + ((right.device_duration - left.device_duration) / + left.device_duration) * + 100 let deltaSelfDeviceDurationPercentNumber = - ((right.self_device_duration - left.self_device_duration) / left.self_device_duration) * 100; + ((right.self_device_duration - left.self_device_duration) / + left.self_device_duration) * + 100 - newTableDataSource.push({ + tableDataSource.push({ key: i, operator: left.name, baselineCalls: left.calls, @@ -686,194 +717,214 @@ export const DiffOverview: React.FC = (props: IProps) => { expHostDuration: right.host_duration, deltaHostDuration: parseFloat((right.host_duration - left.host_duration).toFixed(3)), deltaHostDurationPercentNumber: deltaHostDurationPercentNumber, - deltaHostDurationPercent: toPercentString(deltaHostDurationPercentNumber), + deltaHostDurationPercent: toPercentString( + deltaHostDurationPercentNumber + ), baselineSelfHostDuration: left.self_host_duration, expSelfHostDuration: right.self_host_duration, - deltaSelfHostDuration: parseFloat((right.self_host_duration - left.self_host_duration).toFixed(3)), + deltaSelfHostDuration: + parseFloat((right.self_host_duration - left.self_host_duration).toFixed(3)), deltaSelfHostDurationPercentNumber: deltaSelfHostDurationPercentNumber, - deltaSelfHostDurationPercent: toPercentString(deltaSelfHostDurationPercentNumber), + deltaSelfHostDurationPercent: toPercentString( + deltaSelfHostDurationPercentNumber + ), baselineDeviceDuration: left.device_duration, expDeviceDuration: right.device_duration, deltaDeviceDuration: parseFloat((right.device_duration - left.device_duration).toFixed(3)), deltaDeviceDurationPercentNumber: deltaDeviceDurationPercentNumber, - deltaDeviceDurationPercent: toPercentString(deltaDeviceDurationPercentNumber), + deltaDeviceDurationPercent: toPercentString( + deltaDeviceDurationPercentNumber + ), baselineSelfDeviceDuration: left.self_device_duration, expSelfDeviceDuration: right.self_device_duration, - deltaSelfDeviceDuration: parseFloat((right.self_device_duration - left.self_device_duration).toFixed(3)), + deltaSelfDeviceDuration: + parseFloat((right.self_device_duration - left.self_device_duration).toFixed(3)), deltaSelfDeviceDurationPercentNumber: deltaSelfDeviceDurationPercentNumber, - deltaSelfDeviceDurationPercent: toPercentString(deltaSelfDeviceDurationPercentNumber), - }); + deltaSelfDeviceDurationPercent: toPercentString( + deltaSelfDeviceDurationPercentNumber + ) + }) } - return newTableDataSource; - }; + return tableDataSource + } React.useEffect(() => { - const hasData = + if ( run.length > 0 && worker.length > 0 && span.length > 0 && expRun.length > 0 && expWorker.length > 0 && - expSpan.length > 0; - if (hasData) { - setLoading(true); + expSpan.length > 0 + ) { + setLoading(true) - columnChartDataStack = []; - stepChartDataStack = []; - columnUnderlyingDataStack = []; - columnTableDataSourceStack = []; + columnChartDataStack = [] + stepChartDataStack = [] + columnUnderlyingDataStack = [] + columnTableDataSourceStack = [] api.defaultApi .diffnodeGet(run, worker, span, expRun, expWorker, expSpan) .then((resp) => { - handleDiffNodeResp(resp); - let newRootUnderlyingData = { + handleDiffNodeResp(resp) + let rootUnderlyingData = { name: 'rootNode', path: resp.path, leftAggs: resp.left.aggs, - rightAggs: resp.right.aggs, - }; + rightAggs: resp.right.aggs + } - setRootUnderlyingData(newRootUnderlyingData); - let tableDataSource3 = generateDataSourceFromUnderlyingData(newRootUnderlyingData); - setTableDataSource(tableDataSource3); + setRootUnderlyingData(rootUnderlyingData) + let tableDataSource = generateDataSourceFromUnderlyingData( + rootUnderlyingData! + ) + setTableDataSource(tableDataSource) }) - .finally(() => setLoading(false)); + .finally(() => setLoading(false)) - setSelectedTableColumns([...baseTableColumns, ...hostDurationColumns]); + setSelectedTableColumns([...baseTableColumns, ...hostDurationColumns]) } - }, [run, worker, span, expRun, expWorker, expSpan]); - - const handleDiffNodeResp = (resp: any): void => { - let newColumnChartData: any[] = []; - let newStepChartData: any[] = []; - let underlyingData: ColumnUnderlyingData[] = []; - - newColumnChartData.push(['Call', 'Baseline', 'Experiment', 'Baseline Trend', 'Exp Trend']); - newStepChartData.push(['Call', 'Diff', 'Accumulated Diff']); + }, [run, worker, span, expRun, expWorker, expSpan]) + + const handleDiffNodeResp = (resp: any) => { + let columnChartData: any[] = [] + let stepChartData: any[] = [] + let underlyingData: ColumnUnderlyingData[] = [] + + columnChartData.push([ + 'Call', + 'Baseline', + 'Experiment', + 'Baseline Trend', + 'Exp Trend' + ]) + stepChartData.push(['Call', 'Diff', 'Accumulated Diff']) if (resp.children.length > 0) { - let accumulatedLeftDuration = 0; - let accumulatedRightDuration = 0; - let accumulatedStepDiff = 0; + let accumulated_left_duration = 0 + let accumulated_right_duration = 0 + let accumulated_step_diff = 0 for (let i = 0; i < resp.children.length; i++) { - let left = resp.children[i].left; - let right = resp.children[i].right; - let currColumn: any[] = []; - let currStep: any[] = []; + let left = resp.children[i].left + let right = resp.children[i].right + let currColumn: any[] = [] + let currStep: any[] = [] - let name = left.name; + let name = left.name if (name === COMPOSITE_NODES_NAME) { - continue; + continue } if (name.startsWith('aten::')) { // Ignore aten operators - continue; + continue } if (name.startsWith('enumerate(DataLoader)')) { - name = name.substring(21); + name = name.substring(21) } if (name.startsWith('enumerate(DataPipe)')) { - name = name.substring(19); + name = name.substring(19) } if (name.startsWith('nn.Module: ')) { - name = name.substring(11); + name = name.substring(11) } if (name.startsWith('Optimizer.zero_grad')) { - name = 'Optimizer.zero_grad'; + name = 'Optimizer.zero_grad' } if (name.startsWith('Optimizer.step')) { - name = 'Optimizer.step'; + name = 'Optimizer.step' } - currColumn.push(name); - currColumn.push(left.total_duration); - currColumn.push(right.total_duration); + currColumn.push(name) + currColumn.push(left.total_duration) + currColumn.push(right.total_duration) - accumulatedLeftDuration += left.total_duration; - currColumn.push(accumulatedLeftDuration); + accumulated_left_duration += left.total_duration + currColumn.push(accumulated_left_duration) - accumulatedRightDuration += right.total_duration; - currColumn.push(accumulatedRightDuration); - newColumnChartData.push(currColumn); + accumulated_right_duration += right.total_duration + currColumn.push(accumulated_right_duration) + columnChartData.push(currColumn) underlyingData.push({ name: name, path: resp.children[i].path, leftAggs: left.aggs, - rightAggs: right.aggs, - }); + rightAggs: right.aggs + }) - currStep.push(name); - let stepDiff = right.total_duration - left.total_duration; - currStep.push(stepDiff); + currStep.push(name) + let stepDiff = right.total_duration - left.total_duration + currStep.push(stepDiff) - accumulatedStepDiff += stepDiff; - currStep.push(accumulatedStepDiff); + accumulated_step_diff += stepDiff + currStep.push(accumulated_step_diff) - newStepChartData.push(currStep); + stepChartData.push(currStep) } } else { - let left = resp.left; - let right = resp.right; - let currColumn: any[] = []; - let currStep: any[] = []; - let name = left.name; + let left = resp.left + let right = resp.right + let currColumn: any[] = [] + let currStep: any[] = [] + let name = left.name if (name.startsWith('nn.Module: ')) { - name = name.substring(11); + name = name.substring(11) } - currColumn.push(name); - currColumn.push(left.total_duration); - currColumn.push(right.total_duration); - currColumn.push(left.total_duration); - currColumn.push(right.total_duration); + currColumn.push(name) + currColumn.push(left.total_duration) + currColumn.push(right.total_duration) + currColumn.push(left.total_duration) + currColumn.push(right.total_duration) - newColumnChartData.push(currColumn); + columnChartData.push(currColumn) - currStep.push(name); - let stepDiff = right.total_duration - left.total_duration; - currStep.push(stepDiff); - currStep.push(stepDiff); - newStepChartData.push(currStep); + currStep.push(name) + let stepDiff = right.total_duration - left.total_duration + currStep.push(stepDiff) + currStep.push(stepDiff) + stepChartData.push(currStep) } - setColumnChartData(newColumnChartData); - columnChartDataStack.push(newColumnChartData); + setColumnChartData(columnChartData) + columnChartDataStack.push(columnChartData) + + setStepChartData(stepChartData) + stepChartDataStack.push(stepChartData) - setStepChartData(newStepChartData); - stepChartDataStack.push(newStepChartData); + setColumnUnderlyingData(underlyingData) + columnUnderlyingDataStack.push(underlyingData) - setColumnUnderlyingData(underlyingData); - columnUnderlyingDataStack.push(underlyingData); + setDataStackLevel(columnChartDataStack.length) + } - setDataStackLevel(columnChartDataStack.length); - }; // #endregion + // #endregion if (!loading && columnUnderlyingDataStack.length === 0) { return ( - - + + There is no run selected for diff. - ); + ) } if (loading) { - return ; + return } return ( @@ -881,62 +932,73 @@ export const DiffOverview: React.FC = (props: IProps) => { - - + + {columnChartData.length > 1 && ( <> - + )} - {columnChartData.length === 1 && No more level to show.} + {columnChartData.length === 1 && ( + No more level to show. + )} - - + +   - +
- ); -}; + ) +} diff --git a/plugins/tensorboard-plugins/tb_plugin/fe/src/components/DistributedView.tsx b/plugins/tensorboard-plugins/tb_plugin/fe/src/components/DistributedView.tsx index 096501b61bc9ce41978c65dc24f6b3640ab960f3..aad14aa29828fa1a8886ab3f68c54dd62cd396f9 100644 --- a/plugins/tensorboard-plugins/tb_plugin/fe/src/components/DistributedView.tsx +++ b/plugins/tensorboard-plugins/tb_plugin/fe/src/components/DistributedView.tsx @@ -2,54 +2,54 @@ * Copyright (c) Microsoft Corporation. All rights reserved. *--------------------------------------------------------------------------------------------*/ -import Card from '@material-ui/core/Card'; -import CardContent from '@material-ui/core/CardContent'; -import CardHeader from '@material-ui/core/CardHeader'; -import Grid from '@material-ui/core/Grid'; -import InputLabel from '@material-ui/core/InputLabel'; -import MenuItem from '@material-ui/core/MenuItem'; -import Select, { SelectProps } from '@material-ui/core/Select'; -import { makeStyles } from '@material-ui/core/styles'; -import { Table } from 'antd'; -import { ColumnsType } from 'antd/es/table'; -import * as React from 'react'; -import * as api from '../api'; -import { DistributedGraph, GpuInfo, Graph } from '../api'; -import { firstOrUndefined } from '../utils'; -import { ColumnChart } from './charts/ColumnChart'; -import { DataLoading } from './DataLoading'; -import { GpuInfoTable } from './GpuInfoTable'; -import { makeChartHeaderRenderer, useTooltipCommonStyles } from './helpers'; +import Card from '@material-ui/core/Card' +import CardContent from '@material-ui/core/CardContent' +import CardHeader from '@material-ui/core/CardHeader' +import Grid from '@material-ui/core/Grid' +import InputLabel from '@material-ui/core/InputLabel' +import MenuItem from '@material-ui/core/MenuItem' +import Select, { SelectProps } from '@material-ui/core/Select' +import { makeStyles } from '@material-ui/core/styles' +import { Table } from 'antd' +import { ColumnsType } from 'antd/es/table' +import * as React from 'react' +import * as api from '../api' +import { DistributedGraph, GpuInfo, Graph } from '../api' +import { firstOrUndefined } from '../utils' +import { ColumnChart } from './charts/ColumnChart' +import { DataLoading } from './DataLoading' +import { GpuInfoTable } from './GpuInfoTable' +import { makeChartHeaderRenderer, useTooltipCommonStyles } from './helpers' import { - distributedCommopsTableTooltip, - distributedGpuInfoTableTooltip, - distributedOverlapGraphTooltip, - distributedWaittimeGraphTooltip, -} from './TooltipDescriptions'; + DistributedCommopsTableTooltip, + DistributedGpuInfoTableTooltip, + DistributedOverlapGraphTooltip, + DistributedWaittimeGraphTooltip +} from './TooltipDescriptions' export interface IProps { - run: string; - worker: string; - span: string; + run: string + worker: string + span: string } const useStyles = makeStyles((theme) => ({ root: { - flexGrow: 1, + flexGrow: 1 }, verticalInput: { display: 'flex', - alignItems: 'center', + alignItems: 'center' }, inputWidth: { - width: '4em', + width: '4em' }, inputWidthOverflow: { minWidth: '15em', - whiteSpace: 'nowrap', + whiteSpace: 'nowrap' }, description: { - marginLeft: theme.spacing(1), + marginLeft: theme.spacing(1) }, table: { height: '100%', @@ -58,152 +58,165 @@ const useStyles = makeStyles((theme) => ({ height: 20, fontSize: '10pt', '& > td': { - padding: '0 8px!important', - }, - }, - }, -})); + padding: '0 8px!important' + } + } + } +})) export const DistributedView: React.FC = (props) => { - const tooltipCommonClasses = useTooltipCommonStyles(); + const tooltipCommonClasses = useTooltipCommonStyles() const chartHeaderRenderer = React.useMemo( () => makeChartHeaderRenderer(tooltipCommonClasses), [tooltipCommonClasses] - ); + ) - let { run, worker, span } = props; - const classes = useStyles(); + let { run, worker, span } = props + const classes = useStyles() - const [overlapGraph, setOverlapGraph] = React.useState(undefined); - const [waittimeGraph, setWaittimeGraph] = React.useState(undefined); - const [commopsTableData, setCommopsTableData] = React.useState(undefined); - const [gpuInfo, setGpuInfo] = React.useState(undefined); - const [commopsTableTitle, setCommopsTableTitle] = React.useState(''); - const [commopsWorkers, setCommopsWorkers] = React.useState([]); - const [overlapSteps, setOverlapSteps] = React.useState([]); - const [waittimeSteps, setWaittimeSteps] = React.useState([]); - const [overlapStep, setOverlapStep] = React.useState(''); - const [waittimeStep, setWaittimeStep] = React.useState(''); - const [commopsWorker, setCommopsWorker] = React.useState(''); - const [columns, setColumns] = React.useState>([]); - const [pageSize, setPageSize] = React.useState(30); + const [overlapGraph, setOverlapGraph] = React.useState< + DistributedGraph | undefined + >(undefined) + const [waittimeGraph, setWaittimeGraph] = React.useState< + DistributedGraph | undefined + >(undefined) + const [commopsTableData, setCommopsTableData] = React.useState< + any | undefined + >(undefined) + const [gpuInfo, setGpuInfo] = React.useState(undefined) + const [commopsTableTitle, setCommopsTableTitle] = React.useState('') + const [commopsWorkers, setCommopsWorkers] = React.useState([]) + const [overlapSteps, setOverlapSteps] = React.useState([]) + const [waittimeSteps, setWaittimeSteps] = React.useState([]) + const [overlapStep, setOverlapStep] = React.useState('') + const [waittimeStep, setWaittimeStep] = React.useState('') + const [commopsWorker, setCommopsWorker] = React.useState('') + const [columns, setColumns] = React.useState>([]) + const [pageSize, setPageSize] = React.useState(30) React.useEffect(() => { if (waittimeSteps.includes('all')) { - setWaittimeStep('all'); + setWaittimeStep('all') } else { - setWaittimeStep(firstOrUndefined(waittimeSteps) ?? ''); + setWaittimeStep(firstOrUndefined(waittimeSteps) ?? '') } - }, [waittimeSteps]); + }, [waittimeSteps]) React.useEffect(() => { if (overlapSteps.includes('all')) { - setOverlapStep('all'); + setOverlapStep('all') } else { - setOverlapStep(firstOrUndefined(overlapSteps) ?? ''); + setOverlapStep(firstOrUndefined(overlapSteps) ?? '') } - }, [overlapSteps]); + }, [overlapSteps]) React.useEffect(() => { - setCommopsWorker(firstOrUndefined(commopsWorkers) ?? ''); - }, [commopsWorkers]); + setCommopsWorker(firstOrUndefined(commopsWorkers) ?? '') + }, [commopsWorkers]) React.useEffect(() => { api.defaultApi.distributedOverlapGet(run, 'All', span).then((resp) => { - setOverlapGraph(resp); - setOverlapSteps(Object.keys(resp.data)); - }); + setOverlapGraph(resp) + setOverlapSteps(Object.keys(resp.data)) + }) api.defaultApi.distributedWaittimeGet(run, 'All', span).then((resp) => { - setWaittimeGraph(resp); - setWaittimeSteps(Object.keys(resp.data)); - }); + setWaittimeGraph(resp) + setWaittimeSteps(Object.keys(resp.data)) + }) api.defaultApi.distributedCommopsGet(run, 'All', span).then((resp) => { - setCommopsTableData(resp.data); - setCommopsWorkers(Object.keys(resp.data)); - setCommopsTableTitle(resp.metadata.title); - }); + setCommopsTableData(resp.data) + setCommopsWorkers(Object.keys(resp.data)) + setCommopsTableTitle(resp.metadata.title) + }) api.defaultApi.distributedGpuinfoGet(run, 'All', span).then((resp) => { - setGpuInfo(resp); - }); - }, [run, worker, span]); + setGpuInfo(resp) + }) + }, [run, worker, span]) const onCommopsWorkerChanged: SelectProps['onChange'] = (event) => { - setCommopsWorker(event.target.value as string); - }; + setCommopsWorker(event.target.value as string) + } const onOverlapStepChanged: SelectProps['onChange'] = (event) => { - setOverlapStep(event.target.value as string); - }; + setOverlapStep(event.target.value as string) + } const onWaittimeStepChanged: SelectProps['onChange'] = (event) => { - setWaittimeStep(event.target.value as string); - }; + setWaittimeStep(event.target.value as string) + } - const getColumnChartData = (distributedGraph?: DistributedGraph, step?: string): any => { - if (!distributedGraph || !step) { - return undefined; - } - const barLabels = Object.keys(distributedGraph.data[step]); + const getColumnChartData = ( + distributedGraph?: DistributedGraph, + step?: string + ) => { + if (!distributedGraph || !step) return undefined + const barLabels = Object.keys(distributedGraph.data[step]) return { legends: distributedGraph.metadata.legends, barLabels, - barHeights: barLabels.map((label) => distributedGraph.data[step][label]), - }; - }; - const overlapData = React.useMemo(() => getColumnChartData(overlapGraph, overlapStep), [overlapGraph, overlapStep]); + barHeights: barLabels.map((label) => distributedGraph.data[step][label]) + } + } + const overlapData = React.useMemo( + () => getColumnChartData(overlapGraph, overlapStep), + [overlapGraph, overlapStep] + ) const waittimeData = React.useMemo( () => getColumnChartData(waittimeGraph, waittimeStep), [waittimeGraph, waittimeStep] - ); + ) - const getTableData = (tableData?: any, opsWorker?: string): any[] => { - if (!tableData || !opsWorker) { - return []; + const getTableData = (tableData?: any, worker?: string) => { + if (!tableData || !worker) { + return [] } - let dataInfo: api.Graph = tableData[opsWorker]; - const stringCompare = (a: string, b: string): number => a.localeCompare(b); - const numberCompare = (a: number, b: number): number => a - b; - let column: any[] = dataInfo.columns.map((item) => { + let dataInfo: api.Graph = tableData[worker] + const stringCompare = (a: string, b: string) => a.localeCompare(b) + const numberCompare = (a: number, b: number) => a - b + let column: any[] = dataInfo.columns.map(item => { return { title: item.name, key: item.name, dataIndex: item.name, - sorter: - item.type === 'string' - ? (a: any, b: any): number => stringCompare(a[item.name], b[item.name]) - : (a: any, b: any): number => numberCompare(a[item.name], b[item.name]), - }; - }); - setColumns(column); + sorter: item.type == 'string' ? (a: any, b: any) => stringCompare(a[item.name], b[item.name]) + : (a: any, b: any) => numberCompare(a[item.name], b[item.name]) + } + }) + setColumns(column) return dataInfo.rows.map((row, index) => { if (row.length !== dataInfo.columns.length) { - return null; + return null } - const dataRow: { [column: string]: number | string } = { key: index }; - dataInfo.columns.forEach((item, idx) => { - dataRow[item.name] = row[idx] as string | number; - }); - return dataRow; - }); - }; + const dataRow: { [column: string]: number | string } = { key: index } + dataInfo.columns.forEach((column, index) => { + dataRow[column.name] = row[index] as string | number + }) + return dataRow + }) + } const commopsTable: any[] = React.useMemo(() => { - return getTableData(commopsTableData, commopsWorker); - }, [commopsTableData, commopsWorker]); + return getTableData(commopsTableData, commopsWorker) + }, [commopsTableData, commopsWorker]) - const onShowSizeChange = (current: number, size: number): void => { - setPageSize(size); - }; + const onShowSizeChange = (current: number, size: number) => { + setPageSize(size) + } return (
- - + + {gpuInfo && ( - + @@ -212,15 +225,19 @@ export const DistributedView: React.FC = (props) => { )} - {(chartData): JSX.Element => ( + {(chartData) => ( - + - Step + Step - {overlapSteps.map((step) => ( {step} ))} @@ -230,25 +247,35 @@ export const DistributedView: React.FC = (props) => { {overlapGraph?.metadata?.title && ( )} - + )} - {(chartData): JSX.Element => ( + {(chartData) => ( - + - Step + Step - {waittimeSteps.map((step) => ( {step} ))} @@ -258,7 +285,10 @@ export const DistributedView: React.FC = (props) => { {waittimeGraph?.metadata?.title && ( )} = (props) => { - - + + - + - Worker + Worker - + {commopsWorkers.map((worker) => ( + {worker} ))} @@ -299,7 +338,7 @@ export const DistributedView: React.FC = (props) => { pageSize, pageSizeOptions: ['20', '30', '50', '100'], hideOnSinglePage: true, - onShowSizeChange, + onShowSizeChange }} /> @@ -309,5 +348,5 @@ export const DistributedView: React.FC = (props) => {
- ); -}; + ) +} diff --git a/plugins/tensorboard-plugins/tb_plugin/fe/src/components/FullCircularProgress.tsx b/plugins/tensorboard-plugins/tb_plugin/fe/src/components/FullCircularProgress.tsx index 3f4c0fbaf15a15d402aa205574a28df045d24aec..5212bd74bf9739cc171d369e6591a0c26f058f6a 100644 --- a/plugins/tensorboard-plugins/tb_plugin/fe/src/components/FullCircularProgress.tsx +++ b/plugins/tensorboard-plugins/tb_plugin/fe/src/components/FullCircularProgress.tsx @@ -1,23 +1,23 @@ /*--------------------------------------------------------------------------------------------- * Copyright (c) Microsoft Corporation. All rights reserved. *--------------------------------------------------------------------------------------------*/ -import CircularProgress from '@material-ui/core/CircularProgress'; -import { makeStyles } from '@material-ui/core/styles'; -import * as React from 'react'; +import CircularProgress from '@material-ui/core/CircularProgress' +import { makeStyles } from '@material-ui/core/styles' +import * as React from 'react' const useStyles = makeStyles(() => ({ root: { width: '100%', display: 'flex', - justifyContent: 'center', - }, -})); + justifyContent: 'center' + } +})) export const FullCircularProgress: React.FC = () => { - const classes = useStyles(); + const classes = useStyles() return (
- ); -}; + ) +} diff --git a/plugins/tensorboard-plugins/tb_plugin/fe/src/components/GpuInfoTable.tsx b/plugins/tensorboard-plugins/tb_plugin/fe/src/components/GpuInfoTable.tsx index 07f6f1d78c88abab5f62f844356b47ca517a2561..4c624db0580caa466271e56505f2838637705884 100644 --- a/plugins/tensorboard-plugins/tb_plugin/fe/src/components/GpuInfoTable.tsx +++ b/plugins/tensorboard-plugins/tb_plugin/fe/src/components/GpuInfoTable.tsx @@ -2,123 +2,127 @@ * Copyright (c) Microsoft Corporation. All rights reserved. *--------------------------------------------------------------------------------------------*/ -import { makeStyles } from '@material-ui/core/styles'; -import * as React from 'react'; +import { makeStyles } from '@material-ui/core/styles' +import * as React from 'react' export interface IProps { - gpuInfo: any; + gpuInfo: any } const useStyles = makeStyles((theme) => ({ root: { border: '1px solid #E0E0E0', borderCollapse: 'collapse', - width: '100%', + width: '100%' }, td: { borderTop: '1px solid #E0E0E0', borderBottom: '1px solid #E0E0E0', borderCollapse: 'collapse', paddingLeft: 10, - paddingRight: 10, + paddingRight: 10 }, nodeTd: { - fontWeight: 'bold', + fontWeight: 'bold' }, pidTd: { - fontWeight: 'normal', + fontWeight: 'normal' }, gpuTd: { - fontWeight: 'normal', + fontWeight: 'normal' }, keyTd: { fontWeight: 'normal', - textAlign: 'right', + textAlign: 'right' }, valueTd: { - fontWeight: 'bold', - }, -})); + fontWeight: 'bold' + } +})) interface TableCellInfo { - content: string; - rowspan: number; - cellType: 'node' | 'pid' | 'gpu' | 'key' | 'value'; - last?: boolean; + content: string + rowspan: number + cellType: 'node' | 'pid' | 'gpu' | 'key' | 'value' + last?: boolean } function makeTableCellInfo(gpuInfo: any): TableCellInfo[][] { - const rows: TableCellInfo[][] = []; - let currRow: TableCellInfo[] = []; - rows.push(currRow); - Object.keys(gpuInfo.data).forEach((nodeName) => { - const nodeCell = { - content: nodeName, + const rows: TableCellInfo[][] = [] + let curr_row: TableCellInfo[] = [] + rows.push(curr_row) + Object.keys(gpuInfo.data).forEach(function (node_name) { + const node_cell = { + content: node_name, rowspan: 0, - cellType: 'node' as const, - }; - const i = rows.length; - currRow.push(nodeCell); - Object.keys(gpuInfo.data[nodeName]).forEach((pid) => { - const pidCell = { content: pid, rowspan: 0, cellType: 'pid' as const }; - const j = rows.length; - currRow.push(pidCell); - Object.keys(gpuInfo.data[nodeName][pid]).forEach((gpu) => { - const gpuCell = { content: gpu, rowspan: 0, cellType: 'gpu' as const }; - const k = rows.length; - currRow.push(gpuCell); - Object.keys(gpuInfo.data[nodeName][pid][gpu]).forEach((keyName) => { - currRow.push({ - content: keyName, + cellType: 'node' as const + } + const i = rows.length + curr_row.push(node_cell) + Object.keys(gpuInfo.data[node_name]).forEach(function (pid) { + const pid_cell = { content: pid, rowspan: 0, cellType: 'pid' as const } + const i = rows.length + curr_row.push(pid_cell) + Object.keys(gpuInfo.data[node_name][pid]).forEach(function (gpu) { + const gpu_cell = { content: gpu, rowspan: 0, cellType: 'gpu' as const } + const i = rows.length + curr_row.push(gpu_cell) + Object.keys(gpuInfo.data[node_name][pid][gpu]).forEach(function ( + key_name + ) { + curr_row.push({ + content: key_name, rowspan: 1, - cellType: 'key' as const, - }); - const value: string = gpuInfo.data[nodeName][pid][gpu][keyName]; - currRow.push({ + cellType: 'key' as const + }) + const value: string = gpuInfo.data[node_name][pid][gpu][key_name] + curr_row.push({ content: value, rowspan: 1, - cellType: 'value' as const, - }); - currRow = []; - rows.push(currRow); - }); - gpuCell.rowspan = rows.length - k; - }); - pidCell.rowspan = rows.length - j; - }); - nodeCell.rowspan = rows.length - i; - }); - rows.pop(); - return rows; + cellType: 'value' as const + }) + curr_row = [] + rows.push(curr_row) + }) + gpu_cell.rowspan = rows.length - i + }) + pid_cell.rowspan = rows.length - i + }) + node_cell.rowspan = rows.length - i + }) + rows.pop() + return rows } export const GpuInfoTable: React.FC = (props) => { - const classes = useStyles(); - interface TableCellInfoNoLast { - content: string; - rowspan: number; - cellType: 'node' | 'pid' | 'gpu' | 'key' | 'value'; + const classes = useStyles() + interface TableCellInfo { + content: string + rowspan: number + cellType: 'node' | 'pid' | 'gpu' | 'key' | 'value' } - const rows = React.useMemo(() => makeTableCellInfo(props.gpuInfo), [props.gpuInfo]); + const rows = React.useMemo(() => makeTableCellInfo(props.gpuInfo), [ + props.gpuInfo + ]) const cellToClass = { node: classes.nodeTd, pid: classes.pidTd, gpu: classes.gpuTd, key: classes.keyTd, - value: classes.valueTd, - }; + value: classes.valueTd + } - const renderCell = function (info: TableCellInfoNoLast): JSX.Element { - let cellClass = cellToClass[info.cellType]; - let content = info.cellType === 'key' ? `${info.content}:` : info.content; + const renderCell = function (info: TableCellInfo) { + let cellClass = cellToClass[info.cellType] + let content = info.cellType == 'key' ? info.content + ':' : info.content return ( -
- ); - }; + ) + } return (
+ {content}
@@ -126,5 +130,5 @@ export const GpuInfoTable: React.FC = (props) => { {row.map(renderCell)} ))}
- ); -}; + ) +} diff --git a/plugins/tensorboard-plugins/tb_plugin/fe/src/components/Kernel.tsx b/plugins/tensorboard-plugins/tb_plugin/fe/src/components/Kernel.tsx index 66e05695153a853f68d382a2f3b6a68931861abf..62ec350b8b400a03bd64c032ee2a61a4ca9a1852 100644 --- a/plugins/tensorboard-plugins/tb_plugin/fe/src/components/Kernel.tsx +++ b/plugins/tensorboard-plugins/tb_plugin/fe/src/components/Kernel.tsx @@ -15,183 +15,208 @@ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. - * + * * Modifications: Add visualization of PyTorch Ascend profiling. *--------------------------------------------------------------------------------------------*/ -import Card from '@material-ui/core/Card'; -import CardContent from '@material-ui/core/CardContent'; -import CardHeader from '@material-ui/core/CardHeader'; -import FormControlLabel from '@material-ui/core/FormControlLabel'; -import Grid from '@material-ui/core/Grid'; -import InputLabel from '@material-ui/core/InputLabel'; -import MenuItem from '@material-ui/core/MenuItem'; -import Radio from '@material-ui/core/Radio'; -import RadioGroup, { RadioGroupProps } from '@material-ui/core/RadioGroup'; -import Select, { SelectProps } from '@material-ui/core/Select'; -import { makeStyles } from '@material-ui/core/styles'; -import TextField, { StandardTextFieldProps, TextFieldProps } from '@material-ui/core/TextField'; -import * as React from 'react'; -import * as api from '../api'; -import { Graph } from '../api'; -import { KernelGroupBy } from '../constants/groupBy'; -import { useSearch } from '../utils/search'; -import { topIsValid, UseTop, useTopN } from '../utils/top'; -import { AntTableChart } from './charts/AntTableChart'; -import { PieChart } from './charts/PieChart'; -import { DataLoading } from './DataLoading'; -import { makeChartHeaderRenderer, useTooltipCommonStyles } from './helpers'; +import Card from '@material-ui/core/Card' +import CardContent from '@material-ui/core/CardContent' +import CardHeader from '@material-ui/core/CardHeader' +import FormControlLabel from '@material-ui/core/FormControlLabel' +import Grid from '@material-ui/core/Grid' +import InputLabel from '@material-ui/core/InputLabel' +import MenuItem from '@material-ui/core/MenuItem' +import Radio from '@material-ui/core/Radio' +import RadioGroup, { RadioGroupProps } from '@material-ui/core/RadioGroup' +import Select, { SelectProps } from '@material-ui/core/Select' +import { makeStyles } from '@material-ui/core/styles' +import TextField, { + StandardTextFieldProps, + TextFieldProps +} from '@material-ui/core/TextField' +import * as React from 'react' +import * as api from '../api' +import { Graph } from '../api' +import { KernelGroupBy } from '../constants/groupBy' +import { useSearch } from '../utils/search' +import { topIsValid, UseTop, useTopN } from '../utils/top' +import { AntTableChart } from './charts/AntTableChart' +import { PieChart } from './charts/PieChart' +import { DataLoading } from './DataLoading' +import { makeChartHeaderRenderer, useTooltipCommonStyles } from './helpers' import { - gpuKernelTotalTimeTooltip, - tensorCoresPieChartTooltip, - tensorCoresPieChartTooltipAscend, -} from './TooltipDescriptions'; + GPUKernelTotalTimeTooltip, + TensorCoresPieChartTooltip, + TensorCoresPieChartTooltipAscend +} from './TooltipDescriptions' export interface IProps { - run: string; - worker: string; - span: string; - deviceTarget: string; + run: string + worker: string + span: string + deviceTarget: string } const useStyles = makeStyles((theme) => ({ root: { - flexGrow: 1, + flexGrow: 1 }, verticalInput: { display: 'flex', - alignItems: 'center', + alignItems: 'center' }, inputWidth: { - width: '4em', + width: '4em' }, inputWidthOverflow: { minWidth: '15em', - whiteSpace: 'nowrap', + whiteSpace: 'nowrap' }, description: { - marginLeft: theme.spacing(1), - }, -})); + marginLeft: theme.spacing(1) + } +})) export const Kernel: React.FC = (props) => { - const { run, worker, span, deviceTarget } = props; - const classes = useStyles(); - const tooltipCommonClasses = useTooltipCommonStyles(); + const { run, worker, span, deviceTarget } = props + const classes = useStyles() + const tooltipCommonClasses = useTooltipCommonStyles() const chartHeaderRenderer = React.useMemo( () => makeChartHeaderRenderer(tooltipCommonClasses), [tooltipCommonClasses] - ); + ) - const [kernelGraph, setKernelGraph] = React.useState(undefined); - const [tcGraph, setTcGraph] = React.useState(undefined); - const [kernelTable, setKernelTable] = React.useState(undefined); - const [groupBy, setGroupBy] = React.useState(KernelGroupBy.KERNEL); - const [searchKernelName, setSearchKernelName] = React.useState(''); - const [searchOpName, setSearchOpName] = React.useState(''); - const [sortColumn, setSortColumn] = React.useState(''); - const [hasStep, setHasStep] = React.useState(false); + const [kernelGraph, setKernelGraph] = React.useState( + undefined + ) + const [tcGraph, setTcGraph] = React.useState(undefined) + const [kernelTable, setKernelTable] = React.useState( + undefined + ) + const [groupBy, setGroupBy] = React.useState(KernelGroupBy.Kernel) + const [searchKernelName, setSearchKernelName] = React.useState('') + const [searchOpName, setSearchOpName] = React.useState('') + const [sortColumn, setSortColumn] = React.useState('') + const [hasStep, setHasStep] = React.useState(false) const [topText, actualTop, useTop, setTopText, setUseTop] = useTopN({ - defaultUseTop: UseTop.USE, - defaultTop: 10, - }); + defaultUseTop: UseTop.Use, + defaultTop: 10 + }) React.useEffect(() => { - setSearchOpName(''); - }, [groupBy]); + setSearchOpName('') + }, [groupBy]) React.useEffect(() => { if (kernelGraph) { - setTopText(String(Math.min(kernelGraph.rows?.length, 10))); + setTopText(String(Math.min(kernelGraph.rows?.length, 10))) } - }, [kernelGraph]); + }, [kernelGraph]) React.useEffect(() => { api.defaultApi.kernelTableGet(run, worker, span, groupBy).then((resp) => { - setSortColumn(resp.metadata.sort); - setKernelTable(resp.data); - const nameColumnIdx = resp.data.columns.findIndex((c) => c.name.toLowerCase() === 'step id'); - setHasStep(nameColumnIdx > -1); - }); - }, [run, worker, span, groupBy]); + setSortColumn(resp.metadata.sort) + setKernelTable(resp.data) + const nameColumnIdx = resp.data.columns.findIndex( + (c) => c.name.toLowerCase() === 'step id' + ) + setHasStep(nameColumnIdx > -1) + }) + }, [run, worker, span, groupBy]) React.useEffect(() => { - api.defaultApi.kernelGet(run, worker, span, KernelGroupBy.KERNEL).then((resp) => { - setKernelGraph(resp.total); - setGroupBy(resp.device_target === 'Ascend' ? KernelGroupBy.KERNEL_NAME_AND_OP_NAME : KernelGroupBy.KERNEL); - }); - }, [run, worker, span]); + api.defaultApi + .kernelGet(run, worker, span, KernelGroupBy.Kernel) + .then((resp) => { + setKernelGraph(resp.total) + setGroupBy(resp.device_target === 'Ascend' ? KernelGroupBy.KernelNameAndOpName : KernelGroupBy.Kernel) + }) + }, [run, worker, span]) React.useEffect(() => { api.defaultApi.kernelTcPieGet(run, worker, span).then((resp) => { - setTcGraph(resp.total); - }); - }, [run, worker, span]); + setTcGraph(resp.total) + }) + }, [run, worker, span]) - const [searchedKernelTable] = useSearch(searchKernelName, 'name', kernelTable); + const [searchedKernelTable] = useSearch(searchKernelName, 'name', kernelTable) const [searchedOpTable] = useSearch( searchOpName, deviceTarget === 'Ascend' ? 'step id' : 'operator', searchedKernelTable - ); + ) const onGroupByChanged: SelectProps['onChange'] = (event) => { - setGroupBy(event.target.value as KernelGroupBy); - }; + setGroupBy(event.target.value as KernelGroupBy) + } const onSearchKernelChanged: TextFieldProps['onChange'] = (event) => { - setSearchKernelName(event.target.value as string); - }; + setSearchKernelName(event.target.value as string) + } const onSearchOpChanged: TextFieldProps['onChange'] = (event) => { - setSearchOpName(event.target.value as string); - }; + setSearchOpName(event.target.value as string) + } const onUseTopChanged: RadioGroupProps['onChange'] = (event) => { - setUseTop(event.target.value as UseTop); - }; + setUseTop(event.target.value as UseTop) + } - const onTopChanged = (event: React.ChangeEvent): void => { - setTopText(event.target.value); - }; + const onTopChanged = (event: React.ChangeEvent) => { + setTopText(event.target.value) + } const inputProps: StandardTextFieldProps['inputProps'] = { - min: 1, - }; + min: 1 + } const GPUKernelTotalTimeTitle = React.useMemo( - () => chartHeaderRenderer('Total Time (us)', gpuKernelTotalTimeTooltip), + () => chartHeaderRenderer('Total Time (us)', GPUKernelTotalTimeTooltip), [chartHeaderRenderer] - ); + ) const TensorCoresTitle = React.useMemo( - () => - deviceTarget === 'Ascend' - ? chartHeaderRenderer('Accelerator Core Utilization', tensorCoresPieChartTooltipAscend) - : chartHeaderRenderer('Tensor Cores Utilization', tensorCoresPieChartTooltip), + () => deviceTarget === 'Ascend' ? + chartHeaderRenderer( + 'Accelerator Core Utilization', + TensorCoresPieChartTooltipAscend + ) + : + chartHeaderRenderer( + 'Tensor Cores Utilization', + TensorCoresPieChartTooltip + ), [chartHeaderRenderer, deviceTarget] - ); + ) return (
- - + + - } label='All kernels' /> - } label='Top kernels to show' /> + } + label="All kernels" + /> + } + label="Top kernels to show" + /> - {useTop === UseTop.USE && ( + {useTop === UseTop.Use && ( = (props) => { - {(graph): JSX.Element => ( + {(graph) => ( - + )} - {(graph): JSX.Element => ( + {(graph) => ( = (props) => { graph={graph} colors={['#0099C6', '#DD4477', '#66AA00', '#B82E2E']} top={actualTop} - tooltipMode='percentage' + tooltip_mode="percentage" /> )} - + - + - Group By - + {deviceTarget === 'Ascend' ? 'Statistic' : 'Kernel Properties + Op Name'} - + {deviceTarget === 'Ascend' ? 'All' : 'Kernel Name'} @@ -246,49 +279,50 @@ export const Kernel: React.FC = (props) => { classes={{ root: classes.inputWidthOverflow }} value={searchKernelName} onChange={onSearchKernelChanged} - type='search' - label='Search by Name' + type="search" + label="Search by Name" inputProps={{ - maxLength: 200, + maxLength: 200 }} /> - {deviceTarget === 'Ascend' - ? groupBy === KernelGroupBy.KERNEL && - hasStep && ( - - - - ) - : groupBy === KernelGroupBy.KERNEL_NAME_AND_OP_NAME && ( - - - - )} + {deviceTarget === 'Ascend' ? + (groupBy === KernelGroupBy.Kernel && hasStep && + + + ) + : + (groupBy === KernelGroupBy.KernelNameAndOpName && + + + ) + } - {(graph): JSX.Element => } + {(graph) => ( + + )} @@ -297,5 +331,5 @@ export const Kernel: React.FC = (props) => {
- ); -}; + ) +} diff --git a/plugins/tensorboard-plugins/tb_plugin/fe/src/components/MemoryView.tsx b/plugins/tensorboard-plugins/tb_plugin/fe/src/components/MemoryView.tsx index 225f28a931e969d7cfd40d3f490e7cb45c64a305..a8f6c458eae79adf09371fcb73ecb29d1a62d067 100644 --- a/plugins/tensorboard-plugins/tb_plugin/fe/src/components/MemoryView.tsx +++ b/plugins/tensorboard-plugins/tb_plugin/fe/src/components/MemoryView.tsx @@ -15,22 +15,22 @@ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. - * + * * Modifications: Add visualization of PyTorch Ascend profiling. *--------------------------------------------------------------------------------------------*/ -import Card from '@material-ui/core/Card'; -import CardContent from '@material-ui/core/CardContent'; -import CardHeader from '@material-ui/core/CardHeader'; -import Grid from '@material-ui/core/Grid'; -import InputLabel from '@material-ui/core/InputLabel'; -import MenuItem from '@material-ui/core/MenuItem'; -import Select, { SelectProps } from '@material-ui/core/Select'; -import Slider from '@material-ui/core/Slider'; -import { makeStyles } from '@material-ui/core/styles'; -import TextField, { TextFieldProps } from '@material-ui/core/TextField'; -import * as React from 'react'; -import * as api from '../api'; +import Card from '@material-ui/core/Card' +import CardContent from '@material-ui/core/CardContent' +import CardHeader from '@material-ui/core/CardHeader' +import Grid from '@material-ui/core/Grid' +import InputLabel from '@material-ui/core/InputLabel' +import MenuItem from '@material-ui/core/MenuItem' +import Select, { SelectProps } from '@material-ui/core/Select' +import Slider from '@material-ui/core/Slider' +import { makeStyles } from '@material-ui/core/styles' +import TextField, { TextFieldProps } from '@material-ui/core/TextField' +import * as React from 'react' +import * as api from '../api' import { Graph, GraphAscend, @@ -39,237 +39,288 @@ import { MemoryCurveDataAscend, MemoryEventsData, MemoryEventsDataAll, - MemoryStatsData, -} from '../api'; -import { useSearchDirectly } from '../utils/search'; -import { AntTableChart } from './charts/AntTableChart'; -import { LineChart } from './charts/NewLineChart'; -import { DataLoading } from './DataLoading'; -import { MemoryStatsTable } from './tables/MemoryStatsTable'; + MemoryStatsData +} from '../api' +import { useSearchDirectly } from '../utils/search' +import { AntTableChart } from './charts/AntTableChart' +import { LineChart } from './charts/NewLineChart' +import { DataLoading } from './DataLoading' +import { MemoryStatsTable } from './tables/MemoryStatsTable' const useStyles = makeStyles((theme) => ({ root: { - flexGrow: 1, + flexGrow: 1 }, curve: { - marginBottom: 20, + marginBottom: 20 }, verticalInput: { display: 'flex', - alignItems: 'center', + alignItems: 'center' }, inputWidth: { - width: '4em', + width: '4em' }, inputWidthOverflow: { minWidth: '15em', - whiteSpace: 'nowrap', + whiteSpace: 'nowrap' }, full: { - width: '100%', + width: '100%' }, description: { - marginLeft: theme.spacing(1), + marginLeft: theme.spacing(1) }, filterSlider: { marginTop: 15, marginRight: 6, - width: 250, + width: 250 }, filterInput: { - width: 100, - }, -})); + width: 100 + } +})) export interface IProps { - run: string; - worker: string; - span: string; - deviceTarget: string; + run: string + worker: string + span: string + deviceTarget: string } -const tags = ['Operator', 'Component']; +const tags = ['Operator', 'Component'] export const MemoryView: React.FC = React.memo((props) => { interface EventSizeFilter { - [deviceName: string]: Array; + [deviceName: string]: Array } interface MaxEventSize { - [deviceName: string]: number; + [deviceName: string]: number } - const { run, worker, span, deviceTarget } = props; - const classes = useStyles(); + const { run, worker, span, deviceTarget } = props + const classes = useStyles() - const [memoryStatsData, setMemoryStatsData] = React.useState(undefined); + const [memoryStatsData, setMemoryStatsData] = React.useState< + MemoryStatsData | undefined + >(undefined) // for backward compatability, old profile do not have events to show - const showEvents = (): boolean | undefined => { - return memoryEventsData && Object.keys(memoryEventsData.rows).length !== 0; - }; - const [memoryEventsData, setMemoryEventsData] = React.useState(undefined); + const showEvents = () => { + return memoryEventsData && Object.keys(memoryEventsData.rows).length != 0 + } + const [memoryEventsData, setMemoryEventsData] = React.useState< + MemoryEventsData | undefined + >(undefined) // for backward compatability, old profile do not have curve to show - const showCurve = (): boolean | undefined => { - return memoryCurveData && Object.keys(memoryCurveData.rows).length !== 0; - }; - const [memoryCurveData, setMemoryCurveData] = React.useState( - undefined - ); + const showCurve = () => { + return memoryCurveData && Object.keys(memoryCurveData.rows).length != 0 + } + const [memoryCurveData, setMemoryCurveData] = React.useState< + MemoryCurveData | MemoryCurveDataAscend | undefined + >(undefined) - const [lineChartData, setLineChartData] = React.useState(undefined); + const [lineChartData, setLineChartData] = React.useState( + undefined + ) - const [devices, setDevices] = React.useState([]); - const [device, setDevice] = React.useState(''); - const [tag, setTag] = React.useState('Operator'); - const memoryCurveDataAllRef = React.useRef(undefined); - const memoryEventDataAllRef = React.useRef(undefined); + const [devices, setDevices] = React.useState([]) + const [device, setDevice] = React.useState('') + const [tag, setTag] = React.useState('Operator') + const memoryCurveDataAllRef = React.useRef(undefined) + const memoryEventDataAllRef = React.useRef(undefined) interface SelectedRange { - start: number; - end: number; - startTs: number; - endTs: number; + start: number + end: number + startTs: number + endTs: number } - const [selectedRange, setSelectedRange] = React.useState(); - const [searchOperatorName, setSearchOperatorName] = React.useState(''); - const [searchEventOperatorName, setSearchEventOperatorName] = React.useState(''); - const [filterEventSize, setFilterEventSize] = React.useState({}); - const [maxSize, setMaxSize] = React.useState({}); - - const getSearchIndex = function (): number { + const [selectedRange, setSelectedRange] = React.useState< + SelectedRange | undefined + >() + const [searchOperatorName, setSearchOperatorName] = React.useState('') + const [searchEventOperatorName, setSearchEventOperatorName] = React.useState( + '' + ) + const [filterEventSize, setFilterEventSize] = React.useState( + {} + ) + const [maxSize, setMaxSize] = React.useState({}) + + const getSearchIndex = function () { if (!memoryStatsData) { - return -1; + return -1 } for (let i = 0; i < memoryStatsData.columns.length; i++) { - if (memoryStatsData.columns[i].name === memoryStatsData.metadata.search) { - return i; + if (memoryStatsData.columns[i].name == memoryStatsData.metadata.search) { + return i } } - return -1; - }; + return -1 + } - const getStep = (size: number, indexBias: number): number => { - return 10 ** (Math.floor(Math.log10(size !== 0 ? size : 1)) - indexBias); - }; + const getStep = (size: number, indexBias: number) => { + return 10 ** (Math.floor(Math.log10(size != 0 ? size : 1)) - indexBias) + } - const filterByEventSize = (rows: T[] | undefined, size: Array): T[] | undefined => { + const filterByEventSize = ( + rows: T[] | undefined, + size: Array + ) => { const result = React.useMemo(() => { if (!rows) { - return undefined; + return undefined } // workaround type system const field = (row: any): number => { - const sizeColIndex = 1; - return row[sizeColIndex]; - }; + const sizeColIndex = 1 + return row[sizeColIndex] + } return rows.filter((row) => { - return field(row) >= size[0] && field(row) <= size[1]; - }); - }, [rows, size]); + return field(row) >= size[0] && field(row) <= size[1] + }) + }, [rows, size]) - return result; - }; + return result + } - const searchIndex = getSearchIndex(); - const getName = React.useCallback((row: any) => row[searchIndex], [searchIndex]); - const getNameAscend = (row: any): any => row[0]; - const [searchedTableDataRows] = useSearchDirectly(searchOperatorName, getName, memoryStatsData?.rows[device] ?? []); + const searchIndex = getSearchIndex() + const getName = React.useCallback((row: any) => row[searchIndex], [ + searchIndex + ]) + const getNameAscend = (row: any) => row[0] + const [searchedTableDataRows] = useSearchDirectly( + searchOperatorName, + getName, + memoryStatsData?.rows[device] ?? [] + ) const [searchedEventsTableDataRows] = useSearchDirectly( searchEventOperatorName, deviceTarget === 'Ascend' ? getNameAscend : getName, - filterByEventSize(memoryEventsData?.rows[device], filterEventSize[device] ?? [0, Infinity]) ?? [] - ); + filterByEventSize( + memoryEventsData?.rows[device], + filterEventSize[device] ?? [0, Infinity] + ) ?? [] + ) const onSearchOperatorChanged: TextFieldProps['onChange'] = (event) => { - setSearchOperatorName(event.target.value as string); - }; + setSearchOperatorName(event.target.value as string) + } const onSearchEventOperatorChanged: TextFieldProps['onChange'] = (event) => { - setSearchEventOperatorName(event.target.value as string); - }; + setSearchEventOperatorName(event.target.value as string) + } - const [selectedRecord, setSelectedRecord] = React.useState(); - const onRowSelected = (record?: object, rowIndex?: number): void => { - setSelectedRecord(record); - }; + const [selectedRecord, setSelectedRecord] = React.useState() + const onRowSelected = (record?: object, rowIndex?: number) => { + setSelectedRecord(record) + } - const onFilterEventSizeChanged = (event: any, newValue: number | number[]): void => { + const onFilterEventSizeChanged = ( + event: any, + newValue: number | number[] + ) => { setFilterEventSize({ ...filterEventSize, - [device]: newValue as number[], - }); - }; + [device]: newValue as number[] + }) + } - const onFilterEventMinSizeInputChanged = (event: React.ChangeEvent): void => { + const onFilterEventMinSizeInputChanged = ( + event: React.ChangeEvent + ) => { setFilterEventSize({ ...filterEventSize, - [device]: [Number(event.target.value), filterEventSize[device][1]], - }); - }; + [device]: [Number(event.target.value), filterEventSize[device][1]] + }) + } - const onFilterEventMaxSizeInputChanged = (event: React.ChangeEvent): void => { + const onFilterEventMaxSizeInputChanged = ( + event: React.ChangeEvent + ) => { setFilterEventSize({ ...filterEventSize, - [device]: [filterEventSize[device][0], Number(event.target.value)], - }); - }; + [device]: [filterEventSize[device][0], Number(event.target.value)] + }) + } React.useEffect(() => { - if (deviceTarget !== 'Ascend') { - api.defaultApi.memoryGet(run, worker, span, selectedRange?.startTs, selectedRange?.endTs).then((resp) => { - setMemoryStatsData(resp); - if (!devices || devices.length === 0) { + deviceTarget !== 'Ascend' && api.defaultApi + .memoryGet( + run, + worker, + span, + selectedRange?.startTs, + selectedRange?.endTs + ) + .then((resp) => { + setMemoryStatsData(resp) + if (!devices || devices.length == 0) { // setDevices only execute on view load. Since selection on curve // might filter all events later, some devices might is missing. - setDevices(Object.keys(resp.rows)); - setDevice(resp.metadata.default_device); + setDevices(Object.keys(resp.rows)) + setDevice(resp.metadata.default_device) } - }); - } - }, [run, worker, span, selectedRange]); + }) + }, [run, worker, span, selectedRange]) React.useEffect(() => { - api.defaultApi.memoryEventsGet(run, worker, span, selectedRange?.startTs, selectedRange?.endTs).then((resp) => { - const tempRes = deviceTarget === 'Ascend' ? (resp as MemoryEventsDataAll).operator : (resp as MemoryEventsData); - if (deviceTarget === 'Ascend') { - memoryEventDataAllRef.current = resp as MemoryEventsDataAll; - } - let curMaxSize: MaxEventSize = {}; - let curFilterEventSize: EventSizeFilter = {}; - Object.keys(tempRes.rows).forEach((deviceName) => { - curMaxSize[deviceName] = 0; - for (let i = 0; i < tempRes.rows[deviceName].length; i++) { - curMaxSize[deviceName] = Math.max(curMaxSize[deviceName], tempRes.rows[deviceName][i][1]); + api.defaultApi + .memoryEventsGet( + run, + worker, + span, + selectedRange?.startTs, + selectedRange?.endTs + ) + .then((resp) => { + const tempRes = deviceTarget === 'Ascend' ? (resp as MemoryEventsDataAll).operator : resp as MemoryEventsData + if (deviceTarget === 'Ascend') { + memoryEventDataAllRef.current = resp as MemoryEventsDataAll + } + let curMaxSize: MaxEventSize = {} + let curFilterEventSize: EventSizeFilter = {} + for (let deviceName in tempRes.rows) { + curMaxSize[deviceName] = 0 + for (let i = 0; i < tempRes.rows[deviceName].length; i++) { + curMaxSize[deviceName] = Math.max( + curMaxSize[deviceName], + tempRes.rows[deviceName][i][1] + ) + } + curFilterEventSize[deviceName] = [ + curMaxSize[deviceName] / 4, + curMaxSize[deviceName] + ] + curMaxSize[deviceName] = curMaxSize[deviceName] } - curFilterEventSize[deviceName] = [curMaxSize[deviceName] / 4, curMaxSize[deviceName]]; - curMaxSize[deviceName] = curMaxSize[deviceName]; - }); - setMaxSize(curMaxSize); - setFilterEventSize(curFilterEventSize); - setMemoryEventsData(tempRes); - }); - }, [run, worker, span, selectedRange]); + setMaxSize(curMaxSize) + setFilterEventSize(curFilterEventSize) + setMemoryEventsData(tempRes) + }) + }, [run, worker, span, selectedRange]) React.useEffect(() => { api.defaultApi.memoryCurveGet(run, worker, span).then((resp) => { // Reset the select range to null whenever run/worker/span changes - setSelectedRange(undefined); + setSelectedRange(undefined) if (deviceTarget === 'Ascend') { - const allCurveData = resp as MemoryCurveDataAll; - memoryCurveDataAllRef.current = allCurveData; - setDevice(allCurveData.default_device); - setDevices(allCurveData.devices); - setMemoryCurveData(allCurveData.total); - setTag('Operator'); + const allCurveData = resp as MemoryCurveDataAll + memoryCurveDataAllRef.current = allCurveData + setDevice(allCurveData.default_device) + setDevices(allCurveData.devices) + setMemoryCurveData(allCurveData.total) + setTag('Operator') } else { - setMemoryCurveData(resp as MemoryCurveData); + setMemoryCurveData(resp as MemoryCurveData) } - }); - }, [run, worker, span]); + }) + }, [run, worker, span]) React.useEffect(() => { if (memoryCurveData !== undefined) { @@ -277,118 +328,127 @@ export const MemoryView: React.FC = React.memo((props) => { setLineChartData({ title: memoryCurveData.metadata.peaks[device] ?? '', columns: memoryCurveData.columns[device] ?? [], - rows: memoryCurveData.rows[device] ?? {}, - }); + rows: memoryCurveData.rows[device] ?? {} + }) } else { setLineChartData({ title: memoryCurveData.metadata.peaks[device], columns: memoryCurveData.columns, - rows: memoryCurveData.rows[device] ?? [], - }); + rows: memoryCurveData.rows[device] ?? [] + }) } } - }, [memoryCurveData, device]); + }, [memoryCurveData, device]) const onDeviceChanged: SelectProps['onChange'] = (event) => { - setDevice(event.target.value as string); - setSelectedRange(undefined); - }; + setDevice(event.target.value as string) + setSelectedRange(undefined) + } const onTagChanged: SelectProps['onChange'] = (event) => { - setTag(event.target.value as string); + setTag(event.target.value as string) if (event.target.value === 'Operator') { - setMemoryCurveData(memoryCurveDataAllRef.current?.total); - setMemoryEventsData(memoryEventDataAllRef.current?.operator); - setSelectedRange(undefined); + setMemoryCurveData(memoryCurveDataAllRef.current?.total) + setMemoryEventsData(memoryEventDataAllRef.current?.operator) + setSelectedRange(undefined) } else { - setMemoryCurveData(memoryCurveDataAllRef.current?.ptaGe); - setMemoryEventsData(memoryEventDataAllRef.current?.component); + setMemoryCurveData(memoryCurveDataAllRef.current?.ptaGe) + setMemoryEventsData(memoryEventDataAllRef.current?.component) } - }; + } - const onSelectedRangeChanged = (start: number, end: number): void => { + const onSelectedRangeChanged = (start: number, end: number) => { if (start > end) { - setSelectedRange(undefined); - return; + setSelectedRange(undefined) + return } - let allDatas = deviceTarget === 'Ascend' ? memoryCurveData?.rows[device]?.Allocated : memoryCurveData?.rows[device]; + let allDatas = deviceTarget === 'Ascend' ? + memoryCurveData?.rows[device]?.Allocated : memoryCurveData?.rows[device] if (allDatas.length <= 1) { - setSelectedRange(undefined); - return; + setSelectedRange(undefined) + return } - let startTs = 0; - let endTs = 0; - let realStart = 0; - let realEnd = 0; - let startId = 1; - let endId = 0; - let needLoopStart = true; + let startTs = 0 + let endTs = 0 + let realStart = 0 + let realEnd = 0 + let startId = 1 + let endId = 0 + let needLoopStart = true for (let i = 1; i < allDatas.length; i++) { if (startId > start && needLoopStart) { - needLoopStart = false; - realStart = i - 1; + needLoopStart = false + realStart = i - 1 } if (allDatas[i][0] !== allDatas[i - 1][0]) { if (startId <= start) { - startId += 1; + startId += 1 } - endId += 1; + endId += 1 } if (endId > end) { - realEnd = i - 1; - break; + realEnd = i - 1 + break } else { - realEnd = i; + realEnd = i if (needLoopStart) { - realStart = i; + realStart = i } } } if (deviceTarget === 'Ascend') { - startTs = allDatas[realStart][0]; - endTs = allDatas[realEnd][0]; + startTs = allDatas[realStart][0] + endTs = allDatas[realEnd][0] } else { - let bias = memoryCurveData?.metadata.first_ts ?? 0; - let scale = 1 / (memoryCurveData?.metadata.time_factor ?? 1); - startTs = Math.round((allDatas[realStart][0] * scale) + bias); - endTs = Math.round((allDatas[realEnd][0] * scale) + bias); + let bias = memoryCurveData?.metadata.first_ts ?? 0 + let scale = 1 / (memoryCurveData?.metadata.time_factor ?? 1) + startTs = Math.round(allDatas[realStart][0] * scale + bias) + endTs = Math.round(allDatas[realEnd][0] * scale + bias) } - setSelectedRange({ start, end, startTs, endTs }); - }; + setSelectedRange({ start, end, startTs, endTs }) + } return (
- - + + - + - {(graph): JSX.Element => ( - + {(graph) => ( + - Device - + {devices.map((device) => ( + {device} ))} - {deviceTarget === 'Ascend' && ( + {deviceTarget === 'Ascend' && - Group By - + {tags.map((device) => ( + {device} ))} - )} + } {showCurve() && lineChartData && lineChartData.columns.length > 0 && ( @@ -411,28 +471,28 @@ export const MemoryView: React.FC = React.memo((props) => { {showEvents() && ( <> - {(deviceTarget !== 'Ascend' || tag === 'Operator') && ( + {(deviceTarget !== 'Ascend' || tag === 'Operator') && - + - + = React.memo((props) => { min: 0, max: filterEventSize[device]?.[1] ?? 0, type: 'number', - 'aria-labelledby': 'input-slider', + 'aria-labelledby': 'input-slider' }} /> @@ -449,7 +509,7 @@ export const MemoryView: React.FC = React.memo((props) => { className={classes.filterSlider} value={filterEventSize[device] ?? [0, 0]} onChange={onFilterEventSizeChanged} - aria-labelledby='input-slider' + aria-labelledby="input-slider" min={0} max={maxSize[device] ?? 0} step={getStep(maxSize[device] ?? 0, 5)} @@ -458,7 +518,7 @@ export const MemoryView: React.FC = React.memo((props) => { = React.memo((props) => { min: filterEventSize[device]?.[0] ?? 0, max: maxSize[device] ?? 0, type: 'number', - 'aria-labelledby': 'input-slider', + 'aria-labelledby': 'input-slider' }} /> - )} - + } + - {(data): JSX.Element => { + {(data) => { return ( - ); + ) }} @@ -494,29 +555,29 @@ export const MemoryView: React.FC = React.memo((props) => { )} {deviceTarget !== 'Ascend' && ( <> - - + + - + - {(data): JSX.Element => ( + {(data) => ( )} @@ -527,5 +588,5 @@ export const MemoryView: React.FC = React.memo((props) => {
- ); -}); + ) +}) diff --git a/plugins/tensorboard-plugins/tb_plugin/fe/src/components/ModuleView.tsx b/plugins/tensorboard-plugins/tb_plugin/fe/src/components/ModuleView.tsx index a66a825365fd3c813e58865c609643ab547b4c49..396188aba4e69cced5208ff4af86631bf02e172c 100644 --- a/plugins/tensorboard-plugins/tb_plugin/fe/src/components/ModuleView.tsx +++ b/plugins/tensorboard-plugins/tb_plugin/fe/src/components/ModuleView.tsx @@ -1,227 +1,241 @@ /*--------------------------------------------------------------------------------------------- * Copyright (c) Microsoft Corporation. All rights reserved. *--------------------------------------------------------------------------------------------*/ -import Card from '@material-ui/core/Card'; -import CardHeader from '@material-ui/core/CardHeader'; -import InputLabel from '@material-ui/core/InputLabel'; -import MenuItem from '@material-ui/core/MenuItem'; -import Select, { SelectProps } from '@material-ui/core/Select'; -import { makeStyles } from '@material-ui/core/styles'; -import { message, Table } from 'antd'; -import * as React from 'react'; -import { FlameGraph } from 'react-flame-graph'; -import { defaultApi, KeyedColumn, ModuleStats, ModuleViewData, OperatorNode } from '../api'; +import Card from '@material-ui/core/Card' +import CardHeader from '@material-ui/core/CardHeader' +import InputLabel from '@material-ui/core/InputLabel' +import MenuItem from '@material-ui/core/MenuItem' +import Select, { SelectProps } from '@material-ui/core/Select' +import { makeStyles } from '@material-ui/core/styles' +import { Table } from 'antd' +import * as React from 'react' +import { FlameGraph } from 'react-flame-graph' +import { + defaultApi, + KeyedColumn, + ModuleStats, + ModuleViewData, + OperatorNode +} from '../api' const useStyles = makeStyles((theme) => ({ root: { - flexGrow: 1, + flexGrow: 1 }, hide: { - display: 'none', - }, -})); + display: 'none' + } +})) export interface IProps { - run: string; - worker: string; - span: string; + run: string + worker: string + span: string } -const getKeyedTableColumns = (columns: KeyedColumn[]): any[] => { +const getKeyedTableColumns = (columns: KeyedColumn[]) => { return columns.map((col) => { return { dataIndex: col.key, key: col.key, - title: col.name, - }; - }); -}; + title: col.name + } + }) +} -const getTableRows = (key: number, rows: ModuleStats[]): any[] => { - let initialKey = key; +const getTableRows = (key: number, rows: ModuleStats[]) => { return rows.map((row) => { - const currentKey = initialKey++; const data: any = { - key: currentKey, + key: key++, name: row.name, occurences: row.occurences, operators: row.operators, host_duration: row.host_duration, self_host_duration: row.self_host_duration, device_duration: row.device_duration, - self_device_duration: row.self_device_duration, - }; + self_device_duration: row.self_device_duration + } if (row.children.length) { - data.children = getTableRows(key, row.children); + data.children = getTableRows(key, row.children) } - return data; - }); -}; + return data + }) +} -const getFlameGraphData = (rows: ModuleStats[]): any[] => { +const getFlameGraphData = (rows: ModuleStats[]) => { return rows.map((row) => { const data: any = { name: row.name, value: row.avg_duration, - tooltip: `${row.name} (module id: ${row.id}): ${row.avg_duration} us`, - }; + tooltip: `${row.name} (module id: ${row.id}): ${row.avg_duration} us` + } if (row.children.length) { - data.children = getFlameGraphData(row.children); + data.children = getFlameGraphData(row.children) } - return data; - }); -}; + return data + }) +} const getTreeHeight = (row: ModuleStats): number => { - if (row.children?.length) { - return 1 + Math.max(...row.children.map((child) => getTreeHeight(child))); + if (row.children && row.children.length) { + return 1 + Math.max(...row.children.map((child) => getTreeHeight(child))) } else { - return 1; + return 1 } -}; +} -const getOperatorTree = (level: number, row: OperatorNode, result: object[]): void => { +const getOperatorTree = ( + level: number, + row: OperatorNode, + result: object[] +) => { result.push({ level: level, name: row.name, start: row.start_time, - end: row.end_time, - }); + end: row.end_time + }) if (row.children.length) { - row.children.forEach((child) => getOperatorTree(level + 1, child, result)); + row.children.forEach((child) => getOperatorTree(level + 1, child, result)) } -}; +} export const ModuleView: React.FC = (props) => { - const { run, worker, span } = props; - const classes = useStyles(); + const { run, worker, span } = props + const classes = useStyles() - const [moduleView, setModuleView] = React.useState(undefined); - const [flameData, setFlameData] = React.useState([]); - const [flameHeight, setFlameHeight] = React.useState(0); - const [modules, setModules] = React.useState([]); - const [module, setModule] = React.useState(0); + const [moduleView, setModuleView] = React.useState< + ModuleViewData | undefined + >(undefined) + const [flameData, setFlameData] = React.useState([]) + const [flameHeight, setFlameHeight] = React.useState(0) + const [modules, setModules] = React.useState([]) + const [module, setModule] = React.useState(0) - const [columns, setColumns] = React.useState([]); - const [rows, setRows] = React.useState([]); + const [columns, setColumns] = React.useState([]) + const [rows, setRows] = React.useState([]) - const cardRef = React.useRef(null); - const [cardWidth, setCardWidth] = React.useState(undefined); - const timelineRef = React.useRef(null); + const cardRef = React.useRef(null) + const [cardWidth, setCardWidth] = React.useState( + undefined + ) + const timelineRef = React.useRef(null) React.useEffect(() => { defaultApi .moduleGet(run, worker, span) .then((resp) => { - setModuleView(resp); + setModuleView(resp) if (resp) { // set the flamegraph data - const flameGraphData: any[] = getFlameGraphData(resp.data); - setFlameData(flameGraphData); - const flameGraphHeight = Math.max(...flameGraphData.map((x) => getTreeHeight(x))); - setFlameHeight(flameGraphHeight * 25); - setModules(Array.from(Array(flameGraphData.length).keys())); - setModule(0); + const flameData: any[] = getFlameGraphData(resp.data) + setFlameData(flameData) + const flameHeight = Math.max( + ...flameData.map((x) => getTreeHeight(x)) + ) + setFlameHeight(flameHeight * 25) + setModules(Array.from(Array(flameData.length).keys())) + setModule(0) // set the tree table data - setColumns(getKeyedTableColumns(resp.columns)); - setRows(getTableRows(1, resp.data)); + setColumns(getKeyedTableColumns(resp.columns)) + setRows(getTableRows(1, resp.data)) } }) .catch((e) => { - if (e.status === 404) { - setModules([]); - setFlameData([]); - setRows([]); + if (e.status == 404) { + setModules([]) + setFlameData([]) + setRows([]) } - }); + }) if (cardRef.current) { - setCardWidth(cardRef.current.offsetWidth - 10); + setCardWidth(cardRef.current.offsetWidth - 10) } try { if (timelineRef.current) { defaultApi.treeGet(run, worker, span).then((resp) => { if (resp) { - const data = new google.visualization.DataTable(); - data.addColumn({ type: 'string', id: 'Layer' }); - data.addColumn({ type: 'string', id: 'Name' }); - data.addColumn({ type: 'string', role: 'tooltip' }); - data.addColumn({ type: 'number', id: 'Start' }); - data.addColumn({ type: 'number', id: 'End' }); - - let timelineData: any[] = []; - getOperatorTree(0, resp, timelineData); - timelineData.sort((a, b) => a.level - b.level); - const maxLevel = timelineData[timelineData.length - 1].level; - timelineData.forEach((d) => { + const data = new google.visualization.DataTable() + data.addColumn({ type: 'string', id: 'Layer' }) + data.addColumn({ type: 'string', id: 'Name' }) + data.addColumn({ type: 'string', role: 'tooltip' }) + data.addColumn({ type: 'number', id: 'Start' }) + data.addColumn({ type: 'number', id: 'End' }) + + let timeline_data: any[] = [] + getOperatorTree(0, resp, timeline_data) + timeline_data.sort((a, b) => a.level - b.level) + const max_level = timeline_data[timeline_data.length - 1].level + timeline_data.forEach((d) => { data.addRow([ d.level.toString(), d.name, `${d.name} Duration: ${d.end - d.start} us`, d.start / 1000.0, // the time unit is us returned from server, but the google charts only accept milliseconds here - d.end / 1000.0, - ]); - }); + d.end / 1000.0 + ]) + }) - const chart = new google.visualization.Timeline(timelineRef.current); + const chart = new google.visualization.Timeline(timelineRef.current) const options = { - height: (maxLevel + 1) * 50, + height: (max_level + 1) * 50, tooltip: { - isHtml: true, + isHtml: true }, timeline: { - showRowLabels: false, - }, - }; - chart.draw(data, options); + showRowLabels: false + } + } + chart.draw(data, options) } - }); + }) } } catch (e) { - message.warning('Timeline in module view is not supported offline.'); + console.warn('Timeline in module view is not supported offline.') } - }, [run, worker, span]); + }, [run, worker, span]) const handleModuleChange: SelectProps['onChange'] = (event) => { - setModule(event.target.value as number); - }; + setModule(event.target.value as number) + } - const moduleComponent = (): JSX.Element => { + const moduleComponent = () => { const moduleFragment = ( - Module + Module - ); + ) if (!modules || modules.length <= 1) { - return
{moduleFragment}
; + return
{moduleFragment}
} else { - return moduleFragment; + return moduleFragment } - }; + } return (
- - + + {rows && rows.length > 0 && ( )} @@ -233,12 +247,13 @@ export const ModuleView: React.FC = (props) => { data={flameData[module]} height={flameHeight} width={cardWidth} - onChange={(node: any): void => {}} + onChange={(node: any) => { + }} /> )}
- ); -}; + ) +} diff --git a/plugins/tensorboard-plugins/tb_plugin/fe/src/components/Operator.tsx b/plugins/tensorboard-plugins/tb_plugin/fe/src/components/Operator.tsx index b19bef1967a31915c3c1d660b699b11c83ebb226..7278ca59c938874b85b2a52abbb36c59f924373b 100644 --- a/plugins/tensorboard-plugins/tb_plugin/fe/src/components/Operator.tsx +++ b/plugins/tensorboard-plugins/tb_plugin/fe/src/components/Operator.tsx @@ -15,99 +15,119 @@ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. - * + * * Modifications: Add visualization of PyTorch Ascend profiling. *--------------------------------------------------------------------------------------------*/ -import Card from '@material-ui/core/Card'; -import CardContent from '@material-ui/core/CardContent'; -import CardHeader from '@material-ui/core/CardHeader'; -import FormControlLabel from '@material-ui/core/FormControlLabel'; -import Grid from '@material-ui/core/Grid'; -import GridList from '@material-ui/core/GridList'; -import GridListTile from '@material-ui/core/GridListTile'; -import InputLabel from '@material-ui/core/InputLabel'; -import MenuItem from '@material-ui/core/MenuItem'; -import Radio from '@material-ui/core/Radio'; -import RadioGroup, { RadioGroupProps } from '@material-ui/core/RadioGroup'; -import Select, { SelectProps } from '@material-ui/core/Select'; -import { makeStyles } from '@material-ui/core/styles'; -import TextField, { StandardTextFieldProps, TextFieldProps } from '@material-ui/core/TextField'; -import * as React from 'react'; -import * as api from '../api'; -import { OperationTableData, OperationTableDataInner, OperatorGraph } from '../api'; -import { OperationGroupBy } from '../constants/groupBy'; -import { useSearchDirectly } from '../utils/search'; -import { topIsValid, UseTop, useTopN } from '../utils/top'; -import { PieChart } from './charts/PieChart'; -import { DataLoading } from './DataLoading'; -import { makeChartHeaderRenderer, useTooltipCommonStyles } from './helpers'; -import { OperationTable } from './tables/OperationTable'; +import Card from '@material-ui/core/Card' +import CardContent from '@material-ui/core/CardContent' +import CardHeader from '@material-ui/core/CardHeader' +import FormControlLabel from '@material-ui/core/FormControlLabel' +import Grid from '@material-ui/core/Grid' +import GridList from '@material-ui/core/GridList' +import GridListTile from '@material-ui/core/GridListTile' +import InputLabel from '@material-ui/core/InputLabel' +import MenuItem from '@material-ui/core/MenuItem' +import Radio from '@material-ui/core/Radio' +import RadioGroup, { RadioGroupProps } from '@material-ui/core/RadioGroup' +import Select, { SelectProps } from '@material-ui/core/Select' +import { makeStyles } from '@material-ui/core/styles' +import TextField, { + StandardTextFieldProps, + TextFieldProps +} from '@material-ui/core/TextField' +import * as React from 'react' +import * as api from '../api' +import { + OperationTableData, + OperationTableDataInner, + OperatorGraph +} from '../api' +import { OperationGroupBy } from '../constants/groupBy' +import { useSearchDirectly } from '../utils/search' +import { topIsValid, UseTop, useTopN } from '../utils/top' +import { PieChart } from './charts/PieChart' +import { DataLoading } from './DataLoading' +import { makeChartHeaderRenderer, useTooltipCommonStyles } from './helpers' +import { OperationTable } from './tables/OperationTable' import { - deviceSelfTimeTooltip, - deviceSelfTimeTooltipAscend, - deviceTotalTimeTooltip, - deviceTotalTimeTooltipAscend, - hostSelfTimeTooltip, - hostTotalTimeTooltip, -} from './TooltipDescriptions'; + DeviceSelfTimeTooltip, + DeviceSelfTimeTooltipAscend, + DeviceTotalTimeTooltip, + DeviceTotalTimeTooltipAscend, + HostSelfTimeTooltip, + HostTotalTimeTooltip +} from './TooltipDescriptions' const useStyles = makeStyles((theme) => ({ root: { - flexGrow: 1, + flexGrow: 1 }, verticalInput: { display: 'flex', - alignItems: 'center', + alignItems: 'center' }, inputWidth: { - width: '4em', + width: '4em' }, inputWidthOverflow: { minWidth: '15em', - whiteSpace: 'nowrap', + whiteSpace: 'nowrap' }, full: { - width: '100%', + width: '100%' }, description: { - marginLeft: theme.spacing(1), - }, -})); + marginLeft: theme.spacing(1) + } +})) export interface IProps { - run: string; - worker: string; - span: string; - deviceTarget: string; + run: string + worker: string + span: string + deviceTarget: string } export const Operator: React.FC = (props) => { - const { run, worker, span, deviceTarget } = props; - const classes = useStyles(); - const tooltipCommonClasses = useTooltipCommonStyles(); + const { run, worker, span, deviceTarget } = props + const classes = useStyles() + const tooltipCommonClasses = useTooltipCommonStyles() const chartHeaderRenderer = React.useMemo( () => makeChartHeaderRenderer(tooltipCommonClasses), [tooltipCommonClasses] - ); + ) - const [operatorGraph, setOperatorGraph] = React.useState(undefined); - const [operatorTable, setOperatorTable] = React.useState(undefined); - const [sortColumn, setSortColumn] = React.useState(''); - const [tableTooltips, setTableTooltips] = React.useState(undefined); - const [groupBy, setGroupBy] = React.useState(OperationGroupBy.OPERATION); - const [searchOperatorName, setSearchOperatorName] = React.useState(''); + const [operatorGraph, setOperatorGraph] = React.useState< + OperatorGraph | undefined + >(undefined) + const [operatorTable, setOperatorTable] = React.useState< + OperationTableData | undefined + >(undefined) + const [sortColumn, setSortColumn] = React.useState('') + const [tableTooltips, setTableTooltips] = React.useState( + undefined + ) + const [groupBy, setGroupBy] = React.useState(OperationGroupBy.Operation) + const [searchOperatorName, setSearchOperatorName] = React.useState('') const [topText, actualTop, useTop, setTopText, setUseTop] = useTopN({ - defaultUseTop: UseTop.USE, - defaultTop: 10, - }); + defaultUseTop: UseTop.Use, + defaultTop: 10 + }) - const getName = React.useCallback((row: OperationTableDataInner) => row.name, []); - const [searchedOperatorTable] = useSearchDirectly(searchOperatorName, getName, operatorTable); + const getName = React.useCallback( + (row: OperationTableDataInner) => row.name, + [] + ) + const [searchedOperatorTable] = useSearchDirectly( + searchOperatorName, + getName, + operatorTable + ) const onSearchOperatorChanged: TextFieldProps['onChange'] = (event) => { - setSearchOperatorName(event.target.value as string); - }; + setSearchOperatorName(event.target.value as string) + } React.useEffect(() => { if (operatorGraph) { @@ -115,45 +135,49 @@ export const Operator: React.FC = (props) => { operatorGraph.device_self_time?.rows.length ?? 0, operatorGraph.device_total_time?.rows.length ?? 0, operatorGraph.host_self_time.rows?.length ?? 0, - operatorGraph.host_total_time.rows?.length ?? 0, - ]; - setTopText(String(Math.min(Math.max(...counts), 10))); + operatorGraph.host_total_time.rows?.length ?? 0 + ] + setTopText(String(Math.min(Math.max(...counts), 10))) } - }, [operatorGraph]); + }, [operatorGraph]) React.useEffect(() => { - api.defaultApi.operationTableGet(run, worker, span, groupBy).then((resp) => { - setSortColumn(resp.metadata.sort); - setTableTooltips(resp.metadata.tooltips); - setOperatorTable(resp.data); - }); - }, [run, worker, span, groupBy]); + api.defaultApi + .operationTableGet(run, worker, span, groupBy) + .then((resp) => { + setSortColumn(resp.metadata.sort) + setTableTooltips(resp.metadata.tooltips) + setOperatorTable(resp.data) + }) + }, [run, worker, span, groupBy]) React.useEffect(() => { - api.defaultApi.operationGet(run, worker, span, groupBy).then((resp) => { - setOperatorGraph(resp); - }); - }, [run, worker, span, groupBy]); + api.defaultApi + .operationGet(run, worker, span, groupBy) + .then((resp) => { + setOperatorGraph(resp) + }) + }, [run, worker, span, groupBy]) const onGroupByChanged: SelectProps['onChange'] = (event) => { - setGroupBy(event.target.value as OperationGroupBy); - }; + setGroupBy(event.target.value as OperationGroupBy) + } const onUseTopChanged: RadioGroupProps['onChange'] = (event) => { - setUseTop(event.target.value as UseTop); - }; + setUseTop(event.target.value as UseTop) + } - const onTopChanged = (event: React.ChangeEvent): void => { - setTopText(event.target.value); - }; + const onTopChanged = (event: React.ChangeEvent) => { + setTopText(event.target.value) + } const inputProps: StandardTextFieldProps['inputProps'] = { - min: 1, - }; + min: 1 + } - const renderCharts = (graph: api.OperatorGraph): JSX.Element => { + const renderCharts = (graph: api.OperatorGraph) => { return ( - + {graph.device_self_time && ( @@ -161,7 +185,7 @@ export const Operator: React.FC = (props) => { )} @@ -176,7 +200,7 @@ export const Operator: React.FC = (props) => { )} @@ -187,7 +211,12 @@ export const Operator: React.FC = (props) => { {graph.host_self_time.title && ( - + )} @@ -195,34 +224,47 @@ export const Operator: React.FC = (props) => { {graph.host_total_time.title && ( - + )} - ); - }; + ) + } return (
- - + + - + - } label='All operators' /> - } label='Top operators to show' /> + } + label="All operators" + /> + } + label="Top operators to show" + /> - {useTop === UseTop.USE && ( + {useTop === UseTop.Use && ( = (props) => { {renderCharts} - + - + - Group By - + + Operator + Input Shape + + + Operator + @@ -248,10 +298,10 @@ export const Operator: React.FC = (props) => { classes={{ root: classes.inputWidthOverflow }} value={searchOperatorName} onChange={onSearchOperatorChanged} - type='search' - label='Search by Name' + type="search" + label="Search by Name" inputProps={{ - maxLength: 200, + maxLength: 200 }} /> @@ -259,7 +309,7 @@ export const Operator: React.FC = (props) => { - {(table): JSX.Element => ( + {(table) => ( = (props) => {
- ); -}; + ) +} diff --git a/plugins/tensorboard-plugins/tb_plugin/fe/src/components/Overview.tsx b/plugins/tensorboard-plugins/tb_plugin/fe/src/components/Overview.tsx index 6a81c567bc5e44b1dd6eb4746135d61268cadb81..e5f6f17bdaae3d276f24ed24f3566fc994fec0ad 100644 --- a/plugins/tensorboard-plugins/tb_plugin/fe/src/components/Overview.tsx +++ b/plugins/tensorboard-plugins/tb_plugin/fe/src/components/Overview.tsx @@ -2,50 +2,53 @@ * Copyright (c) Microsoft Corporation. All rights reserved. *--------------------------------------------------------------------------------------------*/ -import Card from '@material-ui/core/Card'; -import CardContent from '@material-ui/core/CardContent'; -import CardHeader from '@material-ui/core/CardHeader'; -import Grid from '@material-ui/core/Grid'; -import { makeStyles } from '@material-ui/core/styles'; -import { Table } from 'antd'; -import { ColumnsType } from 'antd/es/table'; -import * as React from 'react'; -import * as api from '../api'; -import { PieChart } from './charts/PieChart'; -import { SteppedAreaChart } from './charts/SteppedAreaChart'; -import { DataLoading } from './DataLoading'; -import { makeChartHeaderRenderer, useTooltipCommonStyles } from './helpers'; -import { TextListItem } from './TextListItem'; -import { stepTimeBreakDownTooltip } from './TooltipDescriptions'; -import { transformPerformanceIntoPie, transformPerformanceIntoTable } from './transform'; +import Card from '@material-ui/core/Card' +import CardContent from '@material-ui/core/CardContent' +import CardHeader from '@material-ui/core/CardHeader' +import Grid from '@material-ui/core/Grid' +import { makeStyles } from '@material-ui/core/styles' +import { Table } from 'antd' +import { ColumnsType } from 'antd/es/table' +import * as React from 'react' +import * as api from '../api' +import { PieChart } from './charts/PieChart' +import { SteppedAreaChart } from './charts/SteppedAreaChart' +import { DataLoading } from './DataLoading' +import { makeChartHeaderRenderer, useTooltipCommonStyles } from './helpers' +import { TextListItem } from './TextListItem' +import { StepTimeBreakDownTooltip } from './TooltipDescriptions' +import { + transformPerformanceIntoPie, + transformPerformanceIntoTable +} from './transform' -const topGraphHeight = 230; +const topGraphHeight = 230 const useStyles = makeStyles((theme) => ({ root: { - flexGrow: 1, + flexGrow: 1 }, pre: { '& ul': { margin: 0, paddingLeft: theme.spacing(3), - ...theme.typography.body1, + ...theme.typography.body1 }, '& li': {}, '& a': { - color: '#ffa726', + color: '#ffa726' }, '& a:active': { - color: '#ffa726', + color: '#ffa726' }, '& p': { margin: 0, ...theme.typography.subtitle1, - fontWeight: theme.typography.fontWeightBold, - }, + fontWeight: theme.typography.fontWeightBold + } }, topGraph: { - height: topGraphHeight + 40, + height: topGraphHeight + 40 }, table: { height: '100%', @@ -54,87 +57,89 @@ const useStyles = makeStyles((theme) => ({ height: 20, fontSize: '10pt', '& > td': { - padding: '0 8px!important', - }, - }, - }, -})); + padding: '0 8px!important' + } + } + } +})) export interface IProps { - run: string; - worker: string; - span: string; + run: string + worker: string + span: string } export const Overview: React.FC = (props) => { - const { run, worker, span } = props; + const { run, worker, span } = props - const [steps, setSteps] = React.useState(undefined); - const [performances, setPerformances] = React.useState([]); - const [environments, setEnvironments] = React.useState([]); - const [gpuMetrics, setGpuMetrics] = React.useState(undefined); - const [recommendations, setRecommendations] = React.useState(''); - const [columns, setColumns] = React.useState>([]); + const [steps, setSteps] = React.useState(undefined) + const [performances, setPerformances] = React.useState([]) + const [environments, setEnvironments] = React.useState([]) + const [gpuMetrics, setGpuMetrics] = React.useState< + api.GpuMetrics | undefined + >(undefined) + const [recommendations, setRecommendations] = React.useState('') + const [columns, setColumns] = React.useState>([]) const tableRows = React.useMemo(() => { - let dataInfo: api.Graph = transformPerformanceIntoTable(performances); + let dataInfo: api.Graph = transformPerformanceIntoTable(performances) if (dataInfo.columns.length < 3) { - return []; + return [] } - const stringCompare = (a: string, b: string): number => a.localeCompare(b); - const numberCompare = (a: number, b: number): number => a - b; - let column: any[] = dataInfo.columns.map((item) => { + const stringCompare = (a: string, b: string) => a.localeCompare(b) + const numberCompare = (a: number, b: number) => a - b + let column: any[] = dataInfo.columns.map(item => { return { title: item.name, key: item.name, dataIndex: item.name, - sorter: - item.type === 'string' - ? (a: any, b: any): number => stringCompare(a[item.name], b[item.name]) - : (a: any, b: any): number => numberCompare(a[item.name], b[item.name]), - }; - }); - setColumns(column); + sorter: item.type == 'string' ? (a: any, b: any) => stringCompare(a[item.name], b[item.name]) + : (a: any, b: any) => numberCompare(a[item.name], b[item.name]) + } + }) + setColumns(column) return dataInfo.rows.map((row, index) => { if (row.length < 3) { - return null; + return null } return { key: index, [dataInfo.columns[0].name]: row[0], [dataInfo.columns[1].name]: row[1], - [dataInfo.columns[2].name]: row[2], - }; - }); - }, [performances]); + [dataInfo.columns[2].name]: row[2] + } + }) + }, [performances]) const synthesizedPieGraph = React.useMemo(() => { - return transformPerformanceIntoPie(performances); - }, [performances]); + return transformPerformanceIntoPie(performances) + }, [performances]) React.useEffect(() => { api.defaultApi.overviewGet(run, worker, span).then((resp) => { - setPerformances(resp.performance); - setEnvironments(resp.environments); - setSteps(resp.steps); - setRecommendations(resp.recommendations); - setGpuMetrics(resp.gpu_metrics); - }); - }, [run, worker, span]); + setPerformances(resp.performance) + setEnvironments(resp.environments) + setSteps(resp.steps) + setRecommendations(resp.recommendations) + setGpuMetrics(resp.gpu_metrics) + }) + }, [run, worker, span]) - const classes = useStyles(); - const tooltipCommonClasses = useTooltipCommonStyles(); + const classes = useStyles() + const tooltipCommonClasses = useTooltipCommonStyles() const chartHeaderRenderer = React.useMemo( () => makeChartHeaderRenderer(tooltipCommonClasses, false), [tooltipCommonClasses] - ); + ) const stepTimeBreakDownTitle = React.useMemo( - () => chartHeaderRenderer('Step Time Breakdown', stepTimeBreakDownTooltip), + () => chartHeaderRenderer('Step Time Breakdown', StepTimeBreakDownTooltip), [tooltipCommonClasses, chartHeaderRenderer] - ); + ) - const cardSizes = gpuMetrics ? ([2, 3, 7] as const) : ([4, undefined, 8] as const); + const cardSizes = gpuMetrics + ? ([2, 3, 7] as const) + : ([4, undefined, 8] as const) return (
@@ -143,11 +148,14 @@ export const Overview: React.FC = (props) => { {React.useMemo( () => ( - - + + {environments.map((environment) => ( - + ))} @@ -157,19 +165,28 @@ export const Overview: React.FC = (props) => { {gpuMetrics && ( - - - + + + {gpuMetrics.data.map((metric) => ( - + ))} )} - - + + @@ -182,7 +199,10 @@ export const Overview: React.FC = (props) => { /> - + @@ -191,12 +211,16 @@ export const Overview: React.FC = (props) => { - + - {(graph): JSX.Element => ( - + {(graph) => ( + )} @@ -205,13 +229,13 @@ export const Overview: React.FC = (props) => { - - + +
@@ -221,5 +245,5 @@ export const Overview: React.FC = (props) => {
- ); -}; + ) +} diff --git a/plugins/tensorboard-plugins/tb_plugin/fe/src/components/TextListItem.tsx b/plugins/tensorboard-plugins/tb_plugin/fe/src/components/TextListItem.tsx index 59eb79c2a8f05cc750d264880bb66ab646c4bbb4..c5e4eee5251f7ab8afedf58f305a5cb30ad92a19 100644 --- a/plugins/tensorboard-plugins/tb_plugin/fe/src/components/TextListItem.tsx +++ b/plugins/tensorboard-plugins/tb_plugin/fe/src/components/TextListItem.tsx @@ -2,69 +2,76 @@ * Copyright (c) Microsoft Corporation. All rights reserved. *--------------------------------------------------------------------------------------------*/ -import Grid from '@material-ui/core/Grid'; -import { makeStyles } from '@material-ui/core/styles'; -import * as React from 'react'; +import Grid from '@material-ui/core/Grid' +import { makeStyles } from '@material-ui/core/styles' +import * as React from 'react' export interface IStylesProps { - root?: string; - name?: string; + root?: string + name?: string } export interface IProps { - name: string; - value?: string; - description?: string; - extra?: string; - classes?: IStylesProps; - dangerouslyAllowHtml?: boolean; + name: string + value?: string + description?: string + extra?: string + classes?: IStylesProps + dangerouslyAllowHtml?: boolean } const useStyles = makeStyles((theme) => ({ label: { ...theme.typography.subtitle2, - fontWeight: 'bolder', + fontWeight: 'bolder' }, value: { textAlign: 'right', ...theme.typography.subtitle2, - fontWeight: 'bolder', - }, -})); + fontWeight: 'bolder' + } +})) export const TextListItem: React.FC = (props) => { - const classes = useStyles(); + const classes = useStyles() - const getSizes = function (): readonly any[] { + const getSizes = function () { if (props.value && props.extra) { - return [4, 4, 4] as const; + return [4, 4, 4] as const } if (props.value) { if (props.value.length > props.name.length) { - return [4, 8, undefined] as const; + return [4, 8, undefined] as const } - return [8, 4, undefined] as const; + return [8, 4, undefined] as const } - return [12, undefined, undefined] as const; - }; + return [12, undefined, undefined] as const + } - const sizes = getSizes(); + const sizes = getSizes() - const renderSpan = function (content: string, className?: string): React.JSX.Element { + const renderSpan = function (content: string, className?: string) { if (props.dangerouslyAllowHtml) { - return ; + return ( + + ) } - return {content}; - }; + return {content} + } return ( - + {renderSpan(props.name, props.classes?.name)} - {props.description && {renderSpan(props.description)}} + {props.description && ( + {renderSpan(props.description)} + )} {props.value && ( @@ -78,5 +85,5 @@ export const TextListItem: React.FC = (props) => { )} - ); -}; + ) +} diff --git a/plugins/tensorboard-plugins/tb_plugin/fe/src/components/TooltipDescriptions.ts b/plugins/tensorboard-plugins/tb_plugin/fe/src/components/TooltipDescriptions.ts index 6d3631fee97a4dd8da5ebde1550573d8c6e501fa..8f434221ddbdbd48a7a41ab6c73b2901519007c5 100644 --- a/plugins/tensorboard-plugins/tb_plugin/fe/src/components/TooltipDescriptions.ts +++ b/plugins/tensorboard-plugins/tb_plugin/fe/src/components/TooltipDescriptions.ts @@ -2,37 +2,37 @@ * Copyright (c) Microsoft Corporation. All rights reserved. *--------------------------------------------------------------------------------------------*/ -export const stepTimeBreakDownTooltip = `The time spent on each step is broken down into multiple categories as follows: +export const StepTimeBreakDownTooltip = `The time spent on each step is broken down into multiple categories as follows: Kernel: Kernels execution time on GPU device; Memcpy: GPU involved memory copy time (either D2D, D2H or H2D); Memset: GPU involved memory set time; Runtime: CUDA runtime execution time on host side; Such as cudaLaunchKernel, cudaMemcpyAsync, cudaStreamSynchronize, ... DataLoader: The data loading time spent in PyTorch DataLoader object; CPU Exec: Host compute time, including every PyTorch operator running time; -Other: The time not included in any of the above.`; +Other: The time not included in any of the above.` -export const deviceSelfTimeTooltip = `The accumulated time spent on GPU, not including this operator’s child operators.`; +export const DeviceSelfTimeTooltip = `The accumulated time spent on GPU, not including this operator’s child operators.` -export const deviceSelfTimeTooltipAscend = `The accumulated time spent on NPU, not including this operator’s child operators.`; +export const DeviceSelfTimeTooltipAscend = `The accumulated time spent on NPU, not including this operator’s child operators.` -export const deviceTotalTimeTooltip = `The accumulated time spent on GPU, including this operator’s child operators.`; +export const DeviceTotalTimeTooltip = `The accumulated time spent on GPU, including this operator’s child operators.` -export const deviceTotalTimeTooltipAscend = `The accumulated time spent on NPU, including this operator’s child operators.`; +export const DeviceTotalTimeTooltipAscend = `The accumulated time spent on NPU, including this operator’s child operators.` -export const hostSelfTimeTooltip = `The accumulated time spent on Host, not including this operator’s child operators.`; +export const HostSelfTimeTooltip = `The accumulated time spent on Host, not including this operator’s child operators.` -export const hostTotalTimeTooltip = `The accumulated time spent on Host, including this operator’s child operators.`; +export const HostTotalTimeTooltip = `The accumulated time spent on Host, including this operator’s child operators.` -export const gpuKernelTotalTimeTooltip = `The accumulated time of all calls of this kernel.`; +export const GPUKernelTotalTimeTooltip = `The accumulated time of all calls of this kernel.` -export const tensorCoresPieChartTooltip = `The accumulated time of all kernels using or not using Tensor Cores.`; +export const TensorCoresPieChartTooltip = `The accumulated time of all kernels using or not using Tensor Cores.` -export const tensorCoresPieChartTooltipAscend = `The accumulated time of all kernels group by Accelerator Core.`; +export const TensorCoresPieChartTooltipAscend = `The accumulated time of all kernels group by Accelerator Core.` -export const distributedGpuInfoTableTooltip = `Information about GPU hardware used during the run.`; +export const DistributedGpuInfoTableTooltip = `Information about GPU hardware used during the run.` -export const distributedOverlapGraphTooltip = `The time spent on computation vs communication.`; +export const DistributedOverlapGraphTooltip = `The time spent on computation vs communication.` -export const distributedWaittimeGraphTooltip = `The time spent waiting vs communicating between devices.`; +export const DistributedWaittimeGraphTooltip = `The time spent waiting vs communicating between devices.` -export const distributedCommopsTableTooltip = `Statistics for operations managing communications between nodes.`; +export const DistributedCommopsTableTooltip = `Statistics for operations managing communications between nodes.` diff --git a/plugins/tensorboard-plugins/tb_plugin/fe/src/components/TraceView.tsx b/plugins/tensorboard-plugins/tb_plugin/fe/src/components/TraceView.tsx index be499794936a085ed72740eea8bac5f33df37171..8f1f3684305cabfe6f35d341557386c1d8f71cf1 100644 --- a/plugins/tensorboard-plugins/tb_plugin/fe/src/components/TraceView.tsx +++ b/plugins/tensorboard-plugins/tb_plugin/fe/src/components/TraceView.tsx @@ -2,78 +2,85 @@ * Copyright (c) Microsoft Corporation. All rights reserved. *--------------------------------------------------------------------------------------------*/ -import ClickAwayListener from '@material-ui/core/ClickAwayListener'; -import { makeStyles } from '@material-ui/core/styles'; -import * as React from 'react'; -import * as api from '../api'; +import ClickAwayListener from '@material-ui/core/ClickAwayListener' +import { makeStyles } from '@material-ui/core/styles' +import * as React from 'react' +import * as api from '../api' export interface IProps { - run: string; - worker: string; - span: string; - iframeRef: React.RefObject; + run: string + worker: string + span: string + iframeRef: React.RefObject } const useStyles = makeStyles(() => ({ root: { - flexGrow: 1, + flexGrow: 1 }, frame: { width: '100%', height: 'calc(100vh - 48px)', - border: 'none', - }, -})); + border: 'none' + } +})) export const TraceView: React.FC = (props) => { - const { run, worker, span, iframeRef } = props; - const classes = useStyles(); + const { run, worker, span, iframeRef } = props + const classes = useStyles() - const [traceData, setTraceData] = React.useState | null>(null); - const [traceViewReady, setTraceViewReady] = React.useState(false); + const [traceData, setTraceData] = React.useState | null>(null) + const [traceViewReady, setTraceViewReady] = React.useState(false) React.useEffect(() => { setTraceData( api.defaultApi.traceGet(run, worker, span).then((resp) => { - return JSON.stringify(resp); + return JSON.stringify(resp) }) - ); - }, [run, worker, span]); + ) + }, [run, worker, span]) React.useEffect(() => { - function callback(event: MessageEvent): void { - const data = event.data || {}; + function callback(event: MessageEvent) { + const data = event.data || {} if (data.msg === 'ready') { - setTraceViewReady(true); + setTraceViewReady(true) } } - window.addEventListener('message', callback); + window.addEventListener('message', callback) return () => { - window.removeEventListener('message', callback); - }; - }, []); + window.removeEventListener('message', callback) + } + }, []) React.useEffect(() => { if (traceData && traceViewReady) { traceData.then((data) => { - iframeRef.current?.contentWindow?.postMessage({ msg: 'data', data }, window.origin); - }); + iframeRef.current?.contentWindow?.postMessage( + { msg: 'data', data }, + '*' + ) + }) } - }, [traceData, traceViewReady]); - const setIframeActive = (): void => { - iframeRef.current?.focus(); - }; + }, [traceData, traceViewReady]) + const SetIframeActive = () => { + iframeRef.current?.focus() + } return (
{React.useMemo( () => ( - - + + ), [] )}
- ); -}; + ) +} diff --git a/plugins/tensorboard-plugins/tb_plugin/fe/src/components/charts/AntTableChart.tsx b/plugins/tensorboard-plugins/tb_plugin/fe/src/components/charts/AntTableChart.tsx index 83618064b55223ab06d4d1fec8b8b5eeab8d3268..064167fc64b4e00ec79b648a85d12dff23ecfcd0 100644 --- a/plugins/tensorboard-plugins/tb_plugin/fe/src/components/charts/AntTableChart.tsx +++ b/plugins/tensorboard-plugins/tb_plugin/fe/src/components/charts/AntTableChart.tsx @@ -2,110 +2,110 @@ * Copyright (c) Microsoft Corporation. All rights reserved. *--------------------------------------------------------------------------------------------*/ -import { makeStyles } from '@material-ui/core/styles'; -import { Table } from 'antd'; -import * as React from 'react'; -import { Graph } from '../../api'; +import { makeStyles } from '@material-ui/core/styles' +import { Table } from 'antd' +import * as React from 'react' +import { Graph } from '../../api' interface IProps { - graph: Graph; - sortColumn?: string; - initialPageSize?: number; - onRowSelected?: (record?: object, rowIndex?: number) => void; + graph: Graph + sortColumn?: string + initialPageSize?: number + onRowSelected?: (record?: object, rowIndex?: number) => void } const useStyles = makeStyles((theme) => ({ tooltip: { - whiteSpace: 'pre-wrap', + whiteSpace: 'pre-wrap' }, row: { - wordBreak: 'break-word', - }, -})); + wordBreak: 'break-word' + } +})) -const getTableColumns = function (columns: any, sort: string | undefined, tooltipClass: string): any { - let i = 0; - return columns.map((col: any) => { - const key = `col${i++}`; - const stringCompare = (a: any, b: any): number => a[key].localeCompare(b[key]); - const numberCompare = (a: any, b: any): number => (a[key] || 0) - (b[key] || 0); +const getTableColumns = function ( + columns: any, + sort: string | undefined, + tooltipClass: string +) { + let i = 0 + return columns.map(function (col: any) { + const key = 'col' + i++ + const stringCompare = (a: any, b: any) => a[key].localeCompare(b[key]) + const numberCompare = (a: any, b: any) => (a[key] || 0) - (b[key] || 0) return { dataIndex: key, key: key, title: col.name, - sorter: col.type === 'string' ? stringCompare : numberCompare, - defaultSortOrder: sort === col.name ? ('descend' as const) : undefined, - showSorterTooltip: col.tooltip ? { title: col.tooltip, overlayClassName: tooltipClass } : true, - }; - }); -}; + sorter: col.type == 'string' ? stringCompare : numberCompare, + defaultSortOrder: sort == col.name ? ('descend' as const) : undefined, + showSorterTooltip: col.tooltip + ? { title: col.tooltip, overlayClassName: tooltipClass } + : true + } + }) +} -const getTableRows = function (rows: any): any { - return rows.map((row: any) => { - let i = 0; - const res: any = {}; - row.forEach((entry: any) => { - res[`col${i++}`] = entry; - }); - return res; - }); -}; +const getTableRows = function (rows: any) { + return rows.map(function (row: any) { + let i = 0 + const res: any = {} + row.forEach(function (entry: any) { + res['col' + i++] = entry + }) + return res + }) +} export const AntTableChart: React.FC = (props) => { - const { graph, sortColumn, initialPageSize, onRowSelected } = props; - const classes = useStyles(props); + const { graph, sortColumn, initialPageSize, onRowSelected } = props + const classes = useStyles(props) - const rows = React.useMemo(() => getTableRows(graph.rows), [graph.rows]); + const rows = React.useMemo(() => getTableRows(graph.rows), [graph.rows]) const columns = React.useMemo( () => getTableColumns(graph.columns, sortColumn, classes.tooltip), [graph.columns, sortColumn, classes.tooltip] - ); + ) // key is used to reset the Table state (page and sort) if the columns change - const key: string = React.useMemo(() => `${Math.random()}`, [graph.columns]); + const key = React.useMemo(() => Math.random() + '', [graph.columns]) - const [pageSize, setPageSize] = React.useState(initialPageSize ?? 30); - const onShowSizeChange = (current: number, size: number): void => { - setPageSize(size); - }; + const [pageSize, setPageSize] = React.useState(initialPageSize ?? 30) + const onShowSizeChange = (current: number, size: number) => { + setPageSize(size) + } - const onRow = ( - record: object, - rowIndex?: number - ): { - onMouseEnter: (event: any) => void; - onMouseLeave: (event: any) => void; - } => { + const onRow = (record: object, rowIndex?: number) => { return { - onMouseEnter: (event: any): void => { + onMouseEnter: (event: any) => { if (onRowSelected) { - onRowSelected(record, rowIndex); + onRowSelected(record, rowIndex) } }, - onMouseLeave: (event: any): void => { + onMouseLeave: (event: any) => { if (onRowSelected) { - onRowSelected(undefined, undefined); + onRowSelected(undefined, undefined) } - }, - }; - }; + } + } + } return (
- ); -}; + ) +} diff --git a/plugins/tensorboard-plugins/tb_plugin/fe/src/components/charts/AreaChart.tsx b/plugins/tensorboard-plugins/tb_plugin/fe/src/components/charts/AreaChart.tsx index cda12860c2fba41f5a15c5d9e73fb92093c0371b..6a0f5b484d9c156927edfeae64a729bec821c164 100644 --- a/plugins/tensorboard-plugins/tb_plugin/fe/src/components/charts/AreaChart.tsx +++ b/plugins/tensorboard-plugins/tb_plugin/fe/src/components/charts/AreaChart.tsx @@ -2,46 +2,44 @@ * Copyright (c) Microsoft Corporation. All rights reserved. *--------------------------------------------------------------------------------------------*/ -import { makeStyles } from '@material-ui/core/styles'; -import * as React from 'react'; -import { Graph } from '../../api'; -import { useResizeEventDependency } from '../../utils/resize'; +import { makeStyles } from '@material-ui/core/styles' +import * as React from 'react' +import { Graph } from '../../api' +import { useResizeEventDependency } from '../../utils/resize' interface IProps { - graph: Graph; - height?: number; - hAxisTitle?: string; + graph: Graph + height?: number + hAxisTitle?: string } const useStyles = makeStyles(() => ({ root: { - height: (props: Pick): number | undefined => props.height, - }, -})); + height: (props: Pick) => props.height + } +})) export const AreaChart: React.FC = (props) => { - const { graph, height = 400, hAxisTitle } = props; - const classes = useStyles({ height }); - const graphRef = React.useRef(null); - const [resizeEventDependency] = useResizeEventDependency(); + const { graph, height = 400, hAxisTitle } = props + const classes = useStyles({ height }) + const graphRef = React.useRef(null) + const [resizeEventDependency] = useResizeEventDependency() React.useLayoutEffect(() => { - const element = graphRef.current; - if (!element) { - return undefined; - } + const element = graphRef.current + if (!element) return - const data = new google.visualization.DataTable(); - data.addColumn('string', 'step'); + const data = new google.visualization.DataTable() + data.addColumn('string', 'step') graph.columns.forEach((column) => { data.addColumn({ type: column.type, label: column.name, role: column.role, - p: column.p, - }); - }); - data.addRows(graph.rows.map((x, i) => [(i + 1).toString(), ...x])); + p: column.p + }) + }) + data.addRows(graph.rows.map((x, i) => [(i + 1).toString(), ...x])) const options = { title: graph.title, @@ -51,22 +49,22 @@ export const AreaChart: React.FC = (props) => { tooltip: { isHtml: true }, chartArea: { left: '15%', width: '80%', top: '10%' }, hAxis: { - title: hAxisTitle, - }, - }; + title: hAxisTitle + } + } - const chart = new google.visualization.AreaChart(element); + const chart = new google.visualization.AreaChart(element) - chart.draw(data, options); + chart.draw(data, options) return () => { - chart.clearChart(); - }; - }, [graph, height, resizeEventDependency]); + chart.clearChart() + } + }, [graph, height, resizeEventDependency]) return (
- ); -}; + ) +} diff --git a/plugins/tensorboard-plugins/tb_plugin/fe/src/components/charts/ColumnChart.tsx b/plugins/tensorboard-plugins/tb_plugin/fe/src/components/charts/ColumnChart.tsx index ae51dc1a34e94b1c91eab2fe502ffe2cbc20f618..1c83eea95998222903a161d6ddbb678189a03775 100644 --- a/plugins/tensorboard-plugins/tb_plugin/fe/src/components/charts/ColumnChart.tsx +++ b/plugins/tensorboard-plugins/tb_plugin/fe/src/components/charts/ColumnChart.tsx @@ -15,62 +15,58 @@ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. - * + * * Modifications: Offer offline supporting. *--------------------------------------------------------------------------------------------*/ -import * as React from 'react'; -import { useResizeEventDependency } from '../../utils/resize'; -import * as echarts from 'echarts'; +import * as React from 'react' +import { useResizeEventDependency } from '../../utils/resize' +import * as echarts from 'echarts' interface IProps { - title?: string; - units?: string; - colors?: Array; - chartData: ColumnChartData; + title?: string + units?: string + colors?: Array + chartData: ColumnChartData } export interface ColumnChartData { - legends: Array; - barLabels: Array; - barHeights: Array>; + legends: Array + barLabels: Array + barHeights: Array> } export const ColumnChart: React.FC = (props) => { - const { title, units, colors, chartData } = props; - const { legends, barLabels, barHeights } = chartData; - const graphRef = React.useRef(null); - const [resizeEventDependency] = useResizeEventDependency(); + const { title, units, colors, chartData } = props + const { legends, barLabels, barHeights } = chartData + const graphRef = React.useRef(null) + const [resizeEventDependency] = useResizeEventDependency() - const getAngleByDataLength = (data: number): number => { + const getAngleByDataLength = (data: number) => { if (data < 10) { - return 0; + return 0 } else { // 数量越大越趋近于旋转90度 - return 90 * (1 - (10 / data)); + return 90 * (1 - 10 / data) } - }; + } React.useLayoutEffect(() => { - const element = graphRef.current; - if (!element) { - return undefined; - } + const element = graphRef.current + if (!element) return - const chart = echarts.init(element); - const dataSource: Array> = []; - dataSource.push(['worker', ...legends]); + const chart = echarts.init(element) + const dataSource: Array> = [] + dataSource.push(['worker', ...legends]) barHeights.forEach((item, index) => { - if (barLabels[index] !== undefined) { - dataSource.push([barLabels[index], ...item]); - } - }); + barLabels[index] !== undefined && dataSource.push([barLabels[index], ...item]) + }) const options: echarts.EChartsOption = { title: { - text: title, + text: title }, legend: { - bottom: 0, + bottom: 0 }, xAxis: { type: 'category', @@ -78,41 +74,43 @@ export const ColumnChart: React.FC = (props) => { interval: 0, rotate: getAngleByDataLength(barLabels.length), formatter: (name: string) => { - const index = name.indexOf('@'); - const processedName = index > -1 ? name.slice(index + 1) : name; // 使用新变量处理 - return processedName.length > 16 ? `${processedName.slice(0, 14)}...` : processedName; - }, - }, + const index = name.indexOf('@') + if (index > -1) { + name = name.slice(index + 1) + } + return name.length > 16 ? name.slice(0, 14) + "..." : name; + } + } }, yAxis: { type: 'value', name: units, nameTextStyle: { - fontSize: 16, - }, + fontSize: 16 + } }, tooltip: { - trigger: 'item', + trigger: 'item' }, dataset: { - source: dataSource, + source: dataSource }, series: Array(legends.length).fill({ type: 'bar', - stack: 'samesign', + stack: 'samesign' }), - }; + } if (colors) { - options.color = colors.slice(0, barLabels.length); + options.color = colors.slice(0, barLabels.length) } - if (options) { - chart.setOption(options, true); - } + options && chart.setOption(options, true) return () => { - chart.dispose(); - }; - }, [title, chartData, resizeEventDependency]); + chart.dispose() + } + }, [title, chartData, resizeEventDependency]) - return
; -}; + return ( +
+ ) +} diff --git a/plugins/tensorboard-plugins/tb_plugin/fe/src/components/charts/LineChart.tsx b/plugins/tensorboard-plugins/tb_plugin/fe/src/components/charts/LineChart.tsx new file mode 100644 index 0000000000000000000000000000000000000000..b9a031d3a44336e568f30524abc8837590b3f603 --- /dev/null +++ b/plugins/tensorboard-plugins/tb_plugin/fe/src/components/charts/LineChart.tsx @@ -0,0 +1,224 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +import { makeStyles } from '@material-ui/core/styles' +import * as React from 'react' +import { Graph, GraphAscend } from '../../api' +import { useResizeEventDependency } from '../../utils/resize' +import { binarySearch } from '../../utils/binarysearch' + +interface IProps { + graph: Graph | GraphAscend + height?: number + deviceTarget: string + tag: string + hAxisTitle?: string + vAxisTitle?: string + explorerOptions?: object + onSelectionChanged?: (start: number, end: number) => void + record?: any +} + +const useStyles = makeStyles(() => ({ + root: { + height: (props: Pick) => props.height + } +})) + +export const LineChart: React.FC = (props) => { + const { + graph, + height = 400, + deviceTarget, + tag, + hAxisTitle, + vAxisTitle, + onSelectionChanged, + explorerOptions, + record + } = props + const classes = useStyles({ height }) + const graphRef = React.useRef(null) + const [resizeEventDependency] = useResizeEventDependency() + const [chartObj, setChartObj] = React.useState() + + React.useLayoutEffect(() => { + const element = graphRef.current + if (!element) return + + const options = { + title: graph.title, + isStacked: true, + height, + legend: { position: 'bottom' }, + tooltip: { isHtml: true }, + hAxis: { + title: hAxisTitle + }, + vAxis: { + title: vAxisTitle + }, + explorer: explorerOptions + } + + const chart = new google.visualization.LineChart(element) + + // Disable selection of single point + google.visualization.events.addListener(chart, 'select', function () { + chart.setSelection() + }) + + google.visualization.events.addListener(chart, 'ready', function () { + var zoomLast = getCoords() + var observer = new MutationObserver(function () { + var zoomCurrent = getCoords() + if (JSON.stringify(zoomLast) !== JSON.stringify(zoomCurrent)) { + zoomLast = getCoords() + if (onSelectionChanged) { + onSelectionChanged(zoomLast.x_min, zoomLast.x_max) + } + } + }) + if (graphRef.current) { + observer.observe(graphRef.current, { + childList: true, + subtree: true + }) + } + }) + + function getCoords() { + var chartLayout = chart.getChartLayoutInterface() + var chartBounds = chartLayout.getChartAreaBoundingBox() + + return { + x_min: chartLayout.getHAxisValue(chartBounds.left), + x_max: chartLayout.getHAxisValue(chartBounds.width + chartBounds.left) + } + } + + if (deviceTarget === 'Ascend') { + let data = new google.visualization.DataTable() + if (tag === 'Component') { + if (graph.columns.length === 3) { + graph.columns.forEach((column) => { + data.addColumn({ + type: column.type, + label: column.name, + role: column.role, + p: column.p + }) + }) + data.addRows(graph.rows['PTA'] ?? graph.rows['GE']) + } else if (graph.columns.length === 5) { + const data2 = new google.visualization.DataTable() + graph.columns.forEach((column, index) => { + if (index === 0 || index < 3) { + data.addColumn({ + type: column.type, + label: column.name, + role: column.role, + p: column.p + }) + } + if (index === 0 || index >= 3) { + data2.addColumn({ + type: column.type, + label: column.name, + role: column.role, + p: column.p + }) + } + }) + data.addRows(graph.rows['PTA']) + data2.addRows(graph.rows['GE']) + data = google.visualization.data.join(data, data2, 'full', [[0, 0]], [1, 2], [1, 2]) + } + } else { + if (graph.columns.length === 2) { + graph.columns.forEach((column) => { + data.addColumn({ + type: column.type, + label: column.name, + role: column.role, + p: column.p + }) + }) + data.addRows(graph.rows['Allocated'] ?? graph.rows['Reserved']) + } else if (graph.columns.length === 3) { + const data2 = new google.visualization.DataTable() + graph.columns.forEach((column, index) => { + if (index === 0 || index < 2) { + data.addColumn({ + type: column.type, + label: column.name, + role: column.role, + p: column.p + }) + } + if (index === 0 || index >= 2) { + data2.addColumn({ + type: column.type, + label: column.name, + role: column.role, + p: column.p + }) + } + }) + data.addRows(graph.rows['Allocated']) + data2.addRows(graph.rows['Reserved']) + data = google.visualization.data.join(data, data2, 'full', [[0, 0]], [1], [1]) + } + } + + chart.draw(data, options) + } else { + const data = new google.visualization.DataTable() + graph.columns.forEach((column) => { + data.addColumn({ + type: column.type, + label: column.name, + role: column.role, + p: column.p + }) + }) + data.addRows(graph.rows) + chart.draw(data, options) + } + + setChartObj(chart) + }, [graph, height, resizeEventDependency]) + + React.useEffect(() => { + const compare_fn = (key: number, mid: Array) => + key - parseFloat(mid[0].toFixed(2)) + if (chartObj && tag === 'Operator') { + if (record) { + if (deviceTarget === 'Ascend') { + let startId = binarySearch(graph.rows['Allocated'], record.col2, compare_fn) + let endId = binarySearch(graph.rows['Allocated'], record.col3, compare_fn) + let selection = [] + if (startId >= 0) selection.push({ row: startId, column: 1 }) + if (endId >= 0) selection.push({ row: endId, column: 1 }) + chartObj.setSelection(selection) + } else { + let startId = binarySearch(graph.rows, record.col2, compare_fn) + let endId = binarySearch(graph.rows, record.col3, compare_fn) + let selection = [] + if (startId >= 0) selection.push({ row: startId, column: 1 }) + if (endId >= 0) selection.push({ row: endId, column: 1 }) + chartObj.setSelection(selection) + } + } else { + chartObj.setSelection() + } + } + }, [graph, record, chartObj]) + + return ( +
+
+
+ ) +} diff --git a/plugins/tensorboard-plugins/tb_plugin/fe/src/components/charts/NewLineChart.tsx b/plugins/tensorboard-plugins/tb_plugin/fe/src/components/charts/NewLineChart.tsx index a6e222a6cc9d04b3b0c9031be60b91b75fe9ab37..af350e93d96c364d9baf4952bd59458a7bbd0801 100644 --- a/plugins/tensorboard-plugins/tb_plugin/fe/src/components/charts/NewLineChart.tsx +++ b/plugins/tensorboard-plugins/tb_plugin/fe/src/components/charts/NewLineChart.tsx @@ -15,79 +15,85 @@ * limitations under the License. *--------------------------------------------------------------------------------------------*/ -import * as React from 'react'; -import { Graph, GraphAscend } from '../../api'; -import { useResizeEventDependency } from '../../utils/resize'; -import { binarySearch } from '../../utils/binarysearch'; -import * as echarts from 'echarts'; +import { makeStyles } from '@material-ui/core/styles' +import * as React from 'react' +import { Graph, GraphAscend } from '../../api' +import { useResizeEventDependency } from '../../utils/resize' +import { binarySearch } from '../../utils/binarysearch' +import * as echarts from 'echarts' interface IProps { - graph: Graph | GraphAscend; - height?: number; - deviceTarget: string; - tag: string; - hAxisTitle?: string; - vAxisTitle?: string; - onSelectionChanged?: (start: number, end: number) => void; - record?: any; + graph: Graph | GraphAscend + height?: number + deviceTarget: string + tag: string + hAxisTitle?: string + vAxisTitle?: string + onSelectionChanged?: (start: number, end: number) => void + record?: any } export const LineChart: React.FC = (props) => { - const { graph, height = 400, deviceTarget, tag, hAxisTitle, vAxisTitle, onSelectionChanged, record } = props; - const graphRef = React.useRef(null); - const [resizeEventDependency] = useResizeEventDependency(); - const [chartObj, setChartObj] = React.useState(); - const selectedPoints = React.useRef>([]); + const { + graph, + height = 400, + deviceTarget, + tag, + hAxisTitle, + vAxisTitle, + onSelectionChanged, + record + } = props + const graphRef = React.useRef(null) + const [resizeEventDependency] = useResizeEventDependency() + const [chartObj, setChartObj] = React.useState() + const selectedPoints = React.useRef>([]) React.useLayoutEffect(() => { - const element = graphRef.current; - if (!element) { - return undefined; - } - element.oncontextmenu = (): boolean => { - return false; - }; + const element = graphRef.current + if (!element) return + element.oncontextmenu = () => { return false } - let myChart = echarts.init(element); + let myChart = echarts.init(element) let option: echarts.EChartsOption = { title: { text: graph.title, textStyle: { - fontSize: 16, - }, + fontSize: 16 + } }, tooltip: { trigger: 'axis' }, legend: { type: 'scroll', - bottom: 0, + bottom: 0 }, xAxis: { type: 'category', boundaryGap: false, - name: hAxisTitle, + name: hAxisTitle }, yAxis: { type: 'value', name: vAxisTitle, - scale: true, + scale: true }, toolbox: { feature: { dataZoom: { - yAxisIndex: 'none', + yAxisIndex: 'none' }, - restore: {}, - }, - }, - }; + restore: {} + } + } + } if (deviceTarget === 'Ascend') { if (tag === 'Component') { const mixedTooltip: echarts.TooltipComponentOption = { trigger: 'axis', formatter: function (params: any) { - let res = `${params[0].name}
`; + var res = `${params[0].name}
` for (const item of params) { if (typeof item.value[item.encode.y[0]] === 'number') { res += ` - ${item.seriesName}: ${item.value[item.encode.y[0]]}
`; + ${item.seriesName}: ${item.value[item.encode.y[0]]}
` } } - return res; - }, - }; + return res + } + } if (graph.columns.length <= 4) { - let finalRows = graph.rows.PTA ?? graph.rows.GE; + let finalRows = graph.rows['PTA'] ?? graph.rows['GE'] if (graph.columns.length === 4) { - const mergedAPPRows = graph.rows.APP.map((item: Array) => { - return [item[0], null, null, item[1]]; - }); + const mergedAPPRows = graph.rows['APP'].map((item: Array) => { + return [item[0], null, null, item[1]] + }) finalRows = finalRows.concat(mergedAPPRows).sort((a: any, b: any) => { - return a[0] - b[0]; - }); + return a[0] - b[0] + }) } option = { ...option, tooltip: mixedTooltip, dataset: { - source: [graph.columns.map((column) => column.name), ...finalRows], + source: [ + graph.columns.map(column => column.name), + ...finalRows + ] }, - series: Array(graph.columns.length - 1).fill({ - type: 'line', - select: { - itemStyle: { - borderWidth: 5, - shadowBlur: 5, + series: Array(graph.columns.length - 1).fill( + { + type: 'line', + select: { + itemStyle: { + borderWidth: 5, + shadowBlur: 5 + } }, - }, - emphasis: { - itemStyle: { - borderWidth: 5, - shadowBlur: 5, + emphasis: { + itemStyle: { + borderWidth: 5, + shadowBlur: 5 + } }, - }, - selectedMode: 'single', - }), - }; + selectedMode: 'single', + } + ) + } } else if (graph.columns.length <= 6) { - const datasetTitle = graph.columns.map((item) => item.name); - let mergedGERows = graph.rows.GE.map((item: Array) => { - return [item[0], null, null, item[1], item[2]]; - }); + const datasetTitle = graph.columns.map(item => item.name) + let mergedGERows = graph.rows['GE'].map((item: Array) => { + return [item[0], null, null, item[1], item[2]] + }) if (graph.columns.length === 6) { - const mergedAPPRows = graph.rows.APP.map((item: Array) => { - return [item[0], null, null, null, null, item[2]]; - }); - mergedGERows = mergedGERows.concat(mergedAPPRows); + const mergedAPPRows = graph.rows['APP'].map((item: Array) => { + return [item[0], null, null, null, null, item[2]] + }) + mergedGERows = mergedGERows.concat(mergedAPPRows) } - const finalRows = graph.rows.PTA.concat(mergedGERows).sort((a: any, b: any) => { - return a[0] - b[0]; - }); + const finalRows = graph.rows['PTA'].concat(mergedGERows).sort((a: any, b: any) => { + return a[0] - b[0] + }) option = { ...option, tooltip: mixedTooltip, - dataset: { - source: [datasetTitle, ...finalRows], + dataset: + { + source: [ + datasetTitle, + ...finalRows + ] }, - series: Array(graph.columns.length - 1).fill({ - type: 'line', - connectNulls: true, - select: { - itemStyle: { - borderWidth: 5, - shadowBlur: 5, + series: Array(graph.columns.length - 1).fill( + { + type: 'line', + connectNulls: true, + select: { + itemStyle: { + borderWidth: 5, + shadowBlur: 5 + } }, - }, - emphasis: { - itemStyle: { - borderWidth: 5, - shadowBlur: 5, + emphasis: { + itemStyle: { + borderWidth: 5, + shadowBlur: 5 + } }, - }, - selectedMode: 'single', - datasetIndex: 0, - }), - }; + selectedMode: 'single', + datasetIndex: 0 + }) + } } } else { if (graph.columns.length === 3) { - const datasetTitle1: Array = []; - const datasetTitle2: Array = []; + const datasetTitle1: Array = [] + const datasetTitle2: Array = [] graph.columns.forEach((column, index) => { if (index === 0 || index < 2) { - datasetTitle1.push(column.name); + datasetTitle1.push(column.name) } if (index === 0 || index >= 2) { - datasetTitle2.push(column.name); + datasetTitle2.push(column.name) } - }); + }) option = { ...option, dataset: [ { - source: [datasetTitle1, ...graph.rows.Allocated], + source: [ + datasetTitle1, + ...graph.rows['Allocated'] + ] }, { - source: [datasetTitle2, ...graph.rows.Reserved], - }, + source: [ + datasetTitle2, + ...graph.rows['Reserved'] + ] + } ], series: [ { @@ -204,20 +226,20 @@ export const LineChart: React.FC = (props) => { name: 'Allocated', emphasis: { label: { - show: true, + show: true }, itemStyle: { borderWidth: 5, - shadowBlur: 5, - }, + shadowBlur: 5 + } }, select: { itemStyle: { borderWidth: 5, - shadowBlur: 5, - }, + shadowBlur: 5 + } }, - datasetIndex: 0, + datasetIndex: 0 }, { type: 'line', @@ -225,27 +247,30 @@ export const LineChart: React.FC = (props) => { select: { itemStyle: { borderWidth: 5, - shadowBlur: 5, - }, + shadowBlur: 5 + } }, emphasis: { itemStyle: { borderWidth: 5, - shadowBlur: 5, - }, + shadowBlur: 5 + } }, selectedMode: 'single', - datasetIndex: 1, - }, - ], - }; + datasetIndex: 1 + } + ] + } } } } else { option = { ...option, dataset: { - source: [graph.columns.map((column) => column.name), ...graph.rows], + source: [ + graph.columns.map(column => column.name), + ...graph.rows + ] }, series: [ { @@ -254,16 +279,16 @@ export const LineChart: React.FC = (props) => { select: { itemStyle: { borderWidth: 5, - shadowBlur: 5, - }, + shadowBlur: 5 + } }, emphasis: { itemStyle: { borderWidth: 5, - shadowBlur: 5, - }, + shadowBlur: 5 + } }, - selectedMode: 'single', + selectedMode: 'single' }, { type: 'line', @@ -271,116 +296,112 @@ export const LineChart: React.FC = (props) => { select: { itemStyle: { borderWidth: 5, - shadowBlur: 5, - }, + shadowBlur: 5 + } }, emphasis: { itemStyle: { borderWidth: 5, - shadowBlur: 5, - }, + shadowBlur: 5 + } }, - selectedMode: 'single', - }, - ], - }; + selectedMode: 'single' + } + ] + } } - if (option) { - myChart.setOption(option, true); - } + option && myChart.setOption(option, true) myChart.dispatchAction({ type: 'takeGlobalCursor', key: 'dataZoomSelect', - dataZoomSelectActive: true, - }); + dataZoomSelectActive: true + }) myChart.on('dataZoom', (param: any) => { if (onSelectionChanged) { - onSelectionChanged(param.batch[0].startValue, param.batch[0].endValue); + onSelectionChanged(param.batch[0].startValue, param.batch[0].endValue) } - }); + }) myChart.on('restore', () => { if (onSelectionChanged) { // Set startId greater than endId to query all memory events. - onSelectionChanged(0, -1); + onSelectionChanged(0, -1) } - }); + }) myChart.on('click', (param) => { myChart.dispatchAction({ type: 'unselect', seriesId: param.seriesId, - dataIndex: selectedPoints.current, - }); + dataIndex: selectedPoints.current + }) myChart.dispatchAction({ type: 'select', seriesId: param.seriesId, - dataIndex: param.dataIndex, - }); + dataIndex: param.dataIndex + }) - selectedPoints.current = [param.dataIndex]; - }); + selectedPoints.current = [param.dataIndex] + }) myChart.getZr().on('contextmenu', () => { myChart.dispatchAction({ - type: 'restore', - }); + type: 'restore' + }) myChart.dispatchAction({ type: 'takeGlobalCursor', key: 'dataZoomSelect', - dataZoomSelectActive: true, - }); - }); + dataZoomSelectActive: true + }) + }) - setChartObj(myChart); + setChartObj(myChart) return () => { - myChart.dispose(); - }; - }, [graph, height, resizeEventDependency]); + myChart.dispose() + } + }, [graph, height, resizeEventDependency]) React.useEffect(() => { - const compareFn = (key: number, mid: Array): number => key - mid[0]; + const compare_fn = (key: number, mid: Array) => key - mid[0] if (chartObj && tag === 'Operator') { if (record) { - let startId = -1; - let endId = -1; + let startId = -1 + let endId = -1 if (deviceTarget === 'Ascend') { - startId = binarySearch(graph.rows.Allocated, record.col2, compareFn); - endId = binarySearch(graph.rows.Allocated, record.col3, compareFn); + startId = binarySearch(graph.rows['Allocated'], record.col2, compare_fn) + endId = binarySearch(graph.rows['Allocated'], record.col3, compare_fn) } else { - startId = binarySearch(graph.rows, record.col2, compareFn); - endId = binarySearch(graph.rows, record.col3, compareFn); - } - let selection = []; - if (startId >= 0) { - selection.push(startId); - } - if (endId >= 0) { - selection.push(endId); + startId = binarySearch(graph.rows, record.col2, compare_fn) + endId = binarySearch(graph.rows, record.col3, compare_fn) } + let selection = [] + startId >= 0 && selection.push(startId) + endId >= 0 && selection.push(endId) chartObj.dispatchAction({ type: 'downplay', seriesName: 'Allocated', - dataIndex: selectedPoints.current, - }); + dataIndex: selectedPoints.current + }) chartObj.dispatchAction({ type: 'highlight', seriesName: 'Allocated', - dataIndex: selection, - }); - selectedPoints.current = selection; + dataIndex: selection + }) + selectedPoints.current = selection } else { chartObj.dispatchAction({ type: 'downplay', seriesName: 'Allocated', - dataIndex: selectedPoints.current, - }); - selectedPoints.current = []; + dataIndex: selectedPoints.current + }) + selectedPoints.current = [] } } - }, [graph, record, chartObj]); + }, [graph, record, chartObj]) - return
; -}; + return ( +
+ ) +} diff --git a/plugins/tensorboard-plugins/tb_plugin/fe/src/components/charts/PieChart.tsx b/plugins/tensorboard-plugins/tb_plugin/fe/src/components/charts/PieChart.tsx index 49c59ff02e91f7b7fe0d90ddff4239478ca19a0a..2c7ea1c1413ab932c226d1a919362a611a88d4ae 100644 --- a/plugins/tensorboard-plugins/tb_plugin/fe/src/components/charts/PieChart.tsx +++ b/plugins/tensorboard-plugins/tb_plugin/fe/src/components/charts/PieChart.tsx @@ -19,104 +19,83 @@ * Modifications: Offer offline supporting. *--------------------------------------------------------------------------------------------*/ -import * as React from 'react'; -import { Graph } from '../../api'; -import { value } from '../../utils'; -import { useResizeEventDependency } from '../../utils/resize'; -import * as echarts from 'echarts'; +import { makeStyles } from '@material-ui/core/styles' +import * as React from 'react' +import { Graph } from '../../api' +import { value } from '../../utils' +import { useResizeEventDependency } from '../../utils/resize' +import * as echarts from 'echarts' interface IProps { - graph: Graph; - height?: number; - top?: number; - noLegend?: boolean; - title?: string; - colors?: Array; - tooltipMode?: string; + graph: Graph + height?: number + top?: number + noLegend?: boolean + title?: string + colors?: Array + tooltip_mode?: string } -interface IAreaPosition { - left: string; - width: string; - top?: string; - height?: string; -} - -const noLegendArea: IAreaPosition = { - left: '5%', - width: '90%', - top: '5%', - height: '90%', -}; -const normalArea: IAreaPosition = { left: '5%', width: '95%' }; -const noTitleArea: IAreaPosition = { - left: '5%', - width: '95%', - top: '10%', - height: '80%', -}; +const noLegendArea = { left: '5%', width: '90%', top: '5%', height: '90%' } +const normalArea = { left: '5%', width: '95%' } +const noTitleArea = { left: '5%', width: '95%', top: '10%', height: '80%' } export const PieChart: React.FC = (props) => { - const { graph, height = 300, top, noLegend, title, colors, tooltipMode = 'both' } = props; - const graphRef = React.useRef(null); + const { + graph, + height = 300, + top, + noLegend, + title, + colors, + tooltip_mode = 'both' + } = props + const graphRef = React.useRef(null) - const [resizeEventDependency] = useResizeEventDependency(); + const [resizeEventDependency] = useResizeEventDependency() React.useLayoutEffect(() => { - const element = graphRef.current; - if (!element) { - return undefined; - } + const element = graphRef.current + if (!element) return - const chart = echarts.init(element); + const chart = echarts.init(element) - let totalValue = 0; - const rowsWithUniqueName: Array<{ name: string; value: number }> = + let totalValue = 0 + const rowsWithUniqueName: Array<{ name: string, value: number }> = top === undefined ? graph.rows.map((item, index) => { - totalValue += item[1] as number; - return { name: `${index}_${item[0]}`, value: item[1] as number }; - }) + totalValue += item[1] as number + return { name: `${index}_${item[0]}`, value: item[1] as number } + }) : graph.rows - .sort((a, b) => (value(b[1]) as number) - (value(a[1]) as number)) - .slice(0, top) - .map((item, index) => { - totalValue += item[1] as number; - return { name: `${index}_${item[0]}`, value: item[1] as number }; - }); + .sort((a, b) => (value(b[1]) as number) - (value(a[1]) as number)) + .slice(0, top).map((item, index) => { + totalValue += item[1] as number + return { name: `${index}_${item[0]}`, value: item[1] as number } + }) const option: echarts.EChartsOption = { height, width: '100%', title: { - text: title, + text: title }, tooltip: { trigger: 'item', formatter: (data) => { - const typedData = data as echarts.DefaultLabelFormatterCallbackParams; - const index = typedData.name.indexOf('_'); - const safeName = typedData.name.replace(//g, '>'); - return `${index > -1 ? safeName.slice(index + 1) : safeName}
${ - tooltipMode === 'both' ? typedData.value : '' - }(${typedData.percent}%)`; + const typedData = data as echarts.DefaultLabelFormatterCallbackParams + const index = typedData.name.indexOf('_') + const safeName = typedData.name.replace(//g, '>') + return `${index > -1 ? safeName.slice(index + 1) : safeName}
${tooltip_mode === 'both' ? + typedData.value : ''}(${typedData.percent}%)` }, confine: true, extraCssText: `max-width: 300px; word-wrap:break-word; white-space:pre-wrap; - padding-right: 10px`, + padding-right: 10px` }, - chartArea: ((): IAreaPosition => { - if (noLegend) { - return noLegendArea; - } - if (!title) { - return noTitleArea; - } else { - return normalArea; - } - })(), + chartArea: noLegend ? noLegendArea : !title ? noTitleArea : normalArea, legend: { type: noLegend ? 'plain' : 'scroll', orient: 'vertical', @@ -125,23 +104,24 @@ export const PieChart: React.FC = (props) => { // Display at most 36 characters. formatter: (name) => { // Show legends for datas with the same name. - const index = name.indexOf('_'); - const processedName = index > -1 ? name.slice(index + 1) : name; // 使用新变量处理 - return processedName.length > 36 ? `${processedName.slice(0, 34)}...` : processedName; + const index = name.indexOf('_') + if (index > -1) { + name = name.slice(index + 1) + } + return name.length > 36 ? name.slice(0, 34) + "..." : name; }, tooltip: { show: true, triggerOn: 'mousemove', formatter: (data) => { - const currentItem = rowsWithUniqueName.find((item) => item.name === data.name); - const index = data.name.indexOf('_'); - const percent = (((currentItem?.value || 0) * 100) / totalValue).toFixed(2); - const safeName = data.name.replace(//g, '>'); - return `${index > -1 ? safeName.slice(index + 1) : safeName}
${ - tooltipMode === 'both' ? currentItem?.value || 0 : '' - }(${percent}%)`; - }, - }, + const currentItem = rowsWithUniqueName.find(item => item.name === data.name) + const index = data.name.indexOf('_') + const percent = ((currentItem?.value || 0) * 100 / totalValue).toFixed(2) + const safeName = data.name.replace(//g, '>') + return `${index > -1 ? safeName.slice(index + 1) : + safeName}
${tooltip_mode === 'both' ? (currentItem?.value || 0) : ''}(${percent}%)` + } + } }, sliceVisibilityThreshold: 0, colors, @@ -153,21 +133,21 @@ export const PieChart: React.FC = (props) => { label: { position: 'inside', formatter: `{d}%`, - color: '#ffffff', + color: '#ffffff' }, - data: rowsWithUniqueName, - }, - ], - }; - - if (option) { - chart.setOption(option, true); + data: rowsWithUniqueName + } + ] } + option && chart.setOption(option, true) + return () => { - chart.dispose(); - }; - }, [graph, height, top, resizeEventDependency]); + chart.dispose() + } + }, [graph, height, top, resizeEventDependency]) - return
; -}; + return ( +
+ ) +} diff --git a/plugins/tensorboard-plugins/tb_plugin/fe/src/components/charts/SteppedAreaChart.tsx b/plugins/tensorboard-plugins/tb_plugin/fe/src/components/charts/SteppedAreaChart.tsx index 3e3b01ccb112aeb80795246bd6f3e2ad83aa2a66..bc38cc31747cd69e8fee7af4d55476f49bef9914 100644 --- a/plugins/tensorboard-plugins/tb_plugin/fe/src/components/charts/SteppedAreaChart.tsx +++ b/plugins/tensorboard-plugins/tb_plugin/fe/src/components/charts/SteppedAreaChart.tsx @@ -19,88 +19,84 @@ * Modifications: Offer offline supporting. *--------------------------------------------------------------------------------------------*/ -import { makeStyles } from '@material-ui/core/styles'; -import * as React from 'react'; -import { StepedGraph } from '../../api'; -import { useResizeEventDependency } from '../../utils/resize'; -import * as echarts from 'echarts'; +import { makeStyles } from '@material-ui/core/styles' +import * as React from 'react' +import { StepedGraph } from '../../api' +import { useResizeEventDependency } from '../../utils/resize' +import * as echarts from 'echarts' interface IProps { - graph: StepedGraph; - height?: number; - hAxisTitle?: string; - vAxisTitle?: string; + graph: StepedGraph + height?: number + hAxisTitle?: string + vAxisTitle?: string } const useStyles = makeStyles(() => ({ root: { - height: (props: Pick): number | undefined => props.height, - }, -})); + height: (props: Pick) => props.height + } +})) export const SteppedAreaChart: React.FC = (props) => { - const { graph, height = 400, hAxisTitle, vAxisTitle } = props; - const classes = useStyles({ height }); - const graphRef = React.useRef(null); - const [resizeEventDependency] = useResizeEventDependency(); + const { graph, height = 400, hAxisTitle, vAxisTitle } = props + const classes = useStyles({ height }) + const graphRef = React.useRef(null) + const [resizeEventDependency] = useResizeEventDependency() React.useLayoutEffect(() => { - const element = graphRef.current; - if (!element) { - return undefined; - } + const element = graphRef.current + if (!element) return - const chart = echarts.init(element); - const dataSource: Array> = []; - dataSource.push(graph.columns); + const chart = echarts.init(element) + const dataSource: Array> = [] + dataSource.push(graph.columns) graph.rows.forEach((row) => { - dataSource.push(row.map((item) => item.value)); - }); + dataSource.push(row.map(item => item.value)) + }) const options: echarts.EChartsOption = { title: { - text: graph.title, + text: graph.title }, legend: { - bottom: 0, + bottom: 0 }, xAxis: { type: 'category', name: hAxisTitle, axisLabel: { interval: 0, - }, + } }, yAxis: { type: 'value', - name: vAxisTitle, + name: vAxisTitle }, tooltip: { trigger: 'item', formatter: (params: any) => { - return graph.rows[params.dataIndex][params.seriesIndex + 1]?.tooltip || ''; - }, + return graph.rows[params.dataIndex][params.seriesIndex + 1]?.tooltip || '' + } }, dataset: { - source: dataSource, + source: dataSource }, series: Array(graph.columns.length - 1).fill({ type: 'bar', - stack: 'samesign', - }), - }; - - if (options) { - chart.setOption(options, true); + stack: 'samesign' + }) } + options && chart.setOption(options, true) + return () => { - chart.dispose(); - }; - }, [graph, height, resizeEventDependency]); + chart.dispose() + } + }, [graph, height, resizeEventDependency]) return (
- ); -}; + ) +} diff --git a/plugins/tensorboard-plugins/tb_plugin/fe/src/components/charts/TableChart.tsx b/plugins/tensorboard-plugins/tb_plugin/fe/src/components/charts/TableChart.tsx index 444b41b196c162340b846ac488d70eb908c7b717..267624c85e02e30e047ff50e7d126259b765c83e 100644 --- a/plugins/tensorboard-plugins/tb_plugin/fe/src/components/charts/TableChart.tsx +++ b/plugins/tensorboard-plugins/tb_plugin/fe/src/components/charts/TableChart.tsx @@ -2,54 +2,56 @@ * Copyright (c) Microsoft Corporation. All rights reserved. *--------------------------------------------------------------------------------------------*/ -import { makeStyles } from '@material-ui/core/styles'; -import * as React from 'react'; -import { Graph } from '../../api'; -import { useResizeEventDependency } from '../../utils/resize'; +import { makeStyles } from '@material-ui/core/styles' +import * as React from 'react' +import { Graph } from '../../api' +import { useResizeEventDependency } from '../../utils/resize' interface IProps { - graph: Graph; - sortColumn?: number; - height?: number; - allowHtml?: boolean; - setCellProperty?: (row: number, column: number, cb: (key: string, value: any) => void) => void; + graph: Graph + sortColumn?: number + height?: number + allowHtml?: boolean + setCellProperty?: ( + row: number, + column: number, + cb: (key: string, value: any) => void + ) => void } const useStyles = makeStyles(() => ({ root: { - height: (props: IProps): number | undefined => props.height, - }, -})); + height: (props: IProps) => props.height + } +})) export const TableChart: React.FC = (props) => { - const { graph, sortColumn, setCellProperty, allowHtml } = props; - const classes = useStyles(props); - const graphRef = React.useRef(null); - const [resizeEventDependency] = useResizeEventDependency(); + const { graph, sortColumn, setCellProperty, allowHtml } = props + const classes = useStyles(props) + const graphRef = React.useRef(null) + const [resizeEventDependency] = useResizeEventDependency() React.useLayoutEffect(() => { - const element = graphRef.current; - if (!element || !element.parentElement) { - return; - } + const element = graphRef.current + if (!element) return - const data = new google.visualization.DataTable(); + const data = new google.visualization.DataTable() graph.columns.forEach((column) => { data.addColumn({ type: column.type, label: column.name, role: column.role, - p: column.p, - }); - }); - data.addRows(graph.rows); + p: column.p + }) + }) + data.addRows(graph.rows) if (setCellProperty) { for (let row = 0; row < graph.rows.length; ++row) { for (let column = 0; column < graph.columns.length; ++column) { setCellProperty(row, column, (key: string, value: any) => { - data.setProperty(row, column, key, value); - }); + data.setProperty(row, column, key, value) + }) } } } @@ -62,24 +64,24 @@ export const TableChart: React.FC = (props) => { pageSize: 30, tooltip: { isHtml: true }, sortColumn: sortColumn, - sortAscending: false, - }; + sortAscending: false + } - const chart = new google.visualization.Table(element); + const chart = new google.visualization.Table(element) /* `chart.draw()` removes the contents of `element` and rebuilds it. This can cause a jump in the scroll position * if the height/width change to 0. Since we can't change the code of Google Charts, we temporarily lock the dims * of the parent container. */ if (element.offsetHeight > 0) { - element.parentElement.style.height = `${element.offsetHeight}px`; + element.parentElement!.style.height = element.offsetHeight + 'px' } - chart.draw(data, options); - element.parentElement.style.height = ''; - }, [graph, resizeEventDependency]); + chart.draw(data, options) + element.parentElement!.style.height = '' + }, [graph, resizeEventDependency]) return (
- ); -}; + ) +} diff --git a/plugins/tensorboard-plugins/tb_plugin/fe/src/components/helpers.tsx b/plugins/tensorboard-plugins/tb_plugin/fe/src/components/helpers.tsx index bfbb346e4b3daf65247e6e954346ed7245993f31..b787a5e91976a7f8f5839978276b35cf2a900cab 100644 --- a/plugins/tensorboard-plugins/tb_plugin/fe/src/components/helpers.tsx +++ b/plugins/tensorboard-plugins/tb_plugin/fe/src/components/helpers.tsx @@ -2,40 +2,48 @@ * Copyright (c) Microsoft Corporation. All rights reserved. *--------------------------------------------------------------------------------------------*/ -import { makeStyles } from '@material-ui/core/styles'; -import Tooltip from '@material-ui/core/Tooltip'; -import HelpOutline from '@material-ui/icons/HelpOutline'; -import clsx from 'clsx'; -import * as React from 'react'; +import { makeStyles } from '@material-ui/core/styles' +import Tooltip from '@material-ui/core/Tooltip' +import HelpOutline from '@material-ui/icons/HelpOutline' +import clsx from 'clsx' +import * as React from 'react' export const useTooltipCommonStyles = makeStyles((theme) => ({ tooltip: { maxWidth: '600px', whiteSpace: 'pre-wrap', - fontSize: '14px', + fontSize: '14px' }, cardTitle: { display: 'flex', - alignItems: 'center', + alignItems: 'center' }, titleText: { - marginRight: theme.spacing(0.5), + marginRight: theme.spacing(0.5) }, smallTitleText: { fontSize: '.8rem', - fontWeight: 'bold', - }, -})); + fontWeight: 'bold' + } +})) -export const makeChartHeaderRenderer = - (classes: ReturnType, smallTitleText = true) => - (title: string, tooltip: string): JSX.Element => { - return ( - - {title} - - - +export const makeChartHeaderRenderer = ( + classes: ReturnType, + smallTitleText = true +) => (title: string, tooltip: string) => { + return ( + + + {title} - ); - }; + + + + + ) +} diff --git a/plugins/tensorboard-plugins/tb_plugin/fe/src/components/tables/CallFrameList.tsx b/plugins/tensorboard-plugins/tb_plugin/fe/src/components/tables/CallFrameList.tsx index 0334d29e511399664d5204224e47cf1b88d50655..1e2a385bb634b3988142ada0d947adbb46c99715 100644 --- a/plugins/tensorboard-plugins/tb_plugin/fe/src/components/tables/CallFrameList.tsx +++ b/plugins/tensorboard-plugins/tb_plugin/fe/src/components/tables/CallFrameList.tsx @@ -2,25 +2,25 @@ * Copyright (c) Microsoft Corporation. All rights reserved. *--------------------------------------------------------------------------------------------*/ -import * as React from 'react'; -import { CallStackFrame } from './transform'; -import { List } from 'antd'; -import { NavToCodeButton } from './NavToCodeButton'; -import { makeStyles } from '@material-ui/core/styles'; +import * as React from 'react' +import { CallStackFrame } from './transform' +import { List } from 'antd' +import { NavToCodeButton } from './NavToCodeButton' +import { makeStyles } from '@material-ui/core/styles' interface IProps { - callFrames: CallStackFrame[]; + callFrames: CallStackFrame[] } const useStyles = makeStyles(() => ({ item: { paddingTop: '1px !important', - paddingBottom: '1px !important', - }, -})); + paddingBottom: '1px !important' + } +})) -export const CallFrameList = (props: IProps): React.JSX.Element => { - const classes = useStyles(); +export const CallFrameList = (props: IProps) => { + const classes = useStyles() const renderItem = React.useCallback( (item: CallStackFrame) => ( @@ -29,7 +29,14 @@ export const CallFrameList = (props: IProps): React.JSX.Element => { ), [classes.item] - ); + ) - return ; -}; + return ( + + ) +} diff --git a/plugins/tensorboard-plugins/tb_plugin/fe/src/components/tables/CallStackTable.tsx b/plugins/tensorboard-plugins/tb_plugin/fe/src/components/tables/CallStackTable.tsx index c3176428d11b8b40c691947b2f0da8fc15674c16..359d7c9028aaeb7497e0a8aa1baba8fa6d8768c1 100644 --- a/plugins/tensorboard-plugins/tb_plugin/fe/src/components/tables/CallStackTable.tsx +++ b/plugins/tensorboard-plugins/tb_plugin/fe/src/components/tables/CallStackTable.tsx @@ -15,89 +15,99 @@ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. - * + * * Modifications: Add visualization of PyTorch Ascend profiling. *--------------------------------------------------------------------------------------------*/ -import * as React from 'react'; -import { makeStyles } from '@material-ui/core/styles'; -import { CallStackTableData, OperationTableDataInner } from '../../api'; -import { Table, TableProps } from 'antd'; +import * as React from 'react' +import { makeStyles } from '@material-ui/core/styles' +import { CallStackTableData, OperationTableDataInner } from '../../api' +import { Table, TableProps } from 'antd' -import * as api from '../../api'; -import { transformTableData, TransformedCallStackDataInner } from './transform'; -import { attachId, getCommonOperationColumns } from './common'; -import { OperationGroupBy } from '../../constants/groupBy'; -import { makeExpandIcon } from './ExpandIcon'; -import { CallFrameList } from './CallFrameList'; +import * as api from '../../api' +import { transformTableData, TransformedCallStackDataInner } from './transform' +import { attachId, getCommonOperationColumns } from './common' +import { OperationGroupBy } from '../../constants/groupBy' +import { makeExpandIcon } from './ExpandIcon' +import { CallFrameList } from './CallFrameList' export interface IProps { - data: OperationTableDataInner; - run: string; - worker: string; - span: string; - groupBy: OperationGroupBy; - deviceTarget: string; + data: OperationTableDataInner + run: string + worker: string + span: string + groupBy: OperationGroupBy + deviceTarget: string } const useStyles = makeStyles((theme) => ({ tooltip: { - whiteSpace: 'pre-wrap', - }, -})); + whiteSpace: 'pre-wrap' + } +})) const expandIcon = makeExpandIcon( 'View call frames', (record) => !record.callStackFrames.length -); +) -const rowExpandable = (record: TransformedCallStackDataInner): boolean => !!record.callStackFrames.length; -const expandedRowRender = (record: TransformedCallStackDataInner): React.JSX.Element => ( +const rowExpandable = (record: TransformedCallStackDataInner) => + !!record.callStackFrames.length +const expandedRowRender = (record: TransformedCallStackDataInner) => ( -); +) -export const CallStackTable = (props: IProps): React.JSX.Element => { - const { data, run, worker, span, groupBy, deviceTarget } = props; - const { name, input_shape } = data; - const classes = useStyles(props); +export const CallStackTable = (props: IProps) => { + const { data, run, worker, span, groupBy, deviceTarget } = props + const { name, input_shape } = data + const classes = useStyles(props) - const [stackData, setStackData] = React.useState(undefined); - const [tooltips, setTooltips] = React.useState(); + const [stackData, setStackData] = React.useState< + CallStackTableData | undefined + >(undefined) + const [tooltips, setTooltips] = React.useState() React.useEffect(() => { - api.defaultApi.operationStackGet(run, worker, span, groupBy, name, input_shape).then((resp) => { - setTooltips(resp.metadata.tooltips); - setStackData(resp.data); - }); - }, [name, input_shape, run, worker, span, groupBy]); + api.defaultApi + .operationStackGet(run, worker, span, groupBy, name, input_shape) + .then((resp) => { + setTooltips(resp.metadata.tooltips) + setStackData(resp.data) + }) + }, [name, input_shape, run, worker, span, groupBy]) - const transformedData = React.useMemo(() => stackData && transformTableData(attachId(stackData)), [stackData]); + const transformedData = React.useMemo( + () => stackData && transformTableData(attachId(stackData)), + [stackData] + ) const columns = React.useMemo( - () => transformedData && getCommonOperationColumns(transformedData, deviceTarget, undefined, tooltips, classes), + () => + transformedData && + getCommonOperationColumns(transformedData, deviceTarget, undefined, tooltips, classes), [transformedData] - ); + ) - const expandIconColumnIndex = columns?.length; + const expandIconColumnIndex = columns?.length const expandable: TableProps['expandable'] = React.useMemo( () => ({ expandIconColumnIndex, expandIcon, expandedRowRender, - rowExpandable, + rowExpandable }), [expandIconColumnIndex] - ); + ) return (
- ); -}; + ) +} diff --git a/plugins/tensorboard-plugins/tb_plugin/fe/src/components/tables/ExpandIcon.tsx b/plugins/tensorboard-plugins/tb_plugin/fe/src/components/tables/ExpandIcon.tsx index 422bb781630c24c6dc4915c3aed8c1f341dba363..68ff482827679d9c51c1ca0178b256dc5ae39581 100644 --- a/plugins/tensorboard-plugins/tb_plugin/fe/src/components/tables/ExpandIcon.tsx +++ b/plugins/tensorboard-plugins/tb_plugin/fe/src/components/tables/ExpandIcon.tsx @@ -2,34 +2,33 @@ * Copyright (c) Microsoft Corporation. All rights reserved. *--------------------------------------------------------------------------------------------*/ -import * as React from 'react'; -import { Button, TableProps } from 'antd'; -import { OperationTableDataInner, CallStackTableDataInner } from '../../api'; -import { Arguments } from '../../utils/type'; +import * as React from 'react' +import { Button, TableProps } from 'antd' +import { OperationTableDataInner, CallStackTableDataInner } from '../../api' +import { Arguments } from '../../utils/type' -type Types = NonNullable['expandable']>['expandIcon']; -type BasePropType = Arguments>>[0]; -type PropType = BasePropType & { text: string; disabled?: boolean }; +type Types = NonNullable['expandable']>['expandIcon'] +type BasePropType = Arguments>>[0] +type PropType = BasePropType & { text: string; disabled?: boolean } -export function ExpandIcon( - props: PropType -): React.JSX.Element { - const onClick = (e: React.MouseEvent): void => { - props.onExpand(props.record, e); - }; +export function ExpandIcon< + T extends OperationTableDataInner | CallStackTableDataInner +>(props: PropType) { + const onClick = (e: React.MouseEvent) => { + props.onExpand(props.record, e) + } return ( - - ); + ) } -export function makeExpandIcon( - text: string, - disabled?: (v: T) => boolean -) { - return (props: BasePropType): React.JSX.Element => ( +export function makeExpandIcon< + T extends OperationTableDataInner | CallStackTableDataInner +>(text: string, disabled?: (v: T) => boolean) { + return (props: BasePropType) => ( - ); + ) } diff --git a/plugins/tensorboard-plugins/tb_plugin/fe/src/components/tables/MemoryStatsTable.tsx b/plugins/tensorboard-plugins/tb_plugin/fe/src/components/tables/MemoryStatsTable.tsx index c7e1809a3c0b58297ca99066243cf7d65fbe4c8c..0b33ab4167ba11e9bb610d7ebc0717def2addda2 100644 --- a/plugins/tensorboard-plugins/tb_plugin/fe/src/components/tables/MemoryStatsTable.tsx +++ b/plugins/tensorboard-plugins/tb_plugin/fe/src/components/tables/MemoryStatsTable.tsx @@ -2,76 +2,84 @@ * Copyright (c) Microsoft Corporation. All rights reserved. *--------------------------------------------------------------------------------------------*/ -import * as React from 'react'; -import { Table } from 'antd'; -import { makeStyles } from '@material-ui/core'; +import * as React from 'react' +import { Table } from 'antd' +import { makeStyles } from '@material-ui/core' export interface IProps { - data: any; - sort: string; + data: any + sort: string } const useStyles = makeStyles((theme) => ({ tooltip: { - whiteSpace: 'pre-wrap', - }, -})); + whiteSpace: 'pre-wrap' + } +})) -const getMemoryStatsTableColumns = function (columns: any, sort: string, tooltipClass: string): any { - let i = 0; - return columns.map((col: any) => { - const key = `col${i++}`; - const stringCompare = (a: any, b: any): number => a[key].localeCompare(b[key]); - const numberCompare = (a: any, b: any): number => (a[key] || 0) - (b[key] || 0); +const getMemoryStatsTableColumns = function ( + columns: any, + sort: string, + tooltipClass: string +) { + let i = 0 + return columns.map(function (col: any) { + const key = 'col' + i++ + const stringCompare = (a: any, b: any) => a[key].localeCompare(b[key]) + const numberCompare = (a: any, b: any) => (a[key] || 0) - (b[key] || 0) return { dataIndex: key, key: key, title: col.name, - sorter: col.type === 'string' ? stringCompare : numberCompare, - defaultSortOrder: sort === col.name ? ('descend' as const) : undefined, - showSorterTooltip: col.tooltip ? { title: col.tooltip, overlayClassName: tooltipClass } : true, - }; - }); -}; + sorter: col.type == 'string' ? stringCompare : numberCompare, + defaultSortOrder: sort == col.name ? ('descend' as const) : undefined, + showSorterTooltip: col.tooltip + ? { title: col.tooltip, overlayClassName: tooltipClass } + : true + } + }) +} -const getMemoryStatsTableRows = function (rows: any): any { - return rows.map((row: any) => { - let i = 0; - const res: any = {}; - row.forEach((entry: any) => { - res[`col${i++}`] = entry; - }); - return res; - }); -}; +const getMemoryStatsTableRows = function (rows: any) { + return rows.map(function (row: any) { + let i = 0 + const res: any = {} + row.forEach(function (entry: any) { + res['col' + i++] = entry + }) + return res + }) +} -export const MemoryStatsTable = (props: IProps): React.JSX.Element => { - const { data, sort } = props; - const classes = useStyles(); +export const MemoryStatsTable = (props: IProps) => { + const { data, sort } = props + const classes = useStyles() - const rows = React.useMemo(() => getMemoryStatsTableRows(data.rows), [data.rows]); + const rows = React.useMemo(() => getMemoryStatsTableRows(data.rows), [ + data.rows + ]) const columns = React.useMemo( () => getMemoryStatsTableColumns(data.columns, sort, classes.tooltip), [data.columns, sort, classes.tooltip] - ); + ) - const [pageSize, setPageSize] = React.useState(30); - const onShowSizeChange = (current: number, size: number): void => { - setPageSize(size); - }; + const [pageSize, setPageSize] = React.useState(30) + const onShowSizeChange = (current: number, size: number) => { + setPageSize(size) + } return (
- ); -}; + ) +} diff --git a/plugins/tensorboard-plugins/tb_plugin/fe/src/components/tables/NavToCodeButton.tsx b/plugins/tensorboard-plugins/tb_plugin/fe/src/components/tables/NavToCodeButton.tsx index 2c999aa12a49726aad12321f260b31b6f331eda2..fb40e7f38bf5ccbe89851b5fe2d0b684af71239a 100644 --- a/plugins/tensorboard-plugins/tb_plugin/fe/src/components/tables/NavToCodeButton.tsx +++ b/plugins/tensorboard-plugins/tb_plugin/fe/src/components/tables/NavToCodeButton.tsx @@ -2,28 +2,28 @@ * Copyright (c) Microsoft Corporation. All rights reserved. *--------------------------------------------------------------------------------------------*/ -import * as React from 'react'; -import { CallStackFrame } from './transform'; -import { Button } from 'antd'; -import { navToCode } from '../../utils/vscode'; +import * as React from 'react' +import { CallStackFrame } from './transform' +import { Button } from 'antd' +import { navToCode } from '../../utils/vscode' interface IProps { - frame: CallStackFrame; + frame: CallStackFrame } -export const NavToCodeButton = (props: IProps): React.JSX.Element => { - const { raw, line, file } = props.frame; - const couldNavToFile = line && file; +export const NavToCodeButton = (props: IProps) => { + const { raw, line, file } = props.frame + const couldNavToFile = line && file - const onClick = (): void => { + const onClick = () => { if (line && file) { - navToCode(file, line - 1); + navToCode(file, line - 1) } - }; + } return ( - - ); -}; + ) +} diff --git a/plugins/tensorboard-plugins/tb_plugin/fe/src/components/tables/OperationTable.tsx b/plugins/tensorboard-plugins/tb_plugin/fe/src/components/tables/OperationTable.tsx index 1ce77ee817967ee69961ccd8c91dbc3b0357bed7..799b8497a04cce30dfc248b380bf477eab85909a 100644 --- a/plugins/tensorboard-plugins/tb_plugin/fe/src/components/tables/OperationTable.tsx +++ b/plugins/tensorboard-plugins/tb_plugin/fe/src/components/tables/OperationTable.tsx @@ -15,55 +15,62 @@ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. - * + * * Modifications: Add visualization of PyTorch Ascend profiling. *--------------------------------------------------------------------------------------------*/ -import * as React from 'react'; -import { makeStyles } from '@material-ui/core/styles'; -import { OperationTableData, OperationTableDataInner, TableMetadata } from '../../api'; -import { OperationGroupBy } from '../../constants/groupBy'; -import { attachId, getCommonOperationColumns } from './common'; -import { Table, TableProps } from 'antd'; -import { makeExpandIcon } from './ExpandIcon'; -import { CallStackTable } from './CallStackTable'; +import * as React from 'react' +import { makeStyles } from '@material-ui/core/styles' +import { + OperationTableData, + OperationTableDataInner, + TableMetadata +} from '../../api' +import { OperationGroupBy } from '../../constants/groupBy' +import { attachId, getCommonOperationColumns } from './common' +import { Table, TablePaginationConfig, TableProps } from 'antd' +import { makeExpandIcon } from './ExpandIcon' +import { CallStackTable } from './CallStackTable' export interface IProps { - data: OperationTableData; - run: string; - worker: string; - span: string; - groupBy: OperationGroupBy; - sortColumn: string; - tooltips?: any; - deviceTarget: string; + data: OperationTableData + run: string + worker: string + span: string + groupBy: OperationGroupBy + sortColumn: string + tooltips?: any + deviceTarget: string } const useStyles = makeStyles((theme) => ({ tooltip: { - whiteSpace: 'pre-wrap', - }, -})); + whiteSpace: 'pre-wrap' + } +})) -const rowExpandable = (record: OperationTableDataInner): boolean => record.has_call_stack; -const expandIcon = makeExpandIcon('View CallStack', (record) => !record.has_call_stack); -export const OperationTable = (props: IProps): React.JSX.Element => { - const { data, run, worker, span, groupBy, sortColumn, tooltips, deviceTarget } = props; - const classes = useStyles(props); +const rowExpandable = (record: OperationTableDataInner) => record.has_call_stack +const expandIcon = makeExpandIcon( + 'View CallStack', + (record) => !record.has_call_stack +) +export const OperationTable = (props: IProps) => { + const { data, run, worker, span, groupBy, sortColumn, tooltips, deviceTarget } = props + const classes = useStyles(props) - const rows = React.useMemo(() => attachId(data), [data]); + const rows = React.useMemo(() => attachId(data), [data]) const columns = React.useMemo( () => getCommonOperationColumns(rows, deviceTarget, sortColumn, tooltips, classes), [rows] - ); + ) - const [pageSize, setPageSize] = React.useState(30); - const onShowSizeChange = (current: number, size: number): void => { - setPageSize(size); - }; + const [pageSize, setPageSize] = React.useState(30) + const onShowSizeChange = (current: number, size: number) => { + setPageSize(size) + } - const expandIconColumnIndex = columns.length; + const expandIconColumnIndex = columns.length const expandedRowRender = React.useCallback( (record: OperationTableDataInner) => ( { /> ), [run, worker, span, groupBy] - ); + ) const expandable: TableProps['expandable'] = React.useMemo( () => ({ expandIconColumnIndex, expandIcon, expandedRowRender, - rowExpandable, + rowExpandable }), [expandIconColumnIndex, expandedRowRender] - ); + ) return (
- ); -}; + ) +} diff --git a/plugins/tensorboard-plugins/tb_plugin/fe/src/components/tables/common.tsx b/plugins/tensorboard-plugins/tb_plugin/fe/src/components/tables/common.tsx index a84a1a3bb3ff96fd5df257af51bdcd302dc318e2..a6f1770e7424539d916c01abef122808291d86a6 100644 --- a/plugins/tensorboard-plugins/tb_plugin/fe/src/components/tables/common.tsx +++ b/plugins/tensorboard-plugins/tb_plugin/fe/src/components/tables/common.tsx @@ -15,136 +15,147 @@ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. - * + * * Modifications: Add visualization of PyTorch Ascend profiling. *--------------------------------------------------------------------------------------------*/ -import { firstOrUndefined, isDef } from '../../utils/def'; -import { CallStackTableDataInner, OperationTableDataInner } from '../../api'; -import type { ColumnsType } from 'antd/es/table'; -import { ClassNameMap } from '@material-ui/styles'; +import { firstOrUndefined, isDef } from '../../utils/def' +import { CallStackTableDataInner, OperationTableDataInner } from '../../api' +import type { ColumnsType } from 'antd/es/table' +import { ClassNameMap } from '@material-ui/styles' -export function getCommonOperationColumns( - data?: T[], +export function getCommonOperationColumns< + T extends OperationTableDataInner | CallStackTableDataInner +>( + data: T[] | undefined, deviceTarget?: string, defaultSort?: string, tooltips?: any, classes?: ClassNameMap<'tooltip'> ): ColumnsType { - const firstData = firstOrUndefined(data); + const firstData = firstOrUndefined(data) - const hasInputShape = !firstData || isDef(firstData.input_shape); - const hasDeviceSelfDuration = !firstData || isDef(firstData.device_self_duration); - const hasDeviceTotalDuration = !firstData || isDef(firstData.device_total_duration); - const hasTcEligible = !firstData || isDef(firstData.tc_eligible); - const hasTcSelfRatio = !firstData || isDef(firstData.tc_self_ratio); - const hasTcTotalRatio = !firstData || isDef(firstData.tc_total_ratio); + const hasInputShape = !firstData || isDef(firstData.input_shape) + const hasDeviceSelfDuration = + !firstData || isDef(firstData.device_self_duration) + const hasDeviceTotalDuration = + !firstData || isDef(firstData.device_total_duration) + const hasTcEligible = !firstData || isDef(firstData.tc_eligible) + const hasTcSelfRatio = !firstData || isDef(firstData.tc_self_ratio) + const hasTcTotalRatio = !firstData || isDef(firstData.tc_total_ratio) - const nameCompare = (a: T, b: T): number => a.name.localeCompare(b.name); - const callsCompare = (a: T, b: T): number => a.calls - b.calls; - const deviceSelfDurationCompare = (a: T, b: T): number => - (a.device_self_duration || 0) - (b.device_self_duration || 0); - const deviceTotalDurationCompare = (a: T, b: T): number => - (a.device_total_duration || 0) - (b.device_total_duration || 0); - const hostSelfDurationCompare = (a: T, b: T): number => (a.host_self_duration || 0) - (b.host_self_duration || 0); - const hostTotalDurationCompare = (a: T, b: T): number => (a.host_total_duration || 0) - (b.host_total_duration || 0); - const tcEligibleCompare = (a: T, b: T): number => (a.tc_eligible ?? '').localeCompare(b.tc_eligible ?? ''); - const tcSelfRatioCompare = (a: T, b: T): number => (a.tc_self_ratio || 0) - (b.tc_self_ratio || 0); - const tcTotalRatioCompare = (a: T, b: T): number => (a.tc_total_ratio || 0) - (b.tc_total_ratio || 0); + const nameCompare = (a: T, b: T) => a.name.localeCompare(b.name) + const callsCompare = (a: T, b: T) => a.calls - b.calls + const deviceSelfDurationCompare = (a: T, b: T) => + (a.device_self_duration || 0) - (b.device_self_duration || 0) + const deviceTotalDurationCompare = (a: T, b: T) => + (a.device_total_duration || 0) - (b.device_total_duration || 0) + const hostSelfDurationCompare = (a: T, b: T) => + (a.host_self_duration || 0) - (b.host_self_duration || 0) + const hostTotalDurationCompare = (a: T, b: T) => + (a.host_total_duration || 0) - (b.host_total_duration || 0) + const tcEligibleCompare = (a: T, b: T) => + a.tc_eligible!.localeCompare(b.tc_eligible!) + const tcSelfRatioCompare = (a: T, b: T) => + (a.tc_self_ratio || 0) - (b.tc_self_ratio || 0) + const tcTotalRatioCompare = (a: T, b: T) => + (a.tc_total_ratio || 0) - (b.tc_total_ratio || 0) const columns: ColumnsType = [ { dataIndex: 'name', key: 'name', title: 'Name', - sorter: nameCompare, + sorter: nameCompare }, hasInputShape ? { - dataIndex: 'input_shape', - key: 'input_shape', - title: 'Input Shape', - } + dataIndex: 'input_shape', + key: 'input_shape', + title: 'Input Shape' + } : undefined, { dataIndex: 'calls', sorter: callsCompare, key: 'calls', - title: 'Calls', + title: 'Calls' }, hasDeviceSelfDuration ? { - dataIndex: 'device_self_duration', - key: 'device_self_duration', - title: 'Device Self Duration (us)', - sorter: deviceSelfDurationCompare, - // Use device_self_duration as default sort if defaultSort is unspecified - defaultSortOrder: defaultSort ? undefined : ('descend' as const), - } + dataIndex: 'device_self_duration', + key: 'device_self_duration', + title: 'Device Self Duration (us)', + sorter: deviceSelfDurationCompare, + // Use device_self_duration as default sort if defaultSort is unspecified + defaultSortOrder: defaultSort ? undefined : ('descend' as const) + } : undefined, hasDeviceTotalDuration ? { - dataIndex: 'device_total_duration', - key: 'device_total_duration', - title: 'Device Total Duration (us)', - sorter: deviceTotalDurationCompare, - } + dataIndex: 'device_total_duration', + key: 'device_total_duration', + title: 'Device Total Duration (us)', + sorter: deviceTotalDurationCompare + } : undefined, { dataIndex: 'host_self_duration', key: 'host_self_duration', title: 'Host Self Duration (us)', - sorter: hostSelfDurationCompare, + sorter: hostSelfDurationCompare }, { dataIndex: 'host_total_duration', key: 'host_total_duration', title: 'Host Total Duration (us)', - sorter: hostTotalDurationCompare, + sorter: hostTotalDurationCompare }, hasTcEligible ? { - dataIndex: 'tc_eligible', - key: 'tc_eligible', - title: deviceTarget === 'Ascend' ? 'AI Cores Eligible' : 'Tensor Cores Eligible', - sorter: tcEligibleCompare, - } + dataIndex: 'tc_eligible', + key: 'tc_eligible', + title: deviceTarget === 'Ascend' ? 'AI Cores Eligible' : 'Tensor Cores Eligible', + sorter: tcEligibleCompare + } : undefined, hasTcSelfRatio ? { - dataIndex: 'tc_self_ratio', - key: 'tc_self_ratio', - title: deviceTarget === 'Ascend' ? 'AI Cores Self(%)' : 'Tensor Cores Self(%)', - sorter: tcSelfRatioCompare, - } + dataIndex: 'tc_self_ratio', + key: 'tc_self_ratio', + title: deviceTarget === 'Ascend' ? 'AI Cores Self(%)' : 'Tensor Cores Self(%)', + sorter: tcSelfRatioCompare + } : undefined, hasTcTotalRatio ? { - dataIndex: 'tc_total_ratio', - key: 'tc_total_ratio', - title: deviceTarget === 'Ascend' ? 'AI Cores Total(%)' : 'Tensor Cores Total(%)', - sorter: tcTotalRatioCompare, - } - : undefined, - ].filter(isDef); + dataIndex: 'tc_total_ratio', + key: 'tc_total_ratio', + title: deviceTarget === 'Ascend' ? 'AI Cores Total(%)' : 'Tensor Cores Total(%)', + sorter: tcTotalRatioCompare + } + : undefined + ].filter(isDef) columns.forEach((column) => { - if (column.key === defaultSort) { - column.defaultSortOrder = 'descend' as const; + if (column.key == defaultSort) { + column.defaultSortOrder = 'descend' as const } if (tooltips[column.key as string]) { column.showSorterTooltip = { title: tooltips[column.key as string], - overlayClassName: classes?.tooltip, - }; + overlayClassName: classes?.tooltip + } } - }); - return columns; + }) + return columns } -let uid = 1; -export function attachId(data: T[]): T[] { +let uid = 1 +export function attachId< + T extends CallStackTableDataInner | OperationTableDataInner +>(data: T[]): T[] { return data.map((d) => ({ ...d, - key: uid++, - })); + key: uid++ + })) } diff --git a/plugins/tensorboard-plugins/tb_plugin/fe/src/components/tables/transform.ts b/plugins/tensorboard-plugins/tb_plugin/fe/src/components/tables/transform.ts index 5f59728feb30ef6d3230c3eec9803b08cdd72779..bd051fd429d5cb26a44a59b60f776b207a861d64 100644 --- a/plugins/tensorboard-plugins/tb_plugin/fe/src/components/tables/transform.ts +++ b/plugins/tensorboard-plugins/tb_plugin/fe/src/components/tables/transform.ts @@ -2,49 +2,49 @@ * Copyright (c) Microsoft Corporation. All rights reserved. *--------------------------------------------------------------------------------------------*/ -import { CallStackTableData, CallStackTableDataInner } from '../../api'; +import { CallStackTableData, CallStackTableDataInner } from '../../api' export interface CallStackFrame { - file?: string; - line?: number; - raw: string; + file?: string + line?: number + raw: string } export interface TransformedCallStackDataInner extends CallStackTableDataInner { - callStackFrames: CallStackFrame[]; + callStackFrames: CallStackFrame[] } -const lineRegex = /\([0-9]+\)$/; +const lineRegex = /\([0-9]+\)$/ function parseCallStackLine(raw: string): CallStackFrame { - let rawResult = raw.trim(); - const results = rawResult.split(':'); - const location = results.slice(0, results.length - 1).join(':'); + raw = raw.trim() + const results = raw.split(':') + const location = results.slice(0, results.length - 1).join(':') - const result = lineRegex.exec(location); + const result = lineRegex.exec(location) if (!result) { - return { raw: rawResult }; + return { raw } } - const lineWithParens = result[0].trim(); - const file = rawResult.slice(0, result.index).trim(); + const lineWithParens = result[0].trim() + const file = raw.slice(0, result.index).trim() const line = Number( lineWithParens.substr(1, lineWithParens.length - 2).trim() - ); + ) return { - raw: rawResult, + raw, file, - line, - }; + line + } } -function parseCallStack(callStack?: string): CallStackFrame[] { +function parseCallStack(callStack: string | undefined): CallStackFrame[] { const lines = (callStack ?? '') .trim() .split(';') - .map((x) => x.trim()); - return lines.map(parseCallStackLine); + .map((x) => x.trim()) + return lines.map(parseCallStackLine) } function transformCallStackData( @@ -52,12 +52,12 @@ function transformCallStackData( ): TransformedCallStackDataInner { return { ...data, - callStackFrames: parseCallStack(data.call_stack), - }; + callStackFrames: parseCallStack(data.call_stack) + } } export function transformTableData( data: CallStackTableData ): TransformedCallStackDataInner[] { - return data.map(transformCallStackData); + return data.map(transformCallStackData) } diff --git a/plugins/tensorboard-plugins/tb_plugin/fe/src/components/transform.ts b/plugins/tensorboard-plugins/tb_plugin/fe/src/components/transform.ts index 94ee9f384ebde3a3ddb057c88fc42beb69b0c908..08dcb25a20daf1868cc4ff2ea6245f444330b93f 100644 --- a/plugins/tensorboard-plugins/tb_plugin/fe/src/components/transform.ts +++ b/plugins/tensorboard-plugins/tb_plugin/fe/src/components/transform.ts @@ -2,82 +2,81 @@ * Copyright (c) Microsoft Corporation. All rights reserved. *--------------------------------------------------------------------------------------------*/ -import * as api from '../api'; -import { assertDef, isDef } from '../utils/def'; +import * as api from '../api' +import { assertDef, isDef } from '../utils/def' -export function transformPerformanceIntoTable(performances: api.Performance[]): api.Graph { +export function transformPerformanceIntoTable( + performances: api.Performance[] +): api.Graph { const columns: api.GraphColumn[] = [ { type: 'string', name: 'Category' }, { type: 'number', name: 'Time Duration (us)' }, - { type: 'number', name: 'Percentage (%)' }, - ]; + { type: 'number', name: 'Percentage (%)' } + ] - const rows: api.Graph['rows'] = []; - const queue = [...performances]; + const rows: api.Graph['rows'] = [] + const queue = [...performances] while (queue.length) { - const first = queue.shift(); - assertDef(first); + const first = queue.shift() + assertDef(first) - const row: api.Graph['rows'][number] = []; - const { name, value, extra, children } = first; - assertDef(value); - assertDef(extra); + const row: api.Graph['rows'][number] = [] + const { name, value, extra, children } = first + assertDef(value) + assertDef(extra) - row.push(name); - row.push(value); - row.push(extra); + row.push(name) + row.push(value) + row.push(extra) if (isDef(children) && children.length) { - queue.push(...children); + queue.push(...children) } - rows.push(row); + rows.push(row) } return { columns, - rows, - }; + rows + } } -export function transformPerformanceIntoPie(performances: api.Performance[]): { - columns: api.GraphColumn[]; - rows: Array>; -} { +export function transformPerformanceIntoPie(performances: api.Performance[]) { const columns: api.GraphColumn[] = [ { type: 'string', name: 'Name' }, - { type: 'number', name: 'Value' }, - ]; + { type: 'number', name: 'Value' } + ] - const rows: api.Graph['rows'] = []; - const queue: api.Performance[] = []; + const rows: api.Graph['rows'] = [] + const queue: api.Performance[] = [] performances.forEach((topLevel) => { if (topLevel.children) { - queue.push(...topLevel.children); + queue.push(...topLevel.children) } - }); + }) while (queue.length) { - const first = queue.shift(); - assertDef(first); + const first = queue.shift() + assertDef(first) - const row: api.Graph['rows'][number] = []; - const { name, value, children } = first; - assertDef(value); + const row: api.Graph['rows'][number] = [] + const { name, value, children } = first + assertDef(value) - row.push(name); - row.push(Number.parseInt(value, 10)); + row.push(name) + row.push(Number.parseInt(value, 10)) if (isDef(children) && children.length) { - queue.push(...children); + queue.push(...children) } - rows.push(row); + rows.push(row) } return { columns, - rows, - }; + rows + } } diff --git a/plugins/tensorboard-plugins/tb_plugin/fe/src/constants/groupBy.ts b/plugins/tensorboard-plugins/tb_plugin/fe/src/constants/groupBy.ts index 88ea9e3f42adfecd2a829384cc78b7ddc88d11aa..2b96c6b8dd3a0f1127f2617b72934d65c89f01f0 100644 --- a/plugins/tensorboard-plugins/tb_plugin/fe/src/constants/groupBy.ts +++ b/plugins/tensorboard-plugins/tb_plugin/fe/src/constants/groupBy.ts @@ -3,11 +3,11 @@ *--------------------------------------------------------------------------------------------*/ export enum OperationGroupBy { - OPERATION = 'Operation', - OPERATION_AND_INPUT_SHAPE = 'OperationAndInputShape', + Operation = 'Operation', + OperationAndInputShape = 'OperationAndInputShape' } export enum KernelGroupBy { - KERNEL = 'Kernel', - KERNEL_NAME_AND_OP_NAME = 'KernelNameAndOpName', + Kernel = 'Kernel', + KernelNameAndOpName = 'KernelNameAndOpName' } diff --git a/plugins/tensorboard-plugins/tb_plugin/fe/src/gstatic.d.ts b/plugins/tensorboard-plugins/tb_plugin/fe/src/gstatic.d.ts index 521c5fbb8d985136529d8233f8a65dffb8acca95..646255c2cdc20595fc0166b8cd5ce4743549bd2c 100644 --- a/plugins/tensorboard-plugins/tb_plugin/fe/src/gstatic.d.ts +++ b/plugins/tensorboard-plugins/tb_plugin/fe/src/gstatic.d.ts @@ -2,5 +2,5 @@ * Copyright (c) Microsoft Corporation. All rights reserved. *--------------------------------------------------------------------------------------------*/ -declare const google: any; -declare module 'react-flame-graph'; +declare const google: any +declare module 'react-flame-graph' diff --git a/plugins/tensorboard-plugins/tb_plugin/fe/src/index.tsx b/plugins/tensorboard-plugins/tb_plugin/fe/src/index.tsx index 851474766de5d9adee682e66ed752c85ffd6d4bf..224f37a5fd066414815caf9e83b15298364fd2bd 100644 --- a/plugins/tensorboard-plugins/tb_plugin/fe/src/index.tsx +++ b/plugins/tensorboard-plugins/tb_plugin/fe/src/index.tsx @@ -2,9 +2,9 @@ * Copyright (c) Microsoft Corporation. All rights reserved. *--------------------------------------------------------------------------------------------*/ -import * as React from 'react'; -import { render } from 'react-dom'; -import { App } from './app'; -import 'antd/dist/antd.css'; +import * as React from 'react' +import { render } from 'react-dom' +import { App } from './app' +import 'antd/dist/antd.css' -render(, document.getElementById('app')); +render(, document.getElementById('app')) diff --git a/plugins/tensorboard-plugins/tb_plugin/fe/src/setup.tsx b/plugins/tensorboard-plugins/tb_plugin/fe/src/setup.tsx index c811ae1524ec7cc6f82410e8aeb999f2ea22476b..5db44e8243119c7988ef33007e2eb3134fe6e857 100644 --- a/plugins/tensorboard-plugins/tb_plugin/fe/src/setup.tsx +++ b/plugins/tensorboard-plugins/tb_plugin/fe/src/setup.tsx @@ -2,8 +2,8 @@ * Copyright (c) Microsoft Corporation. All rights reserved. *--------------------------------------------------------------------------------------------*/ -export async function setup(): Promise { +export async function setup() { await google.charts.load('current', { - packages: ['corechart', 'table', 'timeline'], - }); + packages: ['corechart', 'table', 'timeline'] + }) } diff --git a/plugins/tensorboard-plugins/tb_plugin/fe/src/utils/binarysearch.ts b/plugins/tensorboard-plugins/tb_plugin/fe/src/utils/binarysearch.ts index 41382dcdb7acc8cb9e2b1b4f856e1855fb7ed88f..0477cac74d0b0d6836b53f18689891feb2f10cea 100644 --- a/plugins/tensorboard-plugins/tb_plugin/fe/src/utils/binarysearch.ts +++ b/plugins/tensorboard-plugins/tb_plugin/fe/src/utils/binarysearch.ts @@ -1,20 +1,20 @@ export function binarySearch( arr: Array, key: any, - compareFn: (key: number, mid: Array) => number + compare_fn: Function ): number { - let low = 0; - let high = arr.length - 1; + let low = 0, + high = arr.length - 1 while (low <= high) { - let mid = Math.round((high + low) / 2); - let cmp = compareFn(key, arr[mid]); + let mid = Math.round((high + low) / 2) + let cmp = compare_fn(key, arr[mid]) if (cmp > 0) { - low = mid + 1; + low = mid + 1 } else if (cmp < 0) { - high = mid - 1; + high = mid - 1 } else { - return mid; + return mid } } - return -1; + return -1 } diff --git a/plugins/tensorboard-plugins/tb_plugin/fe/src/utils/debounce.ts b/plugins/tensorboard-plugins/tb_plugin/fe/src/utils/debounce.ts index 82c7f04a98b788ab2c7c7647c292f163b8a92783..fcd6368e6ac9e971c85267fe5e6ccc9781235c9e 100644 --- a/plugins/tensorboard-plugins/tb_plugin/fe/src/utils/debounce.ts +++ b/plugins/tensorboard-plugins/tb_plugin/fe/src/utils/debounce.ts @@ -2,20 +2,20 @@ * Copyright (c) Microsoft Corporation. All rights reserved. *--------------------------------------------------------------------------------------------*/ -import * as React from 'react'; +import * as React from 'react' export function useDebounce(value: T, delay: number): T { - const [debouncedValue, setDebouncedValue] = React.useState(value); + const [debouncedValue, setDebouncedValue] = React.useState(value) React.useEffect(() => { const handler = setTimeout(() => { - setDebouncedValue(value); - }, delay); + setDebouncedValue(value) + }, delay) return () => { - clearTimeout(handler); - }; - }, [value, delay]); + clearTimeout(handler) + } + }, [value, delay]) - return debouncedValue; + return debouncedValue } diff --git a/plugins/tensorboard-plugins/tb_plugin/fe/src/utils/def.ts b/plugins/tensorboard-plugins/tb_plugin/fe/src/utils/def.ts index df6bef8eab076d13c0785902127f46a472ff9fa6..c024293a54e18e543c331226c317713f829c5c10 100644 --- a/plugins/tensorboard-plugins/tb_plugin/fe/src/utils/def.ts +++ b/plugins/tensorboard-plugins/tb_plugin/fe/src/utils/def.ts @@ -2,19 +2,17 @@ * Copyright (c) Microsoft Corporation. All rights reserved. *--------------------------------------------------------------------------------------------*/ -export function isDef(v?: T | null): v is T { - return v !== null && v !== undefined; +export function isDef(v: T | undefined | null): v is T { + return v !== null && v !== undefined } -export function assertDef(v?: T | null): asserts v is T { +export function assertDef(v: T | undefined | null): asserts v is T { if (!isDef(v)) { - throw new Error('Must be defined'); + throw new Error('Must be defined') } } -export function firstOrUndefined(v?: T[]): T | undefined { - if (!v || !v.length) { - return undefined; - } - return v[0]; +export function firstOrUndefined(v: T[] | undefined): T | undefined { + if (!v || !v.length) return undefined + return v[0] } diff --git a/plugins/tensorboard-plugins/tb_plugin/fe/src/utils/hooks.ts b/plugins/tensorboard-plugins/tb_plugin/fe/src/utils/hooks.ts index 473b393d9fa270438be85a7b528d78107c5f87f5..d8dd3eff536eb5e22683debe4338e785fe630616 100644 --- a/plugins/tensorboard-plugins/tb_plugin/fe/src/utils/hooks.ts +++ b/plugins/tensorboard-plugins/tb_plugin/fe/src/utils/hooks.ts @@ -2,26 +2,26 @@ * Copyright (c) Microsoft Corporation. All rights reserved. *--------------------------------------------------------------------------------------------*/ -import * as React from 'react'; +import * as React from 'react' -const cbs: Array<() => void> = []; -export const useOnResize = (cb: () => void): void => { +const cbs: (() => void)[] = [] +export const useOnResize = (cb: () => void) => { React.useEffect(() => { if (cbs.length === 0) { window.addEventListener('resize', () => { - cbs.forEach((callback) => callback()); - }); + cbs.forEach((cb) => cb()) + }) } - cbs.push(cb); + cbs.push(cb) - return (): void => { - const idx = cbs.findIndex(cb); + return () => { + const idx = cbs.findIndex(cb) if (idx > -1) { - cbs.splice(idx, 1); + cbs.splice(idx, 1) } if (cbs.length === 0) { - window.removeEventListener('reset', cb); + window.removeEventListener('reset', cb) } - }; - }, [cb]); -}; + } + }, [cb]) +} diff --git a/plugins/tensorboard-plugins/tb_plugin/fe/src/utils/index.ts b/plugins/tensorboard-plugins/tb_plugin/fe/src/utils/index.ts index 5da446721e9d1cac3729d8aea03bca2615031f41..1c7074b4c2002c40dc0b3f2f3da88d9a2b783a5f 100644 --- a/plugins/tensorboard-plugins/tb_plugin/fe/src/utils/index.ts +++ b/plugins/tensorboard-plugins/tb_plugin/fe/src/utils/index.ts @@ -2,23 +2,23 @@ * Copyright (c) Microsoft Corporation. All rights reserved. *--------------------------------------------------------------------------------------------*/ -import { ValueAndFormat } from '../api'; +import { ValueAndFormat } from '../api' -export function firstOrUndefined(v?: T[] | null): T | undefined { - if (!v || !v.length) { - return undefined; - } - return v[0]; +export function firstOrUndefined(v: T[] | undefined | null): T | undefined { + if (!v || !v.length) return undefined + return v[0] } -export function sleep(delay: number): Promise { - return new Promise((resolve) => setTimeout(resolve, delay)); +export function sleep(delay: number) { + return new Promise((resolve) => setTimeout(resolve, delay)) } export function isValueAndFormat(v: any): v is ValueAndFormat { - return 'f' in v && 'v' in v; + return 'f' in v && 'v' in v } -export function value(v: boolean | number | string | ValueAndFormat): boolean | number | string { - return typeof v === 'object' && isValueAndFormat(v) ? v.v : v; +export function value( + v: boolean | number | string | ValueAndFormat +): boolean | number | string { + return typeof v === 'object' && isValueAndFormat(v) ? v.v : v } diff --git a/plugins/tensorboard-plugins/tb_plugin/fe/src/utils/resize.ts b/plugins/tensorboard-plugins/tb_plugin/fe/src/utils/resize.ts index 766a10d54143fecd637b1d0dff33db17f22bee0d..57ab394042651fcddb7a48cfa158647d2e6b9faa 100644 --- a/plugins/tensorboard-plugins/tb_plugin/fe/src/utils/resize.ts +++ b/plugins/tensorboard-plugins/tb_plugin/fe/src/utils/resize.ts @@ -2,26 +2,26 @@ * Copyright (c) Microsoft Corporation. All rights reserved. *--------------------------------------------------------------------------------------------*/ -import * as React from 'react'; -import debounce from '@material-ui/core/utils/debounce'; +import * as React from 'react' +import debounce from '@material-ui/core/utils/debounce' -export function useResizeEventDependency(): readonly [number] { - const [version, setVersion] = React.useState(0); +export function useResizeEventDependency() { + const [version, setVersion] = React.useState(0) const increaseVersion = React.useCallback( debounce(() => { - setVersion((prev) => prev + 1); + setVersion((prev) => prev + 1) }, 100), [] - ); + ) React.useEffect(() => { - window.addEventListener('resize', increaseVersion); + window.addEventListener('resize', increaseVersion) - return (): void => { - window.removeEventListener('resize', increaseVersion); - }; - }, []); + return () => { + window.removeEventListener('resize', increaseVersion) + } + }, []) - return [version] as const; + return [version] as const } diff --git a/plugins/tensorboard-plugins/tb_plugin/fe/src/utils/search.ts b/plugins/tensorboard-plugins/tb_plugin/fe/src/utils/search.ts index 8a2efc36ddf505aee50171affd722bd5ef0a5b86..36689758752625b6c249c5fd532d93c9e5fbafb4 100644 --- a/plugins/tensorboard-plugins/tb_plugin/fe/src/utils/search.ts +++ b/plugins/tensorboard-plugins/tb_plugin/fe/src/utils/search.ts @@ -2,67 +2,65 @@ * Copyright (c) Microsoft Corporation. All rights reserved. *--------------------------------------------------------------------------------------------*/ -import * as React from 'react'; -import { value } from '.'; -import * as api from '../api'; -import { useDebounce } from './debounce'; +import * as React from 'react' +import { value } from '.' +import * as api from '../api' +import { useDebounce } from './debounce' export function useSearch( searchName: string, columnName: string, - table?: api.Graph + table: api.Graph | undefined ): [api.Graph | undefined] { - const searchNameDebounce = useDebounce(searchName.trim(), 500); + const searchNameDebounce = useDebounce(searchName.trim(), 500) const searchedTable: api.Graph | undefined = React.useMemo(() => { if (!searchNameDebounce) { - return table; + return table } if (!table) { - return undefined; + return undefined } - const columnNameToFind = columnName.toLowerCase(); + const columnNameToFind = columnName.toLowerCase() const nameColumnIdx = table.columns.findIndex( (c) => c.name.toLowerCase() === columnNameToFind - ); + ) if (nameColumnIdx < 0) { - return table; + return table } return { ...table, rows: table.rows.filter((x) => { - const cell = value(x[nameColumnIdx]); - return typeof cell === 'string' && cell.includes(searchNameDebounce); - }), - }; - }, [table, searchNameDebounce]); - return [searchedTable]; + const cell = value(x[nameColumnIdx]) + return typeof cell === 'string' && cell.includes(searchNameDebounce) + }) + } + }, [table, searchNameDebounce]) + return [searchedTable] } export function useSearchDirectly( searchName: string, field: (v: T) => string, - table?: T[] + table: T[] | undefined ): [T[] | undefined] { - const searchNameDebounce = useDebounce(searchName.trim(), 500); + const searchNameDebounce = useDebounce(searchName.trim(), 500) const result = React.useMemo(() => { if (!searchNameDebounce) { - return table; + return table } if (!table) { - return undefined; + return undefined } return table.filter((row) => { - return field(row) - .toLowerCase() - .includes(searchNameDebounce.toLowerCase()); - }); - }, [table, field, searchNameDebounce]); - return [result]; + return field(row).toLowerCase().includes(searchNameDebounce.toLowerCase()) + }) + }, [table, field, searchNameDebounce]) + return [result] } diff --git a/plugins/tensorboard-plugins/tb_plugin/fe/src/utils/top.ts b/plugins/tensorboard-plugins/tb_plugin/fe/src/utils/top.ts index 4af19968d637d6c13bf64caa94f09fff104f6091..87bd3c1b86f763a63dbf195ee5feaf649d56e006 100644 --- a/plugins/tensorboard-plugins/tb_plugin/fe/src/utils/top.ts +++ b/plugins/tensorboard-plugins/tb_plugin/fe/src/utils/top.ts @@ -2,53 +2,49 @@ * Copyright (c) Microsoft Corporation. All rights reserved. *--------------------------------------------------------------------------------------------*/ -import debounce from '@material-ui/core/utils/debounce'; -import * as React from 'react'; +import debounce from '@material-ui/core/utils/debounce' +import * as React from 'react' export enum UseTop { - NOT_USE = 'NotUse', - USE = 'Use', + NotUse = 'NotUse', + Use = 'Use' } interface IOptions { - defaultTop?: number; - defaultUseTop?: UseTop; - noDebounce?: boolean; - wait?: number; + defaultTop?: number + defaultUseTop?: UseTop + noDebounce?: boolean + wait?: number } -export function useTopN( - options?: IOptions -): readonly [ - string, - number | undefined, - UseTop, - React.Dispatch>, - React.Dispatch> -] { - let realOptions = options ?? {}; - - const [topText, setTopText] = React.useState(String(realOptions.defaultTop ?? 15)); - const [actualTop, setActualTop] = React.useState(Number(topText)); - const [useTop, setUseTop] = React.useState(realOptions.defaultUseTop ?? UseTop.NOT_USE); - - const setActualDebounce = !realOptions.noDebounce - ? React.useCallback(debounce(setActualTop, realOptions.wait ?? 500), []) - : setActualTop; +export function useTopN(options?: IOptions) { + options ??= {} + + const [topText, setTopText] = React.useState(String(options.defaultTop ?? 15)) + const [actualTop, setActualTop] = React.useState( + Number(topText) + ) + const [useTop, setUseTop] = React.useState( + options.defaultUseTop ?? UseTop.NotUse + ) + + const setActualDebounce = !options.noDebounce + ? React.useCallback(debounce(setActualTop, options.wait ?? 500), []) + : setActualTop React.useEffect(() => { - if (useTop !== UseTop.USE) { - setActualDebounce(undefined); + if (useTop !== UseTop.Use) { + setActualDebounce(undefined) } else if (topIsValid(topText)) { - setActualDebounce(Number(topText)); + setActualDebounce(Number(topText)) } else { - setActualDebounce(actualTop); + setActualDebounce(actualTop) } - }, [topText, useTop]); + }, [topText, useTop]) - return [topText, actualTop, useTop, setTopText, setUseTop] as const; + return [topText, actualTop, useTop, setTopText, setUseTop] as const } -export function topIsValid(topText: string): boolean { - const top = Number(topText); - return !Number.isNaN(top) && top > 0 && Number.isInteger(top); +export function topIsValid(topText: string) { + const top = Number(topText) + return !Number.isNaN(top) && top > 0 && Number.isInteger(top) } diff --git a/plugins/tensorboard-plugins/tb_plugin/fe/src/utils/type.ts b/plugins/tensorboard-plugins/tb_plugin/fe/src/utils/type.ts index ccd45fd16e11043abe40a4235a7b39a5d18afcdd..fde74bc598b930f26dd8a83157c91953da2c045c 100644 --- a/plugins/tensorboard-plugins/tb_plugin/fe/src/utils/type.ts +++ b/plugins/tensorboard-plugins/tb_plugin/fe/src/utils/type.ts @@ -6,4 +6,4 @@ export type Arguments void> = T extends ( ...args: infer A ) => void ? A - : never; + : never diff --git a/plugins/tensorboard-plugins/tb_plugin/fe/src/utils/vscode.ts b/plugins/tensorboard-plugins/tb_plugin/fe/src/utils/vscode.ts index 2a763adca54ef3eba96837aa111df627e3f8b116..62f1a90809548691f3b7b7a89d71ac65e4bf622b 100644 --- a/plugins/tensorboard-plugins/tb_plugin/fe/src/utils/vscode.ts +++ b/plugins/tensorboard-plugins/tb_plugin/fe/src/utils/vscode.ts @@ -2,12 +2,12 @@ * Copyright (c) Microsoft Corporation. All rights reserved. *--------------------------------------------------------------------------------------------*/ -export function navToCode(filename: string, line: number): void { +export function navToCode(filename: string, line: number) { window.parent.parent.postMessage( { filename, - line, + line }, - window.origin - ); + '*' + ) } diff --git a/plugins/tensorboard-plugins/tb_plugin/fe/update-static.js b/plugins/tensorboard-plugins/tb_plugin/fe/update-static.js index 67c9be6ccc266ca2470705ad7bb990e550769e96..9923c216781c4cfd3505bdc4cb99a736b1bc61a1 100644 --- a/plugins/tensorboard-plugins/tb_plugin/fe/update-static.js +++ b/plugins/tensorboard-plugins/tb_plugin/fe/update-static.js @@ -1,7 +1,7 @@ -const fs = require('fs'); -const path = require('path'); +const fs = require('fs') +const path = require('path') fs.copyFileSync( path.resolve(__dirname, 'dist/index.html'), path.resolve(__dirname, '../torch_tb_profiler/static/index.html') -); +) diff --git a/plugins/tensorboard-plugins/tb_plugin/fe/webpack.config.js b/plugins/tensorboard-plugins/tb_plugin/fe/webpack.config.js index a47f8b319e83a9c96c80c11afe5adf09e308fbfa..70541ae9cff81eccfd33a8edd2b2a8424edf5a4b 100644 --- a/plugins/tensorboard-plugins/tb_plugin/fe/webpack.config.js +++ b/plugins/tensorboard-plugins/tb_plugin/fe/webpack.config.js @@ -1,8 +1,8 @@ -const path = require('path'); -const HtmlWebpackPlugin = require('html-webpack-plugin'); -const InlineChunkHtmlPlugin = require('inline-chunk-html-plugin'); +const path = require('path') +const HtmlWebpackPlugin = require('html-webpack-plugin') +const InlineChunkHtmlPlugin = require('inline-chunk-html-plugin') -const isDev = process.env.NODE_ENV !== 'production'; +const isDev = process.env.NODE_ENV !== 'production' /** * @type {import('webpack').Configuration & import('webpack-dev-server').Configuration} @@ -12,25 +12,25 @@ module.exports = { entry: './src/index.tsx', output: { path: path.resolve(__dirname, 'dist'), - filename: 'index.js', + filename: 'index.js' }, resolve: { // Add `.ts` and `.tsx` as a resolvable extension. - extensions: ['.ts', '.tsx', '.js'], + extensions: ['.ts', '.tsx', '.js'] }, module: { rules: [ { test: /\.tsx?$/i, use: 'ts-loader' }, - { test: /\.css$/i, use: ['style-loader', 'css-loader'] }, - ], + { test: /\.css$/i, use: ['style-loader', 'css-loader'] } + ] }, plugins: [ new HtmlWebpackPlugin({ inject: true, scriptLoading: 'blocking', - template: 'index.html', + template: 'index.html' }), - !isDev ? new InlineChunkHtmlPlugin(HtmlWebpackPlugin, [/.*/]) : undefined, + !isDev ? new InlineChunkHtmlPlugin(HtmlWebpackPlugin, [/.*/]) : undefined ].filter(Boolean), - devServer: {}, -}; + devServer: {} +} diff --git a/plugins/tensorboard-plugins/tb_plugin/fe/yarn.lock b/plugins/tensorboard-plugins/tb_plugin/fe/yarn.lock new file mode 100644 index 0000000000000000000000000000000000000000..3e914db864c7654443e9041cfc1899ea2ac30bb1 --- /dev/null +++ b/plugins/tensorboard-plugins/tb_plugin/fe/yarn.lock @@ -0,0 +1,3672 @@ +# THIS IS AN AUTOGENERATED FILE. DO NOT EDIT THIS FILE DIRECTLY. +# yarn lockfile v1 + + +"@ant-design/colors@^6.0.0": + version "6.0.0" + resolved "https://registry.yarnpkg.com/@ant-design/colors/-/colors-6.0.0.tgz#9b9366257cffcc47db42b9d0203bb592c13c0298" + integrity sha512-qAZRvPzfdWHtfameEGP2Qvuf838NhergR35o+EuVyB5XvSA98xod5r4utvi4TJ3ywmevm290g9nsCG5MryrdWQ== + dependencies: + "@ctrl/tinycolor" "^3.4.0" + +"@ant-design/icons-svg@^4.2.1": + version "4.2.1" + resolved "https://registry.yarnpkg.com/@ant-design/icons-svg/-/icons-svg-4.2.1.tgz#8630da8eb4471a4aabdaed7d1ff6a97dcb2cf05a" + integrity sha512-EB0iwlKDGpG93hW8f85CTJTs4SvMX7tt5ceupvhALp1IF44SeUFOMhKUOYqpsoYWQKAOuTRDMqn75rEaKDp0Xw== + +"@ant-design/icons@^4.7.0": + version "4.7.0" + resolved "https://registry.yarnpkg.com/@ant-design/icons/-/icons-4.7.0.tgz#8c3cbe0a556ba92af5dc7d1e70c0b25b5179af0f" + integrity sha512-aoB4Z7JA431rt6d4u+8xcNPPCrdufSRMUOpxa1ab6mz1JCQZOEVolj2WVs/tDFmN62zzK30mNelEsprLYsSF3g== + dependencies: + "@ant-design/colors" "^6.0.0" + "@ant-design/icons-svg" "^4.2.1" + "@babel/runtime" "^7.11.2" + classnames "^2.2.6" + rc-util "^5.9.4" + +"@ant-design/react-slick@~0.28.1": + version "0.28.4" + resolved "https://registry.yarnpkg.com/@ant-design/react-slick/-/react-slick-0.28.4.tgz#8b296b87ad7c7ae877f2a527b81b7eebd9dd29a9" + integrity sha512-j9eAHTn7GxbXUFNknJoHS2ceAsqrQi2j8XykjZE1IXCD8kJF+t28EvhBLniDpbOsBk/3kjalnhriTfZcjBHNqg== + dependencies: + "@babel/runtime" "^7.10.4" + classnames "^2.2.5" + json2mq "^0.2.0" + lodash "^4.17.21" + resize-observer-polyfill "^1.5.0" + +"@babel/runtime@^7.0.0", "@babel/runtime@^7.10.1", "@babel/runtime@^7.10.2", "@babel/runtime@^7.10.4", "@babel/runtime@^7.11.1", "@babel/runtime@^7.11.2", "@babel/runtime@^7.12.5", "@babel/runtime@^7.13.10", "@babel/runtime@^7.3.1", "@babel/runtime@^7.4.4", "@babel/runtime@^7.5.5", "@babel/runtime@^7.8.3", "@babel/runtime@^7.8.4", "@babel/runtime@^7.8.7": + version "7.17.2" + resolved "https://registry.yarnpkg.com/@babel/runtime/-/runtime-7.17.2.tgz#66f68591605e59da47523c631416b18508779941" + integrity sha512-hzeyJyMA1YGdJTuWU0e/j4wKXrU4OMFvY2MSlaI9B7VQb0r5cxTE3EAIS2Q7Tn2RIcDkRvTA/v2JsAEhxe99uw== + dependencies: + regenerator-runtime "^0.13.4" + +"@ctrl/tinycolor@^3.4.0": + version "3.4.0" + resolved "https://registry.yarnpkg.com/@ctrl/tinycolor/-/tinycolor-3.4.0.tgz#c3c5ae543c897caa9c2a68630bed355be5f9990f" + integrity sha512-JZButFdZ1+/xAfpguQHoabIXkcqRRKpMrWKBkpEZZyxfY9C1DpADFB8PEqGSTeFr135SaTRfKqGKx5xSCLI7ZQ== + +"@discoveryjs/json-ext@^0.5.0": + version "0.5.6" + resolved "https://registry.yarnpkg.com/@discoveryjs/json-ext/-/json-ext-0.5.6.tgz#d5e0706cf8c6acd8c6032f8d54070af261bbbb2f" + integrity sha512-ws57AidsDvREKrZKYffXddNkyaF14iHNHm8VQnZH6t99E8gczjNN0GpvcGny0imC80yQ0tHz1xVUKk/KFQSUyA== + +"@emotion/hash@^0.8.0": + version "0.8.0" + resolved "https://registry.yarnpkg.com/@emotion/hash/-/hash-0.8.0.tgz#bbbff68978fefdbe68ccb533bc8cbe1d1afb5413" + integrity sha512-kBJtf7PH6aWwZ6fka3zQ0p6SBYzx4fl1LoZXE2RrnYST9Xljm7WfKJrU4g/Xr3Beg72MLrp1AWNUmuYJTL7Cow== + +"@material-ui/core@^4.11.3": + version "4.12.3" + resolved "https://registry.yarnpkg.com/@material-ui/core/-/core-4.12.3.tgz#80d665caf0f1f034e52355c5450c0e38b099d3ca" + integrity sha512-sdpgI/PL56QVsEJldwEe4FFaFTLUqN+rd7sSZiRCdx2E/C7z5yK0y/khAWVBH24tXwto7I1hCzNWfJGZIYJKnw== + dependencies: + "@babel/runtime" "^7.4.4" + "@material-ui/styles" "^4.11.4" + "@material-ui/system" "^4.12.1" + "@material-ui/types" "5.1.0" + "@material-ui/utils" "^4.11.2" + "@types/react-transition-group" "^4.2.0" + clsx "^1.0.4" + hoist-non-react-statics "^3.3.2" + popper.js "1.16.1-lts" + prop-types "^15.7.2" + react-is "^16.8.0 || ^17.0.0" + react-transition-group "^4.4.0" + +"@material-ui/icons@^4.11.2": + version "4.11.2" + resolved "https://registry.yarnpkg.com/@material-ui/icons/-/icons-4.11.2.tgz#b3a7353266519cd743b6461ae9fdfcb1b25eb4c5" + integrity sha512-fQNsKX2TxBmqIGJCSi3tGTO/gZ+eJgWmMJkgDiOfyNaunNaxcklJQFaFogYcFl0qFuaEz1qaXYXboa/bUXVSOQ== + dependencies: + "@babel/runtime" "^7.4.4" + +"@material-ui/styles@^4.11.4": + version "4.11.4" + resolved "https://registry.yarnpkg.com/@material-ui/styles/-/styles-4.11.4.tgz#eb9dfccfcc2d208243d986457dff025497afa00d" + integrity sha512-KNTIZcnj/zprG5LW0Sao7zw+yG3O35pviHzejMdcSGCdWbiO8qzRgOYL8JAxAsWBKOKYwVZxXtHWaB5T2Kvxew== + dependencies: + "@babel/runtime" "^7.4.4" + "@emotion/hash" "^0.8.0" + "@material-ui/types" "5.1.0" + "@material-ui/utils" "^4.11.2" + clsx "^1.0.4" + csstype "^2.5.2" + hoist-non-react-statics "^3.3.2" + jss "^10.5.1" + jss-plugin-camel-case "^10.5.1" + jss-plugin-default-unit "^10.5.1" + jss-plugin-global "^10.5.1" + jss-plugin-nested "^10.5.1" + jss-plugin-props-sort "^10.5.1" + jss-plugin-rule-value-function "^10.5.1" + jss-plugin-vendor-prefixer "^10.5.1" + prop-types "^15.7.2" + +"@material-ui/system@^4.12.1": + version "4.12.1" + resolved "https://registry.yarnpkg.com/@material-ui/system/-/system-4.12.1.tgz#2dd96c243f8c0a331b2bb6d46efd7771a399707c" + integrity sha512-lUdzs4q9kEXZGhbN7BptyiS1rLNHe6kG9o8Y307HCvF4sQxbCgpL2qi+gUk+yI8a2DNk48gISEQxoxpgph0xIw== + dependencies: + "@babel/runtime" "^7.4.4" + "@material-ui/utils" "^4.11.2" + csstype "^2.5.2" + prop-types "^15.7.2" + +"@material-ui/types@5.1.0": + version "5.1.0" + resolved "https://registry.yarnpkg.com/@material-ui/types/-/types-5.1.0.tgz#efa1c7a0b0eaa4c7c87ac0390445f0f88b0d88f2" + integrity sha512-7cqRjrY50b8QzRSYyhSpx4WRw2YuO0KKIGQEVk5J8uoz2BanawykgZGoWEqKm7pVIbzFDN0SpPcVV4IhOFkl8A== + +"@material-ui/utils@^4.11.2": + version "4.11.2" + resolved "https://registry.yarnpkg.com/@material-ui/utils/-/utils-4.11.2.tgz#f1aefa7e7dff2ebcb97d31de51aecab1bb57540a" + integrity sha512-Uul8w38u+PICe2Fg2pDKCaIG7kOyhowZ9vjiC1FsVwPABTW8vPPKfF6OvxRq3IiBaI1faOJmgdvMG7rMJARBhA== + dependencies: + "@babel/runtime" "^7.4.4" + prop-types "^15.7.2" + react-is "^16.8.0 || ^17.0.0" + +"@nodelib/fs.scandir@2.1.5": + version "2.1.5" + resolved "https://registry.yarnpkg.com/@nodelib/fs.scandir/-/fs.scandir-2.1.5.tgz#7619c2eb21b25483f6d167548b4cfd5a7488c3d5" + integrity sha512-vq24Bq3ym5HEQm2NKCr3yXDwjc7vTsEThRDnkp2DK9p1uqLR+DHurm/NOTo0KG7HYHU7eppKZj3MyqYuMBf62g== + dependencies: + "@nodelib/fs.stat" "2.0.5" + run-parallel "^1.1.9" + +"@nodelib/fs.stat@2.0.5", "@nodelib/fs.stat@^2.0.2": + version "2.0.5" + resolved "https://registry.yarnpkg.com/@nodelib/fs.stat/-/fs.stat-2.0.5.tgz#5bd262af94e9d25bd1e71b05deed44876a222e8b" + integrity sha512-RkhPPp2zrqDAQA/2jNhnztcPAlv64XdhIp7a7454A5ovI7Bukxgt7MX7udwAu3zg1DcpPU0rz3VV1SeaqvY4+A== + +"@nodelib/fs.walk@^1.2.3": + version "1.2.8" + resolved "https://registry.yarnpkg.com/@nodelib/fs.walk/-/fs.walk-1.2.8.tgz#e95737e8bb6746ddedf69c556953494f196fe69a" + integrity sha512-oGB+UxlgWcgQkgwo8GcEGwemoTFt3FIO9ababBmaGwXIoBKZ+GTy0pP185beGg7Llih/NSHSV2XAs1lnznocSg== + dependencies: + "@nodelib/fs.scandir" "2.1.5" + fastq "^1.6.0" + +"@types/body-parser@*": + version "1.19.2" + resolved "https://registry.yarnpkg.com/@types/body-parser/-/body-parser-1.19.2.tgz#aea2059e28b7658639081347ac4fab3de166e6f0" + integrity sha512-ALYone6pm6QmwZoAgeyNksccT9Q4AWZQ6PvfwR37GT6r6FWUPguq6sUmNGSMV2Wr761oQoBxwGGa6DR5o1DC9g== + dependencies: + "@types/connect" "*" + "@types/node" "*" + +"@types/bonjour@^3.5.9": + version "3.5.10" + resolved "https://registry.yarnpkg.com/@types/bonjour/-/bonjour-3.5.10.tgz#0f6aadfe00ea414edc86f5d106357cda9701e275" + integrity sha512-p7ienRMiS41Nu2/igbJxxLDWrSZ0WxM8UQgCeO9KhoVF7cOVFkrKsiDr1EsJIla8vV3oEEjGcz11jc5yimhzZw== + dependencies: + "@types/node" "*" + +"@types/connect-history-api-fallback@^1.3.5": + version "1.3.5" + resolved "https://registry.yarnpkg.com/@types/connect-history-api-fallback/-/connect-history-api-fallback-1.3.5.tgz#d1f7a8a09d0ed5a57aee5ae9c18ab9b803205dae" + integrity sha512-h8QJa8xSb1WD4fpKBDcATDNGXghFj6/3GRWG6dhmRcu0RX1Ubasur2Uvx5aeEwlf0MwblEC2bMzzMQntxnw/Cw== + dependencies: + "@types/express-serve-static-core" "*" + "@types/node" "*" + +"@types/connect@*": + version "3.4.35" + resolved "https://registry.yarnpkg.com/@types/connect/-/connect-3.4.35.tgz#5fcf6ae445e4021d1fc2219a4873cc73a3bb2ad1" + integrity sha512-cdeYyv4KWoEgpBISTxWvqYsVy444DOqehiF3fM3ne10AmJ62RSyNkUnxMJXHQWRQQX2eR94m5y1IZyDwBjV9FQ== + dependencies: + "@types/node" "*" + +"@types/eslint-scope@^3.7.3": + version "3.7.3" + resolved "https://registry.yarnpkg.com/@types/eslint-scope/-/eslint-scope-3.7.3.tgz#125b88504b61e3c8bc6f870882003253005c3224" + integrity sha512-PB3ldyrcnAicT35TWPs5IcwKD8S333HMaa2VVv4+wdvebJkjWuW/xESoB8IwRcog8HYVYamb1g/R31Qv5Bx03g== + dependencies: + "@types/eslint" "*" + "@types/estree" "*" + +"@types/eslint@*": + version "8.4.1" + resolved "https://registry.yarnpkg.com/@types/eslint/-/eslint-8.4.1.tgz#c48251553e8759db9e656de3efc846954ac32304" + integrity sha512-GE44+DNEyxxh2Kc6ro/VkIj+9ma0pO0bwv9+uHSyBrikYOHr8zYcdPvnBOp1aw8s+CjRvuSx7CyWqRrNFQ59mA== + dependencies: + "@types/estree" "*" + "@types/json-schema" "*" + +"@types/estree@*", "@types/estree@^0.0.51": + version "0.0.51" + resolved "https://registry.yarnpkg.com/@types/estree/-/estree-0.0.51.tgz#cfd70924a25a3fd32b218e5e420e6897e1ac4f40" + integrity sha512-CuPgU6f3eT/XgKKPqKd/gLZV1Xmvf1a2R5POBOGQa6uv82xpls89HU5zKeVoyR8XzHd1RGNOlQlvUe3CFkjWNQ== + +"@types/express-serve-static-core@*", "@types/express-serve-static-core@^4.17.18": + version "4.17.28" + resolved "https://registry.yarnpkg.com/@types/express-serve-static-core/-/express-serve-static-core-4.17.28.tgz#c47def9f34ec81dc6328d0b1b5303d1ec98d86b8" + integrity sha512-P1BJAEAW3E2DJUlkgq4tOL3RyMunoWXqbSCygWo5ZIWTjUgN1YnaXWW4VWl/oc8vs/XoYibEGBKP0uZyF4AHig== + dependencies: + "@types/node" "*" + "@types/qs" "*" + "@types/range-parser" "*" + +"@types/express@*", "@types/express@^4.17.13": + version "4.17.13" + resolved "https://registry.yarnpkg.com/@types/express/-/express-4.17.13.tgz#a76e2995728999bab51a33fabce1d705a3709034" + integrity sha512-6bSZTPaTIACxn48l50SR+axgrqm6qXFIxrdAKaG6PaJk3+zuUr35hBlgT7vOmJcum+OEaIBLtHV/qloEAFITeA== + dependencies: + "@types/body-parser" "*" + "@types/express-serve-static-core" "^4.17.18" + "@types/qs" "*" + "@types/serve-static" "*" + +"@types/html-minifier-terser@^6.0.0": + version "6.1.0" + resolved "https://registry.yarnpkg.com/@types/html-minifier-terser/-/html-minifier-terser-6.1.0.tgz#4fc33a00c1d0c16987b1a20cf92d20614c55ac35" + integrity sha512-oh/6byDPnL1zeNXFrDXFLyZjkr1MsBG667IM792caf1L2UPOOMf65NFzjUH/ltyfwjAGfs1rsX1eftK0jC/KIg== + +"@types/http-proxy@^1.17.8": + version "1.17.8" + resolved "https://registry.yarnpkg.com/@types/http-proxy/-/http-proxy-1.17.8.tgz#968c66903e7e42b483608030ee85800f22d03f55" + integrity sha512-5kPLG5BKpWYkw/LVOGWpiq3nEVqxiN32rTgI53Sk12/xHFQ2rG3ehI9IO+O3W2QoKeyB92dJkoka8SUm6BX1pA== + dependencies: + "@types/node" "*" + +"@types/json-schema@*", "@types/json-schema@^7.0.8", "@types/json-schema@^7.0.9": + version "7.0.9" + resolved "https://registry.yarnpkg.com/@types/json-schema/-/json-schema-7.0.9.tgz#97edc9037ea0c38585320b28964dde3b39e4660d" + integrity sha512-qcUXuemtEu+E5wZSJHNxUXeCZhAfXKQ41D+duX+VYPde7xyEVZci+/oXKJL13tnRs9lR2pr4fod59GT6/X1/yQ== + +"@types/mime@^1": + version "1.3.2" + resolved "https://registry.yarnpkg.com/@types/mime/-/mime-1.3.2.tgz#93e25bf9ee75fe0fd80b594bc4feb0e862111b5a" + integrity sha512-YATxVxgRqNH6nHEIsvg6k2Boc1JHI9ZbH5iWFFv/MTkchz3b1ieGDa5T0a9RznNdI0KhVbdbWSN+KWWrQZRxTw== + +"@types/node@*": + version "17.0.21" + resolved "https://registry.yarnpkg.com/@types/node/-/node-17.0.21.tgz#864b987c0c68d07b4345845c3e63b75edd143644" + integrity sha512-DBZCJbhII3r90XbQxI8Y9IjjiiOGlZ0Hr32omXIZvwwZ7p4DMMXGrKXVyPfuoBOri9XNtL0UK69jYIBIsRX3QQ== + +"@types/prop-types@*": + version "15.7.4" + resolved "https://registry.yarnpkg.com/@types/prop-types/-/prop-types-15.7.4.tgz#fcf7205c25dff795ee79af1e30da2c9790808f11" + integrity sha512-rZ5drC/jWjrArrS8BR6SIr4cWpW09RNTYt9AMZo3Jwwif+iacXAqgVjm0B0Bv/S1jhDXKHqRVNCbACkJ89RAnQ== + +"@types/qs@*": + version "6.9.7" + resolved "https://registry.yarnpkg.com/@types/qs/-/qs-6.9.7.tgz#63bb7d067db107cc1e457c303bc25d511febf6cb" + integrity sha512-FGa1F62FT09qcrueBA6qYTrJPVDzah9a+493+o2PCXsesWHIn27G98TsSMs3WPNbZIEj4+VJf6saSFpvD+3Zsw== + +"@types/range-parser@*": + version "1.2.4" + resolved "https://registry.yarnpkg.com/@types/range-parser/-/range-parser-1.2.4.tgz#cd667bcfdd025213aafb7ca5915a932590acdcdc" + integrity sha512-EEhsLsD6UsDM1yFhAvy0Cjr6VwmpMWqFBCb9w07wVugF7w9nfajxLuVmngTIpgS6svCnm6Vaw+MZhoDCKnOfsw== + +"@types/react-dom@^16.9.8": + version "16.9.14" + resolved "https://registry.yarnpkg.com/@types/react-dom/-/react-dom-16.9.14.tgz#674b8f116645fe5266b40b525777fc6bb8eb3bcd" + integrity sha512-FIX2AVmPTGP30OUJ+0vadeIFJJ07Mh1m+U0rxfgyW34p3rTlXI+nlenvAxNn4BP36YyI9IJ/+UJ7Wu22N1pI7A== + dependencies: + "@types/react" "^16" + +"@types/react-transition-group@^4.2.0": + version "4.4.4" + resolved "https://registry.yarnpkg.com/@types/react-transition-group/-/react-transition-group-4.4.4.tgz#acd4cceaa2be6b757db61ed7b432e103242d163e" + integrity sha512-7gAPz7anVK5xzbeQW9wFBDg7G++aPLAFY0QaSMOou9rJZpbuI58WAuJrgu+qR92l61grlnCUe7AFX8KGahAgug== + dependencies: + "@types/react" "*" + +"@types/react@*": + version "17.0.39" + resolved "https://registry.yarnpkg.com/@types/react/-/react-17.0.39.tgz#d0f4cde092502a6db00a1cded6e6bf2abb7633ce" + integrity sha512-UVavlfAxDd/AgAacMa60Azl7ygyQNRwC/DsHZmKgNvPmRR5p70AJ5Q9EAmL2NWOJmeV+vVUI4IAP7GZrN8h8Ug== + dependencies: + "@types/prop-types" "*" + "@types/scheduler" "*" + csstype "^3.0.2" + +"@types/react@^16", "@types/react@^16.9.51": + version "16.14.23" + resolved "https://registry.yarnpkg.com/@types/react/-/react-16.14.23.tgz#37201b9f2324c5ff8fa4600dbf19079dfdffc880" + integrity sha512-WngBZLuSkP4IAgPi0HOsGCHo6dn3CcuLQnCfC17VbA7YBgipZiZoTOhObwl/93DsFW0Y2a/ZXeonpW4DxirEJg== + dependencies: + "@types/prop-types" "*" + "@types/scheduler" "*" + csstype "^3.0.2" + +"@types/retry@^0.12.0": + version "0.12.1" + resolved "https://registry.yarnpkg.com/@types/retry/-/retry-0.12.1.tgz#d8f1c0d0dc23afad6dc16a9e993a0865774b4065" + integrity sha512-xoDlM2S4ortawSWORYqsdU+2rxdh4LRW9ytc3zmT37RIKQh6IHyKwwtKhKis9ah8ol07DCkZxPt8BBvPjC6v4g== + +"@types/scheduler@*": + version "0.16.2" + resolved "https://registry.yarnpkg.com/@types/scheduler/-/scheduler-0.16.2.tgz#1a62f89525723dde24ba1b01b092bf5df8ad4d39" + integrity sha512-hppQEBDmlwhFAXKJX2KnWLYu5yMfi91yazPb2l+lbJiwW+wdo1gNeRA+3RgNSO39WYX2euey41KEwnqesU2Jew== + +"@types/serve-index@^1.9.1": + version "1.9.1" + resolved "https://registry.yarnpkg.com/@types/serve-index/-/serve-index-1.9.1.tgz#1b5e85370a192c01ec6cec4735cf2917337a6278" + integrity sha512-d/Hs3nWDxNL2xAczmOVZNj92YZCS6RGxfBPjKzuu/XirCgXdpKEb88dYNbrYGint6IVWLNP+yonwVAuRC0T2Dg== + dependencies: + "@types/express" "*" + +"@types/serve-static@*": + version "1.13.10" + resolved "https://registry.yarnpkg.com/@types/serve-static/-/serve-static-1.13.10.tgz#f5e0ce8797d2d7cc5ebeda48a52c96c4fa47a8d9" + integrity sha512-nCkHGI4w7ZgAdNkrEu0bv+4xNV/XDqW+DydknebMOQwkpDGx8G+HTlj7R7ABI8i8nKxVw0wtKPi1D+lPOkh4YQ== + dependencies: + "@types/mime" "^1" + "@types/node" "*" + +"@types/sockjs@^0.3.33": + version "0.3.33" + resolved "https://registry.yarnpkg.com/@types/sockjs/-/sockjs-0.3.33.tgz#570d3a0b99ac995360e3136fd6045113b1bd236f" + integrity sha512-f0KEEe05NvUnat+boPTZ0dgaLZ4SfSouXUgv5noUiefG2ajgKjmETo9ZJyuqsl7dfl2aHlLJUiki6B4ZYldiiw== + dependencies: + "@types/node" "*" + +"@types/ws@^8.2.2": + version "8.5.2" + resolved "https://registry.yarnpkg.com/@types/ws/-/ws-8.5.2.tgz#77e0c2e360e9579da930ffcfa53c5975ea3bdd26" + integrity sha512-VXI82ykONr5tacHEojnErTQk+KQSoYbW1NB6iz6wUwrNd+BqfkfggQNoNdCqhJSzbNumShPERbM+Pc5zpfhlbw== + dependencies: + "@types/node" "*" + +"@webassemblyjs/ast@1.11.1": + version "1.11.1" + resolved "https://registry.yarnpkg.com/@webassemblyjs/ast/-/ast-1.11.1.tgz#2bfd767eae1a6996f432ff7e8d7fc75679c0b6a7" + integrity sha512-ukBh14qFLjxTQNTXocdyksN5QdM28S1CxHt2rdskFyL+xFV7VremuBLVbmCePj+URalXBENx/9Lm7lnhihtCSw== + dependencies: + "@webassemblyjs/helper-numbers" "1.11.1" + "@webassemblyjs/helper-wasm-bytecode" "1.11.1" + +"@webassemblyjs/floating-point-hex-parser@1.11.1": + version "1.11.1" + resolved "https://registry.yarnpkg.com/@webassemblyjs/floating-point-hex-parser/-/floating-point-hex-parser-1.11.1.tgz#f6c61a705f0fd7a6aecaa4e8198f23d9dc179e4f" + integrity sha512-iGRfyc5Bq+NnNuX8b5hwBrRjzf0ocrJPI6GWFodBFzmFnyvrQ83SHKhmilCU/8Jv67i4GJZBMhEzltxzcNagtQ== + +"@webassemblyjs/helper-api-error@1.11.1": + version "1.11.1" + resolved "https://registry.yarnpkg.com/@webassemblyjs/helper-api-error/-/helper-api-error-1.11.1.tgz#1a63192d8788e5c012800ba6a7a46c705288fd16" + integrity sha512-RlhS8CBCXfRUR/cwo2ho9bkheSXG0+NwooXcc3PAILALf2QLdFyj7KGsKRbVc95hZnhnERon4kW/D3SZpp6Tcg== + +"@webassemblyjs/helper-buffer@1.11.1": + version "1.11.1" + resolved "https://registry.yarnpkg.com/@webassemblyjs/helper-buffer/-/helper-buffer-1.11.1.tgz#832a900eb444884cde9a7cad467f81500f5e5ab5" + integrity sha512-gwikF65aDNeeXa8JxXa2BAk+REjSyhrNC9ZwdT0f8jc4dQQeDQ7G4m0f2QCLPJiMTTO6wfDmRmj/pW0PsUvIcA== + +"@webassemblyjs/helper-numbers@1.11.1": + version "1.11.1" + resolved "https://registry.yarnpkg.com/@webassemblyjs/helper-numbers/-/helper-numbers-1.11.1.tgz#64d81da219fbbba1e3bd1bfc74f6e8c4e10a62ae" + integrity sha512-vDkbxiB8zfnPdNK9Rajcey5C0w+QJugEglN0of+kmO8l7lDb77AnlKYQF7aarZuCrv+l0UvqL+68gSDr3k9LPQ== + dependencies: + "@webassemblyjs/floating-point-hex-parser" "1.11.1" + "@webassemblyjs/helper-api-error" "1.11.1" + "@xtuc/long" "4.2.2" + +"@webassemblyjs/helper-wasm-bytecode@1.11.1": + version "1.11.1" + resolved "https://registry.yarnpkg.com/@webassemblyjs/helper-wasm-bytecode/-/helper-wasm-bytecode-1.11.1.tgz#f328241e41e7b199d0b20c18e88429c4433295e1" + integrity sha512-PvpoOGiJwXeTrSf/qfudJhwlvDQxFgelbMqtq52WWiXC6Xgg1IREdngmPN3bs4RoO83PnL/nFrxucXj1+BX62Q== + +"@webassemblyjs/helper-wasm-section@1.11.1": + version "1.11.1" + resolved "https://registry.yarnpkg.com/@webassemblyjs/helper-wasm-section/-/helper-wasm-section-1.11.1.tgz#21ee065a7b635f319e738f0dd73bfbda281c097a" + integrity sha512-10P9No29rYX1j7F3EVPX3JvGPQPae+AomuSTPiF9eBQeChHI6iqjMIwR9JmOJXwpnn/oVGDk7I5IlskuMwU/pg== + dependencies: + "@webassemblyjs/ast" "1.11.1" + "@webassemblyjs/helper-buffer" "1.11.1" + "@webassemblyjs/helper-wasm-bytecode" "1.11.1" + "@webassemblyjs/wasm-gen" "1.11.1" + +"@webassemblyjs/ieee754@1.11.1": + version "1.11.1" + resolved "https://registry.yarnpkg.com/@webassemblyjs/ieee754/-/ieee754-1.11.1.tgz#963929e9bbd05709e7e12243a099180812992614" + integrity sha512-hJ87QIPtAMKbFq6CGTkZYJivEwZDbQUgYd3qKSadTNOhVY7p+gfP6Sr0lLRVTaG1JjFj+r3YchoqRYxNH3M0GQ== + dependencies: + "@xtuc/ieee754" "^1.2.0" + +"@webassemblyjs/leb128@1.11.1": + version "1.11.1" + resolved "https://registry.yarnpkg.com/@webassemblyjs/leb128/-/leb128-1.11.1.tgz#ce814b45574e93d76bae1fb2644ab9cdd9527aa5" + integrity sha512-BJ2P0hNZ0u+Th1YZXJpzW6miwqQUGcIHT1G/sf72gLVD9DZ5AdYTqPNbHZh6K1M5VmKvFXwGSWZADz+qBWxeRw== + dependencies: + "@xtuc/long" "4.2.2" + +"@webassemblyjs/utf8@1.11.1": + version "1.11.1" + resolved "https://registry.yarnpkg.com/@webassemblyjs/utf8/-/utf8-1.11.1.tgz#d1f8b764369e7c6e6bae350e854dec9a59f0a3ff" + integrity sha512-9kqcxAEdMhiwQkHpkNiorZzqpGrodQQ2IGrHHxCy+Ozng0ofyMA0lTqiLkVs1uzTRejX+/O0EOT7KxqVPuXosQ== + +"@webassemblyjs/wasm-edit@1.11.1": + version "1.11.1" + resolved "https://registry.yarnpkg.com/@webassemblyjs/wasm-edit/-/wasm-edit-1.11.1.tgz#ad206ebf4bf95a058ce9880a8c092c5dec8193d6" + integrity sha512-g+RsupUC1aTHfR8CDgnsVRVZFJqdkFHpsHMfJuWQzWU3tvnLC07UqHICfP+4XyL2tnr1amvl1Sdp06TnYCmVkA== + dependencies: + "@webassemblyjs/ast" "1.11.1" + "@webassemblyjs/helper-buffer" "1.11.1" + "@webassemblyjs/helper-wasm-bytecode" "1.11.1" + "@webassemblyjs/helper-wasm-section" "1.11.1" + "@webassemblyjs/wasm-gen" "1.11.1" + "@webassemblyjs/wasm-opt" "1.11.1" + "@webassemblyjs/wasm-parser" "1.11.1" + "@webassemblyjs/wast-printer" "1.11.1" + +"@webassemblyjs/wasm-gen@1.11.1": + version "1.11.1" + resolved "https://registry.yarnpkg.com/@webassemblyjs/wasm-gen/-/wasm-gen-1.11.1.tgz#86c5ea304849759b7d88c47a32f4f039ae3c8f76" + integrity sha512-F7QqKXwwNlMmsulj6+O7r4mmtAlCWfO/0HdgOxSklZfQcDu0TpLiD1mRt/zF25Bk59FIjEuGAIyn5ei4yMfLhA== + dependencies: + "@webassemblyjs/ast" "1.11.1" + "@webassemblyjs/helper-wasm-bytecode" "1.11.1" + "@webassemblyjs/ieee754" "1.11.1" + "@webassemblyjs/leb128" "1.11.1" + "@webassemblyjs/utf8" "1.11.1" + +"@webassemblyjs/wasm-opt@1.11.1": + version "1.11.1" + resolved "https://registry.yarnpkg.com/@webassemblyjs/wasm-opt/-/wasm-opt-1.11.1.tgz#657b4c2202f4cf3b345f8a4c6461c8c2418985f2" + integrity sha512-VqnkNqnZlU5EB64pp1l7hdm3hmQw7Vgqa0KF/KCNO9sIpI6Fk6brDEiX+iCOYrvMuBWDws0NkTOxYEb85XQHHw== + dependencies: + "@webassemblyjs/ast" "1.11.1" + "@webassemblyjs/helper-buffer" "1.11.1" + "@webassemblyjs/wasm-gen" "1.11.1" + "@webassemblyjs/wasm-parser" "1.11.1" + +"@webassemblyjs/wasm-parser@1.11.1": + version "1.11.1" + resolved "https://registry.yarnpkg.com/@webassemblyjs/wasm-parser/-/wasm-parser-1.11.1.tgz#86ca734534f417e9bd3c67c7a1c75d8be41fb199" + integrity sha512-rrBujw+dJu32gYB7/Lup6UhdkPx9S9SnobZzRVL7VcBH9Bt9bCBLEuX/YXOOtBsOZ4NQrRykKhffRWHvigQvOA== + dependencies: + "@webassemblyjs/ast" "1.11.1" + "@webassemblyjs/helper-api-error" "1.11.1" + "@webassemblyjs/helper-wasm-bytecode" "1.11.1" + "@webassemblyjs/ieee754" "1.11.1" + "@webassemblyjs/leb128" "1.11.1" + "@webassemblyjs/utf8" "1.11.1" + +"@webassemblyjs/wast-printer@1.11.1": + version "1.11.1" + resolved "https://registry.yarnpkg.com/@webassemblyjs/wast-printer/-/wast-printer-1.11.1.tgz#d0c73beda8eec5426f10ae8ef55cee5e7084c2f0" + integrity sha512-IQboUWM4eKzWW+N/jij2sRatKMh99QEelo3Eb2q0qXkvPRISAj8Qxtmw5itwqK+TTkBuUIE45AxYPToqPtL5gg== + dependencies: + "@webassemblyjs/ast" "1.11.1" + "@xtuc/long" "4.2.2" + +"@webpack-cli/configtest@^1.1.1": + version "1.1.1" + resolved "https://registry.yarnpkg.com/@webpack-cli/configtest/-/configtest-1.1.1.tgz#9f53b1b7946a6efc2a749095a4f450e2932e8356" + integrity sha512-1FBc1f9G4P/AxMqIgfZgeOTuRnwZMten8E7zap5zgpPInnCrP8D4Q81+4CWIch8i/Nf7nXjP0v6CjjbHOrXhKg== + +"@webpack-cli/info@^1.4.1": + version "1.4.1" + resolved "https://registry.yarnpkg.com/@webpack-cli/info/-/info-1.4.1.tgz#2360ea1710cbbb97ff156a3f0f24556e0fc1ebea" + integrity sha512-PKVGmazEq3oAo46Q63tpMr4HipI3OPfP7LiNOEJg963RMgT0rqheag28NCML0o3GIzA3DmxP1ZIAv9oTX1CUIA== + dependencies: + envinfo "^7.7.3" + +"@webpack-cli/serve@^1.6.1": + version "1.6.1" + resolved "https://registry.yarnpkg.com/@webpack-cli/serve/-/serve-1.6.1.tgz#0de2875ac31b46b6c5bb1ae0a7d7f0ba5678dffe" + integrity sha512-gNGTiTrjEVQ0OcVnzsRSqTxaBSr+dmTfm+qJsCDluky8uhdLWep7Gcr62QsAKHTMxjCS/8nEITsmFAhfIx+QSw== + +"@xtuc/ieee754@^1.2.0": + version "1.2.0" + resolved "https://registry.yarnpkg.com/@xtuc/ieee754/-/ieee754-1.2.0.tgz#eef014a3145ae477a1cbc00cd1e552336dceb790" + integrity sha512-DX8nKgqcGwsc0eJSqYt5lwP4DH5FlHnmuWWBRy7X0NcaGR0ZtuyeESgMwTYVEtxmsNGY+qit4QYT/MIYTOTPeA== + +"@xtuc/long@4.2.2": + version "4.2.2" + resolved "https://registry.yarnpkg.com/@xtuc/long/-/long-4.2.2.tgz#d291c6a4e97989b5c61d9acf396ae4fe133a718d" + integrity sha512-NuHqBY1PB/D8xU6s/thBgOAiAP7HOYDQ32+BFZILJ8ivkUkAHQnWfn6WhL79Owj1qmUnoN/YPhktdIoucipkAQ== + +accepts@~1.3.4, accepts@~1.3.5, accepts@~1.3.8: + version "1.3.8" + resolved "https://registry.yarnpkg.com/accepts/-/accepts-1.3.8.tgz#0bf0be125b67014adcb0b0921e62db7bffe16b2e" + integrity sha512-PYAthTa2m2VKxuvSD3DPC/Gy+U+sOA1LAuT8mkmRuvw+NACSaeXEQ+NHcVF7rONl6qcaxV3Uuemwawk+7+SJLw== + dependencies: + mime-types "~2.1.34" + negotiator "0.6.3" + +acorn-import-assertions@^1.7.6: + version "1.8.0" + resolved "https://registry.yarnpkg.com/acorn-import-assertions/-/acorn-import-assertions-1.8.0.tgz#ba2b5939ce62c238db6d93d81c9b111b29b855e9" + integrity sha512-m7VZ3jwz4eK6A4Vtt8Ew1/mNbP24u0FhdyfA7fSvnJR6LMdfOYnmuIrrJAgrYfYJ10F/otaHTtrtrtmHdMNzEw== + +acorn@^8.4.1, acorn@^8.5.0: + version "8.7.0" + resolved "https://registry.yarnpkg.com/acorn/-/acorn-8.7.0.tgz#90951fde0f8f09df93549481e5fc141445b791cf" + integrity sha512-V/LGr1APy+PXIwKebEWrkZPwoeoF+w1jiOBUmuxuiUIaOHtob8Qc9BTrYo7VuI5fR8tqsy+buA2WFooR5olqvQ== + +aggregate-error@^3.0.0: + version "3.1.0" + resolved "https://registry.yarnpkg.com/aggregate-error/-/aggregate-error-3.1.0.tgz#92670ff50f5359bdb7a3e0d40d0ec30c5737687a" + integrity sha512-4I7Td01quW/RpocfNayFdFVk1qSuoh0E7JrbRJ16nH01HhKFQ88INq9Sd+nd72zqRySlr9BmDA8xlEJ6vJMrYA== + dependencies: + clean-stack "^2.0.0" + indent-string "^4.0.0" + +ajv-formats@^2.1.1: + version "2.1.1" + resolved "https://registry.yarnpkg.com/ajv-formats/-/ajv-formats-2.1.1.tgz#6e669400659eb74973bbf2e33327180a0996b520" + integrity sha512-Wx0Kx52hxE7C18hkMEggYlEifqWZtYaRgouJor+WMdPnQyEK13vgEWyVNup7SoeeoLMsr4kf5h6dOW11I15MUA== + dependencies: + ajv "^8.0.0" + +ajv-keywords@^3.5.2: + version "3.5.2" + resolved "https://registry.yarnpkg.com/ajv-keywords/-/ajv-keywords-3.5.2.tgz#31f29da5ab6e00d1c2d329acf7b5929614d5014d" + integrity sha512-5p6WTN0DdTGVQk6VjcEju19IgaHudalcfabD7yhDGeA6bcQnmL+CpveLJq/3hvfwd1aof6L386Ougkx6RfyMIQ== + +ajv-keywords@^5.0.0: + version "5.1.0" + resolved "https://registry.yarnpkg.com/ajv-keywords/-/ajv-keywords-5.1.0.tgz#69d4d385a4733cdbeab44964a1170a88f87f0e16" + integrity sha512-YCS/JNFAUyr5vAuhk1DWm1CBxRHW9LbJ2ozWeemrIqpbsqKjHVxYPyi5GC0rjZIT5JxJ3virVTS8wk4i/Z+krw== + dependencies: + fast-deep-equal "^3.1.3" + +ajv@^6.12.5: + version "6.12.6" + resolved "https://registry.yarnpkg.com/ajv/-/ajv-6.12.6.tgz#baf5a62e802b07d977034586f8c3baf5adf26df4" + integrity sha512-j3fVLgvTo527anyYyJOGTYJbG+vnnQYvE0m5mmkc1TK+nxAppkCLMIL0aZ4dblVCNoGShhm+kzE4ZUykBoMg4g== + dependencies: + fast-deep-equal "^3.1.1" + fast-json-stable-stringify "^2.0.0" + json-schema-traverse "^0.4.1" + uri-js "^4.2.2" + +ajv@^8.0.0, ajv@^8.8.0: + version "8.10.0" + resolved "https://registry.yarnpkg.com/ajv/-/ajv-8.10.0.tgz#e573f719bd3af069017e3b66538ab968d040e54d" + integrity sha512-bzqAEZOjkrUMl2afH8dknrq5KEk2SrwdBROR+vH1EKVQTqaUbJVPdc/gEdggTMM0Se+s+Ja4ju4TlNcStKl2Hw== + dependencies: + fast-deep-equal "^3.1.1" + json-schema-traverse "^1.0.0" + require-from-string "^2.0.2" + uri-js "^4.2.2" + +ansi-html-community@^0.0.8: + version "0.0.8" + resolved "https://registry.yarnpkg.com/ansi-html-community/-/ansi-html-community-0.0.8.tgz#69fbc4d6ccbe383f9736934ae34c3f8290f1bf41" + integrity sha512-1APHAyr3+PCamwNw3bXCPp4HFLONZt/yIH0sZp0/469KWNTEy+qN5jQ3GVX6DMZ1UXAi34yVwtTeaG/HpBuuzw== + +ansi-regex@^5.0.1: + version "5.0.1" + resolved "https://registry.yarnpkg.com/ansi-regex/-/ansi-regex-5.0.1.tgz#082cb2c89c9fe8659a311a53bd6a4dc5301db304" + integrity sha512-quJQXlTSUGL2LH9SUXo8VwsY4soanhgo6LNSm84E1LBcE8s3O0wpdiRzyR9z/ZZJMlMWv37qOOb9pdJlMUEKFQ== + +ansi-regex@^6.0.1: + version "6.0.1" + resolved "https://registry.yarnpkg.com/ansi-regex/-/ansi-regex-6.0.1.tgz#3183e38fae9a65d7cb5e53945cd5897d0260a06a" + integrity sha512-n5M855fKb2SsfMIiFFoVrABHJC8QtHwVx+mHWP3QcEqBHYienj5dHSgjbxtC0WEZXYt4wcD6zrQElDPhFuZgfA== + +ansi-styles@^4.1.0: + version "4.3.0" + resolved "https://registry.yarnpkg.com/ansi-styles/-/ansi-styles-4.3.0.tgz#edd803628ae71c04c85ae7a0906edad34b648937" + integrity sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg== + dependencies: + color-convert "^2.0.1" + +antd@^4.17.0: + version "4.19.0" + resolved "https://registry.yarnpkg.com/antd/-/antd-4.19.0.tgz#1c637a4d7dde091a2299260ca89f05c29fb21f80" + integrity sha512-4Kp47+zg3j1g1lWmzFstGrmlGdHzUIvxAVXxYJKJqX+iQs++QYgcK2HF+9PBpwEwP6H6VPZCsL0LqKEflke5qg== + dependencies: + "@ant-design/colors" "^6.0.0" + "@ant-design/icons" "^4.7.0" + "@ant-design/react-slick" "~0.28.1" + "@babel/runtime" "^7.12.5" + "@ctrl/tinycolor" "^3.4.0" + classnames "^2.2.6" + copy-to-clipboard "^3.2.0" + lodash "^4.17.21" + memoize-one "^6.0.0" + moment "^2.25.3" + rc-cascader "~3.2.1" + rc-checkbox "~2.3.0" + rc-collapse "~3.1.0" + rc-dialog "~8.6.0" + rc-drawer "~4.4.2" + rc-dropdown "~3.3.2" + rc-field-form "~1.23.0" + rc-image "~5.2.5" + rc-input "^0.0.1-alpha.5" + rc-input-number "~7.3.0" + rc-mentions "~1.6.1" + rc-menu "~9.2.1" + rc-motion "^2.4.4" + rc-notification "~4.5.7" + rc-pagination "~3.1.9" + rc-picker "~2.6.4" + rc-progress "~3.2.1" + rc-rate "~2.9.0" + rc-resize-observer "^1.2.0" + rc-select "~14.0.0-alpha.15" + rc-slider "~10.0.0-alpha.4" + rc-steps "~4.1.0" + rc-switch "~3.2.0" + rc-table "~7.23.0" + rc-tabs "~11.10.0" + rc-textarea "~0.3.0" + rc-tooltip "~5.1.1" + rc-tree "~5.4.3" + rc-tree-select "~5.1.1" + rc-trigger "^5.2.10" + rc-upload "~4.3.0" + rc-util "^5.14.0" + scroll-into-view-if-needed "^2.2.25" + +anymatch@~3.1.2: + version "3.1.2" + resolved "https://registry.yarnpkg.com/anymatch/-/anymatch-3.1.2.tgz#c0557c096af32f106198f4f4e2a383537e378716" + integrity sha512-P43ePfOAIupkguHUycrc4qJ9kz8ZiuOUijaETwX7THt0Y/GNK7v0aa8rY816xWjZ7rJdA5XdMcpVFTKMq+RvWg== + dependencies: + normalize-path "^3.0.0" + picomatch "^2.0.4" + +array-flatten@1.1.1: + version "1.1.1" + resolved "https://registry.yarnpkg.com/array-flatten/-/array-flatten-1.1.1.tgz#9a5f699051b1e7073328f2a008968b64ea2955d2" + integrity sha1-ml9pkFGx5wczKPKgCJaLZOopVdI= + +array-flatten@^2.1.0: + version "2.1.2" + resolved "https://registry.yarnpkg.com/array-flatten/-/array-flatten-2.1.2.tgz#24ef80a28c1a893617e2149b0c6d0d788293b099" + integrity sha512-hNfzcOV8W4NdualtqBFPyVO+54DSJuZGY9qT4pRroB6S9e3iiido2ISIC5h9R2sPJ8H3FHCIiEnsv1lPXO3KtQ== + +array-tree-filter@^2.1.0: + version "2.1.0" + resolved "https://registry.yarnpkg.com/array-tree-filter/-/array-tree-filter-2.1.0.tgz#873ac00fec83749f255ac8dd083814b4f6329190" + integrity sha512-4ROwICNlNw/Hqa9v+rk5h22KjmzB1JGTMVKP2AKJBOCgb0yL0ASf0+YvCcLNNwquOHNX48jkeZIJ3a+oOQqKcw== + +array-union@^2.1.0: + version "2.1.0" + resolved "https://registry.yarnpkg.com/array-union/-/array-union-2.1.0.tgz#b798420adbeb1de828d84acd8a2e23d3efe85e8d" + integrity sha512-HGyxoOTYUyCM6stUe6EJgnd4EoewAI7zMdfqO+kGjnlZmBDz/cR5pf8r/cR4Wq60sL/p0IkcjUEEPwS3GFrIyw== + +async-validator@^4.0.2: + version "4.0.7" + resolved "https://registry.yarnpkg.com/async-validator/-/async-validator-4.0.7.tgz#034a0fd2103a6b2ebf010da75183bec299247afe" + integrity sha512-Pj2IR7u8hmUEDOwB++su6baaRi+QvsgajuFB9j95foM1N2gy5HM4z60hfusIO0fBPG5uLAEl6yCJr1jNSVugEQ== + +async@^2.6.2: + version "2.6.3" + resolved "https://registry.yarnpkg.com/async/-/async-2.6.3.tgz#d72625e2344a3656e3a3ad4fa749fa83299d82ff" + integrity sha512-zflvls11DCy+dQWzTW2dzuilv8Z5X/pjfmZOWba6TNIVDm+2UDaJmXSOXlasHKfNBs8oo3M0aT50fDEWfKZjXg== + dependencies: + lodash "^4.17.14" + +balanced-match@^1.0.0: + version "1.0.2" + resolved "https://registry.yarnpkg.com/balanced-match/-/balanced-match-1.0.2.tgz#e83e3a7e3f300b34cb9d87f615fa0cbf357690ee" + integrity sha512-3oSeUO0TMV67hN1AmbXsK4yaqU7tjiHlbxRDZOpH0KW9+CeX4bRAaX0Anxt0tx2MrpRpWwQaPwIlISEJhYU5Pw== + +batch@0.6.1: + version "0.6.1" + resolved "https://registry.yarnpkg.com/batch/-/batch-0.6.1.tgz#dc34314f4e679318093fc760272525f94bf25c16" + integrity sha1-3DQxT05nkxgJP8dgJyUl+UvyXBY= + +big.js@^5.2.2: + version "5.2.2" + resolved "https://registry.yarnpkg.com/big.js/-/big.js-5.2.2.tgz#65f0af382f578bcdc742bd9c281e9cb2d7768328" + integrity sha512-vyL2OymJxmarO8gxMr0mhChsO9QGwhynfuu4+MHTAW6czfq9humCB7rKpUjDd9YUiDPU4mzpyupFSvOClAwbmQ== + +binary-extensions@^2.0.0: + version "2.2.0" + resolved "https://registry.yarnpkg.com/binary-extensions/-/binary-extensions-2.2.0.tgz#75f502eeaf9ffde42fc98829645be4ea76bd9e2d" + integrity sha512-jDctJ/IVQbZoJykoeHbhXpOlNBqGNcwXJKJog42E5HDPUwQTSdjCHdihjj0DlnheQ7blbT6dHOafNAiS8ooQKA== + +body-parser@1.19.2: + version "1.19.2" + resolved "https://registry.yarnpkg.com/body-parser/-/body-parser-1.19.2.tgz#4714ccd9c157d44797b8b5607d72c0b89952f26e" + integrity sha512-SAAwOxgoCKMGs9uUAUFHygfLAyaniaoun6I8mFY9pRAJL9+Kec34aU+oIjDhTycub1jozEfEwx1W1IuOYxVSFw== + dependencies: + bytes "3.1.2" + content-type "~1.0.4" + debug "2.6.9" + depd "~1.1.2" + http-errors "1.8.1" + iconv-lite "0.4.24" + on-finished "~2.3.0" + qs "6.9.7" + raw-body "2.4.3" + type-is "~1.6.18" + +bonjour@^3.5.0: + version "3.5.0" + resolved "https://registry.yarnpkg.com/bonjour/-/bonjour-3.5.0.tgz#8e890a183d8ee9a2393b3844c691a42bcf7bc9f5" + integrity sha1-jokKGD2O6aI5OzhExpGkK897yfU= + dependencies: + array-flatten "^2.1.0" + deep-equal "^1.0.1" + dns-equal "^1.0.0" + dns-txt "^2.0.2" + multicast-dns "^6.0.1" + multicast-dns-service-types "^1.1.0" + +boolbase@^1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/boolbase/-/boolbase-1.0.0.tgz#68dff5fbe60c51eb37725ea9e3ed310dcc1e776e" + integrity sha1-aN/1++YMUes3cl6p4+0xDcwed24= + +brace-expansion@^1.1.7: + version "1.1.11" + resolved "https://registry.yarnpkg.com/brace-expansion/-/brace-expansion-1.1.11.tgz#3c7fcbf529d87226f3d2f52b966ff5271eb441dd" + integrity sha512-iCuPHDFgrHX7H2vEI/5xpz07zSHB00TpugqhmYtVmMO6518mCuRMoOYFldEBl0g187ufozdaHgWKcYFb61qGiA== + dependencies: + balanced-match "^1.0.0" + concat-map "0.0.1" + +braces@^3.0.1, braces@~3.0.2: + version "3.0.2" + resolved "https://registry.yarnpkg.com/braces/-/braces-3.0.2.tgz#3454e1a462ee8d599e236df336cd9ea4f8afe107" + integrity sha512-b8um+L1RzM3WDSzvhm6gIz1yfTbBt6YTlcEKAvsmqCZZFw46z626lVj9j1yEPW33H5H+lBQpZMP1k8l+78Ha0A== + dependencies: + fill-range "^7.0.1" + +browserslist@^4.14.5, browserslist@^4.16.5: + version "4.20.0" + resolved "https://registry.yarnpkg.com/browserslist/-/browserslist-4.20.0.tgz#35951e3541078c125d36df76056e94738a52ebe9" + integrity sha512-bnpOoa+DownbciXj0jVGENf8VYQnE2LNWomhYuCsMmmx9Jd9lwq0WXODuwpSsp8AVdKM2/HorrzxAfbKvWTByQ== + dependencies: + caniuse-lite "^1.0.30001313" + electron-to-chromium "^1.4.76" + escalade "^3.1.1" + node-releases "^2.0.2" + picocolors "^1.0.0" + +buffer-from@^1.0.0: + version "1.1.2" + resolved "https://registry.yarnpkg.com/buffer-from/-/buffer-from-1.1.2.tgz#2b146a6fd72e80b4f55d255f35ed59a3a9a41bd5" + integrity sha512-E+XQCRwSbaaiChtv6k6Dwgc+bx+Bs6vuKJHHl5kox/BaKbhiXzqQOwK4cO22yElGp2OCmjwVhT3HmxgyPGnJfQ== + +buffer-indexof@^1.0.0: + version "1.1.1" + resolved "https://registry.yarnpkg.com/buffer-indexof/-/buffer-indexof-1.1.1.tgz#52fabcc6a606d1a00302802648ef68f639da268c" + integrity sha512-4/rOEg86jivtPTeOUUT61jJO1Ya1TrR/OkqCSZDyq84WJh3LuuiphBYJN+fm5xufIk4XAFcEwte/8WzC8If/1g== + +bytes@3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/bytes/-/bytes-3.0.0.tgz#d32815404d689699f85a4ea4fa8755dd13a96048" + integrity sha1-0ygVQE1olpn4Wk6k+odV3ROpYEg= + +bytes@3.1.2: + version "3.1.2" + resolved "https://registry.yarnpkg.com/bytes/-/bytes-3.1.2.tgz#8b0beeb98605adf1b128fa4386403c009e0221a5" + integrity sha512-/Nf7TyzTx6S3yRJObOAV7956r8cr2+Oj8AC5dt8wSP3BQAoeX58NoHyCU8P8zGkNXStjTSi6fzO6F0pBdcYbEg== + +call-bind@^1.0.2: + version "1.0.2" + resolved "https://registry.yarnpkg.com/call-bind/-/call-bind-1.0.2.tgz#b1d4e89e688119c3c9a903ad30abb2f6a919be3c" + integrity sha512-7O+FbCihrB5WGbFYesctwmTKae6rOiIzmz1icreWJ+0aA7LJfuqhEso2T9ncpcFtzMQtzXf2QGGueWJGTYsqrA== + dependencies: + function-bind "^1.1.1" + get-intrinsic "^1.0.2" + +camel-case@^4.1.2: + version "4.1.2" + resolved "https://registry.yarnpkg.com/camel-case/-/camel-case-4.1.2.tgz#9728072a954f805228225a6deea6b38461e1bd5a" + integrity sha512-gxGWBrTT1JuMx6R+o5PTXMmUnhnVzLQ9SNutD4YqKtI6ap897t3tKECYla6gCWEkplXnlNybEkZg9GEGxKFCgw== + dependencies: + pascal-case "^3.1.2" + tslib "^2.0.3" + +caniuse-lite@^1.0.30001313: + version "1.0.30001313" + resolved "https://registry.yarnpkg.com/caniuse-lite/-/caniuse-lite-1.0.30001313.tgz#a380b079db91621e1b7120895874e2fd62ed2e2f" + integrity sha512-rI1UN0koZUiKINjysQDuRi2VeSCce3bYJNmDcj3PIKREiAmjakugBul1QSkg/fPrlULYl6oWfGg3PbgOSY9X4Q== + +chalk@^4.1.0: + version "4.1.2" + resolved "https://registry.yarnpkg.com/chalk/-/chalk-4.1.2.tgz#aac4e2b7734a740867aeb16bf02aad556a1e7a01" + integrity sha512-oKnbhFyRIXpUuez8iBMmyEa4nbj4IOQyuhc/wy9kY7/WVPcwIO9VA668Pu8RkO7+0G76SLROeyw9CpQ061i4mA== + dependencies: + ansi-styles "^4.1.0" + supports-color "^7.1.0" + +chokidar@^3.5.3: + version "3.5.3" + resolved "https://registry.yarnpkg.com/chokidar/-/chokidar-3.5.3.tgz#1cf37c8707b932bd1af1ae22c0432e2acd1903bd" + integrity sha512-Dr3sfKRP6oTcjf2JmUmFJfeVMvXBdegxB0iVQ5eb2V10uFJUCAS8OByZdVAyVb8xXNz3GjjTgj9kLWsZTqE6kw== + dependencies: + anymatch "~3.1.2" + braces "~3.0.2" + glob-parent "~5.1.2" + is-binary-path "~2.1.0" + is-glob "~4.0.1" + normalize-path "~3.0.0" + readdirp "~3.6.0" + optionalDependencies: + fsevents "~2.3.2" + +chrome-trace-event@^1.0.2: + version "1.0.3" + resolved "https://registry.yarnpkg.com/chrome-trace-event/-/chrome-trace-event-1.0.3.tgz#1015eced4741e15d06664a957dbbf50d041e26ac" + integrity sha512-p3KULyQg4S7NIHixdwbGX+nFHkoBiA4YQmyWtjb8XngSKV124nJmRysgAeujbUVb15vh+RvFUfCPqU7rXk+hZg== + +classnames@2.x, classnames@^2.2.1, classnames@^2.2.3, classnames@^2.2.5, classnames@^2.2.6, classnames@^2.3.1: + version "2.3.1" + resolved "https://registry.yarnpkg.com/classnames/-/classnames-2.3.1.tgz#dfcfa3891e306ec1dad105d0e88f4417b8535e8e" + integrity sha512-OlQdbZ7gLfGarSqxesMesDa5uz7KFbID8Kpq/SxIoNGDqY8lSYs0D+hhtBXhcdB3rcbXArFr7vlHheLk1voeNA== + +clean-css@^5.2.2: + version "5.2.4" + resolved "https://registry.yarnpkg.com/clean-css/-/clean-css-5.2.4.tgz#982b058f8581adb2ae062520808fb2429bd487a4" + integrity sha512-nKseG8wCzEuji/4yrgM/5cthL9oTDc5UOQyFMvW/Q53oP6gLH690o1NbuTh6Y18nujr7BxlsFuS7gXLnLzKJGg== + dependencies: + source-map "~0.6.0" + +clean-stack@^2.0.0: + version "2.2.0" + resolved "https://registry.yarnpkg.com/clean-stack/-/clean-stack-2.2.0.tgz#ee8472dbb129e727b31e8a10a427dee9dfe4008b" + integrity sha512-4diC9HaTE+KRAMWhDhrGOECgWZxoevMc5TlkObMqNSsVU62PYzXZ/SMTjzyGAFF1YusgxGcSWTEXBhp0CPwQ1A== + +clone-deep@^4.0.1: + version "4.0.1" + resolved "https://registry.yarnpkg.com/clone-deep/-/clone-deep-4.0.1.tgz#c19fd9bdbbf85942b4fd979c84dcf7d5f07c2387" + integrity sha512-neHB9xuzh/wk0dIHweyAXv2aPGZIVk3pLMe+/RNzINf17fe0OG96QroktYAUm7SM1PBnzTabaLboqqxDyMU+SQ== + dependencies: + is-plain-object "^2.0.4" + kind-of "^6.0.2" + shallow-clone "^3.0.0" + +clsx@^1.0.4, clsx@^1.1.1: + version "1.1.1" + resolved "https://registry.yarnpkg.com/clsx/-/clsx-1.1.1.tgz#98b3134f9abbdf23b2663491ace13c5c03a73188" + integrity sha512-6/bPho624p3S2pMyvP5kKBPXnI3ufHLObBFCfgx+LkeR5lg2XYy2hqZqUf45ypD8COn2bhgGJSUE+l5dhNBieA== + +color-convert@^2.0.1: + version "2.0.1" + resolved "https://registry.yarnpkg.com/color-convert/-/color-convert-2.0.1.tgz#72d3a68d598c9bdb3af2ad1e84f21d896abd4de3" + integrity sha512-RRECPsj7iu/xb5oKYcsFHSppFNnsj/52OVTRKb4zP5onXwVF3zVmmToNcOfGC+CRDpfK/U584fMg38ZHCaElKQ== + dependencies: + color-name "~1.1.4" + +color-name@~1.1.4: + version "1.1.4" + resolved "https://registry.yarnpkg.com/color-name/-/color-name-1.1.4.tgz#c2a09a87acbde69543de6f63fa3995c826c536a2" + integrity sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA== + +colorette@^2.0.10, colorette@^2.0.14: + version "2.0.16" + resolved "https://registry.yarnpkg.com/colorette/-/colorette-2.0.16.tgz#713b9af84fdb000139f04546bd4a93f62a5085da" + integrity sha512-hUewv7oMjCp+wkBv5Rm0v87eJhq4woh5rSR+42YSQJKecCqgIqNkZ6lAlQms/BwHPJA5NKMRlpxPRv0n8HQW6g== + +commander@^2.20.0: + version "2.20.3" + resolved "https://registry.yarnpkg.com/commander/-/commander-2.20.3.tgz#fd485e84c03eb4881c20722ba48035e8531aeb33" + integrity sha512-GpVkmM8vF2vQUkj2LvZmD35JxeJOLCwJ9cUkugyk2nuhbv3+mJvpLYYt+0+USMxE+oj+ey/lJEnhZw75x/OMcQ== + +commander@^7.0.0: + version "7.2.0" + resolved "https://registry.yarnpkg.com/commander/-/commander-7.2.0.tgz#a36cb57d0b501ce108e4d20559a150a391d97ab7" + integrity sha512-QrWXB+ZQSVPmIWIhtEO9H+gwHaMGYiF5ChvoJ+K9ZGHG/sVsa6yiesAD1GC/x46sET00Xlwo1u49RVVVzvcSkw== + +commander@^8.3.0: + version "8.3.0" + resolved "https://registry.yarnpkg.com/commander/-/commander-8.3.0.tgz#4837ea1b2da67b9c616a67afbb0fafee567bca66" + integrity sha512-OkTL9umf+He2DZkUq8f8J9of7yL6RJKI24dVITBmNfZBmri9zYZQrKkuXiKhyfPSu8tUhnVBB1iKXevvnlR4Ww== + +compressible@~2.0.16: + version "2.0.18" + resolved "https://registry.yarnpkg.com/compressible/-/compressible-2.0.18.tgz#af53cca6b070d4c3c0750fbd77286a6d7cc46fba" + integrity sha512-AF3r7P5dWxL8MxyITRMlORQNaOA2IkAFaTr4k7BUumjPtRpGDTZpl0Pb1XCO6JeDCBdp126Cgs9sMxqSjgYyRg== + dependencies: + mime-db ">= 1.43.0 < 2" + +compression@^1.7.4: + version "1.7.4" + resolved "https://registry.yarnpkg.com/compression/-/compression-1.7.4.tgz#95523eff170ca57c29a0ca41e6fe131f41e5bb8f" + integrity sha512-jaSIDzP9pZVS4ZfQ+TzvtiWhdpFhE2RDHz8QJkpX9SIpLq88VueF5jJw6t+6CUQcAoA6t+x89MLrWAqpfDE8iQ== + dependencies: + accepts "~1.3.5" + bytes "3.0.0" + compressible "~2.0.16" + debug "2.6.9" + on-headers "~1.0.2" + safe-buffer "5.1.2" + vary "~1.1.2" + +compute-scroll-into-view@^1.0.17: + version "1.0.17" + resolved "https://registry.yarnpkg.com/compute-scroll-into-view/-/compute-scroll-into-view-1.0.17.tgz#6a88f18acd9d42e9cf4baa6bec7e0522607ab7ab" + integrity sha512-j4dx+Fb0URmzbwwMUrhqWM2BEWHdFGx+qZ9qqASHRPqvTYdqvWnHg0H1hIbcyLnvgnoNAVMlwkepyqM3DaIFUg== + +concat-map@0.0.1: + version "0.0.1" + resolved "https://registry.yarnpkg.com/concat-map/-/concat-map-0.0.1.tgz#d8a96bd77fd68df7793a73036a3ba0d5405d477b" + integrity sha1-2Klr13/Wjfd5OnMDajug1UBdR3s= + +connect-history-api-fallback@^1.6.0: + version "1.6.0" + resolved "https://registry.yarnpkg.com/connect-history-api-fallback/-/connect-history-api-fallback-1.6.0.tgz#8b32089359308d111115d81cad3fceab888f97bc" + integrity sha512-e54B99q/OUoH64zYYRf3HBP5z24G38h5D3qXu23JGRoigpX5Ss4r9ZnDk3g0Z8uQC2x2lPaJ+UlWBc1ZWBWdLg== + +content-disposition@0.5.4: + version "0.5.4" + resolved "https://registry.yarnpkg.com/content-disposition/-/content-disposition-0.5.4.tgz#8b82b4efac82512a02bb0b1dcec9d2c5e8eb5bfe" + integrity sha512-FveZTNuGw04cxlAiWbzi6zTAL/lhehaWbTtgluJh4/E95DqMwTmha3KZN1aAWA8cFIhHzMZUvLevkw5Rqk+tSQ== + dependencies: + safe-buffer "5.2.1" + +content-type@~1.0.4: + version "1.0.4" + resolved "https://registry.yarnpkg.com/content-type/-/content-type-1.0.4.tgz#e138cc75e040c727b1966fe5e5f8c9aee256fe3b" + integrity sha512-hIP3EEPs8tB9AT1L+NUqtwOAps4mk2Zob89MWXMHjHWg9milF/j4osnnQLXBCBFBk/tvIG/tUc9mOUJiPBhPXA== + +cookie-signature@1.0.6: + version "1.0.6" + resolved "https://registry.yarnpkg.com/cookie-signature/-/cookie-signature-1.0.6.tgz#e303a882b342cc3ee8ca513a79999734dab3ae2c" + integrity sha1-4wOogrNCzD7oylE6eZmXNNqzriw= + +cookie@0.4.2: + version "0.4.2" + resolved "https://registry.yarnpkg.com/cookie/-/cookie-0.4.2.tgz#0e41f24de5ecf317947c82fc789e06a884824432" + integrity sha512-aSWTXFzaKWkvHO1Ny/s+ePFpvKsPnjc551iI41v3ny/ow6tBG5Vd+FuqGNhh1LxOmVzOlGUriIlOaokOvhaStA== + +copy-to-clipboard@^3.2.0: + version "3.3.1" + resolved "https://registry.yarnpkg.com/copy-to-clipboard/-/copy-to-clipboard-3.3.1.tgz#115aa1a9998ffab6196f93076ad6da3b913662ae" + integrity sha512-i13qo6kIHTTpCm8/Wup+0b1mVWETvu2kIMzKoK8FpkLkFxlt0znUAHcMzox+T8sPlqtZXq3CulEjQHsYiGFJUw== + dependencies: + toggle-selection "^1.0.6" + +core-util-is@~1.0.0: + version "1.0.3" + resolved "https://registry.yarnpkg.com/core-util-is/-/core-util-is-1.0.3.tgz#a6042d3634c2b27e9328f837b965fac83808db85" + integrity sha512-ZQBvi1DcpJ4GDqanjucZ2Hj3wEO5pZDS89BWbkcrvdxksJorwUDDZamX9ldFkp9aw2lmBDLgkObEA4DWNJ9FYQ== + +cross-env@^7.0.2: + version "7.0.3" + resolved "https://registry.yarnpkg.com/cross-env/-/cross-env-7.0.3.tgz#865264b29677dc015ba8418918965dd232fc54cf" + integrity sha512-+/HKd6EgcQCJGh2PSjZuUitQBQynKor4wrFbRg4DtAgS1aWO+gU52xpH7M9ScGgXSYmAVS9bIJ8EzuaGw0oNAw== + dependencies: + cross-spawn "^7.0.1" + +cross-spawn@^7.0.1, cross-spawn@^7.0.3: + version "7.0.3" + resolved "https://registry.yarnpkg.com/cross-spawn/-/cross-spawn-7.0.3.tgz#f73a85b9d5d41d045551c177e2882d4ac85728a6" + integrity sha512-iRDPJKUPVEND7dHPO8rkbOnPpyDygcDFtWjpeWNCgy8WP2rXcxXL8TskReQl6OrB2G7+UJrags1q15Fudc7G6w== + dependencies: + path-key "^3.1.0" + shebang-command "^2.0.0" + which "^2.0.1" + +css-loader@^5.2.4: + version "5.2.7" + resolved "https://registry.yarnpkg.com/css-loader/-/css-loader-5.2.7.tgz#9b9f111edf6fb2be5dc62525644cbc9c232064ae" + integrity sha512-Q7mOvpBNBG7YrVGMxRxcBJZFL75o+cH2abNASdibkj/fffYD8qWbInZrD0S9ccI6vZclF3DsHE7njGlLtaHbhg== + dependencies: + icss-utils "^5.1.0" + loader-utils "^2.0.0" + postcss "^8.2.15" + postcss-modules-extract-imports "^3.0.0" + postcss-modules-local-by-default "^4.0.0" + postcss-modules-scope "^3.0.0" + postcss-modules-values "^4.0.0" + postcss-value-parser "^4.1.0" + schema-utils "^3.0.0" + semver "^7.3.5" + +css-select@^4.1.3: + version "4.2.1" + resolved "https://registry.yarnpkg.com/css-select/-/css-select-4.2.1.tgz#9e665d6ae4c7f9d65dbe69d0316e3221fb274cdd" + integrity sha512-/aUslKhzkTNCQUB2qTX84lVmfia9NyjP3WpDGtj/WxhwBzWBYUV3DgUpurHTme8UTPcPlAD1DJ+b0nN/t50zDQ== + dependencies: + boolbase "^1.0.0" + css-what "^5.1.0" + domhandler "^4.3.0" + domutils "^2.8.0" + nth-check "^2.0.1" + +css-vendor@^2.0.8: + version "2.0.8" + resolved "https://registry.yarnpkg.com/css-vendor/-/css-vendor-2.0.8.tgz#e47f91d3bd3117d49180a3c935e62e3d9f7f449d" + integrity sha512-x9Aq0XTInxrkuFeHKbYC7zWY8ai7qJ04Kxd9MnvbC1uO5DagxoHQjm4JvG+vCdXOoFtCjbL2XSZfxmoYa9uQVQ== + dependencies: + "@babel/runtime" "^7.8.3" + is-in-browser "^1.0.2" + +css-what@^5.1.0: + version "5.1.0" + resolved "https://registry.yarnpkg.com/css-what/-/css-what-5.1.0.tgz#3f7b707aadf633baf62c2ceb8579b545bb40f7fe" + integrity sha512-arSMRWIIFY0hV8pIxZMEfmMI47Wj3R/aWpZDDxWYCPEiOMv6tfOrnpDtgxBYPEQD4V0Y/958+1TdC3iWTFcUPw== + +cssesc@^3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/cssesc/-/cssesc-3.0.0.tgz#37741919903b868565e1c09ea747445cd18983ee" + integrity sha512-/Tb/JcjK111nNScGob5MNtsntNM1aCNUDipB/TkwZFhyDrrE47SOx/18wF2bbjgc3ZzCSKW1T5nt5EbFoAz/Vg== + +csstype@^2.5.2: + version "2.6.20" + resolved "https://registry.yarnpkg.com/csstype/-/csstype-2.6.20.tgz#9229c65ea0b260cf4d3d997cb06288e36a8d6dda" + integrity sha512-/WwNkdXfckNgw6S5R125rrW8ez139lBHWouiBvX8dfMFtcn6V81REDqnH7+CRpRipfYlyU1CmOnOxrmGcFOjeA== + +csstype@^3.0.2: + version "3.0.11" + resolved "https://registry.yarnpkg.com/csstype/-/csstype-3.0.11.tgz#d66700c5eacfac1940deb4e3ee5642792d85cd33" + integrity sha512-sa6P2wJ+CAbgyy4KFssIb/JNMLxFvKF1pCYCSXS8ZMuqZnMsrxqI2E5sPyoTpxoPU/gVZMzr2zjOfg8GIZOMsw== + +date-fns@2.x: + version "2.28.0" + resolved "https://registry.yarnpkg.com/date-fns/-/date-fns-2.28.0.tgz#9570d656f5fc13143e50c975a3b6bbeb46cd08b2" + integrity sha512-8d35hViGYx/QH0icHYCeLmsLmMUheMmTyV9Fcm6gvNwdw31yXXH+O85sOBJ+OLnLQMKZowvpKb6FgMIQjcpvQw== + +dayjs@1.x: + version "1.10.8" + resolved "https://registry.yarnpkg.com/dayjs/-/dayjs-1.10.8.tgz#267df4bc6276fcb33c04a6735287e3f429abec41" + integrity sha512-wbNwDfBHHur9UOzNUjeKUOJ0fCb0a52Wx0xInmQ7Y8FstyajiV1NmK1e00cxsr9YrE9r7yAChE0VvpuY5Rnlow== + +debug@2.6.9: + version "2.6.9" + resolved "https://registry.yarnpkg.com/debug/-/debug-2.6.9.tgz#5d128515df134ff327e90a4c93f4e077a536341f" + integrity sha512-bC7ElrdJaJnPbAP+1EotYvqZsb3ecl5wi6Bfi6BJTUcNowp6cvspg0jXznRTKDjm/E7AdgFBVeAPVMNcKGsHMA== + dependencies: + ms "2.0.0" + +debug@^3.1.1: + version "3.2.7" + resolved "https://registry.yarnpkg.com/debug/-/debug-3.2.7.tgz#72580b7e9145fb39b6676f9c5e5fb100b934179a" + integrity sha512-CFjzYYAi4ThfiQvizrFQevTTXHtnCqWfe7x1AhgEscTz6ZbLbfoLRLPugTQyBth6f8ZERVUSyWHFD/7Wu4t1XQ== + dependencies: + ms "^2.1.1" + +debug@^4.1.0: + version "4.3.3" + resolved "https://registry.yarnpkg.com/debug/-/debug-4.3.3.tgz#04266e0b70a98d4462e6e288e38259213332b664" + integrity sha512-/zxw5+vh1Tfv+4Qn7a5nsbcJKPaSvCDhojn6FEl9vupwK2VCSDtEiEtqr8DFtzYFOdz63LBkxec7DYuc2jon6Q== + dependencies: + ms "2.1.2" + +deep-equal@^1.0.1: + version "1.1.1" + resolved "https://registry.yarnpkg.com/deep-equal/-/deep-equal-1.1.1.tgz#b5c98c942ceffaf7cb051e24e1434a25a2e6076a" + integrity sha512-yd9c5AdiqVcR+JjcwUQb9DkhJc8ngNr0MahEBGvDiJw8puWab2yZlh+nkasOnZP+EGTAP6rRp2JzJhJZzvNF8g== + dependencies: + is-arguments "^1.0.4" + is-date-object "^1.0.1" + is-regex "^1.0.4" + object-is "^1.0.1" + object-keys "^1.1.1" + regexp.prototype.flags "^1.2.0" + +default-gateway@^6.0.3: + version "6.0.3" + resolved "https://registry.yarnpkg.com/default-gateway/-/default-gateway-6.0.3.tgz#819494c888053bdb743edbf343d6cdf7f2943a71" + integrity sha512-fwSOJsbbNzZ/CUFpqFBqYfYNLj1NbMPm8MMCIzHjC83iSJRBEGmDUxU+WP661BaBQImeC2yHwXtz+P/O9o+XEg== + dependencies: + execa "^5.0.0" + +define-lazy-prop@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/define-lazy-prop/-/define-lazy-prop-2.0.0.tgz#3f7ae421129bcaaac9bc74905c98a0009ec9ee7f" + integrity sha512-Ds09qNh8yw3khSjiJjiUInaGX9xlqZDY7JVryGxdxV7NPeuqQfplOpQ66yJFZut3jLa5zOwkXw1g9EI2uKh4Og== + +define-properties@^1.1.3: + version "1.1.3" + resolved "https://registry.yarnpkg.com/define-properties/-/define-properties-1.1.3.tgz#cf88da6cbee26fe6db7094f61d870cbd84cee9f1" + integrity sha512-3MqfYKj2lLzdMSf8ZIZE/V+Zuy+BgD6f164e8K2w7dgnpKArBDerGYpM46IYYcjnkdPNMjPk9A6VFB8+3SKlXQ== + dependencies: + object-keys "^1.0.12" + +del@^6.0.0: + version "6.0.0" + resolved "https://registry.yarnpkg.com/del/-/del-6.0.0.tgz#0b40d0332cea743f1614f818be4feb717714c952" + integrity sha512-1shh9DQ23L16oXSZKB2JxpL7iMy2E0S9d517ptA1P8iw0alkPtQcrKH7ru31rYtKwF499HkTu+DRzq3TCKDFRQ== + dependencies: + globby "^11.0.1" + graceful-fs "^4.2.4" + is-glob "^4.0.1" + is-path-cwd "^2.2.0" + is-path-inside "^3.0.2" + p-map "^4.0.0" + rimraf "^3.0.2" + slash "^3.0.0" + +depd@~1.1.2: + version "1.1.2" + resolved "https://registry.yarnpkg.com/depd/-/depd-1.1.2.tgz#9bcd52e14c097763e749b274c4346ed2e560b5a9" + integrity sha1-m81S4UwJd2PnSbJ0xDRu0uVgtak= + +destroy@~1.0.4: + version "1.0.4" + resolved "https://registry.yarnpkg.com/destroy/-/destroy-1.0.4.tgz#978857442c44749e4206613e37946205826abd80" + integrity sha1-l4hXRCxEdJ5CBmE+N5RiBYJqvYA= + +detect-node@^2.0.4: + version "2.1.0" + resolved "https://registry.yarnpkg.com/detect-node/-/detect-node-2.1.0.tgz#c9c70775a49c3d03bc2c06d9a73be550f978f8b1" + integrity sha512-T0NIuQpnTvFDATNuHN5roPwSBG83rFsuO+MXXH9/3N1eFbn4wcPjttvjMLEPWJ0RGUYgQE7cGgS3tNxbqCGM7g== + +dir-glob@^3.0.1: + version "3.0.1" + resolved "https://registry.yarnpkg.com/dir-glob/-/dir-glob-3.0.1.tgz#56dbf73d992a4a93ba1584f4534063fd2e41717f" + integrity sha512-WkrWp9GR4KXfKGYzOLmTuGVi1UWFfws377n9cc55/tb6DuqyF6pcQ5AbiHEshaDpY9v6oaSr2XCDidGmMwdzIA== + dependencies: + path-type "^4.0.0" + +dns-equal@^1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/dns-equal/-/dns-equal-1.0.0.tgz#b39e7f1da6eb0a75ba9c17324b34753c47e0654d" + integrity sha1-s55/HabrCnW6nBcySzR1PEfgZU0= + +dns-packet@^1.3.1: + version "1.3.4" + resolved "https://registry.yarnpkg.com/dns-packet/-/dns-packet-1.3.4.tgz#e3455065824a2507ba886c55a89963bb107dec6f" + integrity sha512-BQ6F4vycLXBvdrJZ6S3gZewt6rcrks9KBgM9vrhW+knGRqc8uEdT7fuCwloc7nny5xNoMJ17HGH0R/6fpo8ECA== + dependencies: + ip "^1.1.0" + safe-buffer "^5.0.1" + +dns-txt@^2.0.2: + version "2.0.2" + resolved "https://registry.yarnpkg.com/dns-txt/-/dns-txt-2.0.2.tgz#b91d806f5d27188e4ab3e7d107d881a1cc4642b6" + integrity sha1-uR2Ab10nGI5Ks+fRB9iBocxGQrY= + dependencies: + buffer-indexof "^1.0.0" + +dom-align@^1.7.0: + version "1.12.2" + resolved "https://registry.yarnpkg.com/dom-align/-/dom-align-1.12.2.tgz#0f8164ebd0c9c21b0c790310493cd855892acd4b" + integrity sha512-pHuazgqrsTFrGU2WLDdXxCFabkdQDx72ddkraZNih1KsMcN5qsRSTR9O4VJRlwTPCPb5COYg3LOfiMHHcPInHg== + +dom-converter@^0.2.0: + version "0.2.0" + resolved "https://registry.yarnpkg.com/dom-converter/-/dom-converter-0.2.0.tgz#6721a9daee2e293682955b6afe416771627bb768" + integrity sha512-gd3ypIPfOMr9h5jIKq8E3sHOTCjeirnl0WK5ZdS1AW0Odt0b1PaWaHdJ4Qk4klv+YB9aJBS7mESXjFoDQPu6DA== + dependencies: + utila "~0.4" + +dom-helpers@^5.0.1: + version "5.2.1" + resolved "https://registry.yarnpkg.com/dom-helpers/-/dom-helpers-5.2.1.tgz#d9400536b2bf8225ad98fe052e029451ac40e902" + integrity sha512-nRCa7CK3VTrM2NmGkIy4cbK7IZlgBE/PYMn55rrXefr5xXDP0LdtfPnblFDoVdcAfslJ7or6iqAUnx0CCGIWQA== + dependencies: + "@babel/runtime" "^7.8.7" + csstype "^3.0.2" + +dom-serializer@^1.0.1: + version "1.3.2" + resolved "https://registry.yarnpkg.com/dom-serializer/-/dom-serializer-1.3.2.tgz#6206437d32ceefaec7161803230c7a20bc1b4d91" + integrity sha512-5c54Bk5Dw4qAxNOI1pFEizPSjVsx5+bpJKmL2kPn8JhBUq2q09tTCa3mjijun2NfK78NMouDYNMBkOrPZiS+ig== + dependencies: + domelementtype "^2.0.1" + domhandler "^4.2.0" + entities "^2.0.0" + +domelementtype@^2.0.1, domelementtype@^2.2.0: + version "2.2.0" + resolved "https://registry.yarnpkg.com/domelementtype/-/domelementtype-2.2.0.tgz#9a0b6c2782ed6a1c7323d42267183df9bd8b1d57" + integrity sha512-DtBMo82pv1dFtUmHyr48beiuq792Sxohr+8Hm9zoxklYPfa6n0Z3Byjj2IV7bmr2IyqClnqEQhfgHJJ5QF0R5A== + +domhandler@^4.0.0, domhandler@^4.2.0, domhandler@^4.3.0: + version "4.3.0" + resolved "https://registry.yarnpkg.com/domhandler/-/domhandler-4.3.0.tgz#16c658c626cf966967e306f966b431f77d4a5626" + integrity sha512-fC0aXNQXqKSFTr2wDNZDhsEYjCiYsDWl3D01kwt25hm1YIPyDGHvvi3rw+PLqHAl/m71MaiF7d5zvBr0p5UB2g== + dependencies: + domelementtype "^2.2.0" + +domutils@^2.5.2, domutils@^2.8.0: + version "2.8.0" + resolved "https://registry.yarnpkg.com/domutils/-/domutils-2.8.0.tgz#4437def5db6e2d1f5d6ee859bd95ca7d02048135" + integrity sha512-w96Cjofp72M5IIhpjgobBimYEfoPjx1Vx0BSX9P30WBdZW2WIKU0T1Bd0kz2eNZ9ikjKgHbEyKx8BB6H1L3h3A== + dependencies: + dom-serializer "^1.0.1" + domelementtype "^2.2.0" + domhandler "^4.2.0" + +dot-case@^3.0.4: + version "3.0.4" + resolved "https://registry.yarnpkg.com/dot-case/-/dot-case-3.0.4.tgz#9b2b670d00a431667a8a75ba29cd1b98809ce751" + integrity sha512-Kv5nKlh6yRrdrGvxeJ2e5y2eRUpkUosIW4A2AS38zwSz27zu7ufDwQPi5Jhs3XAlGNetl3bmnGhQsMtkKJnj3w== + dependencies: + no-case "^3.0.4" + tslib "^2.0.3" + +ee-first@1.1.1: + version "1.1.1" + resolved "https://registry.yarnpkg.com/ee-first/-/ee-first-1.1.1.tgz#590c61156b0ae2f4f0255732a158b266bc56b21d" + integrity sha1-WQxhFWsK4vTwJVcyoViyZrxWsh0= + +electron-to-chromium@^1.4.76: + version "1.4.76" + resolved "https://registry.yarnpkg.com/electron-to-chromium/-/electron-to-chromium-1.4.76.tgz#a0494baedaf51094b1c172999919becd9975a934" + integrity sha512-3Vftv7cenJtQb+k00McEBZ2vVmZ/x+HEF7pcZONZIkOsESqAqVuACmBxMv0JhzX7u0YltU0vSqRqgBSTAhFUjA== + +emojis-list@^3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/emojis-list/-/emojis-list-3.0.0.tgz#5570662046ad29e2e916e71aae260abdff4f6a78" + integrity sha512-/kyM18EfinwXZbno9FyUGeFh87KC8HRQBQGildHZbEuRyWFOmv1U10o9BBp8XVZDVNNuQKyIGIu5ZYAAXJ0V2Q== + +encodeurl@~1.0.2: + version "1.0.2" + resolved "https://registry.yarnpkg.com/encodeurl/-/encodeurl-1.0.2.tgz#ad3ff4c86ec2d029322f5a02c3a9a606c95b3f59" + integrity sha1-rT/0yG7C0CkyL1oCw6mmBslbP1k= + +enhanced-resolve@^4.0.0: + version "4.5.0" + resolved "https://registry.yarnpkg.com/enhanced-resolve/-/enhanced-resolve-4.5.0.tgz#2f3cfd84dbe3b487f18f2db2ef1e064a571ca5ec" + integrity sha512-Nv9m36S/vxpsI+Hc4/ZGRs0n9mXqSWGGq49zxb/cJfPAQMbUtttJAlNPS4AQzaBdw/pKskw5bMbekT/Y7W/Wlg== + dependencies: + graceful-fs "^4.1.2" + memory-fs "^0.5.0" + tapable "^1.0.0" + +enhanced-resolve@^5.9.2: + version "5.9.2" + resolved "https://registry.yarnpkg.com/enhanced-resolve/-/enhanced-resolve-5.9.2.tgz#0224dcd6a43389ebfb2d55efee517e5466772dd9" + integrity sha512-GIm3fQfwLJ8YZx2smuHpBKkXC1yOk+OBEmKckVyL0i/ea8mqDEykK3ld5dgH1QYPNyT/lIllxV2LULnxCHaHkA== + dependencies: + graceful-fs "^4.2.4" + tapable "^2.2.0" + +entities@^2.0.0: + version "2.2.0" + resolved "https://registry.yarnpkg.com/entities/-/entities-2.2.0.tgz#098dc90ebb83d8dffa089d55256b351d34c4da55" + integrity sha512-p92if5Nz619I0w+akJrLZH0MX0Pb5DX39XOwQTtXSdQQOaYH03S1uIQp4mhOZtAXrxq4ViO67YTiLBo2638o9A== + +envinfo@^7.7.3: + version "7.8.1" + resolved "https://registry.yarnpkg.com/envinfo/-/envinfo-7.8.1.tgz#06377e3e5f4d379fea7ac592d5ad8927e0c4d475" + integrity sha512-/o+BXHmB7ocbHEAs6F2EnG0ogybVVUdkRunTT2glZU9XAaGmhqskrvKwqXuDfNjEO0LZKWdejEEpnq8aM0tOaw== + +errno@^0.1.3: + version "0.1.8" + resolved "https://registry.yarnpkg.com/errno/-/errno-0.1.8.tgz#8bb3e9c7d463be4976ff888f76b4809ebc2e811f" + integrity sha512-dJ6oBr5SQ1VSd9qkk7ByRgb/1SH4JZjCHSW/mr63/QcXO9zLVxvJ6Oy13nio03rxpSnVDDjFor75SjVeZWPW/A== + dependencies: + prr "~1.0.1" + +es-module-lexer@^0.9.0: + version "0.9.3" + resolved "https://registry.yarnpkg.com/es-module-lexer/-/es-module-lexer-0.9.3.tgz#6f13db00cc38417137daf74366f535c8eb438f19" + integrity sha512-1HQ2M2sPtxwnvOvT1ZClHyQDiggdNjURWpY2we6aMKCQiUVxTmVs2UYPLIrD84sS+kMdUwfBSylbJPwNnBrnHQ== + +escalade@^3.1.1: + version "3.1.1" + resolved "https://registry.yarnpkg.com/escalade/-/escalade-3.1.1.tgz#d8cfdc7000965c5a0174b4a82eaa5c0552742e40" + integrity sha512-k0er2gUkLf8O0zKJiAhmkTnJlTvINGv7ygDNPbeIsX/TJjGJZHuh9B2UxbsaEkmlEo9MfhrSzmhIlhRlI2GXnw== + +escape-html@~1.0.3: + version "1.0.3" + resolved "https://registry.yarnpkg.com/escape-html/-/escape-html-1.0.3.tgz#0258eae4d3d0c0974de1c169188ef0051d1d1988" + integrity sha1-Aljq5NPQwJdN4cFpGI7wBR0dGYg= + +eslint-scope@5.1.1: + version "5.1.1" + resolved "https://registry.yarnpkg.com/eslint-scope/-/eslint-scope-5.1.1.tgz#e786e59a66cb92b3f6c1fb0d508aab174848f48c" + integrity sha512-2NxwbF/hZ0KpepYN0cNbo+FN6XoK7GaHlQhgx/hIZl6Va0bF45RQOOwhLIy8lQDbuCiadSLCBnH2CFYquit5bw== + dependencies: + esrecurse "^4.3.0" + estraverse "^4.1.1" + +esrecurse@^4.3.0: + version "4.3.0" + resolved "https://registry.yarnpkg.com/esrecurse/-/esrecurse-4.3.0.tgz#7ad7964d679abb28bee72cec63758b1c5d2c9921" + integrity sha512-KmfKL3b6G+RXvP8N1vr3Tq1kL/oCFgn2NYXEtqP8/L3pKapUA4G8cFVaoF3SU323CD4XypR/ffioHmkti6/Tag== + dependencies: + estraverse "^5.2.0" + +estraverse@^4.1.1: + version "4.3.0" + resolved "https://registry.yarnpkg.com/estraverse/-/estraverse-4.3.0.tgz#398ad3f3c5a24948be7725e83d11a7de28cdbd1d" + integrity sha512-39nnKffWz8xN1BU/2c79n9nB9HDzo0niYUqx6xyqUnyoAnQyyWpOTdZEeiCch8BBu515t4wp9ZmgVfVhn9EBpw== + +estraverse@^5.2.0: + version "5.3.0" + resolved "https://registry.yarnpkg.com/estraverse/-/estraverse-5.3.0.tgz#2eea5290702f26ab8fe5370370ff86c965d21123" + integrity sha512-MMdARuVEQziNTeJD8DgMqmhwR11BRQ/cBP+pLtYdSTnf3MIO8fFeiINEbX36ZdNlfU/7A9f3gUw49B3oQsvwBA== + +etag@~1.8.1: + version "1.8.1" + resolved "https://registry.yarnpkg.com/etag/-/etag-1.8.1.tgz#41ae2eeb65efa62268aebfea83ac7d79299b0887" + integrity sha1-Qa4u62XvpiJorr/qg6x9eSmbCIc= + +eventemitter3@^4.0.0: + version "4.0.7" + resolved "https://registry.yarnpkg.com/eventemitter3/-/eventemitter3-4.0.7.tgz#2de9b68f6528d5644ef5c59526a1b4a07306169f" + integrity sha512-8guHBZCwKnFhYdHr2ysuRWErTwhoN2X8XELRlrRwpmfeY2jjuUN4taQMsULKUVo1K4DvZl+0pgfyoysHxvmvEw== + +events@^3.2.0: + version "3.3.0" + resolved "https://registry.yarnpkg.com/events/-/events-3.3.0.tgz#31a95ad0a924e2d2c419a813aeb2c4e878ea7400" + integrity sha512-mQw+2fkQbALzQ7V0MY0IqdnXNOeTtP4r0lN9z7AAawCXgqea7bDii20AYrIBrFd/Hx0M2Ocz6S111CaFkUcb0Q== + +execa@^5.0.0: + version "5.1.1" + resolved "https://registry.yarnpkg.com/execa/-/execa-5.1.1.tgz#f80ad9cbf4298f7bd1d4c9555c21e93741c411dd" + integrity sha512-8uSpZZocAZRBAPIEINJj3Lo9HyGitllczc27Eh5YYojjMFMn8yHMDMaUHE2Jqfq05D/wucwI4JGURyXt1vchyg== + dependencies: + cross-spawn "^7.0.3" + get-stream "^6.0.0" + human-signals "^2.1.0" + is-stream "^2.0.0" + merge-stream "^2.0.0" + npm-run-path "^4.0.1" + onetime "^5.1.2" + signal-exit "^3.0.3" + strip-final-newline "^2.0.0" + +express@^4.17.1: + version "4.17.3" + resolved "https://registry.yarnpkg.com/express/-/express-4.17.3.tgz#f6c7302194a4fb54271b73a1fe7a06478c8f85a1" + integrity sha512-yuSQpz5I+Ch7gFrPCk4/c+dIBKlQUxtgwqzph132bsT6qhuzss6I8cLJQz7B3rFblzd6wtcI0ZbGltH/C4LjUg== + dependencies: + accepts "~1.3.8" + array-flatten "1.1.1" + body-parser "1.19.2" + content-disposition "0.5.4" + content-type "~1.0.4" + cookie "0.4.2" + cookie-signature "1.0.6" + debug "2.6.9" + depd "~1.1.2" + encodeurl "~1.0.2" + escape-html "~1.0.3" + etag "~1.8.1" + finalhandler "~1.1.2" + fresh "0.5.2" + merge-descriptors "1.0.1" + methods "~1.1.2" + on-finished "~2.3.0" + parseurl "~1.3.3" + path-to-regexp "0.1.7" + proxy-addr "~2.0.7" + qs "6.9.7" + range-parser "~1.2.1" + safe-buffer "5.2.1" + send "0.17.2" + serve-static "1.14.2" + setprototypeof "1.2.0" + statuses "~1.5.0" + type-is "~1.6.18" + utils-merge "1.0.1" + vary "~1.1.2" + +fast-deep-equal@^3.1.1, fast-deep-equal@^3.1.3: + version "3.1.3" + resolved "https://registry.yarnpkg.com/fast-deep-equal/-/fast-deep-equal-3.1.3.tgz#3a7d56b559d6cbc3eb512325244e619a65c6c525" + integrity sha512-f3qQ9oQy9j2AhBe/H9VC91wLmKBCCU/gDOnKNAYG5hswO7BLKj09Hc5HYNz9cGI++xlpDCIgDaitVs03ATR84Q== + +fast-glob@^3.2.9: + version "3.2.11" + resolved "https://registry.yarnpkg.com/fast-glob/-/fast-glob-3.2.11.tgz#a1172ad95ceb8a16e20caa5c5e56480e5129c1d9" + integrity sha512-xrO3+1bxSo3ZVHAnqzyuewYT6aMFHRAd4Kcs92MAonjwQZLsK9d0SF1IyQ3k5PoirxTW0Oe/RqFgMQ6TcNE5Ew== + dependencies: + "@nodelib/fs.stat" "^2.0.2" + "@nodelib/fs.walk" "^1.2.3" + glob-parent "^5.1.2" + merge2 "^1.3.0" + micromatch "^4.0.4" + +fast-json-stable-stringify@^2.0.0: + version "2.1.0" + resolved "https://registry.yarnpkg.com/fast-json-stable-stringify/-/fast-json-stable-stringify-2.1.0.tgz#874bf69c6f404c2b5d99c481341399fd55892633" + integrity sha512-lhd/wF+Lk98HZoTCtlVraHtfh5XYijIjalXck7saUtuanSDyLMxnHhSXEDJqHxD7msR8D0uCmqlkwjCV8xvwHw== + +fastest-levenshtein@^1.0.12: + version "1.0.12" + resolved "https://registry.yarnpkg.com/fastest-levenshtein/-/fastest-levenshtein-1.0.12.tgz#9990f7d3a88cc5a9ffd1f1745745251700d497e2" + integrity sha512-On2N+BpYJ15xIC974QNVuYGMOlEVt4s0EOI3wwMqOmK1fdDY+FN/zltPV8vosq4ad4c/gJ1KHScUn/6AWIgiow== + +fastq@^1.6.0: + version "1.13.0" + resolved "https://registry.yarnpkg.com/fastq/-/fastq-1.13.0.tgz#616760f88a7526bdfc596b7cab8c18938c36b98c" + integrity sha512-YpkpUnK8od0o1hmeSc7UUs/eB/vIPWJYjKck2QKIzAf71Vm1AAQ3EbuZB3g2JIy+pg+ERD0vqI79KyZiB2e2Nw== + dependencies: + reusify "^1.0.4" + +faye-websocket@^0.11.3: + version "0.11.4" + resolved "https://registry.yarnpkg.com/faye-websocket/-/faye-websocket-0.11.4.tgz#7f0d9275cfdd86a1c963dc8b65fcc451edcbb1da" + integrity sha512-CzbClwlXAuiRQAlUyfqPgvPoNKTckTPGfwZV4ZdAhVcP2lh9KUxJg2b5GkE7XbjKQ3YJnQ9z6D9ntLAlB+tP8g== + dependencies: + websocket-driver ">=0.5.1" + +fill-range@^7.0.1: + version "7.0.1" + resolved "https://registry.yarnpkg.com/fill-range/-/fill-range-7.0.1.tgz#1919a6a7c75fe38b2c7c77e5198535da9acdda40" + integrity sha512-qOo9F+dMUmC2Lcb4BbVvnKJxTPjCm+RRpe4gDuGrzkL7mEVl/djYSu2OdQ2Pa302N4oqkSg9ir6jaLWJ2USVpQ== + dependencies: + to-regex-range "^5.0.1" + +finalhandler@~1.1.2: + version "1.1.2" + resolved "https://registry.yarnpkg.com/finalhandler/-/finalhandler-1.1.2.tgz#b7e7d000ffd11938d0fdb053506f6ebabe9f587d" + integrity sha512-aAWcW57uxVNrQZqFXjITpW3sIUQmHGG3qSb9mUah9MgMC4NeWhNOlNjXEYq3HjRAvL6arUviZGGJsBg6z0zsWA== + dependencies: + debug "2.6.9" + encodeurl "~1.0.2" + escape-html "~1.0.3" + on-finished "~2.3.0" + parseurl "~1.3.3" + statuses "~1.5.0" + unpipe "~1.0.0" + +find-up@^4.0.0: + version "4.1.0" + resolved "https://registry.yarnpkg.com/find-up/-/find-up-4.1.0.tgz#97afe7d6cdc0bc5928584b7c8d7b16e8a9aa5d19" + integrity sha512-PpOwAdQ/YlXQ2vj8a3h8IipDuYRi3wceVQQGYWxNINccq40Anw7BlsEXCMbt1Zt+OLA6Fq9suIpIWD0OsnISlw== + dependencies: + locate-path "^5.0.0" + path-exists "^4.0.0" + +flow-bin@^0.118.0: + version "0.118.0" + resolved "https://registry.yarnpkg.com/flow-bin/-/flow-bin-0.118.0.tgz#fb706364a58c682d67a2ca7df39396467dc397d1" + integrity sha512-jlbUu0XkbpXeXhan5xyTqVK1jmEKNxE8hpzznI3TThHTr76GiFwK0iRzhDo4KNy+S9h/KxHaqVhTP86vA6wHCg== + +follow-redirects@^1.0.0: + version "1.14.9" + resolved "https://registry.yarnpkg.com/follow-redirects/-/follow-redirects-1.14.9.tgz#dd4ea157de7bfaf9ea9b3fbd85aa16951f78d8d7" + integrity sha512-MQDfihBQYMcyy5dhRDJUHcw7lb2Pv/TuE6xP1vyraLukNDHKbDxDNaOE3NbCAdKQApno+GPRyo1YAp89yCjK4w== + +forwarded@0.2.0: + version "0.2.0" + resolved "https://registry.yarnpkg.com/forwarded/-/forwarded-0.2.0.tgz#2269936428aad4c15c7ebe9779a84bf0b2a81811" + integrity sha512-buRG0fpBtRHSTCOASe6hD258tEubFoRLb4ZNA6NxMVHNw2gOcwHo9wyablzMzOA5z9xA9L1KNjk/Nt6MT9aYow== + +fresh@0.5.2: + version "0.5.2" + resolved "https://registry.yarnpkg.com/fresh/-/fresh-0.5.2.tgz#3d8cadd90d976569fa835ab1f8e4b23a105605a7" + integrity sha1-PYyt2Q2XZWn6g1qx+OSyOhBWBac= + +fs-monkey@1.0.3: + version "1.0.3" + resolved "https://registry.yarnpkg.com/fs-monkey/-/fs-monkey-1.0.3.tgz#ae3ac92d53bb328efe0e9a1d9541f6ad8d48e2d3" + integrity sha512-cybjIfiiE+pTWicSCLFHSrXZ6EilF30oh91FDP9S2B051prEa7QWfrVTQm10/dDpswBDXZugPa1Ogu8Yh+HV0Q== + +fs.realpath@^1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/fs.realpath/-/fs.realpath-1.0.0.tgz#1504ad2523158caa40db4a2787cb01411994ea4f" + integrity sha1-FQStJSMVjKpA20onh8sBQRmU6k8= + +fsevents@~2.3.2: + version "2.3.2" + resolved "https://registry.yarnpkg.com/fsevents/-/fsevents-2.3.2.tgz#8a526f78b8fdf4623b709e0b975c52c24c02fd1a" + integrity sha512-xiqMQR4xAeHTuB9uWm+fFRcIOgKBMiOBP+eXiyT7jsgVCq1bkVygt00oASowB7EdtpOHaaPgKt812P9ab+DDKA== + +function-bind@^1.1.1: + version "1.1.1" + resolved "https://registry.yarnpkg.com/function-bind/-/function-bind-1.1.1.tgz#a56899d3ea3c9bab874bb9773b7c5ede92f4895d" + integrity sha512-yIovAzMX49sF8Yl58fSCWJ5svSLuaibPxXQJFLmBObTuCr0Mf1KiPopGM9NiFjiYBCbfaa2Fh6breQ6ANVTI0A== + +get-intrinsic@^1.0.2: + version "1.1.1" + resolved "https://registry.yarnpkg.com/get-intrinsic/-/get-intrinsic-1.1.1.tgz#15f59f376f855c446963948f0d24cd3637b4abc6" + integrity sha512-kWZrnVM42QCiEA2Ig1bG8zjoIMOgxWwYCEeNdwY6Tv/cOSeGpcoX4pXHfKUxNKVoArnrEr2e9srnAxxGIraS9Q== + dependencies: + function-bind "^1.1.1" + has "^1.0.3" + has-symbols "^1.0.1" + +get-stream@^6.0.0: + version "6.0.1" + resolved "https://registry.yarnpkg.com/get-stream/-/get-stream-6.0.1.tgz#a262d8eef67aced57c2852ad6167526a43cbf7b7" + integrity sha512-ts6Wi+2j3jQjqi70w5AlN8DFnkSwC+MqmxEzdEALB2qXZYV3X/b1CTfgPLGJNMeAWxdPfU8FO1ms3NUfaHCPYg== + +glob-parent@^5.1.2, glob-parent@~5.1.2: + version "5.1.2" + resolved "https://registry.yarnpkg.com/glob-parent/-/glob-parent-5.1.2.tgz#869832c58034fe68a4093c17dc15e8340d8401c4" + integrity sha512-AOIgSQCepiJYwP3ARnGx+5VnTu2HBYdzbGP45eLw1vr3zB3vZLeyed1sC9hnbcOc9/SrMyM5RPQrkGz4aS9Zow== + dependencies: + is-glob "^4.0.1" + +glob-to-regexp@^0.4.1: + version "0.4.1" + resolved "https://registry.yarnpkg.com/glob-to-regexp/-/glob-to-regexp-0.4.1.tgz#c75297087c851b9a578bd217dd59a92f59fe546e" + integrity sha512-lkX1HJXwyMcprw/5YUZc2s7DrpAiHB21/V+E1rHUrVNokkvB6bqMzT0VfV6/86ZNabt1k14YOIaT7nDvOX3Iiw== + +glob@^7.1.3: + version "7.2.0" + resolved "https://registry.yarnpkg.com/glob/-/glob-7.2.0.tgz#d15535af7732e02e948f4c41628bd910293f6023" + integrity sha512-lmLf6gtyrPq8tTjSmrO94wBeQbFR3HbLHbuyD69wuyQkImp2hWqMGB47OX65FBkPffO641IP9jWa1z4ivqG26Q== + dependencies: + fs.realpath "^1.0.0" + inflight "^1.0.4" + inherits "2" + minimatch "^3.0.4" + once "^1.3.0" + path-is-absolute "^1.0.0" + +globby@^11.0.1: + version "11.1.0" + resolved "https://registry.yarnpkg.com/globby/-/globby-11.1.0.tgz#bd4be98bb042f83d796f7e3811991fbe82a0d34b" + integrity sha512-jhIXaOzy1sb8IyocaruWSn1TjmnBVs8Ayhcy83rmxNJ8q2uWKCAj3CnJY+KpGSXCueAPc0i05kVvVKtP1t9S3g== + dependencies: + array-union "^2.1.0" + dir-glob "^3.0.1" + fast-glob "^3.2.9" + ignore "^5.2.0" + merge2 "^1.4.1" + slash "^3.0.0" + +graceful-fs@^4.1.2, graceful-fs@^4.2.4, graceful-fs@^4.2.6, graceful-fs@^4.2.9: + version "4.2.9" + resolved "https://registry.yarnpkg.com/graceful-fs/-/graceful-fs-4.2.9.tgz#041b05df45755e587a24942279b9d113146e1c96" + integrity sha512-NtNxqUcXgpW2iMrfqSfR73Glt39K+BLwWsPs94yR63v45T0Wbej7eRmL5cWfwEgqXnmjQp3zaJTshdRW/qC2ZQ== + +handle-thing@^2.0.0: + version "2.0.1" + resolved "https://registry.yarnpkg.com/handle-thing/-/handle-thing-2.0.1.tgz#857f79ce359580c340d43081cc648970d0bb234e" + integrity sha512-9Qn4yBxelxoh2Ow62nP+Ka/kMnOXRi8BXnRaUwezLNhqelnN49xKz4F/dPP8OYLxLxq6JDtZb2i9XznUQbNPTg== + +has-flag@^4.0.0: + version "4.0.0" + resolved "https://registry.yarnpkg.com/has-flag/-/has-flag-4.0.0.tgz#944771fd9c81c81265c4d6941860da06bb59479b" + integrity sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ== + +has-symbols@^1.0.1, has-symbols@^1.0.2: + version "1.0.3" + resolved "https://registry.yarnpkg.com/has-symbols/-/has-symbols-1.0.3.tgz#bb7b2c4349251dce87b125f7bdf874aa7c8b39f8" + integrity sha512-l3LCuF6MgDNwTDKkdYGEihYjt5pRPbEg46rtlmnSPlUbgmB8LOIrKJbYYFBSbnPaJexMKtiPO8hmeRjRz2Td+A== + +has-tostringtag@^1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/has-tostringtag/-/has-tostringtag-1.0.0.tgz#7e133818a7d394734f941e73c3d3f9291e658b25" + integrity sha512-kFjcSNhnlGV1kyoGk7OXKSawH5JOb/LzUc5w9B02hOTO0dfFRjbHQKvg1d6cf3HbeUmtU9VbbV3qzZ2Teh97WQ== + dependencies: + has-symbols "^1.0.2" + +has@^1.0.3: + version "1.0.3" + resolved "https://registry.yarnpkg.com/has/-/has-1.0.3.tgz#722d7cbfc1f6aa8241f16dd814e011e1f41e8796" + integrity sha512-f2dvO0VU6Oej7RkWJGrehjbzMAjFp5/VKPp5tTpWIV4JHHZK1/BxbFRtf/siA2SWTe09caDmVtYYzWEIbBS4zw== + dependencies: + function-bind "^1.1.1" + +he@^1.2.0: + version "1.2.0" + resolved "https://registry.yarnpkg.com/he/-/he-1.2.0.tgz#84ae65fa7eafb165fddb61566ae14baf05664f0f" + integrity sha512-F/1DnUGPopORZi0ni+CvrCgHQ5FyEAHRLSApuYWMmrbSwoN2Mn/7k+Gl38gJnR7yyDZk6WLXwiGod1JOWNDKGw== + +hoist-non-react-statics@^3.3.2: + version "3.3.2" + resolved "https://registry.yarnpkg.com/hoist-non-react-statics/-/hoist-non-react-statics-3.3.2.tgz#ece0acaf71d62c2969c2ec59feff42a4b1a85b45" + integrity sha512-/gGivxi8JPKWNm/W0jSmzcMPpfpPLc3dY/6GxhX2hQ9iGj3aDfklV4ET7NjKpSinLpJ5vafa9iiGIEZg10SfBw== + dependencies: + react-is "^16.7.0" + +hpack.js@^2.1.6: + version "2.1.6" + resolved "https://registry.yarnpkg.com/hpack.js/-/hpack.js-2.1.6.tgz#87774c0949e513f42e84575b3c45681fade2a0b2" + integrity sha1-h3dMCUnlE/QuhFdbPEVoH63ioLI= + dependencies: + inherits "^2.0.1" + obuf "^1.0.0" + readable-stream "^2.0.1" + wbuf "^1.1.0" + +html-entities@^2.3.2: + version "2.3.2" + resolved "https://registry.yarnpkg.com/html-entities/-/html-entities-2.3.2.tgz#760b404685cb1d794e4f4b744332e3b00dcfe488" + integrity sha512-c3Ab/url5ksaT0WyleslpBEthOzWhrjQbg75y7XUsfSzi3Dgzt0l8w5e7DylRn15MTlMMD58dTfzddNS2kcAjQ== + +html-minifier-terser@^6.0.2: + version "6.1.0" + resolved "https://registry.yarnpkg.com/html-minifier-terser/-/html-minifier-terser-6.1.0.tgz#bfc818934cc07918f6b3669f5774ecdfd48f32ab" + integrity sha512-YXxSlJBZTP7RS3tWnQw74ooKa6L9b9i9QYXY21eUEvhZ3u9XLfv6OnFsQq6RxkhHygsaUMvYsZRV5rU/OVNZxw== + dependencies: + camel-case "^4.1.2" + clean-css "^5.2.2" + commander "^8.3.0" + he "^1.2.0" + param-case "^3.0.4" + relateurl "^0.2.7" + terser "^5.10.0" + +html-webpack-plugin@^5.3.1: + version "5.5.0" + resolved "https://registry.yarnpkg.com/html-webpack-plugin/-/html-webpack-plugin-5.5.0.tgz#c3911936f57681c1f9f4d8b68c158cd9dfe52f50" + integrity sha512-sy88PC2cRTVxvETRgUHFrL4No3UxvcH8G1NepGhqaTT+GXN2kTamqasot0inS5hXeg1cMbFDt27zzo9p35lZVw== + dependencies: + "@types/html-minifier-terser" "^6.0.0" + html-minifier-terser "^6.0.2" + lodash "^4.17.21" + pretty-error "^4.0.0" + tapable "^2.0.0" + +htmlparser2@^6.1.0: + version "6.1.0" + resolved "https://registry.yarnpkg.com/htmlparser2/-/htmlparser2-6.1.0.tgz#c4d762b6c3371a05dbe65e94ae43a9f845fb8fb7" + integrity sha512-gyyPk6rgonLFEDGoeRgQNaEUvdJ4ktTmmUh/h2t7s+M8oPpIPxgNACWa+6ESR57kXstwqPiCut0V8NRpcwgU7A== + dependencies: + domelementtype "^2.0.1" + domhandler "^4.0.0" + domutils "^2.5.2" + entities "^2.0.0" + +http-deceiver@^1.2.7: + version "1.2.7" + resolved "https://registry.yarnpkg.com/http-deceiver/-/http-deceiver-1.2.7.tgz#fa7168944ab9a519d337cb0bec7284dc3e723d87" + integrity sha1-+nFolEq5pRnTN8sL7HKE3D5yPYc= + +http-errors@1.8.1: + version "1.8.1" + resolved "https://registry.yarnpkg.com/http-errors/-/http-errors-1.8.1.tgz#7c3f28577cbc8a207388455dbd62295ed07bd68c" + integrity sha512-Kpk9Sm7NmI+RHhnj6OIWDI1d6fIoFAtFt9RLaTMRlg/8w49juAStsrBgp0Dp4OdxdVbRIeKhtCUvoi/RuAhO4g== + dependencies: + depd "~1.1.2" + inherits "2.0.4" + setprototypeof "1.2.0" + statuses ">= 1.5.0 < 2" + toidentifier "1.0.1" + +http-errors@~1.6.2: + version "1.6.3" + resolved "https://registry.yarnpkg.com/http-errors/-/http-errors-1.6.3.tgz#8b55680bb4be283a0b5bf4ea2e38580be1d9320d" + integrity sha1-i1VoC7S+KDoLW/TqLjhYC+HZMg0= + dependencies: + depd "~1.1.2" + inherits "2.0.3" + setprototypeof "1.1.0" + statuses ">= 1.4.0 < 2" + +http-parser-js@>=0.5.1: + version "0.5.6" + resolved "https://registry.yarnpkg.com/http-parser-js/-/http-parser-js-0.5.6.tgz#2e02406ab2df8af8a7abfba62e0da01c62b95afd" + integrity sha512-vDlkRPDJn93swjcjqMSaGSPABbIarsr1TLAui/gLDXzV5VsJNdXNzMYDyNBLQkjWQCJ1uizu8T2oDMhmGt0PRA== + +http-proxy-middleware@^2.0.0: + version "2.0.3" + resolved "https://registry.yarnpkg.com/http-proxy-middleware/-/http-proxy-middleware-2.0.3.tgz#5df04f69a89f530c2284cd71eeaa51ba52243289" + integrity sha512-1bloEwnrHMnCoO/Gcwbz7eSVvW50KPES01PecpagI+YLNLci4AcuKJrujW4Mc3sBLpFxMSlsLNHS5Nl/lvrTPA== + dependencies: + "@types/http-proxy" "^1.17.8" + http-proxy "^1.18.1" + is-glob "^4.0.1" + is-plain-obj "^3.0.0" + micromatch "^4.0.2" + +http-proxy@^1.18.1: + version "1.18.1" + resolved "https://registry.yarnpkg.com/http-proxy/-/http-proxy-1.18.1.tgz#401541f0534884bbf95260334e72f88ee3976549" + integrity sha512-7mz/721AbnJwIVbnaSv1Cz3Am0ZLT/UBwkC92VlxhXv/k/BBQfM2fXElQNC27BVGr0uwUpplYPQM9LnaBMR5NQ== + dependencies: + eventemitter3 "^4.0.0" + follow-redirects "^1.0.0" + requires-port "^1.0.0" + +human-signals@^2.1.0: + version "2.1.0" + resolved "https://registry.yarnpkg.com/human-signals/-/human-signals-2.1.0.tgz#dc91fcba42e4d06e4abaed33b3e7a3c02f514ea0" + integrity sha512-B4FFZ6q/T2jhhksgkbEW3HBvWIfDW85snkQgawt07S7J5QXTk6BkNV+0yAeZrM5QpMAdYlocGoljn0sJ/WQkFw== + +hyphenate-style-name@^1.0.3: + version "1.0.4" + resolved "https://registry.yarnpkg.com/hyphenate-style-name/-/hyphenate-style-name-1.0.4.tgz#691879af8e220aea5750e8827db4ef62a54e361d" + integrity sha512-ygGZLjmXfPHj+ZWh6LwbC37l43MhfztxetbFCoYTM2VjkIUpeHgSNn7QIyVFj7YQ1Wl9Cbw5sholVJPzWvC2MQ== + +iconv-lite@0.4.24: + version "0.4.24" + resolved "https://registry.yarnpkg.com/iconv-lite/-/iconv-lite-0.4.24.tgz#2022b4b25fbddc21d2f524974a474aafe733908b" + integrity sha512-v3MXnZAcvnywkTUEZomIActle7RXXeedOR31wwl7VlyoXO4Qi9arvSenNQWne1TcRwhCL1HwLI21bEqdpj8/rA== + dependencies: + safer-buffer ">= 2.1.2 < 3" + +icss-utils@^5.0.0, icss-utils@^5.1.0: + version "5.1.0" + resolved "https://registry.yarnpkg.com/icss-utils/-/icss-utils-5.1.0.tgz#c6be6858abd013d768e98366ae47e25d5887b1ae" + integrity sha512-soFhflCVWLfRNOPU3iv5Z9VUdT44xFRbzjLsEzSr5AQmgqPMTHdU3PMT1Cf1ssx8fLNJDA1juftYl+PUcv3MqA== + +ignore@^5.2.0: + version "5.2.0" + resolved "https://registry.yarnpkg.com/ignore/-/ignore-5.2.0.tgz#6d3bac8fa7fe0d45d9f9be7bac2fc279577e345a" + integrity sha512-CmxgYGiEPCLhfLnpPp1MoRmifwEIOgjcHXxOBjv7mY96c+eWScsOP9c112ZyLdWHi0FxHjI+4uVhKYp/gcdRmQ== + +import-local@^3.0.2: + version "3.1.0" + resolved "https://registry.yarnpkg.com/import-local/-/import-local-3.1.0.tgz#b4479df8a5fd44f6cdce24070675676063c95cb4" + integrity sha512-ASB07uLtnDs1o6EHjKpX34BKYDSqnFerfTOJL2HvMqF70LnxpjkzDB8J44oT9pu4AMPkQwf8jl6szgvNd2tRIg== + dependencies: + pkg-dir "^4.2.0" + resolve-cwd "^3.0.0" + +indent-string@^4.0.0: + version "4.0.0" + resolved "https://registry.yarnpkg.com/indent-string/-/indent-string-4.0.0.tgz#624f8f4497d619b2d9768531d58f4122854d7251" + integrity sha512-EdDDZu4A2OyIK7Lr/2zG+w5jmbuk1DVBnEwREQvBzspBJkCEbRa8GxU1lghYcaGJCnRWibjDXlq779X1/y5xwg== + +inflight@^1.0.4: + version "1.0.6" + resolved "https://registry.yarnpkg.com/inflight/-/inflight-1.0.6.tgz#49bd6331d7d02d0c09bc910a1075ba8165b56df9" + integrity sha1-Sb1jMdfQLQwJvJEKEHW6gWW1bfk= + dependencies: + once "^1.3.0" + wrappy "1" + +inherits@2, inherits@2.0.4, inherits@^2.0.1, inherits@^2.0.3, inherits@~2.0.3: + version "2.0.4" + resolved "https://registry.yarnpkg.com/inherits/-/inherits-2.0.4.tgz#0fa2c64f932917c3433a0ded55363aae37416b7c" + integrity sha512-k/vGaX4/Yla3WzyMCvTQOXYeIHvqOKtnqBduzTHpzpQZzAskKMhZ2K+EnBiSM9zGSoIFeMpXKxa4dYeZIQqewQ== + +inherits@2.0.3: + version "2.0.3" + resolved "https://registry.yarnpkg.com/inherits/-/inherits-2.0.3.tgz#633c2c83e3da42a502f52466022480f4208261de" + integrity sha1-Yzwsg+PaQqUC9SRmAiSA9CCCYd4= + +inline-chunk-html-plugin@^1.1.1: + version "1.1.1" + resolved "https://registry.yarnpkg.com/inline-chunk-html-plugin/-/inline-chunk-html-plugin-1.1.1.tgz#f64111aed16fac274d2b929f6a6a08671d82354e" + integrity sha512-6W1eGIj8z/Yla6xJx5il6jJfCxMZS3kVkbiLQThbbjdsDLRIWkUVmpnhfW2l6WAwCW+qfy0zoXVGBZM1E5XF3g== + +interpret@^2.2.0: + version "2.2.0" + resolved "https://registry.yarnpkg.com/interpret/-/interpret-2.2.0.tgz#1a78a0b5965c40a5416d007ad6f50ad27c417df9" + integrity sha512-Ju0Bz/cEia55xDwUWEa8+olFpCiQoypjnQySseKtmjNrnps3P+xfpUmGr90T7yjlVJmOtybRvPXhKMbHr+fWnw== + +ip@^1.1.0: + version "1.1.5" + resolved "https://registry.yarnpkg.com/ip/-/ip-1.1.5.tgz#bdded70114290828c0a039e72ef25f5aaec4354a" + integrity sha1-vd7XARQpCCjAoDnnLvJfWq7ENUo= + +ipaddr.js@1.9.1: + version "1.9.1" + resolved "https://registry.yarnpkg.com/ipaddr.js/-/ipaddr.js-1.9.1.tgz#bff38543eeb8984825079ff3a2a8e6cbd46781b3" + integrity sha512-0KI/607xoxSToH7GjN1FfSbLoU0+btTicjsQSWQlh/hZykN8KpmMf7uYwPW3R+akZ6R/w18ZlXSHBYXiYUPO3g== + +ipaddr.js@^2.0.1: + version "2.0.1" + resolved "https://registry.yarnpkg.com/ipaddr.js/-/ipaddr.js-2.0.1.tgz#eca256a7a877e917aeb368b0a7497ddf42ef81c0" + integrity sha512-1qTgH9NG+IIJ4yfKs2e6Pp1bZg8wbDbKHT21HrLIeYBTRLgMYKnMTPAuI3Lcs61nfx5h1xlXnbJtH1kX5/d/ng== + +is-arguments@^1.0.4: + version "1.1.1" + resolved "https://registry.yarnpkg.com/is-arguments/-/is-arguments-1.1.1.tgz#15b3f88fda01f2a97fec84ca761a560f123efa9b" + integrity sha512-8Q7EARjzEnKpt/PCD7e1cgUS0a6X8u5tdSiMqXhojOdoV9TsMsiO+9VLC5vAmO8N7/GmXn7yjR8qnA6bVAEzfA== + dependencies: + call-bind "^1.0.2" + has-tostringtag "^1.0.0" + +is-binary-path@~2.1.0: + version "2.1.0" + resolved "https://registry.yarnpkg.com/is-binary-path/-/is-binary-path-2.1.0.tgz#ea1f7f3b80f064236e83470f86c09c254fb45b09" + integrity sha512-ZMERYes6pDydyuGidse7OsHxtbI7WVeUEozgR/g7rd0xUimYNlvZRE/K2MgZTjWy725IfelLeVcEM97mmtRGXw== + dependencies: + binary-extensions "^2.0.0" + +is-core-module@^2.8.1: + version "2.8.1" + resolved "https://registry.yarnpkg.com/is-core-module/-/is-core-module-2.8.1.tgz#f59fdfca701d5879d0a6b100a40aa1560ce27211" + integrity sha512-SdNCUs284hr40hFTFP6l0IfZ/RSrMXF3qgoRHd3/79unUTvrFO/JoXwkGm+5J/Oe3E/b5GsnG330uUNgRpu1PA== + dependencies: + has "^1.0.3" + +is-date-object@^1.0.1: + version "1.0.5" + resolved "https://registry.yarnpkg.com/is-date-object/-/is-date-object-1.0.5.tgz#0841d5536e724c25597bf6ea62e1bd38298df31f" + integrity sha512-9YQaSxsAiSwcvS33MBk3wTCVnWK+HhF8VZR2jRxehM16QcVOdHqPn4VPHmRK4lSr38n9JriurInLcP90xsYNfQ== + dependencies: + has-tostringtag "^1.0.0" + +is-docker@^2.0.0, is-docker@^2.1.1: + version "2.2.1" + resolved "https://registry.yarnpkg.com/is-docker/-/is-docker-2.2.1.tgz#33eeabe23cfe86f14bde4408a02c0cfb853acdaa" + integrity sha512-F+i2BKsFrH66iaUFc0woD8sLy8getkwTwtOBjvs56Cx4CgJDeKQeqfz8wAYiSb8JOprWhHH5p77PbmYCvvUuXQ== + +is-extglob@^2.1.1: + version "2.1.1" + resolved "https://registry.yarnpkg.com/is-extglob/-/is-extglob-2.1.1.tgz#a88c02535791f02ed37c76a1b9ea9773c833f8c2" + integrity sha1-qIwCU1eR8C7TfHahueqXc8gz+MI= + +is-glob@^4.0.1, is-glob@~4.0.1: + version "4.0.3" + resolved "https://registry.yarnpkg.com/is-glob/-/is-glob-4.0.3.tgz#64f61e42cbbb2eec2071a9dac0b28ba1e65d5084" + integrity sha512-xelSayHH36ZgE7ZWhli7pW34hNbNl8Ojv5KVmkJD4hBdD3th8Tfk9vYasLM+mXWOZhFkgZfxhLSnrwRr4elSSg== + dependencies: + is-extglob "^2.1.1" + +is-in-browser@^1.0.2, is-in-browser@^1.1.3: + version "1.1.3" + resolved "https://registry.yarnpkg.com/is-in-browser/-/is-in-browser-1.1.3.tgz#56ff4db683a078c6082eb95dad7dc62e1d04f835" + integrity sha1-Vv9NtoOgeMYILrldrX3GLh0E+DU= + +is-number@^7.0.0: + version "7.0.0" + resolved "https://registry.yarnpkg.com/is-number/-/is-number-7.0.0.tgz#7535345b896734d5f80c4d06c50955527a14f12b" + integrity sha512-41Cifkg6e8TylSpdtTpeLVMqvSBEVzTttHvERD741+pnZ8ANv0004MRL43QKPDlK9cGvNp6NZWZUBlbGXYxxng== + +is-path-cwd@^2.2.0: + version "2.2.0" + resolved "https://registry.yarnpkg.com/is-path-cwd/-/is-path-cwd-2.2.0.tgz#67d43b82664a7b5191fd9119127eb300048a9fdb" + integrity sha512-w942bTcih8fdJPJmQHFzkS76NEP8Kzzvmw92cXsazb8intwLqPibPPdXf4ANdKV3rYMuuQYGIWtvz9JilB3NFQ== + +is-path-inside@^3.0.2: + version "3.0.3" + resolved "https://registry.yarnpkg.com/is-path-inside/-/is-path-inside-3.0.3.tgz#d231362e53a07ff2b0e0ea7fed049161ffd16283" + integrity sha512-Fd4gABb+ycGAmKou8eMftCupSir5lRxqf4aD/vd0cD2qc4HL07OjCeuHMr8Ro4CoMaeCKDB0/ECBOVWjTwUvPQ== + +is-plain-obj@^3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/is-plain-obj/-/is-plain-obj-3.0.0.tgz#af6f2ea14ac5a646183a5bbdb5baabbc156ad9d7" + integrity sha512-gwsOE28k+23GP1B6vFl1oVh/WOzmawBrKwo5Ev6wMKzPkaXaCDIQKzLnvsA42DRlbVTWorkgTKIviAKCWkfUwA== + +is-plain-object@^2.0.4: + version "2.0.4" + resolved "https://registry.yarnpkg.com/is-plain-object/-/is-plain-object-2.0.4.tgz#2c163b3fafb1b606d9d17928f05c2a1c38e07677" + integrity sha512-h5PpgXkWitc38BBMYawTYMWJHFZJVnBquFE57xFpjB8pJFiF6gZ+bU+WyI/yqXiFR5mdLsgYNaPe8uao6Uv9Og== + dependencies: + isobject "^3.0.1" + +is-regex@^1.0.4: + version "1.1.4" + resolved "https://registry.yarnpkg.com/is-regex/-/is-regex-1.1.4.tgz#eef5663cd59fa4c0ae339505323df6854bb15958" + integrity sha512-kvRdxDsxZjhzUX07ZnLydzS1TU/TJlTUHHY4YLL87e37oUA49DfkLqgy+VjFocowy29cKvcSiu+kIv728jTTVg== + dependencies: + call-bind "^1.0.2" + has-tostringtag "^1.0.0" + +is-stream@^2.0.0: + version "2.0.1" + resolved "https://registry.yarnpkg.com/is-stream/-/is-stream-2.0.1.tgz#fac1e3d53b97ad5a9d0ae9cef2389f5810a5c077" + integrity sha512-hFoiJiTl63nn+kstHGBtewWSKnQLpyb155KHheA1l39uvtO9nWIop1p3udqPcUd/xbF1VLMO4n7OI6p7RbngDg== + +is-wsl@^2.2.0: + version "2.2.0" + resolved "https://registry.yarnpkg.com/is-wsl/-/is-wsl-2.2.0.tgz#74a4c76e77ca9fd3f932f290c17ea326cd157271" + integrity sha512-fKzAra0rGJUUBwGBgNkHZuToZcn+TtXHpeCgmkMJMMYx1sQDYaCSyjJBSCa2nH1DGm7s3n1oBnohoVTBaN7Lww== + dependencies: + is-docker "^2.0.0" + +isarray@~1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/isarray/-/isarray-1.0.0.tgz#bb935d48582cba168c06834957a54a3e07124f11" + integrity sha1-u5NdSFgsuhaMBoNJV6VKPgcSTxE= + +isexe@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/isexe/-/isexe-2.0.0.tgz#e8fbf374dc556ff8947a10dcb0572d633f2cfa10" + integrity sha1-6PvzdNxVb/iUehDcsFctYz8s+hA= + +isobject@^3.0.1: + version "3.0.1" + resolved "https://registry.yarnpkg.com/isobject/-/isobject-3.0.1.tgz#4e431e92b11a9731636aa1f9c8d1ccbcfdab78df" + integrity sha1-TkMekrEalzFjaqH5yNHMvP2reN8= + +jest-worker@^27.4.5: + version "27.5.1" + resolved "https://registry.yarnpkg.com/jest-worker/-/jest-worker-27.5.1.tgz#8d146f0900e8973b106b6f73cc1e9a8cb86f8db0" + integrity sha512-7vuh85V5cdDofPyxn58nrPjBktZo0u9x1g8WtjQol+jZDaE+fhN+cIvTj11GndBnMnyfrUOG1sZQxCdjKh+DKg== + dependencies: + "@types/node" "*" + merge-stream "^2.0.0" + supports-color "^8.0.0" + +"js-tokens@^3.0.0 || ^4.0.0": + version "4.0.0" + resolved "https://registry.yarnpkg.com/js-tokens/-/js-tokens-4.0.0.tgz#19203fb59991df98e3a287050d4647cdeaf32499" + integrity sha512-RdJUflcE3cUzKiMqQgsCu06FPu9UdIJO0beYbPhHN4k6apgJtifcoCtT9bcxOpYBtpD2kCM6Sbzg4CausW/PKQ== + +json-parse-better-errors@^1.0.2: + version "1.0.2" + resolved "https://registry.yarnpkg.com/json-parse-better-errors/-/json-parse-better-errors-1.0.2.tgz#bb867cfb3450e69107c131d1c514bab3dc8bcaa9" + integrity sha512-mrqyZKfX5EhL7hvqcV6WG1yYjnjeuYDzDhhcAAUrq8Po85NBQBJP+ZDUT75qZQ98IkUoBqdkExkukOU7Ts2wrw== + +json-schema-traverse@^0.4.1: + version "0.4.1" + resolved "https://registry.yarnpkg.com/json-schema-traverse/-/json-schema-traverse-0.4.1.tgz#69f6a87d9513ab8bb8fe63bdb0979c448e684660" + integrity sha512-xbbCH5dCYU5T8LcEhhuh7HJ88HXuW3qsI3Y0zOZFKfZEHcpWiHU/Jxzk629Brsab/mMiHQti9wMP+845RPe3Vg== + +json-schema-traverse@^1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/json-schema-traverse/-/json-schema-traverse-1.0.0.tgz#ae7bcb3656ab77a73ba5c49bf654f38e6b6860e2" + integrity sha512-NM8/P9n3XjXhIZn1lLhkFaACTOURQXjWhV4BA/RnOv8xvgqtqpAX9IO4mRQxSx1Rlo4tqzeqb0sOlruaOy3dug== + +json2mq@^0.2.0: + version "0.2.0" + resolved "https://registry.yarnpkg.com/json2mq/-/json2mq-0.2.0.tgz#b637bd3ba9eabe122c83e9720483aeb10d2c904a" + integrity sha1-tje9O6nqvhIsg+lyBIOusQ0skEo= + dependencies: + string-convert "^0.2.0" + +json5@^2.1.2: + version "2.2.0" + resolved "https://registry.yarnpkg.com/json5/-/json5-2.2.0.tgz#2dfefe720c6ba525d9ebd909950f0515316c89a3" + integrity sha512-f+8cldu7X/y7RAJurMEJmdoKXGB/X550w2Nr3tTbezL6RwEE/iMcm+tZnXeoZtKuOq6ft8+CqzEkrIgx1fPoQA== + dependencies: + minimist "^1.2.5" + +jss-plugin-camel-case@^10.5.1: + version "10.9.0" + resolved "https://registry.yarnpkg.com/jss-plugin-camel-case/-/jss-plugin-camel-case-10.9.0.tgz#4921b568b38d893f39736ee8c4c5f1c64670aaf7" + integrity sha512-UH6uPpnDk413/r/2Olmw4+y54yEF2lRIV8XIZyuYpgPYTITLlPOsq6XB9qeqv+75SQSg3KLocq5jUBXW8qWWww== + dependencies: + "@babel/runtime" "^7.3.1" + hyphenate-style-name "^1.0.3" + jss "10.9.0" + +jss-plugin-default-unit@^10.5.1: + version "10.9.0" + resolved "https://registry.yarnpkg.com/jss-plugin-default-unit/-/jss-plugin-default-unit-10.9.0.tgz#bb23a48f075bc0ce852b4b4d3f7582bc002df991" + integrity sha512-7Ju4Q9wJ/MZPsxfu4T84mzdn7pLHWeqoGd/D8O3eDNNJ93Xc8PxnLmV8s8ZPNRYkLdxZqKtm1nPQ0BM4JRlq2w== + dependencies: + "@babel/runtime" "^7.3.1" + jss "10.9.0" + +jss-plugin-global@^10.5.1: + version "10.9.0" + resolved "https://registry.yarnpkg.com/jss-plugin-global/-/jss-plugin-global-10.9.0.tgz#fc07a0086ac97aca174e37edb480b69277f3931f" + integrity sha512-4G8PHNJ0x6nwAFsEzcuVDiBlyMsj2y3VjmFAx/uHk/R/gzJV+yRHICjT4MKGGu1cJq2hfowFWCyrr/Gg37FbgQ== + dependencies: + "@babel/runtime" "^7.3.1" + jss "10.9.0" + +jss-plugin-nested@^10.5.1: + version "10.9.0" + resolved "https://registry.yarnpkg.com/jss-plugin-nested/-/jss-plugin-nested-10.9.0.tgz#cc1c7d63ad542c3ccc6e2c66c8328c6b6b00f4b3" + integrity sha512-2UJnDrfCZpMYcpPYR16oZB7VAC6b/1QLsRiAutOt7wJaaqwCBvNsosLEu/fUyKNQNGdvg2PPJFDO5AX7dwxtoA== + dependencies: + "@babel/runtime" "^7.3.1" + jss "10.9.0" + tiny-warning "^1.0.2" + +jss-plugin-props-sort@^10.5.1: + version "10.9.0" + resolved "https://registry.yarnpkg.com/jss-plugin-props-sort/-/jss-plugin-props-sort-10.9.0.tgz#30e9567ef9479043feb6e5e59db09b4de687c47d" + integrity sha512-7A76HI8bzwqrsMOJTWKx/uD5v+U8piLnp5bvru7g/3ZEQOu1+PjHvv7bFdNO3DwNPC9oM0a//KwIJsIcDCjDzw== + dependencies: + "@babel/runtime" "^7.3.1" + jss "10.9.0" + +jss-plugin-rule-value-function@^10.5.1: + version "10.9.0" + resolved "https://registry.yarnpkg.com/jss-plugin-rule-value-function/-/jss-plugin-rule-value-function-10.9.0.tgz#379fd2732c0746fe45168011fe25544c1a295d67" + integrity sha512-IHJv6YrEf8pRzkY207cPmdbBstBaE+z8pazhPShfz0tZSDtRdQua5jjg6NMz3IbTasVx9FdnmptxPqSWL5tyJg== + dependencies: + "@babel/runtime" "^7.3.1" + jss "10.9.0" + tiny-warning "^1.0.2" + +jss-plugin-vendor-prefixer@^10.5.1: + version "10.9.0" + resolved "https://registry.yarnpkg.com/jss-plugin-vendor-prefixer/-/jss-plugin-vendor-prefixer-10.9.0.tgz#aa9df98abfb3f75f7ed59a3ec50a5452461a206a" + integrity sha512-MbvsaXP7iiVdYVSEoi+blrW+AYnTDvHTW6I6zqi7JcwXdc6I9Kbm234nEblayhF38EftoenbM+5218pidmC5gA== + dependencies: + "@babel/runtime" "^7.3.1" + css-vendor "^2.0.8" + jss "10.9.0" + +jss@10.9.0, jss@^10.5.1: + version "10.9.0" + resolved "https://registry.yarnpkg.com/jss/-/jss-10.9.0.tgz#7583ee2cdc904a83c872ba695d1baab4b59c141b" + integrity sha512-YpzpreB6kUunQBbrlArlsMpXYyndt9JATbt95tajx0t4MTJJcCJdd4hdNpHmOIDiUJrF/oX5wtVFrS3uofWfGw== + dependencies: + "@babel/runtime" "^7.3.1" + csstype "^3.0.2" + is-in-browser "^1.1.3" + tiny-warning "^1.0.2" + +kind-of@^6.0.2: + version "6.0.3" + resolved "https://registry.yarnpkg.com/kind-of/-/kind-of-6.0.3.tgz#07c05034a6c349fa06e24fa35aa76db4580ce4dd" + integrity sha512-dcS1ul+9tmeD95T+x28/ehLgd9mENa3LsvDTtzm3vyBEO7RPptvAD+t44WVXaUjTBRcrpFeFlC8WCruUR456hw== + +loader-runner@^4.2.0: + version "4.2.0" + resolved "https://registry.yarnpkg.com/loader-runner/-/loader-runner-4.2.0.tgz#d7022380d66d14c5fb1d496b89864ebcfd478384" + integrity sha512-92+huvxMvYlMzMt0iIOukcwYBFpkYJdpl2xsZ7LrlayO7E8SOv+JJUEK17B/dJIHAOLMfh2dZZ/Y18WgmGtYNw== + +loader-utils@^2.0.0: + version "2.0.2" + resolved "https://registry.yarnpkg.com/loader-utils/-/loader-utils-2.0.2.tgz#d6e3b4fb81870721ae4e0868ab11dd638368c129" + integrity sha512-TM57VeHptv569d/GKh6TAYdzKblwDNiumOdkFnejjD0XwTH87K90w3O7AiJRqdQoXygvi1VQTJTLGhJl7WqA7A== + dependencies: + big.js "^5.2.2" + emojis-list "^3.0.0" + json5 "^2.1.2" + +locate-path@^5.0.0: + version "5.0.0" + resolved "https://registry.yarnpkg.com/locate-path/-/locate-path-5.0.0.tgz#1afba396afd676a6d42504d0a67a3a7eb9f62aa0" + integrity sha512-t7hw9pI+WvuwNJXwk5zVHpyhIqzg2qTlklJOf0mVxGSbe3Fp2VieZcduNYjaLDoy6p9uGpQEGWG87WpMKlNq8g== + dependencies: + p-locate "^4.1.0" + +lodash@^4.17.14, lodash@^4.17.20, lodash@^4.17.21: + version "4.17.21" + resolved "https://registry.yarnpkg.com/lodash/-/lodash-4.17.21.tgz#679591c564c3bffaae8454cf0b3df370c3d6911c" + integrity sha512-v2kDEe57lecTulaDIuNTPy3Ry4gLGJ6Z1O3vE1krgXZNrsQ+LFTGHVxVjcXPs17LhbZVGedAJv8XZ1tvj5FvSg== + +loose-envify@^1.1.0, loose-envify@^1.4.0: + version "1.4.0" + resolved "https://registry.yarnpkg.com/loose-envify/-/loose-envify-1.4.0.tgz#71ee51fa7be4caec1a63839f7e682d8132d30caf" + integrity sha512-lyuxPGr/Wfhrlem2CL/UcnUc1zcqKAImBDzukY7Y5F/yQiNdko6+fRLevlw1HgMySw7f611UIY408EtxRSoK3Q== + dependencies: + js-tokens "^3.0.0 || ^4.0.0" + +lower-case@^2.0.2: + version "2.0.2" + resolved "https://registry.yarnpkg.com/lower-case/-/lower-case-2.0.2.tgz#6fa237c63dbdc4a82ca0fd882e4722dc5e634e28" + integrity sha512-7fm3l3NAF9WfN6W3JOmf5drwpVqX78JtoGJ3A6W0a6ZnldM41w2fV5D490psKFTpMds8TJse/eHLFFsNHHjHgg== + dependencies: + tslib "^2.0.3" + +lru-cache@^6.0.0: + version "6.0.0" + resolved "https://registry.yarnpkg.com/lru-cache/-/lru-cache-6.0.0.tgz#6d6fe6570ebd96aaf90fcad1dafa3b2566db3a94" + integrity sha512-Jo6dJ04CmSjuznwJSS3pUeWmd/H0ffTlkXXgwZi+eq1UCmqQwCh+eLsYOYCwY991i2Fah4h1BEMCx4qThGbsiA== + dependencies: + yallist "^4.0.0" + +media-typer@0.3.0: + version "0.3.0" + resolved "https://registry.yarnpkg.com/media-typer/-/media-typer-0.3.0.tgz#8710d7af0aa626f8fffa1ce00168545263255748" + integrity sha1-hxDXrwqmJvj/+hzgAWhUUmMlV0g= + +memfs@^3.4.1: + version "3.4.1" + resolved "https://registry.yarnpkg.com/memfs/-/memfs-3.4.1.tgz#b78092f466a0dce054d63d39275b24c71d3f1305" + integrity sha512-1c9VPVvW5P7I85c35zAdEr1TD5+F11IToIHIlrVIcflfnzPkJa0ZoYEoEdYDP8KgPFoSZ/opDrUsAoZWym3mtw== + dependencies: + fs-monkey "1.0.3" + +"memoize-one@>=3.1.1 <6": + version "5.2.1" + resolved "https://registry.yarnpkg.com/memoize-one/-/memoize-one-5.2.1.tgz#8337aa3c4335581839ec01c3d594090cebe8f00e" + integrity sha512-zYiwtZUcYyXKo/np96AGZAckk+FWWsUdJ3cHGGmld7+AhvcWmQyGCYUh1hc4Q/pkOhb65dQR/pqCyK0cOaHz4Q== + +memoize-one@^3.1.1: + version "3.1.1" + resolved "https://registry.yarnpkg.com/memoize-one/-/memoize-one-3.1.1.tgz#ef609811e3bc28970eac2884eece64d167830d17" + integrity sha512-YqVh744GsMlZu6xkhGslPSqSurOv6P+kLN2J3ysBZfagLcL5FdRK/0UpgLoL8hwjjEvvAVkjJZyFP+1T6p1vgA== + +memoize-one@^6.0.0: + version "6.0.0" + resolved "https://registry.yarnpkg.com/memoize-one/-/memoize-one-6.0.0.tgz#b2591b871ed82948aee4727dc6abceeeac8c1045" + integrity sha512-rkpe71W0N0c0Xz6QD0eJETuWAJGnJ9afsl1srmwPrI+yBCkge5EycXXbYRyvL29zZVUWQCY7InPRCv3GDXuZNw== + +memory-fs@^0.5.0: + version "0.5.0" + resolved "https://registry.yarnpkg.com/memory-fs/-/memory-fs-0.5.0.tgz#324c01288b88652966d161db77838720845a8e3c" + integrity sha512-jA0rdU5KoQMC0e6ppoNRtpp6vjFq6+NY7r8hywnC7V+1Xj/MtHwGIbB1QaK/dunyjWteJzmkpd7ooeWg10T7GA== + dependencies: + errno "^0.1.3" + readable-stream "^2.0.1" + +merge-descriptors@1.0.1: + version "1.0.1" + resolved "https://registry.yarnpkg.com/merge-descriptors/-/merge-descriptors-1.0.1.tgz#b00aaa556dd8b44568150ec9d1b953f3f90cbb61" + integrity sha1-sAqqVW3YtEVoFQ7J0blT8/kMu2E= + +merge-stream@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/merge-stream/-/merge-stream-2.0.0.tgz#52823629a14dd00c9770fb6ad47dc6310f2c1f60" + integrity sha512-abv/qOcuPfk3URPfDzmZU1LKmuw8kT+0nIHvKrKgFrwifol/doWcdA4ZqsWQ8ENrFKkd67Mfpo/LovbIUsbt3w== + +merge2@^1.3.0, merge2@^1.4.1: + version "1.4.1" + resolved "https://registry.yarnpkg.com/merge2/-/merge2-1.4.1.tgz#4368892f885e907455a6fd7dc55c0c9d404990ae" + integrity sha512-8q7VEgMJW4J8tcfVPy8g09NcQwZdbwFEqhe/WZkoIzjn/3TGDwtOCYtXGxA3O8tPzpczCCDgv+P2P5y00ZJOOg== + +methods@~1.1.2: + version "1.1.2" + resolved "https://registry.yarnpkg.com/methods/-/methods-1.1.2.tgz#5529a4d67654134edcc5266656835b0f851afcee" + integrity sha1-VSmk1nZUE07cxSZmVoNbD4Ua/O4= + +micromatch@^4.0.0, micromatch@^4.0.2, micromatch@^4.0.4: + version "4.0.4" + resolved "https://registry.yarnpkg.com/micromatch/-/micromatch-4.0.4.tgz#896d519dfe9db25fce94ceb7a500919bf881ebf9" + integrity sha512-pRmzw/XUcwXGpD9aI9q/0XOwLNygjETJ8y0ao0wdqprrzDa4YnxLcz7fQRZr8voh8V10kGhABbNcHVk5wHgWwg== + dependencies: + braces "^3.0.1" + picomatch "^2.2.3" + +mime-db@1.51.0: + version "1.51.0" + resolved "https://registry.yarnpkg.com/mime-db/-/mime-db-1.51.0.tgz#d9ff62451859b18342d960850dc3cfb77e63fb0c" + integrity sha512-5y8A56jg7XVQx2mbv1lu49NR4dokRnhZYTtL+KGfaa27uq4pSTXkwQkFJl4pkRMyNFz/EtYDSkiiEHx3F7UN6g== + +"mime-db@>= 1.43.0 < 2": + version "1.52.0" + resolved "https://registry.yarnpkg.com/mime-db/-/mime-db-1.52.0.tgz#bbabcdc02859f4987301c856e3387ce5ec43bf70" + integrity sha512-sPU4uV7dYlvtWJxwwxHD0PuihVNiE7TyAbQ5SWxDCB9mUYvOgroQOwYQQOKPJ8CIbE+1ETVlOoK1UC2nU3gYvg== + +mime-types@^2.1.27, mime-types@^2.1.31, mime-types@~2.1.17, mime-types@~2.1.24, mime-types@~2.1.34: + version "2.1.34" + resolved "https://registry.yarnpkg.com/mime-types/-/mime-types-2.1.34.tgz#5a712f9ec1503511a945803640fafe09d3793c24" + integrity sha512-6cP692WwGIs9XXdOO4++N+7qjqv0rqxxVvJ3VHPh/Sc9mVZcQP+ZGhkKiTvWMQRr2tbHkJP/Yn7Y0npb3ZBs4A== + dependencies: + mime-db "1.51.0" + +mime@1.6.0: + version "1.6.0" + resolved "https://registry.yarnpkg.com/mime/-/mime-1.6.0.tgz#32cd9e5c64553bd58d19a568af452acff04981b1" + integrity sha512-x0Vn8spI+wuJ1O6S7gnbaQg8Pxh4NNHb7KSINmEWKiPE4RKOplvijn+NkmYmmRgP68mc70j2EbeTFRsrswaQeg== + +mimic-fn@^2.1.0: + version "2.1.0" + resolved "https://registry.yarnpkg.com/mimic-fn/-/mimic-fn-2.1.0.tgz#7ed2c2ccccaf84d3ffcb7a69b57711fc2083401b" + integrity sha512-OqbOk5oEQeAZ8WXWydlu9HJjz9WVdEIvamMCcXmuqUYjTknH/sqsWvhQ3vgwKFRR1HpjvNBKQ37nbJgYzGqGcg== + +minimalistic-assert@^1.0.0: + version "1.0.1" + resolved "https://registry.yarnpkg.com/minimalistic-assert/-/minimalistic-assert-1.0.1.tgz#2e194de044626d4a10e7f7fbc00ce73e83e4d5c7" + integrity sha512-UtJcAD4yEaGtjPezWuO9wC4nwUnVH/8/Im3yEHQP4b67cXlD/Qr9hdITCU1xDbSEXg2XKNaP8jsReV7vQd00/A== + +minimatch@^3.0.4: + version "3.1.2" + resolved "https://registry.yarnpkg.com/minimatch/-/minimatch-3.1.2.tgz#19cd194bfd3e428f049a70817c038d89ab4be35b" + integrity sha512-J7p63hRiAjw1NDEww1W7i37+ByIrOWO5XQQAzZ3VOcL0PNybwpfmV/N05zFAzwQ9USyEcX6t3UO+K5aqBQOIHw== + dependencies: + brace-expansion "^1.1.7" + +minimist@^1.2.5: + version "1.2.5" + resolved "https://registry.yarnpkg.com/minimist/-/minimist-1.2.5.tgz#67d66014b66a6a8aaa0c083c5fd58df4e4e97602" + integrity sha512-FM9nNUYrRBAELZQT3xeZQ7fmMOBg6nWNmJKTcgsJeaLstP/UODVpGsr5OhXhhXg6f+qtJ8uiZ+PUxkDWcgIXLw== + +mkdirp@^0.5.5: + version "0.5.5" + resolved "https://registry.yarnpkg.com/mkdirp/-/mkdirp-0.5.5.tgz#d91cefd62d1436ca0f41620e251288d420099def" + integrity sha512-NKmAlESf6jMGym1++R0Ra7wvhV+wFW63FaSOFPwRahvea0gMUcGUhVeAg/0BC0wiv9ih5NYPB1Wn1UEI1/L+xQ== + dependencies: + minimist "^1.2.5" + +moment@^2.24.0, moment@^2.25.3: + version "2.29.1" + resolved "https://registry.yarnpkg.com/moment/-/moment-2.29.1.tgz#b2be769fa31940be9eeea6469c075e35006fa3d3" + integrity sha512-kHmoybcPV8Sqy59DwNDY3Jefr64lK/by/da0ViFcuA4DH0vQg5Q6Ze5VimxkfQNSC+Mls/Kx53s7TjP1RhFEDQ== + +ms@2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/ms/-/ms-2.0.0.tgz#5608aeadfc00be6c2901df5f9861788de0d597c8" + integrity sha1-VgiurfwAvmwpAd9fmGF4jeDVl8g= + +ms@2.1.2: + version "2.1.2" + resolved "https://registry.yarnpkg.com/ms/-/ms-2.1.2.tgz#d09d1f357b443f493382a8eb3ccd183872ae6009" + integrity sha512-sGkPx+VjMtmA6MX27oA4FBFELFCZZ4S4XqeGOXCv68tT+jb3vk/RyaKWP0PTKyWtmLSM0b+adUTEvbs1PEaH2w== + +ms@2.1.3, ms@^2.1.1: + version "2.1.3" + resolved "https://registry.yarnpkg.com/ms/-/ms-2.1.3.tgz#574c8138ce1d2b5861f0b44579dbadd60c6615b2" + integrity sha512-6FlzubTLZG3J2a/NVCAleEhjzq5oxgHyaCU9yYXvcLsvoVaHJq/s5xXI6/XXP6tz7R9xAOtHnSO/tXtF3WRTlA== + +multicast-dns-service-types@^1.1.0: + version "1.1.0" + resolved "https://registry.yarnpkg.com/multicast-dns-service-types/-/multicast-dns-service-types-1.1.0.tgz#899f11d9686e5e05cb91b35d5f0e63b773cfc901" + integrity sha1-iZ8R2WhuXgXLkbNdXw5jt3PPyQE= + +multicast-dns@^6.0.1: + version "6.2.3" + resolved "https://registry.yarnpkg.com/multicast-dns/-/multicast-dns-6.2.3.tgz#a0ec7bd9055c4282f790c3c82f4e28db3b31b229" + integrity sha512-ji6J5enbMyGRHIAkAOu3WdV8nggqviKCEKtXcOqfphZZtQrmHKycfynJ2V7eVPUA4NhJ6V7Wf4TmGbTwKE9B6g== + dependencies: + dns-packet "^1.3.1" + thunky "^1.0.2" + +nanoid@^3.1.31, nanoid@^3.3.1: + version "3.3.1" + resolved "https://registry.yarnpkg.com/nanoid/-/nanoid-3.3.1.tgz#6347a18cac88af88f58af0b3594b723d5e99bb35" + integrity sha512-n6Vs/3KGyxPQd6uO0eH4Bv0ojGSUvuLlIHtC3Y0kEO23YRge8H9x1GCzLn28YX0H66pMkxuaeESFq4tKISKwdw== + +negotiator@0.6.3: + version "0.6.3" + resolved "https://registry.yarnpkg.com/negotiator/-/negotiator-0.6.3.tgz#58e323a72fedc0d6f9cd4d31fe49f51479590ccd" + integrity sha512-+EUsqGPLsM+j/zdChZjsnX51g4XrHFOIXwfnCVPGlQk/k5giakcKsuxCObBRu6DSm9opw/O6slWbJdghQM4bBg== + +neo-async@^2.6.2: + version "2.6.2" + resolved "https://registry.yarnpkg.com/neo-async/-/neo-async-2.6.2.tgz#b4aafb93e3aeb2d8174ca53cf163ab7d7308305f" + integrity sha512-Yd3UES5mWCSqR+qNT93S3UoYUkqAZ9lLg8a7g9rimsWmYGK8cVToA4/sF3RrshdyV3sAGMXVUmpMYOw+dLpOuw== + +no-case@^3.0.4: + version "3.0.4" + resolved "https://registry.yarnpkg.com/no-case/-/no-case-3.0.4.tgz#d361fd5c9800f558551a8369fc0dcd4662b6124d" + integrity sha512-fgAN3jGAh+RoxUGZHTSOLJIqUc2wmoBwGR4tbpNAKmmovFoWq0OdRkb0VkldReO2a2iBT/OEulG9XSUc10r3zg== + dependencies: + lower-case "^2.0.2" + tslib "^2.0.3" + +node-fetch@^1.0.1, node-fetch@^2.6.1: + version "2.6.7" + resolved "https://registry.yarnpkg.com/node-fetch/-/node-fetch-2.6.7.tgz#24de9fba827e3b4ae44dc8b20256a379160052ad" + integrity sha512-ZjMPFEfVx5j+y2yF35Kzx5sF7kDzxuDj6ziH4FFbOp87zKDZNx8yExJIb05OGF4Nlt9IHFIMBkRl41VdvcNdbQ== + dependencies: + whatwg-url "^5.0.0" + +node-forge@^1.2.0: + version "1.2.1" + resolved "https://registry.yarnpkg.com/node-forge/-/node-forge-1.2.1.tgz#82794919071ef2eb5c509293325cec8afd0fd53c" + integrity sha512-Fcvtbb+zBcZXbTTVwqGA5W+MKBj56UjVRevvchv5XrcyXbmNdesfZL37nlcWOfpgHhgmxApw3tQbTr4CqNmX4w== + +node-releases@^2.0.2: + version "2.0.2" + resolved "https://registry.yarnpkg.com/node-releases/-/node-releases-2.0.2.tgz#7139fe71e2f4f11b47d4d2986aaf8c48699e0c01" + integrity sha512-XxYDdcQ6eKqp/YjI+tb2C5WM2LgjnZrfYg4vgQt49EK268b6gYCHsBLrK2qvJo4FmCtqmKezb0WZFK4fkrZNsg== + +normalize-path@^3.0.0, normalize-path@~3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/normalize-path/-/normalize-path-3.0.0.tgz#0dcd69ff23a1c9b11fd0978316644a0388216a65" + integrity sha512-6eZs5Ls3WtCisHWp9S2GUy8dqkpGi4BVSz3GaqiE6ezub0512ESztXUwUB6C6IKbQkY2Pnb/mD4WYojCRwcwLA== + +npm-run-path@^4.0.1: + version "4.0.1" + resolved "https://registry.yarnpkg.com/npm-run-path/-/npm-run-path-4.0.1.tgz#b7ecd1e5ed53da8e37a55e1c2269e0b97ed748ea" + integrity sha512-S48WzZW777zhNIrn7gxOlISNAqi9ZC/uQFnRdbeIHhZhCA6UqpkOT8T1G7BvfdgP4Er8gF4sUbaS0i7QvIfCWw== + dependencies: + path-key "^3.0.0" + +nth-check@^2.0.1: + version "2.0.1" + resolved "https://registry.yarnpkg.com/nth-check/-/nth-check-2.0.1.tgz#2efe162f5c3da06a28959fbd3db75dbeea9f0fc2" + integrity sha512-it1vE95zF6dTT9lBsYbxvqh0Soy4SPowchj0UBGj/V6cTPnXXtQOPUbhZ6CmGzAD/rW22LQK6E96pcdJXk4A4w== + dependencies: + boolbase "^1.0.0" + +object-assign@^4.1.1: + version "4.1.1" + resolved "https://registry.yarnpkg.com/object-assign/-/object-assign-4.1.1.tgz#2109adc7965887cfc05cbbd442cac8bfbb360863" + integrity sha1-IQmtx5ZYh8/AXLvUQsrIv7s2CGM= + +object-is@^1.0.1: + version "1.1.5" + resolved "https://registry.yarnpkg.com/object-is/-/object-is-1.1.5.tgz#b9deeaa5fc7f1846a0faecdceec138e5778f53ac" + integrity sha512-3cyDsyHgtmi7I7DfSSI2LDp6SK2lwvtbg0p0R1e0RvTqF5ceGx+K2dfSjm1bKDMVCFEDAQvy+o8c6a7VujOddw== + dependencies: + call-bind "^1.0.2" + define-properties "^1.1.3" + +object-keys@^1.0.12, object-keys@^1.1.1: + version "1.1.1" + resolved "https://registry.yarnpkg.com/object-keys/-/object-keys-1.1.1.tgz#1c47f272df277f3b1daf061677d9c82e2322c60e" + integrity sha512-NuAESUOUMrlIXOfHKzD6bpPu3tYt3xvjNdRIQ+FeT0lNb4K8WR70CaDxhuNguS2XG+GjkyMwOzsN5ZktImfhLA== + +obuf@^1.0.0, obuf@^1.1.2: + version "1.1.2" + resolved "https://registry.yarnpkg.com/obuf/-/obuf-1.1.2.tgz#09bea3343d41859ebd446292d11c9d4db619084e" + integrity sha512-PX1wu0AmAdPqOL1mWhqmlOd8kOIZQwGZw6rh7uby9fTc5lhaOWFLX3I6R1hrF9k3zUY40e6igsLGkDXK92LJNg== + +on-finished@~2.3.0: + version "2.3.0" + resolved "https://registry.yarnpkg.com/on-finished/-/on-finished-2.3.0.tgz#20f1336481b083cd75337992a16971aa2d906947" + integrity sha1-IPEzZIGwg811M3mSoWlxqi2QaUc= + dependencies: + ee-first "1.1.1" + +on-headers@~1.0.2: + version "1.0.2" + resolved "https://registry.yarnpkg.com/on-headers/-/on-headers-1.0.2.tgz#772b0ae6aaa525c399e489adfad90c403eb3c28f" + integrity sha512-pZAE+FJLoyITytdqK0U5s+FIpjN0JP3OzFi/u8Rx+EV5/W+JTWGXG8xFzevE7AjBfDqHv/8vL8qQsIhHnqRkrA== + +once@^1.3.0: + version "1.4.0" + resolved "https://registry.yarnpkg.com/once/-/once-1.4.0.tgz#583b1aa775961d4b113ac17d9c50baef9dd76bd1" + integrity sha1-WDsap3WWHUsROsF9nFC6753Xa9E= + dependencies: + wrappy "1" + +onetime@^5.1.2: + version "5.1.2" + resolved "https://registry.yarnpkg.com/onetime/-/onetime-5.1.2.tgz#d0e96ebb56b07476df1dd9c4806e5237985ca45e" + integrity sha512-kbpaSSGJTWdAY5KPVeMOKXSrPtr8C8C7wodJbcsd51jRnmD+GZu8Y0VoU6Dm5Z4vWr0Ig/1NKuWRKf7j5aaYSg== + dependencies: + mimic-fn "^2.1.0" + +open@^8.0.9: + version "8.4.0" + resolved "https://registry.yarnpkg.com/open/-/open-8.4.0.tgz#345321ae18f8138f82565a910fdc6b39e8c244f8" + integrity sha512-XgFPPM+B28FtCCgSb9I+s9szOC1vZRSwgWsRUA5ylIxRTgKozqjOCrVOqGsYABPYK5qnfqClxZTFBa8PKt2v6Q== + dependencies: + define-lazy-prop "^2.0.0" + is-docker "^2.1.1" + is-wsl "^2.2.0" + +p-limit@^2.2.0: + version "2.3.0" + resolved "https://registry.yarnpkg.com/p-limit/-/p-limit-2.3.0.tgz#3dd33c647a214fdfffd835933eb086da0dc21db1" + integrity sha512-//88mFWSJx8lxCzwdAABTJL2MyWB12+eIY7MDL2SqLmAkeKU9qxRvWuSyTjm3FUmpBEMuFfckAIqEaVGUDxb6w== + dependencies: + p-try "^2.0.0" + +p-locate@^4.1.0: + version "4.1.0" + resolved "https://registry.yarnpkg.com/p-locate/-/p-locate-4.1.0.tgz#a3428bb7088b3a60292f66919278b7c297ad4f07" + integrity sha512-R79ZZ/0wAxKGu3oYMlz8jy/kbhsNrS7SKZ7PxEHBgJ5+F2mtFW2fK2cOtBh1cHYkQsbzFV7I+EoRKe6Yt0oK7A== + dependencies: + p-limit "^2.2.0" + +p-map@^4.0.0: + version "4.0.0" + resolved "https://registry.yarnpkg.com/p-map/-/p-map-4.0.0.tgz#bb2f95a5eda2ec168ec9274e06a747c3e2904d2b" + integrity sha512-/bjOqmgETBYB5BoEeGVea8dmvHb2m9GLy1E9W43yeyfP6QQCZGFNa+XRceJEuDB6zqr+gKpIAmlLebMpykw/MQ== + dependencies: + aggregate-error "^3.0.0" + +p-retry@^4.5.0: + version "4.6.1" + resolved "https://registry.yarnpkg.com/p-retry/-/p-retry-4.6.1.tgz#8fcddd5cdf7a67a0911a9cf2ef0e5df7f602316c" + integrity sha512-e2xXGNhZOZ0lfgR9kL34iGlU8N/KO0xZnQxVEwdeOvpqNDQfdnxIYizvWtK8RglUa3bGqI8g0R/BdfzLMxRkiA== + dependencies: + "@types/retry" "^0.12.0" + retry "^0.13.1" + +p-try@^2.0.0: + version "2.2.0" + resolved "https://registry.yarnpkg.com/p-try/-/p-try-2.2.0.tgz#cb2868540e313d61de58fafbe35ce9004d5540e6" + integrity sha512-R4nPAVTAU0B9D35/Gk3uJf/7XYbQcyohSKdvAxIRSNghFl4e71hVoGnBNQz9cWaXxO2I10KTC+3jMdvvoKw6dQ== + +param-case@^3.0.4: + version "3.0.4" + resolved "https://registry.yarnpkg.com/param-case/-/param-case-3.0.4.tgz#7d17fe4aa12bde34d4a77d91acfb6219caad01c5" + integrity sha512-RXlj7zCYokReqWpOPH9oYivUzLYZ5vAPIfEmCTNViosC78F8F0H9y7T7gG2M39ymgutxF5gcFEsyZQSph9Bp3A== + dependencies: + dot-case "^3.0.4" + tslib "^2.0.3" + +parseurl@~1.3.2, parseurl@~1.3.3: + version "1.3.3" + resolved "https://registry.yarnpkg.com/parseurl/-/parseurl-1.3.3.tgz#9da19e7bee8d12dff0513ed5b76957793bc2e8d4" + integrity sha512-CiyeOxFT/JZyN5m0z9PfXw4SCBJ6Sygz1Dpl0wqjlhDEGGBP1GnsUVEL0p63hoG1fcj3fHynXi9NYO4nWOL+qQ== + +pascal-case@^3.1.2: + version "3.1.2" + resolved "https://registry.yarnpkg.com/pascal-case/-/pascal-case-3.1.2.tgz#b48e0ef2b98e205e7c1dae747d0b1508237660eb" + integrity sha512-uWlGT3YSnK9x3BQJaOdcZwrnV6hPpd8jFH1/ucpiLRPh/2zCVJKS19E4GvYHvaCcACn3foXZ0cLB9Wrx1KGe5g== + dependencies: + no-case "^3.0.4" + tslib "^2.0.3" + +path-exists@^4.0.0: + version "4.0.0" + resolved "https://registry.yarnpkg.com/path-exists/-/path-exists-4.0.0.tgz#513bdbe2d3b95d7762e8c1137efa195c6c61b5b3" + integrity sha512-ak9Qy5Q7jYb2Wwcey5Fpvg2KoAc/ZIhLSLOSBmRmygPsGwkVVt0fZa0qrtMz+m6tJTAHfZQ8FnmB4MG4LWy7/w== + +path-is-absolute@^1.0.0: + version "1.0.1" + resolved "https://registry.yarnpkg.com/path-is-absolute/-/path-is-absolute-1.0.1.tgz#174b9268735534ffbc7ace6bf53a5a9e1b5c5f5f" + integrity sha1-F0uSaHNVNP+8es5r9TpanhtcX18= + +path-key@^3.0.0, path-key@^3.1.0: + version "3.1.1" + resolved "https://registry.yarnpkg.com/path-key/-/path-key-3.1.1.tgz#581f6ade658cbba65a0d3380de7753295054f375" + integrity sha512-ojmeN0qd+y0jszEtoY48r0Peq5dwMEkIlCOu6Q5f41lfkswXuKtYrhgoTpLnyIcHm24Uhqx+5Tqm2InSwLhE6Q== + +path-parse@^1.0.7: + version "1.0.7" + resolved "https://registry.yarnpkg.com/path-parse/-/path-parse-1.0.7.tgz#fbc114b60ca42b30d9daf5858e4bd68bbedb6735" + integrity sha512-LDJzPVEEEPR+y48z93A0Ed0yXb8pAByGWo/k5YYdYgpY2/2EsOsksJrq7lOHxryrVOn1ejG6oAp8ahvOIQD8sw== + +path-to-regexp@0.1.7: + version "0.1.7" + resolved "https://registry.yarnpkg.com/path-to-regexp/-/path-to-regexp-0.1.7.tgz#df604178005f522f15eb4490e7247a1bfaa67f8c" + integrity sha1-32BBeABfUi8V60SQ5yR6G/qmf4w= + +path-type@^4.0.0: + version "4.0.0" + resolved "https://registry.yarnpkg.com/path-type/-/path-type-4.0.0.tgz#84ed01c0a7ba380afe09d90a8c180dcd9d03043b" + integrity sha512-gDKb8aZMDeD/tZWs9P6+q0J9Mwkdl6xMV8TjnGP3qJVJ06bdMgkbBlLU8IdfOsIsFz2BW1rNVT3XuNEl8zPAvw== + +picocolors@^1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/picocolors/-/picocolors-1.0.0.tgz#cb5bdc74ff3f51892236eaf79d68bc44564ab81c" + integrity sha512-1fygroTLlHu66zi26VoTDv8yRgm0Fccecssto+MhsZ0D/DGW2sm8E8AjW7NU5VVTRt5GxbeZ5qBuJr+HyLYkjQ== + +picomatch@^2.0.4, picomatch@^2.2.1, picomatch@^2.2.3: + version "2.3.1" + resolved "https://registry.yarnpkg.com/picomatch/-/picomatch-2.3.1.tgz#3ba3833733646d9d3e4995946c1365a67fb07a42" + integrity sha512-JU3teHTNjmE2VCGFzuY8EXzCDVwEqB2a8fsIvwaStHhAWJEeVd1o1QD80CU6+ZdEXXSLbSsuLwJjkCBWqRQUVA== + +pkg-dir@^4.2.0: + version "4.2.0" + resolved "https://registry.yarnpkg.com/pkg-dir/-/pkg-dir-4.2.0.tgz#f099133df7ede422e81d1d8448270eeb3e4261f3" + integrity sha512-HRDzbaKjC+AOWVXxAU/x54COGeIv9eb+6CkDSQoNTt4XyWoIJvuPsXizxu/Fr23EiekbtZwmh1IcIG/l/a10GQ== + dependencies: + find-up "^4.0.0" + +popper.js@1.16.1-lts: + version "1.16.1-lts" + resolved "https://registry.yarnpkg.com/popper.js/-/popper.js-1.16.1-lts.tgz#cf6847b807da3799d80ee3d6d2f90df8a3f50b05" + integrity sha512-Kjw8nKRl1m+VrSFCoVGPph93W/qrSO7ZkqPpTf7F4bk/sqcfWK019dWBUpE/fBOsOQY1dks/Bmcbfn1heM/IsA== + +portable-fetch@^3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/portable-fetch/-/portable-fetch-3.0.0.tgz#3cbf4aa6dbc5a5734b41c0419c9273313bfd9ad8" + integrity sha1-PL9KptvFpXNLQcBBnJJzMTv9mtg= + dependencies: + node-fetch "^1.0.1" + whatwg-fetch ">=0.10.0" + +portfinder@^1.0.28: + version "1.0.28" + resolved "https://registry.yarnpkg.com/portfinder/-/portfinder-1.0.28.tgz#67c4622852bd5374dd1dd900f779f53462fac778" + integrity sha512-Se+2isanIcEqf2XMHjyUKskczxbPH7dQnlMjXX6+dybayyHvAf/TCgyMRlzf/B6QDhAEFOGes0pzRo3by4AbMA== + dependencies: + async "^2.6.2" + debug "^3.1.1" + mkdirp "^0.5.5" + +postcss-modules-extract-imports@^3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/postcss-modules-extract-imports/-/postcss-modules-extract-imports-3.0.0.tgz#cda1f047c0ae80c97dbe28c3e76a43b88025741d" + integrity sha512-bdHleFnP3kZ4NYDhuGlVK+CMrQ/pqUm8bx/oGL93K6gVwiclvX5x0n76fYMKuIGKzlABOy13zsvqjb0f92TEXw== + +postcss-modules-local-by-default@^4.0.0: + version "4.0.0" + resolved "https://registry.yarnpkg.com/postcss-modules-local-by-default/-/postcss-modules-local-by-default-4.0.0.tgz#ebbb54fae1598eecfdf691a02b3ff3b390a5a51c" + integrity sha512-sT7ihtmGSF9yhm6ggikHdV0hlziDTX7oFoXtuVWeDd3hHObNkcHRo9V3yg7vCAY7cONyxJC/XXCmmiHHcvX7bQ== + dependencies: + icss-utils "^5.0.0" + postcss-selector-parser "^6.0.2" + postcss-value-parser "^4.1.0" + +postcss-modules-scope@^3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/postcss-modules-scope/-/postcss-modules-scope-3.0.0.tgz#9ef3151456d3bbfa120ca44898dfca6f2fa01f06" + integrity sha512-hncihwFA2yPath8oZ15PZqvWGkWf+XUfQgUGamS4LqoP1anQLOsOJw0vr7J7IwLpoY9fatA2qiGUGmuZL0Iqlg== + dependencies: + postcss-selector-parser "^6.0.4" + +postcss-modules-values@^4.0.0: + version "4.0.0" + resolved "https://registry.yarnpkg.com/postcss-modules-values/-/postcss-modules-values-4.0.0.tgz#d7c5e7e68c3bb3c9b27cbf48ca0bb3ffb4602c9c" + integrity sha512-RDxHkAiEGI78gS2ofyvCsu7iycRv7oqw5xMWn9iMoR0N/7mf9D50ecQqUo5BZ9Zh2vH4bCUR/ktCqbB9m8vJjQ== + dependencies: + icss-utils "^5.0.0" + +postcss-selector-parser@^6.0.2, postcss-selector-parser@^6.0.4: + version "6.0.9" + resolved "https://registry.yarnpkg.com/postcss-selector-parser/-/postcss-selector-parser-6.0.9.tgz#ee71c3b9ff63d9cd130838876c13a2ec1a992b2f" + integrity sha512-UO3SgnZOVTwu4kyLR22UQ1xZh086RyNZppb7lLAKBFK8a32ttG5i87Y/P3+2bRSjZNyJ1B7hfFNo273tKe9YxQ== + dependencies: + cssesc "^3.0.0" + util-deprecate "^1.0.2" + +postcss-value-parser@^4.1.0: + version "4.2.0" + resolved "https://registry.yarnpkg.com/postcss-value-parser/-/postcss-value-parser-4.2.0.tgz#723c09920836ba6d3e5af019f92bc0971c02e514" + integrity sha512-1NNCs6uurfkVbeXG4S8JFT9t19m45ICnif8zWLd5oPSZ50QnwMfK+H3jv408d4jw/7Bttv5axS5IiHoLaVNHeQ== + +postcss@^8.2.15: + version "8.4.8" + resolved "https://registry.yarnpkg.com/postcss/-/postcss-8.4.8.tgz#dad963a76e82c081a0657d3a2f3602ce10c2e032" + integrity sha512-2tXEqGxrjvAO6U+CJzDL2Fk2kPHTv1jQsYkSoMeOis2SsYaXRO2COxTdQp99cYvif9JTXaAk9lYGc3VhJt7JPQ== + dependencies: + nanoid "^3.3.1" + picocolors "^1.0.0" + source-map-js "^1.0.2" + +prettier@^2.1.2: + version "2.5.1" + resolved "https://registry.yarnpkg.com/prettier/-/prettier-2.5.1.tgz#fff75fa9d519c54cf0fce328c1017d94546bc56a" + integrity sha512-vBZcPRUR5MZJwoyi3ZoyQlc1rXeEck8KgeC9AwwOn+exuxLxq5toTRDTSaVrXHxelDMHy9zlicw8u66yxoSUFg== + +pretty-error@^4.0.0: + version "4.0.0" + resolved "https://registry.yarnpkg.com/pretty-error/-/pretty-error-4.0.0.tgz#90a703f46dd7234adb46d0f84823e9d1cb8f10d6" + integrity sha512-AoJ5YMAcXKYxKhuJGdcvse+Voc6v1RgnsR3nWcYU7q4t6z0Q6T86sv5Zq8VIRbOWWFpvdGE83LtdSMNd+6Y0xw== + dependencies: + lodash "^4.17.20" + renderkid "^3.0.0" + +process-nextick-args@~2.0.0: + version "2.0.1" + resolved "https://registry.yarnpkg.com/process-nextick-args/-/process-nextick-args-2.0.1.tgz#7820d9b16120cc55ca9ae7792680ae7dba6d7fe2" + integrity sha512-3ouUOpQhtgrbOa17J7+uxOTpITYWaGP7/AhoR3+A+/1e9skrzelGi/dXzEYyvbxubEF6Wn2ypscTKiKJFFn1ag== + +prop-types@^15.6.2, prop-types@^15.7.2: + version "15.8.1" + resolved "https://registry.yarnpkg.com/prop-types/-/prop-types-15.8.1.tgz#67d87bf1a694f48435cf332c24af10214a3140b5" + integrity sha512-oj87CgZICdulUohogVAR7AjlC0327U4el4L6eAvOqCeudMDVU0NThNaV+b9Df4dXgSP1gXMTnPdhfe/2qDH5cg== + dependencies: + loose-envify "^1.4.0" + object-assign "^4.1.1" + react-is "^16.13.1" + +proxy-addr@~2.0.7: + version "2.0.7" + resolved "https://registry.yarnpkg.com/proxy-addr/-/proxy-addr-2.0.7.tgz#f19fe69ceab311eeb94b42e70e8c2070f9ba1025" + integrity sha512-llQsMLSUDUPT44jdrU/O37qlnifitDP+ZwrmmZcoSKyLKvtZxpyV0n2/bD/N4tBAAZ/gJEdZU7KMraoK1+XYAg== + dependencies: + forwarded "0.2.0" + ipaddr.js "1.9.1" + +prr@~1.0.1: + version "1.0.1" + resolved "https://registry.yarnpkg.com/prr/-/prr-1.0.1.tgz#d3fc114ba06995a45ec6893f484ceb1d78f5f476" + integrity sha1-0/wRS6BplaRexok/SEzrHXj19HY= + +punycode@^2.1.0: + version "2.1.1" + resolved "https://registry.yarnpkg.com/punycode/-/punycode-2.1.1.tgz#b58b010ac40c22c5657616c8d2c2c02c7bf479ec" + integrity sha512-XRsRjdf+j5ml+y/6GKHPZbrF/8p2Yga0JPtdqTIY2Xe5ohJPD9saDJJLPvp9+NSBprVvevdXZybnj2cv8OEd0A== + +qs@6.9.7: + version "6.9.7" + resolved "https://registry.yarnpkg.com/qs/-/qs-6.9.7.tgz#4610846871485e1e048f44ae3b94033f0e675afe" + integrity sha512-IhMFgUmuNpyRfxA90umL7ByLlgRXu6tIfKPpF5TmcfRLlLCckfP/g3IQmju6jjpu+Hh8rA+2p6A27ZSPOOHdKw== + +queue-microtask@^1.2.2: + version "1.2.3" + resolved "https://registry.yarnpkg.com/queue-microtask/-/queue-microtask-1.2.3.tgz#4929228bbc724dfac43e0efb058caf7b6cfb6243" + integrity sha512-NuaNSa6flKT5JaSYQzJok04JzTL1CA6aGhv5rfLW3PgqA+M2ChpZQnAC8h8i4ZFkBS8X5RqkDBHA7r4hej3K9A== + +randombytes@^2.1.0: + version "2.1.0" + resolved "https://registry.yarnpkg.com/randombytes/-/randombytes-2.1.0.tgz#df6f84372f0270dc65cdf6291349ab7a473d4f2a" + integrity sha512-vYl3iOX+4CKUWuxGi9Ukhie6fsqXqS9FE2Zaic4tNFD2N2QQaXOMFbuKK4QmDHC0JO6B1Zp41J0LpT0oR68amQ== + dependencies: + safe-buffer "^5.1.0" + +range-parser@^1.2.1, range-parser@~1.2.1: + version "1.2.1" + resolved "https://registry.yarnpkg.com/range-parser/-/range-parser-1.2.1.tgz#3cf37023d199e1c24d1a55b84800c2f3e6468031" + integrity sha512-Hrgsx+orqoygnmhFbKaHE6c296J+HTAQXoxEF6gNupROmmGJRoyzfG3ccAveqCBrwr/2yxQ5BVd/GTl5agOwSg== + +raw-body@2.4.3: + version "2.4.3" + resolved "https://registry.yarnpkg.com/raw-body/-/raw-body-2.4.3.tgz#8f80305d11c2a0a545c2d9d89d7a0286fcead43c" + integrity sha512-UlTNLIcu0uzb4D2f4WltY6cVjLi+/jEN4lgEUj3E04tpMDpUlkBo/eSn6zou9hum2VMNpCCUone0O0WeJim07g== + dependencies: + bytes "3.1.2" + http-errors "1.8.1" + iconv-lite "0.4.24" + unpipe "1.0.0" + +rc-align@^4.0.0: + version "4.0.11" + resolved "https://registry.yarnpkg.com/rc-align/-/rc-align-4.0.11.tgz#8198c62db266bc1b8ef05e56c13275bf72628a5e" + integrity sha512-n9mQfIYQbbNTbefyQnRHZPWuTEwG1rY4a9yKlIWHSTbgwI+XUMGRYd0uJ5pE2UbrNX0WvnMBA1zJ3Lrecpra/A== + dependencies: + "@babel/runtime" "^7.10.1" + classnames "2.x" + dom-align "^1.7.0" + lodash "^4.17.21" + rc-util "^5.3.0" + resize-observer-polyfill "^1.5.1" + +rc-cascader@~3.2.1: + version "3.2.7" + resolved "https://registry.yarnpkg.com/rc-cascader/-/rc-cascader-3.2.7.tgz#74ac3ab9258f930e0c84dfacffd838b122b2cedf" + integrity sha512-M8VtKtifTXXo/qqXj63p12tsMNXm1z45Lytj7tu86L6gxIF8keDPcJ16/ZqrhS5JwlBPfoJNA1VooNl/KId15A== + dependencies: + "@babel/runtime" "^7.12.5" + array-tree-filter "^2.1.0" + classnames "^2.3.1" + rc-select "~14.0.0-alpha.23" + rc-tree "~5.4.3" + rc-util "^5.6.1" + +rc-checkbox@~2.3.0: + version "2.3.2" + resolved "https://registry.yarnpkg.com/rc-checkbox/-/rc-checkbox-2.3.2.tgz#f91b3678c7edb2baa8121c9483c664fa6f0aefc1" + integrity sha512-afVi1FYiGv1U0JlpNH/UaEXdh6WUJjcWokj/nUN2TgG80bfG+MDdbfHKlLcNNba94mbjy2/SXJ1HDgrOkXGAjg== + dependencies: + "@babel/runtime" "^7.10.1" + classnames "^2.2.1" + +rc-collapse@~3.1.0: + version "3.1.2" + resolved "https://registry.yarnpkg.com/rc-collapse/-/rc-collapse-3.1.2.tgz#76028a811b845d03d9460ccc409c7ea8ad09db14" + integrity sha512-HujcKq7mghk/gVKeI6EjzTbb8e19XUZpakrYazu1MblEZ3Hu3WBMSN4A3QmvbF6n1g7x6lUlZvsHZ5shABWYOQ== + dependencies: + "@babel/runtime" "^7.10.1" + classnames "2.x" + rc-motion "^2.3.4" + rc-util "^5.2.1" + shallowequal "^1.1.0" + +rc-dialog@~8.6.0: + version "8.6.0" + resolved "https://registry.yarnpkg.com/rc-dialog/-/rc-dialog-8.6.0.tgz#3b228dac085de5eed8c6237f31162104687442e7" + integrity sha512-GSbkfqjqxpZC5/zc+8H332+q5l/DKUhpQr0vdX2uDsxo5K0PhvaMEVjyoJUTkZ3+JstEADQji1PVLVb/2bJeOQ== + dependencies: + "@babel/runtime" "^7.10.1" + classnames "^2.2.6" + rc-motion "^2.3.0" + rc-util "^5.6.1" + +rc-drawer@~4.4.2: + version "4.4.3" + resolved "https://registry.yarnpkg.com/rc-drawer/-/rc-drawer-4.4.3.tgz#2094937a844e55dc9644236a2d9fba79c344e321" + integrity sha512-FYztwRs3uXnFOIf1hLvFxIQP9MiZJA+0w+Os8dfDh/90X7z/HqP/Yg+noLCIeHEbKln1Tqelv8ymCAN24zPcfQ== + dependencies: + "@babel/runtime" "^7.10.1" + classnames "^2.2.6" + rc-util "^5.7.0" + +rc-dropdown@^3.2.0, rc-dropdown@~3.3.2: + version "3.3.2" + resolved "https://registry.yarnpkg.com/rc-dropdown/-/rc-dropdown-3.3.2.tgz#097c2ec1b6d55c10eeb94dcf6120ba034c7a58e0" + integrity sha512-49GOz42oNvLtYGoJ2X5UWXJFp7aUiSZkj9OcgTV1UpxFZqHQMw+xijkaL5k3XDkMbb92XsuFnFt7IGG3/C0DKw== + dependencies: + "@babel/runtime" "^7.10.1" + classnames "^2.2.6" + rc-trigger "^5.0.4" + +rc-field-form@~1.23.0: + version "1.23.1" + resolved "https://registry.yarnpkg.com/rc-field-form/-/rc-field-form-1.23.1.tgz#638c11d05d7ed2efdcb862ff3da5fe2a7d199aaa" + integrity sha512-Mun+eaFmX1Pjud9bz0fD0IvxwDfFKWk2Q8tkt4sg4aKR9/FML/rzYC5MjY77p86X45XBurBDUR3gAda+Cg/ULw== + dependencies: + "@babel/runtime" "^7.8.4" + async-validator "^4.0.2" + rc-util "^5.8.0" + +rc-image@~5.2.5: + version "5.2.5" + resolved "https://registry.yarnpkg.com/rc-image/-/rc-image-5.2.5.tgz#44e6ffc842626827960e7ab72e1c0d6f3a8ce440" + integrity sha512-qUfZjYIODxO0c8a8P5GeuclYXZjzW4hV/5hyo27XqSFo1DmTCs2HkVeQObkcIk5kNsJtgsj1KoPThVsSc/PXOw== + dependencies: + "@babel/runtime" "^7.11.2" + classnames "^2.2.6" + rc-dialog "~8.6.0" + rc-util "^5.0.6" + +rc-input-number@~7.3.0: + version "7.3.4" + resolved "https://registry.yarnpkg.com/rc-input-number/-/rc-input-number-7.3.4.tgz#674aea98260250287d36e330a7e065b174486e9d" + integrity sha512-W9uqSzuvJUnz8H8vsVY4kx+yK51SsAxNTwr8SNH4G3XqQNocLVmKIibKFRjocnYX1RDHMND9FFbgj2h7E7nvGA== + dependencies: + "@babel/runtime" "^7.10.1" + classnames "^2.2.5" + rc-util "^5.9.8" + +rc-input@^0.0.1-alpha.5: + version "0.0.1-alpha.5" + resolved "https://registry.yarnpkg.com/rc-input/-/rc-input-0.0.1-alpha.5.tgz#cc043c44570c651f4d10d9809b3d634ed12537e6" + integrity sha512-RHvNweOVWFbbx2l/y6hgnSAdOg5fXc1D1VGhX2RNkGGyGr6cemnvyiYMxwZJjcXs0al3YK9jMObm20+DgH/mpw== + dependencies: + "@babel/runtime" "^7.11.1" + classnames "^2.2.1" + rc-util "^5.18.1" + +rc-mentions@~1.6.1: + version "1.6.2" + resolved "https://registry.yarnpkg.com/rc-mentions/-/rc-mentions-1.6.2.tgz#62ed7cdd8fa86d857c3ce3f9e73438022130815e" + integrity sha512-cntfJkNMq8B910rXuvnsnOV88DfmoUidnQnSIeXzWiYiUX4RL5oWUfSZzs+HAXYRU4SL1l8Mwjx95wHETiZ/fQ== + dependencies: + "@babel/runtime" "^7.10.1" + classnames "^2.2.6" + rc-menu "^9.0.0" + rc-textarea "^0.3.0" + rc-trigger "^5.0.4" + rc-util "^5.0.1" + +rc-menu@^9.0.0: + version "9.3.2" + resolved "https://registry.yarnpkg.com/rc-menu/-/rc-menu-9.3.2.tgz#bb842d37ebf71da912bea201cf7ef0a27267ad49" + integrity sha512-h3m45oY1INZyqphGELkdT0uiPnFzxkML8m0VMhJnk2fowtqfiT7F5tJLT3znEVaPIY80vMy1bClCkgq8U91CzQ== + dependencies: + "@babel/runtime" "^7.10.1" + classnames "2.x" + rc-motion "^2.4.3" + rc-overflow "^1.2.0" + rc-trigger "^5.1.2" + rc-util "^5.12.0" + shallowequal "^1.1.0" + +rc-menu@~9.2.1: + version "9.2.1" + resolved "https://registry.yarnpkg.com/rc-menu/-/rc-menu-9.2.1.tgz#6fbe47f4846363bb81a5a21f0960026c3ada497a" + integrity sha512-UbEtn3rflJ8zS+etYGTVQuzy7Fm+yWXR5c0Rl6ecNTS/dPknRyWAyhJcbeR0Hu1+RdQT+0VCqrUPrgKnm4iY+w== + dependencies: + "@babel/runtime" "^7.10.1" + classnames "2.x" + rc-motion "^2.4.3" + rc-overflow "^1.2.0" + rc-trigger "^5.1.2" + rc-util "^5.12.0" + shallowequal "^1.1.0" + +rc-motion@^2.0.0, rc-motion@^2.0.1, rc-motion@^2.2.0, rc-motion@^2.3.0, rc-motion@^2.3.4, rc-motion@^2.4.3, rc-motion@^2.4.4: + version "2.4.5" + resolved "https://registry.yarnpkg.com/rc-motion/-/rc-motion-2.4.5.tgz#b061c50bb29ecd3d735d5f4c40924a3c78226cbd" + integrity sha512-f3uJHR4gcpeZS/s8/nYFSOrXt2Wu/h9GrEcbJmC0qmKrVNgwL1pTgrT5kW7lgG6PFeoL4yHDmpQoEKkrPtKIzQ== + dependencies: + "@babel/runtime" "^7.11.1" + classnames "^2.2.1" + rc-util "^5.18.1" + +rc-notification@~4.5.7: + version "4.5.7" + resolved "https://registry.yarnpkg.com/rc-notification/-/rc-notification-4.5.7.tgz#265e6e6a0c1a0fac63d6abd4d832eb8ff31522f1" + integrity sha512-zhTGUjBIItbx96SiRu3KVURcLOydLUHZCPpYEn1zvh+re//Tnq/wSxN4FKgp38n4HOgHSVxcLEeSxBMTeBBDdw== + dependencies: + "@babel/runtime" "^7.10.1" + classnames "2.x" + rc-motion "^2.2.0" + rc-util "^5.0.1" + +rc-overflow@^1.0.0, rc-overflow@^1.2.0: + version "1.2.3" + resolved "https://registry.yarnpkg.com/rc-overflow/-/rc-overflow-1.2.3.tgz#1754216d807f5473304272b0321c3aba7615f47a" + integrity sha512-Bz6dXTn/ww8nmu70tUQfRV0wT3BkfXY6j1lB1O38OVkDPz4xwfAcGK+LJ2zewUR5cTXkJ8hAN7YULohG8z4M7Q== + dependencies: + "@babel/runtime" "^7.11.1" + classnames "^2.2.1" + rc-resize-observer "^1.0.0" + rc-util "^5.15.0" + +rc-pagination@~3.1.9: + version "3.1.15" + resolved "https://registry.yarnpkg.com/rc-pagination/-/rc-pagination-3.1.15.tgz#e05eddf4c15717a5858290bed0857e27e2f957ff" + integrity sha512-4L3fot8g4E+PjWEgoVGX0noFCg+8ZFZmeLH4vsnZpB3O2T2zThtakjNxG+YvSaYtyMVT4B+GLayjKrKbXQpdAg== + dependencies: + "@babel/runtime" "^7.10.1" + classnames "^2.2.1" + +rc-picker@~2.6.4: + version "2.6.4" + resolved "https://registry.yarnpkg.com/rc-picker/-/rc-picker-2.6.4.tgz#916aa5fcd8abd11106f1c2fb64bfd549439abfa0" + integrity sha512-Mnc1udPyGNSG7/ya5SmYltUjCUcsMH7jfJnuuXVAvEaEdx9qZxDGMWtIii//+ARC06CSHQ83s5iwiGFwM+FcDw== + dependencies: + "@babel/runtime" "^7.10.1" + classnames "^2.2.1" + date-fns "2.x" + dayjs "1.x" + moment "^2.24.0" + rc-trigger "^5.0.4" + rc-util "^5.4.0" + shallowequal "^1.1.0" + +rc-progress@~3.2.1: + version "3.2.4" + resolved "https://registry.yarnpkg.com/rc-progress/-/rc-progress-3.2.4.tgz#4036acdae2566438545bc4df2203248babaf7549" + integrity sha512-M9WWutRaoVkPUPIrTpRIDpX0SPSrVHzxHdCRCbeoBFrd9UFWTYNWRlHsruJM5FH1AZI+BwB4wOJUNNylg/uFSw== + dependencies: + "@babel/runtime" "^7.10.1" + classnames "^2.2.6" + rc-util "^5.16.1" + +rc-rate@~2.9.0: + version "2.9.1" + resolved "https://registry.yarnpkg.com/rc-rate/-/rc-rate-2.9.1.tgz#e43cb95c4eb90a2c1e0b16ec6614d8c43530a731" + integrity sha512-MmIU7FT8W4LYRRHJD1sgG366qKtSaKb67D0/vVvJYR0lrCuRrCiVQ5qhfT5ghVO4wuVIORGpZs7ZKaYu+KMUzA== + dependencies: + "@babel/runtime" "^7.10.1" + classnames "^2.2.5" + rc-util "^5.0.1" + +rc-resize-observer@^1.0.0, rc-resize-observer@^1.1.0, rc-resize-observer@^1.2.0: + version "1.2.0" + resolved "https://registry.yarnpkg.com/rc-resize-observer/-/rc-resize-observer-1.2.0.tgz#9f46052f81cdf03498be35144cb7c53fd282c4c7" + integrity sha512-6W+UzT3PyDM0wVCEHfoW3qTHPTvbdSgiA43buiy8PzmeMnfgnDeb9NjdimMXMl3/TcrvvWl5RRVdp+NqcR47pQ== + dependencies: + "@babel/runtime" "^7.10.1" + classnames "^2.2.1" + rc-util "^5.15.0" + resize-observer-polyfill "^1.5.1" + +rc-select@~14.0.0-alpha.15, rc-select@~14.0.0-alpha.23, rc-select@~14.0.0-alpha.8: + version "14.0.0" + resolved "https://registry.yarnpkg.com/rc-select/-/rc-select-14.0.0.tgz#87735dbc548f1cc8e94d579b21682ed2d34f7653" + integrity sha512-DkoWMhyxmrfpc1KJSqPORZdkKevzgOINvjR4WI+dibRe6i6DyqGB4Jk21sencnK9di6dumzOCHf93x9t9+gp3Q== + dependencies: + "@babel/runtime" "^7.10.1" + classnames "2.x" + rc-motion "^2.0.1" + rc-overflow "^1.0.0" + rc-trigger "^5.0.4" + rc-util "^5.16.1" + rc-virtual-list "^3.2.0" + +rc-slider@~10.0.0-alpha.4: + version "10.0.0-alpha.4" + resolved "https://registry.yarnpkg.com/rc-slider/-/rc-slider-10.0.0-alpha.4.tgz#f14ec0905d53f1f9d7f495c301527d6eca5781cf" + integrity sha512-ih2xwkBgXAWAf7MjZIZyCiiWo6tnoIMuHifn0UeKXVAup7sH53QdSVvT9x/cysuSZIPNMYWEf6mec184n3gbiQ== + dependencies: + "@babel/runtime" "^7.10.1" + classnames "^2.2.5" + rc-tooltip "^5.0.1" + rc-util "^5.18.1" + shallowequal "^1.1.0" + +rc-steps@~4.1.0: + version "4.1.4" + resolved "https://registry.yarnpkg.com/rc-steps/-/rc-steps-4.1.4.tgz#0ba82db202d59ca52d0693dc9880dd145b19dc23" + integrity sha512-qoCqKZWSpkh/b03ASGx1WhpKnuZcRWmvuW+ZUu4mvMdfvFzVxblTwUM+9aBd0mlEUFmt6GW8FXhMpHkK3Uzp3w== + dependencies: + "@babel/runtime" "^7.10.2" + classnames "^2.2.3" + rc-util "^5.0.1" + +rc-switch@~3.2.0: + version "3.2.2" + resolved "https://registry.yarnpkg.com/rc-switch/-/rc-switch-3.2.2.tgz#d001f77f12664d52595b4f6fb425dd9e66fba8e8" + integrity sha512-+gUJClsZZzvAHGy1vZfnwySxj+MjLlGRyXKXScrtCTcmiYNPzxDFOxdQ/3pK1Kt/0POvwJ/6ALOR8gwdXGhs+A== + dependencies: + "@babel/runtime" "^7.10.1" + classnames "^2.2.1" + rc-util "^5.0.1" + +rc-table@~7.23.0: + version "7.23.0" + resolved "https://registry.yarnpkg.com/rc-table/-/rc-table-7.23.0.tgz#e5f76998ecf3246147d45ed311417c08886e6507" + integrity sha512-Q1gneB2+lUa8EzCCfbrq+jO1qNSwQv1RUUXKB84W/Stdp4EvGOt2+QqGyfotMNM4JUw0fgGLwY+WjnhUhnLuQQ== + dependencies: + "@babel/runtime" "^7.10.1" + classnames "^2.2.5" + rc-resize-observer "^1.1.0" + rc-util "^5.14.0" + shallowequal "^1.1.0" + +rc-tabs@~11.10.0: + version "11.10.7" + resolved "https://registry.yarnpkg.com/rc-tabs/-/rc-tabs-11.10.7.tgz#7d8b5dcc17f1608cf3b9425d80069f1415479335" + integrity sha512-7IKmcU7QU3CdYnJTabeXs2DDeLiXLyALC8fvOtgyWWFXUD47G5vG+4bFO3f9+AI+rcFAPpfwapZbXxgmiRuWYQ== + dependencies: + "@babel/runtime" "^7.11.2" + classnames "2.x" + rc-dropdown "^3.2.0" + rc-menu "^9.0.0" + rc-resize-observer "^1.0.0" + rc-util "^5.5.0" + +rc-textarea@^0.3.0, rc-textarea@~0.3.0: + version "0.3.7" + resolved "https://registry.yarnpkg.com/rc-textarea/-/rc-textarea-0.3.7.tgz#987142891efdedb774883c07e2f51b318fde5a11" + integrity sha512-yCdZ6binKmAQB13hc/oehh0E/QRwoPP1pjF21aHBxlgXO3RzPF6dUu4LG2R4FZ1zx/fQd2L1faktulrXOM/2rw== + dependencies: + "@babel/runtime" "^7.10.1" + classnames "^2.2.1" + rc-resize-observer "^1.0.0" + rc-util "^5.7.0" + shallowequal "^1.1.0" + +rc-tooltip@^5.0.1, rc-tooltip@~5.1.1: + version "5.1.1" + resolved "https://registry.yarnpkg.com/rc-tooltip/-/rc-tooltip-5.1.1.tgz#94178ed162d0252bc4993b725f5dc2ac0fccf154" + integrity sha512-alt8eGMJulio6+4/uDm7nvV+rJq9bsfxFDCI0ljPdbuoygUscbsMYb6EQgwib/uqsXQUvzk+S7A59uYHmEgmDA== + dependencies: + "@babel/runtime" "^7.11.2" + rc-trigger "^5.0.0" + +rc-tree-select@~5.1.1: + version "5.1.4" + resolved "https://registry.yarnpkg.com/rc-tree-select/-/rc-tree-select-5.1.4.tgz#3577135399d1f4931b0f4d8245e0845861802e2b" + integrity sha512-sA6vTUQghzbjh3u6YAwJIebKkJEHUWDPFHQpfiPObqsEYqi9TKE1LvWqbJ77NbOlOARZq0KIb7LDGF8X0dikDQ== + dependencies: + "@babel/runtime" "^7.10.1" + classnames "2.x" + rc-select "~14.0.0-alpha.8" + rc-tree "~5.4.3" + rc-util "^5.16.1" + +rc-tree@~5.4.3: + version "5.4.4" + resolved "https://registry.yarnpkg.com/rc-tree/-/rc-tree-5.4.4.tgz#2ea3663ad3c566aef79a46ba6a1e050d24323e01" + integrity sha512-2qoObRgp31DBXmVzMJmo4qmwP20XEa4hR3imWQtRPcgN3pmljW3WKFmZRrYdOFHz7CyTnRsFZR065bBkIoUpiA== + dependencies: + "@babel/runtime" "^7.10.1" + classnames "2.x" + rc-motion "^2.0.1" + rc-util "^5.16.1" + rc-virtual-list "^3.4.2" + +rc-trigger@^5.0.0, rc-trigger@^5.0.4, rc-trigger@^5.1.2, rc-trigger@^5.2.10: + version "5.2.10" + resolved "https://registry.yarnpkg.com/rc-trigger/-/rc-trigger-5.2.10.tgz#8a0057a940b1b9027eaa33beec8a6ecd85cce2b1" + integrity sha512-FkUf4H9BOFDaIwu42fvRycXMAvkttph9AlbCZXssZDVzz2L+QZ0ERvfB/4nX3ZFPh1Zd+uVGr1DEDeXxq4J1TA== + dependencies: + "@babel/runtime" "^7.11.2" + classnames "^2.2.6" + rc-align "^4.0.0" + rc-motion "^2.0.0" + rc-util "^5.5.0" + +rc-upload@~4.3.0: + version "4.3.3" + resolved "https://registry.yarnpkg.com/rc-upload/-/rc-upload-4.3.3.tgz#e237aa525e5313fa16f4d04d27f53c2f0e157bb8" + integrity sha512-YoJ0phCRenMj1nzwalXzciKZ9/FAaCrFu84dS5pphwucTC8GUWClcDID/WWNGsLFcM97NqIboDqrV82rVRhW/w== + dependencies: + "@babel/runtime" "^7.10.1" + classnames "^2.2.5" + rc-util "^5.2.0" + +rc-util@^5.0.1, rc-util@^5.0.6, rc-util@^5.0.7, rc-util@^5.12.0, rc-util@^5.14.0, rc-util@^5.15.0, rc-util@^5.16.1, rc-util@^5.18.1, rc-util@^5.2.0, rc-util@^5.2.1, rc-util@^5.3.0, rc-util@^5.4.0, rc-util@^5.5.0, rc-util@^5.6.1, rc-util@^5.7.0, rc-util@^5.8.0, rc-util@^5.9.4, rc-util@^5.9.8: + version "5.18.1" + resolved "https://registry.yarnpkg.com/rc-util/-/rc-util-5.18.1.tgz#80bd1450b5254655d2fbea63e3d34f6871e9be79" + integrity sha512-24xaSrMZUEKh1+suDOtJWfPe9E6YrwryViZcoPO0miJTKzP4qhUlV5AAlKQ82AJilz/AOHfi3l6HoX8qa1ye8w== + dependencies: + "@babel/runtime" "^7.12.5" + react-is "^16.12.0" + shallowequal "^1.1.0" + +rc-virtual-list@^3.2.0, rc-virtual-list@^3.4.2: + version "3.4.2" + resolved "https://registry.yarnpkg.com/rc-virtual-list/-/rc-virtual-list-3.4.2.tgz#1078327aa7230b5e456d679ed2ce99f3c036ebd1" + integrity sha512-OyVrrPvvFcHvV0ssz5EDZ+7Rf5qLat/+mmujjchNw5FfbJWNDwkpQ99EcVE6+FtNRmX9wFa1LGNpZLUTvp/4GQ== + dependencies: + classnames "^2.2.6" + rc-resize-observer "^1.0.0" + rc-util "^5.0.7" + +react-dom@^16.13.1: + version "16.14.0" + resolved "https://registry.yarnpkg.com/react-dom/-/react-dom-16.14.0.tgz#7ad838ec29a777fb3c75c3a190f661cf92ab8b89" + integrity sha512-1gCeQXDLoIqMgqD3IO2Ah9bnf0w9kzhwN5q4FGnHZ67hBm9yePzB5JJAIQCc8x3pFnNlwFq4RidZggNAAkzWWw== + dependencies: + loose-envify "^1.1.0" + object-assign "^4.1.1" + prop-types "^15.6.2" + scheduler "^0.19.1" + +react-flame-graph@^1.4.0: + version "1.4.0" + resolved "https://registry.yarnpkg.com/react-flame-graph/-/react-flame-graph-1.4.0.tgz#52d118cc94348f630a812fc0ec530a5b73c30cdb" + integrity sha512-DaCK9ZX+xK0mNca72kUE5cu6T8hGe/KLsefQWf+eT9sVt+0WP1dVxZCGD8Svfn2KrZB9Mv011Intg/yG2YWSxA== + dependencies: + flow-bin "^0.118.0" + memoize-one "^3.1.1" + react-window "^1" + +react-is@^16.12.0, react-is@^16.13.1, react-is@^16.7.0: + version "16.13.1" + resolved "https://registry.yarnpkg.com/react-is/-/react-is-16.13.1.tgz#789729a4dc36de2999dc156dd6c1d9c18cea56a4" + integrity sha512-24e6ynE2H+OKt4kqsOvNd8kBpV65zoxbA4BVsEOB3ARVWQki/DHzaUoC5KuON/BiccDaCCTZBuOcfZs70kR8bQ== + +"react-is@^16.8.0 || ^17.0.0": + version "17.0.2" + resolved "https://registry.yarnpkg.com/react-is/-/react-is-17.0.2.tgz#e691d4a8e9c789365655539ab372762b0efb54f0" + integrity sha512-w2GsyukL62IJnlaff/nRegPQR94C/XXamvMWmSHRJ4y7Ts/4ocGRmTHvOs8PSE6pB3dWOrD/nueuU5sduBsQ4w== + +react-transition-group@^4.4.0: + version "4.4.2" + resolved "https://registry.yarnpkg.com/react-transition-group/-/react-transition-group-4.4.2.tgz#8b59a56f09ced7b55cbd53c36768b922890d5470" + integrity sha512-/RNYfRAMlZwDSr6z4zNKV6xu53/e2BuaBbGhbyYIXTrmgu/bGHzmqOs7mJSJBHy9Ud+ApHx3QjrkKSp1pxvlFg== + dependencies: + "@babel/runtime" "^7.5.5" + dom-helpers "^5.0.1" + loose-envify "^1.4.0" + prop-types "^15.6.2" + +react-window@^1: + version "1.8.6" + resolved "https://registry.yarnpkg.com/react-window/-/react-window-1.8.6.tgz#d011950ac643a994118632665aad0c6382e2a112" + integrity sha512-8VwEEYyjz6DCnGBsd+MgkD0KJ2/OXFULyDtorIiTz+QzwoP94tBoA7CnbtyXMm+cCeAUER5KJcPtWl9cpKbOBg== + dependencies: + "@babel/runtime" "^7.0.0" + memoize-one ">=3.1.1 <6" + +react@^16.13.1: + version "16.14.0" + resolved "https://registry.yarnpkg.com/react/-/react-16.14.0.tgz#94d776ddd0aaa37da3eda8fc5b6b18a4c9a3114d" + integrity sha512-0X2CImDkJGApiAlcf0ODKIneSwBPhqJawOa5wCtKbu7ZECrmS26NvtSILynQ66cgkT/RJ4LidJOc3bUESwmU8g== + dependencies: + loose-envify "^1.1.0" + object-assign "^4.1.1" + prop-types "^15.6.2" + +readable-stream@^2.0.1: + version "2.3.7" + resolved "https://registry.yarnpkg.com/readable-stream/-/readable-stream-2.3.7.tgz#1eca1cf711aef814c04f62252a36a62f6cb23b57" + integrity sha512-Ebho8K4jIbHAxnuxi7o42OrZgF/ZTNcsZj6nRKyUmkhLFq8CHItp/fy6hQZuZmP/n3yZ9VBUbp4zz/mX8hmYPw== + dependencies: + core-util-is "~1.0.0" + inherits "~2.0.3" + isarray "~1.0.0" + process-nextick-args "~2.0.0" + safe-buffer "~5.1.1" + string_decoder "~1.1.1" + util-deprecate "~1.0.1" + +readable-stream@^3.0.6: + version "3.6.0" + resolved "https://registry.yarnpkg.com/readable-stream/-/readable-stream-3.6.0.tgz#337bbda3adc0706bd3e024426a286d4b4b2c9198" + integrity sha512-BViHy7LKeTz4oNnkcLJ+lVSL6vpiFeX6/d3oSH8zCW7UxP2onchk+vTGB143xuFjHS3deTgkKoXXymXqymiIdA== + dependencies: + inherits "^2.0.3" + string_decoder "^1.1.1" + util-deprecate "^1.0.1" + +readdirp@~3.6.0: + version "3.6.0" + resolved "https://registry.yarnpkg.com/readdirp/-/readdirp-3.6.0.tgz#74a370bd857116e245b29cc97340cd431a02a6c7" + integrity sha512-hOS089on8RduqdbhvQ5Z37A0ESjsqz6qnRcffsMU3495FuTdqSm+7bhJ29JvIOsBDEEnan5DPu9t3To9VRlMzA== + dependencies: + picomatch "^2.2.1" + +rechoir@^0.7.0: + version "0.7.1" + resolved "https://registry.yarnpkg.com/rechoir/-/rechoir-0.7.1.tgz#9478a96a1ca135b5e88fc027f03ee92d6c645686" + integrity sha512-/njmZ8s1wVeR6pjTZ+0nCnv8SpZNRMT2D1RLOJQESlYFDBvwpTA4KWJpZ+sBJ4+vhjILRcK7JIFdGCdxEAAitg== + dependencies: + resolve "^1.9.0" + +regenerator-runtime@^0.13.4: + version "0.13.9" + resolved "https://registry.yarnpkg.com/regenerator-runtime/-/regenerator-runtime-0.13.9.tgz#8925742a98ffd90814988d7566ad30ca3b263b52" + integrity sha512-p3VT+cOEgxFsRRA9X4lkI1E+k2/CtnKtU4gcxyaCUreilL/vqI6CdZ3wxVUx3UOUg+gnUOQQcRI7BmSI656MYA== + +regexp.prototype.flags@^1.2.0: + version "1.4.1" + resolved "https://registry.yarnpkg.com/regexp.prototype.flags/-/regexp.prototype.flags-1.4.1.tgz#b3f4c0059af9e47eca9f3f660e51d81307e72307" + integrity sha512-pMR7hBVUUGI7PMA37m2ofIdQCsomVnas+Jn5UPGAHQ+/LlwKm/aTLJHdasmHRzlfeZwHiAOaRSo2rbBDm3nNUQ== + dependencies: + call-bind "^1.0.2" + define-properties "^1.1.3" + +relateurl@^0.2.7: + version "0.2.7" + resolved "https://registry.yarnpkg.com/relateurl/-/relateurl-0.2.7.tgz#54dbf377e51440aca90a4cd274600d3ff2d888a9" + integrity sha1-VNvzd+UUQKypCkzSdGANP/LYiKk= + +renderkid@^3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/renderkid/-/renderkid-3.0.0.tgz#5fd823e4d6951d37358ecc9a58b1f06836b6268a" + integrity sha512-q/7VIQA8lmM1hF+jn+sFSPWGlMkSAeNYcPLmDQx2zzuiDfaLrOmumR8iaUKlenFgh0XRPIUeSPlH3A+AW3Z5pg== + dependencies: + css-select "^4.1.3" + dom-converter "^0.2.0" + htmlparser2 "^6.1.0" + lodash "^4.17.21" + strip-ansi "^6.0.1" + +require-from-string@^2.0.2: + version "2.0.2" + resolved "https://registry.yarnpkg.com/require-from-string/-/require-from-string-2.0.2.tgz#89a7fdd938261267318eafe14f9c32e598c36909" + integrity sha512-Xf0nWe6RseziFMu+Ap9biiUbmplq6S9/p+7w7YXP/JBHhrUDDUhwa+vANyubuqfZWTveU//DYVGsDG7RKL/vEw== + +requires-port@^1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/requires-port/-/requires-port-1.0.0.tgz#925d2601d39ac485e091cf0da5c6e694dc3dcaff" + integrity sha1-kl0mAdOaxIXgkc8NpcbmlNw9yv8= + +resize-observer-polyfill@^1.5.0, resize-observer-polyfill@^1.5.1: + version "1.5.1" + resolved "https://registry.yarnpkg.com/resize-observer-polyfill/-/resize-observer-polyfill-1.5.1.tgz#0e9020dd3d21024458d4ebd27e23e40269810464" + integrity sha512-LwZrotdHOo12nQuZlHEmtuXdqGoOD0OhaxopaNFxWzInpEgaLWoVuAMbTzixuosCx2nEG58ngzW3vxdWoxIgdg== + +resolve-cwd@^3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/resolve-cwd/-/resolve-cwd-3.0.0.tgz#0f0075f1bb2544766cf73ba6a6e2adfebcb13f2d" + integrity sha512-OrZaX2Mb+rJCpH/6CpSqt9xFVpN++x01XnN2ie9g6P5/3xelLAkXWVADpdz1IHD/KFfEXyE6V0U01OQ3UO2rEg== + dependencies: + resolve-from "^5.0.0" + +resolve-from@^5.0.0: + version "5.0.0" + resolved "https://registry.yarnpkg.com/resolve-from/-/resolve-from-5.0.0.tgz#c35225843df8f776df21c57557bc087e9dfdfc69" + integrity sha512-qYg9KP24dD5qka9J47d0aVky0N+b4fTU89LN9iDnjB5waksiC49rvMB0PrUJQGoTmH50XPiqOvAjDfaijGxYZw== + +resolve@^1.9.0: + version "1.22.0" + resolved "https://registry.yarnpkg.com/resolve/-/resolve-1.22.0.tgz#5e0b8c67c15df57a89bdbabe603a002f21731198" + integrity sha512-Hhtrw0nLeSrFQ7phPp4OOcVjLPIeMnRlr5mcnVuMe7M/7eBn98A3hmFRLoFo3DLZkivSYwhRUJTyPyWAk56WLw== + dependencies: + is-core-module "^2.8.1" + path-parse "^1.0.7" + supports-preserve-symlinks-flag "^1.0.0" + +retry@^0.13.1: + version "0.13.1" + resolved "https://registry.yarnpkg.com/retry/-/retry-0.13.1.tgz#185b1587acf67919d63b357349e03537b2484658" + integrity sha512-XQBQ3I8W1Cge0Seh+6gjj03LbmRFWuoszgK9ooCpwYIrhhoO80pfq4cUkU5DkknwfOfFteRwlZ56PYOGYyFWdg== + +reusify@^1.0.4: + version "1.0.4" + resolved "https://registry.yarnpkg.com/reusify/-/reusify-1.0.4.tgz#90da382b1e126efc02146e90845a88db12925d76" + integrity sha512-U9nH88a3fc/ekCF1l0/UP1IosiuIjyTh7hBvXVMHYgVcfGvt897Xguj2UOLDeI5BG2m7/uwyaLVT6fbtCwTyzw== + +rimraf@^3.0.2: + version "3.0.2" + resolved "https://registry.yarnpkg.com/rimraf/-/rimraf-3.0.2.tgz#f1a5402ba6220ad52cc1282bac1ae3aa49fd061a" + integrity sha512-JZkJMZkAGFFPP2YqXZXPbMlMBgsxzE8ILs4lMIX/2o0L9UBw9O/Y3o6wFw/i9YLapcUJWwqbi3kdxIPdC62TIA== + dependencies: + glob "^7.1.3" + +run-parallel@^1.1.9: + version "1.2.0" + resolved "https://registry.yarnpkg.com/run-parallel/-/run-parallel-1.2.0.tgz#66d1368da7bdf921eb9d95bd1a9229e7f21a43ee" + integrity sha512-5l4VyZR86LZ/lDxZTR6jqL8AFE2S0IFLMP26AbjsLVADxHdhB/c0GUsH+y39UfCi3dzz8OlQuPmnaJOMoDHQBA== + dependencies: + queue-microtask "^1.2.2" + +safe-buffer@5.1.2, safe-buffer@~5.1.0, safe-buffer@~5.1.1: + version "5.1.2" + resolved "https://registry.yarnpkg.com/safe-buffer/-/safe-buffer-5.1.2.tgz#991ec69d296e0313747d59bdfd2b745c35f8828d" + integrity sha512-Gd2UZBJDkXlY7GbJxfsE8/nvKkUEU1G38c1siN6QP6a9PT9MmHB8GnpscSmMJSoF8LOIrt8ud/wPtojys4G6+g== + +safe-buffer@5.2.1, safe-buffer@>=5.1.0, safe-buffer@^5.0.1, safe-buffer@^5.1.0, safe-buffer@~5.2.0: + version "5.2.1" + resolved "https://registry.yarnpkg.com/safe-buffer/-/safe-buffer-5.2.1.tgz#1eaf9fa9bdb1fdd4ec75f58f9cdb4e6b7827eec6" + integrity sha512-rp3So07KcdmmKbGvgaNxQSJr7bGVSVk5S9Eq1F+ppbRo70+YeaDxkw5Dd8NPN+GD6bjnYm2VuPuCXmpuYvmCXQ== + +"safer-buffer@>= 2.1.2 < 3": + version "2.1.2" + resolved "https://registry.yarnpkg.com/safer-buffer/-/safer-buffer-2.1.2.tgz#44fa161b0187b9549dd84bb91802f9bd8385cd6a" + integrity sha512-YZo3K82SD7Riyi0E1EQPojLz7kpepnSQI9IyPbHHg1XXXevb5dJI7tpyN2ADxGcQbHG7vcyRHk0cbwqcQriUtg== + +scheduler@^0.19.1: + version "0.19.1" + resolved "https://registry.yarnpkg.com/scheduler/-/scheduler-0.19.1.tgz#4f3e2ed2c1a7d65681f4c854fa8c5a1ccb40f196" + integrity sha512-n/zwRWRYSUj0/3g/otKDRPMh6qv2SYMWNq85IEa8iZyAv8od9zDYpGSnpBEjNgcMNq6Scbu5KfIPxNF72R/2EA== + dependencies: + loose-envify "^1.1.0" + object-assign "^4.1.1" + +schema-utils@^3.0.0, schema-utils@^3.1.0, schema-utils@^3.1.1: + version "3.1.1" + resolved "https://registry.yarnpkg.com/schema-utils/-/schema-utils-3.1.1.tgz#bc74c4b6b6995c1d88f76a8b77bea7219e0c8281" + integrity sha512-Y5PQxS4ITlC+EahLuXaY86TXfR7Dc5lw294alXOq86JAHCihAIZfqv8nNCWvaEJvaC51uN9hbLGeV0cFBdH+Fw== + dependencies: + "@types/json-schema" "^7.0.8" + ajv "^6.12.5" + ajv-keywords "^3.5.2" + +schema-utils@^4.0.0: + version "4.0.0" + resolved "https://registry.yarnpkg.com/schema-utils/-/schema-utils-4.0.0.tgz#60331e9e3ae78ec5d16353c467c34b3a0a1d3df7" + integrity sha512-1edyXKgh6XnJsJSQ8mKWXnN/BVaIbFMLpouRUrXgVq7WYne5kw3MW7UPhO44uRXQSIpTSXoJbmrR2X0w9kUTyg== + dependencies: + "@types/json-schema" "^7.0.9" + ajv "^8.8.0" + ajv-formats "^2.1.1" + ajv-keywords "^5.0.0" + +scroll-into-view-if-needed@^2.2.25: + version "2.2.29" + resolved "https://registry.yarnpkg.com/scroll-into-view-if-needed/-/scroll-into-view-if-needed-2.2.29.tgz#551791a84b7e2287706511f8c68161e4990ab885" + integrity sha512-hxpAR6AN+Gh53AdAimHM6C8oTN1ppwVZITihix+WqalywBeFcQ6LdQP5ABNl26nX8GTEL7VT+b8lKpdqq65wXg== + dependencies: + compute-scroll-into-view "^1.0.17" + +select-hose@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/select-hose/-/select-hose-2.0.0.tgz#625d8658f865af43ec962bfc376a37359a4994ca" + integrity sha1-Yl2GWPhlr0Psliv8N2o3NZpJlMo= + +selfsigned@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/selfsigned/-/selfsigned-2.0.0.tgz#e927cd5377cbb0a1075302cff8df1042cc2bce5b" + integrity sha512-cUdFiCbKoa1mZ6osuJs2uDHrs0k0oprsKveFiiaBKCNq3SYyb5gs2HxhQyDNLCmL51ZZThqi4YNDpCK6GOP1iQ== + dependencies: + node-forge "^1.2.0" + +semver@^7.3.4, semver@^7.3.5: + version "7.3.5" + resolved "https://registry.yarnpkg.com/semver/-/semver-7.3.5.tgz#0b621c879348d8998e4b0e4be94b3f12e6018ef7" + integrity sha512-PoeGJYh8HK4BTO/a9Tf6ZG3veo/A7ZVsYrSA6J8ny9nb3B1VrpkuN+z9OE5wfE5p6H4LchYZsegiQgbJD94ZFQ== + dependencies: + lru-cache "^6.0.0" + +send@0.17.2: + version "0.17.2" + resolved "https://registry.yarnpkg.com/send/-/send-0.17.2.tgz#926622f76601c41808012c8bf1688fe3906f7820" + integrity sha512-UJYB6wFSJE3G00nEivR5rgWp8c2xXvJ3OPWPhmuteU0IKj8nKbG3DrjiOmLwpnHGYWAVwA69zmTm++YG0Hmwww== + dependencies: + debug "2.6.9" + depd "~1.1.2" + destroy "~1.0.4" + encodeurl "~1.0.2" + escape-html "~1.0.3" + etag "~1.8.1" + fresh "0.5.2" + http-errors "1.8.1" + mime "1.6.0" + ms "2.1.3" + on-finished "~2.3.0" + range-parser "~1.2.1" + statuses "~1.5.0" + +serialize-javascript@^6.0.0: + version "6.0.0" + resolved "https://registry.yarnpkg.com/serialize-javascript/-/serialize-javascript-6.0.0.tgz#efae5d88f45d7924141da8b5c3a7a7e663fefeb8" + integrity sha512-Qr3TosvguFt8ePWqsvRfrKyQXIiW+nGbYpy8XK24NQHE83caxWt+mIymTT19DGFbNWNLfEwsrkSmN64lVWB9ag== + dependencies: + randombytes "^2.1.0" + +serve-index@^1.9.1: + version "1.9.1" + resolved "https://registry.yarnpkg.com/serve-index/-/serve-index-1.9.1.tgz#d3768d69b1e7d82e5ce050fff5b453bea12a9239" + integrity sha1-03aNabHn2C5c4FD/9bRTvqEqkjk= + dependencies: + accepts "~1.3.4" + batch "0.6.1" + debug "2.6.9" + escape-html "~1.0.3" + http-errors "~1.6.2" + mime-types "~2.1.17" + parseurl "~1.3.2" + +serve-static@1.14.2: + version "1.14.2" + resolved "https://registry.yarnpkg.com/serve-static/-/serve-static-1.14.2.tgz#722d6294b1d62626d41b43a013ece4598d292bfa" + integrity sha512-+TMNA9AFxUEGuC0z2mevogSnn9MXKb4fa7ngeRMJaaGv8vTwnIEkKi+QGvPt33HSnf8pRS+WGM0EbMtCJLKMBQ== + dependencies: + encodeurl "~1.0.2" + escape-html "~1.0.3" + parseurl "~1.3.3" + send "0.17.2" + +setprototypeof@1.1.0: + version "1.1.0" + resolved "https://registry.yarnpkg.com/setprototypeof/-/setprototypeof-1.1.0.tgz#d0bd85536887b6fe7c0d818cb962d9d91c54e656" + integrity sha512-BvE/TwpZX4FXExxOxZyRGQQv651MSwmWKZGqvmPcRIjDqWub67kTKuIMx43cZZrS/cBBzwBcNDWoFxt2XEFIpQ== + +setprototypeof@1.2.0: + version "1.2.0" + resolved "https://registry.yarnpkg.com/setprototypeof/-/setprototypeof-1.2.0.tgz#66c9a24a73f9fc28cbe66b09fed3d33dcaf1b424" + integrity sha512-E5LDX7Wrp85Kil5bhZv46j8jOeboKq5JMmYM3gVGdGH8xFpPWXUMsNrlODCrkoxMEeNi/XZIwuRvY4XNwYMJpw== + +shallow-clone@^3.0.0: + version "3.0.1" + resolved "https://registry.yarnpkg.com/shallow-clone/-/shallow-clone-3.0.1.tgz#8f2981ad92531f55035b01fb230769a40e02efa3" + integrity sha512-/6KqX+GVUdqPuPPd2LxDDxzX6CAbjJehAAOKlNpqqUpAqPM6HeL8f+o3a+JsyGjn2lv0WY8UsTgUJjU9Ok55NA== + dependencies: + kind-of "^6.0.2" + +shallowequal@^1.1.0: + version "1.1.0" + resolved "https://registry.yarnpkg.com/shallowequal/-/shallowequal-1.1.0.tgz#188d521de95b9087404fd4dcb68b13df0ae4e7f8" + integrity sha512-y0m1JoUZSlPAjXVtPPW70aZWfIL/dSP7AFkRnniLCrK/8MDKog3TySTBmckD+RObVxH0v4Tox67+F14PdED2oQ== + +shebang-command@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/shebang-command/-/shebang-command-2.0.0.tgz#ccd0af4f8835fbdc265b82461aaf0c36663f34ea" + integrity sha512-kHxr2zZpYtdmrN1qDjrrX/Z1rR1kG8Dx+gkpK1G4eXmvXswmcE1hTWBWYUzlraYw1/yZp6YuDY77YtvbN0dmDA== + dependencies: + shebang-regex "^3.0.0" + +shebang-regex@^3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/shebang-regex/-/shebang-regex-3.0.0.tgz#ae16f1644d873ecad843b0307b143362d4c42172" + integrity sha512-7++dFhtcx3353uBaq8DDR4NuxBetBzC7ZQOhmTQInHEd6bSrXdiEyzCvG07Z44UYdLShWUyXt5M/yhz8ekcb1A== + +signal-exit@^3.0.3: + version "3.0.7" + resolved "https://registry.yarnpkg.com/signal-exit/-/signal-exit-3.0.7.tgz#a9a1767f8af84155114eaabd73f99273c8f59ad9" + integrity sha512-wnD2ZE+l+SPC/uoS0vXeE9L1+0wuaMqKlfz9AMUo38JsyLSBWSFcHR1Rri62LZc12vLr1gb3jl7iwQhgwpAbGQ== + +slash@^3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/slash/-/slash-3.0.0.tgz#6539be870c165adbd5240220dbe361f1bc4d4634" + integrity sha512-g9Q1haeby36OSStwb4ntCGGGaKsaVSjQ68fBxoQcutl5fS1vuY18H3wSt3jFyFtrkx+Kz0V1G85A4MyAdDMi2Q== + +sockjs@^0.3.21: + version "0.3.24" + resolved "https://registry.yarnpkg.com/sockjs/-/sockjs-0.3.24.tgz#c9bc8995f33a111bea0395ec30aa3206bdb5ccce" + integrity sha512-GJgLTZ7vYb/JtPSSZ10hsOYIvEYsjbNU+zPdIHcUaWVNUEPivzxku31865sSSud0Da0W4lEeOPlmw93zLQchuQ== + dependencies: + faye-websocket "^0.11.3" + uuid "^8.3.2" + websocket-driver "^0.7.4" + +source-map-js@^1.0.2: + version "1.0.2" + resolved "https://registry.yarnpkg.com/source-map-js/-/source-map-js-1.0.2.tgz#adbc361d9c62df380125e7f161f71c826f1e490c" + integrity sha512-R0XvVJ9WusLiqTCEiGCmICCMplcCkIwwR11mOSD9CR5u+IXYdiseeEuXCVAjS54zqwkLcPNnmU4OeJ6tUrWhDw== + +source-map-support@~0.5.20: + version "0.5.21" + resolved "https://registry.yarnpkg.com/source-map-support/-/source-map-support-0.5.21.tgz#04fe7c7f9e1ed2d662233c28cb2b35b9f63f6e4f" + integrity sha512-uBHU3L3czsIyYXKX88fdrGovxdSCoTGDRZ6SYXtSRxLZUzHg5P/66Ht6uoUlHu9EZod+inXhKo3qQgwXUT/y1w== + dependencies: + buffer-from "^1.0.0" + source-map "^0.6.0" + +source-map@^0.6.0, source-map@^0.6.1, source-map@~0.6.0: + version "0.6.1" + resolved "https://registry.yarnpkg.com/source-map/-/source-map-0.6.1.tgz#74722af32e9614e9c287a8d0bbde48b5e2f1a263" + integrity sha512-UjgapumWlbMhkBgzT7Ykc5YXUT46F0iKu8SGXq0bcwP5dz/h0Plj6enJqjz1Zbq2l5WaqYnrVbwWOWMyF3F47g== + +source-map@~0.7.2: + version "0.7.3" + resolved "https://registry.yarnpkg.com/source-map/-/source-map-0.7.3.tgz#5302f8169031735226544092e64981f751750383" + integrity sha512-CkCj6giN3S+n9qrYiBTX5gystlENnRW5jZeNLHpe6aue+SrHcG5VYwujhW9s4dY31mEGsxBDrHR6oI69fTXsaQ== + +spdy-transport@^3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/spdy-transport/-/spdy-transport-3.0.0.tgz#00d4863a6400ad75df93361a1608605e5dcdcf31" + integrity sha512-hsLVFE5SjA6TCisWeJXFKniGGOpBgMLmerfO2aCyCU5s7nJ/rpAepqmFifv/GCbSbueEeAJJnmSQ2rKC/g8Fcw== + dependencies: + debug "^4.1.0" + detect-node "^2.0.4" + hpack.js "^2.1.6" + obuf "^1.1.2" + readable-stream "^3.0.6" + wbuf "^1.7.3" + +spdy@^4.0.2: + version "4.0.2" + resolved "https://registry.yarnpkg.com/spdy/-/spdy-4.0.2.tgz#b74f466203a3eda452c02492b91fb9e84a27677b" + integrity sha512-r46gZQZQV+Kl9oItvl1JZZqJKGr+oEkB08A6BzkiR7593/7IbtuncXHd2YoYeTsG4157ZssMu9KYvUHLcjcDoA== + dependencies: + debug "^4.1.0" + handle-thing "^2.0.0" + http-deceiver "^1.2.7" + select-hose "^2.0.0" + spdy-transport "^3.0.0" + +"statuses@>= 1.4.0 < 2", "statuses@>= 1.5.0 < 2", statuses@~1.5.0: + version "1.5.0" + resolved "https://registry.yarnpkg.com/statuses/-/statuses-1.5.0.tgz#161c7dac177659fd9811f43771fa99381478628c" + integrity sha1-Fhx9rBd2Wf2YEfQ3cfqZOBR4Yow= + +string-convert@^0.2.0: + version "0.2.1" + resolved "https://registry.yarnpkg.com/string-convert/-/string-convert-0.2.1.tgz#6982cc3049fbb4cd85f8b24568b9d9bf39eeff97" + integrity sha1-aYLMMEn7tM2F+LJFaLnZvznu/5c= + +string_decoder@^1.1.1: + version "1.3.0" + resolved "https://registry.yarnpkg.com/string_decoder/-/string_decoder-1.3.0.tgz#42f114594a46cf1a8e30b0a84f56c78c3edac21e" + integrity sha512-hkRX8U1WjJFd8LsDJ2yQ/wWWxaopEsABU1XfkM8A+j0+85JAGppt16cr1Whg6KIbb4okU6Mql6BOj+uup/wKeA== + dependencies: + safe-buffer "~5.2.0" + +string_decoder@~1.1.1: + version "1.1.1" + resolved "https://registry.yarnpkg.com/string_decoder/-/string_decoder-1.1.1.tgz#9cf1611ba62685d7030ae9e4ba34149c3af03fc8" + integrity sha512-n/ShnvDi6FHbbVfviro+WojiFzv+s8MPMHBczVePfUpDJLwoLT0ht1l4YwBCbi8pJAveEEdnkHyPyTP/mzRfwg== + dependencies: + safe-buffer "~5.1.0" + +strip-ansi@^6.0.1: + version "6.0.1" + resolved "https://registry.yarnpkg.com/strip-ansi/-/strip-ansi-6.0.1.tgz#9e26c63d30f53443e9489495b2105d37b67a85d9" + integrity sha512-Y38VPSHcqkFrCpFnQ9vuSXmquuv5oXOKpGeT6aGrr3o3Gc9AlVa6JBfUSOCnbxGGZF+/0ooI7KrPuUSztUdU5A== + dependencies: + ansi-regex "^5.0.1" + +strip-ansi@^7.0.0: + version "7.0.1" + resolved "https://registry.yarnpkg.com/strip-ansi/-/strip-ansi-7.0.1.tgz#61740a08ce36b61e50e65653f07060d000975fb2" + integrity sha512-cXNxvT8dFNRVfhVME3JAe98mkXDYN2O1l7jmcwMnOslDeESg1rF/OZMtK0nRAhiari1unG5cD4jG3rapUAkLbw== + dependencies: + ansi-regex "^6.0.1" + +strip-final-newline@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/strip-final-newline/-/strip-final-newline-2.0.0.tgz#89b852fb2fcbe936f6f4b3187afb0a12c1ab58ad" + integrity sha512-BrpvfNAE3dcvq7ll3xVumzjKjZQ5tI1sEUIKr3Uoks0XUl45St3FlatVqef9prk4jRDzhW6WZg+3bk93y6pLjA== + +style-loader@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/style-loader/-/style-loader-2.0.0.tgz#9669602fd4690740eaaec137799a03addbbc393c" + integrity sha512-Z0gYUJmzZ6ZdRUqpg1r8GsaFKypE+3xAzuFeMuoHgjc9KZv3wMyCRjQIWEbhoFSq7+7yoHXySDJyyWQaPajeiQ== + dependencies: + loader-utils "^2.0.0" + schema-utils "^3.0.0" + +supports-color@^7.1.0: + version "7.2.0" + resolved "https://registry.yarnpkg.com/supports-color/-/supports-color-7.2.0.tgz#1b7dcdcb32b8138801b3e478ba6a51caa89648da" + integrity sha512-qpCAvRl9stuOHveKsn7HncJRvv501qIacKzQlO/+Lwxc9+0q2wLyv4Dfvt80/DPn2pqOBsJdDiogXGR9+OvwRw== + dependencies: + has-flag "^4.0.0" + +supports-color@^8.0.0: + version "8.1.1" + resolved "https://registry.yarnpkg.com/supports-color/-/supports-color-8.1.1.tgz#cd6fc17e28500cff56c1b86c0a7fd4a54a73005c" + integrity sha512-MpUEN2OodtUzxvKQl72cUF7RQ5EiHsGvSsVG0ia9c5RbWGL2CI4C7EpPS8UTBIplnlzZiNuV56w+FuNxy3ty2Q== + dependencies: + has-flag "^4.0.0" + +supports-preserve-symlinks-flag@^1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/supports-preserve-symlinks-flag/-/supports-preserve-symlinks-flag-1.0.0.tgz#6eda4bd344a3c94aea376d4cc31bc77311039e09" + integrity sha512-ot0WnXS9fgdkgIcePe6RHNk1WA8+muPa6cSjeR3V8K27q9BB1rTE3R1p7Hv0z1ZyAc8s6Vvv8DIyWf681MAt0w== + +tapable@^1.0.0: + version "1.1.3" + resolved "https://registry.yarnpkg.com/tapable/-/tapable-1.1.3.tgz#a1fccc06b58db61fd7a45da2da44f5f3a3e67ba2" + integrity sha512-4WK/bYZmj8xLr+HUCODHGF1ZFzsYffasLUgEiMBY4fgtltdO6B4WJtlSbPaDTLpYTcGVwM2qLnFTICEcNxs3kA== + +tapable@^2.0.0, tapable@^2.1.1, tapable@^2.2.0: + version "2.2.1" + resolved "https://registry.yarnpkg.com/tapable/-/tapable-2.2.1.tgz#1967a73ef4060a82f12ab96af86d52fdb76eeca0" + integrity sha512-GNzQvQTOIP6RyTfE2Qxb8ZVlNmw0n88vp1szwWRimP02mnTsx3Wtn5qRdqY9w2XduFNUgvOwhNnQsjwCp+kqaQ== + +terser-webpack-plugin@^5.1.3: + version "5.3.1" + resolved "https://registry.yarnpkg.com/terser-webpack-plugin/-/terser-webpack-plugin-5.3.1.tgz#0320dcc270ad5372c1e8993fabbd927929773e54" + integrity sha512-GvlZdT6wPQKbDNW/GDQzZFg/j4vKU96yl2q6mcUkzKOgW4gwf1Z8cZToUCrz31XHlPWH8MVb1r2tFtdDtTGJ7g== + dependencies: + jest-worker "^27.4.5" + schema-utils "^3.1.1" + serialize-javascript "^6.0.0" + source-map "^0.6.1" + terser "^5.7.2" + +terser@^5.10.0, terser@^5.7.2: + version "5.12.0" + resolved "https://registry.yarnpkg.com/terser/-/terser-5.12.0.tgz#728c6bff05f7d1dcb687d8eace0644802a9dae8a" + integrity sha512-R3AUhNBGWiFc77HXag+1fXpAxTAFRQTJemlJKjAgD9r8xXTpjNKqIXwHM/o7Rh+O0kUJtS3WQVdBeMKFk5sw9A== + dependencies: + acorn "^8.5.0" + commander "^2.20.0" + source-map "~0.7.2" + source-map-support "~0.5.20" + +thunky@^1.0.2: + version "1.1.0" + resolved "https://registry.yarnpkg.com/thunky/-/thunky-1.1.0.tgz#5abaf714a9405db0504732bbccd2cedd9ef9537d" + integrity sha512-eHY7nBftgThBqOyHGVN+l8gF0BucP09fMo0oO/Lb0w1OF80dJv+lDVpXG60WMQvkcxAkNybKsrEIE3ZtKGmPrA== + +tiny-warning@^1.0.2: + version "1.0.3" + resolved "https://registry.yarnpkg.com/tiny-warning/-/tiny-warning-1.0.3.tgz#94a30db453df4c643d0fd566060d60a875d84754" + integrity sha512-lBN9zLN/oAf68o3zNXYrdCt1kP8WsiGW8Oo2ka41b2IM5JL/S1CTyX1rW0mb/zSuJun0ZUrDxx4sqvYS2FWzPA== + +to-regex-range@^5.0.1: + version "5.0.1" + resolved "https://registry.yarnpkg.com/to-regex-range/-/to-regex-range-5.0.1.tgz#1648c44aae7c8d988a326018ed72f5b4dd0392e4" + integrity sha512-65P7iz6X5yEr1cwcgvQxbbIw7Uk3gOy5dIdtZ4rDveLqhrdJP+Li/Hx6tyK0NEb+2GCyneCMJiGqrADCSNk8sQ== + dependencies: + is-number "^7.0.0" + +toggle-selection@^1.0.6: + version "1.0.6" + resolved "https://registry.yarnpkg.com/toggle-selection/-/toggle-selection-1.0.6.tgz#6e45b1263f2017fa0acc7d89d78b15b8bf77da32" + integrity sha1-bkWxJj8gF/oKzH2J14sVuL932jI= + +toidentifier@1.0.1: + version "1.0.1" + resolved "https://registry.yarnpkg.com/toidentifier/-/toidentifier-1.0.1.tgz#3be34321a88a820ed1bd80dfaa33e479fbb8dd35" + integrity sha512-o5sSPKEkg/DIQNmH43V0/uerLrpzVedkUh8tGNvaeXpfpuwjKenlSox/2O/BTlZUtEe+JG7s5YhEz608PlAHRA== + +tr46@~0.0.3: + version "0.0.3" + resolved "https://registry.yarnpkg.com/tr46/-/tr46-0.0.3.tgz#8184fd347dac9cdc185992f3a6622e14b9d9ab6a" + integrity sha1-gYT9NH2snNwYWZLzpmIuFLnZq2o= + +ts-loader@^8.0.18: + version "8.3.0" + resolved "https://registry.yarnpkg.com/ts-loader/-/ts-loader-8.3.0.tgz#83360496d6f8004fab35825279132c93412edf33" + integrity sha512-MgGly4I6cStsJy27ViE32UoqxPTN9Xly4anxxVyaIWR+9BGxboV4EyJBGfR3RePV7Ksjj3rHmPZJeIt+7o4Vag== + dependencies: + chalk "^4.1.0" + enhanced-resolve "^4.0.0" + loader-utils "^2.0.0" + micromatch "^4.0.0" + semver "^7.3.4" + +tslib@^2.0.3: + version "2.3.1" + resolved "https://registry.yarnpkg.com/tslib/-/tslib-2.3.1.tgz#e8a335add5ceae51aa261d32a490158ef042ef01" + integrity sha512-77EbyPPpMz+FRFRuAFlWMtmgUWGe9UOG2Z25NqCwiIjRhOf5iKGuzSe5P2w1laq+FkRy4p+PCuVkJSGkzTEKVw== + +type-is@~1.6.18: + version "1.6.18" + resolved "https://registry.yarnpkg.com/type-is/-/type-is-1.6.18.tgz#4e552cd05df09467dcbc4ef739de89f2cf37c131" + integrity sha512-TkRKr9sUTxEH8MdfuCSP7VizJyzRNMjj2J2do2Jr3Kym598JVdEksuzPQCnlFPW4ky9Q+iA+ma9BGm06XQBy8g== + dependencies: + media-typer "0.3.0" + mime-types "~2.1.24" + +typescript@^4.0.3: + version "4.6.2" + resolved "https://registry.yarnpkg.com/typescript/-/typescript-4.6.2.tgz#fe12d2727b708f4eef40f51598b3398baa9611d4" + integrity sha512-HM/hFigTBHZhLXshn9sN37H085+hQGeJHJ/X7LpBWLID/fbc2acUMfU+lGD98X81sKP+pFa9f0DZmCwB9GnbAg== + +unpipe@1.0.0, unpipe@~1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/unpipe/-/unpipe-1.0.0.tgz#b2bf4ee8514aae6165b4817829d21b2ef49904ec" + integrity sha1-sr9O6FFKrmFltIF4KdIbLvSZBOw= + +uri-js@^4.2.2: + version "4.4.1" + resolved "https://registry.yarnpkg.com/uri-js/-/uri-js-4.4.1.tgz#9b1a52595225859e55f669d928f88c6c57f2a77e" + integrity sha512-7rKUyy33Q1yc98pQ1DAmLtwX109F7TIfWlW1Ydo8Wl1ii1SeHieeh0HHfPeL2fMXK6z0s8ecKs9frCuLJvndBg== + dependencies: + punycode "^2.1.0" + +util-deprecate@^1.0.1, util-deprecate@^1.0.2, util-deprecate@~1.0.1: + version "1.0.2" + resolved "https://registry.yarnpkg.com/util-deprecate/-/util-deprecate-1.0.2.tgz#450d4dc9fa70de732762fbd2d4a28981419a0ccf" + integrity sha1-RQ1Nyfpw3nMnYvvS1KKJgUGaDM8= + +utila@~0.4: + version "0.4.0" + resolved "https://registry.yarnpkg.com/utila/-/utila-0.4.0.tgz#8a16a05d445657a3aea5eecc5b12a4fa5379772c" + integrity sha1-ihagXURWV6Oupe7MWxKk+lN5dyw= + +utils-merge@1.0.1: + version "1.0.1" + resolved "https://registry.yarnpkg.com/utils-merge/-/utils-merge-1.0.1.tgz#9f95710f50a267947b2ccc124741c1028427e713" + integrity sha1-n5VxD1CiZ5R7LMwSR0HBAoQn5xM= + +uuid@^8.3.2: + version "8.3.2" + resolved "https://registry.yarnpkg.com/uuid/-/uuid-8.3.2.tgz#80d5b5ced271bb9af6c445f21a1a04c606cefbe2" + integrity sha512-+NYs2QeMWy+GWFOEm9xnn6HCDp0l7QBD7ml8zLUmJ+93Q5NF0NocErnwkTkXVFNiX3/fpC6afS8Dhb/gz7R7eg== + +vary@~1.1.2: + version "1.1.2" + resolved "https://registry.yarnpkg.com/vary/-/vary-1.1.2.tgz#2299f02c6ded30d4a5961b0b9f74524a18f634fc" + integrity sha1-IpnwLG3tMNSllhsLn3RSShj2NPw= + +watchpack@^2.3.1: + version "2.3.1" + resolved "https://registry.yarnpkg.com/watchpack/-/watchpack-2.3.1.tgz#4200d9447b401156eeca7767ee610f8809bc9d25" + integrity sha512-x0t0JuydIo8qCNctdDrn1OzH/qDzk2+rdCOC3YzumZ42fiMqmQ7T3xQurykYMhYfHaPHTp4ZxAx2NfUo1K6QaA== + dependencies: + glob-to-regexp "^0.4.1" + graceful-fs "^4.1.2" + +wbuf@^1.1.0, wbuf@^1.7.3: + version "1.7.3" + resolved "https://registry.yarnpkg.com/wbuf/-/wbuf-1.7.3.tgz#c1d8d149316d3ea852848895cb6a0bfe887b87df" + integrity sha512-O84QOnr0icsbFGLS0O3bI5FswxzRr8/gHwWkDlQFskhSPryQXvrTMxjxGP4+iWYoauLoBvfDpkrOauZ+0iZpDA== + dependencies: + minimalistic-assert "^1.0.0" + +webidl-conversions@^3.0.0: + version "3.0.1" + resolved "https://registry.yarnpkg.com/webidl-conversions/-/webidl-conversions-3.0.1.tgz#24534275e2a7bc6be7bc86611cc16ae0a5654871" + integrity sha1-JFNCdeKnvGvnvIZhHMFq4KVlSHE= + +webpack-cli@^4.5.0: + version "4.9.2" + resolved "https://registry.yarnpkg.com/webpack-cli/-/webpack-cli-4.9.2.tgz#77c1adaea020c3f9e2db8aad8ea78d235c83659d" + integrity sha512-m3/AACnBBzK/kMTcxWHcZFPrw/eQuY4Df1TxvIWfWM2x7mRqBQCqKEd96oCUa9jkapLBaFfRce33eGDb4Pr7YQ== + dependencies: + "@discoveryjs/json-ext" "^0.5.0" + "@webpack-cli/configtest" "^1.1.1" + "@webpack-cli/info" "^1.4.1" + "@webpack-cli/serve" "^1.6.1" + colorette "^2.0.14" + commander "^7.0.0" + execa "^5.0.0" + fastest-levenshtein "^1.0.12" + import-local "^3.0.2" + interpret "^2.2.0" + rechoir "^0.7.0" + webpack-merge "^5.7.3" + +webpack-dev-middleware@^5.3.1: + version "5.3.1" + resolved "https://registry.yarnpkg.com/webpack-dev-middleware/-/webpack-dev-middleware-5.3.1.tgz#aa079a8dedd7e58bfeab358a9af7dab304cee57f" + integrity sha512-81EujCKkyles2wphtdrnPg/QqegC/AtqNH//mQkBYSMqwFVCQrxM6ktB2O/SPlZy7LqeEfTbV3cZARGQz6umhg== + dependencies: + colorette "^2.0.10" + memfs "^3.4.1" + mime-types "^2.1.31" + range-parser "^1.2.1" + schema-utils "^4.0.0" + +webpack-dev-server@^4.7.4: + version "4.7.4" + resolved "https://registry.yarnpkg.com/webpack-dev-server/-/webpack-dev-server-4.7.4.tgz#d0ef7da78224578384e795ac228d8efb63d5f945" + integrity sha512-nfdsb02Zi2qzkNmgtZjkrMOcXnYZ6FLKcQwpxT7MvmHKc+oTtDsBju8j+NMyAygZ9GW1jMEUpy3itHtqgEhe1A== + dependencies: + "@types/bonjour" "^3.5.9" + "@types/connect-history-api-fallback" "^1.3.5" + "@types/express" "^4.17.13" + "@types/serve-index" "^1.9.1" + "@types/sockjs" "^0.3.33" + "@types/ws" "^8.2.2" + ansi-html-community "^0.0.8" + bonjour "^3.5.0" + chokidar "^3.5.3" + colorette "^2.0.10" + compression "^1.7.4" + connect-history-api-fallback "^1.6.0" + default-gateway "^6.0.3" + del "^6.0.0" + express "^4.17.1" + graceful-fs "^4.2.6" + html-entities "^2.3.2" + http-proxy-middleware "^2.0.0" + ipaddr.js "^2.0.1" + open "^8.0.9" + p-retry "^4.5.0" + portfinder "^1.0.28" + schema-utils "^4.0.0" + selfsigned "^2.0.0" + serve-index "^1.9.1" + sockjs "^0.3.21" + spdy "^4.0.2" + strip-ansi "^7.0.0" + webpack-dev-middleware "^5.3.1" + ws "^8.4.2" + +webpack-merge@^5.7.3: + version "5.8.0" + resolved "https://registry.yarnpkg.com/webpack-merge/-/webpack-merge-5.8.0.tgz#2b39dbf22af87776ad744c390223731d30a68f61" + integrity sha512-/SaI7xY0831XwP6kzuwhKWVKDP9t1QY1h65lAFLbZqMPIuYcD9QAW4u9STIbU9kaJbPBB/geU/gLr1wDjOhQ+Q== + dependencies: + clone-deep "^4.0.1" + wildcard "^2.0.0" + +webpack-sources@^3.2.3: + version "3.2.3" + resolved "https://registry.yarnpkg.com/webpack-sources/-/webpack-sources-3.2.3.tgz#2d4daab8451fd4b240cc27055ff6a0c2ccea0cde" + integrity sha512-/DyMEOrDgLKKIG0fmvtz+4dUX/3Ghozwgm6iPp8KRhvn+eQf9+Q7GWxVNMk3+uCPWfdXYC4ExGBckIXdFEfH1w== + +webpack@^5.28.0: + version "5.70.0" + resolved "https://registry.yarnpkg.com/webpack/-/webpack-5.70.0.tgz#3461e6287a72b5e6e2f4872700bc8de0d7500e6d" + integrity sha512-ZMWWy8CeuTTjCxbeaQI21xSswseF2oNOwc70QSKNePvmxE7XW36i7vpBMYZFAUHPwQiEbNGCEYIOOlyRbdGmxw== + dependencies: + "@types/eslint-scope" "^3.7.3" + "@types/estree" "^0.0.51" + "@webassemblyjs/ast" "1.11.1" + "@webassemblyjs/wasm-edit" "1.11.1" + "@webassemblyjs/wasm-parser" "1.11.1" + acorn "^8.4.1" + acorn-import-assertions "^1.7.6" + browserslist "^4.14.5" + chrome-trace-event "^1.0.2" + enhanced-resolve "^5.9.2" + es-module-lexer "^0.9.0" + eslint-scope "5.1.1" + events "^3.2.0" + glob-to-regexp "^0.4.1" + graceful-fs "^4.2.9" + json-parse-better-errors "^1.0.2" + loader-runner "^4.2.0" + mime-types "^2.1.27" + neo-async "^2.6.2" + schema-utils "^3.1.0" + tapable "^2.1.1" + terser-webpack-plugin "^5.1.3" + watchpack "^2.3.1" + webpack-sources "^3.2.3" + +websocket-driver@>=0.5.1, websocket-driver@^0.7.4: + version "0.7.4" + resolved "https://registry.yarnpkg.com/websocket-driver/-/websocket-driver-0.7.4.tgz#89ad5295bbf64b480abcba31e4953aca706f5760" + integrity sha512-b17KeDIQVjvb0ssuSDF2cYXSg2iztliJ4B9WdsuB6J952qCPKmnVq4DyW5motImXHDC1cBT/1UezrJVsKw5zjg== + dependencies: + http-parser-js ">=0.5.1" + safe-buffer ">=5.1.0" + websocket-extensions ">=0.1.1" + +websocket-extensions@>=0.1.1: + version "0.1.4" + resolved "https://registry.yarnpkg.com/websocket-extensions/-/websocket-extensions-0.1.4.tgz#7f8473bc839dfd87608adb95d7eb075211578a42" + integrity sha512-OqedPIGOfsDlo31UNwYbCFMSaO9m9G/0faIHj5/dZFDMFqPTcx6UwqyOy3COEaEOg/9VsGIpdqn62W5KhoKSpg== + +whatwg-fetch@>=0.10.0: + version "3.6.2" + resolved "https://registry.yarnpkg.com/whatwg-fetch/-/whatwg-fetch-3.6.2.tgz#dced24f37f2624ed0281725d51d0e2e3fe677f8c" + integrity sha512-bJlen0FcuU/0EMLrdbJ7zOnW6ITZLrZMIarMUVmdKtsGvZna8vxKYaexICWPfZ8qwf9fzNq+UEIZrnSaApt6RA== + +whatwg-url@^5.0.0: + version "5.0.0" + resolved "https://registry.yarnpkg.com/whatwg-url/-/whatwg-url-5.0.0.tgz#966454e8765462e37644d3626f6742ce8b70965d" + integrity sha1-lmRU6HZUYuN2RNNib2dCzotwll0= + dependencies: + tr46 "~0.0.3" + webidl-conversions "^3.0.0" + +which@^2.0.1: + version "2.0.2" + resolved "https://registry.yarnpkg.com/which/-/which-2.0.2.tgz#7c6a8dd0a636a0327e10b59c9286eee93f3f51b1" + integrity sha512-BLI3Tl1TW3Pvl70l3yq3Y64i+awpwXqsGBYWkkqMtnbXgrMD+yj7rhW0kuEDxzJaYXGjEW5ogapKNMEKNMjibA== + dependencies: + isexe "^2.0.0" + +wildcard@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/wildcard/-/wildcard-2.0.0.tgz#a77d20e5200c6faaac979e4b3aadc7b3dd7f8fec" + integrity sha512-JcKqAHLPxcdb9KM49dufGXn2x3ssnfjbcaQdLlfZsL9rH9wgDQjUtDxbo8NE0F6SFvydeu1VhZe7hZuHsB2/pw== + +wrappy@1: + version "1.0.2" + resolved "https://registry.yarnpkg.com/wrappy/-/wrappy-1.0.2.tgz#b5243d8f3ec1aa35f1364605bc0d1036e30ab69f" + integrity sha1-tSQ9jz7BqjXxNkYFvA0QNuMKtp8= + +ws@^8.4.2: + version "8.5.0" + resolved "https://registry.yarnpkg.com/ws/-/ws-8.5.0.tgz#bfb4be96600757fe5382de12c670dab984a1ed4f" + integrity sha512-BWX0SWVgLPzYwF8lTzEy1egjhS4S4OEAHfsO8o65WOVsrnSRGaSiUaa9e0ggGlkMTtBlmOpEXiie9RUcBO86qg== + +yallist@^4.0.0: + version "4.0.0" + resolved "https://registry.yarnpkg.com/yallist/-/yallist-4.0.0.tgz#9bb92790d9c0effec63be73519e11a35019a3a72" + integrity sha512-3wdGidZyq5PB084XLES5TpOSRA3wjXAlIWMhum2kRcv/41Sn2emQ0dycQW4uZXLejwKvg6EsvbdlVL+FYEct7A== diff --git a/plugins/tensorboard-plugins/tb_plugin/test/resources/resnet50_num_workers_0/worker0.1623143089861.pt.trace.json.gz b/plugins/tensorboard-plugins/tb_plugin/samples/resnet50_num_workers_0/worker0.1623143089861.pt.trace.json.gz similarity index 100% rename from plugins/tensorboard-plugins/tb_plugin/test/resources/resnet50_num_workers_0/worker0.1623143089861.pt.trace.json.gz rename to plugins/tensorboard-plugins/tb_plugin/samples/resnet50_num_workers_0/worker0.1623143089861.pt.trace.json.gz diff --git a/plugins/tensorboard-plugins/tb_plugin/test/resources/resnet50_num_workers_0/worker0.1623143566756.pt.trace.json.gz b/plugins/tensorboard-plugins/tb_plugin/samples/resnet50_num_workers_0/worker0.1623143566756.pt.trace.json.gz similarity index 100% rename from plugins/tensorboard-plugins/tb_plugin/test/resources/resnet50_num_workers_0/worker0.1623143566756.pt.trace.json.gz rename to plugins/tensorboard-plugins/tb_plugin/samples/resnet50_num_workers_0/worker0.1623143566756.pt.trace.json.gz diff --git a/plugins/tensorboard-plugins/tb_plugin/test/resources/resnet50_num_workers_4/worker0.1623212756351.pt.trace.json.gz b/plugins/tensorboard-plugins/tb_plugin/samples/resnet50_num_workers_4/worker0.1623212756351.pt.trace.json.gz similarity index 100% rename from plugins/tensorboard-plugins/tb_plugin/test/resources/resnet50_num_workers_4/worker0.1623212756351.pt.trace.json.gz rename to plugins/tensorboard-plugins/tb_plugin/samples/resnet50_num_workers_4/worker0.1623212756351.pt.trace.json.gz diff --git a/plugins/tensorboard-plugins/tb_plugin/test/resources/resnet50_num_workers_4/worker0.1623213129365.pt.trace.json.gz b/plugins/tensorboard-plugins/tb_plugin/samples/resnet50_num_workers_4/worker0.1623213129365.pt.trace.json.gz similarity index 100% rename from plugins/tensorboard-plugins/tb_plugin/test/resources/resnet50_num_workers_4/worker0.1623213129365.pt.trace.json.gz rename to plugins/tensorboard-plugins/tb_plugin/samples/resnet50_num_workers_4/worker0.1623213129365.pt.trace.json.gz diff --git a/plugins/tensorboard-plugins/tb_plugin/setup.py b/plugins/tensorboard-plugins/tb_plugin/setup.py index 2d4260b2133ae00a91831a7e2867b467e029d108..3c09006122c776df8fbe8af5836711613e3f6a9c 100644 --- a/plugins/tensorboard-plugins/tb_plugin/setup.py +++ b/plugins/tensorboard-plugins/tb_plugin/setup.py @@ -1,5 +1,6 @@ # ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. +# Copyright (c) Microsoft Corporation. All rights reserved. +# # Copyright(c) 2023 Huawei Technologies. # All rights reserved # @@ -20,13 +21,8 @@ import os import pathlib import subprocess -from configparser import ConfigParser - import setuptools -config = ConfigParser() -config.read('./torch_tb_profiler/config/config.ini') - def read(rel_path): here = os.path.abspath(os.path.dirname(__file__)) @@ -87,16 +83,17 @@ setuptools.setup( name="torch-tb-profiler-ascend", version=get_version(os.path.join('torch_tb_profiler', '__init__.py')), description="PyTorch Ascend Profiler TensorBoard Plugin", - long_description=f"PyTorch Ascend Profiler TensorBoard Plugin: {config.get('URL', 'repository_url')}", - url=config.get('URL', 'repository_url'), + long_description="PyTorch Ascend Profiler TensorBoard Plugin : \ + https://gitee.com/ascend/att/tree/master/plugins/tensorboard-plugins/tb_plugin", + url="https://gitee.com/ascend/att/tree/master/plugins/tensorboard-plugins/tb_plugin", author="Ascend Team", - author_email=config.get('EMAIL', 'author_email'), + author_email="pmail_mindstudio@huawei.com", cmdclass={ "build_fe": build_fe }, packages=setuptools.find_packages(), package_data={ - "torch_tb_profiler": ["static/**", "config/**"], + "torch_tb_profiler": ["static/**"], }, entry_points={ "tensorboard_plugins": [ diff --git a/plugins/tensorboard-plugins/tb_plugin/test/test_tensorboard_end2end.py b/plugins/tensorboard-plugins/tb_plugin/test/test_tensorboard_end2end.py index 46636d11801a739935b4f385c6ce548009d09916..fae95b49050537b921e291a4771c63a6bff35690 100644 --- a/plugins/tensorboard-plugins/tb_plugin/test/test_tensorboard_end2end.py +++ b/plugins/tensorboard-plugins/tb_plugin/test/test_tensorboard_end2end.py @@ -13,7 +13,7 @@ from urllib.error import HTTPError def get_samples_dir(): - return os.path.join(os.path.dirname(os.path.abspath(__file__)), 'resources') + return os.path.join(os.path.dirname(os.path.abspath(__file__)), '../samples') class TestEnd2End(unittest.TestCase): diff --git a/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/__init__.py b/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/__init__.py index f7b951e609e5c65895a6db82d391e8d584eb37c8..fd7b265cfa7d67023075ec8d9bc59ed85f4e0f15 100644 --- a/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/__init__.py +++ b/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/__init__.py @@ -4,4 +4,4 @@ # Entry point for Pytorch TensorBoard plugin package. -__version__ = '0.4.0.11' +__version__ = '0.4.0.8' diff --git a/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/config/config.ini b/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/config/config.ini deleted file mode 100644 index 500d472d27b2ca574e07829a64c50d6eb2ab7e71..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/config/config.ini +++ /dev/null @@ -1,11 +0,0 @@ -[URL] -pytorch_data_loading_url = https://pytorch.org/docs/stable/data.html#single-and-multi-process-data-loading -pytorch_amp_url = https://pytorch.org/docs/stable/amp.html -pytorch_ckp_url = https://pytorch.org/docs/stable/checkpoint.html -cuda_nn_ddp_instead_url = https://pytorch.org/docs/stable/notes/cuda.html#cuda-nn-ddp-instead -compress_url = https://pytorch.org/docs/stable/ddp_comm_hooks.html -grad_acc_url = https://towardsdatascience.com/what-is-gradient-accumulation-in-deep-learning-ec034122cfa -lamb_url = https://nvidia.github.io/apex/optimizers.html#apex.optimizers.FusedLAMB -repository_url = https://gitee.com/ascend/att/tree/master/plugins/tensorboard-plugins/tb_plugin -[EMAIL] -author_email = pmail_mindstudio@huawei.com \ No newline at end of file diff --git a/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/consts.py b/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/consts.py index b3e202af61eb9df1d210cd366e7d172075e1e570..533effb8bb91f1f775fb1b98725b63854182ef53 100644 --- a/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/consts.py +++ b/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/consts.py @@ -35,8 +35,6 @@ NODE_PROCESS_PATTERN = re.compile(r"""^(.*)_(\d+)""") MONITOR_RUN_REFRESH_INTERNAL_IN_SECONDS = 10 MAX_GPU_PER_NODE = 64 MAX_FILE_SIZE = 500 * 1024 * 1024 -MAX_LINUX_PATH_LENGTH = 4096 -MAX_WINDOWS_PATH_LENGTH = 260 View = namedtuple('View', 'id, name, display_name') OVERALL_VIEW = View(1, 'overall', 'Overview') diff --git a/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/io/__init__.py b/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/io/__init__.py index 296f53b7c813b2c97b498469f49b973438d9f3ae..6bd764e88d4fecd142e7a953b1adb5c4a72262b9 100644 --- a/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/io/__init__.py +++ b/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/io/__init__.py @@ -1,23 +1,4 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. -# Copyright(c) 2023 Huawei Technologies. -# All rights reserved -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# Modifications: Add visualization of PyTorch Ascend profiling. -# -------------------------------------------------------------------------- from .cache import Cache from .file import (BaseFileSystem, StatData, abspath, basename, download_file, exists, get_filesystem, glob, isdir, join, listdir, - makedirs, read, register_filesystem, relpath, walk, stat, check_file_valid) + makedirs, read, register_filesystem, relpath, walk, stat) diff --git a/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/io/azureblob.py b/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/io/azureblob.py index 2fcd69fee8c24393458875635c17bd74a71b0fc4..b0ac49a655fd3d999ea80dfc3e6fa62e33fc5269 100644 --- a/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/io/azureblob.py +++ b/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/io/azureblob.py @@ -20,9 +20,9 @@ class AzureBlobSystem(RemotePath, BaseFileSystem): raise ImportError('azure-storage-blob must be installed for Azure Blob support.') self.connection_string = os.environ.get('AZURE_STORAGE_CONNECTION_STRING', None) - def exists(self, filename): + def exists(self, dirname): """Returns whether the path is a directory or not.""" - basename, parts = self.split_blob_path(filename) + basename, parts = self.split_blob_path(dirname) if basename is None or parts is None: return False if basename == '': @@ -31,10 +31,10 @@ class AzureBlobSystem(RemotePath, BaseFileSystem): else: return basename == parts[0] - def read(self, file, binary_mode=False, size=None, continue_from=None): + def read(self, filename, binary_mode=False, size=None, continue_from=None): """Reads contents of a file to a string.""" - logger.info('azure blob: starting reading file %s' % file) - account, container, path = self.container_and_path(file) + logger.info('azure blob: starting reading file %s' % filename) + account, container, path = self.container_and_path(filename) client = self.create_container_client(account, container) blob_client = client.get_blob_client(path) if not blob_client.exists(): @@ -47,7 +47,7 @@ class AzureBlobSystem(RemotePath, BaseFileSystem): continuation_token = downloader.size data = downloader.readall() - logger.info('azure blob: file %s download is done, size is %d' % (file, len(data))) + logger.info('azure blob: file %s download is done, size is %d' % (filename, len(data))) if binary_mode: return as_bytes(data), continuation_token else: @@ -122,7 +122,7 @@ class AzureBlobSystem(RemotePath, BaseFileSystem): items.append(item) return items - def makedirs(self, path): + def makedirs(self, dirname): """No need create directory since the upload blob will automatically create""" pass diff --git a/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/io/file.py b/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/io/file.py index 9ef5d8485264f18426c18147663f2e1b9fb6900e..dc9abb056860d7a7708533bba55995a1ac6a5e79 100644 --- a/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/io/file.py +++ b/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/io/file.py @@ -15,34 +15,32 @@ The following functionalities are added after forking: """ import glob as py_glob import os -import platform -import sys import tempfile from .. import utils from .base import BaseFileSystem, LocalPath, RemotePath, StatData from .utils import as_bytes, as_text, parse_blob_url -from ..consts import MAX_FILE_SIZE, MAX_WINDOWS_PATH_LENGTH, MAX_LINUX_PATH_LENGTH logger = utils.get_logger() -S3_ENABLED = True try: import boto3 import botocore.exceptions + + S3_ENABLED = True except ImportError: S3_ENABLED = False -BLOB_ENABLED = True try: from azure.storage.blob import ContainerClient + BLOB_ENABLED = True except ImportError: BLOB_ENABLED = False -GS_ENABLED = True try: # Imports the Google Cloud client library from google.cloud import storage + GS_ENABLED = True except ImportError: GS_ENABLED = False @@ -88,23 +86,19 @@ class LocalFileSystem(LocalPath, BaseFileSystem): def __init__(self): pass - @staticmethod - def islink(path): - return os.path.islink(path) - def exists(self, filename): return os.path.exists(filename) - def read(self, file, binary_mode=False, size=None, continue_from=None): + def read(self, filename, binary_mode=False, size=None, continue_from=None): mode = "rb" if binary_mode else "r" encoding = None if binary_mode else "utf8" - if not self.exists(file): - raise FileNotFoundError(file) + if not self.exists(filename): + raise FileNotFoundError(filename) offset = None if continue_from is not None: offset = continue_from.get("opaque_offset", None) - with open(file, mode, encoding=encoding) as f: + with open(filename, mode, encoding=encoding) as f: if offset is not None: f.seek(offset) data = f.read(size) @@ -166,6 +160,10 @@ class LocalFileSystem(LocalPath, BaseFileSystem): return StatData(file_length) def walk(self, top, topdown=True, onerror=None): + # Note on followlinks=True: per the tensorboard documentation [1], users are encouraged to + # use symlink trees to have fine-grained control over the filesystem layout of runs. To + # support such trees, we must follow links. + # [1] https://github.com/tensorflow/tensorboard/blob/master/README.md#logdir--logdir_spec-legacy-mode yield from os.walk(top, topdown, onerror, followlinks=True) @@ -200,10 +198,10 @@ class S3FileSystem(RemotePath, BaseFileSystem): return True return False - def read(self, file, binary_mode=False, size=None, continue_from=None): + def read(self, filename, binary_mode=False, size=None, continue_from=None): """Reads contents of a file to a string.""" s3 = boto3.resource("s3", endpoint_url=self._s3_endpoint) - bucket, path = self.bucket_and_path(file) + bucket, path = self.bucket_and_path(filename) args = {} # S3 use continuation tokens of the form: {byte_offset: number} @@ -218,7 +216,7 @@ class S3FileSystem(RemotePath, BaseFileSystem): if offset != 0 or endpoint != "": args["Range"] = "bytes={}-{}".format(offset, endpoint) - logger.info("s3: starting reading file %s" % file) + logger.info("s3: starting reading file %s" % filename) try: stream = s3.Object(bucket, path).get(**args)["Body"].read() except botocore.exceptions.ClientError as exc: @@ -240,7 +238,7 @@ class S3FileSystem(RemotePath, BaseFileSystem): raise logger.info("s3: file %s download is done, size is %d" % - (file, len(stream))) + (filename, len(stream))) # `stream` should contain raw bytes here (i.e., there has been neither decoding nor newline translation), # so the byte offset increases by the expected amount. continuation_token = {"byte_offset": (offset + len(stream))} @@ -263,6 +261,9 @@ class S3FileSystem(RemotePath, BaseFileSystem): def download_file(self, file_to_download, file_to_save): logger.info("s3: starting downloading file %s as %s" % (file_to_download, file_to_save)) + # Use boto3.resource instead of boto3.client('s3') to support minio. + # https://docs.min.io/docs/how-to-use-aws-sdk-for-python-with-minio-server.html + # To support minio, the S3_ENDPOINT need to be set like: S3_ENDPOINT=http://localhost:9000 s3 = boto3.resource("s3", endpoint_url=self._s3_endpoint) bucket, path = self.bucket_and_path(file_to_download) s3.Bucket(bucket).download_file(path, file_to_save) @@ -320,14 +321,14 @@ class S3FileSystem(RemotePath, BaseFileSystem): keys.append(key) return keys - def makedirs(self, path): + def makedirs(self, dirname): """Creates a directory and all parent/intermediate directories.""" - if not self.exists(path): + if not self.exists(dirname): client = boto3.client("s3", endpoint_url=self._s3_endpoint) - bucket, dir_path = self.bucket_and_path(path) - if not dir_path.endswith("/"): - dir_path += "/" - client.put_object(Body="", Bucket=bucket, Key=dir_path) + bucket, path = self.bucket_and_path(dirname) + if not path.endswith("/"): + path += "/" + client.put_object(Body="", Bucket=bucket, Key=path) def stat(self, filename): """Returns file statistics for a given path.""" @@ -465,7 +466,7 @@ class File(object): if line and (line[-1] == "\n" or not self.buff): return line if not self.buff: - return None + raise StopIteration() else: index = self.buff.find("\n", self.buff_offset) if index != -1: @@ -480,7 +481,7 @@ class File(object): if line and (line[-1] == "\n" or not self.buff): return line if not self.buff: - return None + raise StopIteration() def next(self): return self.__next__() @@ -619,40 +620,3 @@ def stat(filename): def read(file): with File(file, 'rb') as f: return f.read() - - -def is_link(path): - return LocalFileSystem.islink(path) - - -def is_too_big_file(filepath): - return stat(filepath).length > MAX_FILE_SIZE - - -def has_too_long_path(filepath): - if platform.system() == 'Windows' and len(filepath) > MAX_WINDOWS_PATH_LENGTH: - logger.warning( - f'The path length of the file "{filepath}" exceeds the maximum limit of {MAX_WINDOWS_PATH_LENGTH} ' - f'and will be skipped.') - return True - elif len(filepath) > MAX_WINDOWS_PATH_LENGTH: - logger.warning( - f'The path length of the file "{filepath}" exceeds the maximum limit of {MAX_LINUX_PATH_LENGTH} ' - f'and will be skipped.') - return True - else: - return False - - -def check_file_valid(filepath): - if is_link(filepath): - logger.warning(f'File "{filepath}" is a soft link and will be skipped.') - return False - if is_too_big_file(filepath): - logger.warning( - f'File "{filepath}" exceeds the maximum limit size of 500MB and will be skipped.') - return False - if has_too_long_path(filepath): - return False - return True - diff --git a/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/io/gs.py b/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/io/gs.py index 8596bce2b892b7188155d05330a6356a83323eff..d3a46877326b12a5e8be49a65cf4c90be8157311 100644 --- a/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/io/gs.py +++ b/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/io/gs.py @@ -16,14 +16,14 @@ class GoogleBlobSystem(RemotePath, BaseFileSystem): if not storage: raise ImportError('google-cloud-storage must be installed for Google Cloud Blob support.') - def exists(self, filename): + def exists(self, dirname): """Returns whether the path is a directory or not.""" - bucket_name, path = self.bucket_and_path(filename) + bucket_name, path = self.bucket_and_path(dirname) client = self.create_google_cloud_client() bucket = client.bucket(bucket_name) return bucket.blob(path).exists() - def read(self, file, binary_mode=False, size=None, continue_from=None): + def read(self, filename, binary_mode=False, size=None, continue_from=None): raise NotImplementedError def write(self, filename, file_content, binary_mode=False): @@ -62,7 +62,7 @@ class GoogleBlobSystem(RemotePath, BaseFileSystem): items.append(item) return items - def makedirs(self, path): + def makedirs(self, dirname): """No need create directory since the upload blob will automatically create""" pass diff --git a/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/plugin.py b/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/plugin.py index 2651f87c087a419c950f93b201606e7601a33a08..6091fdbcd906bf49e4e631afe7d2ba57e65ce711 100644 --- a/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/plugin.py +++ b/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/plugin.py @@ -1,5 +1,6 @@ # ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. +# Copyright (c) Microsoft Corporation. All rights reserved. +# # Copyright(c) 2023 Huawei Technologies. # All rights reserved # @@ -46,7 +47,6 @@ def decorate_headers(func): headers = func(*args, **kwargs) headers.extend(TorchProfilerPlugin.headers) return headers - return wrapper @@ -344,23 +344,14 @@ class TorchProfilerPlugin(base_plugin.TBPlugin): end_ts = float(end_ts) for key in operator_memory_events: if start_ts is not None and end_ts is not None: - operator_memory_events[key] = [ - i - for i in operator_memory_events[key] - if i[2] and start_ts <= i[2] <= end_ts - ] + operator_memory_events[key] = [i for i in operator_memory_events[key] if + i[2] and start_ts <= i[2] <= end_ts] elif start_ts is not None: - operator_memory_events[key] = [ - i - for i in operator_memory_events[key] - if i[2] and start_ts <= i[2] - ] + operator_memory_events[key] = [i for i in operator_memory_events[key] if + i[2] and start_ts <= i[2]] elif end_ts is not None: - operator_memory_events[key] = [ - i - for i in operator_memory_events[key] - if i[2] and end_ts >= i[2] - ] + operator_memory_events[key] = [i for i in operator_memory_events[key] if + i[2] and end_ts >= i[2]] return self.respond_as_json(temp_memory_events, True) else: if start_ts is not None: @@ -482,8 +473,9 @@ class TorchProfilerPlugin(base_plugin.TBPlugin): def _monitor_runs(self): logger.info('Monitor runs begin') - touched = set() + try: + touched = set() while True: try: logger.debug('Scan run dir') diff --git a/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/__init__.py b/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/__init__.py index 59a0e64155546ce75e1c4607cf35c3144a28271f..9ca062abf58245753361a96890a2ee1ccdec42fb 100644 --- a/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/__init__.py +++ b/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/__init__.py @@ -1,6 +1,7 @@ # ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # -------------------------------------------------------------------------- -__all__ = ['RunLoader'] from .loader import RunLoader + +__all__ = ['RunLoader'] diff --git a/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/communication.py b/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/communication.py index 0afcdb11a66f89b8a448713bf140e3293db7e503..00f8dc98139d5bbb96daffb5989b9c3c660f2cbc 100644 --- a/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/communication.py +++ b/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/communication.py @@ -59,7 +59,7 @@ def analyze_communication_nodes(comm_node_list: List[CommunicationNode])\ total_comm_stats[comm_node.name][0] += 1 bytes_one_value = 0 if comm_node.input_shape: - for i, shape in enumerate(comm_node.input_shape): + for i in range(len(comm_node.input_shape)): if comm_node.input_type[i] == 'long int': bytes_one_value = 8 elif comm_node.input_type[i] == 'float': @@ -76,7 +76,7 @@ def analyze_communication_nodes(comm_node_list: List[CommunicationNode])\ logger.warning('Found an unknown tensor type: {}'.format(comm_node.input_type[i])) bytes_one_value = 0 total_size = 1 - for size in shape: + for size in comm_node.input_shape[i]: total_size *= size total_comm_stats[comm_node.name][1] += total_size * bytes_one_value total_comm_stats[comm_node.name][2].extend(comm_node.kernel_ranges) diff --git a/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/data.py b/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/data.py index 00544e635340c556d5346fc307bb29913c08929c..d6f9bb245eb2d170cb4a63e7f912a9c69932e28b 100644 --- a/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/data.py +++ b/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/data.py @@ -22,16 +22,14 @@ import gzip import io as sysio import json import math -import os.path import re import tempfile from json.decoder import JSONDecodeError from typing import Dict, List, Optional -from configparser import ConfigParser from .op_tree import OpTreeBuilder from .. import io, utils -from ..consts import InputFilesType, INPUT_FILE_LIST +from ..consts import InputFilesType, MAX_FILE_SIZE, INPUT_FILE_LIST from ..utils import href from . import trace from .communication import analyze_communication_nodes @@ -46,9 +44,6 @@ from .tensor_cores_parser import TensorCoresParser from .trace import BaseEvent, EventTypes, MemoryEvent logger = utils.get_logger() -config = ConfigParser() -config_path = os.path.join(os.getcwd(), 'torch_tb_profiler', 'config', '../config/config.ini') -config.read(config_path) class RunProfileData(object): @@ -169,8 +164,15 @@ class RunProfileData(object): has_communication_overlap = False has_communication_wait_ops = False + def _check_file_size_valid(filepath): + if io.stat(filepath).length > MAX_FILE_SIZE: + logger.warning( + f'File "{filepath}" exceeds the maximum limit size of 500MB and will be skipped.') + return False + return True + for file in io.listdir(path): - if utils.is_npu_trace_path(file) and io.check_file_valid(io.join(path, file)): + if utils.is_npu_trace_path(file) and _check_file_size_valid(io.join(path, file)): has_trace = True trace_file = io.join(path, file) trace_path, trace_json = RunProfileData._preprocess_file(trace_file, cache_dir, 'Ascend') @@ -192,7 +194,7 @@ class RunProfileData(object): profile.profiler_start_ts = 0 for file in io.listdir(path): - if str(file) in INPUT_FILE_LIST and io.check_file_valid(io.join(path, file)): + if str(file) in INPUT_FILE_LIST and _check_file_size_valid(io.join(path, file)): if InputFilesType(file) == InputFilesType.KERNEL_DETAILS_CSV: has_kernel = True profile.kernel_file_path = io.join(path, file) @@ -260,10 +262,10 @@ class RunProfileData(object): try: trace_json = json.loads(fout.getvalue()) logger.warning('Get JSONDecodeError: %s, Re-encode it to temp file' % e.msg) + json_reencode = True except JSONDecodeError: logger.error(f'File "{trace_path}" is not in a legal JSON format and will be skipped.') return trace_path, {} - json_reencode = True # work-around to remove the 'Record Window End' events to avoid the huge end timestamp if device_target == 'Ascend': @@ -361,7 +363,7 @@ class RunProfileData(object): dataloader_ratio = self.avg_costs.costs[ProfileRole.DataLoader] / self.avg_costs.costs[ProfileRole.Total] if dataloader_ratio > 0.05: percentage = dataloader_ratio * 100 - url = config.get('URL', 'pytorch_data_loading_url') + url = 'https://pytorch.org/docs/stable/data.html#single-and-multi-process-data-loading' self.recommendations.append( f'This run has high time cost on input data loading. {percentage:.1f}% of the step ' + "time is in DataLoader. You could try to set num_workers on DataLoader's construction " + @@ -373,11 +375,12 @@ class RunProfileData(object): if self.device_props: # Tensor Cores feature is available on GPU cards with compute capability >= 7.0 + # https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications major = self.device_props[0].get('computeMajor') # If it's a pure CPU run, then self.tc_used_ratio is None, this rule will not be triggered. if major is not None and major >= 7: if math.isclose(self.tc_used_ratio, 0.0) and self.tc_eligible_ops_kernel_ratio > 0.0: - url = config.get('URL', 'pytorch_amp_url') + url = 'https://pytorch.org/docs/stable/amp.html' self.recommendations.append( f'Kernels with {round(self.tc_eligible_ops_kernel_ratio * 100)}%' ' time are launched by Tensor Cores eligible operators. ' @@ -392,8 +395,8 @@ class RunProfileData(object): if total_mem is not None and peak_mem > total_mem * 0.9: percentage = peak_mem / total_mem * 100 if total_mem > 0 else 0 total_mem_gb = total_mem / 1024 / 1024 / 1024 - ckp_url = config.get('URL', 'pytorch_ckp_url') - amp_url = config.get('URL', 'pytorch_amp_url') + ckp_url = 'https://pytorch.org/docs/stable/checkpoint.html' + amp_url = 'https://pytorch.org/docs/stable/amp.html' self.recommendations.append( f'Device memory usage is at the limit of device memory capacity ' f'({percentage:.1f}% of {total_mem_gb:.1f}GB on GPU{dev_id}). ' @@ -403,7 +406,7 @@ class RunProfileData(object): def _analyze_distributed_metrics(self): if self.use_dp and len(self.used_devices) > 1: - url = config.get('URL', 'cuda_nn_ddp_instead_url') + url = 'https://pytorch.org/docs/stable/notes/cuda.html#cuda-nn-ddp-instead' self.recommendations.append( f"It is recommended to {href('use DistributedDataParallel instead of DataParallel', url)}" ' to do multi-GPU training.') @@ -425,9 +428,9 @@ class RunProfileData(object): communication_ratio = self.avg_costs.costs[ProfileRole.Communication] / self.avg_costs.costs[ProfileRole.Total] if communication_ratio > 0.1: percentage = communication_ratio * 100 - compress_url = config.get('URL', 'compress_url') - grad_acc_url = config.get('URL', 'grad_acc_url') - lamb_url = config.get('URL', 'lamb_url') + compress_url = 'https://pytorch.org/docs/stable/ddp_comm_hooks.html', + grad_acc_url = 'https://towardsdatascience.com/what-is-gradient-accumulation-in-deep-learning-ec034122cfa' + lamb_url = 'https://nvidia.github.io/apex/optimizers.html#apex.optimizers.FusedLAMB' self.recommendations.append( f'This run has high time cost on communication. {percentage:.1f}% of the step time is in ' f"communication. You could try {href('Gradient Compression', compress_url)} or " diff --git a/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/diffrun/tree.py b/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/diffrun/tree.py index c5cf5fad448122c74db46467cb0c70b8ce4f727e..a164bd3d37390ba367f0d504910e45050227ffbf 100644 --- a/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/diffrun/tree.py +++ b/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/diffrun/tree.py @@ -56,9 +56,8 @@ class DiffNode: def compare_operator_nodes( left_nodes: List[OperatorNode], right_nodes: List[OperatorNode]) -> Generator['DiffNode', None, None]: - """Given two OperatorNode lists, find the DataLoader/Module/Backward/Optimizer node and - create the child list DiffNode - """ + '''Given two OperatorNode lists, find the DataLoader/Module/Backward/Optimizer node and create the child list DiffNode + ''' right_keys = [(type(r), r.name) for r in right_nodes] # find matching points in the two list diff --git a/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/event_parser.py b/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/event_parser.py index 9b364e0dbba55e07b939690d45123bbf6dc6fe23..3cd7ce9ff662a152cc9e1e4150bfe4d762e7a691 100644 --- a/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/event_parser.py +++ b/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/event_parser.py @@ -3,7 +3,6 @@ # ------------------------------------------------------------------------- import sys from collections import defaultdict -from dataclasses import dataclass from enum import IntEnum from typing import Dict, Iterable, List, Optional, Tuple @@ -32,19 +31,11 @@ class ProfileRole(IntEnum): Total = 8 -@dataclass -class NodeInfoParams: - event: DurationEvent - corrid_to_device: Dict[int, List[DeviceNode]] - corrid_to_runtime: Dict[int, RuntimeNode] - externalid_to_runtime: Dict[int, List[RuntimeNode]] - tid2list: Dict[int, List[OperatorNode]] - pl_tid2list: Dict[int, List[PLProfileNode]] - tid2zero_rt_list: Dict[int, List[RuntimeNode]] - - class NodeParserMixin: def __init__(self, *args, **kwargs): + """Please refer to https://stackoverflow.com/questions/9575409/calling-parent-class-init-with-multiple-inheritance-whats-the-right-way # noqa: E501 + to see the reason why we need call super().__init__ like this way + """ super().__init__(*args, **kwargs) self.communication_data: Dict[int, CommunicationNode] = {} @@ -77,9 +68,14 @@ class NodeParserMixin: for event in events: if event.type == EventTypes.MEMORY: continue - params = NodeInfoParams(event, corrid_to_device, corrid_to_runtime, externalid_to_runtime, tid2list, - pl_tid2list, tid2zero_rt_list) - self._parse_node(params) + self._parse_node( + event, + corrid_to_device, + corrid_to_runtime, + externalid_to_runtime, + tid2list, + pl_tid2list, + tid2zero_rt_list) if CommLibTypes.Nccl in self.comm_lib: for event in events: @@ -120,14 +116,14 @@ class NodeParserMixin: return comm_node is not None - def _parse_node(self, params: NodeInfoParams): - event = params.event - corrid_to_device = params.corrid_to_device - corrid_to_runtime = params.corrid_to_runtime - externalid_to_runtime = params.externalid_to_runtime - tid2list = params.tid2list - pl_tid2list = params.pl_tid2list - tid2zero_rt_list = params.tid2zero_rt_list + def _parse_node(self, + event: DurationEvent, + corrid_to_device: Dict[int, List[DeviceNode]], + corrid_to_runtime: Dict[int, RuntimeNode], + externalid_to_runtime: Dict[int, List[RuntimeNode]], + tid2list: Dict[int, List[OperatorNode]], + pl_tid2list: Dict[int, List[PLProfileNode]], + tid2zero_rt_list: Dict[int, List[RuntimeNode]]): corrid = event.correlation_id tid = event.tid if event.type in [EventTypes.KERNEL, EventTypes.MEMCPY, EventTypes.MEMSET]: @@ -230,8 +226,8 @@ class StepParser: self.steps.append((self.cpu_min_ts, self.cpu_max_ts)) self.steps_names.append('0') - for i, role_range in enumerate(self.role_ranges): - self.role_ranges[i] = merge_ranges(role_range) + for i in range(len(self.role_ranges)): + self.role_ranges[i] = merge_ranges(self.role_ranges[i]) def update_device_steps(self, runtime_node_list: List[RuntimeNode]): self._update_steps_duration(*self._find_device_steps(runtime_node_list)) @@ -366,9 +362,9 @@ class StepParser: # Change step time to device side on the condition that any step have device time. is_use_gpu = prev_step_end_time is not None if is_use_gpu: - for i_step, step in enumerate(self.steps): - step_start_time = max(prev_step_end_time, step[0]) - step_end_time = step[1] + for i_step in range(len(self.steps)): + step_start_time = max(prev_step_end_time, self.steps[i_step][0]) + step_end_time = self.steps[i_step][1] if steps_device[i_step][0] == sys.maxsize: # When step i_step has no device event. # Assign to step_start_time when kernel is behind host step end. step_end_time = max(step_end_time, step_start_time) @@ -406,7 +402,7 @@ class StepParser: class EventParser(NodeParserMixin, StepParser): def __init__(self): super().__init__() - self.comm_node_list: List[CommunicationNode] = None + self.comm_node_list: Dict[CommunicationNode] = None def parse(self, events: Iterable[BaseEvent], fwd_bwd_map: Dict[int, int]) -> Dict[int, List[OperatorNode]]: with utils.timing('EventParser: parse nodes'): @@ -443,10 +439,10 @@ class EventParser(NodeParserMixin, StepParser): header = f'[{ctx.tid}]' + '.'.join(ctx.name_stack[1:]) # omit the CallTreeRoot prefix_len = len(ctx.name_stack) * 4 - 4 - 1 if len(ctx.name_stack) > 1: - logger.info(header) + print(header) prefix = ' ' * prefix_len - logger.info(prefix, node.name) - logger.info(prefix, 'time:', node.start_time, '-->', node.end_time) + print(prefix, node.name) + print(prefix, 'time:', node.start_time, '-->', node.end_time) def push(node: OperatorNode): ctx.name_stack.append(node.name) diff --git a/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/kernel_parser.py b/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/kernel_parser.py index 229251e60a90d5bf4fed514d5f175199b92d3870..838fc38ce60619977c3e096791241d7fc697562d 100644 --- a/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/kernel_parser.py +++ b/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/kernel_parser.py @@ -6,7 +6,7 @@ from typing import Optional import numpy as np import pandas as pd -from .tensor_core import TcAllowlist +from .tensor_core import TC_Allowlist from .trace import EventTypes @@ -19,7 +19,7 @@ class KernelParser: events = [vars(event) for event in events if event.type == EventTypes.KERNEL] events = pd.DataFrame(events) events = events.astype({'type': 'category', 'name': 'string'}, copy=False) - events['tc_used'] = events['name'].map(lambda name: name in TcAllowlist) + events['tc_used'] = events['name'].map(lambda name: name in TC_Allowlist) def weighted_avg(x: pd.Series): try: diff --git a/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/memory_parser.py b/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/memory_parser.py index 64b78127a4c7a5675e5b2f71877754c541dde94f..766782be271240dabffc76bbc389d8659e601299 100644 --- a/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/memory_parser.py +++ b/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/memory_parser.py @@ -25,7 +25,7 @@ class MemoryMetrics(IntEnum): class MemoryRecord: def __init__(self, scope: str, pid: int, tid: int, ts: int, device_type: DeviceType, device_id: int, - address: int, record_bytes: int, total_allocated: float, total_reserved: float): + address: int, bytes: int, total_allocated: float, total_reserved: float): self.scope = scope self.tid = tid self.pid = pid @@ -33,7 +33,7 @@ class MemoryRecord: self.device_type = device_type self.device_id = device_id self.addr = address - self.bytes = record_bytes + self.bytes = bytes self.total_allocated = total_allocated self.total_reserved = total_reserved self.op_name: Optional[str] = None @@ -132,7 +132,7 @@ class MemorySnapshot: for i in range(self_metric_length, metric_length): memory_metrics_keyed_by_node[node][device][i] += metrics[i] - for _, root in tid2tree.items(): + for tid, root in tid2tree.items(): for child in root.children: traverse_node_memory(child) @@ -217,8 +217,7 @@ class MemoryParser: """In the loop, one pass will process one record. The basic logic is: It will search from the node that last visited since both the records and tree is ordered already 1. it current node contains the records, then find the exactly child which just embrace it. - 2. otherwise, find the parent node and set the child_index, so that the parent node could continue from - previous visited node. # noqa: E501 + 2. otherwise, find the parent node and set the child_index, so that the parent node could continue from previous visited node. # noqa: E501 3. if there is not any node contains the records, then all remaining records will be ignored. """ record = records[record_index] diff --git a/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/module_op.py b/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/module_op.py index 15f1e4ef93a5234cdf6273f9830ac1a6f3aeaa41..061a503b411bb900c6a405c0b97c8a07dd986a00 100644 --- a/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/module_op.py +++ b/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/module_op.py @@ -260,3 +260,10 @@ def get_module_tree(tid2tree: Dict[int, OperatorNode]): traverse_node(child, None) return modules + + +def dump_modules(level: int, modules: Iterable[Union[Module, ModuleNode]]): + """testing purpose""" + for module in modules: + print(f"{' ' * level}{module.name.replace('nn.Module: ', '')}_{module.module_id}") + dump_modules(level + 1, module.children) diff --git a/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/node.py b/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/node.py index 0528491c28752b0358d79e27168d055546bd0310..80860e53661e9a554de6fa9b09e6f13057fca8bb 100644 --- a/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/node.py +++ b/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/node.py @@ -6,7 +6,7 @@ from abc import ABC from typing import List, Optional, Tuple from .. import utils -from .tensor_core import TcAllowlist, TcOpAllowlist +from .tensor_core import TC_Allowlist, TC_OP_Allowlist from .trace import (DurationEvent, EventTypes, KernelEvent, ModuleEvent, OperatorEvent, PLProfileEvent, NcclOpNameSet, GlooOpNameSet) @@ -16,12 +16,12 @@ ExcludeOpName = ['DataParallel.forward', 'DistributedDataParallel.forward'] class BaseNode(ABC): - def __init__(self, name: str, start_time: int, end_time: int, node_type: str, tid: int, + def __init__(self, name: str, start_time: int, end_time: int, type: str, tid: int, external_id: Optional[int] = None): self.name = name self.start_time = start_time self.end_time = end_time - self.type = node_type + self.type = type self.tid = tid self.external_id = external_id # For consistency check. @@ -31,7 +31,7 @@ class BaseNode(ABC): kwargs['name'] = event.name kwargs['start_time'] = event.ts kwargs['end_time'] = event.ts + event.duration - kwargs['node_type'] = event.type + kwargs['type'] = event.type kwargs['tid'] = event.tid external_id = getattr(event, 'external_id', None) @@ -84,18 +84,15 @@ class OperatorNode(HostNode): self.callstack = callstack self.self_host_duration = self_host_duration self.self_device_duration = self_device_duration - self.tc_eligible = self.name in TcOpAllowlist + # self.parent_node = None + self.tc_eligible = self.name in TC_OP_Allowlist self.tc_self_duration = 0 # Time of TC kernels launched by this op excluding its children operators. self.tc_total_duration = 0 # Time of TC kernels launched by this op including its children operators. def fill_stats(self): - def sort_key(x): - if x.start_time and x.end_time: - return x.start_time, -x.end_time - else: - return sys.maxsize, -sys.maxsize - 1 self.children.sort(key=lambda x: (x.start_time, -x.end_time)) - self.runtimes.sort(key=sort_key) + self.runtimes.sort(key=lambda x: (x.start_time, -x.end_time) + if x.start_time and x.end_time else (sys.maxsize, -sys.maxsize - 1)) for child in self.children: child.fill_stats() @@ -276,7 +273,7 @@ class DeviceNode(BaseNode): self.block = block self.regs_per_thread = regs_per_thread self.shared_memory = shared_memory - self.tc_used = self.name in TcAllowlist + self.tc_used = self.name in TC_Allowlist self.device_id = device_id @classmethod @@ -309,7 +306,7 @@ def create_operator_node(event: OperatorEvent): def is_operator_node(node: BaseNode): - return bool(isinstance(node, OperatorNode) and node.type == EventTypes.OPERATOR and node.name not in ExcludeOpName + return bool(type(node) is OperatorNode and node.type == EventTypes.OPERATOR and node.name not in ExcludeOpName and not node.name.startswith("Optimizer.")) # exclude Optimizer.zero_grad diff --git a/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/op_agg.py b/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/op_agg.py index d6fdb5903d368e02c4ddb9fc3f29f536696e2a2e..08a3f0d7061dc332a78ec97a6ff085bf1840a47d 100644 --- a/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/op_agg.py +++ b/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/op_agg.py @@ -49,6 +49,7 @@ def aggregate_ops(op_list: List[OperatorNode], agg.self_device_duration += op.self_device_duration agg.tc_self_duration += op.tc_self_duration agg.tc_total_duration += op.tc_total_duration + return agg agg_dicts: List[Dict[str, OperatorAgg]] = [{} for _ in range(len(keys_func))] for op in op_list: diff --git a/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/op_tree.py b/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/op_tree.py index fe919b29ced02efcea862f5e83ab52704f3f0d09..55e264617d835fb5bf94819b329fdbd2ee1c53f6 100644 --- a/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/op_tree.py +++ b/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/op_tree.py @@ -68,10 +68,9 @@ class OpTreeBuilder: if main_tid: # only append the staled device nodes into main thread self.main_tid = op_list[0].tid - root_node = OpTreeBuilder._build_tree_internal(op_list, zero_rt_list, tid, staled_device_nodes, - is_ascend) + root_node = self._build_tree_internal(op_list, zero_rt_list, tid, staled_device_nodes, is_ascend) else: - root_node = OpTreeBuilder._build_tree_internal(op_list, zero_rt_list, tid, [], is_ascend) + root_node = self._build_tree_internal(op_list, zero_rt_list, tid, [], is_ascend) tid2tree[int(tid)] = root_node return tid2tree @@ -84,8 +83,7 @@ class OpTreeBuilder: # there are multiple tids backward_tid = self._find_backward_tid() tid2len = { - tid: root.end_time - root.start_time - for tid, root in self.tid2tree.items() + tid: root.end_time - root.start_time for tid, root in self.tid2tree.items() if tid != backward_tid or backward_tid is None } # get the maximum length as the main thread @@ -99,8 +97,7 @@ class OpTreeBuilder: return None - @staticmethod - def _build_tree_internal(host_node_list, zero_rt_list, tid, staled_device_nodes, is_ascend): + def _build_tree_internal(self, host_node_list, zero_rt_list, tid, staled_device_nodes, is_ascend): """host_node_list: list of OperatorNode and ProfilerStepNode. zero_rt_list: list of RuntimeNode with external_id=0.""" @@ -113,7 +110,7 @@ class OpTreeBuilder: name='dummy', start_time=None, end_time=None, - node_type=EventTypes.RUNTIME, + type=EventTypes.RUNTIME, tid=0, device_nodes=staled_device_nodes)) dummpy_rt[0].fill_stats() @@ -122,7 +119,7 @@ class OpTreeBuilder: name='CallTreeRoot', start_time=-sys.maxsize - 1, end_time=sys.maxsize, - node_type=EventTypes.PYTHON, + type=EventTypes.PYTHON, tid=tid, runtimes=zero_rt_list + dummpy_rt) # Give the list of RuntimeNode with external_id=0 to root node. node_stack.append(root_node) @@ -133,6 +130,7 @@ class OpTreeBuilder: if node.end_time <= tail_node.end_time or ( is_ascend and math.isclose(node.end_time, tail_node.end_time, rel_tol=1)): tail_node.children.append(node) + # node.parent_node = weakref.ref(tail_node) node_stack.append(node) else: logger.error('Error in input data: ranges on the same thread should not intersect!' @@ -276,7 +274,7 @@ class OpTreeBuilder: if isinstance(node, ModuleNode): backward_node = BackwardNode(name=node.name + '.backward', start_time=None, end_time=None, - node_type='backward', tid=0) + type='backward', tid=0) if parent is None: result.append(backward_node) else: diff --git a/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/overall_parser.py b/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/overall_parser.py index c646a33b89a673e1738fd38704516df8bfdfaade..e12fbfd1cc502accee83fb44c52b94f8253c64ce 100644 --- a/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/overall_parser.py +++ b/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/overall_parser.py @@ -23,8 +23,8 @@ class OverallParser(object): @classmethod def create_from_statistics(cls, statistics: 'OverallParser.Statistics', total_duration: int): costs = [0.] * len(ProfileRole) - for i, cost_range in enumerate(statistics.cost_ranges): - costs[i] = get_ranges_sum(cost_range) + for i in range(len(statistics.cost_ranges)): + costs[i] = get_ranges_sum(statistics.cost_ranges[i]) costs[ProfileRole.Total] = total_duration return cls(costs) @@ -58,8 +58,8 @@ class OverallParser(object): def intersection_with_step(self, step: Tuple[int, int]): cost_ranges: List[List[Tuple[int, int]]] = [] step = [step] - for cost_range in self.cost_ranges: - cost_ranges.append(intersection_ranges_lists(step, cost_range)) + for range in self.cost_ranges: + cost_ranges.append(intersection_ranges_lists(step, range)) return OverallParser.Statistics(cost_ranges) @@ -77,9 +77,6 @@ class OverallParser(object): def aggregate(self, steps: List[Tuple[int, int]], role_ranges: List[List[Tuple[int, int]]]): logger.debug('Overall, statistics') - if len(steps) <= 0: - logger.error('Invalid steps number of 0') - return global_stats = OverallParser.Statistics.create_from_range(steps, role_ranges) if role_ranges[ProfileRole.Kernel]: comm_comp_overlap = intersection_ranges_lists( @@ -92,7 +89,7 @@ class OverallParser(object): for i, step in enumerate(steps): steps_stat = global_stats.intersection_with_step(step) self.steps_costs.append(OverallParser.Costs.create_from_statistics(steps_stat, step[1] - step[0])) - for cost_index, _ in enumerate(self.avg_costs.costs): + for cost_index in range(len(self.avg_costs.costs)): self.avg_costs.costs[cost_index] += self.steps_costs[i].costs[cost_index] comm_costs = OverallParser.StepCommunicationCosts() @@ -110,5 +107,5 @@ class OverallParser(object): self.communication_overlap.append(comm_costs) valid_steps = len(steps) - for i, _ in enumerate(self.avg_costs.costs): + for i in range(len(self.avg_costs.costs)): self.avg_costs.costs[i] /= valid_steps diff --git a/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/run_generator.py b/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/run_generator.py index 111dc34e81031a33ff9e0a2c03b0375522de24cf..f2ab0452ec733783d880abfebae948f8ec4b3e6e 100644 --- a/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/run_generator.py +++ b/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/run_generator.py @@ -2,6 +2,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # # Copyright(c) 2023 Huawei Technologies. +# All rights reserved # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -48,140 +49,6 @@ class RunGenerator(object): self.component_curve_data = {} self.process_data = {} - @staticmethod - def check_overlap_data(title): - # csv: step / compute time / communication_not_overlap / overlap / communication / free time - length = len(title) - if length < 5: - return [] - key = ["computing", "overlapped", "communication(not overlapped)", "free"] - get_key = list() - for j in key: - for i in range(length): - if j == title[i]: - get_key.append(i) - if len(get_key) < 4: - return [] - return get_key - - @staticmethod - def get_table_head(name: str, input_shape: str, call_stack: str, value: list): - if name is None: - return {} - temp = { - 'name': name, 'calls': 0, 'host_self_duration': 0, - 'host_total_duration': 0, 'device_self_duration': 0, 'device_total_duration': 0, - 'tc_self_ratio': 0, 'tc_total_ratio': 0, 'tc_eligible': 'Yes' - } - if input_shape is not None: - temp['input_shape'] = input_shape - if call_stack is not None: - temp['call_stack'] = call_stack - else: - temp['has_call_stack'] = False - else: - if call_stack is not None: - temp['call_stack'] = call_stack - else: - temp['has_call_stack'] = False - for vl in iter(value): - if 'has_call_stack' in temp and vl[2]: - temp['has_call_stack'] = True - temp['calls'] += 1 - temp['host_self_duration'] = round(temp['host_self_duration'] + vl[3], 2) - temp['host_total_duration'] = round(temp['host_total_duration'] + vl[4], 2) - temp['device_self_duration'] = round(temp['device_self_duration'] + vl[5], 2) - temp['device_total_duration'] = round(temp['device_total_duration'] + vl[6], 2) - temp['tc_self_ratio'] = round(temp['tc_self_ratio'] + vl[7], 2) - temp['tc_total_ratio'] = round(temp['tc_total_ratio'] + vl[8], 2) - temp['tc_eligible'] = 'Yes' if temp['tc_self_ratio'] > 0 or temp['tc_total_ratio'] > 0 else 'No' - temp['tc_self_ratio'] = 0 if temp['device_self_duration'] == 0 \ - else round(temp['tc_self_ratio'] / temp['device_self_duration'] * 100, 2) - temp['tc_total_ratio'] = 0 if temp['device_total_duration'] == 0 \ - else round(temp['tc_total_ratio'] / temp['device_total_duration'] * 100, 2) - return temp - - @staticmethod - def get_wait_table_by_ops(op, ops): - total_trans = 0 - total_synchronize = 0 - for key, data in op.items(): - if str(key) == "Total Op Info" and data.get("Communication Time Info"): - total_trans += float(data.get("Communication Time Info").get("Transit Time(ms)")) - total_synchronize += float(data.get("Communication Time Info").get("Synchronization Time(ms)")) - continue - k = re.sub(r'[0-9]+', ' ', key).split(" ")[0] - if k not in ops: - ops[k] = [0, 0, 0, 0] - ops[k][0] += 1 - for _, band in data.get("Communication Bandwidth Info").items(): - ops[k][1] += float(band.get("Transit Size(MB)")) - if data.get("Communication Time Info") is not None: - ops[k][2] += data.get("Communication Time Info").get("Elapse Time(ms)") - ops[k][3] += data.get("Communication Time Info").get("Transit Time(ms)") - return total_trans, total_synchronize - - @staticmethod - def trans_shape(shape: str): - result = list() - if ';' not in shape: - result.append('[' + shape.strip() + ']') - return '[' + ', '.join(result) + ']' - if len(shape.strip()) <= 1: - result.append('[]') - return '[' + ', '.join(result) + ']' - shape_spl = shape.split("\n") - for shape_div in iter(shape_spl): - result.append('[' + str(shape_div.replace(';', '')) + ']') - return '[' + ', '.join(result) + ']' - - @staticmethod - def get_process_peaks_and_devices_type(process_data: dict, memory_metric: str): - devices_type = [] - peaks = {} - for device in process_data: - devices_type.append(device) - reserved_list = process_data.get(device).get('Allocated') - if reserved_list is not None: - max_reserved = 0 - for array_value in reserved_list: - max_reserved = max(array_value[1], max_reserved) - peaks[device] = f'Peak Memory Usage: {max_reserved:.1f}{memory_metric}' - return devices_type, peaks - - @staticmethod - def get_pta_ge_peaks_and_devices_type(process_data: dict, memory_metric: str): - devices_type = [] - peaks = {} - for device in process_data: - devices_type.append(device) - peaks[device] = 'Reserved Peak Memory Usage:' - for component in process_data.get(device): - max_reserved = 0 - for array_value in process_data.get(device).get(component): - max_reserved = max(array_value[2], max_reserved) - peaks[device] += f' {component}-{max_reserved:.1f}{memory_metric} |' - return devices_type, peaks - - @staticmethod - def check_csv_columns(columns: list, column_idxs: dict): - column_exist_count = 0 - for idx, column in enumerate(columns): - if column in column_idxs: - column_idxs[column] = idx - column_exist_count += 1 - return column_idxs.values(), column_exist_count - - @staticmethod - def get_csv_data(path: str): - if path is None: - return [] - datas = [] - with open(path, encoding='utf-8-sig') as f: - for row in csv.reader(f, skipinitialspace=True): - datas.append(row) - return datas - def generate_run_profile(self): profile_run = RunProfile(self.worker, self.span) profile_run.is_pytorch_lightning = self.profile_data.is_pytorch_lightning @@ -218,7 +85,7 @@ class RunGenerator(object): profile_run.gpu_metrics = self.profile_data.gpu_metrics_parser.get_gpu_metrics() - gpu_infos = {gpu_id: RunGenerator.get_gpu_info(self.profile_data.device_props, gpu_id) + gpu_infos = {gpu_id: RunGenerator._get_gpu_info(self.profile_data.device_props, gpu_id) for gpu_id in self.profile_data.gpu_metrics_parser.gpu_ids} gpu_infos = {gpu_id: gpu_info for gpu_id, gpu_info in gpu_infos.items() if gpu_info is not None} @@ -273,11 +140,11 @@ class RunGenerator(object): def _npu_get_overlap(self): path = self.profile_data.distributed_csv_path overlap_by_steps: Dict[str, List[float]] = OrderedDict() - data = RunGenerator.get_csv_data(path) + data = RunGenerator._get_csv_data(path) if len(data) <= 1: return overlap_by_steps title = [x.lower() for x in data[0]] - title_name = RunGenerator.check_overlap_data(title) + title_name = RunGenerator._check_overlap_data(title) if not title_name: logger.error(f"Incomplete content of CSV file {path}.") return overlap_by_steps @@ -287,10 +154,8 @@ class RunGenerator(object): key = step[0] if key == '': key = 'all' - overlap = [ - float(step[int(title_name[0])]), float(step[int(title_name[1])]), - float(step[int(title_name[2])]), float(step[int(title_name[3])]) - ] + overlap = [float(step[int(title_name[0])]), float(step[int(title_name[1])]), + float(step[int(title_name[2])]), float(step[int(title_name[3])])] if key in overlap_by_steps: overlap_by_steps[key] = list(np.add(overlap, overlap_by_steps[key])) else: @@ -299,6 +164,22 @@ class RunGenerator(object): logger.error(f'File "{path}" has wrong data format in row {idx + 2} and will skip it.') return overlap_by_steps + @staticmethod + def _check_overlap_data(title): + # csv: step / compute time / communication_not_overlap / overlap / communication / free time + length = len(title) + if length < 5: + return [] + key = ["computing", "overlapped", "communication(not overlapped)", "free"] + get_key = list() + for j in key: + for i in range(length): + if j == title[i]: + get_key.append(i) + if len(get_key) < 4: + return [] + return get_key + def _npu_get_wait_table(self): path = self.profile_data.communication_json_path if not io.exists(path): @@ -333,9 +214,9 @@ class RunGenerator(object): collection_ops = data.get("collective") p2p_ops = data.get("p2p") try: - coll_total_trans, coll_total_synchronize = RunGenerator.get_wait_table_by_ops(collection_ops, - table_ops) - p2p_total_trans, p2p_total_synchronize = RunGenerator.get_wait_table_by_ops(p2p_ops, table_ops) + coll_total_trans, coll_total_synchronize = RunGenerator._get_wait_table_by_ops(collection_ops, + table_ops) + p2p_total_trans, p2p_total_synchronize = RunGenerator._get_wait_table_by_ops(p2p_ops, table_ops) except ValueError: logger.error(f'Time and size info must be number, please check file "{path}"') return wait_by_step, table_ops @@ -346,21 +227,39 @@ class RunGenerator(object): } return wait_by_step, table_ops + @staticmethod + def _get_wait_table_by_ops(op, ops): + total_trans = 0 + total_synchronize = 0 + for key, data in op.items(): + if str(key) == "Total Op Info" and data.get("Communication Time Info"): + total_trans += float(data.get("Communication Time Info").get("Transit Time(ms)")) + total_synchronize += float(data.get("Communication Time Info").get("Synchronization Time(ms)")) + continue + k = re.sub(r'[0-9]+', ' ', key).split(" ")[0] + if k not in ops: + ops[k] = [0, 0, 0, 0] + ops[k][0] += 1 + for _, band in data.get("Communication Bandwidth Info").items(): + ops[k][1] += float(band.get("Transit Size(MB)")) + if data.get("Communication Time Info") is not None: + ops[k][2] += data.get("Communication Time Info").get("Elapse Time(ms)") + ops[k][3] += data.get("Communication Time Info").get("Transit Time(ms)") + return total_trans, total_synchronize + def _get_operator_details_by_name(self): operator_by_name = defaultdict(list) operator_by_name_and_input_shapes = defaultdict(list) path = self.profile_data.operator_path - datas = RunGenerator.get_csv_data(path) + datas = RunGenerator._get_csv_data(path) if len(datas) <= 1: return operator_by_name, operator_by_name_and_input_shapes for idx, ls in enumerate(datas[1:]): try: - temp: list = [ - ls[0], RunGenerator.trans_shape(str(ls[1])), ls[2], float(ls[3]), float(ls[4]), - float(ls[5]), float(ls[6]), float(ls[7]), float(ls[8]) - ] + temp: list = [ls[0], RunGenerator._trans_shape(str(ls[1])), ls[2], float(ls[3]), float(ls[4]), + float(ls[5]), float(ls[6]), float(ls[7]), float(ls[8])] operator_by_name[ls[0]].append(temp) - key = "{}###{}".format(str(ls[0]), RunGenerator.trans_shape(str(ls[1]))) + key = "{}###{}".format(str(ls[0]), RunGenerator._trans_shape(str(ls[1]))) operator_by_name_and_input_shapes[key].append(temp) except (ValueError, IndexError): logger.error(f'File "{path}" has wrong data format in row {idx + 2} and will skip it.') @@ -382,10 +281,8 @@ class RunGenerator(object): def _get_operator_pie(self, group_by_input_shape=False): data = {} - tag = { - 'device_self_time': 'Device Self Time (us)', 'device_total_time': 'Device Total Time (us)', - 'host_self_time': 'Host Self Time (us)', 'host_total_time': 'Host Total Time (us)' - } + tag = {'device_self_time': 'Device Self Time (us)', 'device_total_time': 'Device Total Time (us)', + 'host_self_time': 'Host Self Time (us)', 'host_total_time': 'Host Total Time (us)'} for key, value in tag.items(): data[key] = { 'title': value, @@ -410,9 +307,9 @@ class RunGenerator(object): if group_by_input_shape: name = name_key.split("###")[0] shape = name_key.split("###")[1] - result.append(RunGenerator.get_table_head(name, shape, None, values)) + result.append(RunGenerator._get_table_head(name, shape, None, values)) else: - result.append(RunGenerator.get_table_head(name_key, None, None, values)) + result.append(RunGenerator._get_table_head(name_key, None, None, values)) return result def _set_name_callstack_data(self, group_by_input_shape=False): @@ -447,10 +344,24 @@ class RunGenerator(object): 'data': [] } for callstack_key, value in values.items(): - table['data'].append(RunGenerator.get_table_head(name, shape, callstack_key, value)) + table['data'].append(RunGenerator._get_table_head(name, shape, callstack_key, value)) result[name_key] = table return result + @staticmethod + def _trans_shape(shape: str): + result = list() + if ';' not in shape: + result.append('[' + shape.strip() + ']') + return '[' + ', '.join(result) + ']' + if len(shape.strip()) <= 1: + result.append('[]') + return '[' + ', '.join(result) + ']' + shape_spl = shape.split("\n") + for shape_div in iter(shape_spl): + result.append('[' + str(shape_div.replace(';', '')) + ']') + return '[' + ', '.join(result) + ']' + def _get_call_stack_by_name(self): result = dict() name_callstack_data = self._set_name_callstack_data() @@ -467,10 +378,45 @@ class RunGenerator(object): 'data': [] } for callstack_key, value in values.items(): - table['data'].append(RunGenerator.get_table_head(name_key, None, callstack_key, value)) + table['data'].append(RunGenerator._get_table_head(name_key, None, callstack_key, value)) result[name_key] = table return result + @staticmethod + def _get_table_head(name: str, input_shape: str, call_stack: str, value: list): + if name is None: + return {} + temp = {'name': name, 'calls': 0, 'host_self_duration': 0, + 'host_total_duration': 0, 'device_self_duration': 0, 'device_total_duration': 0, + 'tc_self_ratio': 0, 'tc_total_ratio': 0, 'tc_eligible': 'Yes'} + if input_shape is not None: + temp['input_shape'] = input_shape + if call_stack is not None: + temp['call_stack'] = call_stack + else: + temp['has_call_stack'] = False + else: + if call_stack is not None: + temp['call_stack'] = call_stack + else: + temp['has_call_stack'] = False + for vl in iter(value): + if 'has_call_stack' in temp and vl[2]: + temp['has_call_stack'] = True + temp['calls'] += 1 + temp['host_self_duration'] = round(temp['host_self_duration'] + vl[3], 2) + temp['host_total_duration'] = round(temp['host_total_duration'] + vl[4], 2) + temp['device_self_duration'] = round(temp['device_self_duration'] + vl[5], 2) + temp['device_total_duration'] = round(temp['device_total_duration'] + vl[6], 2) + temp['tc_self_ratio'] = round(temp['tc_self_ratio'] + vl[7], 2) + temp['tc_total_ratio'] = round(temp['tc_total_ratio'] + vl[8], 2) + temp['tc_eligible'] = 'Yes' if temp['tc_self_ratio'] > 0 or temp['tc_total_ratio'] > 0 else 'No' + temp['tc_self_ratio'] = 0 if temp['device_self_duration'] == 0 \ + else round(temp['tc_self_ratio'] / temp['device_self_duration'] * 100, 2) + temp['tc_total_ratio'] = 0 if temp['device_total_duration'] == 0 \ + else round(temp['tc_total_ratio'] / temp['device_total_duration'] * 100, 2) + return temp + def _get_memory_event(self, peak_memory_events: dict): display_columns = ('Name', 'Size(KB)', 'Allocation Time(us)', 'Release Time(us)', 'Duration(us)') path = self.profile_data.memory_operator_path @@ -484,16 +430,10 @@ class RunGenerator(object): 'columns': [], 'rows': {} } - datas = RunGenerator.get_csv_data(path) - if len(datas) < 1: - return { - 'operator': table, - 'component': peak_memory_events - } - device_type_form_idx = -1 + datas = RunGenerator._get_csv_data(path) for idx, column in enumerate(datas[0]): if column == 'Device Type': - device_type_form_idx = idx + self.device_type_form_idx = idx if column in display_columns: if column == 'Name': table['columns'].append({'name': column, 'type': 'string'}) @@ -504,22 +444,20 @@ class RunGenerator(object): table['columns'].append({'name': column.replace('(us)', '(ms)'), 'type': 'number'}) required_column_idxs = {key: -1 for key in display_columns} (name_idx, size_idx, allocation_idx, release_idx, duration_idx), column_exist_count = \ - RunGenerator.check_csv_columns(datas[0], required_column_idxs) - if device_type_form_idx < 0 or column_exist_count < len(required_column_idxs): - raise ValueError('Required column is missing in file "operator_memory.csv"') + RunGenerator._check_csv_columns(datas[0], required_column_idxs) + if column_exist_count < len(required_column_idxs): + logger.error('Required column is missing in file "operator_memory.csv"') for idx, ls in enumerate(datas[1:]): - device_type = ls[device_type_form_idx] + device_type = ls[self.device_type_form_idx] # convert time metric 'us' to 'ms' # some operators may not have the following columns try: - nums = [ - ls[name_idx] if ls[name_idx] else '', abs(float(ls[size_idx])), + nums = [ls[name_idx] if ls[name_idx] else '', abs(float(ls[size_idx])), round((float(ls[allocation_idx]) - self.profile_data.profiler_start_ts) / 1000, 3) if ls[ allocation_idx] else None, round((float(ls[release_idx]) - self.profile_data.profiler_start_ts) / 1000, 3) if ls[ release_idx] else None, - round(float(ls[duration_idx]) / 1000, 3) if ls[duration_idx] else None - ] + round(float(ls[duration_idx]) / 1000, 3) if ls[duration_idx] else None] display_datas[device_type].append(nums) except ValueError: logger.error(f'File "{path}" has wrong data format in row {idx + 2} and will skip it.') @@ -536,8 +474,8 @@ class RunGenerator(object): time_metric: str = 'ms' memory_metric: str = 'MB' cano = Canonicalizer(time_metric, memory_metric) - process_devices_type, process_peaks = RunGenerator.get_process_peaks_and_devices_type(self.process_data, - memory_metric) + process_devices_type, process_peaks = RunGenerator._get_process_peaks_and_devices_type(self.process_data, + memory_metric) total_result = { 'metadata': { 'devices': process_devices_type, @@ -564,8 +502,8 @@ class RunGenerator(object): if len(total_result['columns'][device]) > 0: total_result['columns'][device].insert(0, {'name': f'Time ({cano.time_metric})', 'type': 'number', 'tooltip': 'Time since profiler starts.'}) - pta_ge_devices_type, pta_ge_peaks = RunGenerator.get_pta_ge_peaks_and_devices_type(self.component_curve_data, - memory_metric) + pta_ge_devices_type, pta_ge_peaks = RunGenerator._get_pta_ge_peaks_and_devices_type(self.component_curve_data, + memory_metric) component_curve_result = { 'metadata': { 'devices': pta_ge_devices_type, @@ -609,11 +547,48 @@ class RunGenerator(object): 'ptaGe': component_curve_result } + @staticmethod + def _get_process_peaks_and_devices_type(process_data: dict, memory_metric: str): + devices_type = [] + peaks = {} + for device in process_data: + devices_type.append(device) + reserved_list = process_data.get(device).get('Allocated') + if reserved_list is not None: + max_reserved = 0 + for array_value in reserved_list: + max_reserved = max(array_value[1], max_reserved) + peaks[device] = f'Peak Memory Usage: {max_reserved:.1f}{memory_metric}' + return devices_type, peaks + + @staticmethod + def _get_pta_ge_peaks_and_devices_type(process_data: dict, memory_metric: str): + devices_type = [] + peaks = {} + for device in process_data: + devices_type.append(device) + peaks[device] = 'Reserved Peak Memory Usage:' + for component in process_data.get(device): + max_reserved = 0 + for array_value in process_data.get(device).get(component): + max_reserved = max(array_value[2], max_reserved) + peaks[device] += f' {component}-{max_reserved:.1f}{memory_metric} |' + return devices_type, peaks + + @staticmethod + def _check_csv_columns(columns: list, column_idxs: dict): + column_exist_count = 0 + for idx, column in enumerate(columns): + if column in column_idxs: + column_idxs[column] = idx + column_exist_count += 1 + return column_idxs.values(), column_exist_count + def _handle_memory_data(self): process_data = defaultdict() pta_or_ge_data = defaultdict() path = self.profile_data.memory_curve_path - datas = RunGenerator.get_csv_data(path) + datas = RunGenerator._get_csv_data(path) required_column_idxs = { 'Component': -1, 'Device Type': -1, @@ -622,7 +597,7 @@ class RunGenerator(object): 'Total Allocated(MB)': -1 } (tag_type_idx, device_type_idx, time_idx, reserved_idx, allocated_idx), column_exist_count = \ - RunGenerator.check_csv_columns(datas[0], required_column_idxs) + RunGenerator._check_csv_columns(datas[0], required_column_idxs) if column_exist_count < len(required_column_idxs): logger.error('Required column is missing in file "memory_record.csv"') else: @@ -640,10 +615,8 @@ class RunGenerator(object): pta_or_ge_data.setdefault(device_type, {}).setdefault(ls[tag_type_idx], []).append( line_chart_data) elif ls[tag_type_idx] in ('PTA', 'GE'): - line_chart_data = [ - time_column, round(float(ls[allocated_idx]), 3), - round(float(ls[reserved_idx]), 3) - ] + line_chart_data = [time_column, round(float(ls[allocated_idx]), 3), + round(float(ls[reserved_idx]), 3)] pta_or_ge_data.setdefault(device_type, {}).setdefault(ls[tag_type_idx], []).append( line_chart_data) except ValueError: @@ -663,7 +636,7 @@ class RunGenerator(object): } peak_memory_rows = defaultdict(list) path = self.profile_data.memory_component_path - component_datas = RunGenerator.get_csv_data(path) + component_datas = RunGenerator._get_csv_data(path) if component_datas: required_column_idxs = { 'Component': -1, @@ -672,7 +645,7 @@ class RunGenerator(object): 'Device': -1 } (tag_type_idx, time_idx, reserved_idx, device_type_idx), column_exist_count = \ - RunGenerator.check_csv_columns(component_datas[0], required_column_idxs) + RunGenerator._check_csv_columns(component_datas[0], required_column_idxs) if column_exist_count < len(required_column_idxs): logger.error(f'Required column is missing in file "{path}"') else: @@ -718,16 +691,14 @@ class RunGenerator(object): '{}: {}us
' 'Percentage: {}%' '') - percentage = 0.0 if costs.costs[ProfileRole.Total] == 0 else round( - 100 * part_cost / costs.costs[ProfileRole.Total], 2) + percentage = round(100 * part_cost / costs.costs[ProfileRole.Total], 2) return format_str.format(step_name, costs.costs[ProfileRole.Total], part_name, part_cost, percentage) def build_avg_cost_dict(part_name: str, part_cost: float): - profiler_total_cost = self.profile_data.avg_costs.costs[ProfileRole.Total] cost_dict = {'name': part_name, 'description': '', 'value': round(part_cost), - 'extra': 0.0 if profiler_total_cost == 0 else round(100 * part_cost / profiler_total_cost, 2)} + 'extra': round(100 * part_cost / self.profile_data.avg_costs.costs[ProfileRole.Total], 2)} return cost_dict show_gpu = (self.profile_data.has_runtime @@ -746,7 +717,8 @@ class RunGenerator(object): data['steps']['columns'].extend(['DataLoader', 'CPU Exec', 'Other']) data['steps']['rows'] = [] - for i, costs in enumerate(self.profile_data.steps_costs): + for i in range(len(self.profile_data.steps_costs)): + costs = self.profile_data.steps_costs[i] step_name = self.profile_data.steps_names[i] row = [{'value': step_name}] if show_gpu: @@ -791,11 +763,9 @@ class RunGenerator(object): build_avg_cost_dict('Other', self.profile_data.avg_costs.costs[ProfileRole.Other]) ]) - data['performance'] = [ - {'name': 'Average Step Time', 'description': '', + data['performance'] = [{'name': 'Average Step Time', 'description': '', 'value': round(self.profile_data.avg_costs.costs[ProfileRole.Total]), - 'extra': 100, 'children': avg_costs} - ] + 'extra': 100, 'children': avg_costs}] if len(self.profile_data.recommendations) == 0: html = '
  • N/A
  • ' @@ -945,8 +915,7 @@ class RunGenerator(object): }, 'data': table } - table['columns'] = [ - {'type': 'string', 'name': 'Name'}, + table['columns'] = [{'type': 'string', 'name': 'Name'}, {'type': 'string', 'name': 'Operator'}, {'type': 'string', 'name': 'Grid'}, {'type': 'string', 'name': 'Block'}, @@ -955,8 +924,7 @@ class RunGenerator(object): {'type': 'string', 'name': 'Kernel Uses Tensor Cores', 'tooltip': consts.TOOLTIP_KERNEL_USES_TC}, {'type': 'string', 'name': 'Op is Tensor Cores eligible', - 'tooltip': consts.TOOLTIP_KERNEL_OP_TC_ELIGIBLE} - ] + 'tooltip': consts.TOOLTIP_KERNEL_OP_TC_ELIGIBLE}] col_names = ['Calls', 'Total Duration (us)', 'Mean Duration (us)', 'Max Duration (us)', 'Min Duration (us)'] for column in col_names: table['columns'].append({'type': 'number', 'name': column}) @@ -967,16 +935,14 @@ class RunGenerator(object): kernel_list: List[KernelAggByNameOp] = sorted( self.profile_data.kernel_list_groupby_name_op, key=lambda x: x.total_duration, reverse=True) for agg_by_name_op in kernel_list: - kernel_op_row = [ - agg_by_name_op.name, agg_by_name_op.op_name, + kernel_op_row = [agg_by_name_op.name, agg_by_name_op.op_name, str(agg_by_name_op.grid), str(agg_by_name_op.block), str(agg_by_name_op.regs_per_thread or '0'), str(agg_by_name_op.shared_memory or '0'), 'Yes' if agg_by_name_op.tc_used else 'No', 'Yes' if agg_by_name_op.op_tc_eligible else 'No', agg_by_name_op.calls, agg_by_name_op.total_duration, round(agg_by_name_op.avg_duration), - agg_by_name_op.max_duration, agg_by_name_op.min_duration - ] + agg_by_name_op.max_duration, agg_by_name_op.min_duration] if self.profile_data.gpu_metrics_parser.has_blocks_per_sm: kernel_op_row.append(round(agg_by_name_op.avg_blocks_per_sm, 2)) if self.profile_data.gpu_metrics_parser.has_occupancy: @@ -999,11 +965,9 @@ class RunGenerator(object): }, 'data': table } - table['columns'] = [ - {'type': 'string', 'name': 'Name'}, + table['columns'] = [{'type': 'string', 'name': 'Name'}, {'type': 'string', 'name': 'Tensor Cores Used', - 'tooltip': consts.TOOLTIP_KERNEL_USES_TC} - ] + 'tooltip': consts.TOOLTIP_KERNEL_USES_TC}] columns = ['count', 'sum', 'mean', 'max', 'min'] round_digits = [0, 0, 0, 0, 0] if self.profile_data.gpu_metrics_parser.has_blocks_per_sm: @@ -1047,8 +1011,7 @@ class RunGenerator(object): {'type': 'number', 'name': 'Total Durations(us)'}, {'type': 'number', 'name': 'Min Durations(us)'}, {'type': 'number', 'name': 'Avg Durations(us)'}, - {'type': 'number', 'name': 'Max Durations(us)'} - ] + {'type': 'number', 'name': 'Max Durations(us)'}] table['rows'] = [] for key, value in self.statistic_data.items(): temp = [key] @@ -1074,14 +1037,14 @@ class RunGenerator(object): 'data': table } path = self.profile_data.kernel_file_path - datas = RunGenerator.get_csv_data(path) + datas = RunGenerator._get_csv_data(path) required_column_idxs = { 'Name': -1, 'Duration(us)': -1, 'Accelerator Core': -1 } (name_idx, duration_idx, core_type_idx), column_exist_count = \ - RunGenerator.check_csv_columns(datas[0], required_column_idxs) + RunGenerator._check_csv_columns(datas[0], required_column_idxs) if column_exist_count < 3: logger.error('Required column is missing in file "kernel_details.csv"') else: @@ -1095,6 +1058,16 @@ class RunGenerator(object): table['rows'] = datas[1:] return result + @staticmethod + def _get_csv_data(path: str): + if path is None: + return [] + datas = [] + with open(path, encoding='utf-8-sig') as f: + for row in csv.reader(f, skipinitialspace=True): + datas.append(row) + return datas + def _generate_tc_pie_npu(self): pie = {'columns': [{'type': 'string', 'name': 'name'}, {'type': 'number', 'name': 'value'}], 'rows': []} for key, val in self.accelerator_data.items(): @@ -1103,7 +1076,7 @@ class RunGenerator(object): return data @staticmethod - def get_gpu_info(device_props, gpu_id): + def _get_gpu_info(device_props, gpu_id): if (device_props is None) or (gpu_id >= len(device_props)) or (gpu_id < 0): return None @@ -1144,17 +1117,12 @@ class RunGenerator(object): self.accelerator_data[call_type] = call_duration if self.statistic_data.get(call_name) is not None: - temp = self.statistic_data.get(call_name, {}) - temp['Max'] = max(temp.get('Max', 0), call_duration) - temp['Min'] = min(temp.get('Min', 0), call_duration) - temp['Total'] = round(temp.get('Total', 0) + call_duration, 2) - temp['Calls'] = temp.get('Calls', 0) + 1 - if temp['Calls'] == 0: - logger.error( - f'temp["Calls"] is zero which can not be divisor.') - temp['Average'] = 0 - else: - temp['Average'] = round(temp['Total'] / temp['Calls'], 2) + temp = self.statistic_data[call_name] + temp['Max'] = max(temp['Max'], call_duration) + temp['Min'] = min(temp['Min'], call_duration) + temp['Total'] = round(temp['Total'] + call_duration, 2) + temp['Calls'] += 1 + temp['Average'] = round(temp['Total'] / temp['Calls'], 2) else: self.statistic_data[call_name] = { 'Calls': 1, @@ -1204,7 +1172,7 @@ class DistributedRunGenerator(object): process_id = 'Process ' + str(process_id) result[node][process_id] = OrderedDict() for used_device in data.used_devices: - gpu_info = RunGenerator.get_gpu_info(data.device_props, used_device) + gpu_info = RunGenerator._get_gpu_info(data.device_props, used_device) if gpu_info is not None: result[node][process_id]['GPU' + str(used_device)] = gpu_info @@ -1255,9 +1223,7 @@ class DistributedRunGenerator(object): round(costs.other, 3) ] steps_to_overlap['all'][data.worker] = [ - sum(x) - for x in zip(steps_to_overlap['all'][data.worker], steps_to_overlap[step_name][data.worker]) - ] + sum(x) for x in zip(steps_to_overlap['all'][data.worker], steps_to_overlap[step_name][data.worker])] @staticmethod def _get_npu_overlap_data(data, steps_to_overlap): @@ -1269,9 +1235,7 @@ class DistributedRunGenerator(object): steps_to_overlap[k][data.worker] = list( [round(v[0] - v[1], 3), round(v[1], 3), round(v[2], 3), round(v[3], 3)]) steps_to_overlap['all'][data.worker] = [ - sum(x) - for x in zip(steps_to_overlap['all'][data.worker], steps_to_overlap[k][data.worker]) - ] + sum(x) for x in zip(steps_to_overlap['all'][data.worker], steps_to_overlap[k][data.worker])] @staticmethod def _get_npu_wait_data(data, steps_to_wait): @@ -1286,9 +1250,7 @@ class DistributedRunGenerator(object): wait = round(v.get('Synchronize') * 1000, 3) # 1ms = 1000us steps_to_wait[k][data.worker] = list([trans, wait]) steps_to_wait['all'][data.worker] = [ - sum(x) - for x in zip(steps_to_wait['all'][data.worker], steps_to_wait[k][data.worker]) - ] + sum(x) for x in zip(steps_to_wait['all'][data.worker], steps_to_wait[k][data.worker])] steps_to_wait['all'][data.worker] = [x / step_number for x in steps_to_wait['all'][data.worker]] @staticmethod @@ -1302,9 +1264,7 @@ class DistributedRunGenerator(object): round(comm_stats[0] - comm_stats[1], 3) ] steps_to_wait['all'][data.worker] = [ - sum(x) - for x in zip(steps_to_wait['all'][data.worker], steps_to_wait[step][data.worker]) - ] + sum(x) for x in zip(steps_to_wait['all'][data.worker], steps_to_wait[step][data.worker])] steps_to_wait['all'][data.worker] = [int(x / step_number) for x in steps_to_wait['all'][data.worker]] def _generate_wait_graph(self): @@ -1392,11 +1352,10 @@ class DistributedRunGenerator(object): op, stats[0], round(stats[1], 3), - - round(stats[1] / stats[0] if stats[0] != 0 else 0), + round(stats[1] / stats[0] if stats != 0 else 0), round(stats[2], 3), - round(stats[2] / stats[0] if stats[0] != 0 else 0), + round(stats[2] / stats[0] if stats != 0 else 0), round(stats[3], 3), - round(stats[3] / stats[0] if stats[0] != 0 else 0) + round(stats[3] / stats[0] if stats != 0 else 0) ] table['rows'].append(row) diff --git a/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/tensor_core.py b/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/tensor_core.py index cc53ab217f0ee6f88817c51da6ba46da68df4e28..3a69cf70b881acc4588682fc4440cb5534541eb1 100644 --- a/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/tensor_core.py +++ b/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/tensor_core.py @@ -1,13 +1,14 @@ # ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # ------------------------------------------------------------------------- -class TcAllowlistMeta(type): - # Enable grammar sugar as 'v in TcAllowlist'. +class TC_Allowlist_Meta(type): + # Enable grammar sugar as 'v in TC_Allowlist'. def __contains__(cls, item): return cls.__contains__(item) -class TcAllowlist(metaclass=TcAllowlistMeta): +class TC_Allowlist(metaclass=TC_Allowlist_Meta): + # Refer to https://github.com/NVIDIA/PyProf/blob/fd1b2902e3306119eee40ba6b6e8b2f816920c29/pyprof/prof/tc.py#L19 allowlist = ['h884', 's884', 'h1688', 's1688', 'hmma', 'i8816', '16816', 'dgrad_1x1_stride_2x2', 'first_layer_wgrad_kernel', 'conv1x1', 'conv2d_c1_k1', 'direct_group', 'xmma_implicit_gemm', @@ -23,7 +24,8 @@ class TcAllowlist(metaclass=TcAllowlistMeta): return False -class TcOpAllowlist(metaclass=TcAllowlistMeta): +class TC_OP_Allowlist(metaclass=TC_Allowlist_Meta): + # Refer to https://github.com/pytorch/pytorch/blob/69b2bf70f9c0e591ce5e566afa59e19618031ead/aten/src/ATen/autocast_mode.cpp#L290-L351 # noqa: E501 allowlist = ['aten::_convolution', 'aten::conv1d', 'aten::conv2d', 'aten::conv3d', 'aten::conv_tbc', 'aten::conv_transpose1d', 'aten::conv_transpose2d', 'aten::conv_transpose3d', 'aten::convolution', 'aten::cudnn_convolution', 'aten::cudnn_convolution_transpose', diff --git a/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/trace.py b/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/trace.py index ea09f79666bd184956469f48fc7922854394940d..e76f8b18dd80a9f12a867c9395de6a96a39bc2c1 100644 --- a/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/trace.py +++ b/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/trace.py @@ -1,13 +1,13 @@ # ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # -------------------------------------------------------------------------- -__all__ = ['EventTypes', 'create_event'] - from enum import IntEnum from typing import Dict, Optional from .. import utils +__all__ = ['EventTypes', 'create_event'] + logger = utils.get_logger() NcclOpNameSet = ['nccl:broadcast', 'nccl:reduce', 'nccl:all_reduce', 'nccl:all_gather', 'nccl:reduce_scatter'] @@ -56,8 +56,8 @@ EventTypeMap = { class BaseEvent(object): - def __init__(self, event_type, data): - self.type: str = event_type + def __init__(self, type, data): + self.type: str = type self.name: str = data.get('name') self.ts: int = data.get('ts') self.pid: int = data.get('pid') @@ -66,8 +66,8 @@ class BaseEvent(object): class DurationEvent(BaseEvent): - def __init__(self, event_type, data): - super().__init__(event_type, data) + def __init__(self, type, data): + super().__init__(type, data) self.category: str = data.get('cat', '') self.duration: int = data.get('dur') @@ -79,8 +79,8 @@ class DurationEvent(BaseEvent): class KernelEvent(DurationEvent): - def __init__(self, event_type, data): - super().__init__(event_type, data) + def __init__(self, type, data): + super().__init__(type, data) self.occupancy = self.args.get('est. achieved occupancy %') self.blocks_per_sm = self.args.get('blocks per SM') self.grid = self.args.get('grid') @@ -91,8 +91,8 @@ class KernelEvent(DurationEvent): class OperatorEvent(DurationEvent): - def __init__(self, event_type, data): - super().__init__(event_type, data) + def __init__(self, type, data): + super().__init__(type, data) self.callstack = self.args.get('Call stack') self.input_type = self.args.get('Input type') @@ -111,8 +111,8 @@ class ProfilerStepEvent(OperatorEvent): class MemoryEvent(BaseEvent): - def __init__(self, event_type, data): - super().__init__(event_type, data) + def __init__(self, type, data): + super().__init__(type, data) self.scope: str = data.get('s', '') self.device_id: int = self.args.get('Device Id') dtype = self.args.get('Device Type') @@ -142,8 +142,8 @@ class MemoryEvent(BaseEvent): class PythonFunctionEvent(DurationEvent): - def __init__(self, event_type, data): - super().__init__(event_type, data) + def __init__(self, type, data): + super().__init__(type, data) self.python_id: int = self.args.get('Python id') self.python_parent_id: int = self.args.get('Python parent id') diff --git a/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/run.py b/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/run.py index 9e30f225244280df7acfd7d2deb95a40208cfa54..2f719fb0c6139e498f51afdcad2497293e90ad1e 100644 --- a/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/run.py +++ b/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/run.py @@ -77,7 +77,7 @@ class Run(object): if worker is not None: if self.span_view.get(worker) is None: return None - spans = self.span_view.get(worker, []) + spans = self.span_view[worker] else: spans = [s for _, s in self.profiles.keys()] diff --git a/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/static/trace_embedding.html b/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/static/trace_embedding.html index 462d2c395f81d932fbf0196ccc53f4b0ece6e93a..bb84da0d0c0cb92d51a2d6ab1cb92ce308b23241 100644 --- a/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/static/trace_embedding.html +++ b/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/static/trace_embedding.html @@ -11,7 +11,7 @@ found in the LICENSE file. 'use strict'; function onTraceViewerImportFail() { - document.addEventListener('DOMContentLoaded', () => { + document.addEventListener('DOMContentLoaded', function () { document.body.textContent = 'tracing/bin/trace_viewer_full.html is missing. ' + 'Run vulcanize_trace_viewer from $TRACE_VIEWER and reload.'; @@ -52,11 +52,12 @@ found in the LICENSE file. // warning. window.__hideTraceViewerPolyfillWarning = true; - window.addEventListener('message', event => { - const data = event.data || {}; - name = data.name || 'unknown'; - onResult(data.data); - }); + window.addEventListener("message", event => { + const data = event.data || {} + console.log(data) + name = data.name || 'unknown' + onResult(data.data) + }) function onResult(result) { model = new tr.Model(); @@ -77,7 +78,7 @@ found in the LICENSE file. overlay.visible = true; } - document.addEventListener('WebComponentsReady', () => { + document.addEventListener('WebComponentsReady', function () { const container = document.createElement('track-view-container'); container.id = 'track_view_container'; @@ -90,7 +91,7 @@ found in the LICENSE file. Polymer.dom(document.body).appendChild(viewer); if (window.parent) { - window.parent.postMessage({ msg: 'ready' }, window.origin); + window.parent.postMessage({ msg: 'ready' }, '*') } }); }()); diff --git a/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/utils.py b/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/utils.py index 5991cf2b33d1e818e6876c8d7550fbb6c87cdaa3..8f4189d765e6e9233478d800ab2d1424597af254 100644 --- a/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/utils.py +++ b/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/utils.py @@ -23,15 +23,14 @@ import math import os import time from contextlib import contextmanager +from math import pow from . import consts -predefined_logging_level = ('CRITICAL', 'ERROR', 'WARNING', 'INFO', 'DEBUG', 'NOTSET') - def get_logging_level(): log_level = os.environ.get('TORCH_PROFILER_LOG_LEVEL', 'INFO').upper() - if log_level not in predefined_logging_level: + if log_level not in logging._levelToName.values(): log_level = logging.getLevelName(logging.INFO) return log_level @@ -77,6 +76,7 @@ class Canonicalizer: input_time_metric='us', input_memory_metric='B'): # raw timestamp is in microsecond + # https://github.com/pytorch/pytorch/blob/v1.9.0/torch/csrc/autograd/profiler_kineto.cpp#L33 time_metric_to_factor = { 'us': 1, 'ms': 1e3, @@ -84,10 +84,10 @@ class Canonicalizer: } # raw memory is in bytes memory_metric_to_factor = { - 'B': math.pow(1024, 0), - 'KB': math.pow(1024, 1), - 'MB': math.pow(1024, 2), - 'GB': math.pow(1024, 3), + 'B': pow(1024, 0), + 'KB': pow(1024, 1), + 'MB': pow(1024, 2), + 'GB': pow(1024, 3), } # canonicalize the memory metric to a string @@ -125,7 +125,7 @@ class DisplayRounder: def __init__(self, ndigits): self.ndigits = ndigits - self.precision = math.pow(10, -ndigits) + self.precision = pow(10, -ndigits) def __call__(self, v: float): _v = abs(v) diff --git a/profiler/__init__.py b/profiler/__init__.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..de0604079e1323b2749bc801a6e8326893c73498 100644 --- a/profiler/__init__.py +++ b/profiler/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. \ No newline at end of file diff --git a/profiler/example/mstx_torch_plugin/mstx_torch_plugin.py b/profiler/example/mstx_torch_plugin/mstx_torch_plugin.py index ed22a3d0b7eed2ab0b457bb0a185061dacabc186..f6b25db7cf48cb5bcd2be687251c2499ecb30965 100644 --- a/profiler/example/mstx_torch_plugin/mstx_torch_plugin.py +++ b/profiler/example/mstx_torch_plugin/mstx_torch_plugin.py @@ -14,9 +14,8 @@ # limitations under the License. import os import functools -import re -import site import torch +import torch_npu from torch.nn import Module from torch.utils.data import DataLoader from torch.optim.optimizer import register_optimizer_step_post_hook @@ -29,18 +28,6 @@ original_multinext = torch.utils.data.dataloader._MultiProcessingDataLoaderIter. origin_patch_step_function = torch.optim.Optimizer._patch_step_function -def _check_directory_path_readable(path): - if not os.path.exists(path): - msg = f"The path dose not exist: {path}" - raise RuntimeError(msg) - if os.path.islink(path): - msg = f"Invalid path is a soft chain: {path}" - raise RuntimeError(msg) - if not os.access(path, os.R_OK): - msg = f"The path permission check failed: {path}" - raise RuntimeError(msg) - - class MstxState: def __init__(self): self.module_dict = {} @@ -157,57 +144,9 @@ def _custom_step(optimizer: torch.optim.Optimizer): mstx_state.last_optimizer_id = id(optimizer) -def _get_torch_npu_version_str(): - torch_npu_version_str = "" - site_packages = site.getsitepackages() - if site_packages and site_packages[0]: - path = site_packages[0] - version_path = os.path.join(path, "torch_npu", "version.py") - _check_directory_path_readable(version_path) - # example version info: "__version__ = '2.1.0.post11.xxxxxx'" - try: - with open(version_path, "r") as f: - for line in f: - if line.find("__version__") != -1: - torch_npu_version_str = line.strip().split("=")[-1][2:-1] - break - except Exception as e: - raise RuntimeError(f"Failed to open {version_path} to get torch npu version.") from e - return torch_npu_version_str - - -def _get_torch_npu_info(version_str: str): - # version info example: "2.1.0.post11.xxxxxx" - match = re.search(r"^(\d+\.\d+\.\d+)\.post(\d+)", version_str) - if match and len(match.groups()) == 2: - return match.group(1), match.group(2) - else: - return '', '' - - -def _check_pta_support_patch(): - pta_support_patch_version = { - "2.1.0": 10, - "2.3.1": 4, - "2.4.0": 2, - } - torch_npu_version_str = _get_torch_npu_version_str() - if not torch_npu_version_str: - raise RuntimeError("Failed to get torch_npu version info.") - torch_branch, torch_npu_version = _get_torch_npu_info(torch_npu_version_str) - if not torch_branch or not torch_npu_version or not torch_npu_version.isdigit(): - raise RuntimeError("Failed to get valid torch branch or torch_npu version.") - for branch, post_version in pta_support_patch_version.items(): - if torch_branch == branch and int(torch_npu_version) <= post_version: - return False - return True - - def apply_mstx_patch(): - pta_support_patch = _check_pta_support_patch() Module.__call__ = _custom_forward_call - if not pta_support_patch: - DataLoader.__iter__ = _custom_dataloader_iter - torch.serialization.save = _custom_save(original_save) + DataLoader.__iter__ = _custom_dataloader_iter + torch.serialization.save = _custom_save(original_save) torch.optim.Optimizer._patch_step_function = _custom_step register_optimizer_step_post_hook(_step_hook) diff --git a/profiler/merge_profiling_timeline/README.md b/profiler/merge_profiling_timeline/README.md new file mode 100644 index 0000000000000000000000000000000000000000..24db91adee88d74bff99117189e70a6ad632ddd3 --- /dev/null +++ b/profiler/merge_profiling_timeline/README.md @@ -0,0 +1,115 @@ +# 合并大json工具 + +merge_profiling_timeline(合并大json工具)支持合并Profiling的timeline数据,支持合并指定rank的timline、合并指定timeline中的item。 + + +## 多timeline融合 + +### 性能数据采集 + +使用Ascend PyTorch Profiler或者E2E性能采集工具采集性能数据,E2E profiling将被废弃,不建议使用。Ascend PyTorch Profiler采集方式参考:[Profiling数据采集](https://gitee.com/ascend/mstt/tree/master/profiler/msprof_analyze)。将采集到的所有节点的性能数据拷贝到当前环境同一目录下,以下假设数据在/home/test/cann_profiling下。 + +E2E Profiling数据目录结构示例如下: + +```bash +|- cann_profiling + |- PROF_*** + |- timeline + |- msprof.json + |- device_* + |- info.json.* + ... + |- PROF_*** + ... +``` + +Ascend PyTorch Profiler数据目录结构示例如下: + +```bash +|- ascend_pytorch_profiling + |- **_ascend_pt + |- ASCEND_PROFILER_OUTPUT + |- trace_view.json + |- FRAMEWORK + |- PROF_*** + |- **_ascend_pt +``` + +### 参数说明 + +| 参数名称 | 说明 | 是否必选 | +| -------- | ------------------------------------------------------------ | -------- | +| -i | 指定Profiling数据目录路径。 | 是 | +| --type | 指定需要合并timeline场景,可选取值:`pytorch`(通过Ascend PyTorch Profiler方式采集profiling数据,合并所有卡的trace_view.json)、`e2e`(通过E2E Profiling方式采集Profiling数据,优先合并总timeline,没有生成则选择合并device目录下的msprof_*.json)、`custom` (自定义需要合并的timeline数据,具体参考**使用示例**)。 | 是 | +| -o | 指定合并后的timeline文件输出的路径(路径末尾可以设置文件名,具体用法参考**使用示例**),不设置该参数的情况下默认文件输出的路径为当前目录(默认文件名为merged.json)。 | 否 | +| --rank | 指定需要合并timeline的Rank ID,默认全部合并。 | 否 | +| --items | 指定需要合并的Profiling数据项,包括:python、Ascend Hardware、CANN、HCCL、PTA、Overlap Analysis,默认全部合并。 | 否 | + +### 使用示例 + +1. 合并单机多卡timeline,默认合并所有卡、所有数据项,生成first.json在path/to/cann_profiling/output/目录下 + + ```bash + python3 main.py -i path/to/cann_profiling/ -o path/to/cann_profiling/output/first --type pytorch + ``` + +2. 合并单机多卡timeline,默认合并所有卡、所有数据项,不设置-o参数时默认生成merge.json在当前目录下 + + ```bash + python3 main.py -i path/to/cann_profiling/ --type pytorch + ``` + +3. 合并单机多卡timeline,只合并0卡和1卡 + + ```bash + python3 main.py -i path/to/cann_profiling/ -o path/to/cann_profiling/output/2p --type pytorch --rank 0,1 + ``` + +4. 合并单机多卡timeline,合并所有卡的CANN层和Ascend_Hardware层数据 + + ```bash + python3 main.py -i path/to/cann_profiling/ --type pytorch --items "CANN,Ascend Hardware" + ``` + +5. 合并多timeline(自定义) + + 以上场景不支持的情况下,可以使用自定义的合并方式,将需要合并的timeline文件放在同一目录下(附:该场景比较特殊,与正常合并不同,无法直接读取info.json中的rank_id,因此该场景下的rank_id为默认分配的序号,用于区分不同文件的相同层,不代表实际rank_id) + 数据目录结构示意如下: + + ```bash + |- timeline + |- msprof_0.json + |- msprof_1.json + |- msprof_2.json + |- hccl_3.json + |- hccl_4.json + ... + ``` + + 通过下面的命令合并所有timeline,同样支持-o、--rank、--items等参数。 + + ```bash + python3 main.py -i path/to/timeline/ -o path/to/timeline/xxx --type custom + ``` + + 合并timeline查看:在 -o 指定的目录(不设置-o时默认在当前目录下的merged.json)的xxx.json为合并后的文件。 + + +## 超大timeline文件查看 + +[下载whl](https://gitee.com/aerfaliang/trace_processor/releases/download/trace_processor_37.0/trace_processor-37.0-py3-none-any.whl)包并执行如下命令安装(windows): + +```bash +pip3 install trace_processor-37.0-py3-none-any.whl +``` + +安装完成后直接执行如下命令: + +```bash +python -m trace_processor --httpd path/to/xxx_merged.json +``` + +等待加载完毕,刷新[perfetto](https://ui.perfetto.dev/)界面,单击Use old version regardless,再单击`YES, use loaded trace`即可展示timeline(通过W放大、S缩小、A左移、D右移来查看timeline文件)。 + +![输入图片说明](perfetto使用指导截图1.png) +![输入图片说明](perfetto使用指导截图2.png) \ No newline at end of file diff --git a/profiler/merge_profiling_timeline/__init__.py b/profiler/merge_profiling_timeline/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/profiler/merge_profiling_timeline/main.py b/profiler/merge_profiling_timeline/main.py new file mode 100644 index 0000000000000000000000000000000000000000..722457812b8c039317cbf541d26767ee2bb91361 --- /dev/null +++ b/profiler/merge_profiling_timeline/main.py @@ -0,0 +1,237 @@ +#! /usr/bin/python3 +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import re + +from functools import partial +from argparse import ArgumentParser +from decimal import Decimal + + +FILTER_DIRS = [".profiler", "HCCL_PROF", "timeline", "query", 'sqlite', 'log'] +RANK_ID_POS = 1000 + +def get_path_dir(path: str) -> list: + """ + check result path exist JOB dir + path : result path + """ + path_dir_filter = filter(partial(_path_dir_filter_func, root_dir=path), os.listdir(path)) + sub_dirs = list(path_dir_filter) + return sub_dirs + + +def _path_dir_filter_func(sub_path, root_dir): + return sub_path not in FILTER_DIRS and os.path.isdir(os.path.realpath(os.path.join(root_dir, sub_path))) + + +def natural_sort(files): + def convert(text): + return int(text) if text.isdigit() else text.lower() + + def alphanum_key(key): + return [convert(c) for c in re.split('([0-9]+)', key)] + + return sorted(files, key=alphanum_key) + + +def get_timeline_info(args, prof_dirs): + timeline_info = {} + + for prof in prof_dirs: + pro_path = os.path.join(args.input, prof) + + # 从info.json读取rank_id + rank_id = get_rank_id_from_info_json(pro_path) + if rank_id is None: + print(f"WARN, There is not rank id info in {pro_path}") + continue + + timeline_path = get_timeline_path(pro_path, args.type) + + if os.path.exists(timeline_path): + timeline_info[rank_id] = timeline_path + else: + print(f"WARN, The file \"{timeline_path}\" does not exist.") + return timeline_info + + +def get_timeline_path(pro_path, type): + for root, dirs, files in os.walk(pro_path): + for dir_ in dirs: + if 'ASCEND_PROFILER_OUTPUT' == dir_ and type == 'pytorch': + timeline_path = os.path.realpath(os.path.join(root, dir_, 'trace_view.json')) + return timeline_path + + for file_ in sorted(files, reverse=True): + if 'msprof' in file_: + timeline_path = os.path.join(root, file_) + return timeline_path + return None + +def get_rank_id_from_info_json(pro_path): + info_json = "" + rank_id = None + for root, _, files in os.walk(pro_path): + for file in files: + if "info.json." in file and ".done" not in file: + info_json = os.path.join(root, file) + break + + if info_json: + if os.path.islink(info_json): + print(f"The file: \"{info_json}\" is link. Please check the path.") + return None + try: + with open(info_json, "r+") as f: + info = json.load(f) + rank_id = info.get("rank_id") + except Exception as err: + print("[ERROR] %s" % err) + return None + return rank_id + + +def merge_timeline_general(args): + """合并e2e profiling生成的msprof*.json""" + if not os.path.isdir(args.input): + print(f"No such file or directory: \"{args.input}\". Please check the path.") + return + prof_dir = get_path_dir(args.input) + if not prof_dir: + message = f"The path \"{args.input}\" does not have PROF dir. Please check the path." + print(message) + return + timeline_info = get_timeline_info(args, prof_dir) + timeline_files_dict = {} + + # 合并部分profiling items + process_list = args.items.split(",") if args.items else None + + # 合并部分rank + if args.rank: + rank_ids = [int(rank_id) for rank_id in args.rank.split(",")] + else: + rank_ids = list(timeline_info.keys()) + + for rank_id in rank_ids: + if not timeline_info.get(rank_id): + print(f"main.py: error rank_id '{rank_id}' ") + return + timeline_files_dict[rank_id] = timeline_info.get(rank_id) + merge_timeline_events(timeline_files_dict, process_list) + + +def merge_timeline_custom(args): + """合并指定目录里所有timeline文件""" + timeline_files = natural_sort(os.listdir(args.input)) + timeline_files_dict = {} + for idx, timeline_file in enumerate(timeline_files): + timeline_files_dict[idx] = os.path.join(args.input, timeline_file) + # 合并部分profiling items + process_list = args.items.split(",") if args.items else None + merge_timeline_events(timeline_files_dict, process_list) + + +def merge_timeline_events(timeline_file_dict, process_list): + """ + 输入需要合并的timeline文件路径及对应的rank_id/id、需要合并的process_list + 输出合并timeline + """ + new_events = [] + for rank_id, timeline_path in timeline_file_dict.items(): + node = rank_id // 8 + print("rank id: ", rank_id, "timeline file: ", timeline_path) + if os.path.islink(timeline_path): + print(f"The file: \"{timeline_path}\" is link. Please check the path.") + return + try: + with open(timeline_path, 'r+') as f: + cur_events = json.load(f) + except Exception as err: + print("[ERROR] %s" % err) + return + + proc_pid_dict = {} + for event in cur_events: + if event.get("name") == "process_name" and event.get("ph") == "M": + if event.get("args"): + proc_pid_dict[event["args"].get("name")] = event.get("pid") + process_list_tmp = process_list if process_list else list(proc_pid_dict.keys()) + # 提取待合并的items的pid + merged_pids = set() + for pro in process_list_tmp: + if pro not in proc_pid_dict.keys(): + print(f"main.py: error argument --items: invalid choice: '{pro}' (choose from {list(proc_pid_dict.keys())})") + return + merged_pids.add(proc_pid_dict.get(pro)) + + for event in cur_events: + + # 只合并特定数据项 + if merged_pids and event.get('pid') not in merged_pids: + continue + + # convert tid to int + if not isinstance(event['tid'], int): + print(f"[WARNNING] {event['tid']} is not int type") + + # 进程名加上rank_id区分不同rank + if event.get("name") == "process_name" and event.get("ph") == "M": + if event.get("args") is not None and event["args"].get("name") is not None: + event["args"]["name"] = event["args"]["name"] + f"_{rank_id}" + + #modify connect id + if event.get('id') and (event.get('ph') == 's' or event.get('ph') == 'f'): + event['id'] = float(event.get('id')) * RANK_ID_POS + rank_id + + new_events.append(event) + out_path = f"{args.output}.json" + if os.path.islink(out_path): + print(f"The file: \"{out_path}\" is link. Please check the path.") + return + if os.path.exists(out_path): + print(f"File {out_path} existed before and is now overwritten.") + os.remove(out_path) + try: + # 设置文件权限为640,安全考虑 + with os.fdopen(os.open(out_path, os.O_WRONLY | os.O_CREAT, 0o640), 'w') as f: + json.dump(new_events, f) + except FileNotFoundError: + print(f"Param -o (output path) is not exists, please check it.") + return + print(f"timeline merged output path: {out_path}") + + +def parse_args(): + parser = ArgumentParser(description="Merge timeline for multi card") + parser.add_argument("-i", "--input", default=None, help="root dir of PROF_* data") + parser.add_argument("-o", "--output", default="./merged", help="save path of merged.json ") + parser.add_argument("--rank", default=None, help="List of ranks to be merged. By default, all ranks are merged") + parser.add_argument("--items", default=None, help="Specify the data items (python,CANN,Ascend Hardware,HCCL,..)to be merged. in the timeline.") + parser.add_argument("--type", choices=('pytorch', 'e2e', 'custom'), help="Customize the timeline file to be merged.") + arg = parser.parse_args() + return arg + + +if __name__ == "__main__": + args = parse_args() + print("========================== start merge timeline ====================") + if args.type == "custom": + merge_timeline_custom(args) + else: + merge_timeline_general(args) \ No newline at end of file diff --git "a/profiler/merge_profiling_timeline/perfetto\344\275\277\347\224\250\346\214\207\345\257\274\346\210\252\345\233\2761.png" "b/profiler/merge_profiling_timeline/perfetto\344\275\277\347\224\250\346\214\207\345\257\274\346\210\252\345\233\2761.png" new file mode 100644 index 0000000000000000000000000000000000000000..beef396ce2996c25ecd74298285ccab5011ddea1 Binary files /dev/null and "b/profiler/merge_profiling_timeline/perfetto\344\275\277\347\224\250\346\214\207\345\257\274\346\210\252\345\233\2761.png" differ diff --git "a/profiler/merge_profiling_timeline/perfetto\344\275\277\347\224\250\346\214\207\345\257\274\346\210\252\345\233\2762.png" "b/profiler/merge_profiling_timeline/perfetto\344\275\277\347\224\250\346\214\207\345\257\274\346\210\252\345\233\2762.png" new file mode 100644 index 0000000000000000000000000000000000000000..48793f136e48f21f618ff3cb13bdcc3388f76930 Binary files /dev/null and "b/profiler/merge_profiling_timeline/perfetto\344\275\277\347\224\250\346\214\207\345\257\274\346\210\252\345\233\2762.png" differ diff --git a/profiler/msprof_analyze/MANIFEST.in b/profiler/msprof_analyze/MANIFEST.in index b4d096405c98ea1a906b8882418362d428cbf1b6..df1488cce957db8d6135caf1e65e834103fe92ed 100644 --- a/profiler/msprof_analyze/MANIFEST.in +++ b/profiler/msprof_analyze/MANIFEST.in @@ -3,5 +3,6 @@ recursive-include msprof_analyze/cli/ * recursive-include msprof_analyze/prof_common/ * recursive-include msprof_analyze/compare_tools/ * recursive-include msprof_analyze/cluster_analyse/ * +recursive-include msprof_analyze/precheck/ * global-exclude */__pycache__/* global-exclude *.pyc diff --git a/profiler/msprof_analyze/OWNERS b/profiler/msprof_analyze/OWNERS index 864e7ecc649aab5a9eb5d6db1b33e9dd8a8882dc..7524470824c5552b570c09cc231e74811a15adf7 100644 --- a/profiler/msprof_analyze/OWNERS +++ b/profiler/msprof_analyze/OWNERS @@ -1,10 +1,12 @@ -options: - no_parent_owners: true -approvers: -- xhahn -- aerfaliang -- chenhao_1209 -- feng123www -reviewers: -- Seanesmhxocism -- wjchuee +options: + no_parent_owners: true +approvers: +- xhahn +- aerfaliang +- chenhao_1209 +- feng123www +- sunboquan +reviewers: +- sunboquan +- Seanesmhxocism +- wjchuee diff --git a/profiler/msprof_analyze/README.md b/profiler/msprof_analyze/README.md index 7e2267a55596bac342b0e2ada564ea31c5625a84..c3be2acd6ef1a33c629a07cba10be953036cfefd 100644 --- a/profiler/msprof_analyze/README.md +++ b/profiler/msprof_analyze/README.md @@ -1,250 +1,250 @@ -# 性能工具 - -MindStudio Training Tools工具针对训练&大模型场景,提供端到端性能调优工具msprof-analyze:用户采集到性能数据后,由MindStudio Training Tools的性能工具msprof-analyze提供统计、分析以及相关的调优建议。 - -## NPU性能数据采集 - -目前MindStudio Training Tools工具主要支持对Ascend PyTorch Profiler接口采集的性能数据进行分析,请参考官方文档:[Ascend PyTorch Profiler数据采集与分析](https://www.hiascend.com/document/detail/zh/canncommercial/80RC1/devaids/auxiliarydevtool/atlasprofiling_16_0006.html)。 - -### 环境和依赖 - -- 硬件环境请参见《[昇腾产品形态说明](https://gitee.com/link?target=https%3A%2F%2Fwww.hiascend.com%2Fdocument%2Fdetail%2Fzh%2Fcanncommercial%2F80RC22%2Fquickstart%2Fquickstart%2Fquickstart_18_0002.html)》。 -- 软件环境请参见《[CANN 软件安装指南](https://gitee.com/link?target=https%3A%2F%2Fwww.hiascend.com%2Fdocument%2Fdetail%2Fzh%2Fcanncommercial%2F80RC22%2Fsoftwareinst%2Finstg%2Finstg_0000.html%3FMode%3DPmIns%26OS%3DUbuntu%26Software%3DcannToolKit)》安装昇腾设备开发或运行环境,即toolkit软件包。 - -以上环境依赖请根据实际环境选择适配的版本。 - -### 版本配套说明 - -- Ascend PyTorch Profiler接口支持AscendPyTorch 1.11.0或更高版本,支持的PyTorch和CANN以及PyTorch和Python软件版本配套关系请参见《[Ascend Extension for PyTorch插件](https://gitee.com/ascend/pytorch)》。 -- Ascend PyTorch Profiler接口支持的固件驱动版本与配套CANN软件支持的固件驱动版本相同,开发者可通过“[昇腾社区-固件与驱动](https://gitee.com/link?target=https%3A%2F%2Fwww.hiascend.com%2Fhardware%2Ffirmware-drivers%2Fcommunity%3Fproduct%3D2%26model%3D28%26cann%3D8.0.RC3.alpha003%26driver%3D1.0.25.alpha)”页面根据产品型号与CANN软件版本获取配套的固件与驱动。 - -### 采集方式一:通过with语句进行采集 - -```python -import torch_npu -experimental_config = torch_npu.profiler._ExperimentalConfig( - aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization, - profiler_level=torch_npu.profiler.ProfilerLevel.Level1, - l2_cache=False -) -with torch_npu.profiler.profile( - activities=[ - torch_npu.profiler.ProfilerActivity.CPU, - torch_npu.profiler.ProfilerActivity.NPU - ], - record_shapes=True, - profile_memory=True, - with_stack=True, - experimental_config=experimental_config, - schedule=torch_npu.profiler.schedule(wait=10, warmup=0, active=1, repeat=1), - on_trace_ready=torch_npu.profiler.tensorboard_trace_handler("./profiling_data") -) as prof: - # 模型训练代码 - for epoch, data in enumerate(dataloader): - train_model_one_step(model, data) - prof.step() -``` - -### 采集方式二:start,stop方式进行采集 - -```python -import torch_npu -experimental_config = torch_npu.profiler._ExperimentalConfig( - aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization, - profiler_level=torch_npu.profiler.ProfilerLevel.Level1, - l2_cache=False -) -prof = torch_npu.profiler.profile( - activities=[ - torch_npu.profiler.ProfilerActivity.CPU, - torch_npu.profiler.ProfilerActivity.NPU - ], - record_shapes=True, - profile_memory=True, - with_stack=True, - experimental_config=experimental_config, - on_trace_ready=torch_npu.profiler.tensorboard_trace_handler("./profiling_data")) -# 模型训练代码 -for epoch, data in enumerate(dataloader): - if epoch == 11: - prof.start() - train_model_one_step(model, data) - prof.step() - if epoch == 11: - prof.stop() -``` - -### NPU性能数据目录结构 - -ascend pytorch profiler数据目录结构如下: - -``` -|- ascend_pytorch_profiling - |- * _ascend_pt - |- ASCEND_PROFILER_OUTPUT - |- trace_view.json - |- FRAMEWORK - |- PROF_XXX - |- profiler_info.json - |- * _ascend_pt -``` - -## 安装 - -性能工具的安装方式包括:**pip安装**、**下载whl包安装**和**源代码编译安装**。 - -### pip安装 - -```shell -pip install msprof-analyze -``` - -使用`pip install msprof-analyze==版本号`可安装指定版本的包,支持1.2.1及之后版本,版本号参见“**下载whl包安装**”。 - -pip命令会自动安装最新的包及其配套依赖。 - -提示如下信息则表示安装成功。 - -```bash -Successfully installed msprof-analyze-{version} -``` - -### 下载whl包安装 - -1. whl包获取。 - - 请通过下表链接下载profiler工具whl包。 - -| profiler版本 | 发布日期 | 下载链接 | 校验码 | -|------------|------------|-------------------------------------------------------------------------------------------------------------------------------------------| ------------------------------------------------------------ | -| 2.0.0 | 2025-02-08 | [msprof_analyze-2.0.0-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/profiler/package/2.0.0/msprof_analyze-2.0.0-py3-none-any.whl) | 8e44e5f3e7681c377bb2657a600ad9841d3bed11061ddd7844c30e8a97242101 | -| 1.3.4 | 2025-01-20 | [msprof_analyze-1.3.4-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/profiler/package/1.3.4/msprof_analyze-1.3.4-py3-none-any.whl) | 8de92188d1a97105fb14cadcb0875ccd5f66629ee3bb25f37178da1906f4cce2 | -| 1.3.3 | 2024-12-26 | [msprof_analyze-1.3.3-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/profiler/package/1.3.3/msprof_analyze-1.3.3-py3-none-any.whl) | 27676f2eee636bd0c65243f81e292c7f9d30d7f985c772ac9cbaf10b54d3584e | -| 1.3.2 | 2024-12-20 | [msprof_analyze-1.3.2-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/profiler/package/1.3.2/msprof_analyze-1.3.2-py3-none-any.whl) | ceb227e751ec3a204135be13801f1deee6a66c347f1bb3cdaef596872874df06 | -| 1.3.1 | 2024-12-04 | [msprof_analyze-1.3.1-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/profiler/package/1.3.1/msprof_analyze-1.3.1-py3-none-any.whl) | eae5548804314110a649caae537f2c63320fc70ec41ce1167f67c1d674d8798e | -| 1.3.0 | 2024-10-12 | [msprof_analyze-1.3.0-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/profiler/package/1.3.0/msprof_analyze-1.3.0-py3-none-any.whl) | 8b09758c6b5181bb656a95857c32852f898c370e7f1041e5a08e4f10d5004d48 | -| 1.2.5 | 2024-09-25 | [msprof_analyze-1.2.5-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/profiler/package/1.2.5/msprof_analyze-1.2.5-py3-none-any.whl) | aea8ae8deac07b5b4980bd2240da27d0eec93b9ace9ea9eb2e3a05ae9072018b | -| 1.2.4 | 2024-09-19 | [msprof_analyze-1.2.4-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/profiler/package/1.2.4/msprof_analyze-1.2.4-py3-none-any.whl) | 7c392e72c3347c4034fd3fdfcccb1f7936c24d9c3eb217e2cc05bae1347e5ab7 | -| 1.2.3 | 2024-08-29 | [msprof_analyze-1.2.3-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/profiler/package/1.2.3/msprof_analyze-1.2.3-py3-none-any.whl) | 354a55747f64ba1ec6ee6fe0f05a53e84e1b403ee0341ec40cc216dd25fda14c | -| 1.2.2 | 2024-08-23 | [msprof_analyze-1.2.2-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/profiler/package/1.2.2/msprof_analyze-1.2.2-py3-none-any.whl) | ed92a8e4eaf5ada8a2b4079072ec0cc42501b1b1f2eb00c8fdcb077fecb4ae02 | -| 1.2.1 | 2024-08-14 | [msprof_analyze-1.2.1-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/profiler/package/1.2.1/msprof_analyze-1.2.1-py3-none-any.whl) | 7acd477417bfb3ea29029dadf175d019ad3212403b7e11dc1f87e84c2412c078 | -| 1.2.0 | 2024-07-25 | [msprof_analyze-1.2.0-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/profiler/package/1.2.0/msprof_analyze-1.2.0-py3-none-any.whl) | 6a4366e3beca40b4a8305080e6e441d6ecafb5c05489e5905ac0265787555f37 | -| 1.1.2 | 2024-07-12 | [msprof_analyze-1.1.2-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/profiler/package/1.1.2/msprof_analyze-1.1.2-py3-none-any.whl) | af62125b1f9348bf491364e03af712fc6d0282ccee3fb07458bc9bbef82dacc6 | -| 1.1.1 | 2024-06-20 | [msprof_analyze-1.1.1-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/profiler/package/1.1.1/msprof_analyze-1.1.1-py3-none-any.whl) | 76aad967a3823151421153d368d4d2f8e5cfbcb356033575e0b8ec5acea8e5e4 | -| 1.1.0 | 2024-05-28 | [msprof_analyze-1.1.0-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/profiler/package/1.1.0/msprof_analyze-1.1.0-py3-none-any.whl) | b339f70e7d1e45e81f289332ca64990a744d0e7ce6fdd84a8d82e814fa400698 | -| 1.0 | 2024-05-10 | [msprof_analyze-1.0-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/profiler/package/1.0/msprof_analyze-1.0-py3-none-any.whl) | 95b2f41c8c8e8afe4887b738c8cababcb4f412e1874483b6adae4a025fcbb7d4 | - -2. whl包校验。 - - 1. 根据以上下载链接下载whl包到Linux安装环境。 - - 2. 进入whl包所在目录,执行如下命令。 - - ``` - sha256sum {name}.whl - ``` - - {name}为whl包名称。 - - 若回显呈现对应版本whl包一致的**校验码**,则表示下载了正确的性能工具whl安装包。示例如下: - - ```bash - sha256sum msprof_analyze-1.0-py3-none-any.whl - xx *msprof_analyze-1.0-py3-none-any.whl - ``` - -3. whl包安装。 - - 执行如下命令进行安装。 - - ```bash - pip3 install ./msprof_analyze-{version}-py3-none-any.whl - ``` - - 提示如下信息则表示安装成功。 - - ```bash - Successfully installed msprof_analyze-{version} - ``` - -### 源代码编译安装 - -1. 安装依赖。 - - 编译前需要安装wheel。 - - ```bash - pip3 install wheel - ``` - -2. 下载源码。 - - ```bash - git clone https://gitee.com/ascend/mstt.git - ``` - -3. 编译whl包。 - - ```bash - cd mstt/profiler/msprof_analyze - pip3 install -r requirements.txt && python3 setup.py bdist_wheel - ``` - - 以上命令执行完成后在mstt/profiler/msprof_analyze/dist目录下生成性能工具whl安装包`msprof_analyze-{version}-py3-none-any.whl`。 - -4. 安装。 - - 执行如下命令进行性能工具安装。 - - ```bash - cd dist - pip3 install ./msprof_analyze-{version}-py3-none-any.whl - ``` - -## 卸载和更新 - -若需要更新工具,请先卸载旧版本后再重新安装新版本,如下操作: - -1. 卸载 - - ```bash - pip3 uninstall msprof-analyze - ``` - -2. 更新 - - ```bash - pip3 install ./msprof_analyze-{version}-py3-none-any.whl - ``` - -## 工具使用 - -```bash -msprof-analyze advisor [-h] -``` - -```bash -msprof-analyze compare [-h] -``` - -```bash -msprof-analyze cluster [-h] -``` - -```bash -msprof-analyze auto-completion [-h] -``` - -``` -msprof-analyze [-h] [-v] -``` - -| 参数 | 说明 | -| -------------------- | ------------------------------------------------------------ | -| advisor | [advisor](./advisor/README.md)。将Ascend PyTorch Profiler或者msprof采集的PyThon场景性能数据进行分析,并输出性能调优建议。 | -| compare | [compare_tools(性能比对工具)](./compare_tools/README.md)。提供NPU与GPU性能拆解功能以及算子、通信、内存性能的比对功能。 | -| cluster | [cluster_analyse(集群分析工具)](./cluster_analyse/README.md)。提供多机多卡的集群分析能力(基于通信域的通信分析和迭代耗时分析), 当前需要配合Ascend Insight的集群分析功能使用。 | -| auto-completion | 自动补全。配置后在当前视图下配置msprof-analyze工具所有的子参数时,可以使用Tab将所有子参数自动补全。 | -| -v,-V
    --version | 查看版本号。 | -| -h,-H
    --help | 命令行参数帮助信息。 | - +# 性能工具 + +MindStudio Training Tools工具针对训练&大模型场景,提供端到端性能调优工具msprof-analyze:用户采集到性能数据后,由MindStudio Training Tools的性能工具msprof-analyze提供统计、分析以及相关的调优建议。 + +## NPU性能数据采集 + +目前MindStudio Training Tools工具主要支持对Ascend PyTorch Profiler接口采集的性能数据进行分析,请参考官方文档:[Ascend PyTorch Profiler数据采集与分析](https://www.hiascend.com/document/detail/zh/canncommercial/80RC1/devaids/auxiliarydevtool/atlasprofiling_16_0006.html)。 + +### 环境和依赖 + +- 硬件环境请参见《[昇腾产品形态说明](https://gitee.com/link?target=https%3A%2F%2Fwww.hiascend.com%2Fdocument%2Fdetail%2Fzh%2Fcanncommercial%2F80RC22%2Fquickstart%2Fquickstart%2Fquickstart_18_0002.html)》。 +- 软件环境请参见《[CANN 软件安装指南](https://gitee.com/link?target=https%3A%2F%2Fwww.hiascend.com%2Fdocument%2Fdetail%2Fzh%2Fcanncommercial%2F80RC22%2Fsoftwareinst%2Finstg%2Finstg_0000.html%3FMode%3DPmIns%26OS%3DUbuntu%26Software%3DcannToolKit)》安装昇腾设备开发或运行环境,即toolkit软件包。 + +以上环境依赖请根据实际环境选择适配的版本。 + +### 版本配套说明 + +- Ascend PyTorch Profiler接口支持AscendPyTorch 1.11.0或更高版本,支持的PyTorch和CANN以及PyTorch和Python软件版本配套关系请参见《[Ascend Extension for PyTorch插件](https://gitee.com/ascend/pytorch)》。 +- Ascend PyTorch Profiler接口支持的固件驱动版本与配套CANN软件支持的固件驱动版本相同,开发者可通过“[昇腾社区-固件与驱动](https://gitee.com/link?target=https%3A%2F%2Fwww.hiascend.com%2Fhardware%2Ffirmware-drivers%2Fcommunity%3Fproduct%3D2%26model%3D28%26cann%3D8.0.RC3.alpha003%26driver%3D1.0.25.alpha)”页面根据产品型号与CANN软件版本获取配套的固件与驱动。 + +### 采集方式一:通过with语句进行采集 + +```python +import torch_npu +experimental_config = torch_npu.profiler._ExperimentalConfig( + aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization, + profiler_level=torch_npu.profiler.ProfilerLevel.Level1, + l2_cache=False +) +with torch_npu.profiler.profile( + activities=[ + torch_npu.profiler.ProfilerActivity.CPU, + torch_npu.profiler.ProfilerActivity.NPU + ], + record_shapes=True, + profile_memory=True, + with_stack=True, + experimental_config=experimental_config, + schedule=torch_npu.profiler.schedule(wait=10, warmup=0, active=1, repeat=1), + on_trace_ready=torch_npu.profiler.tensorboard_trace_handler("./profiling_data") +) as prof: + # 模型训练代码 + for epoch, data in enumerate(dataloader): + train_model_one_step(model, data) + prof.step() +``` + +### 采集方式二:start,stop方式进行采集 + +```python +import torch_npu +experimental_config = torch_npu.profiler._ExperimentalConfig( + aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization, + profiler_level=torch_npu.profiler.ProfilerLevel.Level1, + l2_cache=False +) +prof = torch_npu.profiler.profile( + activities=[ + torch_npu.profiler.ProfilerActivity.CPU, + torch_npu.profiler.ProfilerActivity.NPU + ], + record_shapes=True, + profile_memory=True, + with_stack=True, + experimental_config=experimental_config, + on_trace_ready=torch_npu.profiler.tensorboard_trace_handler("./profiling_data")) +# 模型训练代码 +for epoch, data in enumerate(dataloader): + if epoch == 11: + prof.start() + train_model_one_step(model, data) + prof.step() + if epoch == 11: + prof.stop() +``` + +### NPU性能数据目录结构 + +ascend pytorch profiler数据目录结构如下: + +``` +|- ascend_pytorch_profiling + |- * _ascend_pt + |- ASCEND_PROFILER_OUTPUT + |- trace_view.json + |- FRAMEWORK + |- PROF_XXX + |- profiler_info.json + |- * _ascend_pt +``` + +## 安装 + +性能工具的安装方式包括:**pip安装**、**下载whl包安装**和**源代码编译安装**。 + +### pip安装 + +```shell +pip install msprof-analyze +``` + +使用`pip install msprof-analyze==版本号`可安装指定版本的包,支持1.2.1及之后版本,版本号参见“**下载whl包安装**”。 + +pip命令会自动安装最新的包及其配套依赖。 + +提示如下信息则表示安装成功。 + +```bash +Successfully installed msprof-analyze-{version} +``` + +### 下载whl包安装 + +1. whl包获取。 + + 请通过下表链接下载profiler工具whl包。 + +| profiler版本 | 发布日期 | 下载链接 | 校验码 | +|------------|------------|-------------------------------------------------------------------------------------------------------------------------------------------| ------------------------------------------------------------ | +| 2.0.0 | 2025-02-08 | [msprof_analyze-2.0.0-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/profiler/package/2.0.0/msprof_analyze-2.0.0-py3-none-any.whl) | 8e44e5f3e7681c377bb2657a600ad9841d3bed11061ddd7844c30e8a97242101 | +| 1.3.4 | 2025-01-20 | [msprof_analyze-1.3.4-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/profiler/package/1.3.4/msprof_analyze-1.3.4-py3-none-any.whl) | 8de92188d1a97105fb14cadcb0875ccd5f66629ee3bb25f37178da1906f4cce2 | +| 1.3.3 | 2024-12-26 | [msprof_analyze-1.3.3-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/profiler/package/1.3.3/msprof_analyze-1.3.3-py3-none-any.whl) | 27676f2eee636bd0c65243f81e292c7f9d30d7f985c772ac9cbaf10b54d3584e | +| 1.3.2 | 2024-12-20 | [msprof_analyze-1.3.2-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/profiler/package/1.3.2/msprof_analyze-1.3.2-py3-none-any.whl) | ceb227e751ec3a204135be13801f1deee6a66c347f1bb3cdaef596872874df06 | +| 1.3.1 | 2024-12-04 | [msprof_analyze-1.3.1-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/profiler/package/1.3.1/msprof_analyze-1.3.1-py3-none-any.whl) | eae5548804314110a649caae537f2c63320fc70ec41ce1167f67c1d674d8798e | +| 1.3.0 | 2024-10-12 | [msprof_analyze-1.3.0-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/profiler/package/1.3.0/msprof_analyze-1.3.0-py3-none-any.whl) | 8b09758c6b5181bb656a95857c32852f898c370e7f1041e5a08e4f10d5004d48 | +| 1.2.5 | 2024-09-25 | [msprof_analyze-1.2.5-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/profiler/package/1.2.5/msprof_analyze-1.2.5-py3-none-any.whl) | aea8ae8deac07b5b4980bd2240da27d0eec93b9ace9ea9eb2e3a05ae9072018b | +| 1.2.4 | 2024-09-19 | [msprof_analyze-1.2.4-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/profiler/package/1.2.4/msprof_analyze-1.2.4-py3-none-any.whl) | 7c392e72c3347c4034fd3fdfcccb1f7936c24d9c3eb217e2cc05bae1347e5ab7 | +| 1.2.3 | 2024-08-29 | [msprof_analyze-1.2.3-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/profiler/package/1.2.3/msprof_analyze-1.2.3-py3-none-any.whl) | 354a55747f64ba1ec6ee6fe0f05a53e84e1b403ee0341ec40cc216dd25fda14c | +| 1.2.2 | 2024-08-23 | [msprof_analyze-1.2.2-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/profiler/package/1.2.2/msprof_analyze-1.2.2-py3-none-any.whl) | ed92a8e4eaf5ada8a2b4079072ec0cc42501b1b1f2eb00c8fdcb077fecb4ae02 | +| 1.2.1 | 2024-08-14 | [msprof_analyze-1.2.1-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/profiler/package/1.2.1/msprof_analyze-1.2.1-py3-none-any.whl) | 7acd477417bfb3ea29029dadf175d019ad3212403b7e11dc1f87e84c2412c078 | +| 1.2.0 | 2024-07-25 | [msprof_analyze-1.2.0-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/profiler/package/1.2.0/msprof_analyze-1.2.0-py3-none-any.whl) | 6a4366e3beca40b4a8305080e6e441d6ecafb5c05489e5905ac0265787555f37 | +| 1.1.2 | 2024-07-12 | [msprof_analyze-1.1.2-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/profiler/package/1.1.2/msprof_analyze-1.1.2-py3-none-any.whl) | af62125b1f9348bf491364e03af712fc6d0282ccee3fb07458bc9bbef82dacc6 | +| 1.1.1 | 2024-06-20 | [msprof_analyze-1.1.1-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/profiler/package/1.1.1/msprof_analyze-1.1.1-py3-none-any.whl) | 76aad967a3823151421153d368d4d2f8e5cfbcb356033575e0b8ec5acea8e5e4 | +| 1.1.0 | 2024-05-28 | [msprof_analyze-1.1.0-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/profiler/package/1.1.0/msprof_analyze-1.1.0-py3-none-any.whl) | b339f70e7d1e45e81f289332ca64990a744d0e7ce6fdd84a8d82e814fa400698 | +| 1.0 | 2024-05-10 | [msprof_analyze-1.0-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/profiler/package/1.0/msprof_analyze-1.0-py3-none-any.whl) | 95b2f41c8c8e8afe4887b738c8cababcb4f412e1874483b6adae4a025fcbb7d4 | + +2. whl包校验。 + + 1. 根据以上下载链接下载whl包到Linux安装环境。 + + 2. 进入whl包所在目录,执行如下命令。 + + ``` + sha256sum {name}.whl + ``` + + {name}为whl包名称。 + + 若回显呈现对应版本whl包一致的**校验码**,则表示下载了正确的性能工具whl安装包。示例如下: + + ```bash + sha256sum msprof_analyze-1.0-py3-none-any.whl + xx *msprof_analyze-1.0-py3-none-any.whl + ``` + +3. whl包安装。 + + 执行如下命令进行安装。 + + ```bash + pip3 install ./msprof_analyze-{version}-py3-none-any.whl + ``` + + 提示如下信息则表示安装成功。 + + ```bash + Successfully installed msprof_analyze-{version} + ``` + +### 源代码编译安装 + +1. 安装依赖。 + + 编译前需要安装wheel。 + + ```bash + pip3 install wheel + ``` + +2. 下载源码。 + + ```bash + git clone https://gitee.com/ascend/mstt.git + ``` + +3. 编译whl包。 + + ```bash + cd mstt/profiler/msprof_analyze + pip3 install -r requirements.txt && python3 setup.py bdist_wheel + ``` + + 以上命令执行完成后在mstt/profiler/msprof_analyze/dist目录下生成性能工具whl安装包`msprof_analyze-{version}-py3-none-any.whl`。 + +4. 安装。 + + 执行如下命令进行性能工具安装。 + + ```bash + cd dist + pip3 install ./msprof_analyze-{version}-py3-none-any.whl + ``` + +## 卸载和更新 + +若需要更新工具,请先卸载旧版本后再重新安装新版本,如下操作: + +1. 卸载 + + ```bash + pip3 uninstall msprof-analyze + ``` + +2. 更新 + + ```bash + pip3 install ./msprof_analyze-{version}-py3-none-any.whl + ``` + +## 工具使用 + +```bash +msprof-analyze advisor [-h] +``` + +```bash +msprof-analyze compare [-h] +``` + +```bash +msprof-analyze cluster [-h] +``` + +```bash +msprof-analyze auto-completion [-h] +``` + +``` +msprof-analyze [-h] [-v] +``` + +| 参数 | 说明 | +| -------------------- | ------------------------------------------------------------ | +| advisor | [advisor](./advisor/README.md)。将Ascend PyTorch Profiler或者msprof采集的PyThon场景性能数据进行分析,并输出性能调优建议。 | +| compare | [compare_tools(性能比对工具)](./compare_tools/README.md)。提供NPU与GPU性能拆解功能以及算子、通信、内存性能的比对功能。 | +| cluster | [cluster_analyse(集群分析工具)](./cluster_analyse/README.md)。提供多机多卡的集群分析能力(基于通信域的通信分析和迭代耗时分析), 当前需要配合Ascend Insight的集群分析功能使用。 | +| auto-completion | 自动补全。配置后在当前视图下配置msprof-analyze工具所有的子参数时,可以使用Tab将所有子参数自动补全。 | +| -v,-V
    --version | 查看版本号。 | +| -h,-H
    --help | 命令行参数帮助信息。 | + diff --git a/profiler/msprof_analyze/advisor/README.md b/profiler/msprof_analyze/advisor/README.md index befdf89fbe9542c69b5ac0e94d163e11f34c4fad..2c9e055a119847134f08337c559d4012b4ea31fc 100644 --- a/profiler/msprof_analyze/advisor/README.md +++ b/profiler/msprof_analyze/advisor/README.md @@ -90,7 +90,6 @@ msprof-analyze advisor命令行包含如下三个参数: | | slow link | 慢链路识别 | PyTorch、MindSpore | | computation | AICPU Issues | AI CPU调优 | PyTorch、MindSpore | | | Operator Dynamic Shape Issues | 识别动态Shape算子 | PyTorch | -| | AI Core Performance analysis | MatMul、FlashAttentionScore、AI_VECTOR_CORE和MIX_AIV类算子的性能分析 | PyTorch | | | Block Dim | Block Dim算子调优 | PyTorch、MindSpore | | | Operator No Bound Issues | 算子瓶颈分析 | PyTorch、MindSpore | | | Fusion Issues | 融合算子图调优 | PyTorch、MindSpore | @@ -104,7 +103,6 @@ msprof-analyze advisor命令行包含如下三个参数: | | SyncBatchNorm Issues | BatchNorm同步检测 | PyTorch、MindSpore | | | Synchronize Stream Issues | 流同步检测 | PyTorch、MindSpore | | | GC Analysis | 识别异常垃圾回收事件。需要Ascend PyTorch Profiler采集时开启experimental_config下的gc_delect_threshold功能 | PyTorch | -| | Fusible Operator Analysis | 检测具有Host瓶颈或者MTE瓶颈的算子序列,可用于代码优化或开发可融合算子 | PyTorch、MindSpore | | dataloader | Slow Dataloader Issues | 异常dataloader检测 | PyTorch、MindSpore | | memory | Memory Operator Issues | 识别异常的内存申请释放操作 | PyTorch、MindSpore | | comparison | Kernel compare of Rank\* Step\* and Rank\* Step\* | 识别标杆和待比对性能数据的Kernel数据(无标杆场景是集群内部快慢卡的性能数据对比,有标杆场景是两个集群之间存在明显耗时差异的相同卡之间的性能数据对比) | PyTorch、MindSpore | @@ -235,7 +233,7 @@ communication模块从通信维度进行分析,目前支持通信小包检测 ![byte_alignment](/img/byte_alignment.png) -computation模块从device计算性能维度进行分析,能够识别AI CPU、动态Shape、AI Core Performance analysis、Dlock Dim、算子瓶颈、融合算子图、AI Core算子降频分析等问题并给出相应建议。此处不再详细展开,按照报告进行调优即可。示例如下: +computation模块从device计算性能维度进行分析,能够识别AI CPU、动态Shape、Dlock Dim、算子瓶颈、融合算子图、AI Core算子降频分析等问题并给出相应建议。此处不再详细展开,按照报告进行调优即可。示例如下: ![computation_1](./img/computation_1.png) @@ -243,8 +241,6 @@ computation模块从device计算性能维度进行分析,能够识别AI CPU、 ![op_no_bound](./img/op_no_bound.png) -![AI Core Performance analysis](./img/AI Core Performance analysis.png) - 上图中torch_npu.npu.set_compile_mode接口介绍请参见[torch_npu.npu.set_compile_mode](https://www.hiascend.com/document/detail/zh/Pytorch/60RC2/apiref/apilist/ptaoplist_000880.html);AICPU算子替换样例可参考《[Samples of AI CPU Operator Replacement](https://gitee.com/ascend/mstt/blob/master/profiler/msprof_analyze/advisor/doc/Samples%20of%20AI%20CPU%20Operator%20Replacement.md)》。 当存在pp stage(流水线并行)时,computation会按stage分析,每个stage就是一个流水线切分,比如0\~7卡为stage-0、8\~15卡为stage-1。 @@ -257,22 +253,7 @@ dataloader模块包含Slow Dataloader Issues,主要检测异常高耗时的dat 上图中的`pin_memory`(内存锁定)和`num_workers`(数据加载是子流程数量)参数为[数据加载优化](https://www.hiascend.com/document/detail/zh/Pytorch/60RC2/ptmoddevg/trainingmigrguide/performance_tuning_0019.html)使用。 -schedule模块包GC Analysis、含亲和API、aclOpCompile、SyncBatchNorm、SynchronizeStream和Fusible Operator Analysis等多项检测。 - -其中Fusible Operator Analysis解析结果仅打屏展示和保存在`mstt_advisor_{timestamp}.xlsx`文件中,包含“基于host瓶颈的算子序列分析”和“基于mte瓶颈的算子序列分析”页签,如下图: - -![Fusible Operator Analysis](/img/Fusible Operator Analysis.png) - -| 字段 | 说明 | -| ------------------ | ------------------------------------------------------------ | -| start index | 序列起始算子在kernel details.csv或op_summary.csv中索引位置(不包含表头,起始索引为0)。 | -| end index | 序列末尾算子在kernel details.csv或op_summary.csv中索引位置。 | -| total time(us) | 算子序列总耗时(包含算子间隙),单位us。 | -| execution time(us) | 序列中算子执行总耗时,单位us。 | -| mte time(us) | 序列中算子搬运总耗时,单位us。 | -| occurrences | 序列出现次数。 | -| mte bound | 是否为MTE瓶颈。 | -| host bound | 是否为Host瓶颈。 | +schedule模块包GC Analysis、含亲和API、aclOpCompile、SyncBatchNorm、SynchronizeStream等多项检测。 如下图示例,GC Analysis提示存在异常垃圾回收事件,用户可以通过有效的Python内存管理、使用`gc.set_threshold()`调整垃圾回收阈值、使用gc.disable()禁用gc等方法处理GC问题。 diff --git a/profiler/msprof_analyze/advisor/advisor_backend/cluster_advice/slow_link_advice.py b/profiler/msprof_analyze/advisor/advisor_backend/cluster_advice/slow_link_advice.py index 6d2a0638913d759817b091a013d7fbce9df09f63..2024adf8f6a020a5e09ce41949f9815831d7b563 100644 --- a/profiler/msprof_analyze/advisor/advisor_backend/cluster_advice/slow_link_advice.py +++ b/profiler/msprof_analyze/advisor/advisor_backend/cluster_advice/slow_link_advice.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy import os from collections import defaultdict from msprof_analyze.advisor.advisor_backend.common_func_advisor.constant import Constant @@ -42,7 +41,7 @@ class SlowLinkAdvice(ClusterAdviceBase): self.SDMA_TIME_MS: 0, self.SDMA_SIZE_MB: 0, } - self.rank_bw_dict = defaultdict(lambda: copy.deepcopy(default_value)) + self.rank_bw_dict = defaultdict(lambda: default_value.copy()) @staticmethod def compute_ratio(dividend: float, divisor: float): diff --git a/profiler/msprof_analyze/advisor/advisor_backend/common_func_advisor/constant.py b/profiler/msprof_analyze/advisor/advisor_backend/common_func_advisor/constant.py index 162a9fd2fdde15e02d2897106b43f52bca99bde1..077bf0074ccc5edc1bbf0814d2d3d72b1c5475e7 100644 --- a/profiler/msprof_analyze/advisor/advisor_backend/common_func_advisor/constant.py +++ b/profiler/msprof_analyze/advisor/advisor_backend/common_func_advisor/constant.py @@ -214,7 +214,7 @@ class CoreType: AICPU = "AI_CPU" MIX_AIV = "MIX_AIV" MIX_AIC = "MIX_AIC" - HCCL = "COMMUNICATION" + HCCL = "HCCL" class PerfColor(Enum): diff --git a/profiler/msprof_analyze/advisor/analyzer/analyzer_controller.py b/profiler/msprof_analyze/advisor/analyzer/analyzer_controller.py index bde9e5cd3454a85853a6fbcfbd0ade060ebc229b..d923ba978f8d8797ba0b41902fa24920afd61dad 100644 --- a/profiler/msprof_analyze/advisor/analyzer/analyzer_controller.py +++ b/profiler/msprof_analyze/advisor/analyzer/analyzer_controller.py @@ -1,947 +1,947 @@ -# Copyright (c) 2024, Huawei Technologies Co., Ltd. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import copy -import logging -import json -import sys -import os -import platform -import multiprocessing as mp -from multiprocessing import Manager -from pathlib import Path - -import psutil - -from msprof_analyze.prof_common.additional_args_manager import AdditionalArgsManager -from msprof_analyze.advisor.analyzer.cluster.slow_rank_analyzer import SlowRankAnalyzer -from msprof_analyze.advisor.analyzer.cluster.slow_link_analyzer import SlowLinkAnalyzer -from msprof_analyze.advisor.analyzer.computation.pp_stage_computation_analyzer import PPStageComputationAnalyzer -from msprof_analyze.advisor.analyzer.overall.overall_summary_analyzer import OverallSummaryAnalyzer -from msprof_analyze.advisor.config.config import Config -from msprof_analyze.advisor.common.analyzer_scopes import SupportedScopes -from msprof_analyze.advisor.common.async_analysis_status import AsyncAnalysisStatus -from msprof_analyze.advisor.common.enum_params_parser import EnumParamsParser -from msprof_analyze.advisor.utils.utils import Timer, safe_index_value, safe_division, safe_index, convert_to_int -from msprof_analyze.advisor.interface.interface import Interface -from msprof_analyze.cluster_analyse.cluster_data_preprocess.pytorch_data_preprocessor import PytorchDataPreprocessor -from msprof_analyze.cluster_analyse.cluster_data_preprocess.mindspore_data_preprocessor import MindsporeDataPreprocessor -from msprof_analyze.prof_common.path_manager import PathManager -from msprof_analyze.prof_common.constant import Constant - -# 以spawn模式启动多进程,避免fork主进程资源。如果主进程逻辑较为复杂,fork可能会导致异常。 -mp.set_start_method("spawn", force=True) -logger = logging.getLogger() - - -class AsyncParams: - """处理用户异步请求的输入参数,包括cli arguments和环境变量两类参数.""" - user_valid_arguments = {} - user_valid_envs = {} - user_non_enum_params = {} - user_invalid_values = [] - user_total_params = {} - - @staticmethod - def parse_async_list_params(key, value, option_values, key_type, value_type): - if isinstance(value, list): - value_list = value - else: - value_list = [_.strip(" ") for _ in str(value).split(",")] - - if sorted(value_list) not in [sorted(option) for option in option_values]: - AsyncParams.user_invalid_values.append( - {"key": key, "invalid value": value, "optional values": option_values, - "required value type": value_type}) - return - if key_type == EnumParamsParser.ENVS: - AsyncParams.user_valid_envs[key.upper()] = ",".join(value_list) - elif key_type == EnumParamsParser.ARGUMENTS: - AsyncParams.user_valid_arguments[key] = value_list - - @staticmethod - def parse_async_int_params(key, value, option_values, key_type, value_type): - if convert_to_int(value) not in option_values: - AsyncParams.user_invalid_values.append( - {"key": key, "invalid value": value, "optional values": option_values, - "required value type": value_type}) - return - - if key_type == EnumParamsParser.ENVS: - AsyncParams.user_valid_envs[key.upper()] = str(convert_to_int(value)) - elif key_type == EnumParamsParser.ARGUMENTS: - AsyncParams.user_valid_arguments[key] = convert_to_int(value) - - @staticmethod - def parse_async_str_params(key, value, option_values, key_type, value_type): - if str(value) not in option_values: - AsyncParams.user_invalid_values.append( - {"key": key, "invalid value": value, "optional values": option_values, - "required value type": value_type}) - return - if key_type == EnumParamsParser.ENVS: - AsyncParams.user_valid_envs[key.upper()] = str(value) - elif key_type == EnumParamsParser.ARGUMENTS: - AsyncParams.user_valid_arguments[key] = str(value) - - @staticmethod - def parse_async_boolean_params(key, value, option_values, key_type, value_type): - - if str(value).lower() not in ["true", "false"]: - AsyncParams.user_invalid_values.append( - {"key": key, "invalid value": value, "optional values": option_values, - "required value type": value_type}) - return - - if key_type == EnumParamsParser.ENVS: - AsyncParams.user_valid_envs[key.upper()] = str(value) - elif key_type == EnumParamsParser.ARGUMENTS: - AsyncParams.user_valid_arguments[key] = str(value).lower() == "true" - - @staticmethod - def parse_params(user_async_params): - params_parser = EnumParamsParser() - valid_env_keys = [key.lower() for key in params_parser.get_envs_keys()] - valid_arg_keys = [key.lower() for key in params_parser.get_arguments_keys()] - - for key, value in user_async_params.items(): - key = key.lower() - if key not in valid_env_keys + valid_arg_keys: - AsyncParams.user_non_enum_params[key] = value - continue - - if key in valid_env_keys: - # 环境变量均大写,异步调用入参到analyzer controller时支持用户使用小写配置环境变量 - option_values = params_parser.get_options(key.upper()) - value_type = params_parser.get_value_type(key.upper()) - key_type = params_parser.ENVS - else: - option_values = params_parser.get_options(key) - value_type = params_parser.get_value_type(key) - key_type = params_parser.ARGUMENTS - - if hasattr(AsyncParams, f"parse_async_{value_type}_params"): - getattr(AsyncParams, f"parse_async_{value_type}_params")(key, value, option_values, key_type, - value_type) - - AsyncParams.user_total_params["async_analysis_env"] = AsyncParams.user_valid_envs - AsyncParams.user_total_params.update(AsyncParams.user_valid_arguments) - AsyncParams.user_total_params.update(AsyncParams.user_non_enum_params) - - -class AnalyzerController: - CLUSTER_RANK_THRESHOLD = 2 - SDMA_SUPPORT_SCOPES = [SupportedScopes.BANDWIDTH_CONTENTION_DETECTION, SupportedScopes.BYTE_ALIGNMENT_DETECTION] - RDMA_SUPPORT_SCOPES = [SupportedScopes.PACKET] - COMMUNICATION_MAPPING = { - SlowLinkAnalyzer.SDMA: SDMA_SUPPORT_SCOPES, - SlowLinkAnalyzer.RDMA: RDMA_SUPPORT_SCOPES - } - - def __init__(self): - self.dimensions = Interface.all_dimension - self.kwargs = {} - self.args_manager = None - self.slow_rank_analyzer = None - self.slow_link_analyzer = None - self.cluster_local_data_map = {} - self.default_rank_id = None - self.rank_id_map = {} - self._is_cluster = False - self.analysis_process_resp = Manager().dict() - - @staticmethod - def _set_analysis_process_priority(pid): - # 将分析进程优先级设置为最低,避免因为分析进程阻塞其他任务进程,unix上19表示最低优先级 - unix_process_lowest_priority = 19 - windows_platform = "windows" - linux_platform = "linux" - p = psutil.Process(pid) - if platform.system().lower() == windows_platform: - p.nice(psutil.BELOW_NORMAL_PRIORITY_CLASS) - elif platform.system().lower() == linux_platform: - p.nice(unix_process_lowest_priority) - - @staticmethod - def _check_profiling_path_valid(profiling_path): - PathManager.input_path_common_check(profiling_path) - - if not Path(profiling_path).exists(): - logger.error("Profiling path is not existed. Invalid profiling path: %s", profiling_path) - return False - - return True - - - @staticmethod - def _get_step_rank_for_cluster_statistic_diff(target_cluster_statistic_data, benchmark_cluster_statistic_data, - headers, dimension, get_max=False): - if dimension not in headers: - logger.error("Error dimension %s for cluster statistics data, optionals are %s.", dimension, headers) - return None, None, None - - dimension_index = safe_index_value(headers, dimension) - diff_record = [] - # 对比目标profiling和benchmark profiling 每张卡的计算和下发和带宽,取计算、下发、带宽差异最大的卡进行下一步分析 - for target_row_data, benchmark_row_data in zip(target_cluster_statistic_data, benchmark_cluster_statistic_data): - target_data = safe_index(target_row_data, dimension_index) - benchmark_data = safe_index(benchmark_row_data, dimension_index) - - if not isinstance(target_data, (int, float)) or not isinstance(benchmark_data, (int, float)): - continue - diff_record.append(target_data - benchmark_data) - - if SlowRankAnalyzer.compute_max_gap_ratio(diff_record, safe_division(sum(diff_record), len( - diff_record))) < SlowRankAnalyzer.RATIO_THRESHOLD: - return None, None, None - - value = max(diff_record) if get_max else min(diff_record) - value_index = safe_index_value(diff_record, value) - - step_value_index = safe_index_value(headers, "step") - rank_id_value_index = safe_index_value(headers, "rank_id") - - step = safe_index(safe_index(target_cluster_statistic_data, value_index, []), step_value_index) - benchmark_step = safe_index(safe_index(benchmark_cluster_statistic_data, value_index, []), step_value_index) - target_rank_id = safe_index(safe_index(target_cluster_statistic_data, value_index, []), rank_id_value_index) - benchmark_rank_id = safe_index(safe_index(benchmark_cluster_statistic_data, value_index, []), - rank_id_value_index) - - if target_rank_id != benchmark_rank_id: - logger.error( - "Rank ids of target profiling must keep the same as benchmark profiling, skip cluster comparison") - return None, None, None - - return step, benchmark_step, target_rank_id - - @staticmethod - def _init_async_analysis_env(kwargs): - envs = kwargs.get("async_analysis_env", {}) - for key, value in envs.items(): - os.environ[key] = value - - def format_async_analysis_params(self, pid, async_resp, dimensions, kwargs): - - AsyncParams.parse_params(kwargs) - dimensions = AsyncParams.user_total_params.get("analysis_dimensions") or dimensions - - if AsyncParams.user_invalid_values: - error_msg = "Got invalid arguments as follows: \n " - for index, invalid_value in enumerate(AsyncParams.user_invalid_values): - error_msg += f"{index + 1}. Key '{invalid_value.get('key')}', " \ - f"invalid value '{invalid_value.get('invalid value')}', " \ - f"optional valid values '{invalid_value.get('optional values')}', " \ - f"required value type '{invalid_value.get('required value type')}'.\n " - self._update_analysis_process_resp(pid, async_resp, error_msg=error_msg, - status_code=AsyncAnalysisStatus.BAD_REQUEST_STATUS_CODE, - status=AsyncAnalysisStatus.FAILED) - raise ValueError(error_msg) - - logger.warning("User parameters for async analysis is as follows:\n %s", - json.dumps(AsyncParams.user_total_params, indent=4)) - return dimensions, AsyncParams.user_total_params - - def do_analysis(self, dimensions, **kwargs): - pid = os.getpid() - resp = {"id": pid} - self.args_manager = AdditionalArgsManager() - self.args_manager.init(kwargs) - output_path = kwargs.get("output_path") - - AnalyzerController._set_analysis_process_priority(pid) - if kwargs.get("is_async_analysis"): - del kwargs["is_async_analysis"] - dimensions, kwargs = self.format_async_analysis_params(pid, resp, dimensions, kwargs) - AnalyzerController._init_async_analysis_env(kwargs) - - try: - if output_path: - - PathManager.check_input_directory_path(output_path) - if os.path.exists(output_path): - PathManager.check_path_owner_consistent([output_path]) - else: - PathManager.make_dir_safety(output_path) - - Config().set_config("_work_path", output_path) - Config().set_log_path(f"mstt_advisor_{Timer().strftime}.xlsx") - - self._do_analysis(dimensions, pid=pid, async_resp=resp, **kwargs) - except Exception as e: - self._update_analysis_process_resp(pid, resp, status_code=AsyncAnalysisStatus.INNER_ERROR_STATUS_CODE, - status=AsyncAnalysisStatus.FAILED, error_msg=str(e)) - logger.error(e) - raise RuntimeError("Do analysis error.") from e - - def async_do_analysis(self, dimensions, **kwargs): - """ Deploy a online service to start async analysis job, wrap this api by flask or tornado and so on, - then could query the analysis status by restful api. - You can view file 'profiler/msprof_analyze/advisor/config/enum_parameters.yaml' to obtain detailed - information for all the args listed below. - - Args: - dimensions: analysis dimension, normally set as Interface.all_dimension, support specific dimension analysis - such as ['computation'] or ['computation', 'schedule'] - cann_version: cann version of your runtime, inpact on the analysis of affinity api and AICPU operators - profiling_type: profiling type of your runtime - profiling_version: profiling version of your runtime, inpact on the analysis of affinity api - analysis_dimensions: can overwite dimensions. - advisor_analyze_processes: number of processes to use while the training params pipeline parallel(pp) >1, - can reduce the time of analysis. - disable_profiling_comparison: disable comparison of operators(including npu computation operator and - cpu torch aten operator), can reduce the time of analysis. - disable_affinity_api: disable analysis of affinity api, normally set as 'True' while you training job - has been trained on NPU for a long time and suddenly shows performance degradation. - output_path: analysis output path(including html and xlsx). - - Example: - >>> # initialize a global analyzer controller - >>> analyzer_controller = AnalyzerController() - >>> analysis_kwargs = dict(advisor_analyze_processes=2, disable_profiling_comparison=True) - >>> - >>> async_analysis_process = analyzer_controller.async_do_analysis( - >>> Interface.all_dimension, **analysis_kwargs) - >>> - >>> - >>> # query the job status every second - >>> while True: - >>> response = analyzer_controller.get_response_by_pid(async_analysis_process.pid) - >>> print(f'analysis response is {response}') - >>> if response.get("status") in ["success", "failed"]: - >>> async_analysis_process.join() - >>> break - >>> time.sleep(1) - """ - kwargs["is_async_analysis"] = True - - async_analysis_process = mp.Process(target=self.do_analysis, args=(dimensions,), kwargs=kwargs, - name="Async advisor performance analysis") - async_analysis_process.start() - self._update_analysis_process_resp(async_analysis_process.pid, {"id": async_analysis_process.pid}, - status_code=AsyncAnalysisStatus.NON_FAILED_STATUS_CODE, - status=AsyncAnalysisStatus.ANALYZING) - return async_analysis_process - - def get_response_by_pid(self, pid): - def _is_pid_exists(pid): - try: - psutil.Process(pid) - return True - except psutil.NoSuchProcess: - return False - - pid_not_exist_response = dict(id=pid, status_code=AsyncAnalysisStatus.NOT_FOUND_STATUS_CODE, - status=AsyncAnalysisStatus.FAILED, - error_msg="The advisor task id does not exist") - if pid not in self.analysis_process_resp: - return pid_not_exist_response - - response = self.analysis_process_resp.get(pid) - if response.get("status") not in [AsyncAnalysisStatus.FAILED, - AsyncAnalysisStatus.SUCCESS] and not _is_pid_exists(pid): - return pid_not_exist_response - return response - - def single_rank_analysis(self, profiling_path, benchmark_profiling_path=None): - job_list = [] - - profiling_path = self._get_profiling_path_by_rank(profiling_path) - benchmark_profiling_path = self._get_profiling_path_by_rank(benchmark_profiling_path) - - # 单卡场景无集群分析 - for dim in [Interface.CLUSTER]: - if dim in self.dimensions: - self.dimensions.remove(dim) - - for dimension in self.dimensions: - dimension_analysis_func_name = f"{dimension}_analysis" - if not hasattr(self, dimension_analysis_func_name): - continue - logger.info("Start %s analysis", dimension) - job_list += getattr(self, dimension_analysis_func_name)(profiling_path) - - if benchmark_profiling_path: - # kernel/api 比对 - compare_profiling_list = [ - dict(profiling_path=profiling_path, benchmark_profiling_path=benchmark_profiling_path, - compare_mode=Constant.KERNEL_COMPARE), - dict(profiling_path=profiling_path, benchmark_profiling_path=benchmark_profiling_path, - compare_mode=Constant.API_COMPARE) - ] - - job_list += self._profiling_comparison(compare_profiling_list) - else: - self.overall(profiling_path) - - return job_list - - def do_cluster_analysis(self, profiling_path, benchmark_profiling_path=None): - job_list = [] - - # 单集群profiling分析:下发、通信、计算、显存/内存 - for dimension in self.dimensions: - dimension_analysis_func_name = f"cluster_{dimension}_analysis" - if not hasattr(self, dimension_analysis_func_name): - continue - logger.info("Start cluster %s analysis", dimension) - job_list += getattr(self, dimension_analysis_func_name)(profiling_path) - - self.overall(profiling_path) - - if benchmark_profiling_path: - # 两个集群profiling比对分析 - job_list += self._cluster_profiling_comparison(profiling_path, benchmark_profiling_path) - return job_list - - def overall(self, profiling_path): - from msprof_analyze.advisor.analyzer.overall.environment_variable_analyzer import EnvironmentVariableAnalyzer - env_analyzer = EnvironmentVariableAnalyzer(profiling_path) - env_analyzer.optimize() - - if self._is_cluster: - self.slow_rank_analyzer.optimize(template_key=Interface.OVERALL) - self.slow_link_analyzer.optimize(template_key=Interface.OVERALL) - else: - overall_analyzer = OverallSummaryAnalyzer(profiling_path) - overall_analyzer.optimize() - - def schedule_analysis(self, profiling_path, benchmark_profiling_path=None, step=None, benchmark_step=None, - **kwargs): - # 任意单卡的下发分析 - - input_kwargs = copy.deepcopy(self.kwargs) - job_list = [] - - input_kwargs["profiling_path"] = profiling_path - input_kwargs["benchmark_profiling_path"] = benchmark_profiling_path - input_kwargs["step"] = step - input_kwargs["benchmark_step"] = benchmark_step - input_kwargs["rank"] = kwargs.get("rank") - input_kwargs["step_duration"] = kwargs.get("step_duration") - - for dimension in [Interface.SCHEDULE]: - for scope in Interface.get_scope(dimension): - interface = Interface(**input_kwargs) - job_list.append((dimension, scope, interface, input_kwargs)) - return job_list - - def computation_analysis(self, profiling_path, benchmark_profiling_path=None, step=None, - benchmark_step=None, stage=None, **kwargs): - # 任意单卡的计算分析 - - input_kwargs = copy.deepcopy(self.kwargs) - input_kwargs["profiling_path"] = profiling_path - input_kwargs["benchmark_profiling_path"] = benchmark_profiling_path - input_kwargs["step"] = step - input_kwargs["benchmark_step"] = benchmark_step - input_kwargs["stage"] = stage - input_kwargs["rank"] = kwargs.get("rank") - input_kwargs["step_duration"] = kwargs.get("step_duration") - job_list = [] - - for dimension in [Interface.COMPUTATION]: - for scope in Interface.get_scope(dimension): - if scope == SupportedScopes.STAGE_COMPUTE: - continue - interface = Interface(**input_kwargs) - job_list.append((dimension, scope, interface, input_kwargs)) - return job_list - - def memory_analysis(self, profiling_path, benchmark_profiling_path=None, step=None, benchmark_step=None, **kwargs): - # 任意单卡的内存分析 - - input_kwargs = copy.deepcopy(self.kwargs) - job_list = [] - - input_kwargs["profiling_path"] = profiling_path - input_kwargs["benchmark_profiling_path"] = benchmark_profiling_path - input_kwargs["step"] = step - input_kwargs["benchmark_step"] = benchmark_step - input_kwargs["rank"] = kwargs.get("rank") - input_kwargs["step_duration"] = kwargs.get("step_duration") - - for dimension in [Interface.MEMORY]: - for scope in Interface.get_scope(dimension): - interface = Interface(**input_kwargs) - job_list.append((dimension, scope, interface, input_kwargs)) - return job_list - - def communication_analysis(self, profiling_path, benchmark_profiling_path=None, **kwargs): - - job_list = [] - supported_trans_type = [SlowLinkAnalyzer.SDMA, SlowLinkAnalyzer.RDMA] - step = kwargs.get("step", None) - benchmark_step = kwargs.get("benchmark_step", None) - bandwidth_type = kwargs.get("bandwidth_type", None) - scope = kwargs.get("scope", None) - if bandwidth_type is not None and bandwidth_type not in supported_trans_type: - logger.error("Error transit type %s, optionals are %s", bandwidth_type, supported_trans_type) - return job_list - - job_list += self._communication_analysis(profiling_path=profiling_path, - benchmark_profiling_path=benchmark_profiling_path, - step=step, benchmark_step=benchmark_step, - scope=scope, bandwidth_type=bandwidth_type) - - return job_list - - def cluster_schedule_analysis(self, profiling_path): - # 目标集群profiling数据下发分析,不包含两个集群profiling数据的比对分析 - - job_list = [] - global_step_rank = self.slow_rank_analyzer.get_global_step_rank(SlowRankAnalyzer.FREE) - - info_msg = "For cluster schedule analysis, " - slow_rank_id = global_step_rank.get("maximum", {}).get("rank_id") - if slow_rank_id is not None: - info_msg += f"maximum free for rank {slow_rank_id}" - else: - slow_rank_id = self.default_rank_id - info_msg += f"no slow rank with free time, analysis for default rank {slow_rank_id}" - - fast_rank_id = global_step_rank.get("minimum", {}).get("rank_id") - - slow_step = global_step_rank.get("maximum", {}).get("step") - fast_step = global_step_rank.get("minimum", {}).get("step") - - if slow_step is not None: - info_msg += f" and step {slow_step}" - logger.info(info_msg) - - kwargs = dict(profiling_path=self._get_profiling_path_by_rank(profiling_path, slow_rank_id), - benchmark_profiling_path=self._get_profiling_path_by_rank(profiling_path, fast_rank_id), - step=slow_step, benchmark_step=fast_step, - rank=slow_rank_id, benchmark_rank=fast_rank_id, - compare_mode=Constant.API_COMPARE, - step_duration=self.slow_rank_analyzer.get_step_duration(slow_rank_id, slow_step)) - - job_list += self.schedule_analysis(**kwargs) - - rank_id_valid = slow_rank_id is not None and fast_rank_id is not None and fast_rank_id != slow_rank_id - if not self.kwargs.get("benchmark_profiling_path") and rank_id_valid: - # 当用户指定benchmark profiling path时,不进行目标集群profiling的内部快慢卡对比 - logger.info("Enable schedule comparison of fast and slow rank/step") - job_list += self._profiling_comparison([kwargs]) - return job_list - - def cluster_communication_analysis(self, profiling_path): - job_list = [] - - for dimension in [Interface.COMMUNICATION]: - for scope in Interface.get_scope(dimension): - analyzer_class = Interface.get_analyzer(dimension, scope) - if hasattr(analyzer_class, "requires_cluster_dataset") and getattr(analyzer_class, - "requires_cluster_dataset"): - - # 如果不依赖数据集,或者依赖的是ClusterDataset,则不用根据带宽确定需要分析的特定rank - kwargs = copy.deepcopy(self.kwargs) - kwargs["profiling_path"] = profiling_path - interface = Interface(**kwargs) - job_list.append((dimension, scope, interface, kwargs)) - else: - # 非ClusterDataset场景,需要根据带宽大小分析特定的rank - for bandwidth_type in [SlowLinkAnalyzer.SDMA, SlowLinkAnalyzer.RDMA]: - global_step_rank = self.slow_link_analyzer.get_global_step_rank(bandwidth_type) - # 获取带宽最小的卡进行分析 - target_rank_id = global_step_rank.get("minimum", {}).get("rank_id") - if target_rank_id is None: - target_rank_id = self.default_rank_id - step = global_step_rank.get("minimum", {}).get("step") - analysis_profiling_path = self._get_profiling_path_by_rank(profiling_path, target_rank_id) - - info_msg = f"Minimum {bandwidth_type} bandwidth for rank {target_rank_id} " - if step: - info_msg += f"and step {step}" - logger.info(info_msg) - - job_list += self.communication_analysis(analysis_profiling_path, step=step, - bandwidth_type=bandwidth_type, scope=scope) - - return job_list - - def cluster_computation_analysis(self, profiling_path): - # 目标集群profiling数据计算分析,不包含两个集群profiling数据的比对分析;如果有pp stage,则对不同stage进行计算分析 - - job_list = [] - global_step_rank = self.slow_rank_analyzer.get_global_step_rank(SlowRankAnalyzer.COMPUTE) - stage_step_rank = self.slow_rank_analyzer.get_stage_step_rank(SlowRankAnalyzer.COMPUTE) - - if stage_step_rank: - job_list = self._stage_computation_analysis(profiling_path, stage_step_rank, job_list) - else: - job_list = self._global_computation_analysis(profiling_path, global_step_rank, job_list) - return job_list - - def cluster_memory_analysis(self, profiling_path): - # 目标集群profiling数据内存分析,当前memory识别的两个算子,导致的问题都是大的free,因此选择FREE最慢的卡进行分析 - - job_list = [] - global_step_rank = self.slow_rank_analyzer.get_global_step_rank(SlowRankAnalyzer.FREE) - - info_msg = "For cluster memory analysis, " - slow_rank_id = global_step_rank.get("maximum", {}).get("rank_id") - if slow_rank_id is not None: - info_msg += f"maximum free for rank {slow_rank_id}" - else: - slow_rank_id = self.default_rank_id - info_msg += f"no slow rank with free time, analysis for default rank {slow_rank_id}" - - slow_step = global_step_rank.get("maximum", {}).get("step") - if slow_step is not None: - info_msg += f" and step {slow_step}" - logger.info(info_msg) - - analysis_profiling_path = self._get_profiling_path_by_rank(profiling_path, slow_rank_id) - step_duration = self.slow_rank_analyzer.get_step_duration(slow_rank_id, slow_step) - job_list += self.memory_analysis(analysis_profiling_path, step=slow_step, rank=slow_rank_id, - step_duration=step_duration) - return job_list - - def _do_analysis(self, dimensions, pid=0, async_resp=None, **kwargs): - self.dimensions = dimensions - self.kwargs = kwargs - result_list = [] - profiling_path = PathManager.get_realpath(self.kwargs.get("profiling_path")) - benchmark_profiling_path = self.kwargs.get("benchmark_profiling_path") - PathManager.check_path_owner_consistent([profiling_path]) - if benchmark_profiling_path: - benchmark_profiling_path = PathManager.get_realpath(benchmark_profiling_path) - PathManager.check_path_owner_consistent([benchmark_profiling_path]) - - if not self._check_profiling_path_valid(profiling_path): - error_msg = f"Got invalid argument '-d/--profiling_path' {profiling_path}, skip analysis" - self._update_analysis_process_resp(pid, async_resp, error_msg=error_msg, - status_code=AsyncAnalysisStatus.BAD_REQUEST_STATUS_CODE, - status=AsyncAnalysisStatus.FAILED) - logger.error(error_msg) - return - - - if benchmark_profiling_path and not self._check_profiling_path_valid(benchmark_profiling_path): - error_msg = (f"Got invalid argument '-bp/--benchmark_profiling_path' {benchmark_profiling_path}, " - f"skip analysis") - self._update_analysis_process_resp(pid, async_resp, error_msg=error_msg, - status_code=AsyncAnalysisStatus.BAD_REQUEST_STATUS_CODE, - status=AsyncAnalysisStatus.FAILED) - logger.error(error_msg) - return - - self._is_cluster = self._is_cluster_profiling(profiling_path) - if benchmark_profiling_path: - # 构建benchmark profiling的map,用于根据rank获取profiling路径,否则无法进行比对 - is_benchmark_cluster = self._is_cluster_profiling(benchmark_profiling_path) - is_comparison_path_valid = (self._is_cluster and is_benchmark_cluster) or ( - not self._is_cluster and not is_benchmark_cluster) - if not is_comparison_path_valid: - error_msg = f"Only support profiling comparison for '1 npu vs 1 gpu/npu' and 'multi npus vs multi npus'" - self._update_analysis_process_resp(pid, async_resp, error_msg=error_msg, - status_code=AsyncAnalysisStatus.BAD_REQUEST_STATUS_CODE, - status=AsyncAnalysisStatus.FAILED) - logger.error(error_msg) - return - - if not self._is_cluster: - job_list = self.single_rank_analysis(profiling_path, benchmark_profiling_path) - else: - self.slow_rank_analyzer = SlowRankAnalyzer(profiling_path, output_path=self.kwargs.get("output_path")) - self.slow_link_analyzer = SlowLinkAnalyzer(profiling_path, output_path=self.kwargs.get("output_path")) - job_list = self.do_cluster_analysis(profiling_path, benchmark_profiling_path) - - for i, (dimension, scope, interface, kwargs) in enumerate(job_list[::-1]): - result_list.append( - interface.get_result(dimension, scope, render_html=i == len(job_list) - 1, output_dict=False, - **kwargs) - ) - - for result in result_list[::-1]: - if result and hasattr(result, "show"): - result.show() - break - self._get_analysis_finished_resp(pid, async_resp) - - def _get_scopes(self, scope=None, bandwidth_type=SlowLinkAnalyzer.SDMA): - """ - Args: - scope: analyzer type - bandwidth_type: analysis standard - Returns: - scope lists - """ - scopes = [] - if scope: - if scope in self.COMMUNICATION_MAPPING.get(bandwidth_type, self.SDMA_SUPPORT_SCOPES): - scopes.append(scope) - return scopes - for dimension in [Interface.COMMUNICATION]: - for scope_ in Interface.get_scope(dimension): - if scope_ in self.SDMA_SUPPORT_SCOPES or scope_ in self.RDMA_SUPPORT_SCOPES: - scopes.append(scope_) - return scopes - - def _communication_analysis(self, **child_kwargs): - kwargs = copy.deepcopy(self.kwargs) - job_list = [] - - kwargs["profiling_path"] = child_kwargs.get("profiling_path", "") - kwargs["benchmark_profiling_path"] = child_kwargs.get("benchmark_profiling_path", "") - kwargs["step"] = child_kwargs.get("step", -1) - kwargs["benchmark_step"] = child_kwargs.get("benchmark_step", -1) - bandwidth_type = child_kwargs.get("bandwidth_type", SlowLinkAnalyzer.SDMA) - scope = child_kwargs.get("scope", None) - - for scope_ in self._get_scopes(scope, bandwidth_type): - interface = Interface(**kwargs) - job_list.append((Interface.COMMUNICATION, scope_, interface, kwargs)) - - return job_list - - def _profiling_comparison(self, compare_profiling_list): - job_list = [] - disable_profiling_comparison = os.getenv(Constant.DISABLE_PROFILING_COMPARISON) - if disable_profiling_comparison is not None and disable_profiling_comparison.lower() == "true": - logger.info( - "Skip profiling comparison due to longer processing time due to env 'DISABLE_PROFILING_COMPARISON'") - return job_list - - for index, _kwargs in enumerate(compare_profiling_list): - kwargs = copy.deepcopy(self.kwargs) - kwargs.update(_kwargs) - compare_profiling_list[index] = kwargs - - compare_kwargs = { - "profiling_path": kwargs.get("profiling_path"), - "compare_profiling_list": compare_profiling_list, - } - - interface = Interface(**compare_kwargs) - job_list.append((Interface.COMPARISON, SupportedScopes.COMPARISON, interface, compare_kwargs)) - - return job_list - - def _cluster_profiling_comparison(self, profiling_path, benchmark_profiling_path): - # 从计算、下发和通信三个维度对集群profiling数据进行对比 - - job_list = [] - benchmark_profiling_path = self._get_profiling_path_by_rank(benchmark_profiling_path) - benchmark_slow_rank_analyzer = SlowRankAnalyzer(benchmark_profiling_path) - benchmark_slow_link_analyzer = SlowLinkAnalyzer(benchmark_profiling_path) - - # 计算和下发分析 - job_list += self._cluster_data_comparison(profiling_path, - benchmark_profiling_path, - self.slow_rank_analyzer, - benchmark_slow_rank_analyzer, - get_max=True) - - # 通信分析 - job_list += self._cluster_data_comparison(profiling_path, - benchmark_profiling_path, - self.slow_link_analyzer, - benchmark_slow_link_analyzer, - get_max=False) - return job_list - - def _cluster_data_comparison(self, profiling_path, benchmark_profiling_path, target_cluster_analyzer, - benchmark_cluster_analyzer, get_max=False): - # #low rank/slow link结果逐行对比获取差值最大的rank和step进行单卡分析 - job_list = [] - - if isinstance(target_cluster_analyzer, SlowRankAnalyzer): - comparison_dims = [SlowRankAnalyzer.COMPUTE, SlowRankAnalyzer.FREE] - comparison_modes = [Constant.KERNEL_COMPARE, Constant.API_COMPARE] - elif isinstance(target_cluster_analyzer, SlowLinkAnalyzer): - comparison_dims = [SlowLinkAnalyzer.SDMA_BANDWIDTH, SlowLinkAnalyzer.RDMA_BANDWIDTH] - comparison_modes = [None, None] - else: - return job_list - - target_data = target_cluster_analyzer.format_datas.get("data", []) - benchmark_data = benchmark_cluster_analyzer.format_datas.get("data", []) - headers = benchmark_cluster_analyzer.format_datas.get("headers", []) - - if len(target_data) != len(benchmark_data): - logger.warning( - "The product of ranks and steps of Benchmark profiling is not equals to target profiling, " - "skip cluster comparison.") - return job_list - - compare_profiling_list = [] - for dimension, compare_mode in zip(comparison_dims, comparison_modes): - step, benchmark_step, rank_id_for_comparison = AnalyzerController._get_step_rank_for_cluster_statistic_diff( - target_data, - benchmark_data, - headers, - dimension, - get_max=get_max - ) - - rank_profiling_path = self._get_profiling_path_by_rank(profiling_path, rank_id_for_comparison) - rank_benchmark_profiling_path = self._get_profiling_path_by_rank( - benchmark_profiling_path, - rank_id_for_comparison - ) - - if rank_id_for_comparison is None: - # rank id为空则无法获取对应rank的profiling路径,无法进行比较 - continue - - compare_profiling_list.append( - dict(profiling_path=rank_profiling_path, benchmark_profiling_path=rank_benchmark_profiling_path, - step=step, benchmark_step=benchmark_step, - rank=rank_id_for_comparison, benchmark_rank=rank_id_for_comparison, compare_mode=compare_mode) - ) - - if not compare_profiling_list: - return job_list - - job_list += self._profiling_comparison(compare_profiling_list) - return job_list - - def _is_cluster_profiling(self, profiling_path): - if os.path.isfile(profiling_path): - return False - path_list = [os.path.join(profiling_path, dir_name) for dir_name in os.listdir(profiling_path)] - ascend_pt_dirs = [path for path in path_list if os.path.isdir(path) and path.endswith("ascend_pt")] - ascend_ms_dirs = [path for path in path_list if os.path.isdir(path) and path.endswith("ascend_ms")] - if ascend_ms_dirs and ascend_pt_dirs: - logger.error("Cannot analyze pytorch and mindspore meantime.") - return False - if not ascend_pt_dirs and not ascend_ms_dirs: - return False - if ascend_ms_dirs and not ascend_pt_dirs: - data_processor = MindsporeDataPreprocessor(ascend_ms_dirs) - elif ascend_pt_dirs and not ascend_ms_dirs: - data_processor = PytorchDataPreprocessor(ascend_pt_dirs) - - self.cluster_local_data_map[profiling_path] = data_processor.get_data_map() - - if not self.cluster_local_data_map or not self.cluster_local_data_map.get(profiling_path): - return False - - self.default_rank_id = list(self.cluster_local_data_map[profiling_path].keys())[0] - - return len(self.cluster_local_data_map[profiling_path]) >= self.CLUSTER_RANK_THRESHOLD - - def _get_profiling_path_by_rank(self, profiling_path, rank_id=None): - - if not profiling_path: - return profiling_path - - return self._get_target_profiling_path_for_local(profiling_path, rank_id) - - def _get_target_profiling_path_for_local(self, profiling_path, rank_id): - rank_id_map = self.cluster_local_data_map.get(profiling_path, {}) - if rank_id is None or not rank_id_map: - return profiling_path - - if rank_id in rank_id_map: - return rank_id_map.get(rank_id) - - local_first_rank_id = sorted(list(map(int, rank_id_map.keys())))[0] - logger.warning("Target rank id %s does not exist in local profiling data %s, use rank %s for analysis", - rank_id, profiling_path, local_first_rank_id) - return rank_id_map.get(local_first_rank_id) - - def _update_analysis_process_resp(self, pid, resp, **kwargs): - if kwargs: - resp.update(kwargs) - self.analysis_process_resp[pid] = resp - - def _get_analysis_finished_resp(self, pid, resp): - advisor_output_file_prefix = f"mstt_advisor_{Timer().strftime}" - html_path = os.path.join(Config().work_path, f"{advisor_output_file_prefix}.html") - xlsx_path = os.path.join(Config().work_path, "log", f"{advisor_output_file_prefix}.xlsx") - if os.path.exists(html_path) and os.path.exists(xlsx_path): - result_files = {"html": html_path, "xlsx": xlsx_path} - self._update_analysis_process_resp(pid, resp, status_code=AsyncAnalysisStatus.NON_FAILED_STATUS_CODE, - status=AsyncAnalysisStatus.SUCCESS, result_files=result_files) - else: - self._update_analysis_process_resp(pid, resp, status_code=AsyncAnalysisStatus.BAD_REQUEST_STATUS_CODE, - status=AsyncAnalysisStatus.FAILED, - error_msg="No optimization suggestions, please check your input path.") - - def _stage_computation_analysis(self, profiling_path, stage_step_rank, job_list): - # 对不同pp stage取min max进行分析 - logger.info("Steps and ranks to be analyzed of different pipeline parallel stages are %s", - json.dumps(stage_step_rank)) - - stages_profiling_path = [] - for stage, step_rank_info in stage_step_rank.items(): - rank_id = step_rank_info.get("maximum", {}).get("rank_id") - step = step_rank_info.get("maximum", {}).get("step") - benchmark_rank_id = step_rank_info.get("minimum", {}).get("rank_id") - benchmark_step = step_rank_info.get("minimum", {}).get("step") - - info_msg = f"For {stage}, slow rank is {rank_id}" - if step: - info_msg += f", step is {step}" - logger.info(info_msg) - - stages_profiling_path.append( - dict( - stage=stage, rank=rank_id, step=step, benchmark_rank=benchmark_rank_id, - benchmark_step=benchmark_step, - profiling_path=self._get_profiling_path_by_rank(profiling_path, rank_id), - benchmark_profiling_path=self._get_profiling_path_by_rank(profiling_path, benchmark_rank_id), - compare_mode=Constant.KERNEL_COMPARE, - step_duration=self.slow_rank_analyzer.get_step_duration(rank_id, step) - ) - ) - Interface.add_analyzer(Interface.COMPUTATION, SupportedScopes.STAGE_COMPUTE, PPStageComputationAnalyzer) - compute_analysis_kwargs = {"stages_profiling_path": stages_profiling_path, "profiling_path": profiling_path} - - job_list.append((Interface.COMPUTATION, SupportedScopes.STAGE_COMPUTE, Interface(**compute_analysis_kwargs), - compute_analysis_kwargs)) - if not self.kwargs.get("benchmark_profiling_path"): - logger.info("Enable computation comparison of fast and slow rank/step in different pp stages") - job_list += self._profiling_comparison(stages_profiling_path) - return job_list - - def _global_computation_analysis(self, profiling_path, global_step_rank, job_list): - # 不区分stage,对所有卡取Min max进行分析 - logger.info("Without pipeline parallel stage, steps and ranks to be analyzed are %s", - json.dumps(global_step_rank)) - slow_rank_id = global_step_rank.get("maximum", {}).get("rank_id") - if slow_rank_id is not None: - info_msg = f"Maximum computation time for rank {slow_rank_id}" - else: - slow_rank_id = self.default_rank_id - info_msg = f"No slow rank with computation time, analysis for default rank {slow_rank_id}" - slow_step = global_step_rank.get("maximum", {}).get("step") - # 如果没有标杆profiling数据的rank id,说明没有快慢卡问题,直接对默认rank id进行分析,因此这里取值为None - fast_rank_id = global_step_rank.get("minimum", {}).get("rank_id") - fast_step = global_step_rank.get("minimum", {}).get("step") - - if slow_step is not None: - info_msg += f" and step {slow_step}, " - if fast_rank_id is not None: - info_msg += f"minimum computation time for rank {fast_rank_id}" - if fast_step is not None: - info_msg += f" and step {fast_step}" - logger.info(info_msg) - - kwargs = dict(profiling_path=self._get_profiling_path_by_rank(profiling_path, slow_rank_id), - benchmark_profiling_path=self._get_profiling_path_by_rank(profiling_path, fast_rank_id), - step=slow_step, benchmark_step=fast_step, rank=slow_rank_id, benchmark_rank=fast_rank_id, - compare_mode=Constant.KERNEL_COMPARE, - step_duration=self.slow_rank_analyzer.get_step_duration(slow_rank_id, slow_step)) - - job_list += self.computation_analysis(**kwargs) - - rank_id_valid = slow_rank_id is not None and fast_rank_id is not None and fast_rank_id != slow_rank_id - if not self.kwargs.get("benchmark_profiling_path") and rank_id_valid: - # 当用户指定benchmark profiling path时,不进行目标集群profiling的内部快慢卡对比 - logger.info("Enable computation comparison of fast and slow rank/step") - job_list += self._profiling_comparison([kwargs]) - return job_list +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +import logging +import json +import sys +import os +import platform +import multiprocessing as mp +from multiprocessing import Manager +from pathlib import Path + +import psutil + +from msprof_analyze.prof_common.additional_args_manager import AdditionalArgsManager +from msprof_analyze.advisor.analyzer.cluster.slow_rank_analyzer import SlowRankAnalyzer +from msprof_analyze.advisor.analyzer.cluster.slow_link_analyzer import SlowLinkAnalyzer +from msprof_analyze.advisor.analyzer.computation.pp_stage_computation_analyzer import PPStageComputationAnalyzer +from msprof_analyze.advisor.analyzer.overall.overall_summary_analyzer import OverallSummaryAnalyzer +from msprof_analyze.advisor.config.config import Config +from msprof_analyze.advisor.common.analyzer_scopes import SupportedScopes +from msprof_analyze.advisor.common.async_analysis_status import AsyncAnalysisStatus +from msprof_analyze.advisor.common.enum_params_parser import EnumParamsParser +from msprof_analyze.advisor.utils.utils import Timer, safe_index_value, safe_division, safe_index, convert_to_int +from msprof_analyze.advisor.interface.interface import Interface +from msprof_analyze.cluster_analyse.cluster_data_preprocess.pytorch_data_preprocessor import PytorchDataPreprocessor +from msprof_analyze.cluster_analyse.cluster_data_preprocess.mindspore_data_preprocessor import MindsporeDataPreprocessor +from msprof_analyze.prof_common.path_manager import PathManager +from msprof_analyze.prof_common.constant import Constant + +# 以spawn模式启动多进程,避免fork主进程资源。如果主进程逻辑较为复杂,fork可能会导致异常。 +mp.set_start_method("spawn", force=True) +logger = logging.getLogger() + + +class AsyncParams: + """处理用户异步请求的输入参数,包括cli arguments和环境变量两类参数.""" + user_valid_arguments = {} + user_valid_envs = {} + user_non_enum_params = {} + user_invalid_values = [] + user_total_params = {} + + @staticmethod + def parse_async_list_params(key, value, option_values, key_type, value_type): + if isinstance(value, list): + value_list = value + else: + value_list = [_.strip(" ") for _ in str(value).split(",")] + + if sorted(value_list) not in [sorted(option) for option in option_values]: + AsyncParams.user_invalid_values.append( + {"key": key, "invalid value": value, "optional values": option_values, + "required value type": value_type}) + return + if key_type == EnumParamsParser.ENVS: + AsyncParams.user_valid_envs[key.upper()] = ",".join(value_list) + elif key_type == EnumParamsParser.ARGUMENTS: + AsyncParams.user_valid_arguments[key] = value_list + + @staticmethod + def parse_async_int_params(key, value, option_values, key_type, value_type): + if convert_to_int(value) not in option_values: + AsyncParams.user_invalid_values.append( + {"key": key, "invalid value": value, "optional values": option_values, + "required value type": value_type}) + return + + if key_type == EnumParamsParser.ENVS: + AsyncParams.user_valid_envs[key.upper()] = str(convert_to_int(value)) + elif key_type == EnumParamsParser.ARGUMENTS: + AsyncParams.user_valid_arguments[key] = convert_to_int(value) + + @staticmethod + def parse_async_str_params(key, value, option_values, key_type, value_type): + if str(value) not in option_values: + AsyncParams.user_invalid_values.append( + {"key": key, "invalid value": value, "optional values": option_values, + "required value type": value_type}) + return + if key_type == EnumParamsParser.ENVS: + AsyncParams.user_valid_envs[key.upper()] = str(value) + elif key_type == EnumParamsParser.ARGUMENTS: + AsyncParams.user_valid_arguments[key] = str(value) + + @staticmethod + def parse_async_boolean_params(key, value, option_values, key_type, value_type): + + if str(value).lower() not in ["true", "false"]: + AsyncParams.user_invalid_values.append( + {"key": key, "invalid value": value, "optional values": option_values, + "required value type": value_type}) + return + + if key_type == EnumParamsParser.ENVS: + AsyncParams.user_valid_envs[key.upper()] = str(value) + elif key_type == EnumParamsParser.ARGUMENTS: + AsyncParams.user_valid_arguments[key] = str(value).lower() == "true" + + @staticmethod + def parse_params(user_async_params): + params_parser = EnumParamsParser() + valid_env_keys = [key.lower() for key in params_parser.get_envs_keys()] + valid_arg_keys = [key.lower() for key in params_parser.get_arguments_keys()] + + for key, value in user_async_params.items(): + key = key.lower() + if key not in valid_env_keys + valid_arg_keys: + AsyncParams.user_non_enum_params[key] = value + continue + + if key in valid_env_keys: + # 环境变量均大写,异步调用入参到analyzer controller时支持用户使用小写配置环境变量 + option_values = params_parser.get_options(key.upper()) + value_type = params_parser.get_value_type(key.upper()) + key_type = params_parser.ENVS + else: + option_values = params_parser.get_options(key) + value_type = params_parser.get_value_type(key) + key_type = params_parser.ARGUMENTS + + if hasattr(AsyncParams, f"parse_async_{value_type}_params"): + getattr(AsyncParams, f"parse_async_{value_type}_params")(key, value, option_values, key_type, + value_type) + + AsyncParams.user_total_params["async_analysis_env"] = AsyncParams.user_valid_envs + AsyncParams.user_total_params.update(AsyncParams.user_valid_arguments) + AsyncParams.user_total_params.update(AsyncParams.user_non_enum_params) + + +class AnalyzerController: + CLUSTER_RANK_THRESHOLD = 2 + SDMA_SUPPORT_SCOPES = [SupportedScopes.BANDWIDTH_CONTENTION_DETECTION, SupportedScopes.BYTE_ALIGNMENT_DETECTION] + RDMA_SUPPORT_SCOPES = [SupportedScopes.PACKET] + COMMUNICATION_MAPPING = { + SlowLinkAnalyzer.SDMA: SDMA_SUPPORT_SCOPES, + SlowLinkAnalyzer.RDMA: RDMA_SUPPORT_SCOPES + } + + def __init__(self): + self.dimensions = Interface.all_dimension + self.kwargs = {} + self.args_manager = None + self.slow_rank_analyzer = None + self.slow_link_analyzer = None + self.cluster_local_data_map = {} + self.default_rank_id = None + self.rank_id_map = {} + self._is_cluster = False + self.analysis_process_resp = Manager().dict() + + @staticmethod + def _set_analysis_process_priority(pid): + # 将分析进程优先级设置为最低,避免因为分析进程阻塞其他任务进程,unix上19表示最低优先级 + unix_process_lowest_priority = 19 + windows_platform = "windows" + linux_platform = "linux" + p = psutil.Process(pid) + if platform.system().lower() == windows_platform: + p.nice(psutil.BELOW_NORMAL_PRIORITY_CLASS) + elif platform.system().lower() == linux_platform: + p.nice(unix_process_lowest_priority) + + @staticmethod + def _check_profiling_path_valid(profiling_path): + PathManager.input_path_common_check(profiling_path) + + if not Path(profiling_path).exists(): + logger.error("Profiling path is not existed. Invalid profiling path: %s", profiling_path) + return False + + return True + + + @staticmethod + def _get_step_rank_for_cluster_statistic_diff(target_cluster_statistic_data, benchmark_cluster_statistic_data, + headers, dimension, get_max=False): + if dimension not in headers: + logger.error("Error dimension %s for cluster statistics data, optionals are %s.", dimension, headers) + return None, None, None + + dimension_index = safe_index_value(headers, dimension) + diff_record = [] + # 对比目标profiling和benchmark profiling 每张卡的计算和下发和带宽,取计算、下发、带宽差异最大的卡进行下一步分析 + for target_row_data, benchmark_row_data in zip(target_cluster_statistic_data, benchmark_cluster_statistic_data): + target_data = safe_index(target_row_data, dimension_index) + benchmark_data = safe_index(benchmark_row_data, dimension_index) + + if not isinstance(target_data, (int, float)) or not isinstance(benchmark_data, (int, float)): + continue + diff_record.append(target_data - benchmark_data) + + if SlowRankAnalyzer.compute_max_gap_ratio(diff_record, safe_division(sum(diff_record), len( + diff_record))) < SlowRankAnalyzer.RATIO_THRESHOLD: + return None, None, None + + value = max(diff_record) if get_max else min(diff_record) + value_index = safe_index_value(diff_record, value) + + step_value_index = safe_index_value(headers, "step") + rank_id_value_index = safe_index_value(headers, "rank_id") + + step = safe_index(safe_index(target_cluster_statistic_data, value_index, []), step_value_index) + benchmark_step = safe_index(safe_index(benchmark_cluster_statistic_data, value_index, []), step_value_index) + target_rank_id = safe_index(safe_index(target_cluster_statistic_data, value_index, []), rank_id_value_index) + benchmark_rank_id = safe_index(safe_index(benchmark_cluster_statistic_data, value_index, []), + rank_id_value_index) + + if target_rank_id != benchmark_rank_id: + logger.error( + "Rank ids of target profiling must keep the same as benchmark profiling, skip cluster comparison") + return None, None, None + + return step, benchmark_step, target_rank_id + + @staticmethod + def _init_async_analysis_env(kwargs): + envs = kwargs.get("async_analysis_env", {}) + for key, value in envs.items(): + os.environ[key] = value + + def format_async_analysis_params(self, pid, async_resp, dimensions, kwargs): + + AsyncParams.parse_params(kwargs) + dimensions = AsyncParams.user_total_params.get("analysis_dimensions") or dimensions + + if AsyncParams.user_invalid_values: + error_msg = "Got invalid arguments as follows: \n " + for index, invalid_value in enumerate(AsyncParams.user_invalid_values): + error_msg += f"{index + 1}. Key '{invalid_value.get('key')}', " \ + f"invalid value '{invalid_value.get('invalid value')}', " \ + f"optional valid values '{invalid_value.get('optional values')}', " \ + f"required value type '{invalid_value.get('required value type')}'.\n " + self._update_analysis_process_resp(pid, async_resp, error_msg=error_msg, + status_code=AsyncAnalysisStatus.BAD_REQUEST_STATUS_CODE, + status=AsyncAnalysisStatus.FAILED) + raise ValueError(error_msg) + + logger.warning("User parameters for async analysis is as follows:\n %s", + json.dumps(AsyncParams.user_total_params, indent=4)) + return dimensions, AsyncParams.user_total_params + + def do_analysis(self, dimensions, **kwargs): + pid = os.getpid() + resp = {"id": pid} + self.args_manager = AdditionalArgsManager() + self.args_manager.init(kwargs) + output_path = kwargs.get("output_path") + + AnalyzerController._set_analysis_process_priority(pid) + if kwargs.get("is_async_analysis"): + del kwargs["is_async_analysis"] + dimensions, kwargs = self.format_async_analysis_params(pid, resp, dimensions, kwargs) + AnalyzerController._init_async_analysis_env(kwargs) + + try: + if output_path: + + PathManager.check_input_directory_path(output_path) + if os.path.exists(output_path): + PathManager.check_path_owner_consistent([output_path]) + else: + PathManager.make_dir_safety(output_path) + + Config().set_config("_work_path", output_path) + Config().set_log_path(f"mstt_advisor_{Timer().strftime}.xlsx") + + self._do_analysis(dimensions, pid=pid, async_resp=resp, **kwargs) + except Exception as e: + self._update_analysis_process_resp(pid, resp, status_code=AsyncAnalysisStatus.INNER_ERROR_STATUS_CODE, + status=AsyncAnalysisStatus.FAILED, error_msg=str(e)) + logger.error(e) + raise RuntimeError("Do analysis error.") from e + + def async_do_analysis(self, dimensions, **kwargs): + """ Deploy a online service to start async analysis job, wrap this api by flask or tornado and so on, + then could query the analysis status by restful api. + You can view file 'profiler/msprof_analyze/advisor/config/enum_parameters.yaml' to obtain detailed + information for all the args listed below. + + Args: + dimensions: analysis dimension, normally set as Interface.all_dimension, support specific dimension analysis + such as ['computation'] or ['computation', 'schedule'] + cann_version: cann version of your runtime, inpact on the analysis of affinity api and AICPU operators + profiling_type: profiling type of your runtime + profiling_version: profiling version of your runtime, inpact on the analysis of affinity api + analysis_dimensions: can overwite dimensions. + advisor_analyze_processes: number of processes to use while the training params pipeline parallel(pp) >1, + can reduce the time of analysis. + disable_profiling_comparison: disable comparison of operators(including npu computation operator and + cpu torch aten operator), can reduce the time of analysis. + disable_affinity_api: disable analysis of affinity api, normally set as 'True' while you training job + has been trained on NPU for a long time and suddenly shows performance degradation. + output_path: analysis output path(including html and xlsx). + + Example: + >>> # initialize a global analyzer controller + >>> analyzer_controller = AnalyzerController() + >>> analysis_kwargs = dict(advisor_analyze_processes=2, disable_profiling_comparison=True) + >>> + >>> async_analysis_process = analyzer_controller.async_do_analysis( + >>> Interface.all_dimension, **analysis_kwargs) + >>> + >>> + >>> # query the job status every second + >>> while True: + >>> response = analyzer_controller.get_response_by_pid(async_analysis_process.pid) + >>> print(f'analysis response is {response}') + >>> if response.get("status") in ["success", "failed"]: + >>> async_analysis_process.join() + >>> break + >>> time.sleep(1) + """ + kwargs["is_async_analysis"] = True + + async_analysis_process = mp.Process(target=self.do_analysis, args=(dimensions,), kwargs=kwargs, + name="Async advisor performance analysis") + async_analysis_process.start() + self._update_analysis_process_resp(async_analysis_process.pid, {"id": async_analysis_process.pid}, + status_code=AsyncAnalysisStatus.NON_FAILED_STATUS_CODE, + status=AsyncAnalysisStatus.ANALYZING) + return async_analysis_process + + def get_response_by_pid(self, pid): + def _is_pid_exists(pid): + try: + psutil.Process(pid) + return True + except psutil.NoSuchProcess: + return False + + pid_not_exist_response = dict(id=pid, status_code=AsyncAnalysisStatus.NOT_FOUND_STATUS_CODE, + status=AsyncAnalysisStatus.FAILED, + error_msg="The advisor task id does not exist") + if pid not in self.analysis_process_resp: + return pid_not_exist_response + + response = self.analysis_process_resp.get(pid) + if response.get("status") not in [AsyncAnalysisStatus.FAILED, + AsyncAnalysisStatus.SUCCESS] and not _is_pid_exists(pid): + return pid_not_exist_response + return response + + def single_rank_analysis(self, profiling_path, benchmark_profiling_path=None): + job_list = [] + + profiling_path = self._get_profiling_path_by_rank(profiling_path) + benchmark_profiling_path = self._get_profiling_path_by_rank(benchmark_profiling_path) + + # 单卡场景无集群分析 + for dim in [Interface.CLUSTER]: + if dim in self.dimensions: + self.dimensions.remove(dim) + + for dimension in self.dimensions: + dimension_analysis_func_name = f"{dimension}_analysis" + if not hasattr(self, dimension_analysis_func_name): + continue + logger.info("Start %s analysis", dimension) + job_list += getattr(self, dimension_analysis_func_name)(profiling_path) + + if benchmark_profiling_path: + # kernel/api 比对 + compare_profiling_list = [ + dict(profiling_path=profiling_path, benchmark_profiling_path=benchmark_profiling_path, + compare_mode=Constant.KERNEL_COMPARE), + dict(profiling_path=profiling_path, benchmark_profiling_path=benchmark_profiling_path, + compare_mode=Constant.API_COMPARE) + ] + + job_list += self._profiling_comparison(compare_profiling_list) + else: + self.overall(profiling_path) + + return job_list + + def do_cluster_analysis(self, profiling_path, benchmark_profiling_path=None): + job_list = [] + + # 单集群profiling分析:下发、通信、计算、显存/内存 + for dimension in self.dimensions: + dimension_analysis_func_name = f"cluster_{dimension}_analysis" + if not hasattr(self, dimension_analysis_func_name): + continue + logger.info("Start cluster %s analysis", dimension) + job_list += getattr(self, dimension_analysis_func_name)(profiling_path) + + self.overall(profiling_path) + + if benchmark_profiling_path: + # 两个集群profiling比对分析 + job_list += self._cluster_profiling_comparison(profiling_path, benchmark_profiling_path) + return job_list + + def overall(self, profiling_path): + from msprof_analyze.advisor.analyzer.overall.environment_variable_analyzer import EnvironmentVariableAnalyzer + env_analyzer = EnvironmentVariableAnalyzer(profiling_path) + env_analyzer.optimize() + + if self._is_cluster: + self.slow_rank_analyzer.optimize(template_key=Interface.OVERALL) + self.slow_link_analyzer.optimize(template_key=Interface.OVERALL) + else: + overall_analyzer = OverallSummaryAnalyzer(profiling_path) + overall_analyzer.optimize() + + def schedule_analysis(self, profiling_path, benchmark_profiling_path=None, step=None, benchmark_step=None, + **kwargs): + # 任意单卡的下发分析 + + input_kwargs = copy.deepcopy(self.kwargs) + job_list = [] + + input_kwargs["profiling_path"] = profiling_path + input_kwargs["benchmark_profiling_path"] = benchmark_profiling_path + input_kwargs["step"] = step + input_kwargs["benchmark_step"] = benchmark_step + input_kwargs["rank"] = kwargs.get("rank") + input_kwargs["step_duration"] = kwargs.get("step_duration") + + for dimension in [Interface.SCHEDULE]: + for scope in Interface.get_scope(dimension): + interface = Interface(**input_kwargs) + job_list.append((dimension, scope, interface, input_kwargs)) + return job_list + + def computation_analysis(self, profiling_path, benchmark_profiling_path=None, step=None, + benchmark_step=None, stage=None, **kwargs): + # 任意单卡的计算分析 + + input_kwargs = copy.deepcopy(self.kwargs) + input_kwargs["profiling_path"] = profiling_path + input_kwargs["benchmark_profiling_path"] = benchmark_profiling_path + input_kwargs["step"] = step + input_kwargs["benchmark_step"] = benchmark_step + input_kwargs["stage"] = stage + input_kwargs["rank"] = kwargs.get("rank") + input_kwargs["step_duration"] = kwargs.get("step_duration") + job_list = [] + + for dimension in [Interface.COMPUTATION]: + for scope in Interface.get_scope(dimension): + if scope == SupportedScopes.STAGE_COMPUTE: + continue + interface = Interface(**input_kwargs) + job_list.append((dimension, scope, interface, input_kwargs)) + return job_list + + def memory_analysis(self, profiling_path, benchmark_profiling_path=None, step=None, benchmark_step=None, **kwargs): + # 任意单卡的内存分析 + + input_kwargs = copy.deepcopy(self.kwargs) + job_list = [] + + input_kwargs["profiling_path"] = profiling_path + input_kwargs["benchmark_profiling_path"] = benchmark_profiling_path + input_kwargs["step"] = step + input_kwargs["benchmark_step"] = benchmark_step + input_kwargs["rank"] = kwargs.get("rank") + input_kwargs["step_duration"] = kwargs.get("step_duration") + + for dimension in [Interface.MEMORY]: + for scope in Interface.get_scope(dimension): + interface = Interface(**input_kwargs) + job_list.append((dimension, scope, interface, input_kwargs)) + return job_list + + def communication_analysis(self, profiling_path, benchmark_profiling_path=None, **kwargs): + + job_list = [] + supported_trans_type = [SlowLinkAnalyzer.SDMA, SlowLinkAnalyzer.RDMA] + step = kwargs.get("step", None) + benchmark_step = kwargs.get("benchmark_step", None) + bandwidth_type = kwargs.get("bandwidth_type", None) + scope = kwargs.get("scope", None) + if bandwidth_type is not None and bandwidth_type not in supported_trans_type: + logger.error("Error transit type %s, optionals are %s", bandwidth_type, supported_trans_type) + return job_list + + job_list += self._communication_analysis(profiling_path=profiling_path, + benchmark_profiling_path=benchmark_profiling_path, + step=step, benchmark_step=benchmark_step, + scope=scope, bandwidth_type=bandwidth_type) + + return job_list + + def cluster_schedule_analysis(self, profiling_path): + # 目标集群profiling数据下发分析,不包含两个集群profiling数据的比对分析 + + job_list = [] + global_step_rank = self.slow_rank_analyzer.get_global_step_rank(SlowRankAnalyzer.FREE) + + info_msg = "For cluster schedule analysis, " + slow_rank_id = global_step_rank.get("maximum", {}).get("rank_id") + if slow_rank_id is not None: + info_msg += f"maximum free for rank {slow_rank_id}" + else: + slow_rank_id = self.default_rank_id + info_msg += f"no slow rank with free time, analysis for default rank {slow_rank_id}" + + fast_rank_id = global_step_rank.get("minimum", {}).get("rank_id") + + slow_step = global_step_rank.get("maximum", {}).get("step") + fast_step = global_step_rank.get("minimum", {}).get("step") + + if slow_step is not None: + info_msg += f" and step {slow_step}" + logger.info(info_msg) + + kwargs = dict(profiling_path=self._get_profiling_path_by_rank(profiling_path, slow_rank_id), + benchmark_profiling_path=self._get_profiling_path_by_rank(profiling_path, fast_rank_id), + step=slow_step, benchmark_step=fast_step, + rank=slow_rank_id, benchmark_rank=fast_rank_id, + compare_mode=Constant.API_COMPARE, + step_duration=self.slow_rank_analyzer.get_step_duration(slow_rank_id, slow_step)) + + job_list += self.schedule_analysis(**kwargs) + + rank_id_valid = slow_rank_id is not None and fast_rank_id is not None and fast_rank_id != slow_rank_id + if not self.kwargs.get("benchmark_profiling_path") and rank_id_valid: + # 当用户指定benchmark profiling path时,不进行目标集群profiling的内部快慢卡对比 + logger.info("Enable schedule comparison of fast and slow rank/step") + job_list += self._profiling_comparison([kwargs]) + return job_list + + def cluster_communication_analysis(self, profiling_path): + job_list = [] + + for dimension in [Interface.COMMUNICATION]: + for scope in Interface.get_scope(dimension): + analyzer_class = Interface.get_analyzer(dimension, scope) + if hasattr(analyzer_class, "requires_cluster_dataset") and getattr(analyzer_class, + "requires_cluster_dataset"): + + # 如果不依赖数据集,或者依赖的是ClusterDataset,则不用根据带宽确定需要分析的特定rank + kwargs = copy.deepcopy(self.kwargs) + kwargs["profiling_path"] = profiling_path + interface = Interface(**kwargs) + job_list.append((dimension, scope, interface, kwargs)) + else: + # 非ClusterDataset场景,需要根据带宽大小分析特定的rank + for bandwidth_type in [SlowLinkAnalyzer.SDMA, SlowLinkAnalyzer.RDMA]: + global_step_rank = self.slow_link_analyzer.get_global_step_rank(bandwidth_type) + # 获取带宽最小的卡进行分析 + target_rank_id = global_step_rank.get("minimum", {}).get("rank_id") + if target_rank_id is None: + target_rank_id = self.default_rank_id + step = global_step_rank.get("minimum", {}).get("step") + analysis_profiling_path = self._get_profiling_path_by_rank(profiling_path, target_rank_id) + + info_msg = f"Minimum {bandwidth_type} bandwidth for rank {target_rank_id} " + if step: + info_msg += f"and step {step}" + logger.info(info_msg) + + job_list += self.communication_analysis(analysis_profiling_path, step=step, + bandwidth_type=bandwidth_type, scope=scope) + + return job_list + + def cluster_computation_analysis(self, profiling_path): + # 目标集群profiling数据计算分析,不包含两个集群profiling数据的比对分析;如果有pp stage,则对不同stage进行计算分析 + + job_list = [] + global_step_rank = self.slow_rank_analyzer.get_global_step_rank(SlowRankAnalyzer.COMPUTE) + stage_step_rank = self.slow_rank_analyzer.get_stage_step_rank(SlowRankAnalyzer.COMPUTE) + + if stage_step_rank: + job_list = self._stage_computation_analysis(profiling_path, stage_step_rank, job_list) + else: + job_list = self._global_computation_analysis(profiling_path, global_step_rank, job_list) + return job_list + + def cluster_memory_analysis(self, profiling_path): + # 目标集群profiling数据内存分析,当前memory识别的两个算子,导致的问题都是大的free,因此选择FREE最慢的卡进行分析 + + job_list = [] + global_step_rank = self.slow_rank_analyzer.get_global_step_rank(SlowRankAnalyzer.FREE) + + info_msg = "For cluster memory analysis, " + slow_rank_id = global_step_rank.get("maximum", {}).get("rank_id") + if slow_rank_id is not None: + info_msg += f"maximum free for rank {slow_rank_id}" + else: + slow_rank_id = self.default_rank_id + info_msg += f"no slow rank with free time, analysis for default rank {slow_rank_id}" + + slow_step = global_step_rank.get("maximum", {}).get("step") + if slow_step is not None: + info_msg += f" and step {slow_step}" + logger.info(info_msg) + + analysis_profiling_path = self._get_profiling_path_by_rank(profiling_path, slow_rank_id) + step_duration = self.slow_rank_analyzer.get_step_duration(slow_rank_id, slow_step) + job_list += self.memory_analysis(analysis_profiling_path, step=slow_step, rank=slow_rank_id, + step_duration=step_duration) + return job_list + + def _do_analysis(self, dimensions, pid=0, async_resp=None, **kwargs): + self.dimensions = dimensions + self.kwargs = kwargs + result_list = [] + profiling_path = PathManager.get_realpath(self.kwargs.get("profiling_path")) + benchmark_profiling_path = self.kwargs.get("benchmark_profiling_path") + PathManager.check_path_owner_consistent([profiling_path]) + if benchmark_profiling_path: + benchmark_profiling_path = PathManager.get_realpath(benchmark_profiling_path) + PathManager.check_path_owner_consistent([benchmark_profiling_path]) + + if not self._check_profiling_path_valid(profiling_path): + error_msg = f"Got invalid argument '-d/--profiling_path' {profiling_path}, skip analysis" + self._update_analysis_process_resp(pid, async_resp, error_msg=error_msg, + status_code=AsyncAnalysisStatus.BAD_REQUEST_STATUS_CODE, + status=AsyncAnalysisStatus.FAILED) + logger.error(error_msg) + return + + + if benchmark_profiling_path and not self._check_profiling_path_valid(benchmark_profiling_path): + error_msg = (f"Got invalid argument '-bp/--benchmark_profiling_path' {benchmark_profiling_path}, " + f"skip analysis") + self._update_analysis_process_resp(pid, async_resp, error_msg=error_msg, + status_code=AsyncAnalysisStatus.BAD_REQUEST_STATUS_CODE, + status=AsyncAnalysisStatus.FAILED) + logger.error(error_msg) + return + + self._is_cluster = self._is_cluster_profiling(profiling_path) + if benchmark_profiling_path: + # 构建benchmark profiling的map,用于根据rank获取profiling路径,否则无法进行比对 + is_benchmark_cluster = self._is_cluster_profiling(benchmark_profiling_path) + is_comparison_path_valid = (self._is_cluster and is_benchmark_cluster) or ( + not self._is_cluster and not is_benchmark_cluster) + if not is_comparison_path_valid: + error_msg = f"Only support profiling comparison for '1 npu vs 1 gpu/npu' and 'multi npus vs multi npus'" + self._update_analysis_process_resp(pid, async_resp, error_msg=error_msg, + status_code=AsyncAnalysisStatus.BAD_REQUEST_STATUS_CODE, + status=AsyncAnalysisStatus.FAILED) + logger.error(error_msg) + return + + if not self._is_cluster: + job_list = self.single_rank_analysis(profiling_path, benchmark_profiling_path) + else: + self.slow_rank_analyzer = SlowRankAnalyzer(profiling_path, output_path=self.kwargs.get("output_path")) + self.slow_link_analyzer = SlowLinkAnalyzer(profiling_path, output_path=self.kwargs.get("output_path")) + job_list = self.do_cluster_analysis(profiling_path, benchmark_profiling_path) + + for i, (dimension, scope, interface, kwargs) in enumerate(job_list[::-1]): + result_list.append( + interface.get_result(dimension, scope, render_html=i == len(job_list) - 1, output_dict=False, + **kwargs) + ) + + for result in result_list[::-1]: + if result and hasattr(result, "show"): + result.show() + break + self._get_analysis_finished_resp(pid, async_resp) + + def _get_scopes(self, scope=None, bandwidth_type=SlowLinkAnalyzer.SDMA): + """ + Args: + scope: analyzer type + bandwidth_type: analysis standard + Returns: + scope lists + """ + scopes = [] + if scope: + if scope in self.COMMUNICATION_MAPPING.get(bandwidth_type, self.SDMA_SUPPORT_SCOPES): + scopes.append(scope) + return scopes + for dimension in [Interface.COMMUNICATION]: + for scope_ in Interface.get_scope(dimension): + if scope_ in self.SDMA_SUPPORT_SCOPES or scope_ in self.RDMA_SUPPORT_SCOPES: + scopes.append(scope_) + return scopes + + def _communication_analysis(self, **child_kwargs): + kwargs = copy.deepcopy(self.kwargs) + job_list = [] + + kwargs["profiling_path"] = child_kwargs.get("profiling_path", "") + kwargs["benchmark_profiling_path"] = child_kwargs.get("benchmark_profiling_path", "") + kwargs["step"] = child_kwargs.get("step", -1) + kwargs["benchmark_step"] = child_kwargs.get("benchmark_step", -1) + bandwidth_type = child_kwargs.get("bandwidth_type", SlowLinkAnalyzer.SDMA) + scope = child_kwargs.get("scope", None) + + for scope_ in self._get_scopes(scope, bandwidth_type): + interface = Interface(**kwargs) + job_list.append((Interface.COMMUNICATION, scope_, interface, kwargs)) + + return job_list + + def _profiling_comparison(self, compare_profiling_list): + job_list = [] + disable_profiling_comparison = os.getenv(Constant.DISABLE_PROFILING_COMPARISON) + if disable_profiling_comparison is not None and disable_profiling_comparison.lower() == "true": + logger.info( + "Skip profiling comparison due to longer processing time due to env 'DISABLE_PROFILING_COMPARISON'") + return job_list + + for index, _kwargs in enumerate(compare_profiling_list): + kwargs = copy.deepcopy(self.kwargs) + kwargs.update(_kwargs) + compare_profiling_list[index] = kwargs + + compare_kwargs = { + "profiling_path": kwargs.get("profiling_path"), + "compare_profiling_list": compare_profiling_list, + } + + interface = Interface(**compare_kwargs) + job_list.append((Interface.COMPARISON, SupportedScopes.COMPARISON, interface, compare_kwargs)) + + return job_list + + def _cluster_profiling_comparison(self, profiling_path, benchmark_profiling_path): + # 从计算、下发和通信三个维度对集群profiling数据进行对比 + + job_list = [] + benchmark_profiling_path = self._get_profiling_path_by_rank(benchmark_profiling_path) + benchmark_slow_rank_analyzer = SlowRankAnalyzer(benchmark_profiling_path) + benchmark_slow_link_analyzer = SlowLinkAnalyzer(benchmark_profiling_path) + + # 计算和下发分析 + job_list += self._cluster_data_comparison(profiling_path, + benchmark_profiling_path, + self.slow_rank_analyzer, + benchmark_slow_rank_analyzer, + get_max=True) + + # 通信分析 + job_list += self._cluster_data_comparison(profiling_path, + benchmark_profiling_path, + self.slow_link_analyzer, + benchmark_slow_link_analyzer, + get_max=False) + return job_list + + def _cluster_data_comparison(self, profiling_path, benchmark_profiling_path, target_cluster_analyzer, + benchmark_cluster_analyzer, get_max=False): + # #low rank/slow link结果逐行对比获取差值最大的rank和step进行单卡分析 + job_list = [] + + if isinstance(target_cluster_analyzer, SlowRankAnalyzer): + comparison_dims = [SlowRankAnalyzer.COMPUTE, SlowRankAnalyzer.FREE] + comparison_modes = [Constant.KERNEL_COMPARE, Constant.API_COMPARE] + elif isinstance(target_cluster_analyzer, SlowLinkAnalyzer): + comparison_dims = [SlowLinkAnalyzer.SDMA_BANDWIDTH, SlowLinkAnalyzer.RDMA_BANDWIDTH] + comparison_modes = [None, None] + else: + return job_list + + target_data = target_cluster_analyzer.format_datas.get("data", []) + benchmark_data = benchmark_cluster_analyzer.format_datas.get("data", []) + headers = benchmark_cluster_analyzer.format_datas.get("headers", []) + + if len(target_data) != len(benchmark_data): + logger.warning( + "The product of ranks and steps of Benchmark profiling is not equals to target profiling, " + "skip cluster comparison.") + return job_list + + compare_profiling_list = [] + for dimension, compare_mode in zip(comparison_dims, comparison_modes): + step, benchmark_step, rank_id_for_comparison = AnalyzerController._get_step_rank_for_cluster_statistic_diff( + target_data, + benchmark_data, + headers, + dimension, + get_max=get_max + ) + + rank_profiling_path = self._get_profiling_path_by_rank(profiling_path, rank_id_for_comparison) + rank_benchmark_profiling_path = self._get_profiling_path_by_rank( + benchmark_profiling_path, + rank_id_for_comparison + ) + + if rank_id_for_comparison is None: + # rank id为空则无法获取对应rank的profiling路径,无法进行比较 + continue + + compare_profiling_list.append( + dict(profiling_path=rank_profiling_path, benchmark_profiling_path=rank_benchmark_profiling_path, + step=step, benchmark_step=benchmark_step, + rank=rank_id_for_comparison, benchmark_rank=rank_id_for_comparison, compare_mode=compare_mode) + ) + + if not compare_profiling_list: + return job_list + + job_list += self._profiling_comparison(compare_profiling_list) + return job_list + + def _is_cluster_profiling(self, profiling_path): + if os.path.isfile(profiling_path): + return False + path_list = [os.path.join(profiling_path, dir_name) for dir_name in os.listdir(profiling_path)] + ascend_pt_dirs = [path for path in path_list if os.path.isdir(path) and path.endswith("ascend_pt")] + ascend_ms_dirs = [path for path in path_list if os.path.isdir(path) and path.endswith("ascend_ms")] + if ascend_ms_dirs and ascend_pt_dirs: + logger.error("Cannot analyze pytorch and mindspore meantime.") + return False + if not ascend_pt_dirs and not ascend_ms_dirs: + return False + if ascend_ms_dirs and not ascend_pt_dirs: + data_processor = MindsporeDataPreprocessor(ascend_ms_dirs) + elif ascend_pt_dirs and not ascend_ms_dirs: + data_processor = PytorchDataPreprocessor(ascend_pt_dirs) + + self.cluster_local_data_map[profiling_path] = data_processor.get_data_map() + + if not self.cluster_local_data_map or not self.cluster_local_data_map.get(profiling_path): + return False + + self.default_rank_id = list(self.cluster_local_data_map[profiling_path].keys())[0] + + return len(self.cluster_local_data_map[profiling_path]) >= self.CLUSTER_RANK_THRESHOLD + + def _get_profiling_path_by_rank(self, profiling_path, rank_id=None): + + if not profiling_path: + return profiling_path + + return self._get_target_profiling_path_for_local(profiling_path, rank_id) + + def _get_target_profiling_path_for_local(self, profiling_path, rank_id): + rank_id_map = self.cluster_local_data_map.get(profiling_path, {}) + if rank_id is None or not rank_id_map: + return profiling_path + + if rank_id in rank_id_map: + return rank_id_map.get(rank_id) + + local_first_rank_id = sorted(list(map(int, rank_id_map.keys())))[0] + logger.warning("Target rank id %s does not exist in local profiling data %s, use rank %s for analysis", + rank_id, profiling_path, local_first_rank_id) + return rank_id_map.get(local_first_rank_id) + + def _update_analysis_process_resp(self, pid, resp, **kwargs): + if kwargs: + resp.update(kwargs) + self.analysis_process_resp[pid] = resp + + def _get_analysis_finished_resp(self, pid, resp): + advisor_output_file_prefix = f"mstt_advisor_{Timer().strftime}" + html_path = os.path.join(Config().work_path, f"{advisor_output_file_prefix}.html") + xlsx_path = os.path.join(Config().work_path, "log", f"{advisor_output_file_prefix}.xlsx") + if os.path.exists(html_path) and os.path.exists(xlsx_path): + result_files = {"html": html_path, "xlsx": xlsx_path} + self._update_analysis_process_resp(pid, resp, status_code=AsyncAnalysisStatus.NON_FAILED_STATUS_CODE, + status=AsyncAnalysisStatus.SUCCESS, result_files=result_files) + else: + self._update_analysis_process_resp(pid, resp, status_code=AsyncAnalysisStatus.BAD_REQUEST_STATUS_CODE, + status=AsyncAnalysisStatus.FAILED, + error_msg="No optimization suggestions, please check your input path.") + + def _stage_computation_analysis(self, profiling_path, stage_step_rank, job_list): + # 对不同pp stage取min max进行分析 + logger.info("Steps and ranks to be analyzed of different pipeline parallel stages are %s", + json.dumps(stage_step_rank)) + + stages_profiling_path = [] + for stage, step_rank_info in stage_step_rank.items(): + rank_id = step_rank_info.get("maximum", {}).get("rank_id") + step = step_rank_info.get("maximum", {}).get("step") + benchmark_rank_id = step_rank_info.get("minimum", {}).get("rank_id") + benchmark_step = step_rank_info.get("minimum", {}).get("step") + + info_msg = f"For {stage}, slow rank is {rank_id}" + if step: + info_msg += f", step is {step}" + logger.info(info_msg) + + stages_profiling_path.append( + dict( + stage=stage, rank=rank_id, step=step, benchmark_rank=benchmark_rank_id, + benchmark_step=benchmark_step, + profiling_path=self._get_profiling_path_by_rank(profiling_path, rank_id), + benchmark_profiling_path=self._get_profiling_path_by_rank(profiling_path, benchmark_rank_id), + compare_mode=Constant.KERNEL_COMPARE, + step_duration=self.slow_rank_analyzer.get_step_duration(rank_id, step) + ) + ) + Interface.add_analyzer(Interface.COMPUTATION, SupportedScopes.STAGE_COMPUTE, PPStageComputationAnalyzer) + compute_analysis_kwargs = {"stages_profiling_path": stages_profiling_path, "profiling_path": profiling_path} + + job_list.append((Interface.COMPUTATION, SupportedScopes.STAGE_COMPUTE, Interface(**compute_analysis_kwargs), + compute_analysis_kwargs)) + if not self.kwargs.get("benchmark_profiling_path"): + logger.info("Enable computation comparison of fast and slow rank/step in different pp stages") + job_list += self._profiling_comparison(stages_profiling_path) + return job_list + + def _global_computation_analysis(self, profiling_path, global_step_rank, job_list): + # 不区分stage,对所有卡取Min max进行分析 + logger.info("Without pipeline parallel stage, steps and ranks to be analyzed are %s", + json.dumps(global_step_rank)) + slow_rank_id = global_step_rank.get("maximum", {}).get("rank_id") + if slow_rank_id is not None: + info_msg = f"Maximum computation time for rank {slow_rank_id}" + else: + slow_rank_id = self.default_rank_id + info_msg = f"No slow rank with computation time, analysis for default rank {slow_rank_id}" + slow_step = global_step_rank.get("maximum", {}).get("step") + # 如果没有标杆profiling数据的rank id,说明没有快慢卡问题,直接对默认rank id进行分析,因此这里取值为None + fast_rank_id = global_step_rank.get("minimum", {}).get("rank_id") + fast_step = global_step_rank.get("minimum", {}).get("step") + + if slow_step is not None: + info_msg += f" and step {slow_step}, " + if fast_rank_id is not None: + info_msg += f"minimum computation time for rank {fast_rank_id}" + if fast_step is not None: + info_msg += f" and step {fast_step}" + logger.info(info_msg) + + kwargs = dict(profiling_path=self._get_profiling_path_by_rank(profiling_path, slow_rank_id), + benchmark_profiling_path=self._get_profiling_path_by_rank(profiling_path, fast_rank_id), + step=slow_step, benchmark_step=fast_step, rank=slow_rank_id, benchmark_rank=fast_rank_id, + compare_mode=Constant.KERNEL_COMPARE, + step_duration=self.slow_rank_analyzer.get_step_duration(slow_rank_id, slow_step)) + + job_list += self.computation_analysis(**kwargs) + + rank_id_valid = slow_rank_id is not None and fast_rank_id is not None and fast_rank_id != slow_rank_id + if not self.kwargs.get("benchmark_profiling_path") and rank_id_valid: + # 当用户指定benchmark profiling path时,不进行目标集群profiling的内部快慢卡对比 + logger.info("Enable computation comparison of fast and slow rank/step") + job_list += self._profiling_comparison([kwargs]) + return job_list diff --git a/profiler/msprof_analyze/advisor/analyzer/communication/base_communication_analyzer.py b/profiler/msprof_analyze/advisor/analyzer/communication/base_communication_analyzer.py index 5fbbf0c56dc204711eb37f47d11e67c65f9d3897..73724ee29988064f8e2b862ac49e5bd6b27f8762 100644 --- a/profiler/msprof_analyze/advisor/analyzer/communication/base_communication_analyzer.py +++ b/profiler/msprof_analyze/advisor/analyzer/communication/base_communication_analyzer.py @@ -1,22 +1,22 @@ -# Copyright (c) 2024, Huawei Technologies Co., Ltd. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from msprof_analyze.advisor.analyzer.base_analyzer import BaseAnalyzer - - -class BaseCommunicationAnalyzer(BaseAnalyzer): - requires_cluster_dataset = True - - def __init__(self, collection_path, n_processes: int = 1, **kwargs): - super().__init__(collection_path, n_processes, **kwargs) +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from msprof_analyze.advisor.analyzer.base_analyzer import BaseAnalyzer + + +class BaseCommunicationAnalyzer(BaseAnalyzer): + requires_cluster_dataset = True + + def __init__(self, collection_path, n_processes: int = 1, **kwargs): + super().__init__(collection_path, n_processes, **kwargs) diff --git a/profiler/msprof_analyze/advisor/analyzer/computation/ai_core_performance/ai_core_performance_analyzer.py b/profiler/msprof_analyze/advisor/analyzer/computation/ai_core_performance/ai_core_performance_analyzer.py deleted file mode 100644 index 23ec775e275134e8a99336b005d9f8f198660245..0000000000000000000000000000000000000000 --- a/profiler/msprof_analyze/advisor/analyzer/computation/ai_core_performance/ai_core_performance_analyzer.py +++ /dev/null @@ -1,53 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging - -from msprof_analyze.advisor.analyzer.base_analyzer import BaseAnalyzer -from msprof_analyze.advisor.analyzer.computation.ai_core_performance.ai_core_performance_checker import \ - AICorePerformanceChecker -from msprof_analyze.advisor.dataset.profiling.profiling_dataset import ProfilingDataset -from msprof_analyze.advisor.result.result import OptimizeResult -from msprof_analyze.advisor.display.html.priority_background_color import PriorityBackgroundColor -from msprof_analyze.advisor.display.html.render import HTMLRender - -logger = logging.getLogger() - - -class AICorePerformanceAnalyzer(BaseAnalyzer): - dataset_cls_list = [ProfilingDataset] - - def __init__(self, collection_path, n_processes: int = 1, **kwargs) -> None: - super().__init__(collection_path, n_processes, **kwargs) - profiling_key = ProfilingDataset.get_key() - self.profiling_dataset = self.get_first_data_by_key(self.dataset_list, profiling_key) - self.result = OptimizeResult() - self.html_render = HTMLRender() - self.html = None - - def optimize(self, **kwargs): - add_render_list = kwargs.get("add_render_list", True) - ai_core_perf_checker = AICorePerformanceChecker() - ai_core_perf_checker.data_filter(self.profiling_dataset) - if not ai_core_perf_checker.ai_core_performance_issues: - return self.result - ai_core_perf_checker.check_ai_core_performance(self.profiling_dataset) - ai_core_perf_checker.make_record(self.result) - self.html = ai_core_perf_checker.make_render(self.html_render, - add_render_list, - priority=self.get_priority(), - rank=kwargs.get("rank")) - return self.result - - def get_priority(self, max_mem_op_dur=None): - return PriorityBackgroundColor.low \ No newline at end of file diff --git a/profiler/msprof_analyze/advisor/analyzer/computation/ai_core_performance/ai_core_performance_checker.py b/profiler/msprof_analyze/advisor/analyzer/computation/ai_core_performance/ai_core_performance_checker.py deleted file mode 100644 index fa62cd6f8958e28320d19e09d8ef1dae5609d03f..0000000000000000000000000000000000000000 --- a/profiler/msprof_analyze/advisor/analyzer/computation/ai_core_performance/ai_core_performance_checker.py +++ /dev/null @@ -1,562 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging -import os -from functools import reduce -from msprof_analyze.advisor.dataset.profiling.profiling_dataset import ProfilingDataset -from msprof_analyze.advisor.result.item import OptimizeItem, OptimizeRecord -from msprof_analyze.advisor.result.result import OptimizeResult -from msprof_analyze.prof_common.additional_args_manager import AdditionalArgsManager -from msprof_analyze.prof_common.file_manager import FileManager - -logger = logging.getLogger() - - -class AICorePerformanceChecker: - """ - operator performance checker - """ - _CHECKER = "AICorePerformanceChecker" - CUBE_OPERATOR_MEMORY_SIZE_MB = 100 - INNER_AXIS_256 = 256 - INNER_AXIS_128 = 128 - - def __init__(self): - self.result = dict() - self.ai_core_performance_issues = False - self._desc = "" - self.cube_dict = {} - self.fa_dict = {} - self.fa_list = [] - self.vector_dict = {} - self.load_aicore_perf_rules() - - @staticmethod - def get_operator_list(cube_dict, profiling_dataset): - operator_list = [] - for op in profiling_dataset.op_summary.op_list: - if op.op_name in cube_dict: - key = op.input_shapes[1:-1] + "-" + op.output_shapes[1:-1] - if key in cube_dict[op.op_name]: - operator_list.append(op) - return operator_list - - @staticmethod - def get_vector_list(profiling_dataset, vector_dict): - vector_list = [] - for op_name in vector_dict: - for shape in vector_dict[op_name]: - for operator in profiling_dataset.op_summary.op_list: - if operator.op_name == op_name and operator.input_shapes[1:-1] + "-" + operator.output_shapes[ - 1:-1] == shape: - vector_list.append(operator) - return vector_list - - @staticmethod - def safe_divide(numerator, denominator): - if denominator == 0: - logger.warning("Warning: Division by zero is not allowed.") - return None - return numerator / denominator - - @staticmethod - def memory_size(operator): - memory = 0 - input_shapes = operator.input_shapes[1:-1].split(";") - output_shapes = operator.output_shapes[1:-1] - for shapes in input_shapes: - if "," not in shapes and shapes != "": - # 多的一维是 bias ,预先乘2 - memory += int(shapes) * 2 - continue - memory += reduce(lambda x, y: x * y, map(int, shapes.split(","))) - memory += reduce(lambda x, y: x * y, map(int, output_shapes.split(","))) - return memory * 2 / 1024 / 1024 - - def load_aicore_perf_rules(self): - language = AdditionalArgsManager().language - rule_path = os.path.join( - os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))), - "rules", language, "aicore_performance.yaml" - ) - - if not os.path.exists(rule_path): - logger.warning("Skip analyze aicpu issues, because %s does not exist.", rule_path) - - self.language = language - self.aicore_rules = FileManager.read_yaml_file(rule_path) - self._cube_problem = self.aicore_rules.get("cube_problem") - self._fa_problem = self.aicore_rules.get("fa_problem") - self._vector_problem = self.aicore_rules.get("vector_problem") - self._desc = self.aicore_rules.get("description") - self._bound_desc = self.aicore_rules.get("bound_description") - self._opti_desc = self.aicore_rules.get("optimization_description") - self._affinity_desc = self.aicore_rules.get("affinity_description") - self._cube_affinity_desc = self.aicore_rules.get("cube_affinity_desc") - self._fa_affinity_desc_head_dim_128 = self.aicore_rules.get("fa_affinity_desc_head_dim_128") - self._fa_affinity_desc_seq_len_128 = self.aicore_rules.get("fa_affinity_desc_seq_len_128") - self._fa_affinity_desc_head_dim_seq_len_128 = self.aicore_rules.get("fa_affinity_desc_head_dim_seq_len_128") - self._suggestion = self.aicore_rules.get("suggestion") - self._affinity_suggestion = self.aicore_rules.get("affinity_suggestion") - self._bound_suggestion = self.aicore_rules.get("bound_suggestion") - self._opti_suggestion = self.aicore_rules.get("optimization_suggestion") - self._operator_rules = {"cube_operators": self.aicore_rules.get("cube_operators"), - "fa_operators": self.aicore_rules.get("fa_operators"), - "vector_operators": self.aicore_rules.get("vector_operators")} - - def data_filter(self, profiling_dataset: ProfilingDataset): - if not self.check_task_list(profiling_dataset): - return - - operator_list = profiling_dataset.op_summary.op_list - total_duration = sum(float(operator.task_duration) for operator in operator_list) - if (total_duration == 0): - return - cube_memory_dict, vector_type_dict = {}, {} - - for op in operator_list: - shapes = op.input_shapes[1:-1] + "-" + op.output_shapes[1:-1] - # preliminary filter cube operator - if op.task_type == "AI_CORE" and "matmul" in op.op_type.lower(): - cube_memory_dict.setdefault(op.op_name, {}).setdefault(shapes, 0) - cube_memory_dict[op.op_name][shapes] += self.memory_size(op) - continue - - # filter fa operator - if op.op_type == "FlashAttentionScore": - self.fa_dict.setdefault(op.op_name, set()).add(shapes) - self.fa_list.append(op) - elif op.op_type == "FlashAttentionScoreGrad": - self.fa_dict.setdefault(op.op_name, set()).add(shapes + "-grad") - self.fa_list.append(op) - - # preliminary filter vector operator - if op.task_type in ["AI_VECTOR_CORE", "MIX_AIV"]: - vector_type_dict.setdefault(op.op_type, set()).add(op) - - # filter cube operator - for op_name in cube_memory_dict: - for shapes in cube_memory_dict[op_name]: - if cube_memory_dict[op_name][shapes] >= self.CUBE_OPERATOR_MEMORY_SIZE_MB: - self.cube_dict.setdefault(op_name, set()).add(shapes) - - # filter vector operator - for op_type in vector_type_dict: - duration_group_by_time = sum(float(op.task_duration) for op in vector_type_dict[op_type]) - if (duration_group_by_time / total_duration) >= 0.01 or duration_group_by_time >= 1000000: - for op in vector_type_dict[op_type]: - shapes = op.input_shapes[1:-1] + "-" + op.output_shapes[1:-1] - self.vector_dict.setdefault(op.op_name, set()).add(shapes) - - if any([self.cube_dict, self.fa_dict, self.vector_dict]): - self.ai_core_performance_issues = True - - def check_ai_core_performance(self, promoting_dataset: ProfilingDataset): - for operator_type in ["cube", "fa", "vector"]: - try: - self.result[operator_type] = getattr(self, f"check_{operator_type}_operator")(promoting_dataset) - except (IndexError, ValueError, AttributeError) as e: - logger.warning(f"Failed to check ai core performance {operator_type} operator, {e}.") - self.result[operator_type] = [] - - if not any([self.result["cube"], self.result["fa"], self.result["vector"]]): - self.ai_core_performance_issues = False - - def check_cube_operator(self, profiling_dataset: ProfilingDataset): - cube_dict = self.cube_dict - suggestion = self._cube_affinity_desc - optimization_queue, bound_queue, affinity_queue = [], [], [] - operator_list = self.get_operator_list(cube_dict, profiling_dataset) - for op in cube_dict: - for shape in cube_dict[op]: - affinity_flag = self._check_cube_inner_axis(shape) - if not affinity_flag: - dtype, shape_duration = None, 0. - for operator in operator_list: - if (operator.op_name == op and - operator.input_shapes[1:-1] + "-" + operator.output_shapes[1:-1] == shape): - dtype = operator.input_data_types - shape_duration += float(operator.task_duration) - affinity_queue.append({"op_name": op, - "shape": shape.split("-")[0], - "dtype": dtype, - "duration": shape_duration, - "suggestion": suggestion}) - else: - shape_list = [] - for operator in operator_list: - if (operator.op_name == op and operator.input_shapes[1:-1] + "-" + - operator.output_shapes[1:-1] == shape): - shape_list.append(operator) - shape_duration = sum(float(operator.task_duration) for operator in shape_list) - dtype = shape_list[0].input_data_types if shape_list else None - bound, optimization = self.del_cube_operator_bound(shape_list) - if bound is None and optimization is None: - continue - if bound: - bound_queue.append({"op_name": op, - "shape": shape.split("-")[0], - "dtype": dtype, - "bound": bound, - "duration": shape_duration}) - else: - optimization_queue.append({"op_name": op, - "shape": shape.split("-")[0], - "dtype": dtype, - "optimization": round(optimization * 100, 2)}) - return [sorted(optimization_queue, key=lambda x: x["optimization"], reverse=True)[:5], - sorted(bound_queue, key=lambda x: x["duration"], reverse=True)[:5], - sorted(affinity_queue, key=lambda x: x["duration"], reverse=True)[:5]] - - def del_cube_operator_bound(self, shape_list): - bound, optimization, aic_mac_ratio, aic_mte2_ratio, length = "", 0., 0., 0., 0 - for operator in shape_list: - try: - aic_mac_ratio += float(operator.aic_mac_ratio) - aic_mte2_ratio += float(operator.aic_mte2_ratio) - length += 1 - except ValueError: - continue - aic_mac_ratio = self.safe_divide(aic_mac_ratio, length) - aic_mte2_ratio = self.safe_divide(aic_mte2_ratio, length) - if aic_mac_ratio is None or aic_mte2_ratio is None: - return None, None - aic_mac_ratio_rule, aic_mte2_ratio_rule = None, None - for operator_rule in self._operator_rules["cube_operators"]: - if operator_rule["target"] == "aic_mac_ratio": - aic_mac_ratio_rule = operator_rule - elif operator_rule["target"] == "aic_mte2_ratio": - aic_mte2_ratio_rule = operator_rule - if (aic_mac_ratio >= aic_mac_ratio_rule["threshold"] - and aic_mte2_ratio >= aic_mte2_ratio_rule["threshold"]): - bound = aic_mac_ratio_rule["bound"] + "_and_" + aic_mte2_ratio_rule["bound"] + "_bound" - elif aic_mac_ratio >= aic_mte2_ratio_rule["threshold"]: - bound = aic_mac_ratio_rule["bound"] - elif aic_mte2_ratio >= aic_mte2_ratio_rule["threshold"]: - bound = aic_mte2_ratio_rule["bound"] - else: - optimization = max(aic_mac_ratio_rule["threshold"] - aic_mac_ratio, - aic_mte2_ratio_rule["threshold"] - aic_mte2_ratio) - return bound, optimization - - def check_fa_operator(self, profiling_dataset: ProfilingDataset): - fa_list, fa_dict = self.fa_list, self.fa_dict - optimization_queue, bound_queue, affinity_queue = [], [], [] - # 不亲和算子筛选 - for op in fa_dict: - for shape in fa_dict[op]: - affinity_flag, dtype, shape_duration, suggestion = self._check_fa_inner_axis(fa_list, op, shape) - if affinity_flag: - # 不亲和算子 计算耗时,加入affinity_queue - affinity_queue.append({"op_name": op, - "shape": shape.split("-")[0], - "dtype": dtype, - "suggestion": suggestion, - "duration": shape_duration}) - else: - # 处理bound算子和优化算子 - if len(shape.split("-")) > 2: - bound, optimization, dtype, shape_duration = self.del_fa_operator_bound_grad(op, shape, fa_list) - else: - bound, optimization, dtype, shape_duration = self.del_fa_operator_bound(op, shape, fa_list) - if bound is None and optimization is None: - continue - if bound: - bound_queue.append({"op_name": op, - "shape": shape.split("-")[0], - "dtype": dtype, - "bound": bound, - "duration": shape_duration}) - else: - optimization_queue.append({"op_name": op, - "shape": shape.split("-")[0], - "dtype": dtype, - "optimization": round(optimization * 100, 2)}) - - return [sorted(optimization_queue, key=lambda x: x["optimization"], reverse=True)[:5], - sorted(bound_queue, key=lambda x: x["duration"], reverse=True)[:5], - sorted(affinity_queue, key=lambda x: x["duration"], reverse=True)[:5]] - - def del_fa_operator_bound_grad(self, op, shape, fa_list): - aic_fixpipe_ratio, aic_mte2_ratio, shape_duration, optimization, length = 0., 0., 0., 0., 0 - bound, dtype = "", None - for operator in fa_list: - if (operator.op_name == op and - operator.input_shapes[1:-1] + "-" + - operator.output_shapes[1:-1] + "-grad" == shape): - try: - aic_fixpipe_ratio += float(operator.aic_fixpipe_ratio) - aic_mte2_ratio += float(operator.aic_mte2_ratio) - shape_duration += float(operator.task_duration) - dtype = operator.input_data_types - length += 1 - except ValueError: - continue - aic_fixpipe_ratio = self.safe_divide(aic_fixpipe_ratio, length) - aic_mte2_ratio = self.safe_divide(aic_mte2_ratio, length) - if aic_mte2_ratio is None or aic_fixpipe_ratio is None: - return None, None, None - aic_fixpipe_ratio_rule, aic_mte2_ratio_rule = None, None - for rule in self._operator_rules["fa_operators"]: - if rule["target"] == "aic_fixpipe_ratio": - aic_fixpipe_ratio_rule = rule - elif rule["target"] == "aic_mte2_ratio": - aic_mte2_ratio_rule = rule - if (aic_mte2_ratio >= aic_mte2_ratio_rule["threshold"] and - aic_fixpipe_ratio >= aic_fixpipe_ratio_rule["threshold"]): - bound = aic_fixpipe_ratio_rule["bound"] + "_and_" + aic_mte2_ratio_rule["bound"] + "_bound" - elif aic_mte2_ratio >= aic_mte2_ratio_rule["threshold"]: - bound = aic_mte2_ratio_rule["bound"] - elif aic_fixpipe_ratio >= aic_fixpipe_ratio_rule["threshold"]: - bound = aic_fixpipe_ratio_rule["bound"] - else: - optimization = max(aic_fixpipe_ratio_rule["threshold"] - aic_fixpipe_ratio, - aic_mte2_ratio_rule["threshold"] - aic_mte2_ratio) - return bound, optimization, dtype, shape_duration - - def del_fa_operator_bound(self, op, shape, fa_list): - aiv_vec_ratio, aic_mte2_ratio, shape_duration, optimization, length = 0., 0., 0., 0., 0 - bound, dtype = "", None - for operator in fa_list: - if (operator.op_name == op and - operator.input_shapes[1:-1] + "-" + operator.output_shapes[1:-1] == shape): - try: - aiv_vec_ratio += float(operator.aiv_vec_ratio) - aic_mte2_ratio += float(operator.aic_mte2_ratio) - shape_duration += float(operator.task_duration) - length += 1 - except ValueError: - continue - aiv_vec_ratio = self.safe_divide(aiv_vec_ratio, length) - aic_mte2_ratio = self.safe_divide(aic_mte2_ratio, length) - if aiv_vec_ratio is None or aic_mte2_ratio is None: - return None, None, None - aiv_vec_ratio_rule, aic_mte2_ratio_rule = None, None - for rule in self._operator_rules["fa_operators"]: - if rule["target"] == "aiv_vec_ratio": - aiv_vec_ratio_rule = rule - elif rule["target"] == "aic_mte2_ratio": - aic_mte2_ratio_rule = rule - if (aic_mte2_ratio >= aic_mte2_ratio_rule["threshold"] - and aiv_vec_ratio >= aiv_vec_ratio_rule["threshold"]): - bound = aic_mte2_ratio_rule["bound"] + "_and_" + aiv_vec_ratio_rule["bound"] + "_bound" - elif aic_mte2_ratio >= aic_mte2_ratio_rule["threshold"]: - bound = aic_mte2_ratio_rule["bound"] - elif aiv_vec_ratio >= aiv_vec_ratio_rule["threshold"]: - bound = aiv_vec_ratio_rule["bound"] - else: - optimization = max(aiv_vec_ratio_rule["threshold"] - aiv_vec_ratio, - aic_mte2_ratio_rule["threshold"] - aic_mte2_ratio) - return bound, optimization, dtype, shape_duration - - def check_vector_operator(self, profiling_dataset: ProfilingDataset): - vector_dict = self.vector_dict - optimization_queue, bound_queue = [], [] - vector_list = self.get_vector_list(profiling_dataset, vector_dict) - for op_name in vector_dict: - for shape in vector_dict[op_name]: - aiv_vec_ratio, aiv_mte2_ratio, aiv_mte3_ratio, shape_duration = 0., 0., 0., 0. - length, dtype = 0, "" - for operator in vector_list: - if (operator.op_name == op_name and - operator.input_shapes[1:-1] + "-" + operator.output_shapes[1:-1] == shape): - try: - aiv_vec_ratio += float(operator.aiv_vec_ratio) - aiv_mte2_ratio += float(operator.aiv_mte2_ratio) - aiv_mte3_ratio += float(operator.aiv_mte3_ratio) - shape_duration += float(operator.task_duration) - dtype = operator.input_data_types - length += 1 - except ValueError: - continue - aiv_vec_ratio = self.safe_divide(aiv_vec_ratio, length) - aiv_mte2_ratio = self.safe_divide(aiv_mte2_ratio, length) - aiv_mte3_ratio = self.safe_divide(aiv_mte3_ratio, length) - if aiv_vec_ratio is None or aiv_mte2_ratio is None or aiv_mte3_ratio is None: - continue - bound, optimization = self.del_vector_operator_bound(aiv_mte2_ratio, aiv_mte3_ratio, aiv_vec_ratio) - if bound: - bound_queue.append({"op_name": op_name, - "shape": shape.split("-")[0], - "bound": bound, - "dtype": dtype, - "duration": shape_duration}) - else: - optimization_queue.append({"op_name": op_name, - "shape": shape.split("-")[0], - "dtype": dtype, - "optimization": round(optimization * 100, 2)}) - return [sorted(optimization_queue, key=lambda x: x["optimization"], reverse=True)[:5], - sorted(bound_queue, key=lambda x: x["duration"], reverse=True)[:5]] - - def del_vector_operator_bound(self, aiv_mte2_ratio, aiv_mte3_ratio, aiv_vec_ratio): - bound, optimization = "", 0 - aiv_vec_ratio_rule, aiv_mte2_ratio_rule, aiv_mte3_ratio_rule, total_rule = None, None, None, None - for operator_rule in self._operator_rules["vector_operators"]: - if operator_rule["target"] == "aiv_vec_ratio": - aiv_vec_ratio_rule = operator_rule - elif operator_rule["target"] == "aiv_mte2_ratio": - aiv_mte2_ratio_rule = operator_rule - elif operator_rule["target"] == "aiv_mte3_ratio": - aiv_mte3_ratio_rule = operator_rule - elif operator_rule["target"] == "total": - total_rule = operator_rule - if aiv_vec_ratio + aiv_mte2_ratio + aiv_mte3_ratio >= total_rule["threshold"]: - bound = total_rule["bound"] - elif aiv_mte2_ratio >= aiv_mte2_ratio_rule["threshold"]: - bound = aiv_mte2_ratio_rule["bound"] - elif aiv_mte3_ratio >= aiv_mte3_ratio_rule["threshold"]: - bound = aiv_mte3_ratio_rule["bound"] - elif aiv_vec_ratio >= aiv_vec_ratio_rule["threshold"]: - bound = aiv_vec_ratio_rule["bound"] - else: - optimization = max(aiv_vec_ratio_rule["threshold"] - aiv_vec_ratio, - aiv_mte2_ratio_rule["threshold"] - aiv_mte2_ratio, - aiv_mte3_ratio_rule["threshold"] - aiv_mte3_ratio) - return bound, optimization - - def draw_record(self, op_type: str, result: OptimizeResult): - suggestion_keys = ['opti', 'bound', 'affinity'] - desc = dict.fromkeys(suggestion_keys, "") - problem_map = { - 'cube': self._cube_problem, - 'fa': self._fa_problem, - 'vector': self._vector_problem - } - if op_type not in problem_map: - return - optimization_item = OptimizeItem(problem_map[op_type], self._desc, [self._suggestion]) - result.add(OptimizeRecord(optimization_item)) - headers = [ - "Type", - "Description and Suggestion", - ] - result.add_detail(problem_map[op_type], headers=headers) - for opti_issue in self.result[op_type][0]: - opti_sugg = self._opti_suggestion.format(**opti_issue) - desc["opti"] += opti_sugg - if desc["opti"]: - result.add_detail(problem_map[op_type], detail=[self._opti_desc, desc["opti"]]) - for bound_issue in self.result[op_type][1]: - bound_sugg = self._bound_suggestion.format(**bound_issue) - desc["bound"] += bound_sugg - if desc["bound"]: - result.add_detail(problem_map[op_type], detail=[self._bound_desc, desc["bound"]]) - if op_type == "vector": # vector 类型没有亲和性建议 - return - for affinity_issue in self.result[op_type][2]: - affinity_sugg = self._affinity_suggestion.format(**affinity_issue) - desc["affinity"] += affinity_sugg - if desc["affinity"]: - result.add_detail(problem_map[op_type], detail=[self._affinity_desc, desc["affinity"]]) - - def make_record(self, result: OptimizeResult): - """ - make record for what and how to optimize - """ - if not self.ai_core_performance_issues: - return self.ai_core_performance_issues - if any(self.result["cube"]): - self.draw_record("cube", result) - if any(self.result["fa"]): - self.draw_record("fa", result) - if any(self.result["vector"]): - self.draw_record("vector", result) - - return True - - def make_render(self, html_render, add_render_list=True, **kwargs): - if not self.ai_core_performance_issues: - return self.ai_core_performance_issues - - priority = kwargs.get("priority") - return html_render.render_template(key="computation", - template_dir="templates", - template_name="ai_core_performance.html", - format_result=self.result, - language=self.language, - add_render_list=add_render_list, - priority_background_color=priority, - rank=kwargs.get("rank")) - - def check_task_list(self, profiling_dataset: ProfilingDataset) -> bool: - if not hasattr(profiling_dataset, "op_summary"): - logger.warning("Skip %s checker because of not containing %s", self._CHECKER, "op summary") - return False - if not hasattr(profiling_dataset.op_summary, "op_list"): - logger.warning("Skip %s checker because of not containing %s", self._CHECKER, "op_list") - return False - if (not hasattr(profiling_dataset.op_summary.op_list[0], "input_shapes") or - not hasattr(profiling_dataset.op_summary.op_list[0], "input_data_types")): - logger.warning("Skip %s checker because of not containing input datas", self._CHECKER) - return False - return True - - def _check_cube_inner_axis(self, shape): - # 判断输入shape内轴是否为256的倍数 - shapes = shape.split("-")[0].split(";") - if (len(shape.split("-")[0].split(";")[0].split(","))) == 4: - # NZ格式 - b_axis, c_axis = int(shapes[0].split(",")[1]), int(shapes[0].split(",")[2]) - f_axis, g_axis = int(shapes[1].split(",")[1]), int(shapes[1].split(",")[2]) - return (b_axis * c_axis % self.INNER_AXIS_256 == 0) and (f_axis * g_axis % self.INNER_AXIS_256 == 0) - elif (len(shape.split("-")[0].split(";")[0].split(","))) == 2: - # ND格式 - l_axis, k_axis = int(shapes[0].split(",")[1]), int(shapes[1].split(",")[1]) - return (l_axis % self.INNER_AXIS_256 == 0) and (k_axis % self.INNER_AXIS_256 == 0) - else: - return False - - def _check_fa_inner_axis(self, fa_list, op, shape): - shape_duration = 0. - affinity_flag = False - dtype = None - suggestion = "" - if "varlen" in op.lower(): - # 处理变长算子 如果不亲和则affinity_flag为False - inner_axis = int(shape.split("-")[0].split(";")[0].split(",")[2]) - if inner_axis % self.INNER_AXIS_128 != 0: - affinity_flag = True - suggestion = self._fa_affinity_desc_head_dim_128 - for operator in fa_list: - if (operator.op_name == op and - operator.input_shapes[1:-1] + "-" + operator.output_shapes[1:-1] == shape): - shape_duration += float(operator.task_duration) - dtype = operator.input_data_types - else: - # 处理定长算子 如果不亲和则affinity_flag为False - head_dim = 0 - seq_len = int(shape.split("-")[1].split(";")[0].split(",")[2]) - input_first_tensor = shape.split("-")[0].split(";")[0].split(",") - if len(input_first_tensor) == 3: - head_dim = int(input_first_tensor[2]) / int(shape.split("-")[1].split(";")[0].split(",")[1]) - else: - head_dim = int(input_first_tensor[3]) - if head_dim % self.INNER_AXIS_128 != 0 and seq_len % self.INNER_AXIS_128 != 0: - affinity_flag = True - suggestion = self._fa_affinity_desc_head_dim_seq_len_128 - elif head_dim % self.INNER_AXIS_128 != 0: - affinity_flag = True - suggestion = self._fa_affinity_desc_head_dim_128 - elif seq_len % self.INNER_AXIS_128 != 0: - affinity_flag = True - suggestion = self._fa_affinity_desc_seq_len_128 - if affinity_flag: - for operator in fa_list: - if (operator.op_name == op and - operator.input_shapes[1:-1] + "-" + - operator.output_shapes[1:-1] == shape): - shape_duration += float(operator.task_duration) - dtype = operator.input_data_types - return affinity_flag, dtype, shape_duration, suggestion diff --git a/profiler/msprof_analyze/advisor/analyzer/computation/operator_checker.py b/profiler/msprof_analyze/advisor/analyzer/computation/operator_checker.py index 4be0fc66ae8b8f75ca0518228cbdccde1a0d7c1e..ab9d4228b470ee515ed912ab018badbba3ec2e67 100644 --- a/profiler/msprof_analyze/advisor/analyzer/computation/operator_checker.py +++ b/profiler/msprof_analyze/advisor/analyzer/computation/operator_checker.py @@ -52,7 +52,6 @@ class OperatorChecker(VersionControl): self._tune_op_list: List[str] = [] self.prompt_class = BasePrompt.get_prompt_class("OperatorChecker") - self.rank_id = self.prompt_class.RANK_ID self.pytorch_op_tune_suggestion = self.prompt_class.PYTORCH_OPERATOR_TUNE_SUGGESTION self.mslite_op_tune_suggestion = self.prompt_class.MSLITE_OPERATOR_TUNE_SUGGESTION self.pytorch_release_suggestion = self.prompt_class.PYTORCH_RELEASE_SUGGESTION @@ -119,7 +118,7 @@ class OperatorChecker(VersionControl): """ if rank is not None: - self._problem = self.rank_id.format(rank) + self._problem.lower() + self._problem = self.prompt_class.RANK_ID.format(rank) + self._problem.lower() task_duration_list = [float(op_info.get_attr("task_duration")) for op_info in self._op_list @@ -302,7 +301,7 @@ class OperatorChecker(VersionControl): def format_suggestion_content(self, profiling_data: ProfilingDataset) -> None: if profiling_data.prof_type == EnumParamsParser().profiling_type.ascend_pytorch_profiler: self._suggestion.append(self.pytorch_op_tune_suggestion) - elif profiling_data.prof_type == EnumParamsParser().profiling_type.mslite: + elif profiling_data.prof_type == EnumParamsParser.profiling_type.mslite: self._suggestion.append(self.mslite_op_tune_suggestion) def _check_data(self, profiling_data): diff --git a/profiler/msprof_analyze/advisor/analyzer/computation/pp_stage_computation_analyzer.py b/profiler/msprof_analyze/advisor/analyzer/computation/pp_stage_computation_analyzer.py index 2780204b2064ed628ee686d91e82169818955eb7..2a08e668e140e090c6a3fb7f65fdbaf01310e741 100644 --- a/profiler/msprof_analyze/advisor/analyzer/computation/pp_stage_computation_analyzer.py +++ b/profiler/msprof_analyze/advisor/analyzer/computation/pp_stage_computation_analyzer.py @@ -1,118 +1,118 @@ -# Copyright (c) 2024, Huawei Technologies Co., Ltd. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging -from multiprocessing import Manager - -from msprof_analyze.advisor.analyzer.base_analyzer import BaseAnalyzer -from msprof_analyze.advisor.common.analyzer_scopes import SupportedScopes -from msprof_analyze.advisor.display.html.render import HTMLRender -from msprof_analyze.advisor.display.html.priority_background_color import PriorityBackgroundColor -from msprof_analyze.advisor.interface.interface import Interface -from msprof_analyze.advisor.utils.utils import ParallelJob, get_analyze_processes -from msprof_analyze.advisor.result.result import OptimizeResult -from msprof_analyze.advisor.result.item import OptimizeItem, OptimizeRecord - -logger = logging.getLogger() - - -class PPStageComputationAnalyzer(BaseAnalyzer): - - def __init__(self, collection_path, **kwargs): - super().__init__(collection_path, **kwargs) - self.collection_path = collection_path - self._stages_rendered_html = Manager().list() - self._multiprocess_result = Manager().dict() - # html render不能序列化,无法用多进程,放到optimize里面初始化 - self.html_render = None - self.result = None - - @staticmethod - def _get_valid_sheet_name(sheet_name, prefix): - if not sheet_name.lower().startswith(prefix.lower()): - sheet_name = f"{prefix} {sheet_name}" - return sheet_name - - def optimize(self, stages_profiling_path, **kwargs): - pp_stage_processes = min(get_analyze_processes(), len(stages_profiling_path)) - if pp_stage_processes <= 1: - for stage_profiling_path in stages_profiling_path: - self._optimize(**stage_profiling_path) - else: - logger.info("Start to parallel analysis of pp stages, number of processes is %s", pp_stage_processes) - parallel_stage_analysis_job = ParallelJob(self._optimize, stages_profiling_path, - "Computation analysis of Pipeline parallel stages") - parallel_stage_analysis_job.start(pp_stage_processes) - self._merge_multiprocess_result() - - self.make_render() - self.html_render = HTMLRender() - return self.result - - def make_render(self): - HTMLRender().render_template(key="computation", - template_dir="templates", - template_name="pp_stage_computation_analysis.html", - stages_rendered_html=list(self._stages_rendered_html), - priority_background_color=PriorityBackgroundColor.high) - - def get_priority(self, max_mem_op_dur=None): - pass - - def _optimize(self, profiling_path, **kwargs): - stage_html_record = dict(stage=kwargs.get("stage"), rank=kwargs.get("rank"), step=kwargs.get("step")) - kwargs["add_render_list"] = False - - # stage 并行分析时,避免调用本身,即SupportedScopes.STAGE_COMPUTE - scopes = Interface.get_scope(Interface.COMPUTATION) - stage_analyzer_list = [Interface.get_analyzer(Interface.COMPUTATION, scope) - for scope in scopes - if scope != SupportedScopes.STAGE_COMPUTE] - - for analyzer_cls in stage_analyzer_list: - analyzer = analyzer_cls(collection_path=profiling_path, **kwargs) - result = analyzer.optimize(**kwargs) - if hasattr(result, "data") and result.data: - self.result = result - if hasattr(analyzer, "html") and analyzer.html: - if "html_list" not in stage_html_record: - stage_html_record["html_list"] = [] - stage_html_record["html_list"].append(analyzer.html) - self._stages_rendered_html.append(stage_html_record) - self._multiprocess_result[f"rank {kwargs.get('rank')}".capitalize()] = result.data - - def _merge_multiprocess_result(self): - self.result = OptimizeResult() - for key, result_data in self._multiprocess_result.items(): - problem_data = result_data.get("problems", {}).get("data", []) - if not problem_data: - continue - - for row in problem_data: - if len(row) < 3: - continue - issue_name, desc, suggestion = row[:3] - sheet_name = PPStageComputationAnalyzer._get_valid_sheet_name(issue_name, key) - optimization_item = OptimizeItem(sheet_name, desc, [suggestion]) - self.result.add(OptimizeRecord(optimization_item)) - del result_data["problems"] - - for issue_name, issue_details in result_data.items(): - headers = issue_details.get("headers", []) - data = issue_details.get("data", []) - sheet_name = PPStageComputationAnalyzer._get_valid_sheet_name(issue_name, key) - self.result.add_detail(sheet_name, headers=headers) - - for row in data: - self.result.add_detail(sheet_name, detail=row) +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from multiprocessing import Manager + +from msprof_analyze.advisor.analyzer.base_analyzer import BaseAnalyzer +from msprof_analyze.advisor.common.analyzer_scopes import SupportedScopes +from msprof_analyze.advisor.display.html.render import HTMLRender +from msprof_analyze.advisor.display.html.priority_background_color import PriorityBackgroundColor +from msprof_analyze.advisor.interface.interface import Interface +from msprof_analyze.advisor.utils.utils import ParallelJob, get_analyze_processes +from msprof_analyze.advisor.result.result import OptimizeResult +from msprof_analyze.advisor.result.item import OptimizeItem, OptimizeRecord + +logger = logging.getLogger() + + +class PPStageComputationAnalyzer(BaseAnalyzer): + + def __init__(self, collection_path, **kwargs): + super().__init__(collection_path, **kwargs) + self.collection_path = collection_path + self._stages_rendered_html = Manager().list() + self._multiprocess_result = Manager().dict() + # html render不能序列化,无法用多进程,放到optimize里面初始化 + self.html_render = None + self.result = None + + @staticmethod + def _get_valid_sheet_name(sheet_name, prefix): + if not sheet_name.lower().startswith(prefix.lower()): + sheet_name = f"{prefix} {sheet_name}" + return sheet_name + + def optimize(self, stages_profiling_path, **kwargs): + pp_stage_processes = min(get_analyze_processes(), len(stages_profiling_path)) + if pp_stage_processes <= 1: + for stage_profiling_path in stages_profiling_path: + self._optimize(**stage_profiling_path) + else: + logger.info("Start to parallel analysis of pp stages, number of processes is %s", pp_stage_processes) + parallel_stage_analysis_job = ParallelJob(self._optimize, stages_profiling_path, + "Computation analysis of Pipeline parallel stages") + parallel_stage_analysis_job.start(pp_stage_processes) + self._merge_multiprocess_result() + + self.make_render() + self.html_render = HTMLRender() + return self.result + + def make_render(self): + HTMLRender().render_template(key="computation", + template_dir="templates", + template_name="pp_stage_computation_analysis.html", + stages_rendered_html=list(self._stages_rendered_html), + priority_background_color=PriorityBackgroundColor.high) + + def get_priority(self, max_mem_op_dur=None): + pass + + def _optimize(self, profiling_path, **kwargs): + stage_html_record = dict(stage=kwargs.get("stage"), rank=kwargs.get("rank"), step=kwargs.get("step")) + kwargs["add_render_list"] = False + + # stage 并行分析时,避免调用本身,即SupportedScopes.STAGE_COMPUTE + scopes = Interface.get_scope(Interface.COMPUTATION) + stage_analyzer_list = [Interface.get_analyzer(Interface.COMPUTATION, scope) + for scope in scopes + if scope != SupportedScopes.STAGE_COMPUTE] + + for analyzer_cls in stage_analyzer_list: + analyzer = analyzer_cls(collection_path=profiling_path, **kwargs) + result = analyzer.optimize(**kwargs) + if hasattr(result, "data") and result.data: + self.result = result + if hasattr(analyzer, "html") and analyzer.html: + if "html_list" not in stage_html_record: + stage_html_record["html_list"] = [] + stage_html_record["html_list"].append(analyzer.html) + self._stages_rendered_html.append(stage_html_record) + self._multiprocess_result[f"rank {kwargs.get('rank')}".capitalize()] = result.data + + def _merge_multiprocess_result(self): + self.result = OptimizeResult() + for key, result_data in self._multiprocess_result.items(): + problem_data = result_data.get("problems", {}).get("data", []) + if not problem_data: + continue + + for row in problem_data: + if len(row) < 3: + continue + issue_name, desc, suggestion = row[:3] + sheet_name = PPStageComputationAnalyzer._get_valid_sheet_name(issue_name, key) + optimization_item = OptimizeItem(sheet_name, desc, [suggestion]) + self.result.add(OptimizeRecord(optimization_item)) + del result_data["problems"] + + for issue_name, issue_details in result_data.items(): + headers = issue_details.get("headers", []) + data = issue_details.get("data", []) + sheet_name = PPStageComputationAnalyzer._get_valid_sheet_name(issue_name, key) + self.result.add_detail(sheet_name, headers=headers) + + for row in data: + self.result.add_detail(sheet_name, detail=row) diff --git a/profiler/msprof_analyze/advisor/analyzer/schedule/fusible_ops/fusible_operator_checker.py b/profiler/msprof_analyze/advisor/analyzer/schedule/fusible_ops/fusible_operator_checker.py index 9070a8036047f7976ca7e9a7ab81bd5bf9632af6..3ab54b0dbb8729c8297606a471ce67e55715b2b8 100644 --- a/profiler/msprof_analyze/advisor/analyzer/schedule/fusible_ops/fusible_operator_checker.py +++ b/profiler/msprof_analyze/advisor/analyzer/schedule/fusible_ops/fusible_operator_checker.py @@ -88,7 +88,7 @@ class FusibleOperatorChecker: @staticmethod def check_hccl(task: OpInfo): - return (task.task_type in ["COMMUNICATION", "HCCL"] or + return (task.task_type == "HCCL" or any(task.op_name.lower().startswith(item) for item in ["hcom", "lccl", "lcoc"])) @staticmethod diff --git a/profiler/msprof_analyze/advisor/analyzer/schedule/syncbn/syncbn_analyzer.py b/profiler/msprof_analyze/advisor/analyzer/schedule/syncbn/syncbn_analyzer.py index 1e75d4e8969d57d54f55eb477165e6379664b817..48506da62646cf337380be7d6c7eb6779161889e 100644 --- a/profiler/msprof_analyze/advisor/analyzer/schedule/syncbn/syncbn_analyzer.py +++ b/profiler/msprof_analyze/advisor/analyzer/schedule/syncbn/syncbn_analyzer.py @@ -1,46 +1,46 @@ -# Copyright (c) 2024, Huawei Technologies Co., Ltd. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging - -from msprof_analyze.advisor.analyzer.base_analyzer import BaseAnalyzer -from msprof_analyze.advisor.result.result import OptimizeResult -from msprof_analyze.advisor.analyzer.schedule.syncbn.syncbn_checker import SyncBNChecker -from msprof_analyze.advisor.display.html.priority_background_color import PriorityBackgroundColor -from msprof_analyze.advisor.display.html.render import HTMLRender -from msprof_analyze.advisor.dataset.timeline_event_dataset import ScheduleAnalysisDataset - -logger = logging.getLogger() - - -class SyncBNAnalyzer(BaseAnalyzer): - dataset_cls_list = [ScheduleAnalysisDataset] - - def __init__(self, collection_path, **kwargs): - super().__init__(collection_path, **kwargs) - self.result = OptimizeResult() - self.html_render = HTMLRender() - key = ScheduleAnalysisDataset.get_key() - self.timeline_event_dataset = self.get_first_data_by_key(self.dataset_list, key) - - @BaseAnalyzer.check_data((ScheduleAnalysisDataset.get_key(),)) - def optimize(self, **kwargs): - syncbn_checker = SyncBNChecker() - syncbn_checker.check_syncbn(self.timeline_event_dataset) - syncbn_checker.make_record(self.result) - syncbn_checker.make_render(self.html_render, priority=self.get_priority(), rank=kwargs.get("rank")) - return self.result - - def get_priority(self, max_mem_op_dur=None): +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from msprof_analyze.advisor.analyzer.base_analyzer import BaseAnalyzer +from msprof_analyze.advisor.result.result import OptimizeResult +from msprof_analyze.advisor.analyzer.schedule.syncbn.syncbn_checker import SyncBNChecker +from msprof_analyze.advisor.display.html.priority_background_color import PriorityBackgroundColor +from msprof_analyze.advisor.display.html.render import HTMLRender +from msprof_analyze.advisor.dataset.timeline_event_dataset import ScheduleAnalysisDataset + +logger = logging.getLogger() + + +class SyncBNAnalyzer(BaseAnalyzer): + dataset_cls_list = [ScheduleAnalysisDataset] + + def __init__(self, collection_path, **kwargs): + super().__init__(collection_path, **kwargs) + self.result = OptimizeResult() + self.html_render = HTMLRender() + key = ScheduleAnalysisDataset.get_key() + self.timeline_event_dataset = self.get_first_data_by_key(self.dataset_list, key) + + @BaseAnalyzer.check_data((ScheduleAnalysisDataset.get_key(),)) + def optimize(self, **kwargs): + syncbn_checker = SyncBNChecker() + syncbn_checker.check_syncbn(self.timeline_event_dataset) + syncbn_checker.make_record(self.result) + syncbn_checker.make_render(self.html_render, priority=self.get_priority(), rank=kwargs.get("rank")) + return self.result + + def get_priority(self, max_mem_op_dur=None): return PriorityBackgroundColor.high \ No newline at end of file diff --git a/profiler/msprof_analyze/advisor/analyzer/schedule/synchronize_stream/synchronize_stream_analyzer.py b/profiler/msprof_analyze/advisor/analyzer/schedule/synchronize_stream/synchronize_stream_analyzer.py index ea095e1968f67d4762280bb4dfe180bddde4368e..4ac82fd827186e8afbf93391ce4b109e7cefcf38 100644 --- a/profiler/msprof_analyze/advisor/analyzer/schedule/synchronize_stream/synchronize_stream_analyzer.py +++ b/profiler/msprof_analyze/advisor/analyzer/schedule/synchronize_stream/synchronize_stream_analyzer.py @@ -1,48 +1,48 @@ -# Copyright (c) 2024, Huawei Technologies Co., Ltd. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging - -from msprof_analyze.advisor.analyzer.base_analyzer import BaseAnalyzer -from msprof_analyze.advisor.analyzer.schedule.synchronize_stream.synchronize_stream_checker import \ - SynchronizeStreamChecker -from msprof_analyze.advisor.dataset.timeline_event_dataset import ScheduleAnalysisDataset -from msprof_analyze.advisor.display.html.render import HTMLRender -from msprof_analyze.advisor.result.result import OptimizeResult - -logger = logging.getLogger() - - -class SynchronizeStreamAnalyzer(BaseAnalyzer): - dataset_cls_list = [ScheduleAnalysisDataset] - - def __init__(self, collection_path, **kwargs): - super().__init__(collection_path, **kwargs) - self.result = OptimizeResult() - self.html_render = HTMLRender() - - key = ScheduleAnalysisDataset.get_key() - self.timeline_event_dataset = self.get_first_data_by_key(self.dataset_list, key) - - @BaseAnalyzer.check_data((ScheduleAnalysisDataset.get_key(),)) - def optimize(self, **kwargs): - synchronize_stream_checker = SynchronizeStreamChecker() - synchronize_stream_checker.check_synchronize(self.timeline_event_dataset) - synchronize_stream_checker.make_record(self.result) - synchronize_stream_checker.make_render(self.html_render, priority=self.get_priority(synchronize_stream_checker), - rank=kwargs.get("rank")) - return self.result - - def get_priority(self, max_mem_op_dur): - return max_mem_op_dur.priority +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from msprof_analyze.advisor.analyzer.base_analyzer import BaseAnalyzer +from msprof_analyze.advisor.analyzer.schedule.synchronize_stream.synchronize_stream_checker import \ + SynchronizeStreamChecker +from msprof_analyze.advisor.dataset.timeline_event_dataset import ScheduleAnalysisDataset +from msprof_analyze.advisor.display.html.render import HTMLRender +from msprof_analyze.advisor.result.result import OptimizeResult + +logger = logging.getLogger() + + +class SynchronizeStreamAnalyzer(BaseAnalyzer): + dataset_cls_list = [ScheduleAnalysisDataset] + + def __init__(self, collection_path, **kwargs): + super().__init__(collection_path, **kwargs) + self.result = OptimizeResult() + self.html_render = HTMLRender() + + key = ScheduleAnalysisDataset.get_key() + self.timeline_event_dataset = self.get_first_data_by_key(self.dataset_list, key) + + @BaseAnalyzer.check_data((ScheduleAnalysisDataset.get_key(),)) + def optimize(self, **kwargs): + synchronize_stream_checker = SynchronizeStreamChecker() + synchronize_stream_checker.check_synchronize(self.timeline_event_dataset) + synchronize_stream_checker.make_record(self.result) + synchronize_stream_checker.make_render(self.html_render, priority=self.get_priority(synchronize_stream_checker), + rank=kwargs.get("rank")) + return self.result + + def get_priority(self, max_mem_op_dur): + return max_mem_op_dur.priority diff --git a/profiler/msprof_analyze/advisor/common/analyzer_scopes.py b/profiler/msprof_analyze/advisor/common/analyzer_scopes.py index 6a6261c7b75e721c0a9df75f35ecb3cd2aa1e487..07ceef769440b39c93aeaaf15ded5ad99fc3f4b3 100644 --- a/profiler/msprof_analyze/advisor/common/analyzer_scopes.py +++ b/profiler/msprof_analyze/advisor/common/analyzer_scopes.py @@ -1,4 +1,5 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2024-2025. All rights reserved. +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -40,4 +41,3 @@ class SupportedScopes: FUSIBLE_OPERATOR_ANALYSIS = "fusible_operator_analysis" CONJECTURED_GC_ANALYSIS = "conjectured_analysis" COMPARISON = "comparison" - AICORE_PERFORMANCE_ANALYSIS = "ai_core_performance_analysis" diff --git a/profiler/msprof_analyze/advisor/common/async_analysis_status.py b/profiler/msprof_analyze/advisor/common/async_analysis_status.py index 98bb458105421b38395f745f2913311a24a5ce40..2d314b5cb0d1994f28d74f876395a04f0d8eedee 100644 --- a/profiler/msprof_analyze/advisor/common/async_analysis_status.py +++ b/profiler/msprof_analyze/advisor/common/async_analysis_status.py @@ -1,27 +1,27 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- - -# Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - - -class AsyncAnalysisStatus: - FAILED = "failed" - SUCCESS = "success" - ANALYZING = "analyzing" - - BAD_REQUEST_STATUS_CODE = 400 - NOT_FOUND_STATUS_CODE = 404 - INNER_ERROR_STATUS_CODE = 500 - NON_FAILED_STATUS_CODE = 200 +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + + +class AsyncAnalysisStatus: + FAILED = "failed" + SUCCESS = "success" + ANALYZING = "analyzing" + + BAD_REQUEST_STATUS_CODE = 400 + NOT_FOUND_STATUS_CODE = 404 + INNER_ERROR_STATUS_CODE = 500 + NON_FAILED_STATUS_CODE = 200 diff --git a/profiler/msprof_analyze/advisor/common/enum_params_parser.py b/profiler/msprof_analyze/advisor/common/enum_params_parser.py index ebf81ae38c249f4701e46f9a05b5cb9f86db635c..7158af929f5711a32de835b426bbe91c2000a401 100644 --- a/profiler/msprof_analyze/advisor/common/enum_params_parser.py +++ b/profiler/msprof_analyze/advisor/common/enum_params_parser.py @@ -1,104 +1,104 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- - -# Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import logging -import typing - -from msprof_analyze.advisor.common.timeline.event import AdvisorDict -from msprof_analyze.advisor.utils.utils import singleton -from msprof_analyze.prof_common.file_manager import FileManager - -logger = logging.getLogger() - - -@singleton -class EnumParamsParser(): - # 枚举变量抽象成yaml文件,统一管理,便于第三方服务对接advisor时调用当前类查询所有枚举变量参数的默认值和可选值 - - ARGUMENTS = "arguments" - ENVS = "envs" - OPTIONS = "options" - DEFAULT = "default" - TYPE = "type" - STR_TYPE = "str" - LIST_TYPE = "list" - INT_TYPE = "int" - BOOLEAN_TYPE = "boolean" - - def __init__(self): - enum_params_path = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "config", - "enum_parameters.yaml") - self.enum_params = FileManager.read_yaml_file(enum_params_path) - self._set_value() - - def get_keys(self): - return list(self.get_arguments_keys()) + list(self.get_envs_keys()) - - def get_arguments_keys(self): - return list(self.enum_params.get(self.ARGUMENTS, {}).keys()) - - def get_envs_keys(self): - return list(self.enum_params.get(self.ENVS, {}).keys()) - - def get_options(self, key, filter_func=None): - options = [] - for param_type in [self.ARGUMENTS, self.ENVS]: - if key not in self.enum_params.get(param_type, {}): - continue - options = self.enum_params.get(param_type, {}).get(key, {}).get(self.OPTIONS, []) - - if not options: - logger.error("Key %s not exists, optionals are %s", key, self.get_keys()) - - if filter_func is not None and callable(filter_func): - options = [value for value in options if filter_func(value)] - - return options - - def get_value_type(self, key): - for param_type in [self.ARGUMENTS, self.ENVS]: - if key not in self.enum_params.get(param_type, {}): - continue - value_type = self.enum_params.get(param_type, {}).get(key, {}).get(self.TYPE, self.STR_TYPE) - return value_type - return self.STR_TYPE - - def get_default(self, key): - default_value = None - for param_type in [self.ARGUMENTS, self.ENVS]: - if key not in self.enum_params.get(param_type, {}): - continue - default_value = self.enum_params.get(param_type, {}).get(key, {}).get(self.DEFAULT, []) - - if not default_value: - logger.error("Key %s not exists, optionals are %s", key, self.get_keys()) - - return default_value - - def _set_value(self): - - for key in self.get_keys(): - - if not hasattr(self, key): - setattr(self, str(key), AdvisorDict()) - - options = self.get_options(key) - - for value in options: - if not isinstance(value, typing.Hashable): - continue - getattr(self, key)[str(value)] = value +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import logging +import typing + +from msprof_analyze.advisor.common.timeline.event import AdvisorDict +from msprof_analyze.advisor.utils.utils import singleton +from msprof_analyze.prof_common.file_manager import FileManager + +logger = logging.getLogger() + + +@singleton +class EnumParamsParser(): + # 枚举变量抽象成yaml文件,统一管理,便于第三方服务对接advisor时调用当前类查询所有枚举变量参数的默认值和可选值 + + ARGUMENTS = "arguments" + ENVS = "envs" + OPTIONS = "options" + DEFAULT = "default" + TYPE = "type" + STR_TYPE = "str" + LIST_TYPE = "list" + INT_TYPE = "int" + BOOLEAN_TYPE = "boolean" + + def __init__(self): + enum_params_path = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "config", + "enum_parameters.yaml") + self.enum_params = FileManager.read_yaml_file(enum_params_path) + self._set_value() + + def get_keys(self): + return list(self.get_arguments_keys()) + list(self.get_envs_keys()) + + def get_arguments_keys(self): + return list(self.enum_params.get(self.ARGUMENTS, {}).keys()) + + def get_envs_keys(self): + return list(self.enum_params.get(self.ENVS, {}).keys()) + + def get_options(self, key, filter_func=None): + options = [] + for param_type in [self.ARGUMENTS, self.ENVS]: + if key not in self.enum_params.get(param_type, {}): + continue + options = self.enum_params.get(param_type, {}).get(key, {}).get(self.OPTIONS, []) + + if not options: + logger.error("Key %s not exists, optionals are %s", key, self.get_keys()) + + if filter_func is not None and callable(filter_func): + options = [value for value in options if filter_func(value)] + + return options + + def get_value_type(self, key): + for param_type in [self.ARGUMENTS, self.ENVS]: + if key not in self.enum_params.get(param_type, {}): + continue + value_type = self.enum_params.get(param_type, {}).get(key, {}).get(self.TYPE, self.STR_TYPE) + return value_type + return self.STR_TYPE + + def get_default(self, key): + default_value = None + for param_type in [self.ARGUMENTS, self.ENVS]: + if key not in self.enum_params.get(param_type, {}): + continue + default_value = self.enum_params.get(param_type, {}).get(key, {}).get(self.DEFAULT, []) + + if not default_value: + logger.error("Key %s not exists, optionals are %s", key, self.get_keys()) + + return default_value + + def _set_value(self): + + for key in self.get_keys(): + + if not hasattr(self, key): + setattr(self, str(key), AdvisorDict()) + + options = self.get_options(key) + + for value in options: + if not isinstance(value, typing.Hashable): + continue + getattr(self, key)[str(value)] = value diff --git a/profiler/msprof_analyze/advisor/config/enum_parameters.yaml b/profiler/msprof_analyze/advisor/config/enum_parameters.yaml index 678fe72b43c7f5b2fd66b3f38c3114cc9793cd50..534859eb9d08887ca35a65b12db70f5cca4a1716 100644 --- a/profiler/msprof_analyze/advisor/config/enum_parameters.yaml +++ b/profiler/msprof_analyze/advisor/config/enum_parameters.yaml @@ -1,58 +1,58 @@ -arguments: - cann_version: - type: str - options: - - 6.3.RC2 - - 7.0.RC1 - - 7.0.0 - - 8.0.RC1 - - 8.0.RC2 - - 8.0.0 - default: 8.0.0 - - torch_version: - type: str - options: - - 1.11.0 - - 2.1.0 - default: 2.1.0 - mindspore_version: - type: str - options: - - 2.3.0 - - 2.4.0 - default: 2.4.0 - analysis_dimensions: - type: list - options: - - [ computation, communication, schedule, memory ] - - [ computation ] - - [ communication ] - - [ schedule ] - - [ memory ] - default: [ computation, communication, schedule, memory ] - - profiling_type: - type: str - options: - - pytorch - - mslite - - msprof - - mindspore - default: pytorch - -envs: - ADVISOR_ANALYZE_PROCESSES: - type: int - options: [ 1, 2, 3, 4, 5, 6, 7, 8 ] - default: 1 - - DISABLE_PROFILING_COMPARISON: - type: boolean - options: [ true, false ] - default: false - - DISABLE_AFFINITY_API: - type: boolean - options: [ true, false ] - default: false +arguments: + cann_version: + type: str + options: + - 6.3.RC2 + - 7.0.RC1 + - 7.0.0 + - 8.0.RC1 + - 8.0.RC2 + - 8.0.0 + default: 8.0.0 + + torch_version: + type: str + options: + - 1.11.0 + - 2.1.0 + default: 2.1.0 + mindspore_version: + type: str + options: + - 2.3.0 + - 2.4.0 + default: 2.4.0 + analysis_dimensions: + type: list + options: + - [ computation, communication, schedule, memory ] + - [ computation ] + - [ communication ] + - [ schedule ] + - [ memory ] + default: [ computation, communication, schedule, memory ] + + profiling_type: + type: str + options: + - pytorch + - mslite + - msprof + - mindspore + default: pytorch + +envs: + ADVISOR_ANALYZE_PROCESSES: + type: int + options: [ 1, 2, 3, 4, 5, 6, 7, 8 ] + default: 1 + + DISABLE_PROFILING_COMPARISON: + type: boolean + options: [ true, false ] + default: false + + DISABLE_AFFINITY_API: + type: boolean + options: [ true, false ] + default: false diff --git a/profiler/msprof_analyze/advisor/dataset/cluster/cluster_dataset.py b/profiler/msprof_analyze/advisor/dataset/cluster/cluster_dataset.py index 4489dde44621e5650f664cd8e28262f2df613c84..b47f6d4518b45d84497fe4eac87cfe11d0fccb04 100644 --- a/profiler/msprof_analyze/advisor/dataset/cluster/cluster_dataset.py +++ b/profiler/msprof_analyze/advisor/dataset/cluster/cluster_dataset.py @@ -50,8 +50,8 @@ class ClusterDataset(Dataset): if self.is_cluster_analysis_output_exist(): return parameter = { - Constant.PROFILING_PATH: self.collection_path, - Constant.MODE: "all", + Constant.COLLECTION_PATH: self.collection_path, + Constant.ANALYSIS_MODE: "all", Constant.CLUSTER_ANALYSIS_OUTPUT_PATH: self.output_path } logger.info("cluster analysis is in the process, please wait...") diff --git a/profiler/msprof_analyze/advisor/dataset/communication/hccl_detail_dataset.py b/profiler/msprof_analyze/advisor/dataset/communication/hccl_detail_dataset.py index a1d5425b5431b9dc8149b957ae8deb95a2f9295d..fac5603b99bfd4956503fc76d6355edb8da54941 100644 --- a/profiler/msprof_analyze/advisor/dataset/communication/hccl_detail_dataset.py +++ b/profiler/msprof_analyze/advisor/dataset/communication/hccl_detail_dataset.py @@ -39,8 +39,7 @@ class HcclDetailDataset: @staticmethod def _get_hccl_pid(tasks: List[TaskInfo]): for task in tasks: - if task.name == "process_name" and hasattr(task, "args") \ - and task.args.get("name", None) in ["Communication", "HCCL"]: + if task.name == "process_name" and hasattr(task, "args") and task.args.get("name", None) == "HCCL": return task.pid return -1 diff --git a/profiler/msprof_analyze/advisor/display/html/priority_background_color.py b/profiler/msprof_analyze/advisor/display/html/priority_background_color.py index f5b89b232f4f0b2b04ec559149fc96768997ea85..6b03747a81b532364816e171b846adac1f1883fa 100644 --- a/profiler/msprof_analyze/advisor/display/html/priority_background_color.py +++ b/profiler/msprof_analyze/advisor/display/html/priority_background_color.py @@ -1,19 +1,19 @@ -# Copyright (c) 2024, Huawei Technologies Co., Ltd. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -class PriorityBackgroundColor: - high = "#B5495B" - medium = "#fcaf17" - low = "#65c294" +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +class PriorityBackgroundColor: + high = "#B5495B" + medium = "#fcaf17" + low = "#65c294" diff --git a/profiler/msprof_analyze/advisor/display/html/templates/ai_core_performance.html b/profiler/msprof_analyze/advisor/display/html/templates/ai_core_performance.html deleted file mode 100644 index 77e5e0cb55200efdf5b854e03ac2844ddc631a8f..0000000000000000000000000000000000000000 --- a/profiler/msprof_analyze/advisor/display/html/templates/ai_core_performance.html +++ /dev/null @@ -1,159 +0,0 @@ -{% if format_result|length > 0 %} -
    -

    AI CORE Performance Analysis

    -
    - {% if language == "cn" %} - {% set title_ns = namespace(type='类别', desc='描述及建议', opti_set='性能优化算子集合', bound_set='bound算子集合', affinity_set='不亲和算子集合', - opti_refer=' 参考性能优化空间: ', bound_refer=' bound类型为: ', affinity_refer=' 不亲和类型为: ', title_desc='算子相关分析,参考如下: ') %} - {% else %} - {% set title_ns = namespace(type='Type', desc='Description and Suggestion', opti_set='set of performance optimization operators', - bound_set='set of bound operators', affinity_set='set of unaffine operators', opti_refer=' refer to Performance Optimization Space: ', - bound_refer=' bound type: ', affinity_refer=' type of disaffinity: ', title_desc=' Operator related analysis, referenced below: ') %} - {% endif %} - {% if format_result.cube[0]|length + format_result.cube[1]|length + format_result.cube[2]|length > 0 %} - MatMul{{ title_ns.title_desc }} -
    -
    - - - - - {% set opti_ns = namespace(total_opti='') %} - {% for opti in format_result.cube[0] %} - {% if not loop.first %} - {% set opti_ns.total_opti = opti_ns.total_opti ~ "
    " ~ opti.op_name ~ " operator shape: " ~ opti.shape ~ " dtype: " ~ opti.dtype ~ title_ns.opti_refer ~ opti.optimization ~ "%" %} - {% else %} - {% set opti_ns.total_opti = opti.op_name ~ " operator shape: " ~ opti.shape ~ " dtype: " ~ opti.dtype ~ title_ns.opti_refer ~ opti.optimization ~ "%" %} - {% endif %} - {% endfor %} - {% if opti_ns.total_opti|length > 0 %} -
    - - - - {% endif %} - {% set bound_ns = namespace(total_bound='') %} - {% for bound in format_result.cube[1] %} - {% if not loop.first %} - {% set bound_ns.total_bound = bound_ns.total_bound ~ "
    " ~ bound.op_name ~ " operator shape: " ~ bound.shape ~ " dtype: " ~ bound.dtype ~ title_ns.bound_refer ~ bound.bound %} - {% else %} - {% set bound_ns.total_bound = bound.op_name ~ " operator shape: " ~ bound.shape ~ " dtype: " ~ bound.dtype ~ title_ns.bound_refer ~ bound.bound %} - {% endif %} - {% endfor %} - {% if bound_ns.total_bound|length > 0 %} -
    - - - - {% endif %} - {% set affinity_ns = namespace(total_affinity='') %} - {% for affinity in format_result.cube[2] %} - {% if not loop.first %} - {% set affinity_ns.total_affinity = affinity_ns.total_affinity ~ "
    " ~ affinity.op_name ~ " operator shape: " ~ affinity.shape ~ " dtype: " ~ affinity.dtype ~ title_ns.affinity_refer ~ affinity.suggestion %} - {% else %} - {% set affinity_ns.total_affinity = affinity.op_name ~ " operator shape: " ~ affinity.shape ~ " dtype: " ~ affinity.dtype ~ title_ns.affinity_refer ~ affinity.suggestion %} - {% endif %} - {% endfor %} - {% if affinity_ns.total_affinity|length > 0 %} -
    - - - - {% endif %} -
    {{ title_ns.type }}{{ title_ns.desc }}
    {{ title_ns.opti_set }}{{ opti_ns.total_opti | safe }}
    {{ title_ns.bound_set }}{{ bound_ns.total_bound | safe }}
    {{ title_ns.affinity_set }}{{ affinity_ns.total_affinity | safe }}
    - {% endif %} - - {% if format_result.fa[0]|length + format_result.fa[1]|length + format_result.fa[2]|length > 0 %} - FA{{ title_ns.title_desc }} -
    - - - - - - {% set opti_ns = namespace(total_opti='') %} - {% for opti in format_result.fa[0] %} - {% if not loop.first %} - {% set opti_ns.total_opti = opti_ns.total_opti ~ "
    " ~ opti.op_name ~ " operator shape: " ~ opti.shape ~ " dtype: " ~ opti.dtype ~ title_ns.opti_refer ~ opti.optimization ~ "%" %} - {% else %} - {% set opti_ns.total_opti = opti.op_name ~ " operator shape: " ~ opti.shape ~ " dtype: " ~ opti.dtype ~ title_ns.opti_refer ~ opti.optimization ~ "%" %} - {% endif %} - {% endfor %} - {% if opti_ns.total_opti|length > 0 %} - - - - - {% endif %} - {% set bound_ns = namespace(total_bound='') %} - {% for bound in format_result.fa[1] %} - {% if not loop.first %} - {% set bound_ns.total_bound = bound_ns.total_bound ~ "
    " ~ bound.op_name ~ " operator shape: " ~ bound.shape ~ " dtype: " ~ bound.dtype ~ title_ns.bound_refer ~ bound.bound %} - {% else %} - {% set bound_ns.total_bound = bound.op_name ~ " operator shape: " ~ bound.shape ~ " dtype: " ~ bound.dtype ~ title_ns.bound_refer ~ bound.bound %} - {% endif %} - {% endfor %} - {% if bound_ns.total_bound|length > 0 %} - - - - - {% endif %} - {% set affinity_ns = namespace(total_affinity='') %} - {% for affinity in format_result.fa[2] %} - {% if not loop.first %} - {% set affinity_ns.total_affinity = affinity_ns.total_affinity ~ "
    " ~ affinity.op_name ~ " operator shape: " ~ affinity.shape ~ " dtype: " ~ affinity.dtype ~ title_ns.affinity_refer ~ affinity.suggestion %} - {% else %} - {% set affinity_ns.total_affinity = affinity.op_name ~ " operator shape: " ~ affinity.shape ~ " dtype: " ~ affinity.dtype ~ title_ns.affinity_refer ~ affinity.suggestion %} - {% endif %} - {% endfor %} - {% if affinity_ns.total_affinity|length > 0 %} - - - - - {% endif %} -
    {{ title_ns.type }}{{ title_ns.desc }}
    {{ title_ns.opti_set }}{{ opti_ns.total_opti | safe }}
    {{ title_ns.bound_set }}{{ bound_ns.total_bound | safe }}
    {{ title_ns.affinity_set }}{{ affinity_ns.total_affinity | safe }}
    - {% endif %} - - {% if format_result.vector[0]|length + format_result.vector[1]|length > 0 %} - Vector{{ title_ns.title_desc }} -
    - - - - - - {% set opti_ns = namespace(total_opti='') %} - {% for opti in format_result.vector[0] %} - {% if not loop.first %} - {% set opti_ns.total_opti = opti_ns.total_opti ~ "
    " ~ opti.op_name ~ " operator shape: " ~ opti.shape ~ " dtype: " ~ opti.dtype ~ title_ns.opti_refer ~ opti.optimization ~ "%" %} - {% else %} - {% set opti_ns.total_opti = opti.op_name ~ " operator shape: " ~ opti.shape ~ " dtype: " ~ opti.dtype ~ title_ns.opti_refer ~ opti.optimization ~ "%" %} - {% endif %} - {% endfor %} - {% if opti_ns.total_opti|length > 0 %} - - - - - {% endif %} - {% set bound_ns = namespace(total_bound='') %} - {% for bound in format_result.vector[1] %} - {% if not loop.first %} - {% set bound_ns.total_bound = bound_ns.total_bound ~ "
    " ~ bound.op_name ~ " operator shape: " ~ bound.shape ~ " dtype: " ~ bound.dtype ~ title_ns.bound_refer ~ bound.bound %} - {% else %} - {% set bound_ns.total_bound = bound.op_name ~ " operator shape: " ~ bound.shape ~ " dtype: " ~ bound.dtype ~ title_ns.bound_refer ~ bound.bound %} - {% endif %} - {% endfor %} - {% if bound_ns.total_bound|length > 0 %} - - - - - {% endif %} -
    {{ title_ns.type }}{{ title_ns.desc }}
    {{ title_ns.opti_set }}{{ opti_ns.total_opti | safe }}
    {{ title_ns.bound_set }}{{ bound_ns.total_bound | safe }}
    - {% endif %} -
    -
    -{% endif %} \ No newline at end of file diff --git a/profiler/msprof_analyze/advisor/display/html/templates/comparison.html b/profiler/msprof_analyze/advisor/display/html/templates/comparison.html index 5963e75308c447a386f50517587e857c237fc061..b81802d6b0505ca4a21e5174a0158b800d4a43ec 100644 --- a/profiler/msprof_analyze/advisor/display/html/templates/comparison.html +++ b/profiler/msprof_analyze/advisor/display/html/templates/comparison.html @@ -1,25 +1,25 @@ -{% if rows|length > 0 %} -
    -

    {{ sheet_name }}

    -
    - Issue: {{ desc }} -

    - - - {% for header in headers %} - - {% endfor %} - - - {% for row in rows %} - - {% for element in row %} - - {% endfor %} - - {% endfor %} -
    {{ header }}
    {{ element|safe }}
    - -
    -
    +{% if rows|length > 0 %} +
    +

    {{ sheet_name }}

    +
    + Issue: {{ desc }} +

    + + + {% for header in headers %} + + {% endfor %} + + + {% for row in rows %} + + {% for element in row %} + + {% endfor %} + + {% endfor %} +
    {{ header }}
    {{ element|safe }}
    + +
    +
    {% endif %} \ No newline at end of file diff --git a/profiler/msprof_analyze/advisor/display/html/templates/memory.html b/profiler/msprof_analyze/advisor/display/html/templates/memory.html index a3d75877b60ef3481a13572fbd6b0e2bb5eaf2a0..2bf57f46a1ee1b38302b9096f07b6754ca41ce82 100644 --- a/profiler/msprof_analyze/advisor/display/html/templates/memory.html +++ b/profiler/msprof_analyze/advisor/display/html/templates/memory.html @@ -1,21 +1,21 @@ -
    -

    Memory Operator Issues

    -
    - {% if rank is not none %} - Analysis of rank {{ rank|safe }}. - {% endif %} - {{ desc }} - - - - - - {% for suggestion in suggestions %} - - - - {% endfor %} -
    Suggestions
    {{ loop.index }}. {{ suggestion|safe }}
    - -
    -
    +
    +

    Memory Operator Issues

    +
    + {% if rank is not none %} + Analysis of rank {{ rank|safe }}. + {% endif %} + {{ desc }} + + + + + + {% for suggestion in suggestions %} + + + + {% endfor %} +
    Suggestions
    {{ loop.index }}. {{ suggestion|safe }}
    + +
    +
    diff --git a/profiler/msprof_analyze/advisor/display/html/templates/pp_stage_computation_analysis.html b/profiler/msprof_analyze/advisor/display/html/templates/pp_stage_computation_analysis.html index 189e6fadf863e5d1ec930690e9ef8b3012d15c51..6d2792f31ae635fe2d5863a902de50f7fa76b46f 100644 --- a/profiler/msprof_analyze/advisor/display/html/templates/pp_stage_computation_analysis.html +++ b/profiler/msprof_analyze/advisor/display/html/templates/pp_stage_computation_analysis.html @@ -1,19 +1,19 @@ -{% if stages_rendered_html|length > 0 %} -
    -

    Pipeline Parallel Stages Issues

    -
    - {% for stage_html in stages_rendered_html %} -
    -

    {{ stage_html['stage']|safe }}

    -
    - Description: analysis for slow rank {{ stage_html['rank']|safe }} in current stage -

    - {% for html in stage_html['html_list'] %} - {{ html|safe }} - {% endfor %} -
    -
    - {% endfor %} -
    -
    +{% if stages_rendered_html|length > 0 %} +
    +

    Pipeline Parallel Stages Issues

    +
    + {% for stage_html in stages_rendered_html %} +
    +

    {{ stage_html['stage']|safe }}

    +
    + Description: analysis for slow rank {{ stage_html['rank']|safe }} in current stage +

    + {% for html in stage_html['html_list'] %} + {{ html|safe }} + {% endfor %} +
    +
    + {% endfor %} +
    +
    {% endif %} \ No newline at end of file diff --git a/profiler/msprof_analyze/advisor/display/html/templates/slow_dataloader.html b/profiler/msprof_analyze/advisor/display/html/templates/slow_dataloader.html index b9ce7a574ab2a838633cb7c5181cfecb737097c9..2a3b2c4462aa6666b3ed42cc77995698ed8ce1c3 100644 --- a/profiler/msprof_analyze/advisor/display/html/templates/slow_dataloader.html +++ b/profiler/msprof_analyze/advisor/display/html/templates/slow_dataloader.html @@ -1,21 +1,21 @@ -
    -

    Slow Dataloader Issues

    -
    - {% if rank is not none %} - Analysis of rank {{ rank|safe }}. - {% endif %} - {{ desc }} - - - - - - {% for suggestion in suggestions %} - - - - {% endfor %} -
    Suggestions
    {{ loop.index }}. {{ suggestion|safe }}
    - -
    -
    +
    +

    Slow Dataloader Issues

    +
    + {% if rank is not none %} + Analysis of rank {{ rank|safe }}. + {% endif %} + {{ desc }} + + + + + + {% for suggestion in suggestions %} + + + + {% endfor %} +
    Suggestions
    {{ loop.index }}. {{ suggestion|safe }}
    + +
    +
    diff --git a/profiler/msprof_analyze/advisor/display/html/templates/sync_batchnorm.html b/profiler/msprof_analyze/advisor/display/html/templates/sync_batchnorm.html index 402404c8a43706ec4a598300eec42c7d2b7767cc..ea322276645ae9ca374f699ed7dcbaec1caad1d8 100644 --- a/profiler/msprof_analyze/advisor/display/html/templates/sync_batchnorm.html +++ b/profiler/msprof_analyze/advisor/display/html/templates/sync_batchnorm.html @@ -1,33 +1,33 @@ - -
    -

    SyncBatchNorm Issues

    -
    - {% if rank is not none %} - Analysis of rank {{ rank|safe }}. - {% endif %} - {{ desc }} - - - - - {% for item in solutions %} - {% set rowloop = loop %} - {% for key, value in item.items() %} - - - - {% endfor %} - {% endfor %} -
    Suggestions
    {{ rowloop.index }}. {{ value.desc }}
    - - More efficient code of syncbn forward as follows: - {% for item in solutions %} - {% for key, value in item.items() %} - {% if 'efficient_code' in value %} -
    {{ value.efficient_code|safe }}
    - {% endif %} - {% endfor %} - {% endfor %} - -
    -
    + +
    +

    SyncBatchNorm Issues

    +
    + {% if rank is not none %} + Analysis of rank {{ rank|safe }}. + {% endif %} + {{ desc }} + + + + + {% for item in solutions %} + {% set rowloop = loop %} + {% for key, value in item.items() %} + + + + {% endfor %} + {% endfor %} +
    Suggestions
    {{ rowloop.index }}. {{ value.desc }}
    + + More efficient code of syncbn forward as follows: + {% for item in solutions %} + {% for key, value in item.items() %} + {% if 'efficient_code' in value %} +
    {{ value.efficient_code|safe }}
    + {% endif %} + {% endfor %} + {% endfor %} + +
    +
    diff --git a/profiler/msprof_analyze/advisor/display/html/templates/synchronize_stream.html b/profiler/msprof_analyze/advisor/display/html/templates/synchronize_stream.html index eb132a6315d223ed36b096c7f6087cb53ca071d4..8636740275a66a5a7ba46703b978385cbc2df3a3 100644 --- a/profiler/msprof_analyze/advisor/display/html/templates/synchronize_stream.html +++ b/profiler/msprof_analyze/advisor/display/html/templates/synchronize_stream.html @@ -1,26 +1,26 @@ -
    -

    Synchronize Stream Issues

    -
    - {% if rank is not none %} - Analysis of rank {{ rank|safe }}. - {% endif %} - {{ desc }} - - - - - - - {% for item in solutions %} - {% set rowloop = loop %} - {% for key, value in item.items() %} - - - - - {% endfor %} - {% endfor %} -
    Suggestions
    {{ rowloop.index }}. {{ value.desc }}
    - -
    -
    +
    +

    Synchronize Stream Issues

    +
    + {% if rank is not none %} + Analysis of rank {{ rank|safe }}. + {% endif %} + {{ desc }} + + + + + + + {% for item in solutions %} + {% set rowloop = loop %} + {% for key, value in item.items() %} + + + + + {% endfor %} + {% endfor %} +
    Suggestions
    {{ rowloop.index }}. {{ value.desc }}
    + +
    +
    diff --git a/profiler/msprof_analyze/advisor/img/AI Core Performance analysis.png b/profiler/msprof_analyze/advisor/img/AI Core Performance analysis.png deleted file mode 100644 index 37708366c990fb899a9b4a846dc81fa43d5e1d43..0000000000000000000000000000000000000000 Binary files a/profiler/msprof_analyze/advisor/img/AI Core Performance analysis.png and /dev/null differ diff --git a/profiler/msprof_analyze/advisor/img/Fusible Operator Analysis.png b/profiler/msprof_analyze/advisor/img/Fusible Operator Analysis.png deleted file mode 100644 index 332b9ff838130e0daa625691aef88059dd31918d..0000000000000000000000000000000000000000 Binary files a/profiler/msprof_analyze/advisor/img/Fusible Operator Analysis.png and /dev/null differ diff --git a/profiler/msprof_analyze/advisor/interface/interface.py b/profiler/msprof_analyze/advisor/interface/interface.py index 99359174de6ecac9257189f6d3c820f39aca9f72..b3afefee57c8c62030af17130f79413238588f8f 100644 --- a/profiler/msprof_analyze/advisor/interface/interface.py +++ b/profiler/msprof_analyze/advisor/interface/interface.py @@ -1,4 +1,5 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2024-2025. All rights reserved. +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -43,8 +44,6 @@ from msprof_analyze.advisor.analyzer.schedule.gc.gc_analyzer import GcAnalyzer from msprof_analyze.advisor.analyzer.schedule.conjectured_gc.conjectured_gc_analyzer import ConjecturedGcAnalyzer from msprof_analyze.advisor.analyzer.comparison.comparison_analyzer import ComparisonAnalyzer from msprof_analyze.advisor.analyzer.schedule.fusible_ops.fusible_operator_analyzer import FusibleOperatorAnalyzer -from msprof_analyze.advisor.analyzer.computation.ai_core_performance.ai_core_performance_analyzer import \ - AICorePerformanceAnalyzer logger = logging.getLogger() @@ -75,8 +74,7 @@ class Interface: SupportedScopes.OPERATOR_NO_BOUND_ANALYSIS: OperatorBoundAnalyzer, SupportedScopes.BLOCK_DIM_ANALYSIS: BlockDimAnalyzer, SupportedScopes.GRAPH: FusionOPAnalyzer, - SupportedScopes.FREQ_ANALYSIS: AICoreFreqAnalyzer, - SupportedScopes.AICORE_PERFORMANCE_ANALYSIS: AICorePerformanceAnalyzer + SupportedScopes.FREQ_ANALYSIS: AICoreFreqAnalyzer }), COMMUNICATION: OrderedDict({SupportedScopes.PACKET: PacketAnalyzer, SupportedScopes.COMMUNICATION_RETRANSMISSION_DETECTION: RDMARetransmissionAnalyzer, diff --git a/profiler/msprof_analyze/advisor/rules/cn/aicore_performance.yaml b/profiler/msprof_analyze/advisor/rules/cn/aicore_performance.yaml deleted file mode 100644 index dcdc3e188f4684c4a80e5a3e064878fb823e3b70..0000000000000000000000000000000000000000 --- a/profiler/msprof_analyze/advisor/rules/cn/aicore_performance.yaml +++ /dev/null @@ -1,48 +0,0 @@ -cube_problem: "Cube算子性能分析" -fa_problem: "FA算子性能分析" -vector_problem: "Vector算子性能分析" -description: "提供一些AICORE算子的参考瓶颈" -bound_description: "bound算子集合" -optimization_description: "性能优化算子集合" -affinity_description: "不亲和算子集合" -cube_affinity_desc: "内轴无法被256整除" -fa_affinity_desc_head_dim_128: "D不能被128整除" -fa_affinity_desc_seq_len_128: "S不能被128整除" -fa_affinity_desc_head_dim_seq_len_128: "D和S均不能被128整除" -suggestion: "请根据亲和性、bound类型或优化空间尝试分析筛选出来的算子" -affinity_suggestion: "{op_name}算子 shape: {shape} dtype: {dtype} 有不亲和特征: {suggestion}\n" -bound_suggestion: "{op_name}算子 shape: {shape} dtype: {dtype} bound类型为: {bound} bound\n" -optimization_suggestion: "{op_name}算子 shape: {shape} dtype: {dtype} 疑似有性能优化空间,参考性能优化空间: {optimization}%\n" - -cube_operators: - - target: aic_mac_ratio - bound: mac - threshold: 0.8 - - target: aic_mte2_ratio - bound: mte2 - threshold: 0.95 - -fa_operators: - - target: aic_mte2_ratio - bound: mac - threshold: 0.8 - - target: aic_fixpipe_ratio - bound: fixpipe - threshold: 0.75 - - target: aiv_vec_ratio - bound: vec - threshold: 0.75 - -vector_operators: - - target: total - bound: vec_mte2_mte3 - threshold: 0.9 - - target: aiv_vec_ratio - bound: vec - threshold: 0.7 - - target: aiv_mte2_ratio - bound: mte2 - threshold: 0.7 - - target: aiv_mte3_ratio - bound: mte3 - threshold: 0.7 \ No newline at end of file diff --git a/profiler/msprof_analyze/advisor/rules/en/aicore_performance.yaml b/profiler/msprof_analyze/advisor/rules/en/aicore_performance.yaml deleted file mode 100644 index 68ab59f16937880c7428330811005297d1551b0d..0000000000000000000000000000000000000000 --- a/profiler/msprof_analyze/advisor/rules/en/aicore_performance.yaml +++ /dev/null @@ -1,48 +0,0 @@ -cube_problem: "Cube operator performance analysis" -fa_problem: "FA operator performance analysis" -vector_problem: "Vector operator performance analysis" -description: "Provide some reference bottlenecks for the AICORE operator" -bound_description: "set of bound operators" -optimization_description: "set of performance optimization operators" -affinity_description: "set of unaffine operators" -cube_affinity_desc: "Then inner axis is not divisible by 256" -fa_affinity_desc_head_dim_128: "D is not divisible by 128" -fa_affinity_desc_seq_len_128: "S is not divisible by 128" -fa_affinity_desc_head_dim_seq_len_128: "Neither D nor S is not divisible by 128" -suggestion: "Please try to analyze the filtered operators based on affinity, bound type or optimization space" -affinity_suggestion: "{op_name} Op shape: {shape} dtype: {dtype} with disaffection characteristics: {suggestion}\n" -bound_suggestion: "{op_name} Op shape: {shape} dtype: {dtype} bound type: {bound} bound\n" -optimization_suggestion: "{op_name} Op shape: {shape} dtype: {dtype} suspect there is room for performance optimization, refer to Performance Optimization Space: {optimization}%\n" - -cube_operators: - - target: aic_mac_ratio - bound: mac - threshold: 0.8 - - target: aic_mte2_ratio - bound: mte2 - threshold: 0.95 - -fa_operators: - - target: aic_mte2_ratio - bound: mac - threshold: 0.8 - - target: aic_fixpipe_ratio - bound: fixpipe - threshold: 0.75 - - target: aiv_vec_ratio - bound: vec - threshold: 0.75 - -vector_operators: - - target: total - bound: vec_mte2_mte3 - threshold: 0.9 - - target: aiv_vec_ratio - bound: vec - threshold: 0.7 - - target: aiv_mte2_ratio - bound: mte2 - threshold: 0.7 - - target: aiv_mte3_ratio - bound: mte3 - threshold: 0.7 \ No newline at end of file diff --git a/profiler/msprof_analyze/advisor/rules/timeline_fusion_ops.yaml b/profiler/msprof_analyze/advisor/rules/timeline_fusion_ops.yaml index 34de80add2ec849a648e64cf6c8b1e3edb1f0cc5..3337c938625ccd4b4ea77a0dafa9879222cf1bfe 100644 --- a/profiler/msprof_analyze/advisor/rules/timeline_fusion_ops.yaml +++ b/profiler/msprof_analyze/advisor/rules/timeline_fusion_ops.yaml @@ -66,23 +66,4 @@ torch_npu.npu_geglu: [ "(slice|chunk)-gelu-mul", "(slice|chunk)-mul-gelu" ] torch_npu.npu_group_norm_silu: [ "group_norm-silu" ] torch.addmm: [ "mul-mul-add" ] - torch_npu.npu_add_layer_norm: [ "add-layer_norm" ] - -- cann_version: 8.0.RC3 - torch_version: [1.11.0, 2.1.0] - unique_id: 4 - inherit_unique_id: 3 - operator_rules: - aten: - add: - mindspeed.ops.npu_matmul_add: [ "matmul-add" ] - -- cann_version: 8.0.RC3 - torch_version: [1.11.0, 2.1.0] - unique_id: 5 - inherit_unique_id: 4 - operator_rules: - aten: - add: - mindspeed.ops.npu_moe_token_permute: ["argsort-argsort-index_select"] - mindspeed.ops.npu_moe_token_unpermute: ["index_select-mul-reduce_sum"] \ No newline at end of file + torch_npu.npu_add_layer_norm: [ "add-layer_norm" ] \ No newline at end of file diff --git a/profiler/msprof_analyze/cli/cluster_cli.py b/profiler/msprof_analyze/cli/cluster_cli.py index adaf0f8d7cab8eff139125fbb4699ea962e3a427..0cdb2bd2b10b2ede411d10221e36a51e3f015e12 100644 --- a/profiler/msprof_analyze/cli/cluster_cli.py +++ b/profiler/msprof_analyze/cli/cluster_cli.py @@ -34,7 +34,6 @@ context_settings['ignore_unknown_options'] = True @click.option("--parallel_mode", type=str, help="context mode", default="concurrent") @click.option("--export_type", help="recipe export type", type=click.Choice(["db", "notebook"]), default="db") @click.option("--rank_list", type=str, help="Rank id list", default='all') -@click.option("--step_id", type=int, help="Step id", default=Constant.VOID_STEP) @click.argument('args', nargs=-1) def cluster_cli(**kwargs) -> None: Interface(kwargs).run() diff --git a/profiler/msprof_analyze/cli/entrance.py b/profiler/msprof_analyze/cli/entrance.py index 534a9b133c7e60d1442cb290490a79e9256ce43d..0aa61f1b6aee2a5b6b321e8e3fb7a04ed63ff98a 100644 --- a/profiler/msprof_analyze/cli/entrance.py +++ b/profiler/msprof_analyze/cli/entrance.py @@ -22,6 +22,7 @@ from msprof_analyze.cli.complete_cli import auto_complete_cli from msprof_analyze.cli.compare_cli import compare_cli from msprof_analyze.cli.cluster_cli import cluster_cli from msprof_analyze.advisor.version import print_version_callback, cli_version +from msprof_analyze.cli.precheck_cli import precheck_cli logger = logging.getLogger() CONTEXT_SETTINGS = dict(help_option_names=['-H', '-h', '--help'], @@ -31,7 +32,8 @@ COMMAND_PRIORITY = { "advisor": 1, "compare": 2, "cluster": 3, - "auto-completion": 4 + "precheck": 4, + "auto-completion": 5 } @@ -66,5 +68,6 @@ def msprof_analyze_cli(**kwargs): msprof_analyze_cli.add_command(analyze_cli, name="advisor") msprof_analyze_cli.add_command(compare_cli, name="compare") msprof_analyze_cli.add_command(cluster_cli, name="cluster") +msprof_analyze_cli.add_command(precheck_cli, name="precheck") msprof_analyze_cli.add_command(auto_complete_cli, name="auto-completion") diff --git a/profiler/msprof_analyze/cli/precheck_cli.py b/profiler/msprof_analyze/cli/precheck_cli.py new file mode 100644 index 0000000000000000000000000000000000000000..c70b540ce9e3a8c9f718fe6be34917f8fed7b85d --- /dev/null +++ b/profiler/msprof_analyze/cli/precheck_cli.py @@ -0,0 +1,159 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys +import ipaddress +import logging +from functools import wraps + +import click + +from msprof_analyze.prof_common.path_manager import PathManager + +logger = logging.getLogger(__name__) +CONTEXT_SETTINGS = dict(help_option_names=['-H', '-h', '--help']) + + +@click.group(context_settings=CONTEXT_SETTINGS) +def precheck_cli(): + """Profiler precheck tool""" + pass + + +def common_options(f): + """Common options for both precheck and runner commands""" + + @wraps(f) + def wrapper(*args, **kwargs): + return f(*args, **kwargs) + + wrapper = click.option('--master_addr', required=True, + help='IP address of the master node (first node in the cluster)')(wrapper) + wrapper = click.option('--master_port', type=int, default=29500, + help='Port on the master node for communication. Default is 29500')(wrapper) + wrapper = click.option('--nnodes', type=int, required=True, + help='Total number of nodes in the distributed setup')(wrapper) + wrapper = click.option('--nproc_per_node', type=int, required=True, + help='Number of processes to run per node')(wrapper) + wrapper = click.option('--node_prof_save_dir', default='', callback=PathManager.expanduser_for_cli, + help='Directory for saving node profiling data')(wrapper) + wrapper = click.option('--master_prof_gather_dir', default='', callback=PathManager.expanduser_for_cli, + help='Directory for saving gathered profiling data in master node')(wrapper) + wrapper = click.option('--output_dir', default='./output', callback=PathManager.expanduser_for_cli, + help='Directory to save profiling dump data, logs, and advisor reports')(wrapper) + wrapper = click.option('--task_name', default='', + help='Name of the task or experiment')(wrapper) + wrapper = click.option('--static', is_flag=True, + help='If set, run profiling in static mode')(wrapper) + wrapper = click.option('--profiling_cmd', default="", + help='Command to run the profiler script')(wrapper) + wrapper = click.option('--prof_in_shared_storage', is_flag=True, + help='If set, skip data collection as profiling data is in shared storage')(wrapper) + return wrapper + + +def validate_ip_list(ctx, param, value): + if not value: + return [] + try: + ips = [ip.strip() for ip in value.split(',')] + # Validate each IP + for ip in ips: + ipaddress.ip_address(ip) + return ips + except ValueError as e: + raise click.BadParameter(f'Invalid IP address in list: {e}') + + +@precheck_cli.command(context_settings=CONTEXT_SETTINGS, + name="start_all", + short_help='Start precheck on all nodes via ssh') +@common_options +@click.option('--host_ips', + callback=validate_ip_list, + help='Comma-separated list of IP addresses for nodes in distributed training (e.g., "192.168.1.1,192.168.1.2")') +@click.option('--python_path', default=sys.executable, callback=PathManager.expanduser_for_cli, + help='Path to the Python interpreter') +@click.option('--host_config_file', default='', callback=PathManager.expanduser_for_cli, + help='Path to the host configuration file (CSV format with node connection details)') +def precheck_start_all(**kwargs): + """Run precheck command""" + # Add validation + if not kwargs.get('host_ips') and not kwargs.get('host_config_file'): + raise click.UsageError('Either --host_ips or --host_config_file must be specified') + + if kwargs.get('host_ips') and kwargs.get('host_config_file'): + raise click.UsageError('Cannot specify both --host_ips and --host_config_file') + + from msprof_analyze.precheck.manager.args_manager import PrecheckArgsManager + from msprof_analyze.precheck.__main__ import main as precheck_main + + args = PrecheckArgsManager(type('Args', (), kwargs)) + click.echo(args) + precheck_main(args) + + +@precheck_cli.command(context_settings=CONTEXT_SETTINGS, + name="start_node", + short_help='Start one node precheck, if your nnodes > 1, you need to run this command on each node') +@common_options +@click.option('--node_rank', type=int, required=True, + help='Rank of the current node') +def precheck_start_node(**kwargs): + """Run precheck runner command""" + from msprof_analyze.precheck.manager.args_manager import PrecheckRunnerArgsManager + from msprof_analyze.precheck.runner.__main__ import main as runner_main + + args = PrecheckRunnerArgsManager(type('Args', (), kwargs)) + click.echo(args) + + runner_main(args) + + +@precheck_cli.command(context_settings=CONTEXT_SETTINGS, + name="env", + short_help='execute environment precheck') +@click.option('--nproc_per_node', type=int, required=True, + help='Number of processes to run per node') +@click.option('--nnodes', type=int, required=True, + help='Total number of nodes in the distributed setup') +@click.option('--node_rank', type=int, required=True, + help='Rank of the current node') +@click.option('--master_addr', type=str, required=False, + help='IP address of the master node', default="localhost") +@click.option('--master_port', type=int, required=False, + help='Port on the master node for communication', default=6000) +@click.option('--tensor-model-parallel-size', type=int, required=False, + help='Degree of tensor parallelism', default=1) +@click.option('--pipeline-model-parallel-size', type=int, required=False, + help='Degree of pipeline parallelism', default=1) +@click.option('--context-parallel-size', type=int, required=False, + help='Degree of context parallelism', default=1) +@click.option('--expert-model-parallel-size', type=int, required=False, + help='Degree of expert parallelism', default=1) +@click.option('--output', type=str, required=False, + help='Output path', default="./output") +@click.option('--check-type', type=str, required=False, + help='Environment precheck type', default="all") +def environment_precheck(**kwargs): + from msprof_analyze.precheck.precheck import Precheck + + click.echo(kwargs) + Precheck.env_precheck(**kwargs) + + +if __name__ == '__main__': + precheck_cli() \ No newline at end of file diff --git a/profiler/msprof_analyze/cluster_analyse/README.md b/profiler/msprof_analyze/cluster_analyse/README.md index 325a0984793297dfac28673f04a582ea7b4316b9..e488ab85f76c40da44d188ac421abe7f3b533da1 100644 --- a/profiler/msprof_analyze/cluster_analyse/README.md +++ b/profiler/msprof_analyze/cluster_analyse/README.md @@ -1,329 +1,193 @@ -# 集群分析工具 -cluster_analyse(集群分析工具)是在集群场景下,通过此工具来进行集群数据的分析,当前主要对基于通信域的迭代内耗时分析、通信时间分析以及通信矩阵分析为主, 从而定位慢卡、慢节点以及慢链路问题。 - -## 性能数据采集 -当前集群调优工具主要支持PyTorch场景的Ascend PyTorch Profiler采集方式和MindSpore场景的MindSpore Profiler采集方式下的集群数据。 - -此工具只需要NPU的性能数据作为输入。 - -Ascend PyTorch Profiler采集方法请参见《[NPU性能数据采集](https://gitee.com/ascend/mstt/tree/master/profiler/msprof_analyze)》,MindSpore Profiler采集方法请参见《[性能调试](https://www.mindspore.cn/mindinsight/docs/zh-CN/r2.3/performance_profiling_ascend.html)》。 - -我们要求至少是L1级别的数据。 -```python -experimental_config = torch_npu.profiler._ExperimentalConfig( - profiler_level=torch_npu.profiler.ProfilerLevel.Level1 -) -``` -### 确认数据是否可用 - -打开采集到的某张卡数据(\*ascend_pt、\*ascend_ms结尾的文件夹),可用的数据应该具备: - -- ./profiler_info_x.json, -- ./ASCEND_PROFILER_OUTPUT/step_trace_time.csv, -- ./ASCEND_PROFILER_OUTPUT/trace_view.json, -- ./ASCEND_PROFILER_OUTPUT/kernel_details.csv, -- ./ASCEND_PROFILER_OUTPUT/communication.json, -- ./ASCEND_PROFILER_OUTPUT/communication_matrix.json - -或者具备: - -- analysis.db -- ascend_pytorch_profiler_{rank_id}.db - -以上csv、json文件与db文件只能存在一类,否则集群分析工具解析异常。MindSpore场景暂不支持以上db文件。 - -确认这几个文件生成后,继续下面的集群分析。 - -## 数据汇聚与解析 - -### 操作步骤 - -1. 参见《[性能工具](../README.md)》完成工具安装。建议安装最新版本。 - -2. 将所有卡的数据拷贝并汇集到一个目录下,运行以下命令,在该目录下即可生成cluster_analysis_output文件夹。 - - ```bash - msprof-analyze cluster -d {cluster profiling data path} [-m mode] [-o output_path] [--data_simplification] [--force] - ``` - - 或 - - ```bash - python3 cluster_analysis.py -d {cluster profiling data path} [-m mode] [-o output_path] [--data_simplification] [--force] - ``` - - 参数说明: - - | 参数名 | 说明 | 是否必选 | - | --------------------- | ------------------------------------------------------------ | -------- | - | --profiling_path或-d | 性能数据汇集目录。未配置-o参数时,运行分析脚本之后会在该目录下自动创建cluster_analysis_output文件夹,保存分析数据。 | 是 | - | --output_path或-o | 自定义输出路径,运行分析脚本之后会在该目录下自动创建cluster_analysis_output文件夹,保存分析数据。 | 否 | - | --mode或-m | 数据解析模式,取值详见“**--mode参数说明**”表。 | 否 | - | --data_simplification | 数据精简模式。对于数据量过大的性能数据db文件,可以通过配置该参数将数据精简,并提高工具分析效率。配置该参数表示开启数据精简,默认未配置表示关闭。 | 否 | - | --force | 强制执行cluster。配置后可强制跳过如下情况:
    指定的目录、文件的用户属主不属于当前用户,忽略属主判断直接执行。
    csv文件大于5G、json文件大于10G、db文件大于8G,忽略文件过大判断直接执行。
    配置该参数表示开启强制执行,默认未配置表示关闭。 | 否 | - | --parallel_mode | 设置收集多卡、多节点db数据时的并发方式。取值为concurrent(使用concurrent.feature进程池实现并发)。
    **只有-m配置cann_api_sum、compute_op_sum、hccl_sum、mstx_sum和自定义分析参数时可配置此参数。** | 否 | - | --export_type | 设置导出的数据形式。取值为db(.db格式文件)和notebook(Jupyter Notebook文件),默认值为db。
    **只有-m配置cann_api_sum、compute_op_sum、hccl_sum、mstx_sum和自定义分析参数时可配置此参数。** | 否 | - | --rank_list | 对特定Rank上的数据进行统计,默认值为all(表示对所有Rank进行统计),须根据实际卡的Rank ID配置。应配置为大于等于0的整数,若所配置的值大于实际训练所运行的卡的Rank ID,则仅解析合法的RankID的数据,比如当前环境Rank ID为0到7,实际训练运行0到3卡,此时若配置Rank ID为0, 3, 4或不存在的10等其他值,则仅解析0和3。配置示例:--rank_list 0, 1, 2。
    **只有-m配置cann_api_sum、compute_op_sum、hccl_sum、mstx_sum和自定义分析参数时可配置此参数。** | 否 | - | --step_id | 性能数据Step ID,配置后对该Step的性能数据进行分析。需配置性能数据中实际存在的Step ID,默认未配置,表示全量分析。配置示例:--step_id=1。
    **只有-m配置cann_api_sum、compute_op_sum、hccl_sum、mstx_sum和自定义分析参数时可配置此参数。** | 否 | - | --top_num | 设置TopN耗时的通信算子的数量,默认值为15,配置示例:--top_num 20。
    **只有-m配置hccl_sum时可配置此参数。** | 否 | - | --exclude_op_name | 控制compute_op_name结果是否包含op_name,示例:--exclude_op_name,后面不需要跟参数。
    **只有-m配置compute_op_sum时可配置此参数。** | 否 | - - --mode参数说明: - - | 参数名 | 说明 | 是否必选 | - | -------------------- | ------------------------------------------------------------ | -------- | - | communication_matrix | 解析通信矩阵数据。 | 否 | - | communication_time | 解析通信耗时数据。 | 否 | - | all | 解析内容包括:
    通信矩阵communication_matrix
    通信耗时数据communication_time
    汇总集群内的节点信息(基于ascend_pytorch_profiler_{rank_id}.db生成)
    --mode参数默认值为all。 | 否 | - | cann_api_sum | 集群API性能数据汇总分析,输入性能数据需要基于ascend_pytorch_profiler_{rank_id}.db文件。--export_type为db时,输出交付件cluster_analysis.db;--export_type为notebook时,在cluster_analysis_output/CannApiSum目录下输出交付件stats.ipynb。 | 否 | - | compute_op_sum | 集群场景性能数据的device运行算子信息汇总分析,输入性能数据需要基于ascend_pytorch_profiler_{rank_id}.db文件。--export_type为db时,输出交付件cluster_analysis.db;--export_type为notebook时,在cluster_analysis_output/ComputeOpSum目录下输出交付件stats.ipynb;可根据实际情况决定是否是否打开--exclude_op_name。 | 否 | - | hccl_sum | 集合通信算子耗时分析,输入性能数据需要基于ascend_pytorch_profiler_{rank_id}.db文件。--export_type为db时,输出交付件cluster_analysis.db;--export_type为notebook时,在cluster_analysis_output/HcclSum目录下输出交付件stats.ipynb。 | 否 | - | mstx_sum | 集群场景mstx打点信息汇总分析,输入性能数据需要基于ascend_pytorch_profiler_{rank_id}.db文件。--export_type为db时,输出交付件cluster_analysis.db;--export_type为notebook时,在cluster_analysis_output/MstxSum目录下输出交付件stats.ipynb。 | 否 | - | 自定义分析参数 | 与cann_api_sum、compute_op_sum、hccl_sum等参数功能类似,用户可自定义一套性能数据的分析规则,需要详细了解性能分析的开发人员,具体开发指导请参见“[自定义分析规则开发指导](#自定义分析规则开发指导)”。 | 否 | - - --parallel_mode参数示例如下: - - ```bash - msprof-analyze cluster -d {cluster profiling data path} -m cann_api_sum --parallel_mode concurrent - ``` - - 或 - - ```bash - python3 cluster_analysis.py -d {cluster profiling data path} -m cann_api_sum --parallel_mode concurrent - ``` - - -### 交付件 - -集群分析工具的交付件通过MindStudio Insight工具展示,详见《[MindStudio Insight用户指南](https://www.hiascend.com/document/detail/zh/mindstudio/70RC2/GUI-baseddevelopmenttool/msascendinsightug/AscendInsight_0002.html)》。 - -#### cluster_step_trace_time.csv - -数据解析模式为communication_matrix、communication_time或all时均生成。 - -A列: Step数,是采集性能数据时设置的,一般来说集群性能数据采集一个step足够,如果采集多个step,需要先筛选一下。 - -B列: Type,主要分两种,rank和stage, 和后面的index强相关,可以理解为一个是单卡rank,一个是rank group(pp 并行的stage),如果type为stage,则后面D-K列信息为rank group下的最大值。 - -C列:Index,与type相关,表示卡号。 - -D列:Computing, 此列统计计算时间。 - -E列:Communication(Not Overlapped),此列统计未被掩盖的通信耗时。 - -F列:Overlapped,统计计算与通信重叠的耗时。 - -G列:Communication,通信时间的全部耗时。 - -H列:Free,空闲时间,指device侧既不在通信也不在计算的耗时,可能在做sdma拷贝或者空等。 - -I列:Stage时间,I、J、K列属于pp并行时有效的数值,stage时间代表除receive算子时间外的时间。 - -J列:Bubble时间,指receive时间的总和。 - -K列:Communication(Not Overlapped and Exclude Receive)指剔除receive算子外的并且不被掩盖的通信时间。 - -L列:Preparing,指迭代开始到首个计算或通信算子运行的时间。 - -M列:DP Index,指集群数据按照并行策略切分后所属DP组的索引, 如果没有采集则不显示。 - -N列:PP Index,指集群数据按照并行策略切分后所属PP组的索引,如果没有采集则不显示。 - -O列:TP Index,指集群数据按照并行策略切分后所属TP组的索引,如果没有采集则不显示。 - -**Tips**:先筛选B列type为stage, 看stage间是否有问题,再筛选B列type为rank,看rank是否有问题,根据以下几点排查。 - -* 根据Computing的时间差异判断是否有慢卡,或者有负载不均衡的现象。 - -* 根据Free统计是否有host bound或者分布不均现象。 - -* 根据Communication(Not Overlapped and Exclude Receive)时间判断是否通信耗时占比过大。 - -* 根据Bubble时间的占比和理论计算公式判断bubble设置是否合理,是否stage间有不均衡现象。 - -以上时间理论上都应该处于持平状态,即最大值小于最小值5%,否则就可能出现慢卡。 - -#### cluster_communication_matrix.json - -数据解析模式为communication_matrix或all时生成。 - -直接打开json(vscode或json查看器), 搜索"Total", 会有多个搜索结果,一般来说链路带宽信息的结构: - -```bash -{src_rank}-{dst_rank}: { - "Transport Type": "LOCAL", - "Transit Time(ms)": 0.02462, - "Transit Size(MB)": 16.777216, - "Bandwidth(GB/s)": 681.4466 -} -``` -**Tips**:可以根据rank互联的带宽以及链路类型,判断是否有慢链路的问题。 - -- "LOCAL"是片内拷贝,速度最高。 -- “HCCS”或“PCIE”是节点内片间拷贝,速度居中。 -- “RDMA”是节点间拷贝,速度最低。 - -#### cluster_communication.json - -数据解析模式为communication_time或all时生成。 - -主要为通信耗时数据。 - -#### cluster_analysis.db - -解析analysis.db或ascend_pytorch_profiler_{rank_id}.db生成的交付件,根据数据解析模式不同而解析不同的数据,可以使用MindStudio Insight工具展示。 - -#### communication_group.json - -记录通信域信息,解析analysis.db生成的交付件,collective表示集合通信域,P2P表示点对点通信,用户无须关注该文件。 - -#### stats.ipynb - -- 数据解析模式为cann_api_sum时生成,保存在cluster_analysis_output/CannApiSum目录下。 - - 可使用jupyter notebook工具或MindStudio Insight工具打开,主要展示集群API耗时信息。 - -- 数据解析模式为compute_op_sum时生成,保存在cluster_analysis_output/ComputeOpSum目录下。 - - 可使用jupyter notebook工具或MindStudio Insight工具打开,主要展示集群计算算子耗时分析(将集群所有计算算子进行汇总并以图表展示),集群Rank计算算子耗时分析(将每个Rank的计算算子进行各自汇总)。 - -- 数据解析模式为hccl_sum时生成,保存在cluster_analysis_output/HcclSum目录下。 - - 可使用jupyter notebook工具或MindStudio Insight工具打开,主要展示集群通信算子耗时分析(将集群所有通信算子进行汇总并以图表展示),集群Rank通信算子耗时分析(将每个Rank的通信算子进行各自汇总)、Top通信算子信息展示。 - -- 数据解析模式为mstx_sum时生成,保存在cluster_analysis_output/MstxSum目录下。 - - 可使用jupyter notebook工具或MindStudio Insight工具打开,主要展示集群场景mstx打点信息,分为框架侧、CANN侧和Device侧三部分的打点信息。 - -## 附录 - -### 自定义分析规则开发指导 - -自定义分析规则是基于对Profiling的analysis.db和ascend_pytorch_profiler_{rank_id}.db文件进行性能数据分析而开发。与cann_api_sum、compute_op_sum、hccl_sum等参数功能实现类似,可自定义一套性能数据的分析规则,方法如下: - -1. 在mstt工具代码仓profiler/msprof_analyze/cluster_analyse/recipes目录下创建xxx目录和xxx.py文件。 - - 例如:profiler/msprof_analyze/cluster_analyse/recipes/cann_api_sum/cann_api_sum.py,其中目录名和文件名要保持一致,该目录名也会作为使用msprof-analyze cluster工具启动该自定义分析的开关参数。 - -2. 在xxx.py文件进行性能数据分析规则的开发,开发要求继承BaseRecipeAnalysis,实现run函数。 - - 典型的run函数实现: - - ```python - def run(self, context): - mapper_res = self.mapper_func(context) - self.reducer_func(mapper_res) - if self._export_type == "db": - self.save_db() - elif self._export_type == "notebook": - self.save_notebook() - else: - logger.error("Unknown export type.") - ``` - - 1. `mapper_func`函数:多卡数据查询并合并返回结果。由于集群数据每张卡的数据处理是同样的,因此采用context并行处理集群数据并将结果按序拼装返回。开发只需要实现单卡数据处理的函数`self._mapper_fun`。 - - ```python - def mapper_func(self, context): - return context.wait( - context.map( - self._mapper_func, - self._get_rank_db(), - analysis_class=self._recipe_name - ) - ) - ``` - - ```python - def _mapper_func(self, data_map, analysis_class): - """ - Extract the profiling data required for cluster analysis from each device, and then aggregate the - results from each device to be processed by a reduce function. - Params: - data_map: eg. {"RANK_ID": 1, "profiler_db_path": "xxxx/ascend_pytorch_profiler_1.db"} - analysis_class: hccl_sum, compute_op_sum, cann_api_sum, mstx_sum...... - """ - pass - ``` - - 2. `reducer_func`函数:对多卡结果分析处理。接收`mapper_func`函数的返回值,进行进一步的集群数据的汇总分析,数据结构采用dataframe。 - - 3. `save_db`函数:分析结果保存在cluster_analysis.db中。 - - 4. `save_notebook`函数:分析结果以csv和stats.ipynb的形式保存。 - -3. `self._mapper_fun`函数依赖单db数据查询,可通过可通过如下两种方式。 - - 1. 使用DatabaseService可配置单表的查询。 - - 可参考:https://gitee.com/ascend/mstt/blob/pre-research/profiler/msprof_analyze/cluster_analyse/recipes/mstx2commop/mstx2commop.py - - 使用样例: - - ```Python - service = DatabaseService(profiler_db_path) - service.add_table_for_query("ENUM_HCCL_DATA_TYPE", ["id", "name"]) # 第一个参数:表名;第二个参数:字段列表,默认为None,当不填写时表明select * - service.add_table_for_query("STRING_IDS", ["id", "value"]) #可 以添加多个表 - df_dict = service.query_data() # 将配置的所有表按序查询,以dict形式返回,key为表名,value为数据库查询结果dataframe数据类型 - ``` - - 2. 维护在msprof_analyze/prof_exports目录下,新建一个py文件,需继承自BaseStatsExport(注:新增之前可以看现有的是否可用,避免重复)如下示例: - - ```Python - from msprof_analyze.prof_exports.base_stats_export import BaseStatsExport - - QUERY = """ - SELECT - NAME_IDS.value AS "OpName", - TYPE_IDS.value AS "OpType", - round(endNs - startNs) AS "Duration", - GROUP_NAME_IDS.value AS "GroupName" - FROM - COMMUNICATION_OP - LEFT JOIN - STRING_IDS AS TYPE_IDS - ON TYPE_IDS.id == COMMUNICATION_OP.opType - LEFT JOIN - STRING_IDS AS NAME_IDS - ON NAME_IDS.id == COMMUNICATION_OP.opName - LEFT JOIN - STRING_IDS AS GROUP_NAME_IDS - ON GROUP_NAME_IDS.id == COMMUNICATION_OP.groupName - """ - - - class HcclSumExport(BaseStatsExport): - def __init__(self, db_path, recipe_name): - super().__init__(db_path, recipe_name) - self._query = QUERY - ``` - - 使用样例:df = HcclSumExport(profiler_db_path, analysis_class).read_export_db(),返回的数据类型是dataframe。 - -4. 分析规则增加拓展参数。 - - 实现函数add_parser_argument,样例如下: - - ```Python - @classmethod - def add_parser_argument(cls, parser): - parser.add_argument("--top_num", type=str, help="Duration cost top count", default=cls.DEFAULT_TOP_NUM) - ``` - - 从self._extra_args里获取对应的扩展参数: - - ```Python - def __init__(self, params): - super().__init__(params) - top_num = self._extra_args.get(self.TOP_NUM, self.DEFAULT_TOP_NUM) - self.top_num = int(top_num) if isinstance(top_num, str) and top_num.isdigit() else self.DEFAULT_TOP_NUM - ``` - -5. 执行自定义分析规则命令。 - - ```bash - msprof-analyze cluster -d {cluster profiling data path} --mode xxx --top_num 10 - ``` - - +# 集群分析工具 +cluster_analyse(集群分析工具)是在集群场景下,通过此工具来进行集群数据的分析,当前主要对基于通信域的迭代内耗时分析、通信时间分析以及通信矩阵分析为主, 从而定位慢卡、慢节点以及慢链路问题。 + +## 性能数据采集 +当前集群调优工具主要支持PyTorch场景的Ascend PyTorch Profiler采集方式和MindSpore场景的MindSpore Profiler采集方式下的集群数据。 + +此工具只需要NPU的性能数据作为输入。 + +Ascend PyTorch Profiler采集方法请参见《[NPU性能数据采集](https://gitee.com/ascend/mstt/tree/master/profiler/msprof_analyze)》,MindSpore Profiler采集方法请参见《[性能调试](https://www.mindspore.cn/mindinsight/docs/zh-CN/r2.3/performance_profiling_ascend.html)》。 + +我们要求至少是L1级别的数据。 +```python +experimental_config = torch_npu.profiler._ExperimentalConfig( + profiler_level=torch_npu.profiler.ProfilerLevel.Level1 +) +``` +### 确认数据是否可用 + +打开采集到的某张卡数据(\*ascend_pt、\*ascend_ms结尾的文件夹),可用的数据应该具备: + +- ./profiler_info_x.json, +- ./ASCEND_PROFILER_OUTPUT/step_trace_time.csv, +- ./ASCEND_PROFILER_OUTPUT/trace_view.json, +- ./ASCEND_PROFILER_OUTPUT/kernel_details.csv, +- ./ASCEND_PROFILER_OUTPUT/communication.json, +- ./ASCEND_PROFILER_OUTPUT/communication_matrix.json + +或者具备: + +- analysis.db +- ascend_pytorch_profiler_{rank_id}.db + +以上csv、json文件与db文件只能存在一类,否则集群分析工具解析异常。MindSpore场景暂不支持以上db文件。 + +确认这几个文件生成后,继续下面的集群分析。 + +## 数据汇聚与解析 + +### 操作步骤 + +1. 参见《[性能工具](../README.md)》完成工具安装。建议安装最新版本。 + + 将所有卡的数据拷贝并汇集到一个目录下,运行以下命令,在该目录下即可生成cluster_analysis_output文件夹。 + + ```bash + msprof-analyze cluster -d {cluster profiling data path} [-m mode] [-o output_path] [--data_simplification] [--force] + ``` + + 或 + + ```bash + python3 cluster_analysis.py -d {cluster profiling data path} [-m mode] [-o output_path] [--data_simplification] [--force] + ``` + + 参数说明: + + | 参数名 | 说明 | 是否必选 | + | --------------------- | ------------------------------------------------------------ | -------- | + | --profiling_path或-d | 性能数据汇集目录。未配置-o参数时,运行分析脚本之后会在该目录下自动创建cluster_analysis_output文件夹,保存分析数据。 | 是 | + | --output_path或-o | 自定义输出路径,运行分析脚本之后会在该目录下自动创建cluster_analysis_output文件夹,保存分析数据。 | 否 | + | --mode或-m | 数据解析模式,取值详见“**--mode参数说明**”表。 | 否 | + | --data_simplification | 数据精简模式。对于数据量过大的性能数据db文件,可以通过配置该参数将数据精简,并提高工具分析效率。配置该参数表示开启数据精简,默认未配置表示关闭。 | 否 | + | --force | 强制执行cluster。配置后可强制跳过如下情况:
    指定的目录、文件的用户属主不属于当前用户,忽略属主判断直接执行。
    csv文件大于5G、json文件大于10G、db文件大于8G,忽略文件过大判断直接执行。
    配置该参数表示开启强制执行,默认未配置表示关闭。 | 否 | + | --parallel_mode | 设置收集多卡、多节点db数据时的并发方式。取值为concurrent(使用concurrent.feature进程池实现并发)。
    **只有-m配置cann_api_sum、compute_op_sum、hccl_sum、mstx_sum时可配置此参数。** | 否 | + | --export_type | 设置导出的数据形式。取值为db(.db格式文件)和notebook(Jupyter Notebook文件),默认值为db。
    **只有-m配置cann_api_sum、compute_op_sum、hccl_sum、mstx_sum时可配置此参数。** | 否 | + | --rank_list | 对特定Rank上的数据进行统计,默认值为all(表示对所有Rank进行统计),须根据实际卡的Rank ID配置。应配置为大于等于0的整数,若所配置的值大于实际训练所运行的卡的Rank ID,则仅解析合法的RankID的数据,比如当前环境Rank ID为0到7,实际训练运行0到3卡,此时若配置Rank ID为0, 3, 4或不存在的10等其他值,则仅解析0和3。配置示例:--rank_list 0, 1, 2。
    **只有-m配置cann_api_sum、compute_op_sum、hccl_sum、mstx_sum时可配置此参数。** | 否 | + | --top_num | 设置TopN耗时的通信算子的数量,默认值为15,配置示例:--top_num 20。
    **只有-m配置hccl_sum时可配置此参数。** | 否 | + | --exclude_op_name | 控制compute_op_name结果是否包含op_name,示例:--exclude_op_name,后面不需要跟参数。
    **只有-m配置compute_op_sum时可配置此参数。** | 否 | + | --bp | 要对比的标杆集群数据,示例:--bp {bp_cluster_profiling_path},表示profiling_path和bp_cluster_profiling_path的数据进行对比。
    **只有-m配置cluster_time_compare_summary时可配置此参数。** | 否 | + + --mode参数说明: + + --mode参数设置不同的数据解析模式,可调用不同的分析能力,交付件详细内容请参见[recipe结果和cluster_analysis.db交付件表结构说明](#recipe结果和cluster_analysisdb交付件表结构说明)。 + + | 参数名 | 说明 | 是否必选 | + |------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|----| + | communication_matrix | 解析通信矩阵数据。 | 否 | + | communication_time | 解析通信耗时数据。 | 否 | + | all | 同时解析通信矩阵communication_matrix和通信耗时数据communication_time,--mode参数默认值为all。 | 否 | + | cann_api_sum | 集群API性能数据汇总分析,输入性能数据需要基于ascend_pytorch_profiler_{rank_id}.db文件。--export_type为db时,输出交付件cluster_analysis.db;--export_type为notebook时,在cluster_analysis_output/CannApiSum目录下输出交付件stats.ipynb。 | 否 | + | compute_op_sum | 集群场景性能数据的device运行算子信息汇总分析,输入性能数据需要基于ascend_pytorch_profiler_{rank_id}.db文件。--export_type为db时,输出交付件cluster_analysis.db;--export_type为notebook时,在cluster_analysis_output/ComputeOpSum目录下输出交付件stats.ipynb;可根据实际情况决定是否是否打开--exclude_op_name。 | 否 | + | hccl_sum | 集合通信算子耗时分析,输入性能数据需要基于ascend_pytorch_profiler_{rank_id}.db文件。--export_type为db时,输出交付件cluster_analysis.db;--export_type为notebook时,在cluster_analysis_output/HcclSum目录下输出交付件stats.ipynb。 | 否 | + | mstx_sum | 集群场景mstx打点信息汇总分析,输入性能数据需要基于ascend_pytorch_profiler_{rank_id}.db文件。--export_type为db时,输出交付件cluster_analysis.db;--export_type为notebook时,在cluster_analysis_output/MstxSum目录下输出交付件stats.ipynb。 | 否 | + | slow_link | 集群慢链路异常分析,输入性能数据需要基于ascend_pytorch_profiler_{rank_id}.db文件。--export_type为db时,输出交付件cluster_analysis.db;--export_type为notebook时,在cluster_analysis_output/SlowLink目录下输出交付件stats.ipynb。 | 否 | + | cluster_time_summary | 集群场景性能数据分析,输入性能数据需要基于ascend_pytorch_profiler_{rank_id}.db和analysis.db文件。--export_type为db时,输出交付件cluster_analysis.db,db里面有ClusterTimeSummary,不支持导出notebook。 | 否 | + | cluster_time_compare_summary | 集群场景性能数据对比分析,使用前集群数据必须先分析cluster_time_summary,需要配合--bp参数使用。输入性能数据需要基于cluster_analysis_output下的cluster_analysis.db文件。--export_type为db时,输出交付件cluster_analysis.db,db文件中有对比结果的表ClusterTimeCompareSummary,不支持导出notebook。 | 否 | + | slow_rank_pp_stage | 集群场景性能数据pp stage通信对比分析,输入性能数据需要基于ascend_pytorch_profiler_{rank_id}.db文件。输入性能数据中MetaData表如果没有包含训练任务的并行策略,则需要通过--tp --pp --dp手动传入,数据类型为正整数。--export_type为db时,输出交付件cluster_analysis.db,db文件中有分析结果PPAnalysisResult和P2PAnalysisResult,不支持导出notebook。 | 否 | + | freq_analysis | 集群场景aicore frequency信息汇总分析,输入性能数据需要基于ascend_pytorch_profiler_{rank_id}.db文件。打印输出是否aicore存在空闲(频率为800MHz)、异常(频率不为1800MHz或800MHz)的现象。如果有,则在输出交付件cluster_analysis.db增加对应的卡和频率信息。 | 否 | + | ep_load_balance | 集群场景moe负载信息汇总分析,输入性能数据需要基于ascend_pytorch_profiler_{rank_id}.db文件。输出交付件cluster_analysis.db增加EPTokensSummary, TopEPTokensInfo分析表格。 | 否 | + | mstx2comm | 基于ascend_pytorch_profiler_{rank_id}.db文件,将通信内置打点数据转换成通信算子。 | 否 | + | slow_rank | 集群场景通信算子快慢卡汇总分析,输入性能数据需要基于ascend_pytorch_profiler_{rank_id}.db文件。输出交付件cluster_analysis.db中展示各个rank按照当前的快慢卡统计算法得出的快慢卡影响次数。 | | + | p2p_pairing | 集群场景P2P算子生成全局关联索引,输入性能数据需要基于ascend_pytorch_profiler_{rank_id}.db文件。输出的关联索引会作为一个新的字段`opConnectionId`附在原性能数据ascend_pytorch_profiler_{rank_id}.db文件的`COMMUNICATION_OP`的表中。 | 否 | + | filter_db | 基于ascend_pytorch_profiler_{rank_id}.db文件,提取通信类大算子数据,计算类关键函数和框架关键函数,节约90%+ 内存,促进快速全局分析。 | 否 | + | pp_chart | 基于打点后的ascend_pytorch_profiler_{rank_id}.db文件,分析打点数据,还原pp流水图 | 否 | + | 自定义分析参数 | 与cann_api_sum、compute_op_sum、hccl_sum等参数功能类似,用户可自定义一套性能数据的分析规则,需要详细了解性能分析的开发人员,具体开发指导请参见“[自定义分析规则开发指导](#自定义分析规则开发指导)”。 | 否 | + + --parallel_mode参数示例如下: + + ```bash + msprof-analyze cluster -d {cluster profiling data path} -m cann_api_sum --parallel_mode concurrent + ``` + + 或 + + ```bash + python3 cluster_analysis.py -d {cluster profiling data path} -m cann_api_sum --parallel_mode concurrent + ``` + + +### 交付件 + +集群分析工具的交付件通过MindStudio Insight工具展示,详见《[MindStudio Insight用户指南](https://www.hiascend.com/document/detail/zh/mindstudio/70RC2/GUI-baseddevelopmenttool/msascendinsightug/AscendInsight_0002.html)》。 + +#### cluster_step_trace_time.csv + +数据解析模式为communication_matrix、communication_time或all时均生成。 + +A列: Step数,是采集性能数据时设置的,一般来说集群性能数据采集一个step足够,如果采集多个step,需要先筛选一下。 + +B列: Type,主要分两种,rank和stage, 和后面的index强相关,可以理解为一个是单卡rank,一个是rank group(pp 并行的stage),如果type为stage,则后面D-K列信息为rank group下的最大值。 + +C列:Index,与type相关,表示卡号。 + +D列:Computing, 此列统计计算时间。 + +E列:Communication(Not Overlapped),此列统计未被掩盖的通信耗时。 + +F列:Overlapped,统计计算与通信重叠的耗时。 + +G列:Communication,通信时间的全部耗时。 + +H列:Free,空闲时间,指device侧既不在通信也不在计算的耗时,可能在做sdma拷贝或者空等。 + +I列:Stage时间,I、J、K列属于pp并行时有效的数值,stage时间代表除receive算子时间外的时间。 + +J列:Bubble时间,指receive时间的总和。 + +K列:Communication(Not Overlapped and Exclude Receive)指剔除receive算子外的并且不被掩盖的通信时间。 + +L列:Preparing,指迭代开始到首个计算或通信算子运行的时间。 + +M列:DP Index,指集群数据按照并行策略切分后所属DP组的索引, 如果没有采集则不显示。 + +N列:PP Index,指集群数据按照并行策略切分后所属PP组的索引,如果没有采集则不显示。 + +O列:TP Index,指集群数据按照并行策略切分后所属TP组的索引,如果没有采集则不显示。 + +**Tips**:先筛选B列type为stage, 看stage间是否有问题,再筛选B列type为rank,看rank是否有问题,根据以下几点排查。 + +* 根据Computing的时间差异判断是否有慢卡,或者有负载不均衡的现象。 + +* 根据Free统计是否有host bound或者分布不均现象。 + +* 根据Communication(Not Overlapped and Exclude Receive)时间判断是否通信耗时占比过大。 + +* 根据Bubble时间的占比和理论计算公式判断bubble设置是否合理,是否stage间有不均衡现象。 + +以上时间理论上都应该处于持平状态,即最大值小于最小值5%,否则就可能出现慢卡。 + +#### cluster_communication_matrix.json + +数据解析模式为communication_matrix或all时生成。 + +直接打开json(vscode或json查看器), 搜索"Total", 会有多个搜索结果,一般来说链路带宽信息的结构: + +```bash +{src_rank}-{dst_rank}: { + "Transport Type": "LOCAL", + "Transit Time(ms)": 0.02462, + "Transit Size(MB)": 16.777216, + "Bandwidth(GB/s)": 681.4466 +} +``` +**Tips**:可以根据rank互联的带宽以及链路类型,判断是否有慢链路的问题。 + +- "LOCAL"是片内拷贝,速度最高。 +- “HCCS”或“PCIE”是节点内片间拷贝,速度居中。 +- “RDMA”是节点间拷贝,速度最低。 + +#### cluster_communication.json + +数据解析模式为communication_time或all时生成。 + +主要为通信耗时数据。 + +#### cluster_analysis.db + +解析analysis.db或ascend_pytorch_profiler_{rank_id}.db生成的交付件,根据数据解析模式不同而解析不同的数据,详情介绍请参见[recipe结果和cluster_analysis.db交付件表结构说明](https://gitee.com/ascend/mstt/tree/pre-research/profiler/msprof_analyze/docs/recipe_output_format.md) + +## 附录 + +### 自定义分析规则开发指导 +详情介绍请参见[自定义分析规则开发指导](https://gitee.com/ascend/mstt/tree/pre-research/profiler/msprof_analyze/docs/custom_analysis_guide.md) diff --git a/profiler/msprof_analyze/cluster_analyse/analysis/cluster_base_info_analysis.py b/profiler/msprof_analyze/cluster_analyse/analysis/cluster_base_info_analysis.py index cb280978c41a639a2f5d17e2a0ff08ed3a9962d6..c5cb2652a1159f9bb645b96c4f60535c74a67859 100644 --- a/profiler/msprof_analyze/cluster_analyse/analysis/cluster_base_info_analysis.py +++ b/profiler/msprof_analyze/cluster_analyse/analysis/cluster_base_info_analysis.py @@ -1,92 +1,92 @@ -# Copyright (c) 2025, Huawei Technologies Co., Ltd. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import json -import os - -from msprof_analyze.cluster_analyse.analysis.base_analysis import BaseAnalysis -from msprof_analyze.prof_common.db_manager import DBManager -from msprof_analyze.cluster_analyse.common_func.utils import increase_shared_value -from msprof_analyze.prof_common.path_manager import PathManager -from msprof_analyze.prof_common.constant import Constant -from msprof_analyze.prof_common.logger import get_logger -from msprof_analyze.prof_common.file_manager import FileManager - - -logger = get_logger() - - -class ClusterBaseInfoAnalysis(BaseAnalysis): - KEY_DISTRIBUTED_ARGS = "distributed_args" - - def __init__(self, param: dict): - super().__init__(param) - self.distributed_args = {} - - def run(self, completed_processes=None, lock=None): - if self.data_type != Constant.DB: - if completed_processes and lock: - increase_shared_value(completed_processes, lock) - logger.info("ClusterBaseInfoAnalysis skipped, since data type is not db") - return - if not self.extract_base_info(): - logger.warning("ClusterBaseInfoAnalysis skipped, since no metadata or distributed args found") - return - self.dump_db() - if completed_processes and lock: - increase_shared_value(completed_processes, lock) - logger.info("ClusterBaseInfoAnalysis completed") - - def dump_db(self): - if not self.distributed_args: - return - output_path = os.path.join(self.cluster_analysis_output_path, Constant.CLUSTER_ANALYSIS_OUTPUT) - PathManager.make_dir_safety(output_path) - result_db = os.path.join(output_path, Constant.DB_CLUSTER_COMMUNICATION_ANALYZER) - conn, curs = DBManager.create_connect_db(result_db) - DBManager.create_tables(result_db, Constant.TABLE_CLUSTER_BASE_INFO) - save_distributed_args = [[json.dumps(self.distributed_args)]] - sql = "insert into {} values ({value})".format(Constant.TABLE_CLUSTER_BASE_INFO, - value="?," * (len(save_distributed_args[0]) - 1) + "?") - DBManager.executemany_sql(conn, sql, save_distributed_args) - DBManager.destroy_db_connect(conn, curs) - - def extract_base_info(self): - file_list = self.get_profiler_metadata_file() - if not file_list: - return False - for file_path in file_list: - try: - meta_data = FileManager.read_json_file(file_path) - except RuntimeError as e: - logger.error("Read json failed. %s", str(e)) - continue - if not meta_data.get(self.KEY_DISTRIBUTED_ARGS): - continue - for key, value in meta_data[self.KEY_DISTRIBUTED_ARGS].items(): - if key == "rank": - continue - self.distributed_args.setdefault(key, value) - return True - return False - - def get_profiler_metadata_file(self): - meta_file_list = [] - for root, _, files in os.walk(self.collection_path): - for file_name in files: - if file_name == Constant.PROFILER_METADATA: - meta_file_list.append(os.path.join(root, file_name)) - return meta_file_list - - +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import os + +from msprof_analyze.cluster_analyse.analysis.base_analysis import BaseAnalysis +from msprof_analyze.prof_common.db_manager import DBManager +from msprof_analyze.cluster_analyse.common_func.utils import increase_shared_value +from msprof_analyze.prof_common.path_manager import PathManager +from msprof_analyze.prof_common.constant import Constant +from msprof_analyze.prof_common.logger import get_logger +from msprof_analyze.prof_common.file_manager import FileManager + + +logger = get_logger() + + +class ClusterBaseInfoAnalysis(BaseAnalysis): + KEY_DISTRIBUTED_ARGS = "distributed_args" + + def __init__(self, param: dict): + super().__init__(param) + self.distributed_args = {} + + def run(self, completed_processes=None, lock=None): + if self.data_type != Constant.DB: + if completed_processes and lock: + increase_shared_value(completed_processes, lock) + logger.info("ClusterBaseInfoAnalysis skipped, since data type is not db") + return + if not self.extract_base_info(): + logger.warning("ClusterBaseInfoAnalysis skipped, since no metadata or distributed args found") + return + self.dump_db() + if completed_processes and lock: + increase_shared_value(completed_processes, lock) + logger.info("ClusterBaseInfoAnalysis completed") + + def dump_db(self): + if not self.distributed_args: + return + output_path = os.path.join(self.cluster_analysis_output_path, Constant.CLUSTER_ANALYSIS_OUTPUT) + PathManager.make_dir_safety(output_path) + result_db = os.path.join(output_path, Constant.DB_CLUSTER_COMMUNICATION_ANALYZER) + conn, curs = DBManager.create_connect_db(result_db) + DBManager.create_tables(result_db, Constant.TABLE_CLUSTER_BASE_INFO) + save_distributed_args = [[json.dumps(self.distributed_args)]] + sql = "insert into {} values ({value})".format(Constant.TABLE_CLUSTER_BASE_INFO, + value="?," * (len(save_distributed_args[0]) - 1) + "?") + DBManager.executemany_sql(conn, sql, save_distributed_args) + DBManager.destroy_db_connect(conn, curs) + + def extract_base_info(self): + file_list = self.get_profiler_metadata_file() + if not file_list: + return False + for file_path in file_list: + try: + meta_data = FileManager.read_json_file(file_path) + except RuntimeError as e: + logger.error("Read json failed. %s", str(e)) + continue + if not meta_data.get(self.KEY_DISTRIBUTED_ARGS): + continue + for key, value in meta_data[self.KEY_DISTRIBUTED_ARGS].items(): + if key == "rank": + continue + self.distributed_args.setdefault(key, value) + return True + return False + + def get_profiler_metadata_file(self): + meta_file_list = [] + for root, _, files in os.walk(self.collection_path): + for file_name in files: + if file_name == Constant.PROFILER_METADATA: + meta_file_list.append(os.path.join(root, file_name)) + return meta_file_list + + diff --git a/profiler/msprof_analyze/cluster_analyse/analysis/comm_matrix_analysis.py b/profiler/msprof_analyze/cluster_analyse/analysis/comm_matrix_analysis.py index a87803438aef3a733c41588413ff2281b85ae418..3a538509b88e7ce3996aa539e49bf714bb163766 100644 --- a/profiler/msprof_analyze/cluster_analyse/analysis/comm_matrix_analysis.py +++ b/profiler/msprof_analyze/cluster_analyse/analysis/comm_matrix_analysis.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy import os from collections import defaultdict @@ -108,7 +107,7 @@ class CommMatrixAnalysis(BaseAnalysis): Constant.OP_NAME: '' } for op_name, op_dict in step_dict.items(): - link_info = defaultdict(lambda: copy.deepcopy(default_value)) + link_info = defaultdict(lambda: default_value.copy()) for rank_id, rank_dict in op_dict.items(): process_link_key(rank_id, rank_dict) step_dict[op_name] = convert_local_to_global_rank() @@ -120,7 +119,7 @@ class CommMatrixAnalysis(BaseAnalysis): Constant.TRANSIT_SIZE_MB: 0, Constant.OP_NAME: '' } - total_op_info = defaultdict(lambda: copy.deepcopy(default_value)) + total_op_info = defaultdict(lambda: default_value.copy()) for op_name, op_dict in step_dict.items(): if self.check_add_op(op_name): for link_key, link_dict in op_dict.items(): diff --git a/profiler/msprof_analyze/cluster_analyse/analysis/communication_analysis.py b/profiler/msprof_analyze/cluster_analyse/analysis/communication_analysis.py index 61daa5b943d4a718d90f80203bac3fc948202199..47846522a9543511b3e55579d05b814d1ca9717d 100644 --- a/profiler/msprof_analyze/cluster_analyse/analysis/communication_analysis.py +++ b/profiler/msprof_analyze/cluster_analyse/analysis/communication_analysis.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy import os from collections import defaultdict @@ -79,7 +78,7 @@ class CommunicationAnalysis(BaseAnalysis): Constant.COMMUNICATION_TIME_INFO: defaultdict(float), Constant.COMMUNICATION_BANDWIDTH_INFO: {} } - total_rank_dict = defaultdict(lambda: copy.deepcopy(default_value)) + total_rank_dict = defaultdict(lambda: default_value.copy()) for _, rank_dict in comm_ops.items(): for rank_id, communication_op_info in rank_dict.items(): for com_info, com_info_dict in communication_op_info.items(): diff --git a/profiler/msprof_analyze/cluster_analyse/cluster_analysis.py b/profiler/msprof_analyze/cluster_analyse/cluster_analysis.py index d7d71908506256eca3b8bd884a593188546189ce..6464bb732ddf57b2790d99ac7148ce3ecaf327ce 100644 --- a/profiler/msprof_analyze/cluster_analyse/cluster_analysis.py +++ b/profiler/msprof_analyze/cluster_analyse/cluster_analysis.py @@ -142,7 +142,6 @@ def cluster_analysis_main(): parser.add_argument("--parallel_mode", type=str, help="context mode", default="concurrent") parser.add_argument("--export_type", type=str, help="recipe export type", choices=["db", "notebook"], default="db") parser.add_argument("--rank_list", type=str, help="Rank id list", default='all') - parser.add_argument("--step_id", type=int, help="Step id", default=Constant.VOID_STEP) args, extra_args = parser.parse_known_args() parameter = vars(args) diff --git a/profiler/msprof_analyze/cluster_analyse/common_func/context.py b/profiler/msprof_analyze/cluster_analyse/common_func/context.py index b41972c0d21ac73fb5b9f0291cddec8d9a06b94a..e4f716e90d991645de514e2bc6ecd12920c0c9e1 100644 --- a/profiler/msprof_analyze/cluster_analyse/common_func/context.py +++ b/profiler/msprof_analyze/cluster_analyse/common_func/context.py @@ -16,6 +16,7 @@ import os from functools import partial from concurrent import futures +from collections import defaultdict from msprof_analyze.prof_common.constant import Constant from msprof_analyze.prof_common.logger import get_logger @@ -68,6 +69,7 @@ class ConcurrentContext(Context): super().__init__() self._custom = executor is None self._executor = executor or futures.ProcessPoolExecutor(max_workers=os.cpu_count()) + self.future_dict = defaultdict(list) def __enter__(self): if self._executor is None: @@ -88,3 +90,11 @@ class ConcurrentContext(Context): def wait(self, waitable): return waitable + + def submit(self, name, func, *args, **kwargs): + self.future_dict[name].append(self._executor.submit(func, *args, **kwargs)) + + def wait_all_futures(self): + for _, future_list in self.future_dict.items(): + for future in future_list: + future.result() \ No newline at end of file diff --git a/profiler/msprof_analyze/cluster_analyse/common_func/table_constant.py b/profiler/msprof_analyze/cluster_analyse/common_func/table_constant.py index 27daae78cb9004a2f713b9ebf9bc6ab916dd9325..f0f2f6735b87f1c4515a3ef6fd47c3c14a5b03b2 100644 --- a/profiler/msprof_analyze/cluster_analyse/common_func/table_constant.py +++ b/profiler/msprof_analyze/cluster_analyse/common_func/table_constant.py @@ -39,3 +39,26 @@ class TableConstant: DST_RANK = "dst_rank" TRANSPORT_TYPE = "transport_type" OPNAME = "op_name" + # CommunicationGroupMapping + GROUP_ID = "group_id" + PG_NAME = "pg_name" + NAME = "name" + VALUE = "value" + + +class ProfilerTableConstant: + + # COMMUNICATION OP + OP_ID = "opId" + OP_NAME = "opName" + START_NS = "startNS" + END_NS = "endNS" + CONNECTION_ID = "connectionId" + GROUP_NAME = "groupName" + RELAY = "relay" + RETRY = "retry" + DATA_TYPE = "dataType" + ALG_TYPE = "algType" + COUNT = "count" + OP_TYPE = "opType" + WAIT_NS = "waitNS" diff --git a/profiler/msprof_analyze/cluster_analyse/common_func/tables_config.py b/profiler/msprof_analyze/cluster_analyse/common_func/tables_config.py index 42c509694cfd1a896f60ec6b282de040f22204b6..7c948ead594dcf5c67d1e70ff417b7bedf2b9265 100644 --- a/profiler/msprof_analyze/cluster_analyse/common_func/tables_config.py +++ b/profiler/msprof_analyze/cluster_analyse/common_func/tables_config.py @@ -31,10 +31,7 @@ class TablesConfig: ], "CommunicationGroupMap": [ ("type", "TEXT, null"), - ("rank_set", "TEXT, null"), - ("group_name", "TEXT, null"), - ("group_id", "TEXT, null"), - ("pg_name", "TEXT, null") + ("rank_set", "TEXT, null") ], "ClusterCommAnalyzerBandwidthMap": [ ("rank_set", "TEXT, null"), @@ -133,10 +130,8 @@ class TablesConfig: ], "CommunicationGroupMappingMap": [ ("type", "TEXT, null"), - ("rank_set", "TEXT, null"), ("group_name", "TEXT, null"), - ("group_id", "TEXT, null"), - ("pg_name", "TEXT, null") + ("rank_set", "TEXT, null") ], "ClusterBaseInfoMap": [ ("distributed_args", "TEXT, null") diff --git a/profiler/msprof_analyze/cluster_analyse/common_func/utils.py b/profiler/msprof_analyze/cluster_analyse/common_func/utils.py index f2ba499d6f42986c1b2ecca49998f33d766c2d21..e3515857ad26a5eda8d3ddd6a3a09746dd0fd9ae 100644 --- a/profiler/msprof_analyze/cluster_analyse/common_func/utils.py +++ b/profiler/msprof_analyze/cluster_analyse/common_func/utils.py @@ -92,3 +92,44 @@ def double_hash(data): hash_values[1] = (hash_values[1] * prime[1] + ord(d)) & uint32_max return ((hash_values[0] << uint32_bits) | hash_values[1]) + + +def calculate_zscore(x, mean, std): + if std != 0: + zscore = (x - mean) / std + elif x > mean: + zscore = 100 + else: + zscore = -100 + return zscore + + +def detect_outliers_z_score(data, threshold=3): + """ + 使用 Z-Score 方法判断是否存在异常值。 + Z-Score 是一种统计方法,用于衡量数据点与均值的标准差距离。 + 如果某个数据点的 Z-Score 超过阈值(默认为3),则认为它是异常值。 + + 返回值: + - True:存在异常值 + - False:不存在异常值 + """ + # 计算数据的均值 + mean = np.mean(data) # 均值表示数据的中心位置 + + # 计算数据的标准差 + std = np.std(data) # 标准差表示数据的离散程度 + + # 如果标准差为0,直接返回 False(不存在异常值) + if std == 0: + return False + + # 计算 Z-Score 的上阈值和下阈值 + z_scores_upper_threshold = threshold * std + mean + z_scores_lower_threshold = -threshold * std + mean + + # 判断是否存在 Z-Score 超过阈值的数据点 + has_outliers = any(x > z_scores_upper_threshold or x < z_scores_lower_threshold for x in data) + + # 返回是否存在异常值的布尔值 + return has_outliers diff --git a/profiler/msprof_analyze/cluster_analyse/communication_group/__init__.py b/profiler/msprof_analyze/cluster_analyse/communication_group/__init__.py index de0604079e1323b2749bc801a6e8326893c73498..7101187a2c2619f3b1c20dded14b433950b4c662 100644 --- a/profiler/msprof_analyze/cluster_analyse/communication_group/__init__.py +++ b/profiler/msprof_analyze/cluster_analyse/communication_group/__init__.py @@ -11,4 +11,4 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -# limitations under the License. \ No newline at end of file +# limitations under the License. diff --git a/profiler/msprof_analyze/cluster_analyse/communication_group/base_communication_group.py b/profiler/msprof_analyze/cluster_analyse/communication_group/base_communication_group.py index 2c02bfdbf1bdd22dec838b56b7eb0c9c9872cac2..8f6625f8f6bbf646cfd77099b70c36398680f67a 100644 --- a/profiler/msprof_analyze/cluster_analyse/communication_group/base_communication_group.py +++ b/profiler/msprof_analyze/cluster_analyse/communication_group/base_communication_group.py @@ -18,21 +18,15 @@ from abc import abstractmethod from collections import defaultdict from copy import deepcopy from multiprocessing import Pool -import pandas as pd from msprof_analyze.cluster_analyse.cluster_utils.data_transfer_adapter import DataTransferAdapter -from msprof_analyze.cluster_analyse.common_func.utils import double_hash from msprof_analyze.prof_common.constant import Constant from msprof_analyze.prof_common.logger import get_logger -from msprof_analyze.prof_common.file_manager import FileManager logger = get_logger() class BaseCommunicationGroup: - KEY_PARALLEL_GROUP_INFO = "parallel_group_info" - KEY_COMM_GROUP_PARALLEL_INFO = "comm_group_parallel_info" - def __init__(self, params: dict): self.collection_path = params.get(Constant.COLLECTION_PATH) self.cluster_analysis_output_path = params.get(Constant.CLUSTER_ANALYSIS_OUTPUT_PATH) @@ -44,11 +38,9 @@ class BaseCommunicationGroup: self.collective_group_dict = defaultdict(set) self.p2p_comm_group = [] self.communication_group = {} - self.parallel_group_info = {} self.communication_ops = [] self.matrix_ops = [] self.adapter = DataTransferAdapter() - self.comm_group_parallel_info_df = None def load_communication_data(self): comm_op_dirs = [] @@ -120,18 +112,6 @@ class BaseCommunicationGroup: def read_communication_func(self, params: tuple): pass - def read_parallel_group_info(self): - for _, profiling_dir_path in self.data_map.items(): - meta_file = os.path.join(profiling_dir_path, Constant.PROFILER_METADATA) - if not os.path.exists(meta_file): - continue - meta_data = FileManager.read_json_file(meta_file) - if self.KEY_PARALLEL_GROUP_INFO not in meta_data: - continue - for group_id, group_info in meta_data[self.KEY_PARALLEL_GROUP_INFO].items(): - if group_id not in self.parallel_group_info: - self.parallel_group_info[group_id] = group_info - def analyze_communication_data(self): for rank_id, rank_id_comm_dict, rank_id_matrix_dict in self.rank_comm_dir_dict: for step_id, step_id_dict in rank_id_comm_dict.items(): @@ -165,11 +145,9 @@ class BaseCommunicationGroup: def generate(self): self.load_communication_data() self.analyze_communication_data() - self.read_parallel_group_info() self.set_p2p_groups() self.generate_collective_communication_group() self.generate_p2p_communication_group() - self.analyze_parallel_group_info() self.dump_data() return self.collect_comm_data() @@ -237,32 +215,6 @@ class BaseCommunicationGroup: Constant.COMM_OP_INFO: op_link_info }) - def analyze_parallel_group_info(self): - # create comm group dataframe - comm_group_cols = ["type", "rank_set", "group_name"] - comm_group_df = pd.DataFrame(columns=comm_group_cols) - for group_name, rank_set in self.collective_group_dict.items(): - comm_group_df.loc[comm_group_df.shape[0]] = [Constant.COLLECTIVE, list(rank_set), group_name] - - # create parallel group dataframe - parallel_group_cols = ["group_name", "group_id", "pg_name"] - parallel_group_df = pd.DataFrame(columns=parallel_group_cols) - for group_id, parallel_info in self.parallel_group_info.items(): - group_name = str(double_hash(group_id)) # group_name is hashed group_id - pg_name = parallel_info.get("group_name", "") - if not pg_name: - continue - parallel_group_df.loc[parallel_group_df.shape[0]] = [group_name, group_id, pg_name] - - # merge by group_name - df = pd.merge(comm_group_df, parallel_group_df, on='group_name', how='left') - # add p2p group - for rank_set in self.communication_group[Constant.P2P]: - df.loc[df.shape[0]] = [Constant.P2P, list(rank_set), None, None, None] - df.fillna("", inplace=True) - - self.comm_group_parallel_info_df = df - class UnionFind(object): """Disjoint Set Union""" diff --git a/profiler/msprof_analyze/cluster_analyse/communication_group/communication_db_group.py b/profiler/msprof_analyze/cluster_analyse/communication_group/communication_db_group.py index 99b55fb9956fff23ba36d9f4b80ba05caa33562c..7d1b4ec250ba1d25079a86f1b0bf95fd2c8906aa 100644 --- a/profiler/msprof_analyze/cluster_analyse/communication_group/communication_db_group.py +++ b/profiler/msprof_analyze/cluster_analyse/communication_group/communication_db_group.py @@ -76,9 +76,12 @@ class CommunicationDBGroup(BaseCommunicationGroup): return rank_id, comm_data, comm_matrix_data def dump_data(self): - self.comm_group_parallel_info_df["rank_set"] = (self.comm_group_parallel_info_df["rank_set"]. - apply(lambda x: "(" + ",".join(str(i) for i in x) + ")")) - res = self.comm_group_parallel_info_df.values.tolist() + res = [] + for data_type, data_list in self.communication_group.items(): + for data in data_list: + rank_set = "(" + ",".join(str(i) for i in data) + ")" + data = [data_type, rank_set] + res.append(data) dump_group_db(res, self.COMMUNICATION_GROUP_TABLE, self.cluster_analysis_output_path) @@ -145,9 +148,16 @@ class CommunicationDBGroupOptimized(BaseCommunicationGroup): return comm_data_dict def dump_data(self): - self.comm_group_parallel_info_df["rank_set"] = (self.comm_group_parallel_info_df["rank_set"]. - apply(lambda x: "(" + ",".join(str(i) for i in x) + ")")) - res = self.comm_group_parallel_info_df.values.tolist() + res = [] + for data_type, data_list in self.communication_group.items(): + if data_type == Constant.P2P: + for data in data_list: + rank_set = "(" + ",".join(str(i) for i in data) + ")" + res.append([data_type, "", rank_set]) + continue + for group_name, data in data_list: + rank_set = "(" + ",".join(str(i) for i in data) + ")" + res.append([data_type, group_name, rank_set]) dump_group_db(res, self.COMMUNICATION_GROUP_MAPPING_TABLE, self.cluster_analysis_output_path) def _merge_data_with_rank(self, rank_id: int, data_list: list): diff --git a/profiler/msprof_analyze/cluster_analyse/communication_group/communication_json_group.py b/profiler/msprof_analyze/cluster_analyse/communication_group/communication_json_group.py index 2975050da0706870136ad4d8e84f28c56ded4718..97948228264f7b6fb2aed8d8b8766b3515626d40 100644 --- a/profiler/msprof_analyze/cluster_analyse/communication_group/communication_json_group.py +++ b/profiler/msprof_analyze/cluster_analyse/communication_group/communication_json_group.py @@ -14,7 +14,6 @@ # limitations under the License. import os -from copy import deepcopy from msprof_analyze.cluster_analyse.communication_group.base_communication_group import BaseCommunicationGroup from msprof_analyze.prof_common.file_manager import FileManager @@ -27,10 +26,8 @@ class CommunicationJsonGroup(BaseCommunicationGroup): super().__init__(params) def dump_data(self): - res = deepcopy(self.communication_group) - res[self.KEY_COMM_GROUP_PARALLEL_INFO] = self.comm_group_parallel_info_df.to_dict(orient="records") FileManager.create_json_file( - self.cluster_analysis_output_path, res, self.COMMUNICATION_GROUP_JSON + self.cluster_analysis_output_path, self.communication_group, self.COMMUNICATION_GROUP_JSON ) def read_communication_func(self: any, params: tuple): diff --git a/profiler/msprof_analyze/cluster_analyse/recipes/base_recipe_analysis.py b/profiler/msprof_analyze/cluster_analyse/recipes/base_recipe_analysis.py index a8b503592536e529b4a9043058284f0094b08038..6a0273c8cfa8fcbea0e6f9f1e81a9a422624fc3a 100644 --- a/profiler/msprof_analyze/cluster_analyse/recipes/base_recipe_analysis.py +++ b/profiler/msprof_analyze/cluster_analyse/recipes/base_recipe_analysis.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import argparse +import json import os import shutil import sys @@ -24,6 +25,7 @@ import pandas as pd from msprof_analyze.prof_common.db_manager import DBManager from msprof_analyze.cluster_analyse.common_func.utils import convert_unit from msprof_analyze.prof_common.constant import Constant +from msprof_analyze.prof_common.database_service import DatabaseService from msprof_analyze.prof_common.logger import get_logger from msprof_analyze.prof_common.path_manager import PathManager @@ -33,8 +35,10 @@ logger = get_logger() class BaseRecipeAnalysis(ABC): UNIT = "Us" DB_UNIT = "Ns" - RANK_LIST = "rank_list" + TP_SIZE = "tensor_model_parallel_size" + PP_SIZE = "pipeline_model_parallel_size" + DP_SIZE = "data_parallel_size" def __init__(self, params): self._collection_dir = params.get(Constant.COLLECTION_PATH, "") @@ -49,7 +53,6 @@ class BaseRecipeAnalysis(ABC): rank_list = params.get(Constant.RANK_LIST, 'all') self._rank_list = rank_list if rank_list == "all" else [int(rank) for rank in rank_list.split(",") if rank.isdigit()] - self._step_id = params.get(Constant.STEP_ID, Constant.VOID_STEP) self._extra_args = self.get_extra_argument(params.get(Constant.EXTRA_ARGS)) PathManager.make_dir_safety(self._output_path) @@ -107,7 +110,7 @@ class BaseRecipeAnalysis(ABC): result_db = custom_db_path if custom_db_path else os.path.join(self.output_path, file_name) conn, cursor = DBManager.create_connect_db(result_db) if isinstance(data, pd.DataFrame): - data.to_sql(table_name, conn, if_exists='replace', index=True) + data.to_sql(table_name, conn, if_exists='replace', index=index) else: logger.error(f"Unknown dump data type: {type(data)}") DBManager.destroy_db_connect(conn, cursor) @@ -144,6 +147,67 @@ class BaseRecipeAnalysis(ABC): if helper_file_path is not None: shutil.copy(helper_file_path, helper_output_path) + def map_rank_pp_stage(self, distributed_args): + tp_size = distributed_args.get(self.TP_SIZE, 1) + pp_size = distributed_args.get(self.PP_SIZE, 1) + dp_size = distributed_args.get(self.DP_SIZE, 1) + rank_pp_stage_map = {} + rank = 0 + for i in range(pp_size): + for _ in range(tp_size * dp_size): + rank_pp_stage_map[rank] = i + rank += 1 + return rank_pp_stage_map + + def load_distributed_args(self): + tp_size = self._extra_args.get("tp", None) + pp_size = self._extra_args.get("pp", None) + dp_size = self._extra_args.get("dp", None) + if tp_size and pp_size and dp_size: + if tp_size <= 0 or pp_size <= 0 or dp_size <= 0: + logger.error("Invalid distributed_args, tp pp dp < 0.") + return None + return { + self.TP_SIZE: tp_size, + self.DP_SIZE: dp_size, + self.PP_SIZE: pp_size, + } + else: + rank_id = list(self._data_map.keys())[0] + profiler_db_path = self._data_map[rank_id] + db_path = os.path.join(profiler_db_path, Constant.SINGLE_OUTPUT, f"ascend_pytorch_profiler_{rank_id}.db") + if os.path.exists(db_path): + try: + service = DatabaseService(db_path) + service.add_table_for_query("META_DATA", ["name", "value"]) + df = service.query_data().get("META_DATA", None) + distributed_args = df.loc[df["name"] == "distributed_args", "value"] + if distributed_args.empty: + distributed_args = {} + logger.error("Distributed args not in profiling files, please input manually.") + else: + distributed_args = json.loads(distributed_args.values[0]) + except Exception as err: + logger.error(err) + logger.error("Distributed args not in profiling files, please input manually.") + return None + tp_size = distributed_args.get(self.TP_SIZE, 1) + pp_size = distributed_args.get(self.PP_SIZE, 1) + dp_size = distributed_args.get(self.DP_SIZE, 1) + if not isinstance(tp_size, int) or not isinstance(pp_size, int) or not isinstance(dp_size, int): + logger.error("Invalid distributed_args in profiling files, please input manually.") + return None + if tp_size <= 0 or pp_size <= 0 or dp_size <= 0: + logger.error("Invalid distributed_args in profiling files, please input manually.") + return None + return { + self.TP_SIZE: tp_size, + self.PP_SIZE: pp_size, + self.DP_SIZE: dp_size, + } + logger.error(f"Db_file: {db_path} not exist.") + return None + def _get_rank_db(self): invalid_rank_id = [] if self._rank_list == 'all': @@ -158,55 +222,27 @@ class BaseRecipeAnalysis(ABC): db_paths = [] for rank_id in rank_ids: rank_path = self._data_map[rank_id] - db_path = os.path.join(rank_path, Constant.SINGLE_OUTPUT, f"ascend_pytorch_profiler_{rank_id}.db") - if os.path.exists(db_path): - db_paths.append({Constant.RANK_ID: rank_id, Constant.PROFILER_DB_PATH: db_path, - Constant.STEP_RANGE: self._get_step_range(db_path)}) + profiler_db_path = os.path.join(rank_path, Constant.SINGLE_OUTPUT, f"ascend_pytorch_profiler_{rank_id}.db") + analysis_db_path = os.path.join(rank_path, Constant.SINGLE_OUTPUT, f"analysis.db") + if not os.path.exists(profiler_db_path): + logger.warning(f"Profiler DB file not found, rank id: {rank_id}, db path: {profiler_db_path}") + continue + db_path_dict = {Constant.RANK_ID: rank_id, Constant.PROFILER_DB_PATH: profiler_db_path} + if os.path.exists(analysis_db_path): + db_path_dict[Constant.ANALYSIS_DB_PATH] = analysis_db_path else: - logger.warning(f"DB file not found, rank id: {rank_id}, db path: {db_path}.") + logger.warning(f"Analysis DB file not found, rank id: {rank_id}, db path: {analysis_db_path}") + db_paths.append(db_path_dict) if invalid_rank_id: - logger.warning(f"Invalid Rank id: [{','.join(invalid_rank_id)}].") + logger.warning(f"Invalid Rank id : [{','.join(invalid_rank_id)}].") return db_paths - def _get_step_range(self, db_path): - step_range = {} - if self._step_id == Constant.VOID_STEP: - return step_range - conn, cursor = DBManager.create_connect_db(db_path) - if not DBManager.judge_table_exists(cursor, "STEP_TIME"): - logger.error(f"The STEP_TIME table does not exist in the database: {db_path}, " - f"the parameter step_id will not take effect.") - DBManager.destroy_db_connect(conn, cursor) - return step_range - - step_time = [] - sql = f"select id, startNs, endNs from STEP_TIME" - try: - step_time = DBManager.fetch_all_data(cursor, sql) - except Exception as err: - logger.error(err) - finally: - DBManager.destroy_db_connect(conn, cursor) - - for step_data in step_time: - if step_data.get("id") == self._step_id: - step_range = step_data - break - if not step_range: - step_list = ", ".join([str(step.get("id", "")) for step in step_time]) - logger.error(f"Invalid step_id {self._step_id} in the database: {db_path}, " - f"step_id must be an element of the set ({step_list}), " - f"the parameter step_id will not take effect.") - return step_range - def _mapper_func(self, data_map, analysis_class): """ Extract the profiling data required for cluster analysis from each device, and then aggregate the results from each device to be processed by a reduce function. Params: - data_map: eg. {"RANK_ID": 1, - "profiler_db_path": "xxxx/ascend_pytorch_profiler_1.db", - "step_range": {"id": 2, "startNs": 12345, "endNs": 12443]} + data_map: eg. {"RANK_ID": 1, "profiler_db_path": "xxxx/ascend_pytorch_profiler_1.db"} analysis_class: hccl_sum, compute_op_sum, cann_api_sum, mstx_sum…… """ - pass + pass \ No newline at end of file diff --git a/profiler/msprof_analyze/cluster_analyse/recipes/cann_api_sum/cann_api_sum.py b/profiler/msprof_analyze/cluster_analyse/recipes/cann_api_sum/cann_api_sum.py index 22cd2c64aeb09a417c1915bfbaaed0cc49bd8b00..01e9b3da79737ec032ec942073007a1f8fb3da60 100644 --- a/profiler/msprof_analyze/cluster_analyse/recipes/cann_api_sum/cann_api_sum.py +++ b/profiler/msprof_analyze/cluster_analyse/recipes/cann_api_sum/cann_api_sum.py @@ -90,15 +90,15 @@ class CannApiSum(BaseRecipeAnalysis): self.add_helper_file("cluster_display.py") def save_db(self): - self.dump_data(self._stats_rank_data, Constant.DB_CLUSTER_COMMUNICATION_ANALYZER, "CannApiSumRank") + self.dump_data(self._stats_rank_data, Constant.DB_CLUSTER_COMMUNICATION_ANALYZER, "CannApiSumRank", + index=False) self.dump_data(self._stats_data, Constant.DB_CLUSTER_COMMUNICATION_ANALYZER, "CannApiSum") def _mapper_func(self, data_map, analysis_class): profiler_db_path = data_map.get(Constant.PROFILER_DB_PATH) rank_id = data_map.get(Constant.RANK_ID) - step_range = data_map.get(Constant.STEP_RANGE) - df = CannApiSumExport(profiler_db_path, analysis_class, step_range).read_export_db() + df = CannApiSumExport(profiler_db_path, analysis_class).read_export_db() if df is None or df.empty: logger.warning(f"There is no stats data in {profiler_db_path}.") return None, None - return rank_id, df + return rank_id, df \ No newline at end of file diff --git a/profiler/msprof_analyze/cluster_analyse/recipes/cann_api_sum/stats.ipynb b/profiler/msprof_analyze/cluster_analyse/recipes/cann_api_sum/stats.ipynb index 2bc1b77e9b14777b57771313233beb7fa255d2e9..c97f039c5a01a6e7cce2968d569d79e137e76f8c 100644 --- a/profiler/msprof_analyze/cluster_analyse/recipes/cann_api_sum/stats.ipynb +++ b/profiler/msprof_analyze/cluster_analyse/recipes/cann_api_sum/stats.ipynb @@ -72,7 +72,7 @@ "outputs": [], "source": [ "per_rank_df = pd.read_csv(\"rank_stats.csv\")\n", - "cluster_display.display_stats_per_operation(per_rank_df, box=False, scatter=False)" + "cluster_display.display_stats_per_operation(per_rank_df, xaxis_title='rank', yaxis_title='duration (ns)')" ] } ], diff --git a/profiler/msprof_analyze/cluster_analyse/recipes/cluster_display.py b/profiler/msprof_analyze/cluster_analyse/recipes/cluster_display.py index 5a23a280fff9b3c0492f1c8cd2fac20824afb708..7e8913948f78231e5eb65da0806dc6834dcf5e31 100644 --- a/profiler/msprof_analyze/cluster_analyse/recipes/cluster_display.py +++ b/profiler/msprof_analyze/cluster_analyse/recipes/cluster_display.py @@ -14,11 +14,14 @@ # limitations under the License. import logging +import math +import matplotlib.pyplot as plt import numpy as np import pandas as pd import plotly.graph_objects as go from IPython.display import display, HTML from ipywidgets import Dropdown, fixed, interact +from msprof_analyze.cluster_analyse.common_func.utils import calculate_zscore logger = logging.getLogger("cluster_display") @@ -189,6 +192,29 @@ def display_graph(figs, x_axis, y_axes, title=None, figs.append(fig) +def display_bar(x_axis, y_axes, title=None, y_index=None): + if isinstance(y_axes, pd.DataFrame): + data = y_axes.set_index(x_axis) + elif isinstance(y_axes, dict): + data = pd.DataFrame(y_axes, index=x_axis) + elif isinstance(y_axes, pd.Series): + data = pd.DataFrame({"": y_axes}, index=x_axis) + elif isinstance(y_axes, np.ndarray): + data = pd.DataFrame({"": pd.Series(y_axes)}, index=x_axis) + else: + return + + fig = data.plot.bar(title=title) + fig.bar_label(fig.containers[0]) + if y_index is not None and y_index in y_axes: + # get index of the top1 + top1_indices = data[y_index].nlargest(1).index + # change the color for the top1 + for i, bar in enumerate(fig.patches): + if data.index[i] in top1_indices: + bar.set_color('#FFA500') # highlight in orange + + def display_stats_per_rank_groups_combobox(rank_stats_gdf): names = list(rank_stats_gdf.groups.keys()) if len(names) > 1: @@ -238,3 +264,64 @@ def display_stats_optional_combobox(options, display_func, args, description="Op dropdown.value = options[0] elif len(options) == 1: display_func(options[0], args) + + +def compute_quantile_intervals(lst, num_intervals): + lst.sort(reverse=False) + if len(lst) > num_intervals: + min_value = min(lst) + max_value = max(lst) + interval_size = len(lst) / num_intervals + result = [min_value] + for i in range(1, num_intervals): + index = int(math.ceil(i * interval_size)) - 1 + result.append(lst[index]) + result.append(max_value) + else: + result = lst + return result[::-1] + + +def process_data(df, group_cols, value_col, num_intervals): + grouped = df.groupby(group_cols)[value_col].apply(list).to_dict() + data = {k: compute_quantile_intervals(v, num_intervals) for k, v in grouped.items()} + max_len = max(len(v) for v in data.values()) + data_dict = { + k: v + [np.nan] * (max_len - len(v)) + for k, v in data.items() + } + # 使用sorted()函数和lambda表达式对字典的键进行排序,reverse=True表示降序排列 + sorted_items = sorted(data_dict.items(), key=lambda item: item[0], reverse=True) + # 将排序后的列表转换为字典 + data_dict = dict(sorted_items) + data_dealed = pd.DataFrame(data_dict) + return data_dealed + + +def plot_data(df, title, ylabel): + ax = df.plot(kind='bar', figsize=(12, 6)) + ax.set_title(title, fontsize=14) + ax.set_xlabel('opTypeRelatedRanksDataSize', fontsize=12) + ax.set_ylabel(ylabel, fontsize=12) + ax.legend(title='Percentiles', bbox_to_anchor=(1.05, 1)) + plt.tight_layout() + plt.show() + + +def display_transmittime_bar(slowlinkops_df, ratio_set=0.05, optype='hcom_allGather_', + relatedranks=5, datasize=1024): + slowlinkops_df_f = slowlinkops_df[(slowlinkops_df['opType'] == optype) & + (slowlinkops_df['relatedRanks'] == relatedranks) & (slowlinkops_df['dataSize'] == datasize)] + slowlinkops_df_f['relatedRanks'] = slowlinkops_df_f['relatedRanks'].apply(str) + slowlinkops_df_f['dataSize'] = slowlinkops_df_f['dataSize'].apply(str) + slowlinkops_df_f['opTypeRelatedRanksDataSize'] = slowlinkops_df_f['opType'] + \ + slowlinkops_df_f['relatedRanks'] + '_' + slowlinkops_df_f['dataSize'] + slowlinkops_df_f['transmitTime_Zscore'] = slowlinkops_df_f['transmitTime'].apply( + lambda x: calculate_zscore(x, slowlinkops_df_f['transmitTime'].mean(), slowlinkops_df_f['transmitTime'].std())) + num_intervals = int(1 / ratio_set) + + data_tt = process_data(slowlinkops_df_f, 'opTypeRelatedRanksDataSize', 'transmitTime', num_intervals) + data_ttzscore = process_data(slowlinkops_df_f, 'opTypeRelatedRanksDataSize', 'transmitTime_Zscore', num_intervals) + + plot_data(data_tt, 'Transmit Time Distribution', 'Time (ns)') + plot_data(data_ttzscore, 'Z-Score of Transmit Time Distribution', 'Z-Score') \ No newline at end of file diff --git a/profiler/msprof_analyze/cluster_analyse/recipes/cluster_time_compare_summary/__init__.py b/profiler/msprof_analyze/cluster_analyse/recipes/cluster_time_compare_summary/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/profiler/msprof_analyze/cluster_analyse/recipes/cluster_time_compare_summary/cluster_time_compare_summary.py b/profiler/msprof_analyze/cluster_analyse/recipes/cluster_time_compare_summary/cluster_time_compare_summary.py new file mode 100644 index 0000000000000000000000000000000000000000..ef8c1acc9765bada7e3c021dd1f6db42351ca3b3 --- /dev/null +++ b/profiler/msprof_analyze/cluster_analyse/recipes/cluster_time_compare_summary/cluster_time_compare_summary.py @@ -0,0 +1,120 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from msprof_analyze.cluster_analyse.recipes.base_recipe_analysis import BaseRecipeAnalysis +from msprof_analyze.prof_common.constant import Constant +from msprof_analyze.prof_common.database_service import DatabaseService +from msprof_analyze.prof_common.db_manager import DBManager +from msprof_analyze.prof_common.logger import get_logger +from msprof_analyze.prof_common.path_manager import PathManager + +logger = get_logger() + + +class ClusterTimeCompareSummary(BaseRecipeAnalysis): + BP = "bp" # 被对比的路径参数 + TABLE_CLUSTER_TIME_COMPARE_SUMMARY = "ClusterTimeCompareSummary" + CLUSTER_TIME_SUMMARY_COLUMNS = [ + "rank", + "step", + "stepTime", + "computation", + "communicationNotOverlapComputation", + "communicationOverlapComputation", + "communication", + "free", + "communicationWaitStageTime", + "communicationTransmitStageTime", + "memory", + "memoryNotOverlapComputationCommunication", + "taskLaunchDelayAvgTime" + ] + + def __init__(self, params): + super().__init__(params) + self.db_path = os.path.join(self._collection_dir, Constant.CLUSTER_ANALYSIS_OUTPUT, + Constant.DB_CLUSTER_COMMUNICATION_ANALYZER) + self.base_db_path = os.path.join(self._extra_args.get(self.BP, ""), Constant.CLUSTER_ANALYSIS_OUTPUT, + Constant.DB_CLUSTER_COMMUNICATION_ANALYZER) + self.compare_result = None + + @property + def base_dir(self): + return os.path.basename(os.path.dirname(__file__)) + + @classmethod + def add_parser_argument(cls, parser): + BaseRecipeAnalysis.add_parser_argument(parser) + parser.add_argument('--bp', type=PathManager.expanduser_for_argumentparser, default="", + help="base profiling data path") + + def run(self, context=None): + logger.info("ClusterTimeCompareSummary starts running.") + if not self.check_params_is_valid(): + return + self.get_compare_data() + self.save_db() + + def check_params_is_valid(self) -> bool: + base_path = self._extra_args.get(self.BP, "") + if not base_path: + logger.error("Must specify the --bp parameter.") + return False + if self._export_type == Constant.NOTEBOOK: + logger.error("For cluster_time_compare_summary, the export_type parameter only supports db.") + return False + try: + PathManager.check_input_directory_path(base_path) # 校验目录 + except RuntimeError: + logger.error(f"{base_path} is not valid.") + return False + if not DBManager.check_tables_in_db(self.db_path, Constant.TABLE_CLUSTER_TIME_SUMMARY): + logger.error(f"{Constant.TABLE_CLUSTER_TIME_SUMMARY} in {self.db_path} does not exist.") + return False + if not DBManager.check_tables_in_db(self.base_db_path, Constant.TABLE_CLUSTER_TIME_SUMMARY): + logger.error(f"{Constant.TABLE_CLUSTER_TIME_SUMMARY} in {self.base_db_path} does not exist.") + return False + return True + + + def get_compare_data(self): + database_service_for_db = DatabaseService(self.db_path) + database_service_for_db.add_table_for_query(Constant.TABLE_CLUSTER_TIME_SUMMARY, + self.CLUSTER_TIME_SUMMARY_COLUMNS) + cluster_time_summary_df_dict = database_service_for_db.query_data() + cluster_time_summary_df = cluster_time_summary_df_dict.get(Constant.TABLE_CLUSTER_TIME_SUMMARY) + database_service_for_base_db = DatabaseService(self.base_db_path) + database_service_for_base_db.add_table_for_query(Constant.TABLE_CLUSTER_TIME_SUMMARY, + self.CLUSTER_TIME_SUMMARY_COLUMNS) + base_cluster_time_summary_df_dict = database_service_for_base_db.query_data() + base_cluster_time_summary_df = base_cluster_time_summary_df_dict.get(Constant.TABLE_CLUSTER_TIME_SUMMARY) + index_cols = ["rank", "step"] + current_df = cluster_time_summary_df.set_index(index_cols) + base_df = base_cluster_time_summary_df.set_index(index_cols).add_suffix("Base") + merged_df = current_df.join(base_df).reset_index() + columns_order = index_cols + for col in self.CLUSTER_TIME_SUMMARY_COLUMNS: + if col not in index_cols: + base_col = f"{col}Base" + diff_col = f"{col}Diff" + merged_df[diff_col] = merged_df[col] - merged_df[base_col] + columns_order.extend([col, base_col, diff_col]) + self.compare_result = merged_df[columns_order].dropna() + + def save_db(self): + self.dump_data(self.compare_result, Constant.DB_CLUSTER_COMMUNICATION_ANALYZER, + self.TABLE_CLUSTER_TIME_COMPARE_SUMMARY, index=False) \ No newline at end of file diff --git a/profiler/msprof_analyze/cluster_analyse/recipes/cluster_time_summary/__init__.py b/profiler/msprof_analyze/cluster_analyse/recipes/cluster_time_summary/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/profiler/msprof_analyze/cluster_analyse/recipes/cluster_time_summary/cluster_time_summary.py b/profiler/msprof_analyze/cluster_analyse/recipes/cluster_time_summary/cluster_time_summary.py new file mode 100644 index 0000000000000000000000000000000000000000..0b9b8f47242bfa5d9fb80df22918a2e235314632 --- /dev/null +++ b/profiler/msprof_analyze/cluster_analyse/recipes/cluster_time_summary/cluster_time_summary.py @@ -0,0 +1,273 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import pandas as pd + +from msprof_analyze.cluster_analyse.common_func.context import ConcurrentContext +from msprof_analyze.cluster_analyse.common_func.table_constant import TableConstant +from msprof_analyze.cluster_analyse.common_func.utils import double_hash +from msprof_analyze.cluster_analyse.recipes.base_recipe_analysis import BaseRecipeAnalysis +from msprof_analyze.cluster_analyse.recipes.communication_group_map.communication_group_map import CommunicationGroupMap +from msprof_analyze.prof_common.constant import Constant +from msprof_analyze.prof_common.logger import get_logger +from msprof_analyze.prof_exports.cluster_time_summary_export import CommunicationTimeExport +from msprof_analyze.prof_exports.cluster_time_summary_export import MemoryAndDispatchTimeExport +from msprof_analyze.prof_common.database_service import DatabaseService +from msprof_analyze.prof_common.db_manager import DBManager + +logger = get_logger() + + +class OverlapInfo: + def __init__(self, start, end, overlap_type): + self.start = start + self.end = end + self.type = overlap_type + + +class ClusterTimeSummary(BaseRecipeAnalysis): + COMPUTING_TYPE = 0 + COMMUNICATION_TYPE = 1 + MEMORY_TYPE = 4 + STEP_TIME = "step_time" + STEP_TRACE = "step_trace" + COMMUNICATION = "communication" + MEMORY_AND_DISPATCH = "memory_and_dispatch" + + def __init__(self, params): + super().__init__(params) + self.params = params + self.db_paths = self._get_rank_db() + self.stats_data = None + + @property + def base_dir(self): + return os.path.basename(os.path.dirname(__file__)) + + @classmethod + def get_memory_not_overlap(cls, df: pd.DataFrame): + memory_not_overlap_time = 0 # free的时间段里面memory的总时间(异步拷贝) + cur_block = OverlapInfo(df.iloc[0]["start"], df.iloc[0]["start"], -1) + for time_info in df.itertuples(): + if cur_block.type == cls.MEMORY_TYPE: + tmp_start = cur_block.start + tmp_end = cur_block.end if time_info.start > cur_block.end else time_info.start + if tmp_start < tmp_end: + memory_not_overlap_time += tmp_end - tmp_start + if time_info.start > cur_block.end: + cur_block.end = time_info.end + cur_block.type = time_info.type + cur_block.start = time_info.start + else: + cur_block.type = time_info.type if time_info.end > cur_block.end else cur_block.type + cur_block.start = cur_block.end if time_info.end > cur_block.end else time_info.end + cur_block.end = time_info.end if time_info.end > cur_block.end else cur_block.end + # 此处为了添加最后一块数据 + if cur_block.type == cls.MEMORY_TYPE: + memory_not_overlap_time += cur_block.end - cur_block.start + return memory_not_overlap_time / Constant.TIME_UNIT_SCALE + + @classmethod + def calculate_dispatch_time(cls, df: pd.DataFrame) -> pd.DataFrame: + filtered_df = df[df['type'].isin([cls.COMPUTING_TYPE, cls.COMMUNICATION_TYPE])] + result = filtered_df.groupby(['step'])['dispatch'].mean().reset_index() + result = result.rename(columns={'dispatch': 'taskLaunchDelayAvgTime'}) + return result + + @classmethod + def calculate_memory_time(cls, df: pd.DataFrame) -> pd.DataFrame: + filtered_df = df[df['type'].isin([cls.MEMORY_TYPE])].copy() + filtered_df['memory'] = filtered_df['end'] - filtered_df['start'] + result = filtered_df.groupby(['step'])['memory'].sum().reset_index() + result['memory'] = result['memory'] / Constant.TIME_UNIT_SCALE + return result + + def calculate_step_time(self, data_map, analysis_class): + profiler_db_path = data_map.get(Constant.PROFILER_DB_PATH) + rank_id = data_map.get(Constant.RANK_ID) + data_service = DatabaseService(profiler_db_path) + data_service.add_table_for_query(Constant.TABLE_STEP_TIME, ["id", "startNs", "endNs"]) + df = data_service.query_data().get(Constant.TABLE_STEP_TIME) + if df is None or df.empty: + logger.warning(f"There is no STEP_TIME data in {profiler_db_path}.") + return None + df["stepTime"] = (df["endNs"] - df["startNs"]) / Constant.TIME_UNIT_SCALE + result_df = df[["id", "stepTime"]].rename(columns={"id": "step"}) + result_df.insert(0, "rank", rank_id) + return result_df + + def calculate_step_trace_time(self, data_map, analysis_class): + analysis_db_path = data_map.get(Constant.ANALYSIS_DB_PATH) + rank_id = data_map.get(Constant.RANK_ID) + data_service = DatabaseService(analysis_db_path) + data_service.add_table_for_query(Constant.TABLE_STEP_TRACE, ["step", "computing", + "communication_not_overlapped", "overlapped", + "communication", "free", ]) + df = data_service.query_data().get(Constant.TABLE_STEP_TRACE) + if df is None or df.empty: + logger.warning(f"There is no stats data in {analysis_db_path}.") + return None + df.insert(0, "rank", rank_id) + df["step"] = df["step"].astype(int) + return df + + def calculate_communication_time(self, data_map, analysis_class): + analysis_db_path = data_map.get(Constant.PROFILER_DB_PATH) + df = CommunicationTimeExport(analysis_db_path, analysis_class).read_export_db() + return df + + def calculate_transmit_and_wait_df(self, communication_df): + transmit_and_wait_df = pd.DataFrame(columns=["rank", "step", "communicationWaitStageTime", + "communicationTransmitStageTime"]) + cluster_db_path = os.path.join(self.output_path, Constant.DB_CLUSTER_COMMUNICATION_ANALYZER) + data_service = DatabaseService(cluster_db_path) + data_service.add_table_for_query(Constant.TABLE_COMMUNICATION_GROUP_MAPPING, + [TableConstant.RANK_SET, TableConstant.GROUP_ID]) + df_dict = data_service.query_data() + rank_set_df = df_dict.get(Constant.TABLE_COMMUNICATION_GROUP_MAPPING, None) + if rank_set_df is None or rank_set_df.empty: + logger.error(f"There is no {Constant.TABLE_COMMUNICATION_GROUP_MAPPING} data in {cluster_db_path}.") + return transmit_and_wait_df + + # 将"[2]"或者"[2,4,6,8]"这样从CommunicationGroupMapping的rank_set列读取出来的字符串转换为集合 + def parse_rank_set(rank_set): + try: + ranks_list = json.loads(rank_set) + return set(ranks_list) + except Exception as e: + logger.error(f"Failed to parse rank_set: {rank_set}, error: {e}") + return set() + + rank_set_df[TableConstant.RANK_SET] = rank_set_df[TableConstant.RANK_SET].apply(parse_rank_set) + # 这里两个表里面的group_name类型不一致 + group_to_ranks = dict(zip(rank_set_df[TableConstant.GROUP_ID], rank_set_df[TableConstant.RANK_SET])) + + # 自定义 filter 函数,检查一个 group 是否包含所有 required_ranks + def valid_group(group): + group_name = group.name[0] # group.name 是 (groupName, opName, step) 的元组 + required_ranks = group_to_ranks.get(group_name, set()) + actual_ranks = set(group['rank']) + return required_ranks.issubset(actual_ranks) + + communication_df["groupName"] = communication_df["groupName"].apply(double_hash) + filtered_df = (communication_df.groupby(["groupName", "opName", "step"], group_keys=False). + filter(valid_group)) + if filtered_df.empty: + logger.warning("No group satisfies the required rank set condition.") + return transmit_and_wait_df + filtered_df["communicationTransmitStageTime"] = \ + filtered_df.groupby(["groupName", "opName", "step"])["communication_time"].transform("min") + filtered_df["communicationWaitStageTime"] = \ + filtered_df["communication_time"] - filtered_df["communicationTransmitStageTime"] + transmit_and_wait_df = filtered_df.groupby(["rank", "step"])[ + ["communicationWaitStageTime", "communicationTransmitStageTime"]].sum().reset_index() + return transmit_and_wait_df + + def calculate_memory_and_dispatch_time(self, data_map, analysis_class): + """ + rank step memory memoryNotOverlapComputationCommunication taskLaunchDelayAvgTime + 0 1 120 150 200 + 0 2 130 150 200 + """ + profiler_db_path = data_map.get(Constant.PROFILER_DB_PATH) + rank_id = data_map.get(Constant.RANK_ID) + df = MemoryAndDispatchTimeExport(profiler_db_path, analysis_class).read_export_db() + if df is None or df.empty: + logger.warning(f"There is no stats data in {profiler_db_path}.") + return None + memory_df = ClusterTimeSummary.calculate_memory_time(df) + memory_not_overlap_df = (df.groupby(["step"]).apply(ClusterTimeSummary.get_memory_not_overlap). + reset_index(name="memoryNotOverlapComputationCommunication")) + dispatch_df = ClusterTimeSummary.calculate_dispatch_time(df) + result_df = pd.merge(memory_df, memory_not_overlap_df, on='step', how='inner') + result_df = pd.merge(result_df, dispatch_df, on='step', how='inner') + result_df.insert(0, "rank", rank_id) + return result_df + + def aggregate_stats(self, context: ConcurrentContext): + step_time_df_list = [future.result() for future in context.future_dict[ClusterTimeSummary.STEP_TIME]] + step_trace_df_list = [future.result() for future in context.future_dict[ClusterTimeSummary.STEP_TRACE]] + communication_df_list = [ + future.result() + for future in context.future_dict[ClusterTimeSummary.COMMUNICATION] + ] + memory_and_dispatch_df_list = [ + future.result() + for future in context.future_dict[ClusterTimeSummary.MEMORY_AND_DISPATCH] + ] + step_time_df = pd.concat(step_time_df_list, ignore_index=True) + step_trace_df = pd.concat(step_trace_df_list, ignore_index=True) + communication_df = pd.concat(communication_df_list, ignore_index=True) + memory_and_dispatch_df = pd.concat(memory_and_dispatch_df_list, ignore_index=True) + transmit_and_wait_df = self.calculate_transmit_and_wait_df(communication_df) + all_dfs = [step_time_df, step_trace_df, transmit_and_wait_df, memory_and_dispatch_df] + merged_df = all_dfs[0] + for df in all_dfs[1:]: + merged_df = pd.merge(merged_df, df, on=['rank', 'step'], how='outer') + # 根据 step 和 rank 列对合并后的 DataFrame 进行排序 + merged_df = merged_df.sort_values(by=['rank', 'step']) + merged_df["free"] = merged_df["free"] - merged_df["memoryNotOverlapComputationCommunication"] + # 单卡场景,通信传输时间和等待时间全部置0 + if len(communication_df_list) == 1: + merged_df[['communicationWaitStageTime', 'communicationTransmitStageTime']] = 0 + merged_df = merged_df.rename(columns={ + 'computing': 'computation', + 'overlapped': 'communicationOverlapComputation', + 'communication_not_overlapped': 'communicationNotOverlapComputation'}) + return merged_df.sort_values(by=['rank', 'step']) + + def mapper_func(self, context: ConcurrentContext): + for db_map in self.db_paths: + context.submit(self.STEP_TIME, self.calculate_step_time, db_map, self._recipe_name) + context.submit(self.STEP_TRACE, self.calculate_step_trace_time, db_map, self._recipe_name) + context.submit(self.COMMUNICATION, self.calculate_communication_time, + db_map, self._recipe_name) + context.submit(self.MEMORY_AND_DISPATCH, self.calculate_memory_and_dispatch_time, + db_map, self._recipe_name) + + def run(self, context: ConcurrentContext): + db_path = os.path.join(self.output_path, Constant.DB_CLUSTER_COMMUNICATION_ANALYZER) + if not DBManager.check_tables_in_db(db_path, Constant.TABLE_COMMUNICATION_GROUP_MAPPING): + if not self.run_communication_group_map_recipe(context) or \ + not DBManager.check_tables_in_db(db_path, Constant.TABLE_COMMUNICATION_GROUP_MAPPING): + logger.error("Create CommunicationGroupMap table failed!") + return + logger.info("ClusterTimeSummary init.") + self.mapper_func(context) + context.wait_all_futures() + self.stats_data = self.aggregate_stats(context) + if self._export_type == Constant.DB: + self.save_db() + else: + logger.warning("cluster_time_summary only supports export db.") + + def run_communication_group_map_recipe(self, context): + """ + Run Recipe to create CommunicationGroupMapping table + """ + logger.info(f"Run CommunicationGroupMap recipe first to get {Constant.TABLE_COMMUNICATION_GROUP_MAPPING} table") + try: + group_map_recipe = CommunicationGroupMap(self.params) + group_map_recipe.run(context) + except Exception as e: + logger.error(f"Run CommunicationGroupMap recipe failed: {e}!") + return False + return True + + def save_db(self): + self.dump_data(self.stats_data, Constant.DB_CLUSTER_COMMUNICATION_ANALYZER, + Constant.TABLE_CLUSTER_TIME_SUMMARY, index=False) diff --git a/profiler/msprof_analyze/cluster_analyse/recipes/communication_group_map/__init__.py b/profiler/msprof_analyze/cluster_analyse/recipes/communication_group_map/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/profiler/msprof_analyze/cluster_analyse/recipes/communication_group_map/communication_group_map.py b/profiler/msprof_analyze/cluster_analyse/recipes/communication_group_map/communication_group_map.py new file mode 100644 index 0000000000000000000000000000000000000000..37214274c7ee9b55bf4508cfa75395a1e822e14d --- /dev/null +++ b/profiler/msprof_analyze/cluster_analyse/recipes/communication_group_map/communication_group_map.py @@ -0,0 +1,112 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import os +import pandas as pd + +from msprof_analyze.cluster_analyse.common_func.utils import double_hash +from msprof_analyze.cluster_analyse.common_func.table_constant import TableConstant +from msprof_analyze.cluster_analyse.recipes.base_recipe_analysis import BaseRecipeAnalysis +from msprof_analyze.prof_common.constant import Constant +from msprof_analyze.prof_common.logger import get_logger +from msprof_analyze.prof_common.database_service import DatabaseService + +logger = get_logger() + + +class CommunicationGroupMap(BaseRecipeAnalysis): + GLOBAL_RANKS = "global_ranks" + + def __init__(self, params): + super().__init__(params) + logger.info("CommunicationGroupMap init.") + self.group_df = None + + @property + def base_dir(self): + return os.path.basename(os.path.dirname(__file__)) + + @staticmethod + def get_comm_type_from_op_name(op_name: str): + op_name_lower = op_name.lower() + return Constant.P2P if ("send" in op_name_lower or "receive" in op_name_lower or "recv" in op_name_lower) \ + else Constant.COLLECTIVE + + def run(self, context): + mapper_res = self.mapper_func(context) + self.reducer_func(mapper_res) + if self._export_type == Constant.DB: + self.save_db() + else: + logger.error(f"CommGroupMap: {self._export_type} is not supported for export type.") + + def reducer_func(self, mapper_res): + # concat and process all comm group + comm_group_df_list = [df for df, _ in mapper_res] + comm_group_combined_df = pd.concat(comm_group_df_list).drop_duplicates() + # concat all parallel group info + parallel_info_df_list = [df for _, df in mapper_res] + parallel_info_combined_df = pd.concat(parallel_info_df_list).drop_duplicates() + # merge by group_name + group_df = pd.merge(comm_group_combined_df, parallel_info_combined_df, on=TableConstant.GROUP_NAME, how="left") + group_df.fillna("", inplace=True) + group_df.sort_values(by=[TableConstant.TYPE], inplace=True) + # column order + column_order = [TableConstant.TYPE, TableConstant.RANK_SET, TableConstant.GROUP_NAME, + TableConstant.GROUP_ID, TableConstant.PG_NAME] + self.group_df = group_df[column_order] + + def save_db(self): + self.dump_data(self.group_df, Constant.DB_CLUSTER_COMMUNICATION_ANALYZER, + Constant.TABLE_COMMUNICATION_GROUP_MAPPING, index=False) + + def _mapper_func(self, data_map, analysis_class): + # read CommAnalyzerTime table + analysis_db_path = data_map.get(Constant.ANALYSIS_DB_PATH) + analysis_data_service = DatabaseService(analysis_db_path) + analysis_data_service.add_table_for_query(Constant.TABLE_COMM_ANALYZER_TIME, + [TableConstant.HCCL_OP_NAME, TableConstant.GROUP_NAME]) + comm_time_res = analysis_data_service.query_data() + # process comm_time_df: group_name, type + comm_time_df = comm_time_res.get(Constant.TABLE_COMM_ANALYZER_TIME) + comm_time_df[TableConstant.TYPE] = (comm_time_df[TableConstant.HCCL_OP_NAME]. + apply(lambda x: self.get_comm_type_from_op_name(x))) + comm_time_df = comm_time_df.drop(columns=[TableConstant.HCCL_OP_NAME]) + comm_time_df = comm_time_df.drop_duplicates() + + # read META_DATA table + profiler_db_path = data_map.get(Constant.PROFILER_DB_PATH) + profiler_data_service = DatabaseService(profiler_db_path) + profiler_data_service.add_table_for_query(Constant.TABLE_META_DATA, + [TableConstant.NAME, TableConstant.VALUE]) + meta_data_res = profiler_data_service.query_data() + meta_data_df = meta_data_res.get(Constant.TABLE_META_DATA) + # process parallel_info_df + parallel_info_df = pd.DataFrame(columns=[TableConstant.GROUP_NAME, TableConstant.GROUP_ID, + TableConstant.PG_NAME, TableConstant.RANK_SET]) + if Constant.PARALLEL_GROUP_INFO not in meta_data_df[TableConstant.NAME].values: + return comm_time_df, parallel_info_df + info_str = meta_data_df.loc[meta_data_df[TableConstant.NAME] == Constant.PARALLEL_GROUP_INFO, + TableConstant.VALUE].values[0] + info_dict = json.loads(info_str) + for group_id, parallel_info in info_dict.items(): + group_name = str(double_hash(group_id)) # group_name is hashed group_id + pg_name = parallel_info.get(TableConstant.GROUP_NAME, "") + rank_set = parallel_info.get(self.GLOBAL_RANKS, []) + if not pg_name or not rank_set: + continue + parallel_info_df.loc[parallel_info_df.shape[0]] = [group_name, group_id, pg_name, str(rank_set)] + + return comm_time_df, parallel_info_df \ No newline at end of file diff --git a/profiler/msprof_analyze/cluster_analyse/recipes/compute_op_sum/compute_op_sum.py b/profiler/msprof_analyze/cluster_analyse/recipes/compute_op_sum/compute_op_sum.py index 528534be399e3ceacadbe7d1acf7294d7b3ff37d..a5d44c3f17f6d2a31f097506c38d829c18d5d74f 100644 --- a/profiler/msprof_analyze/cluster_analyse/recipes/compute_op_sum/compute_op_sum.py +++ b/profiler/msprof_analyze/cluster_analyse/recipes/compute_op_sum/compute_op_sum.py @@ -108,11 +108,10 @@ class ComputeOpSum(BaseRecipeAnalysis): def _mapper_func(self, data_map, analysis_class): profiler_db_path = data_map.get(Constant.PROFILER_DB_PATH) rank_id = data_map.get(Constant.RANK_ID) - step_range = data_map.get(Constant.STEP_RANGE) if self.exclude_op_name: - df = ComputeOpSumExportExcludeOpName(profiler_db_path, analysis_class, step_range).read_export_db() + df = ComputeOpSumExportExcludeOpName(profiler_db_path, analysis_class).read_export_db() else: - df = ComputeOpSumExport(profiler_db_path, analysis_class, step_range).read_export_db() + df = ComputeOpSumExport(profiler_db_path, analysis_class).read_export_db() if df is None or df.empty: logger.warning(f"There is no stats data in {profiler_db_path}.") return None diff --git a/profiler/msprof_analyze/cluster_analyse/recipes/ep_load_balance/__init__.py b/profiler/msprof_analyze/cluster_analyse/recipes/ep_load_balance/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b14094e3f9a77a0970342980ed8de1017f58ce19 --- /dev/null +++ b/profiler/msprof_analyze/cluster_analyse/recipes/ep_load_balance/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. \ No newline at end of file diff --git a/profiler/msprof_analyze/cluster_analyse/recipes/ep_load_balance/ep_load_balance.py b/profiler/msprof_analyze/cluster_analyse/recipes/ep_load_balance/ep_load_balance.py new file mode 100644 index 0000000000000000000000000000000000000000..c1c3eef9ba18b942dc03580ca5c6ee1a1da14607 --- /dev/null +++ b/profiler/msprof_analyze/cluster_analyse/recipes/ep_load_balance/ep_load_balance.py @@ -0,0 +1,131 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import json + +import pandas as pd + +from msprof_analyze.cluster_analyse.recipes.base_recipe_analysis import BaseRecipeAnalysis +from msprof_analyze.prof_common.constant import Constant +from msprof_analyze.prof_common.logger import get_logger +from msprof_analyze.prof_exports.ep_load_balance_export import InputShapeExport +from msprof_analyze.prof_common.database_service import DatabaseService + +logger = get_logger() + + +class EPLoadBalance(BaseRecipeAnalysis): + + EP_TOKENS_SUMMARY = "EPTokensSummary" + TOP_EP_TOKENS_INFO = "TopEPTokensInfo" + META_DATA = "META_DATA" + Top_Num = 20 + GROUPEP = "exp" + + def __init__(self, params): + super().__init__(params) + logger.info("EPLoadBalance init.") + self.ep_tokens_summary = None + self.top_ep_tokens_map = None + + @property + def base_dir(self): + return os.path.basename(os.path.dirname(__file__)) + + def process_input_shapes(self, df): + def calculate_seqlength(shape_str): + shape_str = shape_str.strip('"') + parts = shape_str.split(";") + non_empty_parts = [part for part in parts if part] + # 取前 n-2 个有值的部分 + if len(non_empty_parts) > 1: + non_empty_parts = non_empty_parts[: len(non_empty_parts) - 2] + else: + return None + seqlength = 0 + for part in non_empty_parts: + part = part.strip() + try: + first_dim = int(part.split(",")[0]) + except (IndexError, ValueError) as e: + return None + seqlength += first_dim + return seqlength + + df["InputShapes"] = df["InputShapes"].apply(calculate_seqlength) + return df + + def reducer_func(self, mapper_res): + mapper_res = list(filter(lambda df: df is not None, mapper_res)) + if not mapper_res: + logger.error("Mapper data is None.") + return + for i, df in enumerate(mapper_res): + mapper_res[i] = self.process_input_shapes(df) + mapper_res = [df.dropna() for df in mapper_res] + for df in mapper_res: + df["epRanks"] = df["epRanks"].apply(lambda x: ",".join(map(str, x))) + combined_df = pd.concat(mapper_res) + self.ep_tokens_summary = combined_df.groupby(["Rank", "epRanks"]).agg({"InputShapes": "sum"}).reset_index() + self.ep_tokens_summary.columns = ["rank", "epRanks", "inputShapesSummary"] + self.top_ep_tokens_map = ( + self.ep_tokens_summary.groupby("epRanks")["inputShapesSummary"] + .agg(tokensDiff=lambda x: x.max() - x.min()) + .reset_index() + ) + self.top_ep_tokens_map = self.top_ep_tokens_map.sort_values(by="tokensDiff", ascending=False).head(self.Top_Num) + + def run(self, context): + mapper_res = self.mapper_func(context) + self.reducer_func(mapper_res) + + if self._export_type == "db": + self.save_db() + else: + logger.error("ep_load_balance is only supported for db export type.") + + def save_db(self): + self.dump_data(self.ep_tokens_summary, Constant.DB_CLUSTER_COMMUNICATION_ANALYZER, self.EP_TOKENS_SUMMARY, + index=False) + self.dump_data(self.top_ep_tokens_map, Constant.DB_CLUSTER_COMMUNICATION_ANALYZER, self.TOP_EP_TOKENS_INFO, + index=False) + + def _mapper_func(self, data_map, analysis_class): + profiler_db_path = data_map.get(Constant.PROFILER_DB_PATH) + rank_id = data_map.get(Constant.RANK_ID) + step_range = data_map.get(Constant.STEP_RANGE) + analysis_data_service = DatabaseService(profiler_db_path, {}) + analysis_data_service.add_table_for_query(self.META_DATA) + meta_map = analysis_data_service.query_data()[self.META_DATA] + parallel_group_info = meta_map.loc[meta_map['name'] == 'parallel_group_info', 'value'].iloc[0] + try: + data_dict = json.loads(parallel_group_info) + except json.JSONDecodeError as e: + logger.error(f"{profiler_db_path}'s parallel_group_info is illegal") + return None + if not isinstance(data_dict, dict): + raise TypeError('{} must be dict, not {}.'.format(data_dict, type(data_dict).__name__)) + for _, value in data_dict.items(): + if value["group_name"] == self.GROUPEP: + global_ranks = value["global_ranks"] + break + df = InputShapeExport(profiler_db_path, analysis_class, step_range).read_export_db() + if df is None or df.empty: + logger.warning(f"There is no stats data in {profiler_db_path}.") + return None + df["Rank"] = rank_id + df["epRanks"] = [global_ranks] * len(df) + return df \ No newline at end of file diff --git a/profiler/msprof_analyze/cluster_analyse/recipes/filter_db/__init__.py b/profiler/msprof_analyze/cluster_analyse/recipes/filter_db/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b14094e3f9a77a0970342980ed8de1017f58ce19 --- /dev/null +++ b/profiler/msprof_analyze/cluster_analyse/recipes/filter_db/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. \ No newline at end of file diff --git a/profiler/msprof_analyze/cluster_analyse/recipes/filter_db/filter_db.py b/profiler/msprof_analyze/cluster_analyse/recipes/filter_db/filter_db.py new file mode 100644 index 0000000000000000000000000000000000000000..29db8f637376fa04629e5b728a2aa53c9251944c --- /dev/null +++ b/profiler/msprof_analyze/cluster_analyse/recipes/filter_db/filter_db.py @@ -0,0 +1,80 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import shutil + +from msprof_analyze.prof_common.db_manager import DBManager +from msprof_analyze.cluster_analyse.recipes.base_recipe_analysis import BaseRecipeAnalysis +from msprof_analyze.prof_common.constant import Constant +from msprof_analyze.prof_common.logger import get_logger +from msprof_analyze.prof_common.path_manager import PathManager +from msprof_analyze.prof_exports.filter_db_export import OPFilter +from msprof_analyze.prof_exports.filter_db_export import TaskFilter +from msprof_analyze.prof_exports.filter_db_export import CANNFilter +from msprof_analyze.prof_exports.filter_db_export import PYTORCHFilter + +logger = get_logger() + +FILTER_COMPUTE = "COMPUTE_TASK_INFO" +FILTER_TASK = "TASK" +FILTER_CANN = "CANN_API" +FILTER_PYTORCH = "PYTORCH_API" + + +class DatabaseFilter(BaseRecipeAnalysis): + def __init__(self, params): + super().__init__(params) + logger.info("filter_db init.") + + @property + def base_dir(self): + return os.path.basename(os.path.dirname(__file__)) + + def run(self, context): + mapper_res = self.mapper_func(context) + logger.info("Filtering database completed.") + + def _mapper_func(self, data_map, analysis_class): + profiler_db_path = data_map.get(Constant.PROFILER_DB_PATH) + rank_id = data_map.get(Constant.RANK_ID) + + paths = profiler_db_path.split(os.path.sep) + sub_path = os.path.join(*paths[-3:-1]) + + output_path = os.path.join(self._output_path, "filter_db", sub_path) + PathManager.make_dir_safety(output_path) + + filtered_db = os.path.join(output_path, f"ascend_pytorch_profiler_{rank_id}.db") + shutil.copyfile(profiler_db_path, filtered_db) + + conn, cursor = DBManager.create_connect_db(filtered_db) + + op = OPFilter(filtered_db, analysis_class).read_export_db() + op.to_sql(FILTER_COMPUTE, conn, if_exists="replace", index=False) + task = TaskFilter(filtered_db, analysis_class).read_export_db() + task.to_sql(FILTER_TASK, conn, if_exists="replace", index=False) + cann = CANNFilter(filtered_db, analysis_class).read_export_db() + cann.to_sql(FILTER_CANN, conn, if_exists="replace", index=False) + pytorch = PYTORCHFilter(filtered_db, analysis_class).read_export_db() + pytorch.to_sql(FILTER_PYTORCH, conn, if_exists="replace", index=False) + + DBManager.execute_sql(conn, "DROP TABLE IF EXISTS COMMUNICATION_TASK_INFO;") + DBManager.execute_sql(conn, "DROP TABLE IF EXISTS TASK_PMU_INFO;") + + cursor.execute("VACUUM;") + conn.commit() + + DBManager.destroy_db_connect(conn, cursor) diff --git a/profiler/msprof_analyze/cluster_analyse/recipes/freq_analysis/__init__.py b/profiler/msprof_analyze/cluster_analyse/recipes/freq_analysis/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b14094e3f9a77a0970342980ed8de1017f58ce19 --- /dev/null +++ b/profiler/msprof_analyze/cluster_analyse/recipes/freq_analysis/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. \ No newline at end of file diff --git a/profiler/msprof_analyze/cluster_analyse/recipes/freq_analysis/freq_analysis.py b/profiler/msprof_analyze/cluster_analyse/recipes/freq_analysis/freq_analysis.py new file mode 100644 index 0000000000000000000000000000000000000000..cf510efac40f829c3fdd976e5d948a321f41f016 --- /dev/null +++ b/profiler/msprof_analyze/cluster_analyse/recipes/freq_analysis/freq_analysis.py @@ -0,0 +1,112 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from collections import defaultdict +import pandas as pd + +from msprof_analyze.cluster_analyse.recipes.base_recipe_analysis import BaseRecipeAnalysis +from msprof_analyze.prof_common.constant import Constant +from msprof_analyze.prof_common.logger import get_logger +from msprof_analyze.prof_common.database_service import DatabaseService + +logger = get_logger() + + +class FreqAnalysis(BaseRecipeAnalysis): + COMMON_FREQ = 1800 + FREE_FREQ = 800 + + def __init__(self, params): + super().__init__(params) + self.free_freq_ranks = [] + self.abnormal_freq_ranks = [] + self.abnormal_freq_ranks_map = {} + + @property + def base_dir(self): + return os.path.basename(os.path.dirname(__file__)) + + def reducer_func(self, mapper_res): + if self._is_msprof: + logger.warning("Freq analysis do not support msprof db now.") + return + mapper_res = list(filter(lambda res: res[0] is not None, mapper_res)) + if not mapper_res: + logger.error("Mapper data is None, load profiling data failed.") + return + for freqs, rank_id in mapper_res: + if freqs == [self.COMMON_FREQ]: + continue + elif set(freqs) == {self.COMMON_FREQ, self.FREE_FREQ}: + self.free_freq_ranks.append(rank_id) + else: + self.abnormal_freq_ranks.append(rank_id) + self.abnormal_freq_ranks_map[rank_id] = str(freqs) + self.free_freq_ranks.sort() + self.abnormal_freq_ranks.sort() + + def save_db(self): + if len(self.free_freq_ranks) > 0: + logger.info(f"Found {len(self.free_freq_ranks)} ranks with free time, " + f"aicore frequency in {[self.FREE_FREQ, self.COMMON_FREQ]}.") + free_ranks_df = pd.DataFrame() + free_ranks_df["rankId"] = self.free_freq_ranks + free_ranks_df["aicoreFrequency"] = str([self.FREE_FREQ, self.COMMON_FREQ]) + free_ranks_df.set_index(["rankId"], inplace=True) + self.dump_data(free_ranks_df, Constant.DB_CLUSTER_COMMUNICATION_ANALYZER, "FreeFrequencyRanks") + else: + logger.info("No rank found with free time.") + if len(self.abnormal_freq_ranks) > 0: + logger.info(f"Found {len(self.abnormal_freq_ranks)} ranks with abnormal aicore frequency.") + + abnormal_ranks_df = pd.DataFrame.from_dict(self.abnormal_freq_ranks_map, + orient="index", columns=["aicoreFrequency"]) + abnormal_ranks_df = abnormal_ranks_df.reset_index().rename(columns={"index": "rankId"}) + abnormal_ranks_df.set_index(["rankId"], inplace=True) + self.dump_data(abnormal_ranks_df, Constant.DB_CLUSTER_COMMUNICATION_ANALYZER, "AbnormalFrequencyRanks") + else: + logger.info("No rank found with abnormal aicore frequency.") + if len(self.free_freq_ranks) > 0 or len(self.abnormal_freq_ranks) > 0: + logger.info("Please verify result in output file.") + + def run(self, context): + mapper_res = self.mapper_func(context) + self.reducer_func(mapper_res) + + if self._export_type == Constant.DB: + self.save_db() + else: + logger.error("Frequence analysis is not supported for notebook export type.") + + def _mapper_func(self, data_map, analysis_class): + profiler_db_path = data_map.get(Constant.PROFILER_DB_PATH) + service = DatabaseService(profiler_db_path, None) + service.add_table_for_query("AICORE_FREQ", ["deviceId", "freq"]) + service.add_table_for_query("RANK_DEVICE_MAP", ["rankId"]) + service_res = service.query_data() + aic_freq = service_res.get("AICORE_FREQ", None) + rank_id = service_res.get("RANK_DEVICE_MAP", None) + if aic_freq is None or aic_freq.empty: + logger.error(f"No aic freq data found in {profiler_db_path}.") + return None, None + if rank_id is None or rank_id.empty: + logger.error(f"No rank_id data found in {profiler_db_path}.") + return None, None + rank_id = rank_id["rankId"].values[0] + freq_arr = aic_freq["freq"].values + freqs = list(set(freq_arr)) + freqs.sort() + return freqs, rank_id diff --git a/profiler/msprof_analyze/cluster_analyse/recipes/hccl_sum/hccl_sum.py b/profiler/msprof_analyze/cluster_analyse/recipes/hccl_sum/hccl_sum.py index 84ff40ac7e5d78d6ea30127739e18dfd1654e2c0..a78603ee0ac2894fb8b60a21f411e7fef9d144db 100644 --- a/profiler/msprof_analyze/cluster_analyse/recipes/hccl_sum/hccl_sum.py +++ b/profiler/msprof_analyze/cluster_analyse/recipes/hccl_sum/hccl_sum.py @@ -128,10 +128,9 @@ class HcclSum(BaseRecipeAnalysis): def _mapper_func(self, data_map, analysis_class): profiler_db_path = data_map.get(Constant.PROFILER_DB_PATH) rank_id = data_map.get(Constant.RANK_ID) - step_range = data_map.get(Constant.STEP_RANGE) - df = HcclSumExport(profiler_db_path, analysis_class, step_range).read_export_db() + df = HcclSumExport(profiler_db_path, analysis_class).read_export_db() if df is None or df.empty: logger.warning(f"There is no stats data in {profiler_db_path}.") return None df["Rank"] = rank_id - return df + return df \ No newline at end of file diff --git a/profiler/msprof_analyze/cluster_analyse/recipes/hccl_sum/stats.ipynb b/profiler/msprof_analyze/cluster_analyse/recipes/hccl_sum/stats.ipynb index 51a08a854b97161ba8e88ec94809b728582d6631..87f8c6d736240531e2c28c0cf33df087ecfe38e8 100644 --- a/profiler/msprof_analyze/cluster_analyse/recipes/hccl_sum/stats.ipynb +++ b/profiler/msprof_analyze/cluster_analyse/recipes/hccl_sum/stats.ipynb @@ -4,9 +4,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# COMMUNICATION Summary\n", + "# HCCL Summary\n", "\n", - "集群场景通信算子数据分析\n", + "集群场景Hccl算子数据分析\n", "\n", "主要包含以下3个统计内容:\n", "1. 按算子类型分组的,整个集群通信算子耗时的统计情况\n", diff --git a/profiler/msprof_analyze/cluster_analyse/recipes/mstx2commop/__init__.py b/profiler/msprof_analyze/cluster_analyse/recipes/mstx2commop/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7101187a2c2619f3b1c20dded14b433950b4c662 --- /dev/null +++ b/profiler/msprof_analyze/cluster_analyse/recipes/mstx2commop/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/profiler/msprof_analyze/cluster_analyse/recipes/mstx2commop/mstx2commop.py b/profiler/msprof_analyze/cluster_analyse/recipes/mstx2commop/mstx2commop.py new file mode 100644 index 0000000000000000000000000000000000000000..77a7095abbba9af1bcd3750714dc73c34a15925d --- /dev/null +++ b/profiler/msprof_analyze/cluster_analyse/recipes/mstx2commop/mstx2commop.py @@ -0,0 +1,178 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import pandas as pd + +from msprof_analyze.cluster_analyse.recipes.base_recipe_analysis import BaseRecipeAnalysis +from msprof_analyze.prof_common.db_manager import DBManager +from msprof_analyze.prof_common.constant import Constant +from msprof_analyze.prof_common.logger import get_logger +from msprof_analyze.prof_exports.mstx2commop_export import Mstx2CommopExport +from msprof_analyze.prof_common.database_service import DatabaseService + +logger = get_logger() + +TABLE_COMMUNICATION_OP = "COMMUNICATION_OP" +TABLE_STRING_IDS = "STRING_IDS" + + +def double_hash(data): + uint32_bits = 32 + uint32_max = 0xFFFFFFFF # 32 位无符号整数的最大值 + prime = [29, 131] + hash_values = [0, 0] + + for d in data: + hash_values[0] = (hash_values[0] * prime[0] + ord(d)) & uint32_max + hash_values[1] = (hash_values[1] * prime[1] + ord(d)) & uint32_max + + return ((hash_values[0] << uint32_bits) | hash_values[1]) + + +class Mstx2Commop(BaseRecipeAnalysis): + + def __init__(self, params): + super().__init__(params) + logger.info("Mstx2Commop init.") + self.communication_op = None + self.string_ids_insert = None + + @property + def base_dir(self): + return os.path.basename(os.path.dirname(__file__)) + + def run(self, context): + self.mapper_func(context) + + def _mapper_func(self, data_map, analysis_class): + profiler_db_path = data_map.get(Constant.PROFILER_DB_PATH) + data_service = DatabaseService(profiler_db_path) + data_service.add_table_for_query("ENUM_HCCL_DATA_TYPE", ["id", "name"]) + data_service.add_table_for_query("STRING_IDS", ["id", "value"]) + df_dict = data_service.query_data() + + df = Mstx2CommopExport(profiler_db_path, analysis_class).read_export_db() + + if df is None or df.empty: + logger.warning(f"There is no stats data in {profiler_db_path}.") + return None + + df_hccl_dt = df_dict.get("ENUM_HCCL_DATA_TYPE") + + if df_hccl_dt is None or df_hccl_dt.empty: + logger.warning(f"There is no stats data in {profiler_db_path}.") + return None + + df_string_ids = df_dict.get("STRING_IDS") + + if df_string_ids is None or df_string_ids.empty: + logger.warning(f"There is no stats data in {profiler_db_path}.") + return None + + value_len = 4 + optype_index, op_start_index = 0, 9 + groupname_index, datatype_index, count_index = 1, 2, 3 + + # json格式数据转化 + if df.loc[0, 'value'][0] == '{': + df['value'] = df['value'].apply(lambda x: json.loads(x)) + df['opType_primal'] = df['value'].apply(lambda x: x['opName'] + '_') + df['groupName_primal'] = df['value'].apply(lambda x: x['groupName']) + df['dataType'] = df['value'].apply(lambda x: x['dataType']) + df['count'] = df['value'].apply(lambda x: x['count']) + # 非json格式数据转化 + else: + df['value_list'] = df['value'].apply(lambda x: x.split(',')) + df['value_list_len'] = df['value_list'].apply(len) + df = df[df['value_list_len'] == value_len] + df['opType_primal'] = df['value_list'].apply(lambda x: 'hcom_' + x[optype_index][op_start_index:] + '_') + df['groupName_primal'] = df['value_list'].apply(lambda x: x[groupname_index]) + df['dataType'] = df['value_list'].apply(lambda x: x[datatype_index]) + df['count'] = df['value_list'].apply(lambda x: x[count_index]) + + df['groupName_hash'] = df['groupName_primal'].apply(double_hash).apply(str) + + df['gN_oT'] = df['groupName_primal'] + df['opType_primal'] + + gnot_set = set(list(df['gN_oT'])) + + df_concat = pd.DataFrame() + for g_o in gnot_set: + df_split = df[df['gN_oT'] == g_o] + df_split = df_split.copy() + df_split['queue'] = list(range(len(df_split))) + df_concat = pd.concat([df_concat, df_split], axis=0) + + df_concat['queue'] = df_concat['queue'].apply(str) + + df_concat['groupId'] = df_concat['groupName_hash'].apply(lambda x: "_" + x[-3:]) + + df_concat['opName_primal'] = df_concat['opType_primal'] + df_concat['groupId'] + '_' + df_concat['queue'] + '_1' + + df_concat['opId'] = list(range(len(df_concat))) + df_concat['relay'] = None + df_concat['retry'] = None + df_concat['algType'] = None + + df_hccl_dt['name'] = df_hccl_dt['name'].apply(lambda x: x.lower()) + hccl_data_type_dict = dict(zip(df_hccl_dt['name'], df_hccl_dt['id'])) + + string_ids_dict = dict(zip(df_string_ids['value'], df_string_ids['id'])) + + string_ids_max = df_string_ids['id'].max() + + df_concat['dataType'] = df_concat['dataType'].apply(lambda x: hccl_data_type_dict[x]) + + df_concat['string_id_opType_primal'] = df_concat['opType_primal'].apply( + lambda x: 1 if x in string_ids_dict else 0) + df_concat['string_id_opName_primal'] = df_concat['opName_primal'].apply( + lambda x: 1 if x in string_ids_dict else 0) + df_concat['string_id_groupName_primal'] = df_concat['groupName_primal'].apply( + lambda x: 1 if x in string_ids_dict else 0) + optype_primal_list = list(set(df_concat[df_concat['string_id_opType_primal'] == 0]['opType_primal'])) + opname_primal_list = list(set(df_concat[df_concat['string_id_opName_primal'] == 0]['opName_primal'])) + groupname_primal_list = list(set(df_concat[df_concat['string_id_groupName_primal'] == 0]['groupName_primal'])) + + special_primal_list = optype_primal_list + opname_primal_list + groupname_primal_list + special_id_list = list(range(string_ids_max + 1, string_ids_max + len(special_primal_list) + 1)) + + special_id_dict = dict(zip(special_primal_list, special_id_list)) + + df_concat['opType'] = df_concat['opType_primal'].apply( + lambda x: string_ids_dict[x] if x in string_ids_dict else special_id_dict[x] + ) + df_concat['opName'] = df_concat['opName_primal'].apply( + lambda x: string_ids_dict[x] if x in string_ids_dict else special_id_dict[x] + ) + df_concat['groupName'] = df_concat['groupName_primal'].apply( + lambda x: string_ids_dict[x] if x in string_ids_dict else special_id_dict[x] + ) + + communication_op = df_concat[ + ['opName', 'startNs', 'endNs', 'connectionId', 'groupName', 'opId', 'relay', 'retry', 'dataType', 'algType', + 'count', 'opType']] + communication_op = communication_op.copy() + communication_op.sort_values('startNs', ascending=True, inplace=True) + communication_op.set_index('opId', inplace=True) + string_ids_insert = list(map(list, zip(special_id_list, special_primal_list))) + + DBManager.insert_data_into_db(data_map.get(Constant.PROFILER_DB_PATH), TABLE_STRING_IDS, string_ids_insert) + + self.dump_data(data=communication_op, file_name=data_map.get(Constant.PROFILER_DB_PATH), + table_name=TABLE_COMMUNICATION_OP, custom_db_path=data_map.get(Constant.PROFILER_DB_PATH)) + + return data_map.get(Constant.RANK_ID) diff --git a/profiler/msprof_analyze/cluster_analyse/recipes/mstx_sum/mstx_sum.py b/profiler/msprof_analyze/cluster_analyse/recipes/mstx_sum/mstx_sum.py index bfbcc6ffb49c6457cd54a9413e8bf7a145ec365b..69b4b056850b85a856634c7feb0121cfcb34494b 100644 --- a/profiler/msprof_analyze/cluster_analyse/recipes/mstx_sum/mstx_sum.py +++ b/profiler/msprof_analyze/cluster_analyse/recipes/mstx_sum/mstx_sum.py @@ -154,11 +154,10 @@ class MstxSum(BaseRecipeAnalysis): def _mapper_func(self, data_map, analysis_class): profiler_db_path = data_map.get(Constant.PROFILER_DB_PATH) rank_id = data_map.get(Constant.RANK_ID) - step_range = data_map.get(Constant.STEP_RANGE) - step_df = MstxStepExport(profiler_db_path, analysis_class, step_range).read_export_db() + step_df = MstxStepExport(profiler_db_path, analysis_class).read_export_db() if step_df is None or step_df.empty: step_df = pd.DataFrame({"start_ns": [0], "end_ns": [float("inf")], "step_id": [0]}) - mark_df = MstxMarkExport(profiler_db_path, analysis_class, step_range).read_export_db() + mark_df = MstxMarkExport(profiler_db_path, analysis_class).read_export_db() if mark_df is None or mark_df.empty: logger.warning(f"There is no mark data in {profiler_db_path}.") return None @@ -195,4 +194,4 @@ class MstxSum(BaseRecipeAnalysis): mark_stats_df["step_id"] = mark_stats_df.apply(compute_step_id, axis=1, step_stats_df=step_df) rename_mark_msg_name(mark_stats_df) mark_stats_df = format_columns(mark_stats_df).set_index("Name", drop=True) - return mark_stats_df + return mark_stats_df \ No newline at end of file diff --git a/profiler/msprof_analyze/cluster_analyse/recipes/p2p_pairing/__init__.py b/profiler/msprof_analyze/cluster_analyse/recipes/p2p_pairing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a355e5a7f08206fc39dda4646817224c067f29f7 --- /dev/null +++ b/profiler/msprof_analyze/cluster_analyse/recipes/p2p_pairing/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/profiler/msprof_analyze/cluster_analyse/recipes/p2p_pairing/p2p_pairing.py b/profiler/msprof_analyze/cluster_analyse/recipes/p2p_pairing/p2p_pairing.py new file mode 100644 index 0000000000000000000000000000000000000000..692d47b0734e43a0598c019115e069b625e71f73 --- /dev/null +++ b/profiler/msprof_analyze/cluster_analyse/recipes/p2p_pairing/p2p_pairing.py @@ -0,0 +1,242 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from json import JSONDecodeError + +import numpy as np +import pandas as pd + +from msprof_analyze.cluster_analyse.recipes.base_recipe_analysis import BaseRecipeAnalysis +from msprof_analyze.cluster_analyse.common_func.table_constant import ProfilerTableConstant +from msprof_analyze.prof_common.constant import Constant +from msprof_analyze.prof_common.db_manager import DBManager +from msprof_analyze.prof_common.file_manager import FileManager +from msprof_analyze.prof_common.logger import get_logger +from msprof_analyze.prof_exports.p2p_pairing_export import P2PPairingExport + + +logger = get_logger() + + +class P2PPairing(BaseRecipeAnalysis): + + P2P_OP_NAME_PATTERN = r"^hcom_([Ss]end|[Rr](ecv|eceive))__\d+_\d+_\d+$" + DOMAIN_ID_EXTRACT_PATTERN = r"__(\d+)_\d+_\d+" + RECEIVE_OP_MATCH_PATTERN = r"[Rr]ecv|[Rr]eceive" + VALID_DST_RANK_TASK_TYPE = [Constant.NOTIFY_RECORD, Constant.NOTIFY_WAIT] + # intermediate dataframe column names + COL_NAME_IS_UNIQUE_VALUE = "isUniqueValue" + COL_NAME_OP_DST_RANK = "opDstRank" + COL_NAME_DOMAIN_ID = "domainId" + COL_NAME_IS_RECEIVE = "isReceive" + COL_NAME_OP_NAMING_INDEX = "opNamingIndex" + # output column name + COL_NAME_P2P_CONNECTION_ID = "opConnectionId" + # export params + TARGET_TABLE_NAME = Constant.TABLE_COMMUNICATION_OP + + def __init__(self, params): + super().__init__(params) + logger.info("P2PPairing init.") + + @property + def base_dir(self): + return os.path.basename(os.path.dirname(__file__)) + + def run(self, context): + self.mapper_func(context) + logger.info("P2PPairing completed.") + + def update_connection_info_to_table(self, df_result, profiler_db_path): + """ + 将生成好的连接ID添加至COMMUNICATION OP表中,新增列`opConnectionId`。目前只处理Send和Recv算子,对应的opId会更新具体的连接ID, + 否则置空 + """ + conn, cursor = DBManager.create_connect_db(profiler_db_path) + ret = DBManager.check_columns_exist(cursor, self.TARGET_TABLE_NAME, {self.COL_NAME_P2P_CONNECTION_ID}) + if ret is None: + logger.error("Failed to connect to the database. Please check the database configurations") + return + if self.COL_NAME_P2P_CONNECTION_ID in ret: + logger.error(f"`{self.COL_NAME_P2P_CONNECTION_ID}` already exists in the {self.TARGET_TABLE_NAME}. " + f"Exiting to prevent result overwrite.") + return + DBManager.execute_sql( + conn, + f"ALTER TABLE {self.TARGET_TABLE_NAME} ADD COLUMN {self.COL_NAME_P2P_CONNECTION_ID} TEXT" + ) + DBManager.execute_sql( + conn, + f"UPDATE {self.TARGET_TABLE_NAME} SET {self.COL_NAME_P2P_CONNECTION_ID} = NULL" + ) + DBManager.executemany_sql( + conn, + f""" + UPDATE {self.TARGET_TABLE_NAME} + SET {self.COL_NAME_P2P_CONNECTION_ID} = ? + WHERE {ProfilerTableConstant.OP_NAME} = ?;""", + [(row[self.COL_NAME_P2P_CONNECTION_ID], row[P2PPairingExport.CO_OP_NAME]) + for _, row in df_result.iterrows()] + ) + DBManager.destroy_db_connect(conn, cursor) + + def generate_p2p_connection_index(self, df): + """ + 生成每一个P2P的算子的对应连接ID,连接ID的生成规则按照`通信域_Send卡号_Recv卡号_算子index`。 + 其中通信域为通信域字符串的哈希值后三位表示;Send卡和Recv卡分别为这个通信域内的local rank号;算子index是这两张卡之间按时间线排序, + 出现Send和Recv算子已有的频次。比如说,一个算子的名称为`hcom_send_233_58_1`,自己在通信域内的rank号为0,对端的rank号为1;在这之前 + 并没有存在0卡向1卡的Send任务。因此生成的id为`233_0_1_0` + """ + df[self.COL_NAME_DOMAIN_ID] = df[P2PPairingExport.OP_NAME]. \ + str.extract(self.DOMAIN_ID_EXTRACT_PATTERN)[0] + df[self.COL_NAME_IS_RECEIVE] = df[P2PPairingExport.OP_NAME]. \ + str.contains(self.RECEIVE_OP_MATCH_PATTERN) + df.loc[ + df[self.COL_NAME_IS_RECEIVE], [P2PPairingExport.SRC_RANK, self.COL_NAME_OP_DST_RANK] + ] = df.loc[ + df[self.COL_NAME_IS_RECEIVE], [self.COL_NAME_OP_DST_RANK, P2PPairingExport.SRC_RANK] + ].values + df[self.COL_NAME_OP_NAMING_INDEX] = df.sort_values(by=[P2PPairingExport.START_TIME]). \ + groupby([P2PPairingExport.SRC_RANK, self.COL_NAME_OP_DST_RANK]).cumcount() + df[self.COL_NAME_P2P_CONNECTION_ID] = (df[self.COL_NAME_DOMAIN_ID].astype(str) + "_" + + df[P2PPairingExport.SRC_RANK].astype(str) + "_" + + df[self.COL_NAME_OP_DST_RANK].astype(str) + "_" + + df[self.COL_NAME_OP_NAMING_INDEX].astype(str)) + return df.reset_index() + + def fine_filtering_src_dst_ranks(self, df: pd.DataFrame): + """ + 精筛符合条件的数据: + 1、小算子任务包含了“Notify_Record”和“Notify_Wait”的数据 + 2、上一步得到的数据中对端卡号是否一致,如果不一致则会抛出warning + 3、步骤1得到数据中本端卡号是否一致,如果不一致则会报出error返回空值 + """ + df = df[df[P2PPairingExport.TASK_TYPE].isin(self.VALID_DST_RANK_TASK_TYPE)] + + def check_dst_rank_unique(group): + return group[P2PPairingExport.DST_RANK].nunique() == 1 + + unique_dst_rank: pd.DataFrame = (df.groupby(P2PPairingExport.OP_NAME).apply(check_dst_rank_unique)) + + def get_dst_rank_value(group): + if group[P2PPairingExport.DST_RANK].nunique() == 1: + return group[P2PPairingExport.DST_RANK].iloc[0] + return np.nan + + dst_rank_value: pd.DataFrame = (df.groupby(P2PPairingExport.OP_NAME, group_keys=False). + apply(get_dst_rank_value)) + + df = df.copy() + df[self.COL_NAME_IS_UNIQUE_VALUE] = df[P2PPairingExport.OP_NAME].map(unique_dst_rank) + df[self.COL_NAME_OP_DST_RANK] = df[P2PPairingExport.OP_NAME].map(dst_rank_value) + df[self.COL_NAME_OP_DST_RANK] = df[self.COL_NAME_OP_DST_RANK].fillna(Constant.INVALID_RANK_NUM) + df[self.COL_NAME_OP_DST_RANK] = df[self.COL_NAME_OP_DST_RANK].astype(df[P2PPairingExport.DST_RANK].dtype) + + check_dst_rank_unique_false: pd.DataFrame = df[~df[self.COL_NAME_IS_UNIQUE_VALUE]] + if not check_dst_rank_unique_false.empty: + logger.warning(f"There are communication op entries with multiple destination ranks! " + f"Please check the corresponding profiler database file.") + + df = df[df[self.COL_NAME_IS_UNIQUE_VALUE]] + + src_rank_unique_values: int = df[P2PPairingExport.SRC_RANK].nunique() + if src_rank_unique_values != 1: + logger.error(f"There are communication op entries with multiple source ranks! " + f"Please check the corresponding profiler database file.") + return None + return df.reset_index() + + def filter_data_by_group_name(self, df: pd.DataFrame): + """ + 初步筛选出目标数据: + 1、筛选出Send和Recv的算子 + 2、筛选出同一opId在COMMUNICATION OP中groupName和COMMUNICATION TASK INFO中groupName一致的数据 + """ + df = df[df[P2PPairingExport.OP_NAME].str.match(self.P2P_OP_NAME_PATTERN)] + filtered_df = df[df[P2PPairingExport.CO_GROUP_NAME] == df[P2PPairingExport.CTI_GROUP_NAME]] + anomaly_group_match = df[df[P2PPairingExport.CO_GROUP_NAME] != df[P2PPairingExport.CTI_GROUP_NAME]] + if not anomaly_group_match.empty: + logger.warning(f"Group name mismatch in {len(anomaly_group_match)} entries. Please check the" + f" profiler database in communication task info.") + return filtered_df.reset_index() + + def extract_pp_group_from_metadata(self, profiler_parent_path) -> any: + """ + 从profiler_metadata.json的文件中获取pp通信域的信息 + """ + metadata_path = os.path.join(profiler_parent_path, Constant.PROFILER_METADATA) + try: + if os.path.exists(metadata_path): + metadata = FileManager.read_json_file(metadata_path) + parallel_group_info: dict = metadata.get(Constant.PARALLEL_GROUP_INFO, None) if metadata else None + else: + raise FileNotFoundError(f"No `{Constant.PROFILER_METADATA}` found in {profiler_parent_path}.") + except (FileNotFoundError, JSONDecodeError) as e: + logger.error(f"Failed to load profiler metadata: {e}") + return None + + if parallel_group_info is None: + logger.error(f"No key name `{Constant.PARALLEL_GROUP_INFO}` found in {metadata_path}") + return None + + pp_group_info = [] + for name in parallel_group_info: + each_group_info: dict = parallel_group_info[name] + if each_group_info[Constant.GROUP_NAME] == Constant.PP: + pp_group_info.append(parallel_group_info[name]) + if not pp_group_info: + logger.error(f"No pipeline parallel info found in {metadata_path}") + return None + + return pp_group_info + + def _mapper_func(self, data_map, analysis_class): + profiler_db_path: str = data_map.get(Constant.PROFILER_DB_PATH) + profiler_parent_path: str = os.path.dirname(os.path.dirname(profiler_db_path)) + + df: pd.DataFrame = P2PPairingExport(profiler_db_path, analysis_class).read_export_db() + if df is None or df.empty: + logger.warning(f"There is no stats data in {profiler_db_path}.") + return None + + pp_group_info = self.extract_pp_group_from_metadata(profiler_parent_path) # 暂时没用到,预留给后续确认用全局rank + if pp_group_info is None: + logger.error(f"Cannot obtain pipeline parallel info from the metadata. " + f"Please check the corresponding {Constant.PROFILER_METADATA}") + + df = self.filter_data_by_group_name(df) + if df.empty: + return None + + df_filtered = self.fine_filtering_src_dst_ranks(df.copy()) + if df_filtered is None: + logger.error("Got error when trying to match rank numbers!") + return None + + df_result = df_filtered.groupby([P2PPairingExport.OP_NAME, P2PPairingExport.CO_OP_NAME]).agg( + { + P2PPairingExport.START_TIME: "first", + P2PPairingExport.SRC_RANK: "first", + self.COL_NAME_OP_DST_RANK: "first" + } + ).reset_index() + + df_result = self.generate_p2p_connection_index(df_result) + + df_result = df_result[[P2PPairingExport.CO_OP_NAME, self.COL_NAME_P2P_CONNECTION_ID]] + + self.update_connection_info_to_table(df_result, profiler_db_path) + return data_map.get(Constant.RANK_ID) diff --git a/profiler/msprof_analyze/cluster_analyse/recipes/pp_chart/__init__.py b/profiler/msprof_analyze/cluster_analyse/recipes/pp_chart/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7101187a2c2619f3b1c20dded14b433950b4c662 --- /dev/null +++ b/profiler/msprof_analyze/cluster_analyse/recipes/pp_chart/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/profiler/msprof_analyze/cluster_analyse/recipes/pp_chart/pp_chart.py b/profiler/msprof_analyze/cluster_analyse/recipes/pp_chart/pp_chart.py new file mode 100644 index 0000000000000000000000000000000000000000..358c645492274c12cd7d859914e71feabe0a9895 --- /dev/null +++ b/profiler/msprof_analyze/cluster_analyse/recipes/pp_chart/pp_chart.py @@ -0,0 +1,292 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import defaultdict +import json +import os +import pandas as pd + +from msprof_analyze.cluster_analyse.recipes.base_recipe_analysis import BaseRecipeAnalysis +from msprof_analyze.prof_common.constant import Constant +from msprof_analyze.prof_common.logger import get_logger +from msprof_analyze.prof_common.database_service import DatabaseService +from msprof_analyze.prof_exports.pp_chart_export import PPChartExport + +logger = get_logger() + + +def filter_non_overlapping(df: pd.DataFrame) -> pd.DataFrame: + if df.empty: + return df + result = [] + last_end = -1 + for _, row in df.iterrows(): + if row['startNs'] >= last_end: + result.append(row) + last_end = row['endNs'] + return pd.DataFrame(result) + + +class PPChart(BaseRecipeAnalysis): + FORWARD_STAGE_0 = "FORWARD_STAGE_0" + FORWARD_STAGE_1 = "FORWARD_STAGE_1" # 表示一个microbatch在同一张卡的两个stage + BACKWARD_STAGE_0 = "BACKWARD_STAGE_0" + BACKWARD_STAGE_1 = "BACKWARD_STAGE_1" + STEP_TASK_INFO = "StepTaskInfo" + LOGITS = "logits" + + def __init__(self, params): + super().__init__(params) + logger.info("PPChart init.") + self.micro_batch_id_dict = defaultdict(list) + self.pp_stage_mstx_num = defaultdict(int) + self.micro_batch_num = None + self.pp_type = None + self.distributed_args = self.load_distributed_args() + self.load_pp_info() + + @property + def base_dir(self): + return os.path.basename(os.path.dirname(__file__)) + + @staticmethod + def generate_dualpipev_schedule(pp_size, num_microbatches): + num_microbatches = num_microbatches * 2 + num_warmup_stages = [0] * pp_size + num_interleaved_forward_stages = [0] * pp_size + num_1b1w1f_stages = [0] * pp_size + num_overlap_stages = [0] * pp_size + num_1b1overlap_stages = [0] * pp_size + num_interleaved_backward_stages = [0] * pp_size + num_cooldown_stages = [0] * pp_size + pp_size *= 2 + for i in range(pp_size // 2): + num_warmup_stages[i] = pp_size - 2 - i * 2 + num_interleaved_forward_stages[i] = i + 1 # 每个单位是一组1f1f + num_1b1w1f_stages[i] = pp_size // 2 - i - 1 + num_overlap_stages[i] = num_microbatches - pp_size * 2 + i * 2 + 2 + num_1b1overlap_stages[i] = (pp_size // 2 - i - 1) * 2 + num_interleaved_backward_stages[i] = i + 1 + num_cooldown_stages[i] = [i + 1, pp_size - 2 * i - 2, i + 1] + schedule_all_stages = { + 'warmup': num_warmup_stages, + 'interleaved_forward': num_interleaved_forward_stages, + '1b1w1f': num_1b1w1f_stages, + 'overlap': num_overlap_stages, + '1b1overlap': num_1b1overlap_stages, + 'interleaved_backward': num_interleaved_backward_stages, + 'cooldown': num_cooldown_stages + } + return schedule_all_stages + + def calculate_micro_batch_id_for_dualpipev(self): + pp_size = self.distributed_args.get(self.PP_SIZE) + if self.micro_batch_num is None or self.micro_batch_num < pp_size * 2: + logger.error("The micro_batch_num is less than pp_size * 2, please set it to a larger value.") + return + schedule_all_stages = self.generate_dualpipev_schedule(pp_size, self.micro_batch_num) + cur_micro_batch_id_dict = defaultdict(dict) + flag = defaultdict(bool) # 标识最后一个阶段是BACKWARD_STAGE_0开头还是BACKWARD_STAGE_1开头 + for stage_name, stage_num in schedule_all_stages.items(): + for i, num in enumerate(stage_num): + last_forward_id_0 = cur_micro_batch_id_dict[i].setdefault(self.FORWARD_STAGE_0, -1) + last_forward_id_1 = cur_micro_batch_id_dict[i].setdefault(self.FORWARD_STAGE_1, -1) + last_backward_id_0 = cur_micro_batch_id_dict[i].setdefault(self.BACKWARD_STAGE_0, -1) + last_backward_id_1 = cur_micro_batch_id_dict[i].setdefault(self.BACKWARD_STAGE_1, -1) + if stage_name == "warmup": + self.micro_batch_id_dict[i].extend([[str(x), 0] for x in range(num)]) + cur_micro_batch_id_dict[i][self.FORWARD_STAGE_0] = num - 1 + self.pp_stage_mstx_num[i] += num + elif stage_name == "interleaved_forward": + for j in range(num): + self.micro_batch_id_dict[i].append([str(last_forward_id_0 + j + 1), 1]) + self.micro_batch_id_dict[i].append([str(self.micro_batch_num + j), 1]) + cur_micro_batch_id_dict[i][self.FORWARD_STAGE_0] += 1 + cur_micro_batch_id_dict[i][self.FORWARD_STAGE_1] = self.micro_batch_num + num - 1 + self.pp_stage_mstx_num[i] += num * 2 + elif stage_name == "1b1w1f": + for j in range(num): + if i == 0: + self.micro_batch_id_dict[i].append([self.LOGITS, 2]) + self.pp_stage_mstx_num[i] += 1 + self.micro_batch_id_dict[i].append([f"{self.micro_batch_num + j}b", 2]) + self.micro_batch_id_dict[i].append([f"{self.micro_batch_num + j}w", 2]) + self.micro_batch_id_dict[i].append([str(last_forward_id_1 + j + 1), 2]) + cur_micro_batch_id_dict[i][self.FORWARD_STAGE_1] += 1 + cur_micro_batch_id_dict[i][self.BACKWARD_STAGE_1] = self.micro_batch_num + num - 1 + self.pp_stage_mstx_num[i] += num * 3 + elif stage_name == "overlap": + for j in range(num // 2): + if i == 0: + self.micro_batch_id_dict[i].append([self.LOGITS, 3]) + self.pp_stage_mstx_num[i] += 1 + if i == pp_size - 1 and j == 0: + self.micro_batch_id_dict[i].append([f"{last_forward_id_0 + j + 1}F", 3]) + self.micro_batch_id_dict[i].append([f"{last_backward_id_1 + j + 1}B", 3]) + self.pp_stage_mstx_num[i] += 1 + else: + self.micro_batch_id_dict[i].append( + [f"{last_forward_id_0 + j + 1}F+{last_backward_id_1 + j + 1}B", 3]) + self.micro_batch_id_dict[i].append( + [f"{last_forward_id_1 + j + 1}F+{last_backward_id_0 + j + 1}B", 3]) + cur_micro_batch_id_dict[i][self.FORWARD_STAGE_0] += 1 + cur_micro_batch_id_dict[i][self.FORWARD_STAGE_1] += 1 + cur_micro_batch_id_dict[i][self.BACKWARD_STAGE_0] += 1 + cur_micro_batch_id_dict[i][self.BACKWARD_STAGE_1] += 1 + self.pp_stage_mstx_num[i] += num + elif stage_name == "1b1overlap": + for j in range(num // 2): + if i == 0: + self.micro_batch_id_dict[i].append([self.LOGITS, 4]) + self.pp_stage_mstx_num[i] += 1 + self.micro_batch_id_dict[i].append([f"{last_backward_id_1 + j + 1}B", 4]) + self.micro_batch_id_dict[i].append( + [f"{last_forward_id_1 + j + 1}F+{last_backward_id_0 + j + 1}B", 4]) + cur_micro_batch_id_dict[i][self.FORWARD_STAGE_1] += 1 + cur_micro_batch_id_dict[i][self.BACKWARD_STAGE_0] += 1 + cur_micro_batch_id_dict[i][self.BACKWARD_STAGE_1] += 1 + self.pp_stage_mstx_num[i] += num + elif stage_name == "interleaved_backward": + for j in range(num): + if j % 2 == 0: + if i == 0: + self.micro_batch_id_dict[i].append([self.LOGITS, 5]) + self.pp_stage_mstx_num[i] += 1 + self.micro_batch_id_dict[i].append([str(f"{last_backward_id_1 + j // 2 + 1}B"), 5]) + cur_micro_batch_id_dict[i][self.BACKWARD_STAGE_1] += 1 + flag[i] = True + else: + self.micro_batch_id_dict[i].append([str(f"{last_backward_id_0 + j // 2 + 1}B"), 5]) + cur_micro_batch_id_dict[i][self.BACKWARD_STAGE_0] += 1 + flag[i] = False + self.pp_stage_mstx_num[i] += num + elif stage_name == "cooldown": + self.pp_stage_mstx_num[i] += pp_size # 不开dw分离 + while last_backward_id_0 < self.micro_batch_num - 1 or \ + last_backward_id_1 < self.micro_batch_num * 2 - 1: + if flag[i]: + if last_backward_id_0 < self.micro_batch_num - 1: + self.micro_batch_id_dict[i].append([str(f"{last_backward_id_0 + 1}B"), 6]) + last_backward_id_0 += 1 + if last_backward_id_1 < self.micro_batch_num * 2 - 1: + self.micro_batch_id_dict[i].append([str(f"{last_backward_id_1 + 1}B"), 6]) + last_backward_id_1 += 1 + else: + if last_backward_id_1 < self.micro_batch_num * 2 - 1: + self.micro_batch_id_dict[i].append([str(f"{last_backward_id_1 + 1}B"), 6]) + last_backward_id_1 += 1 + if last_backward_id_0 < self.micro_batch_num - 1: + self.micro_batch_id_dict[i].append([str(f"{last_backward_id_0 + 1}B"), 6]) + last_backward_id_0 += 1 + + def load_pp_info(self): + rank_id = list(self._data_map.keys())[0] + profiler_db_path = self._data_map[rank_id] + db_path = os.path.join(profiler_db_path, Constant.SINGLE_OUTPUT, f"ascend_pytorch_profiler_{rank_id}.db") + if not os.path.exists(profiler_db_path): + logger.error(f"Db_file: {db_path} not exist.") + return + try: + service = DatabaseService(db_path) + service.add_table_for_query("META_DATA", ["name", "value"]) + df = service.query_data().get("META_DATA", None) + if df is None: + logger.warning(f"There is no META_DATA in {db_path}.") + return + pp_info = df.loc[df["name"] == "pp_info", "value"] + if pp_info.empty: + logger.warning("pp_info not in profiling files, please input manually.") + return + else: + pp_info = json.loads(pp_info.values[0]) + self.micro_batch_num = pp_info.get("microbatch_num") + self.pp_type = pp_info.get("pp_type").lower() + except Exception as err: + logger.error(err) + logger.error("pp_info not in profiling files, please input manually.") + + def mapper_func_for_dualpipev(self, context): + return context.wait( + context.map( + self._mapper_func_for_dualpipev, + self._get_rank_db(), + analysis_class=self._recipe_name, + rank_pp_stage_map=self.map_rank_pp_stage(self.distributed_args), + pp_stage_mstx_num=self.pp_stage_mstx_num, + micro_batch_id_dict=self.micro_batch_id_dict + ) + ) + + def run(self, context): + if self.distributed_args is None: + logger.warning("The parallel strategy is lost.") + if self.pp_type == "dualpipev": + self.calculate_micro_batch_id_for_dualpipev() + res = self.mapper_func_for_dualpipev(context) # 忽略返回值 + else: + res = self.mapper_func(context) # 忽略返回值 + if res: + logger.info("PPChart finished.") + + def _mapper_func_for_dualpipev(self, data_map, analysis_class, rank_pp_stage_map, pp_stage_mstx_num, + micro_batch_id_dict): + """ + rank_pp_stage_map: 记录rank与pp_stage的映射,可以知道某个rank属于哪个pp_stage + pp_stage_mstx_num: 每个pp_stage预期的前反向的总打点数 + micro_batch_id_dict: 每个pp_stage的microbatch_id信息以及属于dualpipeV的哪个阶段,示例如下 + { + 0: [ ["0", 0], [ "2", 0], ..., ["7F+13B", 3], ...] + .... + } + """ + profiler_db_path = data_map.get(Constant.PROFILER_DB_PATH) + df = PPChartExport(profiler_db_path, analysis_class).read_export_db() + if df is None or df.empty: + logger.warning(f"There is no mstx data in {profiler_db_path}.") + return + rank_id = data_map.get(Constant.RANK_ID) + pp_stage = rank_pp_stage_map.get(rank_id) + if pp_stage is None: + logger.error(f"The rank {rank_id} does not belong to any PP stage.") + return + df = filter_non_overlapping(df) + df["name"] = "" + df["type"] = 0 + + def match_mstx_name(group): + if len(group) != pp_stage_mstx_num[pp_stage]: + logger.error(f"The number of mstx_count should be {pp_stage_mstx_num[pp_stage]}, not {len(group)}.") + return group + for idx, (i, row) in enumerate(group.iterrows()): + micro_batch_id_info = micro_batch_id_dict[pp_stage][idx] + group.at[i, "name"] = micro_batch_id_info[0] + group.at[i, "type"] = micro_batch_id_info[1] + return group + df = df.groupby("step").apply(match_mstx_name) + result = df[["name", "startNs", "endNs", "type"]] + self.dump_data(data=result, file_name="", table_name=self.STEP_TASK_INFO, index=False, + custom_db_path=data_map.get(Constant.PROFILER_DB_PATH)) + + def _mapper_func(self, data_map, analysis_class): + profiler_db_path = data_map.get(Constant.PROFILER_DB_PATH) + df = PPChartExport(profiler_db_path, analysis_class).read_export_db() + if df is None or df.empty: + logger.warning(f"There is no mstx data in {profiler_db_path}.") + return + df["name"] = df["msg"].apply(lambda x: "FP" if "forward" in x.lower() else "BP") + df['type'] = df['name'].map({'FP': 0, 'BP': 1}) + result = df[["name", "startNs", "endNs", "type"]] + self.dump_data(data=result, file_name="", table_name=self.STEP_TASK_INFO, index=False, + custom_db_path=data_map.get(Constant.PROFILER_DB_PATH)) \ No newline at end of file diff --git a/profiler/msprof_analyze/cluster_analyse/recipes/slow_calc/__init__.py b/profiler/msprof_analyze/cluster_analyse/recipes/slow_calc/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/profiler/msprof_analyze/cluster_analyse/recipes/slow_calc/slow_calc.py b/profiler/msprof_analyze/cluster_analyse/recipes/slow_calc/slow_calc.py new file mode 100644 index 0000000000000000000000000000000000000000..6a600addf14dfbabef268a7aa6807b29617ea191 --- /dev/null +++ b/profiler/msprof_analyze/cluster_analyse/recipes/slow_calc/slow_calc.py @@ -0,0 +1,177 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from pathlib import Path +import tqdm +import numpy as np +import pandas as pd + +from msprof_analyze.cluster_analyse.recipes.base_recipe_analysis import BaseRecipeAnalysis +from msprof_analyze.prof_common.constant import Constant +from msprof_analyze.prof_common.logger import get_logger +from msprof_analyze.prof_common.database_service import DatabaseService + +from msprof_analyze.prof_exports.slow_calc_export import SlowCalcExport +from msprof_analyze.cluster_analyse.common_func.utils import calculate_zscore + + +logger = get_logger() + +STR_COLUMN_RANK = "rank" + +GROUP_KEY_LIST = ["opName", "inputShapes", "outputShapes"] + + +class SlowCalc(BaseRecipeAnalysis): + TABLE_SLOW_CALC_SUM = "SlowCalcSum" + TABLE_SLOW_CALC_TETAIL = "SlowCalcDetail" + + def __init__(self, params): + super().__init__(params) + logger.info("SlowRankCheck init.") + self.path_output = Path(self._output_path) + + @property + def base_dir(self): + return os.path.basename(os.path.dirname(__file__)) + + @staticmethod + def check_abnormal(df_total) -> pd.DataFrame: + str_name_p = "p_score" + str_name_z = "z_score" + str_name_index = "the_index" + + # * 计算偏离最小值很多的算子 + df_total: pd.DataFrame = df_total + df_total.loc[:, str_name_index] = range(len(df_total)) + df_total.set_index(str_name_index, inplace=True) + df_total.loc[:, str_name_p] = np.nan + df_total.loc[:, str_name_z] = np.nan + + all_df_total = df_total.groupby(GROUP_KEY_LIST) + + # * 第一层: shape 和算子名称分组 + dict_abnormal = {} + for group_i in tqdm.tqdm(all_df_total.groups, desc="Analyzing..."): + # * 第二层:rank 分组 + the_index_i = all_df_total.get_group(group_i).index + detail_i = df_total.loc[the_index_i] + + np_duration = detail_i[[Constant.DURATION_TIME]].to_numpy().reshape(-1) + + m = np.mean(np_duration) + sd = np.std(np_duration) + + if m == 0: + p = 1 + else: + p = (detail_i[Constant.DURATION_TIME].to_numpy() - m) / m + z = calculate_zscore(np_duration, m, sd) + + df_total.loc[the_index_i, str_name_p] = p + df_total.loc[the_index_i, str_name_z] = z + + detail_i = df_total.loc[the_index_i] # * 重新索引,更新写入的数据 + df_abnormal_i = detail_i.groupby(STR_COLUMN_RANK, as_index=False).agg( + count=(Constant.DURATION_TIME, len), + time_mean=(Constant.DURATION_TIME, "mean"), + time_max=(Constant.DURATION_TIME, "max"), + p_mean=(str_name_p, "mean"), + p_max=(str_name_p, "max"), + z_mean=(str_name_z, "mean"), + z_max=(str_name_z, "max"), + ) + + for ii, k in enumerate(GROUP_KEY_LIST): + df_abnormal_i[k] = group_i[ii] + + dict_abnormal[group_i] = df_abnormal_i + + logger.info("Summarizing...") + df_abnormal = pd.concat(dict_abnormal.values(), axis=0) + + dict_df = { + SlowCalc.TABLE_SLOW_CALC_SUM: df_abnormal, + SlowCalc.TABLE_SLOW_CALC_TETAIL: df_total, + } + return dict_df + + @staticmethod + def _mapper_func(data_map, analysis_class): + profiler_db_path = data_map.get(Constant.PROFILER_DB_PATH) + rank_id = data_map.get(Constant.RANK_ID) + + db_service = DatabaseService(profiler_db_path) + db_service.add_table_for_query("STRING_IDS") + dict_table = db_service.query_data() + if "STRING_IDS" not in dict_table: + logger.error(f"No STRING_IDS in database ({profiler_db_path}).") + return None + name_data = dict_table["STRING_IDS"] + + df_calc = SlowCalcExport(profiler_db_path, analysis_class).read_export_db() + + dict_id_name = dict(zip(name_data["id"], name_data["value"])) + df_calc["inputShapes"] = df_calc["inputShapes"].map(dict_id_name) + df_calc["outputShapes"] = df_calc["outputShapes"].map(dict_id_name) + df_calc["opName"] = df_calc["opName"].map(dict_id_name) + df_calc["opType"] = df_calc["opType"].map(dict_id_name) + df_calc[STR_COLUMN_RANK] = rank_id + + return df_calc + + def run(self, context): + mapper_res = self.mapper_func(context) + logger.info("Collecting op info completed.") + self.reducer_func(mapper_res) + logger.info("Summarying completed.") + + def reducer_func(self, mapper_res): + if mapper_res is None or len(mapper_res) == 0: + logger.error("mapper_res is None.") + return + list_df = [i for i in mapper_res if i is not None] + total_df = pd.concat(list_df, axis=0) + + dict_df = self.check_abnormal(total_df) + self.save_summary(dict_df) + + def save_summary(self, dict_df): + self.slow_calc_sum = dict_df[self.TABLE_SLOW_CALC_SUM] + self.slow_calc_detail = dict_df[self.TABLE_SLOW_CALC_TETAIL] + + if self._export_type == "db": + self.save_db() + elif self._export_type == "notebook": + self.save_notebook() + else: + logger.warning(f"Unknown export type [{self._export_type}]. Defalut to save db.") + self.save_db() + + def save_notebook(self): + self.dump_data(self.slow_calc_sum, "slow_calc_sum" + Constant.CSV_SUFFIX, index=False) + self.dump_data(self.slow_calc_detail, "slow_calc_detail" + Constant.CSV_SUFFIX, index=False) + + def save_db(self): + self.dump_data( + self.slow_calc_sum, + Constant.DB_CLUSTER_COMMUNICATION_ANALYZER, + self.TABLE_SLOW_CALC_SUM, + index=False, + ) + self.dump_data( + self.slow_calc_detail, + Constant.DB_CLUSTER_COMMUNICATION_ANALYZER, + self.TABLE_SLOW_CALC_TETAIL, + index=False, + ) diff --git a/profiler/msprof_analyze/cluster_analyse/recipes/slow_link/__init__.py b/profiler/msprof_analyze/cluster_analyse/recipes/slow_link/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/profiler/msprof_analyze/cluster_analyse/recipes/slow_link/slow_link.py b/profiler/msprof_analyze/cluster_analyse/recipes/slow_link/slow_link.py new file mode 100644 index 0000000000000000000000000000000000000000..f2c5e5fe7d8004687a8cd0ab6eae929659b662b9 --- /dev/null +++ b/profiler/msprof_analyze/cluster_analyse/recipes/slow_link/slow_link.py @@ -0,0 +1,216 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from collections import defaultdict + +import pandas as pd +import numpy as np +from tqdm import tqdm + +from msprof_analyze.cluster_analyse.common_func.utils import describe_duration +from msprof_analyze.cluster_analyse.common_func.utils import detect_outliers_z_score +from msprof_analyze.cluster_analyse.recipes.base_recipe_analysis import BaseRecipeAnalysis +from msprof_analyze.prof_common.constant import Constant +from msprof_analyze.prof_common.logger import get_logger +from msprof_analyze.prof_exports.slow_link_export import SlowLinkExport + +logger = get_logger() + + +class SlowLink(BaseRecipeAnalysis): + TABLE_SLOW_LINK_SUM = "SlowLinkSum" + TABLE_SLOW_LINK_OPS = "SlowLinkOps" + + TOP_NUM = "top_num" + DEFAULT_TOP_NUM = 15 + + def __init__(self, params): + super().__init__(params) + logger.info("SlowLink init.") + self.slow_link_sum = [] + self.slow_link_ops = [] + top_num = self._extra_args.get(self.TOP_NUM, self.DEFAULT_TOP_NUM) + self.top_num = int(top_num) if isinstance(top_num, str) and top_num.isdigit() else self.DEFAULT_TOP_NUM + + @property + def base_dir(self): + return os.path.basename(os.path.dirname(__file__)) + + @classmethod + def add_parser_argument(cls, parser): + parser.add_argument("--top_num", type=str, help="Duration cost top count", default=cls.DEFAULT_TOP_NUM) + + def merge_func(self, mapper_res): + # 过滤掉mapper_res中为None的元素 + mapper_res = list(filter(lambda df: df is not None, mapper_res)) + + # 如果过滤后mapper_res为空,记录错误并返回 + if not mapper_res: + logger.error("Mapper data is empty. Please check the input or data source.") + return + dataframes = [pd.DataFrame(item) for item in mapper_res] + mapper_res = pd.concat(dataframes, ignore_index=True) + # 从mapper_res中提取各个字段的值 + rank_id_arr = mapper_res["rankId"].values # 提取rankId数组 + num_ranks = len(rank_id_arr) # 获取rankId数组的长度 + group_name_arr = mapper_res["groupName"].values # 提取groupName数组 + communication_time_arr = mapper_res["communicationTime"].values # 提取通信时间数组 + op_name_arr = mapper_res["opName"].values # 提取操作名称数组 + + # 初始化用于存储分组信息的字典和数组 + process_group = defaultdict(lambda: defaultdict(list)) # 用于存储按组和操作名分组的索引 + transmit_time_arr = np.zeros(num_ranks, dtype=np.int64) # 初始化传输时间数组 + related_ranks_arr = np.zeros(num_ranks, dtype=np.int32) # 初始化相关rank数量数组 + + # 遍历所有记录,按groupName和opName分组 + for idx in range(num_ranks): + # 如果操作名称中包含"send"或"receive",跳过(可能是发送或接收操作) + if "send" in op_name_arr[idx] or "receive" in op_name_arr[idx]: + continue + # 将当前索引添加到对应的分组中 + process_group[group_name_arr[idx]][op_name_arr[idx]].append(idx) + + # 遍历分组后的数据,计算每个操作的传输时间和相关rank数量 + for _, ops_same_group in tqdm(process_group.items(), desc="Processing database data..."): + for _, ops in ops_same_group.items(): + # 提取当前分组中所有操作的通信时间 + communication_time_list = [communication_time_arr[op_idx] for op_idx in ops] + # 计算最小通信时间作为传输时间 + transmit_time = min(communication_time_list) + # 计算当前分组中操作的数量作为相关rank数量 + related_ranks_num = len(ops) + + # 更新传输时间和相关rank数量数组 + for op_idx in ops: + transmit_time_arr[op_idx] = transmit_time + related_ranks_arr[op_idx] = related_ranks_num + + # 将计算得到的传输时间和相关rank数量添加到mapper_res中 + mapper_res.insert(mapper_res.shape[1], 'transmitTime', transmit_time_arr) + mapper_res.insert(mapper_res.shape[1], 'relatedRanks', related_ranks_arr) + + # 调用过滤函数处理mapper_res + self.filter_func(mapper_res) + + def filter_func(self, mapper_res): + """ + 处理数据,分组并检测异常值。 + """ + # 按 opType, dataSize, related_ranks 分组 + grouped = mapper_res.groupby(['opType', 'dataSize', 'relatedRanks']) + + for _, group in grouped: + # 提取分组数据中的 transmit_time 列 + transmit_time_data = group['transmitTime'].values + + # 检测异常值 + outliers = detect_outliers_z_score(transmit_time_data) + + if outliers: + # 如果存在异常值,将整个分组数据存入 Slow_Link_Ops + self.slow_link_ops.append(group) + + if self.slow_link_ops: + self.slow_link_ops = pd.concat(self.slow_link_ops, ignore_index=True) + # 重置索引并去掉多余的索引列 + data = pd.DataFrame(self.slow_link_ops) + + # 按 'opType', 'dataSize', 'related_ranks' 分组 + grouped = data.groupby(['opType', 'dataSize', 'relatedRanks']) + + # 计算统计信息 + group_data = describe_duration(grouped['transmitTime']) + + # 找到每个组中 transmit_time 最小值和最大值对应的 rankId + min_rank = grouped['transmitTime'].idxmin().map(data['rankId']) + max_rank = grouped['transmitTime'].idxmax().map(data['rankId']) + + # 将最大值和最小值对应的 rankId 添加到 group_data + group_data['maxRank'] = max_rank.values + group_data['minRank'] = min_rank.values + + # 构造 filteringName + group_data['opTypeRelatedRanksDataSize'] = group_data.index.map(lambda x: f"{x[0]}{x[2]}_{x[1]}") + # 将 filteringName 移动到第一列 + cols = ['opTypeRelatedRanksDataSize'] + [col for col in group_data.columns if + col != 'opTypeRelatedRanksDataSize'] + group_data = group_data[cols] + + # 重置索引 + group_data = group_data.reset_index(drop=True) + # 计算最大值和最小值与均值的绝对值 + group_data['abs_max_mean'] = abs(group_data['MaxNs'] - group_data['MeanNs']) + group_data['abs_min_mean'] = abs(group_data['MinNs'] - group_data['MeanNs']) + + # 计算最大值和最小值与均值的绝对值中的较大值 + group_data['max_abs_mean'] = group_data[['abs_max_mean', 'abs_min_mean']].max(axis=1) + + # 计算偏移比值 + group_data['offsetRatio'] = group_data['max_abs_mean'] / group_data['StdNs'] + + # 按偏移比值降序排序 + group_data = group_data.sort_values(by='offsetRatio', ascending=False) + + # 根据 self.top_num 筛选出偏移比值最大的前 N 条记录 + group_data = group_data.head(self.top_num) + + # 删除辅助列 'abs_max_mean', 'abs_min_mean', 'max_abs_mean' + group_data = group_data.drop(columns=['abs_max_mean', 'abs_min_mean', 'max_abs_mean']) + + # 调整列的顺序,将 offsetRatio 移到 MinRank 和 MaxRank 之前 + columns = [col for col in group_data.columns if col not in ['maxRank', 'minRank', 'offsetRatio']] + columns.insert(len(columns), 'offsetRatio') # 将 offsetRatio 插入到倒数第三的位置 + columns.extend(['maxRank', 'minRank']) # 添加 MaxRank 和 MinRank 到列的最后 + + # 重新排列列的顺序 + group_data = group_data[columns] + + # 在处理 group_data 的最后部分并保存 + self.slow_link_sum = group_data + + def run(self, context): + if self.top_num <= 0: + logger.warning(f"SlowLink: top_num is set to a invalid value, " + f"it will be reset to default value({self.DEFAULT_TOP_NUM}).") + self.top_num = self.DEFAULT_TOP_NUM + mapper_res = self.mapper_func(context) + self.merge_func(mapper_res) + + if self._export_type == "db": + self.save_db() + elif self._export_type == "notebook": + self.save_notebook() + else: + logger.error("Unknown export type.") + + def save_notebook(self): + self.dump_data(self.slow_link_sum, "slow_link_sum.csv", index=False) + self.dump_data(self.slow_link_ops, "slow_link_ops.csv", index=False) + self.create_notebook("stats.ipynb") + self.add_helper_file("cluster_display.py") + + def save_db(self): + self.dump_data(self.slow_link_sum, Constant.DB_CLUSTER_COMMUNICATION_ANALYZER, self.TABLE_SLOW_LINK_SUM, + index=False) + self.dump_data(self.slow_link_ops, Constant.DB_CLUSTER_COMMUNICATION_ANALYZER, self.TABLE_SLOW_LINK_OPS, + index=False) + + def _mapper_func(self, data_map, analysis_class): + profiler_db_path = data_map.get(Constant.PROFILER_DB_PATH) + rank_id = data_map.get(Constant.RANK_ID) + df = SlowLinkExport(profiler_db_path, analysis_class).read_export_db() + if df is None or df.empty: + logger.warning(f"There is no stats data in {profiler_db_path}.") + return None + df.insert(0, "rankId", rank_id) + return df \ No newline at end of file diff --git a/profiler/msprof_analyze/cluster_analyse/recipes/slow_link/stats.ipynb b/profiler/msprof_analyze/cluster_analyse/recipes/slow_link/stats.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..30edbc245379aa6b02e8895427bc7ad5db6656b3 --- /dev/null +++ b/profiler/msprof_analyze/cluster_analyse/recipes/slow_link/stats.ipynb @@ -0,0 +1,111 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# SLOWLINK Summary\n", + "\n", + "集群场景快慢卡数据分析\n", + "\n", + "主要包含以下2个统计内容:\n", + "1. 按算子类型分组的,整个集群通信算子耗时的统计情况\n", + "2. 整个集群异常的opType_relatedRanks_dataSize" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 数据准备" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from IPython.display import display, HTML\n", + "display(HTML(\"\"))\n", + "\n", + "import matplotlib.pyplot as plt\n", + "\n", + "import pandas as pd\n", + "pd.set_option(\"display.max_rows\", 100)\n", + "pd.set_option(\"display.width\", 1000)\n", + "\n", + "import cluster_display\n", + "\n", + "slow_link_ops_df = pd.read_csv(\"slow_link_ops.csv\")\n", + "slow_link_sum_df = pd.read_csv(\"slow_link_sum.csv\", index_col=\"opTypeRelatedRanksDataSize\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "cluster_display.display_transmittime_bar(slow_link_ops_df, 0.05, 'hcom_allGather_', 5, 1024)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "### 集群异常的opType_relatedRanks_dataSize分析\n", + "\n", + "统计集群异常的opType_relatedRanks_dataSize,时间单位为微秒(us)\n", + "\n", + "包含以下统计项:\n", + "- Count:算子数量\n", + "- Mean:平均耗时\n", + "- Std:标准差\n", + "- Min:最小值\n", + "- Q1:四分之一分位数\n", + "- Median:中位数\n", + "- Q3:四分之三分位数\n", + "- Max:最大值\n", + "- Sum:总耗时\n", + "- MinRank:耗时最少算子所在的Rank\n", + "- MaxRank:耗时最长算子所在的Rank" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "display(slow_link_sum_df)\n", + "fig_slow_link_ops = cluster_display.display_duration_boxplots(None, slow_link_sum_df, x_title=\"opTypeRelatedRanksDataSize\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.8" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/profiler/msprof_analyze/cluster_analyse/recipes/slow_rank/__init__.py b/profiler/msprof_analyze/cluster_analyse/recipes/slow_rank/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a355e5a7f08206fc39dda4646817224c067f29f7 --- /dev/null +++ b/profiler/msprof_analyze/cluster_analyse/recipes/slow_rank/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/profiler/msprof_analyze/cluster_analyse/recipes/slow_rank/dixon_table.py b/profiler/msprof_analyze/cluster_analyse/recipes/slow_rank/dixon_table.py new file mode 100644 index 0000000000000000000000000000000000000000..7bf7e2c80621f6756b2d9ad82051eefac263d141 --- /dev/null +++ b/profiler/msprof_analyze/cluster_analyse/recipes/slow_rank/dixon_table.py @@ -0,0 +1,117 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# 单边狄克逊检验表,995置信度 +DIXON_TABLE_995 = { + 3: 0.994, + 4: 0.920, + 5: 0.823, + 6: 0.744, + 7: 0.680, + 8: 0.723, + 9: 0.676, + 10: 0.638, + 11: 0.707, + 12: 0.675, + 13: 0.649, + 14: 0.672, + 15: 0.649, + 16: 0.629, + 17: 0.611, + 18: 0.595, + 19: 0.580, + 20: 0.568, + 21: 0.556, + 22: 0.545, + 23: 0.536, + 24: 0.526, + 25: 0.519, + 26: 0.510, + 27: 0.503, + 28: 0.496, + 29: 0.489, + 30: 0.484, + 31: 0.478, + 32: 0.473, + 33: 0.468, + 34: 0.463, + 35: 0.458, + 36: 0.454, + 37: 0.450, + 38: 0.446, + 39: 0.442, + 40: 0.439, + 41: 0.435, + 42: 0.432, + 43: 0.429, + 44: 0.425, + 45: 0.423, + 46: 0.420, + 47: 0.417, + 48: 0.414, + 49: 0.412, + 50: 0.409, + 51: 0.407, + 52: 0.405, + 53: 0.402, + 54: 0.400, + 55: 0.398, + 56: 0.396, + 57: 0.394, + 58: 0.392, + 59: 0.391, + 60: 0.388, + 61: 0.387, + 62: 0.385, + 63: 0.383, + 64: 0.382, + 65: 0.380, + 66: 0.379, + 67: 0.377, + 68: 0.376, + 69: 0.374, + 70: 0.372, + 71: 0.371, + 72: 0.370, + 73: 0.368, + 74: 0.368, + 75: 0.366, + 76: 0.365, + 77: 0.364, + 78: 0.363, + 79: 0.361, + 80: 0.360, + 81: 0.359, + 82: 0.358, + 83: 0.356, + 84: 0.356, + 85: 0.355, + 86: 0.353, + 87: 0.352, + 88: 0.352, + 89: 0.351, + 90: 0.350, + 91: 0.349, + 92: 0.348, + 93: 0.347, + 94: 0.346, + 95: 0.345, + 96: 0.344, + 97: 0.344, + 98: 0.343, + 99: 0.341, + 100: 0.341, +} \ No newline at end of file diff --git a/profiler/msprof_analyze/cluster_analyse/recipes/slow_rank/slow_rank.py b/profiler/msprof_analyze/cluster_analyse/recipes/slow_rank/slow_rank.py new file mode 100644 index 0000000000000000000000000000000000000000..58dfde4018df422021e03a3091a24de3fe93c115 --- /dev/null +++ b/profiler/msprof_analyze/cluster_analyse/recipes/slow_rank/slow_rank.py @@ -0,0 +1,190 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from collections import defaultdict + +import pandas as pd +import numpy as np + +from msprof_analyze.cluster_analyse.recipes.base_recipe_analysis import BaseRecipeAnalysis +from msprof_analyze.prof_common.constant import Constant +from msprof_analyze.prof_common.logger import get_logger +from msprof_analyze.prof_exports.slow_rank_export import SlowRankExport +from msprof_analyze.cluster_analyse.recipes.slow_rank.dixon_table import DIXON_TABLE_995 + +logger = get_logger() + + +def judge_norm(time_list, threshold=3): + t_max = max(time_list) + t_min = min(time_list) + t_mean = np.mean(time_list) + t_std = np.std(time_list) + threshold_high = t_mean + threshold * t_std + threshold_low = t_mean - threshold * t_std + + # 耗时低于下阈值的卡认为是慢卡 + outliers_idx = [i for i, time in enumerate(time_list) if time < threshold_low] + + # 如果存在高于上阈值的卡,则将耗时最短的卡加到慢卡的list中 + if t_max > threshold_high: + if time_list.index(t_min) not in outliers_idx: + outliers_idx.append(time_list.index(t_min)) + return outliers_idx + + +def judge_dixon(time_list): + n = len(time_list) + if n in [0, 1, 2]: + return [] + sorted_list = sorted(time_list) + + # 判断计算检验指标时分母是否可能为0 + if len(set(sorted_list)) <= 3: + return [] + + # 计算狄克逊检验的检验指标,次小值和最小值差,比上最大值和最小值的差。根据数据数量改变次小值和最大值的选取 + if n <= Constant.MAX_DIXON_NUM: + if n <= Constant.DIXON_THRESHOLD_1: + flag = (sorted_list[1] - sorted_list[0]) / (sorted_list[-1] - sorted_list[0]) \ + if (sorted_list[-1] - sorted_list[0]) else 0 + elif n <= Constant.DIXON_THRESHOLD_2: + flag = (sorted_list[1] - sorted_list[0]) / (sorted_list[-2] - sorted_list[0]) \ + if (sorted_list[-2] - sorted_list[0]) else 0 + elif n <= Constant.DIXON_THRESHOLD_3: + flag = (sorted_list[2] - sorted_list[0]) / (sorted_list[-2] - sorted_list[0]) \ + if (sorted_list[-2] - sorted_list[0]) else 0 + else: + flag = (sorted_list[2] - sorted_list[0]) / (sorted_list[-3] - sorted_list[0]) \ + if (sorted_list[-3] - sorted_list[0]) else 0 + + # 根据数据数量查表,若计算的检验指标较大,则认为有异常值,耗时最短的卡是慢卡 + if flag > DIXON_TABLE_995[n]: + return [time_list.index(sorted_list[0])] + return [] + + +def judge_slow_rank(time_list): + """根据time list长度 选择狄克逊检验或三倍标准差""" + if len(time_list) <= Constant.MAX_DIXON_NUM: + return judge_dixon(time_list) + else: + return judge_norm(time_list) + + +class SlowRankAnalysis(BaseRecipeAnalysis): + def __init__(self, params): + super().__init__(params) + logger.info("Slow Rank Analysis init.") + + @property + def base_dir(self): + return os.path.basename(os.path.dirname(__file__)) + + def reducer_func(self, mapper_res): + mapper_res = list(filter(lambda df: df is not None, mapper_res)) + if not mapper_res: + logger.error("Mapper data is None.") + return None + concated_df = pd.concat(mapper_res) + return concated_df + + def run(self, context): + if self._is_msprof: + logger.warning("Slow rank analysis do not support msprof db now.") + return + + mapper_res = self.mapper_func(context) + comm_ops_df = self.reducer_func(mapper_res) + if comm_ops_df is None: + return + + analyzer = SlowRankVoteAnalysis(comm_ops_df) + perpector_df = analyzer.run() + + if self._export_type == Constant.DB: + self.save_db(perpector_df) + elif self._export_type == "notebook": + self.save_notebook(perpector_df) + else: + logger.error("SlowRank analysis is not supported for notebook export type.") + + def save_db(self, perpector_df): + self.dump_data(perpector_df, Constant.DB_CLUSTER_COMMUNICATION_ANALYZER, "SlowRank") + + def save_notebook(self, perpector_df): + self.dump_data(perpector_df, "rank_stats.csv") + self.create_notebook("stats.ipynb") + self.add_helper_file("cluster_display.py") + + def _mapper_func(self, data_map, analysis_class): + profiler_db_path = data_map.get(Constant.PROFILER_DB_PATH) + step_range = data_map.get(Constant.STEP_RANGE) + df = SlowRankExport(profiler_db_path, analysis_class, step_range).read_export_db() + return df + + +class SlowRankVoteAnalysis: + def __init__(self, comm_ops): + self.comm_ops = comm_ops + + def grouping_ops(self): + """按照通信域、算子名称对通信算子进行分组""" + grouped_ops_dict = defaultdict(lambda: defaultdict(list)) + self.comm_ops = self.comm_ops[~self.comm_ops["opName"].str.contains("send")] + self.comm_ops = self.comm_ops[~self.comm_ops["opName"].str.contains("receive")] + grouped_df = self.comm_ops.groupby("groupName") + exclude_groups = [] + for group_name in grouped_df.groups.keys(): + ops_groupby_group_name = grouped_df.get_group(group_name) + ops_num = ops_groupby_group_name.groupby("opName").size().values + if len(set(ops_num)) > 1: + exclude_groups.append(group_name) + for exclude_group in exclude_groups: + self.comm_ops.drop(self.comm_ops[self.comm_ops["groupName"] == exclude_group].index, inplace=True) + self.comm_ops.reset_index(drop=True, inplace=True) + n = len(self.comm_ops) + group_name_arr = self.comm_ops["groupName"].values + op_name_arr = self.comm_ops["opName"].values + for idx in range(n): + group_name = group_name_arr[idx] + op_name = op_name_arr[idx] + grouped_ops_dict[group_name][op_name].append(idx) + return grouped_ops_dict + + def run(self): + grouped_ops_dict = self.grouping_ops() + perpector_dict = self.analysis(grouped_ops_dict) + return perpector_dict + + def analysis(self, grouped_ops_dict): + rank_id_arr = self.comm_ops["rankId"].values + comm_time_arr = self.comm_ops["communication_time"].values + perpector_dict = defaultdict(lambda: 0) + for _, ops_same_group in grouped_ops_dict.items(): + for _, ops_list in ops_same_group.items(): + time_list = [comm_time_arr[op_idx] for op_idx in ops_list] + perpector_rank_idx = judge_slow_rank(time_list) + if perpector_rank_idx: + for rank_idx in perpector_rank_idx: + slow_rank = rank_id_arr[ops_list[rank_idx]] + perpector_dict[slow_rank] += 1 + + perpector_df = pd.DataFrame(columns=["rankId", "slowAffectCount"]) + for rank, perpector_times in perpector_dict.items(): + perpector_df.loc[len(perpector_df)] = [rank, perpector_times] + perpector_df.set_index(["rankId"], inplace=True) + return perpector_df diff --git a/profiler/msprof_analyze/cluster_analyse/recipes/slow_rank/stats.ipynb b/profiler/msprof_analyze/cluster_analyse/recipes/slow_rank/stats.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..c63a50fd5dbca6a44bf516b38e49797e86981ee8 --- /dev/null +++ b/profiler/msprof_analyze/cluster_analyse/recipes/slow_rank/stats.ipynb @@ -0,0 +1,102 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Slow Rank\n", + "集群场景通信算子快慢卡汇总分析\n", + "\n", + "1.根据卡粒度,统计每个Rank上的影响因子\n", + "\n", + "2.将统计的结果按柱状图呈现,TOP影响的极为慢卡候选" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 数据准备" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "vscode": { + "languageId": "plaintext" + } + }, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import plotly.offline as pyo\n", + "\n", + "from IPython.display import display, HTML\n", + "\n", + "import cluster_display\n", + "\n", + "display(HTML(\"\"))\n", + "pd.set_option('display.max_columns', None)\n", + "pd.set_option('display.max_rows', None)\n", + "pyo.init_notebook_mode()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "展示各Rank受影响程度的统计表" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "df = pd.read_csv(\"rank_stats.csv\", index_col=\"rankId\")\n", + "display(df)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "cluster_display.display_bar(x_axis=df.index, y_axes=df, title=\"Slow Rank\", y_index=\"slowAffectCount\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.0" + }, + "pycharm": { + "stem_cell": { + "cell_type": "raw", + "metadata": { + "collapsed": false + }, + "source": [] + } + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/profiler/msprof_analyze/cluster_analyse/recipes/slow_rank_pp_stage/__init__.py b/profiler/msprof_analyze/cluster_analyse/recipes/slow_rank_pp_stage/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a355e5a7f08206fc39dda4646817224c067f29f7 --- /dev/null +++ b/profiler/msprof_analyze/cluster_analyse/recipes/slow_rank_pp_stage/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/profiler/msprof_analyze/cluster_analyse/recipes/slow_rank_pp_stage/slow_rank_pp_stage.py b/profiler/msprof_analyze/cluster_analyse/recipes/slow_rank_pp_stage/slow_rank_pp_stage.py new file mode 100644 index 0000000000000000000000000000000000000000..8d26f849291da1a87cf51ef85b6ca0bc9320951f --- /dev/null +++ b/profiler/msprof_analyze/cluster_analyse/recipes/slow_rank_pp_stage/slow_rank_pp_stage.py @@ -0,0 +1,239 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import json +from collections import defaultdict + +import pandas as pd + +from msprof_analyze.cluster_analyse.recipes.base_recipe_analysis import BaseRecipeAnalysis +from msprof_analyze.prof_common.constant import Constant +from msprof_analyze.prof_common.logger import get_logger +from msprof_analyze.prof_exports.cluster_time_summary_export import CommunicationTimeExport +from msprof_analyze.prof_common.database_service import DatabaseService + +logger = get_logger() + + +class SlowRankPPStageAnalysis(BaseRecipeAnalysis): + def __init__(self, params): + super().__init__(params) + logger.info("SlowRank PPstage analysis init.") + + self.p2p_analysis_result = None + self.pp_analysis_result = None + self.p2p_vote_result = None + self.pp_vote_result = None + + self.distributed_args = self.load_distributed_args() + + @property + def base_dir(self): + return os.path.basename(os.path.dirname(__file__)) + + @classmethod + def add_parser_argument(cls, parser): + parser.add_argument("--tp", type=int, help=cls.TP_SIZE, default=None) + parser.add_argument("--pp", type=int, help=cls.PP_SIZE, default=None) + parser.add_argument("--dp", type=int, help=cls.DP_SIZE, default=None) + + def reducer_func(self, mapper_res): + mapper_res = list(filter(lambda df: df is not None, mapper_res)) + if not mapper_res: + logger.error("Mapper data is None.") + return None + concated_df = pd.concat(mapper_res) + return concated_df + + def run(self, context): + if self.distributed_args is None: + return + mapper_res = self.mapper_func(context) + comm_ops_df = self.reducer_func(mapper_res) + if comm_ops_df is None: + return + + p2p_analysis_result_list = [] + p2p_vote_result_list = [] + pp_analysis_result_list = [] + pp_vote_result_list = [] + + pp_stage_rank_map = self.map_rank_pp_stage() + + for _, df_one_step in comm_ops_df.groupby("step"): + p2p_analysis_result, p2p_vote_result, pp_analysis_result, pp_vote_result = \ + SlowRankPPStageStepAnalysis(df_one_step).analysis(pp_stage_rank_map) + p2p_analysis_result_list.append(p2p_analysis_result) + p2p_vote_result_list.append(p2p_vote_result) + pp_analysis_result_list.append(pp_analysis_result) + pp_vote_result_list.append(pp_vote_result) + + for step_id, (p2p_analysis_result, p2p_vote_result, pp_analysis_result, pp_vote_result) in \ + enumerate( + zip( + p2p_analysis_result_list, + p2p_vote_result_list, + pp_analysis_result_list, + pp_vote_result_list + )): + p2p_analysis_result["step"] = step_id + p2p_vote_result["step"] = step_id + pp_analysis_result["step"] = step_id + pp_vote_result["step"] = step_id + + self.p2p_analysis_result = pd.concat(p2p_analysis_result_list) + self.p2p_vote_result = pd.concat(p2p_vote_result_list) + self.pp_analysis_result = pd.concat(pp_analysis_result_list) + self.pp_vote_result = pd.concat(pp_vote_result_list) + + if self._export_type == Constant.DB: + self.save_db() + else: + logger.error("SlowRank PPstage is not supported for notebook export type.") + + def save_db(self): + self.dump_data(self.p2p_vote_result, Constant.DB_CLUSTER_COMMUNICATION_ANALYZER, "P2PAnalysisResult") + self.dump_data(self.pp_vote_result, Constant.DB_CLUSTER_COMMUNICATION_ANALYZER, "PPAnalysisResult") + + def map_rank_pp_stage(self): + tp_size = self.distributed_args.get(self.TP_SIZE, 1) + pp_size = self.distributed_args.get(self.PP_SIZE, 1) + dp_size = self.distributed_args.get(self.DP_SIZE, 1) + + rank_pp_stage_map = {} + rank = 0 + for i in range(pp_size): + for _ in range(tp_size * dp_size): + rank_pp_stage_map[rank] = i + rank += 1 + return rank_pp_stage_map + + def _mapper_func(self, data_map, analysis_class): + profiler_db_path = data_map.get(Constant.PROFILER_DB_PATH) + df = CommunicationTimeExport(profiler_db_path, analysis_class).read_export_db() + return df + + +class SlowRankPPStageStepAnalysis: + def __init__(self, comm_ops): + self.comm_ops = comm_ops + self.exclude_ranks = [] + + def grouping_pp_stage_ops(self, pp_stage_rank_map): + p2p_op_group = defaultdict(lambda: defaultdict(list)) + pp_op_group = defaultdict(lambda: defaultdict(list)) + + def divid_opname(op_name): + # op_name的格式:输入 OPTYPE__GORUPHASH_IDX_1 输出 OPTYPE_IDX + splited_name = op_name.split("__") + if len(splited_name) != 2: + return None + splited_num = splited_name[1].split("_") + if len(splited_num) != 3: + return None + return "_".join([splited_name[0], splited_num[1]]) + + ops_num = len(self.comm_ops) + op_name_arr = self.comm_ops["opName"].values + rank_id_arr = self.comm_ops["rank"].values + for idx in range(ops_num): + rank = rank_id_arr[idx] + op_name = op_name_arr[idx] + op_name_short = divid_opname(op_name) + if op_name_short is None: + continue + pp_stage_idx = pp_stage_rank_map[rank] + if rank in self.exclude_ranks: + continue + if "send" in op_name_short or "receive" in op_name_short: + p2p_op_group[pp_stage_idx][op_name_short].append(idx) + else: + pp_op_group[pp_stage_idx][op_name_short].append(idx) + + return p2p_op_group, pp_op_group + + def analysis_pp_stage(self, vote_group): + min_time_dict = defaultdict(lambda: defaultdict(lambda: 0)) + max_time_dict = defaultdict(lambda: defaultdict(lambda: 0)) + mean_time_dict = defaultdict(lambda: defaultdict(lambda: 0)) + count_dict = defaultdict(lambda: defaultdict(lambda: 0)) + rank_vote = defaultdict(lambda: 0) + perpetrator_dict = defaultdict(lambda: defaultdict(lambda: 0)) + minimum_rank_op_name = defaultdict(list) + + communication_time_arr = self.comm_ops["communication_time"].values + rank_id_arr = self.comm_ops["rank"].values + for pp_idx, ops_same_group in vote_group.items(): + for op_name, ops in ops_same_group.items(): + communication_time_list = [communication_time_arr[op_idx] for op_idx in ops] + min_time = min(communication_time_list) + min_op_idx = ops[communication_time_list.index(min_time)] + min_op_rank = rank_id_arr[min_op_idx] + rank_vote[min_op_rank] += 1 + perpetrator_dict[pp_idx][op_name] = min_op_rank + minimum_rank_op_name[min_op_rank].append(op_name) + + max_time = max(communication_time_list) + mean_time = sum(communication_time_list) // len(communication_time_list) + min_time_dict[pp_idx][op_name] = min_time + max_time_dict[pp_idx][op_name] = max_time + mean_time_dict[pp_idx][op_name] = mean_time + count_dict[pp_idx][op_name] = len(ops) + + analysis_result = pd.DataFrame( + columns=[ + "ppIdx", + "opName", + "minTime", + "maxTime", + "meanTime", + "count", + "perpetratorRank" + ] + ) + + for pp_idx in min_time_dict.keys(): + for op_name in min_time_dict[pp_idx].keys(): + analysis_result.loc[len(analysis_result)] = [ + pp_idx, op_name, + min_time_dict[pp_idx][op_name], + max_time_dict[pp_idx][op_name], + mean_time_dict[pp_idx][op_name], + count_dict[pp_idx][op_name], + perpetrator_dict[pp_idx][op_name] + ] + + vote_result = pd.DataFrame(columns=["rankId", "minimumTimes"]) + for rank, minimum_times in rank_vote.items(): + vote_result.loc[len(vote_result)] = [rank, minimum_times] + vote_result.set_index(["rankId"], inplace=True) + + return analysis_result, vote_result + + def analysis(self, pp_stage_rank_map): + self.select_exclude_ranks() + p2p_op_group, pp_op_group = self.grouping_pp_stage_ops(pp_stage_rank_map) + p2p_analysis_result, p2p_vote_result = self.analysis_pp_stage(p2p_op_group) + pp_analysis_result, pp_vote_result = self.analysis_pp_stage(pp_op_group) + return p2p_analysis_result, p2p_vote_result, pp_analysis_result, pp_vote_result + + def select_exclude_ranks(self): + grouped_df = self.comm_ops.groupby("rank") + for rank in grouped_df.groups.keys(): + ops_groupby_rank = grouped_df.get_group(rank) + ops_num = ops_groupby_rank.groupby("opName").size().values + if len(set(ops_num)) > 1: + self.exclude_ranks.append(rank) diff --git a/profiler/msprof_analyze/compare_tools/compare_backend/compare_bean/origin_data_bean/trace_event_bean.py b/profiler/msprof_analyze/compare_tools/compare_backend/compare_bean/origin_data_bean/trace_event_bean.py index 9d813c23b63350ecc724dfea5cbdd36ac0579afd..ab12d640a1aad9478ca067c56db3bcc10a156a0c 100644 --- a/profiler/msprof_analyze/compare_tools/compare_backend/compare_bean/origin_data_bean/trace_event_bean.py +++ b/profiler/msprof_analyze/compare_tools/compare_backend/compare_bean/origin_data_bean/trace_event_bean.py @@ -193,7 +193,7 @@ class TraceEventBean: return self._args.get("name", "").find("Communication") != -1 def is_hccl_process_name(self) -> bool: - return self.process_name in ["Communication", "HCCL"] + return self.process_name == "HCCL" def is_overlap_process_name(self) -> bool: return self.process_name == "Overlap Analysis" diff --git a/profiler/msprof_analyze/compare_tools/compare_backend/data_prepare/sequence_pre_matching.py b/profiler/msprof_analyze/compare_tools/compare_backend/data_prepare/sequence_pre_matching.py index cdca93a92767f169f4f4c014ed18f5aa7d407a7a..5c2590c723e646660b456acbdf3f114fb2726190 100644 --- a/profiler/msprof_analyze/compare_tools/compare_backend/data_prepare/sequence_pre_matching.py +++ b/profiler/msprof_analyze/compare_tools/compare_backend/data_prepare/sequence_pre_matching.py @@ -91,7 +91,7 @@ class SequencePreMatching: base_index += 1 comparison_index += 1 while comparison_index < comparison_data_len: - result_data.extend(self._match_torch_op([], comparison_data[comparison_index].get(Constant.OPS, []))) + result_data.extend(self._match_torch_op([], comparison_data[0].get(Constant.OPS, []))) comparison_index += 1 return result_data diff --git a/profiler/msprof_analyze/docs/custom_analysis_guide.md b/profiler/msprof_analyze/docs/custom_analysis_guide.md new file mode 100644 index 0000000000000000000000000000000000000000..5791f5a6651245ef83a298df962b456ebb594043 --- /dev/null +++ b/profiler/msprof_analyze/docs/custom_analysis_guide.md @@ -0,0 +1,155 @@ +### 自定义分析规则开发指导 +自定义分析规则是基于对Profiling的analysis.db和ascend_pytorch_profiler_{rank_id}.db文件进行性能数据分析而开发。与cann_api_sum、compute_op_sum、hccl_sum等参数功能实现类似,可自定义一套性能数据的分析规则,方法如下: + +1. 在mstt工具代码仓profiler/msprof_analyze/cluster_analyse/recipes目录下创建xxx目录和xxx.py文件。 + + 例如:profiler/msprof_analyze/cluster_analyse/recipes/cann_api_sum/cann_api_sum.py,其中目录名和文件名要保持一致,该目录名也会作为使用msprof-analyze cluster工具启动该自定义分析的开关参数。 + +2. 在xxx.py文件进行性能数据分析规则的开发,开发要求继承BaseRecipeAnalysis,实现run函数。 + + 典型的run函数实现: + + ```python + def run(self, context): + mapper_res = self.mapper_func(context) + self.reducer_func(mapper_res) + if self._export_type == "db": + self.save_db() + elif self._export_type == "notebook": + self.save_notebook() + else: + logger.error("Unknown export type.") + ``` + + 1. `mapper_func`函数:多卡数据查询并合并返回结果。由于集群数据每张卡的数据处理是同样的,因此采用context并行处理集群数据并将结果按序拼装返回。开发只需要实现单卡数据处理的函数`self._mapper_fun`。 + + ```python + def mapper_func(self, context): + return context.wait( + context.map( + self._mapper_func, + self._get_rank_db(), + analysis_class=self._recipe_name + ) + ) + ``` + + ```python + def _mapper_func(self, data_map, analysis_class): + """ + Extract the profiling data required for cluster analysis from each device, and then aggregate the + results from each device to be processed by a reduce function. + Params: + data_map: eg. {"RANK_ID": 1, "profiler_db_path": "xxxx/ascend_pytorch_profiler_1.db"} + analysis_class: hccl_sum, compute_op_sum, cann_api_sum, mstx_sum...... + """ + pass + ``` + + 2. `reducer_func`函数:对多卡结果分析处理。接收`mapper_func`函数的返回值,进行进一步的集群数据的汇总分析,数据结构采用dataframe。 + + 3. `save_db`函数:分析结果保存在cluster_analysis.db中。 + + 4. `save_notebook`函数:分析结果以csv和stats.ipynb的形式保存。 + +3. `self._mapper_fun`函数依赖单db数据查询,可通过可通过如下两种方式。 + + 1. 使用DatabaseService可配置单表的查询。 + + 可参考:https://gitee.com/ascend/mstt/blob/pre-research/profiler/msprof_analyze/cluster_analyse/recipes/mstx2commop/mstx2commop.py + + 使用样例: + + ```Python + service = DatabaseService(profiler_db_path) + service.add_table_for_query("ENUM_HCCL_DATA_TYPE", ["id", "name"]) # 第一个参数:表名;第二个参数:字段列表,默认为None,当不填写时表明select * + service.add_table_for_query("STRING_IDS", ["id", "value"]) #可 以添加多个表 + df_dict = service.query_data() # 将配置的所有表按序查询,以dict形式返回,key为表名,value为数据库查询结果dataframe数据类型 + ``` + + 2. 维护在msprof_analyze/prof_exports目录下,新建一个py文件,需继承自BaseStatsExport(注:新增之前可以看现有的是否可用,避免重复)如下示例: + + ```Python + from msprof_analyze.prof_exports.base_stats_export import BaseStatsExport + + QUERY = """ + SELECT + NAME_IDS.value AS "OpName", + TYPE_IDS.value AS "OpType", + round(endNs - startNs) AS "Duration", + GROUP_NAME_IDS.value AS "GroupName" + FROM + COMMUNICATION_OP + LEFT JOIN + STRING_IDS AS TYPE_IDS + ON TYPE_IDS.id == COMMUNICATION_OP.opType + LEFT JOIN + STRING_IDS AS NAME_IDS + ON NAME_IDS.id == COMMUNICATION_OP.opName + LEFT JOIN + STRING_IDS AS GROUP_NAME_IDS + ON GROUP_NAME_IDS.id == COMMUNICATION_OP.groupName + """ + + + class HcclSumExport(BaseStatsExport): + def __init__(self, db_path, recipe_name): + super().__init__(db_path, recipe_name) + self._query = QUERY + ``` + + 使用样例:df = HcclSumExport(profiler_db_path, analysis_class).read_export_db(),返回的数据类型是dataframe。 + +4. 分析规则增加拓展参数。 + + 实现函数add_parser_argument,样例如下: + + ```Python + @classmethod + def add_parser_argument(cls, parser): + parser.add_argument("--top_num", type=str, help="Duration cost top count", default=cls.DEFAULT_TOP_NUM) + ``` + + 从self._extra_args里获取对应的扩展参数: + + ```Python + def __init__(self, params): + super().__init__(params) + top_num = self._extra_args.get(self.TOP_NUM, self.DEFAULT_TOP_NUM) + self.top_num = int(top_num) if isinstance(top_num, str) and top_num.isdigit() else self.DEFAULT_TOP_NUM + ``` + +5. 执行自定义分析规则命令。 + + ```bash + msprof-analyze cluster -d {cluster profiling data path} --mode xxx --top_num 10 + ``` + +### 开发和上库流程规范 + +开发要遵守以下流程规范。 + +1. **需求澄清和串讲** + + 确定要做该需求后,首先要明确该需求的**迭代时间**,开发流程需要严格遵守我们的迭代时间,参加该需求的需求澄清以及串讲(我们会安排相应会议)。需求澄清可由DE完成(对齐输入输入以及流程图),需求串讲需要开发者来完成,串讲时需要准备**设计文档和测试用例**(有文档模版,可以跟SE或者DE联系拿到)。 + +2. **UT** + + 为了保证后面的开发者修改你的代码时不会影响你的功能,或者能够感知这次修改的影响,比如算法实现、字段变更等,需要在上库的同时添加UT。 + UT的编写可以参考已经上库的其他用例,建议四段式命名:test_{目标方法名}_should_{预期结果}_when_{分支条件}_given_{输入参数},可以灵活使用mock方式构造虚拟返回。 + +3. **资料编写** + + 目前,如果新增一个分析能力,需要在[操作步骤](#操作步骤)的第2小节的“--mode参数说明”中添加对应参数的说明,简洁说明该分析能力的作用以及输入输出。 + 另外,需要在[recipe结果和cluster_analysis.db交付件表结构说明](#recipe结果和cluster_analysisdb交付件表结构说明)中添加表结构说明,明确输入输出。可以详细说明你的分析能力的**主要场景、用途甚至是算法原理**,保证用户知道这个分析能力的能做什么,对调优有什么帮助。(参考[freq_analysis](#freq_analysis)的说明) + +4. **CI** + + 正常商发需求合入master分支;预研需求合入pre-research分支;poc需求合入poc分支。 + 提了PR之后,可以评论**compile**,触发线上CI,会跑cleancode和冒烟,只有全绿,才可以发起代码检视。PR合入需要lgtm标签和approve标签(群里有相应的committer可以加标签)。 + +5. **代码检视** + + 代码上库,需要经过检视,可以将链接发到**msprof-analyze代码检视群**,说明该PR的标题,然后@相关人进行检视。修改完检视意见后再次@commiter,合代码。 + 为了让结果可信以及方便其他开发或者测试使用这个分析能力,需要编写测试用例并提供**自验报告**作为凭证。 + 注:cluster_analysis.db里面的表格,统一遵守表名大驼峰,列名小驼峰的命名规则。 \ No newline at end of file diff --git a/profiler/msprof_analyze/docs/recipe_output_format.md b/profiler/msprof_analyze/docs/recipe_output_format.md new file mode 100644 index 0000000000000000000000000000000000000000..089287a112670556afb02aa444042fe38fea419e --- /dev/null +++ b/profiler/msprof_analyze/docs/recipe_output_format.md @@ -0,0 +1,614 @@ +## recipe结果和cluster_analysis.db交付件表结构说明 + +说明: + +msprof-analyze配置--mode参数时可分析并输出cluster_analysis.db交付件,本节介绍该交付件的表结构和字段说明。 + +注意:部分分析能力不会生成cluster_analysis.db。 + +### compute_op_sum + +设置-m compute_op_sum时,会生成以下表。 + +#### ComputeOpAllRankStats + +说明: + +基于db格式的集群性能数据,针对全部rank的数据,以OpType和TaskType分组,对计算算子的耗时进行统计分析。 + +格式: + +| 字段名 | 类型 | 含义 | +| ------ | ---- | ---- | +| OpType | TEXT | 计算算子类型 | +| TaskType | TEXT | 算子执行的加速器类型 | +| Count | INTEGER | 以OpType和TaskType分组进行统计的算子数量 | +| MeanNs | REAL | 耗时的平均值 | +| StdNs | REAL | 耗时的标准差 | +| MinNs | REAL | 耗时的最小值 | +| Q1Ns | REAL | 耗时的25%分位数 | +| MedianNs | REAL | 耗时的50%分位数 | +| Q3Ns | REAL | 耗时的75%分位数 | +| MaxNs | REAL | 耗时的最大值 | +| SumNs | REAL | 耗时的总和 | + +#### ComputeOpPerRankStatsByOpType + +说明: + +基于db格式的集群性能数据,针对每个rank的数据,以OpType和TaskType分组,对计算算子的耗时进行统计分析。 + +格式: + +| 字段名 | 类型 | 含义 | +| ------ | ---- | ---- | +| OpType | TEXT | 计算算子类型 | +| TaskType | TEXT | 算子执行的加速器类型 | +| Count | INTEGER | 以OpType和TaskType分组进行统计的算子数量 | +| MeanNs | REAL | 耗时的平均值 | +| StdNs | REAL | 耗时的标准差 | +| MinNs | REAL | 耗时的最小值 | +| Q1Ns | REAL | 耗时的25%分位数 | +| MedianNs | REAL | 耗时的50%分位数 | +| Q3Ns | REAL | 耗时的75%分位数 | +| MaxNs | REAL | 耗时的最大值 | +| SumNs | REAL | 耗时的总和 | +| Rank | INTEGER | rank_id | + +#### ComputeOpPerRankStatsByOpName + +说明: + +配置--exclude_op_name参数时不会生成该表; +基于db格式的集群性能数据,针对每个rank的数据,以OpName、OpType、TaskType和InputShapes分组,对计算算子的耗时进行统计分析。 + +格式: + +| 字段名 | 类型 | 含义 | +| ------ | ---- | ---- | +| OpName | TEXT | 计算算子名字 | +| OpType | TEXT | 计算算子类型 | +| TaskType | TEXT | 算子执行的加速器类型 | +| InputShapes | TEXT | 算子的输入维度 | +| Count | INTEGER | 这个分组的算子数量 | +| MeanNs | REAL | 耗时的平均值 | +| StdNs | REAL | 耗时的标准差 | +| MinNs | REAL | 耗时的最小值 | +| Q1Ns | REAL | 耗时的25%分位数 | +| MedianNs | REAL | 耗时的50%分位数 | +| Q3Ns | REAL | 耗时的75%分位数 | +| MaxNs | REAL | 耗时的最大值 | +| SumNs | REAL | 耗时的总和 | +| Rank | INTEGER | rank_id | + +### cann_api_sum + +设置-m cann_api_sum时,会生成以下表。 + +#### CannApiSum + +说明: + +基于db格式的集群性能数据,针对全部rank的数据,对每一种api(名字不同)的耗时进行统计分析。 + +格式: + +| 字段名 | 类型 | 含义 | +| ------ | ---- | ---- | +| name | TEXT | API名字 | +| timeRatio | REAL | API的耗时占所有API总耗时的百分比 | +| totalTimeNs | INTEGER | API的总耗时 | +| totalCount | INTEGER | API的数量 | +| averageNs | REAL | 耗时的平均值 | +| Q1Ns | REAL | 耗时的25%分位数 | +| medNs | REAL | 耗时的50%分位数 | +| Q3Ns | REAL | 耗时的75%分位数 | +| minNs | REAL | 耗时的最小值 | +| maxNs | REAL | 耗时的最大值 | +| stdev | REAL | 耗时的标准差 | +| minRank | TEXT | minNs对应的rank的集合 | +| maxRank | TEXT | maxNs对应的rank的集合 | + +#### CannApiSumRank + +说明: + +基于db格式的集群性能数据,针对每个rank的数据,对每一种api(名字不同)的耗时进行统计分析。 + +格式: + +| 字段名 | 类型 | 含义 | +| ------ | ---- | ---- | +| name | TEXT | API名字 | +| durationRatio | REAL | API的耗时占卡内所有API总耗时的百分比 | +| totalTimeNs | INTEGER | API的总耗时 | +| totalCount | INTEGER | API的数量 | +| averageNs | REAL | 耗时的平均值 | +| minNs | REAL | 耗时的最小值 | +| Q1Ns | REAL | 耗时的25%分位数 | +| medNs | REAL | 耗时的50%分位数 | +| Q3Ns | REAL | 耗时的75%分位数 | +| maxNs | REAL | 耗时的最大值 | +| stdev | REAL | 耗时的标准差 | +| rank | INTEGER | rank_id | + +### hccl_sum + +设置-m hccl_sum时,会生成以下表。 + +#### HcclAllRankStats + +说明: + +基于db格式的集群性能数据,针对全部rank的数据,对每一种通信算子类型(例如hcom_broadcast_)的耗时进行统计分析。 + +格式: + +| 字段名 | 类型 | 含义 | +| ------ | ---- | ---- | +| OpType | TEXT | 通信算子类型 | +| Count | INTEGER | 数量 | +| MeanNs | REAL | 耗时的平均值 | +| StdNs | REAL | 耗时的标准差 | +| MinNs | REAL | 耗时的最小值 | +| Q1Ns | REAL | 耗时的25%分位数 | +| MedianNs | REAL | 耗时的50%分位数 | +| Q3Ns | REAL | 耗时的75%分位数 | +| MaxNs | REAL | 耗时的最大值 | +| SumNs | REAL | 耗时的总和 | + +#### HcclPerRankStats + +说明: + +基于db格式的集群性能数据,针对每个rank的数据,对每一种通信算子类型(例如hcom_broadcast_)的耗时进行统计分析。 + +格式: + +| 字段名 | 类型 | 含义 | +| ------ | ---- | ---- | +| OpType | TEXT | 通信算子类型 | +| Count | INTEGER | 数量 | +| MeanNs | REAL | 耗时的平均值 | +| StdNs | REAL | 耗时的标准差 | +| MinNs | REAL | 耗时的最小值 | +| Q1Ns | REAL | 耗时的25%分位数 | +| MedianNs | REAL | 耗时的50%分位数 | +| Q3Ns | REAL | 耗时的75%分位数 | +| MaxNs | REAL | 耗时的最大值 | +| SumNs | REAL | 耗时的总和 | +| Rank | INTEGER | rank_id | + +#### HcclGroupNameMap + +说明: + +通信域内包含的rank。 + +格式: + +| 字段名 | 类型 | 含义 | +| ------ | ---- | ---- | +| GroupName | TEXT | 通信域,例如:10.170.22.98%enp67s0f5_60000_0_1708156014257149 | +| GroupId | TEXT | 通信域的hash值的后三位 | +| Ranks | TEXT | 该通信域的所有rank | + +#### HcclTopOpStats + +说明: + +基于db格式的集群性能数据,对所有rank的通信算子的耗时进行分析,展示耗时平均值排名TOP N(默认为 15)的通信算子的数据。 + +格式: + +| 字段名 | 类型 | 含义 | +| ------ | ---- | ---- | +| OpName | TEXT | 通信算子名,例如hcom_allReduce__606_0_1 | +| Count | INTEGER | 数量 | +| MeanNs | REAL | 耗时的平均值 | +| StdNs | REAL | 耗时的标准差 | +| MinNs | REAL | 耗时的最小值 | +| Q1Ns | REAL | 耗时的25%分位数 | +| MedianNs | REAL | 耗时的50%分位数 | +| Q3Ns | REAL | 耗时的75%分位数 | +| MaxNs | REAL | 耗时的最大值 | +| SumNs | REAL | 耗时的总和 | +| MinRank | INTEGER | 该通信算子耗时最小的rank | +| MaxRank | INTEGER | 该通信算子耗时最大的rank | + +### mstx_sum + +设置-m mstx_sum时,会生成以下表。 + +#### MSTXAllFrameworkStats + +说明: + +基于db格式的集群性能数据,分析mstx打点数据的框架侧耗时(不区分rank)。 + +格式: + +| 字段名 | 类型 | 含义 | +| ------ | ---- | ---- | +| Name | TEXT | mstx打点数据携带信息 | +| Count | INTEGER | 该迭代内以Name为分组的打点的次数 | +| MeanNs | REAL | 平均值 | +| StdNs | REAL | 标准差 | +| MinNs | REAL | 最小值 | +| Q1Ns | REAL | 25%分位数 | +| MedianNs | REAL | 50%分位数 | +| Q3Ns | REAL | 75%分位数 | +| MaxNs | REAL | 最大值 | +| SumNs | REAL | 总和 | +| StepId | INTEGER | 迭代id | + +#### MSTXAllCannStats + +说明: + +基于db格式的集群性能数据,分析mstx打点数据的cann层耗时(不区分rank)。 + +格式: + +| 字段名 | 类型 | 含义 | +| ------ | ---- | ---- | +| Name | TEXT | mstx打点数据携带信息 | +| Count | INTEGER | 该迭代内以Name为分组的打点的次数 | +| MeanNs | REAL | 平均值 | +| StdNs | REAL | 标准差 | +| MinNs | REAL | 最小值 | +| Q1Ns | REAL | 25%分位数 | +| MedianNs | REAL | 50%分位数 | +| Q3Ns | REAL | 75%分位数 | +| MaxNs | REAL | 最大值 | +| SumNs | REAL | 总和 | +| StepId | INTEGER | 迭代id | + +#### MSTXAllDeviceStats + +说明: + +基于db格式的集群性能数据,分析mstx打点数据的device侧耗时(不区分rank)。 + +格式: + +| 字段名 | 类型 | 含义 | +| ------ | ---- | ---- | +| Name | TEXT | mstx打点数据携带信息 | +| Count | INTEGER | 该迭代内以Name为分组的打点的次数 | +| MeanNs | REAL | 平均值 | +| StdNs | REAL | 标准差 | +| MinNs | REAL | 最小值 | +| Q1Ns | REAL | 25%分位数 | +| MedianNs | REAL | 50%分位数 | +| Q3Ns | REAL | 75%分位数 | +| MaxNs | REAL | 最大值 | +| SumNs | REAL | 总和 | +| StepId | INTEGER | 迭代id | + +#### MSTXMarkStats + +说明: + +基于db格式的集群性能数据,针对每个rank的打点数据,以Rank,StepId分组,对mstx打点的耗时进行统计分析。 + +格式: + +| 字段名 | 类型 | 含义 | +| ------ | ---- | ---- | +| Name | TEXT | mstx打点数据携带信息 | +| FrameworkDurationNs | REAL | 框架侧耗时 | +| CannDurationNs | REAL | CANN层耗时 | +| DeviceDurationNs | REAL | device侧耗时 | +| Rank | INTEGER | global rank | +| StepId | INTEGER | 迭代id | + +### communication_group_map + +设置-m communication_group_map,会生成以下表。 + +#### CommunicationGroupMapping + +说明: + +基于db格式的集群性能数据,生成通信域与并行策略的对应关系。 + +格式: + +| 字段名 | 类型 | 含义 | +| ------ | ---- | ---- | +| type | TEXT | 算子类型,包含collective和p2p, 其中算子名包含"send","recv","receive"的算子被认为是p2p | +| rank_set | TEXT | 通信域内包含的rank(global rank)| +| group_name | TEXT | 通信域的hash值,可映射成group_id | +| group_id | TEXT | hccl内部定义的通信域名字,例如:10.170.22.98%enp67s0f5_60000_0_1708156014257149 | +| pg_name | TEXT | 业务定义的通信域名字,例如:"dp","dp_cp","mp"等等 | + +### cluster_time_summary + +设置-m cluster_time_summary时,会生成以下表。 + +说明:和cluster_step_trace_time.csv相似,之后考虑替代它。 + +#### ClusterTimeSummary + +说明: + +基于db格式的集群性能数据,针对全部rank的数据,对集群的一些耗时进行统计分析,可以用来定位性能问题。 + +格式: +**下表的时间单位都是us** + +| 字段名 | 类型 | 含义 | +| ------ | ---- | ---- | +| rank | INTEGER | global rank | +| step | INTEGER | 迭代id | +| stepTime | REAL | 整个迭代耗时 | +| computation | REAL | 计算时间的全部耗时 | +| communicationNotOverlapComputation | REAL | 未被计算掩盖的通信耗时 | +| communicationOverlapComputation | REAL | 计算与通信重叠的耗时 | +| communication | REAL | 通信时间的全部耗时 | +| free | REAL | 空闲时间,指device侧既不在通信也不在计算、并且不包含异步拷贝的总耗时 | +| communicationWaitStageTime | REAL | 通信等待总耗时 | +| communicationTransmitStageTime | REAL | 通信传输总耗时 | +| memory | REAL | 异步拷贝的总耗时 | +| memoryNotOverlapComputationCommunication | REAL | 不被计算和通信掩盖的异步拷贝的总耗时 | +| taskLaunchDelayAvgTime | REAL | 下发耗时,指所有task从host侧api的开始时间到device侧task的开始时间的平均耗时 | + +### cluster_time_compare_summary + +设置-m cluster_time_compare_summary时,会生成以下表。 + +说明:该分析能力需要基于cluster_time_summary的结果,集群数据和标杆集群数据都要有cluster_analysis.db,db里面要有ClusterTimeSummary这个表。 + +#### ClusterTimeCompareSummary + +说明:结果表示当前集群与标杆集群的比较结果,比如computationDiff表示当前集群与标杆集群的计算时间的差值,如果当前集群的计算时间比标杆集群多,则computationDiff为正数,反之为负数。 + +格式: +**下表的时间单位都是us** + +| 字段名 | 类型 | 含义 | +|--------------| ---- | ---- | +| rank | INTEGER | global rank | +| step | INTEGER | 迭代id | +| stepTime | REAL | 当前集群数据的迭代耗时 | +| stepTimeBase | REAL | 标杆集群数据的计算时间 | +| stepTimeDiff | REAL | 迭代耗时的差值 | +…… +| taskLaunchDelayAvgTime | REAL | 当前集群数据的下发耗时 | +| taskLaunchDelayAvgTimeBase | REAL | 标杆集群数据的下发耗时 | +| taskLaunchDelayAvgTimeDiff | REAL | 下发耗时的差值 | + +由于列过多,就不展示全部列了,对于ClusterTimeSummary的每一列,在这个表里面都会展示当前集群数据、标杆集群数据以及他们的差值。 + + +### freq_analysis + +说明: + +基于db格式的集群性能数据,分析aicore frequency,提供npu降频一键检测能力。频率分为三种情况: +* 正常情况下,应当稳定在1800MHz; +* 当npu空闲时间较长时,设备会自动降频,会掉到800MHz; +* 当npu因为各种原因,出现降频现象时,除了1800MHz,800MHz,还有出现其他异常频率。 + +设置-m freq_analysis时,如果发生降频,会生成以下表。 + +#### FreeFrequencyRanks + +说明: + +对应第二种情况。 + +格式: + +| 字段名 | 类型 | 含义 | +| ------ | ---- | ---- | +| rankId | INTEGER | global rank | +| aicoreFrequency | TEXT | [800, 1800] | + +#### AbnormalFrequencyRanks + +说明: + +对应第三种情况。 + +格式: + +| 字段名 | 类型 | 含义 | +| ------ | ---- | ---- | +| rankId | INTEGER | global rank | +| aicoreFrequency | TEXT | 异常频率列表;例如:[800, 1150, 1450, 1800] | + +### ep_load_balance + +说明: + +集群训练场景下,MOE负载不均指的是,在分布式环境下,不同的专家模型处理的任务量不均衡,导致某些专家过载(处理过多任务),而其他专家闲置。这种负载不均会降低系统的整体效率,甚至可能导致性能瓶颈。 + +设置-m ep_load_balance时,会生成以下表。 + +#### EPTokensSummary + +说明: + +基于db格式的集群性能数据,分析GroupedMatmul算子的shape信息。 + +格式: + +| 字段名 | 类型 | 含义 | +| ------ | ---- | ---- | +| rank | INTEGER | global rank | +| epRanks | TEXT | 同一个ep(Expert Parallelism)的rank集合,例如0,1 | +| inputShapesSummary | INTEGER | 该rank的GroupedMatmul算子的inputshapes的第一个维度的总和 | + +#### TopEPTokensInfo + +说明: + +负载不均的ep。 + +格式: + +| 字段名 | 类型 | 含义 | +| ------ | ---- | ---- | +| epRanks | TEXT | 负载不均的ep(Expert Parallelism)的rank集合,例如0,1 | +| tokensDiff | INTEGER | 同一个ep内最大值与最小值之间的差值 | + + +### filter_db + +设置-m filter_db时,不会生成cluster_analysis.db,会将原始db进行过滤,使得集群数据变小,有时候甚至能减少90%,将TB级数据过滤至GB级。过滤后的数据依旧可以在MindStudio Insight中呈现。 + +**说明:最好指定-o 参数,指定提取后的集群数据的目录。** + +示例:msprof-analyze cluster -d {cluster_profiling_data_path} -m filter_db -o {output_dir} + +结果示例: +``` +output_dir +├── cluster_analysis_output +├── ├── filter_db +├── ├── ├── xxx_ascend_pt +├── ├── ├── ├── ASCEND_PROFILER_OUTPUT +├── ├── ├── ├── ├── ascend_pytorch_profiler_x.db +├── ├── ├── xxx_ascend_pt +├── ├── ├── …… +``` + +过滤逻辑: + +1. 删除COMMUNICATION_TASK_INFO和TASK_PMU_INFO +2. 获取所有的COMMUNICATION_OP全保留,获取全部的connectionId集合 +3. 根据opType(TOP算子FA、MatMul、GroupedMatmul)筛选COMPUTE_TASK_INFO和TASK(通过globalTaskId进行关联)数据,同时获取筛选出来的connectionId集合。 +4. 逻辑2和逻辑3的connectionId集合取并集,CANN_API、Pytorch_API的connectionId相同的数据保留,其他去掉。 + +### mstx2commop + +设置-m mstx2commop时,不会生成cluster_analysis.db,会将通信内置打点数据转换成通信算子。 + +**说明:强烈建议在levelNone的情况下使用,会新生成COMMUNICATION_OP,否则会破坏原来的表结构。** + +结果: + +设置levelNone时,统一db里面没有COMMUNICATION_OP,该分析能力会将通信内置打点数据转换成通信算子,并且可以在MindStudio Insight中呈现。 + +### slow_rank + +设置-m slow_rank时,会生成以下表。 + +#### SlowRank + +说明: + +基于db格式的集群性能数据,进行慢卡分析。 + +格式: + +| 字段名 | 类型 | 含义 | +| ------ | ---- | ---- | +| rankId | INTEGER | 慢卡 | +| slowAffectCount | INTEGER | 该rank影响了多少次通信 | + +### p2p_pairing + +设置-m p2p_pairing时,不会生成cluster_analysis.db。 + +说明:该分析能力主要是为了显示P2P算子的连线,让用户看到发送和接收的src_rank和dst_rank。**目前MindStudio Insight暂时没有做这一块的适配。** + +结果: + +会在集群数据的ascend_pytorch_profiler_{rank_id}.db的Communication_OP表中新增一列opConnectionId。 根据这个opConnectionId可以把不同rank的P2P算子连线。 + + +### pp_chart + +说明: 这个能力需要首先要使用轻量化打点在前反向前后打点,然后使用mstt进行处理,最后用MindStudio Insight进行显示。 + +#### 打点 + +以DualpipeV2为例,找到前反向代码,在dualpipev_schedules.py里面添加如下代码(仅为示例,需要注意这段代码添加的位置): +``` +import torch_npu +def step_wrapper(func, msg: str): + def wrapper(*args, **kwargs): + new_msg = {"name": msg} + if msg = "forward_step_with_model_graph" and kwargs.get("extra_block_kwargs") is not None: + new_msg["name"] = "forward_backward_overlaping" + if "current_microbatch" in kwargs: + new_msg["current_microbatch"] = kwargs["current_microbatch"] + if msg == "WeightGradStore_pop" and len(WeightGradStore.cache) == 0: + mstx_state_step_range_id = None + else: + mstx_state_step_range_id = torch_npu.npu.mstx.range_start(str(new_msg), torch_npu.npu.current_stream()) + out = func(*args, **kwargs) + if mstx_state_step_range_id is not None: + torch_npu.npu.mstx.range_end(mstx_state_step_range_id) + mstx_state_step_range_id = None + return out + return wrapper + +forward_step_with_model_graph = step_wrapper(forward_step_with_model_graph, "forward_step_with_model_graph") +forward_step_no_model_graph = step_wrapper(forward_step_no_model_graph, "forward_step_no_model_graph") +backward_step_with_model_graph = step_wrapper(backward_step_with_model_graph, "backward_step_with_model_graph") +backward_step = step_wrapper(backward_step, "backward_step") +WeightGradStore.pop = step_wrapper(WeightGradStore.pop, "WeightGradStore.pop") +``` + +同时,采集profiling数据时,需要添加metadata: + +``` +prof.add_metadata('pp_info', json.dumps( + { + 'pp_type': 'dualpipev', + 'microbatch_num': 10, + } +)) +# microbatch_num需要替换成实际的值 +``` + +#### StepTaskInfo + +说明: + +基于上一章节打点后的db格式的集群性能数据,进行处理,生成表格,供可视化显示 + +格式: + +| 字段名 | 类型 | 含义 | +| ------ | ---- | ---- | +| name | TEXT | 前反向信息 | +| startNs | INTEGER | 在device上开始时间 | +| endNs | INTEGER | 在device上结束时间 | +| type | INTEGER | 类型,不同类型显示不同颜色 | + +#### 通信 + +当profiler_level设为Level_none时,COMMUNICATION_OP这个表不存在,需要使用mstx2commop这个分析能力将通信内置打点转换为通信算子,这样就会生成这个表。pp流水图也可以显示send和recv。 + +有了COMMUNICATION_OP这个表,需要使用分析能力p2p_pairing。这样pp流水图也可以显示send和recv的连线,但是这个能力需要level1及以上。 + +#### communication_group.json + +记录通信域信息,解析analysis.db生成的交付件,collective表示集合通信域,P2P表示点对点通信,用户无须关注该文件。 + +#### stats.ipynb + +- 数据解析模式为cann_api_sum时生成,保存在cluster_analysis_output/CannApiSum目录下。 + + 可使用jupyter notebook工具或MindStudio Insight工具打开,主要展示集群API耗时信息。 + +- 数据解析模式为compute_op_sum时生成,保存在cluster_analysis_output/ComputeOpSum目录下。 + + 可使用jupyter notebook工具或MindStudio Insight工具打开,主要展示集群计算算子耗时分析(将集群所有计算算子进行汇总并以图表展示),集群Rank计算算子耗时分析(将每个Rank的计算算子进行各自汇总)。 + +- 数据解析模式为hccl_sum时生成,保存在cluster_analysis_output/HcclSum目录下。 + + 可使用jupyter notebook工具或MindStudio Insight工具打开,主要展示集群通信算子耗时分析(将集群所有通信算子进行汇总并以图表展示),集群Rank通信算子耗时分析(将每个Rank的通信算子进行各自汇总)、Top通信算子信息展示。 + +- 数据解析模式为mstx_sum时生成,保存在cluster_analysis_output/MstxSum目录下。 + + 可使用jupyter notebook工具或MindStudio Insight工具打开,主要展示集群场景mstx打点信息,分为框架侧、CANN侧和Device侧三部分的打点信息。 + +- 数据解析模式为slow_link时生成,保存在cluster_analysis_output/SlowLink目录下。 + + 可使用jupyter notebook工具或MindStudio Insight工具打开,主要展示集群场景异常慢链路数据分析(将集群所有链路进行汇总并以图表展示),集群慢链路汇总耗时分析(展示检测到可能存在慢链路的数据)。 diff --git a/profiler/msprof_analyze/module_visualization/__init__.py b/profiler/msprof_analyze/module_visualization/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/profiler/msprof_analyze/module_visualization/graph/__init__.py b/profiler/msprof_analyze/module_visualization/graph/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/profiler/msprof_analyze/module_visualization/graph/prof_node.py b/profiler/msprof_analyze/module_visualization/graph/prof_node.py new file mode 100644 index 0000000000000000000000000000000000000000..1f39ee9bfa2dd86cc02ea74400de8bcebdd75e21 --- /dev/null +++ b/profiler/msprof_analyze/module_visualization/graph/prof_node.py @@ -0,0 +1,217 @@ +# Copyright (c) 2024 Huawei Technologies Co., Ltd +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from msprof_analyze.prof_common.constant import Constant +from msprof_analyze.prof_common.base_node import BaseNode +from msprof_analyze.prof_common.trace_event_bean import TraceEventBean + + +class ProfNode(BaseNode): + + def __init__(self, event: TraceEventBean, parent_node=None): + super().__init__(event, parent_node) + self._kernel_total_list = [] + self._communication_total_list = [] + self._precision_index = 1 + self._computing_time = 0 + self._uncovered_comm_time = 0 + self._free_time = 0 + self._step_id = None + self._micro_step_id = None + self._bwd_overall_data = {} + + @property + def node_id(self): + return self._event.unique_id + + @property + def node_type(self): + if self._event.event_type is None: + return Constant.VIRTUAL_TYPE + return self._event.event_type + + @property + def step_id(self): + return self._step_id + + @property + def micro_step_id(self): + return self._micro_step_id + + @property + def is_backward(self): + return self.node_id.startswith(Constant.BACKWARD_MODULE) + + @property + def fwd_bwd_id(self): + return self._event.fwd_bwd_id + + @property + def is_bwd(self): + return "BACKWARD" in self.node_id + + @property + def total_kernels(self): + if self.node_type == Constant.VIRTUAL_TYPE: + return [kernel for node in self.child_nodes for kernel in node.total_kernels] + return self._kernel_total_list + + @property + def total_communications(self): + if self.node_type == Constant.VIRTUAL_TYPE: + return [comm for node in self.child_nodes for comm in node.total_communications] + return self._communication_total_list + + @property + def host_total_dur(self): + if self.node_type == Constant.VIRTUAL_TYPE: + return sum((node.host_total_dur for node in self.child_nodes)) + return self._event.dur + + @property + def host_self_dur(self): + if self.node_type == Constant.VIRTUAL_TYPE: + return 0 + return self.host_total_dur - sum((node.host_total_dur for node in self.child_nodes)) + + @property + def device_total_dur(self): + return sum((kernel.dur for kernel in self.total_kernels)) + + @property + def device_self_dur(self): + if self.node_type == Constant.VIRTUAL_TYPE: + return 0 + return self.device_total_dur - sum((node.device_total_dur for node in self.child_nodes)) + + @property + def input_data(self) -> dict: + data = {} + input_dim = self._event.args.get("Input Dims") + if input_dim: + data["Input Dims"] = input_dim + input_type = self._event.args.get("Input type") + if input_type: + data["Input type"] = input_type + return data + + @property + def kernel_data(self) -> list: + return [kernel.kernel_info for kernel in self.total_kernels] + + @property + def communication_data(self) -> list: + return [[comm.name, comm.dur] for comm in self.total_communications] + + @property + def overall_data(self): + return {"Computing Time(us)": round(self._computing_time, 3), + "Uncovered Communication Time(us)": round(self._uncovered_comm_time, 3), + "Free Time(us)": round(self._free_time, 3)} + + @property + def data(self): + data = { + "Overall Metrics": self.overall_data} if self.node_type != Constant.OPERATOR_TYPE else {} + if self._bwd_overall_data: + data.update({"Backward Overall Metrics": self._bwd_overall_data}) + data.update({"Input Data": self.input_data, + "precision_index": self.precision_index, + "Host Self Duration(us)": round(self.host_self_dur, 3), + "Host Total Duration(us)": round(self.host_total_dur, 3), + "Device Self Duration(us)": round(self.device_self_dur, 3), + "Device Total Duration(us)": round(self.device_total_dur, 3), + "kernels": self.kernel_data, + "Communications": self.communication_data}) + return data + + @property + def info(self): + info = {"id": self.node_id, + "node_type": self.node_type, + "data": self.data, + "upnode": self.parent_node.node_id if self.parent_node else "None", + "subnodes": [node.node_id for node in iter(self.child_nodes)]} + if self.step_id is not None: + info.update({"step_id": self.step_id}) + if self.micro_step_id is not None: + info.update({"micro_step_id": self.micro_step_id}) + return info + + @property + def is_root_node(self): + return self.node_id == Constant.NPU_ROOT_ID + + @property + def precision_index(self): + return self._precision_index + + @precision_index.setter + def precision_index(self, precision_index): + self._precision_index = precision_index + + @step_id.setter + def step_id(self, step_id): + self._step_id = step_id + + @micro_step_id.setter + def micro_step_id(self, micro_step_id): + self._micro_step_id = micro_step_id + + def update_child_nodes(self, node): + self._child_nodes.append(node) + + def reset_child_nodes(self, nodes): + self._child_nodes = nodes + + def update_kernel_total_list(self, kernel_list: list): + self._kernel_total_list.extend(kernel_list) + + def update_communication_total_list(self, communication_list: list): + self._communication_total_list.extend(communication_list) + + def update_child_precision_index(self): + if not self.child_nodes: + return + max_dur = max((node.device_total_dur for node in self.child_nodes)) + min_dur = min((node.device_total_dur for node in self.child_nodes)) + diff_dur = max_dur - min_dur + for node in self.child_nodes: + node.precision_index = 1 - (node.device_total_dur - min_dur) / diff_dur if diff_dur else 1 + + def update_overall_metrics(self, overlap_analysis_event): + if not self.total_kernels and not self.total_communications: + return + device_events = [] + device_events.extend(self.total_kernels) + device_events.extend(self.total_communications) + device_events.sort(key=lambda x: x.start_time) + device_start = device_events[0].start_time + device_end = device_events[-1].end_time + for event in overlap_analysis_event: + if event.start_time >= device_end: + break + if event.end_time <= device_start: + continue + duration_us = float( + min(device_end, event.end_time) - max(device_start, event.start_time)) + if event.name == Constant.COMPUTING_EVENT: + self._computing_time += duration_us + elif event.name == Constant.FREE_EVENT: + self._free_time += duration_us + elif event.name == Constant.UNCOVERED_COMMUNICATION_EVENT: + self._uncovered_comm_time += duration_us + + def update_bwd_overall_metrics(self, overall_metrics): + self._bwd_overall_data = overall_metrics diff --git a/profiler/msprof_analyze/module_visualization/graph_build/__init__.py b/profiler/msprof_analyze/module_visualization/graph_build/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/profiler/msprof_analyze/module_visualization/graph_build/fwd_module_node.py b/profiler/msprof_analyze/module_visualization/graph_build/fwd_module_node.py new file mode 100644 index 0000000000000000000000000000000000000000..27bb52da7960ccb7f7ac51d92552cf461196903d --- /dev/null +++ b/profiler/msprof_analyze/module_visualization/graph_build/fwd_module_node.py @@ -0,0 +1,33 @@ +# Copyright (c) 2024 Huawei Technologies Co., Ltd +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from msprof_analyze.prof_common.base_node import BaseNode +from msprof_analyze.prof_common.trace_event_bean import TraceEventBean + + +class FwdModuleNode(BaseNode): + def __init__(self, event: TraceEventBean, parent_node=None): + super().__init__(event, parent_node) + self._bwd_op_list = [] + + @property + def bwd_op_list(self): + return self._bwd_op_list + + @property + def event(self): + return self._event + + def update_bwd_op(self, bwd_op_list: list): + self._bwd_op_list.extend(bwd_op_list) diff --git a/profiler/msprof_analyze/module_visualization/graph_build/prof_graph_builder.py b/profiler/msprof_analyze/module_visualization/graph_build/prof_graph_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..a0c00ef92b3e15e2c2c7c81f76d451db5eacb183 --- /dev/null +++ b/profiler/msprof_analyze/module_visualization/graph_build/prof_graph_builder.py @@ -0,0 +1,237 @@ +# Copyright (c) 2024 Huawei Technologies Co., Ltd +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from decimal import Decimal + +from msprof_analyze.module_visualization.graph.prof_node import ProfNode +from msprof_analyze.module_visualization.graph_build.fwd_module_node import FwdModuleNode +from msprof_analyze.prof_common.tree_builder import TreeBuilder +from msprof_analyze.prof_common.trace_event_bean import TraceEventBean +from msprof_analyze.prof_common.constant import Constant +from msprof_analyze.module_visualization.prof_parse.prof_data_pre_process import ProfDataPreProcess + + +class ProfGraphBuilder: + + def __init__(self, prof_data_path: str): + self._prof_data_path = prof_data_path + self._prof_data = {} + self._fwd_bwd_id = 1 + + @classmethod + def _create_event_bean_from_ops(cls, op_list: list, name: str) -> TraceEventBean: + min_start = min((op.start_time for op in iter(op_list))) + max_end = max((op.end_time for op in iter(op_list))) + # 以反向算子的区间作为反向module的区间范围,为了module包含算子,做了-0.0001 +0.0001处理 + event = TraceEventBean( + {"ts": min_start - Decimal("0.0001"), "dur": float(max_end - min_start + Decimal("0.0001")), "name": name}) + event.event_type = Constant.MODULE_TYPE + return event + + @classmethod + def _trans_flow_to_dict(cls, flow_events: dict, end_events: list) -> dict: + end_event_dict = {} + for event in end_events: + end_event_dict[event.start_time] = event + result_data = {} + for flow in flow_events.values(): + start_point = flow.get("start") + end_point = flow.get("end") + if not start_point or not end_point: + continue + end_event = end_event_dict.get(end_point.start_time) + if end_event: + result_data.setdefault(start_point.start_time, []).append(end_event) + return result_data + + @classmethod + def _create_virtual_node(cls, all_nodes: list): + root_node = all_nodes[0] + virtual_nodes = [] + first_level_nodes = root_node.child_nodes + root_node.reset_child_nodes([]) + merged_nodes = [] + order_id = 1 + for node in first_level_nodes: + if node.node_type == Constant.OPERATOR_TYPE: + merged_nodes.append(node) + continue + if len(merged_nodes) >= 2: + virtual_node = ProfNode(TraceEventBean({"ts": min((node.start_time for node in merged_nodes))}, + f"Operators_Between_Modules_{order_id}"), root_node) + root_node.update_child_nodes(virtual_node) + order_id += 1 + for op_node in merged_nodes: + op_node.parent_node = virtual_node + virtual_node.update_child_nodes(op_node) + virtual_nodes.append(virtual_node) + elif len(merged_nodes) == 1: + root_node.update_child_nodes(merged_nodes[0]) + root_node.update_child_nodes(node) + merged_nodes = [] + if len(merged_nodes) >= 2: + virtual_node = ProfNode(TraceEventBean({"ts": min((node.start_time for node in merged_nodes))}, + f"Operators_Between_Modules_{order_id}"), root_node) + root_node.update_child_nodes(virtual_node) + for op_node in merged_nodes: + op_node.parent_node = virtual_node + virtual_node.update_child_nodes(op_node) + virtual_nodes.append(virtual_node) + elif len(merged_nodes) == 1: + root_node.update_child_nodes(merged_nodes[0]) + all_nodes.extend(virtual_nodes) + + @classmethod + def _set_event_order_id(cls, all_events: list): + name_dict = {} + for event in all_events: + order_id = name_dict.get(event.name, 0) + event.set_id(f"{event.name}_{order_id}") + name_dict[event.name] = order_id + 1 + + def build_graph(self): + self._prof_data = ProfDataPreProcess(self._prof_data_path).run() + all_data = [*self._prof_data.get(Constant.MODULE_EVENT, []), + *self.find_bwd_module(), + *self._prof_data.get(Constant.CPU_OP_EVENT, [])] + all_data.sort(key=lambda x: x.start_time) + self._set_event_order_id(all_data) + all_nodes = TreeBuilder.build_tree(all_data, ProfNode, TraceEventBean({}, Constant.NPU_ROOT_ID)) + if len(all_nodes) < 2: + msg = "Failed to build graph." + raise RuntimeError(msg) + self._update_kernel_details(all_nodes[0]) + self._update_communication_details(all_nodes[0]) + self._create_virtual_node(all_nodes) + self._update_precision_index_and_overall_metrics(all_nodes) + self._update_step_info(all_nodes[0]) + return all_nodes + + def find_bwd_module(self) -> list: + bwd_module_list = [] + fwdbwd_flow = self._prof_data.get(Constant.FWD_BWD_FLOW, {}) + fwdbwd_flow = {key: value for key, value in fwdbwd_flow.items() if + value.get("start") and value.get("end") and value.get("start").tid != value.get("end").tid} + module_list = self._prof_data.get(Constant.MODULE_EVENT, []) + cpu_op_list = self._prof_data.get(Constant.CPU_OP_EVENT, []) + if not fwdbwd_flow or not module_list or not cpu_op_list: + return bwd_module_list + fwd_tid = module_list[0].tid + bwd_tid = fwd_tid + for end_point in (flow.get("end") for flow in fwdbwd_flow.values()): + if end_point: + bwd_tid = end_point.tid + break + if fwd_tid == bwd_tid: + return bwd_module_list + # 将每一个反向包成一个module,名字叫“nn.Module: BACKWARD_0” + cpu_op_list.sort(key=lambda x: x.start_time) + pre_status = Constant.FWD_OR_OPT + bwd_op_list = [] + for op in cpu_op_list: + if op.tid == bwd_tid: + bwd_op_list.append(op) + pre_status = Constant.BACKWARD + continue + elif pre_status == Constant.BACKWARD: + bwd_module_list.append(self._create_event_bean_from_ops(bwd_op_list, Constant.BACKWARD_MODULE)) + bwd_module_list.extend(self._match_fwd_module(module_list, fwdbwd_flow, bwd_op_list)) + bwd_op_list.clear() + pre_status = Constant.FWD_OR_OPT + if bwd_op_list: + bwd_module_list.append(self._create_event_bean_from_ops(bwd_op_list, Constant.BACKWARD_MODULE)) + bwd_module_list.extend(self._match_fwd_module(module_list, fwdbwd_flow, bwd_op_list)) + bwd_op_list.clear() + return bwd_module_list + + def _match_fwd_module(self, module_list, fwdbwd_flow, bwd_op_list): + # 通过连线匹配正向module,构建出反向的整体module关系 + bwd_module_list = [] + all_nodes = TreeBuilder.build_tree(module_list, FwdModuleNode, TraceEventBean({})) + root_node = all_nodes[0] + fwdbwd_flow_dict = self._trans_flow_to_dict(fwdbwd_flow, bwd_op_list) + for start_time, end_events in fwdbwd_flow_dict.items(): + matched_node = root_node.binary_search(start_time) + while matched_node != Constant.INVALID_RETURN: + matched_node.update_bwd_op(end_events) + matched_node = matched_node.binary_search(start_time) + for module_node in all_nodes: + if module_node.bwd_op_list: + module_node.event.fwd_bwd_id = self._fwd_bwd_id + bwd_module_list.append( + self._create_event_bean_from_ops(module_node.bwd_op_list, f"{module_node.name} [BACKWARD]")) + bwd_module_list[-1].fwd_bwd_id = self._fwd_bwd_id + self._fwd_bwd_id += 1 + return bwd_module_list + + def _update_kernel_details(self, root_node): + kernel_flow_dict = self._trans_flow_to_dict(self._prof_data.get(Constant.TORCH_TO_NPU_FLOW, {}), + self._prof_data.get(Constant.KERNEL_EVENT, [])) + for start_time, kernels in kernel_flow_dict.items(): + matched_node = root_node.binary_search(start_time) + while matched_node != Constant.INVALID_RETURN: + matched_node.update_kernel_total_list(kernels) + matched_node = matched_node.binary_search(start_time) + + def _update_communication_details(self, root_node): + communication_flow_dict = self._trans_flow_to_dict(self._prof_data.get(Constant.TORCH_TO_NPU_FLOW, {}), + self._prof_data.get(Constant.HCCL_EVENT, [])) + for start_time, communications in communication_flow_dict.items(): + matched_node = root_node.binary_search(start_time) + while matched_node != Constant.INVALID_RETURN: + matched_node.update_communication_total_list(communications) + matched_node = matched_node.binary_search(start_time) + + def _update_step_info(self, root_node): + first_level_nodes = root_node.child_nodes + step_events = self._prof_data.get(Constant.STEP_EVENT, []) + node_dict = {} + if not step_events: + node_dict[None] = first_level_nodes + else: + for node in first_level_nodes: + for step_event in step_events: + if step_event.start_time <= node.start_time <= step_event.end_time: + node.step_id = step_event.step_id + node_dict.setdefault(step_event.step_id, []).append(node) + break + for nodes in node_dict.values(): + micro_step_list = [] + micro_events = [] + for node in nodes: + micro_events.append(node) + if node.is_backward: + micro_step_list.append(micro_events) + micro_events = [] + if micro_step_list: + micro_step_list[-1].extend(micro_events) + else: + micro_step_list.append(micro_events) + for index, micro_events in enumerate(micro_step_list): + for node in micro_events: + node.micro_step_id = index + + def _update_precision_index_and_overall_metrics(self, all_nodes: list): + overlap_analysis_event = self._prof_data.get(Constant.OVERLAP_ANALYSIS_EVENT, []) + overlap_analysis_event.sort(key=lambda x: x.start_time) + bwd_infos = {} + for node in all_nodes: + node.update_child_precision_index() + if node.node_type != Constant.OPERATOR_TYPE: + node.update_overall_metrics(overlap_analysis_event) + if node.is_bwd and node.fwd_bwd_id: + bwd_infos[node.fwd_bwd_id] = node.overall_data + for node in all_nodes: + if node.node_type != Constant.OPERATOR_TYPE and not node.is_bwd: + node.update_bwd_overall_metrics(bwd_infos.get(node.fwd_bwd_id, {})) diff --git a/profiler/msprof_analyze/module_visualization/prof_graph_export.py b/profiler/msprof_analyze/module_visualization/prof_graph_export.py new file mode 100644 index 0000000000000000000000000000000000000000..acb178f7e7e60ea733f93ccbcb7bccdbae458442 --- /dev/null +++ b/profiler/msprof_analyze/module_visualization/prof_graph_export.py @@ -0,0 +1,58 @@ +# Copyright (c) 2024 Huawei Technologies Co., Ltd +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import os.path +from datetime import datetime + +from msprof_analyze.prof_common.constant import Constant +from msprof_analyze.prof_common.file_reader import FileReader +from msprof_analyze.prof_common.path_manager import PathManager +from msprof_analyze.module_visualization.graph_build.prof_graph_builder import ProfGraphBuilder + + +class ProfGraphExport: + @classmethod + def export_to_json(cls, prof_data_path: str, output_path: str): + logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s") + output_path = os.path.abspath(output_path) + prof_data_path = os.path.abspath(prof_data_path) + try: + PathManager.input_path_common_check(prof_data_path) + PathManager.check_input_directory_path(output_path) + PathManager.make_dir_safety(output_path) + PathManager.check_path_writeable(output_path) + except RuntimeError as err: + logging.error(err) + try: + cls.generate_graph_data(prof_data_path, output_path) + except RuntimeError as err: + logging.error(err) + + @classmethod + def generate_graph_data(cls, prof_data_path: str, output_path: str): + all_nodes = ProfGraphBuilder(prof_data_path).build_graph() + result_data = {"root": Constant.NPU_ROOT_ID, "node": {}} + for node in all_nodes: + result_data["node"][node.node_id] = node.info + step_list = list(set([node.step_id for node in all_nodes[0].child_nodes if node.step_id is not None])) + if step_list: + result_data["StepList"] = step_list + micro_steps = len( + set([node.micro_step_id for node in all_nodes[0].child_nodes if node.micro_step_id is not None])) + result_data["MicroSteps"] = micro_steps + file_name = "prof_graph_json_{}.vis".format(datetime.utcnow().strftime("%Y%m%d%H%M%S%f")[:-3]) + FileReader.write_json_file(output_path, result_data, file_name) + logging.info("Performance data has been converted into a graph-structured file: %s", + os.path.join(output_path, file_name)) diff --git a/profiler/msprof_analyze/module_visualization/prof_parse/__init__.py b/profiler/msprof_analyze/module_visualization/prof_parse/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/profiler/msprof_analyze/module_visualization/prof_parse/prof_data_pre_process.py b/profiler/msprof_analyze/module_visualization/prof_parse/prof_data_pre_process.py new file mode 100644 index 0000000000000000000000000000000000000000..2d39649d58d543fc295bdc0296bd244c25f506f3 --- /dev/null +++ b/profiler/msprof_analyze/module_visualization/prof_parse/prof_data_pre_process.py @@ -0,0 +1,137 @@ +# Copyright (c) 2024 Huawei Technologies Co., Ltd +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import os + +from msprof_analyze.prof_common.file_reader import FileReader +from msprof_analyze.prof_common.constant import Constant +from msprof_analyze.prof_common.kernel_bean import KernelBean +from msprof_analyze.prof_common.trace_event_bean import TraceEventBean + + +class ProfDataPreProcess: + def __init__(self, prof_data_path: str): + self._prof_data_path = prof_data_path + self._trace_path = "" + self._kernel_details_path = "" + self._kernel_pid = None + self._hccl_pid = None + self._overlap_analysis_pid = None + self._result_data = {Constant.CPU_OP_EVENT: [], Constant.MODULE_EVENT: [], Constant.KERNEL_EVENT: [], + Constant.TORCH_TO_NPU_FLOW: {}, Constant.FWD_BWD_FLOW: {}, Constant.HCCL_EVENT: [], + Constant.OVERLAP_ANALYSIS_EVENT: [], Constant.STEP_EVENT: []} + + @staticmethod + def _check_trace_data(trace_data): + if not isinstance(trace_data, list): + msg = f"Invalid profiling data path, this feature only supports performance data " \ + f"collected by Ascend PyTorch Profiler." + raise RuntimeError(msg) + + def run(self) -> dict: + self._check_trace_path() + self._parse_trace_events() + self._parse_kernel_details() + self._check_result_data() + return self._result_data + + def _check_trace_path(self): + if os.path.isfile(self._prof_data_path): + (split_file_path, split_file_name) = os.path.split(self._prof_data_path) + (shot_name, extension) = os.path.splitext(split_file_name) + if extension != ".json": + msg = f"Invalid profiling path suffix: {self._prof_data_path}. " \ + f"You should input in a json file path, such as trace_view.json." + raise RuntimeError(msg) + self._trace_path = self._prof_data_path + return + ascend_output = os.path.join(self._prof_data_path, "ASCEND_PROFILER_OUTPUT") + profiler_output = ascend_output if os.path.isdir(ascend_output) else self._prof_data_path + json_path = os.path.join(profiler_output, "trace_view.json") + if not os.path.isfile(json_path): + msg = f"Invalid profiling path: {self._prof_data_path}. The data path should be the " \ + f"folder that ends with the ascend_pt collected by the Ascend PyTorch Profiler." + raise RuntimeError(msg) + kernel_path = os.path.join(profiler_output, "kernel_details.csv") + if os.path.isfile(kernel_path): + self._kernel_details_path = kernel_path + self._trace_path = json_path + + def _parse_trace_events(self): + trace_data = FileReader.read_json_file(self._trace_path) + self._check_trace_data(trace_data) + iter_trace_data = [TraceEventBean(data) for data in trace_data] + for event in iter_trace_data: + if self._kernel_pid is not None and self._hccl_pid is not None and self._overlap_analysis_pid is not None: + break + if not event.is_meta(): + continue + if event.is_npu_process(): + self._kernel_pid = event.pid + elif event.is_hccl_process(): + self._hccl_pid = event.pid + elif event.is_overlap_analysis_process(): + self._overlap_analysis_pid = event.pid + if self._kernel_pid is None: + msg = "There is no operator on the NPU side for this data, please check whether the NPU switch is enabled." + raise RuntimeError(msg) + for event in iter_trace_data: + if event.is_optimizer(): + event.event_type = Constant.MODULE_TYPE + self._result_data[Constant.MODULE_EVENT].append(event) + elif event.is_cpu_op(): + if event.is_step(): + self._result_data[Constant.STEP_EVENT].append(event) + else: + event.event_type = Constant.OPERATOR_TYPE + self._result_data[Constant.CPU_OP_EVENT].append(event) + elif event.is_nn_module(): + event.event_type = Constant.MODULE_TYPE + self._result_data[Constant.MODULE_EVENT].append(event) + elif event.is_torch_to_npu(): + if event.is_flow_start(): + self._result_data[Constant.TORCH_TO_NPU_FLOW].setdefault(event.id, {})["start"] = event + else: + self._result_data[Constant.TORCH_TO_NPU_FLOW].setdefault(event.id, {})["end"] = event + elif event.is_fwd_bwd_flow(): + if event.is_flow_start(): + self._result_data[Constant.FWD_BWD_FLOW].setdefault(event.id, {})["start"] = event + else: + self._result_data[Constant.FWD_BWD_FLOW].setdefault(event.id, {})["end"] = event + elif event.is_kernel_event(self._kernel_pid): + self._result_data[Constant.KERNEL_EVENT].append(event) + elif event.is_hccl_event(self._hccl_pid): + self._result_data[Constant.HCCL_EVENT].append(event) + elif event.is_overlap_analysis_event(self._overlap_analysis_pid): + self._result_data[Constant.OVERLAP_ANALYSIS_EVENT].append(event) + + def _parse_kernel_details(self): + if not self._kernel_details_path: + return + try: + all_kernels = FileReader.read_csv_file(self._kernel_details_path, KernelBean) + except Exception as e: + logging.error(e) + kernels = list(filter(lambda x: x.is_computing_op, all_kernels)) + if kernels: + self._result_data[Constant.KERNEL_EVENT] = kernels + + def _check_result_data(self): + if not self._result_data.get(Constant.CPU_OP_EVENT): + msg = "This data does not have any aten operator, please make sure to enable the CPU switch." + raise RuntimeError(msg) + if not [event for event in self._result_data.get(Constant.MODULE_EVENT) if event.is_nn_module()]: + msg = "This data does not collect any modules, please make sure to enable the with_stack or with_modules." + raise RuntimeError(msg) diff --git a/profiler/msprof_analyze/osrt_trace/README.md b/profiler/msprof_analyze/osrt_trace/README.md new file mode 100644 index 0000000000000000000000000000000000000000..0ffb70415c60922fdff1c9de07313f8ee81f6aed --- /dev/null +++ b/profiler/msprof_analyze/osrt_trace/README.md @@ -0,0 +1,157 @@ +# MSOSRT Trace系统库函数耗时检测 + +OSRT(OS runtime libraries trace)是根据Linux操作系统运行时库采集用户层库函数API的调用信息。MSOSRT(MindStudio OSRT)则是采集Linux C库函数和POSIX线程(pthread)库中典型的高耗时接口,即可能阻塞用户进程的函数(如read、ioctl,pthread_mutex_lock等),统计其耗时信息,帮助用户分析进程阻塞的原因。 + +## 使用方法 + +1. 约束条件:仅支持Linux系统,拥有g++编译环境和glibc、pthread等标准库。 +2. 将mstt代码仓下载到本地,进入到profiler/msprof_analyze/osrt_trace目录,执行`bash build.sh`,生成`libmsosrt_trace.so`。 +3. 执行`export LD_PRELOAD=./libmsosrt_trace.so:$LD_PRELOAD`,将`libmsosrt_trace.so`加入到LD_PRELOAD环境变量中。 +4. 设置检测阈值和导出目录的环境变量: + + ```bash + # 检测阈值,正整数,只统计超过阈值的库函数,单位:ns,默认为10000000 + export MSOSRT_TRACE_THRESHOLD=10000000 + # 导出目录,字符串,设置检测结果导出的目录,默认为当前目录 + export MSOSRT_EXPORT_PATH="./osrt_trace_result" + ``` + +5. 执行用户进程,如`python main.py` + +6. 用户进程执行结束后,在MSOSRT_EXPORT_PATH路径下会生成检测结果,生成结果文件:msosrt_trace\_{进程号}\_{进程名}.csv,如`msosrt_trace_2328177_python3.csv`,文件内容包含pid、tid、函数名、开始执行时间和耗时等信息,如下所示: + + | Pid | Tid | Function | StartTime(ns) | Duration(ns) | + | ------: | ------: | ----------------: | ------------------: | -----------: | + | 2328177 | 2328280 | pthread_cond_wait | 1725398310787080000 | 3088062410 | + | 2328177 | 2328282 | pthread_cond_wait | 1725398310787170000 | 3087994240 | + | 2328177 | 2328480 | read | 1725398318916180000 | 100509970 | + | 2328177 | 2328440 | ioctl | 1725398319218640000 | 512040720 | + | 2328177 | 2328177 | free | 1725398330504550000 | 56386880 | + +## 检测接口 + +MSOSRT支持检测如下操作系统库函数: + +- 内存操作 + + ```c + malloc + realloc + free + mmap + munmap + mremap + msync + mprotect + brk + ``` + +- 文件操作 + + ```c + dup + dup2 + dup3 + tee + splice + fallocate + fdatasync + fsync + fcntl + flock + lockf + truncate + ftruncate + ioctl + open + openat + pipe + pipe2 + mkfifo + mkfifoat + read + pread + readv + preadv + preadv2 + write + pwrite + writev + pwritev + pwritev2 + copy_file_range + sync + syncfs + sync_file_range + vmsplice + process_vm_readv + process_vm_writev + fclose + fcloseall + fflush + fgetc + fgets + fputc + fputs + flockfile + ftrylockfile + funlockfile + fopen + freopen + fread + fwrite + getdelim + getline + getc + putc + getc_unlocked + putc_unlocked + fflush_unlocked + fgetc_unlocked + fputc_unlocked + fread_unlocked + fwrite_unlocked + fgets_unlocked + fputs_unlocked + ``` + +- 网络操作 + + ```c + socket + socketpair + epoll_ctl + epoll_wait + epoll_pwait + select + listen + accept + accept4 + bind + poll + ppoll + send + sendto + sendmsg + sendmmsg + sendfile + recv + recvfrom + recvmsg + recvmmsg + ``` + +- 线程操作 + + ```c + pthread_mutex_lock + pthread_mutex_timedlock + pthread_cond_signal + pthread_cond_broadcast + pthread_cond_wait + pthread_cond_timedwait + pthread_rwlock_rdlock + pthread_rwlock_timedrdlock + pthread_rwlock_wrlock + pthread_rwlock_timedwrlock + ``` \ No newline at end of file diff --git a/profiler/msprof_analyze/osrt_trace/build.sh b/profiler/msprof_analyze/osrt_trace/build.sh new file mode 100644 index 0000000000000000000000000000000000000000..bb153e6247122c922dc5cea247be43bfec3d5430 --- /dev/null +++ b/profiler/msprof_analyze/osrt_trace/build.sh @@ -0,0 +1 @@ +g++ ./src/*.cpp -std=c++11 -fPIC -fstack-protector-all -fno-strict-aliasing -fno-common -fvisibility=hidden -fvisibility-inlines-hidden -Wfloat-equal -Wextra -O2 -shared -lpthread -ldl -o libmsosrt_trace.so \ No newline at end of file diff --git a/profiler/msprof_analyze/osrt_trace/src/file_func.cpp b/profiler/msprof_analyze/osrt_trace/src/file_func.cpp new file mode 100644 index 0000000000000000000000000000000000000000..319dcb227b139adf158d55fe762f97afdfa5fdd8 --- /dev/null +++ b/profiler/msprof_analyze/osrt_trace/src/file_func.cpp @@ -0,0 +1,664 @@ +#include "file_func.h" + +#include + +#include "msosrt_trace.h" + +void FileFuncProxy::loadFunc() +{ + LOAD_FUNC(dup, DupFunc); + LOAD_FUNC(dup2, Dup2Func); + LOAD_FUNC(dup3, Dup3Func); + LOAD_FUNC(tee, TeeFunc); + LOAD_FUNC(splice, SpliceFunc); + LOAD_FUNC(fallocate, FallocateFunc); + LOAD_FUNC(fdatasync, FdatasyncFunc); + LOAD_FUNC(fsync, FsyncFunc); + LOAD_FUNC(fcntl, FcntlFunc); + LOAD_FUNC(flock, FlockFunc); + LOAD_FUNC(lockf, LockfFunc); + LOAD_FUNC(truncate, TruncateFunc); + LOAD_FUNC(ftruncate, FtruncateFunc); + LOAD_FUNC(ioctl, IoctlFunc); + LOAD_FUNC(open, OpenFunc); + LOAD_FUNC(openat, OpenatFunc); + LOAD_FUNC(pipe, PipeFunc); + LOAD_FUNC(pipe2, Pipe2Func); + LOAD_FUNC(mkfifo, MkfifoFunc); + LOAD_FUNC(mkfifoat, MkfifoatFunc); + LOAD_FUNC(read, ReadFunc); + LOAD_FUNC(pread, PreadFunc); + LOAD_FUNC(readv, ReadvFunc); + LOAD_FUNC(preadv, PreadvFunc); + LOAD_FUNC(preadv2, Preadv2Func); + LOAD_FUNC(write, WriteFunc); + LOAD_FUNC(pwrite, PwriteFunc); + LOAD_FUNC(writev, WritevFunc); + LOAD_FUNC(pwritev, PwritevFunc); + LOAD_FUNC(pwritev2, Pwritev2Func); + LOAD_FUNC(copy_file_range, CopyFileRangeFunc); + LOAD_FUNC(sync, SyncFunc); + LOAD_FUNC(syncfs, SyncfsFunc); + LOAD_FUNC(sync_file_range, SyncFileRangeFunc); + LOAD_FUNC(vmsplice, VmspliceFunc); + LOAD_FUNC(process_vm_readv, ProcessVmReadvFunc); + LOAD_FUNC(process_vm_writev, ProcessVmWritevFunc); + LOAD_FUNC(fclose, FcloseFunc); + LOAD_FUNC(fcloseall, FcloseallFunc); + LOAD_FUNC(fflush, FflushFunc); + LOAD_FUNC(fgetc, FgetcFunc); + LOAD_FUNC(fgets, FgetsFunc); + LOAD_FUNC(fputc, FputcFunc); + LOAD_FUNC(fputs, FputsFunc); + LOAD_FUNC(flockfile, FlockfileFunc); + LOAD_FUNC(ftrylockfile, FtrylockfileFunc); + LOAD_FUNC(funlockfile, FunlockfileFunc); + LOAD_FUNC(fopen, FopenFunc); + LOAD_FUNC(freopen, FreopenFunc); + LOAD_FUNC(fread, FreadFunc); + LOAD_FUNC(fwrite, FwriteFunc); + LOAD_FUNC(getdelim, GetdelimFunc); + LOAD_FUNC(getline, GetlineFunc); + LOAD_FUNC(getc, GetcFunc); + LOAD_FUNC(putc, PutcFunc); + LOAD_FUNC(getc_unlocked, GetcUnlockedFunc); + LOAD_FUNC(putc_unlocked, PutcUnlockedFunc); + LOAD_FUNC(fflush_unlocked, FflushUnlockedFunc); + LOAD_FUNC(fgetc_unlocked, FgetcUnlockedFunc); + LOAD_FUNC(fputc_unlocked, FputcUnlockedFunc); + LOAD_FUNC(fread_unlocked, FreadUnlockedFunc); + LOAD_FUNC(fwrite_unlocked, FwriteUnlockedFunc); + LOAD_FUNC(fgets_unlocked, FgetsUnlockedFunc); + LOAD_FUNC(fputs_unlocked, FputsUnlockedFunc); +} + +int dup(int oldfd) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_dup(oldfd); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int dup2(int oldfd, int newfd) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_dup2(oldfd, newfd); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int dup3(int oldfd, int newfd, int flags) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_dup3(oldfd, newfd, flags); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +ssize_t tee(int fd_in, int fd_out, size_t len, unsigned int flags) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_tee(fd_in, fd_out, len, flags); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +ssize_t splice(int fd_in, off_t* off_in, int fd_out, off_t* off_out, size_t len, unsigned int flags) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_splice(fd_in, off_in, fd_out, off_out, len, flags); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int fallocate(int fd, int mode, off_t offset, off_t len) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_fallocate(fd, mode, offset, len); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int fdatasync(int fildes) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_fdatasync(fildes); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int fsync(int fd) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_fsync(fd); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int fcntl(int fd, int op, ...) +{ + global_osrt_func.loadFunc(); + va_list args; + va_start(args, op); + void* arg = va_arg(args, void*); + va_end(args); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_fcntl(fd, op, arg); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int flock(int fd, int op) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_flock(fd, op); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int lockf(int fd, int op, off_t len) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_lockf(fd, op, len); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int truncate(const char* path, off_t length) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_truncate(path, length); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int ftruncate(int fildes, off_t length) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_ftruncate(fildes, length); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int ioctl(int fd, int op, ...) +{ + global_osrt_func.loadFunc(); + va_list args; + va_start(args, op); + void* arg = va_arg(args, void*); + va_end(args); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_ioctl(fd, op, arg); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int open(const char* pathname, int flags, ...) +{ + global_osrt_func.loadFunc(); + va_list args; + va_start(args, flags); + mode_t arg = va_arg(args, mode_t); + va_end(args); + uint64_t start_time = nsec_now(); + auto ret = arg ? global_osrt_func.file_func.real_open(pathname, flags, arg) : global_osrt_func.file_func.real_open(pathname, flags); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int openat(int dirfd, const char *pathname, int flags, ...) +{ + global_osrt_func.loadFunc(); + va_list args; + va_start(args, flags); + mode_t arg = va_arg(args, mode_t); + va_end(args); + uint64_t start_time = nsec_now(); + auto ret = arg ? global_osrt_func.file_func.real_openat(dirfd, pathname, flags, arg) : global_osrt_func.file_func.real_openat(dirfd, pathname, flags); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int pipe(int pipefd[2]) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_pipe(pipefd); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int pipe2(int pipefd[2], int flags) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_pipe2(pipefd, flags); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int mkfifo(const char* pathname, mode_t mode) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_mkfifo(pathname, mode); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int mkfifoat(int dirfd, const char* pathname, mode_t mode) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_mkfifoat(dirfd, pathname, mode); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +ssize_t read(int fd, void* buf, size_t count) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_read(fd, buf, count); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +ssize_t pread(int fd, void* buf, size_t count, off_t offset) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_pread(fd, buf, count, offset); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +ssize_t readv(int fd, const struct iovec* iov, int iovcnt) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_readv(fd, iov, iovcnt); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +ssize_t preadv(int fd, const struct iovec* iov, int iovcnt, off_t offset) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_preadv(fd, iov, iovcnt, offset); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +ssize_t preadv2(int fd, const struct iovec* iov, int iovcnt, off_t offset, int flags) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_preadv2(fd, iov, iovcnt, offset, flags); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +ssize_t write(int fd, const void* buf, size_t count) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_write(fd, buf, count); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +ssize_t pwrite(int fd, const void* buf, size_t count, off_t offset) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_pwrite(fd, buf, count, offset); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +ssize_t writev(int fd, const struct iovec* iov, int iovcnt) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_writev(fd, iov, iovcnt); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +ssize_t pwritev(int fd, const struct iovec* iov, int iovcnt, off_t offset) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_pwritev(fd, iov, iovcnt, offset); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +ssize_t pwritev2(int fd, const struct iovec* iov, int iovcnt, off_t offset, int flags) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_pwritev2(fd, iov, iovcnt, offset, flags); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +ssize_t copy_file_range(int fd_in, off_t* off_in, int fd_out, off_t* off_out, size_t len, unsigned int flags) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_copy_file_range(fd_in, off_in, fd_out, off_out, len, flags); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +void sync(void) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + global_osrt_func.file_func.real_sync(); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); +} + +int syncfs(int fd) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_syncfs(fd); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int sync_file_range(int fd, off_t offset, off_t nbytes, unsigned int flags) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_sync_file_range(fd, offset, nbytes, flags); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +ssize_t vmsplice(int fd, const struct iovec* iov, size_t nr_segs, unsigned int flags) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_vmsplice(fd, iov, nr_segs, flags); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +ssize_t process_vm_readv(pid_t pid, const struct iovec* local_iov, unsigned long liovcnt, + const struct iovec* remote_iov, unsigned long riovcnt, unsigned long flags) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_process_vm_readv(pid, local_iov, liovcnt, remote_iov, riovcnt, flags); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +ssize_t process_vm_writev(pid_t pid, const struct iovec* local_iov, unsigned long liovcnt, + const struct iovec* remote_iov, unsigned long riovcnt, unsigned long flags) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_process_vm_writev(pid, local_iov, liovcnt, remote_iov, riovcnt, flags); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int fclose(FILE* stream) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_fclose(stream); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int fcloseall(void) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_fcloseall(); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int fflush(FILE* stream) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_fflush(stream); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int fgetc(FILE* stream) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_fgetc(stream); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +char* fgets(char* s, int size, FILE* stream) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + char* ret = global_osrt_func.file_func.real_fgets(s, size, stream); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int fputc(int c, FILE* stream) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_fputc(c, stream); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int fputs(const char* s, FILE* stream) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_fputs(s, stream); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +void flockfile(FILE* filehandle) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + global_osrt_func.file_func.real_flockfile(filehandle); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); +} + +int ftrylockfile(FILE* filehandle) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_ftrylockfile(filehandle); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +void funlockfile(FILE* filehandle) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + global_osrt_func.file_func.real_funlockfile(filehandle); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); +} + +FILE* fopen(const char* pathname, const char* mode) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_fopen(pathname, mode); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +FILE* freopen(const char* pathname, const char* mode, FILE* stream) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_freopen(pathname, mode, stream); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +size_t fread(void* ptr, size_t size, size_t nmemb, FILE* stream) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_fread(ptr, size, nmemb, stream); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +size_t fwrite(const void* ptr, size_t size, size_t nitems, FILE* stream) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_fwrite(ptr, size, nitems, stream); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +ssize_t getdelim(char** lineptr, size_t* n, int delimiter, FILE* stream) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_getdelim(lineptr, n, delimiter, stream); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +ssize_t getline(char** lineptr, size_t* n, FILE* stream) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_getline(lineptr, n, stream); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int getc(FILE* stream) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_getc(stream); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int putc(int c, FILE* stream) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_putc(c, stream); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int getc_unlocked(FILE* stream) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_getc_unlocked(stream); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int putc_unlocked(int c, FILE* stream) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_putc_unlocked(c, stream); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int fflush_unlocked(FILE* stream) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_fflush_unlocked(stream); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int fgetc_unlocked(FILE* stream) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_fgetc_unlocked(stream); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int fputc_unlocked(int c, FILE* stream) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_fputc_unlocked(c, stream); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +size_t fread_unlocked(void* ptr, size_t size, size_t n, FILE* stream) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_fread_unlocked(ptr, size, n, stream); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +size_t fwrite_unlocked(const void* ptr, size_t size, size_t n, FILE* stream) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_fwrite_unlocked(ptr, size, n, stream); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +char* fgets_unlocked(char* s, int n, FILE* stream) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + char* ret = global_osrt_func.file_func.real_fgets_unlocked(s, n, stream); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int fputs_unlocked(const char* s, FILE* stream) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_fputs_unlocked(s, stream); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} diff --git a/profiler/msprof_analyze/osrt_trace/src/file_func.h b/profiler/msprof_analyze/osrt_trace/src/file_func.h new file mode 100644 index 0000000000000000000000000000000000000000..23c6a25eeeddd734a1ab10ecfcb7d3035d2f9a6a --- /dev/null +++ b/profiler/msprof_analyze/osrt_trace/src/file_func.h @@ -0,0 +1,144 @@ +#pragma once + +#ifndef _GNU_SOURCE +#define _GNU_SOURCE +#endif + +#include +#include +#include + +using DupFunc = int(*)(int); +using Dup2Func = int(*)(int, int); +using Dup3Func = int(*)(int, int, int); +using TeeFunc = ssize_t(*)(int, int, size_t, unsigned int); +using SpliceFunc = ssize_t(*)(int, off_t*, int, off_t*, size_t, unsigned int); +using FallocateFunc = int(*)(int, int, off_t, off_t); +using FdatasyncFunc = int(*)(int); +using FsyncFunc = int(*)(int); +using FcntlFunc = int(*)(int, int, ...); +using FlockFunc = int(*)(int, int); +using LockfFunc = int(*)(int, int, off_t); +using TruncateFunc = int(*)(const char*, off_t); +using FtruncateFunc = int(*)(int, off_t); +using IoctlFunc = int(*)(int, int, ...); +using OpenFunc = int(*)(const char*, int, ...); +using OpenatFunc = int(*)(int, const char*, int, ...); +using PipeFunc = int(*)(int*); +using Pipe2Func = int(*)(int*, int); +using MkfifoFunc = int(*)(const char*, mode_t); +using MkfifoatFunc = int(*)(int, const char*, mode_t); +using ReadFunc = ssize_t(*)(int, void*, size_t); +using PreadFunc = ssize_t(*)(int, void*, size_t, off_t); +using ReadvFunc = ssize_t(*)(int, const struct iovec*, int); +using PreadvFunc = ssize_t(*)(int, const struct iovec*, int, off_t); +using Preadv2Func = ssize_t(*)(int, const struct iovec*, int, off_t, int); +using WriteFunc = ssize_t(*)(int, const void*, size_t); +using PwriteFunc = ssize_t(*)(int, const void*, size_t, off_t); +using WritevFunc = ssize_t(*)(int, const struct iovec*, int); +using PwritevFunc = ssize_t(*)(int, const struct iovec*, int, off_t); +using Pwritev2Func = ssize_t(*)(int, const struct iovec*, int, off_t, int); +using CopyFileRangeFunc = ssize_t(*)(int, off_t*, int, off_t*, size_t, unsigned int); +using SyncFunc = void(*)(void); +using SyncfsFunc = int(*)(int); +using SyncFileRangeFunc = int(*)(int, off_t, off_t, unsigned int); +using VmspliceFunc = ssize_t(*)(int, const struct iovec*, size_t, unsigned int); +using ProcessVmReadvFunc = ssize_t(*)(pid_t, const struct iovec*, unsigned long, const struct iovec*, unsigned long, unsigned long); +using ProcessVmWritevFunc = ssize_t(*)(pid_t, const struct iovec*, unsigned long, const struct iovec*, unsigned long, unsigned long); +using FcloseFunc = int(*)(FILE*); +using FcloseallFunc = int(*)(void); +using FflushFunc = int(*)(FILE*); +using FgetcFunc = int(*)(FILE*); +using FgetsFunc = char*(*)(char*, int, FILE*); +using FputcFunc = int(*)(int, FILE*); +using FputsFunc = int(*)(const char*, FILE*); +using FlockfileFunc = void(*)(FILE*); +using FtrylockfileFunc = int(*)(FILE*); +using FunlockfileFunc = void(*)(FILE*); +using FopenFunc = FILE*(*)(const char*, const char*); +using FreopenFunc = FILE*(*)(const char*, const char*, FILE*); +using FreadFunc = size_t(*)(void*, size_t, size_t, FILE*); +using FwriteFunc = size_t(*)(const void*, size_t, size_t, FILE*); +using GetdelimFunc = ssize_t(*)(char**, size_t*, int, FILE*); +using GetlineFunc = ssize_t(*)(char**, size_t*, FILE*); +using GetcFunc = int(*)(FILE*); +using PutcFunc = int(*)(int, FILE*); +using GetcUnlockedFunc = int(*)(FILE*); +using PutcUnlockedFunc = int(*)(int, FILE*); +using FflushUnlockedFunc = int(*)(FILE*); +using FgetcUnlockedFunc = int(*)(FILE*); +using FputcUnlockedFunc = int(*)(int, FILE*); +using FreadUnlockedFunc = size_t(*)(void*, size_t, size_t, FILE*); +using FwriteUnlockedFunc = size_t(*)(const void*, size_t, size_t, FILE*); +using FgetsUnlockedFunc = char*(*)(char*, int, FILE*); +using FputsUnlockedFunc = int(*)(const char*, FILE*); + +struct FileFuncProxy +{ + DupFunc real_dup = nullptr; + Dup2Func real_dup2 = nullptr; + Dup3Func real_dup3 = nullptr; + TeeFunc real_tee = nullptr; + SpliceFunc real_splice = nullptr; + FallocateFunc real_fallocate = nullptr; + FdatasyncFunc real_fdatasync = nullptr; + FsyncFunc real_fsync = nullptr; + FcntlFunc real_fcntl = nullptr; + FlockFunc real_flock = nullptr; + LockfFunc real_lockf = nullptr; + TruncateFunc real_truncate = nullptr; + FtruncateFunc real_ftruncate = nullptr; + IoctlFunc real_ioctl = nullptr; + OpenFunc real_open = nullptr; + OpenatFunc real_openat = nullptr; + PipeFunc real_pipe = nullptr; + Pipe2Func real_pipe2 = nullptr; + MkfifoFunc real_mkfifo = nullptr; + MkfifoatFunc real_mkfifoat = nullptr; + ReadFunc real_read = nullptr; + PreadFunc real_pread = nullptr; + ReadvFunc real_readv = nullptr; + PreadvFunc real_preadv = nullptr; + Preadv2Func real_preadv2 = nullptr; + WriteFunc real_write = nullptr; + PwriteFunc real_pwrite = nullptr; + WritevFunc real_writev = nullptr; + PwritevFunc real_pwritev = nullptr; + Pwritev2Func real_pwritev2 = nullptr; + CopyFileRangeFunc real_copy_file_range = nullptr; + SyncFunc real_sync = nullptr; + SyncfsFunc real_syncfs = nullptr; + SyncFileRangeFunc real_sync_file_range = nullptr; + VmspliceFunc real_vmsplice = nullptr; + ProcessVmReadvFunc real_process_vm_readv = nullptr; + ProcessVmWritevFunc real_process_vm_writev = nullptr; + FcloseFunc real_fclose = nullptr; + FcloseallFunc real_fcloseall = nullptr; + FflushFunc real_fflush = nullptr; + FgetcFunc real_fgetc = nullptr; + FgetsFunc real_fgets = nullptr; + FputcFunc real_fputc = nullptr; + FputsFunc real_fputs = nullptr; + FlockfileFunc real_flockfile = nullptr; + FtrylockfileFunc real_ftrylockfile = nullptr; + FunlockfileFunc real_funlockfile = nullptr; + FopenFunc real_fopen = nullptr; + FreopenFunc real_freopen = nullptr; + FreadFunc real_fread = nullptr; + FwriteFunc real_fwrite = nullptr; + GetdelimFunc real_getdelim = nullptr; + GetlineFunc real_getline = nullptr; + GetcFunc real_getc = nullptr; + PutcFunc real_putc = nullptr; + GetcUnlockedFunc real_getc_unlocked = nullptr; + PutcUnlockedFunc real_putc_unlocked = nullptr; + FflushUnlockedFunc real_fflush_unlocked = nullptr; + FgetcUnlockedFunc real_fgetc_unlocked = nullptr; + FputcUnlockedFunc real_fputc_unlocked = nullptr; + FreadUnlockedFunc real_fread_unlocked = nullptr; + FwriteUnlockedFunc real_fwrite_unlocked = nullptr; + FgetsUnlockedFunc real_fgets_unlocked = nullptr; + FputsUnlockedFunc real_fputs_unlocked = nullptr; + + void loadFunc(); +}; diff --git a/profiler/msprof_analyze/osrt_trace/src/msosrt_trace.cpp b/profiler/msprof_analyze/osrt_trace/src/msosrt_trace.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a3a88b05480193ce9bee6c26480e214f69e4ddf0 --- /dev/null +++ b/profiler/msprof_analyze/osrt_trace/src/msosrt_trace.cpp @@ -0,0 +1,476 @@ +#include "msosrt_trace.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#if !defined (__linux__) || !defined(__GLIBC__) +#error "This tool only works on Linux!" +#endif + +#ifdef __cplusplus +extern "C" { +#endif +static void setup_trace() __attribute ((constructor)); +static void end_trace() __attribute ((destructor)); +#ifdef __cplusplus +} +#endif + +// Special handling exit func +static void (*real_exit)(int status) __attribute__((noreturn)) = nullptr; +static void (*real__exit)(int status) __attribute__((noreturn)) = nullptr; +static void (*real__Exit)(int status) __attribute__((noreturn)) = nullptr; + +static __thread bool RECURSIVE = false; +static volatile bool INITIALIZED = false; + +namespace { +pid_t GetPid() +{ + static thread_local pid_t pid = getpid(); + return pid; +} + +pid_t GetTid() +{ + static thread_local pid_t tid = gettid(); + return tid; +} + +const char* DUMP_FILE = "msosrt_trace_"; +char EXPORT_PATH[PATH_MAX]; + +const size_t RECORD_LENGTH = 512 * 1024; // Default number of trace data records +struct { + OSRTRecord data_[RECORD_LENGTH]; + std::atomic index_{0}; + bool is_full_ = false; + + void recordData(const char* function, uint64_t start_time, uint64_t duration) + { + size_t index = index_.load(std::memory_order_relaxed); + if (index + 1 >= RECORD_LENGTH) { + index_.store(0, std::memory_order_relaxed); + is_full_ = true; + } else { + index_.fetch_add(1, std::memory_order_relaxed); + } + auto& record = data_[index]; + record.pid = GetPid(); + record.tid = GetTid(); + record.function = function; + record.start_time = start_time; + record.duration = duration; + } + + size_t size() + { + return is_full_ ? RECORD_LENGTH : index_.load(std::memory_order_relaxed); + } + + bool hasValidData() + { + pid_t pid = getpid(); + for (size_t i = 0, len = size(); i < len; ++i) { + if (data_[i].pid == pid && data_[i].function != nullptr) { + return true; + } + } + return false; + } +} OSRT_RECORD_QUEUE; +} + +OSRTFunc global_osrt_func; + +void OSRTFunc::loadFunc() +{ + static volatile bool loaded = false; + if (LIKELY(loaded)) { + return; + } + RECURSIVE = true; + LOAD_FUNC(malloc, MallocFunc); + LOAD_FUNC(realloc, ReallocFunc); + LOAD_FUNC(free, FreeFunc); + LOAD_FUNC(mmap, MmapFunc); + LOAD_FUNC(munmap, MunmapFunc); + LOAD_FUNC(mremap, MremapFunc); + LOAD_FUNC(msync, MsyncFunc); + LOAD_FUNC(mprotect, MprotectFunc); + LOAD_FUNC(brk, BrkFunc); + + LOAD_FUNC(pthread_mutex_lock, PthreadMutexLockFunc); + LOAD_FUNC(pthread_mutex_timedlock, PthreadMutexTimedlockFunc); + LOAD_FUNC(pthread_cond_signal, PthreadCondSignalFunc); + LOAD_FUNC(pthread_cond_broadcast, PthreadCondBroadcastFunc); + LOAD_FUNC(pthread_cond_wait, PthreadCondWaitFunc); + LOAD_FUNC(pthread_cond_timedwait, PthreadCondTimedwaitFunc); + LOAD_FUNC(pthread_rwlock_rdlock, PthreadRwlockRdlockFunc); + LOAD_FUNC(pthread_rwlock_timedrdlock, PthreadRwlockTimedrdlockFunc); + LOAD_FUNC(pthread_rwlock_wrlock, PthreadRwlockWrlockFunc); + LOAD_FUNC(pthread_rwlock_timedwrlock, PthreadRwlockTimedwrlockFunc); + + real_exit = reinterpret_cast(dlsym(RTLD_NEXT, "exit")); + real__exit = reinterpret_cast(dlsym(RTLD_NEXT, "_exit")); + real__Exit = reinterpret_cast(dlsym(RTLD_NEXT, "_Exit")); + + file_func.loadFunc(); + socket_func.loadFunc(); + + loaded = true; + RECURSIVE = false; +} + +void OSRTFunc::recordFunc(uint64_t start_time, uint64_t duration, const char* name) +{ + if (UNLIKELY(!INITIALIZED || RECURSIVE)) { + return; + } + if (UNLIKELY(duration >= threshold_)) { + RECURSIVE = true; + OSRT_RECORD_QUEUE.recordData(name, start_time, duration); + RECURSIVE = false; + } +} + +void OSRTFunc::dumpFunc() +{ + if (!INITIALIZED) { + return; + } + static std::mutex dump_mutex; + static bool dumped = false; + + std::lock_guard lock(dump_mutex); + if (!dumped) { + RECURSIVE = true; + if (OSRT_RECORD_QUEUE.hasValidData()) { + std::string dump_file; + pid_t pid = getpid(); + // The glibc program_invocation_short_name contains the basename that was used to invoke the calling program + if (program_invocation_short_name != nullptr) { + dump_file = std::string(EXPORT_PATH) + "/" + DUMP_FILE + std::to_string(pid) + "_" + program_invocation_short_name + ".csv"; + } else { + dump_file = std::string(EXPORT_PATH) + "/" + DUMP_FILE + std::to_string(pid) + ".csv"; + } + if (!PathUtils::IsFileExist(dump_file) && !PathUtils::CreateFile(dump_file)) { + fprintf(stderr, "[ERROR] Create msosrt trace file failed.\n"); + RECURSIVE = false; + return; + } + auto fd = fopen(dump_file.c_str(), "ab"); + if (fd == nullptr) { + RECURSIVE = false; + return; + } + fprintf(fd, "%s\n", "Pid,Tid,Function,StartTime(ns),Duration(ns)"); + for (size_t i = 0, len = OSRT_RECORD_QUEUE.size(); i < len; ++i) { + if (OSRT_RECORD_QUEUE.data_[i].pid == pid && OSRT_RECORD_QUEUE.data_[i].function != nullptr) { + fprintf(fd, "%" PRIdMAX ",%" PRIdMAX ",%s,%" PRIu64 ",%" PRIu64 "\n", + static_cast(pid), + static_cast(OSRT_RECORD_QUEUE.data_[i].tid), + OSRT_RECORD_QUEUE.data_[i].function, + OSRT_RECORD_QUEUE.data_[i].start_time, + OSRT_RECORD_QUEUE.data_[i].duration); + } + } + fclose(fd); + } + RECURSIVE = false; + } + dumped = true; +} + +static void setup_trace() +{ + if (LIKELY(INITIALIZED)) { + return; + } + global_osrt_func.loadFunc(); + INITIALIZED = true; + + RECURSIVE = true; + const char* threshold_env_val = getenv("MSOSRT_TRACE_THRESHOLD"); + int64_t threshold = 0; + if (threshold_env_val == nullptr || str_to_i64(threshold_env_val, threshold) != 0) { + fprintf(stderr, "[WARNING] Parse MSOSRT_TRACE_THRESHOLD failed, use default value\n"); + } else { + if (threshold > 0) { + global_osrt_func.threshold_ = threshold; + } else { + fprintf(stderr, "[WARNING] MSOSRT_TRACE_THRESHOLD must be a positive integer, use default value\n"); + } + } + + const char* export_path_env_val = getenv("MSOSRT_EXPORT_PATH"); + std::string dump_path; + if (export_path_env_val != nullptr) { + dump_path = export_path_env_val; + } + if (dump_path.empty()) { + fprintf(stderr, "[WARNING] MSOSRT_EXPORT_PATH is not set, data will export to current working directory\n"); + char cwd_path[PATH_MAX] = {0}; + if (getcwd(cwd_path, PATH_MAX) != nullptr) { + dump_path = cwd_path; + } + } + std::string abs_path = PathUtils::RelativeToAbsPath(dump_path); + if (PathUtils::DirPathCheck(abs_path)) { + std::string real_path = PathUtils::RealPath(abs_path); + strncpy(EXPORT_PATH, real_path.c_str(), real_path.size() < PATH_MAX ? real_path.size() : PATH_MAX); + fprintf(stderr, "[INFO] MSOSRT result export path is: %s\n", real_path.c_str()); + } else { + fprintf(stderr, "[ERROR] Invalid export path, data will not be exported.\n"); + } + RECURSIVE = false; +} + +static void end_trace() +{ + global_osrt_func.dumpFunc(); +} + +void* malloc(size_t size) +{ + global_osrt_func.loadFunc(); + if (UNLIKELY(RECURSIVE)) { + return (void*)global_osrt_func.real_malloc(size); + } + uint64_t start_time = nsec_now(); + void* ret = global_osrt_func.real_malloc(size); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +void* realloc(void* ptr, size_t size) +{ + global_osrt_func.loadFunc(); + if (UNLIKELY(RECURSIVE)) { + return (void*)global_osrt_func.real_realloc(ptr, size); + } + uint64_t start_time = nsec_now(); + void* ret = global_osrt_func.real_realloc(ptr, size); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +void free(void* ptr) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + global_osrt_func.real_free(ptr); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); +} + +void* mmap(void* addr, size_t length, int prot, int flags, int fd, off_t offset) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + void* ret = global_osrt_func.real_mmap(addr, length, prot, flags, fd, offset); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +void* mremap(void* old_address, size_t old_size, size_t new_size, int flags, ...) +{ + global_osrt_func.loadFunc(); + va_list args; + va_start(args, flags); + void* arg = va_arg(args, void*); + va_end(args); + uint64_t start_time = nsec_now(); + void* ret = global_osrt_func.real_mremap(old_address, old_size, new_size, flags, arg); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int munmap(void* addr, size_t length) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.real_munmap(addr, length); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int msync(void* addr, size_t length, int flags) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.real_msync(addr, length, flags); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int mprotect(void* addr, size_t len, int prot) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.real_mprotect(addr, len, prot); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int brk(void* addr) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.real_brk(addr); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int pthread_mutex_lock(pthread_mutex_t* mutex) +{ + if (UNLIKELY(!INITIALIZED && RECURSIVE)) { + // During the initialization phase we might be called inside of dlsym(). + // Since we'd enter an endless loop if we tried to resolved the real + // pthread_mutex_lock() here then we simply fake the lock which should + // be safe since no thread can be running yet. + return 0; + } + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.real_pthread_mutex_lock(mutex); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int pthread_mutex_timedlock(pthread_mutex_t* mutex, const struct timespec* abstime) +{ + global_osrt_func.loadFunc(); + if (UNLIKELY(RECURSIVE)) { + return global_osrt_func.real_pthread_mutex_timedlock(mutex, abstime); + } + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.real_pthread_mutex_timedlock(mutex, abstime); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int pthread_cond_signal(pthread_cond_t* cond) +{ + global_osrt_func.loadFunc(); + if (UNLIKELY(RECURSIVE)) { + return global_osrt_func.real_pthread_cond_signal(cond); + } + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.real_pthread_cond_signal(cond); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int pthread_cond_broadcast(pthread_cond_t* cond) +{ + global_osrt_func.loadFunc(); + if (UNLIKELY(RECURSIVE)) { + return global_osrt_func.real_pthread_cond_broadcast(cond); + } + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.real_pthread_cond_broadcast(cond); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int pthread_cond_wait(pthread_cond_t* cond, pthread_mutex_t* mutex) +{ + global_osrt_func.loadFunc(); + if (UNLIKELY(RECURSIVE)) { + return global_osrt_func.real_pthread_cond_wait(cond, mutex); + } + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.real_pthread_cond_wait(cond, mutex); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int pthread_cond_timedwait(pthread_cond_t* cond, pthread_mutex_t* mutex, const struct timespec* abstime) +{ + global_osrt_func.loadFunc(); + if (UNLIKELY(RECURSIVE)) { + return global_osrt_func.real_pthread_cond_timedwait(cond, mutex, abstime); + } + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.real_pthread_cond_timedwait(cond, mutex, abstime); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int pthread_rwlock_rdlock(pthread_rwlock_t* rwlock) +{ + global_osrt_func.loadFunc(); + if (UNLIKELY(RECURSIVE)) { + return global_osrt_func.real_pthread_rwlock_rdlock(rwlock); + } + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.real_pthread_rwlock_rdlock(rwlock); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int pthread_rwlock_timedrdlock(pthread_rwlock_t* rwlock, const struct timespec* abstime) +{ + global_osrt_func.loadFunc(); + if (UNLIKELY(RECURSIVE)) { + return global_osrt_func.real_pthread_rwlock_timedrdlock(rwlock, abstime); + } + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.real_pthread_rwlock_timedrdlock(rwlock, abstime); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int pthread_rwlock_wrlock(pthread_rwlock_t* rwlock) +{ + global_osrt_func.loadFunc(); + if (UNLIKELY(RECURSIVE)) { + global_osrt_func.real_pthread_rwlock_wrlock(rwlock); + } + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.real_pthread_rwlock_wrlock(rwlock); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int pthread_rwlock_timedwrlock(pthread_rwlock_t* rwlock, const struct timespec* abstime) +{ + global_osrt_func.loadFunc(); + if (UNLIKELY(RECURSIVE)) { + return global_osrt_func.real_pthread_rwlock_timedwrlock(rwlock, abstime); + } + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.real_pthread_rwlock_timedwrlock(rwlock, abstime); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +void exit(int status) +{ + if (LIKELY(INITIALIZED)) { + global_osrt_func.dumpFunc(); + } + real_exit(status); +} + +void _exit(int status) +{ + if (LIKELY(INITIALIZED)) { + global_osrt_func.dumpFunc(); + } + real__exit(status); +} + +void _Exit(int status) +{ + if (LIKELY(INITIALIZED)) { + global_osrt_func.dumpFunc(); + } + real__Exit(status); +} diff --git a/profiler/msprof_analyze/osrt_trace/src/msosrt_trace.h b/profiler/msprof_analyze/osrt_trace/src/msosrt_trace.h new file mode 100644 index 0000000000000000000000000000000000000000..e153ef5138883cd597c0a5a524adc5ec5b555ea4 --- /dev/null +++ b/profiler/msprof_analyze/osrt_trace/src/msosrt_trace.h @@ -0,0 +1,207 @@ +#pragma once + +#ifndef _GNU_SOURCE +#define _GNU_SOURCE +#endif + +#include +#include +#include +#include +#include +#include +#include + +#include "utils.h" +#include "file_func.h" +#include "socket_func.h" + +#define TRACE_API __attribute__((visibility("default"))) +#define LOAD_FUNC(name, func_type) \ + do { \ + (real_##name) = reinterpret_cast(dlsym(RTLD_NEXT, #name)); \ + } while (false) + +#ifdef __cplusplus +extern "C" { +#endif +// memory func +TRACE_API void* malloc(size_t size); +TRACE_API void* realloc(void* ptr, size_t size); +TRACE_API void free(void* ptr); +TRACE_API void* mmap(void* addr, size_t length, int prot, int flags, int fd, off_t offset); +TRACE_API int munmap(void* addr, size_t length); +TRACE_API void* mremap(void* old_address, size_t old_size, size_t new_size, int flags, ... /* void *new_address */); +TRACE_API int msync(void* addr, size_t length, int flags); +TRACE_API int mprotect(void* addr, size_t len, int prot); +TRACE_API int brk(void* addr); +// pthread func +TRACE_API int pthread_mutex_lock(pthread_mutex_t* mutex); +TRACE_API int pthread_mutex_timedlock(pthread_mutex_t* mutex, const struct timespec* abstime); +TRACE_API int pthread_cond_signal(pthread_cond_t* cond); +TRACE_API int pthread_cond_broadcast(pthread_cond_t* cond); +TRACE_API int pthread_cond_wait(pthread_cond_t* cond, pthread_mutex_t* mutex); +TRACE_API int pthread_cond_timedwait(pthread_cond_t* cond, pthread_mutex_t* mutex, const struct timespec* abstime); +TRACE_API int pthread_rwlock_rdlock(pthread_rwlock_t* rwlock); +TRACE_API int pthread_rwlock_timedrdlock(pthread_rwlock_t* rwlock, const struct timespec* abstime); +TRACE_API int pthread_rwlock_wrlock(pthread_rwlock_t* rwlock); +TRACE_API int pthread_rwlock_timedwrlock(pthread_rwlock_t* rwlock, const struct timespec* abstime); +// exit func +TRACE_API void exit(int status) __attribute__((noreturn)); +TRACE_API void _exit(int status) __attribute__((noreturn)); +TRACE_API void _Exit(int status) __attribute__((noreturn)); +// file func +TRACE_API int dup(int oldfd); +TRACE_API int dup2(int oldfd, int newfd); +TRACE_API int dup3(int oldfd, int newfd, int flags); +TRACE_API ssize_t tee(int fd_in, int fd_out, size_t len, unsigned int flags); +TRACE_API ssize_t splice(int fd_in, off_t* off_in, int fd_out, off_t* off_out, size_t len, unsigned int flags); +TRACE_API int fallocate(int fd, int mode, off_t offset, off_t len); +TRACE_API int fdatasync(int fildes); +TRACE_API int fsync(int fd); +TRACE_API int fcntl(int fd, int op, ...); +TRACE_API int flock(int fd, int op); +TRACE_API int lockf(int fd, int op, off_t len); +TRACE_API int truncate(const char* path, off_t length); +TRACE_API int ftruncate(int fildes, off_t length); +TRACE_API int ioctl(int fd, int op, ...); +TRACE_API int open(const char* pathname, int flags, ... /* mode_t mode */ ); +TRACE_API int openat(int dirfd, const char* pathname, int flags, ... /* mode_t mode */ ); +TRACE_API int pipe(int pipefd[2]); +TRACE_API int pipe2(int pipefd[2], int flags); +TRACE_API int mkfifo(const char* pathname, mode_t mode); +TRACE_API int mkfifoat(int dirfd, const char* pathname, mode_t mode); +TRACE_API ssize_t read(int fd, void* buf, size_t count); +TRACE_API ssize_t pread(int fd, void* buf, size_t count, off_t offset); +TRACE_API ssize_t readv(int fd, const struct iovec* iov, int iovcnt); +TRACE_API ssize_t preadv(int fd, const struct iovec* iov, int iovcnt, off_t offset); +TRACE_API ssize_t preadv2(int fd, const struct iovec* iov, int iovcnt, off_t offset, int flags); +TRACE_API ssize_t write(int fd, const void* buf, size_t count); +TRACE_API ssize_t pwrite(int fd, const void* buf, size_t count, off_t offset); +TRACE_API ssize_t writev(int fd, const struct iovec* iov, int iovcnt); +TRACE_API ssize_t pwritev(int fd, const struct iovec* iov, int iovcnt, off_t offset); +TRACE_API ssize_t pwritev2(int fd, const struct iovec* iov, int iovcnt, off_t offset, int flags); +TRACE_API ssize_t copy_file_range(int fd_in, off_t* off_in, int fd_out, off_t* off_out, size_t len, unsigned int flags); +TRACE_API void sync(void); +TRACE_API int syncfs(int fd); +TRACE_API int sync_file_range(int fd, off_t offset, off_t nbytes, unsigned int flags); +TRACE_API ssize_t vmsplice(int fd, const struct iovec* iov, size_t nr_segs, unsigned int flags); +TRACE_API ssize_t process_vm_readv(pid_t pid, const struct iovec* local_iov, unsigned long liovcnt, + const struct iovec* remote_iov, unsigned long riovcnt, unsigned long flags); +TRACE_API ssize_t process_vm_writev(pid_t pid, const struct iovec* local_iov, unsigned long liovcnt, + const struct iovec* remote_iov, unsigned long riovcnt, unsigned long flags); +TRACE_API int fclose(FILE* stream); +TRACE_API int fcloseall(void); +TRACE_API int fflush(FILE* stream); +TRACE_API int fgetc(FILE* stream); +TRACE_API char* fgets(char* s, int size, FILE* stream); +TRACE_API int fputc(int c, FILE* stream); +TRACE_API int fputs(const char* s, FILE* stream); +TRACE_API void flockfile(FILE* filehandle); +TRACE_API int ftrylockfile(FILE* filehandle); +TRACE_API void funlockfile(FILE* filehandle); +TRACE_API FILE* fopen(const char* pathname, const char* mode); +TRACE_API FILE* freopen(const char* pathname, const char* mode, FILE* stream); +TRACE_API size_t fread(void* ptr, size_t size, size_t nmemb, FILE* stream); +TRACE_API size_t fwrite(const void* ptr, size_t size, size_t nitems, FILE* stream); +TRACE_API ssize_t getdelim(char** lineptr, size_t* n, int delimiter, FILE* stream); +TRACE_API ssize_t getline(char** lineptr, size_t* n, FILE* stream); +TRACE_API int getc(FILE* stream); +TRACE_API int putc(int c, FILE* stream); +TRACE_API int getc_unlocked(FILE* stream); +TRACE_API int putc_unlocked(int c, FILE* stream); +TRACE_API int fflush_unlocked(FILE* stream); +TRACE_API int fgetc_unlocked(FILE* stream); +TRACE_API int fputc_unlocked(int c, FILE* stream); +TRACE_API size_t fread_unlocked(void* ptr, size_t size, size_t n, FILE* stream); +TRACE_API size_t fwrite_unlocked(const void* ptr, size_t size, size_t n, FILE* stream); +TRACE_API char* fgets_unlocked(char* s, int n, FILE* stream); +TRACE_API int fputs_unlocked(const char* s, FILE* stream); +// socket func +TRACE_API int socket(int domain, int type, int protocol); +TRACE_API int socketpair(int domain, int type, int protocol, int sv[2]); +TRACE_API int epoll_ctl(int epfd, int op, int fd, struct epoll_event* event); +TRACE_API int epoll_wait(int epfd, struct epoll_event* events, int maxevents, int timeout); +TRACE_API int epoll_pwait(int epfd, struct epoll_event* events, int maxevents, int timeout, const sigset_t* sigmask); +TRACE_API int select(int nfds, fd_set* readfds, fd_set* writefds, fd_set* exceptfds, struct timeval* timeout); +TRACE_API int listen(int sockfd, int backlog); +TRACE_API int accept(int sockfd, struct sockaddr* addr, socklen_t* addrlen); +TRACE_API int accept4(int sockfd, struct sockaddr* addr, socklen_t* addrlen, int flags); +TRACE_API int bind(int sockfd, const struct sockaddr* addr, socklen_t addrlen); +TRACE_API int poll(struct pollfd* fds, nfds_t nfds, int timeout); +TRACE_API int ppoll(struct pollfd* fds, nfds_t nfds, const struct timespec* tmo_p, const sigset_t* sigmask); +TRACE_API ssize_t send(int sockfd, const void* buf, size_t len, int flags); +TRACE_API ssize_t sendto(int sockfd, const void* buf, size_t len, int flags, const struct sockaddr* dest_addr, socklen_t addrlen); +TRACE_API ssize_t sendmsg(int sockfd, const struct msghdr* msg, int flags); +TRACE_API int sendmmsg(int sockfd, struct mmsghdr* msgvec, unsigned int vlen, int flags); +TRACE_API ssize_t sendfile(int out_fd, int in_fd, off_t* offset, size_t count); +TRACE_API ssize_t recv(int sockfd, void* buf, size_t len, int flags); +TRACE_API ssize_t recvfrom(int sockfd, void* buf, size_t len, int flags, struct sockaddr* src_addr, socklen_t* addrlen); +TRACE_API ssize_t recvmsg(int sockfd, struct msghdr* msg, int flags); +TRACE_API int recvmmsg(int sockfd, struct mmsghdr* msgvec, unsigned int vlen, int flags, struct timespec* timeout); +#ifdef __cplusplus +} +#endif + +using MallocFunc = void*(*)(size_t); +using ReallocFunc = void*(*)(void*, size_t); +using FreeFunc = void(*)(void*); +using MmapFunc = void*(*)(void*, size_t, int, int, int, off_t); +using MunmapFunc = int(*)(void*, size_t); +using MremapFunc = void*(*)(void*, size_t, size_t, int, ...); +using MsyncFunc = int(*)(void*, size_t, int); +using MprotectFunc = int(*)(void*, size_t, int); +using BrkFunc = int(*)(void*); +using PthreadMutexLockFunc = int(*)(pthread_mutex_t*); +using PthreadMutexTimedlockFunc = int(*)(pthread_mutex_t*, const struct timespec*); +using PthreadCondSignalFunc = int(*)(pthread_cond_t*); +using PthreadCondBroadcastFunc = int(*)(pthread_cond_t*); +using PthreadCondWaitFunc = int(*)(pthread_cond_t*, pthread_mutex_t*); +using PthreadCondTimedwaitFunc = int(*)(pthread_cond_t*, pthread_mutex_t*, const struct timespec*); +using PthreadRwlockRdlockFunc = int(*)(pthread_rwlock_t*); +using PthreadRwlockTimedrdlockFunc = int(*)(pthread_rwlock_t*, const struct timespec*); +using PthreadRwlockWrlockFunc = int(*)(pthread_rwlock_t*); +using PthreadRwlockTimedwrlockFunc = int(*)(pthread_rwlock_t*, const struct timespec*); + +struct OSRTRecord { + pid_t pid = 0; + pid_t tid = 0; + const char* function = nullptr; + uint64_t start_time = 0; + uint64_t duration = 0; +}; + +const uint64_t DEFAULT_THRESHOLD = 10 * 1000 * 1000; // 10ms + +struct OSRTFunc { + uint64_t threshold_ = DEFAULT_THRESHOLD; + + MallocFunc real_malloc = nullptr; + ReallocFunc real_realloc = nullptr; + FreeFunc real_free = nullptr; + MmapFunc real_mmap = nullptr; + MunmapFunc real_munmap = nullptr; + MremapFunc real_mremap = nullptr; + MsyncFunc real_msync = nullptr; + MprotectFunc real_mprotect = nullptr; + BrkFunc real_brk = nullptr; + PthreadMutexLockFunc real_pthread_mutex_lock = nullptr; + PthreadMutexTimedlockFunc real_pthread_mutex_timedlock = nullptr; + PthreadCondSignalFunc real_pthread_cond_signal = nullptr; + PthreadCondBroadcastFunc real_pthread_cond_broadcast = nullptr; + PthreadCondWaitFunc real_pthread_cond_wait = nullptr; + PthreadCondTimedwaitFunc real_pthread_cond_timedwait = nullptr; + PthreadRwlockRdlockFunc real_pthread_rwlock_rdlock = nullptr; + PthreadRwlockTimedrdlockFunc real_pthread_rwlock_timedrdlock = nullptr; + PthreadRwlockWrlockFunc real_pthread_rwlock_wrlock = nullptr; + PthreadRwlockTimedwrlockFunc real_pthread_rwlock_timedwrlock = nullptr; + + FileFuncProxy file_func; + SocketFuncProxy socket_func; + + void loadFunc(); + void recordFunc(uint64_t start_time, uint64_t duration, const char* name); + void dumpFunc(); +}; + +extern OSRTFunc global_osrt_func; diff --git a/profiler/msprof_analyze/osrt_trace/src/socket_func.cpp b/profiler/msprof_analyze/osrt_trace/src/socket_func.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f2863c6a515f3d5159eb5e7e1212499d78301df9 --- /dev/null +++ b/profiler/msprof_analyze/osrt_trace/src/socket_func.cpp @@ -0,0 +1,217 @@ +#include "socket_func.h" + +#include "msosrt_trace.h" + +void SocketFuncProxy::loadFunc() +{ + LOAD_FUNC(socket, SocketFunc); + LOAD_FUNC(socketpair, SocketpairFunc); + LOAD_FUNC(epoll_ctl, EpollCtlFunc); + LOAD_FUNC(epoll_wait, EpollWaitFunc); + LOAD_FUNC(epoll_pwait, EpollPwaitFunc); + LOAD_FUNC(select, SelectFunc); + LOAD_FUNC(listen, ListenFunc); + LOAD_FUNC(accept, AcceptFunc); + LOAD_FUNC(accept4, Accept4Func); + LOAD_FUNC(bind, BindFunc); + LOAD_FUNC(poll, PollFunc); + LOAD_FUNC(ppoll, PpollFunc); + LOAD_FUNC(send, SendFunc); + LOAD_FUNC(sendto, SendtoFunc); + LOAD_FUNC(sendmsg, SendmsgFunc); + LOAD_FUNC(sendmmsg, SendmmsgFunc); + LOAD_FUNC(sendfile, SendfileFunc); + LOAD_FUNC(recv, RecvFunc); + LOAD_FUNC(recvfrom, RecvfromFunc); + LOAD_FUNC(recvmsg, RecvmsgFunc); + LOAD_FUNC(recvmmsg, RecvmmsgFunc); +} + +int socket(int domain, int type, int protocol) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.socket_func.real_socket(domain, type, protocol); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int socketpair(int domain, int type, int protocol, int sv[2]) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.socket_func.real_socketpair(domain, type, protocol, sv); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int epoll_ctl(int epfd, int op, int fd, struct epoll_event* event) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.socket_func.real_epoll_ctl(epfd, op, fd, event); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int epoll_wait(int epfd, struct epoll_event* events, int maxevents, int timeout) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.socket_func.real_epoll_wait(epfd, events, maxevents, timeout); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int epoll_pwait(int epfd, struct epoll_event* events, int maxevents, int timeout, const sigset_t* sigmask) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.socket_func.real_epoll_pwait(epfd, events, maxevents, timeout, sigmask); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int select(int nfds, fd_set* readfds, fd_set* writefds, fd_set* exceptfds, struct timeval* timeout) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.socket_func.real_select(nfds, readfds, writefds, exceptfds, timeout); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int listen(int sockfd, int backlog) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.socket_func.real_listen(sockfd, backlog); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int accept(int sockfd, struct sockaddr* addr, socklen_t* addrlen) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.socket_func.real_accept(sockfd, addr, addrlen); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int accept4(int sockfd, struct sockaddr* addr, socklen_t* addrlen, int flags) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.socket_func.real_accept4(sockfd, addr, addrlen, flags); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int bind(int sockfd, const struct sockaddr* addr, socklen_t addrlen) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.socket_func.real_bind(sockfd, addr, addrlen); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int poll(struct pollfd* fds, nfds_t nfds, int timeout) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.socket_func.real_poll(fds, nfds, timeout); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int ppoll(struct pollfd* fds, nfds_t nfds, const struct timespec* tmo_p, const sigset_t* sigmask) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.socket_func.real_ppoll(fds, nfds, tmo_p, sigmask); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +ssize_t send(int sockfd, const void* buf, size_t len, int flags) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.socket_func.real_send(sockfd, buf, len, flags); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +ssize_t sendto(int sockfd, const void* buf, size_t len, int flags, const struct sockaddr* dest_addr, socklen_t addrlen) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.socket_func.real_sendto(sockfd, buf, len, flags, dest_addr, addrlen); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +ssize_t sendmsg(int sockfd, const struct msghdr* msg, int flags) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.socket_func.real_sendmsg(sockfd, msg, flags); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int sendmmsg(int sockfd, struct mmsghdr* msgvec, unsigned int vlen, int flags) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.socket_func.real_sendmmsg(sockfd, msgvec, vlen, flags); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +ssize_t sendfile(int out_fd, int in_fd, off_t* offset, size_t count) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.socket_func.real_sendfile(out_fd, in_fd, offset, count); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +ssize_t recv(int sockfd, void* buf, size_t len, int flags) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.socket_func.real_recv(sockfd, buf, len, flags); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +ssize_t recvfrom(int sockfd, void* buf, size_t len, int flags, struct sockaddr* src_addr, socklen_t* addrlen) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.socket_func.real_recvfrom(sockfd, buf, len, flags, src_addr, addrlen); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +ssize_t recvmsg(int sockfd, struct msghdr* msg, int flags) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.socket_func.real_recvmsg(sockfd, msg, flags); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int recvmmsg(int sockfd, struct mmsghdr* msgvec, unsigned int vlen, int flags, struct timespec* timeout) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.socket_func.real_recvmmsg(sockfd, msgvec, vlen, flags, timeout); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} diff --git a/profiler/msprof_analyze/osrt_trace/src/socket_func.h b/profiler/msprof_analyze/osrt_trace/src/socket_func.h new file mode 100644 index 0000000000000000000000000000000000000000..361ce1d6382eada6cd942d74c2f3e0e7cd8621a0 --- /dev/null +++ b/profiler/msprof_analyze/osrt_trace/src/socket_func.h @@ -0,0 +1,60 @@ +#pragma once + +#ifndef _GNU_SOURCE +#define _GNU_SOURCE +#endif + +#include +#include +#include +#include +#include + +using SocketFunc = int(*)(int, int, int); +using SocketpairFunc = int(*)(int, int, int, int* sv); +using EpollCtlFunc = int(*)(int, int, int, struct epoll_event*); +using EpollWaitFunc = int(*)(int, struct epoll_event*, int, int); +using EpollPwaitFunc = int(*)(int, struct epoll_event*, int, int, const sigset_t*); +using SelectFunc = int(*)(int, fd_set*, fd_set*, fd_set*, struct timeval*); +using ListenFunc = int(*)(int, int); +using AcceptFunc = int(*)(int, struct sockaddr*, socklen_t*); +using Accept4Func = int(*)(int, struct sockaddr*, socklen_t*, int); +using BindFunc = int(*)(int, const struct sockaddr*, socklen_t); +using PollFunc = int(*)(struct pollfd*, nfds_t, int); +using PpollFunc = int(*)(struct pollfd*, nfds_t, const struct timespec*, const sigset_t*); +using SendFunc = ssize_t(*)(int, const void*, size_t, int); +using SendtoFunc = ssize_t(*)(int, const void*, size_t, int, const struct sockaddr*, socklen_t); +using SendmsgFunc = ssize_t(*)(int, const struct msghdr*, int); +using SendmmsgFunc = int(*)(int, struct mmsghdr*, unsigned int, int); +using SendfileFunc = ssize_t(*)(int, int, off_t*, size_t); +using RecvFunc = ssize_t(*)(int, void*, size_t, int); +using RecvfromFunc = ssize_t(*)(int, void*, size_t, int, struct sockaddr*, socklen_t*); +using RecvmsgFunc = ssize_t(*)(int, struct msghdr*, int); +using RecvmmsgFunc = int(*)(int, struct mmsghdr*, unsigned int, int, struct timespec*); + +struct SocketFuncProxy +{ + SocketFunc real_socket = nullptr; + SocketpairFunc real_socketpair = nullptr; + EpollCtlFunc real_epoll_ctl = nullptr; + EpollWaitFunc real_epoll_wait = nullptr; + EpollPwaitFunc real_epoll_pwait = nullptr; + SelectFunc real_select = nullptr; + ListenFunc real_listen = nullptr; + AcceptFunc real_accept = nullptr; + Accept4Func real_accept4 = nullptr; + BindFunc real_bind = nullptr; + PollFunc real_poll = nullptr; + PpollFunc real_ppoll = nullptr; + SendFunc real_send = nullptr; + SendtoFunc real_sendto = nullptr; + SendmsgFunc real_sendmsg = nullptr; + SendmmsgFunc real_sendmmsg = nullptr; + SendfileFunc real_sendfile = nullptr; + RecvFunc real_recv = nullptr; + RecvfromFunc real_recvfrom = nullptr; + RecvmsgFunc real_recvmsg = nullptr; + RecvmmsgFunc real_recvmmsg = nullptr; + + void loadFunc(); +}; diff --git a/profiler/msprof_analyze/osrt_trace/src/utils.cpp b/profiler/msprof_analyze/osrt_trace/src/utils.cpp new file mode 100644 index 0000000000000000000000000000000000000000..82382d23039e63c7ab2d4475d0dcf7fe2aec9fad --- /dev/null +++ b/profiler/msprof_analyze/osrt_trace/src/utils.cpp @@ -0,0 +1,159 @@ +#include "utils.h" + +#include +#include +#include +#include +#include +#include + +int str_to_i64(const std::string& str, int64_t& num) +{ + if (str.empty()) { + return -1; + } + size_t pos = 0; + try { + num = std::stoll(str, &pos); + } catch (...) { + return -1; + } + if (pos != str.size()) { + return -1; + } + return 0; +} + +bool PathUtils::IsFileExist(const std::string &path) +{ + if (path.empty() || path.size() > PATH_MAX) { + return false; + } + return (access(path.c_str(), F_OK) == 0) ? true : false; +} + +bool PathUtils::IsFileWritable(const std::string &path) +{ + if (path.empty() || path.size() > PATH_MAX) { + return false; + } + return (access(path.c_str(), W_OK) == 0) ? true : false; +} + +bool PathUtils::IsDir(const std::string &path) +{ + if (path.empty() || path.size() > PATH_MAX) { + return false; + } + struct stat st{}; + int ret = lstat(path.c_str(), &st); + if (ret != 0) { + return false; + } + return S_ISDIR(st.st_mode) ? true : false; +} + +bool PathUtils::CreateDir(const std::string &path) +{ + if (path.empty() || path.size() > PATH_MAX) { + return false; + } + if (IsFileExist(path)) { + return IsDir(path) ? true : false; + } + size_t pos = 0; + while ((pos = path.find_first_of('/', pos)) != std::string::npos) { + std::string base_dir = path.substr(0, ++pos); + if (IsFileExist(base_dir)) { + if (IsDir(base_dir)) { + continue; + } else { + return false; + } + } + if (mkdir(base_dir.c_str(), DATA_DIR_AUTHORITY) != 0) { + return false; + } + } + return (mkdir(path.c_str(), DATA_DIR_AUTHORITY) == 0) ? true : false; +} + +std::string PathUtils::RealPath(const std::string &path) +{ + if (path.empty() || path.size() > PATH_MAX) { + return ""; + } + char realPath[PATH_MAX] = {0}; + if (realpath(path.c_str(), realPath) == nullptr) { + return ""; + } + return std::string(realPath); +} + +std::string PathUtils::RelativeToAbsPath(const std::string &path) +{ + if (path.empty() || path.size() > PATH_MAX) { + return ""; + } + if (path[0] != '/') { + char pwd_path[PATH_MAX] = {0}; + if (getcwd(pwd_path, PATH_MAX) != nullptr) { + return std::string(pwd_path) + "/" + path; + } + return ""; + } + return std::string(path); +} + +std::string PathUtils::DirName(const std::string &path) +{ + if (path.empty()) { + return ""; + } + char temp_path[PATH_MAX] = {0}; + strncpy(temp_path, path.c_str(), path.size() < PATH_MAX ? path.size() : PATH_MAX); + char* path_c = dirname(temp_path); + return path_c ? std::string(path_c) : ""; +} + +bool PathUtils::CreateFile(const std::string &path) +{ + if (path.empty() || path.size() > PATH_MAX || !CreateDir(DirName(path))) { + return false; + } + int fd = creat(path.c_str(), DATA_FILE_AUTHORITY); + return (fd < 0 || close(fd) != 0) ? false : true; +} + +bool PathUtils::IsSoftLink(const std::string &path) +{ + if (path.empty() || path.size() > PATH_MAX || !IsFileExist(path)) { + return false; + } + struct stat st{}; + if (lstat(path.c_str(), &st) != 0) { + return false; + } + return S_ISLNK(st.st_mode); +} + +bool PathUtils::DirPathCheck(const std::string& abs_path) +{ + if (abs_path.empty() || abs_path.size() > PATH_MAX) { + fprintf(stderr, "[ERROR] The length of Path %s is invalid.\n", abs_path.c_str()); + return false; + } + if (IsSoftLink(abs_path)) { + fprintf(stderr, "[ERROR] Path %s is soft link.\n", abs_path.c_str()); + return false; + } + if (!IsFileExist(abs_path) && !CreateDir(abs_path)) { + fprintf(stderr, "[ERROR] Path %s not exist and create failed.\n", abs_path.c_str()); + return false; + } + if (!IsDir(abs_path) || !IsFileWritable(abs_path)) { + fprintf(stderr, "[ERROR] %s is not a directory or is not writable.\n", abs_path.c_str()); + return false; + } + return true; +} diff --git a/profiler/msprof_analyze/osrt_trace/src/utils.h b/profiler/msprof_analyze/osrt_trace/src/utils.h new file mode 100644 index 0000000000000000000000000000000000000000..129c062d5f2898d0b33db33f4716ae497c6ad8d1 --- /dev/null +++ b/profiler/msprof_analyze/osrt_trace/src/utils.h @@ -0,0 +1,50 @@ +/** + * Copyright 2024 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include + +#define LIKELY(x) (__builtin_expect(!!(x), 1)) +#define UNLIKELY(x) (__builtin_expect(!!(x), 0)) + +const mode_t DATA_FILE_AUTHORITY = 0640; +const mode_t DATA_DIR_AUTHORITY = 0750; + +inline uint64_t nsec_now() +{ + static const uint64_t S_TO_NS = 1000 * 1000 * 1000; + struct timespec ts; + clock_gettime(CLOCK_REALTIME, &ts); + return static_cast(ts.tv_sec * S_TO_NS + ts.tv_nsec); +} + +int str_to_i64(const std::string& str, int64_t& num); + + +struct PathUtils { + static bool IsFileExist(const std::string &path); + static bool IsFileWritable(const std::string &path); + static bool IsDir(const std::string &path); + static bool CreateDir(const std::string &path); + static std::string RealPath(const std::string &path); + static std::string RelativeToAbsPath(const std::string &path); + static std::string DirName(const std::string &path); + static bool CreateFile(const std::string &path); + static bool IsSoftLink(const std::string &path); + static bool DirPathCheck(const std::string &path); +}; diff --git a/profiler/msprof_analyze/precheck/README.md b/profiler/msprof_analyze/precheck/README.md new file mode 100644 index 0000000000000000000000000000000000000000..882cc4e12fff0b64c905d6146d9b7a0d98e7bad9 --- /dev/null +++ b/profiler/msprof_analyze/precheck/README.md @@ -0,0 +1,393 @@ +# Profiler Precheck 用户指南 + +欢迎使用 Profiler Precheck 工具!本指南将详细介绍该工具的功能、使用方法以及内部实现原理,帮助您快速上手并充分利用其性能分析能力。 + +## 目录 +- [1. 概述](#1-概述) +- [2. 整体架构](#2-整体架构) +- [3. 使用方法](#3-使用方法) + - [3.1 云容器场景](#31-云容器场景) + - [3.2 裸机场景](#32-裸机场景) +- [4. 命令参数说明](#4-命令参数说明) +- [5. 常见问题](#5-常见问题) + +## 1. 概述 + +Profiler Precheck 是一个用于分布式训练任务的性能分析工具。它可以自动采集集群中各节点的硬件与软件环境信息,并基于历史数据和专家知识,对当前训练任务的配置与资源使用情况进行分析,给出优化建议与预警,帮助用户发现并解决潜在的性能瓶颈,提升训练效率。 + +## 2. 整体架构 + +Profiler Precheck 采用主从架构,由一个主节点(master节点)和多个从节点(slave节点)组成: + +- **主节点(master节点)**: + - 负责接收用户的任务请求 + - 将 Precheck 相关代码分发到各从节点 + - 协调各节点的分析过程 + - 汇总分析结果生成最终报告 + - 通常是集群训练中rank=0的设备所在的主机节点 + +- **从节点(slave节点)**: + - 负责在本节点上执行用户的训练脚本 + - 运行 Profiler 采集各项性能指标 + - 将结果回传给主节点 + +### 预检流程 +1. **准备阶段**:用户在master节点上提交预检请求,master节点将代码分发到各slave节点 +2. **采集阶段**:各节点启动训练脚本,同时运行 Profiler 采集性能数据 +3. **汇总阶段**:master节点汇总各slave节点上报的性能数据 +4. **分析阶段**:主节点对汇总数据进行分析,生成分析报告 + +### 典型场景 + +## 3. 使用方法 + +### 3.1 云容器场景 +详细预检流程请参考:[云场景预检流程](assert/code_structure_startnode_docker.svg) + +#### 3.1.1 部署流程 + +1. **准备基础环境** +```bash +# 下载并加载基础镜像 +docker load -i user_image.tar + +# 创建训练容器 +docker run -it --name user_container \ + --device=/dev/davinci0 \ + --device=/dev/davinci_manager \ + --device=/dev/devmm_svm \ + --device=/dev/hisi_hdc \ + -v /usr/local/Ascend:/usr/local/Ascend \ + -v /path/to/data:/data \ + -v /path/to/model:/model \ + user_image:latest +``` + +2. **构建预检环境** +```bash +# 安装预检工具 +pip install msprof-analyze-xx.whl + +# 创建预检启动脚本 +cat > /usr/local/bin/run_node_precheck.sh << 'EOF' +#!/bin/bash +msprof-analyze precheck start_node \ + --node_rank ${NODE_RANK:-0} \ + --master_addr ${MASTER_ADDR:-"127.0.0.1"} \ + --master_port ${MASTER_PORT:-29500} \ + --nnodes ${NNODES:-1} \ + --nproc_per_node ${NPUS_PER_NODE:-8} \ + --task_name ${TASK_NAME:-"container_test"} \ + --profiling_cmd ${PROFILING_CMD:-"run.sh"} +EOF +chmod +x /usr/local/bin/run_node_precheck.sh + +# 保存预检镜像 +docker commit user_container precheck_image:latest +docker save -o precheck_image.tar precheck_image:latest +``` + +3. **分发和启动** +```bash +# 在每个节点上加载镜像 +docker load -i precheck_image.tar + +# 启动主节点容器 +docker run -d --name precheck_master \ + --network host \ + --device=/dev/davinci* \ + -v /usr/local/Ascend:/usr/local/Ascend \ + -v /path/to/data:/data \ + -e MASTER_ADDR=192.168.0.1 \ + -e MASTER_PORT=29500 \ + -e NNODES=2 \ + -e NODE_RANK=0 \ + -e NPUS_PER_NODE=8 \ + -e TASK_NAME=container_test \ + -e PROFILING_CMD="run.sh" \ + precheck_image:latest \ + /usr/local/bin/run_node_precheck.sh + +# 启动从节点容器 +docker run -d --name precheck_worker \ + --network host \ + --device=/dev/davinci* \ + -v /usr/local/Ascend:/usr/local/Ascend \ + -v /path/to/data:/data \ + -e MASTER_ADDR=192.168.0.1 \ + -e MASTER_PORT=29500 \ + -e NNODES=2 \ + -e NODE_RANK=1 \ + -e NPUS_PER_NODE=8 \ + -e TASK_NAME=container_test \ + -e PROFILING_CMD="run.sh" \ + precheck_image:latest \ + /usr/local/bin/run_node_precheck.sh +``` + +#### 3.1.2 配置说明 + +##### 容器环境变量 +| 变量名 | 说明 | 默认值 | +|--------|------|--------| +| MASTER_ADDR | 主节点IP地址 | 127.0.0.1 | +| MASTER_PORT | 主节点端口 | 29500 | +| NNODES | 总节点数 | 1 | +| NODE_RANK | 节点序号 | 0 | +| NPUS_PER_NODE | 每节点NPU数量 | 8 | +| TASK_NAME | 预检任务名称 | container_test | +| PROFILING_CMD | 训练命令 | run.sh | + +##### 容器挂载说明 +| 挂载点 | 说明 | 必需 | +|--------|------|------| +| /usr/local/Ascend | CANN工具包 | 是 | +| /data | 训练数据目录 | 否 | +| /model | 模型文件目录 | 否 | +| /output | 输出目录 | 否 | + +### 3.2 裸机场景 +详细预检流程请参考:[裸机场景预检流程](assert/code_structure_startall.svg) + +#### 3.2.1 环境配置验证 + +在开始使用预检工具前,需要确保集群环境配置正确。我们提供了一系列验证脚本帮助您快速检查环境: + +##### 1. SSH 免密配置 +```bash +# 1. 生成SSH密钥(如果已存在则跳过) +[ ! -f ~/.ssh/id_rsa ] && ssh-keygen -t rsa -N '' -f ~/.ssh/id_rsa + +# 2. 复制密钥到其他节点(替换用户名和IP) +ssh-copy-id user@192.168.0.2 +``` + +##### 2. 环境检查 +我们提供两个脚本帮助验证集群配置: + +1. **SSH连通性检查** +```bash +# 基础检查(默认5秒超时) +HOST_IPS="192.168.0.1,192.168.0.2" bash test_hosts_ssh.sh + +# 自定义超时时间 +HOST_IPS="192.168.0.1,192.168.0.2" TIMEOUT=10 bash test_hosts_ssh.sh +``` + +2. **集群环境一致性检查** +```bash +# 基础环境检查(Python、PyTorch等) +HOST_IPS="192.168.0.1,192.168.0.2" bash test_hosts_env.sh + +# 完整环境检查(包含CANN环境, developing) +HOST_IPS="192.168.0.1,192.168.0.2" CHECK_CANN=1 bash test_hosts_env.sh +``` + +示例输出: +``` +🔍 Cluster Environment Checker +━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +📊 Step 1: Collecting local environment info... +Detecting Python environment... +Checking installed packages... +Checking CANN environment... + +📌 Local Environment Summary: + • Python Path: /usr/bin/python3 + • Python Version: Python 3.8.10 + • Msprof-analyze: v1.3.0 + • Torch: v2.4.0 + • Torch_NPU: v2.2.0 +``` + +#### 3.2.2 启动预检 + +预检工具支持两种使用模式:内置基准测试和自定义训练脚本。 + +##### 方式一:使用内置ResNet基准测试 + +```bash +# 使用IP列表方式 +msprof-analyze precheck start_all \ + --host_ips "192.168.0.1,192.168.0.2" \ + --master_addr 192.168.0.1 \ + --nnodes 2 \ + --nproc_per_node 8 \ + --task_name resnet_test \ + --profiling_cmd "[resnet]" + +# 使用配置文件方式 +msprof-analyze precheck start_all \ + --host_config_file hosts.csv \ + --master_addr 192.168.0.1 \ + --nnodes 2 \ + --nproc_per_node 8 \ + --profiling_cmd "[resnet]" +``` + +##### 方式二:使用自定义训练脚本 + +1. **准备训练脚本** +```python +# train.py +import torch_npu +from torch_npu.profiler import profile, ProfilerActivity + +def train(node_prof_save_dir): + # 配置性能分析器 + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.NPU], + on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(node_prof_save_dir) + ) as prof: + # 训练代码 + for epoch in range(num_epochs): + for batch in dataloader: + # 训练逻辑 + ... + prof.step() # 记录性能数据 +``` + +2. **创建启动脚本** + +run.sh示例如下: +```bash +#!/bin/bash + +# 设置性能数据保存目录 +NODE_PROF_SAVE_DIR=${NODE_PROF_SAVE_DIR:-"./output/prof_data"} +mkdir -p "$NODE_PROF_SAVE_DIR" + +# 启动训练 +python3 train.py \ + --prof_dir "$NODE_PROF_SAVE_DIR" \ + --batch-size 32 \ + --epochs 10 \ + "$@" # 支持传入额外参数 +``` + +3. **启动预检分析** +```bash +# 设置执行权限 +chmod +x run.sh + +# 使用相对路径 +msprof-analyze precheck start_all \ + --host_ips "192.168.0.1,192.168.0.2" \ + --master_addr 192.168.0.1 \ + --nnodes 2 \ + --nproc_per_node 8 \ + --task_name custom_test \ + --profiling_cmd "./run.sh --extra-args value" + +# 使用绝对路径(推荐) +msprof-analyze precheck start_all \ + --host_ips "192.168.0.1,192.168.0.2" \ + --master_addr 192.168.0.1 \ + --nnodes 2 \ + --nproc_per_node 8 \ + --task_name custom_test \ + --profiling_cmd "/path/to/run.sh --extra-args value" +``` + +#### 3.2.3 使用注意事项 + +1. **路径设置** + - 建议使用绝对路径指定脚本位置 + - 确保所有节点的脚本路径一致 + - 检查目录和文件的读写权限 + +2. **环境变量** + - `NODE_PROF_SAVE_DIR`: 性能数据保存目录 + - 可通过 `"$@"` 传递额外的训练参数 + +3. **常见问题** + - 确保 run.sh 有执行权限 + - 验证工作目录的正确性 + - 检查性能数据目录是否可写 + +## 4. 命令参数说明 + +### 基本用法 +```bash +msprof-analyze precheck [options] + +Commands: + start_all 启动所有节点的预检 + start_node 启动单个节点的预检 + stop 停止预检(todo) + status 查看预检状态 (todo) +``` + +### 通用参数 +| 参数名 | 类型 | 必需 | 默认值 | 说明 | +|--------|------|------|-----------------------------------------------------------------------------------------------------------------|------| +| master_addr | str | 是 | - | 主节点IP地址 | +| master_port | int | 否 | 29500 | 主节点通信端口 | +| nnodes | int | 是 | - | 总节点数 | +| nproc_per_node | int | 是 | - | 每节点进程数 | +| task_name | str | 否 | auto_timestamp | 任务名称 | +| output_dir | str | 否 | ./output | 输出目录 | +| node_prof_save_dir | str | 否 | {output_dir}/{task_name}/node_prof_save_dir | 节点性能数据保存目录 | +| master_prof_gather_dir | str | 否 | {output_dir}/{task_name}/master_prof_gather_dir | 主节点数据汇总目录 | +| static | bool | 否 | False | 是否使用静态profiler采集模式 | +| prof_in_shared_storage | bool | 否 | False | 是否使用共享存储(跳过数据收集) | +| profiling_cmd | str | 是 | 训练命令说明:
    - `[resnet]`: 运行ResNet基准测试
    - `python train.py [args]`: 自定义训练脚本
    - `bash run.sh [args]`: 自定义训练脚本 | 要求用户自定义脚需要将profiler数据保存到{node_prof_save_dir} + +### start_all 专用参数 +| 参数名 | 类型 | 必需 | 说明 | +|--------|------|------|------| +| host_ips | str | 是* | 节点IP列表,逗号分隔 | +| host_config_file | str | 是* | SSH配置文件路径 | + +*注:host_ips 和 host_config_file 必须提供其中之一 + +### start_node 专用参数 +| 参数名 | 类型 | 必需 | 说明 | +|--------|------|------|------| +| node_rank | int | 是 | 当前节点序号(0 到 nnodes-1) | + +## 5. 常见问题 + +### 5.1 容器场景常见问题 + +1. **容器启动失败** +```bash +# 检查设备挂载 +ls -l /dev/davinci* + +# 检查日志 +docker logs precheck_container +``` + +2. **网络连接问题** +```bash +# 检查网络配置 +docker network inspect precheck_net + +# 测试容器间连接 +docker exec precheck_master ping precheck_worker +``` + +### 5.2 裸机场景常见问题 + +1. **SSH连接超时** +```bash +# 增加连接超时时间 +HOST_IPS="192.168.0.1,192.168.0.2" TIMEOUT=10 bash test_hosts_ssh.sh +``` + +2. **环境不一致** +```bash +# 详细检查环境 +HOST_IPS="192.168.0.1,192.168.0.2" CHECK_CANN=1 bash test_hosts_env.sh +``` + +3. **CANN环境问题** +```bash +# 检查CANN工具 +npu-smi info + +# 检查环境变量 +echo $LD_LIBRARY_PATH | grep Ascend +``` \ No newline at end of file diff --git a/profiler/msprof_analyze/precheck/__init__.py b/profiler/msprof_analyze/precheck/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/profiler/msprof_analyze/precheck/__main__.py b/profiler/msprof_analyze/precheck/__main__.py new file mode 100644 index 0000000000000000000000000000000000000000..deb0a713199c5629195ed16d54d1ae67c8df3d78 --- /dev/null +++ b/profiler/msprof_analyze/precheck/__main__.py @@ -0,0 +1,98 @@ +import os +from copy import deepcopy +import logging + +from msprof_analyze.precheck.common.constant import Constant +from msprof_analyze.precheck.common.logger import add_file_handler, create_logger +from msprof_analyze.precheck.common.utils import cn_now +from msprof_analyze.precheck.manager.args_manager import PrecheckArgsManager +from msprof_analyze.precheck.tools.ssh_utils import run_remote_command +from msprof_analyze.prof_common.path_manager import PathManager + + +def get_command_tpl(): + cwd = os.getcwd() + from msprof_analyze.precheck.runner.__main__ import get_conda_envs_info + _, conda_activate_cmd = get_conda_envs_info() + + EXECUTOR = f'source ~/.bashrc && {conda_activate_cmd} && cd {cwd} && {Constant.MS_PROF_PRECHECK_CMD} start_node' + ARGS = ('--nnodes={nnodes}', '--nproc_per_node={nproc_per_node}', + '--node_rank={node_rank}', '--master_addr={master_addr}', + '--master_port={master_port}', + '--nproc_per_node={nproc_per_node}', + '--node_prof_save_dir={node_prof_save_dir}', + '--master_prof_gather_dir={master_prof_gather_dir}', + '--task_name={task_name}', + '--profiling_cmd="{profiling_cmd}"', + '--output_dir={output_dir}', + ) + TPL = EXECUTOR + " " + " ".join(ARGS) + return TPL + + +def start_precheck(args: PrecheckArgsManager, logger): + config = dict( + nnodes=args.nnodes, + node_rank=-1, + nproc_per_node=args.nproc_per_node, + master_addr=args.master_addr, + master_port=args.master_port, + node_prof_save_dir=args.node_prof_save_dir, + master_prof_gather_dir=args.master_prof_gather_dir, + static=args.static, + task_name=args.task_name, + python_path=args.python_path, + output_dir=args.output_dir, + profiling_cmd=args.profiling_cmd, + prof_in_shared_storage=args.prof_in_shared_storage, + ) + + hosts_info = [] + for node_id, host in enumerate(args.host_ips): + node_config = deepcopy(config) + node_config['node_rank'] = node_id + + TPL = get_command_tpl() + cmd = TPL.format(**node_config) + if node_config.get('static', False) is True: + cmd += ' --static' + if node_config.get('prof_in_shared_storage', False) is True: + cmd += ' --prof_in_shared_storage' + + host_info = { + "host": host, + "username": os.getenv('USER'), + "key_filename": "~/.ssh/id_rsa", + "command": cmd, + "port": 22 + } + + if args.host_config_file: + host_info.update(args.ssh_remote_hosts[host]) + + hosts_info.append(host_info) + + logger.info("Starting remote command execution on %d hosts", len(hosts_info)) + run_remote_command(hosts_info) + logger.info("Precheck main processes have been started on all hosts") + + +def main(args=None): + logger = create_logger("profiler.precheck", Constant.LOGGING_LEVEL, use_memory_handler=True) + + PathManager.make_dir_safety(args.task_output_dir) + + timestamp = cn_now().strftime('%Y%m%d_%H%M%S') + log_filename = f'precheck_{timestamp}.log' + log_file_path = os.path.join(args.task_output_dir, log_filename) + PathManager.create_file_safety(log_file_path) + PathManager.check_path_writeable(log_file_path) + + logger = add_file_handler(logger, log_file_path) + logger.info("Starting precheck, Precheck log file will be saved at %s", log_file_path) + logger.info("Precheck arguments: %s", args) + + try: + start_precheck(args, logger) + except Exception as e: + logger.error("Precheck runner failed with error: %s", e, exc_info=Constant.ENABLE_STACKTRACE_LOGGING) diff --git a/profiler/msprof_analyze/precheck/analyze/__init__.py b/profiler/msprof_analyze/precheck/analyze/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/profiler/msprof_analyze/precheck/analyze/advisor_adaptor.py b/profiler/msprof_analyze/precheck/analyze/advisor_adaptor.py new file mode 100644 index 0000000000000000000000000000000000000000..491969e804f2622c4077d9a2abb52b9915b38ca7 --- /dev/null +++ b/profiler/msprof_analyze/precheck/analyze/advisor_adaptor.py @@ -0,0 +1,56 @@ +import sys +import os +import logging +from pathlib import Path + +sys.path.append(os.path.join(os.path.dirname(os.path.dirname(__file__)), "compare_tools")) +sys.path.append(os.path.join(os.path.dirname(os.path.dirname(__file__)), "cluster_analyse")) + +from msprof_analyze.advisor.analyzer.analyzer_controller import AnalyzerController +from msprof_analyze.advisor.interface.interface import Interface +from msprof_analyze.prof_common.path_manager import PathManager + +logger = logging.getLogger(__name__) + + +class advisor_adaptor: + def __init__(self): + pass + + @staticmethod + def _check_profiling_path_valid(profiling_path): + PathManager.input_path_common_check(profiling_path) + PathManager.check_path_owner_consistent(profiling_path) + if not Path(profiling_path).exists(): + logger.error(" Invalid profiling path: %s", profiling_path) + return False + return True + + @staticmethod + def _check_output_path_valid(output_path): + if not output_path: + return False + + if not os.path.exists(output_path): + return PathManager.make_dir_safety(output_path) + + PathManager.check_input_directory_path(output_path) + PathManager.input_path_common_check(output_path) + PathManager.check_path_owner_consistent(output_path) + return True + + def analyze(self, input_profiling_path, output_path): + if self._check_profiling_path_valid(input_profiling_path) and self._check_output_path_valid(output_path): + try: + reduced_dimensions = Interface.all_dimension[:-1] #advisor 默认调用全部功能,此方法不需要compare功能,故对列表进行处理 + AnalyzerController().do_analysis(dimensions=reduced_dimensions, + profiling_path=input_profiling_path, + benchmark_profiling_path=None, + output_path=output_path, + ) + except RuntimeError as e: + logger.error("RuntimeError during analysis: %s", e) + except Exception as e: + logger.error("Unexpected error during analysis: %s", e) + else: + logger.error("Invalid paths provided; analysis aborted.") diff --git a/profiler/msprof_analyze/precheck/assert/code_structure_startall.svg b/profiler/msprof_analyze/precheck/assert/code_structure_startall.svg new file mode 100644 index 0000000000000000000000000000000000000000..9502f093c35d4a05eef603a0c9e3089075d6ea5b --- /dev/null +++ b/profiler/msprof_analyze/precheck/assert/code_structure_startall.svg @@ -0,0 +1 @@ +Launch LayerPrecheck Control LayerPrecheck Execution LayerData Collection & Analysis LayerUserUserrun_llama2_precheck.sh/run_precheck.shrun_llama2_precheck.sh/run_precheck.shprecheck_cli.pyprecheck_cli.pyprecheck/_ _main_ _.pyprecheck/_ _main_ _.pySSH RunnerSSH Runnerprecheck_cli.py(start_node)precheck_cli.py(start_node)runner/_ _main_ _.pyrunner/_ _main_ _.pyUser Training ScriptUser Training Scripttrain_with_profiler.pytrain_with_profiler.pyCollectorRunnerCollectorRunnerAdvisorRunnerAdvisorRunnerExecute scriptmsprof-analyze precheck start_allConfiguration:1. Node IPs2. Master node settings3. Distributed parameters4. Output directoriesstart_precheck()run_remote_command()loop[for each host]Execute on remote nodestart_precheck_runner()get_conda_envs_info()Auto-detect conda/python envalt[profiling_cmd == "[resnet]"]Execute example modelInitialize profilerTraining loop1. Load model & dataset2. Configure optimizer3. Execute training steps4. Collect metricsComplete training[profiling_cmd == custom command]Prepare environmentSet distributed env vars:- MASTER_ADDR- MASTER_PORT- NNODES- NODE_RANK- NPROC_PER_NODEExecute via bashExample:torchrun $DISTRIBUTED_ARGS \pretrain_gpt.py \$MODEL_ARGS \$PROFILE_ARGS \...Training completealt[not prof_in_shared_storage]Package profiling datazip_directory()1. Compress profiling data2. Filter by whitelist patterns3. Check archive size limitstransport()1. Transfer to master node2. Handle node rank specific logicCollection completealt[rank == 0]Analyze collected datarun_analyzer()1. Extract archives2. Process ascend_pt files3. Generate reportsAnalysis completeExecution completeNode completeAll nodes completePrecheck completeCommand completeDisplay completion \ No newline at end of file diff --git a/profiler/msprof_analyze/precheck/assert/code_structure_startnode_docker.svg b/profiler/msprof_analyze/precheck/assert/code_structure_startnode_docker.svg new file mode 100644 index 0000000000000000000000000000000000000000..a3bcca97fefddc8b3fa3123452c879a58d074e6c --- /dev/null +++ b/profiler/msprof_analyze/precheck/assert/code_structure_startnode_docker.svg @@ -0,0 +1 @@ +Cloud Platform LayerLaunch LayerPrecheck Execution LayerData Collection & Analysis LayerUserUserCloud PlatformCloud PlatformDocker ContainersDocker Containersrun_node_precheck.shrun_node_precheck.shprecheck_cli.pyprecheck_cli.pyrunner/_ _main_ _.pyrunner/_ _main_ _.pyUser Training ScriptUser Training Scripttrain_with_profiler.pytrain_with_profiler.pyCollectorRunnerCollectorRunnerAdvisorRunnerAdvisorRunnerPlatform Configuration1. Upload Docker image2. Configure cluster settings(nodes, NPUs per node)3. Set training parameters(model, dataset, etc.)Container DeploymentDeploy containers across cluster nodesPrecheck Executionloop[For each container in parallel]Execute with env vars(MASTER_ADDR, NODES,NODE_RANK, etc.)msprof-analyze precheck start_nodeInitialize precheck sessionget_conda_envs_info()1. Detect conda environment2. Get activation command3. Setup environment varsalt[profiling_cmd == "[resnet]"]Execute example modelInitialize profilerTraining loop1. Load model & dataset2. Configure optimizer3. Execute training steps4. Collect metricsComplete training[profiling_cmd == custom command]Prepare environmentSet distributed env vars:- MASTER_ADDR- MASTER_PORT- NNODES- NODE_RANK- NPROC_PER_NODEExecute via bashExample:torchrun $DISTRIBUTED_ARGS \custom_training.py \$MODEL_ARGS \$PROFILE_ARGS \...Initialize profilerTraining loop1. Load custom configuration2. Setup distributed env3. Execute training steps4. Collect profiling dataTraining completealt[not prof_in_shared_storage]Package profiling datazip_directory()1. Compress profiling data2. Filter by whitelist patterns3. Check archive size limitstransport()1. Transfer to master node2. Handle node rank specific logicCollection completealt[rank == 0]Analyze collected datarun_analyzer()1. Extract archives2. Process ascend_pt files3. Generate reportsAnalysis completePrecheck completeCommand finishedContainer task completeAll containers finishedResultsReturn profiling resultsand analysis report \ No newline at end of file diff --git a/profiler/msprof_analyze/precheck/collect/__init__.py b/profiler/msprof_analyze/precheck/collect/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/profiler/msprof_analyze/precheck/collect/collector.py b/profiler/msprof_analyze/precheck/collect/collector.py new file mode 100644 index 0000000000000000000000000000000000000000..ca74b3e6769106ec23c8e789ca8eb170fbef600a --- /dev/null +++ b/profiler/msprof_analyze/precheck/collect/collector.py @@ -0,0 +1,458 @@ +import sys +import os +from typing import Any, Dict + +sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) + +import logging +from pathlib import Path +import argparse +import time +import math + +import torch +import torch_npu +import torch.distributed as dist +import numpy as np +import torch.multiprocessing as mp + +from msprof_analyze.prof_common.path_manager import PathManager +from msprof_analyze.precheck.manager.group_manager import GroupManager, EnvGroup, SubGroup +from msprof_analyze.precheck.common.constant import Constant +from msprof_analyze.precheck.common.time_stat import TimeStat +from msprof_analyze.precheck.common.utils import create_npu_event, event_elaspe_second, parition_sub_group_ranks, \ + get_master_rank_collect_dir, get_slave_rank_collect_dir, cat_files, is_equal_file_hash, get_quick_hash, \ + compress_directory +from msprof_analyze.precheck.manager.disk_manager import DiskManager + + +class Collector: + + def __init__(self): + self.stream = None + self.time_stat = None + self.world_size = None + self.device = None + self.local_rank = None + self.rank = None + self.logger = logging.getLogger(__name__) + + def init(self, slave_env: EnvGroup): + self.rank = slave_env.rank + self.local_rank = slave_env.local_rank + torch.npu.set_device(self.local_rank) + self.device = torch.device('npu:%d' % self.local_rank) + self.world_size = slave_env.world_size + self.time_stat = TimeStat() + self.stream = torch_npu.npu.current_stream() + + def gather_rank_data(self, group, gather_tensor, all_gather=False, dst_rank=None) -> tuple: + cur_group_size = dist.get_world_size(group) + self.logger.debug( + "[Rank %d] Local rank %d, gather data from %d ranks" % (self.rank, self.local_rank, cur_group_size)) + wait_event = create_npu_event(self.stream) + dist.barrier(group=group) + start_event = create_npu_event(self.stream) + wait_time = event_elaspe_second(self.stream, wait_event, start_event) + if all_gather: + gather_list = [] + for _ in range(cur_group_size): + gather_list.append(torch.zeros_like(gather_tensor, dtype=gather_tensor.dtype, device=self.device)) + dist.all_gather(gather_list, gather_tensor, group=group) + else: + if self.rank == dst_rank: + gather_list = [] + for _ in range(cur_group_size): + gather_list.append(torch.zeros_like(gather_tensor, dtype=gather_tensor.dtype, device=self.device)) + else: + gather_list = None + dist.gather(gather_tensor, gather_list=gather_list, dst=dst_rank, group=group) + end_event = create_npu_event(self.stream) + transfer_time = event_elaspe_second(self.stream, start_event, end_event) + + return gather_list, wait_time, transfer_time + + def create_sub_group(self, file_sizes_hash, master_rank_num): + # 需要根据file_sizes来划分sub_group ranks + file_sizes = [item[0] for item in file_sizes_hash[master_rank_num:]] + partitions = parition_sub_group_ranks(master_rank_num, file_sizes) + self.logger.debug("[Rank %d] subgroup partiitons %s" % (self.rank, partitions)) + + wait_time = 0 + transfer_time = 0 + for ranks in partitions: + if len(ranks) > 1: + wait_event = create_npu_event(self.stream) + dist.barrier() + start_event = create_npu_event(self.stream) + wait_time = event_elaspe_second(self.stream, wait_event, start_event) + sub_group = dist.new_group(ranks=ranks, backend='hccl') + end_event = create_npu_event(self.stream) + transfer_time = event_elaspe_second(self.stream, start_event, end_event) + + self.logger.info( + '[Rank %d] after new group, ranks: %s, file_sizes_hash %s' % (self.rank, ranks, file_sizes_hash)) + cur_file_sizes = [file_sizes_hash[r].cpu().tolist()[0] for r in ranks[1:]] + cur_file_hashes = [file_sizes_hash[r].cpu().tolist()[1:] for r in ranks[1:]] + + GroupManager().add_rank_sub_group(sub_group=sub_group, ranks=ranks, file_sizes=cur_file_sizes, + file_hashes=cur_file_hashes) + else: + self.logger.debug('[Rank %d] ranks %s not enough for creating subgroup' % (self.rank, ranks)) + self.time_stat.init_pg_stat.sub_group_init = [wait_time, transfer_time] + + def bd_split_file_size(self, sub_group, split_size=None): + split_size_bd = torch.tensor([split_size], dtype=torch.int64, device=self.device) \ + if self.rank == sub_group.master_rank else torch.zeros(1, dtype=torch.int64, device=self.device) + wait_event = create_npu_event(self.stream) + dist.barrier(group=sub_group.group) + start_event = create_npu_event(self.stream) + wait_time = event_elaspe_second(self.stream, wait_event, start_event) + self.logger.info("[Rank %d] after split size barrier" % self.rank) + dist.broadcast(split_size_bd, group=sub_group.group, src=sub_group.master_rank) + end_event = create_npu_event(self.stream) + transfer_time = event_elaspe_second(self.stream, start_event, end_event) + self.logger.info("[Rank %d] after split size bd, %s" % (self.rank, split_size_bd)) + + self.time_stat.com_stat.broad_splits = [wait_time, transfer_time] + return split_size_bd.cpu().item() + + def gather_file_split(self, sub_group, tensor, master_rank_num, output_file_dir=None): + for i in range(sub_group.max_splits): + # is master node + if self.rank < master_rank_num: + cur_tensor = torch.zeros(sub_group.split_file_size, dtype=torch.uint8, device=self.device) + else: + start_time = time.perf_counter() + cur_tensor = tensor[i * sub_group.split_file_size: (i + 1) * sub_group.split_file_size] + if len(cur_tensor) < sub_group.split_file_size: + cur_tensor = np.pad(cur_tensor, (0, sub_group.split_file_size - len(cur_tensor)), 'constant', + constant_values=0) + cur_tensor = torch.tensor(cur_tensor, dtype=torch.uint8, device=self.device) + end_time = time.perf_counter() + self.time_stat.disk_stat.read_input_file_splits.append(end_time - start_time) + + # gather rank data内部有barrier与计时 + file_tensor_list, wait_time, transfer_time = self.gather_rank_data(dst_rank=sub_group.master_rank, + group=sub_group.group, + gather_tensor=cur_tensor) + self.logger.debug("[Rank %d] gather file split %d, wait time: %f, gather time: %f seconds" % ( + self.rank, i, wait_time, transfer_time)) + self.time_stat.com_stat.gather_file_splits.append([wait_time, transfer_time]) + + # 记录从memory_on_chip刷到硬盘中的耗时 + if file_tensor_list: + master_rank_collect_dir = get_master_rank_collect_dir(output_file_dir, self.rank) + memory_on_chip_ram_times = [] + ram_disk_times = [] + for rank_i, rank in enumerate(sub_group.ranks): + if rank != sub_group.master_rank: + group_rank = rank - master_rank_num + rank_dir = get_slave_rank_collect_dir(master_rank_collect_dir, group_rank) + if not os.path.exists(rank_dir): + os.makedirs(rank_dir, exist_ok=True) + rank_file = os.path.join(rank_dir, 'split_%d' % i) + cur_split_size = sub_group.splits[rank_i - 1][i] + if cur_split_size > 0: + start_time = time.perf_counter() + data = file_tensor_list[rank_i][:cur_split_size].cpu().numpy().tobytes() + ram_time = time.perf_counter() + with open(rank_file, 'wb') as f: + f.write(data) + end_time = time.perf_counter() + memory_on_chip_ram_times.append(ram_time - start_time) + ram_disk_times.append(end_time - ram_time) + + self.time_stat.disk_stat.memory_on_chip.append(memory_on_chip_ram_times) + self.time_stat.disk_stat.ram_disk.append(ram_disk_times) + + for tensor in file_tensor_list: + del tensor + del file_tensor_list + torch.npu.empty_cache() + + def concat_file_split(self, output_file_dir: str, sub_group: SubGroup, master_rank_num): + cur_rank_collect_dir = get_master_rank_collect_dir(output_file_dir, self.rank) + concat_times = [] + verify_hash_times = [] + for rank_i, rank in enumerate(sub_group.ranks): + # 只提取slave rank的case + if rank == self.rank: + continue + group_rank = rank - master_rank_num + rank_dir = get_slave_rank_collect_dir(cur_rank_collect_dir, group_rank) + output_file_name = os.path.join(rank_dir, 'merge.zip') + file_split_names = [] + start_time = time.perf_counter() + with open(output_file_name, 'wb') as output_file: + for split_i in range(sub_group.max_splits): + file_split = os.path.join(rank_dir, 'split_%d' % split_i) + if not os.path.exists(file_split): + self.logger.error('[Rank %d] not exist file split %s' % (self.rank, file_split)) + else: + file_split_names.append(file_split) + cat_files(output_file_name, input_files=file_split_names) + for file_split in file_split_names: + os.remove(file_split) + + end_time = time.perf_counter() + concat_times.append(end_time - start_time) + self.logger.debug( + '[Rank %d] concatenate slave rank %s, time: %f seconds' % (self.rank, rank, end_time - start_time)) + + start_time = time.perf_counter() + output_file_hash = get_quick_hash(output_file_name) + self.logger.debug('[Rank %d] rank_i %d, file_hashs:%s' % (self.rank, rank_i, sub_group.file_hashes)) + if not is_equal_file_hash(output_file_hash, sub_group.file_hashes[rank_i - 1]): + self.logger.error('[Rank %d] Not equal merge file hash. %s. %s' % ( + self.rank, output_file_hash, sub_group.file_hashes[rank_i - 1])) + end_time = time.perf_counter() + verify_hash_times.append(end_time - start_time) + + self.time_stat.disk_stat.hash_output_file = verify_hash_times + self.time_stat.disk_stat.concat_file = concat_times + + def master_node_run(self, master_env: EnvGroup, output_file_dir, split_file_size=None): + try: + # 设置环境变量,这些会在torch.dist中用到 + # 因为master node rank为0, 所以global rank直接等于local rank + master_env.set_env() + self.init(master_env) + + start_event = create_npu_event(self.stream) + self.logger.info('[Rank %d] Start master node process' % self.rank) + torch.npu.set_device(self.device) + init_process_group_event = create_npu_event(self.stream) + elp_time = event_elaspe_second(self.stream, start_event, init_process_group_event) + self.logger.debug('[Rank %d] init process group time %f seconds' % (self.rank, elp_time)) + self.time_stat.init_pg_stat.global_group_init = elp_time + + self.logger.info("[Rank %d] master node run" % (self.rank)) + # Step 2. Gather tensor size from slave node. + gather_tensor = torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=torch.int64, device=self.device) + # 分为 (file_size, file_hash) + dist.init_process_group(backend='hccl', rank=self.rank, world_size=self.world_size) + if not (dist.is_available() and dist.is_initialized()): + raise RuntimeError("Distributed environment is not available") + + file_sizes_hash, wait_time, transfer_time = self.gather_rank_data(group=dist.group.WORLD, + gather_tensor=gather_tensor, + all_gather=True) + self.time_stat.com_stat.gather_file_size = [wait_time, transfer_time] + + self.logger.debug('[Rank %d] gather file size time %f seconds' % (self.rank, transfer_time)) + + # 判断硬盘空间是否足够,解压的过程中需要额外的空间存储临时文件与原压缩包 + file_sizes = [item[0] for item in file_sizes_hash[master_env.local_world_size:]] + total_file_size = sum(file_sizes) + + total_size_gb = Constant.UNZIP_DISK_SIZE_RAIO * total_file_size / (1024 * 1024 * 1024) + + self.logger.debug( + '[Rank %d] collect file sizes %s, total size %fgb' % (self.rank, file_sizes, total_file_size)) + DiskManager.check_disk_space(output_file_dir, total_size_gb) + + # Step 3. broadcast子通信域配置,建立子通信域 + self.logger.info("[Rank %d] creating sub group %s" % (self.rank, file_sizes_hash)) + self.create_sub_group(file_sizes_hash, master_env.local_world_size) + sub_group = GroupManager().get_rank_sub_group(self.rank) + + # 以下进入每个子通信域特定的逻辑 + if sub_group: + self.logger.info("[Rank %d] Subgroup ranks %s, file_sizes %s" % ( + self.rank, sub_group.ranks, sub_group.file_sizes)) + + # 未指定split file size的话,根据memory_on_chip/rank_num计算 + if not split_file_size: + if len(sub_group.ranks) > 0: + split_file_size = math.floor(Constant.MASTER_RANK_MEMORY_ON_CHIP / (len(sub_group.ranks))) + else: + logger.error("Value of sub_group.ranks is invalid, %d.", len(sub_group.ranks)) + self.bd_split_file_size(sub_group, split_file_size) + sub_group.split_size(split_file_size) + self.logger.info("[Rank %d] Subgroup split file size %s, splits %s" % ( + self.rank, sub_group.split_file_size, sub_group.splits)) + self.gather_file_split(sub_group=sub_group, tensor=None, master_rank_num=master_env.local_world_size, + output_file_dir=output_file_dir) + self.logger.debug("[Rank %d] start concat file split" % self.rank) + self.concat_file_split(output_file_dir, sub_group, master_env.local_world_size) + if len(sub_group.ranks) > 1: + self.logger.info(self.time_stat.to_string()) + else: + self.logger.info("[Rank %d] master rank not in sub group" % self.rank) + dist.barrier() + except Exception as e: + self.logger.error("%s", e, exc_info=Constant.ENABLE_STACKTRACE_LOGGING) + raise e + finally: + dist.destroy_process_group() + + def slave_node_run(self, slave_env: EnvGroup, input_file_dir, master_rank_num): + try: + self.logger.debug('Enter slave node run wrapper') + # 设置环境变量,这些会在torch.dist中用到 + slave_env.set_env() + self.init(slave_env) + torch.npu.set_device(self.device) + start_event = create_npu_event(self.stream) + init_process_group_event = create_npu_event(self.stream) + elp_time = event_elaspe_second(self.stream, start_event, init_process_group_event) + self.time_stat.init_pg_stat.global_group_init = elp_time + + self.logger.debug('[Rank %d] init process group time %f seconds' % (self.rank, elp_time)) + self.logger.info('[Rank %d] Start slave node process' % self.rank) + + # Step2. 先压缩文件,统计文件大小,再进入到gather逻辑里 + if os.path.isfile(input_file_dir): + file_path = input_file_dir + else: + PathManager.check_path_writeable(input_file_dir) + file_path = os.path.join(str(Path(input_file_dir).parent), 'compress.tar') + start_time = time.perf_counter() + compress_directory(input_file_dir, file_path) + end_time = time.perf_counter() + self.time_stat.disk_stat.compress_input_file = end_time - start_time + self.logger.info("[Rank %d] Compress directory time: %f seconds" % (self.rank, end_time - start_time)) + file_size = os.path.getsize(file_path) + start_time = time.perf_counter() + file_hash_chunks = get_quick_hash(file_path) + end_time = time.perf_counter() + self.time_stat.disk_stat.hash_input_file = end_time - start_time + self.logger.info("[Rank %d] Hash input file time: %f seconds" % (self.rank, end_time - start_time)) + file_hash_chunks.insert(0, file_size) + self.logger.info( + "[Rank %d] File hash chunks (first element is file size): %s" % (self.rank, file_hash_chunks)) + gather_tensor = torch.tensor(file_hash_chunks, dtype=torch.int64, device=self.device) + + dist.init_process_group(backend='hccl', rank=self.rank, world_size=self.world_size) + if not (dist.is_available() and dist.is_initialized()): + raise RuntimeError("Distributed environment is not available") + + file_sizes_hash, wait_time, transfer_time = self.gather_rank_data(group=dist.group.WORLD, + gather_tensor=gather_tensor, + all_gather=True) + self.time_stat.com_stat.gather_file_size = [wait_time, transfer_time] + self.logger.info("[Rank %d] Gather file size - wait time: %f seconds, transfer time: %f seconds" % ( + self.rank, wait_time, transfer_time)) + # Step3. 建立子通信域 + self.logger.debug("[Rank %d] creating sub group %s" % (self.rank, file_sizes_hash)) + self.create_sub_group(file_sizes_hash, master_rank_num) + sub_group = GroupManager().get_rank_sub_group(self.rank) + + # 进入每个子通信域特定的逻辑 + if sub_group: + # Step4. broacast split size大小 + self.logger.info("[Rank %d] Subgroup ranks %s, file_sizes %s" % ( + self.rank, sub_group.ranks, sub_group.file_sizes)) + split_file_size = self.bd_split_file_size(sub_group) + sub_group.split_size(split_file_size) + file_tensor = np.memmap(file_path, dtype=np.uint8, mode='r') + self.gather_file_split(sub_group=sub_group, tensor=file_tensor, master_rank_num=master_rank_num) + self.logger.info(self.time_stat.to_string()) + else: + self.logger.warning("[Rank %d] slave rank not in sub group" % (self.rank)) + dist.barrier() + except Exception as e: + self.logger.error("%s", e, exc_info=Constant.ENABLE_STACKTRACE_LOGGING) + raise e + finally: + dist.destroy_process_group() + + def run(self, args_dict: Dict[str, Any]): + input_file_dir = args_dict.get("input_file_dir") + output_file_dir = args_dict.get("output_file_dir") + nnodes = args_dict.get("nnodes") + node_rank = args_dict.get("node_rank") + master_addr = args_dict.get("master_addr") + master_port = args_dict.get("master_port") + master_rank_num = args_dict.get("master_rank_num") + split_file_size = args_dict.get("split_file_size") + time_out = args_dict.get("time_out") + log_file = args_dict.get("log_file") + + logging.basicConfig( + filename=log_file, # File to write logs to + level=logging.DEBUG, # Minimum logging level to write to the file + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' # Log message format + ) + self.logger.info({"message": "Run method arguments", + "class": self.__class__.__name__, + "method": sys._getframe().f_code.co_name, + "args": args_dict}) + + # 计算calculate world size + world_size = nnodes + master_rank_num - 1 + # master node的逻辑 + if node_rank == 0: + processes = [] + for i in range(master_rank_num): + master_env = EnvGroup(rank=i, local_rank=i, world_size=world_size, master_addr=master_addr, + master_port=master_port, group_rank=0, local_world_size=master_rank_num) + process = mp.Process(target=self.master_node_run, args=(master_env, output_file_dir, split_file_size)) + self.logger.info("Start master node subprocess %d." % i) + process.start() + processes.append(process) + start_time = time.perf_counter() + try: + while True: + all_done = all(not process.is_alive() for process in processes) + if all_done: + self.logger.info("All subprocesses finished successfully.") + break + elapsed_time = time.perf_counter() - start_time + time.sleep(5) + if elapsed_time > time_out: + raise TimeoutError("Timeout reached. Terminating all subprocesses.") + + except TimeoutError as e: + self.logger.error("%s", e, exc_info=Constant.ENABLE_STACKTRACE_LOGGING) + for process in processes: + if process.is_alive(): + process.terminate() + process.join() + finally: + # 确保Ensure all processes are cleaned up + for process in processes: + process.join() + # slave node的逻辑 + else: + rank = node_rank + master_rank_num - 1 + slave_env = EnvGroup(rank=rank, local_rank=0, world_size=world_size, master_addr=master_addr, + master_port=master_port, group_rank=node_rank, local_world_size=1) + self.slave_node_run(slave_env, input_file_dir, master_rank_num) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--input_file_dir", type=str, help='input profiling data dir') + parser.add_argument("--output_file_dir", type=str, help='input profiling data dir') + parser.add_argument("--nnodes", type=int, help='the total node number') + parser.add_argument("--node_rank", type=int, help='node rank in the cluster') + parser.add_argument("--master_addr", type=str, help='master address') + parser.add_argument("--master_port", type=int, default=29501, help='master port') + parser.add_argument("--master_rank_num", type=int, default=8, help='master rank nums') + + parser.add_argument("--split_file_size", type=int, default=None, help='split file size') + + # master node整体time out的时间 + parser.add_argument("--time_out", type=int, default=Constant.DEFAULT_TIME_OUT, + help='totoal process time out in seconds') + parser.add_argument("--log_file", type=str, default=None, help='logging file') + args = parser.parse_args() + + logging.basicConfig( + filename=args.log_file, # File to write logs to + level=logging.DEBUG, # Minimum logging level to write to the file + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' # Log message format + ) + logger = logging.getLogger(__name__) + + collector = Collector() + logger.debug(vars(args)) + args_dict = vars(args) + + try: + collector.run(args_dict) + except Exception as e: + logger.error("%s", e, exc_info=Constant.ENABLE_STACKTRACE_LOGGING) + raise e diff --git a/profiler/msprof_analyze/precheck/common/__init__.py b/profiler/msprof_analyze/precheck/common/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/profiler/msprof_analyze/precheck/common/constant.py b/profiler/msprof_analyze/precheck/common/constant.py new file mode 100644 index 0000000000000000000000000000000000000000..1fc724e7524917c4c5abb21b83e26e573d0f956b --- /dev/null +++ b/profiler/msprof_analyze/precheck/common/constant.py @@ -0,0 +1,45 @@ +import logging +import os +import stat +from datetime import timezone, timedelta + + +class Constant: + DEFAULT_SPLIT_FILE_SIZE = 15 * 1024 # 便于测试多文件split,默认split size设为15k + MASTER_RANK_MEMORY_ON_CHIP = 10 * 1024 * 1024 * 1024 # 10GB 片上内存可用显存来传输数据 + UNZIP_DISK_SIZE_RAIO = 1.0 # 需要x倍压缩文件的空间进行解压操作 + DEFAULT_TIME_OUT = 1200 + + ARG_MAX_LEN = 255 # 参数最大长度 + ARG_MIN_INT_VALUE = - (1 << 31) # 32位整数最小值 + ARG_MAX_INT_VALUE = (1 << 31) - 1 # 32位整数最大值 + ARG_MIN_PORT_VALUE = 0 + ARG_MAX_PORT_VALUE = 65535 + + PROFILER_FILE_PATTERNS = [r'profiler_metadata\.json', r'profiler_info_\d{1,10}\.json', r'ASCEND_PROFILER_OUTPUT/.*'] + + COLLECTOR_MASTER_RANK_NUM = 4 + COLLECTOR_DEFAULT_TIMEOUT = 1200 # seconds + COLLECTOR_SPLIT_FILE_SIZE = None # 文件传输的split块大小,默认split size设为根据显存自动计算 + LOCALHOST_ADDRESSES = {'localhost', '127.0.0.1'} + + MAX_ARCHIVE_SIZE = 20 * 1024 * 1024 * 1024 # 20 GB + MAX_ARCHIVE_FILE_COUNT = 10000 + MAX_ARCHIVE_RATIO = 10 + + DEFAULT_PROFILING_COMMANDS = { + "[resnet]": "resnet", + } + + MS_PROF_PRECHECK_CMD = "msprof-analyze precheck" + + ENABLE_STACKTRACE_LOGGING = False + LOGGING_LEVEL = logging.INFO + + +class TimeConstant: + """Time related constants""" + UTC = timezone.utc + CHINA_OFFSET = timedelta(hours=8) + CHINA_TIMEZONE = timezone(CHINA_OFFSET, name='Asia/Shanghai') + MS_TO_S = 1 / 1000 # Milliseconds to seconds conversion factor diff --git a/profiler/msprof_analyze/precheck/common/logger.py b/profiler/msprof_analyze/precheck/common/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..04346a80343098491e7610610730f9ea52fded8a --- /dev/null +++ b/profiler/msprof_analyze/precheck/common/logger.py @@ -0,0 +1,103 @@ +import logging +import logging.handlers + + +def create_logger(name: str, level: int = logging.DEBUG, use_memory_handler: bool = True) -> logging.Logger: + """ + Create a logger with optional memory handler for buffering logs. + + Args: + name: The name of the logger. recommend to use the module name: __name__. + level: The logging level, default is DEBUG. + use_memory_handler: Whether to add a memory handler for buffering logs, default is True. + + Returns: + A configured logger instance. + + Examples: + # Create a logger with memory handler + logger = create_logger("my_logger", logging.INFO, use_memory_handler=True) + + # Create a logger without memory handler + logger = create_logger("my_logger", logging.INFO, use_memory_handler=False) + + Notes: + When use_memory_handler is True, a memory handler is added to buffer logs until a specific log level + (default is ERROR) is reached, then logs are flushed to the target handler. This can avoid frequent + file writes and improve performance. Buffered logs can be manually flushed by calling logger.handlers[1].flush() + if no file handler is created yet. + + When use_memory_handler is False, no memory handler is added, and logs are written to the target handler + (e.g., console or file) in real-time. + """ + logger = logging.getLogger(name) + logger.handlers.clear() + + logger.setLevel(level) + logger.propagate = False + + console_handler = logging.StreamHandler() + console_handler.setLevel(level) + formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + console_handler.setFormatter(formatter) + logger.addHandler(console_handler) + + if use_memory_handler: + memory_handler = logging.handlers.MemoryHandler(capacity=1000, flushLevel=logging.ERROR) + memory_handler.setLevel(level) + memory_handler.setFormatter(formatter) + logger.addHandler(memory_handler) + + return logger + + +def add_file_handler(logger: logging.Logger, log_file: str) -> logging.Logger: + """ + Add a file handler to an existing logger and handle the memory handler if present. + + Args: + logger: An existing logger instance. + log_file: The path to the log file. + + Returns: + The updated logger instance. + + Example: + # Initialize a logger + logger = create_logger("my_logger", logging.DEBUG, use_memory_handler=True) + + # Add a file handler to the logger + logger = add_file_handler(logger, "output.log") + + Notes: + This function adds a file handler to the given logger, inheriting the log level from the logger. + If a memory handler was previously added to the logger, its target handler is set to the new file handler, + buffered logs are flushed to the file, and then the memory handler is removed. + This ensures that both buffered logs and subsequent logs are written to the file after using the file handler. + """ + file_handler = logging.FileHandler(log_file) + file_handler.setLevel(logger.level) + formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + file_handler.setFormatter(formatter) + logger.addHandler(file_handler) + + for handler in logger.handlers: + if isinstance(handler, logging.handlers.MemoryHandler): + handler.setTarget(file_handler) + handler.flush() + logger.removeHandler(handler) + + return logger + + +if __name__ == "__main__": + logger = create_logger("test_logger", logging.DEBUG, use_memory_handler=True) + logger.info("This is an info message from initial logger with memory handler") + + import tempfile + + with tempfile.NamedTemporaryFile(mode='w', delete=False) as temp_file: + temp_file_path = temp_file.name + add_file_handler(logger, temp_file_path) + logger.info("This is an info message from logger with file handler") + logger.info("The log file is {}".format(temp_file_path)) diff --git a/profiler/msprof_analyze/precheck/common/singleton.py b/profiler/msprof_analyze/precheck/common/singleton.py new file mode 100644 index 0000000000000000000000000000000000000000..b645f284d642d3ba84c6e0cde374d865f16b7105 --- /dev/null +++ b/profiler/msprof_analyze/precheck/common/singleton.py @@ -0,0 +1,9 @@ +def singleton(cls: any) -> any: + _instance = {} + + def _singleton(*args: any, **kw: any) -> any: + if cls not in _instance: + _instance[cls] = cls(*args, **kw) + return _instance.get(cls) + + return _singleton diff --git a/profiler/msprof_analyze/precheck/common/time_stat.py b/profiler/msprof_analyze/precheck/common/time_stat.py new file mode 100644 index 0000000000000000000000000000000000000000..69df61cb797310adb4262c1261eec651770c08d8 --- /dev/null +++ b/profiler/msprof_analyze/precheck/common/time_stat.py @@ -0,0 +1,74 @@ +from dataclasses import dataclass, field +from typing import List + +@dataclass +class InitProcessGroupStat: + global_group_init: float = None + sub_group_init: List[float] = field(default_factory=list) #wait time, transfer time + def sum_transfer_time(self): + return self.global_group_init + self.sub_group_init[1] + + def to_str_list(self): + str_list = ['[InitPGStat]:'] + str_list.append(' global group init: %f seconds:' % self.global_group_init) + str_list.append(' sub group init: %f seconds:' % self.sub_group_init[1]) + return str_list + +@dataclass +class ComStat: + gather_file_size: List[float] = field(default_factory=list) + broad_splits: List[float] = field(default_factory=list) + gather_file_splits: List[List[float]] = field(default_factory=list) + def sum_transfer_time(self): + return self.gather_file_size[1] + self.broad_splits[1] + + def to_str_list(self): + str_list = ['[ComStat]:'] + str_list.append(' gather file size: %f seconds:' % self.gather_file_size[1]) + str_list.append(' broad splits: %f seconds:' % self.broad_splits[1]) + file_split_times = [t[1] for t in self.gather_file_splits] + str_list.append(' gather file splits: %s seconds:' % file_split_times) + return str_list + +@dataclass +class DiskStat: + memory_on_chip: List[List[float]] = field(default_factory=list) + ram_disk: List[List[float]] = field(default_factory=list) + + concat_file: List[float] = field(default_factory=list) + hash_output_file: List[float] = field(default_factory=list) + + read_input_file_splits: List[float] = field(default_factory=list) + hash_input_file: float = None + + def to_str_list(self): + str_list = ['[DiskStat]:'] + if len(self.memory_on_chip) > 0: + for memory_on_chip, ram_disk in zip(self.memory_on_chip, self.ram_disk): + str_list.append(' File Split: ') + str_list.append(' hdm_ram time: %s' % memory_on_chip) + str_list.append(' ram_disk time: %s' % ram_disk) + str_list.append(' concat file time for slave ranks: %s' % self.concat_file) + str_list.append(' verify file hash time for slave ranks: %s' % self.hash_output_file) + + #slave node + else: + str_list.append(' hash file time: %s' % self.hash_input_file) + str_list.append(' read file split times: %s' % self.read_input_file_splits) + + return str_list + + +@dataclass +class TimeStat: + init_pg_stat: InitProcessGroupStat = field(default_factory=InitProcessGroupStat) + com_stat: ComStat = field(default_factory=ComStat) + disk_stat: DiskStat = field(default_factory=DiskStat) + + #print it for logging, 应当区分master node rank与slave node。 + def to_string(self): + str_list = ['[TimeStat]:'] + str_list.extend(' %s' %s for s in self.init_pg_stat.to_str_list()) + str_list.extend(' %s' %s for s in self.com_stat.to_str_list()) + str_list.extend(' %s' %s for s in self.disk_stat.to_str_list()) + return '\n'.join(str_list) diff --git a/profiler/msprof_analyze/precheck/common/utils.py b/profiler/msprof_analyze/precheck/common/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..03f4e7d7ee31032067298351c25f426505cfc2ff --- /dev/null +++ b/profiler/msprof_analyze/precheck/common/utils.py @@ -0,0 +1,193 @@ +import os +import sys +import hashlib +import subprocess +import logging +from datetime import datetime + +import torch_npu +from msprof_analyze.precheck.common.constant import TimeConstant +from msprof_analyze.prof_common.path_manager import PathManager + +logger = logging.getLogger(__name__) + + +def get_file_md5(filepath, chunk_size=4096, split_hash_size=4): + PathManager.check_input_file_path(filepath) + PathManager.check_path_readable(filepath) + md5_hash = hashlib.md5() + with open(filepath, "rb") as file: + for chunk in iter(lambda: file.read(chunk_size), b""): + md5_hash.update(chunk) + hash_bytes = int(md5_hash.hexdigest(), 16).to_bytes(16, 'big') + + chunks = [] + for i in range(0, 16, split_hash_size): + chunks.append(int.from_bytes(hash_bytes[i:i + split_hash_size], 'big')) + return chunks + + +def get_quick_hash(file_path, sample_size=65536, hash_spilt_size=4): + PathManager.check_input_file_path(file_path) + PathManager.check_path_readable(file_path) + file_size = os.path.getsize(file_path) + if file_size < sample_size * 5: + return get_file_md5(file_path) + hash_md5 = hashlib.md5() + with open(file_path, "rb") as f: + hash_md5.update(f.read(sample_size)) + f.seek(max(0, (os.path.getsize(file_path) // 2) - (sample_size // 2))) + hash_md5.update(f.read(sample_size)) + f.seek(-sample_size, 2) + hash_md5.update(f.read(sample_size)) + hash_bytes = int(hash_md5.hexdigest(), 16).to_bytes(16, 'big') + + chunks = [] + for i in range(0, 16, hash_spilt_size): + chunks.append(int.from_bytes(hash_bytes[i:i + hash_spilt_size], 'big')) + return chunks + + +def is_equal_file_hash(chunks1, chunks2): + for chunk1, chunk2 in zip(chunks1, chunks2): + if chunk1 != chunk2: + return False + return True + + +def cat_files(output_file, input_files): + """ + Concatenate multiple binary input files into a single output file using cat command. + + Args: + output_file (str): Path to the output file + input_files (list): List of input file paths to concatenate + + Returns: + bool: True if concatenation was successful + + Raises: + subprocess.CalledProcessError: If the cat command fails + """ + PathManager.check_input_file_path(output_file) + cmd = ["cat"] + list(input_files) + + try: + with open(output_file, 'wb') as outfile: + result = subprocess.run(cmd, stdout=outfile, stderr=subprocess.PIPE) + + if result.returncode == 0: + return True + else: + logger.error("Error occurred during concatenation: %s", + result.stderr.decode('utf-8', errors='replace')) + raise subprocess.CalledProcessError(result.returncode, cmd, + output=None, + stderr=result.stderr) + + except OSError as e: + logger.error("OS error occurred during file operation: %s", str(e)) + raise + + +def compress_directory(src_dir, output_file): + PathManager.check_input_directory_path(src_dir) + PathManager.check_path_readable(src_dir) + if not os.path.isdir(src_dir): + raise FileNotFoundError(f"The directory '{src_dir}' does not exist.") + try: + result = subprocess.run( + ["/bin/tar", "-czf", output_file, "-C", src_dir, "."], + check=True, # Raise an error if the command fails + stdout=subprocess.PIPE, + stderr=subprocess.PIPE + ) + except subprocess.CalledProcessError as e: + raise RuntimeError( + f"Failed to compress directory '{src_dir}' into '{output_file}'. " + f"Error: {e.stderr.decode('utf-8')}" + ) from e + + +def get_master_rank_collect_dir(output_file_dir, master_rank_i): + return os.path.join(output_file_dir, 'rank_%d_collect' % master_rank_i) + + +def get_slave_rank_collect_dir(master_rank_collect_dir, group_rank): + return os.path.join(master_rank_collect_dir, 'node_%d' % group_rank) + + +def parition_sub_group_ranks(master_rank_num, file_sizes): + master_rank_num = int(master_rank_num) + indexed_lst = sorted(enumerate(file_sizes), key=lambda x: x[1]) + sorted_indices = [index + master_rank_num for index, value in indexed_lst] + if master_rank_num != 0: + base_size = len(file_sizes) // master_rank_num + else: + logging.error("%s value can not be 0", master_rank_num) + extra_items = len(file_sizes) % master_rank_num + partitions = [] + start = 0 + for i in range(master_rank_num): + end = start + base_size + (1 if i < extra_items else 0) + partition_indices = [i] + partition_indices.extend(sorted_indices[start:end]) + partitions.append(partition_indices) + start = end + return partitions + + +def get_split_file_size(memory_on_chip_size, sub_group_rank_num): + if sub_group_rank_num != 0: + return memory_on_chip_size // sub_group_rank_num + else: + logging.error("%s value can not be 0", sub_group_rank_num) + return None + + +def create_npu_event(stream): + event = torch_npu.npu.Event(enable_timing=True) + stream.record_event(event) + return event + + +def event_elaspe_second(stream, event1, event2): + stream.synchronize() + return event1.elapsed_time(event2) * TimeConstant.MS_TO_S + + +def cn_now() -> datetime: + """ + Get current time in China timezone as a formatted string. + + Returns: + datetime: Current time in China timezone + """ + return datetime.now(tz=TimeConstant.UTC).astimezone(TimeConstant.CHINA_TIMEZONE) + + +def check_file_owner_and_permission(file_path): + """ + Check if the file belongs to current user and only owner has write permission. + + Args: + file_path: Path to the file to check + + Raises: + RuntimeError: If file not found, not owned by current user, or has wrong permissions + """ + PathManager.check_path_readable(file_path) + + if not os.path.isfile(file_path): + raise RuntimeError(f"File not found at {file_path}") + + # Check file owner + if os.stat(file_path).st_uid != os.getuid(): + raise RuntimeError(f"File {file_path} is not owned by current user") + + # Check file permissions (only owner should have write permission) + current_mode = os.stat(file_path).st_mode + desired_mode = 0o700 # rwx------ (only owner has read/write/execute) + if (current_mode & 0o777) != desired_mode: + os.chmod(file_path, desired_mode) + logger.warning("File %s has wrong permissions, has been changed to %o", file_path, desired_mode) diff --git a/profiler/msprof_analyze/precheck/distributed_cluster/__init__.py b/profiler/msprof_analyze/precheck/distributed_cluster/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b14094e3f9a77a0970342980ed8de1017f58ce19 --- /dev/null +++ b/profiler/msprof_analyze/precheck/distributed_cluster/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. \ No newline at end of file diff --git a/profiler/msprof_analyze/precheck/distributed_cluster/distributed_cluster_base.py b/profiler/msprof_analyze/precheck/distributed_cluster/distributed_cluster_base.py new file mode 100644 index 0000000000000000000000000000000000000000..7ccd1e542eee2050542a08df62e1720a9cdf4dcb --- /dev/null +++ b/profiler/msprof_analyze/precheck/distributed_cluster/distributed_cluster_base.py @@ -0,0 +1,19 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class DistributedClusterBase: + def __init__(self): + pass diff --git a/profiler/msprof_analyze/precheck/env_check/__init__.py b/profiler/msprof_analyze/precheck/env_check/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b14094e3f9a77a0970342980ed8de1017f58ce19 --- /dev/null +++ b/profiler/msprof_analyze/precheck/env_check/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. \ No newline at end of file diff --git a/profiler/msprof_analyze/precheck/env_check/check_item_factory.py b/profiler/msprof_analyze/precheck/env_check/check_item_factory.py new file mode 100644 index 0000000000000000000000000000000000000000..0ea14bfe0d37768828291a1e9c71a1b890c7bd0d --- /dev/null +++ b/profiler/msprof_analyze/precheck/env_check/check_item_factory.py @@ -0,0 +1,57 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from msprof_analyze.precheck.env_check.environment_variable_check import EnvironmentVariableCheck +from msprof_analyze.precheck.env_check.python_library_check import PythonLibraryCheck +from msprof_analyze.precheck.env_check.cpu_check import CPUCheck +from msprof_analyze.precheck.env_check.npu_check import NPUCheck +from msprof_analyze.precheck.env_check.communication_check import CommunicationCheck +from msprof_analyze.precheck.env_check.io_check import IOCheck + + +HARDWARE_CHECK_LIST = [ + CPUCheck, + NPUCheck, + CommunicationCheck, + IOCheck, +] + +SOFTWARE_CHECK_LIST = [ + EnvironmentVariableCheck, + PythonLibraryCheck, +] + + +class CheckItemFactory: + CHECK_ITEMS = { + check_item.CHECK_TYPE: check_item + for check_item in SOFTWARE_CHECK_LIST + HARDWARE_CHECK_LIST + } + + @staticmethod + def get_check_item(check_type: str) -> list: + if check_type == "all": + return SOFTWARE_CHECK_LIST + HARDWARE_CHECK_LIST + if check_type == "software": + return SOFTWARE_CHECK_LIST + if check_type == "hardware": + return HARDWARE_CHECK_LIST + check_type_list = check_type.split("|") + check_items = [] + for check_type in check_type_list: + check_item = CheckItemFactory.CHECK_ITEMS.get(check_type) + if not check_item: + continue + check_items.append(check_item) + return check_items diff --git a/dynolog_npu/plugin/setup.py b/profiler/msprof_analyze/precheck/env_check/communication_check.py similarity index 42% rename from dynolog_npu/plugin/setup.py rename to profiler/msprof_analyze/precheck/env_check/communication_check.py index 151b9b3fb3fa1a42e147685f632163c8b3f5a564..807d4008115422ff312d2877495273bc25312eea 100644 --- a/dynolog_npu/plugin/setup.py +++ b/profiler/msprof_analyze/precheck/env_check/communication_check.py @@ -1,42 +1,25 @@ -# Copyright (c) 2025, Huawei Technologies Co., Ltd. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import os -from setuptools import setup -from pybind11.setup_helpers import Pybind11Extension - -BASE_DIR = os.path.dirname(os.path.realpath(__file__)) - -# Define the extension module -ext_modules = [ - Pybind11Extension( - "IPCMonitor", # Name of the Python module - sources=["bindings.cpp", - "ipc_monitor/utils.cpp", - "ipc_monitor/DynoLogNpuMonitor.cpp", - "ipc_monitor/NpuIpcClient.cpp", - ], # Source files - include_dirs=[os.path.join(BASE_DIR, "ipc_monitor")], # Include Pybind11 headers - language="c++", # Specify the language - ), -] - -# Set up the package -setup( - name="dynolog_npu_plugin", - version="0.1", - description="dynolog npu plugins", - ext_modules=ext_modules, - install_requires=["pybind11"], -) \ No newline at end of file +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from msprof_analyze.precheck.env_check.environment_check import HardwareCheck + + +class CommunicationCheck(HardwareCheck): + CHECK_TYPE = "communication" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def check(self): + pass diff --git a/profiler/msprof_analyze/precheck/env_check/cpu_check.py b/profiler/msprof_analyze/precheck/env_check/cpu_check.py new file mode 100644 index 0000000000000000000000000000000000000000..e3765c71ebc3a9ee4700fe59cfc84727c68c3417 --- /dev/null +++ b/profiler/msprof_analyze/precheck/env_check/cpu_check.py @@ -0,0 +1,25 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from msprof_analyze.precheck.env_check.environment_check import HardwareCheck + + +class CPUCheck(HardwareCheck): + CHECK_TYPE = "cpu" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def check(self): + pass diff --git a/profiler/msprof_analyze/precheck/env_check/environment_check.py b/profiler/msprof_analyze/precheck/env_check/environment_check.py new file mode 100644 index 0000000000000000000000000000000000000000..98d54ac506400ec53348cdaeea613d555dc81290 --- /dev/null +++ b/profiler/msprof_analyze/precheck/env_check/environment_check.py @@ -0,0 +1,50 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from abc import ABC, abstractmethod + + +class EnvironmentCheck(ABC): + CHECK_TYPE = "" + + def __init__(self, **kwargs): + self.output = kwargs.get("output", "./output") + + def init(self): + pass + + def uninit(self): + pass + + @abstractmethod + def check(self): + pass + + +class HardwareCheck(EnvironmentCheck): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + @abstractmethod + def check(self): + pass + + +class SoftwareCheck(EnvironmentCheck): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + @abstractmethod + def check(self): + pass diff --git a/profiler/msprof_analyze/precheck/env_check/environment_variable_check.py b/profiler/msprof_analyze/precheck/env_check/environment_variable_check.py new file mode 100644 index 0000000000000000000000000000000000000000..58d2becb23266ff085b80d2acd9c17a229e8420d --- /dev/null +++ b/profiler/msprof_analyze/precheck/env_check/environment_variable_check.py @@ -0,0 +1,25 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from msprof_analyze.precheck.env_check.environment_check import SoftwareCheck + + +class EnvironmentVariableCheck(SoftwareCheck): + CHECK_TYPE = "env_variable" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def check(self): + pass diff --git a/profiler/msprof_analyze/precheck/env_check/io_check.py b/profiler/msprof_analyze/precheck/env_check/io_check.py new file mode 100644 index 0000000000000000000000000000000000000000..5cfd5c425f0d18d7021c8ef8dca7447c9df6dfc6 --- /dev/null +++ b/profiler/msprof_analyze/precheck/env_check/io_check.py @@ -0,0 +1,25 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from msprof_analyze.precheck.env_check.environment_check import HardwareCheck + + +class IOCheck(HardwareCheck): + CHECK_TYPE = "io" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def check(self): + pass diff --git a/profiler/msprof_analyze/precheck/env_check/npu_check.py b/profiler/msprof_analyze/precheck/env_check/npu_check.py new file mode 100644 index 0000000000000000000000000000000000000000..c7ffa4997da7a75f566461c70af22393e9b97fb1 --- /dev/null +++ b/profiler/msprof_analyze/precheck/env_check/npu_check.py @@ -0,0 +1,25 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from msprof_analyze.precheck.env_check.environment_check import HardwareCheck + + +class NPUCheck(HardwareCheck): + CHECK_TYPE = "npu" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def check(self): + pass diff --git a/profiler/msprof_analyze/precheck/env_check/python_library_check.py b/profiler/msprof_analyze/precheck/env_check/python_library_check.py new file mode 100644 index 0000000000000000000000000000000000000000..81de7000ce7cdf37c1a6c52ff6d650df95d86d9b --- /dev/null +++ b/profiler/msprof_analyze/precheck/env_check/python_library_check.py @@ -0,0 +1,25 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from msprof_analyze.precheck.env_check.environment_check import SoftwareCheck + + +class PythonLibraryCheck(SoftwareCheck): + CHECK_TYPE = "python_lib" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def check(self): + pass diff --git a/profiler/msprof_analyze/precheck/examples/__init__.py b/profiler/msprof_analyze/precheck/examples/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/profiler/msprof_analyze/precheck/examples/profiler/__init__.py b/profiler/msprof_analyze/precheck/examples/profiler/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/profiler/msprof_analyze/precheck/examples/profiler/dynamic_prof.py b/profiler/msprof_analyze/precheck/examples/profiler/dynamic_prof.py new file mode 100644 index 0000000000000000000000000000000000000000..f4b1e9b849b32380978b45661d18c03447ee6482 --- /dev/null +++ b/profiler/msprof_analyze/precheck/examples/profiler/dynamic_prof.py @@ -0,0 +1,71 @@ +import json +import os +import logging +from copy import deepcopy + +logger = logging.getLogger(__name__) + +DEFAULT_DP_CONFIG = { + "activities": ["CPU", "NPU"], + "prof_dir": "./prof_result", + "analyse": False, + "record_shapes": False, + "profile_memory": False, + "with_stack": False, + "with_flops": False, + "with_modules": False, + "active": 1, + "is_rank": False, + "rank_list": [], + "experimental_config": { + "profiler_level": "Level0", + "aic_metrics": "AiCoreNone", + "l2_cache": False, + "op_attr": False, + "gc_detect_threshold": None, + "data_simplification": True, + "record_op_args": False, + "export_type": "text", + "msprof_tx": False + } +} + + +def _get_prof_config_json(prof_dp_path): + prof_config_json = os.path.join(prof_dp_path, "profiler_config.json") + return prof_config_json + + +def _set_default_prof_config(prof_config_json): + with open(prof_config_json, "w") as f: + json.dump(DEFAULT_DP_CONFIG, f, indent=4) + + +def get_dynamic_prof_config_path(): + cwd = os.path.dirname(os.path.realpath(__file__)) + prof_dp_path = os.path.join(cwd, './local_config/config_dynamic') + + prof_config_json = _get_prof_config_json(prof_dp_path) + os.makedirs(os.path.dirname(prof_config_json), exist_ok=True) + + if not os.path.exists(prof_config_json): + _set_default_prof_config(prof_config_json) + logger.info("Created default dynamic profiler config file at {}".format(prof_config_json)) + + return prof_dp_path + + +def start_dynamic_profiler(prof_dp_path, prof_save_dir): + prof_config_json = _get_prof_config_json(prof_dp_path) + if prof_save_dir is not None: + if not os.path.exists(prof_config_json): + data = deepcopy(DEFAULT_DP_CONFIG) + else: + with open(prof_config_json, 'r') as f: + data = json.load(f) + data['prof_dir'] = prof_save_dir + + with open(prof_config_json, 'w') as f: + json.dump(data, f, indent=4) + + logger.info('has started dynamic profiling') diff --git a/profiler/msprof_analyze/precheck/examples/profiler/models.py b/profiler/msprof_analyze/precheck/examples/profiler/models.py new file mode 100644 index 0000000000000000000000000000000000000000..4a0f8cc0de62efcd92081a632fb9786188f05de3 --- /dev/null +++ b/profiler/msprof_analyze/precheck/examples/profiler/models.py @@ -0,0 +1,67 @@ +import logging +from typing import Dict, Any, Tuple + +import torch +import torch.nn as nn +from torch.utils.data import Dataset + +logger = logging.getLogger(__name__) + + +# ============= Models ============= +class SimpleResNet(nn.Module): + def __init__(self, num_classes: int = 10): + super().__init__() + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) + self.bn1 = nn.BatchNorm2d(64) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.fc = nn.Linear(64 * 56 * 56, num_classes) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + x = torch.flatten(x, 1) + x = self.fc(x) + return x + + +# ============= Datasets ============= +class DummyImageDataset(Dataset): + def __init__(self, input_shape: Tuple[int, ...], num_samples: int = 100): + self.input_shape = input_shape + self.num_samples = num_samples + + def __len__(self) -> int: + return self.num_samples + + def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]: + x = torch.randn(self.input_shape) + y = torch.randint(0, 10, ()) + return x, y + + +# ============= Example Registry ============= +class ExampleRegistry: + @staticmethod + def get_example_config(example_name: str) -> Dict[str, Any]: + configs = { + "resnet": { + "model_class": SimpleResNet, + "model_args": {"num_classes": 10}, + "dataset_class": DummyImageDataset, + "dataset_args": {"input_shape": (3, 224, 224), "num_samples": 800}, + "batch_size": 8, + }, + } + + if example_name not in configs: + available_models = ", ".join(configs.keys()) + raise ValueError( + f"Unknown example name: {example_name}. " + f"Available models are: {available_models}" + ) + + return configs[example_name] diff --git a/profiler/msprof_analyze/precheck/examples/profiler/train_with_profiler.py b/profiler/msprof_analyze/precheck/examples/profiler/train_with_profiler.py new file mode 100644 index 0000000000000000000000000000000000000000..9e6eb482c4cc10e4f31026b36738654d305409f2 --- /dev/null +++ b/profiler/msprof_analyze/precheck/examples/profiler/train_with_profiler.py @@ -0,0 +1,286 @@ +""" +Example Usage: +1. Single node training examples: +torchrun --nproc_per_node=8 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr="127.0.0.1" \ + --master_port=29500 \ + train_with_profiler.py \ + --example_name bert \ + --prof_output_dir ./profiler_output + +2. Distributed training examples: + + # Multiple nodes (2 nodes, 8 GPUs each) + # On node 0 (master node): + torchrun --nproc_per_node=8 \ + --nnodes=2 \ + --node_rank=0 \ + --master_addr="192.168.1.1" \ + --master_port=29500 \ + train_with_profiler.py \ + --example_name bert \ + --prof_output_dir ./profiler_output + + # On node 1: + torchrun --nproc_per_node=8 \ + --nnodes=2 \ + --node_rank=1 \ + --master_addr="192.168.1.1" \ + --master_port=29500 \ + train_with_profiler.py \ + --example_name bert \ + --prof_output_dir ./profiler_output + +Distributed Training Parameters: +--nproc_per_node: Number of processes per node (typically number of GPUs) +--nnodes: Total number of nodes +--node_rank: Rank of current node (0 to nnodes-1) +--master_addr: IP address of master node +--master_port: Port for master node communication + +Available Models: +- resnet: ResNet model implementation + +Environment Variables (automatically set by torchrun): +- RANK: Global rank of the process +- WORLD_SIZE: Total number of processes +- LOCAL_RANK: Local rank within the current node +- MASTER_ADDR: Master node address +- MASTER_PORT: Master node port +""" + +import os +import argparse +import ipaddress +import datetime +import logging +from typing import Optional, List + +import torch +import torch_npu +import torch.nn as nn +import torch.distributed as dist +from torch.utils.data import Dataset, DataLoader +from tqdm import tqdm + +try: + from torch_npu.profiler import dynamic_profile as dp +except ImportError: + dp = None + +from msprof_analyze.precheck.examples.profiler.models import ExampleRegistry +from msprof_analyze.precheck.examples.profiler.dynamic_prof import get_dynamic_prof_config_path +from msprof_analyze.precheck.common.constant import Constant + +logger = logging.getLogger(__name__) + + +class ProfilerCallback: + """Callback for handling profiling operations""" + + def __init__(self, prof_save_dir, + is_dynamic=False, dynamic_prof_path=None): + self.profiler = None + self.is_dynamic = is_dynamic + if is_dynamic: + self.dynamic_prof_path = dynamic_prof_path if dynamic_prof_path else get_dynamic_prof_config_path() + self.prof_save_dir = prof_save_dir + + def on_train_begin(self): + if self.is_dynamic: + dp.init(self.dynamic_prof_path) + dist.barrier() + if dist.get_rank() == 0: + from msprof_analyze.precheck.examples.profiler.dynamic_prof import start_dynamic_profiler + start_dynamic_profiler(self.dynamic_prof_path, + self.prof_save_dir) + self.profiler = dp + else: + experimental_config = torch_npu.profiler._ExperimentalConfig( + aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization, + profiler_level=torch_npu.profiler.ProfilerLevel.Level2, + l2_cache=False, + data_simplification=False + ) + self.profiler = torch_npu.profiler.profile( + activities=[ + torch_npu.profiler.ProfilerActivity.CPU, + torch_npu.profiler.ProfilerActivity.NPU + ], + with_stack=True, + record_shapes=True, + profile_memory=True, + schedule=torch_npu.profiler.schedule( + wait=5, warmup=5, active=20, repeat=1, skip_first=10), + experimental_config=experimental_config, + with_flops=True, + with_modules=True, + on_trace_ready=torch_npu.profiler.tensorboard_trace_handler( + self.prof_save_dir) + ) + self.profiler.__enter__() + + def on_step_end(self): + if self.profiler: + self.profiler.step() + + def on_train_end(self): + if not self.is_dynamic and self.profiler: + self.profiler.__exit__(None, None, None) + + +class Trainer: + def __init__( + self, + model: nn.Module, + dataloader: Optional[Dataset] = None, + callbacks: Optional[List[ProfilerCallback]] = None, + criterion: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + ): + self.model = model + self.dataloader = dataloader + self.callbacks = callbacks or [] + + # Setup loss and optimizer with defaults + self.criterion = criterion or nn.CrossEntropyLoss() + self.optimizer = optimizer or torch.optim.Adam(self.model.parameters()) + + # get dist config from env + self.rank = int(os.environ.get("RANK", 0)) + self.world_size = int(os.environ.get("WORLD_SIZE", 1)) + self.local_rank = int(os.environ.get("LOCAL_RANK", 0)) + self.device = f"npu:{self.local_rank}" + + # Setup device and distributed training + self.setup_distributed(self.rank, self.world_size, self.local_rank) + + # Move model and criterion to device + self.model = self.model.to(self.device) + self.criterion = self.criterion.to(self.device) + + @staticmethod + def setup_distributed(rank, world_size, local_rank): + if dist.is_initialized(): + return + + torch.npu.set_device(local_rank) + dist.init_process_group( + backend='hccl', + rank=rank, + world_size=world_size, + timeout=datetime.timedelta(seconds=1800) + ) + logger.info(f"[Rank {rank}] Initialized distributed training") + + def cleanup(self): + """Explicitly cleanup distributed training resources""" + if dist.is_initialized(): + dist.destroy_process_group() + logger.info(f"[Rank {self.rank}] Destroyed distributed training") + + def train(self, epoch: int = 1): + # Call training start callbacks + for callback in self.callbacks: + callback.on_train_begin() + + # Training loop + for epoch_idx in range(epoch): + if self.rank == 0: + pbar = tqdm( + total=len(self.dataloader), + desc=f'Epoch {epoch_idx + 1}/{epoch}', + unit='batch' + ) + + for step, (inputs, labels) in enumerate(self.dataloader): + # Move data to device + inputs = inputs.to(self.device) + labels = labels.to(self.device) + + # Forward pass + self.optimizer.zero_grad() + outputs = self.model(inputs) + loss = self.criterion(outputs, labels) + + # Backward pass + loss.backward() + self.optimizer.step() + + if self.rank == 0: + pbar.update(1) + pbar.set_postfix({ + 'step': f'{step + 1}/{len(self.dataloader)}', + 'loss': f'{loss.item():.4f}' + }) + + dist.barrier() + + # Call step end callbacks + for callback in self.callbacks: + callback.on_step_end() + + if self.rank == 0: + pbar.close() + + # Call training end callbacks + for callback in self.callbacks: + callback.on_train_end() + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--example_name', default='resnet', + choices=['resnet'], + help='Name of the example to run') + parser.add_argument('--prof_output_dir', required=True) + parser.add_argument('--static', action='store_true', required=False, default=False) + args = parser.parse_args() + + # Get example configuration + example_config = ExampleRegistry.get_example_config(args.example_name) + + # Create model and dataset + model = example_config["model_class"](**example_config["model_args"]) + dataset = example_config["dataset_class"](**example_config["dataset_args"]) + + # Create loss and optimizer (可选,使用默认值也可以) + criterion = nn.CrossEntropyLoss() + optimizer = torch.optim.Adam(model.parameters(), lr=0.001) + + # Create profiler callback + profiler_callback = ProfilerCallback( + args.prof_output_dir, + is_dynamic=(not args.static) + ) + + dataloader = DataLoader(dataset, batch_size=example_config["batch_size"]) + + # Initialize trainer + trainer = Trainer( + model=model, + dataloader=dataloader, + callbacks=[profiler_callback], + criterion=criterion, # 可选 + optimizer=optimizer, # 可选 + ) + + try: + trainer.train() + finally: + trainer.cleanup() + + +if __name__ == '__main__': + logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' + ) + + try: + main() + except Exception as e: + logger.error(f"Unexpected error: {e}", exc_info=Constant.ENABLE_STACKTRACE_LOGGING) + raise diff --git a/profiler/msprof_analyze/precheck/examples/scripts/precheck_run_llama2.sh b/profiler/msprof_analyze/precheck/examples/scripts/precheck_run_llama2.sh new file mode 100644 index 0000000000000000000000000000000000000000..e3bf0859e7565ecbea7857bb1601fc9e58812b57 --- /dev/null +++ b/profiler/msprof_analyze/precheck/examples/scripts/precheck_run_llama2.sh @@ -0,0 +1,128 @@ +#!/bin/bash + +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True + +GPUS_PER_NODE=${GPUS_PER_NODE:-8} +MASTER_ADDR=${MASTER_ADDR:-"192.168.0.1"} +MASTER_PORT=${MASTER_PORT:-6000} +NNODES=${NNODES:-2} +NODE_RANK=${NODE_RANK:-0} +WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) + +CKPT_SAVE_DIR=${CKPT_SAVE_DIR:-"./ckpt/llama-2-7b"} +CKPT_LOAD_DIR=${CKPT_LOAD_DIR:-"./model_weights/llama-2-7b-legacy"} +TOKENIZER_MODEL=${TOKENIZER_MODEL:-"./model_from_hf/llama-2-7b-hf/tokenizer.model"} +DATA_PATH=${DATA_PATH:-"./dataset/enwiki_text_document"} + +TP=${TP:-2} +PP=${PP:-4} + +# Result directory +OUTPUT_DIR=${OUTPUT_DIR:-"./result/precheck/llama2-1129-2130"} + +PROF_NODE_RES_DIR="$OUTPUT_DIR/node_prof_save_dir" +LOG_FILE="$OUTPUT_DIR/precheck.log" + +# Check if profiling output directory exists before running training +# This prevents starting a long training job if the directory is missing +if [ ! -d "$OUTPUT_DIR" ]; then + echo "Error: Result directory $OUTPUT_DIR does not exist." \ + "Please create the directory before running training" \ + "(in ${BASH_SOURCE[0]})" >&2 + exit 1 +fi + +# Get the directory of the current script and cd into it +# SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" +# echo "Script directory: $SCRIPT_DIR" +# cd "$SCRIPT_DIR" +# echo "Changed working directory to: $(pwd)" + + +DISTRIBUTED_ARGS=" + --nproc_per_node $GPUS_PER_NODE \ + --nnodes $NNODES \ + --node_rank $NODE_RANK \ + --master_addr $MASTER_ADDR \ + --master_port $MASTER_PORT +" + +GPT_ARGS=" + --tensor-model-parallel-size ${TP} \ + --pipeline-model-parallel-size ${PP} \ + --sequence-parallel \ + --num-layers 32 \ + --hidden-size 4096 \ + --ffn-hidden-size 11008 \ + --num-attention-heads 32 \ + --tokenizer-type Llama2Tokenizer \ + --tokenizer-model ${TOKENIZER_MODEL} \ + --seq-length 4096 \ + --max-position-embeddings 4096 \ + --micro-batch-size 1 \ + --global-batch-size 256 \ + --make-vocab-size-divisible-by 1 \ + --lr 1.25e-6 \ + --train-iters 5 \ + --lr-decay-style cosine \ + --untie-embeddings-and-output-weights \ + --disable-bias-linear \ + --attention-dropout 0.0 \ + --init-method-std 0.01 \ + --hidden-dropout 0.0 \ + --position-embedding-type rope \ + --normalization RMSNorm \ + --use-fused-rmsnorm \ + --swiglu \ + --use-flash-attn \ + --no-masked-softmax-fusion \ + --attention-softmax-in-fp32 \ + --min-lr 1.25e-7 \ + --weight-decay 1e-1 \ + --lr-warmup-fraction 0.01 \ + --clip-grad 1.0 \ + --adam-beta1 0.9 \ + --initial-loss-scale 65536 \ + --adam-beta2 0.95 \ + --no-gradient-accumulation-fusion \ + --no-load-optim \ + --no-load-rng \ + --use-distributed-optimizer \ + --use-fused-swiglu \ + --use-fused-rotary-pos-emb \ + --overlap-grad-reduce \ + --bf16" + +DATA_ARGS=" \ + --data-path $DATA_PATH \ + --split 949,50,1" + +PROFILE_ARGS=" \ + --profile \ + --profile-step-start 2 \ + --profile-step-end 4 \ + --profile-ranks -1 \ + --profile-level level0 \ + --profile-with-cpu \ + --profile-save-path $PROF_NODE_RES_DIR" + +OUTPUT_ARGS=" \ + --log-interval 1 \ + --save-interval 10000 \ + --eval-interval 1000 \ + --eval-iters 0" + +# Add precheck arguments +# PRECHECK_ARGS=" \ +# --do_precheck" + +torchrun $DISTRIBUTED_ARGS pretrain_gpt.py \ + $GPT_ARGS \ + $DATA_ARGS \ + $OUTPUT_ARGS \ + $PROFILE_ARGS \ + --distributed-backend nccl \ + --load $CKPT_LOAD_DIR \ + --save $CKPT_SAVE_DIR \ + | tee $LOG_FILE diff --git a/profiler/msprof_analyze/precheck/examples/scripts/run_llama2_precheck.sh b/profiler/msprof_analyze/precheck/examples/scripts/run_llama2_precheck.sh new file mode 100644 index 0000000000000000000000000000000000000000..495dab8ca6fdeaab6ca87df61a6be0d4d7830f6c --- /dev/null +++ b/profiler/msprof_analyze/precheck/examples/scripts/run_llama2_precheck.sh @@ -0,0 +1,40 @@ +#!/bin/bash + +# You should set the IP addresses of the nodes in the NODES_IP variable +# Change the IP addresses to the actual IP addresses of your nodes +NODES_IP="${NODES_IP:-192.168.0.1,192.168.0.2}" + +# Convert comma-separated NODES_IP to an array nodes_ip +IFS=',' read -r -a nodes_ip <<< "$NODES_IP" + + +echo "Starting distributed precheck with ${#nodes_ip[@]} nodes" +echo "Master node: ${nodes_ip[0]}" +echo "All nodes: ${nodes_ip[*]}" + +output_dir_base="./result/demo_precheck" + +# Add timestamp to task name +timestamp=$(date +"%Y%m%d_%H%M%S") +task_name="llama2-demo_${timestamp}" + +output_dir="${output_dir_base}/${task_name}" +node_prof_save_dir="${output_dir}/node_prof_save_dir" + +# Join array elements with commas +host_ips=$(IFS=,; echo "${nodes_ip[*]}") + +# Run precheck with distributed configuration +msprof-analyze precheck start_all \ + --host_ips "${host_ips}" \ + --master_addr "${nodes_ip[0]}" \ + --master_port 29500 \ + --nnodes ${#nodes_ip[@]} \ + --nproc_per_node 8 \ + --output_dir ${output_dir_base} \ + --task_name ${task_name} \ + --node_prof_save_dir ${node_prof_save_dir} \ + --profiling_cmd "OUTPUT_DIR=${output_dir} bash ./examples/scripts/precheck_run_llama2.sh" \ + --static + +echo "Precheck completed" diff --git a/profiler/msprof_analyze/precheck/examples/scripts/run_precheck.sh b/profiler/msprof_analyze/precheck/examples/scripts/run_precheck.sh new file mode 100644 index 0000000000000000000000000000000000000000..bf5b3b89cff5e945af557ed07c997174ac19a78b --- /dev/null +++ b/profiler/msprof_analyze/precheck/examples/scripts/run_precheck.sh @@ -0,0 +1,37 @@ +#!/bin/bash + +# You should set the IP addresses of the nodes in the NODES_IP variable +# Change the IP addresses to the actual IP addresses of your nodes +NODES_IP="${NODES_IP:-192.168.0.1,192.168.0.2}" + + +# Convert comma-separated NODES_IP to an array nodes_ip +IFS=',' read -r -a nodes_ip <<< "$NODES_IP" + +timestamp=$(date +"%Y%m%d_%H%M%S") +task_name="task_demo_${timestamp}" + +echo "Starting distributed precheck with ${#nodes_ip[@]} nodes" +echo "Master node: ${nodes_ip[0]}" +echo "All nodes: ${nodes_ip[@]}" + +output_dir=./output_test + +PROFILING_CMD="[resnet]" + +# Join array elements with commas +host_ips=$(IFS=,; echo "${nodes_ip[*]}") + +# Run precheck with distributed configuration +msprof-analyze precheck start_all \ + --host_ips "${host_ips}" \ + --master_addr ${nodes_ip[0]} \ + --master_port 29500 \ + --nnodes ${#nodes_ip[@]} \ + --nproc_per_node 8 \ + --output_dir "${output_dir}" \ + --task_name ${task_name} \ + --profiling_cmd "${PROFILING_CMD}" \ + --static + +echo "Precheck completed" diff --git a/profiler/msprof_analyze/precheck/examples/scripts/test_hosts_env.sh b/profiler/msprof_analyze/precheck/examples/scripts/test_hosts_env.sh new file mode 100644 index 0000000000000000000000000000000000000000..68aa4b33ddce4cfaeee0b2b5e1008901b02809e8 --- /dev/null +++ b/profiler/msprof_analyze/precheck/examples/scripts/test_hosts_env.sh @@ -0,0 +1,166 @@ +#!/bin/bash + +# 默认值设置 +HOST_IPS=${HOST_IPS:-""} +TIMEOUT=${TIMEOUT:-5} + +# ANSI 颜色代码 +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +RED='\033[0;31m' +BLUE='\033[0;34m' +NC='\033[0m' # No Color +BOLD='\033[1m' + +# 检查必需参数 +if [ -z "$HOST_IPS" ]; then + echo -e "${RED}Error: HOST_IPS environment variable is not set${NC}" + echo -e "Usage: ${BOLD}HOST_IPS='192.168.0.1,192.168.0.2' [CHECK_CANN=1] [TIMEOUT=5] bash $0${NC}" + exit 1 +fi + +# 获取CANN信息的函数 +get_cann_info() { + # 尝试多种方式获取CANN信息 + if command -v npu-smi &>/dev/null; then + npu_info=$(npu-smi info 2>/dev/null) + driver_version=$(echo "$npu_info" | grep "Driver Version" | awk -F':' '{print $2}' | tr -d ' ') + firmware_version=$(echo "$npu_info" | grep "Firmware Version" | awk -F':' '{print $2}' | tr -d ' ') + echo "Driver:$driver_version;Firmware:$firmware_version" + else + echo "NPU-SMI Not Found" + fi +} + +# 打印标题 +echo -e "\n${BOLD}🔍 Cluster Environment Checker${NC}" +echo -e "Usage: ${BOLD}HOST_IPS='192.168.0.1,192.168.0.2' [CHECK_CANN=1] [TIMEOUT=5] bash $0${NC}" +echo -e "${BLUE}━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━${NC}" + +# 获取本机环境信息 +echo -e "\n${BOLD}📊 Step 1: Collecting local environment info...${NC}" +echo -e "${BLUE}Detecting Python environment...${NC}" +LOCAL_PYTHON_PATH=$(which python3) +LOCAL_PYTHON_VERSION=$($LOCAL_PYTHON_PATH -V 2>&1) +echo -e "${BLUE}Checking installed packages...${NC}" +LOCAL_MSPROF_VERSION=$($LOCAL_PYTHON_PATH -m pip show msprof-analyze | grep Version | awk '{print $2}') +LOCAL_TORCH_VERSION=$($LOCAL_PYTHON_PATH -m pip show torch | grep Version | awk '{print $2}') +LOCAL_TORCH_NPU_VERSION=$($LOCAL_PYTHON_PATH -m pip show torch_npu | grep Version | awk '{print $2}') + +echo -e "\n${BOLD}📌 Local Environment Summary:${NC}" +echo -e " • Python Path: ${GREEN}$LOCAL_PYTHON_PATH${NC}" +echo -e " • Python Version: ${GREEN}$LOCAL_PYTHON_VERSION${NC}" +echo -e " • Msprof-analyze: ${GREEN}v$LOCAL_MSPROF_VERSION${NC}" +echo -e " • Torch: ${GREEN}v$LOCAL_TORCH_VERSION${NC}" +echo -e " • Torch_NPU: ${GREEN}v$LOCAL_TORCH_NPU_VERSION${NC}" + +# 构建远程检查命令 +CHECK_CMD=$(cat << EOF +echo "=== Python Path Check ===" && \ +test -f $LOCAL_PYTHON_PATH && \ +echo "=== Python Version ===" && \ +$LOCAL_PYTHON_PATH -V && \ +echo "=== Msprof-analyze Version ===" && \ +$LOCAL_PYTHON_PATH -m pip show msprof-analyze | grep Version | awk '{print \$2}' && \ +echo "=== Torch Version ===" && \ +$LOCAL_PYTHON_PATH -m pip show torch | grep Version | awk '{print \$2}' && \ +echo "=== Torch_NPU Version ===" && \ +$LOCAL_PYTHON_PATH -m pip show torch_npu | grep Version | awk '{print \$2}' && \ +echo "=== TMUX Check ===" && \ +which tmux +EOF +) + +# 检查每个远程主机 +echo -e "\n${BOLD}🔄 Step 2: Checking cluster nodes...${NC}" +IFS=',' read -ra HOSTS <<< "$HOST_IPS" +total_hosts=${#HOSTS[@]} +current_host=0 +failed_hosts=() + +for host in "${HOSTS[@]}"; do + ((current_host++)) + echo -e "\n${BLUE}━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━${NC}" + echo -e "${BOLD}📡 Checking host [$current_host/$total_hosts]: ${YELLOW}$host${NC}" + + # 检查ssh连接 + echo -e " ⏳ Testing SSH connection..." + if ! ssh -o BatchMode=yes -o ConnectTimeout=$TIMEOUT $host "exit 0" &>/dev/null; then + echo -e " ${RED}❌ SSH connection failed${NC}" + failed_hosts+=("$host [SSH Failed]") + continue + fi + echo -e " ${GREEN}✓ SSH connection successful${NC}" + + # 检查Python解释器 + echo -e " ⏳ Verifying Python interpreter..." + if ! ssh -o BatchMode=yes -o ConnectTimeout=$TIMEOUT $host "test -f $LOCAL_PYTHON_PATH" &>/dev/null; then + echo -e " ${RED}❌ Python interpreter not found at: $LOCAL_PYTHON_PATH${NC}" + failed_hosts+=("$host [Python Not Found]") + continue + fi + echo -e " ${GREEN}✓ Python interpreter verified${NC}" + + # 检查环境 + echo -e " ⏳ Checking environment..." + remote_output=$(ssh -o BatchMode=yes -o ConnectTimeout=$TIMEOUT $host "$CHECK_CMD" 2>&1) + if [ $? -ne 0 ]; then + echo -e " ${RED}❌ Environment check failed${NC}" + echo -e " Error details: $remote_output" + failed_hosts+=("$host [Check Failed]") + continue + fi + + # 解析远程输出 + remote_python_version=$(echo "$remote_output" | awk '/=== Python Version ===/{getline; print}') + remote_msprof_version=$(echo "$remote_output" | awk '/=== Msprof-analyze Version ===/{getline; print}') + remote_torch_version=$(echo "$remote_output" | awk '/=== Torch Version ===/{getline; print}') + remote_torch_npu_version=$(echo "$remote_output" | awk '/=== Torch_NPU Version ===/{getline; print}') + remote_tmux_path=$(echo "$remote_output" | awk '/=== TMUX Check ===/{getline; print}') + + # 检查结果 + errors=() + + [ "$remote_python_version" != "$LOCAL_PYTHON_VERSION" ] && \ + errors+=("Python version mismatch: Local=$LOCAL_PYTHON_VERSION Remote=$remote_python_version") + + [ "$remote_msprof_version" != "$LOCAL_MSPROF_VERSION" ] && \ + errors+=("Msprof version mismatch: Local=$LOCAL_MSPROF_VERSION Remote=$remote_msprof_version") + + [ "$remote_torch_version" != "$LOCAL_TORCH_VERSION" ] && \ + errors+=("Torch version mismatch: Local=$LOCAL_TORCH_VERSION Remote=$remote_torch_version") + + [ "$remote_torch_npu_version" != "$LOCAL_TORCH_NPU_VERSION" ] && \ + errors+=("Torch_NPU version mismatch: Local=$LOCAL_TORCH_NPU_VERSION Remote=$remote_torch_npu_version") + + [ -z "$remote_tmux_path" ] && \ + errors+=("TMUX not found") + + if [ ${#errors[@]} -eq 0 ]; then + echo -e " ${GREEN}✓ All environment checks passed${NC}" + else + echo -e " ${RED}❌ Environment check failed:${NC}" + for error in "${errors[@]}"; do + echo -e " • ${RED}$error${NC}" + done + failed_hosts+=("$host [Version Mismatch]") + fi +done + +# 总结报告 +echo -e "\n${BLUE}━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━${NC}" +echo -e "${BOLD}📋 Final Report${NC}" +if [ ${#failed_hosts[@]} -eq 0 ]; then + echo -e "${GREEN}✅ All $total_hosts hosts passed environment checks!${NC}" + exit 0 +else + echo -e "${RED}❌ Environment check failed for ${#failed_hosts[@]} out of $total_hosts hosts:${NC}" + for failed_host in "${failed_hosts[@]}"; do + echo -e " • ${RED}$failed_host${NC}" + done + echo -e "\n${YELLOW}💡 Tips:${NC}" + echo -e " • Ensure all hosts have the same Python environment" + echo -e " • Check if tmux is installed: ${BOLD}sudo apt-get install tmux${NC}" + echo -e " • Verify SSH connectivity: ${BOLD}ssh-copy-id user@host${NC}" + exit 1 +fi diff --git a/profiler/msprof_analyze/precheck/examples/scripts/test_hosts_ssh.sh b/profiler/msprof_analyze/precheck/examples/scripts/test_hosts_ssh.sh new file mode 100644 index 0000000000000000000000000000000000000000..7489bb601ceaff09acf43aadd2b36e40a6682fb5 --- /dev/null +++ b/profiler/msprof_analyze/precheck/examples/scripts/test_hosts_ssh.sh @@ -0,0 +1,61 @@ +### SSH 连通性测试 +# 保存为 test_hosts_ssh.sh +#!/bin/bash + +# 默认值设置 +HOST_IPS=${HOST_IPS:-""} +TIMEOUT=${TIMEOUT:-5} + +# 检查必需参数 +if [ -z "$HOST_IPS" ]; then + echo "Error: HOST_IPS environment variable is not set" + echo "Usage: HOST_IPS='192.168.0.1,192.168.0.2' TIMEOUT=5 bash $0" + exit 1 +fi + +echo "Testing SSH connections with timeout ${TIMEOUT}s..." +echo "Host list: $HOST_IPS" +echo "-----------------------------------" + +# 测试每个主机的SSH连接 +failed_hosts=() +IFS=',' read -ra HOSTS <<< "$HOST_IPS" +for host in "${HOSTS[@]}"; do + echo -n "Testing SSH connection to $host... " + if ssh -o BatchMode=yes -o ConnectTimeout=$TIMEOUT $host "exit 0" &> /dev/null; then + echo "Success ✓" + else + echo "Failed ✗" + failed_hosts+=($host) + fi +done + +# 如果有失败的主机,输出设置建议 +if [ ${#failed_hosts[@]} -ne 0 ]; then + echo -e "\n❌ Some hosts are not accessible via SSH" + echo "Please run these commands to set up passwordless SSH:" + echo "-----------------------------------" + for host in "${failed_hosts[@]}"; do + echo "# 1. If ~/.ssh/id_rsa doesn't exist, generate it" + echo "[ ! -f ~/.ssh/id_rsa ] && ssh-keygen -t rsa -N '' -f ~/.ssh/id_rsa" + echo "" + echo "# 2. Copy your key to remote host" + echo "ssh-copy-id $USER@$host" + echo "" + echo "# 3. Set correct permissions" + echo "chmod 600 ~/.ssh/id_rsa" + echo "-----------------------------------" + done + exit 1 +else + echo -e "\n✅ All SSH connections successful!" +fi + +# 使用方法: +# ```bash +# # 方式1:直接运行(使用默认超时时间5秒) +# HOST_IPS="192.168.0.1,192.168.0.2" bash test_hosts_ssh.sh + +# # 方式2:指定超时时间 +# HOST_IPS="192.168.0.1,192.168.0.2" TIMEOUT=3 bash test_hosts_ssh.sh +# ``` diff --git a/profiler/msprof_analyze/precheck/manager/__init__.py b/profiler/msprof_analyze/precheck/manager/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/profiler/msprof_analyze/precheck/manager/args_manager.py b/profiler/msprof_analyze/precheck/manager/args_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..252a51bae5257e750541fde452db69e9e88eb8bb --- /dev/null +++ b/profiler/msprof_analyze/precheck/manager/args_manager.py @@ -0,0 +1,446 @@ +import argparse +import ipaddress +import os +import re +import shlex +import shutil +import sys +import logging +from typing import List, Union +from collections import OrderedDict + +from msprof_analyze.precheck.common.constant import Constant +from msprof_analyze.precheck.common.utils import cn_now +from msprof_analyze.prof_common.path_manager import PathManager + +logger = logging.getLogger(__name__) + + +class BaseArgsManager: + def __init__(self, args): + self._args = args + + def __repr__(self): + return str(self.to_dict()) + + @property + def master_addr(self): + return self._args.master_addr + + @property + def master_port(self): + return self._args.master_port + + @property + def nnodes(self): + return self._args.nnodes + + @property + def nproc_per_node(self): + return self._args.nproc_per_node + + @property + def node_prof_save_dir(self): + return self._args.node_prof_save_dir or os.path.join(self.task_output_dir, 'node_prof_save_dir') + + @property + def master_prof_gather_dir(self): + return self._args.master_prof_gather_dir or os.path.join(self.task_output_dir, 'master_prof_gather_dir') + + @property + def output_dir(self): + return self._args.output_dir + + @property + def task_name(self): + if self._args.task_name: + return self._args.task_name + return "task_" + cn_now().strftime("%Y%m%d-%H%M%S") + + @property + def static(self): + return self._args.static + + @property + def task_output_dir(self): + return os.path.join(self.output_dir, self.task_name) + + @property + def profiling_cmd(self): + return self._args.profiling_cmd + + @property + def prof_in_shared_storage(self): + return getattr(self._args, 'prof_in_shared_storage', False) + + @staticmethod + def escape_special_chars(text): + ESCAPE_CHARS_MAP = { + '\n': '\\n', + '\t': '\\t', + '\r': '\\r', + '\\': '\\\\', + '\"': '\\\"', + '\'': '\\\'' + } + return re.sub(r'([\n\t\r\\\'"])', lambda match: ESCAPE_CHARS_MAP[match.group()], text) + + @staticmethod + def _check_output_path_valid(output_path: str) -> Union[Exception, None]: + try: + if not os.path.exists(output_path): + PathManager.check_input_directory_path(output_path) + else: + PathManager.check_input_directory_path(output_path) + PathManager.check_path_owner_consistent(output_path) + except Exception as e: + return e + return None + + @staticmethod + def _check_ip_valid(ip: str) -> Union[Exception, None]: + try: + ipaddress.ip_address(ip) + except ValueError as e: + return e + return None + + @staticmethod + def _check_int_range( + value: int, min_value: int = Constant.ARG_MIN_INT_VALUE, max_value: int = Constant.ARG_MAX_INT_VALUE + ) -> Union[Exception, None]: + if not (min_value <= value <= max_value): + return ValueError(f"The value must be between {min_value} and {max_value}.") + return None + + @staticmethod + def _check_executable_path_valid(executable_path: str) -> Union[Exception, None]: + try: + PathManager.check_path_owner_consistent(executable_path) + if not os.path.isfile(executable_path): + raise ValueError("The path is not a valid executable file.") + if not os.access(executable_path, os.X_OK): + raise ValueError("The file at the path is not executable.") + except Exception as e: + return e + return None + + @staticmethod + def _check_identifier_valid(identifier: str) -> Union[Exception, None]: + pattern = r'^[a-zA-Z_][a-zA-Z0-9_-]*$' + if not re.match(pattern, identifier): + return ValueError(f"It must start with a letter or underscore, " + f"followed by any number of letters, digits, underscores, or dashes.") + return None + + @staticmethod + def _check_command_injection(cmd: str) -> Union[Exception, None]: + dangerous_chars = [';', '&&', '||', '|', '>', '<', '`', '$', '\\'] + for char in dangerous_chars: + if char in cmd: + return ValueError( + f"Command contains dangerous character '{char}'. " + "Command injection is not allowed." + ) + return None + + @staticmethod + def _check_dangerous_commands(cmd: str) -> Union[Exception, None]: + dangerous_commands = [ + 'rm', 'mv', 'cp', 'chmod', 'chown', 'dd', + 'mkfs', 'mount', 'umount', 'sudo', 'su', + 'reboot', 'shutdown', 'poweroff', 'init', + 'passwd', 'adduser', 'deluser', 'useradd', + 'userdel', 'groupadd', 'groupdel' + ] + + cmd_parts = shlex.split(cmd) + if not cmd_parts: + return ValueError("Empty command is not allowed") + + base_cmd = os.path.basename(cmd_parts[0]) + if base_cmd in dangerous_commands: + return ValueError( + f"Command '{base_cmd}' is not allowed for security reasons" + ) + return None + + @classmethod + def safe_format(cls, format_str: str, *args, max_len=Constant.ARG_MAX_LEN): + """ + Safely formats a string by truncating arguments longer than a specified maximum length and escaping special characters. + + This function is designed to create user-friendly error messages by ensuring that all arguments are displayed in a safe and concise manner. + It truncates any argument that exceeds the maximum length and appends an ellipsis to indicate the truncation. + Additionally, it escapes special characters in the arguments to prevent formatting errors or injection issues. + + Args: + format_str (str): The format string into which the arguments are inserted. + *args: Variable length argument list to be formatted into the format_str. + max_len (int): The maximum allowed length of any argument string after which it will be truncated. + Defaults to Constant.MAX_ARG_LEN. + + Returns: + str: A formatted string with all arguments safely inserted. + """ + + def _str(x): + x_str = str(x) + if len(x_str) > max_len: + x_str = x_str[:max_len] + "..." + return cls.escape_special_chars(x_str) + + args = [_str(arg) for arg in args] + return format_str.format(*args) + + @classmethod + def raise_error(cls, error_format_msg, *args): + """ + Raises a RuntimeError with a formatted message that includes special character escaping and length limitation. + + This method is designed to handle untrusted external parameters `*args` by ensuring that the error message is user-friendly. + It applies special character escaping and truncates arguments to a predefined maximum length to prevent formatting errors or injection issues. + + Args: + error_format_msg (str): The format string into which the arguments are inserted. + *args: Variable length argument list to be formatted into the error_format_msg. + """ + err_msg = cls.safe_format(error_format_msg, *args) + raise RuntimeError(err_msg) + + def to_dict(self): + """Automatically convert all properties to a dictionary.""" + properties_dict = {} + for prop in dir(self): + if isinstance(getattr(type(self), prop, None), property): + properties_dict[prop] = getattr(self, prop) + return properties_dict + + def check_args(self): + + error = self._check_ip_valid(self.master_addr) + if error: + self.raise_error('Master address {} is not valid: {}', self.master_addr, error) + + error = self._check_int_range(self.master_port, + min_value=Constant.ARG_MIN_PORT_VALUE, max_value=Constant.ARG_MAX_PORT_VALUE) + if error: + self.raise_error('Master port {} is not valid: {}', self.master_port, error) + + error = self._check_int_range(self.nnodes, min_value=1) + if error: + self.raise_error('Total number of nodes {} is not valid: {}', self.nnodes, error) + + error = self._check_int_range(self.nproc_per_node, min_value=1) + if error: + self.raise_error('Number of processes per node {} is not valid: {}', self.nproc_per_node, error) + + error = self._check_output_path_valid(self.output_dir) + if error: + self.raise_error('Output directory {} is not valid: {}', self.output_dir, error) + + error = self._check_identifier_valid(self.task_name) + if error: + self.raise_error('Task name {} is not valid: {}', self.task_name, error) + + error = self._check_output_path_valid(self.node_prof_save_dir) + if error: + self.raise_error('Node prof save directory {} is not valid: {}', self.node_prof_save_dir, error) + + error = self._check_output_path_valid(self.master_prof_gather_dir) + if error: + self.raise_error('Master prof gather directory {} is not valid: {}', self.master_prof_gather_dir, error) + + self._check_profiling_cmd_valid(self.profiling_cmd) + + def _check_profiling_cmd_valid(self, profiling_cmd: str) -> None: + if not profiling_cmd.strip(): + logger.error('Profiling command should not be empty.') + + if profiling_cmd in Constant.DEFAULT_PROFILING_COMMANDS: + logger.info(self.safe_format('Using default profiling command for {}', profiling_cmd)) + return + + if len(self.profiling_cmd) > Constant.ARG_MAX_LEN: + self.raise_error( + 'The profiling command is too long, it must be less than {} characters', Constant.ARG_MAX_LEN) + + error = self._check_command_injection(self.profiling_cmd) + if error: + self.raise_error('Profiling command {} is not valid: {}', self.profiling_cmd, error) + + error = self._check_dangerous_commands(self.profiling_cmd) + if error: + self.raise_error('Profiling command {} is not valid: {}', self.profiling_cmd, error) + + +class PrecheckArgsManager(BaseArgsManager): + def __init__(self, args): + super().__init__(args) + + self._args = args + self._ssh_remote_hosts = {} + self._host_ips = [] + + self.check_args() + + @property + def host_ips(self): + return self._host_ips + + @property + def host_config_file(self): + return self._args.host_config_file + + @property + def ssh_remote_hosts(self): + return self._ssh_remote_hosts + + @property + def python_path(self): + if not self._args.python_path: + return sys.executable + + if os.path.exists(self._args.python_path): + return self._args.python_path + + python_path = shutil.which(self._args.python_path) + return python_path + + @classmethod + def _check_host_ips_valid(cls, host_ips: List[str]) -> Union[Exception, None]: + if not host_ips: + return None + + for i, ip in enumerate(host_ips): + if not ipaddress.ip_address(ip): + return ValueError(f"The {i}-th host ip is not valid.") + + if len(host_ips) != len(set(host_ips)): + return ValueError("Host IPs must be unique.") + + return None + + def try_to_parse_host_config_file(self, host_config_file: str) -> Union[Exception, OrderedDict]: + if not host_config_file: + logger.info("SSH config file is not provided.") + logger.info("Use default ssh settings for all nodes: ssh_key_file, user, port = ~/.ssh/id_rsa, $USER, 22") + return {} + + if not os.path.isfile(host_config_file): + return FileNotFoundError(f"SSH config file {host_config_file} does not exist.") + + PathManager.check_path_readable(host_config_file) + PathManager.check_file_size(host_config_file) + + ssh_remote_hosts = [] + required_fields = ['host_ip', 'ssh_key_file', 'user', 'port'] + with open(host_config_file, 'r') as f: + header = f.readline().strip().split(',') + if any(field not in header for field in required_fields): + return ValueError(f"Host config file {host_config_file} is missing required fields: {required_fields}") + + for line in f: + values = line.strip().split(',') + if len(values) != len(required_fields): + return ValueError( + f"Host config file {host_config_file} has invalid number of fields in line: {line}") + + host_ip, ssh_key_file, user, port = values + ssh_key_file = PathManager.expanduser_for_argumentparser(ssh_key_file) + port = int(port) + + exception = None + try: + PathManager.check_path_readable(ssh_key_file) + if os.stat(ssh_key_file).st_mode & 0o777 != 0o600: + raise ValueError(f"SSH key file {ssh_key_file} must have permissions set to 600") + + exception = self._check_int_range(port, min_value=Constant.ARG_MIN_PORT_VALUE, + max_value=Constant.ARG_MAX_PORT_VALUE) \ + or self._check_identifier_valid(user) \ + or self._check_ip_valid(host_ip) + + except Exception as e: + exception = e + + if exception: + return RuntimeError( + f"Host config file {host_config_file} is not valid, invalid line: {line}, error: {exception}") + + ssh_remote_hosts.append({ + 'host': host_ip, + 'username': user, + 'key_filename': ssh_key_file, + 'port': int(port) + }) + + ssh_remote_hosts = OrderedDict({item['host']: item for item in ssh_remote_hosts}) + return ssh_remote_hosts + + def check_args(self): + super().check_args() + + error = self._check_executable_path_valid(self.python_path) + if error: + self.raise_error('Python path {} is not valid: {}', self.python_path, error) + + # Ensure either host_ips or host_config_file is provided + if not self.host_config_file and not self._args.host_ips: + self.raise_error('Either host config file or host ips must be provided') + + # If host_ips is provided, validate it first + if self._args.host_ips: + error = self._check_host_ips_valid(self._args.host_ips) + if error: + self.raise_error('Host ips {} is not valid: {}', self._args.host_ips, error) + + # Set the validated host_ips + self._host_ips = self._args.host_ips + + # If config file is provided, parse and validate it + if self.host_config_file: + res = self.try_to_parse_host_config_file(self.host_config_file) + if isinstance(res, Exception): + self.raise_error('Host config file {} is not valid: {}', self.host_config_file, res) + self._ssh_remote_hosts = res + config_file_ips = list(self._ssh_remote_hosts.keys()) + + # If host_ips is also provided, verify they match + if self._args.host_ips: + if not set(self._args.host_ips) == set(config_file_ips): + self.raise_error('Host ips does not match the IPs in host config file. Given: {}, In file: {}', + self._args.host_ips, config_file_ips) + else: + # If only config file is provided, use IPs from the config file + self._host_ips = config_file_ips + + # Validate number of nodes and master node configuration + if self.nnodes != len(self.host_ips): + self.raise_error( + 'The number of nodes {} is not equal to the number of host ips {}', + self.nnodes, len(self.host_ips)) + + if self.master_addr != self.host_ips[0]: + self.raise_error( + 'The master address {} is not the first host ip {}', + self.master_addr, self.host_ips[0]) + + +class PrecheckRunnerArgsManager(BaseArgsManager): + def __init__(self, args): + super().__init__(args) + + self._args = args + self.check_args() + + @property + def node_rank(self): + return self._args.node_rank + + def check_args(self): + super().check_args() + + error = self._check_int_range(self.node_rank, min_value=0, max_value=self.nnodes - 1) + if error: + self.raise_error('Node rank {} is not valid: {}', self.node_rank, error) diff --git a/profiler/msprof_analyze/precheck/manager/disk_manager.py b/profiler/msprof_analyze/precheck/manager/disk_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..a497c992cbe895e2dcaf115a8ad3469a687a0759 --- /dev/null +++ b/profiler/msprof_analyze/precheck/manager/disk_manager.py @@ -0,0 +1,24 @@ +import os +import logging + +logger = logging.getLogger(__name__) + + +class DiskManager: + @staticmethod + def check_disk_space(input_prof_path, prof_data_size_gb): + if not os.path.exists(input_prof_path): + logger.error(f"路径不存在: {input_prof_path}") + raise FileNotFoundError(f"路径不存在: {input_prof_path}") + + if not os.access(input_prof_path, os.R_OK): + logger.error(f"无读取权限: {input_prof_path}") + raise PermissionError(f"无读取权限: {input_prof_path}") + + statvfs = os.statvfs(input_prof_path) + disk_free_gb = statvfs.f_bavail * statvfs.f_frsize / (1024 ** 3) + + if disk_free_gb - prof_data_size_gb <= 50: + logger.error(f"磁盘空间不足: {disk_free_gb:.2f}GB, 输入数据大小: {prof_data_size_gb:.2f}GB") + raise BufferError(f"磁盘空间不足: {disk_free_gb:.2f}GB, 输入数据大小: {prof_data_size_gb:.2f}GB") + diff --git a/profiler/msprof_analyze/precheck/manager/distribute_manager.py b/profiler/msprof_analyze/precheck/manager/distribute_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..f35fdf45c6ad6245ebf3e1225faec41c6a0382c8 --- /dev/null +++ b/profiler/msprof_analyze/precheck/manager/distribute_manager.py @@ -0,0 +1,52 @@ +from copy import deepcopy + + +class DistributeManager: + def __init__(self, args): + self.master_addr = args.master_addr + self.master_port = args.master_port + self.nnodes = args.nnodes + self.nproc_per_node = args.nproc_per_node + self.node_rank = args.node_rank + + self.local_rank = 0 + self.rank = self.local_rank + self.node_rank * self.nproc_per_node + + self.world_size = self.nnodes * self.nproc_per_node + self.local_world_size = self.nproc_per_node + + self.group_rank = -1 + + def __repr__(self): + """ + Custom __repr__ method to print out the object in a human-readable format + """ + return (f"DistributeManager(master_addr='{self.master_addr}', " + f"master_port='{self.master_port}', nnodes={self.nnodes}, " + f"nproc_per_node={self.nproc_per_node}, node_rank={self.node_rank}, " + f"local_rank={self.local_rank}, rank={self.rank}, " + f"world_size={self.world_size}, local_world_size={self.local_world_size}, " + f"group_rank={self.group_rank})") + + def update_local_rank(self, local_rank: int): + self.local_rank = local_rank + self.rank = self.local_rank + self.node_rank * self.nproc_per_node + return deepcopy(self) + + def get_dist_env_data(self): + self.rank = self.local_rank + self.node_rank * self.nproc_per_node + + data = { + "MASTER_ADDR": self.master_addr, + "MASTER_PORT": self.master_port, + "LOCAL_RANK": self.local_rank, + "GROUP_RANK": self.group_rank, + "NODE_RANK": self.node_rank, + "RANK": self.rank, + "WORLD_SIZE": self.world_size, + "LOCAL_WORLD_SIZE": self.local_world_size, + } + + for k in data: + data[k] = str(data[k]) + return data diff --git a/profiler/msprof_analyze/precheck/manager/group_manager.py b/profiler/msprof_analyze/precheck/manager/group_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..fd492bd3d8b54612ccf8401d2e1997f5a0908081 --- /dev/null +++ b/profiler/msprof_analyze/precheck/manager/group_manager.py @@ -0,0 +1,123 @@ +import math +import os +import torch.distributed as dist + +from msprof_analyze.advisor.utils.utils import singleton + + +class EnvGroup: + def __init__(self, rank, local_rank, world_size, master_addr, master_port, group_rank, local_world_size): + self.rank = rank + self.local_rank = local_rank + self.world_size = world_size + self.master_addr = master_addr + self.master_port = master_port + self.group_rank = group_rank + self.local_world_size = local_world_size + self.check_all_attribute() + + def check_all_attribute(self): + if not isinstance(self.rank, int): + raise ValueError('rank must be an integer') + + if not isinstance(self.local_rank, int): + raise ValueError('local_rank must be an integer') + + if not isinstance(self.world_size, int): + raise ValueError('world_size must be an integer') + + if not isinstance(self.master_addr, str): + raise ValueError('master_addr must be an string') + + if not isinstance(self.master_port, int): + raise ValueError('master_port must be an integer') + + if not isinstance(self.group_rank, int): + raise ValueError('group_rank must be an integer') + + if not isinstance(self.local_world_size, int): + raise ValueError('local_world_size must be an integer') + + def set_env(self): + os.environ["RANK"] = str(self.rank) + os.environ["LOCAL_RANK"] = str(self.local_rank) + os.environ["WORLD_SIZE"] = str(self.world_size) + os.environ["MASTER_ADDR"] = self.master_addr + os.environ["MASTER_PORT"] = str(self.master_port) + os.environ["GROUP_RANK"] = str(self.group_rank) + os.environ["LOCAL_WORLD_SIZE"] = str(self.local_world_size) + + +class SubGroup: + def __init__(self, group, master_rank, ranks, file_sizes, file_hashes): + self.group = group + self.master_rank = master_rank + self.ranks = ranks + self.file_sizes = file_sizes + self.file_hashes = file_hashes + self.max_file_sizes = max(file_sizes) + self.split_file_size = None + self.splits = None + self.max_splits = None + + def split_size(self, split_file_size): + self.split_file_size = split_file_size + self.splits = [] + self.max_splits = math.ceil(self.max_file_sizes / split_file_size) + for file_size in self.file_sizes: + cur_splits = [] + for _ in range(self.max_splits): + if file_size > 0: + cur_splits.append(min(split_file_size, file_size)) + else: + cur_splits.append(0) + file_size -= split_file_size + self.splits.append(cur_splits) + + +@singleton +class GroupManager: + _initialized = False + + def __init__(self): + if not self._initialized: + self._rank = int(os.environ['RANK']) + self._local_rank = int(os.environ['LOCAL_RANK']) + self._world_size = int(os.environ['WORLD_SIZE']) + self._group_rank = int(os.environ['GROUP_RANK']) + self._rank_size = int(os.environ['LOCAL_WORLD_SIZE']) + self._local_group = None + self._node_group = None + self._sub_group_dict = {} + + def get_rank(self): + return self._rank + + def get_local_rank(self): + return self._local_rank + + def get_world_size(self): + return self._world_size + + def get_rank_size(self): + return self._rank_size + + def get_group_rank(self): + return self._group_rank + + def get_local_group(self): + if self._local_group is None: + groups = [x for x in range(self._group_rank * self._rank_size, (self._group_rank + 1) * self._rank_size)] + self._local_group = dist.new_group(ranks=groups) + return self._local_group + + def add_rank_sub_group(self, sub_group, ranks, file_sizes, file_hashes): + for rank in ranks: + self._sub_group_dict[rank] = SubGroup(group=sub_group, master_rank=ranks[0], ranks=ranks, + file_sizes=file_sizes, file_hashes=file_hashes) + + def get_rank_sub_group(self, rank): + if rank in self._sub_group_dict: + return self._sub_group_dict[rank] + else: + return None diff --git a/profiler/msprof_analyze/precheck/manager/task_manager.py b/profiler/msprof_analyze/precheck/manager/task_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..d08f7afdad2c624e210e506ae5897480c0e6ced7 --- /dev/null +++ b/profiler/msprof_analyze/precheck/manager/task_manager.py @@ -0,0 +1,79 @@ +import os +import logging +import argparse + +from msprof_analyze.precheck.analyze.advisor_adaptor import advisor_adaptor +from msprof_analyze.prof_common.path_manager import PathManager + +logger = logging.getLogger(__name__) + + +class TaskManager: + ADVISOR = 'advisor' + supported_analyzer = { + ADVISOR: advisor_adaptor, + } + + all_analyzer = list(supported_analyzer.keys()) + + @staticmethod + def add_analyzer(analyzer_name, analyzer_class): + + if analyzer not in TaskManager.supported_analyzer: + TaskManager.supported_analyzer[analyzer_name] = analyzer_class + + @staticmethod + def get_analyzer(analyzer_name): + return TaskManager.supported_analyzer.get(analyzer_name) + + @staticmethod + def get_result(analyzer_name, input_path, output): + + if analyzer_name not in TaskManager.all_analyzer: + logger.error("Error analyzer %s, supported analyzer are %s", analyzer_name, TaskManager.all_analyzer) + raise ValueError("Error analyzer %s, supported analyzer are %s", analyzer_name, TaskManager.all_analyzer) + + input_profiling_path_real = PathManager.get_realpath(input_path) + output_path_real = PathManager.get_realpath(output) + try: + analyze = TaskManager.get_analyzer(analyzer_name) + analyzer_instance = analyze() + result = analyzer_instance.analyze(input_profiling_path=input_profiling_path_real, + output_path=output_path_real) + + except Exception as e: + logger.error("%s is skipped when an exception is encountered. The exception is as follows: %s", + analyzer_name, e) + + +def get_args(): + parser = argparse.ArgumentParser(description="Profiler task manager") + + # Add command-line arguments + parser.add_argument('--input_profiling_path', type=str, + default=os.path.abspath("./result/"), + help="Path to the input profiling data") + parser.add_argument('--output_path', type=str, default=os.path.abspath('../result'), + help="Path to store the output results") + + return parser.parse_args() + + +if __name__ == "__main__": + try: + # Get arguments from the command line + args = get_args() + + # Use the command-line arguments or the default values + input_profiling_path = args.input_profiling_path + output_path = args.output_path + # Access all analyzers from the TaskManager + all_analyzer = TaskManager.all_analyzer + + # Loop through all analyzers and fetch the results + for analyzer in all_analyzer: + TaskManager.get_result(analyzer=analyzer, input_profiling_path=input_profiling_path, + output_path=output_path) + + except Exception as error: + logger.error("%s", error) diff --git a/profiler/msprof_analyze/precheck/precheck.py b/profiler/msprof_analyze/precheck/precheck.py new file mode 100644 index 0000000000000000000000000000000000000000..64c938983813ab35e07d4ef5208fbe5423808100 --- /dev/null +++ b/profiler/msprof_analyze/precheck/precheck.py @@ -0,0 +1,33 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from msprof_analyze.precheck.env_check.check_item_factory import CheckItemFactory + + +class Precheck: + + @staticmethod + def env_precheck(**kwargs): + check_type = kwargs.get("check_type") + if not check_type: + return + check_items = CheckItemFactory.get_check_item(check_type) + for check_item in check_items: + check_obj = check_item(**kwargs) + check_obj.check() + return + + +if __name__ == '__main__': + Precheck.env_precheck(check_type="env_variable") diff --git a/profiler/msprof_analyze/precheck/requirements.txt b/profiler/msprof_analyze/precheck/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..8203bbe24f892e6d71b116939629c3582b2b1582 --- /dev/null +++ b/profiler/msprof_analyze/precheck/requirements.txt @@ -0,0 +1,41 @@ +absl-py==2.1.0 +attrs==24.2.0 +auto-tune +cloudpickle==3.0.0 +decorator==5.1.1 +filelock==3.15.4 +fsspec==2024.6.1 +MarkupSafe==2.1.5 +ml-dtypes==0.2.0 +mpmath==1.3.0 +networkx==3.1 +numpy==1.24.4 +psutil==6.0.0 +scipy==1.10.1 +sympy==1.13.2 +te +torch_npu==2.4.0 +tornado==6.4.1 +typing_extensions==4.12.2 + + + +## requirements for mstt advisor +click +tabulate +jinja2 +PyYAML +tqdm +prettytable +ijson +requests +xlsxwriter +SQLAlchemy +urllib3<2.0 +# bottleneck >= 1.3.6 # 注释行没有问题 +pandas + +# 如果你想要确保下载所有包的完整版本和所有的依赖项(包括子依赖), +# pip download -r requirements.txt -d pip_cache --no-deps +# 在离线环境中使用缓存安装依赖 +# pip install --no-index --find-links=file:///path/to/pip_cache -r requirements.txt diff --git a/profiler/msprof_analyze/precheck/runner/__init__.py b/profiler/msprof_analyze/precheck/runner/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/profiler/msprof_analyze/precheck/runner/__main__.py b/profiler/msprof_analyze/precheck/runner/__main__.py new file mode 100644 index 0000000000000000000000000000000000000000..8f031ae14c2b3610799e2824f4a2d7212ae19eae --- /dev/null +++ b/profiler/msprof_analyze/precheck/runner/__main__.py @@ -0,0 +1,160 @@ +import subprocess +import sys +import os +import logging + +from msprof_analyze.precheck.common.constant import Constant +from msprof_analyze.precheck.common.logger import add_file_handler, create_logger +from msprof_analyze.precheck.common.utils import check_file_owner_and_permission, cn_now +from msprof_analyze.precheck.manager.args_manager import PrecheckRunnerArgsManager +from msprof_analyze.precheck.runner.runners import CollectorRunner, AdvisorRunner +from msprof_analyze.precheck.manager.distribute_manager import DistributeManager +from msprof_analyze.prof_common.path_manager import PathManager + +logging.basicConfig(level=Constant.LOGGING_LEVEL) +logger = create_logger("msprof_analyze.precheck", Constant.LOGGING_LEVEL, use_memory_handler=True) + + +def get_conda_envs_info(python_path=sys.executable): + """ + Get the conda environment activation command based on Python executable path. + For non-conda environments, returns source ~/.bashrc command. + + Args: + python_path (str): The path to the Python executable. + + Returns: + tuple: A tuple containing (env_name, activation_command). + For conda: (env_name, "source /path/to/conda/bin/activate env_name") + For non-conda: (None, "source ~/.bashrc") + """ + try: + # Check if we're in a conda environment using CONDA_PREFIX + conda_prefix = os.environ.get('CONDA_PREFIX') + if conda_prefix: + conda_env = os.path.basename(conda_prefix) + conda_base = os.path.dirname(os.path.dirname(conda_prefix)) if 'envs' in conda_prefix else conda_prefix + activate_script = os.path.join(conda_base, "bin", "activate") + + if os.path.exists(activate_script): + check_file_owner_and_permission(activate_script) + return conda_env, f"source {activate_script} {conda_env}" + + # Fallback to path-based detection + CONDA_ENV_BASE_BIAS = 4 + path_splits = python_path.rsplit(os.path.sep, CONDA_ENV_BASE_BIAS) + + if len(path_splits) == CONDA_ENV_BASE_BIAS + 1: + conda_base_path, envs_str, conda_env, _, _ = path_splits + + if envs_str == 'envs': + activate_script = os.path.join(conda_base_path, "bin", "activate") + if os.path.exists(activate_script): + check_file_owner_and_permission(activate_script) + return conda_env, f"source {activate_script} {conda_env}" + + return None, "source ~/.bashrc" + + except Exception as e: + logger.warning("Failed to get conda environment info: %s. Falling back to source ~/.bashrc", str(e)) + return None, "source ~/.bashrc" + + +def start_precheck_runner(args: PrecheckRunnerArgsManager): + logger.info("Starting precheck runner with arguments: %s", args) + + dist_config = DistributeManager(args) + logger.info("Command line arguments: %s", sys.argv) + logger.info("Distributed configuration: %s", dist_config) + + profiler_res_dir_base = args.node_prof_save_dir + transporter_res_dir_base = args.master_prof_gather_dir + advisor_res_dir_base = args.master_prof_gather_dir + + PathManager.make_dir_safety(profiler_res_dir_base) + PathManager.make_dir_safety(transporter_res_dir_base) + PathManager.make_dir_safety(advisor_res_dir_base) + + prof_node_res_dir = profiler_res_dir_base + logger.info("Profiler results directory: %s", prof_node_res_dir) + + # start profiling + logger.info("Starting profiler runner") + env_name, conda_activate_cmd = get_conda_envs_info() + if env_name is None: + logger.warning("No conda environment found. Using system environment.") + else: + logger.info("Using conda environment: %s", env_name) + + profiler_example_name = Constant.DEFAULT_PROFILING_COMMANDS.get(args.profiling_cmd, None) + if profiler_example_name is None: + profiling_cmd = [ + "/bin/bash", "-ic", + f"{conda_activate_cmd} && cd {os.getcwd()} && " + f"MASTER_ADDR={dist_config.master_addr} MASTER_PORT={dist_config.master_port} " + f"NNODES={dist_config.nnodes} NODE_RANK={dist_config.node_rank} " + f"NPROC_PER_NODE={dist_config.nproc_per_node} " + f"{args.profiling_cmd}" + ] + else: + profiler_example_base = os.path.join(os.path.dirname(os.path.dirname(__file__)), "examples", "profiler", ) + + profiling_cmd = [ + "/bin/bash", "-ic", + f"{conda_activate_cmd} && cd {os.getcwd()} && " + f"torchrun " + f"--master_addr={dist_config.master_addr} " + f"--master_port={dist_config.master_port} " + f"--nproc_per_node={dist_config.nproc_per_node} " + f"--nnodes={dist_config.nnodes} " + f"--node_rank={dist_config.node_rank} " + f"{os.path.join(profiler_example_base, 'train_with_profiler.py')} " + f"--example_name {profiler_example_name} " + f"--prof_output_dir {prof_node_res_dir}" + + (" --static" if args.static else "") + ] + + logger.info("Using custom profiling command: %s", ' '.join(profiling_cmd)) + try: + logger.info("Executing profiling command...") + subprocess.run(profiling_cmd, check=True, capture_output=False, text=True) + logger.info("Profiling command completed successfully") + except subprocess.CalledProcessError as e: + logger.error("Profiling command failed with error: %s", e, exc_info=Constant.ENABLE_STACKTRACE_LOGGING) + raise + + # zip and transport to master + if args.prof_in_shared_storage: + logger.info("Skipping data collection as profiling data is in shared storage") + prof_gather_dir = prof_node_res_dir + else: + logger.info("Starting collector runner") + CollectorRunner(src_dir=prof_node_res_dir, des_dir=transporter_res_dir_base, config=dist_config).run() + prof_gather_dir = transporter_res_dir_base + + # analyse the gathered files + if dist_config.rank == 0: + logger.info("Starting advisor runner") + AdvisorRunner( + src_dir=prof_gather_dir, + des_dir=advisor_res_dir_base, + config=dist_config, + is_shared_storage=args.prof_in_shared_storage + ).run() + + logger.info("Completed precheck runner execution") + + +def main(args=None): + global logger + output_dir = os.path.join(args.output_dir, args.task_name) + PathManager.make_dir_safety(output_dir) + + timestamp = cn_now().strftime('%Y%m%d_%H%M%S') + log_file_path = os.path.join(output_dir, f'precheck_runner_{timestamp}.log') + logger = add_file_handler(logger, log_file_path) + + try: + start_precheck_runner(args) + except Exception as e: + logger.error("Precheck runner failed with error: %s", e, exc_info=Constant.ENABLE_STACKTRACE_LOGGING) diff --git a/profiler/msprof_analyze/precheck/runner/runners.py b/profiler/msprof_analyze/precheck/runner/runners.py new file mode 100644 index 0000000000000000000000000000000000000000..f46dc398a7fe1a5d02733be3428c4fb30f649f43 --- /dev/null +++ b/profiler/msprof_analyze/precheck/runner/runners.py @@ -0,0 +1,151 @@ +import os +import subprocess +import zipfile +import glob +import logging + +from msprof_analyze.precheck.common.constant import Constant +from msprof_analyze.precheck.manager.distribute_manager import DistributeManager +from msprof_analyze.precheck.tools.archive_utils import create_archive, extract_archive, ArchiveConfig, \ + compare_directory_with_archive +from msprof_analyze.prof_common.path_manager import PathManager + +logger = logging.getLogger(__name__) + + +class AdvisorRunner: + def __init__(self, src_dir, des_dir, config: DistributeManager, *args, **kwargs): + self.src_dir = src_dir + self.dest_dir = des_dir + self.config = config + self.is_shared_storage = kwargs.get('is_shared_storage', False) + + logger.info('%s init, args: %s, kwargs: %s', self.__class__.__name__, args, kwargs) + self.archive_extract_dir = os.path.join(self.dest_dir, 'prof_unzipped') + + def prepare_analysis_dir(self): + """Prepare directory for analysis, either by extracting archives or using source directly""" + if self.is_shared_storage: + logger.info("Using shared storage directory directly: %s", self.src_dir) + return self.src_dir + + logger.info("Preparing analysis directory by extracting archives") + PathManager.make_dir_safety(self.archive_extract_dir) + + archives_found = False + for root, _, files in os.walk(self.src_dir): + for file in files: + if any(file.endswith(ext) for ext in ['.zip', '.tar', '.tar.gz', '.tgz', ]): + archives_found = True + archive_path = os.path.join(root, file) + logger.info("Extracting archive: %s", archive_path) + extract_archive(archive_path, self.archive_extract_dir) + + if not archives_found: + logger.info("No archives found in %s, using source directory directly", self.src_dir) + return self.src_dir + + return self.archive_extract_dir + + def run(self): + if self.config.node_rank == 0 and self.config.local_rank == 0: + analysis_dir = self.prepare_analysis_dir() + self.run_analyzer(analysis_dir) + + def run_analyzer(self, analysis_dir): + """Find and process ascend_pt files in the analysis directory""" + + def call_analyzer(input_profiling_path, output_path): + from msprof_analyze.precheck.manager.task_manager import TaskManager + all_analyzer = TaskManager.all_analyzer + for analyzer in all_analyzer: + TaskManager.get_result(analyzer_name=analyzer, + input_path=input_profiling_path, + output=output_path) + + ascend_pt_dirs = glob.glob(os.path.join(analysis_dir, "*_ascend_pt"), recursive=False) + + if ascend_pt_dirs: + logger.info("Found %d ascend_pt directories in %s:", len(ascend_pt_dirs), analysis_dir) + for ascend_pt_dir in ascend_pt_dirs: + logger.debug("Found ascend_pt directory: %s", ascend_pt_dir) + + call_analyzer(analysis_dir, self.dest_dir) + else: + logger.warning("No ascend_pt files found in %s", analysis_dir) + + +class CollectorRunner: + def __init__(self, src_dir, des_dir, config: DistributeManager): + self.src_dir = os.path.abspath(src_dir) + self.des_dir = os.path.abspath(des_dir) + self.config = config + + logger.info('%s init', self.__class__.__name__) + + @staticmethod + def zip_directory(src_dir): + """Zip the specified directory.""" + zip_file_path = f"{src_dir}.zip" + + logger.info('Start zipping directory %s to %s', src_dir, zip_file_path) + + # Check if zip file already exists and contents match + if os.path.exists(zip_file_path): + logger.info('Found existing zip file: %s', zip_file_path) + logger.info('Comparing contents with source directory...') + + if compare_directory_with_archive(src_dir, zip_file_path): + logger.info('Existing zip matches source - reusing zip file') + return zip_file_path + + logger.info('Existing zip differs from source - creating new zip') + + # Create new zip file + create_archive(ArchiveConfig( + src_dir=src_dir, + output_path=zip_file_path, + whitelist=Constant.PROFILER_FILE_PATTERNS, + use_regex=True, + regex_fullmatch=False, + )) + + logger.info('Successfully created new zip file %s', zip_file_path) + + return zip_file_path + + def run(self): + zip_file = self.zip_directory(self.src_dir) + + self.transport(zip_file) + + def transport(self, zip_file): + """Transport the zip file to the destination.""" + + def run_collector(input_file_dir, output_file_dir: str, config: DistributeManager): + args_dict = { + "input_file_dir": input_file_dir, + "output_file_dir": output_file_dir, + "nnodes": config.nnodes, + "node_rank": config.node_rank, + "master_addr": config.master_addr, + "master_port": config.master_port, + "master_rank_num": Constant.COLLECTOR_MASTER_RANK_NUM, + "split_file_size": Constant.COLLECTOR_SPLIT_FILE_SIZE, + "time_out": Constant.COLLECTOR_DEFAULT_TIMEOUT, + "log_file": None + } + + from msprof_analyze.precheck.collect.collector import Collector + Collector().run(args_dict) + + run_collector(zip_file, self.des_dir, self.config) + + if self.config.node_rank == 0 or self.config.master_addr in Constant.LOCALHOST_ADDRESSES: + mv_command = ['cp', zip_file, self.des_dir] + logger.info("[rank=%s] %s", self.config.rank, mv_command) + subprocess.run(mv_command, check=True) + else: + pass + + logger.info("[rank=%s] Successfully transferred %s to %s", self.config.rank, zip_file, self.des_dir) diff --git a/profiler/msprof_analyze/precheck/tools/__init__.py b/profiler/msprof_analyze/precheck/tools/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/profiler/msprof_analyze/precheck/tools/archive_utils.py b/profiler/msprof_analyze/precheck/tools/archive_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..236413a3052174d07b7463ffb0d42042b50daa59 --- /dev/null +++ b/profiler/msprof_analyze/precheck/tools/archive_utils.py @@ -0,0 +1,274 @@ +import glob +import os +import zipfile +import tarfile +import logging +import re +import fnmatch +from dataclasses import dataclass +from typing import List, Optional + +from msprof_analyze.precheck.common.constant import Constant +from msprof_analyze.prof_common.path_manager import PathManager + +logger = logging.getLogger(__name__) + + +@dataclass +class ArchiveConfig: + src_dir: str + output_path: str + use_tar: bool = False + whitelist: Optional[List[str]] = None + blacklist: Optional[List[str]] = None + use_regex: bool = False + regex_fullmatch: bool = True + + +def create_archive(archive_args: ArchiveConfig): + """ + Create a zip or tar archive from a source directory. + + The archive will contain files from the source directory that match the whitelist + patterns (if specified) and don't match the blacklist patterns (if specified). + Patterns can be either glob patterns or regular expressions based on the use_regex flag. + + For regex patterns: + - If regex_fullmatch is True, the entire path must match the pattern + - If regex_fullmatch is False, the pattern can match anywhere in the path + + For glob patterns: + - Standard glob syntax is used (*, ?, [seq], [!seq]) + - Patterns are matched against the full relative path + + Args: + archive_args: Configuration object containing: + src_dir: Source directory to archive + output_path: Output path for the archive file + use_tar: If True create tar.gz, if False create zip + whitelist: List of patterns to include + blacklist: List of patterns to exclude + use_regex: If True use regex patterns, if False use glob + regex_fullmatch: If True require full regex match + + """ + + if not os.path.exists(archive_args.src_dir): + raise ValueError(f"Source directory '{archive_args.src_dir}' does not exist") + + save_dir = os.path.dirname(archive_args.output_path) + if not os.path.exists(save_dir): + raise ValueError(f"Destination directory '{save_dir}' does not exist") + + logger.info("Creating %s archive: %s", 'tar' if archive_args.use_tar else 'zip', archive_args.output_path) + logger.debug("Source directory: %s", archive_args.src_dir) + + if archive_args.use_regex: + if archive_args.whitelist: + whitelist = [re.compile(pattern) for pattern in archive_args.whitelist] + else: + whitelist = None + if archive_args.blacklist: + blacklist = [re.compile(pattern) for pattern in archive_args.blacklist] + else: + blacklist = None + else: + whitelist = archive_args.whitelist + blacklist = archive_args.blacklist + + def should_include_file(relative_path): + # Define pattern matching functions + def regex_fullmatch(pattern): + return pattern.fullmatch(relative_path) + + def regex_search(pattern): + return pattern.search(relative_path) + + def glob_match(pattern): + return fnmatch.fnmatch(relative_path, pattern) + + # Choose pattern matcher based on args + if archive_args.use_regex: + if archive_args.regex_fullmatch: + pattern_matcher = regex_fullmatch + else: + pattern_matcher = regex_search + else: + pattern_matcher = glob_match + + # Check blacklist first + if blacklist and any(map(pattern_matcher, blacklist)): + return False + + # If no whitelist, include all non-blacklisted files + if not whitelist: + return True + + # Check whitelist + return any(map(pattern_matcher, whitelist)) + + # Get all files in source directory recursively + abs_files = glob.glob(os.path.join(archive_args.src_dir, '**', '*'), recursive=True) + files = [os.path.relpath(file, archive_args.src_dir) for file in abs_files] + + files_to_add = [ + file for file_abs_path, file in zip(abs_files, files) + if should_include_file(file) and os.path.isfile(file_abs_path) + ] + + logger.info("Has found %d files to add at path: %s", len(files_to_add), archive_args.src_dir) + + # Process files based on archive type (tar or zip) + def add_files_to_tar(files_to_add): + with tarfile.open(archive_args.output_path, 'w:gz') as f: + for file in files_to_add: + file_path = os.path.join(archive_args.src_dir, file) + f.add(file_path, arcname=file) + + def add_files_to_zip(files_to_add): + with zipfile.ZipFile(archive_args.output_path, 'w', zipfile.ZIP_DEFLATED) as f: + for file in files_to_add: + file_path = os.path.join(archive_args.src_dir, file) + f.write(file_path, arcname=file) + + if archive_args.use_tar: + add_files_to_tar(files_to_add) + else: + add_files_to_zip(files_to_add) + + logger.info("Archive created successfully: %s", archive_args.output_path) + + +def _check_safe_zip(archive_file, max_archive_ratio=None, + max_size=Constant.MAX_ARCHIVE_SIZE, + max_file_count=Constant.MAX_ARCHIVE_FILE_COUNT, + ): + PathManager.check_path_readable(archive_file) + + archive_size = os.path.getsize(archive_file) + if max_archive_ratio is not None: + max_size = max(max_size, max_archive_ratio * archive_size) + + try: + with zipfile.ZipFile(archive_file, 'r') as zip_ref: + total_size = 0 + total_file_count = 0 + for info in zip_ref.infolist(): + total_size += info.file_size + total_file_count += 1 + if total_size > max_size: + raise RuntimeError("Archive size exceeds the limit") + if total_file_count > max_file_count: + raise RuntimeError("Archive file count exceeds the limit") + except (zipfile.BadZipFile, OSError) as e: + logger.error("Error reading zip file %s: %s", archive_file, e) + raise + + +def _check_safe_tar(archive_file, max_archive_ratio=None, + max_size=Constant.MAX_ARCHIVE_SIZE, + max_file_count=Constant.MAX_ARCHIVE_FILE_COUNT, + ): + PathManager.check_path_readable(archive_file) + + archive_size = os.path.getsize(archive_file) + if max_archive_ratio is not None: + max_size = max(max_size, max_archive_ratio * archive_size) + + try: + with tarfile.open(archive_file, 'r:*') as tar_ref: + total_size = 0 + total_file_count = 0 + for member in tar_ref.getmembers(): + total_size += member.size + total_file_count += 1 + if total_size > max_size: + raise RuntimeError("Archive size exceeds the limit") + if total_file_count > max_file_count: + raise RuntimeError("Archive file count exceeds the limit") + except (tarfile.TarError, OSError) as e: + logger.error("Error reading tar file %s: %s", archive_file, e) + raise + + +def _unzip(zip_file, extract_dir): + """Extract contents from a zip archive""" + + _check_safe_zip(zip_file, max_archive_ratio=Constant.MAX_ARCHIVE_RATIO) + with zipfile.ZipFile(zip_file, 'r') as zip_ref: + zip_ref.extractall(extract_dir) + logger.info("Unzipped %s to %s", zip_file, extract_dir) + + +def _untar(tar_file, extract_dir): + """Extract contents from a tar/tar.gz/tgz archive""" + + _check_safe_tar(tar_file, max_archive_ratio=Constant.MAX_ARCHIVE_RATIO) + with tarfile.open(tar_file, 'r:*') as tar_ref: # Auto-detect compression type + tar_ref.extractall(extract_dir) + logger.info("Untarred %s to %s", tar_file, extract_dir) + + +def extract_archive(archive_file, extract_dir): + """Extract contents from zip or tar archive files""" + + if archive_file.endswith('.zip'): + _unzip(archive_file, extract_dir) + elif archive_file.endswith('.tar') or archive_file.endswith('.tar.gz') or archive_file.endswith('.tgz'): + _untar(archive_file, extract_dir) + else: + logger.warning("Unsupported archive type: %s", archive_file) + + +def compare_directory_with_archive(src_dir: str, zip_file_path: str) -> bool: + """ + Compare contents of source directory with existing zip file. + + Args: + src_dir: Source directory path + zip_file_path: Path to zip file + + Returns: + bool: True if contents match, False otherwise + """ + # Get source files info + src_files = {} + for file_path in glob.glob(os.path.join(src_dir, "**"), recursive=True): + if os.path.isfile(file_path): + rel_path = os.path.relpath(file_path, src_dir) + src_files[rel_path] = os.path.getsize(file_path) + + # Compare with zip contents + with zipfile.ZipFile(zip_file_path, 'r') as existing_zip: + zip_files = { + info.filename: info.file_size + for info in existing_zip.filelist + } + + return src_files == zip_files + + +if __name__ == '__main__': + logging.basicConfig(level=logging.INFO) + + # Example usage with fnmatch whitelist, blacklist + config = ArchiveConfig( + src_dir="profiler/msprof_analyze/precheck/runner", + output_path="profiler/msprof_analyze/precheck/runner.zip", + whitelist=[r"tools/*", r"profiler/*", r"tests/*"], # Only include files in these directories + blacklist=[r"*.pyc"], # Exclude .pyc files + use_regex=False, + ) + + create_archive(config) + + # Example usage with regex whitelist, blacklist + config = ArchiveConfig( + src_dir="profiler/msprof_analyze/precheck/runner", + output_path="profiler/msprof_analyze/precheck/runner_regex.zip", + whitelist=[r"tools/.*", r"profiler/.*", r"tests/.*"], + blacklist=[r".*\.pyc$"], + use_regex=True, + ) + + create_archive(config) diff --git a/profiler/msprof_analyze/precheck/tools/ssh_utils.py b/profiler/msprof_analyze/precheck/tools/ssh_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c99c828d15ecda1ff13ad3ad7a2af885431a64e7 --- /dev/null +++ b/profiler/msprof_analyze/precheck/tools/ssh_utils.py @@ -0,0 +1,264 @@ +import getpass +import ipaddress +import os +import logging +import re +import subprocess +import shlex +from dataclasses import dataclass +from typing import List, Union + +from msprof_analyze.precheck.common.constant import Constant +from msprof_analyze.precheck.common.utils import cn_now +from msprof_analyze.prof_common.path_manager import PathManager + +logger = logging.getLogger(__name__) + + +@dataclass +class SSHConfig: + host: str + username: str + key_file: str + port: int = 22 + timeout: int = 3 + + def __post_init__(self): + """ Validate all fields after initialization """ + error = _check_ip_valid(self.host) + if error: + raise RuntimeError(f"Invalid host {self.host}: {error}") + + error = _check_int_range(self.port, min_value=1, max_value=Constant.ARG_MAX_INT_VALUE) + if error: + raise RuntimeError(f"Invalid port {self.port}: {error}") + + error = _check_ssh_key_file_valid(self.key_file) + if error: + raise RuntimeError(f"Invalid SSH key file {self.key_file}: {error}") + + error = _check_identifier_valid(self.username) + if error: + raise RuntimeError(f"Invalid username {self.username}: {error}") + + error = _check_int_range(self.timeout, min_value=1) + if error: + raise RuntimeError(f"Invalid timeout {self.timeout}: {error}") + + +def _check_ip_valid(ip: str) -> Union[Exception, None]: + try: + ipaddress.ip_address(ip) + except ValueError as e: + return e + return None + + +def _check_int_range( + value: int, min_value: int = Constant.ARG_MIN_INT_VALUE, max_value: int = Constant.ARG_MAX_INT_VALUE +) -> Union[Exception, None]: + if not (min_value <= value <= max_value): + return ValueError(f"The value must be between {min_value} and {max_value}.") + return None + + +def _check_identifier_valid(identifier: str) -> Union[Exception, None]: + pattern = r'^[a-zA-Z_][a-zA-Z0-9_-]*$' + if not re.match(pattern, identifier): + return ValueError(f"It must start with a letter or underscore, " + f"followed by any number of letters, digits, underscores, or dashes.") + return None + + +def _check_ssh_key_file_valid(ssh_key_file: str) -> Union[Exception, None]: + try: + expanded_path = os.path.expanduser(ssh_key_file) + stat_info = os.stat(expanded_path) + current_uid = os.getuid() + + # check file owner + if stat_info.st_uid != current_uid: + return ValueError(f"SSH key file {ssh_key_file} must be owned by the current user") + # check permissions to only read and write by owner + if stat_info.st_mode & 0o777 != 0o600: + return ValueError(f"SSH key file {ssh_key_file} must have permissions set to 600") + + return None + + except FileNotFoundError: + return ValueError(f"SSH key file {ssh_key_file} does not exist") + except PermissionError: + return ValueError(f"Permission denied when accessing SSH key file {ssh_key_file}") + + +def execute_ssh_command(config: SSHConfig, command: str) -> dict: + """ + Execute a command directly on a remote host using SSH without using tmux. + + Args: + config (SSHConfig): SSH configuration + command (str): Command to run on the remote host + + Returns: + dict: Dict containing command execution status and output with keys: + - success (bool): Whether the command was executed successfully + - output (str): Output from the command execution + """ + if not isinstance(config, SSHConfig): + raise ValueError("config must be an instance of SSHConfig") + + ssh_prefix = f"ssh -o ConnectTimeout={config.timeout} -p {config.port} {config.username}@{config.host}" + if config.key_file: + ssh_prefix += f" -i {config.key_file}" + + try: + result = subprocess.run([*shlex.split(ssh_prefix), command], capture_output=True, text=True, check=True) + return { + 'success': True, + 'output': result.stdout + } + except subprocess.CalledProcessError as e: + logger.error("SSH command failed on %s: %s", config.host, e, exc_info=Constant.ENABLE_STACKTRACE_LOGGING) + return { + 'success': False, + 'output': e.stderr + } + + +def execute_ssh_command_in_tmux(config: SSHConfig, session_name: str, command: str) -> dict: + """ + Connect to remote host using system ssh command, start or update tmux session and run command + + Args: + config (SSHConfig): SSH configuration + session_name (str): Base name for tmux session + command (str): Command to run in tmux session + + Returns: + dict: Dict containing session info with keys: + - session_name (str): Name of tmux session + - win_name (str): Name of tmux window + - attach_cmd (str): Command to attach to tmux session + """ + if not isinstance(config, SSHConfig): + raise ValueError("config must be an instance of SSHConfig") + + error = _check_identifier_valid(session_name) + if error: + raise RuntimeError(f"Invalid session name {session_name}: {error}") + + win_name = cn_now().strftime("%H%M") + attach_cmd = "" + + try: + ssh_prefix = f"ssh -o ConnectTimeout={config.timeout} -p {config.port} {config.username}@{config.host}" + if config.key_file: + ssh_prefix += f" -i {config.key_file}" + + check_cmd = f"{ssh_prefix} 'tmux list-sessions | grep -q \"^{session_name}:\" && echo exists || echo new'" + result = subprocess.run(shlex.split(check_cmd), capture_output=True, text=True) + session_status = result.stdout.strip() + + escaped_command = command.replace("'", "\\'").replace('"', '\\"') + + tmux_cmd_suffix = f"script -f /tmp/tmux_output_{win_name} -c \"{escaped_command}\"; bash -i" + if session_status == "exists": + logger.info("Session '%s' exists on %s. Creating a new window with name '%s'.", + session_name, config.host, win_name) + tmux_cmd = f"tmux new-window -t {session_name} -n '{win_name}' '{tmux_cmd_suffix}'" + else: + logger.info( + "Session '%s' does not exist on %s. Creating a new session with name '%s'. " + "Creating a new window with name '%s'.", session_name, config.host, session_name, win_name) + tmux_cmd = f"tmux new-session -d -s {session_name} -n '{win_name}' '{tmux_cmd_suffix}'" + + logger.info("Running command to start session: %s", tmux_cmd) + + result = subprocess.run(shlex.split(ssh_prefix) + [tmux_cmd], capture_output=True, text=True, check=True) + + if result.stdout.strip(): + logger.info("Output from %s:\n%s", config.host, result.stdout) + + attach_cmd = f"tmux attach -t {session_name}:{win_name}" + logger.info('Session started. To attach to the session, run: "%s" in terminal on %s@%s', + attach_cmd, config.username, config.host) + + except Exception as e: + logger.error("Failed to connect to %s: %s", config.host, e, exc_info=Constant.ENABLE_STACKTRACE_LOGGING) + raise RuntimeError(f"Fail to start host {config.host}") from e + + return dict( + session_name=session_name, + win_name=win_name, + attach_cmd=attach_cmd, + ) + + +def run_remote_command(hosts_info: List[dict], session_name: str = None, using_tmux: bool = True) -> List[dict]: + """ + Execute specified commands on remote hosts using SSH, optionally within a tmux session. + + This function supports executing commands directly via SSH or within a tmux session for + better management of long-running processes. + + Args: + hosts_info (list of dict): Information about the hosts on which commands will be executed. + Each dictionary should contain: + - host (str): Hostname or IP address of the remote machine. + - username (str): SSH username for the remote host. + - key_filename (str, optional): Path to the SSH private key file. Defaults to '~/.ssh/id_rsa'. + - command (str): Command to be executed on the remote host. + - port (int, optional): SSH port number. Defaults to 22. + session_name (str, optional): Name to be used for the tmux session, if using tmux. Automatically generated + if not provided. + using_tmux (bool): Whether to execute the command within a tmux session. Defaults to True. + + Returns: + list of dict: Results from each host, with each dictionary containing: + - session_name (str): Name of the tmux session (if used). + - win_name (str): Name of the tmux window (if used). + - attach_cmd (str): Command to attach to the tmux session (if used). + """ + user = getpass.getuser() + if session_name is None: + session_name = f"auto_{user}_{cn_now().strftime('%m%d')}" + + results = [] + + for host_info in hosts_info: + config = SSHConfig( + host=host_info["host"], + username=host_info["username"], + key_file=host_info.get("key_filename", "~/.ssh/id_rsa"), + port=host_info.get("port", 22) + ) + config.key_file = PathManager.expanduser_for_argumentparser(config.key_file) + if using_tmux: + results.append(execute_ssh_command_in_tmux(config, session_name, host_info["command"])) + else: + results.append(execute_ssh_command(config, host_info["command"])) + + return results + + +def main(): + hosts = [{ + "host": "127.0.0.1", + "username": os.getenv("USER"), + "key_filename": "~/.ssh/id_ed25519", + "command": f"echo Hello!", + "port": 22 + }] + + run_remote_command(hosts) + + +if __name__ == "__main__": + logging.basicConfig( + level=logging.DEBUG, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + handlers=[ + logging.StreamHandler(), + ] + ) + main() diff --git a/profiler/msprof_analyze/prof_common/__init__.py b/profiler/msprof_analyze/prof_common/__init__.py index c2764ec2a520567abc0c7d119b222f5fea7c3b72..8b7e7544bb1bd466a9b223cb1f706422bcab9435 100644 --- a/profiler/msprof_analyze/prof_common/__init__.py +++ b/profiler/msprof_analyze/prof_common/__init__.py @@ -14,4 +14,4 @@ # limitations under the License. import os import sys -sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) +sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) \ No newline at end of file diff --git a/profiler/msprof_analyze/prof_common/base_node.py b/profiler/msprof_analyze/prof_common/base_node.py new file mode 100644 index 0000000000000000000000000000000000000000..e96c5521ca11b778e277df1d17fb26a88f9f988f --- /dev/null +++ b/profiler/msprof_analyze/prof_common/base_node.py @@ -0,0 +1,82 @@ +# Copyright (c) 2024 Huawei Technologies Co., Ltd +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from math import ceil +from queue import Queue + +from decimal import Decimal + +from msprof_analyze.prof_common.constant import Constant +from msprof_analyze.prof_common.trace_event_bean import TraceEventBean + + +class BaseNode: + def __init__(self, event: TraceEventBean, parent_node=None): + self._event = event + self._parent_node = parent_node + self._child_nodes = [] + + @property + def parent_node(self): + return self._parent_node + + @property + def child_nodes(self): + return self._child_nodes + + @property + def name(self): + return self._event.name + + @property + def start_time(self) -> Decimal: + return self._event.start_time + + @property + def end_time(self) -> Decimal: + return self._event.end_time + + @parent_node.setter + def parent_node(self, parent_node): + self._parent_node = parent_node + + def update_child_nodes(self, node): + self._child_nodes.append(node) + + def binary_search(self, ts_time): + if not self.child_nodes: + return Constant.INVALID_RETURN + right = len(self.child_nodes) - 1 + left = 0 + while right > left: + mid = left + ceil((right - left) / 2) + if ts_time >= self.child_nodes[mid].start_time: + left = mid + else: + right = mid - 1 + if self.child_nodes[left].start_time < ts_time < self.child_nodes[left].end_time: + return self.child_nodes[left] + return Constant.INVALID_RETURN + + def find_all_child_nodes(self) -> list: + result_data = [] + node_queue = Queue() + for child_node in self.child_nodes: + node_queue.put(child_node) + while not node_queue.empty(): + tree_node = node_queue.get() + result_data.append(tree_node) + for child_node in tree_node.child_nodes: + node_queue.put(child_node) + return result_data diff --git a/profiler/msprof_analyze/prof_common/constant.py b/profiler/msprof_analyze/prof_common/constant.py index 5353fc6d40f25cee9f1a10e9b734dc95573b3b2c..5601a9f77fe0332d8258664a40aa449fb6aafbd7 100644 --- a/profiler/msprof_analyze/prof_common/constant.py +++ b/profiler/msprof_analyze/prof_common/constant.py @@ -114,6 +114,37 @@ class Constant(object): DB = "db" INVALID = "invalid" + # profiler db tables + TABLE_AICORE_FREQ = "AICORE_FREQ" + TABLE_CANN_API = "CANN_API" + TABLE_COMMUNICATION_OP = "COMMUNICATION_OP" + TABLE_COMMUNICATION_TASK_INFO = "COMMUNICATION_TASK_INFO" + TABLE_COMPUTE_TASK_INFO = "COMPUTE_TASK_INFO" + TABLE_CONNECTION_IDS = "CONNECTION_IDS" + TABLE_CONNECTION_CATS = "connectionCats" + TABLE_ENUM_API_TYPE = "ENUM_API_TYPE" + TABLE_ENUM_HCCL_DATA_TYPE = "ENUM_HCCL_DATA_TYPE" + TABLE_ENUM_HCCL_LINK_TYPE = "ENUM_HCCL_LINK_TYPE" + TABLE_ENUM_HCCL_RDMA_TYPE = "ENUM_HCCL_RDMA_TYPE" + TABLE_ENUM_TRANSPORT_TYPE = "ENUM_TRANSPORT_TYPE" + TABLE_ENUM_MODULE = "ENUM_MODULE" + TABLE_MSTX_EVENT_TYPE = "MSTX_EVENT_TYPE" + TABLE_HOST_INFO = "HOST_INFO" + TABLE_META_DATA = "META_DATA" + TABLE_NPU_INFO = "NPU_INFO" + TABLE_OVERLAP_ANALYSIS = "OVERLAP_ANALYSIS" + TABLE_PYTORCH_API = "PYTORCH_API" + TABLE_RANK_DEVICE_MAP = "RANK_DEVICE_MAP" + TABLE_SESSION_TIME_INFO = "SESSION_TIME_INFO" + TABLE_STATUS_INFO = "status_info" + TABLE_STEP_TIME = "STEP_TIME" + TABLE_STRING_IDS = "STRING_IDS" + TABLE_TASK = "TASK" + TABLE_TASK_MPU_INFO = "TASK_MPU_INFO" + + # export_type + NOTEBOOK = "notebook" + # db name DB_COMMUNICATION_ANALYZER = "analysis.db" DB_CLUSTER_COMMUNICATION_ANALYZER = "cluster_analysis.db" @@ -126,14 +157,25 @@ class Constant(object): TABLE_HOST_INFO = "HostInfo" TABLE_RANK_DEVICE_MAP = "RankDeviceMap" TABLE_CLUSTER_BASE_INFO = "ClusterBaseInfo" + TABLE_CLUSTER_TIME_SUMMARY = "ClusterTimeSummary" + TABLE_COMMUNICATION_GROUP_MAPPING = "CommunicationGroupMapping" # data config key CONFIG = "config" EXPER_CONFIG = "experimental_config" EXPER_EXPORT_TYPE = "_export_type" + EXPORT_TYPE = "_export_type" # metadata key DISTRIBUTED_ARGS = "distributed_args" + PARALLEL_GROUP_INFO = "parallel_group_info" + + # parallel_info_key + GROUP_NAME = "group_name" + GLOBAL_RANKS = "global_ranks" + + # group name value + PP = "pp" # mode ALL = "all" @@ -180,6 +222,7 @@ class Constant(object): COMPARISON_PROFILING = 'Comparison Profiling: ' WAIT_TIME = "wait" TRANSMIT_TIME = "transmit" + DURATION_TIME = "duration" # compare type OPERATOR_COMPARE = "OperatorCompare" @@ -247,6 +290,10 @@ class Constant(object): VOID_STEP = -1 + # communication task type + NOTIFY_RECORD = "Notify_Record" + NOTIFY_WAIT = "Notify_Wait" + # advisor # timeline @@ -352,6 +399,7 @@ class Constant(object): PT_PROF_SUFFIX = "ascend_pt" ASCEND_PROFILER_OUTPUT = "ASCEND_PROFILER_OUTPUT" + KERNEL_DETAILS_CSV = "kernel_details.csv" CLUSTER_STEP_TIME_CSV = "cluster_step_trace_time.csv" CLUSTER_COMM_JSON = "cluster_communication.json" COMMUNICATION_JSON = "communication.json" @@ -379,6 +427,7 @@ class Constant(object): # Unit Conversion COMMUNICATION_B_TO_GB = 0.001 ** 3 US_TO_S = 0.001 ** 2 + TIME_UNIT_SCALE = 1000 WRITE_MODES = stat.S_IWUSR | stat.S_IRUSR | stat.S_IRGRP WRITE_FLAGS = os.O_WRONLY | os.O_CREAT | os.O_TRUNC @@ -396,9 +445,9 @@ class Constant(object): OPERATOR_TYPE = 1 VIRTUAL_TYPE = 9 - # json trace bar + # trace bar NPU_BAR = "Ascend Hardware" - COMM_BAR = "Communication" + HCCL_BAR = "HCCL" OVERLAP_BAR = "Overlap Analysis" # overlap_analysis event @@ -421,13 +470,13 @@ class Constant(object): CONCURRENT_MODE = "concurrent" PROFILER_DB_PATH = "profiler_db_path" + ANALYSIS_DB_PATH = "analysis_db_path" RANK_LIST = "rank_list" EXPORT_TYPE = "export_type" EXTRA_ARGS = "args" - STEP_RANGE = "step_range" - START_NS = "startNs" - END_NS = "endNs" # hccl_sum UINT32_BITS = 32 - UINT32_MASK = 0xffffffff \ No newline at end of file + UINT32_MASK = 0xffffffff + + INVALID_RANK_NUM = 4294967295 diff --git a/profiler/msprof_analyze/prof_common/database_service.py b/profiler/msprof_analyze/prof_common/database_service.py index 6b776d4d957a9491aeb5690cf456038c114c3590..1e51b787dcb3e2911f3d0795fefd95cd34bb68af 100644 --- a/profiler/msprof_analyze/prof_common/database_service.py +++ b/profiler/msprof_analyze/prof_common/database_service.py @@ -16,37 +16,13 @@ import pandas as pd from msprof_analyze.prof_common.db_manager import DBManager from msprof_analyze.prof_common.logger import get_logger -from msprof_analyze.prof_common.constant import Constant logger = get_logger() class DatabaseService: - TABLE_TS_DICT = { - "TASK": "startNs", - "COMMUNICATION_OP": "startNs", - "CANN_API": "startNs", - "PYTORCH_API": "startNs", - "MSTX_EVENTS": "startNs", - "GC_RECORD": "startNs", - "ACC_PMU": "timestampNs", - "NIC": "timestampNs", - "RoCE": "timestampNs", - "LLC": "timestampNs", - "SAMPLE_PMU_TIMELINE": "timestampNs", - "NPU_MEM": "timestampNs", - "NPU_MODULE_MEM": "timestampNs", - "NPU_OP_MEM": "timestampNs", - "HBM": "timestampNs", - "DDR": "timestampNs", - "HCCS": "timestampNs", - "PCIE": "timestampNs", - "AICORE_FREQ": "timestampNs" - } - - def __init__(self, db_path, step_range): + def __init__(self, db_path): self._db_path = db_path - self._step_range = step_range self._table_info = {} def add_table_for_query(self, table_name: str, columns=None): @@ -72,12 +48,7 @@ class DatabaseService: logger.warning(f"This table {table_name} does not exist in this database {self._db_path}.") continue columns_str = "*" if not columns else ",".join(columns) - if table_name in self.TABLE_TS_DICT and self._step_range: - where_str = f"where {self.TABLE_TS_DICT.get(table_name)} >= {self._step_range.get(Constant.START_NS)}" \ - f" and {self.TABLE_TS_DICT.get(table_name)} <= {self._step_range.get(Constant.END_NS)}" - else: - where_str = "" - query_sql = f"select {columns_str} from {table_name} {where_str}" + query_sql = f"select {columns_str} from {table_name}" try: data = pd.read_sql(query_sql, conn) result_data[table_name] = data diff --git a/profiler/msprof_analyze/prof_common/db_manager.py b/profiler/msprof_analyze/prof_common/db_manager.py index ac24ec8144f7a67c1796906d7e75ab25a7a7f71c..8740499c27edc9562ad2861b5da8d1a21f02dd0c 100644 --- a/profiler/msprof_analyze/prof_common/db_manager.py +++ b/profiler/msprof_analyze/prof_common/db_manager.py @@ -143,6 +143,41 @@ class DBManager: logger.error("conn is invalid param") return False + @staticmethod + def execute_sql(conn: any, sql: str, params: any = None) -> bool: + """ + execute sql + """ + try: + if isinstance(conn, sqlite3.Connection): + if params: + conn.cursor().execute(sql, params) + else: + conn.cursor().execute(sql) + conn.commit() + return True + except sqlite3.Error as err: + logger.error(err) + return False + logger.error("conn is invalid param") + return False + + @staticmethod + def executemany_sql(conn: any, sql: str, params: any) -> bool: + """ + execute many sql once + """ + try: + if isinstance(conn, sqlite3.Connection): + conn.cursor().executemany(sql, params) + conn.commit() + return True + except sqlite3.Error as err: + logger.error(err) + return False + logger.error("conn is invalid param") + return False + @classmethod def check_tables_in_db(cls, db_path: any, *tables: any) -> bool: if check_db_path_valid(db_path): @@ -249,6 +284,21 @@ class DBManager: cls.insert_data_into_table(conn, table_name, data) cls.destroy_db_connect(conn, curs) + @classmethod + def check_columns_exist(cls, curs: any, table_name: str, columns: set) -> any: + """ + check columns exist in table, return empty set if none of them exist, else return the set of existing columns + """ + if not isinstance(curs, sqlite3.Cursor): + return None + try: + curs.execute(f"PRAGMA table_info({table_name})") + table_columns = {col[1] for col in curs.fetchall()} + return columns & table_columns + except sqlite3.Error as err: + logger.error(err) + return None + class CustomizedDictFactory: @staticmethod diff --git a/profiler/msprof_analyze/prof_common/file_reader.py b/profiler/msprof_analyze/prof_common/file_reader.py new file mode 100644 index 0000000000000000000000000000000000000000..313933ba7f9334d8ce9273824aeba565c379a1cc --- /dev/null +++ b/profiler/msprof_analyze/prof_common/file_reader.py @@ -0,0 +1,86 @@ +# Copyright (c) 2024 Huawei Technologies Co., Ltd +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import csv +import json +import logging +import os + +from msprof_analyze.prof_common.path_manager import PathManager +from msprof_analyze.prof_common.constant import Constant + + +class FileReader: + DATA_FILE_AUTHORITY = 0o640 + DATA_DIR_AUTHORITY = 0o750 + + @classmethod + def read_json_file(cls, file_path: str) -> any: + PathManager.check_path_readable(file_path) + if not os.path.isfile(file_path): + raise FileNotFoundError("File not exists.") + file_size = os.path.getsize(file_path) + if file_size <= 0: + return [] + if file_size > Constant.MAX_FILE_SIZE_5_GB: + msg = f"The file({file_path}) size exceeds the preset max value, failed to read the file." + raise RuntimeError(msg) + try: + with open(file_path, "rt") as file: + json_data = json.loads(file.read()) + except Exception as e: + msg = f"Can't read file: {file_path}" + raise RuntimeError(msg) from e + return json_data + + @classmethod + def write_json_file(cls, output_path: str, data: dict, file_name: str, format_json: bool = False) -> None: + if not data: + return + output_file = os.path.join(output_path, file_name) + PathManager.check_path_writeable(output_path) + try: + with os.fdopen( + os.open(output_file, os.O_WRONLY | os.O_CREAT, cls.DATA_FILE_AUTHORITY), 'w' + ) as file: + indent = 4 if format_json else None + file.write(json.dumps(data, indent=indent)) + except Exception as e: + raise RuntimeError(f"Can't create the file: {output_file}") from e + + @classmethod + def read_csv_file(cls, file_path: str, bean_class: any = None) -> any: + PathManager.check_path_readable(file_path) + if not os.path.isfile(file_path): + raise FileNotFoundError("File not exists.") + file_size = os.path.getsize(file_path) + if file_size <= 0: + return [] + if file_size > Constant.MAX_FILE_SIZE_5_GB: + check_msg = input( + f"The file({file_path}) size exceeds the preset max value. Continue reading the file? [y/n]") + if check_msg.lower() != "y": + logging.warning(f"The user choose not to read the file: %s", file_path) + return [] + result_data = [] + try: + with open(file_path, newline="") as csv_file: + reader = csv.DictReader(csv_file) + for row in reader: + row_data = bean_class(row) if bean_class else row + result_data.append(row_data) + except Exception as e: + msg = f"Failed to read the file: {file_path}" + raise RuntimeError(msg) from e + return result_data diff --git a/profiler/msprof_analyze/prof_common/kernel_bean.py b/profiler/msprof_analyze/prof_common/kernel_bean.py new file mode 100644 index 0000000000000000000000000000000000000000..f1c90895fc4bf78dc6b7c98bc6d7d781b7308b38 --- /dev/null +++ b/profiler/msprof_analyze/prof_common/kernel_bean.py @@ -0,0 +1,47 @@ +# Copyright (c) 2024 Huawei Technologies Co., Ltd +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from msprof_analyze.prof_common.utils import convert_to_decimal + + +class KernelBean: + def __init__(self, data: dict): + self._name = data.get("Name", "") + self._op_type = data.get("Type", "") + self._core_type = data.get("Accelerator Core", "") + self._input_shape = data.get("Input Shapes", "").replace("\"", "") + self._input_type = data.get("Input Data Types", "") + self._input_format = data.get("Input Formats", "") + self._duration = data.get("Duration(us)", 0) + self._ts = data.get("Start Time(us)", "") + + @property + def start_time(self): + return convert_to_decimal(self._ts) + + @property + def end_time(self): + return self.start_time + convert_to_decimal(self.dur) + + @property + def is_computing_op(self): + return self._core_type != "HCCL" + + @property + def dur(self): + return float(self._duration) + + @property + def kernel_info(self): + return [self._name, self._op_type, self._core_type, self._input_shape, self._input_type, self.dur] diff --git a/profiler/msprof_analyze/prof_common/trace_event_bean.py b/profiler/msprof_analyze/prof_common/trace_event_bean.py new file mode 100644 index 0000000000000000000000000000000000000000..ea78b54df57f8a1d72517baf2c48748b13ab7847 --- /dev/null +++ b/profiler/msprof_analyze/prof_common/trace_event_bean.py @@ -0,0 +1,113 @@ +# Copyright (c) 2024 Huawei Technologies Co., Ltd +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from decimal import Decimal + +from msprof_analyze.prof_common.constant import Constant +from msprof_analyze.prof_common.utils import convert_to_decimal +from msprof_analyze.prof_common.analyze_dict import AnalyzeDict + + +class TraceEventBean(AnalyzeDict): + def __init__(self, data: dict, unique_id: str = None): + super().__init__(data) + self._id = unique_id + self._type = None + self._start_time = convert_to_decimal(self.ts) if self.ts else 0 + self._end_time = self._start_time + convert_to_decimal(self.dur) if self.dur else 0 + self._fwd_bwd_id = None + + @property + def unique_id(self): + return self._id + + @property + def start_time(self) -> Decimal: + return self._start_time + + @property + def step_id(self) -> int: + return self.name.split("#")[-1] + + @property + def end_time(self) -> Decimal: + return self._end_time + + @property + def kernel_info(self): + return [self.name, self.args.get("Task Type", ""), self.dur] + + @property + def event_type(self): + return self._type + + @property + def fwd_bwd_id(self): + return self._fwd_bwd_id + + @event_type.setter + def event_type(self, event_type): + self._type = event_type + + @fwd_bwd_id.setter + def fwd_bwd_id(self, fwd_bwd_id): + self._fwd_bwd_id = fwd_bwd_id + + def set_id(self, name_id): + self._id = name_id + + def is_cpu_op(self): + return self.cat == "cpu_op" + + def is_optimizer(self): + return self.cat == "cpu_op" and self.name.lower().startswith("optimizer") + + def is_nn_module(self): + return self.cat == "python_function" and self.name.lower().startswith("nn.module") + + def is_step(self): + return self.name.lower().startswith("profilerstep#") + + def is_torch_to_npu(self): + return self.cat == "async_npu" + + def is_fwd_bwd_flow(self): + return self.cat == "fwdbwd" + + def is_flow_start(self): + return self.ph == "s" + + def is_flow_end(self): + return self.ph == "f" + + def is_meta(self): + return self.ph == "M" + + def is_kernel_event(self, kernel_pid): + return self.ph == "X" and self.pid == kernel_pid + + def is_hccl_event(self, hccl_pid): + return self.ph == "X" and self.pid == hccl_pid and self.name.startswith("hcom_") + + def is_overlap_analysis_event(self, overlap_analysis_pid): + return self.ph == "X" and self.pid == overlap_analysis_pid + + def is_npu_process(self): + return self.ph == "M" and self.name == "process_name" and self.args.get("name", "") == Constant.NPU_BAR + + def is_hccl_process(self): + return self.ph == "M" and self.name == "process_name" and self.args.get("name", "") == Constant.HCCL_BAR + + def is_overlap_analysis_process(self): + return self.ph == "M" and self.name == "process_name" and self.args.get("name", "") == Constant.OVERLAP_BAR diff --git a/profiler/msprof_analyze/prof_common/tree_builder.py b/profiler/msprof_analyze/prof_common/tree_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..34b056e71bd9880cea0e3402da699a4fbadd150a --- /dev/null +++ b/profiler/msprof_analyze/prof_common/tree_builder.py @@ -0,0 +1,37 @@ +# Copyright (c) 2024 Huawei Technologies Co., Ltd +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from msprof_analyze.prof_common.trace_event_bean import TraceEventBean + + +class TreeBuilder: + @staticmethod + def build_tree(event_list: list, node_class: any, root_bean: any): + root_node = node_class(root_bean) + all_nodes = [root_node] + [None] * len(event_list) + event_list.sort(key=lambda x: x.start_time) + last_node = root_node + index = 1 + for event in event_list: + while last_node: + if last_node != root_node and event.start_time > last_node.end_time: + last_node = last_node.parent_node + continue + tree_node = node_class(event, last_node) + last_node.update_child_nodes(tree_node) + all_nodes[index] = tree_node + last_node = tree_node + index += 1 + break + return all_nodes diff --git a/profiler/msprof_analyze/prof_common/utils.py b/profiler/msprof_analyze/prof_common/utils.py index 005d8505c9ccd750d4856518961c62b4407eea1e..284c17c86e36b8fb87d2ea73ed7e3089f44fcbb6 100644 --- a/profiler/msprof_analyze/prof_common/utils.py +++ b/profiler/msprof_analyze/prof_common/utils.py @@ -12,13 +12,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import configparser import os from email.utils import parseaddr from typing import Dict, List from urllib.parse import urlparse +from decimal import Decimal + from msprof_analyze.prof_common.logger import get_logger from msprof_analyze.prof_common.path_manager import PathManager @@ -85,6 +86,15 @@ def convert_to_float(num): return 0 +def convert_to_decimal(data: any) -> Decimal: + try: + decimal_value = Decimal(data) + except Exception: + logger.error('Invalid profiling data which failed to convert data to decimal.') + return 0.0 + return decimal_value + + def convert_to_int(num): try: return int(num) diff --git a/profiler/msprof_analyze/prof_exports/base_stats_export.py b/profiler/msprof_analyze/prof_exports/base_stats_export.py index 65ccd69ecde0acb296308e0c37bec3323468ae34..59d58bdff5485a6ace0f2c12dadbf543ecd4b978 100644 --- a/profiler/msprof_analyze/prof_exports/base_stats_export.py +++ b/profiler/msprof_analyze/prof_exports/base_stats_export.py @@ -24,11 +24,11 @@ logger = get_logger() class BaseStatsExport: - def __init__(self, db_path, analysis_class, step_range): + def __init__(self, db_path, analysis_class): self._db_path = db_path self._analysis_class = analysis_class - self._step_range = step_range self._query = None + self.mode = Constant.ANALYSIS def get_query(self): return self._query @@ -39,10 +39,10 @@ class BaseStatsExport: if query is None: logger.error("query is None.") return None - conn, cursor = DBManager.create_connect_db(self._db_path, Constant.ANALYSIS) + conn, cursor = DBManager.create_connect_db(self._db_path, self.mode) data = pd.read_sql(query, conn) DBManager.destroy_db_connect(conn, cursor) return data except Exception as e: logger.error(f"File {self._db_path} read failed error: {e}") - return None + return None \ No newline at end of file diff --git a/profiler/msprof_analyze/prof_exports/cann_api_sum_export.py b/profiler/msprof_analyze/prof_exports/cann_api_sum_export.py index 0d3da94a001609cdbaed7d3f4646dc908d2b8c23..efdba81e94360e7f8e88801711fb2ff72fa5b47f 100644 --- a/profiler/msprof_analyze/prof_exports/cann_api_sum_export.py +++ b/profiler/msprof_analyze/prof_exports/cann_api_sum_export.py @@ -14,7 +14,6 @@ # limitations under the License. from msprof_analyze.prof_exports.base_stats_export import BaseStatsExport -from msprof_analyze.prof_common.constant import Constant QUERY = """ WITH @@ -32,7 +31,6 @@ WITH upper_quartile(endNs - startNs) AS upper_quartile_duration FROM CANN_API - {} GROUP BY name ), totals AS ( @@ -62,14 +60,6 @@ ORDER BY 2 DESC; class CannApiSumExport(BaseStatsExport): - def __init__(self, db_path, recipe_name, step_range): - super().__init__(db_path, recipe_name, step_range) - self._query = self.get_query_statement() - - def get_query_statement(self): - if self._step_range: - filter_statement = f"WHERE CANN_API.startNs >= {self._step_range.get(Constant.START_NS)} " \ - f"and CANN_API.startNs <= {self._step_range.get(Constant.END_NS)}" - else: - filter_statement = "" - return QUERY.format(filter_statement) + def __init__(self, db_path, recipe_name): + super().__init__(db_path, recipe_name) + self._query = QUERY diff --git a/profiler/msprof_analyze/prof_exports/cluster_time_summary_export.py b/profiler/msprof_analyze/prof_exports/cluster_time_summary_export.py new file mode 100644 index 0000000000000000000000000000000000000000..840359618f383b4606544832788b886fba6cad4b --- /dev/null +++ b/profiler/msprof_analyze/prof_exports/cluster_time_summary_export.py @@ -0,0 +1,101 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from msprof_analyze.prof_common.db_manager import DBManager +from msprof_analyze.prof_exports.base_stats_export import BaseStatsExport + + +class CommunicationTimeExport(BaseStatsExport): + QUERY = """ + SELECT + rdm.rankid AS rank, + si.value AS groupName, + (co.endNs - co.startNs) / 1000.0 AS communication_time, + sii.value AS opName, + step_time.id AS step + FROM + COMMUNICATION_OP co + CROSS JOIN + RANK_DEVICE_MAP rdm + JOIN + STRING_IDS si ON co.groupName = si.id + JOIN + STRING_IDS sii ON co.opName = sii.id + LEFT JOIN STEP_TIME step_time + ON co.startNs >= step_time.startNs + AND co.endNs <= step_time.endNs + """ + + def __init__(self, db_path, recipe_name): + super().__init__(db_path, recipe_name) + self._query = self.QUERY + + +class MemoryAndDispatchTimeExport(BaseStatsExport): + QUERY = """ + + WITH + computing AS ( + SELECT TASK.startNs, TASK.endNs, CANN_API.startNs as apiStartNs, 0 AS type + FROM COMPUTE_TASK_INFO + JOIN TASK + ON COMPUTE_TASK_INFO.globalTaskId = TASK.globalTaskId + AND TASK.startNs != TASK.endNs + JOIN CANN_API + ON CANN_API.connectionId = TASK.connectionId + ), + communication AS ( + SELECT COMMUNICATION_OP.startNs, COMMUNICATION_OP.endNs, CANN_API.startNs as apiStartNs, 1 AS type + FROM COMMUNICATION_OP + JOIN CANN_API + ON CANN_API.connectionId = COMMUNICATION_OP.connectionId + ), + memory AS ( + SELECT TASK.startNs, TASK.endNs, TASK.startNs as apiStartNs, 4 AS type + FROM TASK + WHERE + taskType = ( + SELECT id + FROM STRING_IDS + WHERE value='MEMCPY_ASYNC' + ) + ), + overlap AS ( + SELECT startNs, endNs, apiStartNs, type + FROM computing + UNION ALL + SELECT startNs, endNs, apiStartNs, type + FROM communication + UNION ALL + SELECT startNs, endNs, apiStartNs, type + FROM memory + ) + SELECT + overlap.startNs AS start, + overlap.endNs AS end, + (overlap.startNs - overlap.apiStartNs) / 1000.0 AS dispatch, + overlap.type, + step_time.id AS step + FROM overlap + LEFT JOIN STEP_TIME step_time + ON overlap.apiStartNs >= step_time.startNs + AND overlap.apiStartNs <= step_time.endNs + ORDER BY overlap.startNs, overlap.endNs + """ + + def __init__(self, db_path, recipe_name): + super().__init__(db_path, recipe_name) + self._query = self.QUERY + self.mode = None diff --git a/profiler/msprof_analyze/prof_exports/compute_op_sum_export.py b/profiler/msprof_analyze/prof_exports/compute_op_sum_export.py index f337925dc36ff8e26c782ab1ea1c00618ebf271c..ed41d128056368b7a0e35f51e78edd95ce746486 100644 --- a/profiler/msprof_analyze/prof_exports/compute_op_sum_export.py +++ b/profiler/msprof_analyze/prof_exports/compute_op_sum_export.py @@ -14,7 +14,6 @@ # limitations under the License. from msprof_analyze.prof_exports.base_stats_export import BaseStatsExport -from msprof_analyze.prof_common.constant import Constant QUERY = """ SELECT @@ -39,7 +38,6 @@ LEFT JOIN LEFT JOIN STRING_IDS AS INPUTSHAPES_IDS ON INPUTSHAPES_IDS.id == COMPUTE_TASK_INFO.inputShapes -{} """ QUERY_EXCLUDE_OPNAME = """ @@ -61,35 +59,18 @@ LEFT JOIN LEFT JOIN STRING_IDS AS INPUTSHAPES_IDS ON INPUTSHAPES_IDS.id == COMPUTE_TASK_INFO.inputShapes -{} """ class ComputeOpSumExport(BaseStatsExport): - def __init__(self, db_path, recipe_name, step_range): - super().__init__(db_path, recipe_name, step_range) - self._query = self.get_query_statement() - - def get_query_statement(self): - if self._step_range: - filter_statement = f"WHERE TASK.startNs >= {self._step_range.get(Constant.START_NS)} " \ - f"and TASK.startNs <= {self._step_range.get(Constant.END_NS)}" - else: - filter_statement = "" - return QUERY.format(filter_statement) + def __init__(self, db_path, recipe_name): + super().__init__(db_path, recipe_name) + self._query = QUERY class ComputeOpSumExportExcludeOpName(BaseStatsExport): - def __init__(self, db_path, recipe_name, step_range): - super().__init__(db_path, recipe_name, step_range) - self._query = self.get_query_statement() - - def get_query_statement(self): - if self._step_range: - filter_statement = f"WHERE TASK.startNs >= {self._step_range.get(Constant.START_NS)} " \ - f"and TASK.startNs <= {self._step_range.get(Constant.END_NS)}" - else: - filter_statement = "" - return QUERY_EXCLUDE_OPNAME.format(filter_statement) + def __init__(self, db_path, recipe_name): + super().__init__(db_path, recipe_name) + self._query = QUERY_EXCLUDE_OPNAME \ No newline at end of file diff --git a/profiler/msprof_analyze/prof_exports/ep_load_balance_export.py b/profiler/msprof_analyze/prof_exports/ep_load_balance_export.py new file mode 100644 index 0000000000000000000000000000000000000000..59acd6bdde7d23ad6d645d420c3f6f9fcf7cd2b6 --- /dev/null +++ b/profiler/msprof_analyze/prof_exports/ep_load_balance_export.py @@ -0,0 +1,41 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from msprof_analyze.prof_exports.base_stats_export import BaseStatsExport +from msprof_analyze.prof_common.constant import Constant + +GROUPED_MATMUL_QUERY = """ +SELECT + InputShapes_IDS.value AS "InputShapes" +FROM COMPUTE_TASK_INFO +JOIN TASK + ON COMPUTE_TASK_INFO.globalTaskId = TASK.globalTaskId +LEFT JOIN STRING_IDS AS InputShapes_IDS + ON InputShapes_IDS.id = COMPUTE_TASK_INFO.inputShapes +WHERE COMPUTE_TASK_INFO.opType = ( + SELECT id + FROM STRING_IDS + WHERE value = 'GroupedMatmul' +) +{} + """ + + +class InputShapeExport(BaseStatsExport): + + def __init__(self, db_path, recipe_name, step_range): + super().__init__(db_path, recipe_name, step_range) + filter_statement = "And TASK.startNs >= ? And TASK.endNs <= ?" if step_range else "" + self._query = GROUPED_MATMUL_QUERY.format(filter_statement) diff --git a/profiler/msprof_analyze/prof_exports/filter_db_export.py b/profiler/msprof_analyze/prof_exports/filter_db_export.py new file mode 100644 index 0000000000000000000000000000000000000000..048b20a260d25ec48c17bd2dc85d70f15b177910 --- /dev/null +++ b/profiler/msprof_analyze/prof_exports/filter_db_export.py @@ -0,0 +1,102 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from msprof_analyze.prof_exports.base_stats_export import BaseStatsExport +from msprof_analyze.prof_common.logger import get_logger + +logger = get_logger() + +FILTER_TABLES = ["MatMulV3", "MatMulV2", "GroupedMatmul", "FlashAttentionScore", "FlashAttentionScoreGrad"] +values_str = ', '.join([f"'{op_type}'" for op_type in FILTER_TABLES]) + +OP_QUERY = f""" +SELECT COMPUTE_TASK_INFO.* +FROM COMPUTE_TASK_INFO + WHERE + opType IN ( + SELECT + id + FROM + STRING_IDS + WHERE + value IN ({values_str}) + ) +""" + +TASK_QUERY = """ +SELECT TASK.* +FROM TASK +INNER JOIN COMPUTE_TASK_INFO +ON TASK.globalTaskId = COMPUTE_TASK_INFO.globalTaskId; +""" + +CANN_QUERY = """ +WITH all_connection_ids AS ( + SELECT connectionId + FROM TASK + UNION + SELECT connectionId + FROM COMMUNICATION_OP +) + +SELECT CANN_API.* +FROM CANN_API +INNER JOIN all_connection_ids +ON CANN_API.connectionId = all_connection_ids.connectionId; +""" + +PYTORCH_QUERY = """ +WITH all_connection_ids AS ( + SELECT connectionId + FROM TASK + UNION + SELECT connectionId + FROM COMMUNICATION_OP +) + +SELECT PYTORCH_API.* +FROM PYTORCH_API + +INNER JOIN all_connection_ids +ON PYTORCH_API.connectionId = all_connection_ids.connectionId; +""" + + +class OPFilter(BaseStatsExport): + + def __init__(self, db_path, recipe_name): + super().__init__(db_path, recipe_name) + self._query = OP_QUERY + + +class TaskFilter(BaseStatsExport): + + def __init__(self, db_path, recipe_name): + super().__init__(db_path, recipe_name) + self._query = TASK_QUERY + + +class CANNFilter(BaseStatsExport): + + def __init__(self, db_path, recipe_name): + super().__init__(db_path, recipe_name) + self._query = CANN_QUERY + + +class PYTORCHFilter(BaseStatsExport): + + def __init__(self, db_path, recipe_name): + super().__init__(db_path, recipe_name) + self._query = PYTORCH_QUERY \ No newline at end of file diff --git a/profiler/msprof_analyze/prof_exports/hccl_sum_export.py b/profiler/msprof_analyze/prof_exports/hccl_sum_export.py index c577d40c0f5ae1289d196bdd6d7cd306ebcbf01e..2470e059ffcfb116f1dad657de53d5aa7ddd865b 100644 --- a/profiler/msprof_analyze/prof_exports/hccl_sum_export.py +++ b/profiler/msprof_analyze/prof_exports/hccl_sum_export.py @@ -14,7 +14,6 @@ # limitations under the License. from msprof_analyze.prof_exports.base_stats_export import BaseStatsExport -from msprof_analyze.prof_common.constant import Constant QUERY = """ SELECT @@ -33,20 +32,11 @@ LEFT JOIN LEFT JOIN STRING_IDS AS GROUP_NAME_IDS ON GROUP_NAME_IDS.id == COMMUNICATION_OP.groupName -{} """ class HcclSumExport(BaseStatsExport): - def __init__(self, db_path, recipe_name, step_range): - super().__init__(db_path, recipe_name, step_range) - self._query = self.get_query_statement() - - def get_query_statement(self): - if self._step_range: - filter_statement = f"WHERE COMMUNICATION_OP.startNs >= {self._step_range.get(Constant.START_NS)} " \ - f"and COMMUNICATION_OP.startNs <= {self._step_range.get(Constant.END_NS)}" - else: - filter_statement = "" - return QUERY.format(filter_statement) + def __init__(self, db_path, recipe_name): + super().__init__(db_path, recipe_name) + self._query = QUERY diff --git a/profiler/msprof_analyze/prof_exports/mstx2commop_export.py b/profiler/msprof_analyze/prof_exports/mstx2commop_export.py new file mode 100644 index 0000000000000000000000000000000000000000..5ed239603e13a556ab888aaf1a8b65fc74adb29f --- /dev/null +++ b/profiler/msprof_analyze/prof_exports/mstx2commop_export.py @@ -0,0 +1,45 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from msprof_analyze.prof_exports.base_stats_export import BaseStatsExport + +QUERY = """ +SELECT + ta.startNs, + ta.endNs, + ta.connectionId, + si.value +FROM + MSTX_EVENTS ms +JOIN + TASK ta + ON ms.connectionId == ta.connectionId +JOIN + STRING_IDS si + ON ms.message == si.id +WHERE + si.value LIKE '%"streamId":%' + AND si.value LIKE '%"count":%' + AND si.value LIKE '%"dataType":%' + AND si.value LIKE '%"groupName":%' + AND si.value LIKE '%"opName":%' + """ + + +class Mstx2CommopExport(BaseStatsExport): + + def __init__(self, db_path, recipe_name): + super().__init__(db_path, recipe_name) + self._query = QUERY diff --git a/profiler/msprof_analyze/prof_exports/mstx_mark_export.py b/profiler/msprof_analyze/prof_exports/mstx_mark_export.py index 6a7f8d0c6d2f1b4cbbceb9323157215421b58464..9b561d9f066687efa373fcbba6dcaaae2e492eff 100644 --- a/profiler/msprof_analyze/prof_exports/mstx_mark_export.py +++ b/profiler/msprof_analyze/prof_exports/mstx_mark_export.py @@ -14,7 +14,6 @@ # limitations under the License. from msprof_analyze.prof_exports.base_stats_export import BaseStatsExport -from msprof_analyze.prof_common.constant import Constant QUERY = """ WITH @@ -27,7 +26,6 @@ WITH LEFT JOIN CONNECTION_IDS ON PYTORCH_API.connectionId == CONNECTION_IDS.id - {} ) SELECT MSG_IDS.value AS "msg", @@ -46,7 +44,6 @@ LEFT JOIN LEFT JOIN STRING_IDS AS MSG_IDS ON MSTX_EVENTS.message == MSG_IDS.id -{} ORDER BY MSTX_EVENTS.startNs """ @@ -54,16 +51,6 @@ ORDER BY class MstxMarkExport(BaseStatsExport): - def __init__(self, db_path, recipe_name, step_range): - super().__init__(db_path, recipe_name, step_range) - self._query = self.get_query_statement() - - def get_query_statement(self): - if self._step_range: - filter_statement_1 = f"WHERE PYTORCH_API.startNs >= {self._step_range.get(Constant.START_NS)} " \ - f"and PYTORCH_API.startNs <= {self._step_range.get(Constant.END_NS)}" - filter_statement_2 = f"WHERE MSTX_EVENTS.startNs >= {self._step_range.get(Constant.START_NS)} " \ - f"and MSTX_EVENTS.startNs <= {self._step_range.get(Constant.END_NS)}" - else: - filter_statement_1, filter_statement_2 = "", "" - return QUERY.format(filter_statement_1, filter_statement_2) + def __init__(self, db_path, recipe_name): + super().__init__(db_path, recipe_name) + self._query = QUERY diff --git a/profiler/msprof_analyze/prof_exports/mstx_step_export.py b/profiler/msprof_analyze/prof_exports/mstx_step_export.py index c8aec91b7e5ce5fb29fffebeb8668fec723e3fa8..3051a280ccb1c9eb2a83933c357948bcf59b4d1f 100644 --- a/profiler/msprof_analyze/prof_exports/mstx_step_export.py +++ b/profiler/msprof_analyze/prof_exports/mstx_step_export.py @@ -29,6 +29,6 @@ ORDER BY class MstxStepExport(BaseStatsExport): - def __init__(self, db_path, recipe_name, step_range): - super().__init__(db_path, recipe_name, step_range) + def __init__(self, db_path, recipe_name): + super().__init__(db_path, recipe_name) self._query = QUERY diff --git a/profiler/msprof_analyze/prof_exports/p2p_pairing_export.py b/profiler/msprof_analyze/prof_exports/p2p_pairing_export.py new file mode 100644 index 0000000000000000000000000000000000000000..2f6a73942619e1bad19eb5978893363acb6cca73 --- /dev/null +++ b/profiler/msprof_analyze/prof_exports/p2p_pairing_export.py @@ -0,0 +1,71 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from string import Template + +from msprof_analyze.cluster_analyse.common_func.table_constant import TableConstant +from msprof_analyze.prof_exports.base_stats_export import BaseStatsExport + + +QUERY = Template(""" +SELECT + co.opName AS "$opNameId", + siii.value AS "$opName", + co.startNs AS "$startTime", + co.endNs AS "$endTime", + rdm.rankId AS "$globalRank", + cti.srcRank AS "$srcRank", + cti.dstRank AS "$dstRank", + siiii.value AS "$taskType", + sii.value AS "$coGroupName", + si.value AS "$ctiGroupName" +FROM + COMMUNICATION_TASK_INFO cti + LEFT JOIN COMMUNICATION_OP co on cti.opId = co.opId + CROSS JOIN RANK_DEVICE_MAP rdm + JOIN STRING_IDS si on cti.groupName = si.id + JOIN STRING_IDS sii on co.groupName = sii.id + JOIN STRING_IDS siii on co.opName = siii.id + JOIN STRING_IDS siiii on cti.taskType = siiii.id +""") + + +class P2PPairingExport(BaseStatsExport): + + CO_OP_NAME = "opNameId" + OP_NAME = "opName" + START_TIME = "startTime" + END_TIME = "endTime" + GLOBAL_RANK = "globalRank" + SRC_RANK = "srcRank" + DST_RANK = "dstRank" + TASK_TYPE = "taskType" + CO_GROUP_NAME = "coGroupName" + CTI_GROUP_NAME = "ctiGroupName" + + + def __init__(self, db_path, recipe_name): + super().__init__(db_path, recipe_name) + self._query = QUERY.safe_substitute( + opNameId=self.CO_OP_NAME, + opName=self.OP_NAME, + startTime=self.START_TIME, + endTime=self.END_TIME, + globalRank=self.GLOBAL_RANK, + srcRank=self.SRC_RANK, + dstRank=self.DST_RANK, + taskType=self.TASK_TYPE, + coGroupName=self.CO_GROUP_NAME, + ctiGroupName=self.CTI_GROUP_NAME + ) diff --git a/profiler/msprof_analyze/prof_exports/pp_chart_export.py b/profiler/msprof_analyze/prof_exports/pp_chart_export.py new file mode 100644 index 0000000000000000000000000000000000000000..40103f04ccf5826832076ab53d9255c0484e5fff --- /dev/null +++ b/profiler/msprof_analyze/prof_exports/pp_chart_export.py @@ -0,0 +1,57 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from msprof_analyze.prof_common.db_manager import DBManager +from msprof_analyze.prof_common.constant import Constant +from msprof_analyze.prof_exports.base_stats_export import BaseStatsExport + + +class PPChartExport(BaseStatsExport): + QUERY = """ + SELECT + {} + MSG_IDS.value AS msg, + TASK.startNs, + TASK.endNs + FROM + MSTX_EVENTS + JOIN + TASK ON MSTX_EVENTS.connectionId = TASK.connectionId + JOIN + STRING_IDS AS MSG_IDS ON MSTX_EVENTS.message = MSG_IDS.id + {} + WHERE + msg LIKE '%forward%' + OR msg LIKE '%backward%' + OR msg LIKE '%WeightGradStore_pop%' + ORDER BY + TASK.startNs + """ + + def __init__(self, db_path, recipe_name): + super().__init__(db_path, recipe_name) + self._query = self._build_query(db_path) + + def _build_query(self, db_path): + str1 = "0 AS step," + str2 = "" + if DBManager.check_tables_in_db(db_path, Constant.TABLE_STEP_TIME): + str1 = "step_time.id AS step," + str2 = """ + LEFT JOIN STEP_TIME step_time + ON TASK.startNs >= step_time.startNs + AND TASK.endNs <= step_time.endNs + """ + return self.QUERY.format(str1, str2) diff --git a/profiler/msprof_analyze/prof_exports/slow_calc_export.py b/profiler/msprof_analyze/prof_exports/slow_calc_export.py new file mode 100644 index 0000000000000000000000000000000000000000..b5851c072b29a60d44d1cf26f1b97e40fcfce523 --- /dev/null +++ b/profiler/msprof_analyze/prof_exports/slow_calc_export.py @@ -0,0 +1,45 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from token import NAME +from msprof_analyze.prof_exports.base_stats_export import BaseStatsExport +from msprof_analyze.prof_common.constant import Constant + + +list_op_analysis = ["MatMulV3", "MatMulV2"] +values_str = ", ".join([f"'{i}'" for i in list_op_analysis]) + +SIMPLE_QUERY = f""" + select + COMPUTE_TASK_INFO.name as "opName", + COMPUTE_TASK_INFO.opType as "opType", + inputShapes, + outputShapes, + endNs - startNs as {Constant.DURATION_TIME}, + deviceId + from TASK + left join + COMPUTE_TASK_INFO + on COMPUTE_TASK_INFO.globalTaskId = TASK.globalTaskId + inner join STRING_IDS + on STRING_IDS.id = COMPUTE_TASK_INFO.opType + where STRING_IDS.value in ({values_str}) +""" + + +class SlowCalcExport(BaseStatsExport): + def __init__(self, db_path, recipe_name): + super().__init__(db_path, recipe_name) + self._query = SIMPLE_QUERY diff --git a/profiler/msprof_analyze/prof_exports/slow_link_export.py b/profiler/msprof_analyze/prof_exports/slow_link_export.py new file mode 100644 index 0000000000000000000000000000000000000000..c584ceb2b2afbbe89c180b5887a6b99e961d96e6 --- /dev/null +++ b/profiler/msprof_analyze/prof_exports/slow_link_export.py @@ -0,0 +1,54 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from msprof_analyze.prof_exports.base_stats_export import BaseStatsExport + +QUERY = """ + SELECT + si.value AS groupName, + co.endNs - co.startNs AS communicationTime, + sii.value AS opName, + op.value AS opType, + et.name AS dataType, + CASE + WHEN et.name = 'INT8' THEN 1 * co.count + WHEN et.name = 'INT16' THEN 2 * co.count + WHEN et.name = 'INT32' THEN 4 * co.count + WHEN et.name = 'INT64' THEN 8 * co.count + WHEN et.name = 'UINT64' THEN 8 * co.count + WHEN et.name = 'UINT8' THEN 1 * co.count + WHEN et.name = 'UINT16' THEN 2 * co.count + WHEN et.name = 'UINT32' THEN 4 * co.count + WHEN et.name = 'FP16' THEN 2 * co.count + WHEN et.name = 'FP32' THEN 4 * co.count + WHEN et.name = 'FP64' THEN 8 * co.count + WHEN et.name = 'BFP16' THEN 2 * co.count + WHEN et.name = 'INT128' THEN 16 * co.count + END AS dataSize + FROM + COMMUNICATION_OP co + CROSS + JOIN STRING_IDS si ON co.groupName = si.id + JOIN STRING_IDS sii ON co.opName = sii.id + JOIN ENUM_HCCL_DATA_TYPE et ON co.dataType = et.id + JOIN STRING_IDS op ON co.opType = op.id +""" + + +class SlowLinkExport(BaseStatsExport): + + def __init__(self, db_path, recipe_name): + super().__init__(db_path, recipe_name) + self._query = QUERY diff --git a/profiler/msprof_analyze/prof_exports/slow_rank_export.py b/profiler/msprof_analyze/prof_exports/slow_rank_export.py new file mode 100644 index 0000000000000000000000000000000000000000..6565ea9244f4fe61770f21311d2174a20267e118 --- /dev/null +++ b/profiler/msprof_analyze/prof_exports/slow_rank_export.py @@ -0,0 +1,38 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from msprof_analyze.prof_exports.base_stats_export import BaseStatsExport +from msprof_analyze.prof_common.constant import Constant + + +class SlowRankExport(BaseStatsExport): + QUERY = """ + SELECT + RANK_DEVICE_MAP.rankId, + si_group.value AS groupName, + si_op.value AS opName, + (COMMUNICATION_OP.endNs - COMMUNICATION_OP.startNs) / 1000.0 AS communication_time + FROM COMMUNICATION_OP + CROSS JOIN RANK_DEVICE_MAP + JOIN STRING_IDS si_group ON COMMUNICATION_OP.groupName = si_group.id + JOIN STRING_IDS si_op ON COMMUNICATION_OP.opName = si_op.id + JOIN CANN_API ON CANN_API.connectionId = COMMUNICATION_OP.connectionId + {} + """ + + def __init__(self, db_path, recipe_name, step_range): + super().__init__(db_path, recipe_name, step_range) + filter_statement = "WHERE CANN_API.startNs >= ? and CANN_API.startNs <= ?" if step_range else "" + self._query = self.QUERY.format(filter_statement) diff --git a/profiler/msprof_analyze/requirements/build.txt b/profiler/msprof_analyze/requirements/build.txt index 3ef20e787be3bad76de0ccde4dc3e3a1dbe63efb..9bb3af4b2a9cdb8401a8c9c44bc6140fc5dc80ec 100644 --- a/profiler/msprof_analyze/requirements/build.txt +++ b/profiler/msprof_analyze/requirements/build.txt @@ -7,7 +7,7 @@ tqdm prettytable ijson requests -xlsxwriter>=3.0.6 +xlsxwriter sqlalchemy urllib3<2.0 numpy<=1.26.4 diff --git a/profiler/msprof_analyze/test/ut/advisor/analyzer/__init__.py b/profiler/msprof_analyze/test/ut/advisor/analyzer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/profiler/msprof_analyze/test/ut/advisor/analyzer/computation/__init__.py b/profiler/msprof_analyze/test/ut/advisor/analyzer/computation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/profiler/msprof_analyze/test/ut/advisor/analyzer/computation/test_operator_checker.py b/profiler/msprof_analyze/test/ut/advisor/analyzer/computation/test_operator_checker.py new file mode 100644 index 0000000000000000000000000000000000000000..85b55e2313ba3483c7c37262460df936df54b75c --- /dev/null +++ b/profiler/msprof_analyze/test/ut/advisor/analyzer/computation/test_operator_checker.py @@ -0,0 +1,178 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest.mock import MagicMock, patch + +from msprof_analyze.advisor.analyzer.computation.operator_checker import OperatorChecker +from msprof_analyze.advisor.dataset.profiling.info_collection import OpInfo +from msprof_analyze.advisor.dataset.profiling.profiling_dataset import ProfilingDataset +from msprof_analyze.advisor.result.item import OptimizeRecord + + +class TestOperatorChecker(unittest.TestCase): + def setUp(self): + self.cann_version = "8.0.RC1" + self.checker = OperatorChecker(self.cann_version) + self.profiling_data = MagicMock(spec=ProfilingDataset) + self.op_info = MagicMock(spec=OpInfo) + self.op_info.has_attr.return_value = True + self.op_info.get_attr.return_value = "10" + self.op_info.task_duration = "10" + self.op_info.op_name = "" + + def test_init(self): + self.assertEqual(self.checker.cann_version, self.cann_version) + self.assertEqual(len(self.checker._op_list), 0) + self.assertEqual(len(self.checker._tune_op_list), 0) + + def test_get_ratio(self): + result = OperatorChecker.get_ratio(self.op_info, "attr") + self.assertEqual(result, 10) + + self.op_info.has_attr.return_value = False + result = OperatorChecker.get_ratio(self.op_info, "attr") + self.assertEqual(result, 0) + + self.op_info.has_attr.return_value = True + self.op_info.get_attr.return_value = None + result = OperatorChecker.get_ratio(self.op_info, "attr") + self.assertEqual(result, 0) + + def test_get_name(self): + self.checker._problem = "Test Problem" + result = self.checker.get_name() + self.assertEqual(result, "Test Problem") + + @patch.object(OperatorChecker, '_check_data') + @patch.object(OperatorChecker, '_check_operator') + def test_check(self, mock_check_operator, mock_check_data): + mock_check_data.return_value = True + mock_check_operator.return_value = True + self.profiling_data.op_summary = MagicMock() + self.profiling_data.op_summary.op_list = [self.op_info] + self.profiling_data.op_summary.get_total_task_duration.return_value = 100 + + result = self.checker.check(self.profiling_data) + self.assertTrue(result) + + @patch.object(OperatorChecker, 'get_incomes') + @patch.object(OperatorChecker, 'get_op_type_list') + @patch.object(OperatorChecker, '_get_description') + def test_make_record(self, mock_get_description, mock_get_op_type_list, mock_get_incomes): + mock_get_incomes.return_value = 100 + mock_get_op_type_list.return_value = ["OpType1"] + mock_get_description.return_value = "Test Description" + self.profiling_data.op_summary = MagicMock() + self.profiling_data.op_summary.get_total_task_duration.return_value = 1000 + + record = self.checker.make_record(self.profiling_data) + self.assertIsInstance(record, OptimizeRecord) + + def test_pre_check(self): + result = self.checker.pre_check(self.profiling_data) + self.assertTrue(result) + + @patch('msprof_analyze.advisor.analyzer.computation.operator_checker.EnumParamsParser.get_options') + def test_is_dynamic_shape(self, mock_get_options): + mock_get_options.return_value = ["7.0.RC1"] + self.checker.cann_version = "7.0.RC1" + self.profiling_data.ge_info = MagicMock() + self.profiling_data.ge_info.get_static_shape_operators.return_value = [] + + result = self.checker.is_dynamic_shape(self.profiling_data) + self.assertTrue(result) + + @patch.object(OperatorChecker, 'group_by') + def test_format_operator_result(self, mock_group_by): + mock_record = MagicMock() + mock_record.optimization_item.suggestion = [self.checker.pytorch_op_tune_suggestion] + mock_group_by.return_value = [] + + result = self.checker.format_operator_result(mock_record, 10) + self.assertIsInstance(result, dict) + + def test_group_by(self): + op_list = [self.op_info] + result = self.checker.group_by(op_list) + self.assertIsInstance(result, list) + + def test_get_tune_op_list(self): + self.checker._tune_op_list = ["Op1", "Op2"] + result = self.checker.get_tune_op_list() + self.assertEqual(result, ["Op1", "Op2"]) + + def test_get_views(self): + result = self.checker.get_views(None) + self.assertEqual(result, []) + + @patch.object(OperatorChecker, '_get_income') + def test_get_incomes(self, mock_get_income): + mock_get_income.return_value = 10 + self.checker._op_list = [self.op_info] + result = self.checker.get_incomes() + self.assertEqual(result, 10) + + def test_get_op_type_list(self): + self.op_info.op_type = "OpType1" + op_list = [self.op_info] + result = self.checker.get_op_type_list(op_list) + self.assertEqual(result, ["OpType1"]) + + def test_get_details(self): + self.checker._op_list = [self.op_info] + self.checker._ITEMS = ["attr"] + self.checker.STACK_INFO_ITEMS = "" + result = self.checker.get_details() + self.assertIsInstance(result, list) + + @patch('msprof_analyze.advisor.analyzer.computation.operator_checker.EnumParamsParser') + def test_format_suggestion_content(self, mock_enum_parser): + mock_enum_parser().profiling_type.ascend_pytorch_profiler = "pytorch" + self.profiling_data.prof_type = "pytorch" + + self.checker.format_suggestion_content(self.profiling_data) + self.assertIn(self.checker.pytorch_op_tune_suggestion, self.checker._suggestion) + + def test__check_data(self): + result = self.checker._check_data(self.profiling_data) + self.assertTrue(result) + + def test__check_operator(self): + result = self.checker._check_operator(self.op_info) + self.assertFalse(result) + + def test__get_income(self): + result = self.checker._get_income(self.op_info) + self.assertEqual(result, 0) + + def test__check_summary(self): + self.profiling_data.op_summary = None + result = self.checker._check_summary(self.profiling_data) + self.assertTrue(result) + + self.profiling_data.op_summary = MagicMock() + result = self.checker._check_summary(self.profiling_data) + self.assertTrue(result) + + def test__get_description(self): + description = "Test Description" + op_type_list = ["OpType1", "OpType2", "OpType3"] + result = self.checker._get_description(description, op_type_list) + self.assertIn("OpType1", result) + + +if __name__ == '__main__': + unittest.main() diff --git a/profiler/msprof_analyze/test/ut/advisor/analyzer/schedule/__init__.py b/profiler/msprof_analyze/test/ut/advisor/analyzer/schedule/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/profiler/msprof_analyze/test/ut/advisor/analyzer/schedule/fusible_ops/__init__.py b/profiler/msprof_analyze/test/ut/advisor/analyzer/schedule/fusible_ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/profiler/msprof_analyze/test/ut/advisor/analyzer/schedule/fusible_ops/test_fusible_operator_checker.py b/profiler/msprof_analyze/test/ut/advisor/analyzer/schedule/fusible_ops/test_fusible_operator_checker.py new file mode 100644 index 0000000000000000000000000000000000000000..2d93eef90d501defe9ec6973c782470220e75260 --- /dev/null +++ b/profiler/msprof_analyze/test/ut/advisor/analyzer/schedule/fusible_ops/test_fusible_operator_checker.py @@ -0,0 +1,112 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest +from unittest.mock import MagicMock, patch + +from msprof_analyze.advisor.analyzer.schedule.fusible_ops.fusible_operator_checker import FusibleOperatorChecker +from msprof_analyze.advisor.result.result import OptimizeResult + + +class TestFusibleOperatorChecker(unittest.TestCase): + def setUp(self): + self.checker = FusibleOperatorChecker() + + def test_get_mte_time(self): + task = MagicMock() + task.aic_mte2_time = '10' + task.aiv_mte2_time = '20' + task.aic_fixpipe_time = '30' + task.aiv_mte3_time = '40' + result = FusibleOperatorChecker.get_mte_time(task) + self.assertEqual(result, 60.0) + + def test_check_hccl(self): + task1 = MagicMock(task_type='COMMUNICATION', op_name='hcom_op') + task2 = MagicMock(task_type='OTHER', op_name='normal_op') + self.assertTrue(FusibleOperatorChecker.check_hccl(task1)) + self.assertFalse(FusibleOperatorChecker.check_hccl(task2)) + + def test_calculate_total_time(self): + pre_timestamp = '10' + timestamp = '20' + duration = '5' + result, flag = FusibleOperatorChecker.calculate_total_time(pre_timestamp, timestamp, duration) + self.assertEqual(result, 15) + self.assertTrue(flag) + + def test_check_sequence_ratio(self): + detail = [100, 0, 0, 0, False, False] + self.checker.step_duration = 50 + self.checker.sequence_duration_threshold = 0.1 + result = self.checker.check_sequence_ratio(detail) + self.assertTrue(result) + + def test_check_sequence_num(self): + detail = [0, 0, 0, 10, False, False] + self.checker.sequence_count_threshold = 5 + result = self.checker.check_sequence_num(detail) + self.assertTrue(result) + + def test_check_bound(self): + detail1 = [100, 0, 0, 0, True, False] + detail2 = [10, 0, 0, 0, False, False] + self.checker.step_duration = 50 + self.checker.sequence_duration_threshold = 0.3 + self.assertTrue(self.checker.check_bound(detail1)) + self.assertFalse(self.checker.check_bound(detail2)) + + def test_add_detail(self): + task_name = 'test_task' + details = [] + detail = [100, 50, 30, 2, True, False] + self.checker.index_dict[task_name] = (0, 1) + self.checker.add_detail(task_name, details, detail) + self.assertEqual(len(details), 1) + + def test_generate_key(self): + task = MagicMock(op_name='test_op', input_shapes='[1,2]', output_shapes='[2,3]') + result = self.checker.generate_key(task) + self.assertEqual(result, 'test_op-[1,2]-[2,3]') + + def test_compute_priority(self): + self.checker.host_details = [[100, 0, 0, 0, False, False]] + self.checker.mte_details = [] + self.checker.step_duration = 50 + result = self.checker.compute_priority() + from msprof_analyze.advisor.display.html.priority_background_color import PriorityBackgroundColor + self.assertEqual(result, PriorityBackgroundColor.high) + + def test_check_tasks(self): + profiling_dataset = MagicMock() + profiling_dataset.op_summary.op_list = [MagicMock()] + with patch('msprof_analyze.advisor.analyzer.schedule.fusible_ops.fusible_operator_checker.' \ + 'FusibleOperatorChecker.calculate_total_time') as mock_calculate: + mock_calculate.return_value = (100, True) + result = self.checker.check_tasks(profiling_dataset) + self.assertTrue(result) + + def test_make_record(self): + result = OptimizeResult() + self.checker.problem = 'test_problem' + self.checker.desc = 'test_desc' + self.checker.suggestions = ['test_suggestion'] + self.checker.host_details = [[1, 2, 100, 50, 30, 2, True, False]] + self.checker.mte_details = [[3, 4, 200, 100, 60, 3, False, True]] + self.checker.make_record(result) + self.assertTrue(result.page_dict) + + +if __name__ == '__main__': + unittest.main() diff --git a/profiler/msprof_analyze/test/ut/advisor/analyzer/schedule/fusion_ops/__init__.py b/profiler/msprof_analyze/test/ut/advisor/analyzer/schedule/fusion_ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/profiler/msprof_analyze/test/ut/advisor/analyzer/schedule/fusion_ops/test_fusion_ops_analyzer.py b/profiler/msprof_analyze/test/ut/advisor/analyzer/schedule/fusion_ops/test_fusion_ops_analyzer.py new file mode 100644 index 0000000000000000000000000000000000000000..39ca57680ef38d2b43df67ce7bc580fb1f935b58 --- /dev/null +++ b/profiler/msprof_analyze/test/ut/advisor/analyzer/schedule/fusion_ops/test_fusion_ops_analyzer.py @@ -0,0 +1,188 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest +from unittest.mock import patch, MagicMock + +from msprof_analyze.advisor.analyzer.schedule.fusion_ops.fusion_ops_analyzer import TimelineFusionOpsAnalyzer +from msprof_analyze.advisor.dataset.timeline_event_dataset import ScheduleAnalysisDataset +from msprof_analyze.advisor.display.html.priority_background_color import PriorityBackgroundColor +from msprof_analyze.prof_common.constant import Constant + + +class TestTimelineFusionOpsAnalyzer(unittest.TestCase): + def setUp(self): + self.collection_path = 'test_path' + self.analyzer = TimelineFusionOpsAnalyzer(self.collection_path) + self.analyzer.dataset_list = [MagicMock()] + self.analyzer.result = MagicMock() + self.analyzer.html_render = MagicMock() + + def test_init(self): + self.assertIsInstance(self.analyzer, TimelineFusionOpsAnalyzer) + self.assertEqual(len(self.analyzer.dataset_cls_list), 1) + self.assertEqual(self.analyzer.dataset_cls_list[0], ScheduleAnalysisDataset) + + def test_get_priority(self): + result = self.analyzer.get_priority() + self.assertEqual(result, PriorityBackgroundColor.low) + + @patch('os.getenv') + def test_optimize_skip_affinity_api(self, mock_getenv): + mock_getenv.return_value = 'true' + result = self.analyzer.optimize() + self.assertEqual(result, self.analyzer.result) + + def test_find_fusion_ops_no_regex(self): + event_dataset = MagicMock() + ops = 'permute-reshape' + npu_api = 'torch_npu_api' + mode = 'aten' + + with patch.object(self.analyzer, '_format_rule_to_pattern') as mock_format, \ + patch.object(self.analyzer, '_match_ops') as mock_match: + mock_format.return_value = (ops, False) + self.analyzer.find_fusion_ops(event_dataset, ops, npu_api, mode) + + mock_format.assert_called_once_with(ops) + mock_match.assert_called_once_with(event_dataset, ops, npu_api, mode) + + def test_find_fusion_ops_with_regex(self): + event_dataset = MagicMock() + ops = 'add-mul{0,10}' + npu_api = 'torch_npu_api' + mode = 'aten' + + with patch.object(self.analyzer, '_format_rule_to_pattern') as mock_format, \ + patch.object(self.analyzer, '_match_ops_with_regex') as mock_match: + mock_format.return_value = (ops, True) + self.analyzer.find_fusion_ops(event_dataset, ops, npu_api, mode) + + mock_format.assert_called_once_with(ops) + mock_match.assert_called_once_with(event_dataset, ops, npu_api, mode) + + def test_find_fusion_ops_with_regex_exception(self): + event_dataset = MagicMock() + ops = 'add-mul{0,10}' + npu_api = 'torch_npu_api' + mode = 'aten' + + with patch.object(self.analyzer, '_format_rule_to_pattern') as mock_format, \ + patch.object(self.analyzer, '_match_ops_with_regex') as mock_match: + mock_format.return_value = (ops, True) + mock_match.side_effect = Exception('Test exception') + self.analyzer.find_fusion_ops(event_dataset, ops, npu_api, mode) + + mock_format.assert_called_once_with(ops) + mock_match.assert_called_once_with(event_dataset, ops, npu_api, mode) + + def test_make_record_no_stacks(self): + self.analyzer.matched_op_stacks = {} + self.analyzer.make_record() + self.analyzer.result.add.assert_not_called() + + @patch('msprof_analyze.advisor.display.prompt.base_prompt.BasePrompt.get_prompt_class') + def test_make_record_with_stacks(self, mock_get_prompt_class): + mock_prompt_class = MagicMock() + mock_prompt_class.DESCRIPTION = 'Description {0} {1} {2}' + mock_prompt_class.SUGGESTION = 'Suggestion' + mock_prompt_class.PROBLEM = 'Problem' + mock_prompt_class.EMPTY_STACK_DESCRIPTION = '' + mock_prompt_class.EMPTY_STACKS_SUGGESTION = '' + mock_get_prompt_class.return_value = mock_prompt_class + + self.analyzer.matched_op_stacks = {'api_name': {'stack': 1}} + self.analyzer.make_record() + + self.analyzer.result.add.assert_called_once() + self.analyzer.result.add_detail.assert_called() + + def test_make_render(self): + self.analyzer.matched_op_stacks = {'api_name': {'stack': 1}} + self.analyzer.make_render(rank=1) + + self.analyzer.html_render.render_template.assert_called_once() + + def test_query_stack_no_matches(self): + self.analyzer._matched_op_index = {'op_rule': []} + event_dataset = MagicMock() + self.analyzer.query_stack(event_dataset) + event_dataset.parse_data_with_generator.assert_not_called() + + def test__match_ops(self): + event_dataset = MagicMock() + ops = 'permute-reshape' + npu_api = 'torch_npu_api' + mode = 'aten' + + with patch.object(self.analyzer, '_replace_op_name_prefix') as mock_replace: + mock_replace.side_effect = ['permute', 'permute', 'reshape', 'reshape', 'permute', 'reshape'] + event_dataset.aten = [MagicMock(name='permute', dataset_index=1), MagicMock(name='reshape')] + self.analyzer._match_ops(event_dataset, ops, npu_api, mode) + + self.assertEqual(self.analyzer._matched_op_index[npu_api + f':{ops}'], {1}) + + def test__match_ops_with_regex(self): + event_dataset = MagicMock() + op_rule_pattern = '(-add-)(-mul-)*' + npu_api = 'torch_npu_api' + mode = 'aten' + event_dataset.aten = [MagicMock(name='-add-'), MagicMock(name='-mul-')] + + with patch("builtins.sorted", return_value=[1]): + self.analyzer._match_ops_with_regex(event_dataset, op_rule_pattern, npu_api, mode) + self.assertIn(npu_api + f':{op_rule_pattern}', self.analyzer._matched_op_index) + + def test__query_stack_by_matched_index(self): + index = 1 + event = {'args': {Constant.CALL_STACKS: 'stack'}} + self.analyzer._matched_op_index = {'op_rule': {1}} + + result = self.analyzer._query_stack_by_matched_index(index, event) + + self.assertEqual(result, {'op_rule': 'stack'}) + + def test__replace_op_name_prefix_dequeue(self): + event_name = 'Dequeue@op_name' + mode = Constant.DEQUEUE.lower() + result = self.analyzer._replace_op_name_prefix(event_name, mode) + self.assertEqual(result, 'op_name') + + def test__replace_op_name_prefix_aten(self): + event_name = 'aten::op_name' + mode = Constant.ATEN + result = self.analyzer._replace_op_name_prefix(event_name, mode) + self.assertEqual(result, 'op_name') + + def test__replace_op_name_prefix_optimizer(self): + event_name = 'Optimizer.step#op_name' + mode = 'optimizer' + result = self.analyzer._replace_op_name_prefix(event_name, mode) + self.assertEqual(result, 'op_name') + + def test__format_rule_to_pattern_no_regex(self): + op_rule = 'permute-reshape' + pattern, enable_regex = self.analyzer._format_rule_to_pattern(op_rule) + self.assertEqual(pattern, op_rule) + self.assertFalse(enable_regex) + + def test__format_rule_to_pattern_with_regex(self): + op_rule = '(mul){0,1}-(add|neg){0,2}-dropout-(softmax)*' + pattern, enable_regex = self.analyzer._format_rule_to_pattern(op_rule) + self.assertTrue(enable_regex) + self.assertIn('(-mul-){0,1}', pattern) + + +if __name__ == '__main__': + unittest.main() diff --git a/profiler/msprof_analyze/test/ut/advisor/cluster_advice/test_slow_rank_analyzer.py b/profiler/msprof_analyze/test/ut/advisor/cluster_advice/test_slow_rank_analyzer.py new file mode 100644 index 0000000000000000000000000000000000000000..f19663e2bb26d322e17fe078c4666c0da3de4e06 --- /dev/null +++ b/profiler/msprof_analyze/test/ut/advisor/cluster_advice/test_slow_rank_analyzer.py @@ -0,0 +1,200 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest.mock import patch, MagicMock + +from msprof_analyze.advisor.analyzer.cluster.slow_rank_analyzer import SlowRankAnalyzer +from msprof_analyze.advisor.dataset.cluster.cluster_dataset import ClusterStepTraceTimeDataset + + +class TestSlowRankAnalyzer(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.collection_path = "test_collection_path" + cls.mock_step_trace_dict = { + '-1_0': [59862.52, 115034.42, 1252571.94], + '-1_1': [60029.1, 971241.44, 396442.14], + '-1_2': [60123.45, 116789.32, 1256789.12], + '-1_3': [59987.32, 115678.90, 1253456.78] + } + cls.mock_stages = [[0, 1, 2, 3]] + + # Create mock dataset + cls.mock_dataset = MagicMock() + cls.mock_dataset.get_key.return_value = ClusterStepTraceTimeDataset.get_key() + cls.mock_dataset.get_data.return_value = cls.mock_step_trace_dict + cls.mock_dataset.get_stages.return_value = cls.mock_stages + + def setUp(self): + # Create analyzer instance for each test + with patch('msprof_analyze.advisor.analyzer.base_analyzer.BaseAnalyzer.init_dataset_list'), \ + patch('msprof_analyze.advisor.analyzer.base_analyzer.BaseAnalyzer.get_first_data_by_key', + return_value=self.mock_dataset): + self.analyzer = SlowRankAnalyzer(self.collection_path) + + def test_init_with_valid_data_then_initialize_correctly(self): + # Verify initialization + self.assertEqual(self.analyzer.step_trace_dict, self.mock_step_trace_dict) + self.assertEqual(self.analyzer.stages, self.mock_stages) + self.assertEqual(self.analyzer.bottelneck, '') + self.assertEqual(self.analyzer.suggestion, '') + self.assertEqual(self.analyzer._steps, set(['-1'])) + + def test_compute_max_gap_ratio_with_non_zero_mean_then_return_correct_ratio(self): + # Test with non-zero mean + data = [14242056.739999993, 14311412.460000006] # min and max compute times + mean = (14242056.739999993 + 14311412.460000006) / 2 + expected_ratio = (14311412.460000006 - 14242056.739999993) / mean + self.assertAlmostEqual(self.analyzer.compute_max_gap_ratio(data, mean), expected_ratio) + + def test_compute_max_gap_ratio_with_zero_mean_then_return_zero(self): + # Test with zero mean + data = [0, 0, 0, 0] + mean = 0 + self.assertEqual(self.analyzer.compute_max_gap_ratio(data, mean), 0) + + def test_format_details_with_valid_data_then_return_formatted_details(self): + details = self.analyzer.format_details() + + # Verify headers + expected_headers = ["step", "rank_id", "compute(us)", "communication(us)", "free(us)"] + self.assertEqual(details["headers"], expected_headers) + + # Verify data format + self.assertEqual(len(details["data"]), 4) # 4 entries in mock data + self.assertEqual(len(details["data"][0]), 5) # 5 columns per row + + # Verify steps are collected + self.assertEqual(self.analyzer._steps, {'-1'}) + + def test_get_step_duration_with_valid_rank_then_return_correct_duration(self): + # Test with valid rank and step + duration = self.analyzer.get_step_duration(0, -1) + expected_duration = 59862.52 + 115034.42 + 1252571.94 + self.assertAlmostEqual(duration, expected_duration) + + def test_get_step_duration_with_invalid_rank_then_return_zero(self): + # Test with invalid rank + duration = self.analyzer.get_step_duration(999) + self.assertEqual(duration, 0.0) + + def test_get_global_step_rank_with_free_dimension_then_return_rank_info(self): + # Test with free dimension + result = self.analyzer.get_global_step_rank("free(us)") + self.assertIn("maximum", result) + self.assertIn("minimum", result) + + self.assertEqual(result["maximum"]["rank_id"], 2) + self.assertEqual(result["minimum"]["rank_id"], 1) + + def test_process_with_significant_differences_then_identify_bottlenecks(self): + # Test process method with mock data that has significant differences + self.analyzer.process() + + # Verify bottleneck message contains expected content + self.assertIn("通信", self.analyzer.bottelneck) + self.assertIn("空闲", self.analyzer.bottelneck) + + # Verify specific bottleneck messages + self.assertIn("集群中的通信有问题", self.analyzer.bottelneck) + self.assertIn("因为通信时间的最大差距已经达到", self.analyzer.bottelneck) + self.assertIn("856.207ms", self.analyzer.bottelneck) + + self.assertIn("集群中的空闲有问题", self.analyzer.bottelneck) + self.assertIn("因为空闲时间的最大差距已经达到", self.analyzer.bottelneck) + self.assertIn("860.347ms", self.analyzer.bottelneck) + + def test_process_with_no_significant_differences_then_report_no_issues(self): + # Test with data that has no significant differences + mock_no_diff = { + '-1_0': [100, 100, 100], + '-1_1': [100, 100, 100], + '-1_2': [100, 100, 100], + '-1_3': [100, 100, 100] + } + self.analyzer.step_trace_dict = mock_no_diff + self.analyzer.bottelneck = '' + self.analyzer.process() + self.assertIn("没有慢节点问题", self.analyzer.bottelneck) + + def test_optimize_with_valid_data_then_return_optimize_result(self): + expected_problem_header = ['category', 'description', 'suggestion', 'problem count', 'total_time(us)', + 'time ratio', 'income(us)', 'income ratio'] + expected_details_header = ['step', 'rank_id', 'compute(us)', 'communication(us)', 'free(us)'] + # Test optimize with valid data + result = self.analyzer.optimize(template_key="overall") + slow_rank_res = dict(result.data) + problems = slow_rank_res.get("问题综述", {}) + self.assertEqual(len(problems), 2) + self.assertEqual(problems.get("headers"), expected_problem_header) + + + details = slow_rank_res.get("慢卡分析", {}) + self.assertEqual(len(details), 2) + self.assertEqual(details.get("headers"), expected_details_header) + + def test_get_stage_step_rank_with_free_dimension_then_return_stage_rank_info(self): + # Test with free dimension + details = self.analyzer.format_details() + result = self.analyzer.get_stage_step_rank("free(us)") + # Verify result structure + self.assertIn("stage-0", result) + stage_result = result["stage-0"] + + # Verify maximum and minimum entries exist + self.assertIn("maximum", stage_result) + self.assertIn("minimum", stage_result) + + # Verify rank_id and step are present + self.assertIn("rank_id", stage_result["maximum"]) + self.assertIn("step", stage_result["maximum"]) + self.assertIn("rank_id", stage_result["minimum"]) + self.assertIn("step", stage_result["minimum"]) + + # Verify specific values + self.assertEqual(stage_result["maximum"]["rank_id"], 2) + self.assertEqual(stage_result["maximum"]["step"], -1) + self.assertEqual(stage_result["minimum"]["rank_id"], 1) + self.assertEqual(stage_result["minimum"]["step"], -1) + + + def test_get_stage_step_rank_with_invalid_dimension_then_return_empty_dict(self): + # Test with invalid dimension + result = self.analyzer.get_stage_step_rank("invalid_dimension") + self.assertEqual(result, {}) + + def test_get_stage_step_rank_with_empty_format_datas_then_return_empty_dict(self): + # Test with empty format_datas + self.analyzer.format_datas = {} + result = self.analyzer.get_stage_step_rank("compute(us)") + self.assertEqual(result, {}) + + def test_get_stage_step_rank_no_significant_difference(self): + # Create mock data with no significant differences + mock_no_diff = { + '-1_0': [100, 100, 100], + '-1_1': [100, 100, 100], + '-1_2': [100, 100, 100], + '-1_3': [100, 100, 100] + } + self.analyzer.step_trace_dict = mock_no_diff + self.analyzer.format_datas = self.analyzer.format_details() + + # Test with compute dimension + result = self.analyzer.get_stage_step_rank("compute(us)") + + # Verify empty result when no significant differences + self.assertEqual(result, {}) diff --git a/profiler/msprof_analyze/test/ut/advisor/common/test_enum_params_parser.py b/profiler/msprof_analyze/test/ut/advisor/common/test_enum_params_parser.py index 5d11af12781eb5845d82e5415253fa93be99cf6b..608a007a60286c6f13312d3d7879b92387034ec9 100644 --- a/profiler/msprof_analyze/test/ut/advisor/common/test_enum_params_parser.py +++ b/profiler/msprof_analyze/test/ut/advisor/common/test_enum_params_parser.py @@ -1,63 +1,63 @@ -# Copyright (c) 2025, Huawei Technologies Co., Ltd. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import unittest - -from msprof_analyze.advisor.common.enum_params_parser import EnumParamsParser -from msprof_analyze.test.ut.advisor.advisor_backend.tools.tool import recover_env - - -class TestEnumParamsParser(unittest.TestCase): - @classmethod - def tearDownClass(cls) -> None: - recover_env() - - def setUp(self) -> None: - self.enum_params_parser = EnumParamsParser() - self.argument_keys = sorted(["cann_version", "torch_version", "analysis_dimensions", "profiling_type", "mindspore_version"]) - self.env_keys = ["ADVISOR_ANALYZE_PROCESSES", "DISABLE_PROFILING_COMPARISON", "DISABLE_AFFINITY_API"] - - def test_get_keys(self): - total_keys = sorted(self.argument_keys + self.env_keys) - keys = sorted(self.enum_params_parser.get_keys()) - self.assertTrue(isinstance(keys, list)) - self.assertEqual(keys, total_keys) - - def test_get_argument_keys(self): - argument_keys = sorted(self.enum_params_parser.get_arguments_keys()) - self.assertTrue(isinstance(argument_keys, list)) - self.assertEqual(argument_keys, self.argument_keys) - - def test_get_env_keys(self): - env_keys = sorted(self.enum_params_parser.get_envs_keys()) - self.assertTrue(isinstance(env_keys, list)) - self.assertEqual(env_keys, sorted(self.env_keys)) - - def test_get_default(self): - self.assertTrue(self.enum_params_parser.get_default("cann_version"), "8.0.RC1") - self.assertTrue(self.enum_params_parser.get_default("torch_version"), "2.1.0") - self.assertTrue(self.enum_params_parser.get_default("analysis_dimensions"), - ["computation", "communication", "schedule", "memory"]) - self.assertTrue(self.enum_params_parser.get_default("profiling_type"), "ascend_pytorch_profiler") - self.assertTrue(self.enum_params_parser.get_default("ADVISOR_ANALYZE_PROCESSES"), 1) - - def test_get_options(self): - self.assertTrue(self.enum_params_parser.get_options("cann_version"), ["6.3.RC2", "7.0.RC1", "7.0.0", "8.0.RC1"]) - self.assertTrue(self.enum_params_parser.get_options("torch_version"), ["1.11.0", "2.1.0"]) - self.assertTrue(self.enum_params_parser.get_options("analysis_dimensions"), - [["computation", "communication", "schedule", "memory"], ["communication"], ["schedule"], - ["computation"], ["memory"]]) - self.assertTrue(self.enum_params_parser.get_options("profiling_type"), - ["ascend_pytorch_profiler", "mslite", "msprof"]) - self.assertTrue(self.enum_params_parser.get_options("ADVISOR_ANALYZE_PROCESSES"), list(range(1, 9))) +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +from msprof_analyze.advisor.common.enum_params_parser import EnumParamsParser +from msprof_analyze.test.ut.advisor.advisor_backend.tools.tool import recover_env + + +class TestEnumParamsParser(unittest.TestCase): + @classmethod + def tearDownClass(cls) -> None: + recover_env() + + def setUp(self) -> None: + self.enum_params_parser = EnumParamsParser() + self.argument_keys = sorted(["cann_version", "torch_version", "analysis_dimensions", "profiling_type", "mindspore_version"]) + self.env_keys = ["ADVISOR_ANALYZE_PROCESSES", "DISABLE_PROFILING_COMPARISON", "DISABLE_AFFINITY_API"] + + def test_get_keys(self): + total_keys = sorted(self.argument_keys + self.env_keys) + keys = sorted(self.enum_params_parser.get_keys()) + self.assertTrue(isinstance(keys, list)) + self.assertEqual(keys, total_keys) + + def test_get_argument_keys(self): + argument_keys = sorted(self.enum_params_parser.get_arguments_keys()) + self.assertTrue(isinstance(argument_keys, list)) + self.assertEqual(argument_keys, self.argument_keys) + + def test_get_env_keys(self): + env_keys = sorted(self.enum_params_parser.get_envs_keys()) + self.assertTrue(isinstance(env_keys, list)) + self.assertEqual(env_keys, sorted(self.env_keys)) + + def test_get_default(self): + self.assertTrue(self.enum_params_parser.get_default("cann_version"), "8.0.RC1") + self.assertTrue(self.enum_params_parser.get_default("torch_version"), "2.1.0") + self.assertTrue(self.enum_params_parser.get_default("analysis_dimensions"), + ["computation", "communication", "schedule", "memory"]) + self.assertTrue(self.enum_params_parser.get_default("profiling_type"), "ascend_pytorch_profiler") + self.assertTrue(self.enum_params_parser.get_default("ADVISOR_ANALYZE_PROCESSES"), 1) + + def test_get_options(self): + self.assertTrue(self.enum_params_parser.get_options("cann_version"), ["6.3.RC2", "7.0.RC1", "7.0.0", "8.0.RC1"]) + self.assertTrue(self.enum_params_parser.get_options("torch_version"), ["1.11.0", "2.1.0"]) + self.assertTrue(self.enum_params_parser.get_options("analysis_dimensions"), + [["computation", "communication", "schedule", "memory"], ["communication"], ["schedule"], + ["computation"], ["memory"]]) + self.assertTrue(self.enum_params_parser.get_options("profiling_type"), + ["ascend_pytorch_profiler", "mslite", "msprof"]) + self.assertTrue(self.enum_params_parser.get_options("ADVISOR_ANALYZE_PROCESSES"), list(range(1, 9))) diff --git a/profiler/msprof_analyze/test/ut/advisor/compute_advice/data/kernel_details.csv b/profiler/msprof_analyze/test/ut/advisor/compute_advice/data/kernel_details.csv deleted file mode 100644 index 8a255e939ae2ff4e781c7a356b342815838e2ff3..0000000000000000000000000000000000000000 --- a/profiler/msprof_analyze/test/ut/advisor/compute_advice/data/kernel_details.csv +++ /dev/null @@ -1,30 +0,0 @@ -Step Id,Model ID,Task ID,Stream ID,Name,Type,OP State,Accelerator Core,Start Time(us),Duration(us),Wait Time(us),Block Dim,Mix Block Dim,HF32 Eligible,Input Shapes,Input Data Types,Input Formats,Output Shapes,Output Data Types,Output Formats,Context ID,aicore_time(us),aic_total_cycles,aic_mac_time(us),aic_mac_ratio,aic_scalar_time(us),aic_scalar_ratio,aic_mte1_time(us),aic_mte1_ratio,aic_mte2_time(us),aic_mte2_ratio,aic_fixpipe_time(us),aic_fixpipe_ratio,aic_icache_miss_rate,aiv_time(us),aiv_total_cycles,aiv_vec_time(us),aiv_vec_ratio,aiv_scalar_time(us),aiv_scalar_ratio,aiv_mte2_time(us),aiv_mte2_ratio,aiv_mte3_time(us),aiv_mte3_ratio,aiv_icache_miss_rate,cube_utilization(%) -19,4294967295,61653,2,aclnnMatmul_MatMulCommon_MatMulV2,MatMulV2,dynamic,AI_CORE,"1736413971558972.912 ",185.504,1.087,16,0,NO,"""81920,4096;8192,512""",DT_BF16;DT_BF16,ND;ND,"""4096,512""",DT_BF16,ND,N/A,183.87,5295467,151.425,0.824,88.03,0.479,119.148,0.648,177.314,0.964,5.736,0.031,0.001,0,0,0,0,0,0,0,0,0,0,0,79.295 -19,4294967295,61669,2,aclnnMatmul_MatMulV3Common_MatMulV3,MatMulV3,dynamic,AI_CORE,"1736413971560588.764 ",501.17,2.2,20,0,NO,"""81920,1536;8192,4096""",DT_BF16;DT_BF16,ND;ND,"""1536,4096""",DT_BF16,ND,N/A,478.701,17233251,356.349,0.744,118.087,0.247,296.009,0.618,452.112,0.944,35.833,0.075,0.001,0,0,0,0,0,0,0,0,0,0,0,95.517 -19,4294967295,61694,2,aclnnMatmul_MatMulCommon_MatMulV2,MatMulV2,dynamic,AI_CORE,"1736413971565213.257 ",186.823,1.178,16,0,NO,"""81920,4096;8192,512""",DT_BF16;DT_BF16,ND;ND,"""4096,512""",DT_BF16,ND,N/A,183.728,5291376,151.502,0.825,87.902,0.478,118.519,0.645,177.654,0.967,5.773,0.031,0.001,0,0,0,0,0,0,0,0,0,0,0,78.675 -19,4294967295,61710,2,aclnnMatmul_MatMulV3Common_MatMulV3,MatMulV3,dynamic,AI_CORE,"1736413971566843.489 ",516.991,2.33,20,0,NO,"""81920,1536;8192,4096""",DT_BF16;DT_BF16,ND;ND,"""1536,4096""",DT_BF16,ND,N/A,491.775,17703905,356.249,0.724,118.59,0.241,295.046,0.6,463.696,0.943,37.671,0.077,0.001,0,0,0,0,0,0,0,0,0,0,0,95.123 -19,4294967295,61735,2,aclnnMatmul_MatMulCommon_MatMulV2,MatMulV2,dynamic,AI_CORE,"1736413971571596.404 ",187.724,0.766,16,0,NO,"""81920,4096;8192,512""",DT_BF16;DT_BF16,ND;ND,"""4096,512""",DT_BF16,ND,N/A,184.904,5325221,151.489,0.819,87.893,0.475,118.63,0.642,178.815,0.967,5.77,0.031,0.001,0,0,0,0,0,0,0,0,0,0,0,78.798 -19,4294967295,61751,2,aclnnMatmul_MatMulV3Common_MatMulV3,MatMulV3,dynamic,AI_CORE,"1736413971573223.437 ",514.87,2.15,20,0,NO,"""81920,1536;8192,4096""",DT_BF16;DT_BF16,ND;ND,"""1536,4096""",DT_BF16,ND,N/A,486.931,17529512,356.117,0.731,118.847,0.244,295.529,0.607,457.002,0.939,37.938,0.078,0.001,0,0,0,0,0,0,0,0,0,0,0,94.574 -19,4294967295,61776,2,aclnnMatmul_MatMulCommon_MatMulV2,MatMulV2,dynamic,AI_CORE,"1736413971577931.851 ",190.544,1.367,16,0,NO,"""81920,4096;8192,512""",DT_BF16;DT_BF16,ND;ND,"""4096,512""",DT_BF16,ND,N/A,187.073,5387702,151.741,0.811,87.935,0.47,117.467,0.628,181.043,0.968,5.803,0.031,0.001,0,0,0,0,0,0,0,0,0,0,0,78.543 -19,4294967295,61792,2,aclnnMatmul_MatMulV3Common_MatMulV3,MatMulV3,dynamic,AI_CORE,"1736413971579566.403 ",504.071,2.28,20,0,NO,"""81920,1536;8192,4096""",DT_BF16;DT_BF16,ND;ND,"""1536,4096""",DT_BF16,ND,N/A,485.542,17479517,356.283,0.734,117.755,0.243,296.421,0.61,455.064,0.937,37.75,0.078,0.001,0,0,0,0,0,0,0,0,0,0,0,96.324 -19,4294967295,13792,2,aclnnMatmul_MatMulV3Common_MatMulV5,MatMulV3,dynamic,AI_CORE,"1736413974248200.543 ",521.31,2.22,20,0,NO,"""8192,15365;8192,4096""",DT_BF16;DT_BF16,ND;ND,"""1536,4096""",DT_BF16,ND,N/A,499.234,17972434,356.364,0.714,117.639,0.236,295.58,0.592,471.784,0.945,35.825,0.072,0.001,0,0,0,0,0,0,0,0,0,0,0,95.765 -19,4294967295,13792,2,aclnnMatmul_MatMulV3Common_MatMulV5,MatMulV3,dynamic,AI_CORE,"1736413974248200.543 ",521.31,2.22,20,0,NO,"""8192,15365;8192,4096""",DT_BF16;DT_BF16,ND;ND,"""1536,4096""",DT_BF16,ND,N/A,499.234,17972434,356.364,0.714,117.639,0.236,295.58,0.592,471.784,0.945,35.825,0.072,0.001,0,0,0,0,0,0,0,0,0,0,0,95.765 -19,4294967295,13792,2,aclnnMatmul_MatMulV3Common_MatMulV5,MatMulV3,dynamic,AI_CORE,"1736413974248200.543 ",521.31,2.22,20,0,NO,"""8192,15365;8192,4096""",DT_BF16;DT_BF16,ND;ND,"""1536,4096""",DT_BF16,ND,N/A,499.234,17972434,356.364,0.714,117.639,0.236,295.58,0.592,471.784,0.945,35.825,0.072,0.001,0,0,0,0,0,0,0,0,0,0,0,95.765 -19,4294967295,13792,2,aclnnMatmul_MatMulV3Common_MatMulV5,MatMulV3,dynamic,AI_CORE,"1736413974248200.543 ",521.31,2.22,20,0,NO,"""8192,15365;8192,4096""",DT_BF16;DT_BF16,ND;ND,"""1536,4096""",DT_BF16,ND,N/A,499.234,17972434,356.364,0.714,117.639,0.236,295.58,0.592,471.784,0.945,35.825,0.072,0.001,0,0,0,0,0,0,0,0,0,0,0,95.765 -19,4294967295,60679,2,aclnnFlashAttentionScore_FlashAttentionScore_FlashAttentionScore,FlashAttentionScore,dynamic,MIX_AIC,"1736413971411629.128 ",410.188,1.53,20,40,NO,"""4096,2,512;4096,2,512;4096,2,512;;;;4096,4096;;;;;""",DT_BF16;DT_BF16;DT_BF16;DT_BF16;UINT8;DT_BF16;BOOL;INT64;INT64;INT64;INT64;INT64,NCL;NCL;NCL;ND;ND;ND;ND;ND;ND;ND;ND;ND,"""2,4,4096,8;2,4,4096,8;;4096,2,512""",FLOAT;FLOAT;DT_BF16;DT_BF16,ND;ND;ND;ND,0,366.147,13181275,129.055,0.352,352.275,0.962,108.364,0.296,172.86,0.872,216.141,0.59,0.003,365.782,26336326,228.687,0.625,137.979,0.377,118.603,0.324,71.448,0.195,0.013,89.263 -19,4294967295,60707,2,aclnnFlashAttentionScore_FlashAttentionScore_FlashAttentionScore,FlashAttentionScore,dynamic,MIX_AIC,"1736413971415611.468 ",406.128,1.279,20,40,NO,"""4096,2,512;4096,2,512;4096,2,512;;;;4096,4096;;;;;""",DT_BF16;DT_BF16;DT_BF16;DT_BF16;UINT8;DT_BF16;BOOL;INT64;INT64;INT64;INT64;INT64,NCL;NCL;NCL;ND;ND;ND;ND;ND;ND;ND;ND;ND,"""2,4,4096,8;2,4,4096,8;;4096,2,512""",FLOAT;FLOAT;DT_BF16;DT_BF16,ND;ND;ND;ND,0,358.77,12915719,128.96,0.359,345.096,0.962,108.337,0.302,168.284,0.869,209.057,0.583,0.003,358.308,25798146,228.693,0.638,137.809,0.385,108.679,0.303,70.099,0.196,0.013,88.339 -19,4294967295,60735,2,aclnnFlashAttentionScore_FlashAttentionScore_FlashAttentionScore,FlashAttentionScore,dynamic,MIX_AIC,"1736413971420248.800 ",407.008,0.84,20,40,NO,"""4096,2,512;4096,2,512;4096,2,512;;;;4096,4096;;;;;""",DT_BF16;DT_BF16;DT_BF16;DT_BF16;UINT8;DT_BF16;BOOL;INT64;INT64;INT64;INT64;INT64,NCL;NCL;NCL;ND;ND;ND;ND;ND;ND;ND;ND;ND,"""2,4,4096,8;2,4,4096,8;;4096,2,512""",FLOAT;FLOAT;DT_BF16;DT_BF16,ND;ND;ND;ND,0,359.702,12949284,128.975,0.359,346.306,0.963,108.43,0.301,166.899,0.864,209.018,0.581,0.003,359.274,25867705,228.693,0.637,138.438,0.385,107.723,0.3,70.146,0.195,0.013,88.377 -19,4294967295,60763,2,aclnnFlashAttentionScore_FlashAttentionScore_FlashAttentionScore,FlashAttentionScore,dynamic,MIX_AIC,"1736413971424592.447 ",405.228,1.35,20,40,NO,"""4096,2,512;4096,2,512;4096,2,512;;;;4096,4096;;;;;""",DT_BF16;DT_BF16;DT_BF16;DT_BF16;UINT8;DT_BF16;BOOL;INT64;INT64;INT64;INT64;INT64,NCL;NCL;NCL;ND;ND;ND;ND;ND;ND;ND;ND;ND,"""2,4,4096,8;2,4,4096,8;;4096,2,512""",FLOAT;FLOAT;DT_BF16;DT_BF16,ND;ND;ND;ND,0,359.793,12952532,128.923,0.358,345.768,0.961,108.411,0.301,167.379,0.865,208.79,0.58,0.003,359.294,25869164,228.691,0.637,138.411,0.385,107.868,0.3,70.163,0.195,0.013,88.788 -19,4294967295,61655,2,aclnnFlashAttentionScoreGrad_FlashAttentionScoreGrad_FlashAttentionScoreGrad,FlashAttentionScoreGrad,dynamic,MIX_AIC,"1736413971559180.676 ",762.215,1.37,20,40,NO,"""4096,2,512;4096,2,512;4096,2,512;4096,2,512;4096,4096;2,4,4096,8;2,4,4096,8;;4096,2,512;""",DT_BF16;DT_BF16;DT_BF16;DT_BF16;BOOL;FLOAT;FLOAT;DT_BF16;DT_BF16;INT64,NCL;NCL;NCL;NCL;ND;NCHW;NCHW;ND;NCL;ND,"""4096,2,512;4096,2,512;4096,2,512;""",DT_BF16;DT_BF16;DT_BF16;DT_BF16,ND;ND;ND;ND,0,755.664,27203907,344.023,0.455,592.472,0.784,266.388,0.353,397.091,0.525,589.726,0.525,0.004,755.04,54362915,318.452,0.422,184.623,0.245,206.78,0.274,152.973,0.203,0.006,99.141 -19,4294967295,61696,2,aclnnFlashAttentionScoreGrad_FlashAttentionScoreGrad_FlashAttentionScoreGrad,FlashAttentionScoreGrad,dynamic,MIX_AIC,"1736413971565420.821 ",763.215,1.189,20,40,NO,"""4096,2,512;4096,2,512;4096,2,512;4096,2,512;4096,4096;2,4,4096,8;2,4,4096,8;;4096,2,512;""",DT_BF16;DT_BF16;DT_BF16;DT_BF16;BOOL;FLOAT;FLOAT;DT_BF16;DT_BF16;INT64,NCL;NCL;NCL;NCL;ND;NCHW;NCHW;ND;NCL;ND,"""4096,2,512;4096,2,512;4096,2,512;""",DT_BF16;DT_BF16;DT_BF16;DT_BF16,ND;ND;ND;ND,0,757.83,27281885,344.047,0.454,595.954,0.786,266.123,0.351,389.105,0.513,576.226,0.513,0.004,757.046,54507345,318.443,0.421,188.292,0.249,200.176,0.264,162.113,0.214,0.006,99.294 -19,4294967295,61737,2,aclnnFlashAttentionScoreGrad_FlashAttentionScoreGrad_FlashAttentionScoreGrad,FlashAttentionScoreGrad,dynamic,MIX_AIC,"1736413971571804.228 ",757.095,0.88,20,40,NO,"""4096,2,512;4096,2,512;4096,2,512;4096,2,512;4096,4096;2,4,4096,8;2,4,4096,8;;4096,2,512;""",DT_BF16;DT_BF16;DT_BF16;DT_BF16;BOOL;FLOAT;FLOAT;DT_BF16;DT_BF16;INT64,NCL;NCL;NCL;NCL;ND;NCHW;NCHW;ND;NCL;ND,"""4096,2,512;4096,2,512;4096,2,512;""",DT_BF16;DT_BF16;DT_BF16;DT_BF16,ND;ND;ND;ND,0,750.605,27021778,343.983,0.458,586.708,0.782,266.304,0.355,392.522,0.523,584.432,0.523,0.004,749.913,53993736,318.436,0.425,188.508,0.251,207.668,0.277,152.634,0.204,0.006,99.143 -19,4294967295,61778,2,aclnnFlashAttentionScoreGrad_FlashAttentionScoreGrad_FlashAttentionScoreGrad,FlashAttentionScoreGrad,dynamic,MIX_AIC,"1736413971578144.095 ",755.915,1.22,20,40,NO,"""4096,2,512;4096,2,512;4096,2,512;4096,2,512;4096,4096;2,4,4096,8;2,4,4096,8;;4096,2,512;""",DT_BF16;DT_BF16;DT_BF16;DT_BF16;BOOL;FLOAT;FLOAT;DT_BF16;DT_BF16;INT64,NCL;NCL;NCL;NCL;ND;NCHW;NCHW;ND;NCL;ND,"""4096,2,512;4096,2,512;4096,2,512;""",DT_BF16;DT_BF16;DT_BF16;DT_BF16,ND;ND;ND;ND,0,750.152,27005467,344.115,0.459,579.317,0.772,266.08,0.355,398.019,0.531,587.37,0.531,0.004,749.348,53953058,318.444,0.425,186.908,0.249,207.068,0.276,151.329,0.202,0.006,99.238 -19,4294967295,60763,2,aclnnFlashAttentionScore_FlashAttentionScore_FlashAttentionScore_varlen,FlashAttentionScore,dynamic,MIX_AIC,"1736413971424592.447 ",405.228,1.35,20,40,NO,"""4096,2,511;4096,2,512;4096,2,512;;;;4096,4096;;;;;""",DT_BF16;DT_BF16;DT_BF16;DT_BF16;UINT8;DT_BF16;BOOL;INT64;INT64;INT64;INT64;INT64,NCL;NCL;NCL;ND;ND;ND;ND;ND;ND;ND;ND;ND,"""2,3,4096,8;2,4,4096,8;;4096,2,512""",FLOAT;FLOAT;DT_BF16;DT_BF16,ND;ND;ND;ND,0,359.793,12952532,128.923,0.358,345.768,0.961,108.411,0.301,167.379,0.465,208.79,0.58,0.003,359.294,25869164,228.691,0.637,138.411,0.385,107.868,0.3,70.163,0.195,0.013,88.788 -19,4294967295,60683,2,aclnnAdd_AddAiCore_Add,Add,dynamic,AI_VECTOR_CORE,"1736413971412768.871 ",26.78,0.485,40,0,NO,"""512,2,4096;512,2,4096""",DT_BF16;DT_BF16,NCL;NCL,"""512,2,4096""",DT_BF16,ND,N/A,0,0,0,0,0,0,0,0,0,0,0,0,0,24.19,1741674,5.986,0.247,1.352,0.056,20.363,0.842,3.195,0.132,0.027,0 -19,4294967295,60690,2,aclnnAdd_AddAiCore_Add,Add,dynamic,AI_VECTOR_CORE,"1736413971414677.549 ",31.201,0.664,40,0,NO,"""512,2,4096;512,2,4096""",DT_BF16;DT_BF16,NCL;NCL,"""512,2,4096""",DT_BF16,ND,N/A,0,0,0,0,0,0,0,0,0,0,0,0,0,28.617,2060443,5.986,0.209,1.444,0.05,25.005,0.874,3.336,0.117,0.026,0 -19,4294967295,60711,2,aclnnAdd_AddAiCore_Add,Add,dynamic,AI_VECTOR_CORE,"1736413971416743.250 ",27.021,1.246,40,0,NO,"""512,2,4096;512,2,4096""",DT_BF16;DT_BF16,NCL;NCL,"""512,2,4096""",DT_BF16,ND,N/A,0,0,0,0,0,0,0,0,0,0,0,0,0,24.304,1749862,5.986,0.246,1.258,0.052,20.424,0.84,3.23,0.133,0.027,0 -19,4294967295,60718,2,aclnnAdd_AddAiCore_Add,Add,dynamic,AI_VECTOR_CORE,"1736413971419318.962 ",25.08,0.984,40,0,NO,"""512,2,4096;512,2,4096""",DT_BF16;DT_BF16,NCL;NCL,"""512,2,4096""",DT_BF16,ND,N/A,0,0,0,0,0,0,0,0,0,0,0,0,0,22.47,1617840,5.989,0.267,2.009,0.089,18.809,0.837,3.191,0.142,0.024,0 -19,4294967295,13907,2,aclnnAdd_AddAiCore_Add,Add,dynamic,AI_VECTOR_CORE,"1736413974268377.206 ",1.38,31.48,1,0,NO,""";""",FLOAT;FLOAT,ND;ND,"""""",FLOAT,ND,N/A,0,0,0,0,0,0,0,0,0,0,0,0,0,0.883,1589,0.027,0.03,0.265,0.3,0.18,0.204,0.108,0.123,0.182,0 -19,4294967295,13910,2,aclnnAdd_AddAiCore_Add,Add,dynamic,AI_VECTOR_CORE,"1736413974268502.128 ",1.46,17.48,1,0,NO,""";""",FLOAT;FLOAT,ND;ND,"""""",FLOAT,ND,N/A,0,0,0,0,0,0,0,0,0,0,0,0,0,0.948,1706,0.027,0.028,0.276,0.291,0.217,0.229,0.127,0.134,0.174,0 -19,4294967295,13913,2,aclnnAdd_AddAiCore_Add,Add,dynamic,AI_VECTOR_CORE,"1736413974268605.410 ",1.5,0.09,1,0,NO,""";""",FLOAT;FLOAT,ND;ND,"""""",FLOAT,ND,N/A,0,0,0,0,0,0,0,0,0,0,0,0,0,0.96,1728,0.027,0.028,0.268,0.28,0.221,0.23,0.132,0.137,0.145,0 -19,4294967295,13916,2,aclnnAdd_AddAiCore_Add,Add,dynamic,AI_VECTOR_CORE,"1736413974268747.953 ",1.58,28.28,1,0,NO,""";""",FLOAT;FLOAT,ND;ND,"""""",FLOAT,ND,N/A,0,0,0,0,0,0,0,0,0,0,0,0,0,1.107,1993,0.027,0.024,0.426,0.384,0.201,0.181,0.118,0.106,0.162,0 \ No newline at end of file diff --git a/profiler/msprof_analyze/test/ut/advisor/compute_advice/test_ai_core_performance_advice.py b/profiler/msprof_analyze/test/ut/advisor/compute_advice/test_ai_core_performance_advice.py deleted file mode 100644 index c8196f5eefdee0c1f3819916b261a002017ba987..0000000000000000000000000000000000000000 --- a/profiler/msprof_analyze/test/ut/advisor/compute_advice/test_ai_core_performance_advice.py +++ /dev/null @@ -1,85 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import os -import shutil - -import unittest -from msprof_analyze.advisor.interface.interface import Interface -from msprof_analyze.advisor.common.analyzer_scopes import SupportedScopes - - -class TestAICorePerformanceAdvice(unittest.TestCase): - TMP_DIR = "./ascend_pt" - OUTPUT_DIR = "./ascend_pt/ASCEND_PROFILER_OUTPUT" - interface = None - err_interface = None - - @classmethod - def clear_htmls(cls): - current_path = os.path.dirname(os.path.abspath(__file__)) - for filename in os.listdir(current_path): - # 检查文件是否以“mstt”开头 - if filename.startswith("mstt"): - # 构建文件的完整路径 - file_path = os.path.join(current_path, filename) - # 删除文件 - os.remove(file_path) - - @classmethod - def copy_kernel_details(cls, path): - # Define source and destination paths - source_csv_path = os.path.join(os.path.dirname(__file__), 'data', path) - destination_csv_path = f"{TestAICorePerformanceAdvice.OUTPUT_DIR}/kernel_details.csv" - - # Check if source CSV file exists - if not os.path.exists(source_csv_path): - raise FileNotFoundError(f"test data file not found:{source_csv_path}") - - # Ensure the output directory exists - if not os.path.exists(TestAICorePerformanceAdvice.OUTPUT_DIR): - os.makedirs(TestAICorePerformanceAdvice.OUTPUT_DIR) - - # Copy the CSV file from source to destination - shutil.copyfile(source_csv_path, destination_csv_path) - - def tearDown(self): - if os.path.exists(TestAICorePerformanceAdvice.TMP_DIR): - shutil.rmtree(TestAICorePerformanceAdvice.TMP_DIR) - self.clear_htmls() - - def setUp(self): - if os.path.exists(TestAICorePerformanceAdvice.TMP_DIR): - shutil.rmtree(TestAICorePerformanceAdvice.TMP_DIR) - if not os.path.exists(TestAICorePerformanceAdvice.TMP_DIR): - os.makedirs(TestAICorePerformanceAdvice.TMP_DIR) - if not os.path.exists(TestAICorePerformanceAdvice.OUTPUT_DIR): - os.makedirs(TestAICorePerformanceAdvice.OUTPUT_DIR) - self.clear_htmls() - - def test_ai_core_performance_total(self): - file_path = "kernel_details.csv" - self.copy_kernel_details(file_path) - interface = Interface(profiling_path=self.TMP_DIR) - dimension = Interface.COMPUTATION - scope = SupportedScopes.AICORE_PERFORMANCE_ANALYSIS - result = interface.get_result(dimension, scope, render_html=1, output_dict=False, profiling_path=self.TMP_DIR) - self.assertLess(1, len(result.data.get("Cube算子性能分析").get("data")[0])) - self.assertLess(1, len(result.data.get("Cube算子性能分析").get("data")[1])) - self.assertLess(1, len(result.data.get("Cube算子性能分析").get("data")[2])) - self.assertLess(1, len(result.data.get("FA算子性能分析").get("data")[0])) - self.assertLess(1, len(result.data.get("FA算子性能分析").get("data")[1])) - self.assertLess(1, len(result.data.get("FA算子性能分析").get("data")[2])) - self.assertLess(1, len(result.data.get("Vector算子性能分析").get("data")[0])) - self.assertLess(1, len(result.data.get("Vector算子性能分析").get("data")[1])) - result.clear() \ No newline at end of file diff --git a/profiler/msprof_analyze/test/ut/advisor/dataset/test_ai_core_freq_dataset.py b/profiler/msprof_analyze/test/ut/advisor/dataset/test_ai_core_freq_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..61bb145dd2f528b526f698292b82cd40dcd384c8 --- /dev/null +++ b/profiler/msprof_analyze/test/ut/advisor/dataset/test_ai_core_freq_dataset.py @@ -0,0 +1,113 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import os +import shutil +import stat +import json +from unittest.mock import patch, MagicMock + +from msprof_analyze.advisor.dataset.ai_core_freq.ai_core_freq_dataset import AICoreFreqDataset + + +class TestAICoreFreqDataset(unittest.TestCase): + TMP_DIR = "./ascend_pt" + OUTPUT_DIR = "./ascend_pt/ASCEND_PROFILER_OUTPUT" + + @classmethod + def create_trace_view_json(cls): + trace_view_data = [ + { + "name": "ProfilerStep#1", + "ts": "1000", + "dur": "100", + "args": {} + }, + { + "name": "Matmul", + "ts": "1100", + "dur": "50", + "args": {"Task Type": "AI_CORE"} + }, + { + "name": "AI Core Freq", + "ts": "1000", + "args": {"MHz": "1000"} + }, + { + "name": "Conv2D", + "ts": "1200", + "dur": "100", + "args": {"Task Type": "AI_CORE"} + }, + { + "name": "AI Core Freq", + "ts": "1250", + "args": {"MHz": "800"} + } + ] + + with os.fdopen(os.open(f"{TestAICoreFreqDataset.OUTPUT_DIR}/trace_view.json", + os.O_WRONLY | os.O_CREAT, stat.S_IWUSR | stat.S_IRUSR), 'w') as fp: + fp.write(json.dumps(trace_view_data)) + + def setUp(self): + if os.path.exists(TestAICoreFreqDataset.TMP_DIR): + shutil.rmtree(TestAICoreFreqDataset.TMP_DIR) + if not os.path.exists(TestAICoreFreqDataset.TMP_DIR): + os.makedirs(TestAICoreFreqDataset.TMP_DIR) + if not os.path.exists(TestAICoreFreqDataset.OUTPUT_DIR): + os.makedirs(TestAICoreFreqDataset.OUTPUT_DIR) + + def tearDown(self): + if os.path.exists(TestAICoreFreqDataset.TMP_DIR): + shutil.rmtree(TestAICoreFreqDataset.TMP_DIR) + + @patch('msprof_analyze.advisor.dataset.ai_core_freq.ai_core_freq_dataset.Config') + def test_aicire_freq_dataset(self, mock_config_class): + # Mock Config singleton instance + mock_config_instance = MagicMock() + mock_config_class.return_value = mock_config_instance + mock_config_instance.get_config.return_value = True + + self.create_trace_view_json() + data = {} + dataset = AICoreFreqDataset(self.OUTPUT_DIR, data) + + # Verify initialization + self.assertEqual(dataset.timeline_dir, self.OUTPUT_DIR) + self.assertEqual(len(dataset.profiler_step), 1) + self.assertEqual(len(dataset.ai_core_ops), 2) # Now we have 2 operators + self.assertEqual(len(dataset.ai_core_freq), 2) # Now we have 2 frequency events + self.assertEqual(dataset.get_key(), "ai_core_freq_dataset") + + # Verify op_freq data + self.assertIn("Matmul", dataset.op_freq) + self.assertIn("Conv2D", dataset.op_freq) + + # Verify Matmul frequency info + matmul_freq = dataset.op_freq["Matmul"] + self.assertEqual(matmul_freq["count"], 1) + self.assertEqual(matmul_freq["dur"], 50.0) + self.assertEqual(len(matmul_freq["freq_list"]), 1) + self.assertEqual(matmul_freq["freq_list"][0], 1000.0) + + # Verify Conv2D frequency info + conv2d_freq = dataset.op_freq["Conv2D"] + self.assertEqual(conv2d_freq["count"], 1) + self.assertEqual(conv2d_freq["dur"], 100.0) + self.assertEqual(len(conv2d_freq["freq_list"]), 1) + self.assertEqual(conv2d_freq["freq_list"][0], 800.0) diff --git a/profiler/msprof_analyze/test/ut/advisor/timeline_advice/test_timeline_op_collector.py b/profiler/msprof_analyze/test/ut/advisor/timeline_advice/test_timeline_op_collector.py index edef567259f8778896e6f3d3291fb4649664aecc..65b9a6cea045958b2fabfbca9db60b60134b01ca 100644 --- a/profiler/msprof_analyze/test/ut/advisor/timeline_advice/test_timeline_op_collector.py +++ b/profiler/msprof_analyze/test/ut/advisor/timeline_advice/test_timeline_op_collector.py @@ -1,152 +1,152 @@ -# Copyright (c) 2025, Huawei Technologies Co., Ltd. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import unittest - -from msprof_analyze.advisor.dataset.timeline_op_collector.timeline_op_collector import ( - OpCompileCollector, - SynchronizeStreamCollector, - MemCollector, - DataloaderCollector, - SyncBNCollector, - AtenCollector, - OptimizerCollector, - FrequencyCollector, - SpecificTaskTypeOpCollector, - TorchToNpuCollector, - AclToNpuCollector, - OpStackCollector, - StepCollector -) -from msprof_analyze.advisor.common.timeline.event import TimelineEvent -from msprof_analyze.test.ut.advisor.advisor_backend.tools.tool import recover_env - - -class TestTimelineOpCollector(unittest.TestCase): - @classmethod - def tearDownClass(cls) -> None: - recover_env() - - def setUp(self) -> None: - self.mock_step_event = TimelineEvent(dict(name="ProfilerStep#1", ts=1, dur=1000)) - self.mock_op_compile_event = TimelineEvent(dict(name="AscendCL@aclopCompileAndExecute", ts=2, dur=1)) - self.mock_sync_stream_event = TimelineEvent(dict(name="AscendCL@aclrtSynchronizeStream", dur=1000000000)) - self.mock_mem_op_event = TimelineEvent(dict(name="AscendCL@aclMallocMemInner", dur=10)) - self.mock_dataloader_event = TimelineEvent(dict(name="dataloader")) - self.mock_sync_bn_event = TimelineEvent(dict(name="syncbatchnorm")) - self.mock_aten_event = TimelineEvent(dict(name="aten::conv3d")) - self.mock_optimizer_event = TimelineEvent(dict(name="Optimizer.step#")) - self.mock_AI_CPU_event = TimelineEvent( - {"name": "index", "args": TimelineEvent({"Task Type": "AI_CPU"}), "ts": 1}) - self.mock_torch_to_npu_event = TimelineEvent(dict(name="torch_to_npu", tid=1, ts=1, ph=1, id=1)) - self.mock_acl_to_npu_event = TimelineEvent(dict(name="acl_to_npu", ts=1)) - self.mock_op_stack_event = TimelineEvent( - {"name": "aten::conv3d", "dataset_index": 1, "ts": 1, "args": TimelineEvent({"Call stack": "mock_stack"})}) - - def test_step_collector(self): - step_collector = StepCollector() - step_collector.add_op(self.mock_step_event) - step_collector.post_process() - self.assertEqual(step_collector.attribute_to_dataset.get("profiler_step"), [self.mock_step_event]) - - def test_op_compile_collector(self): - op_compile_collector = OpCompileCollector() - op_compile_collector.add_op(self.mock_op_compile_event) - op_compile_collector.post_process(op_compile_collector.op_list) - self.assertEqual(op_compile_collector.attribute_to_dataset.get("ops_compile"), op_compile_collector) - self.assertEqual(op_compile_collector.total_time, 1) - self.assertEqual(op_compile_collector.total_count, 1) - - def test_sync_stream_collector(self): - sync_stream_collector = SynchronizeStreamCollector() - sync_stream_collector.post_process() - self.assertEqual(sync_stream_collector.attribute_to_dataset.get("synchronize_stream"), []) - - def test_mem_op_collector(self): - mem_op_collector = MemCollector() - mem_op_collector.add_op(self.mock_mem_op_event) - mem_op_collector.post_process(mem_op_collector.op_list) - self.assertEqual(mem_op_collector.attribute_to_dataset.get("memory_ops"), mem_op_collector) - self.assertEqual(mem_op_collector.mem_op_info.get("AscendCL@aclMallocMemInner"), {"count": 1, "total_dur": 10}) - - def test_dataloader_collector(self): - dataloader_collector = DataloaderCollector() - dataloader_collector.add_op(self.mock_dataloader_event) - dataloader_collector.post_process() - self.assertEqual(len(dataloader_collector.attribute_to_dataset.get("dataloader")), 1) - - def test_sync_bn_collector(self): - sync_bn_collector = SyncBNCollector() - sync_bn_collector.add_op(self.mock_sync_bn_event) - sync_bn_collector.post_process(sync_bn_collector.op_list) - self.assertEqual(len(sync_bn_collector.attribute_to_dataset.get("sync_batchnorm")), 1) - - def test_aten_collector(self): - aten_collector = AtenCollector() - aten_collector.add_op(self.mock_aten_event) - aten_collector.add_op(self.mock_sync_stream_event) - aten_collector.post_process(aten_collector.op_list) - self.assertEqual(len(aten_collector.attribute_to_dataset.get("aten")), 2) - - def test_optimizer_collector(self): - optimizer_collector = OptimizerCollector() - optimizer_collector.add_op(self.mock_optimizer_event) - optimizer_collector.post_process(optimizer_collector.op_list) - self.assertEqual(len(optimizer_collector.attribute_to_dataset.get("optimizer")), 1) - - def test_specific_task_type_op_collector(self): - specific_task_type_op_collector = SpecificTaskTypeOpCollector() - specific_task_type_op_collector.add_op(self.mock_AI_CPU_event) - specific_task_type_op_collector.post_process(specific_task_type_op_collector.op_list) - key = f"{self.mock_AI_CPU_event.name}-{self.mock_AI_CPU_event.ts}" - self.assertTrue( - specific_task_type_op_collector.attribute_to_dataset.get("ops_with_task_type", {}).get(key)) - self.assertTrue(specific_task_type_op_collector.attribute_to_dataset.get("task_op_names"), [key]) - - def test_torch_to_npu_collector(self): - torch_to_npu_collector = TorchToNpuCollector() - torch_to_npu_collector.add_op(self.mock_torch_to_npu_event) - torch_to_npu_collector.post_process(torch_to_npu_collector.op_list) - key = f"{self.mock_torch_to_npu_event.ph}-{self.mock_torch_to_npu_event.id}" - self.assertTrue("1-1" in torch_to_npu_collector.attribute_to_dataset.get("torch_to_npu")) - - def test_acl_to_npu_collector(self): - acl_to_npu_collector = AclToNpuCollector() - acl_to_npu_collector.add_op(self.mock_acl_to_npu_event) - acl_to_npu_collector.post_process(acl_to_npu_collector.op_list) - self.assertEqual(acl_to_npu_collector.attribute_to_dataset.get("acl_to_npu"), - set([str(self.mock_acl_to_npu_event.ts)])) - - def test_op_stack_collector(self): - op_stack_collector = OpStackCollector() - op_stack_collector.add_op(self.mock_op_stack_event) - op_stack_collector.post_process(op_stack_collector.op_list) - self.assertTrue( - str(self.mock_op_stack_event.ts) in op_stack_collector.attribute_to_dataset.get("ops_with_stack")) - - -if __name__ == '__main__': - tester = TestTimelineOpCollector() - tester.test_step_collector() - tester.test_op_compile_collector() - tester.test_sync_stream_collector() - tester.test_mem_op_collector() - tester.test_dataloader_collector() - tester.test_sync_bn_collector() - tester.test_aten_collector() - tester.test_optimizer_collector() - tester.test_specific_task_type_op_collector() - tester.test_torch_to_npu_collector() - tester.test_acl_to_npu_collector() - tester.test_op_stack_collector() +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +from msprof_analyze.advisor.dataset.timeline_op_collector.timeline_op_collector import ( + OpCompileCollector, + SynchronizeStreamCollector, + MemCollector, + DataloaderCollector, + SyncBNCollector, + AtenCollector, + OptimizerCollector, + FrequencyCollector, + SpecificTaskTypeOpCollector, + TorchToNpuCollector, + AclToNpuCollector, + OpStackCollector, + StepCollector +) +from msprof_analyze.advisor.common.timeline.event import TimelineEvent +from msprof_analyze.test.ut.advisor.advisor_backend.tools.tool import recover_env + + +class TestTimelineOpCollector(unittest.TestCase): + @classmethod + def tearDownClass(cls) -> None: + recover_env() + + def setUp(self) -> None: + self.mock_step_event = TimelineEvent(dict(name="ProfilerStep#1", ts=1, dur=1000)) + self.mock_op_compile_event = TimelineEvent(dict(name="AscendCL@aclopCompileAndExecute", ts=2, dur=1)) + self.mock_sync_stream_event = TimelineEvent(dict(name="AscendCL@aclrtSynchronizeStream", dur=1000000000)) + self.mock_mem_op_event = TimelineEvent(dict(name="AscendCL@aclMallocMemInner", dur=10)) + self.mock_dataloader_event = TimelineEvent(dict(name="dataloader")) + self.mock_sync_bn_event = TimelineEvent(dict(name="syncbatchnorm")) + self.mock_aten_event = TimelineEvent(dict(name="aten::conv3d")) + self.mock_optimizer_event = TimelineEvent(dict(name="Optimizer.step#")) + self.mock_AI_CPU_event = TimelineEvent( + {"name": "index", "args": TimelineEvent({"Task Type": "AI_CPU"}), "ts": 1}) + self.mock_torch_to_npu_event = TimelineEvent(dict(name="torch_to_npu", tid=1, ts=1, ph=1, id=1)) + self.mock_acl_to_npu_event = TimelineEvent(dict(name="acl_to_npu", ts=1)) + self.mock_op_stack_event = TimelineEvent( + {"name": "aten::conv3d", "dataset_index": 1, "ts": 1, "args": TimelineEvent({"Call stack": "mock_stack"})}) + + def test_step_collector(self): + step_collector = StepCollector() + step_collector.add_op(self.mock_step_event) + step_collector.post_process() + self.assertEqual(step_collector.attribute_to_dataset.get("profiler_step"), [self.mock_step_event]) + + def test_op_compile_collector(self): + op_compile_collector = OpCompileCollector() + op_compile_collector.add_op(self.mock_op_compile_event) + op_compile_collector.post_process(op_compile_collector.op_list) + self.assertEqual(op_compile_collector.attribute_to_dataset.get("ops_compile"), op_compile_collector) + self.assertEqual(op_compile_collector.total_time, 1) + self.assertEqual(op_compile_collector.total_count, 1) + + def test_sync_stream_collector(self): + sync_stream_collector = SynchronizeStreamCollector() + sync_stream_collector.post_process() + self.assertEqual(sync_stream_collector.attribute_to_dataset.get("synchronize_stream"), []) + + def test_mem_op_collector(self): + mem_op_collector = MemCollector() + mem_op_collector.add_op(self.mock_mem_op_event) + mem_op_collector.post_process(mem_op_collector.op_list) + self.assertEqual(mem_op_collector.attribute_to_dataset.get("memory_ops"), mem_op_collector) + self.assertEqual(mem_op_collector.mem_op_info.get("AscendCL@aclMallocMemInner"), {"count": 1, "total_dur": 10}) + + def test_dataloader_collector(self): + dataloader_collector = DataloaderCollector() + dataloader_collector.add_op(self.mock_dataloader_event) + dataloader_collector.post_process() + self.assertEqual(len(dataloader_collector.attribute_to_dataset.get("dataloader")), 1) + + def test_sync_bn_collector(self): + sync_bn_collector = SyncBNCollector() + sync_bn_collector.add_op(self.mock_sync_bn_event) + sync_bn_collector.post_process(sync_bn_collector.op_list) + self.assertEqual(len(sync_bn_collector.attribute_to_dataset.get("sync_batchnorm")), 1) + + def test_aten_collector(self): + aten_collector = AtenCollector() + aten_collector.add_op(self.mock_aten_event) + aten_collector.add_op(self.mock_sync_stream_event) + aten_collector.post_process(aten_collector.op_list) + self.assertEqual(len(aten_collector.attribute_to_dataset.get("aten")), 2) + + def test_optimizer_collector(self): + optimizer_collector = OptimizerCollector() + optimizer_collector.add_op(self.mock_optimizer_event) + optimizer_collector.post_process(optimizer_collector.op_list) + self.assertEqual(len(optimizer_collector.attribute_to_dataset.get("optimizer")), 1) + + def test_specific_task_type_op_collector(self): + specific_task_type_op_collector = SpecificTaskTypeOpCollector() + specific_task_type_op_collector.add_op(self.mock_AI_CPU_event) + specific_task_type_op_collector.post_process(specific_task_type_op_collector.op_list) + key = f"{self.mock_AI_CPU_event.name}-{self.mock_AI_CPU_event.ts}" + self.assertTrue( + specific_task_type_op_collector.attribute_to_dataset.get("ops_with_task_type", {}).get(key)) + self.assertTrue(specific_task_type_op_collector.attribute_to_dataset.get("task_op_names"), [key]) + + def test_torch_to_npu_collector(self): + torch_to_npu_collector = TorchToNpuCollector() + torch_to_npu_collector.add_op(self.mock_torch_to_npu_event) + torch_to_npu_collector.post_process(torch_to_npu_collector.op_list) + key = f"{self.mock_torch_to_npu_event.ph}-{self.mock_torch_to_npu_event.id}" + self.assertTrue("1-1" in torch_to_npu_collector.attribute_to_dataset.get("torch_to_npu")) + + def test_acl_to_npu_collector(self): + acl_to_npu_collector = AclToNpuCollector() + acl_to_npu_collector.add_op(self.mock_acl_to_npu_event) + acl_to_npu_collector.post_process(acl_to_npu_collector.op_list) + self.assertEqual(acl_to_npu_collector.attribute_to_dataset.get("acl_to_npu"), + set([str(self.mock_acl_to_npu_event.ts)])) + + def test_op_stack_collector(self): + op_stack_collector = OpStackCollector() + op_stack_collector.add_op(self.mock_op_stack_event) + op_stack_collector.post_process(op_stack_collector.op_list) + self.assertTrue( + str(self.mock_op_stack_event.ts) in op_stack_collector.attribute_to_dataset.get("ops_with_stack")) + + +if __name__ == '__main__': + tester = TestTimelineOpCollector() + tester.test_step_collector() + tester.test_op_compile_collector() + tester.test_sync_stream_collector() + tester.test_mem_op_collector() + tester.test_dataloader_collector() + tester.test_sync_bn_collector() + tester.test_aten_collector() + tester.test_optimizer_collector() + tester.test_specific_task_type_op_collector() + tester.test_torch_to_npu_collector() + tester.test_acl_to_npu_collector() + tester.test_op_stack_collector() diff --git a/profiler/msprof_analyze/test/ut/cluster_analyse/recipes/test_cann_api_sum.py b/profiler/msprof_analyze/test/ut/cluster_analyse/recipes/test_cann_api_sum.py new file mode 100644 index 0000000000000000000000000000000000000000..ef00d2deb696ebd6431eaabd196642dc10a5859b --- /dev/null +++ b/profiler/msprof_analyze/test/ut/cluster_analyse/recipes/test_cann_api_sum.py @@ -0,0 +1,69 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest.mock import patch +import pandas as pd + +from msprof_analyze.cluster_analyse.recipes.cann_api_sum.cann_api_sum import CannApiSum +from msprof_analyze.prof_common.constant import Constant + + +class TestCannApiSum(unittest.TestCase): + + @patch("msprof_analyze.cluster_analyse.recipes.base_recipe_analysis.BaseRecipeAnalysis.dump_data") + @patch("msprof_analyze.cluster_analyse.recipes.base_recipe_analysis.BaseRecipeAnalysis.add_helper_file") + @patch("msprof_analyze.cluster_analyse.recipes.base_recipe_analysis.BaseRecipeAnalysis.create_notebook") + @patch("msprof_analyze.cluster_analyse.recipes.base_recipe_analysis.BaseRecipeAnalysis.mapper_func") + def test_run_should_save_db_or_notebook(self, mock_mapper_func, mock_create_notebook, + mock_add_helper_file, mock_dump_data): + mock_mapper_func.return_value = [ + (0, pd.DataFrame({ + "name": ["aclnnCast"], + "durationRatio": [1.05], + "totalTimeNs": [761090], + "totalCount": [72], + "averageNs": [10570.7], + "minNs": [5530.0], + "Q1Ns": [6892.5], + "medNs": [9035.0], + "Q3Ns": [12910.0], + "maxNs": [28000.0], + "stdev": [4755.2] + }) + ), + (1, pd.DataFrame({ + "name": ["aclnnMseLoss"], + "durationRatio": [1.09], + "totalTimeNs": [271560], + "totalCount": [6], + "averageNs": [45260.7], + "minNs": [29240.0], + "Q1Ns": [35815.0], + "medNs": [52785.0], + "Q3Ns": [53075.0], + "maxNs": [53420.0], + "stdev": [10981.2] + })) + ] + params = {Constant.EXPORT_TYPE: Constant.DB} + recipe = CannApiSum(params) + recipe.run(context=None) + recipe._export_type = "notebook" + recipe.run(context=None) + self.assertEqual(recipe._stats_data.shape, (2, 12)) + self.assertEqual(recipe._stats_rank_data.shape, (2, 12)) + self.assertEqual(recipe._stats_data.iloc[0, 0], 73.7) + self.assertEqual(recipe._stats_data.iloc[1, 0], 26.3) \ No newline at end of file diff --git a/profiler/msprof_analyze/test/ut/cluster_analyse/recipes/test_cluster_time_compare_summary.py b/profiler/msprof_analyze/test/ut/cluster_analyse/recipes/test_cluster_time_compare_summary.py new file mode 100644 index 0000000000000000000000000000000000000000..d1d44745555187d7d3a82f19facc0074d9fe3049 --- /dev/null +++ b/profiler/msprof_analyze/test/ut/cluster_analyse/recipes/test_cluster_time_compare_summary.py @@ -0,0 +1,161 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest +from unittest import mock +import pandas as pd + +from msprof_analyze.cluster_analyse.recipes.cluster_time_compare_summary.cluster_time_compare_summary import \ + ClusterTimeCompareSummary +from msprof_analyze.prof_common.constant import Constant + +NAMESPACE = "msprof_analyze.prof_common" + + +class TestClusterTimeCompareSummary(unittest.TestCase): + PARAMS = { + Constant.COLLECTION_PATH: "/data", + Constant.DATA_MAP: {}, + Constant.DATA_TYPE: Constant.DB, + Constant.CLUSTER_ANALYSIS_OUTPUT_PATH: "./test_cluster_time_compare_summary", + Constant.RECIPE_NAME: "ClusterTimeCompareSummary", + Constant.RECIPE_CLASS: ClusterTimeCompareSummary, + Constant.PARALLEL_MODE: Constant.CONCURRENT_MODE, + Constant.EXPORT_TYPE: Constant.DB, + ClusterTimeCompareSummary.RANK_LIST: Constant.ALL, + } + + def test_check_params_is_valid_should_return_false_when_bp_param_does_not_exist(self): + params = {} + params.update(self.PARAMS) + self.assertFalse(ClusterTimeCompareSummary(params).check_params_is_valid()) + + def test_check_params_is_valid_should_return_false_when_export_type_is_notebook(self): + params = {Constant.EXTRA_ARGS: ["--bp", "/data2"]} + params.update(self.PARAMS) + params[Constant.EXPORT_TYPE] = Constant.NOTEBOOK + self.assertFalse(ClusterTimeCompareSummary(params).check_params_is_valid()) + + def test_check_params_is_valid_should_return_false_when_base_path_is_invalid(self): + params = {Constant.EXTRA_ARGS: ["--bp", "/data2"]} + params.update(self.PARAMS) + with mock.patch(NAMESPACE + ".path_manager.PathManager.check_input_file_path", side_effect=RuntimeError): + self.assertFalse(ClusterTimeCompareSummary(params).check_params_is_valid()) + + def test_check_params_is_valid_should_return_false_when_table_cluster_time_summary_does_not_exist(self): + params = {} + params.update(self.PARAMS) + with mock.patch(NAMESPACE + ".db_manager.DBManager.check_tables_in_db", return_value=False): + self.assertFalse(ClusterTimeCompareSummary(params).check_params_is_valid()) + + def test_check_params_is_valid_should_return_false_when_base_table_cluster_time_summary_does_not_exist(self): + params = {Constant.EXTRA_ARGS: ["--bp", "/data2"]} + params.update(self.PARAMS) + with mock.patch(NAMESPACE + ".path_manager.PathManager.check_input_file_path"), \ + mock.patch(NAMESPACE + ".db_manager.DBManager.check_tables_in_db", side_effect=[True, False]): + self.assertFalse(ClusterTimeCompareSummary(params).check_params_is_valid()) + + def test_run_when_all_parameters_are_normal(self): + params = {Constant.EXTRA_ARGS: ["--bp", "/data2"]} + params.update(self.PARAMS) + params[Constant.EXPORT_TYPE] = "" + data_base = [0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5, 10.5, 11.5, 12.5, 13.5] + data = [1.6, 2.6, 3.6, 4.6, 5.6, 6.6, 7.6, 8.6, 9.6, 10.6, 11.6, 12.6, 13.6, 14.6, 15.6, 16.6] + data1 = [1.6, 2.6, 3.6, 4.6, 5.6, 6.6, 7.6, 8.6, 9.6, 10.6, 11.6, 12.6, 13.6, 14.6] + data_diff = [1.1, 1.1, 1.1, 1.1, 1.1, 1.1, 1.1, 1.1, 1.1, 1.1, 1.1, 1.1, 1.1, 1.1] + base_cluster_time_summary_df_dict = { + Constant.TABLE_CLUSTER_TIME_SUMMARY: pd.DataFrame( + { + "rank": [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6], + "step": [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1], + "stepTime": data_base, + "computation": data_base, + "communicationNotOverlapComputation": data_base, + "communicationOverlapComputation": data_base, + "communication": data_base, + "free": data_base, + "communicationWaitStageTime": data_base, + "communicationTransmitStageTime": data_base, + "memory": data_base, + "memoryNotOverlapComputationCommunication": data_base, + "taskLaunchDelayAvgTime": data_base + } + ) + } + cluster_time_summary_df_dict = { + Constant.TABLE_CLUSTER_TIME_SUMMARY: pd.DataFrame( + { + "rank": [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7], + "step": [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1], + "stepTime": data, + "computation": data, + "communicationNotOverlapComputation": data, + "communicationOverlapComputation": data, + "communication": data, + "free": data, + "communicationWaitStageTime": data, + "communicationTransmitStageTime": data, + "memory": data, + "memoryNotOverlapComputationCommunication": data, + "taskLaunchDelayAvgTime": data + } + ) + } + expected_result = pd.DataFrame({ + "rank": [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6], + "step": [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1], + "stepTime": data1, + "stepTimeBase": data_base, + "stepTimeDiff": data_diff, + "computation": data1, + "computationBase": data_base, + "computationDiff": data_diff, + "communicationNotOverlapComputation": data1, + "communicationNotOverlapComputationBase": data_base, + "communicationNotOverlapComputationDiff": data_diff, + "communicationOverlapComputation": data1, + "communicationOverlapComputationBase": data_base, + "communicationOverlapComputationDiff": data_diff, + "communication": data1, + "communicationBase": data_base, + "communicationDiff": data_diff, + "free": data1, + "freeBase": data_base, + "freeDiff": data_diff, + "communicationWaitStageTime": data1, + "communicationWaitStageTimeBase": data_base, + "communicationWaitStageTimeDiff": data_diff, + "communicationTransmitStageTime": data1, + "communicationTransmitStageTimeBase": data_base, + "communicationTransmitStageTimeDiff": data_diff, + "memory": data1, + "memoryBase": data_base, + "memoryDiff": data_diff, + "memoryNotOverlapComputationCommunication": data1, + "memoryNotOverlapComputationCommunicationBase": data_base, + "memoryNotOverlapComputationCommunicationDiff": data_diff, + "taskLaunchDelayAvgTime": data1, + "taskLaunchDelayAvgTimeBase": data_base, + "taskLaunchDelayAvgTimeDiff": data_diff + }) + with mock.patch(NAMESPACE + ".path_manager.PathManager.check_input_file_path"), \ + mock.patch(NAMESPACE + ".db_manager.DBManager.check_tables_in_db", side_effect=[True, True]), \ + mock.patch(NAMESPACE + ".database_service.DatabaseService.query_data", + side_effect=[cluster_time_summary_df_dict, base_cluster_time_summary_df_dict]): + cluster_time_compare_summary = ClusterTimeCompareSummary(params) + cluster_time_compare_summary.run() + self.assertTrue(cluster_time_compare_summary.compare_result.round(2).equals(expected_result.round(2))) + diff --git a/profiler/msprof_analyze/test/ut/cluster_analyse/recipes/test_hccl_sum.py b/profiler/msprof_analyze/test/ut/cluster_analyse/recipes/test_hccl_sum.py new file mode 100644 index 0000000000000000000000000000000000000000..da24461c1a4d512ad312823150f11e993b743c43 --- /dev/null +++ b/profiler/msprof_analyze/test/ut/cluster_analyse/recipes/test_hccl_sum.py @@ -0,0 +1,68 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest.mock import patch +import pandas as pd + +from msprof_analyze.cluster_analyse.recipes.hccl_sum.hccl_sum import HcclSum +from msprof_analyze.prof_common.constant import Constant + + +class TestHcclSum(unittest.TestCase): + + def test_reduce_func_should_calculate_all_stats_df(self): + mapper_res = [ + pd.DataFrame({ + "OpName": ["hcom_allReduce__659_0_1"], + "OpType": ["hcom_allReduce"], + "Duration": [20400.0], + "GroupName": ["123%enp123abc_600001_1_17375465987622656"], + "Rank": [0], + }), + pd.DataFrame({ + "OpName": ["hcom_allReduce__659_0_1"], + "OpType": ["hcom_allReduce"], + "Duration": [881780.0], + "GroupName": ["123%enp123abc_600001_1_17375465987622656"], + "Rank": [1], + }) + ] + expected_all_fwk_stats = pd.DataFrame({ + "FrameworkDurationNs": [0.0, 0.0], + "CannDurationNs": [442600.0, 404020.0], + "DeviceDurationNs": [434849.0, 1502410.0], + "Rank": [0, 0], + "StepId": [0, 0], + }) + recipe = HcclSum({}) + recipe.reducer_func(mapper_res) + self.assertEqual(recipe.all_rank_stats.shape, (1, 9)) + self.assertEqual(recipe.per_rank_stats.shape, (2, 10)) + self.assertEqual(recipe.top_op_stats.shape, (1, 11)) + self.assertEqual(recipe.group_name_map.shape, (1, 2)) + + @patch("msprof_analyze.cluster_analyse.recipes.base_recipe_analysis.BaseRecipeAnalysis.dump_data") + @patch("msprof_analyze.cluster_analyse.recipes.base_recipe_analysis.BaseRecipeAnalysis.add_helper_file") + @patch("msprof_analyze.cluster_analyse.recipes.base_recipe_analysis.BaseRecipeAnalysis.create_notebook") + @patch("msprof_analyze.cluster_analyse.recipes.hccl_sum.hccl_sum.HcclSum.reducer_func") + @patch("msprof_analyze.cluster_analyse.recipes.base_recipe_analysis.BaseRecipeAnalysis.mapper_func") + def test_run_should_save_db_or_notebook(self, mock_mapper_func, mock_reducer_func, mock_create_notebook, + mock_add_helper_file, mock_dump_data): + params = {Constant.EXPORT_TYPE: Constant.DB} + recipe = HcclSum(params) + recipe.run(context=None) + recipe._export_type = "notebook" + recipe.run(context=None) \ No newline at end of file diff --git a/profiler/msprof_analyze/test/ut/cluster_analyse/recipes/test_pp_chart.py b/profiler/msprof_analyze/test/ut/cluster_analyse/recipes/test_pp_chart.py new file mode 100644 index 0000000000000000000000000000000000000000..5a21c032bed2d57606a0cd2e739d0510231f8236 --- /dev/null +++ b/profiler/msprof_analyze/test/ut/cluster_analyse/recipes/test_pp_chart.py @@ -0,0 +1,77 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest +from unittest import mock +import pandas as pd + +from msprof_analyze.cluster_analyse.common_func.context import Context +from msprof_analyze.cluster_analyse.recipes.pp_chart.pp_chart import PPChart + +NAMESPACE = "msprof_analyze.cluster_analyse.recipes" + + +class TestClusterTimeCompareSummary(unittest.TestCase): + def test_calculate_micro_batch_id_for_dualpipev_when_pp_size_4_and_num_microbatches_10(self): + expected_pp_stage_mstx_num = { + 0: 44, + 1: 32, + 2: 30, + 3: 29 + } + expected_micro_batch_id_dict = { + 0: [['0', 0], ['1', 0], ['2', 0], ['3', 0], ['4', 0], ['5', 0], ['6', 1], ['10', 1], ['logits', 2], + ['10b', 2], ['10w', 2], ['11', 2], ['logits', 2], ['11b', 2], ['11w', 2], ['12', 2], ['logits', 2], + ['12b', 2], ['12w', 2], ['13', 2], ['logits', 3], ['7F+13B', 3], ['14F+0B', 3], ['logits', 3], + ['8F+14B', 3], ['15F+1B', 3], ['logits', 3], ['9F+15B', 3], ['16F+2B', 3], ['logits', 4], ['16B', 4], + ['17F+3B', 4], ['logits', 4], ['17B', 4], ['18F+4B', 4], ['logits', 4], ['18B', 4], ['19F+5B', 4], + ['logits', 5], ['19B', 5], ['6B', 6], ['7B', 6], ['8B', 6], ['9B', 6]], + 1: [['0', 0], ['1', 0], ['2', 0], ['3', 0], ['4', 1], ['10', 1], ['5', 1], ['11', 1], ['10b', 2], + ['10w', 2], ['12', 2], ['11b', 2], ['11w', 2], ['13', 2], ['6F+12B', 3], ['14F+0B', 3], ['7F+13B', 3], + ['15F+1B', 3], ['8F+14B', 3], ['16F+2B', 3], ['9F+15B', 3], ['17F+3B', 3], ['16B', 4], ['18F+4B', 4], + ['17B', 4], ['19F+5B', 4], ['18B', 5], ['6B', 5], ['19B', 6], ['7B', 6], ['8B', 6], ['9B', 6]], + 2: [['0', 0], ['1', 0], ['2', 1], ['10', 1], ['3', 1], ['11', 1], ['4', 1], ['12', 1], ['10b', 2], + ['10w', 2], ['13', 2], ['5F+11B', 3], ['14F+0B', 3], ['6F+12B', 3], ['15F+1B', 3], ['7F+13B', 3], + ['16F+2B', 3], ['8F+14B', 3], ['17F+3B', 3], ['9F+15B', 3], ['18F+4B', 3], ['16B', 4], ['19F+5B', 4], + ['17B', 5], ['6B', 5], ['18B', 5], ['7B', 6], ['19B', 6], ['8B', 6], ['9B', 6]], + 3: [['0', 1], ['10', 1], ['1', 1], ['11', 1], ['2', 1], ['12', 1], ['3', 1], ['13', 1], ['4F', 3], + ['10B', 3], ['14F+0B', 3], ['5F+11B', 3], ['15F+1B', 3], ['6F+12B', 3], ['16F+2B', 3], ['7F+13B', 3], + ['17F+3B', 3], ['8F+14B', 3], ['18F+4B', 3], ['9F+15B', 3], ['19F+5B', 3], ['16B', 5], ['6B', 5], + ['17B', 5], ['7B', 5], ['18B', 6], ['8B', 6], ['19B', 6], ['9B', 6]] + } + with (mock.patch(NAMESPACE + ".base_recipe_analysis.BaseRecipeAnalysis.load_distributed_args", + return_value={PPChart.PP_SIZE: 4}), + mock.patch(NAMESPACE + ".pp_chart.pp_chart.PPChart.load_pp_info")): + pp_chart_instance = PPChart({}) + pp_chart_instance.micro_batch_num = 10 + pp_chart_instance.calculate_micro_batch_id_for_dualpipev() + self.assertEqual(pp_chart_instance.pp_stage_mstx_num, expected_pp_stage_mstx_num) + self.assertEqual(pp_chart_instance.micro_batch_id_dict, expected_micro_batch_id_dict) + + def test_pp_chart_should_generate_table_when_pp_info_not_existed(self): + df = pd.DataFrame({"step": [0, 0], "msg": ["forward_step", "backward_step"], "startNs": [1, 4], + "endNs": [2, 5]}) + with mock.patch(NAMESPACE + ".base_recipe_analysis.BaseRecipeAnalysis.load_distributed_args", + return_value={}), \ + mock.patch(NAMESPACE + ".base_recipe_analysis.BaseRecipeAnalysis.dump_data"), \ + mock.patch(NAMESPACE + ".pp_chart.pp_chart.PPChart.load_pp_info"), \ + mock.patch("msprof_analyze.prof_exports.base_stats_export.BaseStatsExport.read_export_db", + return_value=df): + with Context.create_context() as context: + pp_chart_instance = PPChart({}) + pp_chart_instance.micro_batch_num = 10 + pp_chart_instance.run(context) + self.assertFalse(pp_chart_instance.micro_batch_id_dict) diff --git a/sample/README.md b/sample/README.md index 15238cb9f3815d6fecb0c743e6f826d2abc2988b..8e555f4870d2c39fc5cabad3092d1c17f60d3dfa 100644 --- a/sample/README.md +++ b/sample/README.md @@ -8,19 +8,10 @@ 说明:该sample目录中,每个最小目录就是一个完整的样例工程。这些样例工程本身可能以为依赖的不同存在差异。 ## 依赖说明 -- 硬件环境请参见《[昇腾产品形态说明](https://gitee.com/link?target=https%3A%2F%2Fwww.hiascend.com%2Fdocument%2Fdetail%2Fzh%2Fcanncommercial%2F80RC22%2Fquickstart%2Fquickstart%2Fquickstart_18_0002.html)》。 -- 软件环境请参见《[CANN 软件安装指南](https://gitee.com/link?target=https%3A%2F%2Fwww.hiascend.com%2Fdocument%2Fdetail%2Fzh%2Fcanncommercial%2F80RC22%2Fsoftwareinst%2Finstg%2Finstg_0000.html%3FMode%3DPmIns%26OS%3DUbuntu%26Software%3DcannToolKit)》安装昇腾设备开发或运行环境,即toolkit软件包。 - -以上环境依赖请根据实际环境选择适配的版本。 - -### 版本配套 -| 条件 | 要求 | -|---|---| -| CANN版本 | >=8.0.RC1.alpha001 | -| 硬件要求 | Atlas 800T A2 训练服务器| - -- 支持AscendPyTorch 1.11.0或更高版本,支持的PyTorch和CANN以及PyTorch和Python软件版本配套关系请参见《[Ascend Extension for PyTorch插件](https://gitee.com/ascend/pytorch)》。 -- 固件驱动版本与配套CANN软件支持的固件驱动版本相同,开发者可通过“[昇腾社区-固件与驱动](https://gitee.com/link?target=https%3A%2F%2Fwww.hiascend.com%2Fhardware%2Ffirmware-drivers%2Fcommunity%3Fproduct%3D2%26model%3D28%26cann%3D8.0.RC3.alpha003%26driver%3D1.0.25.alpha)”页面根据产品型号与CANN软件版本获取配套的固件与驱动。 +安装CANN包,并使能环境变量,并确保```ASCEND_HOME_PATH```生效,可以在CANN包安装目录下使能: +``` +source set_env.sh +``` ## 目录介绍 整体目录结构如下: @@ -100,7 +91,7 @@ mssanitizer ./*.fatbin # 默认进行memcheck检查 ``` LINK_LIBS := -L${ASCEND_HOME_PATH}/lib64 -lruntime -lascendcl -lstdc++ 修改为: - LINK_LIBS := -L${ASCEND_HOME_PATH}/lib64 -L${ASCEND_HOME_PATH}/tools/simulator/${SOC_VERSION}/lib/ -lruntime_camodel -lascendcl -lstdc++ # 需要添加libruntime_camodel的依赖路径, SOC_VERSION 通过使用npu-smi info命令进行查询,获取Chip Name信息。实际配置值 为AscendChip Name,例如Chip Name取值为xxxyy,实际配置值为Ascendxxxyy。当Ascendxxxyy为代码样例路径时,需要配置ascendxxxyy。 + LINK_LIBS := -L${ASCEND_HOME_PATH}/lib64 -L${ASCEND_HOME_PATH}/tools/simulator/${SOC_VERSION}/lib/ -lruntime_camodel -lascendcl -lstdc++ # 需要添加libruntime_camodel的依赖路径, SOC_VERSION 使用npu-smi info查询NPU Name ``` + 调试信息增强: ``` diff --git "a/\345\205\254\347\275\221URL\350\257\264\346\230\216.md" "b/\345\205\254\347\275\221URL\350\257\264\346\230\216.md" index c78d206c1a47d0e39555574ac78b111cc0d37c53..5d77e387caf7964eb405dc2aa5b7cbb009f510cc 100644 --- "a/\345\205\254\347\275\221URL\350\257\264\346\230\216.md" +++ "b/\345\205\254\347\275\221URL\350\257\264\346\230\216.md" @@ -2,13 +2,8 @@ | 软件类型 | 软件名 | 路径 | 类型 | 内容 | 用途说明 | |------|----------------------------------------------------|------------------------------------------|------|------------------------------------------------------------------------------------------------------------|--------------------| -| 开源软件 | MindStudio Training Tools - msprof-analyze advisor | /profiler/msprof_analyze/advisor/config/config.ini | 公网地址 | https://gitee.com/ascend/mstt/blob/master/profiler/msprof_analyze/advisor/doc/Samples%20of%20Fused%20Operator%20API%20Replacement.md" | Advisor优化手段参考示例 | -| 开源软件 | MindStudio Training Tools - msprof-analyze advisor | /profiler/msprof_analyze/advisor/config/config.ini | 公网地址 | https://www.hiascend.com/document/detail/zh/canncommercial/70RC1/modeldevpt/ptmigr/AImpug_0067.html | Advisor优化手段参考示例 | -| 开源软件 | MindStudio Training Tools - msprof-analyze advisor | /profiler/msprof_analyze/advisor/config/config.ini | 公网地址 | https://www.hiascend.com/document/detail/zh/canncommercial/70RC1/devtools/auxiliarydevtool/aoe_16_045.html | Advisor优化手段参考示例 | -| 开源软件 | MindStudio Training Tools - msprof-analyze advisor | /profiler/msprof_analyze/advisor/config/config.ini | 公网地址 | https://www.mindspore.cn/lite/docs/en/master/use/cloud_infer/converter_tool_ascend.html#aoe-auto-tuning | Advisor优化手段参考示例 | -| 开源软件 | MindStudio Training Tools - msprof-analyze advisor | /profiler/msprof_analyze/advisor/config/config.ini | 公网地址 | https://www.hiascend.com/document/detail/zh/canncommercial/70RC1/modeldevpt/ptmigr/AImpug_0059.html | Advisor优化手段参考示例 | -| 开源软件 | MindStudio Training Tools - msprof-analyze | /profiler/msprof_analyze/config/config.ini | 公网地址 | https://gitee.com/ascend/mstt/tree/master/profiler/msprof_analyze | msprof-analyze工具地址 | -| 开源软件 | MindStudio Training Tools - msprof-analyze | /profiler/msprof_analyze/LICENSE | 公网地址 | http://www.apache.org/licenses/LICENSE-2.0 | 开源软件协议地址 | -| 开源软件 | MindStudio Training Tools - msprof-analyze advisor | /profiler/msprof_analyze/advisor/rules/aicpu_rules.ymal | 公网地址 | https://gitee.com/ascend/mstt/blob/master/profiler/msprof_analyze/advisor/doc/Samples%20of%20AI%20CPU%20Operator%20Replacement.md | AI CPU 算子替换样例 | -| 开源软件 | MindStudio Training Tools - msprof-analyze advisor | /profiler/msprof_analyze/advisor/rules/environment_variable_info.yaml | 公网地址 | https://support.huawei.com/enterprise/zh/doc/EDOC1100371278/5eeeed85?idPath=23710424 | 组网指南 | -| 开源软件 | MindStudio Training Tools - msprof-analyze | /profiler/msprof_analyze/config/config.ini | 公网地址 | pmail_mindstudio@huawei.com | 公网邮箱 | +| 开源软件 | MindStudio Training Tools - accuracy_tools | /debug/accuracy_tools/cmake/config.ini | 公网地址 | https://gitee.com/mirrors/googletest/repository/archive/release-1.12.1.tar.gz | 开源软件下载 | +| 开源软件 | MindStudio Training Tools - accuracy_tools | /debug/accuracy_tools/cmake/config.ini | 公网地址 | https://gitee.com/sinojelly/mockcpp/repository/archive/v2.7.zip | 开源软件下载 | +| 开源软件 | MindStudio Training Tools - accuracy_tools | /debug/accuracy_tools/cmake/config.ini | 公网地址 | https://gitee.com/mirrors/JSON-for-Modern-CPP/repository/archive/v3.10.1.zip | 开源软件下载 | +| 开源软件 | MindStudio Training Tools - accuracy_tools | /debug/accuracy_tools/cmake/config.ini | 公网地址 | https://gitee.com/mirrors/openssl/repository/archive/OpenSSL_1_1_1k.tar.gz | 开源软件下载 | +| 开源软件 | MindStudio Training Tools - accuracy_tools | /debug/accuracy_tools/cmake/config.ini | 公网地址 | https://gitee.com/mirrors/protobuf_source/repository/archive/v3.15.0.tar.gz | 开源软件下载 |