diff --git a/.gitee/PULL_REQUEST_TEMPLATE.zh-CN.md b/.gitee/PULL_REQUEST_TEMPLATE.zh-CN.md new file mode 100644 index 0000000000000000000000000000000000000000..e9cc1deb82ff0498f1a8267cd288ecde798f308c --- /dev/null +++ b/.gitee/PULL_REQUEST_TEMPLATE.zh-CN.md @@ -0,0 +1,66 @@ +# PR 合入模板 + +**注:经过自检不涉及的可标注“不涉及”或直接打勾,特殊情况请文字备注。不符合规范的 PR 不允许合入,请(后备)commit 注意。** + +--- + +## 1. 修改描述 +- **修改原因:** +- **修改内容:** + +--- + +## 2. 功能验证 +- [ ] **功能自验** +- [ ] **本地自验用例截图** +- [ ] **冒烟是否通过** (填入群链接的自验证报告中,如未通过,请说明原因:____________________ ,功能代码请主动申报添加冒烟) + +--- + +## 3. 分支合并要求 +- [ ] **代码合并**(请确保将 master 分支的最新代码同步合并至 poc 分支及 pre-research 分支,同时保证 poc 分支的代码也已正确合并到 pre-research 分支。) + +--- + +## 3. 代码检视 +- **要求:** + - 合入代码超过 200 行,需三人以上会议检视。 + - 检视密度≥1个/100行。 + - 检视缺陷密度未达要求需提供说明。 + - 大于 1000 行代码原则上不允许合入,需进行备案。 +- [ ] **是否经过代码检视** +- [ ] **是否具备 UT 测试用例看护** (如不符合,请说明原因:____________________) + +- **检视意见数:____ 条** (请填写本次检视的意见总数,用于commit合入前审视) + +--- + +## 4. 安全自检 + +### Python、C++: +- [ ] **对外接口新增/删除/变更,需要更新外部输入表格** +- [ ] **不允许私有的文件操作,需要使用公共函数** +- [ ] **数组使用需要校验越界场景** +- [ ] **对正则表达式做 ReDos 校验** +- [ ] **对除法做除零校验** +- [ ] **充分进行接口返回值异常情况的校验** +- [ ] **充分进行接口输入值异常情况的校验** +- [ ] **日志不要暴露代码细节和敏感信息** + +### C++: +- [ ] **指针使用前需要判空** +- [ ] **数值计算校验溢出和反转** +- [ ] **不可存在内存泄漏(异常场景需要释放内存)** +- [ ] **类型转换不能出现数据截断** +- [ ] **拷贝字符串时,目的缓冲区至少比源缓冲区大 1** +- [ ] **拷贝内存时,目的缓冲区不小于源缓冲区** +- [ ] **内存释放后指针赋值为 nullptr** + + +--- + +## 5. 变更知会 +- [ ] **资料修改** +- [ ] **变更通知(消息知会 + 邮件知会)** + +--- diff --git a/debug/OWNERS b/debug/OWNERS index 12cfc46b262c2be6adf69481d928adc18ea1ee7d..0bda9243569f0b6bcd0ce761d7817d512b487ddd 100644 --- a/debug/OWNERS +++ b/debug/OWNERS @@ -3,17 +3,14 @@ options: approvers: - wangchao285 - kun_8 -- binghamhuang - brightlyking -- litian_drinksnow reviewers: - lv-kaimeng -- binghamhuang - TAJh - jiandaobao - pengxiaopeng1 - zhengxinqian - louyujing -- yangchen +- yang_chen_2001_02_14 - shawnzhu1 - wqc01202410 diff --git a/debug/accuracy_tools/msprobe/README.md b/debug/accuracy_tools/msprobe/README.md index d770019febbae686583ddff36a69a943eb50efe3..0e68d1f8d9bdaba93a2f65220f85d08eb45f8586 100644 --- a/debug/accuracy_tools/msprobe/README.md +++ b/debug/accuracy_tools/msprobe/README.md @@ -15,7 +15,7 @@ debugger = PrecisionDebugger(config_path='./config.json') ... debugger.start() # 一般在训练循环开头启动工具 ... # 循环体 -debugger.stop() # 一般在训练循环末尾结束工具 +debugger.stop() # 一般在训练循环末尾结束工具。必须调用,否则可能导致精度数据落盘不全 debugger.step() # 在训练循环的最后需要重置工具,非循环场景不需要 ``` @@ -51,6 +51,8 @@ export MSPROBE_LOG_LEVEL={x} **1. Pytorch 框架下,工具暂不支持 Fully Sharded Data Parallel(FSDP)。** +**2. 工具读写的所有路径,如config_path、dump_path等,只允许包含大小写字母、数字、下划线、斜杠、点和短横线。** + ## ⚙️ [安装](./docs/01.installation.md) ## 🌟 新版本特性 @@ -125,7 +127,7 @@ MindSpore 动态图场景的[离线预检](./docs/09.accuracy_checker_MindSpore. 该功能收集和聚合模型训练过程中的网络层,优化器, 通信算子的中间值,帮助诊断模型训练过程中计算, 通信,优化器各部分出现的异常情况。 -[PyTorch 场景的训练状态监控](./docs/19.monitor.md) +[兼容 PyTorch 和 MindSpore 框架的训练状态监控](./docs/19.monitor.md) ### 10 分级可视化构图比对 diff --git a/debug/accuracy_tools/msprobe/ccsrc/core/AclDumper.cpp b/debug/accuracy_tools/msprobe/ccsrc/core/AclDumper.cpp index aa9beea96dc4aac048204c92a1a7401494a74c52..80769d7fc5fbc9d36115a544e05dd00f2a7541c3 100644 --- a/debug/accuracy_tools/msprobe/ccsrc/core/AclDumper.cpp +++ b/debug/accuracy_tools/msprobe/ccsrc/core/AclDumper.cpp @@ -1,5 +1,5 @@ /* - * Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. + * Copyright (C) 2024-2025. Huawei Technologies Co., Ltd. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -469,5 +469,40 @@ void AclDumper::FinalizeDump(ExtArgs& args) aclDumpHasSet = false; } +void KernelInitDump() { + if (AscendCLApi::LoadAclApi() != DebuggerErrno::OK) { + return; + } + + DebuggerErrno ret = InitAcl(); + if (ret != DebuggerErrno::OK) { + LOG_ERROR(ret, "Failed to call InitAcl."); + return; + } + auto aclRet = CALL_ACL_API(aclmdlInitDump); + if (aclRet != ACL_SUCCESS) { + LOG_ERROR(DebuggerErrno::ERROR_EXTERNAL_API_ERROR, + "Failed to init acldump(" + std::to_string(aclRet) + ")."); + return; + } +} +void KernelSetDump(const std::string &filePath) { + std::string dumpPath = FileUtils::GetAbsPath(filePath); + auto aclRet = CALL_ACL_API(aclmdlSetDump, dumpPath.c_str()); + if (aclRet != ACL_SUCCESS) { + LOG_ERROR(DebuggerErrno::ERROR_EXTERNAL_API_ERROR, + "Failed to enable acldump(" + std::to_string(aclRet) + ")."); + return; + } +} + +void KernelFinalizeDump() { + CALL_ACL_API(aclrtSynchronizeDevice); + auto aclRet = CALL_ACL_API(aclmdlFinalizeDump); + if (aclRet != ACL_SUCCESS) { + LOG_ERROR(DebuggerErrno::ERROR_EXTERNAL_API_ERROR, + "Failed to finalize acldump(" + std::to_string(aclRet) + ")."); + } +} } \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/ccsrc/core/AclDumper.hpp b/debug/accuracy_tools/msprobe/ccsrc/core/AclDumper.hpp index ff1a40ae752bfc45ddcf14817d64e3df8d8f83e8..dcfad5fafcabdf944e1d4b0b0a3cd77251ce047d 100644 --- a/debug/accuracy_tools/msprobe/ccsrc/core/AclDumper.hpp +++ b/debug/accuracy_tools/msprobe/ccsrc/core/AclDumper.hpp @@ -1,5 +1,5 @@ /* - * Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. + * Copyright (C) 2024-2025. Huawei Technologies Co., Ltd. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -65,4 +65,7 @@ private: std::map> dataProcessors; }; +void KernelInitDump(); +void KernelSetDump(const std::string &filePath); +void KernelFinalizeDump(); } \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/ccsrc/if/python/ACLDump.cpp b/debug/accuracy_tools/msprobe/ccsrc/if/python/ACLDump.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1c380ed3f505795eb622f7f401558f72a54db557 --- /dev/null +++ b/debug/accuracy_tools/msprobe/ccsrc/if/python/ACLDump.cpp @@ -0,0 +1,64 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include "base/ErrorInfos.hpp" +#include "core/AclDumper.hpp" +#include "utils/CPythonUtils.hpp" + +namespace MindStudioDebugger { + +static PyObject *CPythonKernelInitDump(PyObject *module, PyObject *args) { + PyGILState_STATE gstate = PyGILState_Ensure(); + KernelInitDump(); + PyGILState_Release(gstate); + Py_RETURN_NONE; +} + +static PyObject *CPythonKernelSetDump(PyObject *module, PyObject *args) { + const char *path; + if (!PyArg_ParseTuple(args, "s", &path)) { + LOG_ERROR(DebuggerErrno::ERROR_INVALID_VALUE, + "npu set dump error, cfg_file must string"); + return nullptr; + } + PyGILState_STATE gstate = PyGILState_Ensure(); + KernelSetDump(std::string(path)); + PyGILState_Release(gstate); + Py_RETURN_NONE; +} + +static PyObject *CPythonKernelFinalizeDump(PyObject *module, PyObject *args) { + PyGILState_STATE gstate = PyGILState_Ensure(); + KernelFinalizeDump(); + PyGILState_Release(gstate); + Py_RETURN_NONE; +} + +static PyMethodDef DumpMethods[] = { + {"init_dump", reinterpret_cast(CPythonKernelInitDump), + METH_NOARGS, "Initialize dump."}, + {"set_dump", reinterpret_cast(CPythonKernelSetDump), + METH_VARARGS, "Set dump."}, + {"finalize_dump", reinterpret_cast(CPythonKernelFinalizeDump), + METH_NOARGS, "Finalize dump."}, + {nullptr, nullptr, 0, nullptr}}; + +PyMethodDef *GetDumpMethods() { return DumpMethods; } +} // namespace MindStudioDebugger diff --git a/debug/accuracy_tools/msprobe/ccsrc/if/python/ACLDump.hpp b/debug/accuracy_tools/msprobe/ccsrc/if/python/ACLDump.hpp new file mode 100644 index 0000000000000000000000000000000000000000..11ae2ad4adb634e0c7cf58295127f76340796b84 --- /dev/null +++ b/debug/accuracy_tools/msprobe/ccsrc/if/python/ACLDump.hpp @@ -0,0 +1,23 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace MindStudioDebugger { +PyMethodDef *GetDumpMethods(); +} diff --git a/debug/accuracy_tools/msprobe/ccsrc/if/python/MsProbeIfPython.cpp b/debug/accuracy_tools/msprobe/ccsrc/if/python/MsProbeIfPython.cpp index a3a928e4d2a611c9b85fe2604379eecb70775381..a18c54a146f7d676d6b3c7f760e50f9e7eebe56c 100644 --- a/debug/accuracy_tools/msprobe/ccsrc/if/python/MsProbeIfPython.cpp +++ b/debug/accuracy_tools/msprobe/ccsrc/if/python/MsProbeIfPython.cpp @@ -1,5 +1,5 @@ /* - * Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. + * Copyright (C) 2024-2025. Huawei Technologies Co., Ltd. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,6 +18,7 @@ #include "PrecisionDebuggerIfPython.hpp" #include "CPythonAgent.hpp" +#include "ACLDump.hpp" namespace MindStudioDebugger { @@ -72,5 +73,13 @@ PyMODINIT_FUNC PyInit__msprobe_c(void) } Py_INCREF(cpyAgent); + PyMethodDef* dumpmethods = MindStudioDebugger::GetDumpMethods(); + for (PyMethodDef* method = dumpmethods; method->ml_name != nullptr; ++method) { + if (PyModule_AddObject(m, method->ml_name, PyCFunction_New(method, nullptr)) < 0) { + PyErr_SetString(PyExc_ImportError, "Failed to bind dump method."); + Py_DECREF(m); + return nullptr; + } + } return m; } \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/config.json b/debug/accuracy_tools/msprobe/config.json index 3b6c930fdd7b981d89c4711a6761f5d00ec3d547..553b7f9ee3b89215647b00fb14b70af44ea5f00c 100644 --- a/debug/accuracy_tools/msprobe/config.json +++ b/debug/accuracy_tools/msprobe/config.json @@ -5,6 +5,7 @@ "step": [], "level": "L1", "enable_dataloader": false, + "async_dump": false, "tensor": { "scope": [], "list":[], diff --git a/debug/accuracy_tools/msprobe/core/common/const.py b/debug/accuracy_tools/msprobe/core/common/const.py index 9591a9a726a6dc46f6b469a6ed84ba2bf1c299c0..6824fc8b42b975d5f5c84e58cd5a82bf2f5d52ef 100644 --- a/debug/accuracy_tools/msprobe/core/common/const.py +++ b/debug/accuracy_tools/msprobe/core/common/const.py @@ -103,8 +103,9 @@ class Const: FREE_BENCHMARK = "free_benchmark" RUN_UT = "run_ut" GRAD_PROBE = "grad_probe" - TASK_LIST = [TENSOR, STATISTICS, OVERFLOW_CHECK, FREE_BENCHMARK, RUN_UT, GRAD_PROBE] - DUMP_DATA_COLLECTION_LIST = [STATISTICS, TENSOR] + STRUCTURE = "structure" + TASK_LIST = [TENSOR, STATISTICS, OVERFLOW_CHECK, FREE_BENCHMARK, RUN_UT, GRAD_PROBE, STRUCTURE] + DUMP_DATA_COLLECTION_LIST = [STATISTICS, TENSOR, STRUCTURE] DUMP_DATA_MODE_LIST = [ALL, INPUT, OUTPUT, FORWARD, BACKWARD] LEVEL_L0 = "L0" LEVEL_L1 = "L1" @@ -197,6 +198,7 @@ class Const: # data type const TORCH_INT_DTYPE = ["torch.int8", "torch.int32", "torch.int64"] + TORCH_FLOAT_DTYPE = ["torch.bfloat16", "torch.float16", "torch.float32", "torch.float64"] FLOAT16 = "Float16" FLOAT32 = "Float32" BFLOAT16 = "BFloat16" @@ -231,6 +233,8 @@ class Const: CLIP_GRAD = "clip_grad" END_PREFIX = "end_" + TENSOR_STAT_LEN = 2 + class CompareConst: """ @@ -539,59 +543,6 @@ class OverflowConst: OVERFLOW_DEBUG_MODE = 1 -class MsCompareConst: - # api_info field - MINT = "Mint" - MINT_FUNCTIONAL = "MintFunctional" - TENSOR_API = "Tensor" - - API_NAME_STR_LENGTH = 4 - MAX_RECURSION_DEPTH = 20 - - # Mindtorch api_info field - MINDTORCH_TENSOR = "Tensor" - MINDTORCH = "Torch" - MINDTORCH_FUNC = "Functional" - MINDTORCH_NPU = "NPU" - MINDTORCH_DIST = "Distributed" - - MT_VALID_API_TYPES = [ - MINDTORCH, MINDTORCH_FUNC, MINDTORCH_TENSOR - ] - - TASK_FIELD = "task" - STATISTICS_TASK = "statistics" - FRAMEWORK = "framework" - TENSOR_TASK = "tensor" - DUMP_DATA_DIR_FIELD = "dump_data_dir" - DATA_FIELD = "data" - - # supported api yaml - SUPPORTED_API_LIST_FILE = "checker_support_api.yaml" - SUPPORTED_TENSOR_LIST_KEY = "tensor" - - # detail_csv - DETAIL_CSV_API_NAME = "API Name" - DETAIL_CSV_BENCH_DTYPE = "Bench Dtype" - DETAIL_CSV_TESTED_DTYPE = "Tested Dtype" - DETAIL_CSV_SHAPE = "Shape" - DETAIL_CSV_PASS_STATUS = "Status" - DETAIL_CSV_MESSAGE = "Message" - DETAIL_CSV_FILE_NAME = "accuracy_checking_details" - - # result_csv - RESULT_CSV_FORWARD_TEST_SUCCESS = "Forward Test Success" - RESULT_CSV_BACKWARD_TEST_SUCCESS = "Backward Test Success" - RESULT_CSV_FILE_NAME = "accuracy_checking_result" - - EPSILON = 1e-8 - - class ProcessStatus: - SUCCESS = "success" - API_NOT_FOUND = "api_not_found" - EXCEPTION_SKIP = "exception_skip" - - class MsgConst: """ Class for log messages const @@ -639,28 +590,37 @@ class MonitorConst: "DeepSpeedZeroOptimizer_Stage1_or_2", "DeepSpeedZeroOptimizer_Stage3" ) + DEEPSPEED_ZERO_OPT_FILTER = "DeepSpeedZeroOptimizer" RULE_NAME = ['AnomalyTurbulence'] + SLICE_SIZE = 20480 + # used for name DOT = "." - VPP_SEP = ":" + NAME_SEP = ":" + INPUT_GRAD = "input_grad" + OUTPUT_GRAD = "output_grad" ACTV_IN = "input" ACTV_OUT = "output" ACTVGRAD_IN = "input_grad" ACTVGRAD_OUT = "output_grad" + # used for tasks + ACTV = "actv" + ACTVGRAD = "actv_grad" POST_GRAD = "post_grad" PRE_GRAD = "pre_grad" ACC_GRAD = "acc_grad" PREFIX_POST = "post" PREFIX_PRE = "pre" - OUTPUT_DIR_PATTERN = r"([\w-]{0,20})-rank(\d{1,5})-" - EXP_AVG = "exp_avg" - EFXP_AVG_SQ = "efxp_avg_sq" + EXP_AVG_SQ = "exp_avg_sq" + PARAM = "param" + CSV_HEADER = ["vpp_stage", "name", "step"] + CSV_HEADER_XY = ["vpp_stage", "name", "step", "micro_step"] + OUTPUT_DIR_PATTERN = r"([\w-]{0,20})-rank(\d{1,5})-" ANOMALY_JSON = "anomaly.json" ANALYSE_JSON = "anomaly_analyse.json" TENSORBOARD = "tensorboard" CSV = "csv" API = "api" - OPS_START_INDEX = 3 - HEADER_NAME_INDEX = 1 + HEADER_NAME = 'name' diff --git a/debug/accuracy_tools/msprobe/core/common/exceptions.py b/debug/accuracy_tools/msprobe/core/common/exceptions.py index 39c1d6a4bcfe8d83c0fc5bbcd9ed613770be80cb..d71d30224b677fb19361f62de0ee25b2d32d389f 100644 --- a/debug/accuracy_tools/msprobe/core/common/exceptions.py +++ b/debug/accuracy_tools/msprobe/core/common/exceptions.py @@ -27,11 +27,13 @@ class MsprobeException(CodedException): INVALID_PARAM_ERROR = 0 OVERFLOW_NUMS_ERROR = 1 RECURSION_LIMIT_ERROR = 2 + INTERFACE_USAGE_ERROR = 3 err_strs = { INVALID_PARAM_ERROR: "[msprobe] 无效参数:", OVERFLOW_NUMS_ERROR: "[msprobe] 超过预设溢出次数 当前溢出次数:", - RECURSION_LIMIT_ERROR: "[msprobe] 递归调用超过限制:" + RECURSION_LIMIT_ERROR: "[msprobe] 递归调用超过限制:", + INTERFACE_USAGE_ERROR: "[msprobe] Invalid interface usage: " } diff --git a/debug/accuracy_tools/msprobe/core/common/inplace_ops.yaml b/debug/accuracy_tools/msprobe/core/common/inplace_ops.yaml index eadd9b764f36f872c285b67a9a096a496f445d36..dc899cbc8620ea6e62e946660942e1940f2bfa62 100644 --- a/debug/accuracy_tools/msprobe/core/common/inplace_ops.yaml +++ b/debug/accuracy_tools/msprobe/core/common/inplace_ops.yaml @@ -250,5 +250,6 @@ inplace_distributed_op: - all_to_all - all_gather_into_tensor - reduce_scatter_tensor + - batch_isend_irecv diff --git a/debug/accuracy_tools/msprobe/core/common/utils.py b/debug/accuracy_tools/msprobe/core/common/utils.py index 5d517a02961877c9fa8cec6967003fa330bc09e0..c06b5b64927bf47da1573df3b1d4db34dfa24cb1 100644 --- a/debug/accuracy_tools/msprobe/core/common/utils.py +++ b/debug/accuracy_tools/msprobe/core/common/utils.py @@ -29,6 +29,7 @@ from msprobe.core.common.const import Const, CompareConst from msprobe.core.common.log import logger from msprobe.core.common.exceptions import MsprobeException + device = collections.namedtuple('device', ['type', 'index']) prefixes = ['api_stack', 'list', 'range', 'acl'] @@ -238,6 +239,8 @@ def md5_find(data): for data_detail in data[key_op][api_info]: if data_detail and 'md5' in data_detail: return True + if isinstance(data[key_op][api_info], bool): + continue elif data[key_op][api_info] and 'md5' in data[key_op][api_info]: return True return False @@ -302,6 +305,9 @@ def get_dump_mode(input_param): if npu_task == Const.TENSOR: return Const.ALL + if npu_task == Const.STRUCTURE: + return Const.STRUCTURE + if npu_task == Const.STATISTICS: npu_md5_compare = md5_find(npu_json_data['data']) bench_md5_compare = md5_find(bench_json_data['data']) @@ -406,8 +412,8 @@ def get_real_step_or_rank(step_or_rank_input, obj): if not Const.STEP_RANK_MINIMUM_VALUE <= element <= Const.STEP_RANK_MAXIMUM_VALUE: raise MsprobeException( MsprobeException.INVALID_PARAM_ERROR, - f"Each element of {obj} must be between {Const.STEP_RANK_MINIMUM_VALUE} and {Const.STEP_RANK_MAXIMUM_VALUE}, " - f"currently it is {element}." + f"Each element of {obj} must be between {Const.STEP_RANK_MINIMUM_VALUE} and " + f"{Const.STEP_RANK_MAXIMUM_VALUE}, currently it is {element}." ) real_step_or_rank.append(element) continue diff --git a/debug/accuracy_tools/msprobe/core/common_config.py b/debug/accuracy_tools/msprobe/core/common_config.py index c9d3e5a19efc940c14cd7f8b4362258e4886aa7f..b9a717c0c52f11e52ac055e3cfe6a0e77fe7e44c 100644 --- a/debug/accuracy_tools/msprobe/core/common_config.py +++ b/debug/accuracy_tools/msprobe/core/common_config.py @@ -27,6 +27,7 @@ class CommonConfig: self.step = get_real_step_or_rank(json_config.get('step'), Const.STEP) self.level = json_config.get('level') self.enable_dataloader = json_config.get('enable_dataloader', False) + self.async_dump = json_config.get("async_dump", False) self._check_config() def _check_config(self): @@ -42,6 +43,11 @@ class CommonConfig: if not isinstance(self.enable_dataloader, bool): logger.error_log_with_exp("enable_dataloader is invalid, it should be a boolean", MsprobeException(MsprobeException.INVALID_PARAM_ERROR)) + if not isinstance(self.async_dump, bool): + logger.error_log_with_exp("async_dump is invalid, it should be a boolean", + MsprobeException(MsprobeException.INVALID_PARAM_ERROR)) + elif self.async_dump: + logger.warning("async_dump is True, it may cause OOM when dumping large tensor.") class BaseConfig: diff --git a/debug/accuracy_tools/msprobe/core/compare/acc_compare.py b/debug/accuracy_tools/msprobe/core/compare/acc_compare.py index 147bbbab0aff28aa058e2870a6666d241245df41..55229d72657c67428186bcb233371e3b9eee73e0 100644 --- a/debug/accuracy_tools/msprobe/core/compare/acc_compare.py +++ b/debug/accuracy_tools/msprobe/core/compare/acc_compare.py @@ -33,7 +33,7 @@ from msprobe.core.compare.highlight import find_compare_result_error_rows, highl from msprobe.core.compare.multiprocessing_compute import ComparisonResult, _handle_multi_process, _save_cmp_result from msprobe.core.compare.npy_compare import compare_ops_apply, get_error_flag_and_msg from msprobe.core.compare.utils import get_accuracy, get_rela_diff_summary_mode, get_un_match_accuracy, merge_tensor, \ - print_compare_ends_info, read_op, get_name_and_state + print_compare_ends_info, read_op, get_name_and_state, reorder_op_x_list class ModeConfig: @@ -254,9 +254,15 @@ class Comparator: CompareConst.PARAMS_STRUCT: 0, CompareConst.PARAMS_GRAD_STRUCT: 0 } + + op_name_list = merge_list.get(CompareConst.OP_NAME) + summary_list = merge_list.get(Const.SUMMARY) data_name_list = merge_list.get('data_name') - for index, op_full_name in enumerate(merge_list[CompareConst.OP_NAME]): - data_name = data_name_list[index] if data_name_list else None + op_name_reorder, summary_reorder, data_name_reorder = reorder_op_x_list(op_name_list, + summary_list, + data_name_list) + for index, op_full_name in enumerate(op_name_reorder): + data_name = data_name_reorder[index] if data_name_reorder else None _, state = get_name_and_state(op_full_name) struct_key = CompareConst.STATE_TO_STRUCT_MAPPING.get(state) @@ -265,8 +271,7 @@ class Comparator: ops_all[op_full_name] = { CompareConst.STRUCT: safe_get_value(merge_list, struct_to_index_mapping.get(struct_key), "merge_list", key=struct_key), - CompareConst.SUMMARY: safe_get_value(merge_list, index, "merge_list", - key=CompareConst.SUMMARY), + CompareConst.SUMMARY: safe_get_value(summary_reorder, index, "summary_reorder"), 'data_name': data_name, 'stack_info': merge_list.get('stack_info') } @@ -394,7 +399,7 @@ class Comparator: result_list, err_msg = compare_ops_apply(n_value, b_value, error_flag, err_msg) - if npu_op_name != bench_op_name and bench_op_name != CompareConst.N_A: + if self.fuzzy_match and npu_op_name != bench_op_name and bench_op_name != CompareConst.N_A: err_msg += " Fuzzy matching data, the comparison accuracy may be affected." result_list.append(err_msg) return result_list diff --git a/debug/accuracy_tools/msprobe/core/compare/highlight.py b/debug/accuracy_tools/msprobe/core/compare/highlight.py index 34553c1f9347d20d1180a057cd76f6c4a59f88cf..cf3e1c4c03e9553f5566870b7c5ebe2d890e9774 100644 --- a/debug/accuracy_tools/msprobe/core/compare/highlight.py +++ b/debug/accuracy_tools/msprobe/core/compare/highlight.py @@ -188,11 +188,10 @@ def find_error_rows(result, api_batch, highlight_dict, dump_mode): color_columns = ColorColumns(red=red_lines, yellow=yellow_lines) api_batch_start = api_batch.start # result_df的input起始全局索引 - api_batch_input_len = api_batch.input_len # result的input结束局部索引 + 1 - api_batch_output_end_index = api_batch.output_end_index # result_df的output结束全局索引 + 1 api_batch_params_end_index = api_batch.params_end_index # result_df的params结束全局索引 + 1 - api_batch_output_slice_index_local = api_batch_output_end_index - api_batch_start # result的output结束局部切片索引 + api_batch_output_end_index = api_batch.output_end_index # result_df的output结束全局索引 + 1 api_batch_params_slice_index_local = api_batch_params_end_index - api_batch_start # result的params结束局部切片索引 + api_batch_output_slice_index_local = api_batch_output_end_index - api_batch_start # result的output结束局部切片索引 # 对单行API的输入或输出进行误差判断 for i, line in enumerate(result): @@ -202,28 +201,21 @@ def find_error_rows(result, api_batch, highlight_dict, dump_mode): rule.apply(line_info, color_columns, dump_mode) # 对API的输出与输入比较,进行误差判断 - for n, api_out in enumerate(result[api_batch_input_len: api_batch_output_slice_index_local]): - index = api_batch_start + api_batch_input_len + n + for n, api_out in enumerate(result[api_batch_params_slice_index_local: api_batch_output_slice_index_local]): + index = api_batch_start + api_batch_params_slice_index_local + n # 单行检查只有溢出检查(红色),如果已经溢出,不进一步检查 if index in red_lines: continue if not check_indices_numeric(api_out, [npu_max_index, bench_max_index, max_diff_index]): continue - # input的比较检查 - for _, api_in in enumerate(result[0: api_batch_input_len]): + # input/parameters的比较检查, 这里api_in包括input、parameters + for _, api_in in enumerate(result[0: api_batch_params_slice_index_local]): if not check_indices_numeric(api_in, [npu_max_index, bench_max_index, max_diff_index]): continue api_info = ApiInfo(api_input=api_in, api_output=api_out, num_pointer=index) apply_comparison_rules(api_info, dump_mode, color_columns) - # parameters的比较检查 - for _, api_params in enumerate(result[api_batch_output_slice_index_local: api_batch_params_slice_index_local]): - if not check_indices_numeric(api_params, [npu_max_index, bench_max_index, max_diff_index]): - continue - api_info = ApiInfo(api_input=api_params, api_output=api_out, num_pointer=index) - apply_comparison_rules(api_info, dump_mode, color_columns) - red_lines_num_set = {x[0] for x in red_lines} yellow_lines_num_set = {x[0] for x in yellow_lines} highlight_dict.get('red_rows', set()).update(red_lines_num_set) @@ -237,8 +229,8 @@ class ApiBatch: self.api_name = api_name self.start = start self.input_len = 1 # input的数量 - self.output_end_index = start + 1 # output的结束index self.params_end_index = start + 1 # params的结束index + self.output_end_index = start + 1 # output的结束index self.params_grad_end_index = start + 1 # params_grad的结束index # 内部state的标志("input", "output", "parameters", "parameters_grad"), # 用于控制计算input_len, output_end_index, params_end_index, self.params_grad_end_index @@ -246,22 +238,22 @@ class ApiBatch: def set_state(self, state: str): """设置当前状态""" - if state in {Const.INPUT, Const.OUTPUT, Const.PARAMS, Const.PARAMS_GRAD}: + if state in {Const.INPUT, Const.OUTPUT, Const.KWARGS, Const.PARAMS, Const.PARAMS_GRAD}: self._state = state else: raise ValueError(f"Invalid state: {state}") def increment(self, state: str): self.set_state(state) - if self._state == Const.INPUT: + if self._state == Const.INPUT or self._state == Const.KWARGS: self.input_len += 1 - self.output_end_index += 1 self.params_end_index += 1 - if self._state == Const.OUTPUT: self.output_end_index += 1 - self.params_end_index += 1 if self._state == Const.PARAMS: self.params_end_index += 1 + self.output_end_index += 1 + if self._state == Const.OUTPUT: + self.output_end_index += 1 self.params_grad_end_index += 1 @@ -297,7 +289,8 @@ def find_compare_result_error_rows(result_df, highlight_dict, dump_mode): api_batches_update(api_batches, api_name, state, i) with tqdm(total=len(api_batches), desc="API/Module Analyse Progress", unit="item", ncols=100) as progress_bar: for api_batch in api_batches: - find_error_rows(result[api_batch.start: api_batch.params_end_index], api_batch, highlight_dict, dump_mode) + find_error_rows(result[api_batch.start: api_batch.params_grad_end_index], api_batch, highlight_dict, + dump_mode) progress_bar.update(1) diff --git a/debug/accuracy_tools/msprobe/core/compare/layer_mapping/data_scope_parser.py b/debug/accuracy_tools/msprobe/core/compare/layer_mapping/data_scope_parser.py index d7713da4d95b3db8454ffbda2c080a2a2fa8e0ec..5ba5aa69a10a8aa408868697ae2982bd1349ff76 100644 --- a/debug/accuracy_tools/msprobe/core/compare/layer_mapping/data_scope_parser.py +++ b/debug/accuracy_tools/msprobe/core/compare/layer_mapping/data_scope_parser.py @@ -112,7 +112,7 @@ class DumpDataItem: self.layer_scope = Const.SEP.join(data_list[:Const.TYPE_NAME_INDEX]) else: self.layer_scope = Const.TOP_LAYER - if construct_info: + if construct_info and Const.SEP in construct_info: construct_list = construct_info.split(Const.SEP) if len(construct_list) < abs(Const.LAYER_NAME_INDEX): logger.error( diff --git a/debug/accuracy_tools/msprobe/core/compare/layer_mapping/layer_mapping.py b/debug/accuracy_tools/msprobe/core/compare/layer_mapping/layer_mapping.py index 53b079835a1b75404d4f1ded6e0ee27bc4da4927..d0f19462ee1ccf4d72c69885c18174cec32df056 100644 --- a/debug/accuracy_tools/msprobe/core/compare/layer_mapping/layer_mapping.py +++ b/debug/accuracy_tools/msprobe/core/compare/layer_mapping/layer_mapping.py @@ -22,7 +22,9 @@ from msprobe.core.common.utils import (add_time_with_yaml, detect_framework_by_dump_json, get_stack_construct_by_dump_json_path) from msprobe.core.compare.layer_mapping.data_scope_parser import get_dump_data_items -from msprobe.core.compare.utils import read_op +from msprobe.core.compare.utils import read_op, reorder_op_name_list + + class LayerTrie: def __init__(self, type_name, framework=None): @@ -225,7 +227,10 @@ def generate_data_mapping(npu_json_path, bench_json_path, api_mapping, output_pa continue npu_full_op_names = read_full_op_names(npu_data, npu_op_name) bench_full_op_names = read_full_op_names(bench_data, bench_op_name) - mapping = generate_op_data_mapping(npu_op_name, npu_full_op_names, bench_op_name, bench_full_op_names) + npu_full_op_names_reorder = reorder_op_name_list(npu_full_op_names) + bench_full_op_names_reorder = reorder_op_name_list(bench_full_op_names) + mapping = generate_op_data_mapping(npu_op_name, npu_full_op_names_reorder, + bench_op_name, bench_full_op_names_reorder) data_mapping.update(mapping) if output_path: file_name = add_time_with_yaml("data_mapping") diff --git a/debug/accuracy_tools/msprobe/core/compare/layer_mapping/postprocess_pass.py b/debug/accuracy_tools/msprobe/core/compare/layer_mapping/postprocess_pass.py index 7f0c8129ae30e934ce16a27a1753271cb756b3f3..2946b86122d6619338a5d9ec057bf3ba96c5ac75 100644 --- a/debug/accuracy_tools/msprobe/core/compare/layer_mapping/postprocess_pass.py +++ b/debug/accuracy_tools/msprobe/core/compare/layer_mapping/postprocess_pass.py @@ -29,9 +29,10 @@ def backward_pass(data_items, name2item): data_name_list = data_item.data_name.split(Const.SEP) if not data_name_list: continue - if Const.BACKWARD in data_name_list[Const.SCOPE_DIRECTION_INDEX :]: - data_name_list[Const.SCOPE_DIRECTION_INDEX :] = [ - s.replace(Const.BACKWARD, Const.FORWARD) for s in data_name_list[Const.SCOPE_DIRECTION_INDEX :] + if Const.BACKWARD in data_name_list[Const.SCOPE_DIRECTION_INDEX:]: + data_name_list[Const.SCOPE_DIRECTION_INDEX:] = [ + s.replace(Const.BACKWARD, Const.FORWARD) + for s in data_name_list[Const.SCOPE_DIRECTION_INDEX:] ] forward_name = Const.SEP.join(data_name_list) forward_item = name2item.get(forward_name, None) diff --git a/debug/accuracy_tools/msprobe/core/compare/merge_result/merge_result.py b/debug/accuracy_tools/msprobe/core/compare/merge_result/merge_result.py index 1da07be6c9145163c13107bdcc9d9816dfa431c0..b605bd59fca0b2b3a510a7a686caa94383488bd2 100644 --- a/debug/accuracy_tools/msprobe/core/compare/merge_result/merge_result.py +++ b/debug/accuracy_tools/msprobe/core/compare/merge_result/merge_result.py @@ -25,6 +25,7 @@ from msprobe.core.common.file_utils import load_yaml, logger, FileChecker, save_ from msprobe.core.common.const import FileCheckConst, Const, CompareConst from msprobe.core.common.utils import CompareException, add_time_with_xlsx from msprobe.core.compare.utils import table_value_is_valid +from msprobe.core.compare.merge_result.utils import replace_compare_index_dict, check_config def check_compare_result_name(file_name): @@ -58,8 +59,8 @@ def get_result_path(input_dir): """ get rank ordered compare result file path list """ - compare_result_path_list = [os.path.join(input_dir, f) for f in os.listdir(input_dir) if - f.endswith(FileCheckConst.XLSX_SUFFIX)] + compare_result_path_list = [os.path.join(input_dir, f) + for f in os.listdir(input_dir) if f.endswith(FileCheckConst.XLSX_SUFFIX)] filt_compare_result_path_list = [] for file_path in compare_result_path_list: file_name = os.path.basename(file_path) @@ -170,6 +171,8 @@ def search_api_index_result(api_list, compare_index_list, result_df, rank_num, c table_value_check(index_value) api_index_dict.setdefault(api_full_name, {})[rank_num] = index_value # update api_index_dict compare_index_dict[compare_index] = api_index_dict + + compare_index_dict = replace_compare_index_dict(compare_index_dict, compare_index_list, rank_num) return compare_index_dict @@ -203,10 +206,13 @@ def result_process(compare_result_path_list, api_list): compare_index_list = check_index_dump_mode_consistent(dump_mode, rank_num) if len(compare_index_list) == 0: return [], [], [] - compare_index_dict = search_api_index_result(api_list, share_compare_index_list, + compare_index_list.extend([CompareConst.NPU_MAX, CompareConst.BENCH_MAX]) + compare_index_dict = search_api_index_result(api_list, compare_index_list, result_df, rank_num, compare_index_dict) compare_index_dict_list.append(compare_index_dict) rank_num_list.append(rank_num) + compare_index_list.pop() + compare_index_list.pop() else: logger.warning(f"Rank{rank_num} compare result is empty and will not shown in merged result.") @@ -313,16 +319,16 @@ def generate_merge_result(all_compare_index_dict_list, all_rank_num_list, all_co for compare_index_dict, rank_num in zip(compare_index_dict_list, rank_num_list): header = [CompareConst.NPU_NAME, "rank" + str(rank_num)] result_df_list = [] - for _, api_index_dict in compare_index_dict.items(): + for _, api_index_dict in compare_index_dict.items(): result_df = generate_result_df(api_index_dict, header) result_df_list.append(result_df) - # all_result_df_list示例:[[result_df_rank1_index1, result_df_rank1_index2], [result_df_rank2_index1, result_df_rank2_index2]] all_result_df_list.append(result_df_list) merge_df_list = df_merge(all_result_df_list) final_result_df_list = [] for i, df in enumerate(merge_df_list): - final_result_df_list.append((df, compare_index_list[i])) # merge_df_list中df与compare_index_list中compare_index一一对应 + # merge_df_list中df与compare_index_list中compare_index一一对应 + final_result_df_list.append((df, compare_index_list[i])) save_excel(output_path, final_result_df_list) logger.info(f"The compare results of the multi-ranks are merged and saved in: {output_path}.") @@ -362,13 +368,8 @@ def merge_result(input_dir, output_dir, config_path): compare_result_path_list = get_result_path(input_dir) # 获得的input_dir中所有比对结果件的全路径,数量少于2,便提示退出 config = load_yaml(config_path) - if not config: - logger.error('config.yaml is empty, please check.') - raise CompareException(CompareException.MERGE_COMPARE_RESULT_ERROR) + config = check_config(config) api_list = config.get('api') - if not api_list: - logger.error('The APIs required to merge data were not found') - raise CompareException(CompareException.MERGE_COMPARE_RESULT_ERROR) # 初始化共享全局变量share_compare_index_list initialize_compare_index(config) diff --git a/debug/accuracy_tools/msprobe/core/compare/merge_result/utils.py b/debug/accuracy_tools/msprobe/core/compare/merge_result/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ce563e9682088b28e2100a3851588632a9bb4b3a --- /dev/null +++ b/debug/accuracy_tools/msprobe/core/compare/merge_result/utils.py @@ -0,0 +1,81 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from msprobe.core.common.const import CompareConst +from msprobe.core.common.file_utils import logger +from msprobe.core.common.utils import CompareException + + +def replace_compare_index_dict(compare_index_dict, compare_index_list, rank_num): + """ + 比对指标值为N/A、unsupported、Nan,将比对指标值替换成NPU max 和 Bench max(几个统计量相同) + + 示例: + Distributed.all_reduce.0.forward.output.group的比对指标值是N/A + 替换后: + 比对指标值为: + NPU: tp-0-1-2-3 + Bench: tp-0-1-2-3 + """ + + if CompareConst.NPU_MAX not in compare_index_dict or CompareConst.BENCH_MAX not in compare_index_dict: + compare_index_dict.pop(CompareConst.NPU_MAX, None) + compare_index_dict.pop(CompareConst.BENCH_MAX, None) + return compare_index_dict + + # 遍历比对指标列表,排除最后两个指标NPU max, Bench max + for compare_index in compare_index_list[:-2]: + op_name_index_dict = compare_index_dict[compare_index] + # 遍历op_item名称和对应的比对指标值 + for op_name, index_value in op_name_index_dict.items(): + npu_max = compare_index_dict[CompareConst.NPU_MAX][op_name][rank_num] + bench_max = compare_index_dict[CompareConst.BENCH_MAX][op_name][rank_num] + # 如果当前比对指标值是N/A、unsupported、Nan,并且NPU和Bench的最大值是类型相同,进行替换 + if index_value[rank_num] in [CompareConst.N_A, CompareConst.UNSUPPORTED, CompareConst.NAN]: + compare_index_dict[compare_index][op_name][rank_num] = f'NPU:{str(npu_max)} Bench:{str(bench_max)}' + + # 删除NPU_MAX和BENCH_MAX + compare_index_dict.pop(CompareConst.NPU_MAX, None) + compare_index_dict.pop(CompareConst.BENCH_MAX, None) + return compare_index_dict + + +def check_config(config): + """ + config.yaml 内容检查 + Args: config: + Returns: config + """ + if not config: + logger.error('config.yaml is empty, please check.') + raise CompareException(CompareException.MERGE_COMPARE_RESULT_ERROR) + + api_list = config.get('api') + if not api_list: + logger.error('The APIs required to merge data were not found.') + raise CompareException(CompareException.MERGE_COMPARE_RESULT_ERROR) + if not isinstance(api_list, list): + logger.error("The config format of 'api' is incorrect, please check.") + raise CompareException(CompareException.MERGE_COMPARE_RESULT_ERROR) + + compare_index_list = config.get('compare_index', []) + if compare_index_list is None: + compare_index_list = [] + config['compare_index'] = compare_index_list + if not isinstance(compare_index_list, list): + logger.error("The config format of 'compare_index' is incorrect, please check.") + raise CompareException(CompareException.MERGE_COMPARE_RESULT_ERROR) + + return config diff --git a/debug/accuracy_tools/msprobe/core/compare/multiprocessing_compute.py b/debug/accuracy_tools/msprobe/core/compare/multiprocessing_compute.py index 864f29d2fbd098fd2ed4aa4d0a27c38cc025e30e..c2c1461e452f9d2c7f4e0e2803dfe51be2a132c0 100644 --- a/debug/accuracy_tools/msprobe/core/compare/multiprocessing_compute.py +++ b/debug/accuracy_tools/msprobe/core/compare/multiprocessing_compute.py @@ -23,7 +23,7 @@ from msprobe.core.common.const import CompareConst def _handle_multi_process(func, input_parma, result_df, lock): - process_num = int((multiprocessing.cpu_count() + 1) / 2) + process_num = max(int((multiprocessing.cpu_count() + 1) // 4), 1) op_name_mapping_dict = read_dump_data(result_df) df_chunk_size = len(result_df) // process_num @@ -63,7 +63,7 @@ def _handle_multi_process(func, input_parma, result_df, lock): def _ms_graph_handle_multi_process(func, result_df, mode): - process_num = int((multiprocessing.cpu_count() + 1) // 4) + process_num = max(int((multiprocessing.cpu_count() + 1) // 4), 1) df_chunk_size = len(result_df) // process_num if df_chunk_size > 0: df_chunks = [result_df.iloc[i:i + df_chunk_size] for i in range(0, len(result_df), df_chunk_size)] diff --git a/debug/accuracy_tools/msprobe/core/compare/utils.py b/debug/accuracy_tools/msprobe/core/compare/utils.py index 9724ffbef6f39f4b0c1692bdb5558e36eab88f2d..a2edf57e5bb91400675fe01734ea7fbf0e1df893 100644 --- a/debug/accuracy_tools/msprobe/core/compare/utils.py +++ b/debug/accuracy_tools/msprobe/core/compare/utils.py @@ -131,7 +131,7 @@ def op_item_parse(op_data, op_name: str, depth: int = 0) -> list: return [default_item] elif not op_data: return [] - + item_list = [] if isinstance(op_data, list): for i, data in enumerate(op_data): @@ -162,7 +162,7 @@ def gen_op_item(op_data, op_name): for i in params: if i not in op_item: op_item[i] = None - + if not op_item.get('dtype'): if op_item.get('type') == 'torch.Size': op_item['dtype'] = op_data.get('type') @@ -170,6 +170,16 @@ def gen_op_item(op_data, op_name): elif op_item.get('type') == 'slice': op_item['dtype'] = op_data.get('type') op_item['shape'] = str(np.shape(np.array(op_data.get('value')))) + elif op_item.get('type') == 'ellipsis': + op_item['dtype'] = op_data.get('type') + op_item['shape'] = '[]' + for i in params: + op_item[i] = op_data.get('value') + elif op_item.get('type') == 'torch.ProcessGroup': + op_item['dtype'] = op_data.get('type') + op_item['shape'] = '[]' + for i in params: + op_item[i] = str(op_data.get('group_ranks')) else: op_item['dtype'] = str(type(op_data.get('value'))) op_item['shape'] = '[]' @@ -177,7 +187,7 @@ def gen_op_item(op_data, op_name): op_item[i] = op_data.get('value') if not op_item.get('md5'): op_item['md5'] = f"{zlib.crc32(str(op_data.get('value', '')).encode()):08x}" - + return op_item @@ -386,9 +396,9 @@ def get_accuracy(result, n_dict, b_dict, dump_mode): b_num, b_num_input, b_num_output, b_num_params, b_num_params_grad = count_struct(b_dict) get_accuracy_core(0, n_num_input, 0, b_num_input, CompareConst.INPUT_STRUCT) - get_accuracy_core(n_num_input, n_num_output, b_num_input, b_num_output, CompareConst.OUTPUT_STRUCT) get_accuracy_core(n_num_input + n_num_output, n_num_params, b_num_input + b_num_output, b_num_params, CompareConst.PARAMS_STRUCT) + get_accuracy_core(n_num_input, n_num_output, b_num_input, b_num_output, CompareConst.OUTPUT_STRUCT) get_accuracy_core(n_num_input + n_num_output + n_num_params, n_num_params_grad, b_num_input + b_num_output + b_num_params, b_num_params_grad, CompareConst.PARAMS_GRAD_STRUCT) @@ -412,7 +422,14 @@ def get_un_match_accuracy(result, n_dict, dump_mode): CompareConst.PARAMS_STRUCT: 0, CompareConst.PARAMS_GRAD_STRUCT: 0 } - for index, n_name in enumerate(n_dict["op_name"]): + + op_name_list = n_dict.get(CompareConst.OP_NAME) + summary_list = n_dict.get(Const.SUMMARY) + data_name_list = n_dict.get('data_name') + op_name_reorder, summary_reorder, _ = reorder_op_x_list(op_name_list, + summary_list, + data_name_list) + for index, n_name in enumerate(op_name_reorder): _, state = get_name_and_state(n_name) struct_key = CompareConst.STATE_TO_STRUCT_MAPPING.get(state) if not struct_key: @@ -440,7 +457,7 @@ def get_un_match_accuracy(result, n_dict, dump_mode): if dump_mode == Const.ALL: result_item.extend([CompareConst.N_A] * 5) - npu_summary_data = safe_get_value(n_dict, index, "n_dict", key=CompareConst.SUMMARY) + npu_summary_data = safe_get_value(summary_reorder, index, "summary_reorder") bench_summary_data = [CompareConst.N_A] * 4 result_item.extend(npu_summary_data) result_item.extend(bench_summary_data) @@ -540,11 +557,51 @@ def get_name_and_state(name): return api, state +def reorder_op_name_list(op_name_list): + if not op_name_list: + return op_name_list + + parameters = [] + output = [] + parameters_grad = [] + others = [] + for x in op_name_list: + state = get_name_and_state(x)[1] + if state == Const.PARAMS: + parameters.append(x) + elif state == Const.OUTPUT: + output.append(x) + elif state == Const.PARAMS_GRAD: + parameters_grad.append(x) + else: + others.append(x) + # 合并others, parameters, 和output,确保parameters排在output前面 + op_name_reorder = others + parameters + output + parameters_grad + return op_name_reorder + + +def reorder_op_x_list(op_name_list, summary_list, data_name_list): + """对op_name, summary, data_name重新排序,把parameters放到input后output前,data_name由于统计量比对时,为None,单独处理""" + if not op_name_list or not summary_list: + return op_name_list, summary_list, data_name_list + + index_map = {name: index for index, name in enumerate(op_name_list)} + + op_name_reorder = reorder_op_name_list(op_name_list) + summary_reorder = [summary_list[index_map.get(name)] for name in op_name_reorder] + if data_name_list: + data_name_reorder = [data_name_list[index_map.get(name)] for name in op_name_reorder] + else: + data_name_reorder = data_name_list + + return op_name_reorder, summary_reorder, data_name_reorder + + def _compare_parser(parser): parser.add_argument("-i", "--input_path", dest="input_path", type=str, help=" The compare input path, a dict json.", required=True) parser.add_argument("-o", "--output_path", dest="output_path", type=str, - help=" The compare task result out path. Default path: ./output", + help=" The compare task result out path. Default path: ./output", required=False, default="./output", nargs="?", const="./output") parser.add_argument("-s", "--stack_mode", dest="stack_mode", action="store_true", help=" Whether to save stack info.", required=False) diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_collector.py b/debug/accuracy_tools/msprobe/core/data_dump/data_collector.py index fa6513d325a2a694c4ea3a336f917f33f2c2ca5b..20e4489f89e4bd345595e6a1db1e39ab427d4908 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_collector.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_collector.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -40,6 +40,7 @@ class DataCollector: self.scope = ScopeFactory(self.config).build_scope() self.backward_module_names = {} self.optimizer_status = "" + self.optimizer_status_first_start = {Const.OPTIMIZER: True, Const.CLIP_GRAD: True} atexit.register(self.write_json) @property @@ -54,6 +55,17 @@ class DataCollector: def check_scope_and_pid(scope, name, pid): return (not scope or scope.check(name)) and pid == os.getpid() + @staticmethod + def set_is_recomputable(data_info, is_recompute): + if data_info and len(data_info) == 1 and is_recompute is not None: # 正常情况下data_info的长度应改为1 + data_info[list(data_info.keys())[0]]["is_recompute"] = is_recompute + + def reset_status(self): + self.optimizer_status = "" + self.optimizer_status_first_start = {Const.OPTIMIZER: True, Const.CLIP_GRAD: True} + self.data_writer.reset_cache() + self.backward_module_names.clear() + def if_return_forward_new_output(self): return self.data_processor.if_return_forward_new_output() @@ -77,7 +89,7 @@ class DataCollector: logger.debug(msg) self.data_writer.update_data(data_info) - def forward_input_data_collect(self, name, module, pid, module_input_output): + def forward_input_data_collect(self, name, module, pid, module_input_output, is_recompute=None): if self.config.task == Const.FREE_BENCHMARK: backward_name = name.replace(Const.FORWARD, Const.BACKWARD) if self.check_scope_and_pid(self.scope, backward_name, pid): @@ -87,37 +99,48 @@ class DataCollector: if not self.check_scope_and_pid(self.scope, name, pid): return - data_info = self.data_processor.analyze_forward_input(name, module, module_input_output) + data_info = {} + if self.config.task != Const.STRUCTURE: + data_info = self.data_processor.analyze_forward_input(name, module, module_input_output) + self.set_is_recomputable(data_info, is_recompute) if self.config.level == Const.LEVEL_L2: return self.handle_data(name, data_info, flush=self.data_processor.is_terminated) - def forward_output_data_collect(self, name, module, pid, module_input_output): + def forward_output_data_collect(self, name, module, pid, module_input_output, is_recompute=None): self.update_construct(name) if not self.check_scope_and_pid(self.scope, name, pid): return - data_info = self.data_processor.analyze_forward_output(name, module, module_input_output) + data_info = {} + if self.config.task != Const.STRUCTURE: + data_info = self.data_processor.analyze_forward_output(name, module, module_input_output) + self.set_is_recomputable(data_info, is_recompute) if self.config.level == Const.LEVEL_L2: return self.data_writer.update_stack(self.data_processor.analyze_api_call_stack(name)) self.handle_data(name, data_info, flush=self.data_processor.is_terminated) - def forward_data_collect(self, name, module, pid, module_input_output): + def forward_data_collect(self, name, module, pid, module_input_output, is_recompute=None): self.update_construct(name) if not self.check_scope_and_pid(self.scope, name, pid): return - data_info = self.data_processor.analyze_forward(name, module, module_input_output) + data_info = {} + if self.config.task != Const.STRUCTURE: + data_info = self.data_processor.analyze_forward(name, module, module_input_output) + self.set_is_recomputable(data_info, is_recompute) self.data_writer.update_stack(self.data_processor.analyze_api_call_stack(name)) self.handle_data(name, data_info, flush=self.data_processor.is_terminated) - def backward_data_collect(self, name, module, pid, module_input_output): + def backward_data_collect(self, name, module, pid, module_input_output, is_recompute=None): self.update_construct(name) if not self.check_scope_and_pid(self.scope, name, pid): return - data_info = self.data_processor.analyze_backward(name, module, module_input_output) + data_info = {} + if self.config.task != Const.STRUCTURE: + data_info = self.data_processor.analyze_backward(name, module, module_input_output) if self.config.level == Const.LEVEL_L2: return # 获取执行反向的模块名称 @@ -127,25 +150,34 @@ class DataCollector: self.backward_module_names[module_name] = True self.handle_data(name, data_info, flush=self.data_processor.is_terminated) - def backward_input_data_collect(self, name, module, pid, module_input_output): + def backward_input_data_collect(self, name, module, pid, module_input_output, is_recompute=None): self.update_construct(name) if not self.check_scope_and_pid(self.scope, name, pid): return - data_info = self.data_processor.analyze_backward_input(name, module, module_input_output) + data_info = {} + if self.config.task != Const.STRUCTURE: + data_info = self.data_processor.analyze_backward_input(name, module, module_input_output) + self.set_is_recomputable(data_info, is_recompute) self.handle_data(name, data_info) - def backward_output_data_collect(self, name, module, pid, module_input_output): + def backward_output_data_collect(self, name, module, pid, module_input_output, is_recompute=None): self.update_construct(name) if not self.check_scope_and_pid(self.scope, name, pid): return - data_info = self.data_processor.analyze_backward_output(name, module, module_input_output) + data_info = {} + if self.config.task != Const.STRUCTURE: + data_info = self.data_processor.analyze_backward_output(name, module, module_input_output) + self.set_is_recomputable(data_info, is_recompute) self.handle_data(name, data_info) def update_construct(self, name): if self.config.level not in DataCollector.level_without_construct: if self.optimizer_status in [Const.OPTIMIZER, Const.CLIP_GRAD]: + if self.optimizer_status_first_start[self.optimizer_status]: + self.data_writer.update_construct({self.optimizer_status: None}) + self.optimizer_status_first_start[self.optimizer_status] = False self.data_writer.update_construct({name: self.optimizer_status}) else: self.data_writer.update_construct({name: self.module_processor.api_parent_node}) @@ -154,6 +186,8 @@ class DataCollector: def handle_data(self, name, data_info, flush=False): if data_info: self.update_data(name, data_info) + if self.config.async_dump: + return if not flush: self.data_writer.flush_data_periodically() else: @@ -179,6 +213,9 @@ class DataCollector: data_info = self.data_processor.analyze_params(grad_name, param_name, data) self.handle_data(grad_name, data_info, flush=self.data_processor.is_terminated) + def fill_stack_tensor_data(self): + self.data_writer.fill_stack_tensor_data() + def debug_data_collect_forward(self, variable, name_with_count): data_info = self.data_processor.analyze_debug_forward(variable, name_with_count) @@ -190,4 +227,4 @@ class DataCollector: self.data_writer.update_debug({grad_name_with_count: all_none_data_info}) # register tensor backward hook - self.data_processor.analyze_debug_backward(variable, grad_name_with_count, self.data_writer.cache_debug['data']) \ No newline at end of file + self.data_processor.analyze_debug_backward(variable, grad_name_with_count, self.data_writer.cache_debug['data']) diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py index 5940886b439911dcb9232fc8e1a441d711824f2d..775a80b2418ef356867228b4ca09fad8c86cce25 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py @@ -19,6 +19,7 @@ from dataclasses import dataclass, is_dataclass from typing import Tuple, Dict, Optional, Any from functools import partial import copy +from typing import Union import numpy as np @@ -78,17 +79,18 @@ class ModuleBackwardOutputs: class TensorStatInfo: - def __init__(self, max_val=None, min_val=None, mean_val=None, norm_val=None): + def __init__(self, max_val=None, min_val=None, mean_val=None, norm_val=None, stack_tensor_stat=None): self.max = max_val self.min = min_val self.mean = mean_val self.norm = norm_val + self.stack_tensor_stat = stack_tensor_stat class BaseDataProcessor: _recursive_key_stack = [] special_type = ( - np.integer, np.floating, np.bool_, np.complexfloating, np.str_, np.byte, np.unicode_, + np.integer, np.floating, np.bool_, np.complexfloating, np.str_, np.byte, np.unicode_, np.ndarray, bool, int, float, str, slice, type(Ellipsis) ) @@ -215,8 +217,22 @@ class BaseDataProcessor: return single_arg @staticmethod - def _analyze_numpy(value, numpy_type): - return {"type": numpy_type, "value": value} + def _analyze_numpy(ndarray, numpy_type): + ndarray_json = {} + ndarray_json.update({'type': 'numpy.ndarray'}) + ndarray_json.update({'dtype': str(ndarray.dtype)}) + ndarray_json.update({'shape': ndarray.shape}) + if ndarray.size > 0: + ndarray_json.update({"Max": np.max(ndarray).item()}) + ndarray_json.update({"Min": np.min(ndarray).item()}) + ndarray_json.update({"Mean": np.mean(ndarray).item()}) + ndarray_json.update({"Norm": np.linalg.norm(ndarray).item()}) + else: + ndarray_json.update({"Max": None}) + ndarray_json.update({"Min": None}) + ndarray_json.update({"Mean": None}) + ndarray_json.update({"Norm": None}) + return ndarray_json @staticmethod def _get_allowed_data_mode(data_mode): @@ -235,7 +251,7 @@ class BaseDataProcessor: return cls.special_type @classmethod - def recursive_apply_transform(cls, args, transform, depth=0): + def recursive_apply_transform(cls, args, transform, depth=0) -> Union[dict, list, None]: if depth > Const.MAX_DEPTH: logger.error(f"The maximum depth of recursive transform, {Const.MAX_DEPTH} is reached.") raise CompareException(CompareException.RECURSION_LIMIT_ERROR) @@ -256,7 +272,7 @@ class BaseDataProcessor: elif isinstance(args, dict): return cls.apply_transform_dict(args, transform, depth) elif args is not None: - logger.warning(f"Data type {type(args)} is not supported.") + logger.debug(f"Data type {type(args)} is not supported.") return None else: return None diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/factory.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/factory.py index 21c6b6989fa82e8e83c56aac336a36d34cde4fc0..83f3c717e88f018b4ecef9ec4e2a5edec3e56c4f 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/factory.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/factory.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,6 +14,7 @@ # limitations under the License. from msprobe.core.common.const import Const +from msprobe.core.data_dump.data_processor.base import BaseDataProcessor class DataProcessorFactory: @@ -62,15 +63,19 @@ class DataProcessorFactory: cls.register_processor(Const.PT_FRAMEWORK, Const.OVERFLOW_CHECK, PytorchOverflowCheckDataProcessor) cls.register_processor(Const.PT_FRAMEWORK, Const.FREE_BENCHMARK, PytorchFreeBenchmarkDataProcessor) cls.register_processor(Const.PT_FRAMEWORK, Const.KERNEL_DUMP, PytorchKernelDumpDataProcessor) + cls.register_processor(Const.PT_FRAMEWORK, Const.STRUCTURE, BaseDataProcessor) cls.register_module_processor(Const.PT_FRAMEWORK, ModuleProcesser) elif framework == Const.MS_FRAMEWORK: from msprobe.core.data_dump.data_processor.mindspore_processor import ( StatisticsDataProcessor as MindsporeStatisticsDataProcessor, TensorDataProcessor as MindsporeTensorDataProcessor, - OverflowCheckDataProcessor as MindsporeOverflowCheckDataProcessor + OverflowCheckDataProcessor as MindsporeOverflowCheckDataProcessor, + KernelDumpDataProcessor as MindsporeKernelDumpDataProcessor ) from msprobe.mindspore.cell_processor import CellProcessor cls.register_processor(Const.MS_FRAMEWORK, Const.STATISTICS, MindsporeStatisticsDataProcessor) cls.register_processor(Const.MS_FRAMEWORK, Const.TENSOR, MindsporeTensorDataProcessor) cls.register_processor(Const.MS_FRAMEWORK, Const.OVERFLOW_CHECK, MindsporeOverflowCheckDataProcessor) + cls.register_processor(Const.MS_FRAMEWORK, Const.KERNEL_DUMP, MindsporeKernelDumpDataProcessor) + cls.register_processor(Const.MS_FRAMEWORK, Const.STRUCTURE, BaseDataProcessor) cls.register_module_processor(Const.MS_FRAMEWORK, CellProcessor) diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py index 133a032f6c065dbe4672337baf759a5e33c8b44e..354989c74d0c15fbb1efecd21cd746c38c62018c 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py @@ -1,4 +1,4 @@ -# Copyright 2024 Huawei Technologies Co., Ltd +# Copyright 2024-2025 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,18 +16,24 @@ import zlib import mindspore as ms -from mindspore import mint, ops +from mindspore import mint, ops, hal from mindspore._c_expression.typing import Number import numpy as np from msprobe.core.common.const import Const from msprobe.core.data_dump.data_processor.base import (BaseDataProcessor, TensorStatInfo, ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs) -from msprobe.core.common.file_utils import path_len_exceeds_limit +from msprobe.core.common.file_utils import path_len_exceeds_limit, save_npy from msprobe.mindspore.common.utils import convert_bf16_to_fp32, save_tensor_as_npy from msprobe.mindspore.common.log import logger from msprobe.mindspore.dump.hook_cell.api_registry import api_register +has_adump = True +try: + from msprobe.lib import _msprobe_c +except ImportError: + has_adump = False + class MindsporeDataProcessor(BaseDataProcessor): mindspore_special_type = tuple([ms.Tensor, Number]) @@ -37,6 +43,7 @@ class MindsporeDataProcessor(BaseDataProcessor): self.mindspore_object_key = { "dtype": self.analyze_dtype_in_kwargs } + self._async_dump_cache = {} @staticmethod def get_md5_for_tensor(x): @@ -50,18 +57,9 @@ class MindsporeDataProcessor(BaseDataProcessor): return {"type": "mindspore.dtype", "value": str(element)} @staticmethod - def is_hookable_element(element): - return hasattr(element, "register_hook") and callable(element.register_hook) - - @classmethod - def get_special_types(cls): - return super().get_special_types() + cls.mindspore_special_type - - def get_stat_info(self, data): + def get_stat_info_sync(data): tensor_stat = TensorStatInfo() - if data.numel() == 0: - return tensor_stat - elif data.dtype == ms.bool_: + if data.dtype == ms.bool_: data_np = data.asnumpy() tensor_stat.max = np.max(data_np).item() tensor_stat.min = np.min(data_np).item() @@ -91,17 +89,64 @@ class MindsporeDataProcessor(BaseDataProcessor): api_register.norm_inner_op_set_hook_func() return tensor_stat + @staticmethod + def get_stat_info_async(data): + tensor_stat = TensorStatInfo() + stack_method = api_register.functional_ori_attr.get("stack", ms.ops.stack) + if data.dtype == ms.complex64 or data.dtype == ms.complex128: + logger.warning("Async dump do not support complex data!") + return tensor_stat + elif data.dtype == ms.bool_: + tensor_stat.stack_tensor_stat = (["Max", "Min"], stack_method([data.any(), data.all()])) + elif not data.shape: + tensor_stat.stack_tensor_stat = (["Max", "Min", "Mean", "Norm"], stack_method([data, data, data, data])) + else: + if not ops.is_floating_point(data) or data.dtype == ms.float64: + data = data.to(ms.float32) + api_register.norm_inner_op_set_ori_func() + get_max_value = api_register.mint_ops_ori_attr.get("max", mint.max) + get_min_value = api_register.mint_ops_ori_attr.get("min", mint.min) + get_mean_value = api_register.mint_ops_ori_attr.get("mean", mint.mean) + if hasattr(mint, "norm"): + get_norm_value = api_register.mint_ops_ori_attr.get("norm", mint.norm) + else: + get_norm_value = api_register.functional_ori_attr.get("norm", ops.norm) + tensor_stat.stack_tensor_stat = (["Max", "Min", "Mean", "Norm"], stack_method( + [get_max_value(data), get_min_value(data), get_mean_value(data), get_norm_value(data)])) + api_register.norm_inner_op_set_hook_func() + return tensor_stat + + @staticmethod + def is_hookable_element(element): + return hasattr(element, "register_hook") and callable(element.register_hook) + + @classmethod + def get_special_types(cls): + return super().get_special_types() + cls.mindspore_special_type + + def get_stat_info(self, data): + tensor_stat = TensorStatInfo() + if data.numel() == 0: + return tensor_stat + else: + if self.config.async_dump: + return MindsporeDataProcessor.get_stat_info_async(data) + else: + return MindsporeDataProcessor.get_stat_info_sync(data) + def analyze_single_element(self, element, suffix_stack): if suffix_stack and suffix_stack[-1] in self.mindspore_object_key: return self.mindspore_object_key[suffix_stack[-1]](element) converted_numpy, numpy_type = self._convert_numpy_to_builtin(element) if converted_numpy is not element: - return self._analyze_numpy(converted_numpy, numpy_type) + return {"type": numpy_type, "value": converted_numpy} if isinstance(element, Number): return self.analyze_dtype_in_kwargs(element) if isinstance(element, ms.Tensor): return self._analyze_tensor(element, Const.SEP.join([str(suffix) for suffix in suffix_stack])) + if isinstance(element, np.ndarray): + return self._analyze_numpy(element, Const.SEP.join([str(suffix) for suffix in suffix_stack])) if isinstance(element, (bool, int, float, str, slice, type(Ellipsis))): return self._analyze_builtin(element) return {} @@ -111,13 +156,17 @@ class MindsporeDataProcessor(BaseDataProcessor): tensor_json = { 'type': 'mindspore.Tensor', 'dtype': str(tensor.dtype), - 'shape': tensor.shape, - 'Max': self.transfer_type(tensor_stat.max), - 'Min': self.transfer_type(tensor_stat.min), - 'Mean': self.transfer_type(tensor_stat.mean), - 'Norm': self.transfer_type(tensor_stat.norm), + 'shape': tensor.shape } - if self.config.summary_mode == Const.MD5: + + if tensor_stat.stack_tensor_stat is None: + tensor_json.update({'Max': self.transfer_type(tensor_stat.max)}) + tensor_json.update({'Min': self.transfer_type(tensor_stat.min)}) + tensor_json.update({'Mean': self.transfer_type(tensor_stat.mean)}) + tensor_json.update({'Norm': self.transfer_type(tensor_stat.norm)}) + else: + tensor_json.update({'tensor_stat': tensor_stat.stack_tensor_stat}) + if self.config.summary_mode == Const.MD5 and not self.config.async_dump: tensor_md5 = self.get_md5_for_tensor(tensor) tensor_json.update({Const.MD5: tensor_md5}) return tensor_json @@ -128,13 +177,28 @@ class StatisticsDataProcessor(MindsporeDataProcessor): class TensorDataProcessor(MindsporeDataProcessor): + def dump_async_data(self): + for file_path, tensor in self._async_dump_cache.items(): + save_tensor_as_npy(tensor, file_path) + self._async_dump_cache.clear() + def _analyze_tensor(self, tensor, suffix): dump_data_name, file_path = self.get_save_file_path(suffix) single_arg = super()._analyze_tensor(tensor, suffix) single_arg.update({"data_name": dump_data_name}) - save_tensor_as_npy(tensor, file_path) + if self.config.async_dump: + self._async_dump_cache[file_path] = tensor.copy() + else: + save_tensor_as_npy(tensor, file_path) return single_arg + def _analyze_numpy(self, ndarray, suffix): + dump_data_name, file_path = self.get_save_file_path(suffix) + save_npy(ndarray, file_path) + ndarray_json = super()._analyze_numpy(ndarray, suffix) + ndarray_json.update({"data_name": dump_data_name}) + return ndarray_json + class OverflowCheckDataProcessor(MindsporeDataProcessor): __slots__ = ["cached_tensors_and_file_paths"] @@ -215,3 +279,61 @@ class OverflowCheckDataProcessor(MindsporeDataProcessor): self._analyze_maybe_overflow_tensor(single_arg) single_arg.update({"data_name": dump_data_name}) return single_arg + + +class KernelDumpDataProcessor(MindsporeDataProcessor): + def __init__(self, config, data_writer): + super().__init__(config, data_writer) + self.enable_kernel_dump = True + + @staticmethod + def start_kernel_dump(config_path): + hal.synchronize() + _msprobe_c.init_dump() + _msprobe_c.set_dump(config_path) + hal.synchronize() + + @staticmethod + def stop_kernel_dump(): + hal.synchronize() + _msprobe_c.finalize_dump() + hal.synchronize() + + @staticmethod + def _print_unsupported_log(api_name): + logger.warning(f"The kernel dump does not support the {api_name} API.") + + def analyze_forward_input(self, name, module, module_input_output): + if not self.enable_kernel_dump: + return + if not has_adump: + logger.warning("The current msprobe package does not compile adump, and kernel dump cannot be used.") + self.enable_kernel_dump = False + return + self.start_kernel_dump(self.config.kernel_config_path) + + def analyze_forward_output(self, name, module, module_input_output): + if not self.enable_kernel_dump: + return + self.enable_kernel_dump = False + self.stop_kernel_dump() + logger.info(f"The kernel data of {name} is dumped successfully.") + + def analyze_backward_input(self, name, module, module_input_output): + if not self.enable_kernel_dump: + return + if not has_adump: + logger.warning("The current msprobe package does not compile adump, and kernel dump cannot be used.") + self.enable_kernel_dump = False + return + self.start_kernel_dump(self.config.kernel_config_path) + + def analyze_backward(self, name, module, module_input_output): + if not self.enable_kernel_dump: + return + self.enable_kernel_dump = False + self.stop_kernel_dump() + logger.info(f"The kernel data of {name} is dumped successfully.") + + def reset_status(self): + self.enable_kernel_dump = True diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py index 2a525c1338baeca75c881dbc56a93dd3538a2201..05b68592041fa8f67c65230cdc53f88d22522e46 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -21,6 +21,7 @@ from typing import List import numpy as np import torch from torch import distributed as dist +from torch.distributed.distributed_c10d import _get_default_group from msprobe.core.common.const import Const from msprobe.core.common.file_utils import path_len_exceeds_limit @@ -40,7 +41,16 @@ except ImportError: class PytorchDataProcessor(BaseDataProcessor): - pytorch_special_type = (torch.device, torch.dtype, torch.Size, torch.Tensor, torch.memory_format, dist.ProcessGroup) + pytorch_special_type = ( + torch.device, + torch.dtype, + torch.Size, + torch.Tensor, + torch.memory_format, + dist.ProcessGroup, + dist.P2POp, + dist.ReduceOp + ) memory_format = { torch.contiguous_format: "contiguous_format", torch.channels_last: "channels_last", @@ -54,6 +64,7 @@ class PytorchDataProcessor(BaseDataProcessor): "device": self.analyze_device_in_kwargs, "dtype": self.analyze_dtype_in_kwargs } + self._async_dump_cache = {} @staticmethod def get_md5_for_tensor(x): @@ -82,33 +93,64 @@ class PytorchDataProcessor(BaseDataProcessor): return {"type": "torch.dtype", "value": str(element)} @staticmethod - def get_stat_info(data): + def get_stat_info_async(data): tensor_stat = TensorStatInfo() - if data.is_meta: + if torch.is_complex(data): + logger.warning("Async dump do not support complex data!") return tensor_stat - data_clone = data.detach() - if data_clone.numel() == 0: - return tensor_stat - elif data_clone.dtype == torch.bool: - tensor_stat.max = torch.any(data_clone).item() - tensor_stat.min = torch.all(data_clone).item() - elif not data_clone.shape: - tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data_clone.item() - elif torch.is_complex(data_clone): - data_np = data_clone.cpu().numpy() + elif data.dtype == torch.bool: + tensor_stat.stack_tensor_stat = (["Max", "Min"], torch.stack( + [torch.any(data), torch.all(data)])) + elif not data.shape: + tensor_stat.stack_tensor_stat = (["Max", "Min", "Mean", "Norm"], torch.stack([data, data, data, data])) + else: + if not data.is_floating_point() or data.dtype == torch.float64: + data = data.float() + tensor_stat.stack_tensor_stat = (["Max", "Min", "Mean", "Norm"], torch.stack([ + torch.max(data), + torch.min(data), + torch.mean(data), + torch.norm(data) + ])) + return tensor_stat + + @staticmethod + def get_stat_info_sync(data): + tensor_stat = TensorStatInfo() + if torch.is_complex(data): + data_np = data.cpu().numpy() data_abs = np.abs(data_np) tensor_stat.max = np.max(data_abs).item() tensor_stat.min = np.min(data_abs).item() tensor_stat.mean = np.mean(data_abs).item() + elif data.dtype == torch.bool: + tensor_stat.max = torch.any(data).item() + tensor_stat.min = torch.all(data).item() + elif not data.shape: + tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data.item() else: - if not data_clone.is_floating_point() or data_clone.dtype == torch.float64: - data_clone = data_clone.float() - tensor_stat.max = torch.max(data_clone).item() - tensor_stat.min = torch.min(data_clone).item() - tensor_stat.mean = torch.mean(data_clone).item() - tensor_stat.norm = torch.norm(data_clone).item() + if not data.is_floating_point() or data.dtype == torch.float64: + data = data.float() + tensor_stat.max = torch.max(data).item() + tensor_stat.min = torch.min(data).item() + tensor_stat.mean = torch.mean(data).item() + tensor_stat.norm = torch.norm(data).item() return tensor_stat + @staticmethod + def get_stat_info(data, async_dump=False): + tensor_stat = TensorStatInfo() + if data.is_meta: + return tensor_stat + data_clone = data.detach() + if data_clone.numel() == 0: + return tensor_stat + else: + if data_clone.device.type == Const.CPU_LOWERCASE or not async_dump: + return PytorchDataProcessor.get_stat_info_sync(data_clone) + else: + return PytorchDataProcessor.get_stat_info_async(data_clone) + @staticmethod def handle_tensor_extremum_nan_inf(tensor, operator): data_clone = tensor.detach() @@ -149,7 +191,6 @@ class PytorchDataProcessor(BaseDataProcessor): def _analyze_memory_format(arg): # 获取内存格式 format_type = PytorchDataProcessor.memory_format.get(arg) - return {"type": "torch.memory_format", "format": format_type} @staticmethod @@ -161,9 +202,18 @@ class PytorchDataProcessor(BaseDataProcessor): group_id = PytorchDataProcessor.process_group_hash(arg) group_info.update({"group_id": group_id}) except Exception as e: - logger.warning(f"Failed to get process group(id: {group_id}) ranks info with error info: {e}.") + logger.warning(f"Failed to get process group ranks info with error info: {e}.") return group_info + @staticmethod + def _analyze_reduce_op(arg): + op_type = None + try: + op_type = str(arg) + except Exception as e: + logger.warning(f"Failed to get value of torch.distributed.ReduceOp with error info: {e}.") + return {"type": "torch.distributed.ReduceOp", "value": op_type} + @classmethod def get_special_types(cls): return super().get_special_types() + cls.pytorch_special_type @@ -177,11 +227,17 @@ class PytorchDataProcessor(BaseDataProcessor): return self._analyze_memory_format(element) if isinstance(element, dist.ProcessGroup): return self._analyze_process_group(element) + if isinstance(element, dist.P2POp): + return self._analyze_p2pop(element) + if isinstance(element, dist.ReduceOp): + return self._analyze_reduce_op(element) converted_numpy, numpy_type = self._convert_numpy_to_builtin(element) if converted_numpy is not element: - return self._analyze_numpy(converted_numpy, numpy_type) + return {"type": numpy_type, "value": converted_numpy} if isinstance(element, torch.Tensor): return self._analyze_tensor(element, Const.SEP.join([str(suffix) for suffix in suffix_stack])) + if isinstance(element, np.ndarray): + return self._analyze_numpy(element, Const.SEP.join([str(suffix) for suffix in suffix_stack])) if isinstance(element, (bool, int, float, str, slice, type(Ellipsis))): return self._analyze_builtin(element) return {} @@ -191,26 +247,45 @@ class PytorchDataProcessor(BaseDataProcessor): module_input_output.update_output_with_args_and_kwargs() return super().analyze_forward_output(name, module, module_input_output) + def _analyze_p2pop(self, arg): + p2pop_info = {"class_type": "torch.distributed.P2POp"} + try: + tensor_info = self._analyze_tensor(arg.tensor, []) + p2pop_info.update({"tensor": tensor_info}) + p2pop_info.update({"op": arg.op.__name__}) + p2pop_info.update({"peer": arg.peer}) + p2pop_info.update({"tag": arg.tag}) + group_id = PytorchDataProcessor.process_group_hash( + arg.group) if arg.group else PytorchDataProcessor.process_group_hash(_get_default_group()) + p2pop_info.update({"group_id": group_id}) + except Exception as e: + logger.warning(f"Failed to parse the P2POp content with error info: {e}.") + return p2pop_info + def _analyze_tensor(self, tensor, suffix): - tensor_stat = self.get_stat_info(tensor) + tensor_stat = self.get_stat_info(tensor, self.config.async_dump) tensor_json = {} tensor_json.update({'type': 'torch.Tensor'}) tensor_json.update({'dtype': str(tensor.dtype)}) tensor_json.update({"shape": tensor.shape}) - tensor_json.update({"Max": tensor_stat.max}) - tensor_json.update({"Min": tensor_stat.min}) - tensor_json.update({"Mean": tensor_stat.mean}) - tensor_json.update({"Norm": tensor_stat.norm}) - tensor_json.update({"requires_grad": tensor.requires_grad}) - - if tensor_stat.max is not None: - if np.isinf(tensor_stat.max) or np.isnan(tensor_stat.max): - tensor_json['Max_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(tensor, "max") - if tensor_stat.min is not None: - if np.isinf(tensor_stat.min) or np.isnan(tensor_stat.min): - tensor_json['Min_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(tensor, "min") - - if self.config.summary_mode == Const.MD5: + if tensor_stat.stack_tensor_stat is None: + tensor_json.update({"Max": tensor_stat.max}) + tensor_json.update({"Min": tensor_stat.min}) + tensor_json.update({"Mean": tensor_stat.mean}) + tensor_json.update({"Norm": tensor_stat.norm}) + tensor_json.update({"requires_grad": tensor.requires_grad}) + if tensor_stat.max is not None: + if np.isinf(tensor_stat.max) or np.isnan(tensor_stat.max): + tensor_json['Max_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(tensor, "max") + if tensor_stat.min is not None: + if np.isinf(tensor_stat.min) or np.isnan(tensor_stat.min): + tensor_json['Min_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(tensor, "min") + + else: + tensor_json.update({"requires_grad": tensor.requires_grad}) + tensor_json.update({"tensor_stat": tensor_stat.stack_tensor_stat}) + + if self.config.summary_mode == Const.MD5 and not self.config.async_dump: tensor_md5 = self.get_md5_for_tensor(tensor) tensor_json.update({Const.MD5: tensor_md5}) return tensor_json @@ -221,14 +296,29 @@ class StatisticsDataProcessor(PytorchDataProcessor): class TensorDataProcessor(PytorchDataProcessor): + def dump_async_data(self): + for file_path, tensor in self._async_dump_cache.items(): + save_pt(tensor.contiguous(), file_path) + self._async_dump_cache.clear() + def _analyze_tensor(self, tensor, suffix): dump_data_name, file_path = self.get_save_file_path(suffix) - saved_tensor = tensor.clone().contiguous().detach() - save_pt(saved_tensor, file_path) single_arg = super()._analyze_tensor(tensor, suffix) single_arg.update({"data_name": dump_data_name}) + if self.config.async_dump: + self._async_dump_cache[file_path] = tensor.clone().detach() + else: + saved_tensor = tensor.clone().contiguous().detach() + save_pt(saved_tensor, file_path) return single_arg + def _analyze_numpy(self, ndarray, suffix): + dump_data_name, file_path = self.get_save_file_path(suffix) + save_pt(torch.tensor(ndarray), file_path) + ndarray_json = super()._analyze_numpy(ndarray, suffix) + ndarray_json.update({"data_name": dump_data_name}) + return ndarray_json + class OverflowCheckDataProcessor(PytorchDataProcessor): __slots__ = ["cached_tensors_and_file_paths"] diff --git a/debug/accuracy_tools/msprobe/core/data_dump/json_writer.py b/debug/accuracy_tools/msprobe/core/data_dump/json_writer.py index aa61b79deefee5cffbc9a01d9e3f525382d2241c..b1e26d16f9741765c1c9600a64efb112aa0f42d7 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/json_writer.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/json_writer.py @@ -16,10 +16,12 @@ import csv import os import copy +import numpy as np from msprobe.core.common.const import Const, FileCheckConst -from msprobe.core.common.file_utils import change_mode, FileOpen, save_json +from msprobe.core.common.file_utils import change_mode, FileOpen, save_json, load_json from msprobe.core.common.log import logger +from msprobe.core.common.exceptions import MsprobeException class DataWriter: @@ -132,4 +134,30 @@ class DataWriter: if self.cache_construct: self.write_construct_info_json(self.construct_file_path) if self.cache_debug: - self.write_debug_info_json(self.debug_file_path) \ No newline at end of file + self.write_debug_info_json(self.debug_file_path) + + def fill_stack_tensor_data(self): + self.process_stat_data_recursive(self.cache_data) + + def process_stat_data_recursive(self, data, depth=0): + if depth > Const.MAX_DEPTH: + logger.error(f"The maximum depth of recursive process stat data, {Const.MAX_DEPTH} is reached.") + raise MsprobeException(MsprobeException.RECURSION_LIMIT_ERROR) + if isinstance(data, dict): + if "tensor_stat" in data.keys(): + tensor_stat = data["tensor_stat"] + if len(tensor_stat) != Const.TENSOR_STAT_LEN or len(tensor_stat[0]) != len(tensor_stat[1]): + logger.warning("Some bad data in async dump") + else: + tensor_stat_index, tensor_stat_data = tensor_stat[0], tensor_stat[1] + if hasattr(tensor_stat_data, "device") and tensor_stat_data.device != Const.CPU_LOWERCASE: + tensor_stat_data = tensor_stat_data.cpu() + for index, stat in zip(tensor_stat_index, tensor_stat_data): + data.update({index: stat.item()}) + del data["tensor_stat"] + else: + for key in data.keys(): + self.process_stat_data_recursive(data[key], depth + 1) + elif isinstance(data, (list, tuple)): + for i in data: + self.process_stat_data_recursive(i, depth + 1) \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/core/data_dump/scope.py b/debug/accuracy_tools/msprobe/core/data_dump/scope.py index 5640b97a8726d7329080518a84c704d8699f2401..7632dcf30c9eb4cc6047cde5fff5d230176b9fc0 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/scope.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/scope.py @@ -45,7 +45,7 @@ class ScopeFactory: if self.level == Const.LEVEL_MIX: return mix_range_scope - + if not self.scope: return api_range_scope if api_range_scope.is_valid and module_range_scope.is_valid: @@ -73,21 +73,21 @@ class BaseScope(ABC): def rectify_args(scope, api_list): if not isinstance(api_list, list): raise ScopeException(ScopeException.InvalidApiStr, - f"api_list参数须配置为列表,实际类型为{type(api_list)}.") + f"api_list参数须配置为列表,实际类型为{type(api_list)}.") for api in api_list: if not isinstance(api, str): raise ScopeException(ScopeException.InvalidApiStr, - f"api_list中的元素须配置为字符串,实际类型为{type(api)}.") + f"api_list中的元素须配置为字符串,实际类型为{type(api)}.") if isinstance(scope, str): scope = [scope] return scope, api_list if not isinstance(scope, list): raise ScopeException(ScopeException.InvalidScope, - f"scope参数须配置为字符串或列表,实际类型为{type(scope)}.") + f"scope参数须配置为字符串或列表,实际类型为{type(scope)}.") for s in scope: if not isinstance(s, str): raise ScopeException(ScopeException.InvalidScope, - f"scope列表元素要求类型为字符串,实际类型为{type(s)}.") + f"scope列表元素要求类型为字符串,实际类型为{type(s)}.") return scope, api_list @abstractmethod @@ -108,7 +108,7 @@ class ListScope(BaseScope): def rectify_args(scope, api_list): if scope and api_list: raise ScopeException(ScopeException.ArgConflict, - f"scope和api_list不可以同时配置,实际配置为scope={scope}, api_list={api_list}.") + f"scope和api_list不可以同时配置,实际配置为scope={scope}, api_list={api_list}.") return super(ListScope, ListScope).rectify_args(scope, api_list) def check(self, name): @@ -134,23 +134,23 @@ class RangeScope(BaseScope, ABC): if self.level == Const.LEVEL_L1: if not re.match(api_pattern, name): raise ScopeException(ScopeException.InvalidScope, - f"scope参数格式错误,要求格式为api完整命名,实际为{name}.") - + f"scope参数格式错误,要求格式为api完整命名,实际为{name}.") + if self.level == Const.LEVEL_L0: if not re.match(module_pattern, name): raise ScopeException(ScopeException.InvalidScope, - f"scope参数格式错误,要求格式为模块完整命名,实际为{name}.") + f"scope参数格式错误,要求格式为模块完整命名,实际为{name}.") if self.level == Const.LEVEL_MIX: if not re.match(api_pattern, name) and not re.match(module_pattern, name): raise ScopeException(ScopeException.InvalidScope, - f"scope参数格式错误,要求格式为api或模块完整命名,实际为{name}.") + f"scope参数格式错误,要求格式为api或模块完整命名,实际为{name}.") def rectify_args(self, scope, api_list): scope, api_list = super(RangeScope, RangeScope).rectify_args(scope, api_list) if scope and len(scope) != 2: raise ScopeException(ScopeException.InvalidScope, - f"scope参数指定区间断点,须传入长度为2的列表,实际长度为{len(scope)}.") + f"scope参数指定区间断点,须传入长度为2的列表,实际长度为{len(scope)}.") for name in scope: self.check_name_pattern(name) return scope, api_list @@ -230,7 +230,7 @@ class ModuleRangeScope(RangeScope): class MixRangeScope(RangeScope): def check_scope_is_valid(self): return True if self.scope else False - + def begin_module(self, module_name): if self.scope and module_name == self.scope[0]: self.in_scope = True @@ -249,12 +249,12 @@ class MixRangeScope(RangeScope): def check_api_list(self, api_name): if not self.api_list: return True - + for name in self.api_list: if name in api_name: return True return False - + def check(self, name): """ dump时调用的接口,根据scope和api_list判断是否需要dump @@ -272,4 +272,3 @@ class MixRangeScope(RangeScope): if self.scope and name == self.scope[1]: self.in_scope = False return result - \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/docs/01.installation.md b/debug/accuracy_tools/msprobe/docs/01.installation.md index 7b0d8245b86336a412c760f88a0d9127d11001b7..1ab5f6419ba07ec749bad139f874fbc7301fd8b3 100644 --- a/debug/accuracy_tools/msprobe/docs/01.installation.md +++ b/debug/accuracy_tools/msprobe/docs/01.installation.md @@ -16,6 +16,8 @@ pip install mindstudio-probe |版本|发布日期|支持 PyTorch 版本|支持 MindSpore 版本|下载链接|校验码| |:--:|:--:|:--:|:--:|:--:|:--:| +|1.2.2|2025.2.26|1.11/2.0/2.1/2.2|2.4.0|[mindstudio_probe-1.2.2-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/msprobe/1.2/mindstudio_probe-1.2.2-py3-none-any.whl)|1db0cf4572bc0305c68705b74775f652c6cb2c2bedb6c6e57f43e31ab273b288| +|1.2.1|2025.2.07|1.11/2.0/2.1/2.2|2.4.0|[mindstudio_probe-1.2.1-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/msprobe/1.2/mindstudio_probe-1.2.1-py3-none-any.whl)|b64b342118558e0339b39237f88a49b93fd24551b0cb202c872fbfef4260c86b| |1.2.0|2025.1.13|1.11/2.0/2.1/2.2|2.4.0|[mindstudio_probe-1.2.0-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/msprobe/1.2/mindstudio_probe-1.2.0-py3-none-any.whl)|1e3aeea1706112f6ee52fd1165037936bb209138f0b9ec42ea21e2c1c8942cdc| |1.1.1|2024.12.09|1.11/2.0/2.1/2.2|2.4.0|[mindstudio_probe-1.1.1-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/msprobe/1.1/mindstudio_probe-1.1.1-py3-none-any.whl)|577b597555dc155b76ba1a62d575c3546004644e140a456c3ba0824d46283735| |1.1.0|2024.10.14|1.11/2.0/2.1/2.2|2.4.0|[mindstudio_probe-1.1.0-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/msprobe/1.1/mindstudio_probe-1.1.0-py3-none-any.whl)|83a5a9b7c65a357639f8c9636d88c693b4cf0eb590d4f8f5cb56395ba69b1f6d| @@ -54,6 +56,30 @@ pip install ./mindstudio_probe*.whl # 特性变更说明 +## 1.2.0 + +【数据采集】 + +- 模块级dump支持采集权重及权重梯度 +- 修复原地覆盖类API前向输入数据采集不正确的问题 +- seed_all接口支持控制dropout失效功能 + +【精度预检】 + +- MindSpore场景新增支持Tensor类的mint API的预检 + +【训练状态监控】 + +- 支持FSDP和ZeRO-0 +- 异常排序支持前向激活值和反向梯度 + +【分级可视化构图比对】 + +- 支持graph结构分页展示,支持graph批量构建和比对 +- 支持溢出检测模式 + +## 1.1.1 + ## 1.1.1 【数据采集】 diff --git a/debug/accuracy_tools/msprobe/docs/02.config_introduction.md b/debug/accuracy_tools/msprobe/docs/02.config_introduction.md index e0a4e56ddad7cac5bdfb2c52be5688cd0f5411df..b7ecb61409116d7123ea7b6a0813b8588db58ca2 100644 --- a/debug/accuracy_tools/msprobe/docs/02.config_introduction.md +++ b/debug/accuracy_tools/msprobe/docs/02.config_introduction.md @@ -18,6 +18,7 @@ | step | 指定采集某个 step 的数据,list[Union[int, str]] 类型。默认未配置,表示采集所有 step 数据。采集特定 step 时,须指定为训练脚本中存在的 step,可逐个配置,也可以指定范围。
**配置示例**:"step": [0, 1 , 2, "4-6"]。 | 否 | | level | dump 级别,str 类型,根据不同级别采集不同数据。可选参数:
"L0":dump 模块级精度数据,PyTorch 与 MindSpore 均支持,使用背景详见 [1.1.1 模块级精度数据 dump 说明](#111-模块级精度数据-dump-说明);
"L1":dump API 级精度数据,默认值,仅 PyTorch 与 MindSpore 动态图场景支持;
"L2":dump kernel 级精度数据,PyTorch场景详细介绍见 [PyTorch 场景的 kernel dump 说明](./04.kernel_dump_PyTorch.md);
"mix":dump module 模块级和 API 级精度数据,即"L0"+"L1",仅 PyTorch 与 MindSpore 动态图场景支持。
"debug":单点保存功能,细节详见[单点保存工具 README](./28.debugger_save_instruction.md)
**配置示例**:"level": "L1"。 | 否 | | enable_dataloader | 自动控制开关,bool 类型,仅 PyTorch 场景支持。可选参数 true(开启)或 false(关闭),默认为 false。配置为 true 后自动识别 step 参数指定的迭代,并在该迭代执行完成后退出训练,此时 start、stop 和 step 函数可不配置,开启该开关要求训练脚本是通过 torch.utils.data.dataloader 方式加载数据。仅支持 PyTorch 单卡训练使用,分布式训练场景下存在数据 dump 不全问题。 **这个特性下个版本将被废弃** | 否 | +| async_dump | 异步 dump 开关,bool 类型。可选参数 true(开启)或 false(关闭),默认为 false。配置为 true 后开启异步 dump,即采集的精度数据会在当前 step 训练结束后统一落盘,训练过程中工具不触发同步操作。由于使用该模式有**显存溢出**的风险,当 task 配置为 tensor 时,即真实数据的异步dump模式,必须配置 [list](#13-task-配置为-tensor) 参数,指定需要 dump 的 tensor 。该模式暂不支持复数类型 tensor
的统计量计算。 #### 1.1.1 模块级精度数据 dump 说明 @@ -29,6 +30,8 @@ PyTorch 与 MindSpore 均支持。 模块指的是继承 nn.Module 类(PyTorch场景)或 nn.Cell 类(MindSpore场景)的子类,通常情况下这类模块就是一个小模型,可以被视为一个整体,dump 数据时以模块为粒度进行 dump。 + + ### 1.2 task 配置为 statistics @@ -161,4 +164,3 @@ PyTorch 与 MindSpore 动态图场景下,"level"须为"L0"或"L1";MindSpore intervals就是根据值分布bounds划分出的区间。 MindSpore静态图模式下,L0级别中暂不支持"MD5" - diff --git a/debug/accuracy_tools/msprobe/docs/03.config_examples.md b/debug/accuracy_tools/msprobe/docs/03.config_examples.md index 715f261c64e7481fd1d087498c76b954b9cee29d..542250fac243f3ab2f1d0aff87bc509ac7c1a675 100644 --- a/debug/accuracy_tools/msprobe/docs/03.config_examples.md +++ b/debug/accuracy_tools/msprobe/docs/03.config_examples.md @@ -100,6 +100,18 @@ } ``` +### 1.6 task 配置为 structure + +```json +{ + "task": "structure", + "dump_path": "/home/data_dump", + "rank": [], + "step": [], + "level": "mix" +} +``` + ## 2 MindSpore 静态图场景 ### 2.1 task 配置为 statistics @@ -228,3 +240,15 @@ } } ``` + +### 3.5 task 配置为 structure + +```json +{ + "task": "structure", + "dump_path": "/home/data_dump", + "rank": [], + "step": [], + "level": "mix" +} +``` diff --git a/debug/accuracy_tools/msprobe/docs/05.data_dump_PyTorch.md b/debug/accuracy_tools/msprobe/docs/05.data_dump_PyTorch.md index c3698003a43fc229ced1d88a69d6de2ddb1f8d95..db9a989c9d1c731fd9099d311f3ab3b95e5c7d5d 100644 --- a/debug/accuracy_tools/msprobe/docs/05.data_dump_PyTorch.md +++ b/debug/accuracy_tools/msprobe/docs/05.data_dump_PyTorch.md @@ -44,7 +44,7 @@ level 配置为"L0"或"mix"时,必须在该接口或 **start** 接口中配置 debugger.start(model=None) ``` -1. model:指定需要采集 Module 级数据的模型,支持传入 torch.nn.Module 或 list[torch.nn.Module] 类型,默认未配置。 +1. model:指定需要采集 Module 级数据的模型,支持传入 torch.nn.Module、list[torch.nn.Module]或Tuple[torch.nn.Module] 类型,默认未配置。 level 配置为"L0"或"mix"时,必须在该接口或 **PrecisionDebugger** 接口中配置该参数。 本接口中的 model 比 PrecisionDebugger 中 model 参数优先级更高,会覆盖 PrecisionDebugger 中的 model 参数。 @@ -52,8 +52,10 @@ level 配置为"L0"或"mix"时,必须在该接口或 **PrecisionDebugger** 接 **功能说明**:停止精度数据采集。在 **start** 函数之后的任意位置添加。 若 **stop** 函数添加在反向计算代码(如loss.backward)之后,则会采集 **start** 和该函数之间的前反向数据。 -若 **stop** 函数添加在反向计算代码之前,则需要将 **step** 函数添加到反向计算代码之后,才能采集 **start** 和该函数之间的前反向数据。 -**step** 函数详细介绍见1.5章节。使用示例可参见 [2.1 快速上手](#21-快速上手) 和 [2.2 采集完整的前反向数据](#22-采集完整的前反向数据)。 +若 **stop** 函数添加在反向计算代码之前,则需要将 [**step**](#15-step) 函数添加到反向计算代码之后,才能采集 **start** 和该函数之间的前反向数据。 +使用示例可参见 [2.1 快速上手](#21-快速上手) 和 [2.2 采集完整的前反向数据](#22-采集完整的前反向数据)。 + +**注意**:**stop** 函数必须调用,否则可能导致精度数据落盘不全。 **原型**: @@ -329,11 +331,12 @@ if __name__ == "__main__": │ | ├── rank0 │ | │ ├── dump_tensor_data | | | | ├── Tensor.permute.1.forward.pt -| | | | ├── Module.conv1.Conv2d.forward.0.input.0.pt # 命名格式为{Module}.{module_name}.{class_name}.{forward/backward}.{调用次数}.{input/output}.{参数序号}, 其中,“参数序号”表示该Module的第n个参数,例如1,则为第一个参数,若该参数为list格式,则根据list继续排序,例如1.1,表示该Module的第1个参数的第1个元素。 -| | | | ├── Module.conv1.Conv2D.forward.0.parameters.bias.pt # 模块参数数据:命名格式为{Module}.{module_name}.{class_name}.forward.{调用次数}.parameters.{parameter_name}。 -| | | | ├── Module.conv1.Conv2D.parameters_grad.weight.pt # 模块参数梯度数据:命名格式为{Module}.{module_name}.{class_name}.parameters_grad.{parameter_name}。因为同一模块的参数使用同一梯度进行更新,所以参数梯度文件名不包含调用次数。 +| | | | ├── Functional.linear.5.backward.output.pt # 命名格式为{api_type}.{api_name}.{API调用次数}.{forward/backward}.{input/output}.{参数序号}, 其中,“参数序号”表示该API的第n个输入或输出,例如1,则为第一个参数,若该参数为list格式,则根据list继续排序,例如1.1,表示该API的第1个参数的第1个元素。 | | | | ... -| | | | └── Functional.linear.5.backward.output.pt # 命名格式为{api_type}.{api_name}.{API调用次数}.{forward/backward}.{input/output}.{参数序号}, 其中,“参数序号”表示该API的第n个输入或输出,例如1,则为第一个参数,若该参数为list格式,则根据list继续排序,例如1.1,表示该API的第1个参数的第1个元素。 +| | | | ├── Module.conv1.Conv2d.forward.0.input.0.pt # 命名格式为{Module}.{module_name}.{class_name}.{forward/backward}.{调用次数}.{input/output}.{参数序号}, 其中,“参数序号”表示该Module的第n个参数,例如1,则为第一个参数,若该参数为list格式,则根据list继续排序,例如1.1,表示该Module的第1个参数的第1个元素。 +| | | | ├── Module.conv1.Conv2D.forward.0.parameters.bias.pt # 模块参数数据:命名格式为{Module}.{module_name}.{class_name}.forward.{调用次数}.parameters.{parameter_name}。 +| | | | └── Module.conv1.Conv2D.parameters_grad.weight.pt # 模块参数梯度数据:命名格式为{Module}.{module_name}.{class_name}.parameters_grad.{parameter_name}。因为同一模块的参数使用同一梯度进行更新,所以参数梯度文件名不包含调用次数。 +| | | | # 当dump时传入的model参数为List[torch.nn.Module]或Tuple[torch.nn.Module]时,模块级数据的命名中包含该模块在列表中的索引index,命名格式为{Module}.{index}.*,*表示以上三种模块级数据的命名格式,例如:Module.0.conv1.Conv2d.forward.0.input.0.pt。 │ | | ├── dump.json │ | | ├── stack.json │ | | └── construct.json diff --git a/debug/accuracy_tools/msprobe/docs/06.data_dump_MindSpore.md b/debug/accuracy_tools/msprobe/docs/06.data_dump_MindSpore.md index 9826ef8abd86cd79741bc5bf3ae35c04925f3f7f..e3755663809ad3176e6d11328dad4c8f17a68098 100644 --- a/debug/accuracy_tools/msprobe/docs/06.data_dump_MindSpore.md +++ b/debug/accuracy_tools/msprobe/docs/06.data_dump_MindSpore.md @@ -69,12 +69,15 @@ dump 的"tensor"模式采集数据量大小,可以参考[数据量基线](data **原型**: ```Python -PrecisionDebugger(config_path=None) +PrecisionDebugger(config_path=None, task=None, dump_path=None, level=None, step=None) ``` **参数说明**: 1. config_path:指定 dump 配置文件路径,string 类型。参数示例:"./config.json"。未配置该路径时,默认使用 [config.json](../config.json) 文件的默认配置,配置选项含义可见 [config.json 介绍](./02.config_introduction.md)。 +2. 其他参数均在 [config.json](../config.json) 文件中可配,详细配置可见 [config.json 介绍](./02.config_introduction.md)。 + +此接口的参数均不是必要,且优先级高于 [config.json](../config.json) 文件中的配置,但可配置的参数相比 config.json 较少。 #### 6.1.1 start @@ -88,13 +91,15 @@ start(model=None) **参数说明**: -1. model:指具体的 mindspore.nn.Cell对象,默认不配置。Cell级别("L0" level)dump 与 "mix" level dump 时,必须传入 model 才可以采集 model 内的所有Cell 对象数据。API级别("L1" level)dump 时,传入 model 可以采集 model 内包含 primitive op 对象在内的所有 API 数据,若不传入 model 参数,则只采集非 primitive op 的 API 数据。 +1. model:指定需要采集数据的实例化模型,支持传入mindspore.nn.Cell、List[mindspore.nn.Cell]或Tuple[mindspore.nn.Cell] 类型, 默认未配置。Cell级别("L0" level)dump 与 "mix" level dump 时,必须传入 model 才可以采集 model 内的所有Cell 对象数据。API级别("L1" level)dump 时,传入 model 可以采集 model 内包含 primitive op 对象在内的所有 API 数据,若不传入 model 参数,则只采集非 primitive op 的 API 数据。 #### 6.1.2 stop **功能说明**:停止精度数据采集。在 **start** 函数之后的任意位置添加。若 **stop** 函数添加在反向计算代码之后,则会采集 **start** 和该函数之间的前反向数据。 -若 **stop** 函数添加在反向计算代码之前,则需要将 **step** 函数添加到反向计算代码之后,才能采集 **start** 和该函数之间的前反向数据。 -**step** 函数详细介绍见6.1.3章节。**仅未使用 Model 高阶 API 的动态图场景支持。** +若 **stop** 函数添加在反向计算代码之前,则需要将 [**step**](#613-step) 函数添加到反向计算代码之后,才能采集 **start** 和该函数之间的前反向数据。 +**仅未使用 Model 高阶 API 的动态图场景支持。** + +**注意**:**stop** 函数必须调用,否则可能导致精度数据落盘不全。 **原型**: @@ -144,6 +149,8 @@ save(variable, name, save_backward=True) | name | 指定的名称 | str | 是 | | save_backward | 是否保存反向数据 | boolean | 否 | + + ### 6.2 msprobe.mindspore.common.utils.MsprobeStep **功能说明**:MindSpore Callback类,自动在每个step开始时调用start()接口,在每个step结束时调用stop()、step()接口。实现使用 Model 高阶 API 的动态图场景下 L0、L1、mix 级别,和静态图场景下 L0级别的精度数据采集控制,控制粒度为单个 **Step** ,而 PrecisionDebugger.start, PrecisionDebugger.stop 接口的控制粒度任意训练代码段。 @@ -393,9 +400,10 @@ dump 结果目录结构示例如下: | | | | ... | | | | ├── Jit.AlexNet.0.forward.input.0.npy | | | | ├── Primitive.conv2d.Conv2D.0.forward.input.0.npy -| | | | ├── Cell.conv1.Conv2D.forward.0.parameters.weight.npy # 模块参数数据:命名格式为{Cell}.{cell_name}.{class_name}.forward.{调用次数}.parameters.{parameter_name}。 -| | | | ├── Cell.conv1.Conv2D.parameters_grad.weight.npy # 模块参数梯度数据:命名格式为{Cell}.{cell_name}.{class_name}.parameters_grad.{parameter_name}。因为同一模块的参数使用同一梯度进行更新,所以参数梯度文件名不包含调用次数。 -| | | | └── Cell.relu.ReLU.forward.0.input.0.npy # 命名格式为{Cell}.{cell_name}.{class_name}.{forward/backward}.{调用次数}.{input/output}.{参数序号}, 其中,“参数序号”表示该Cell的第n个参数,例如1,则为第一个参数,若该参数为list格式,则根据list继续排序,例如1.1,表示该Cell的第1个参数的第1个元素。 +| | | | ├── Cell.conv1.Conv2D.forward.0.parameters.weight.npy # 模块参数数据:命名格式为{Cell}.{cell_name}.{class_name}.forward.{调用次数}.parameters.{parameter_name}。 +| | | | ├── Cell.conv1.Conv2D.parameters_grad.weight.npy # 模块参数梯度数据:命名格式为{Cell}.{cell_name}.{class_name}.parameters_grad.{parameter_name}。因为同一模块的参数使用同一梯度进行更新,所以参数梯度文件名不包含调用次数。 +| | | | └── Cell.relu.ReLU.forward.0.input.0.npy # 命名格式为{Cell}.{cell_name}.{class_name}.{forward/backward}.{调用次数}.{input/output}.{参数序号}, 其中,“参数序号”表示该Cell的第n个参数,例如1,则为第一个参数,若该参数为list格式,则根据list继续排序,例如1.1,表示该Cell的第1个参数的第1个元素。 +| | | | # 当dump时传入的model参数为List[mindspore.nn.Cell]或Tuple[mindspore.nn.Cell]时,模块级数据的命名中包含该模块在列表中的索引index,命名格式为{Cell}.{index}.*,*表示以上三种模块级数据的命名格式,例如:Cell.0.relu.ReLU.forward.0.input.0.npy。 │ | | ├── dump.json │ | | ├── stack.json │ | | └── construct.json diff --git a/debug/accuracy_tools/msprobe/docs/07.accuracy_checker_PyTorch.md b/debug/accuracy_tools/msprobe/docs/07.accuracy_checker_PyTorch.md index 5f0c094e308b22d2dc998501be0fea149ff2b26d..b07568e25a2915a4e8e5c2157e7de4252410f38d 100644 --- a/debug/accuracy_tools/msprobe/docs/07.accuracy_checker_PyTorch.md +++ b/debug/accuracy_tools/msprobe/docs/07.accuracy_checker_PyTorch.md @@ -295,3 +295,13 @@ a:误差比对法指标。 - npu_scaled_masked_softmax - npu_swiglu + +- npu_apply_adam + +- npu_group_norm_silu + +- npu_mish + +- npu_moe_gating_top_k_softmax + +- npu_sort_v2 diff --git a/debug/accuracy_tools/msprobe/docs/09.accuracy_checker_MindSpore.md b/debug/accuracy_tools/msprobe/docs/09.accuracy_checker_MindSpore.md index d14daf09d3a36c860a646e1e2279b134d02425bc..3bf65032edae2b8e35c5818d5c030c9ce4c79e95 100644 --- a/debug/accuracy_tools/msprobe/docs/09.accuracy_checker_MindSpore.md +++ b/debug/accuracy_tools/msprobe/docs/09.accuracy_checker_MindSpore.md @@ -2,9 +2,9 @@ ## 1 简介 -**MindSpore 动态图精度预检**a通过扫描昇腾 NPU 上用户训练 MindSpore 模型中的所有 Mint API,输出精度情况的诊断和分析。工具以模型中所有 Mint API 前反向的 dump 结果为输入,构造相应的 API 单元测试,将 NPU 输出与标杆(CPU 高精度)比对,计算对应的精度指标,从而找出 NPU 中存在精度问题的 Mint API。本工具支持**随机生成模式和真实数据模式**b。 +**MindSpore 动态图精度预检**a通过扫描昇腾 NPU 上用户训练 MindSpore 模型中的所有 Mint API 以及 Msadapter场景下迁移的 Mindspore API,输出精度情况的诊断和分析。工具以模型中所有 API 前反向的 dump 结果为输入,构造相应的 API 单元测试,将 NPU 输出与标杆(CPU 高精度)比对,计算对应的精度指标,从而找出 NPU 中存在精度问题的 API。本工具支持**随机生成模式和真实数据模式**b。 -a. 支持 Mindspore 版本:2.4; +a. 支持 Mindspore 版本:2.4/2.5; b. (可选)当使用Msadapter时,由于需要环境中同时存在 Torch 与 Msadapter,所以只支持在**安装原生Torch**的场景下通过export PYTHONPATH="xx/msadapter/build/lib"等通过**环境变量使能Msadapter的方式**的环境中进行预检,预检工具能够自动索引得到所需的 Torch 与 Msadapter环境,环境安装详细参考:[msadapter官网](https://gitee.com/mindspore/msadapter)。 diff --git a/debug/accuracy_tools/msprobe/docs/10.accuracy_compare_PyTorch.md b/debug/accuracy_tools/msprobe/docs/10.accuracy_compare_PyTorch.md index 0ef8313e31ca5ac10299cdabe2b0adf5be38bade..b4525d738d849a17ca5049bd2214784c6f788d21 100644 --- a/debug/accuracy_tools/msprobe/docs/10.accuracy_compare_PyTorch.md +++ b/debug/accuracy_tools/msprobe/docs/10.accuracy_compare_PyTorch.md @@ -270,16 +270,16 @@ PyTorch 精度比对是以 CPU 或 GPU 的计算结果为标杆,通过计算 在比对结果中的Err_message列呈现比对结果颜色标记的原因,具体含义如下: 红色标记情况: -1. 一个 API 或模块的 One Thousandth Err Ratio 的 input > 0.9 同时 output < 0.6(真实数据模式); -2. 一个 API 或模块的 output 的最大值相对误差 (Max diff 除以 max(0.01, Bench max)) > 0.5(统计数据模式); -3. 一个 API 或模块的 NPU 的最大值或最小值中存在 nan/inf/-inf(真实数据模式、统计数据模式); -4. 一个 API 或模块的最大值绝对误差大于 1e+10(真实数据模式,统计数据模式)。 +1. 一个 API 或模块的 NPU 的最大值或最小值中存在 nan/inf/-inf(真实数据模式、统计数据模式); +2. 一个 API 或模块的最大值绝对误差大于 1e+10(真实数据模式,统计数据模式); +3. 一个 API 或模块的 One Thousandth Err Ratio 的 input/parameters > 0.9 同时 output < 0.6(真实数据模式)(仅标记output); +4. 一个 API 或模块的 output 的最大值相对误差 (Max diff 除以 max(0.01, Bench max)) > 0.5(统计数据模式)(仅标记output)。 -黄色标记情况: -1. 一个 API 或模块的 One Thousandth Err Ratio 的 input - output > 0.1(真实数据模式); -2. 一个 API 或模块的 Cosine 的 input - output > 0.1(真实数据模式); -3. 一个 API 或模块的 output 的最大值相对误差 > 0.1 同时 input < 0.01(真实数据模式,统计数据模式); -4. 一个 API 或模块的 input 与 output 的最大值绝对误差都大于 1,同时 output 比 input 大一个数量级以上(真实数据模式、统计数据模式)。 +黄色标记情况(仅标记output): +1. 一个 API 或模块的 input/parameters 与 output 的最大值绝对误差都大于 1,同时 output 比 input/parameters 大一个数量级以上(真实数据模式、统计数据模式); +2. 一个 API 或模块的 One Thousandth Err Ratio 的 input/parameters - output > 0.1(真实数据模式); +3. 一个 API 或模块的 output 的最大值相对误差 > 0.1 同时 input/parameters < 0.01(真实数据模式,统计数据模式); +4. 一个 API 或模块的 Cosine 的 input/parameters - output > 0.1(真实数据模式)。 ### 3.3 比对结果(Result)——统计数据模式、MD5 模式 @@ -342,6 +342,10 @@ MD5 模式: 本功能是将多卡比对场景的比对结果,进行通信算子数据提取和汇总,输出整理好的通信算子多卡比对精度表。 +**使用场景** + +已完成精度比对,获得多卡精度比对结果,但是通信算子数据分布在多个结果件中,不利于精度问题的分析。通过此功能,可以汇总多卡通信算子数据,减少问题定位时间。 + **约束** 不支持MD5比对结果。 @@ -389,3 +393,23 @@ compare_index: 2. rank*列为多卡数据。 3. 不同比对指标的数据通过不同sheet页呈现。 4. 如果一个API或module在某张卡上找不到数据,汇总结果中将空白呈现。 +5. 如果比对指标值为N/A,unsupported,Nan,表示无法计算该比对指标值,汇总结果将以”NPU:’NPU max值‘ Bench:’Bench max值‘“呈现。 +6. 针对图示案例,此处NPU:N/A Bench:N/A表示output为None。 + +
+如何基于group信息查看分组数据: + +以Distributed.all_reduce.0.forward为例。这个API将多卡数据规约操作,输出为一个group内的规约结果,同一个group内的输出保持一致。
这个API中,rank0-3为一个group,Distributed.all_reduce.0.forward.input.group展示为tp-0-1-2-3,rank0-3输出一致;rank4-7为一个group,展示为tp-4-5-6-7,rank4-7输出一致。
group除了这种形式,还有如[0, 1, 2, 3]的呈现形式。 + +
+常见通信API预期结果: + +1. Distributed.all_gather:多卡数据汇总,每张卡输入可以不一致,同group内输出一致,输出是张量列表。 +2. Distributed.all_gather_into_tensor:多卡数据汇总,每张卡输入可以不一致,同group内输出一致,输出是张量。 +3. Distributed.all_reduce:多卡数据规约操作,每张卡输入可以不一致,同group内输出一致,为规约结果。 +4. Distributed.reduce_scatter:多卡数据规约操作,每张卡输入可以不一致,输出为group内规约结果的不同部分,输入是张量列表。 +5. Distributed.reduce_scatter_tensor:多卡数据规约操作,每张卡输入可以不一致,输出为group内规约结果的不同部分,输入是张量。 +6. Distributed.broadcast:输入为要广播的数据,输出为广播后的数据。 +7. Distributed.isend:点对点通信,输入为要发送的数据,输出为发送的数据。 +8. Distributed.irecv:点对点通信,输入为原数据,输出为接收的新数据。 +9. Distributed.all_to_all_single:输出数据为所有卡上的数据切分后合并的结果。 \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/docs/11.accuracy_compare_MindSpore.md b/debug/accuracy_tools/msprobe/docs/11.accuracy_compare_MindSpore.md index 0768f77fb2ad952edda0e25741484297410cf2d5..1b1824a774f15a86106585669d5f3412b3faca2e 100644 --- a/debug/accuracy_tools/msprobe/docs/11.accuracy_compare_MindSpore.md +++ b/debug/accuracy_tools/msprobe/docs/11.accuracy_compare_MindSpore.md @@ -187,6 +187,10 @@ layer_mapping可以从Layer层识别整网的API和Cell,简化配置。 本功能是将多卡比对场景的比对结果,进行通信算子数据提取和汇总,输出整理好的通信算子多卡比对精度表。 +**使用场景** + +已完成精度比对,获得多卡精度比对结果,但是通信算子数据分布在多个结果件中,不利于精度问题的分析。通过此功能,可以汇总多卡通信算子数据,减少问题定位时间。 + **约束** - 不支持MD5比对结果。 @@ -235,6 +239,26 @@ compare_index: 2. rank*列为多卡数据。 3. 不同比对指标的数据通过不同sheet页呈现。 4. 如果一个API或module在某张卡上找不到数据,汇总结果中将空白呈现。 +5. 如果比对指标值为N/A,unsupported,Nan,表示无法计算该比对指标值,汇总结果将以”NPU:’NPU max值‘ Bench:’Bench max值‘“呈现。 +6. 针对图示案例,此处NPU:N/A Bench:N/A表示output为None。 + +
+如何基于group信息查看分组数据: + +以Distributed.all_reduce.0.forward为例。这个API将多卡数据规约操作,输出为一个group内的规约结果,同一个group内的输出保持一致。
这个API中,rank0-3为一个group,Distributed.all_reduce.0.forward.input.group展示为tp-0-1-2-3,rank0-3输出一致;rank4-7为一个group,展示为tp-4-5-6-7,rank4-7输出一致。
group除了这种形式,还有如[0, 1, 2, 3]的呈现形式。 + +
+常见通信API预期结果: + +1. Distributed.all_gather:多卡数据汇总,每张卡输入可以不一致,同group内输出一致,输出是张量列表。 +2. Distributed.all_gather_into_tensor:多卡数据汇总,每张卡输入可以不一致,同group内输出一致,输出是张量。 +3. Distributed.all_reduce:多卡数据规约操作,每张卡输入可以不一致,同group内输出一致,为规约结果。 +4. Distributed.reduce_scatter:多卡数据规约操作,每张卡输入可以不一致,输出为group内规约结果的不同部分,输入是张量列表。 +5. Distributed.reduce_scatter_tensor:多卡数据规约操作,每张卡输入可以不一致,输出为group内规约结果的不同部分,输入是张量。 +6. Distributed.broadcast:输入为要广播的数据,输出为广播后的数据。 +7. Distributed.isend:点对点通信,输入为要发送的数据,输出为发送的数据。 +8. Distributed.irecv:点对点通信,输入为原数据,输出为接收的新数据。 +9. Distributed.all_to_all_single:输出数据为所有卡上的数据切分后合并的结果。 ## 4 附录 @@ -389,10 +413,6 @@ pt_outputs: {cell_name}.{class_name}: {module_name}.{class_name} ``` -冒号左侧为MindSpore框架cell模块的{cell_name}.{class_name},冒号右侧为PyTorch框架module模块的{module_name}.{class_name}。 - -{cell_name}.{class_name}从dump cell模块级.npy文件名获取,命名格式为:`{Cell}.{cell_name}.{class_name}.{前向反向}.{index}.{input/output}.{参数序号}` - 文件内容示例: ```yaml @@ -400,6 +420,20 @@ fc2.Dense: fc2.Linear conv1.Conv2d: conv3.Conv2d ``` +冒号左侧为MindSpore框架cell模块的{cell_name}.{class_name},冒号右侧为PyTorch框架module模块的{module_name}.{class_name}。 + +```yaml +{cell_name}.{class_name}从dump cell模块级.npy文件名获取,命名格式为: +{Cell}.{cell_name}.{class_name}.{forward/backward}.{index}.{input/output}.{参数序号/参数名} +或 +{Cell}.{cell_name}.{class_name}.parameters_grad.{parameter_name} + +{module_name}.{class_name}从dump module模块级.npy文件名获取,命名格式为: +{Module}.{module_name}.{class_name}.{forward/backward}.{index}.{input/output}.{参数序号/参数名} +或 +{Module}.{module_name}.{class_name}.parameters_grad.{parameter_name} +``` + ### 4.5 自定义映射文件(data_mapping) 文件名格式:\*.yaml,*为文件名,可自定义。 @@ -408,9 +442,11 @@ conv1.Conv2d: conv3.Conv2d ```yaml # API -{api_type}.{api_name}.{API调用次数}.{前向反向}.{input/output}.{参数序号}: {api_type}.{api_name}.{API调用次数}.{前向反向}.{input/output}.{参数序号} +{api_type}.{api_name}.{API调用次数}.{forward/backward}.{input/output}.{参数序号/参数名}: {api_type}.{api_name}.{API调用次数}.{forward/backward}.{input/output}.{参数序号/参数名} # 模块 -{Cell}.{cell_name}.{class_name}.{前向反向}.{index}.{input/output}.{参数序号}: {Module}.{module_name}.{前向反向}.{index}.{input/output}.{参数序号} +{Cell}.{cell_name}.{class_name}.{forward/backward}.{index}.{input/output}.{参数序号/参数名}: {Module}.{module_name}.{class_name}.{forward/backward}.{index}.{input/output}.{参数序号/参数名} +或 +{Cell}.{cell_name}.{class_name}.parameters_grad.{parameter_name}: {Module}.{module_name}.{class_name}.parameters_grad.{parameter_name} ``` 冒号左侧为MindSpore框架API的名称和Cell模块的名称,冒号右侧为PyTorch框架API的名称和module模块名称。 @@ -424,6 +460,8 @@ API和模块名称请分别从《[MindSpore 场景的精度数据采集](./06.da Functional.flash_attention_score.4.forward.input.0: NPU.npu_fusion_attention.4.forward.input.0 # 模块 Cell.relu.ReLU.forward.0.input.0: Module.module.language_model.embedding.word_embedding.VocabParallelEmbedding.forward.0.input.0 +或 +Cell.relu.ReLU.parameters_grad.weight: Module.module.language_model.embedding.word_embedding.VocabParallelEmbedding.parameters_grad.weight ``` 当dump.json文件中存在“data_name”字段时,API和模块名称为data_name字段去掉文件后缀,如下图红框处所示: diff --git a/debug/accuracy_tools/msprobe/docs/12.overflow_check_PyTorch.md b/debug/accuracy_tools/msprobe/docs/12.overflow_check_PyTorch.md index f34fca3311f2bdf6835dba638b0a51d4b44e5206..983477554e138f3e547f2d3efcf14fdfc4a991a0 100644 --- a/debug/accuracy_tools/msprobe/docs/12.overflow_check_PyTorch.md +++ b/debug/accuracy_tools/msprobe/docs/12.overflow_check_PyTorch.md @@ -26,7 +26,9 @@ msprobe 工具在 PyTorch 场景下提供溢出数据采集功能和溢出数据 ### 1.5 其他说明 -溢出数据采集功能在昇腾 NPU 上支持饱和模式和 INF/NAN 模式。INF/NAN 模式遵循 IEEE 754 标准,根据定义输出 INF/NAN 的计算结果。与之对应的饱和模式在计算出现溢出时,饱和为浮点数极值(+-MAX)。对于 CANN 侧配置,Atlas 训练系列产品,默认为饱和模式,且不建议使用 INF/NAN 模式;Atlas A2 训练系列产品,默认为 INF/NAN 模式,且不建议使用饱和模式。 +溢出数据采集功能在昇腾 NPU 上支持饱和模式(仅支持 Atlas 训练系列产品)和 INF/NAN 模式。 + +INF/NAN 模式遵循 IEEE 754 标准,根据定义输出 INF/NAN 的计算结果。与之对应的饱和模式在计算出现溢出时,饱和为浮点数极值(+-MAX)。对于 CANN 侧配置,Atlas 训练系列产品,默认为饱和模式,且不支持使用 INF/NAN 模式;Atlas A2 训练系列产品,默认为 INF/NAN 模式,且不建议使用饱和模式。 INF/NAN 模式的使能方式如下: diff --git a/debug/accuracy_tools/msprobe/docs/13.overflow_check_MindSpore.md b/debug/accuracy_tools/msprobe/docs/13.overflow_check_MindSpore.md index 33ff4a0259aef02d122022402966c65358e8efff..ef83aa17237d1cc56b8a67bf4b3ec9f57647fb9c 100644 --- a/debug/accuracy_tools/msprobe/docs/13.overflow_check_MindSpore.md +++ b/debug/accuracy_tools/msprobe/docs/13.overflow_check_MindSpore.md @@ -11,7 +11,7 @@ export INF_NAN_MODE_ENABLE=1 export MS_ASCEND_CHECK_OVERFLOW_MODE="INFNAN_MODE" ``` -**a**:在处理浮点数计算溢出问题时,NPU 当前支持两种溢出模式:INF/NAN 模式与饱和模式。INF/NAN 模式遵循 IEEE 754 标准,根据定义输出 INF/NAN 的计算结果。与之对应的饱和模式在计算出现溢出时,饱和为浮点数极值(+-MAX)。对于 CANN 侧配置,Atlas 训练系列产品,默认为饱和模式,且不建议使用 INF/NAN 模式;Atlas A2训练系列产品,默认为 INF/NAN 模式,且不建议使用饱和模式。对于 MindSpore 框架侧配置,仅支持对 Atlas A2 训练系列产品进行设置,默认为 INF/NAN 模式。CANN 侧 与 MindSpore 框架侧配置须一致。 +**a**:在处理浮点数计算溢出问题时,NPU 当前支持两种溢出模式:INF/NAN 模式与饱和模式。INF/NAN 模式遵循 IEEE 754 标准,根据定义输出 INF/NAN 的计算结果。与之对应的饱和模式在计算出现溢出时,饱和为浮点数极值(+-MAX)。对于 CANN 侧配置,Atlas 训练系列产品,默认为饱和模式,且不支持使用 INF/NAN 模式;Atlas A2训练系列产品,默认为 INF/NAN 模式,且不建议使用饱和模式。对于 MindSpore 框架侧配置,仅支持对 Atlas A2 训练系列产品进行设置,默认为 INF/NAN 模式。CANN 侧 与 MindSpore 框架侧配置须一致。 溢出检测任务的配置示例见[MindSpore 静态图场景下 task 配置为 overflow_check](https://gitee.com/ascend/mstt/blob/master/debug/accuracy_tools/msprobe/docs/03.config_examples.md#23-task-%E9%85%8D%E7%BD%AE%E4%B8%BA-overflow_check)、[MindSpore 动态图场景下 task 配置为 overflow_check](https://gitee.com/ascend/mstt/blob/master/debug/accuracy_tools/msprobe/docs/03.config_examples.md#33-task-%E9%85%8D%E7%BD%AE%E4%B8%BA-overflow_check)。 diff --git a/debug/accuracy_tools/msprobe/docs/19.monitor.md b/debug/accuracy_tools/msprobe/docs/19.monitor.md index deebfb54e6b1742b947203e76c2f8bd9698d93ca..1c197ba5496378130d8d04b6f847ee2f35c3e946 100644 --- a/debug/accuracy_tools/msprobe/docs/19.monitor.md +++ b/debug/accuracy_tools/msprobe/docs/19.monitor.md @@ -4,73 +4,77 @@ 训练状态轻量化监控工具,能够在较低性能损耗下收集和记录模型训练过程中的激活值、权重梯度、优化器状态和通信算子的中间值,实时呈现训练状态。 -- [快速上手](#快速上手) - - [权重监控](#权重监控) - - [权重梯度监控](#权重梯度监控) - - [激活值监控](#激活值监控) - - [优化器状态监控](#优化器状态监控) - - [csv格式数据转tensorboard可视化显示](#csv格式数据转tensorboard可视化显示) -- [详细配置](#详细配置) - ## 安装 -参见[msprobe安装](./01.installation.md) -要求torch版本不低于2.0。 +参见[msprobe安装](./01.installation.md)。 + +要求: + +- PyTorch场景:torch不低于**2.0** +- MindSpore场景:mindspore不低于**2.4.10**,仅支持**MindSpore动态图**,暂不支持**msadapter**套件 + +## 功能介绍 +下表中字段为训练状态轻量化监控工具的完整功能点: + +| 功能 | 说明 | 支持场景 | +| ------------------------------------------------------------ | ------------------------------------------------------------ | ----------------- | +| [权重监控](#权重监控) | 开启权重监控 | PyTorch、MindSpore | +| [权重梯度监控](#权重梯度监控) | 开启权重梯度监控 | PyTorch、MindSpore | +| [激活值监控](#激活值监控) | 开启激活值监控 | PyTorch、MindSpore | +| [优化器状态监控](#优化器状态监控) | 开启优化器状态监控 | PyTorch、MindSpore | +| [指定监控对象](#指定监控对象) | 指定监控的nn.Module(nn.Cell)及对应的输入输出 | PyTorch、MindSpore | +| [打印模型结构](#打印模型结构) | 打印模型结构 | PyTorch | +| [Module全量监控](#Module全量监控) | 对全量module的输入输出做监控 | PyTorch、MindSpore | +| [Parameter全量监控](#Parameter全量监控) | 对全量Parameter的输入输出做监控 | PyTorch、MindSpore | +| [输出格式和统计量](#输出格式和统计量) | format PyTorch支持`csv`、`tensorboard`和`api`,MindSpore仅支持`csv`,`ops`均支持,`ndigits`仅PyTorch支持 | PyTorch、MindSpore | +| [梯度异常时序判断](#梯度异常时序判断) | 梯度异常时自动梯度落盘 | PyTorch | +| [csv格式数据转tensorboard可视化显示](#csv格式数据转tensorboard可视化显示) | 将csv转为tensorboard文件显示 | PyTorch | +| [动态启停](#动态启停) | 训练过程中动态修改配置开启监控 | PyTorch、MindSpore | +| [功能重载](#功能重载) | 训练中开启激活值监控。待废弃,请使用动态启停功能代替。 | PyTorch | ## 快速上手 根据需求监控相应对象。比如在loss上扬,grad norm正常的异常训练过程中,优先考虑监控模型前向过程;在grad norm异常的训练过程中,监控权重和激活值的梯度。 推荐使用方式:权重梯度的监控性能损耗小(20B dense模型全量权重梯度监控,时间增加<1%,内存增加<1%),可以长期开启。激活值监控性能损耗大,在必要时开启或者仅监控部分。 ### 工具使能 -在训练脚本中使能工具,在配置文件(json)中控制工具行为。 +在实际训练代码中找到模型、优化器定义的位置,使能monitor工具,通过配置文件(json)控制工具行为。如下分别为Pytorch场景和MindSpore场景下的使能方式。 + +- Pytorch使能方式: ```python -# megatorn中构建初始化模型和优化器。在实际训练代码中找到模型、优化器(optional)定义的位置 -# Megatron-LM(core_r0.6.0) megatron/training.py, def pretrain: +# Megatron-LM(core_r0.6.0) training.py model, optimizer, opt_param_scheduler = setup_model_and_optimizer( model_provider, model_type) -# 使能工具 +... from msprobe.pytorch import TrainerMon -# 监控工具初始化 monitor = TrainerMon( config_file_path="./monitor_config.json", - process_group=None, params_have_main_grad=True, # 权重是否使用main_grad,通常megatron为True,deepspeed为False。默认为True。 - opt_ty=None # 优化器类型,默认为None,具体取值参考公开接口 ) -monitor.set_wrapped_optimizer(optimizer) # 挂载监控对象 -monitor.monitor_gnorm_with_ad( +monitor.set_monitor( model, grad_acc_steps=args.global_batch_size//args.data_parallel_size//args.micro_batch_size, - optimizer=None, + optimizer=optimizer, dp_group=None, tp_group=None, - start_iteration=0 - ) - - -# optional -# 可以在任意位置获取当前的参数梯度统计量 -reduced, unreduced = monitor.generate_wgrad_metrics() -# 可以在任意位置获取当前的激活值、激活值梯度统计量 -actv, actv_grad = monitor.generate_xy_metrics() + start_iteration=0 # 断点续训时提供当前iteration,默认从0开始 +) ``` -补充deepspeed下常用框架的使能位置,提供参考。 +*注意*:补充deepspeed下常用框架的使能位置。 -注意deepspeed与megaton的区别在于optimizer的传值不同,`optimizer=optimizer.optimizer`。若未使用deepspeed,则直接传optimizer,`optimizer=optimizer`。 +deepspeed与accelerate、transformers同时使用时,optimizer传值方式为`optimizer=optimizer.optimizer`,若未使用deepspeed,单独使用accelerate、transformers,optimizer传值方式为`optimizer=optimizer`。 -- accelerate +1) 同时使用deepspeed和accelerate时,工具使能位置参考如下: ```python model, optimizer, trainloader, evalloader, schedular = accelerator.prepare(...) - +... monitor = TrainerMon(...) -monitor.set_wrapped_optimizer(optimizer.optimizer) # optimizer.optimizer为DeepSpeedZeroOptimizer -monitor.monitor_gnorm_with_ad(....optimizer=optimizer.optimizer) +monitor.set_monitor(....optimizer=optimizer.optimizer) ``` -- transformers +2. 同时使用deepspeed和transformers时,工具使能位置参考如下: ```python # src/transformers/trainer.py @@ -78,19 +82,107 @@ class Trainer: def _inner_training_loop: ... monitor = TrainerMon(...) - monitor.set_wrapped_optimizer(self.optimizer.optimizer) - monitor.monitor_gnorm_with_ad(....optimizer=self.optimizer.optimizer) + monitor.set_monitor(....optimizer=self.optimizer.optimizer) for epoch in range(epochs_trained, num_train_epochs): ... ``` +- MindSpore使能方式: +```python +... +from msprobe.mindspore import TrainerMon +monitor = TrainerMon( + config_file_path="./monitor_config.json", + process_group=None, + params_have_main_grad=True, # 权重是否使用main_grad,通常megatron为True,deepspeed为False。默认为True。 +) +# 挂载监控对象 +monitor.set_monitor( + model, + grad_acc_steps=args.global_batch_size//args.data_parallel_size//args.micro_batch_size, + optimizer=optimizer, + dp_group=None, + tp_group=None +) +``` + + +### 权重监控 +- 工具配置示例: +```json +{ + "targets": { + }, + "param_distribution": true, + "format": "csv", + "ops": ["norm", "min", "max", "nans"] +} +``` +`targets`中指定module包含的所有权重都会被监控。`targets`为空时,默认监控全部module。 +设置`param_distribution`为true,表示开启权重监控功能,默认值为false。 + +### 权重梯度监控 +- 工具配置示例: +```json +{ + "targets": { + }, + "wg_distribution": true, + "format": "csv", + "ops": ["norm", "min", "max", "nans"] +} +``` +`targets`中指定module包含的所有权重都会被监控。`targets`为空时,默认监控全部module。 +设置`wg_distribution`(weight grad, noted as `wg`) 为true,表示开启权重梯度监控功能,默认值为false。 + +### 激活值监控 + +- 工具配置 +```json +{ + "targets": { + }, + "xy_distribution": true, + "forward_only": false, + "backward_only": false, + "all_xy": true, + "format": "csv", + "ops": ["norm", "min", "max", "nans"] +} +``` +`all_xy`为true表示监控全量module激活值,若需要对指定模块设置监控对象,在`targets`中进行配置,配置方式参考 [指定监控对象](#指定监控对象) 。 + +设置`xy_distribution`为true表示开启激活值监控功能,默认值为false。 + +注意:`forward_only`和`backward_only`均为true时,触发warning,前反向均不采集;默认值均为false时,前反向均采集。 + + +### 优化器状态监控 +- 工具配置示例: +```json +{ + "targets": { + }, + "mv_distribution": true, + "format": "csv", + "ops": ["norm", "min", "max", "nans"] +} +``` +`targets`中指定module包含的所有权重都会被监控。`targets`为空时,默认监控全部module。 +设置`mv_distribution`为true表示开启优化监控功能(1st moment noted as `m`, 2nd moment noted as `v`),默认值为false。[什么是mv](https://arxiv.org/pdf/1412.6980) + +本工具针对分布式计算框架megatron和deepspeed框架做了适配,暂不支持其他框架。 + + +## 高阶功能 + ### 指定监控对象 工具支持对nn.Module(**激活值监控**)和nn.Parameter(**权重监控**、**权重梯度监控、优化器监控**)对象实现相应的监控行为,在配置文件的"targets"(dict)字段指定,targets格式为{module_name/param_name: {filed: format}}。 -- 打印模型结构 -工具提供可选项"print_struct"打印模型结构,帮助配置targets。工具会在在第一个step后打印结构并停止训练进程,模型结构默认打印在`$MONITOR_OUTPUT_DIR/module_struct.json`。 +#### 打印模型结构 +工具提供可选项`print_struct`打印模型结构,帮助配置targets。工具会在在第一个step后打印结构并停止训练进程,模型结构默认打印在`$MONITOR_OUTPUT_DIR/module_struct.json`。 ```json { "print_struct": true @@ -98,7 +190,7 @@ class Trainer: ``` 输出样例: -"config"字段用于配置文件中指定module target。其余为各个元素的shape和dtype。 +字段`config`用于配置文件中指定module target。其余为各个元素的shape和dtype。 ```json "0:63.mlp.linear_fc2": { @@ -140,7 +232,8 @@ class Trainer: } } ``` -**Module全量监控**:工具提供简便的全量module监控方式。或不配置targets、all_xy字段,同样表示全量监控。 +#### Module全量监控 +工具提供简便的全量module监控方式。或不配置targets、all_xy字段,同样表示全量监控。 ```json { @@ -165,7 +258,8 @@ class Trainer: } ``` -**Parameter全量监控**:工具提供简便的全量parameter监控方式。或不配置targets,同样表示全量监控。 +#### Parameter全量监控 +工具提供简便的全量parameter监控方式。或不配置targets,同样表示全量监控。 ```json { @@ -183,17 +277,19 @@ class Trainer: } ``` -- 输出路径 -通过环境变量`MONITOR_OUTPUT_DIR`设置,默认为"monitor_output"。 +#### 输出路径 +通过环境变量`MONITOR_OUTPUT_DIR`设置monitor输出路径,默认为`./monitor_output/`。 ```shell export MONITOR_OUTPUT_DIR=/xxx/output_dir ``` - 输出格式 - 通过可选配置项`format`指定。可以是\["tensorboard"(缺省值), "csv", "api"\]。 + 通过可选配置项`format`指定,当前支持`csv`, `tensorboard`, `api`。其中`csv`为默认缺省值。 - - format: tensorboard - 监控结果写入tensorboard的event文件,启动tensorboard查看 + - **tensorboard** + 监控结果写入tensorboard的event文件,启动tensorboard查看。 + 激活值监控任务的tag为{vpp_stage}:{module_name}.{input or output}:{micro_step}/{rank}/{task}\_{ops} + 其他监控任务的tag为{vpp_stage}:{param_name}/{rank}/{task}\_{ops} ```shell tensorboard --logdir=$MONITOR_OUTPUT_DIR ``` @@ -202,68 +298,25 @@ export MONITOR_OUTPUT_DIR=/xxx/output_dir ssh -N -L localhost:6006:localhost:6006 your_username@remote_server_address ``` - - format: csv - 监控结果写入csv文件中,可以通过`ndigits`字段设置小数位数. + - **csv** + 监控结果写入csv文件中,可以通过`ndigits`字段设置小数位数。 + 表头为 vpp_stage | name | step | micro_step(optional) | *ops |。 + 仅在激活值监控的输出文件中包含micor_step。 + 激活值监控的name为.\, 其他任务的name为> - - format: api - 监控结果不落盘,在训练过程中可以通过`generate_wgrad_metrics`、`generate_xy_metrics`等接口获取。 + - **api** + 监控结果不落盘,在训练过程中可以通过`generate_wgrad_metrics`、`generate_xy_metrics`等接口获取,使用方式参考[公开接口](#公开接口) 。 - 统计量 -通过配置项"ops"指定。可以是["norm", "min", "max", "mean", "nans","zeros"]。其中"nans"统计tensor中nan的数量,"zeros"统计tensor中数值小于"eps"的比例。 - -### 权重监控 -- 工具配置示例: -```json -{ - "targets": { - "": {} - }, - "param_distribution": true, - "format": "csv", - "ops": ["norm", "min", "max", "nans"] -} -``` -"targets"中指定module包含的所有权重都会被监控。整个model的name为空字符串可以覆盖全量权重。 -设置"param_distribution"开启权重监控功能。 - -使用deepspeed的zero优化器时,需要在工具中指定优化器类型并传入优化器,获取梯度切分行为已还原参数梯度。 -```python -from msprobe.pytorch import TrainerMon -# 以zero1优化器举例,opt_ty取值DeepSpeedZeroOptimizer_Stage1_or_2 -# 示例为deepspeed,params_have_main_grad取值False -monitor = TrainerMon("./monitor_config.json", params_have_main_grad=False, opt_ty="DeepSpeedZeroOptimizer_Stage1_or_2") -monitor.set_wrapped_optimizer(optimizer) # optimzier为训练框架自定义的优化器 -monitor.monitor_gnorm_with_ad( - model, grad_acc_steps=model.grad_acc_steps, optimizer=optimizer) -``` - +通过配置项`ops`指定。当前支持`norm`, `min`, `max`, `mean`, `nans`,`zeros`。其中`nans`监控tensor中`nan`的数量,`zeros`统计tensor中数值小于`eps`的比例。 -### 权重梯度监控 -- 工具配置示例: -```json -{ - "targets": { - "": {} - }, - "wg_distribution": true, - "format": "csv", - "ops": ["norm", "min", "max", "nans"] -} -``` -"targets"中指定module包含的所有权重都会被监控。整个model的name为空字符串可以覆盖全量梯度。 -设置"wg_distribution"(weight grad, noted as `wg`)开启梯度监控功能。 +- csv输出件合并 -使用deepspeed的zero优化器时,需要在工具中指定优化器类型并传入优化器,获取梯度切分行为已还原参数梯度。 -```python -from msprobe.pytorch import TrainerMon -# 以zero1优化器举例,opt_ty取值DeepSpeedZeroOptimizer_Stage1_or_2 -# 示例为deepspeed,params_have_main_grad取值False -monitor = TrainerMon("./monitor_config.json", params_have_main_grad=False, opt_ty="DeepSpeedZeroOptimizer_Stage1_or_2") -monitor.set_wrapped_optimizer(optimizer) # optimzier为训练框架自定义的优化器 -monitor.monitor_gnorm_with_ad( - model, grad_acc_steps=model.grad_acc_steps, optimizer=optimizer) -``` + 提供csv输出件合并功能,在配置json文件中设置`step_count_per_record`,表示每个csv文件存储多个step的监控数据。默认值为1,表示每个csv文件记录一个step的监控数据。 + + 如下图所示为梯度监控结果示例,配置`step_count_per_record`为5,连续监控10个step,每个csv文件记录了5个step的梯度数据。其中`grad_reduced_0-4.csv`为step0至step4共计5个step的聚合后梯度数据,`grad_unreduced_0-4.csv`为step0至step4共计5个step的聚合前梯度数据。 + ![step_count_per_record](img/monitor/step_count_per_record.png) ### 梯度异常时序判断 1. 训练前配置相关参数 @@ -277,7 +330,11 @@ monitor.monitor_gnorm_with_ad( ``` 2. 实例化工具时传入流水线并行group ```python -monitor = TrainerMon("./monitor_config.json", process_group=mpu.get_pipeline_model_parallel_group(), params_have_main_grad=True) +monitor = TrainerMon( + "./monitor_config.json", + process_group=mpu.get_pipeline_model_parallel_group(), + params_have_main_grad=True # 权重是否使用main_grad,通常megatron为True,deepspeed为False。默认为True。 +) ``` 训练过程中,检测到异常后打屏提示,并将异常信息按照rank分组写入json文件,文件路径默认为`monitor_output/anomaly_detected`,异常信息示例如下: @@ -307,91 +364,14 @@ python3 -m msprobe.pytorch.monitor.anomaly_analyse -d $MONITOR_OUTPUT_DIR/anomal ``` 异常事件分析结束,将topk事件写入文件`anomaly_detected/anomaly_analyse.json`。异常分析支持以下参数配置: -| 字段名 | 解释 | 是否必选 | -| ------ | -------- | -------- | -|-d 或 --data_path| 指定梯度异常落盘文件夹,梯度监控功能输出,一般为$MONITOR_OUTPUT_DIR/anomaly_detected。|是 | -|-o 或 --out_path| 排序后的异常落盘文件地址,默认在--data_path路径下落盘一个anomaly_analyse.json文件。 | 否 | -|-k 或 --topk| 指定保留前topk个异常,默认为8。 | 否 | -|-s 或 --step_list| 指定分析的step范围,默认为[]。 | 否 | - -### 激活值监控 - -- 工具配置 -```json -{ - "targets": { - "module.module.language_model.encoder.layers.0": { - "input": "tuple[2]", - "output": "tensor" - } - }, - "print_struct": false, - "xy_distribution": true, - "forward_only": true, - "backward_only": false, - "all_xy": true, - "format": "csv", - "ops": ["norm", "min", "max", "nans"] -} -``` -设置"xy_distribution"为true表示开启激活值监控功能,"all_xy"为true表示监控全量module激活值。 - -forward_only和backward_only均为true时,触发warning,前反向均不采集;均为false时,前反向均采集。 -```python -from msprobe.pytorch import TrainerMon -# 以zero1优化器举例,opt_ty取值DeepSpeedZeroOptimizer_Stage1_or_2 -# 示例为deepspeed,params_have_main_grad取值False -monitor = TrainerMon("./monitor_config.json", params_have_main_grad=False, opt_ty="DeepSpeedZeroOptimizer_Stage1_or_2") -monitor.set_wrapped_optimizer(optimizer) # optimzier为训练框架自定义的优化器 -monitor.monitor_gnorm_with_ad( - model, grad_acc_steps=model.grad_acc_steps, optimizer=optimizer) -``` +| 字段名 | 解释 | 是否必选 | +| ----------------- | ------------------------------------------------------------ | -------- | +| -d 或 --data_path | 指定梯度异常落盘文件夹,梯度监控功能输出,一般为$MONITOR_OUTPUT_DIR/anomaly_detected。 | 是 | +| -o 或 --out_path | 排序后的异常落盘文件地址,默认在--data_path路径下落盘一个anomaly_analyse.json文件。 | 否 | +| -k 或 --topk | 指定保留前topk个异常,默认为8。 | 否 | +| -s 或 --step_list | 指定分析的step范围,默认为[]。 | 否 | - -### 功能重载 -- 统计量 -可以在训练过程中修改`TrainerMon`实例的`ops`属性, 调整监控的统计量。 -```python -if {some condition}: - monitor.ops = ["min", "max"] -``` - -- 训练过程中开关激活值监控 -激活值监控的性能损耗较大, 推荐仅在必要时开启, 比如发现loss出现尖刺, 根据loss的异常开启激活值监控. -```python -if {some condition}: - monitor.reload_xy(xy_distribution=True) -``` - -### 优化器状态监控 -- 工具配置示例: -```json -{ - "targets": { - "module.encoder.layers.0": {}, - "module.embedding.word_embedding.weight": {} - }, - "mv_distribution": true, - "format": "csv", - "ops": ["norm", "min", "max", "nans"] -} -``` -"targets"中指定module包含的所有权重都会被监控。 -设置"mv_distribution"表示开启优化监控功能(1st moment noted as `m`, 2nd moment noted as `v`)。[什么是mv](https://arxiv.org/pdf/1412.6980) - -本工具针对分布式计算框架megatron和deepspeed框架做了适配,暂不支持其他框架。 - -```python -from msprobe.pytorch import TrainerMon -# 以zero1优化器举例,opt_ty取值DeepSpeedZeroOptimizer_Stage1_or_2 -# 示例为deepspeed,params_have_main_grad取值False -monitor = TrainerMon("./monitor_config.json", params_have_main_grad=False, opt_ty="DeepSpeedZeroOptimizer_Stage1_or_2") -monitor.set_wrapped_optimizer(optimizer) # optimzier为训练框架自定义的优化器 -monitor.monitor_gnorm_with_ad( - model, grad_acc_steps=model.grad_acc_steps, optimizer=optimizer) -``` - ### csv格式数据转tensorboard可视化显示 将csv数据转换为tensorboard格式数据。 @@ -413,58 +393,97 @@ csv2tensorboard_by_step( ) ``` +### 动态启停 +动态启停模式:支持用户在训练过程中随时启动/更新监控。 -## 公开接口 +用户可在训练开始前通过配置环境变量DYNAMIC_MONITOR=True来确认开启动态启停模式,该模式下需要配合config.json文件中的dynamic_on字段来使用。 -```python -TrainerMon.__init__(config_file_path, process_group=None, params_have_main_grad=True, opt_ty=None) -> None -``` +在动态启停模式下,启动和停止分别由如下控制: + +- 启动: + 首次监控:config.json文件中dynamic_on字段为true,代表是否需要开启监控。 + 非首次监控:config文件时间戳更新且config.json文件中dynamic_on字段为true。 +- 停止: + 到达collect_times之后自动停止并改config.json文件中dynamic_on字段为false,可再通过上述操作重启。 + +大部分情况下,用户可在看到异常趋势后再手动更新config.json文件并打开dynamic_on开关;此外,使用时若想要在一开始就启动监控,可直接打开dynamic_on开关做基础配置的监测(首次不要求时间戳更新) -| 参数 | 说明 | 是否必选 | -| ----- | -------------------- | -------- | -| config_file_path |json配置文件路径。 | 是 | -| process_group | 传入ProcessGroup对象,用以确定pipeline并行不同rank异常间时序,megatron下通过core.parallel_state.get_pipeline_model_parallel_group()获得。 | 否 | -| params_have_main_grad |权重是否使用main_grad,通常megatron为True,deepspeed为False。默认为True。 | 否 | -| opt_ty |优化器类型,默认为None。
-Megatron_DistributedOptimizer:megatron分布式优化器;
-Megatron_Float16OptimizerWithFloat16Params:megatron混合精度优化器;
-Megatron_ChainedDistributedOptimizer:megatron分布式优化器序列;
-Megatron_ChainedFloat16OptimizerWithFloat16Params:megatron混合精度优化器序列;
-DeepSpeedZeroOptimizer_Stage1_or_2:DeepSpeed Zero1和Zero2;
-DeepSpeedZeroOptimizer_Stage3:DeepSpeed Zero3。 | 否 | +注意事项: + +- 默认监控启动皆统一在配置初始化或查询到更新后的下一步,也就是若第n步挂上hook则第n+1步才启动采集,如需采集第0步数据请用静态模式。 +- config中途修改出错时,若此时不在监控就不生效,若在监控则用原配置继续。 +- 达到collect_times之后会自动将该值置为false待下次改true重启。 + +### 功能重载 +此功能将在2026年废弃。请使用[动态启停](#动态启停)功能代替。 +- 统计量 +可以在训练过程中修改`TrainerMon`实例的`ops`属性, 调整监控的统计量。 ```python -TrainerMon.monitor_gnorm_with_ad(model, grad_acc_steps, optimizer, dp_group, tp_group, start_iteration) -> None +if {some condition}: + monitor.ops = ["min", "max"] ``` -| 参数 | 说明 | 是否必选 | -| ----- | -------------------- | -------- | -| model |需要监控的模型,需要是一个torch.nn.Module。 | 是 | -| grad_acc_steps | 梯度累积步数。 | 是 | -| optimizer | 需要patch的优化器 | 否 | -| dp_group | 数据并行的通信组。
dp域通信后,且没有使用分布式优化器时,group内所有rank的梯度相同,落盘数据冗余。
提供dp_group后,工具仅保留每个dp_group的第一个rank的梯度。 | 否 | -| tp_group | 张量并行的通信组。
tp域通信后,group内部分参数所有rank的梯度相同,落盘数据冗余。
提供tp_group后,工具仅保留每个tp_group中冗余参数在第一个rank的梯度。
当前适配Megatron core_v0.6.0, 通过权重属性"tensor_model_parallel"判断是否冗余。 | 否 | -| start_iteration | 训练的起始iteration,影响工具计数 | 否 | +- 训练过程中开关激活值监控 +激活值监控的性能损耗较大, 推荐仅在必要时开启, 比如发现loss出现尖刺, 根据loss的异常开启激活值监控. +```python +if {some condition}: + monitor.reload_xy(xy_distribution=True) +``` +## 公开接口 +- monitor工具初始化 ```python -TrainerMon.set_wrapped_optimizer(_wrapped_optimizer) -> None +TrainerMon.__init__(config_file_path, process_group=None, params_have_main_grad=True) -> None ``` -| 参数 | 说明 | 是否必选 | -| ----- | -------------------- | -------- | -| _wrapped_optimizer |megatron、deepspeed创建好的混合精度优化器。 | 是 | +| 参数 | 说明 | 是否必选 | +| --------------------- | ------------------------------------------------------------ | -------- | +| config_file_path | json配置文件路径。 | 是 | +| process_group | 传入ProcessGroup对象,用以确定pipeline并行不同rank异常间时序,megatron下通过core.parallel_state.get_pipeline_model_parallel_group()获得。仅在异常时序判断功能中使用。 | 否 | +| params_have_main_grad | 权重是否使用main_grad,通常megatron为True,deepspeed为False。默认为True。 | 否 | +- 模型挂载monitor工具 +```python +TrainerMon.set_monitor(model, grad_acc_steps, optimizer, dp_group=None, tp_group=None, start_iteration=0) -> None +``` +| 参数 | 说明 | 是否必选 | +| --------------- | ------------------------------------------------------------ | -------- | +| model | 需要监控的模型,需要是一个torch.nn.Module或者mindspore.nn.Cell。 | 是 | +| grad_acc_steps | 梯度累积步数。 | 是 | +| optimizer | 需要patch的优化器。 | 否 | +| dp_group | 数据并行的通信组。
dp域通信后,且没有使用分布式优化器时,group内所有rank的梯度相同,落盘数据冗余。
提供dp_group后,工具仅保留每个dp_group的第一个rank的梯度。 | 否 | +| tp_group | 张量并行的通信组。
tp域通信后,group内部分参数所有rank的梯度相同,落盘数据冗余。
提供tp_group后,工具仅保留每个tp_group中冗余参数在第一个rank的梯度。
当前适配Megatron core_r0.6.0, 通过权重属性"tensor_model_parallel"判断是否冗余。 | 否 | +| start_iteration | 训练的起始iteration,影响工具计数。**仅PyTorch场景支持此参数**。 | 否 | + +- csv输出件转tensorboard输出件 ```python csv2tensorboard_by_step(monitor_path, time_start, time_end, process_num=1, data_type_list=None) -> None ``` -| 参数 | 说明 | 是否必选 | -| ----- | -------------------- | -------- | -| monitor_path | 待转换的csv存盘目录。 | 是 | -| time_start | 起始时间戳。搭配time_end一起使用。指定一个时间范围,会对这个范围内的文件进行转换。左闭右闭的区间。 | 是 | -| time_end | 结束时间戳。搭配time_start一起使用。指定一个时间范围,会对这个范围内的文件进行转换。左闭右闭的区间。 | 是 | -| process_num | 指定拉起的进程个数,默认为1,更多的进程个数可以加速转换。 | 否 | -| data_type_list | 指定需要转换的数据类型, 数据类型应来自输出件文件前缀,所有类型数据:
["actv", "actv_grad", "exp_avg", "exp_avg_sq", "grad_unreduced", "grad_reduced", "param"]。
不指定就转换全部数据。 | 否 | +| 参数 | 说明 | 是否必选 | +| -------------- | ------------------------------------------------------------ | -------- | +| monitor_path | 待转换的csv存盘目录。 | 是 | +| time_start | 起始时间戳。搭配time_end一起使用。指定一个时间范围,会对这个范围内的文件进行转换。左闭右闭的区间。 | 是 | +| time_end | 结束时间戳。搭配time_start一起使用。指定一个时间范围,会对这个范围内的文件进行转换。左闭右闭的区间。 | 是 | +| process_num | 指定拉起的进程个数,默认为1,更多的进程个数可以加速转换。 | 否 | +| data_type_list | 指定需要转换的数据类型, 数据类型应来自输出件文件前缀,所有类型数据:
["actv", "actv_grad", "exp_avg", "exp_avg_sq", "grad_unreduced", "grad_reduced", "param"]。
不指定就转换全部数据。 | 否 | +- 在模型任意位置获取当前参数**梯度**统计量 ```python -TrainerMon.generate_wgrad_metrics() -> tuple[dict[dict]] +TrainerMon.generate_wgrad_metrics() -> tuple[dict, dict] +``` +具体使用方式如下: +```python +reduced, unreduced = monitor.generate_wgrad_metrics() ``` +- 在模型任意位置获取当前参数**激活值**统计量 ```python -TrainerMon.generate_xy_metrics() -> tuple[dict[dict]] +TrainerMon.generate_xy_metrics() -> tuple[dict, dict] +``` +具体使用方式如下: +```python +actv, actv_grad = monitor.generate_xy_metrics() ``` @@ -476,6 +495,10 @@ TrainerMon.generate_xy_metrics() -> tuple[dict[dict]] "targets": { "language_model.encoder.layers.0": {"input": "tuple[2]:0", "output": "tensor", "input_grad":"tuple[2]:0", "output_grad":"tuple[1]:0"} }, + "dynamic_on": false, + "start_step": 0, + "collect_times": 100000000, + "step_interval": 1, "print_struct": false, "module_ranks": [0,1,2,3], "ur_distribution": true, @@ -491,39 +514,45 @@ TrainerMon.generate_xy_metrics() -> tuple[dict[dict]] "rules": [{"rule_name": "AnomalyTurbulence", "args": {"threshold": 0.5}}], "dump": false }, - "format": "tensorboard", + "format": "csv", "ops": ["min", "max", "norm", "zeros", "nans", "mean"], "eps": 1e-8, "ndigits": 12, "step_count_per_record": 1, - "append_output": [] + "append_output": [], + "squash_name": true } ``` 下面详细解释各个字段: -| 字段名字 | 是否必选 | 解释 | -| ------------------------------------------------------------ | -------- | -------- | -|"targets"| 可选 |指定需要监控的模型层和监控对象, 例如transformer的第0层language_model.encoder.layers.0,可选择监控input、output、input_grad、output_grad。如果不清楚模型结构, 可以将 "print_struct" 字段设置为 true, 监控工具会打印模型中torch module的名字和详细结构,并在第1个step后退出。未配置时默认为全量监控。| -|"input"| 可选 |"tuple[2]:0"的意思是目标module的前向input参数为长度为2的tuple, 我们关心的是tuple第0个元素。| -|"output"| 必选 |"tensor"的意思是目标module的前向output参数类型为tensor| -|"input_grad"| 可选 |"tuple[2]:0"的意思是目标module的后向input_grad参数是长度为2的tuple, 我们关心的是tuple的第0个元素。| -|"output_grad"| 必选 |"tuple[1]:0"的意思是目标module的后向input_grad参数是长度为1的tuple, 我们关心的是tuple的第0个元素。| -|"print_struct"| 可选 |设置为true后监控工具会打印模型中torch module的名字和详细结构,并在第1个step后退出。不填默认为false。| -|"module_ranks"| 可选 |用于在分布式训练场景中希望控制在哪些rank开启module监控。如果不填,则默认在所有rank开启。| -|"ur_distribution"| 可选 |若为true则会统计adam优化器指定模块(targets中指定)参数的update和ratio向量的数值分布,并展示在heatmap里,默认为false,同时format字段必须设置为tensorboard。
依赖histc算子, 需要CANN8.0.rc2以上版本, 否则会有严重的性能问题。 | -|"xy_distribution"| 可选 |若为true则会监控指定module(targets中指定)的输入输出张量。 默认为false。| -|"all_xy"| 可选 |开启xy_distribution后生效,若为true,监控所有module。默认为false。
与targets同时生效,all_xy配置为true时,若targets配置module_xx和指定对象,则module_xx按targets配置生效,其他module则监控全部对象,包含input、output、input_grad、output_grad。| -|"forward_only"| 可选 |开启xy_distribution后生效,若为true,仅监控指定module的前向,targets中的input_grad、output_grad不生效。默认为false。| -|"backward_only"| 可选 |开启xy_distribution后生效,若为true,仅监控指定module的反向,targets中的input、output不生效。默认为false。| -|"mv_distribution"| 可选 |若为true则会监控指定模块中的参数的优化器状态, 默认为false。需要在TrainerMon构造函数正确指定opt_ty。 目前支持megatron和Deepspeed的分布式优化器。
-Megatron_DistributedOptimizer:megatron分布式优化器;
-Megatron_Float16OptimizerWithFloat16Params:megatron混合精度优化器;
-Megatron_ChainedDistributedOptimizer:megatron分布式优化器序列;
-Megatron_ChainedFloat16OptimizerWithFloat16Params:megatron混合精度优化器序列;
-DeepSpeedZeroOptimizer_Stage0:DeepSpeed Zero0
-DeepSpeedZeroOptimizer_Stage1_or_2:DeepSpeed Zero1和Zero2;
-DeepSpeedZeroOptimizer_Stage3:DeepSpeed Zero3。
未使用megatron和deepspeed框架时,opt_ty默认为None,无需传入。 | -|"wg_distribution"| 可选 |若为true则会监控指定模块的参数梯度, 默认为false。 | -|"param_distribution"| 可选 |若为true则会监控指定模块的参数, 默认为false。 | -|"alert"| 可选 | "rules": 指定自动报警的异常检测机制及其相应的阈值。目前实现的异常检测是AnomalyTurbulence, 如果统计标量超出历史均值的指定浮动范围(threshold 0.5意味着上浮或者下浮50%)则在控制台打印报警信息。当"dump"字段配置为true表示异常事件写入文件,默认为false。 | -|"cc_distribution"| 可选 |其中"enable"字段控制通信监控模块的开关;需要监控通信算子时,务必尽量早地实例化`TrainerMon`, 因为监控通过劫持原始func后挂hook实现,部分加速库初始化时会保存原始function,避免监控失效。"cc_codeline"字段指定监控的代码行,如:`train.py\\[23\\]`,默认为空列表,不特别指定;"cc_pre_hook"字段控制是否监控通信前的数据; 模块会在第二个optimize.step之前打印通信日志,包括通信api的调用栈、输入dtype、通信group。 "cc_log_only"为true时,仅打印日志,不监控通信的输入输出,并在打印后中断训练。可以根据通信日志设置"cc_codeline",规避与训练过程不相关的通信,比如一些时间、metrics的同步。| -|"format"| 可选 | 数据落盘格式,默认为tensorboard,可选 \["tensorboard", "csv", "api"\]。 | -|"ops"| 可选 |类型为list,与ur_distribution、xy_distribution、mv_distribution、wg_distribution、mg_direction、cc_distribution配合,监控所选张量的统计指标,目前支持"min"、"max"、"norm"、"mean"、"zeros"、"nans"。其中,zeros代表监控所选张量的元素小于eps的比例,nans代表张量中nan的数量。当ops中无有效指标时,默认监控norm指标。| -|"eps"| 可选 |若ops里包含"zeros"则需要配置,默认为1e-8。| -|"ndigits"| 可选 |"format"为"csv"时,设置落盘文件中的小数位数,默认为6。| -|"step_count_per_record"| 可选 | "format"为"csv"时生效,每个csv记录多少个step的数据,默认为1。| -|"append_output"| 可选 | 适用于断点续训场景。多卡场景下生效,指定两个时间戳,将输出续写到这两个时间戳范围间的输出件中,不在范围内的rank不被续写。时间戳应来自原有输出件目录前缀,例如["Dec03_21-34-40", "Dec03_21-34-41"]。默认为[],不续写。 | +| 字段名字 | 是否必选 | 解释 | +| ----------------------- | -------- |---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| "targets" | 可选 | 指定需要监控的模型层和监控对象, 例如transformer的第0层language_model.encoder.layers.0,可选择监控input、output、input_grad、output_grad。如果不清楚模型结构, 可以将 "print_struct" 字段设置为 true, 监控工具会打印模型中torch module的名字和详细结构,并在第1个step后退出。未配置时默认为全量监控。 | +| "input" | 可选 | "tuple[2]:0"的意思是目标module的前向input参数为长度为2的tuple, 我们关心的是tuple第0个元素。 | +| "output" | 必选 | "tensor"的意思是目标module的前向output参数类型为tensor | +| "input_grad" | 可选 | "tuple[2]:0"的意思是目标module的后向input_grad参数是长度为2的tuple, 我们关心的是tuple的第0个元素。 | +| "output_grad" | 必选 | "tuple[1]:0"的意思是目标module的后向input_grad参数是长度为1的tuple, 我们关心的是tuple的第0个元素。 | +| "dynamic_on" | 可选 | 在动态启停时使用,true代表打开监控,false代表关闭监控,默认值为false,且达到collect_times之后会自动将该值置为false待下次改true重启。**仅PyTorch场景支持此参数**。 | +| "collect_times" | 可选 | 设置采集次数,达到该次数后停止监控,默认值为100000000,目的是一直采集。 | +| "start_step" | 可选 | 设置开始采集step,模型训练达到start_step后开始监控采集,默认值为0,表示从step0开始监控采集。 | +| "step_interval" | 可选 | 设置采集step间隔,默认值为1,表示每个step均采集监控数据。 | +| "print_struct" | 可选 | 设置为true后监控工具会打印模型中torch module的名字和详细结构,并在第1个step后退出。不填默认为false。**仅PyTorch场景支持此参数**。 | +| "module_ranks" | 可选 | 用于在分布式训练场景中希望控制在哪些rank开启module监控。如果不填,则默认在所有rank开启。 | +| "ur_distribution" | 可选 | 若为true则会统计adam优化器指定模块(targets中指定)参数的update和ratio向量的数值分布,并展示在heatmap里,默认为false,同时format字段必须设置为tensorboard。
依赖histc算子, 需要CANN8.0.rc2以上版本, 否则会有严重的性能问题。**仅PyTorch场景支持此参数**。 | +| "xy_distribution" | 可选 | 若为true则会监控指定module(targets中指定)的输入输出张量。 默认为false。 | +| "all_xy" | 可选 | 开启xy_distribution后生效,若为true,监控所有module。默认为false。
与targets同时生效,all_xy配置为true时,若targets配置module_xx和指定对象,则module_xx按targets配置生效,其他module则监控全部对象,包含input、output、input_grad、output_grad。 | +| "forward_only" | 可选 | 开启xy_distribution后生效,若为true,仅监控指定module的前向,targets中的input_grad、output_grad不生效。默认为false。 | +| "backward_only" | 可选 | 开启xy_distribution后生效,若为true,仅监控指定module的反向,targets中的input、output不生效。默认为false。 | +| "mv_distribution" | 可选 | 若为true则会监控指定模块中的参数的优化器状态, 默认为false。需要在TrainerMon构造函数正确指定opt_ty。 目前支持megatron和Deepspeed的分布式优化器。
-Megatron_DistributedOptimizer:megatron分布式优化器;
-Megatron_Float16OptimizerWithFloat16Params:megatron混合精度优化器;
-Megatron_ChainedDistributedOptimizer:megatron分布式优化器序列;
-Megatron_ChainedFloat16OptimizerWithFloat16Params:megatron混合精度优化器序列;
-DeepSpeedZeroOptimizer_Stage0:DeepSpeed Zero0
-DeepSpeedZeroOptimizer_Stage1_or_2:DeepSpeed Zero1和Zero2;
-DeepSpeedZeroOptimizer_Stage3:DeepSpeed Zero3。
未使用megatron和deepspeed框架时,opt_ty默认为None,无需传入。 | +| "wg_distribution" | 可选 | 若为true则会监控指定模块的参数梯度, 默认为false。 | +| "param_distribution" | 可选 | 若为true则会监控指定模块的参数, 默认为false。 | +| "alert" | 可选 | "rules": 指定自动报警的异常检测机制及其相应的阈值。目前实现的异常检测是AnomalyTurbulence, 如果统计标量超出历史均值的指定浮动范围(threshold 0.5意味着上浮或者下浮50%)则在控制台打印报警信息。当"dump"字段配置为true表示异常事件写入文件,默认为false。**仅PyTorch场景支持此参数**。 | +| "cc_distribution" | 可选 | 其中"enable"字段控制通信监控模块的开关;需要监控通信算子时,务必尽量早地实例化`TrainerMon`, 因为监控通过劫持原始func后挂hook实现,部分加速库初始化时会保存原始function,避免监控失效。"cc_codeline"字段指定监控的代码行,如:`train.py\\[23\\]`,默认为空列表,不特别指定;"cc_pre_hook"字段控制是否监控通信前的数据; 模块会在第二个optimize.step之前打印通信日志,包括通信api的调用栈、输入dtype、通信group。 "cc_log_only"为true时,仅打印日志,不监控通信的输入输出,并在打印后中断训练。可以根据通信日志设置"cc_codeline",规避与训练过程不相关的通信,比如一些时间、metrics的同步。**仅PyTorch场景支持此参数**。 | +| "format" | 可选 | 数据落盘格式,默认值为"csv",可选 \["csv", "tensorboard", "api"\]。仅PyThon和MindSpore动态图场景支持此参数,且MindSpore动态图场景仅支持\["csv"\]。 | +| "ops" | 可选 | 类型为list,与ur_distribution、xy_distribution、mv_distribution、wg_distribution、mg_direction、cc_distribution配合,监控所选张量的统计指标,目前支持"min"、"max"、"norm"、"mean"、"zeros"、"nans"。其中,zeros代表监控所选张量的元素小于eps的比例,nans代表张量中nan的数量。当ops中无有效指标时,默认监控norm指标。 | +| "eps" | 可选 | 若ops里包含"zeros"则需要配置,默认为1e-8。 | +| "ndigits" | 可选 | "format"为"csv"时,设置落盘文件中的小数位数,默认为6。**仅PyTorch场景支持此参数**。 | +| "step_count_per_record" | 可选 | "format"为"csv"时生效,每个csv记录多少个step的数据,默认为1。 | +| "append_output" | 可选 | 适用于断点续训场景。多卡场景下生效,指定两个时间戳,将输出续写到这两个时间戳范围间的输出件中,不在范围内的rank不被续写。时间戳应来自原有输出件目录前缀,例如["Dec03_21-34-40", "Dec03_21-34-41"]。默认为[],不续写。**仅PyTorch场景支持此参数**。 | +| "squash_name" | 可选 | 是否简化参数名/模块名,多模态场景建议关闭,默认为True | diff --git a/debug/accuracy_tools/msprobe/docs/21.visualization_PyTorch.md b/debug/accuracy_tools/msprobe/docs/21.visualization_PyTorch.md index b2d67e585a1ebdc7b332e4c7957100ff9a3e60d4..34cdc2aa99b8f6f1ab65a2b692424506f2563b56 100644 --- a/debug/accuracy_tools/msprobe/docs/21.visualization_PyTorch.md +++ b/debug/accuracy_tools/msprobe/docs/21.visualization_PyTorch.md @@ -43,14 +43,14 @@ msprobe -f pytorch graph -i ./compare.json -o ./output ``` **命令行参数说明**: -| 参数名 | 说明 | 是否必选 | -|------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------| -| -i 或 --input_path | 指定比对文件,参考[比对文件说明](#313-比对文件说明) | 是 | -| -o 或 --output_path | 配置比对结果文件存盘目录,str 类型。文件名称基于时间戳自动生成,格式为:`compare_{timestamp}.vis或build_{timestamp}.vis`。 | 是 | -| -lm 或 --layer_mapping | 跨套件比对,例如同一个模型分别使用了DeepSpeed和Megatron套件的比对场景。配置该参数时表示开启跨套件Layer层的比对功能,指定模型代码中的Layer层后,可以识别对应dump数据中的模块或API。需要指定自定义映射文件*.yaml。自定义映射文件的格式请参见[自定义映射文件(Layer)](#71-自定义映射文件layer)。 | 否 | -| -oc 或 --overflow_check | 是否开启溢出检测模式,开启后会在输出vis文件中(`compare_{timestamp}.vis或build_{timestamp}.vis`)对每个溢出节点进行标记溢出等级,溢出等级说明参考[溢出等级说明](#312-溢出等级说明) | 否 | -| -f 或 --fuzzy_match | 是否开启模糊匹配,bool类型。模糊匹配说明参考[匹配说明](#311-匹配说明) | 否 | -| -cs 或 --complete_stack | 是否使用完整的堆栈信息,bool类型。默认使用精简的堆栈信息,数据量小有助于增加流畅度。完整堆栈和精简堆栈信息参考[堆栈信息说明](#72-堆栈信息说明) | 否 | +| 参数名 | 说明 | 是否必选 | +|------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------| +| -i 或 --input_path | 指定比对文件,参考[比对文件说明](#313-比对文件说明) | 是 | +| -o 或 --output_path | 配置比对结果文件存盘目录,str 类型。文件名称基于时间戳自动生成,格式为:`compare_{timestamp}.vis或build_{timestamp}.vis`。 | 是 | +| -lm 或 --layer_mapping | 跨套件比对,例如同一个模型分别使用了DeepSpeed和Megatron套件的比对场景。配置该参数时表示开启跨套件Layer层的比对功能,指定模型代码中的Layer层后,可以识别对应dump数据中的模块或API。需要指定自定义映射文件*.yaml。自定义映射文件的格式请参见[自定义映射文件(Layer)](#71-自定义映射文件layer),如何配置自定义映射文件请参考[模型分级可视化如何配置layer mapping映射文件](./visualization/layer_mapping_example.md)。 | 否 | +| -oc 或 --overflow_check | 是否开启溢出检测模式,开启后会在输出vis文件中(`compare_{timestamp}.vis或build_{timestamp}.vis`)对每个溢出节点进行标记溢出等级,溢出等级说明参考[溢出等级说明](#312-溢出等级说明) | 否 | +| -f 或 --fuzzy_match | 是否开启模糊匹配,bool类型。模糊匹配说明参考[匹配说明](#311-匹配说明) | 否 | +| -cs 或 --complete_stack | 是否使用完整的堆栈信息,bool类型。默认使用精简的堆栈信息,数据量小有助于增加流畅度。完整堆栈和精简堆栈信息参考[堆栈信息说明](#72-堆栈信息说明) | 否 | #### 3.1.1 匹配说明 @@ -302,6 +302,16 @@ msprobe -f pytorch graph -i ./compare.json -o ./output ├── compare_stepn_rankn_{timestamp}.vis ``` +#### 3.2.4 仅模型结构比对 + +适用场景:**主要关注模型结构而非训练过程数据**。例如,在模型迁移过程中,确保迁移前后模型结构的一致性,或在排查精度差异时,判断是否由模型结构差异所引起。 + +使用msprobe工具对模型数据进行采集时,**可选择仅采集模型结构(task配置为structure)**,此配置将避免采集模型训练过程的数据,从而显著减少采集所需的时间。 + +dump配置请参考[dump配置示例](./03.config_examples.md#16-task-配置为-structure) + +得到dump数据后,若需比较特定两个rank之间的数据,请参考[3.2.2 双图比对](#322-双图比对);若需进行多个rank或多个step的数据批量比对,请参考[3.2.3 批量构建或比对](#323-批量构建或比对)。 + ## 4.启动tensorboard ### 4.1 可直连的服务器 @@ -456,3 +466,11 @@ yaml文件中只需配置待调试侧与标杆侧模型代码中功能一致但 } ``` +# FAQ +1. 图比对场景,节点呈现灰色,且没有精度比对数据,怎么处理? + +节点呈现灰色,代表左边待调试侧节点与右边标杆侧节点没有匹配上,可能有以下几点原因: + +- **标杆侧确实没有能与待调试侧匹配上的节点**,属于代码实现上的差异,请确认此差异是否正常,是否会影响到整网精度。 +- **节点的输入或输出type、shape不一致,参数个数不一致,节点所在层级的父层级不一致**,导致节点无法匹配,具体匹配规则见[匹配说明](#311-匹配说明),可尝试使用模糊匹配功能,如何使用此功能请参考[构图命令行说明](#31-构图命令行说明)。如果是参数shape不一致,即使是模糊匹配功能也无法让节点匹配上,请检查参数shape不一致是否合理。 +- **节点名称不一致**,导致节点无法匹配,可使用layer mapping功能,如何使用此功能请参考[构图命令行说明](#31-构图命令行说明),如何自定义映射文件请参考[模型分级可视化如何配置layer mapping映射文件](./visualization/layer_mapping_example.md)。 diff --git a/debug/accuracy_tools/msprobe/docs/22.visualization_MindSpore.md b/debug/accuracy_tools/msprobe/docs/22.visualization_MindSpore.md index 166f704721d58ed134711aabfc4d395b1bd86d0a..12306b8be027e7cee715f99f75b00f7504ba8252 100644 --- a/debug/accuracy_tools/msprobe/docs/22.visualization_MindSpore.md +++ b/debug/accuracy_tools/msprobe/docs/22.visualization_MindSpore.md @@ -43,14 +43,14 @@ msprobe -f mindspore graph -i ./compare.json -o ./output ``` **命令行参数说明**: -| 参数名 | 说明 | 是否必选 | -|-------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------| -------- | -| -i 或 --input_path | 指定比对文件,参考[比对文件说明](#313-比对文件说明) | 是 | -| -o 或 --output_path | 配置比对结果文件存盘目录,str 类型。文件名称基于时间戳自动生成,格式为:`compare_{timestamp}.vis或build_{timestamp}.vis`。 | 是 | -| -lm 或 --layer_mapping| 跨框架比对,MindSpore和PyTorch的比对场景。配置该参数时表示开启跨框架Layer层的比对功能,指定模型代码中的Layer层后,可以识别对应dump数据中的模块或API。需要指定自定义映射文件*.yaml。自定义映射文件的格式请参见[自定义映射文件(Layer)](#71-自定义映射文件layer)。 | 否 | -| -oc 或 --overflow_check | 是否开启溢出检测模式,开启后会在输出vis文件中(`compare_{timestamp}.vis或build_{timestamp}.vis`)对每个溢出节点进行标记溢出等级,溢出等级说明参考[溢出等级说明](#312-溢出等级说明) | 否 | -| -f 或 --fuzzy_match | 是否开启模糊匹配,bool类型。模糊匹配说明参考[匹配说明](#311-匹配说明) | 否 | -| -cs 或 --complete_stack | 是否使用完整的堆栈信息,bool类型。默认使用精简的堆栈信息,数据量小有助于增加流畅度。完整堆栈和精简堆栈信息参考[堆栈信息说明](#72-堆栈信息说明) | 否 | +| 参数名 | 说明 | 是否必选 | +|-------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -------- | +| -i 或 --input_path | 指定比对文件,参考[比对文件说明](#313-比对文件说明) | 是 | +| -o 或 --output_path | 配置比对结果文件存盘目录,str 类型。文件名称基于时间戳自动生成,格式为:`compare_{timestamp}.vis或build_{timestamp}.vis`。 | 是 | +| -lm 或 --layer_mapping| 跨框架比对,MindSpore和PyTorch的比对场景。配置该参数时表示开启跨框架Layer层的比对功能,指定模型代码中的Layer层后,可以识别对应dump数据中的模块或API。需要指定自定义映射文件*.yaml。自定义映射文件的格式请参见[自定义映射文件(Layer)](#71-自定义映射文件layer), 如何配置自定义映射文件请参考[模型分级可视化如何配置layer mapping映射文件](./visualization/layer_mapping_example.md)。 | 否 | +| -oc 或 --overflow_check | 是否开启溢出检测模式,开启后会在输出vis文件中(`compare_{timestamp}.vis或build_{timestamp}.vis`)对每个溢出节点进行标记溢出等级,溢出等级说明参考[溢出等级说明](#312-溢出等级说明) | 否 | +| -f 或 --fuzzy_match | 是否开启模糊匹配,bool类型。模糊匹配说明参考[匹配说明](#311-匹配说明) | 否 | +| -cs 或 --complete_stack | 是否使用完整的堆栈信息,bool类型。默认使用精简的堆栈信息,数据量小有助于增加流畅度。完整堆栈和精简堆栈信息参考[堆栈信息说明](#72-堆栈信息说明) | 否 | #### 3.1.1 匹配说明 @@ -303,6 +303,17 @@ msprobe -f mindspore graph -i ./compare.json -o ./output ├── compare_stepn_rankn_{timestamp}.vis ``` +#### 3.2.4 仅模型结构比对 + +适用场景:**主要关注模型结构而非训练过程数据**。例如,在模型迁移过程中,确保迁移前后模型结构的一致性,或在排查精度差异时,判断是否由模型结构差异所引起。 + +使用msprobe工具对模型数据进行采集时,**可选择仅采集模型结构(task配置为structure)**,此配置将避免采集模型训练过程的数据,从而显著减少采集所需的时间。 + +dump配置请参考[dump配置示例](./03.config_examples.md#35-task-配置为-structure) + +得到dump数据后,若需比较特定两个rank之间的数据,请参考[3.2.2 双图比对](#322-双图比对);若需进行多个rank或多个step的数据批量比对,请参考[3.2.3 批量构建或比对](#323-批量构建或比对)。 + + ## 4.启动tensorboard ### 4.1 可直连的服务器 @@ -471,3 +482,11 @@ yaml文件中只需配置MindSpore与PyTorch模型代码中功能一致但名称 ] } ``` +# FAQ +1. 图比对场景,节点呈现灰色,且没有精度比对数据,怎么处理? + +节点呈现灰色,代表左边待调试侧节点与右边标杆侧节点没有匹配上,可能有以下几点原因: + +- **标杆侧确实没有能与待调试侧匹配上的节点**,属于代码实现上的差异,请确认此差异是否正常,是否会影响到整网精度。 +- **节点的输入或输出type、shape不一致,参数个数不一致,节点所在层级的父层级不一致**,导致节点无法匹配,具体匹配规则见[匹配说明](#311-匹配说明),可尝试使用模糊匹配功能,如何使用此功能请参考[构图命令行说明](#31-构图命令行说明)。如果是参数shape不一致,即使是模糊匹配功能也无法让节点匹配上,请检查参数shape不一致是否合理。 +- **节点名称不一致**,导致节点无法匹配,可使用layer mapping功能,如何使用此功能请参考[构图命令行说明](#31-构图命令行说明),如何自定义映射文件请参考[模型分级可视化如何配置layer mapping映射文件](./visualization/layer_mapping_example.md)。 diff --git a/debug/accuracy_tools/msprobe/docs/27.dump_json_instruction.md b/debug/accuracy_tools/msprobe/docs/27.dump_json_instruction.md index 4a82a0775b072ddc0fc7672240c5b6cc92acf5b5..f994dc2301bcae6b23dc7a7503297aa4fe5b3724 100644 --- a/debug/accuracy_tools/msprobe/docs/27.dump_json_instruction.md +++ b/debug/accuracy_tools/msprobe/docs/27.dump_json_instruction.md @@ -1,16 +1,18 @@ # dump.json文件说明及示例 -## 1. dump.json文件介绍(Pytorch) +## 1. dump.json文件示例(PyTorch) ### 1.1 L0级别 -L0级别的dump.json文件包括模块的前反向的输入输出,以及模块的参数和参数梯度。以Pytorch的Conv2d模块为例,网络中模块调用代码为: -`output = torch.nn.Conv2d(64, 128, 5, padding=2, bias=True)(input)` +L0级别的dump.json文件包括模块的前反向的输入输出,以及模块的参数和参数梯度。以PyTorch的Conv2d模块为例,网络中模块调用代码为: +`output = self.conv2(input) # self.conv2 = torch.nn.Conv2d(64, 128, 5, padding=2, bias=True)` -dump.json文件中包含以下字段: +dump.json文件中包含以下数据名称: -1. `Module.conv2.Conv2d.forward.0`为模块的前向数据,其中input_args为模块的输入数据(位置参数),input_kwargs为模块的输入数据(关键字参数),output为模块的输出数据,parameters为模块的参数数据,包括权重(weight)和偏置(bias)。 -2. `Module.conv2.Conv2d.parameters_grad`为模块的参数梯度数据,包括权重(weight)和偏置(bias)的梯度。 -3. `Module.conv2.Conv2d.backward.0`为模块的反向数据,其中input为模块反向的输入梯度(对应前向输出的梯度),output为模块的反向输出梯度(对应前向输入的梯度)。 +- `Module.conv2.Conv2d.forward.0`:模块的前向数据,其中input_args为模块的输入数据(位置参数),input_kwargs为模块的输入数据(关键字参数),output为模块的输出数据,parameters为模块的参数数据,包括权重(weight)和偏置(bias)。 +- `Module.conv2.Conv2d.parameters_grad`:模块的参数梯度数据,包括权重(weight)和偏置(bias)的梯度。 +- `Module.conv2.Conv2d.backward.0`:模块的反向数据,其中input为模块反向的输入梯度(对应前向输出的梯度),output为模块的反向输出梯度(对应前向输入的梯度)。 + +**说明**:当dump时传入的model参数为List[torch.nn.Module]或Tuple[torch.nn.Module]时,模块级数据的命名中包含该模块在列表中的索引index,命名格式为`{Module}.{index}.*`,*表示以上三种模块级数据的命名格式,例如:`Module.0.conv1.Conv2d.forward.0`。 ```json { @@ -167,12 +169,12 @@ dump.json文件中包含以下字段: ``` ### 1.2 L1级别 -L1级别的dump.json文件包括API的前反向的输入输出。以Pytorch的relu函数为例,网络中API调用代码为: - `output = torch.nn.functional.relu(input)` +L1级别的dump.json文件包括API的前反向的输入输出。以PyTorch的relu函数为例,网络中API调用代码为: +`output = torch.nn.functional.relu(input)` -dump.json文件中包含以下字段: -1. `Functional.relu.0.forward`为API的前向数据,其中input_args为API的输入数据(位置参数),input_kwargs为API的输入数据(关键字参数),output为API的输出数据。 -2. `Functional.relu.0.backward`为API的反向数据,其中input为API的反向输入梯度(对应前向输出的梯度),output为API的反向输出梯度(对应前向输入的梯度)。 +dump.json文件中包含以下数据名称: +- `Functional.relu.0.forward`:API的前向数据,其中input_args为API的输入数据(位置参数),input_kwargs为API的输入数据(关键字参数),output为API的输出数据。 +- `Functional.relu.0.backward`:API的反向数据,其中input为API的反向输入梯度(对应前向输出的梯度),output为API的反向输出梯度(对应前向输入的梯度)。 ```json { @@ -272,12 +274,14 @@ mix级别的dump.json文件同时包括L0和L1级别的dump数据,文件格式 L0级别的dump.json文件包括模块的前反向的输入输出,以及模块的参数和参数梯度。 以MindSpore的Conv2d模块为例,dump.json文件中使用的模块调用代码为: -`output = mindspore.nn.Conv2d(64, 128, 5, pad_mode='same', has_bias=True)(input)` +`output = self.conv2(input) # self.conv2 = mindspore.nn.Conv2d(64, 128, 5, pad_mode='same', has_bias=True)` + +dump.json文件中包含以下数据名称: +- `Cell.conv2.Conv2d.forward.0`:模块的前向数据,其中input_args为模块的输入数据(位置参数),input_kwargs为模块的输入数据(关键字参数),output为模块的输出数据,parameters为模块的参数数据,包括权重(weight)和偏置(bias)。 +- `Cell.conv2.Conv2d.parameters_grad`:模块的参数梯度数据,包括权重(weight)和偏置(bias)的梯度。 +- `Cell.conv2.Conv2d.backward.0`:模块的反向数据,其中input为模块反向的输入梯度(对应前向输出的梯度),output为模块的反向输出梯度(对应前向输入的梯度)。 -dump.json文件中包含以下字段: -1. `Cell.conv2.Conv2d.forward.0`为模块的前向数据,其中input_args为模块的输入数据(位置参数),input_kwargs为模块的输入数据(关键字参数),output为模块的输出数据,parameters为模块的参数数据,包括权重(weight)和偏置(bias)。 -2. `Cell.conv2.Conv2d.parameters_grad`为模块的参数梯度数据,包括权重(weight)和偏置(bias)的梯度。 -3. `Cell.conv2.Conv2d.backward.0`为模块的反向数据,其中input为模块反向的输入梯度(对应前向输出的梯度),output为模块的反向输出梯度(对应前向输入的梯度)。 +**说明**:当dump时传入的model参数为List[mindspore.nn.Cell]或Tuple[mindspore.nn.Cell]时,模块级数据的命名中包含该模块在列表中的索引index,命名格式为`{Cell}.{index}.*`,*表示以上三种模块级数据的命名格式,例如:`Cell.0.conv2.Conv2d.forward.0`。 ```json { @@ -429,9 +433,9 @@ dump.json文件中包含以下字段: L1级别的dump.json文件包括API的前反向的输入输出,以MindSpore的relu函数为例,网络中API调用代码为: `output = mindspore.ops.relu(input)` - dump.json文件中包含以下字段: -1. `Functional.relu.0.forward`为API的前向数据,其中input_args为API的输入数据(位置参数),input_kwargs为API的输入数据(关键字参数),output为API的输出数据。 -2. `Functional.relu.0.backward`为API的反向数据,其中input为API的反向输入梯度(对应前向输出的梯度),output为API的反向输出梯度(对应前向输入的梯度)。 + dump.json文件中包含以下数据名称: +- `Functional.relu.0.forward`:API的前向数据,其中input_args为API的输入数据(位置参数),input_kwargs为API的输入数据(关键字参数),output为API的输出数据。 +- `Functional.relu.0.backward`:API的反向数据,其中input为API的反向输入梯度(对应前向输出的梯度),output为API的反向输出梯度(对应前向输入的梯度)。 ```json { diff --git a/debug/accuracy_tools/msprobe/docs/28.debugger_save_instruction.md b/debug/accuracy_tools/msprobe/docs/28.debugger_save_instruction.md index 0f6e2b9c39e6782b57def33026257cbb29904ca6..6f4d519d5f61d5efaaffe54a1bde4f140b539f72 100644 --- a/debug/accuracy_tools/msprobe/docs/28.debugger_save_instruction.md +++ b/debug/accuracy_tools/msprobe/docs/28.debugger_save_instruction.md @@ -91,3 +91,4 @@ PrecisionDebugger.save(dict_variable, "dict_variable", save_backward=False) - indexes: 索引,在保存嵌套结构数据时的索引。例如:嵌套结构为`{"key1": "value1", "key2": ["value2", "value3"]}`,"value2"的索引为"key2.0" - file_suffix:文件后缀,pytorch场景为"pt",mindspore场景为"npy" + diff --git a/debug/accuracy_tools/msprobe/docs/28.kernel_dump_MindSpore.md b/debug/accuracy_tools/msprobe/docs/28.kernel_dump_MindSpore.md new file mode 100644 index 0000000000000000000000000000000000000000..6b8cc558aa22526158033cfb35f31203d8b04278 --- /dev/null +++ b/debug/accuracy_tools/msprobe/docs/28.kernel_dump_MindSpore.md @@ -0,0 +1,69 @@ +# MindSpore 场景的 kernel dump 说明 + +当使用 msprobe 数据采集功能时,level 配置为 "L2" 表示采集 kernel 层级的算子数据,仅支持昇腾 NPU 平台。 + +本文主要介绍 kernel dump 的配置示例和采集结果介绍, msprobe 数据采集功能的详细使用参考 《[MindSpore 场景的精度数据采集](./06.data_dump_MindSpore.md)》。 + +## 1 kernel dump 配置示例 + +使用 kernel dump 时,list 必须要填一个 API 名称,kernel dump 目前每个 step 只支持采集一个 API 的数据。 +API 名称填写参考 L1 dump 结果文件 dump.json 中的API名称,命名格式为:`{api_type}.{api_name}.{API调用次数}.{forward/backward}`。 + +```json +{ + "task": "tensor", + "dump_path": "/home/data_dump", + "level": "L2", + "rank": [], + "step": [], + "tensor": { + "scope": [], + "list": ["Functional.linear.0.backward"] + } +} +``` + +## 2 结果文件介绍 + +### 2.1 采集结果说明 + +如果 API kernel 级数据采集成功,会打印以下信息: + +```bash +The kernel data of {api_name} is dumped successfully. +``` + +注意:如果打印该信息后,没有数据生成,参考**常见问题3.1**进行排查。 + +如果 kernel dump 遇到不支持的 API, 会打印以下信息: + +```bash +The kernel dump does not support the {api_name} API. +``` + +其中 {api_name} 是对应溢出的 API 名称。 + +### 2.2 输出文件说明 +kernel dump 采集成功后,会在指定的 dump_path 目录下生成如下文件: + +``` +├── /home/data_dump/ +│ ├── step0 +│ │ ├── 20241201103000 # 日期时间格式,表示2024-12-01 10:30:00 +│ │ │ ├── 0 # 表示 device id +│ │ │ │ ├──{op_type}.{op_name}.{task_id}.{stream_id}.{timestamp} # kernel 层算子数据 +│ │ │ ... +│ │ ├── kernel_config_{device_id}.json # kernel dump 在接口调用过程中生成的中间文件,一般情况下无需关注 +│ │ ... +│ ├── step1 +│ ... +``` +成功采集到数据后,可以使用 msprobe 工具提供的《[PyTorch 场景的数据解析](./14.data_parse_PyTorch.md)》功能分析数据。 + +## 3 常见问题 + +#### 3.1 采集结果文件为空,有可能是什么原因? + +1. 首先需要确认工具使用方式、配置文件内容、list 填写的 API 名称格式是否都正确无误。 + +2. 其次需要确认 API 是否运行在昇腾 NPU 上,如果是运行在其他设备上则不会存在 kernel 级数据。 diff --git a/debug/accuracy_tools/msprobe/docs/FAQ.md b/debug/accuracy_tools/msprobe/docs/FAQ.md index ea4c1022d8a82203354f5c411ee62001f7fe21af..833ca07a236f33e69b102d4acb45d35cd6fe7e3a 100644 --- a/debug/accuracy_tools/msprobe/docs/FAQ.md +++ b/debug/accuracy_tools/msprobe/docs/FAQ.md @@ -31,6 +31,11 @@ ``` 在上述场景中,若希望采集relu数据,只需要将`relu(x)`修改为`torch.relu(x)`即可。 +4. 在使用L0 dump时,发现有些 module 的数据没有采集下来,原因是什么? + - 确认日志打印中是否存在`The {module_name} has registered deprecated register_backward_hook`信息, + 该信息说明 module 挂载了被 PyTorch 框架废弃的 register_backward_hook,这与工具使用的 register_full_backward_hook 接口会产生冲突,故工具会跳过该 module 的反向数据采集。 + - 如果您希望所有 module 数据都能采集下来,可以将模型中使用的 register_backward_hook 接口改为 PyTorch 框架推荐的 register_full_backward_pre_hook 或 register_full_backward_hook 接口。 + # 2 精度预检(PyTorch) 1. 预检工具在 dump 和 run_ut 的过程中,是否需要同时开启或关闭 jit 编译(jit_compile)? diff --git a/debug/accuracy_tools/msprobe/docs/img/merge_result.png b/debug/accuracy_tools/msprobe/docs/img/merge_result.png index fc92b03544a44b2399f17c4f418d7396f0c3ee91..a8c97a9f619206ae53a86df2013c2ba19202713b 100644 Binary files a/debug/accuracy_tools/msprobe/docs/img/merge_result.png and b/debug/accuracy_tools/msprobe/docs/img/merge_result.png differ diff --git a/debug/accuracy_tools/msprobe/docs/img/monitor/step_count_per_record.png b/debug/accuracy_tools/msprobe/docs/img/monitor/step_count_per_record.png new file mode 100644 index 0000000000000000000000000000000000000000..9347d3ecae01b6d4717db8fcdd6c40b6766fa908 Binary files /dev/null and b/debug/accuracy_tools/msprobe/docs/img/monitor/step_count_per_record.png differ diff --git a/debug/accuracy_tools/msprobe/docs/visualization/GPTModel.png b/debug/accuracy_tools/msprobe/docs/visualization/GPTModel.png new file mode 100644 index 0000000000000000000000000000000000000000..71c1ff2e5bd9a38489d6ff0b7365936508660fec Binary files /dev/null and b/debug/accuracy_tools/msprobe/docs/visualization/GPTModel.png differ diff --git a/debug/accuracy_tools/msprobe/docs/visualization/ParallelMLP.png b/debug/accuracy_tools/msprobe/docs/visualization/ParallelMLP.png new file mode 100644 index 0000000000000000000000000000000000000000..d76650c9103c2d81d6b07458832945a237b43acc Binary files /dev/null and b/debug/accuracy_tools/msprobe/docs/visualization/ParallelMLP.png differ diff --git a/debug/accuracy_tools/msprobe/docs/visualization/layer_mapping_example.md b/debug/accuracy_tools/msprobe/docs/visualization/layer_mapping_example.md new file mode 100644 index 0000000000000000000000000000000000000000..35acb23ab4c44a763909fce40ae5cec136b584f6 --- /dev/null +++ b/debug/accuracy_tools/msprobe/docs/visualization/layer_mapping_example.md @@ -0,0 +1,132 @@ +# 模型分级可视化如何配置layer mapping映射文件 + +## 1.使用场景 +同框架跨套件比对(例如PyTorch DeepSpeed vs Megatron),或者跨框架比对(例如PyTorch vs MindSpore),**由于代码实现的差异,导致一些模型层级和层级命名有所不同无法进行匹配**,需要进行layer层名称映射,才能够比对。 + +## 2.模块命名说明 + +由于有些节点的名称比较长,例如Module.module.module.language_model.embedding.Embedding.forward.0,在图节点上由于字符串过长无法完整显示,forward或backward信息被省略,**因此节点中显示的名称字符串去掉了Module前缀,并将forward或backward信息提取到名称字符串的第二位展示**。 + +![module_name.png](./module_name.png) + +![module_name1.png](./module_name1.png) + +### 2.1 命名格式 + +**{Module}.{module_name}.{class_name}.{forward/backward}.{调用次数}** + +**layer mapping主要是针对module_name的映射** + +#### 2.1.1 命名示例 + +- **Module.module.Float16Module.forward.0** -----> Module{**Module**}.module{**module_name**}.Float16Module{**class_name**}.forward.0{**调用次数**} +- **Module.module.module.GPTModel.forward.0** -----> Module{**Module**}.module.module{**module_name**}.GPTModel{**class_name**}.forward.0{**调用次数**} +- **Module.module.module.language_model.TransformerLanguageModel.forward.0** -----> Module{**Module**}.module.module.language_model{**module_name**}.TransformerLanguageModel{**class_name**}.forward.0{**调用次数**} +- **Module.module.module.language_model.embedding.Embedding.forward.0** -----> Module{**Module**}.module.module.language_model.embedding{**module_name**}.Embedding{**class_name**}.forward.0{**调用次数**} + +可以看到,module_name随着模型层级的深入在变长,**embedding层module_name拼接了它的上层language_model、上上层module和顶层module**。 + +## 3.示例 + +如图所示,左边为NPU模型,右边为GPU模型,由于代码实现上的差异,导致模型层级和层级命名有所不同,导致节点无法匹配,**图上节点显示为灰色,表示节点未匹配**。 + +![no_mapping.png](./no_mapping.png) + +### 3.1 看图分析 + +同一模型使用了不同套件或者框架,虽然两个模型的层级关系和层级命名可能有所不同,但也可以从图上的**节点名称**看出一些匹配关系,例如同是embedding层,代码里也是会命名为xxx_embedding,不会命名为xxx_norm,体现在节点名称上也是带有embedding的信息,并且层级关系也是大致相同的。 + +![no_mapping_analyze.png](./no_mapping_analyze.png) + +分析可知,节点匹配关系如下: + +**注意,仅需关注module_name的差异** + +| NPU节点名称 | GPU节点名称 | module_name差异 | +|-------------------|----------------------------------------------------------------|---------------------------| +| Module.module.Float16Module.forward.0 | Module.model.FloatModule.forward.0 | NPU为module,GPU为model | +| Module.module.module.GPTModel.forward.0 | Module.model.module.GPT2Model.forward.0 | NPU为module,GPU为module,无差异 | +| Module.module.module.language_model.TransformerLanguageModel.forward.0 | 无 | NPU多了一层 | +| Module.module.module.language_model.embedding.Embedding.forward.0 | Module.module.module.embedding.LanguageModelEmbedding.forward.0 | NPU为language_model.embedding,GPU为embedding | +| Module.module.module.language_model.rotary_pos_emb.RotaryEmbedding.forward.0 | Module.module.module.rotary_pos_emb.RotaryEmbedding.forward.0 | NPU为language_model.rotary_pos_emb,GPU为rotary_pos_emb | +| Module.module.module.language_model.encoder.ParallelTransformer.forward.0 | Module.module.module.decoder.TransformerBlock.forward.0 | NPU为language_model.encoder,GPU为decoder | +| Module.module.module.language_model.encoder.layers.0.ParallelTransformerLayer.forward.0 | Module.module.module.decoder.layers.0.TransformerLayer.forward.0 | 父层级有差异,本层级NPU和GPU都叫layers,无差异 | + +### 3.2 构建layer_mapping配置文件 +准备一个命名为mapping.yaml文件,建立**module_name**的映射关系 + +#### 3.2.1 顶层模块映射 +NPU和GPU侧的模块Module.module.Float16Module.forward.0和Module.model.FloatModule.forward.0处于图的顶层,需要进行如下配置: + +![top_layer.png](./top_layer.png) + +```yaml +TopLayer: + module: model +``` + +#### 3.2.2 其他模块映射 +配置module下的子模块,虽然两边的class_name不同(NPU侧为GPTModel,GPU侧为GPT2Model),**但是仅需取NPU侧也就是左边图的class_name进行配置,无需关心右边图的class_name叫什么**。 + +**这里涉及到跨层级的配置,NPU多了一层language_model层**,将language_model作为embedding层、rotary_pos_emb层和encoder层的前缀,进行如下配置: + +![GPTModel.png](./GPTModel.png) + +```yaml +GPTModel: + language_model.embedding: embedding + language_model.rotary_pos_emb: rotary_pos_emb + language_model.encoder: decoder +``` +然后看Module.module.module.language_model.encoder.ParallelTransformer.forward.0层下的子模块: + +此层下的若干个层,NPU和GPU的层名都叫layers,**当前层名称相同,则不用进行配置**。 + +### 3.3 查看效果 + +执行命令,指定-lm: +``` +msprobe -f pytorch graph -i ./compare.json -o ./output -lm ./mapping.yaml +``` +或 +``` +msprobe -f mindspore graph -i ./compare.json -o ./output -lm ./mapping.yaml +``` +可以看到,除了language_model层(NPU多的一层,GPU没有层与其匹配),其余在mapping.yaml文件配置的层均匹配上了。 + +![mapping.png](./mapping.png) + +### 3.4 继续配置 + +展开节点过程中,如果发现还有未匹配节点,则继续配置mapping.yaml + +![no_mapping1.png](./no_mapping1.png) + +按前一章过程进行分析配置,分析可知,节点匹配关系如下: + +| NPU节点名称 | GPU节点名称 | 差异 | +|-------------------|------------------------------------------------------------------|---------------------------------------------| +| Module.module.module.language_model.encoder.layers.0.mlp.dense_h_to_4h.ColumnParallelLinear.forward.0 | Module.module.module.decoder.layers.0.mlp.linear_fc1.TELayerNormColumnParallelLinear.forward.0 | NPU为dense_h_to_4h,GPU为linear_fc1 | +| Module.module.module.language_model.encoder.layers.0.mlp.dense_4h_to_h.RowParallelLinear.forward.0 | Module.module.module.decoder.layers.0.mlp.linear_fc2.TERowParallelLinear.forward.0 | NPU为dense_4h_to_h,GPU为linear_fc2 | + +![ParallelMLP.png](./ParallelMLP.png) + +追加mapping.yaml配置: + +```yaml +TopLayer: + module: model + +GPTModel: + language_model.embedding: embedding + language_model.rotary_pos_emb: rotary_pos_emb + language_model.encoder: decoder + +ParallelMLP: + dense_h_to_4h: linear_fc1 + dense_4h_to_h: linear_fc2 +``` + +执行命令,查看效果,可以看到节点已成功匹配上。 + +![mapping1.png](./mapping1.png) diff --git a/debug/accuracy_tools/msprobe/docs/visualization/mapping.png b/debug/accuracy_tools/msprobe/docs/visualization/mapping.png new file mode 100644 index 0000000000000000000000000000000000000000..fb03d85fab802ed881b75b5eba67bff815f97b30 Binary files /dev/null and b/debug/accuracy_tools/msprobe/docs/visualization/mapping.png differ diff --git a/debug/accuracy_tools/msprobe/docs/visualization/mapping1.png b/debug/accuracy_tools/msprobe/docs/visualization/mapping1.png new file mode 100644 index 0000000000000000000000000000000000000000..1ec713f29ca812b900adcab22c198e8705bbe1bb Binary files /dev/null and b/debug/accuracy_tools/msprobe/docs/visualization/mapping1.png differ diff --git a/debug/accuracy_tools/msprobe/docs/visualization/module_name.png b/debug/accuracy_tools/msprobe/docs/visualization/module_name.png new file mode 100644 index 0000000000000000000000000000000000000000..8e959dc7ce8d9c8e5dec72853b02e29bdaa2389c Binary files /dev/null and b/debug/accuracy_tools/msprobe/docs/visualization/module_name.png differ diff --git a/debug/accuracy_tools/msprobe/docs/visualization/module_name1.png b/debug/accuracy_tools/msprobe/docs/visualization/module_name1.png new file mode 100644 index 0000000000000000000000000000000000000000..764fa08166050123e12ebd87d9a4012a64d688bb Binary files /dev/null and b/debug/accuracy_tools/msprobe/docs/visualization/module_name1.png differ diff --git a/debug/accuracy_tools/msprobe/docs/visualization/no_mapping.png b/debug/accuracy_tools/msprobe/docs/visualization/no_mapping.png new file mode 100644 index 0000000000000000000000000000000000000000..47693dc78cdf0c205184f7be2bf2cd196a4b5ce8 Binary files /dev/null and b/debug/accuracy_tools/msprobe/docs/visualization/no_mapping.png differ diff --git a/debug/accuracy_tools/msprobe/docs/visualization/no_mapping1.png b/debug/accuracy_tools/msprobe/docs/visualization/no_mapping1.png new file mode 100644 index 0000000000000000000000000000000000000000..88f8dc9e7aa89775f2a33c6c41d314a60af8ab76 Binary files /dev/null and b/debug/accuracy_tools/msprobe/docs/visualization/no_mapping1.png differ diff --git a/debug/accuracy_tools/msprobe/docs/visualization/no_mapping_analyze.png b/debug/accuracy_tools/msprobe/docs/visualization/no_mapping_analyze.png new file mode 100644 index 0000000000000000000000000000000000000000..f9ff18681a2b99507658152921a38dc00e1a2918 Binary files /dev/null and b/debug/accuracy_tools/msprobe/docs/visualization/no_mapping_analyze.png differ diff --git a/debug/accuracy_tools/msprobe/docs/visualization/top_layer.png b/debug/accuracy_tools/msprobe/docs/visualization/top_layer.png new file mode 100644 index 0000000000000000000000000000000000000000..9f2482969c6b5340e15251a794b339858e010ba4 Binary files /dev/null and b/debug/accuracy_tools/msprobe/docs/visualization/top_layer.png differ diff --git a/debug/accuracy_tools/msprobe/mindspore/__init__.py b/debug/accuracy_tools/msprobe/mindspore/__init__.py index 048f0992789b42465fa99b8738ff6db0531bf562..089c29eb098ad4305edcca1306462f8924dd9291 100644 --- a/debug/accuracy_tools/msprobe/mindspore/__init__.py +++ b/debug/accuracy_tools/msprobe/mindspore/__init__.py @@ -25,3 +25,4 @@ except ImportError: from msprobe.mindspore.debugger.precision_debugger import PrecisionDebugger from msprobe.mindspore.common.utils import seed_all +from msprobe.mindspore.monitor.module_hook import TrainerMon \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py index 98c6b4b98530ec447c2e239c11b5d4d7b927d874..557d731e042913da3a622035219ec8dea0409ab4 100644 --- a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +++ b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py @@ -16,7 +16,7 @@ import os from tqdm import tqdm -from msprobe.core.common.const import Const, CompareConst, MsCompareConst +from msprobe.core.common.const import Const, CompareConst from msprobe.core.common.file_utils import FileOpen, create_directory, write_csv, load_json, load_yaml from msprobe.core.common.utils import add_time_as_suffix from msprobe.mindspore.api_accuracy_checker.api_info import ApiInfo @@ -25,6 +25,7 @@ from msprobe.mindspore.api_accuracy_checker.base_compare_algorithm import compar from msprobe.mindspore.api_accuracy_checker.data_manager import DataManager from msprobe.mindspore.api_accuracy_checker.utils import (check_and_get_from_json_dict, global_context, trim_output_compute_element_list) +from msprobe.mindspore.common.const import MsCompareConst from msprobe.mindspore.common.log import logger from msprobe.mindspore.api_accuracy_checker import torch_mindtorch_importer @@ -156,6 +157,7 @@ class ApiAccuracyChecker: real_api_str = Const.SEP.join(api_name_str_list[1:-2]) api_list = load_yaml(yaml_path) supported_tensor_api_list = api_list.get(MsCompareConst.SUPPORTED_TENSOR_LIST_KEY) + supported_fusion_api_list = MsCompareConst.SUPPORTED_FUSION_LIST if api_type_str in (MsCompareConst.MINT, MsCompareConst.MINT_FUNCTIONAL) \ and global_context.get_framework() == Const.MS_FRAMEWORK: return True @@ -165,6 +167,9 @@ class ApiAccuracyChecker: if api_type_str == MsCompareConst.TENSOR_API and real_api_str in supported_tensor_api_list \ and global_context.get_framework() == Const.MS_FRAMEWORK: return True + if api_type_str == MsCompareConst.FUNCTIONAL_API and real_api_str in supported_fusion_api_list \ + and global_context.get_framework() == Const.MS_FRAMEWORK: + return True return False def parse(self, api_info_path): diff --git a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/api_runner.py b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/api_runner.py index f42702be0b114e40e5e31dc4326bd9ca21f82202..36e506f67737cdea4452ba27f4fad0524d4c2884 100644 --- a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/api_runner.py +++ b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/api_runner.py @@ -15,11 +15,13 @@ import mindspore from mindspore import ops -from msprobe.core.common.const import Const, MsCompareConst +from msprobe.core.common.const import Const from msprobe.core.common.exceptions import ApiAccuracyCheckerException from msprobe.mindspore.api_accuracy_checker.compute_element import ComputeElement from msprobe.mindspore.api_accuracy_checker.type_mapping import float_dtype_str_list, torch_dtype_to_dtype_str from msprobe.mindspore.api_accuracy_checker.utils import convert_to_tuple +from msprobe.mindspore.api_accuracy_checker.bench_functions.fusion_operator import fusion +from msprobe.mindspore.common.const import MsCompareConst from msprobe.mindspore.common.log import logger @@ -64,7 +66,9 @@ api_parent_module_mapping = { (MsCompareConst.MINDTORCH_FUNC, Const.MT_FRAMEWORK): mindtorch_func, (MsCompareConst.MINDTORCH_FUNC, Const.PT_FRAMEWORK): torch.nn.functional, (MsCompareConst.MINDTORCH_DIST, Const.MT_FRAMEWORK): mindtorch_dist, - (MsCompareConst.MINDTORCH_DIST, Const.PT_FRAMEWORK): torch.distributed + (MsCompareConst.MINDTORCH_DIST, Const.PT_FRAMEWORK): torch.distributed, + (MsCompareConst.FUNCTIONAL_API, Const.MS_FRAMEWORK): mindspore.ops, + (MsCompareConst.FUSION_API, Const.PT_FRAMEWORK): fusion } @@ -83,7 +87,9 @@ api_parent_module_str_mapping = { (MsCompareConst.MINDTORCH_FUNC, Const.MT_FRAMEWORK): "mindtorch_func", (MsCompareConst.MINDTORCH_FUNC, Const.PT_FRAMEWORK): "torch.nn.functional", (MsCompareConst.MINDTORCH_DIST, Const.MT_FRAMEWORK): "mindtorch_dist", - (MsCompareConst.MINDTORCH_DIST, Const.PT_FRAMEWORK): "torch.distributed" + (MsCompareConst.MINDTORCH_DIST, Const.PT_FRAMEWORK): "torch.distributed", + (MsCompareConst.FUNCTIONAL_API, Const.MS_FRAMEWORK): "mindspore.ops", + (MsCompareConst.FUSION_API, Const.PT_FRAMEWORK): "fusion" } @@ -125,7 +131,8 @@ class ApiRunner: err_msg = f"ApiRunner.get_info_from_name failed: api_name_str: {api_name_str} is not in defined format" logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue)) api_type_str, api_sub_name = api_name_list[0], api_name_list[1] - if api_type_str not in [MsCompareConst.MINT, MsCompareConst.MINT_FUNCTIONAL, MsCompareConst.TENSOR_API] \ + if api_type_str not in [MsCompareConst.MINT, MsCompareConst.MINT_FUNCTIONAL, MsCompareConst.TENSOR_API, + MsCompareConst.FUNCTIONAL_API] \ and api_platform == Const.MS_FRAMEWORK: err_msg = f"ApiRunner.get_info_from_name failed: not mint, mint.nn.functional or Tensor api" logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue)) @@ -139,9 +146,9 @@ class ApiRunner: def get_api_instance(api_type_str, api_sub_name, api_platform): """ Args: - api_type_str: str, Union["MintFunctional", "Mint", "Tensor"] + api_type_str: str, Union["MintFunctional", "Mint", "Tensor", "Functional"] api_sub_name: str, e.g. "relu" - api_platform: str: Union["mindpore", "torch"] + api_platform: str: Union["mindpore", "pytorch"] Return: api_instance: function object @@ -151,9 +158,12 @@ class ApiRunner: mindspore.mint.{api_sub_name} <--> torch.{api_sub_name} mindspore.mint.nn.functional.{api_sub_name} <--> torch.nn.functional.{api_sub_name} """ - - api_parent_module = api_parent_module_mapping.get((api_type_str, api_platform)) - api_parent_module_str = api_parent_module_str_mapping.get((api_type_str, api_platform)) + if api_sub_name in MsCompareConst.SUPPORTED_FUSION_LIST and api_platform == "pytorch": + api_parent_module = api_parent_module_mapping.get((MsCompareConst.FUSION_API, api_platform)) + api_parent_module_str = api_parent_module_str_mapping.get((MsCompareConst.FUSION_API, api_platform)) + else: + api_parent_module = api_parent_module_mapping.get((api_type_str, api_platform)) + api_parent_module_str = api_parent_module_str_mapping.get((api_type_str, api_platform)) full_api_name = api_parent_module_str + Const.SEP + api_sub_name if not hasattr(api_parent_module, api_sub_name): diff --git a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py index ead03d25ea5c2e6bb0422486f1939c5b31ee589b..da2f8ad612fcf3a42083894ff1b8e56db757f919 100644 --- a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +++ b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py @@ -18,9 +18,10 @@ from abc import ABC, abstractmethod import mindspore import numpy as np import torch -from msprobe.core.common.const import CompareConst, MsCompareConst +from msprobe.core.common.const import CompareConst from msprobe.core.common.exceptions import ApiAccuracyCheckerException from msprobe.mindspore.common.log import logger +from msprobe.mindspore.common.const import MsCompareConst class CompareResult: diff --git a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py new file mode 100644 index 0000000000000000000000000000000000000000..cb268efeae90a51465493c65caa948045bae4913 --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py @@ -0,0 +1,602 @@ +# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import namedtuple +import torch +import torch.nn as nn +import numpy as np + +from einops import rearrange + + +from msprobe.pytorch.common.utils import logger + +GTYPE = torch.float64 # arm host必须选择float64,x86环境选择float32即可,64也行。arm计算很慢,s=8k的场景建议使用x86 +SOFTMAX_BUILD_MODE = "QKV" # "MAX_SUM" + +FaForwardParams = namedtuple("FaForwardParams", + ["q", "k", "v", "drop_mask", "attn_mask", "pse", "scalar_value", "keep_prob"]) +FaBackwardParams = namedtuple("FaBackwardParams", + ["dx", "q", "k", "v", "softmax_res", "drop_mask", "pse", "scalar_value", "keep_prob"]) +RebuildSoftmaxParams = namedtuple("RebuildSoftmaxParams", + ["q", "k", "attn_mask", "pse", "scalar_value", "softmax_max", "softmax_sum"]) + + +def softmax_forward(x): + x_max = torch.max(x, dim=-1, keepdims=True)[0] + x_sub = x.sub(x_max) + y = torch.exp(x_sub) + x_sum = y.sum(dim=-1, keepdims=True) + res = y.div(x_sum) + return res, x_max, x_sum + + +def softmax_grad(dp, softmax_res): + muls = dp * softmax_res + muls_r = muls.sum(dim=-1, keepdims=True) + sub_r = dp - muls_r + res = sub_r * softmax_res + return res + + +def broadcast_kv(num_heads, num_kv_heads, kv_tensor, dtype): + if num_kv_heads == 0 or num_kv_heads > num_heads: + raise ValueError(f"num_kv_heads must be non-zero and bigger than num_heads.") + + factor = num_heads // num_kv_heads + kv_shape = kv_tensor.shape + b = kv_shape[0] + s = kv_shape[2] + d = kv_shape[3] + kv_res = torch.zeros([b, num_heads, s, d]).to(dtype) + for i in range(num_heads): + j = i // factor + kv_res[:, i:i + 1, :, :] = kv_tensor[:, j:j + 1, :, :] + return kv_res + + +def calculate_qk(q, k, attn_mask, pse, scalar_value): + if k.dim() != 4: + raise ValueError(f"k tensor dimension must be 4, but got {k.dim()} dimensions (shape: {k.shape})") + + if k.dim() == 3: + k = k.unsqueeze(1) # 在head维度扩展 + + if pse is None or len(pse.shape) == 0: + qk = torch.matmul(q, k.permute(0, 1, 3, 2)).mul(scalar_value) + else: + qk = (torch.matmul(q, k.permute(0, 1, 3, 2)) + pse).mul(scalar_value) + if attn_mask is None or len(attn_mask.shape) == 0: + return qk + else: + qk = qk + attn_mask.bool() * (-40000.0) # -10000 + return qk + + +def fusion_attention_forward(forward_params): + q = forward_params.q + k = forward_params.k + v = forward_params.v + drop_mask = forward_params.drop_mask + attn_mask = forward_params.attn_mask + pse = forward_params.pse + scalar_value = forward_params.scalar_value + keep_prob = forward_params.keep_prob + + qk = calculate_qk(q, k, attn_mask, pse, scalar_value) + softmax_res, softmax_max, softmax_sum = softmax_forward(qk) + if drop_mask is None or len(drop_mask.shape) == 0: + drop_res = softmax_res + else: + drop_res = softmax_res * drop_mask * (1.0 / keep_prob) + y = torch.matmul(drop_res, v) + return y, softmax_max, softmax_sum + + +def fusion_attention_backward(backward_params): + dx = backward_params.dx + q = backward_params.q + k = backward_params.k + v = backward_params.v + softmax_res = backward_params.softmax_res + drop_mask = backward_params.drop_mask + pse = backward_params.pse + scalar_value = backward_params.scalar_value + keep_prob = backward_params.keep_prob + dp = torch.matmul(dx, v.permute(0, 1, 3, 2)) + if drop_mask is None or len(drop_mask.shape) == 0: + drop_res = softmax_res.permute(0, 1, 3, 2) + dp_drop = dp + else: + drop_res = softmax_res.mul(drop_mask).mul(1.0 / keep_prob).permute(0, 1, 3, 2) + dp_drop = dp * drop_mask * (1.0 / keep_prob) + dv = torch.matmul(drop_res, dx) + softmax_grad_res = (softmax_grad(dp_drop, softmax_res) * scalar_value) + dq = torch.matmul(softmax_grad_res, k) + dk = torch.matmul(softmax_grad_res.permute(0, 1, 3, 2), q) + return dq, dk, dv + + +def parse_bsnd_args(query, key, head_num, input_layout): + supported_input_layout = ["BSH", "SBH", "BSND", "BNSD", "TND"] + b, s1, s2, n1, n2, d, h1, h2 = None, None, None, head_num, None, None, None, None + + if not isinstance(input_layout, str) or input_layout not in supported_input_layout: + raise ValueError(f"Invalid input_layout arg which must be one of {supported_input_layout}.") + + if input_layout == "TND": + raise ValueError(f"input_layout {input_layout} does not supported for now.") + try: + if input_layout == "BSH": + b, s1, h1 = query.shape + _, s2, h2 = key.shape + d = h1 // n1 + n2 = h2 // d + elif input_layout == "SBH": + s1, b, h1 = query.shape + s2, _, h2 = key.shape + d = h1 // n1 + n2 = h2 // d + elif input_layout == "BSND": + b, s1, n1, d = query.shape + _, s2, n2, _ = key.shape + h1 = n1 * d + h2 = n2 * d + elif input_layout == "BNSD": + b, n1, s1, d = query.shape + _, n2, s2, _ = key.shape + h1 = n1 * d + h2 = n2 * d + except Exception as e: + raise ValueError(f"query.shape: {query.shape}, key.shape: {key.shape}, parse_bsnd_args error: {e}") from e + + if d == 0: + raise ValueError(f"Value d must be non-zero.") + _dtype = query.dtype + ret = (b, s1, s2, n1, n2, d, h1, h2, _dtype) + return ret + + +def convert_from_bnsd(_input, input_layout): + """ + transform qkv from bnsd to input_layout. + B: batch_size + S: sequence_length + N: num_heads + D: head_dim + Args: + _input (torch.Tensor): tensor of shape (B,N,S,D) + input_layout (str): "BSH" or "SBH" or "BSND" or "BNSD" or "TND" + Returns: + tensor of shape (B,N,S,D) or (B,S,N,D) or (S,B,H) or (B,S,H) + """ + if input_layout == "BSH": + # (B,N,S,D)=>(B,S,N*D) + out = rearrange(_input, 'b n s d -> b s (n d)').contiguous() + elif input_layout == "SBH": + # (B,N,S,D)=>(S,B,N*D) + out = rearrange(_input, 'b n s d -> s b (n d)').contiguous() + elif input_layout == "BSND": + # (B,N,S,D)=>(B,S,N,D) + out = rearrange(_input, 'b n s d -> b s n d').contiguous() + elif input_layout == "TND": + raise ValueError(f"input_layout {input_layout} does not supported for now.") + else: + out = _input + return out + + +def convert_to_bnsd(_input, n, input_layout): + """ + transform qkv from input_layout to bnsd. + B: batch_size + S: sequence_length + N: num_heads + D: head_dim + Args: + _input (torch.Tensor): tensor of shape (B,N,S,D) or (B,S,N,D) or (S,B,H) or (B,S,H) + n (int): num_heads + input_layout (str):"BSH" or "SBH" or "BSND" or "BNSD" or "TND" + Returns: + tensor of shape (B,N,S,D) + """ + if input_layout == "BSH": + # (B,S,N*D)=>(B,N,S,D) + out = rearrange(_input, 'b s (n d) -> b n s d', n=n) + elif input_layout == "SBH": + # (S,B,N*D)=>(B,N,S,D) + out = rearrange(_input, 's b (n d) -> b n s d', n=n) + elif input_layout == "BSND": + # (B,S,N,D)=>(B,N,S,D) + out = rearrange(_input, 'b s n d -> b n s d', n=n) + elif input_layout == "TND": + raise ValueError(f"input_layout {input_layout} does not supported for now.") + else: + out = _input + if out.dim() != 4: + raise ValueError(f"convert qkv format failed with input_layout {input_layout}.") + return out.to(GTYPE) + + +def convert_from_bsnd(_input, input_layout): + """ + transform qkv from bsnd to input_layout. + B: batch_size + S: sequence_length + N: num_heads + D: head_dim + Args: + _input (torch.Tensor): tensor of shape (B,S,N,D) + input_layout (str): "BSH" or "SBH" or "BSND" or "BNSD" or "TND" + Returns: + tensor of shape (B,N,S,D) or (B,S,N,D) or (S,B,H) or (B,S,H) + """ + if input_layout == "BSH": + # (B,S,N,D)=>(B,S,N*D) + out = rearrange(_input, 'b s n d -> b s (n d)').contiguous() + elif input_layout == "SBH": + # (B,S,N,D)=>(S,B,N*D) + out = rearrange(_input, 'b s n d -> s b (n d)').contiguous() + elif input_layout == "BNSD": + # (B,S,N,D)=>(B,N,S,D) + out = rearrange(_input, 'b s n d -> b n s d').contiguous() + elif input_layout == "TND": + raise ValueError(f"input_layout {input_layout} does not supported for now.") + else: + out = _input + return out + + +def convert_to_bsnd(_input, n, input_layout): + """ + transform qkv from input_layout to bsnd. + B: batch_size + S: sequence_length + N: num_heads + D: head_dim + Args: + _input (torch.Tensor): tensor of shape (B,N,S,D) or (B,S,N,D) or (S,B,H) or (B,S,H) + n (int): num_heads + input_layout (str):"BSH" or "SBH" or "BSND" or "BNSD" or "TND" + Returns: + tensor of shape (B,S,N,D) + """ + if input_layout == "BSH": + # (B,S,N*D)=>(B,S,N,D) + out = rearrange(_input, 'b s (n d) -> b s n d', n=n) + elif input_layout == "SBH": + # (S,B,N*D)=>(B,S,N,D) + out = rearrange(_input, 's b (n d) -> b s n d', n=n) + elif input_layout == "BNSD": + # (B,N,S,D)=>(B,S,N,D) + out = rearrange(_input, 'b n s d -> b s n d', n=n) + elif input_layout == "TND": + raise ValueError(f"input_layout {input_layout} does not supported for now.") + else: + out = _input + if out.dim() != 4: + raise ValueError(f"convert qkv format failed with input_layout {input_layout}.") + return out + + +def generate_attn_mask(*args): + """ + # 当sparse_mode=2、3、4时小算子到融合算子会走这个优化,反过来看就要拆解回原来的基本实现 + ===> attn_mask = torch.from_numpy(np.triu(np.ones([2048, 2048]), k=1)).to(dtype) + """ + + sparse_mode, attn_mask, b, n1, s1, s2, pre_tocken, next_tocken, dtype = args + shape = [s1, s2] + + if attn_mask is not None: + # 当FA的输入已经包含attn_mask时,可以认为已经是转换之后的mask矩阵了,有三种特殊场景,即稀疏矩阵场景,需要进行逆向还原 + if sparse_mode == 2 or sparse_mode == 3 or sparse_mode == 4: + logger.info(f"s1: {s1}, s2:{s2}, attn_mask.shape:{attn_mask.shape}, attn_mask.dtype:{attn_mask.dtype}") + + if attn_mask.dim() == 2 and attn_mask.shape[0] == 2048 and attn_mask.shape[1] == 2048: + if attn_mask.equal(torch.from_numpy(np.triu(np.ones([2048, 2048]), k=1)).to(attn_mask.dtype)): + if sparse_mode == 2: + attn_mask = torch.from_numpy(np.triu(np.ones(shape), k=1)) + elif sparse_mode == 3: + attn_mask = torch.from_numpy(np.triu(np.ones(shape), k=s2 - s1 + 1)) + elif sparse_mode == 4: + attn_mask_u = torch.from_numpy(np.triu(np.ones(shape), k=next_tocken + 1)) + attn_mask_l = torch.from_numpy(np.tril(np.ones(shape), k=-pre_tocken - 1)) + attn_mask = attn_mask_u + attn_mask_l + logger.debug(f"反向转换attn_mask {attn_mask.shape}") + return attn_mask.to(dtype) + + return attn_mask.to(dtype) + + if attn_mask is not None: + if attn_mask.dim() == 2: + if attn_mask.shape[0] != s1 or attn_mask.shape[1] != s2: + raise ValueError(f"Invalid attn_mask shape `SS` {attn_mask.shape}") + shape = [s1, s2] + elif attn_mask.dim() == 4: + if attn_mask.shape[1] == 1: + shape = [b, 1, s1, s2] if b != 1 else [1, 1, s1, s2] + else: + shape = [b, n1, s1, s2] if b != 1 else [1, n1, s1, s2] + + if sparse_mode == 0: + attn_mask_u = torch.from_numpy(np.triu(np.ones(shape), k=next_tocken + 1)) + attn_mask_l = torch.from_numpy(np.tril(np.ones(shape), k=-pre_tocken - 1)) + attn_mask = attn_mask_u + attn_mask_l + elif sparse_mode == 1: # no sparse + attn_mask = torch.from_numpy(np.zeros(shape)) + elif sparse_mode == 2: + attn_mask = torch.from_numpy(np.triu(np.ones(shape), k=1)) + elif sparse_mode == 3: + attn_mask = torch.from_numpy(np.triu(np.ones(shape), k=s2 - s1 + 1)) + elif sparse_mode == 4: + attn_mask_u = torch.from_numpy(np.triu(np.ones(shape), k=next_tocken + 1)) + attn_mask_l = torch.from_numpy(np.tril(np.ones(shape), k=-pre_tocken - 1)) + attn_mask = attn_mask_u + attn_mask_l + # 注:不会出现sparse_mode=5的情况,该情况要求必须要传入attn_mask,且attn_mask矩阵数据格式须为BNSS或B1SS, + # 因此可以认为FA的输入已经是正确的attn_mask了 + return attn_mask.to(dtype) + + +def generate_kv(key, value, n1, n2): + # N不等长适配by cdy + if not (n1 == n2): + k_new = broadcast_kv(n1, n2, key, key.dtype) + v_new = broadcast_kv(n1, n2, value, value.dtype) + else: + k_new = key + v_new = value + return k_new, v_new + + +def rebuid_softmax_by_qkv(q, k, attn_mask, pse, scalar_value): + """ + attention = softmax(QK^T/sqrt(d))V + softmax(x_i) = e^(x_i - x_max) / sum(e^(x_i - x_max)) + """ + logger.info("Using QKV to rebuild original softmax") + qk = calculate_qk(q, k, attn_mask, pse, scalar_value) + softmax_res, _, _ = softmax_forward(qk) + return softmax_res + + +def rebuild_softmax_by_max_sum(softmax_params): + """ + attention = softmax(QK^T/sqrt(d))V + softmax(x_i) = e^(x_i - x_max_i) / x_sum_i) + """ + q = softmax_params.q + k = softmax_params.k + attn_mask = softmax_params.attn_mask + pse = softmax_params.pse + scalar_value = softmax_params.scalar_value + softmax_max = softmax_params.softmax_max + softmax_sum = softmax_params.softmax_sum + logger.info("Using softmax_max and softmax_sum to rebuild original softmax") + + qk = calculate_qk(q, k, attn_mask, pse, scalar_value) + if softmax_max.shape[-1] == 0: + raise ValueError(f"softmax_max.shape[-1] must be non-zero, softmax_max.shape: {softmax_max.shape}") + repeat_dim = qk.shape[-1] // softmax_max.shape[-1] + softmax_res = torch.exp(qk.sub(softmax_max.repeat(1, 1, 1, repeat_dim))).div( + softmax_sum.repeat(1, 1, 1, repeat_dim)) + return softmax_res + + +def get_head_num(*args, **kwargs): + if kwargs.get("head_num", None): + head_num = kwargs.get("head_num") + elif len(args) >= 4: + head_num = args[3] + else: + raise ValueError(f"Unsupported npu_fusion_attention args {args}.") + return head_num + + +def get_input_layout(*args, **kwargs): + if kwargs.get("input_layout", None): + input_layout = kwargs.get("input_layout") + elif len(args) >= 5: + input_layout = args[4] + else: + raise ValueError(f"Unsupported npu_fusion_attention args {args}.") + return input_layout + + +def npu_fusion_attention_forward_patch(*args, **kwargs): + if len(args) < 2: + raise RuntimeError("npu_fusion_attention_forward_patch: length of args should greater than or equal to 2.") + + # query, key, value, head_num, input_layout + head_num = get_head_num(*args, **kwargs) + input_layout = get_input_layout(*args, **kwargs) + + b, s1, s2, n1, n2, d, h1, h2, dtype = parse_bsnd_args(args[0], args[1], head_num, input_layout) + if n1 == n2 and s1 == s2: + logger.debug(f"running case : BNSD = {b}_{n1}_{s1}_{d}, sparse = {kwargs.get('sparse_mode', 0)}") + else: + logger.debug(f"running case: BNSD = {b}_{n1}({n2})_{s1}({s2})_{d}, sparse = {kwargs.get('sparse_mode', 0)}") + if not (n1 % n2 == 0 and n1 >= n2): + raise ValueError(f"N1与N2不匹配,请检查: n1 = {n1}, n2 = {n2}.") + + dims_kwargs = { + "b": b, "s1": s1, "s2": s2, "n1": n1, "n2": n2, + "d": d, "h1": h1, "h2": h2, "dtype": dtype + } + new_kwargs = { + "keep_prob": 1, + "scalar_value": kwargs.get("scalar_value", 1 / (d ** 0.5)), + "sparse_mode": kwargs.get("sparse_mode", 0), + "prefix": kwargs.get("prefix"), + "pre_tockens": kwargs.get("pre_tockens", 2147483647), + "next_tockens": kwargs.get("next_tockens", 2147483647), + "pse": kwargs.get("pse"), + "padding_mask": kwargs.get("padding_mask"), + "attn_mask": kwargs.get("attn_mask") + } + + return args, dims_kwargs, new_kwargs + + +def npu_fusion_attention_backward_patch(*args, **kwargs): + if len(args) != 6: + raise ValueError(f"Unsupported npu_fusion_attention_grad args {args}.") + + b, s1, s2, n1, n2, d, h1, h2, dtype = parse_bsnd_args(args[0], args[1], args[4], args[5]) + if n1 == n2 and s1 == s2: + logger.info(f"running case : bnsd = {b}_{n1}_{s1}_{d}, sparse = {kwargs.get('sparse_mode', 0)}") + else: + logger.info(f"running case: bnsd = {b}_{n1}({n2})_{s1}({s2})_{d}, sparse = {kwargs.get('sparse_mode', 0)}") + if not (n1 % n2 == 0 and n1 >= n2): + raise ValueError(f"N1与N2不匹配,请检查: n1 = {n1}, n2 = {n2}.") + + dims_kwargs = { + "b": b, "s1": s1, "s2": s2, "n1": n1, "n2": n2, + "d": d, "h1": h1, "h2": h2, "dtype": dtype + } + + new_kwargs = { + "keep_prob": 1, + "scalar_value_value": kwargs.get("scalar_value_value", 1 / (d ** 0.5)), + "sparse_mode": kwargs.get("sparse_mode", 0), + "prefix": kwargs.get("prefix"), + "pre_tockens": kwargs.get("pre_tockens", 2147483647), + "next_tockens": kwargs.get("next_tockens", 2147483647), + "pse": kwargs.get("pse"), + "padding_mask": kwargs.get("padding_mask"), + "softmax_max": kwargs.get("softmax_max"), + "softmax_sum": kwargs.get("softmax_sum"), + "softmax_in": kwargs.get("softmax_in"), + "attention_in": kwargs.get("attention_in"), + "seed": kwargs.get("seed", 0), + "offset": kwargs.get("offset", 0), + "numels": kwargs.get("numels", 0), + "attn_mask": kwargs.get("attn_mask") + } + + return args, dims_kwargs, new_kwargs + + +class FlashAttentionScore(nn.Module): + def __init__(self): + super(FlashAttentionScore, self).__init__() + # You can initialize any parameters here if necessary + + def forward(self, *inputs, **kwargs): + # Extract the inputs for the attention calculation + new_args, dims_kwargs, new_kwargs = npu_fusion_attention_forward_patch(*inputs, **kwargs) + query, key, value = new_args[0], new_args[1], new_args[2] + + input_layout = get_input_layout(*inputs, **kwargs) + + n1 = dims_kwargs.get("n1") + n2 = dims_kwargs.get("n2") + s1 = dims_kwargs.get("s1") + s2 = dims_kwargs.get("s2") + b = dims_kwargs.get("b") + dtype = dims_kwargs.get("dtype") + attn_mask = new_kwargs.get("attn_mask") + keep_prob = new_kwargs.get("keep_prob") + sparse_mode = new_kwargs.get("sparse_mode") + pre_tockens = new_kwargs.get("pre_tockens") + next_tockens = new_kwargs.get("next_tokens") + pse = new_kwargs.get("real_shift") + scalar_value = new_kwargs.get("scalar_value") + + args_temp = [sparse_mode, attn_mask, b, n1, s1, s2, pre_tockens, next_tockens, dtype] + + attn_mask = generate_attn_mask(*args_temp) + query = convert_to_bnsd(query, n1, input_layout) + key = convert_to_bnsd(key, n2, input_layout) + value = convert_to_bnsd(value, n2, input_layout) + + forward_params = FaForwardParams( + q=query, + k=key, + v=value, + drop_mask=None, + attn_mask=attn_mask, + pse=pse, + scalar_value=scalar_value, + keep_prob=keep_prob + ) + + out_golden, softmax_max, softmax_sum = fusion_attention_forward(forward_params) + + # If output dimension is 5, reshape accordingly + if out_golden.dim() == 5: + out_golden = out_golden.reshape(out_golden.size(0), + out_golden.size(1) * out_golden.size(2), + out_golden.size(3), out_golden.size(4)) + + out_golden = convert_from_bnsd(out_golden, input_layout) + + # Ensure the output matches the desired layout + out_golden = out_golden.cpu(), softmax_max.repeat(1, 1, 1, 8).cpu(), softmax_sum.repeat(1, 1, 1, 8).cpu() + + return out_golden + + def backward(self, *inputs, **kwargs): + # The backward pass will be similar to what was described for the gradient computation + new_args, dims_kwargs, new_kwargs = npu_fusion_attention_backward_patch(*inputs, **kwargs) + query, key, value, dx, input_layout = new_args[0], new_args[1], new_args[2], new_args[3], new_args[5] + n1 = dims_kwargs.get("n1") + n2 = dims_kwargs.get("n2") + s1 = dims_kwargs.get("s1") + s2 = dims_kwargs.get("s2") + b = dims_kwargs.get("b") + dtype = dims_kwargs.get("dtype") + attn_mask = new_kwargs.get("attn_mask") + keep_prob = new_kwargs.get("keep_prob") + sparse_mode = new_kwargs.get("sparse_mode") + pre_tockens = new_kwargs.get("pre_tockens") + next_tockens = new_kwargs.get("next_tockens") + pse = new_kwargs.get("pse") + softmax_max = new_kwargs.get("softmax_max") + softmax_sum = new_kwargs.get("softmax_sum") + scalar_value = new_kwargs.get("scalar_value") + + args_temp = [sparse_mode, attn_mask, b, n1, s1, s2, pre_tockens, next_tockens, dtype] + attn_mask = generate_attn_mask(*args_temp) + + query = convert_to_bnsd(query, n1, input_layout) + dx = convert_to_bnsd(dx, n1, input_layout) + key = convert_to_bnsd(key, n2, input_layout) + value = convert_to_bnsd(value, n2, input_layout) + + k_new, v_new = generate_kv(key, value, n1, n2) + + if SOFTMAX_BUILD_MODE == "QKV": + softmax_res = rebuid_softmax_by_qkv(query, k_new, attn_mask, pse, scalar_value) + else: + softmax_params = RebuildSoftmaxParams(query, k_new, attn_mask, pse, scalar_value, softmax_max, softmax_sum) + softmax_res = rebuild_softmax_by_max_sum(softmax_params) + + backward_params = FaBackwardParams(dx, query, k_new, v_new, softmax_res, None, pse, scalar_value, keep_prob) + dq, dk, dv = fusion_attention_backward(backward_params) + + # Reshape as needed + if dq.dim() == 5: + dq = dq.reshape(dq.size(0), dq.size(1) * dq.size(2), dq.size(3), dq.size(4)) + if dk.dim() == 5: + dk = dk.reshape(dk.size(0), dk.size(1) * dk.size(2), dk.size(3), dk.size(4)) + if dv.dim() == 5: + dv = dv.reshape(dv.size(0), dv.size(1) * dv.size(2), dv.size(3), dv.size(4)) + + dq = convert_from_bnsd(dq, input_layout) + dk = convert_from_bnsd(dk, input_layout) + dv = convert_from_bnsd(dv, input_layout) + + return dq.cpu(), dk.cpu(), dv.cpu() diff --git a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py new file mode 100644 index 0000000000000000000000000000000000000000..e1344541e89c4dafd9d49d63e3fdea117366bdd9 --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py @@ -0,0 +1,41 @@ +# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from msprobe.mindspore.api_accuracy_checker.bench_functions.flash_attention_score import FlashAttentionScore + + +class FusionOperator: + """ + 所有融合算子的父类,定义了通用的接口和属性。 + """ + + # 初始化操作符字典 + def __init__(self): + self.flash_attention_score = None # 用于存放 FlashAttentionScore 操作符 + self._register_operators() + + def __getattr__(self, name): + """ 动态获取算子类 """ + if hasattr(self, name): + return getattr(self, name) + else: + raise AttributeError(f"'FusionOperator' object has no attribute '{name}'") + + def _register_operators(self): + """ 注册操作符到父类,以便通过 fusion.xxx 调用 """ + self.flash_attention_score = FlashAttentionScore() + + +fusion = FusionOperator() diff --git a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/data_manager.py b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/data_manager.py index 748adf7d02cafe3983fe1990b40b1e77e993698b..fc2680d68a5697dae165c70a276b21038f87fbe0 100644 --- a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/data_manager.py +++ b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/data_manager.py @@ -16,12 +16,13 @@ import os import csv -from msprobe.core.common.const import Const, CompareConst, MsCompareConst +from msprobe.core.common.const import Const, CompareConst from msprobe.core.common.file_utils import FileOpen, create_directory, write_csv, read_csv from msprobe.core.common.utils import add_time_as_suffix, MsprobeBaseException from msprobe.mindspore.api_accuracy_checker.base_compare_algorithm import compare_algorithms from msprobe.core.common.file_utils import check_file_or_directory_path from msprobe.mindspore.common.log import logger +from msprobe.mindspore.common.const import MsCompareConst class ResultCsvEntry: diff --git a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py index e764140badf4c107ea83044353aba19a1c412fe0..1913675ad162bf690fc0aed5fc84c245ae4f73ca 100644 --- a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +++ b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py @@ -27,10 +27,11 @@ import numpy as np from tqdm import tqdm # 本地应用/库特定导入 -from msprobe.core.common.const import Const, CompareConst, MsCompareConst +from msprobe.core.common.const import Const, CompareConst from msprobe.mindspore.api_accuracy_checker.api_accuracy_checker import ApiAccuracyChecker, BasicInfoAndStatus from msprobe.mindspore.api_accuracy_checker.multi_data_manager import MultiDataManager from msprobe.mindspore.common.log import logger +from msprobe.mindspore.common.const import MsCompareConst class MultiApiAccuracyChecker(ApiAccuracyChecker): diff --git a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py index 1680c0adaa757d9774c024af5b62a22d4676bdf3..7b319382eb4eba4abac3bd6894cc3b0262032d88 100644 --- a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +++ b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py @@ -19,7 +19,8 @@ import sys from pathlib import Path import mindspore from msprobe.mindspore.common.log import logger -from msprobe.core.common.const import Const, CompareConst, MsCompareConst +from msprobe.core.common.const import Const, CompareConst +from msprobe.mindspore.common.const import MsCompareConst import torch as mindtorch from torch import Tensor as mindtorch_tensor import torch.nn.functional as mindtorch_func @@ -33,7 +34,7 @@ def is_mindtorch(): mindtorch_check_result = False try: import torch as test_torch - from mindspore._c_expression import Tensor as MindsporeTensor + from mindspore import Tensor as MindsporeTensor except ImportError: return mindtorch_check_result tensor = test_torch.tensor(0.0) diff --git a/debug/accuracy_tools/msprobe/mindspore/code_mapping/graph.py b/debug/accuracy_tools/msprobe/mindspore/code_mapping/graph.py index 14c05ec6cec08f1fd83d8cdfaeebb22af1f2bb93..69c067de0fc6ab1e646073bd2fe962766186ab9b 100644 --- a/debug/accuracy_tools/msprobe/mindspore/code_mapping/graph.py +++ b/debug/accuracy_tools/msprobe/mindspore/code_mapping/graph.py @@ -20,7 +20,7 @@ class GraphNode: def __init__(self, name: str, pos: int = -1, unique_name: str = "", operator_name: str = "", return_variable: str = "", return_value: str = "", var_inputs: List[str] = None, has_constant_input: bool = False, - unique_id: str="", scope: str = "", code_info: List[str] = None, + unique_id: str = "", scope: str = "", code_info: List[str] = None, is_subgraph: bool = False, attrs: Union[Dict[str, str], List[str]] = None): self.name = name self.unique_name = unique_name diff --git a/debug/accuracy_tools/msprobe/mindspore/code_mapping/graph_parser.py b/debug/accuracy_tools/msprobe/mindspore/code_mapping/graph_parser.py index 9afd114d23b2d08856c24a8590cdcd8756010631..ee35750fb35c100e2025b0dcbdd9e20ef998b2ee 100644 --- a/debug/accuracy_tools/msprobe/mindspore/code_mapping/graph_parser.py +++ b/debug/accuracy_tools/msprobe/mindspore/code_mapping/graph_parser.py @@ -124,7 +124,8 @@ class Parser: scope_match = scope_pattern.search(text, end_pos) scope = scope_match.group(1) if scope_match else "" - id_pattern = re.compile(r'.*cnode_primal_attrs:\s*\{.*\b(?:forward_unique_id|unique_id):\s*\"(\d+)\".*', re.IGNORECASE) + id_pattern = re.compile(r'.*cnode_primal_attrs:' + r'\s*\{.*\b(?:forward_unique_id|unique_id):\s*\"(\d+)\".*', re.IGNORECASE) unique_id_match = id_pattern.search(text, end_pos, scope_match.start()) unique_id = unique_id_match.group(1) if unique_id_match else None diff --git a/debug/accuracy_tools/msprobe/mindspore/common/const.py b/debug/accuracy_tools/msprobe/mindspore/common/const.py index 3ca03bc3552b6e652e8d83ce5ee6dbc7100d5ace..b41dc5ce012dc5353a2f62607eabc604fda4eb3a 100644 --- a/debug/accuracy_tools/msprobe/mindspore/common/const.py +++ b/debug/accuracy_tools/msprobe/mindspore/common/const.py @@ -71,6 +71,67 @@ class Const: } +class MsCompareConst: + # api_info field + MINT = "Mint" + MINT_FUNCTIONAL = "MintFunctional" + TENSOR_API = "Tensor" + FUNCTIONAL_API = "Functional" + FUSION_API = "FUSION" + + API_NAME_STR_LENGTH = 4 + MAX_RECURSION_DEPTH = 20 + + # Mindtorch api_info field + MINDTORCH_TENSOR = "Tensor" + MINDTORCH = "Torch" + MINDTORCH_FUNC = "Functional" + MINDTORCH_NPU = "NPU" + MINDTORCH_DIST = "Distributed" + + + + MT_VALID_API_TYPES = [ + MINDTORCH, MINDTORCH_FUNC, MINDTORCH_TENSOR + ] + SUPPORTED_FUSION_LIST = ["flash_attention_score"] + + + TASK_FIELD = "task" + STATISTICS_TASK = "statistics" + FRAMEWORK = "framework" + TENSOR_TASK = "tensor" + DUMP_DATA_DIR_FIELD = "dump_data_dir" + DATA_FIELD = "data" + + # supported api yaml + SUPPORTED_API_LIST_FILE = "checker_support_api.yaml" + SUPPORTED_TENSOR_LIST_KEY = "tensor" + + # detail_csv + DETAIL_CSV_API_NAME = "API Name" + DETAIL_CSV_BENCH_DTYPE = "Bench Dtype" + DETAIL_CSV_TESTED_DTYPE = "Tested Dtype" + DETAIL_CSV_SHAPE = "Shape" + DETAIL_CSV_PASS_STATUS = "Status" + DETAIL_CSV_MESSAGE = "Message" + DETAIL_CSV_FILE_NAME = "accuracy_checking_details" + + # result_csv + RESULT_CSV_FORWARD_TEST_SUCCESS = "Forward Test Success" + RESULT_CSV_BACKWARD_TEST_SUCCESS = "Backward Test Success" + RESULT_CSV_FILE_NAME = "accuracy_checking_result" + + EPSILON = 1e-8 + + class ProcessStatus: + SUCCESS = "success" + API_NOT_FOUND = "api_not_found" + EXCEPTION_SKIP = "exception_skip" + + + + class FreeBenchmarkConst: ADD_NOISE = "add_noise" BIT_NOISE = "bit_noise" diff --git a/debug/accuracy_tools/msprobe/mindspore/common/utils.py b/debug/accuracy_tools/msprobe/mindspore/common/utils.py index 5c713e02c05f965450a341cc17df780b6de0f0eb..ded3faaa22b565ef35c17a7596782976ddf9125d 100644 --- a/debug/accuracy_tools/msprobe/mindspore/common/utils.py +++ b/debug/accuracy_tools/msprobe/mindspore/common/utils.py @@ -108,8 +108,8 @@ class MsprobeStep(ms.train.Callback): class Dropout(ops.Dropout): - def __init__(self, keep_prob=0.5, Seed0=0, Seed1=1): - super().__init__(1., Seed0, Seed1) + def __init__(self, keep_prob=0.5, seed0=0, seed1=1): + super().__init__(1., seed0, seed1) class Dropout2D(ops.Dropout2D): @@ -151,11 +151,10 @@ def is_mindtorch(): mindtorch_check_result = False try: import torch - from mindspore._c_expression import Tensor except ImportError: return mindtorch_check_result tensor = torch.tensor(0.0) - if isinstance(tensor, Tensor): + if isinstance(tensor, ms.Tensor): mindtorch_check_result = True return mindtorch_check_result diff --git a/debug/accuracy_tools/msprobe/mindspore/compare/ms_compare.py b/debug/accuracy_tools/msprobe/mindspore/compare/ms_compare.py index 1a6b2e20e4dc8e9f52b3f2de06913a80faca55a4..8509a7f38add0c2e8d3f3638f4c247895e07bd6d 100644 --- a/debug/accuracy_tools/msprobe/mindspore/compare/ms_compare.py +++ b/debug/accuracy_tools/msprobe/mindspore/compare/ms_compare.py @@ -29,7 +29,7 @@ from msprobe.core.common.utils import CompareException, check_compare_param, che from msprobe.core.compare.acc_compare import Comparator, ModeConfig from msprobe.core.compare.check import dtype_mapping from msprobe.core.compare.layer_mapping import generate_data_mapping_by_layer_mapping -from msprobe.core.compare.utils import set_stack_json_path +from msprobe.core.compare.utils import set_stack_json_path, reorder_op_x_list class MappingConfig: @@ -353,7 +353,14 @@ class MSComparator(Comparator): merge_list = self.gen_merge_list(data_json, data_name, stack_json_data) if not merge_list: continue - for op_name in merge_list[CompareConst.OP_NAME]: + + op_name_list = merge_list.get(CompareConst.OP_NAME) + summary_list = merge_list.get(Const.SUMMARY) + data_name_list = merge_list.get('data_name') + op_name_reorder, summary_reorder, data_name_reorder = reorder_op_x_list(op_name_list, + summary_list, + data_name_list) + for op_name in op_name_reorder: result[CompareConst.OP_NAME].append(op_name) if (CompareConst.INPUT_PATTERN in op_name) or (CompareConst.KWARGS_PATTERN in op_name): struct = merge_list[CompareConst.INPUT_STRUCT].pop(0) @@ -367,10 +374,10 @@ class MSComparator(Comparator): result[Const.SHAPE].append(struct[1]) if self.dump_mode == Const.MD5: result[Const.MD5].append(struct[2]) - result[Const.SUMMARY].append(merge_list[Const.SUMMARY].pop(0)) + result[Const.SUMMARY].append(summary_reorder.pop(0)) result['stack_info'].append(merge_list['stack_info'][0] if self.stack_mode else None) if self.dump_mode == Const.ALL: - result['data_name'].append(merge_list['data_name'].pop(0)) + result['data_name'].append(data_name_reorder.pop(0)) return pd.DataFrame(result) diff --git a/debug/accuracy_tools/msprobe/mindspore/debugger/debugger_config.py b/debug/accuracy_tools/msprobe/mindspore/debugger/debugger_config.py index 89e5d875df2ab1596e61dacbe2be3ad96d24cf25..92155b4ec4ebd636477ef67f1c75b43e7a82b802 100644 --- a/debug/accuracy_tools/msprobe/mindspore/debugger/debugger_config.py +++ b/debug/accuracy_tools/msprobe/mindspore/debugger/debugger_config.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -16,9 +16,11 @@ import os from msprobe.core.common.const import Const +from msprobe.core.common.exceptions import MsprobeException from msprobe.core.common.file_utils import create_directory from msprobe.mindspore.common.const import Const as MsConst from msprobe.mindspore.common.const import FreeBenchmarkConst +from msprobe.core.common.log import logger class DebuggerConfig: @@ -39,6 +41,7 @@ class DebuggerConfig: self.check_mode = task_config.check_mode self.framework = Const.MS_FRAMEWORK self.summary_mode = task_config.summary_mode + self.async_dump = common_config.async_dump if common_config.async_dump else False self.check() create_directory(self.dump_path) @@ -49,7 +52,7 @@ class DebuggerConfig: if not task_config.handler_type else task_config.handler_type) self.stage = FreeBenchmarkConst.DEFAULT_STAGE if not task_config.fuzz_stage else task_config.fuzz_stage if self.handler_type == FreeBenchmarkConst.FIX and \ - self.pert_type != FreeBenchmarkConst.DEFAULT_PERT_TYPE: + self.pert_type != FreeBenchmarkConst.DEFAULT_PERT_TYPE: raise ValueError("pert_mode must be improve_precision or empty when handler_type is fix, " f"but got {self.pert_type}.") if self.stage == Const.BACKWARD and self.handler_type == FreeBenchmarkConst.FIX: @@ -69,4 +72,27 @@ class DebuggerConfig: self.file_format = "npy" if not self.check_mode: self.check_mode = "all" + if not isinstance(self.async_dump, bool): + raise Exception("The parameters async_dump should be bool.") + if self.async_dump and self.task == Const.TENSOR and not self.list: + raise Exception("The parameters async_dump is true in tensor task, the parameters list cannot be empty.") + if self.task == Const.STRUCTURE and self.level_ori not in [Const.LEVEL_L0, Const.LEVEL_MIX]: + logger.warning_on_rank_0( + f"When the task is set to structure, the level should be one of {[Const.LEVEL_L0, Const.LEVEL_MIX]}. " + f"If not, the default level is {Const.LEVEL_MIX}." + ) + self.level_ori = Const.LEVEL_MIX return True + + def check_config_with_l2(self): + if self.level_ori != Const.LEVEL_L2: + return + if self.task != Const.TENSOR: + raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR, + f"When level is set to L2, the task must be set to tensor.") + if self.scope: + raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR, + f"When level is set to L2, the scope cannot be configured.") + if not self.list or len(self.list) != 1: + raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR, + f"When level is set to L2, the list must be configured as a list with one api name.") diff --git a/debug/accuracy_tools/msprobe/mindspore/debugger/precision_debugger.py b/debug/accuracy_tools/msprobe/mindspore/debugger/precision_debugger.py index 844ba6b2a34dae4a285b6e085f378abfe0f22a68..a7082d3e569755c93b04c99af11ab70bf11e73d0 100644 --- a/debug/accuracy_tools/msprobe/mindspore/debugger/precision_debugger.py +++ b/debug/accuracy_tools/msprobe/mindspore/debugger/precision_debugger.py @@ -14,12 +14,15 @@ # limitations under the License. import os -from collections import defaultdict +from collections import defaultdict, namedtuple import mindspore as ms from mindspore._c_expression import MSContext -from msprobe.core.common.const import Const, MsgConst +from msprobe.core.common.const import Const, FileCheckConst, MsgConst +from msprobe.core.common.exceptions import MsprobeException +from msprobe.core.common.file_utils import FileChecker +from msprobe.core.common.utils import get_real_step_or_rank from msprobe.mindspore.cell_processor import CellProcessor from msprobe.mindspore.common.const import Const as MsConst from msprobe.mindspore.common.utils import set_register_backward_hook_functions, check_save_param @@ -39,11 +42,15 @@ except ImportError: _msprobe_c = None +ConfigParameters = namedtuple("ConfigParameters", ["config_path", "task", "dump_path", "level"]) + + class PrecisionDebugger: _instance = None task_not_need_service = [Const.GRAD_PROBE] - def __new__(cls, config_path=None, opt=None): + def __new__(cls, config_path=None, task=None, dump_path=None, + level=None, step=None, opt=None): if not cls._instance: cls._instance = super().__new__(cls) cls._instance.initialized = False @@ -52,7 +59,8 @@ class PrecisionDebugger: cls.first_start = False return cls._instance - def __init__(self, config_path=None): + def __init__(self, config_path=None, task=None, dump_path=None, + level=None, step=None): if self.initialized: return self.initialized = True @@ -61,11 +69,20 @@ class PrecisionDebugger: if not config_path: config_path = os.path.join(os.path.dirname(__file__), "../../config.json") + + config_params = ConfigParameters(config_path, task, dump_path, level) + self.check_input_params(config_params) + common_config, task_config = parse_json_config(config_path) + common_config.task = task if task else common_config.task self.task = common_config.task if self.task == Const.GRAD_PROBE: self.gm = GradientMonitor(common_config, task_config) return + common_config.step = get_real_step_or_rank( + step, Const.STEP) if step is not None else common_config.step + common_config.level = level if level else common_config.level + common_config.dump_path = dump_path if dump_path else common_config.dump_path self.config = DebuggerConfig(common_config, task_config) if _msprobe_c: @@ -73,11 +90,35 @@ class PrecisionDebugger: self.config.execution_mode = self._get_execution_mode() if self._need_service(): + self.config.check_config_with_l2() self.service = Service(self.config) Runtime.step_count = 0 Runtime.is_running = False + @staticmethod + def check_input_params(args): + if args.config_path is not None: + if not isinstance(args.config_path, str): + raise MsprobeException( + MsprobeException.INVALID_PARAM_ERROR, f"config_path must be a string") + file_checker = FileChecker( + file_path=args.config_path, path_type=FileCheckConst.FILE, file_type=FileCheckConst.JSON_SUFFIX) + file_checker.common_check() + + if args.task is not None and args.task not in Const.TASK_LIST: + raise MsprobeException( + MsprobeException.INVALID_PARAM_ERROR, f"task must be one of {Const.TASK_LIST}") + + if args.dump_path is not None: + if not isinstance(args.dump_path, str): + raise MsprobeException( + MsprobeException.INVALID_PARAM_ERROR, f"dump_path must be a string") + + if args.level is not None and args.level not in Const.LEVEL_LIST: + raise MsprobeException( + MsprobeException.INVALID_PARAM_ERROR, f"level must be one of {Const.LEVEL_LIST}") + @staticmethod def _get_execution_mode(): jit_level = ms.context.get_jit_config().get(MsConst.JIT_LEVEL) @@ -96,6 +137,16 @@ class PrecisionDebugger: else: return MsConst.PYNATIVE_MODE + @staticmethod + def _is_graph_dump(config): + if config.level != MsConst.KERNEL: + return False + if not config.list: + return True + is_graph = any(item.startswith("name-regex") for item in config.list) + is_graph |= all("." not in item for item in config.list) + return is_graph + @classmethod def start(cls, model=None): instance = cls._instance @@ -194,4 +245,4 @@ class PrecisionDebugger: if instance.config.execution_mode != MsConst.PYNATIVE_MODE: return False else: - return instance.config.task != Const.FREE_BENCHMARK and instance.config.level != MsConst.KERNEL \ No newline at end of file + return instance.config.task != Const.FREE_BENCHMARK and not instance._is_graph_dump(instance.config) \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py index 6025930280c73587a7cf3d5f029f0d392f3d7c86..7aee1deccd9689985c7a2e270648bd0877cd7cf3 100644 --- a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py +++ b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py @@ -106,6 +106,7 @@ class ApiRegistry: self.set_api_attr(torch.Tensor, self.torch_tensor_hook_attr) self.set_api_attr(torch.nn.functional, self.torch_functional_hook_attr) self.set_api_attr(torch.distributed, self.torch_distributed_hook_attr) + self.set_api_attr(torch.distributed.distributed_c10d, self.torch_distributed_hook_attr) self.set_api_attr(torch_npu, self.torch_npu_hook_attr) else: self.set_api_attr(Tensor, self.tensor_hook_attr) @@ -121,6 +122,7 @@ class ApiRegistry: self.set_api_attr(torch.Tensor, self.torch_tensor_ori_attr) self.set_api_attr(torch.nn.functional, self.torch_functional_ori_attr) self.set_api_attr(torch.distributed, self.torch_distributed_ori_attr) + self.set_api_attr(torch.distributed.distributed_c10d, self.torch_distributed_ori_attr) self.set_api_attr(torch_npu, self.torch_npu_ori_attr) else: self.set_api_attr(Tensor, self.tensor_ori_attr) diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/hook_cell.py b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/hook_cell.py index 59909d3e212fa0decef5326494c1f826c8fadc27..b68a7d995a56497a219281c5a43d692c46cfac4d 100644 --- a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/hook_cell.py +++ b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/hook_cell.py @@ -40,10 +40,11 @@ def __init__(self, build_hook) -> None: self.prefix = self.prefix_api_name self.forward_data_collected = False - forward_pre_hook, forward_hook, backward_hook = build_hook(self.prefix) + forward_pre_hook, forward_hook, backward_hook, backward_pre_hook = build_hook(self.prefix) self.register_forward_pre_hook(forward_pre_hook) self.register_forward_hook(forward_hook) register_backward_hook_functions["full"](self, backward_hook) + register_backward_hook_functions["pre"](self, backward_pre_hook) # 重载call,加全局标志。 diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/wrap_api.py b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/wrap_api.py index cea8614610af6152b208e1173da635ef800a9722..0e97929ecd7f8444b19fd531efc49883d0df58de 100644 --- a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/wrap_api.py +++ b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/wrap_api.py @@ -87,9 +87,16 @@ class ApiTemplate(HOOKCell): super().__init__(hook) @staticmethod - def async_to_sync(handle): - if hasattr(handle, "wait"): - handle.wait() + def async_to_sync(output): + # Fake handle, used to return after the CommHandle executes the wait method + fake_handle = type("FakeHandle", (), {"wait": lambda self: None})() + if isinstance(output, tuple) and len(output) == 2 and hasattr(output[1], "wait"): + output[1].wait() + output = (output[0], fake_handle) + elif hasattr(output, "wait"): + output.wait() + output = fake_handle + return output def construct(self, *args, **kwargs): if self.api_name.startswith(MsConst.DROPOUT_API_NAME_PREFIX): @@ -99,9 +106,7 @@ class ApiTemplate(HOOKCell): if self.prefix_api_name.startswith(MsConst.DISTRIBUTED_DATA_PREFIX): if kwargs.get("async_op") or self.api_name in ["isend", "irecv"]: - self.async_to_sync( - output[1] if isinstance(output, tuple) and len(output) > 1 else output) - + output = self.async_to_sync(output) return output def forward(self, *args, **kwargs): diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/jit_dump.py b/debug/accuracy_tools/msprobe/mindspore/dump/jit_dump.py index 4eee5eed724dbf63178069672d88db273e74f719..0a32200639a1f3805f815c37caaef5d3bb64c82f 100644 --- a/debug/accuracy_tools/msprobe/mindspore/dump/jit_dump.py +++ b/debug/accuracy_tools/msprobe/mindspore/dump/jit_dump.py @@ -16,7 +16,6 @@ import os from collections import defaultdict -from mindspore import Tensor from mindspore._c_expression import PyNativeExecutor_ try: from mindspore.common.api import _MindsporeFunctionExecutor @@ -24,9 +23,8 @@ except ImportError: from mindspore.common.api import _JitExecutor as _MindsporeFunctionExecutor from msprobe.core.common.log import logger -from msprobe.core.data_dump.data_processor.base import ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs from msprobe.core.common.const import Const -from msprobe.core.data_dump.data_processor.base import ModuleForwardInputsOutputs +from msprobe.core.data_dump.data_processor.base import ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs from msprobe.mindspore.dump.hook_cell.api_registry import api_register @@ -43,8 +41,8 @@ def dump_jit(name, in_feat, out_feat, is_forward): if JitDump.need_dump(): if is_forward: JitDump.jit_count[result] += 1 - name_template = Const.JIT + Const.SEP + result + Const.SEP + str(JitDump.jit_count[result]) + Const.SEP + \ - Const.FORWARD + name_template = (Const.JIT + Const.SEP + result + Const.SEP + + str(JitDump.jit_count[result]) + Const.SEP + Const.FORWARD) JitDump.data_collector.update_api_or_module_name(name_template) module_input_output = ModuleForwardInputsOutputs(args=in_feat, kwargs={}, output=out_feat) JitDump.data_collector.forward_data_collect(name_template, None, pid, module_input_output) diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/kernel_dump/kernel_config.py b/debug/accuracy_tools/msprobe/mindspore/dump/kernel_dump/kernel_config.py new file mode 100644 index 0000000000000000000000000000000000000000..aff10d79dc8879e3a5a4053f8e61d9bddc225f71 --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/dump/kernel_dump/kernel_config.py @@ -0,0 +1,33 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from msprobe.core.common.file_utils import save_json + + +def create_kernel_config_json(dump_path, cur_rank): + kernel_config_name = "kernel_config.json" if cur_rank == '' else f"kernel_config_{cur_rank}.json" + kernel_config_path = os.path.join(dump_path, kernel_config_name) + config_info = { + "dump": { + "dump_list": [], + "dump_path": dump_path, + "dump_mode": "all", + "dump_op_switch": "on" + } + } + save_json(kernel_config_path, config_info, indent=4) + return kernel_config_path diff --git a/debug/accuracy_tools/msprobe/mindspore/free_benchmark/api_pynative_self_check.py b/debug/accuracy_tools/msprobe/mindspore/free_benchmark/api_pynative_self_check.py index 51088ae74f14417ccbbedd7773f90af7772fe8d9..57b7de4fa567d73a19178256d79f5e4cbeb38864 100644 --- a/debug/accuracy_tools/msprobe/mindspore/free_benchmark/api_pynative_self_check.py +++ b/debug/accuracy_tools/msprobe/mindspore/free_benchmark/api_pynative_self_check.py @@ -99,7 +99,10 @@ class ApiPyNativeSelfCheck: def wrap_backward_hook(cell, grad_input, grad_output): return backward_hook(cell, grad_input, grad_output) - return pre_hook, wrap_forward_hook, wrap_backward_hook + def pre_backward_hook(cell, grad_input): + return None + + return pre_hook, wrap_forward_hook, wrap_backward_hook, pre_backward_hook def store_original_func(self): for api_name in self.api_list: diff --git a/debug/accuracy_tools/msprobe/mindspore/grad_probe/grad_analyzer.py b/debug/accuracy_tools/msprobe/mindspore/grad_probe/grad_analyzer.py index c875d52794fea28d77532a53078afde5abbd51e9..8a154f4d65f63e55f6b0cf3165d3c905bcb68546 100644 --- a/debug/accuracy_tools/msprobe/mindspore/grad_probe/grad_analyzer.py +++ b/debug/accuracy_tools/msprobe/mindspore/grad_probe/grad_analyzer.py @@ -16,6 +16,7 @@ import multiprocessing import os import time +from dataclasses import dataclass from multiprocessing import Process from typing import List @@ -23,6 +24,7 @@ import mindspore as ms import numpy as np from mindspore.common.parameter import Parameter from mindspore.communication import get_rank + from msprobe.core.common.file_utils import (create_directory, check_file_or_directory_path, write_csv, remove_path, move_file, load_npy) from msprobe.core.grad_probe.constant import GradConst @@ -31,6 +33,16 @@ from msprobe.mindspore.common.log import logger from msprobe.mindspore.grad_probe.global_context import grad_context, GlobalContext +@dataclass +class GradDumpConfig: + dump_dir: str + g_name: str + dump_step: Parameter + grad: ms.Tensor + level: str + bounds: List + + def get_rank_id(): try: rank_id = get_rank() @@ -40,35 +52,35 @@ def get_rank_id(): @ms.jit -def grad_dump(dump_dir: str, g_name: str, dump_step: Parameter, grad: ms.Tensor, level: str, bounds: List): +def grad_dump(config: GradDumpConfig): """ Dump gradient statistic data. level0: [step, max, min, norm, shape_dim, shape] level1: [step, max, min, norm, shape_dim, shape] + grad_bool_data level2: [step, max, min, norm, shape_dim, shape, dist_dim, dist] + grad_bool_data """ - dump_path = os.path.join(dump_dir, g_name) + dump_path = os.path.join(config.dump_dir, config.g_name) dump_dir_path = dump_path + "_dir" save_op = ms.ops.TensorDump() - grad_flat = grad.reshape(-1) + grad_flat = config.grad.reshape(-1) max_val = grad_flat.max(axis=0).float() min_val = grad_flat.min(axis=0).float() norm_val = grad_flat.norm(ord=2).float() - shape = grad.shape - extrem_list = [dump_step[0].float(), max_val, min_val, norm_val] + shape = config.grad.shape + extrem_list = [config.dump_step[0].float(), max_val, min_val, norm_val] extrem_stat = ms.ops.stack(extrem_list) shape_list = [len(shape)] + list(shape) shape_stat = ms.Tensor(shape_list).float() level0_stat = ms.ops.concat((extrem_stat, shape_stat), axis=0) level_stat = level0_stat - if level == GradConst.LEVEL2: - zero_grad = (grad == 0).sum() - dist_dim = ms.Tensor([len(bounds) + 2]).float() - bucket_result = ms.ops.bucketize(grad.float(), bounds) + if config.level == GradConst.LEVEL2: + zero_grad = (config.grad == 0).sum() + dist_dim = ms.Tensor([len(config.bounds) + 2]).float() + bucket_result = ms.ops.bucketize(config.grad.float(), config.bounds) bucket_result = bucket_result.astype(ms.int8) - dist_stat = [(bucket_result == i).sum() for i in range(len(bounds) + 1)] + dist_stat = [(bucket_result == i).sum() for i in range(len(config.bounds) + 1)] dist_stat.append(zero_grad) dist_stat.append(ms.Tensor(1, dtype=ms.int64)) # make sure dist_stat is not empty dist_stat = ms.ops.stack(dist_stat, axis=0).float() @@ -76,8 +88,8 @@ def grad_dump(dump_dir: str, g_name: str, dump_step: Parameter, grad: ms.Tensor, level_stat = level2_stat save_op(dump_path, level_stat) - if level == GradConst.LEVEL1 or level == GradConst.LEVEL2: - grad_direction = grad > 0 + if config.level == GradConst.LEVEL1 or config.level == GradConst.LEVEL2: + grad_direction = config.grad > 0 save_op(dump_dir_path, grad_direction) diff --git a/debug/accuracy_tools/msprobe/mindspore/grad_probe/hook.py b/debug/accuracy_tools/msprobe/mindspore/grad_probe/hook.py index 7006708a53867db91cdd0e39300383720189802c..1aa9fcfad10815d5845de66ab0ea6d4d7211741f 100644 --- a/debug/accuracy_tools/msprobe/mindspore/grad_probe/hook.py +++ b/debug/accuracy_tools/msprobe/mindspore/grad_probe/hook.py @@ -26,7 +26,7 @@ from msprobe.core.grad_probe.constant import GradConst from msprobe.mindspore.common.log import logger from msprobe.mindspore.grad_probe.global_context import grad_context from msprobe.mindspore.grad_probe.grad_analyzer import csv_generator -from msprobe.mindspore.grad_probe.grad_analyzer import grad_dump, get_rank_id +from msprobe.mindspore.grad_probe.grad_analyzer import grad_dump, get_rank_id, GradDumpConfig from msprobe.mindspore.grad_probe.grad_stat_csv import GradStatCsv, CsvInput from msprobe.mindspore.grad_probe.utils import save_grad_direction, get_adapted_level @@ -66,8 +66,10 @@ def hook_graph_mode_optimizer(opt, hook_input): for index, grad_value in enumerate(gradients): if hook_input.param_list and hook_input.g_names[index] not in hook_input.param_list: continue - grad_dump(hook_input.dump_dir, hook_input.g_names[index], self.dump_step, - grad_value, hook_input.level, hook_input.bounds) + conf = GradDumpConfig(dump_dir=hook_input.dump_dir, g_name=hook_input.g_names[index], + dump_step=self.dump_step, grad=grad_value, level=hook_input.level, + bounds=hook_input.bounds) + grad_dump(conf) ms.ops.TensorDump()(hook_input.step_finish_flag, self.dump_step) self.assignadd(self.dump_step, self.global_step_increase_tensor) out = hook_input.func(gradients) diff --git a/debug/accuracy_tools/msprobe/mindspore/mindtorch/mindtorch_adaptor.py b/debug/accuracy_tools/msprobe/mindspore/mindtorch/mindtorch_adaptor.py index 5b3b7c3fd85f810216afdf2a5b6dd6d09578d26c..27e42d52ba6190ec7e7531af25464e6aa3996b2b 100644 --- a/debug/accuracy_tools/msprobe/mindspore/mindtorch/mindtorch_adaptor.py +++ b/debug/accuracy_tools/msprobe/mindspore/mindtorch/mindtorch_adaptor.py @@ -177,7 +177,6 @@ def _call_impl(self, *args, **kwargs): result = apply_backward_hook_on_tensors(bw_pre_hook, result) return result - except Exception: # run always called hooks if they have not already been run # For now only forward hooks have the always_call option but perhaps diff --git a/debug/accuracy_tools/msprobe/mindspore/monitor/anomaly_detect.py b/debug/accuracy_tools/msprobe/mindspore/monitor/anomaly_detect.py new file mode 100644 index 0000000000000000000000000000000000000000..3544ebbd025614349585bc799b15e00a5c2c7956 --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/monitor/anomaly_detect.py @@ -0,0 +1,404 @@ +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +import os +import sys +import statistics as st +from abc import ABC +from dataclasses import dataclass, field +from typing import List +from collections import defaultdict + +import pandas as pd + +from mindspore import ops +from mindspore import _no_grad +from msprobe.core.common.log import logger +from msprobe.core.common.file_utils import change_mode, create_directory, write_df_to_csv +from msprobe.core.common.const import FileCheckConst, MonitorConst + + +class ScanRule(ABC): + name = "ScanRule" + + def apply(self, history, cur): + raise NotImplementedError("abstract method apply is not implemented") + + +class AnomalyTurbulence(ScanRule): + name = "AnomalyTurbulence" + + def __init__(self, threshold) -> None: + self.threshold = threshold + + def apply(self, history, cur): + baseline = st.mean(history) if isinstance(history, list) else history + + up_bound = baseline + baseline * self.threshold + if baseline > 0: + return cur > up_bound + else: + return cur < up_bound + + +class AnomalyScanner: + + @staticmethod + def load_rules(specs: List[dict]): + """ + specs: [{"rule_name": "AnomalyTurbulence", "args": {"threshold": 0.5}}] + """ + if specs is None: + return [] + alert_rules = [] + for spec in specs: + # 使用get方法获取键值,如果键不存在则返回None + rule_cls_name = spec.get("rule_name") + rule_args = spec.get("args") + + # 检查必要的键是否存在 + if rule_cls_name is None or rule_args is None: + logger.warning(f"Spec is missing required keys: {spec}") + continue + + cur_module = sys.modules.get(__name__) + try: + rule_cls = getattr(cur_module, rule_cls_name) + except AttributeError: + logger.error(f"Rule class '{rule_cls_name}' not found in the current module.") + continue + + try: + rule_instance = rule_cls(**rule_args) + alert_rules.append(rule_instance) + except Exception as e: + logger.error(f"Error creating instance of rule '{rule_cls_name}': {e}") + continue + + return alert_rules + + @staticmethod + def scan(scan_rules: List[ScanRule], history, cur): + anomaly = False + for rule in scan_rules: + anomaly = rule.apply(history, cur) + if anomaly: + return anomaly, rule.name + return anomaly, None + + +class BCOLORS: + HEADER = '\033[95m' + OKBLUE = '\033[94m' + OKCYAN = '\033[96m' + OKGREEN = '\033[92m' + WARNING = '\033[93m' + FAIL = '\033[91m' + ENDC = '\033[0m' + BOLD = '\033[1m' + UNDERLINE = '\033[4m' + + +class AnomalyDataFactory(ABC): + def __init__(self, rank, pp_stage, group_mates): + super().__init__() + self.rank = rank + self.pp_stage = pp_stage + self.group_mates = group_mates + self.micro_step = 0 + self.name2callid = {} + + def set_call_id(self, name2callid): + """根据当前GradContext信息更新call_id vpp_stage等信息 + """ + self.name2callid = name2callid + + def create(self, tag, message, step): + """如果检查出异常, 调用当前接口生成GradAnomalyData实例 + tag (tuple): metric tag ('0:1.post_attention_norm.weight/rank0/pre_grad', 'min') + message (str): anomaly detect message + step (int): training step + """ + if not isinstance(tag, tuple) or len(tag) != 2: + raise ValueError("tag must be a tuple with length 2") + tag_name = tag[0] + param_name = tag_name.split('/')[0] + call_id = self.name2callid.get(tag_name, -1) + if MonitorConst.NAME_SEP in param_name: + vpp_stage = int(param_name.split(MonitorConst.NAME_SEP)[0]) + else: + vpp_stage = 0 + + return GradAnomalyData( + self.rank, + step, + self.micro_step, + self.pp_stage, + vpp_stage, + call_id, + tag_name, + message, + self.group_mates + ) + + +class TrainStage: + DEFAULT_STAGE = -1 + FORWARD_STAGE = 0 + BACKWARD_STAGE = 1 + OPTIMIZER_STAGE = 2 + + +FORWARD_KEY = [MonitorConst.ACTV_IN, MonitorConst.ACTV_OUT] +BACKWARD_KEY = [MonitorConst.ACTVGRAD_IN, MonitorConst.ACTVGRAD_OUT, + MonitorConst.PRE_GRAD, MonitorConst.POST_GRAD, MonitorConst.ACC_GRAD] +OPTIMIZER_KEY = [MonitorConst.EXP_AVG, MonitorConst.EXP_AVG_SQ] +TRAIN_STAGE = { + **{key_: TrainStage.FORWARD_STAGE for key_ in FORWARD_KEY}, + **{key_: TrainStage.BACKWARD_STAGE for key_ in BACKWARD_KEY}, + **{key_: TrainStage.OPTIMIZER_STAGE for key_ in OPTIMIZER_KEY} +} + + +@dataclass(eq=True) +class GradAnomalyData: + rank: int = 0 + step: int = 0 + micro_step: int = 0 + pp_stage: int = 0 + vpp_stage: int = 0 + call_id: int = 0 + tag_name: str = field(default=None, compare=False) + message: str = field(default="", compare=False) + group_mates: list = field(default=None, compare=False) + + def __lt__(self, other): + """ + 自定义比较函数,用于确定 GradAnomalyData 实例之间的顺序。 + 比较规则为: + step 和 micro_step 值越小优先级越高; + vpp 和 pp 在前向阶段值越小优先级越高,在非前向阶段值越大优先级越高; + call_id 值越小优先级越高。 + """ + if not isinstance(other, GradAnomalyData): + return NotImplemented + + self_train_stage = self.get_train_stage(self.tag_name) + other_train_stage = self.get_train_stage(other.tag_name) + + def vpp_pp_comparator(anomaly): + """ + Determine the priority rule for vpp and pp based on train stage + Forward stage prefers smaller vpp and pp + Other stages prefer larger vpp and pp + """ + if self_train_stage == TrainStage.FORWARD_STAGE: + return anomaly.vpp_stage, anomaly.pp_stage + else: + return -anomaly.vpp_stage, -anomaly.pp_stage + + self_cmp = [self.step, self.micro_step, self_train_stage, *vpp_pp_comparator(self), self.call_id] + other_cmp = [other.step, other.micro_step, other_train_stage, *vpp_pp_comparator(other), other.call_id] + return self_cmp < other_cmp + + def __le__(self, other): + if not isinstance(other, GradAnomalyData): + return NotImplemented + return self == other or self < other + + @staticmethod + def get_train_stage(tag_name): + """ + :param tag_name: "0:fc2_0/rank0/input", "0:fc1.weight/rank0/post_grad", "0:fc2.weight/rank0/exp_avg_sq" + :return: int, if forward return 0; if backward return 1; if optimizer return 2 + """ + key_ = tag_name.split("/")[-1] + return TRAIN_STAGE.get(key_, TrainStage.DEFAULT_STAGE) + + def to_dict(self): + return self.__dict__ + + def get_key(self): + # 0:1.self_attention.core_attention_flash_0/rank0/input_grad + return ''.join([str(self.tag_name), "_step_", str(self.step), "_call_", str(self.call_id)]) + + +@dataclass +class WriterInput: + path: str + ad_rules: list + job_id: str + anomaly_factory: AnomalyDataFactory = None + ndigits: int = 6 + step_count_per_record: int = 1 + + +class BaseWriterWithAD: + def __init__(self, writer_input: WriterInput): + self.tag2scalars = {} + self.ad_rules = writer_input.ad_rules + self.job_id = writer_input.job_id + self.anomaly_factory = writer_input.anomaly_factory + self.anomalies = [] + self.ndigits = writer_input.ndigits + + def get_anomalies(self): + """返回已检测到的异常列表 + """ + return self.anomalies + + def clear_anomalies(self): + self.anomalies.clear() + + def add_scalar(self, tag, scalar_value, global_step=None, need_explain=False): + """If an anomaly is detected, the anomaly information is recorded and added to self.anomalies. + Args: + tag (tuple): tuple of tag_name and tag like ('0:1.post_attention_norm.weight/rank0/pre_grad', 'min'). + scalar_value (float): scalar_value. + global_step (int): global_step. + Returns: + None + """ + detected = False + if self.ad_rules: + avg = self._update_tag2scalars(tag, scalar_value) + detected, rule_name = self._ad(scalar_value, history=avg) + if detected: + exception_message = f"Rule {rule_name} reports anomaly signal in {tag} at step {global_step}." + logger.info(f"{BCOLORS.WARNING}> {exception_message}{BCOLORS.ENDC}") + # append to self.anomalies for dump + if self.anomaly_factory: + self.anomalies.append(self.anomaly_factory.create(tag, exception_message, global_step)) + + def write_metrics(self, op_list, metric_value, step, prefix='', need_explain=False): + if not metric_value: + return + tensors = [] + tags = list(itertools.product(metric_value.keys(), op_list)) + for op2tensor in metric_value.values(): + tensors.extend(op2tensor.values()) + with _no_grad(): + metric_list = ops.stack(tensors).tolist() if tensors else [] + for tag, metric in zip(tags, metric_list): + self.add_scalar(tag, metric, step, need_explain) + + def _ad(self, scalar_value, history): + return AnomalyScanner.scan(self.ad_rules, history, cur=scalar_value) + + def _update_tag2scalars(self, tag, scalar_value): + """Update the average and count of a scalar value associated with a tag. + + This method is used to maintain a running average of scalar values for each tag. + + + Args: + tag (str): The tag identifier. + scalar_value (float): The scalar value to be added. + + Returns: + float: The average value before update. + """ + if tag not in self.tag2scalars: + self.tag2scalars[tag] = {'avg': scalar_value, 'count': 0} + avg = self.tag2scalars[tag]['avg'] + new_avg = (avg * self.tag2scalars[tag]['count'] + scalar_value) / (self.tag2scalars[tag]['count'] + 1) + self.tag2scalars[tag]['avg'] = new_avg + self.tag2scalars[tag]['count'] += 1 + return avg + + +class CSVWriterWithAD(BaseWriterWithAD): + def __init__(self, writer_input: WriterInput): + super().__init__(writer_input) + + path = writer_input.path + self.log_dir = path + create_directory(path) + change_mode(path, FileCheckConst.DATA_DIR_AUTHORITY) + self.context_dict = defaultdict(list) + self.header = [] + self.step_count_per_record = writer_input.step_count_per_record + + def get_step_interval(self, step): + count = step // self.step_count_per_record + return count * self.step_count_per_record, (count + 1) * self.step_count_per_record - 1 + + def write_csv(self, prefix, step): + """ + Args: + prefix[str]: prefix of output csv file e.g. grad_unreduced + step[int] + """ + if len(self.context_dict) == 0: + return + + ster_start, step_end = self.get_step_interval(step) + filepath = os.path.join(self.log_dir, f'{prefix}_{ster_start}-{step_end}.csv') + if not os.path.exists(filepath): + data_frame = pd.DataFrame(columns=self.header) + write_df_to_csv(data_frame, filepath) + + new_data = [] + for name, metric_value in self.context_dict.items(): + if MonitorConst.NAME_SEP not in name: + new_data.append([name] + [step] + metric_value) + else: + new_data.append(name.split(MonitorConst.NAME_SEP) + [step] + metric_value) + new_data = pd.DataFrame(new_data).round(self.ndigits) + write_df_to_csv(new_data, filepath, mode='a+', header=False) + self.context_dict = defaultdict(list) + + def add_scalar(self, tag, scalar_value, global_step, need_explain=False): + """ + ('0:1.post_attention_norm.weight/rank0/pre_grad', 'min') + """ + super().add_scalar(tag, scalar_value, global_step, need_explain=False) + split_name = tag[0].split('/') + name = split_name[0] + if need_explain: + if 'pre' in split_name[-1]: + name += '.input' + if 'post' in split_name[-1]: + name += '.output' + self.context_dict[name].append(scalar_value) + + def write_metrics(self, op_list, metric_value, step, prefix='', need_explain=False): + need_explain = prefix == 'other' + super().write_metrics(op_list, metric_value, step, prefix='', need_explain=need_explain) + + # generate csv headers + # set hashmap to reduce the number of headers generated. + # 前向的norm用input.ops_和output.ops_,反向的用input_grad.ops_和output_grad.ops_ + if prefix in {"actv", "actv_grad"}: + if prefix == "actv": + input_and_output = [MonitorConst.ACTV_IN, MonitorConst.ACTV_OUT] + else: + input_and_output = [MonitorConst.ACTVGRAD_IN, MonitorConst.ACTVGRAD_OUT] + ops_ = [MonitorConst.DOT.join(i) for i in itertools.product(input_and_output, op_list)] + csv_header = ["module_name", "step", *ops_] + else: + csv_header = ["param_name", "step", *op_list] + + keys = list(metric_value.keys()) + if keys and MonitorConst.NAME_SEP in keys[0]: + csv_header.insert(0, "vpp_stage") + + self.header = csv_header + self.write_csv(prefix, step) + self.header = [] + + def close(self): + pass diff --git a/debug/accuracy_tools/msprobe/pytorch/functional/__init__.py b/debug/accuracy_tools/msprobe/mindspore/monitor/distributed/__init__.py similarity index 100% rename from debug/accuracy_tools/msprobe/pytorch/functional/__init__.py rename to debug/accuracy_tools/msprobe/mindspore/monitor/distributed/__init__.py diff --git a/debug/accuracy_tools/msprobe/mindspore/monitor/distributed/distributed_ops.yaml b/debug/accuracy_tools/msprobe/mindspore/monitor/distributed/distributed_ops.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6f336b2fffd81c3e3aa60a4dec1c743e31f2609b --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/monitor/distributed/distributed_ops.yaml @@ -0,0 +1,15 @@ +communication.comm_func: + - all_reduce + - all_gather_into_tensor + - reduce + - reduce_scatter_tensor + - all_to_all_single_with_output_shape + - all_to_all_with_output_shape + - batch_isend_irecv + - broadcast + - gather_into_tensor + - scatter_tensor + - send + - recv + - isend + - irecv \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/mindspore/monitor/distributed/stack_blacklist.yaml b/debug/accuracy_tools/msprobe/mindspore/monitor/distributed/stack_blacklist.yaml new file mode 100644 index 0000000000000000000000000000000000000000..068935cebec687497d75b688fad228866a0b3622 --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/monitor/distributed/stack_blacklist.yaml @@ -0,0 +1,5 @@ +stack: +- msprobe/mindspore/monitor/distributed +- site-packages/mindspore/nn/cell.py +- multiprocessing +- debugpy \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/mindspore/monitor/distributed/wrap_distributed.py b/debug/accuracy_tools/msprobe/mindspore/monitor/distributed/wrap_distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..33fd58c7278c6245140e50a984f44e59b90c69de --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/monitor/distributed/wrap_distributed.py @@ -0,0 +1,300 @@ +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import os +import re + +import numpy as np + +from mindspore import nn, Tensor, ops, _no_grad +from mindspore import communication +from mindspore.communication import comm_func, get_rank + +from msprobe.core.common.const import MonitorConst, Const +from msprobe.core.common.file_utils import load_yaml +from msprobe.mindspore.monitor.utils import get_metrics, get_summary_writer_tag_name + +enable_communication = True +try: + from mindspore._c_expression import CommHandle as CommHandle_ +except ImportError: + enable_communication = False + + +RANK = None + +OpsPath = os.path.join(os.path.dirname(__file__), "distributed_ops.yaml") +WrapDistributedOps = load_yaml(OpsPath).get("communication.comm_func", []) + +StackBlackListPath = os.path.join(os.path.dirname(__file__), "stack_blacklist.yaml") +StackBlackList = load_yaml(StackBlackListPath).get("stack", []) + +distributed_func = {} +for f in dir(comm_func): + distributed_func[f] = getattr(comm_func, f) + +ORIGIN_WAIT = CommHandle_.wait if enable_communication else None +PENDING_ASYNC_CC_BY_HANDLE = {} + + +def get_distributed_ops(): + global WrapDistributedOps + _all_distributed_ops = dir(comm_func) + return set(WrapDistributedOps) & set(_all_distributed_ops) + + +class DistributedOPTemplate(nn.Cell): + def __init__(self, op_name, pre_hooks, post_hooks): + super(DistributedOPTemplate, self).__init__() + self.op_name_ = str(op_name) + self.__name__ = self.op_name_ + self.cc_hooks = [] + for pre_hook in pre_hooks: + handle = self.register_forward_pre_hook(pre_hook) + self.cc_hooks.append(handle) + for hook in post_hooks: + handle = self.register_forward_hook(hook) + self.cc_hooks.append(handle) + + def construct(self, *args, **kwargs): + return distributed_func.get(self.op_name_)(*args, **kwargs) + + def forward(self, *args, **kwargs): + return distributed_func.get(self.op_name_)(*args, **kwargs) + + +class ApiRegistry: + def __init__(self): + self.distributed_attr_origin = {} + self.distributed_attr_hooked = {} + + @staticmethod + def store_ori_attr(ori_api_group, api_list, api_ori_attr): + for api in api_list: + if Const.SEP in api: + sub_module_name, sub_op = api.rsplit(Const.SEP, 1) + sub_module = getattr(ori_api_group, sub_module_name) + api_ori_attr[api] = getattr(sub_module, sub_op) + else: + api_ori_attr[api] = getattr(ori_api_group, api) + + @staticmethod + def set_api_attr(api_group, attr_dict): + for cc_api_name, cc_api_entry_func in attr_dict.items(): + if Const.SEP in cc_api_name: + sub_module_name, sub_op = cc_api_name.rsplit(Const.SEP, 1) + sub_module = getattr(api_group, sub_module_name, None) + if sub_module is not None: + setattr(sub_module, sub_op, cc_api_entry_func) + else: + setattr(api_group, cc_api_name, cc_api_entry_func) + + @staticmethod + def redirect_wait(): + global ORIGIN_WAIT + global PENDING_ASYNC_CC_BY_HANDLE + if not ORIGIN_WAIT: + return + + def wrapped_wait(work): + def wrapped_wait(*args, **kwargs): + ORIGIN_WAIT(*args, **kwargs) + if args[0] in PENDING_ASYNC_CC_BY_HANDLE: + store_func = PENDING_ASYNC_CC_BY_HANDLE.pop(args[0]) + store_func() + + return wrapped_wait + + CommHandle_.wait = wrapped_wait(CommHandle_) + + def redirect_api(self): + self.set_api_attr(comm_func, self.distributed_attr_hooked) + self.redirect_wait() + + def restore_api(self): + if not ORIGIN_WAIT: + return + self.set_api_attr(comm_func, self.distributed_attr_origin) + setattr(CommHandle_, 'wait', ORIGIN_WAIT) + + def initialize_hook(self, pre_hooks, post_hooks): + self.store_ori_attr(comm_func, get_distributed_ops(), self.distributed_attr_origin) + cc_hooks = [] + for op_name in get_distributed_ops(): + self.distributed_attr_hooked[op_name] = DistributedOPTemplate(op_name, pre_hooks, post_hooks) + cc_hooks.extend(self.distributed_attr_hooked[op_name].cc_hooks) + return cc_hooks + + +def get_process_group(process_group): + return ( + process_group + if process_group + else comm_func.HCCL_WORLD_GROUP + ) + + +def stack_filter(stack): + for pattern in StackBlackList: + if re.search(pattern, stack): + return False + return True + + +def get_callstack(): + callstack = [] + for (_, path, line, func, _, _) in inspect.stack(): + stack_line = f'{path}[{line}]' + if stack_filter(stack_line): + callstack.append(stack_line + ' ' + func) + return callstack + + +@_no_grad() +def op_aggregate(op, tensorlist): + if isinstance(tensorlist, Tensor): + return tensorlist + if not tensorlist: + return Tensor(float('nan')) + if op == 'min': + return min(tensorlist) + if op == 'max': + return max(tensorlist) + if op == 'norm': + return sum(tensorlist) + if op == 'zeros': + return sum(tensorlist) / len(tensorlist) + if op == 'nans': + return sum(tensorlist) + if op == 'mean': + return sum(tensorlist) / len(tensorlist) + return Tensor(float('nan')) + + +def update_data(old, new): + for tag, op2tensor in new.items(): + if tag not in old: + old[tag] = {} + for op, tensor in op2tensor.items(): + if op not in old[tag]: + old[tag][op] = [tensor] + else: + old[tag][op].append(tensor) + return old + + +def is_target_line(codeline): + stack = get_callstack() + whole_stack = ';'.join(stack) + if codeline == []: + return True + for pattern in codeline: + if re.search(pattern, whole_stack): + return True + return False + + +@_no_grad() +def catch_data(cc_context, cc_name, ops_list, args, prefix): + tensor_args = {} + for arg in args: + if isinstance(arg, Tensor): + key = get_summary_writer_tag_name(cc_name, f'{prefix}_{len(tensor_args)}', RANK) + tensor_args[key] = arg + elif isinstance(arg, list): + if isinstance(arg[0], Tensor): + stacked_arg = ops.stack(arg) + elif isinstance(arg[0], comm_func.P2POp): + stacked_arg = ops.stack([op.tensor for op in arg]) + key = get_summary_writer_tag_name(cc_name, f'{prefix}_{len(tensor_args)}', RANK) + tensor_args[key] = stacked_arg + + new_data = get_metrics(ops_list, tensor_args, 1e-8) + cc_context.data = update_data(cc_context.data, new_data) + + +def create_async_callback_func(context, cc_name, ops_list, args, prefix): + def store_data(): + catch_data(context, cc_name, ops_list, args, prefix) + + return store_data + + +def create_hooks(context, monitor): + def cc_log_hook(module, inputs): + stack = ';'.join(get_callstack()) + monitor.cc_logged_stack[module.op_name_].add(stack) + return + + def cc_pre_hook(module, inputs): + if not is_target_line(monitor.cc_codeline): + return + catch_data(context[module.op_name_], module.op_name_, monitor.ops, inputs, MonitorConst.PREFIX_PRE) + return + + def cc_hook(module, inputs, out=None): + if not is_target_line(monitor.cc_codeline): + return out + if out and enable_communication: # async + if isinstance(out, CommHandle_): + PENDING_ASYNC_CC_BY_HANDLE[out] = create_async_callback_func( + context[module.op_name_], + module.op_name_, + monitor.ops, inputs, + MonitorConst.PREFIX_POST + ) + elif isinstance(out, list): # batch_isend_irecv + for out_element in out: + if isinstance(out_element, comm_func.P2POp): + PENDING_ASYNC_CC_BY_HANDLE[out_element] = create_async_callback_func( + context[module.op_name_], + module.op_name_, + monitor.ops, inputs, + MonitorConst.PREFIX_POST + ) + elif isinstance(out, tuple): + if len(out) == 2 and isinstance(out[1], CommHandle_): + PENDING_ASYNC_CC_BY_HANDLE[out[1]] = create_async_callback_func( + context[module.op_name_], + module.op_name_, + monitor.ops, inputs, + MonitorConst.PREFIX_POST + ) + + return out + catch_data(context[module.op_name_], module.op_name_, monitor.ops, inputs, MonitorConst.PREFIX_POST) + return out + + global RANK + pre_hooks = [] + hooks = [] + RANK = str(get_rank()) + if communication.GlobalComm.INITED and RANK not in monitor.module_rank_list and monitor.module_rank_list != []: + return [pre_hooks, hooks] + + if monitor.cc_log_only: + pre_hooks.append(cc_log_hook) + return [pre_hooks, hooks] + + if monitor.cc_pre_hook: + pre_hooks.append(cc_pre_hook) + + hooks.append(cc_hook) + + return [pre_hooks, hooks] + + +api_register = ApiRegistry() diff --git a/debug/accuracy_tools/msprobe/mindspore/monitor/features.py b/debug/accuracy_tools/msprobe/mindspore/monitor/features.py new file mode 100644 index 0000000000000000000000000000000000000000..be958dadfe8fcc50f26f16c93b3a090269235d1e --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/monitor/features.py @@ -0,0 +1,63 @@ +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from mindspore import mint, ops, _no_grad +from mindspore import Tensor +from mindspore import dtype as mstype + + +@_no_grad() +def square_sum(x: Tensor): + return (x * x).sum() + + +@_no_grad() +def get_min(x: Tensor): + return mint.min(x) + + +@_no_grad() +def get_mean(x: Tensor): + return mint.mean(x.astype(mstype.float32)) + + +@_no_grad() +def get_norm(x: Tensor): + norm_func = mint.norm if hasattr(mint, "norm") else ops.norm + return norm_func(x.astype(mstype.float32)) + + +@_no_grad() +def get_max(x: Tensor): + return mint.max(x) + + +@_no_grad() +def get_zeros(x: Tensor, eps: float): + return mint.sum(mint.abs(x) < eps) / x.numel() + + +@_no_grad() +def get_nans(t): + return ops.isnan(t.astype(mstype.float32)).sum() + + +FUNC_MAP = {"min" : get_min, + "max" : get_max, + "mean" : get_mean, + "norm" : get_norm, + "nans" : get_nans, + "zeros": get_zeros + } \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/mindspore/monitor/module_hook.py b/debug/accuracy_tools/msprobe/mindspore/monitor/module_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..068be9ff6c782bec2bf637999ef5f0eabe0c2675 --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/monitor/module_hook.py @@ -0,0 +1,870 @@ +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import re +import uuid +from collections import defaultdict +from datetime import datetime + +import pytz +import mindspore as ms +from mindspore import Tensor, mint +from mindspore import nn, _no_grad +from mindspore.communication import get_rank + +from msprobe.core.common.log import logger +from msprobe.core.common.const import MonitorConst +from msprobe.core.common.file_utils import load_json, save_json +from msprobe.mindspore.monitor.utils import get_summary_writer_tag_name, validate_config, step_accumulates_one, \ + is_skip_step, get_metrics, get_single_metrics, get_target_output_dir +from msprobe.mindspore.monitor.module_spec_verifier import validate_config_spec +from msprobe.mindspore.monitor.anomaly_detect import AnomalyScanner, AnomalyDataFactory, \ + CSVWriterWithAD, BaseWriterWithAD, WriterInput +from msprobe.mindspore.monitor.distributed.wrap_distributed import api_register, create_hooks, op_aggregate, \ + get_process_group + +FORMAT_MAPPING = { + MonitorConst.CSV: CSVWriterWithAD, + MonitorConst.API: BaseWriterWithAD +} + + +def get_output_base_dir(): + return os.getenv(MonitorConst.MONITOR_OUTPUT_DIR, MonitorConst.DEFAULT_MONITOR_OUTPUT_DIR) + + +def get_param_struct(param): + res = {} + if isinstance(param, (tuple, list)): + res['config'] = f'{type(param).__name__}[{len(param)}]' + for i, x in enumerate(param): + res[i] = f'size={tuple(x.shape)}, dtype={x.dtype}' if isinstance(x, Tensor) else f'{type(x)}' + elif isinstance(param, Tensor): + res['config'] = 'tensor' + res['tensor'] = f'size={tuple(param.shape)}, dtype={param.dtype}' + else: + res['config'] = f'{type(param)}' + logger.warning(f'Not support type({type(param)}) now, please check the type of param {param}') + return res + + +def param_is_not_tensor_parallel_duplicate(param, tp_group): + return (hasattr(param, 'tensor_model_parallel') and param.tensor_model_parallel) or ( + mint.distributed.get_rank(group=tp_group) == 0 + ) + + +def param_is_data_parallel_duplicate(dp_group): + return mint.distributed.get_rank(group=dp_group) != 0 + + +def squash_param_name(param_name): + for pattern in ['layers?\.(.*)', 'embeddings?\.(.*)', 'final.*', 'output.*', 'norm.*']: + match = re.findall(pattern, param_name) + if match: + return match[0] + return param_name + + +# Used For Module Forward & Backward Collect +class ModuleHookContext: + def __init__(self, module_name) -> None: + self.step = 0 + self.micro_step = 0 + self.actv = defaultdict(dict) + self.actvgrad = [] + self.module_name = module_name + self.struct = {} + self.format_by_arg = {} + self.verified = False + self.focused_in_col = 0 + self.focused_out_col = 0 + self.ignore_in = False # no need to care when no key 'input' or 'input_grad' found + + def set_format_by_arg(self, key_name: str, target_config: dict): + cared = target_config.get(self.module_name, self.struct) + if key_name in cared: + if isinstance(cared[key_name], dict): + # current cared is self.struct + config = cared[key_name].get('config') + self.format_by_arg[key_name] = config + else: + # current cared is target_config[self.module_name] + self.format_by_arg[key_name] = cared[key_name] + elif key_name in ['input', 'input_grad']: + self.ignore_in = True + + def reset(self): + self.actv.clear() + self.actvgrad.clear() + + +start_step = 0 + + +# Used For Optimizer Weight Grad & M/V Collect +class OptimizerContext: + def __init__(self) -> None: + self.step = start_step + self.param_mg_direction = defaultdict(float) + self.param_adam_update = defaultdict() + self.param_adam_ratio = defaultdict() + self.param_weight_grad = defaultdict() + self.param_exp_avg = defaultdict() + self.exp_avg_metric = {} + self.param_exp_avg_sq = defaultdict() + self.exp_avg_sq_metric = {} + self.metric_dict = {} + self.param_metric = {} + + def reset(self) -> None: + self.param_mg_direction.clear() + self.param_adam_update.clear() + self.param_adam_ratio.clear() + self.param_weight_grad.clear() + self.param_exp_avg.clear() + self.exp_avg_metric.clear() + self.param_exp_avg_sq.clear() + self.exp_avg_sq_metric.clear() + self.metric_dict.clear() + self.param_metric.clear() + + +# Used For Weight Grad Collect +class GradContext: + def __init__(self) -> None: + self.pre = {} + self.post = {} + self.acc_metric = {} + self.acc = {} + self.actv = {} + + def reset(self): + self.pre.clear() + self.post.clear() + self.acc_metric.clear() + self.acc.clear() + self.actv.clear() + + +class CommunicationContext: + def __init__(self) -> None: + self.data = {} + + @staticmethod + def _agg(data): + aggregated_data = {} + for tag, op2tensorlist in data.items(): + aggregated_data[tag] = {} + for op, tensorlist in op2tensorlist.items(): + aggregated_data[tag][op] = op_aggregate(op, tensorlist) + return aggregated_data + + def reset(self): + self.data = {} + + def aggregate(self): + self.data = self._agg(self.data) + + +class TrainerMon: + def __init__(self, config_file_path, process_group=None, params_have_main_grad=True) -> None: + # TYPE1: 只在这里初始化的变量, 不会随着训练中途config配置改变而重置 + self.config_file_path = config_file_path + self.process_group = process_group + self.params_have_main_grad = params_have_main_grad + self.config_timestamp = 0 # 后面有校验时间戳, 首次监控无需为了更新config文件时间戳而去改, 可通过dynamic_on开关直接打开 + self.config = load_json(config_file_path) + validate_config(self.config) + + local_tz = pytz.timezone("Asia/Shanghai") # 根据需要调整为目标时区 + cur_time = datetime.now(local_tz).strftime('%b%d_%H-%M-%S') + self.unique_id = str(uuid.uuid4())[:8] + self.output_base_dir = get_output_base_dir() + time_tags = self.config.get("append_output", []) + try: + self.rank = get_rank() + if time_tags: + output_append_dirs = get_target_output_dir(self.output_base_dir, time_tags[0], time_tags[1]) + if str(self.rank) in output_append_dirs: + self.tensorboard_dir = output_append_dirs[str(self.rank)] + logger.info(f"Append rank({self.rank}) result to {self.tensorboard_dir}") + else: + self.tensorboard_dir = os.path.join(self.output_base_dir, + f"{cur_time}-rank{self.rank}-{self.unique_id}") + except Exception as e: + self.rank = 0 + self.tensorboard_dir = os.path.join(self.output_base_dir, f"{cur_time}-rank{self.rank}-{self.unique_id}") + + self.pp_stage = 0 + self.group_mates = [0] + + # TYPE2: 只会在set_monitor()主调中赋值的变量 + self.model = None + self.vpp = False + self.dp_group = None + self.tp_group = None + self.micro_batch_number = 1 + + # TYPE3: 会随着训练中途config配置更新或监控状态改变而重置的变量 + self.module_fwd_hook_context_by_module = defaultdict(ModuleHookContext) + self.module_bwd_hook_context_by_module = defaultdict(ModuleHookContext) + self.optimizer_context = defaultdict(OptimizerContext) + self.cc_context = defaultdict(CommunicationContext) + self.grad_context = GradContext() + self.handles = defaultdict(list) + self.param2name = defaultdict(str) + self.name2index = defaultdict() + self.name2indices = defaultdict() + self.name2param = {} + self.duplicate_param = {} + self.name2tag = {} + self.param_name_call_id = {} + self.call_id = 0 + self.module_struct = defaultdict(dict) + self.grad_accs = [] + self.weight_hooked = False + self.optimizer_hooked = False + self.param_registered = False + self.struct_printed = False + + # 动静态区分 + self.dynamic_enable = os.getenv("DYNAMIC_MONITOR", 'False').lower() == 'true' + if self.dynamic_enable: + logger.warning(f"DYNAMIC_MONITOR is set, " + f"please make sure you have 'dynamic_on' and 'collect_times' in {self.config_file_path}") + self.monitoring = False + else: + self.set_config() + # 静态且collect_times>0时在第0步self.monitoring就可以True, 动态默认在下一步开启 + if self.collect_times > 0: + self.monitoring = True + + def set_config(self): + self.start_step = self.config.get("start_step", 0) + self.collect_times = self.config.get("collect_times", 100000000) # 默认大值, 目的是一直采集 + self.step_interval = self.config.get("step_interval", 1) + self.has_collect_times = 0 # 重设采集计数器 + self.print_struct = self.config.get("print_struct", False) + self.targets = self.config.get("targets", None) + self.is_select = self.config.get("is_select", False) + self.module_rank_list = self.config.get("module_ranks", []) + self.format = self.config.get('format', MonitorConst.CSV) # only csv supported in mindspore + self.eps = self.config.get('eps', 1e-8) + self.ops = self.config.get('ops', []) # monitor mean/max/norm/min/nan... + self.ndigits = self.config.get('ndigits', 6) + self.all_xy = self.config.get('all_xy', False) + self.xy_distribution = self.config.get('xy_distribution', False) + self.forward_only = self.config.get('forward_only', False) + self.backward_only = self.config.get('backward_only', False) + self.ur_distribution = self.config.get('ur_distribution', False) # vector and ratio vector of adam + self.mv_distribution = self.config.get("mv_distribution", False) # m/v of adam + self.wg_distribution = self.config.get("wg_distribution", False) + self.param_distribution = self.config.get("param_distribution", False) + self.mg_direction = self.config.get('mg_direction', False) # main grad direction + self.cc_distribution = self.config.get("cc_distribution", {}) # communication ops + if not self.cc_distribution.get('enable', False): + self.cc_log_only = False + else: + self.cc_codeline = self.cc_distribution.get('cc_codeline', []) + self.cc_log_only = self.cc_distribution.get('cc_log_only', False) + self.cc_logged_stack = defaultdict(set) + self.cc_pre_hook = self.cc_distribution.get('cc_pre_hook', False) + self.handles['cc'] = api_register.initialize_hook(*create_hooks(context=self.cc_context, monitor=self)) + api_register.redirect_api() + self.common_info() + + # 初始化AnomalyData工厂 + alert_setting = self.config.get('alert', {"rules": []}) + self.alert_rules = AnomalyScanner.load_rules(alert_setting["rules"]) + self.anomaly_data_factory = None + if alert_setting.get('dump', False): + self.anomaly_data_factory = AnomalyDataFactory(self.rank, self.pp_stage, self.group_mates) + + # 初始化writer, 创建输出目录 + if self.format not in FORMAT_MAPPING: + logger.error(f"Unsupported format: {self.format}, use default format: {MonitorConst.CSV}") + self.format = MonitorConst.CSV + writer = FORMAT_MAPPING[self.format] + self.step_count_per_record = self.config.get('step_count_per_record', 1) + self.summary_writer = writer( + WriterInput( + self.tensorboard_dir, + self.alert_rules, + self.unique_id, + self.anomaly_data_factory, + self.ndigits, + self.step_count_per_record + ) + ) + + def common_info(self): + if not self.xy_distribution: + logger.info("> module input/output input_grad/output_grad is not monitored. ") + if self.forward_only: + logger.info("> only module forward is monitored. ") + if not self.ur_distribution: + logger.info("> update vector and ratio vector of adam is not monitored. ") + if not self.mv_distribution: + logger.info("> momentum and variance of adam is not monitored. ") + if not self.wg_distribution: + logger.info("> weight grad of specified module is not monitored. ") + if not self.mg_direction: + logger.info('> grad and momentum direction will not be compared.') + if not self.cc_distribution.get('enable', False): + logger.info("> cc operator is not monitored.") + + def set_monitor( + self, + model, + optimizer, + grad_acc_steps=1, + tp_group=None, + dp_group=None, + start_iteration=0 + ): + global start_step + start_step = start_iteration + self.micro_batch_number = grad_acc_steps + self.dp_group = dp_group + self.tp_group = tp_group + self.hook_step_final(optimizer) + if not isinstance(model, list): + model = [model] + self.model = model + if len(model) > 1: + self.vpp = True + logger.info('vpp enabled') + if not self.dynamic_enable: + self.register_hooks(optimizer) + + def hook_step_final(self, optimizer): + def step_final_hook(optimizer, *args, **kwargs): + context = self.optimizer_context[optimizer] + # 静态在第0步就可以保存, 动态在第0步不可以, 因为动态设计的就是重置后下一步开启, 第0步的self.monitoring还是False + if self.monitoring: + module_rank_valid = self.is_target_rank() + step_condition = (context.step >= self.start_step and ( + context.step - self.start_step) % self.step_interval == 0) + if module_rank_valid and step_condition: + self.has_collect_times += 1 + self.write_xy_tb(context.step) + self.write_grad_tb(context.step) + self.write_mv_tb(context) + self.write_param_tb(context) + + if context.metric_dict: + self.summary_writer.write_metrics(self.ops, context.metric_dict, context.step, 'other') + context.metric_dict.clear() + + self.summary_writer.clear_anomalies() + self.call_id = 0 + self.param_name_call_id.clear() + + if self.has_collect_times >= self.collect_times: + self._remove_all_hooks_final(optimizer) + + context.step += 1 + self.dynamic_monitor(optimizer) + + optimizer.register_forward_hook(step_final_hook) + return + + def dynamic_monitor(self, optimizer): + """ + If dynamic monitor enabled and config.json updated, + remove hooks and register new hooks according to new configuration. + """ + context = self.optimizer_context[optimizer] + if not self.dynamic_enable: + return + try: + # 如果文件时间戳没变, 可以不读取节省时间 + config_timestamp = os.path.getmtime(self.config_file_path) + if config_timestamp == self.config_timestamp: + return + # 更新config文件最新修改时间戳 + self.config_timestamp = config_timestamp + config = load_json(self.config_file_path) + except Exception as e: + logger.error(f"get config.json wrong because {e}, not updated, please check!!!") + return + + if config.get("dynamic_on", False): + try: + validate_config(config) + self.config = config + self.set_config() + logger.warning(f"config is updated at step{context.step - 1}, " + f"will start new hook at step{context.step}.") + except Exception as e: + logger.error(f"set config wrong because {e}, not updated, please check!!!") + return + + self._remove_all_hooks() + self.register_hooks(optimizer) + + def register_hooks(self, optimizer): + self._register_param_name() + self.hook_modules() + self.hook_optimizer(optimizer) + self._patch_grad_sync() + self.monitoring = True + + def hook_modules(self): + if not self.is_target_rank(): + return + module_in_all_stage = [key for key in self.targets.keys() if MonitorConst.NAME_SEP not in key] + + for key in module_in_all_stage: + struct = self.targets.pop(key) + self.targets.update( + {f'{vpp_stage}{MonitorConst.NAME_SEP}{key}': struct for vpp_stage in range(len(self.model))}) + + hooked_count = 0 + for vpp_stage, model_chunk in enumerate(self.model): + if not isinstance(model_chunk, nn.Cell): + logger.info("Target Model is not Cell") + continue + vpp_stage = f'{vpp_stage}{MonitorConst.NAME_SEP}' + targets = [x for x, _ in model_chunk.cells_and_names()] if self.print_struct else self.targets.keys() + hooked_count += self._hook_module(targets, model_chunk, vpp_stage) + logger.info(f"> {hooked_count} modules are monitored.") + + def hook_optimizer(self, optimizer): + def optimizer_pre_hook_function(opt, grad_names, gradients): + context = self.optimizer_context[opt] + if is_skip_step(context.step, self.start_step, self.step_interval, self.has_collect_times, + self.collect_times): + return + gradient_list = gradients[0] if isinstance(gradients, tuple) else gradients + is_select = self.is_select + for idx, grad in enumerate(gradient_list): + grad_name = grad_names[idx] + if is_select and grad_name not in self.targets: + continue + get_single_metrics(self.ops, grad_name, grad, context.param_weight_grad) + + if self.mv_distribution: + # fetch mean + for param in m_list: + name = param.name + if is_select and name not in self.targets: + continue + get_single_metrics(self.ops, name, param, context.exp_avg_metric) + # fetch variance + for param in v_list: + name = param.name + if is_select and name not in self.targets: + continue + get_single_metrics(self.ops, name, param, context.exp_avg_sq_metric) + if self.param_distribution: + for param in param_list: + get_single_metrics(self.ops, param.name, param, context.param_metric) + self.generate_wgrad_metrics() + metric_dict = {} + for cc in self.cc_context.values(): + cc.aggregate() + metric_dict.update(cc.data) + cc.reset() + + if not metric_dict: + return + context.metric_dict = metric_dict + return + + def optimizer_pre_hook_wrapper(func, grad_names): + def wrapper(opt, gradients): + return func(opt, grad_names, gradients) + return wrapper + + if self.optimizer_hooked or not self.is_target_rank(): + return + + m_list = [] + v_list = [] + param_list = [] + grad_names = [] + for param in optimizer.get_parameters(): + if MonitorConst.EXP_AVG_SQ in param.name: + v_list.append(param) + elif MonitorConst.EXP_AVG in param.name: + m_list.append(param) + elif param.name in ['global_step', 'learning_rate']: + pass + else: + param_list.append(param) + grad_names.append(param.name) + + handle = optimizer.register_forward_pre_hook( + optimizer_pre_hook_wrapper(optimizer_pre_hook_function, grad_names)) + self.handles['optimizer'].append(handle) + self.optimizer_hooked = True + return + + def generate_wgrad_metrics(self): + if not self.wg_distribution: + return {}, {} + + if self.weight_hooked: + try: + get_metrics(self.ops, self.grad_context.acc, self.eps, self.grad_context.acc_metric) + except Exception as e: + logger.warning(f"An error occurred while generating wgrad pre metrics") + return {}, {} + + grad_dict = {} + for param, name in self.param2name.items(): + if self.duplicate_param.get(name, False): + continue + grad = param.main_grad if self.params_have_main_grad else param.grad + if grad is None: + logger.warning(f"grad is None: {name}, maybe something wrong happened.") + continue + tag = self.name2tag.get(name, {}).get(MonitorConst.POST_GRAD) + self._register_param_call_id("hook_optimizer", tag) + grad_dict[tag] = grad + try: + get_metrics(self.ops, grad_dict, self.eps, self.grad_context.post) + except Exception as e: + logger.warning(f"An error occurred while generating wgrad post metrics") + return {}, {} + return self.grad_context.post, self.grad_context.pre + + def write_xy_tb(self, step): + if not self.xy_distribution: + return + for _, fwd_context in self.module_fwd_hook_context_by_module.items(): + if len(fwd_context.actv) == 0: + continue + self.summary_writer.write_metrics(self.ops, fwd_context.actv, step, 'actv') + fwd_context.actv.clear() + if self.grad_context.actv: + self.summary_writer.write_metrics(self.ops, self.grad_context.actv, step, 'actv_grad') + + def write_param_tb(self, opt_context): + if not self.param_distribution: + return + self.summary_writer.write_metrics(self.ops, opt_context.param_metric, opt_context.step, 'param') + + def write_mv_tb(self, opt_context): + if not self.mv_distribution: + return + self.summary_writer.write_metrics(self.ops, opt_context.exp_avg_metric, opt_context.step, 'exp_avg') + self.summary_writer.write_metrics(self.ops, opt_context.exp_avg_sq_metric, opt_context.step, 'exp_avg_sq') + + def write_grad_tb(self, step): + if not self.wg_distribution: + return + + self.summary_writer.write_metrics(self.ops, self.grad_context.acc_metric, step, 'grad_unreduced') + self.summary_writer.write_metrics(self.ops, self.grad_context.post, step, 'grad_reduced') + + def is_target_rank(self): + if self.module_rank_list and (self.rank not in self.module_rank_list): + return False + return True + + def build_tbtag_tensor_map(self, module_name, tag, tensor): + metrics = {} + key = get_summary_writer_tag_name(module_name, tag, str(self.rank)) + if isinstance(tensor, Tensor): + self._register_param_call_id("_hook_module", key) + metrics[key] = tensor + return metrics + + def _register_param_name(self): + for vpp_stage, model_chunk in enumerate(self.model): + prefix = f'{vpp_stage}{MonitorConst.NAME_SEP}' + self._register_chunk(model_chunk, prefix) + + def _register_chunk(self, model_chunk, prefix): + index = 0 + for param in model_chunk.get_parameters(): + param_name = param.name + if not param.requires_grad: + continue + if self._is_target_param(param_name, param, prefix): + name = prefix + squash_param_name(param_name) + if name in self.param2name.values(): + name = prefix + param_name + self.param2name[param] = name + self.name2param[name] = param + self.name2index[name] = index + + if self.tp_group and not param_is_not_tensor_parallel_duplicate(param, self.tp_group): + self.duplicate_param[name] = True + if self.dp_group and param_is_data_parallel_duplicate(self.dp_group): + self.duplicate_param[name] = True + self.name2tag[name] = { + MonitorConst.PRE_GRAD: get_summary_writer_tag_name(name, MonitorConst.PRE_GRAD, self.rank), + MonitorConst.POST_GRAD: get_summary_writer_tag_name(name, MonitorConst.POST_GRAD, self.rank) + } + index += 1 + + def _hook_module(self, target_names, module, vpp_stage=''): + if not isinstance(module, nn.Cell): + # nothing to hook + return 0 + + def fwd_hook_fun(module, module_input, module_output, name): + if module not in self.module_fwd_hook_context_by_module: + self.module_fwd_hook_context_by_module[module] = ModuleHookContext(name) + context: ModuleHookContext = self.module_fwd_hook_context_by_module[module] + if not context.struct: + context.struct = { + MonitorConst.ACTV_IN: get_param_struct(module_input), + MonitorConst.ACTV_OUT: get_param_struct(module_output) + } + if self.print_struct: + self.module_struct[context.module_name].update(context.struct) + return + if not module.training: + return + if is_skip_step(context.step, self.start_step, self.step_interval, self.has_collect_times, + self.collect_times): + step_accumulates_one(context, self.micro_batch_number) + return + if not context.format_by_arg: + context.set_format_by_arg(MonitorConst.ACTV_IN, self.targets) + context.set_format_by_arg(MonitorConst.ACTV_OUT, self.targets) + if not context.format_by_arg: + return + if not context.verified: + if not context.ignore_in: + context.focused_in_col = validate_config_spec(context.format_by_arg[MonitorConst.ACTV_IN], + module_input, context.module_name, + MonitorConst.ACTV_IN) + context.focused_out_col = validate_config_spec(context.format_by_arg[MonitorConst.ACTV_OUT], + module_output, context.module_name, + MonitorConst.ACTV_OUT) + context.verified = True + + tbtag_tensor_map = {} + if not context.ignore_in: + cared_input = module_input if context.focused_in_col is None else module_input[context.focused_in_col] + tbtag_tensor_map.update( + self.build_tbtag_tensor_map(f'{context.module_name}_{context.micro_step}', MonitorConst.ACTV_IN, + cared_input)) + cared_output = module_output if context.focused_out_col is None else module_output[context.focused_out_col] + tbtag_tensor_map.update( + self.build_tbtag_tensor_map(f'{context.module_name}_{context.micro_step}', MonitorConst.ACTV_OUT, + cared_output)) + try: + get_metrics(self.ops, tbtag_tensor_map, self.eps, context.actv) + except Exception as e: + logger.warning(f"An error occurred while generating forward activation metrics") + + step_accumulates_one(context, self.micro_batch_number) + return + + def bwd_hook_fun(module, input_grad, output_grad): + context: ModuleHookContext = self.module_bwd_hook_context_by_module[module] + if not context.struct: + context.struct = { + MonitorConst.ACTVGRAD_IN: get_param_struct(input_grad), + MonitorConst.ACTVGRAD_OUT: get_param_struct(output_grad) + } + if self.print_struct: + self.module_struct[context.module_name].update(context.struct) + return + + if is_skip_step(context.step, self.start_step, self.step_interval, self.has_collect_times, + self.collect_times): + step_accumulates_one(context, self.micro_batch_number) + return + + if not context.format_by_arg: + context.set_format_by_arg(MonitorConst.ACTVGRAD_IN, self.targets) + context.set_format_by_arg(MonitorConst.ACTVGRAD_OUT, self.targets) + if not context.format_by_arg: + return + if not context.verified: + if not context.ignore_in: + context.focused_in_col = validate_config_spec(context.format_by_arg[MonitorConst.ACTVGRAD_IN], + input_grad, context.module_name, + MonitorConst.ACTVGRAD_IN) + context.focused_out_col = validate_config_spec(context.format_by_arg[MonitorConst.ACTVGRAD_OUT], + output_grad, context.module_name, + MonitorConst.ACTVGRAD_OUT) + context.verified = True + + tbtag_tensor_map = {} + if not context.ignore_in: + cared_input_grad = input_grad if context.focused_in_col is None else input_grad[context.focused_in_col] + tbtag_tensor_map.update( + self.build_tbtag_tensor_map( + f'{context.module_name}_{context.micro_step}', MonitorConst.ACTVGRAD_IN, cared_input_grad)) + cared_output_grad = output_grad if context.focused_out_col is None else output_grad[context.focused_out_col] + tbtag_tensor_map.update( + self.build_tbtag_tensor_map(f'{context.module_name}_{context.micro_step}', MonitorConst.ACTVGRAD_OUT, + cared_output_grad)) + + if context.micro_step == 0 and context.actvgrad: + logger.warning(f"actvgrad context of {context.module_name} is not empty when first micro_step, " + f"maybe something wrong happened. Now clear it.") + context.actvgrad.clear() + try: + get_metrics(self.ops, tbtag_tensor_map, self.eps, self.grad_context.actv) + except Exception as e: + logger.warning(f"An error occurred while generating backward activation metrics: {e}") + + step_accumulates_one(context, self.micro_batch_number) + return + + def fwd_hook_fun_wrapper(fwd_hook_fun, name): + def wrapper(module, module_input, module_output): + return fwd_hook_fun(module, module_input, module_output, name) + return wrapper + + if self.backward_only and self.forward_only: + logger.warning('not enable backward_only and forward_only simultaneously') + hooked_count = 0 + if self.xy_distribution or self.print_struct: + for module_name, submodule in module.cells_and_names(): + name = self._is_target_module(module_name, target_names, vpp_stage) + if not name: + continue + if not self.backward_only: + handle = submodule.register_forward_hook(fwd_hook_fun_wrapper(fwd_hook_fun, name=name)) + self.handles['xy'].append(handle) + if not self.forward_only: + handle = submodule.register_backward_hook(bwd_hook_fun) + self.handles['xy'].append(handle) + self.module_bwd_hook_context_by_module[submodule] = ModuleHookContext(name) + logger.info(f"> {name} is monitored successfully") + hooked_count += 1 + return hooked_count + + def _patch_grad_sync(self): + if not self.wg_distribution: + return + self._hook_weights() + + def _hook_weights(self): + context = self.grad_context + + @_no_grad() + def param_hook(grad, context_dict, param, key): + param.micro_step += 1 + self._register_param_call_id("param_hook", key) + if param.micro_step == self.micro_batch_number: + param.micro_step = 0 + context_dict[key] = grad + + def param_hook_wrapper(param_hook, context_dict, param, key): + def wrapper(grad): + return param_hook(grad, context_dict, param, key) + return wrapper + + for param, name in self.param2name.items(): + key = get_summary_writer_tag_name(name, 'acc_grad', self.rank) + setattr(param, 'micro_step', 0) + handle = param.register_hook(param_hook_wrapper(param_hook, context_dict=context.acc, param=param, key=key)) + self.handles['wgrads'].append(handle) + self.weight_hooked = True + + def _is_target_param(self, param_name, param, prefix): + if not self.targets: + return True + squash_name = prefix + squash_param_name(param_name) + name = prefix + param_name + for target in self.targets.keys(): + if param_name.startswith(target) or squash_name.startswith(target) or name.startswith(target): + setattr(param, "zero_out_wgrad", True) + return True + return False + + def _is_target_module(self, module_name, targets, vpp_stage): + if self.all_xy or self.print_struct: + return vpp_stage + squash_param_name(module_name) + for pattern in [ + vpp_stage + squash_param_name(module_name), + vpp_stage + module_name, + ]: + if pattern in targets: + return pattern + return "" + + def _register_param_call_id(self, hook_name: str, key: str): + """ + :param hook_name: + :param key: str, '0:relu_0/output_grad' + :return: + """ + logger.debug(f"{hook_name} {key}: {self.call_id}") + self.param_name_call_id[key] = self.call_id + self.call_id += 1 + + def _remove_all_hooks(self): + # 清空hook handle + for handle in self.handles['xy']: + handle.remove() + self.handles['xy'].clear() + # 清空对应context缓存 + for _, fwd_context in self.module_fwd_hook_context_by_module.items(): + fwd_context.reset() + for _, bwd_context in self.module_bwd_hook_context_by_module.items(): + bwd_context.reset() + self.grad_context.reset() # 权重梯度和激活值梯度都在这 + + for handle in self.handles['wgrads']: + handle.remove() + self.handles['wgrads'].clear() + self.weight_hooked = False + + if self.optimizer_hooked: + for handle in self.handles['optimizer']: + handle.remove() + self.handles['optimizer'].clear() + for _, context in self.optimizer_context.items(): + context.reset() + self.optimizer_hooked = False + + for handle in self.handles['cc']: + handle.remove() + self.handles['cc'].clear() + for _, context in self.cc_context.items(): + context.reset() + + # 清空节点缓存 + self.param2name.clear() + self.name2index.clear() + self.name2indices.clear() + self.name2param.clear() + self.duplicate_param.clear() + self.name2tag.clear() + self.module_struct.clear() + self.grad_accs.clear() + + # 关闭采集状态 + self.monitoring = False + + def _remove_all_hooks_final(self, optimizer): + if self.dynamic_enable: + # 结束后自动重置dynamic_on为False等待用户手动开启 + try: + config = load_json(self.config_file_path) + config['dynamic_on'] = False + save_json(self.config_file_path, config, indent=2) + config_timestamp = os.path.getmtime(self.config_file_path) + self.config_timestamp = config_timestamp + logger.info( + "Finish monitor, set config'dynamic_on=False, will restart by set it to True and update config") + except Exception as e: + logger.warning(f"Finish monitor, set config'dynamic_on=False fail because {e}, please check!!!") + logger.info("Finish monitor") + self._remove_all_hooks() diff --git a/debug/accuracy_tools/msprobe/mindspore/monitor/module_spec_verifier.py b/debug/accuracy_tools/msprobe/mindspore/monitor/module_spec_verifier.py new file mode 100644 index 0000000000000000000000000000000000000000..c06e8ea10f6a2178c3670e596ad64e333db44cab --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/monitor/module_spec_verifier.py @@ -0,0 +1,94 @@ +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +import abc +from mindspore import Tensor + +from msprobe.core.common.log import logger + + +# 用于存储所有validator实现类的注册表 +config_validator_registry = {} + + +def register_config_validator(cls): + """装饰器 用于注册ConfigValidator的实现类""" + config_validator_registry[cls.__name__] = cls + return cls + + +class ConfigValidator(metaclass=abc.ABCMeta): + @abc.abstractmethod + def check_pattern_match(self, config_spec: str): + pass + + @abc.abstractmethod + def validate(self, actual_data, module_name: str, data_type: str, pattern_match): + pass + + +@register_config_validator +class TensorValidator(ConfigValidator): + def check_pattern_match(self, config_spec: str): + pattern = re.compile(r"tensor") + return pattern.match(config_spec) + + def validate(self, actual_data, module_name: str, data_type: str, pattern_match): + if not isinstance(actual_data, Tensor): + raise ValueError( + f"Format of {module_name} {data_type} does not match the required format 'tensor' in config.") + + +@register_config_validator +class TupleValidator(ConfigValidator): + def check_pattern_match(self, config_spec: str): + pattern = re.compile(r"tuple\[(\d+)\]:?(\d+)?") + return pattern.match(config_spec) + + def validate(self, actual_data, module_name: str, data_type: str, pattern_match): + length, index = pattern_match.groups() + if index is None: + index = 0 + length, index = int(length), int(index) + + if not (0 <= index < length): + raise ValueError( + f"Format of {module_name} {data_type} in config.json does not match the required format 'tuple[x]:y'." + f"y must be greater than or equal to 0 and less than x.") + if not isinstance(actual_data, tuple): + raise ValueError( + f"Type of {module_name} {data_type} does not match spec of config.json, should be tuple, please check.") + if len(actual_data) != length: + raise ValueError( + f"Length of {module_name} {data_type} does not match spec of config.json, should be {length}, " + f"actual is {len(actual_data)} please check.") + return index + + +def validate_config_spec(config_spec: str, actual_data, module_name: str, data_type: str): + focused_col = None + for _, validator_cls in config_validator_registry.items(): + config_validator = validator_cls() + pattern_match = config_validator.check_pattern_match(config_spec) + if pattern_match: + try: + focused_col = config_validator.validate(actual_data, module_name, data_type, pattern_match) + except ValueError as e: + logger.warning(f"config spec validate failed: {str(e)}") + return focused_col + logger.warning(f"config spec in {module_name} {data_type} not supported, " + f"expected spec:'tuple\[(\d+)\]:(\d+)' or 'tensor', actual spec: {config_spec}.") + return focused_col \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/mindspore/monitor/utils.py b/debug/accuracy_tools/msprobe/mindspore/monitor/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..506ad6c3f91c7c73e5e12109a6ea617309df72c0 --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/monitor/utils.py @@ -0,0 +1,301 @@ +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import re +from datetime import datetime +from mindspore import dtype as mstype, Tensor + +from msprobe.mindspore.monitor.features import FUNC_MAP +from msprobe.core.common.const import MonitorConst +from msprobe.core.common.utils import is_int +from msprobe.core.common.log import logger +from msprobe.core.common.file_utils import check_file_or_directory_path + + +def get_single_metrics(op_list, tag, tensor, output=None): + if output is None: + output = {} + if tag not in output: + output[tag] = {} + for op in op_list: + func = FUNC_MAP.get(op) + statistic = func(tensor) + if hasattr(statistic, "dtype") and statistic.dtype == mstype.bfloat16: + statistic = float(statistic) + statistic = Tensor(statistic) + output[tag][op] = statistic.astype(mstype.float32) + + +def get_metrics(op_list, tag2tensor, eps, output=None): + if output is None: + output = {} + for tag, tensor in tag2tensor.items(): + if tag not in output: + output[tag] = {} + get_single_metrics(op_list, tag, tensor, output) + return output + + +def get_summary_writer_tag_name(module_or_param_name: str, tag: str, rank): + if rank is None: + return f"{module_or_param_name}/{tag}" + else: + return f"{module_or_param_name}/rank{rank}/{tag}" + + +def step_accumulates_one(context, micro_batch_number): + """ + :param context: ModuleHookContext + :param micro_batch_number: mbs of training model. + :return: + """ + context.micro_step += 1 + if context.micro_step == micro_batch_number: + context.micro_step = 0 + context.step += 1 + + +def is_skip_step(step, start_step, step_interval, has_collect_times=0, collect_times=1e8): + """ + If current step less than start_step or not reach step_interval, skip current step. + :param step: current training step, int + :param start_step: int + :param step_interval: int + :return: whether skip or not, bool + """ + return step < start_step or (step - start_step) % step_interval != 0 or has_collect_times >= collect_times + + +def validate_ops(ops): + if not isinstance(ops, list): + raise TypeError("ops should be a list") + valid_ops = [] + for op in ops: + if op not in MonitorConst.OP_LIST: + logger.warning(f"op {op} is not supported. Optional ops: {MonitorConst.OP_LIST}") + continue + valid_ops.append(op) + if not valid_ops: + default_op = MonitorConst.OP_LIST[0] + valid_ops.append(default_op) + logger.info(f"There is no valid ops, default op {default_op} is used") + return valid_ops + + +def validate_ranks(ranks): + if not isinstance(ranks, list): + raise TypeError("module_ranks should be a list") + for rank in ranks: + if not isinstance(rank, str): + raise TypeError(f"element in module_ranks should be a str, get {type(rank)}") + + +def validate_targets(targets): + if not isinstance(targets, dict): + raise TypeError('targets in config.json should be a dict') + for module_name, field in targets.items(): + if not isinstance(module_name, str): + raise TypeError('key of targets should be module_name[str] in config.json') + if not isinstance(field, dict): + raise TypeError('values of targets should be cared filed e.g. {"input": "tensor"} in config.json') + + +def validate_print_struct(print_struct): + if not isinstance(print_struct, bool): + raise TypeError("print_struct should be a bool") + + +def validate_ur_distribution(ur_distribution): + if not isinstance(ur_distribution, bool): + raise TypeError('ur_distribution should be a bool') + + +def validate_xy_distribution(xy_distribution): + if not isinstance(xy_distribution, bool): + raise TypeError('xy_distribution should be a bool') + + +def validate_wg_distribution(wg_distribution): + if not isinstance(wg_distribution, bool): + raise TypeError('wg_distribution should be a bool') + + +def validate_mg_distribution(mg_distribution): + if not isinstance(mg_distribution, bool): + raise TypeError('mg_distribution should be a bool') + + +def validate_param_distribution(param_distribution): + if not isinstance(param_distribution, bool): + raise TypeError('param_distribution should be a bool') + + +def validate_cc_distribution(cc_distribution): + if not isinstance(cc_distribution, dict): + raise TypeError('cc_distribution should be a dictionary') + expected_keys = { + 'enable': bool, + 'cc_codeline': list, + 'cc_pre_hook': bool, + 'cc_log_only': bool + } + for key, value in cc_distribution.items(): + if key in expected_keys: + if not isinstance(value, expected_keys[key]): + raise TypeError(f'cc_distribution {key} should be a {expected_keys[key].__name__}') + else: + raise TypeError(f'{key} of cc_distribution is not supported.') + + +def validate_alert(alert): + if not isinstance(alert, dict): + raise TypeError('alert should be a dictionary') + rules = alert.get('rules') + if rules and isinstance(rules, list): + for rule in rules: + rule_name = rule.get("rule_name") + if rule_name and rule_name not in MonitorConst.RULE_NAME: + raise TypeError(f"{rule_name} is not supported") + args = rule.get("args") + if args and isinstance(args, dict): + threshold = args.get("threshold") + if not isinstance(threshold, float) or threshold < 0: + raise TypeError('threshold must be float and not less than 0') + dump = alert.get('dump') + if dump and not isinstance(dump, bool): + raise TypeError('dump must be bool.') + + +def validate_step_count_per_record(step_count_per_record): + if not is_int(step_count_per_record): + raise TypeError('step_count_per_record must be int.') + if step_count_per_record < 1: + raise ValueError("step_count_per_record must greater than 0") + if step_count_per_record > 1e6: + raise ValueError("step_count_per_record must smaller than 1e6") + + +def validate_start_step(start_step): + if not is_int(start_step): + raise TypeError('start_step must be int.') + if start_step < 0: + raise ValueError("start_step must greater than 0") + if start_step > 1e8: + raise ValueError("start_step must smaller than 1e8") + + +def validate_step_interval(step_interval): + if not is_int(step_interval): + raise TypeError('step_interval must be int.') + if step_interval < 1: + raise ValueError("step_interval must greater than 1") + if step_interval > 1e8: + raise ValueError("step_interval must smaller than 1e8") + + +def validate_collect_times(collect_times): + if not is_int(collect_times): + raise TypeError('collect_times must be int.') + if collect_times < 1: + raise ValueError("collect_times must greater than 1") + + +def validate_config(config): + config['ops'] = validate_ops(config.get('ops', [])) + + eps = config.get('eps', 1e-8) + if not isinstance(eps, float): + raise TypeError("eps should be a float") + + ranks = config.get("module_ranks", []) + validate_ranks(ranks) + + targets = config.get("targets", {}) + validate_targets(targets) + + print_struct = config.get('print_struct', False) + validate_print_struct(print_struct) + + ur_distribution = config.get('ur_distribution', False) + validate_ur_distribution(ur_distribution) + + xy_distribution = config.get('xy_distribution', False) + validate_xy_distribution(xy_distribution) + + wg_distribution = config.get('wg_distribution', False) + validate_wg_distribution(wg_distribution) + + mg_distribution = config.get('mg_distribution', False) + validate_mg_distribution(mg_distribution) + + param_distribution = config.get('param_distribution', False) + validate_param_distribution(param_distribution) + + cc_distribution = config.get('cc_distribution', {}) + validate_cc_distribution(cc_distribution) + + alert = config.get('alert', {}) + validate_alert(alert) + + step_count_per_record = config.get('step_count_per_record', 1) + validate_step_count_per_record(step_count_per_record) + + start_step = config.get('start_step', 0) + validate_start_step(start_step) + + step_interval = config.get('step_interval', 1) + validate_step_interval(step_interval) + + collect_times = config.get('collect_times', int(1e8)) + validate_collect_times(collect_times) + + if not targets: + if xy_distribution: + config["all_xy"] = True + config["targets"] = {"": {}} + config["is_select"] = False + else: + config["is_select"] = True + + +def time_str2time_digit(time_str): + time_format = '%b%d_%H-%M-%S' + try: + time_digit = datetime.strptime(time_str, time_format) + except Exception as e: + raise RuntimeError(f"illegal timestamp: {time_str}, timestamp should be prefix \ + of existing output dirpath, like 'Dec03_21-34-40'.") from e + return time_digit + + +def get_target_output_dir(monitor_path, time_start, time_end): + check_file_or_directory_path(monitor_path, isdir=True) + time_start = time_str2time_digit(time_start) if time_start is not None else time_start + time_end = time_str2time_digit(time_end) if time_end is not None else time_end + if time_start and time_end and time_start > time_end: + raise ValueError(f"time_start({time_start}) greater than time_end({time_end})") + result = {} + for dirname in os.listdir(monitor_path): + match = re.match(MonitorConst.OUTPUT_DIR_PATTERN, dirname) + if not match: + continue + time_tag = match.group(1) + rank = match.group(2) + target_time = time_str2time_digit(time_tag) + start_ok = time_start is None or target_time >= time_start + end_ok = time_end is None or target_time <= time_end + if start_ok and end_ok: + result[rank] = os.path.join(monitor_path, dirname) + return result diff --git a/debug/accuracy_tools/msprobe/mindspore/ms_config.py b/debug/accuracy_tools/msprobe/mindspore/ms_config.py index 2585938899da4e9db06ae3c008df599ac868c3f1..f20ed804c5bb8d8fbe4dba3e208060e8f52a3120 100644 --- a/debug/accuracy_tools/msprobe/mindspore/ms_config.py +++ b/debug/accuracy_tools/msprobe/mindspore/ms_config.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -48,7 +48,7 @@ class StatisticsConfig(BaseConfig): single_opt = ["statistics", "md5"] muti_opt = ["md5", "max", "min", "mean", "l2norm"] if isinstance(self.summary_mode, str) and self.summary_mode not in single_opt: - raise Exception("summary_mode is invalid") + raise Exception("summary_mode is invalid") if isinstance(self.summary_mode, list) and not all(opt in muti_opt for opt in self.summary_mode): raise Exception("summary_mode is invalid") @@ -106,12 +106,18 @@ class GradProbeConfig(BaseConfig): check_numeral_list_ascend(self.bounds) +class StructureConfig(BaseConfig): + def __init__(self, json_config): + super().__init__(json_config) + + TaskDict = { Const.TENSOR: TensorConfig, Const.STATISTICS: StatisticsConfig, Const.OVERFLOW_CHECK: OverflowCheckConfig, Const.FREE_BENCHMARK: FreeBenchmarkConfig, - Const.GRAD_PROBE: GradProbeConfig + Const.GRAD_PROBE: GradProbeConfig, + Const.STRUCTURE: StructureConfig } diff --git a/debug/accuracy_tools/msprobe/mindspore/service.py b/debug/accuracy_tools/msprobe/mindspore/service.py index 125d40011af73404abcfad6c3791f2e574dfb1c1..5afbd046be4caf29c4b247a0f8fdd655c5208fd0 100644 --- a/debug/accuracy_tools/msprobe/mindspore/service.py +++ b/debug/accuracy_tools/msprobe/mindspore/service.py @@ -22,6 +22,7 @@ import mindspore as ms from mindspore import nn from mindspore.common.api import _no_grad from mindspore.ops.primitive import Primitive + try: from mindspore.common._pijit_context import PIJitCaptureContext except ImportError: @@ -33,7 +34,8 @@ from msprobe.core.common.exceptions import DistributedNotInitializedError, Mspro from msprobe.core.common.file_utils import create_directory from msprobe.core.common.utils import Const, print_tools_ends_info, DumpPathAggregation from msprobe.core.data_dump.data_collector import build_data_collector -from msprobe.core.data_dump.data_processor.base import ModuleBackwardInputsOutputs, ModuleForwardInputsOutputs +from msprobe.core.data_dump.data_processor.base import (ModuleBackwardInputsOutputs, ModuleForwardInputsOutputs, + ModuleBackwardInputs) from msprobe.core.data_dump.scope import BaseScope from msprobe.mindspore.cell_processor import CellProcessor from msprobe.mindspore.common.log import logger @@ -43,6 +45,7 @@ from msprobe.mindspore.dump.hook_cell.api_registry import api_register from msprobe.mindspore.dump.hook_cell.primitive_hooks import PrimitiveHookService from msprobe.mindspore.dump.jit_dump import JitDump from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell +from msprobe.mindspore.dump.kernel_dump.kernel_config import create_kernel_config_json if is_mindtorch(): import torch @@ -64,23 +67,33 @@ class Service: self.current_rank = None self.dump_iter_dir = None self.start_call = False - self.check_level_valid() self.should_stop_service = False self.params_grad_info = {} + self.hook_handle_dict = {} # 提前注册,确保注册尽可能多的API hook self.register_api_hook() self.init_for_debug_level() @staticmethod - def check_model_valid(model): - if not model: - return model - targer_module_type = (torch.nn.Module, "torch.nn.Module") if is_mindtorch() else (nn.Cell, "mindspore.nn.Cell") - if not isinstance(model, targer_module_type[0]): + def check_model_valid(models): + target_module_type = (torch.nn.Module, "torch.nn.Module") if is_mindtorch() else (nn.Cell, "mindspore.nn.Cell") + if models is None or isinstance(models, target_module_type[0]): + return models + error_model = None + if isinstance(models, (list, tuple)): + for model in models: + if not isinstance(model, target_module_type[0]): + error_model = model + break + else: + error_model = models + + if error_model is not None: + error_info = (f"The 'model' parameter must be a {target_module_type[1]} or list[{target_module_type[1]}] " + f"type, currently there is a {type(error_model)} type.") raise MsprobeException( - MsprobeException.INVALID_PARAM_ERROR, f"model 参数必须是 {targer_module_type[1]} 类型。" - ) - return model + MsprobeException.INVALID_PARAM_ERROR, error_info) + return models @staticmethod def prepare_module_input_output(target_type, cell, input_data, output): @@ -90,12 +103,6 @@ class Service: module_input_output = ModuleForwardInputsOutputs(args=input_data, kwargs=cell.input_kwargs, output=output) return module_input_output - def check_level_valid(self): - if self.config.level == Const.LEVEL_L2: - raise MsprobeException( - MsprobeException.INVALID_PARAM_ERROR, "L2 level dump function is currently not supported." - ) - def build_hook(self, target_type, name): def pre_hook(api_or_cell_name, cell, input_data): if not self.should_execute_hook(target_type, cell, True): @@ -134,7 +141,12 @@ class Service: if not (Const.FORWARD in self.config.data_mode and Const.BACKWARD not in self.config.data_mode): for param_name, param in params_dict.items(): if param.requires_grad: - param.register_hook(grad_hook(cell, ori_name, param_name)) + name = ori_name + Const.SEP + param_name + old_handle = self.hook_handle_dict.get(name) + if old_handle and hasattr(old_handle, "remove"): + old_handle.remove() + handle = param.register_hook(grad_hook(cell, ori_name, param_name)) + self.hook_handle_dict[name] = handle def init_params_grad_info(cell, params_dict): ''' @@ -164,10 +176,15 @@ class Service: module_input_output = self.prepare_module_input_output(target_type, cell, input_data, output) if target_type == BaseScope.Module_Type_Module: api_or_cell_name = self.cell_processor.set_and_get_reserved_name(cell, api_or_cell_name) - params_dict = {key.split(Const.SEP)[-1]: value for key, value in cell.parameters_dict(recurse=False).items()} - setattr(module_input_output, Const.PARAMS, params_dict) + params_dict = {} + if self.config.task != Const.STRUCTURE: + params_dict = { + key.split(Const.SEP)[-1]: value + for key, value in cell.parameters_dict(recurse=False).items() + } + setattr(module_input_output, Const.PARAMS, params_dict) # 判断是否需要注册参数hook - if not hasattr(cell, 'params_grad_name') and params_dict: + if params_dict: ori_name = api_or_cell_name.rsplit(Const.SEP, 2)[0] grad_name = ori_name + Const.SEP + Const.PARAMS_GRAD # 首次执行前向hook时,添加params_grad_name属性,并注册参数hook @@ -209,6 +226,16 @@ class Service: self.data_collector.backward_data_collect(api_or_cell_name, cell, pid, module_input_output) self.inner_switch = False + def pre_backward_hook(api_or_cell_name, cell, grad_input): + if not self.should_execute_hook(target_type, cell, False): + return + self.inner_switch = True + module_input = ModuleBackwardInputs(grad_input=grad_input) + self.data_collector.update_api_or_module_name(api_or_cell_name) + self.data_collector.backward_input_data_collect(api_or_cell_name, cell, pid, module_input) + + self.inner_switch = False + pid = os.getpid() if target_type == BaseScope.Module_Type_Module: full_forward_name = name + Const.FORWARD @@ -219,6 +246,7 @@ class Service: pre_forward_hook = functools.partial(pre_hook, full_forward_name) forward_hook = functools.partial(forward_hook, full_forward_name) backward_hook = functools.partial(backward_hook, full_backward_name) + pre_backward_hook = functools.partial(pre_backward_hook, full_backward_name) def wrap_pre_forward_hook(cell, input_data): return pre_forward_hook(cell, input_data) @@ -229,7 +257,10 @@ class Service: def wrap_backward_hook(cell, grad_input, grad_output): return backward_hook(cell, grad_input, grad_output) - return wrap_pre_forward_hook, wrap_forward_hook, wrap_backward_hook + def wrap_pre_backward_hook(cell, grad_input): + return pre_backward_hook(cell, grad_input) + + return wrap_pre_forward_hook, wrap_forward_hook, wrap_backward_hook, wrap_pre_backward_hook def update_primitive_counters(self, primitive_name): if primitive_name not in self.primitive_counters: @@ -240,6 +271,10 @@ class Service: def step(self): if self.config.level == Const.LEVEL_DEBUG: return + if self.config.async_dump: + self.data_collector.fill_stack_tensor_data() + if self.config.task == Const.TENSOR: + self.data_collector.data_processor.dump_async_data() self.data_collector.write_json() self.current_iter += 1 self.data_collector.update_iter(self.current_iter) @@ -276,7 +311,10 @@ class Service: if self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L1]: JitDump.set_config(self.config) JitDump.set_data_collector(self.data_collector) - ms.common.api._MindsporeFunctionExecutor = JitDump + if hasattr(ms.common.api, "_MindsporeFunctionExecutor"): + ms.common.api._MindsporeFunctionExecutor = JitDump + else: + ms.common.api._JitExecutor = JitDump ms.common.api._PyNativeExecutor.grad = JitDump.grad if pijit_label: PIJitCaptureContext.__enter__ = self.empty @@ -308,6 +346,10 @@ class Service: self.switch = False self.primitive_switch = False self.start_call = False + if self.config.async_dump: + self.data_collector.fill_stack_tensor_data() + if self.config.task == Const.TENSOR: + self.data_collector.data_processor.dump_async_data() self.data_collector.write_json() JitDump.jit_dump_switch = False @@ -337,6 +379,12 @@ class Service: create_directory(self.config.dump_path) self.dump_iter_dir = os.path.join(self.config.dump_path, f"step{self.current_iter}") cur_rank = self.current_rank if self.current_rank is not None else '' + if self.config.level == Const.LEVEL_L2: + create_directory(self.dump_iter_dir) + kernel_config_path = create_kernel_config_json(self.dump_iter_dir, cur_rank) + self.config.kernel_config_path = kernel_config_path + return + dump_dir = os.path.join(self.dump_iter_dir, f"rank{cur_rank}") create_directory(dump_dir) if self.config.task in self.data_collector.tasks_need_tensor_data: @@ -360,11 +408,24 @@ class Service: pass def register_api_hook(self): - if self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L1]: + if self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L1, Const.LEVEL_L2]: logger.info(f"The api {self.config.task} hook function is successfully mounted to the model.") api_register.initialize_hook(functools.partial(self.build_hook, BaseScope.Module_Type_API)) api_register.api_set_hook_func() + def get_cells_and_names(self): + cells_and_names_with_index = {} + + def get_cell_or_module(model): + return model.named_modules() if is_mindtorch() else model.cells_and_names() + + if isinstance(self.model, (list, tuple)): + for index, model in enumerate(self.model): + cells_and_names_with_index[str(index)] = get_cell_or_module(model) + else: + cells_and_names_with_index["-1"] = get_cell_or_module(self.model) + return cells_and_names_with_index + def register_primitive_hook(self): if self.config.level not in [Const.LEVEL_MIX, Const.LEVEL_L1]: return @@ -372,14 +433,12 @@ class Service: return primitive_set = set() - if is_mindtorch(): - cells_and_names = self.model.named_modules() - else: - cells_and_names = self.model.cells_and_names() - for _, cell in cells_and_names: - for attribute, value in vars(cell).items(): - if isinstance(value, Primitive): - primitive_set.add((attribute, value)) + cells_and_names_with_index = self.get_cells_and_names() + for cells_and_names in cells_and_names_with_index.values(): + for _, cell in cells_and_names: + for attribute, value in vars(cell).items(): + if isinstance(value, Primitive): + primitive_set.add((attribute, value)) for pname, primitive in primitive_set: primitive_class_name = primitive.__class__.__name__ @@ -395,35 +454,38 @@ class Service: if not self.model: raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR, f"The current level is {self.config.level}, the model cannot be None") - - cell_names_and_type = (self.model.named_modules(), Const.MODULE) \ - if is_mindtorch() else (self.model.cells_and_names(), Const.CELL) - - for name, cell in cell_names_and_type[0]: - if cell == self.model: - continue - - prefix = (cell_names_and_type[1] + Const.SEP + name + Const.SEP + - cell.__class__.__name__ + Const.SEP) - _, forward_hook, backward_hook = self.build_hook(BaseScope.Module_Type_Module, prefix) - cell.register_forward_hook(forward_hook) - cell.register_forward_pre_hook( - self.cell_processor.node_hook(prefix + Const.FORWARD, Const.START)) - cell.register_forward_hook( - self.cell_processor.node_hook(prefix + Const.FORWARD, Const.STOP)) - - register_backward_hook_functions["full"](cell, backward_hook) - register_backward_hook_functions["pre"]( - cell, self.cell_processor.node_hook(prefix + Const.BACKWARD, Const.START)) - register_backward_hook_functions["full"]( - cell, self.cell_processor.node_hook(prefix + Const.BACKWARD, Const.STOP)) + model_type = Const.MODULE if is_mindtorch() else Const.CELL + cells_and_names_with_index = self.get_cells_and_names() + + for index, cells_and_names in cells_and_names_with_index.items(): + model = self.model if index == "-1" else self.model[int(index)] + for name, cell in cells_and_names: + if cell == model: + continue + cell_index = (index + Const.SEP) if index != "-1" else "" + prefix = (model_type + Const.SEP + cell_index + name + + Const.SEP + cell.__class__.__name__ + Const.SEP) + _, forward_hook, backward_hook, _ = self.build_hook(BaseScope.Module_Type_Module, prefix) + cell.register_forward_hook(forward_hook) + cell.register_forward_pre_hook( + self.cell_processor.node_hook(prefix + Const.FORWARD, Const.START)) + cell.register_forward_hook( + self.cell_processor.node_hook(prefix + Const.FORWARD, Const.STOP)) + + register_backward_hook_functions["full"](cell, backward_hook) + register_backward_hook_functions["pre"]( + cell, self.cell_processor.node_hook(prefix + Const.BACKWARD, Const.START)) + register_backward_hook_functions["full"]( + cell, self.cell_processor.node_hook(prefix + Const.BACKWARD, Const.STOP)) def reset_status(self): self.primitive_hook_service.primitive_counters.clear() - self.data_collector.data_writer.reset_cache() + self.data_collector.reset_status() JitDump.jit_count = defaultdict(int) self.params_grad_info.clear() - + if self.config.level == Const.LEVEL_L2: + self.data_collector.data_processor.reset_status() + return if self.config.step and self.current_iter not in self.config.step: return if self.config.rank and self.current_rank not in self.config.rank: @@ -478,4 +540,4 @@ class Service: # backward save if save_backward: - self.data_collector.debug_data_collect_backward(variable, grad_name_with_count) \ No newline at end of file + self.data_collector.debug_data_collect_backward(variable, grad_name_with_count) diff --git a/debug/accuracy_tools/msprobe/msprobe.py b/debug/accuracy_tools/msprobe/msprobe.py index c88519f4de361eee6b11bb2fd98c11f21678cca1..8e0386fde6dccc071c3d9d8e1a86729a2c483c7c 100644 --- a/debug/accuracy_tools/msprobe/msprobe.py +++ b/debug/accuracy_tools/msprobe/msprobe.py @@ -55,7 +55,6 @@ def main(): _merge_result_parser(merge_result_parser) is_torch_available = is_module_available("torch") - is_mindspore_available = is_module_available("mindspore") if len(sys.argv) < 4: parser.print_help() diff --git a/debug/accuracy_tools/msprobe/pytorch/__init__.py b/debug/accuracy_tools/msprobe/pytorch/__init__.py index d8767e2df47bcb636c67e7ce960dcadb39139933..ce84e6b35b74e55a90915350ff3ef2da3f7ba441 100644 --- a/debug/accuracy_tools/msprobe/pytorch/__init__.py +++ b/debug/accuracy_tools/msprobe/pytorch/__init__.py @@ -1,6 +1,4 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -15,13 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. - import torch -torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0' -if torch_version_above_or_equal_2: - from msprobe.pytorch.monitor.module_hook import TrainerMon from .compare.distributed_compare import compare_distributed from .compare.pt_compare import compare from .common.utils import seed_all -from .debugger.precision_debugger import PrecisionDebugger -from .functional.module_dump import module_dump, module_dump_end +from .debugger.precision_debugger import PrecisionDebugger, module_dump, module_dump_end + +torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0' +if torch_version_above_or_equal_2: + from msprobe.pytorch.monitor.module_hook import TrainerMon diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare_column.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare_column.py index e457afc71d3401c3fdd03d9ef745ccbbb1f63545..976fb7f5f258eaa4e6a57caf596f5bbfc39acfa5 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare_column.py @@ -21,11 +21,12 @@ from msprobe.pytorch.common.log import logger class CompareColumn: __slots__ = [ - 'bench_type', 'npu_type','shape', 'cosine_sim', 'max_abs_err', 'rel_err_hundredth', + 'bench_type', 'npu_type', 'shape', 'cosine_sim', 'max_abs_err', 'rel_err_hundredth', 'rel_err_ten_thousandth', 'inf_nan_error_ratio', 'rel_err_ratio', 'abs_err_ratio', 'small_value_err_ratio', 'max_rel_error', 'mean_rel_error', 'rmse', 'eb', 'max_ulp_error', 'mean_ulp_error', 'ulp_error_proportion', 'error_rate', 'rel_err_thousandth' ] + def __init__(self): self.bench_type = CompareConst.SPACE self.npu_type = CompareConst.SPACE @@ -76,13 +77,15 @@ class CompareColumn: class ApiPrecisionOutputColumn: __slots__ = [ - 'api_name', 'small_value_err_ratio','small_value_err_status', 'rmse_ratio', 'rmse_status', - 'max_rel_err_ratio','max_rel_err_status','mean_rel_err_ratio','mean_rel_err_status', 'eb_ratio', - 'eb_status', 'inf_nan_error_ratio', 'inf_nan_error_ratio_status','rel_err_ratio', + 'api_name', 'small_value_err_ratio', 'small_value_err_status', 'rmse_ratio', 'rmse_status', + 'max_rel_err_ratio', 'max_rel_err_status', 'mean_rel_err_ratio', 'mean_rel_err_status', 'eb_ratio', + 'eb_status', 'inf_nan_error_ratio', 'inf_nan_error_ratio_status', 'rel_err_ratio', 'rel_err_ratio_status', 'abs_err_ratio', 'abs_err_ratio_status', 'error_rate', 'error_rate_status', - 'mean_ulp_err', 'ulp_err_proportion', 'ulp_err_proportion_ratio', 'ulp_err_status','rel_err_thousandth', - 'rel_err_thousandth_status', 'compare_result', 'compare_algorithm', 'compare_message' + 'mean_ulp_err', 'ulp_err_proportion', 'ulp_err_proportion_ratio', 'ulp_err_status', + 'rel_err_thousandth', 'rel_err_thousandth_status', 'compare_result', 'compare_algorithm', + 'compare_message' ] + def __init__(self): self.api_name = CompareConst.SPACE self.small_value_err_ratio = CompareConst.SPACE diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare_input.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare_input.py index 414acabb21d0c86d617fb88a00c8f1f238624280..8c21def9d859a3ec5637e396f9396527c5fc8979 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare_input.py @@ -30,7 +30,8 @@ class CompareInput: rel_err_orign (float or array-like, optional): The original relative error values. Defaults to None. Methods: - __init__(bench_output, device_output, compare_column, dtype, rel_err_orign): Initializes an instance of CompareInput. + __init__(bench_output, device_output, compare_column, dtype, rel_err_orign): + Initializes an instance of CompareInput. """ def __init__(self, bench_output, device_output, compare_column, dtype=None, rel_err_orign=None): self.bench_output = bench_output diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py index f2526468a48927fcd9b720eb3f766ed07a253801..797210f09c3b55a64002a4aa84a3d39770ae803c 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py @@ -1,8 +1,9 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -""" -# Copyright (C) 2024. Huawei Technologies Co., Ltd. All rights reserved. -# Licensed under the Apache License, Version 2.0 (the "License"); +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # @@ -13,17 +14,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" + import argparse import json import os import re + import math import numpy as np import torch - -from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import binary_standard_api, absolute_standard_api, ulp_standard_api, thousandth_standard_api +from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import binary_standard_api, absolute_standard_api, \ +ulp_standard_api, thousandth_standard_api from msprobe.core.common.file_utils import FileOpen, load_json, save_json from msprobe.core.common.utils import check_file_or_directory_path, check_op_str_pattern_valid, is_int from msprobe.core.common.const import Const, MonitorConst, MsgConst @@ -78,6 +80,7 @@ class APIInfo: def is_supported_type(self): return self.api_type in OPERATOR_TYPE + class CommonConfig: def __init__(self, json_config): self.dump_json_path = json_config.get('dump_json_path') @@ -147,6 +150,7 @@ class CommonConfig: if not is_int(self.iter_times): raise ValueError(f'iter_times is invalid, it should be an int') + class APIExtractor: def __init__(self, api_name, dump_json_path, output_file): self.api_name = api_name @@ -186,6 +190,7 @@ class APIExtractor: elif DATA_NAME in data: data[DATA_NAME] = os.path.join(dump_data_dir, data[DATA_NAME]) + class OperatorScriptGenerator: def __init__(self, common_config, args_info_forward, kwargs_info_forward, args_info_backward): self.common_config = common_config @@ -238,7 +243,8 @@ class OperatorScriptGenerator: ordinal_number: how many times the same api has been called direction_status: forward random_seed: if mode is random_data, random seed is random_seed - iter_times: if mode is random_data, generate iter_times group of data; if mode is real_data, iter_times does not matter + iter_times: if mode is random_data, generate iter_times group of data; if mode is real_data, + iter_times does not matter args_element_assignment: code for args assignment args_list_generator_device: code for generate args list on device args_list_generator_bench: code for generate args list on bench @@ -267,17 +273,25 @@ class OperatorScriptGenerator: internal_settings["iter_times"] = 1 else: internal_settings["iter_times"] = self.common_config.iter_times - internal_settings["args_element_assignment"] = self.generate_args_element_assignment_code(self.args_info_forward) - internal_settings["args_list_generator_device"] = self.generate_args_list(self.args_info_forward, flag_device=True) - internal_settings["args_list_generator_bench"] = self.generate_args_list(self.args_info_forward, flag_device=False) - internal_settings["kwargs_value_assignment"] = self.generate_kwargs_value_assignment_code(self.kwargs_info_forward) - internal_settings["kwargs_dict_generator_device"] = self.generate_kwargs_dict(self.kwargs_info_forward, flag_device=True) - internal_settings["kwargs_dict_generator_bench"] = self.generate_kwargs_dict(self.kwargs_info_forward, flag_device=False) + internal_settings["args_element_assignment"] = \ + self.generate_args_element_assignment_code(self.args_info_forward) + internal_settings["args_list_generator_device"] = \ + self.generate_args_list(self.args_info_forward, flag_device=True) + internal_settings["args_list_generator_bench"] = \ + self.generate_args_list(self.args_info_forward, flag_device=False) + internal_settings["kwargs_value_assignment"] = \ + self.generate_kwargs_value_assignment_code(self.kwargs_info_forward) + internal_settings["kwargs_dict_generator_device"] = \ + self.generate_kwargs_dict(self.kwargs_info_forward, flag_device=True) + internal_settings["kwargs_dict_generator_bench"] = \ + self.generate_kwargs_dict(self.kwargs_info_forward, flag_device=False) if self.common_config.propagation == Const.BACKWARD: internal_settings["args_element_assignment_backward"] = self.generate_args_element_assignment_code( self.args_info_backward) - internal_settings["args_list_generator_device_backward"] = self.generate_args_list(self.args_info_backward, flag_device=True) - internal_settings["args_list_generator_bench_backward"] = self.generate_args_list(self.args_info_backward, flag_device=False) + internal_settings["args_list_generator_device_backward"] = \ + self.generate_args_list(self.args_info_backward, flag_device=True) + internal_settings["args_list_generator_bench_backward"] = \ + self.generate_args_list(self.args_info_backward, flag_device=False) else: internal_settings["args_element_assignment_backward"] = '' internal_settings["args_list_generator_device_backward"] = '' @@ -290,12 +304,15 @@ class OperatorScriptGenerator: args_element_assignment = "" for index, arg in enumerate(args_info): if isinstance(arg, (list, tuple)): - new_args_element_assignment = self.recursive_args_element_assignment(arg, name_number + "_" + str(index)) + new_args_element_assignment = \ + self.recursive_args_element_assignment(arg, name_number + "_" + str(index)) args_element_assignment += new_args_element_assignment else: arg["parameter_name"] = "arg" + name_number + "_" + str(index) - args_element_assignment += " " + "arg_info" + name_number + "_" + str(index) + " = " + "{}".format(str(arg)) + MsgConst.SPECIAL_CHAR[0] - args_element_assignment += " " + "arg" + name_number + "_" + str(index) + " = " + "generate_data(arg_info" + name_number + "_" + str(index) + ")" + MsgConst.SPECIAL_CHAR[0] + args_element_assignment += " " + "arg_info" + name_number + "_" + str(index) + " = " + \ + "{}".format(str(arg)) + MsgConst.SPECIAL_CHAR[0] + args_element_assignment += " " + "arg" + name_number + "_" + str(index) + " = " + \ + "generate_data(arg_info" + name_number + "_" + str(index) + ")" + MsgConst.SPECIAL_CHAR[0] return args_element_assignment @@ -320,7 +337,8 @@ class OperatorScriptGenerator: args_list_generator += ".to(device)" if flag_bench: args_list_generator += '.to(torch.device("cpu"))' - args_list_generator += ".to(RAISE_PRECISION.get(str(" + arg.get("parameter_name") + ".dtype), " + arg.get("parameter_name") + ".dtype))" + args_list_generator += ".to(RAISE_PRECISION.get(str(" + arg.get("parameter_name") + \ + ".dtype), " + arg.get("parameter_name") + ".dtype))" args_list_generator += Const.COMMA return args_list_generator @@ -338,12 +356,15 @@ class OperatorScriptGenerator: if info.get("type") == "torch.device" or info.get("type") == "torch.dtype": kwargs_value_assignment += " " + "kwarg_" + key_name + name_number + " = " + info.get("value") else: - kwargs_value_assignment += " " + "kwarg_info_" + key_name + name_number + " = " + "{}".format(str(info)) + MsgConst.SPECIAL_CHAR[0] - kwargs_value_assignment += " " + "kwarg_" + key_name + name_number + " = " + "generate_data(kwarg_info_" + key_name + name_number + ")" + MsgConst.SPECIAL_CHAR[0] + kwargs_value_assignment += " " + "kwarg_info_" + key_name + name_number + " = " + \ + "{}".format(str(info)) + MsgConst.SPECIAL_CHAR[0] + kwargs_value_assignment += " " + "kwarg_" + key_name + name_number + " = " + \ + "generate_data(kwarg_info_" + key_name + name_number + ")" + MsgConst.SPECIAL_CHAR[0] info["parameter_name"] = "kwarg_" + key_name + name_number else: for index, arg in enumerate(info): - new_kwargs_value_assignment = self.recursive_kwargs_value_assignment(arg, key_name, name_number + "_" + str(index)) + new_kwargs_value_assignment = self.recursive_kwargs_value_assignment(arg, key_name, name_number + \ + "_" + str(index)) kwargs_value_assignment += new_kwargs_value_assignment return kwargs_value_assignment @@ -363,7 +384,8 @@ class OperatorScriptGenerator: kwargs_dict_generator += ".to(device)" if flag_bench: kwargs_dict_generator += '.to(torch.device("cpu"))' - kwargs_dict_generator += ".to(RAISE_PRECISION.get(str(" + info.get("parameter_name") + ".dtype), " + info.get("parameter_name") + ".dtype))" + kwargs_dict_generator += ".to(RAISE_PRECISION.get(str(" + info.get("parameter_name") + \ + ".dtype), " + info.get("parameter_name") + ".dtype))" else: (left_bracket, right_bracket) = ("[", "]") if isinstance(info, list) else ("(", ")") kwargs_dict_generator += left_bracket @@ -377,7 +399,7 @@ class OperatorScriptGenerator: def generate_kwargs_dict(self, kwargs_info, flag_device): kwargs_dict_generator = "" for key, value in kwargs_info.items(): - kwargs_dict_generator += '"' + key + '"' + MonitorConst.VPP_SEP + kwargs_dict_generator += '"' + key + '"' + MonitorConst.NAME_SEP if flag_device: kwargs_dict_generator += self.recursive_kwargs_dict(value, flag_device=True) + Const.COMMA else: @@ -393,6 +415,7 @@ def _op_generator_parser(parser): help=" Path of extract api_name.json.", required=True) + def parse_json_config(json_file_path): if not json_file_path: config_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) @@ -401,6 +424,7 @@ def parse_json_config(json_file_path): common_config = CommonConfig(json_config) return common_config + def _run_operator_generate_commond(cmd_args): common_config = parse_json_config(cmd_args.config_input) @@ -434,7 +458,8 @@ def _run_operator_generate_commond(cmd_args): internal_settings = op_generate.get_settings(api_full_name_forward) template_path = os.path.join(os.path.dirname(__file__), "operator_replication.template") - operator_script_path = os.path.join(cmd_args.api_output_path, "{0}.py".format(internal_settings.get("api_full_name"))) + operator_script_path = os.path.join(cmd_args.api_output_path, + "{0}.py".format(internal_settings.get("api_full_name"))) try: with FileOpen(template_path, 'r') as ftemp, FileOpen(operator_script_path, 'w') as fout: diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py index c1d72e8d1fcbea3036771b414d2c56b786521418..82df8c54e87ea1627159a52aef2544028ab21b22 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py @@ -20,6 +20,7 @@ from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import absolute_ ulp_standard_api, thousandth_standard_api, accumulative_error_standard_api, BINARY_COMPARE_UNSUPPORT_LIST from msprobe.core.common.const import CompareConst + class StandardRegistry: """ Registry class for managing comparison standards and functions. diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py index a26dc16dca3330110c8ba116cd93b9c8197e2286..df181588ad01836186c82df6fc2d23eef63238f0 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py @@ -135,7 +135,7 @@ class UlpPrecisionCompare(BasePrecisionCompare): CompareConst.ULP_ERR_STATUS: CompareConst.ERROR } compare_result = CompareConst.ERROR - metrics[CompareConst.COMPARE_MESSAGE] = metrics.get(CompareConst.COMPARE_MESSAGE, "") + \ + metrics[CompareConst.COMPARE_MESSAGE] = metrics.get(CompareConst.COMPARE_MESSAGE, "") + \ "ERROR: ULP误差不满足标准\n" metrics.update({CompareConst.COMPARE_RESULT: compare_result}) return metrics @@ -150,7 +150,7 @@ class UlpPrecisionCompare(BasePrecisionCompare): else: status, final_message = \ self._get_fp16_ulp_err_status(ulp_err_proportion, ulp_err_proportion_ratio) - metrics[CompareConst.COMPARE_MESSAGE] = metrics.get(CompareConst.COMPARE_MESSAGE, "") + final_message + metrics[CompareConst.COMPARE_MESSAGE] = metrics.get(CompareConst.COMPARE_MESSAGE, "") + final_message status_dict = { CompareConst.ULP_ERR_STATUS: status diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py index 59415a387e64fff2942c7f230f94cf34992de116..9d89b2de32f70c6fa7abf38add49b58a13531d7a 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py @@ -393,10 +393,11 @@ def get_output_dtype(api_info): """ output_dtype = None output_info = api_info.get(Const.OUTPUT) - if output_info: - output_dtype = output_info[0].get(Const.DTYPE) - module_name, attribute_name = get_module_and_atttribute_name(output_dtype) - output_dtype = get_attribute(module_name, attribute_name) + if output_info and isinstance(output_info[0], dict): + output_str_dtype = output_info[0].get(Const.DTYPE) + if output_str_dtype in Const.TORCH_FLOAT_DTYPE: + module_name, attribute_name = get_module_and_atttribute_name(output_str_dtype) + output_dtype = get_attribute(module_name, attribute_name) return output_dtype diff --git a/debug/accuracy_tools/msprobe/pytorch/bench_functions/apply_adam.py b/debug/accuracy_tools/msprobe/pytorch/bench_functions/apply_adam.py new file mode 100644 index 0000000000000000000000000000000000000000..408929685e0c5de9984f06674df9ad6a76cd1281 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/bench_functions/apply_adam.py @@ -0,0 +1,215 @@ +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import namedtuple +import torch + + +VarParams = namedtuple('VarParams', ['var', 'lr_t', 'm_t', 'beta1_broad', 'grad', 'epsilon', 'v_t']) + + +def _output_m_compute(m, beta1_broad, grad): + """ + _output_m_compute + do compute m_t = m + (beta1 - 1) * (m - grad) + """ + input_dtype = m.dtype + + sneg_one = torch.ones((1), dtype=input_dtype) * -1 + sneg_one = sneg_one.to(beta1_broad.device) + + # `formula; beta1 -1` + vsub_beta1_1 = torch.add(beta1_broad, sneg_one) + + # `formula; m - grad` + vsub_m_grad = torch.sub(m, grad) + + # `formula; (beta1 - 1) * (m - grad)` + vmul_m = torch.mul(vsub_beta1_1, vsub_m_grad) + + # `formula; m_t = m + (beta1 - 1) * (m - grad)` + m_t = torch.add(m, vmul_m) + + return m_t + + +def _output_v_compute(v, beta2, grad): + """ + _output_v_compute + do compute v_t = v + (1 - beta2)*(grad*grad -v) + """ + input_dtype = v.dtype + + sneg_one = torch.ones((1), dtype=input_dtype) * -1 + + # `formula; broadcast beta2 to vector` + beta2_tensor = torch.tensor(beta2, dtype=input_dtype) + beta2_broad = beta2_tensor.expand_as(v) + + # `formula; beta2 - 1` + vsub_beta2_1 = torch.add(beta2_broad, sneg_one) + vsub_beta2_1 = vsub_beta2_1.to(v.device) + + # `formula; grad * grad` + vmul_grad_grad = torch.mul(grad, grad) + + # `formula; (v - grad*grad)` + vsub_v_grad = torch.sub(v, vmul_grad_grad) + + # `formula; (beta2 -1) * (v - grad * grad)` + vmul_grad = torch.mul(vsub_beta2_1, vsub_v_grad) + + # `formula; v_t = v + (beta2 - 1) * (v - grad * grad)` + v_t = torch.add(v, vmul_grad) + + return v_t + + +def _inner_lr_compute(lr, beta2_power, beta1_power, compute_shape_tensor): + """ + _inner_lr_compute + `formula; lr_t = learning_rate * (sqrt(1-beta2_power)) / (1 - beta1_power)` + """ + + input_dtype = compute_shape_tensor.dtype + + s_one = torch.ones((1), dtype=input_dtype) + + s_neg_one = torch.ones((1), dtype=input_dtype) * -1 + + # `formula; (1 - beta2_power)` + v_neg_beta2_power = torch.mul(beta2_power, s_neg_one) + v_add_beta2_power = torch.add(v_neg_beta2_power, s_one) + + # `formula; sqrt(1 - beta2_power)` + v_sqrt_beta2_power = torch.sqrt(v_add_beta2_power) + + # `formula; (1 - beta1_power)` + v_neg_beta1_power = torch.mul(beta1_power, s_neg_one) + v_add_beta1_power = torch.add(v_neg_beta1_power, s_one) + + # `formula; learning_rate * (sqrt(1-beta2_power)` + res = torch.mul(lr, v_sqrt_beta2_power) + + # `formula; learning_rate*(sqrt(1-beta2_power))/(1-beta1_power)` + res = torch.div(res, v_add_beta1_power) + return res.expand_as(compute_shape_tensor) + + +def _inner_eps_add_sqrt_vt_compute(epsilon, v_t): + """ + (epsilon + sqrt(v_t) ) + """ + # `formula; sqrt(v_t)` + sqrt_vt = torch.sqrt(v_t) + + # `formula; broadcast epsilon to vector` + input_dtype = v_t.dtype + epsilon_tensor = torch.tensor(epsilon, dtype=input_dtype) + epsilon_broad = epsilon_tensor.expand_as(v_t) + epsilon_broad = epsilon_broad.to(sqrt_vt.device) + + # `formula; epsilon + sqrt(v_t)` + v_add_sqrt_v = torch.add(sqrt_vt, epsilon_broad) + + return v_add_sqrt_v + + +def _output_var_t_compute_use_nesterov(varparams): + """ + _output_var_t_compute_use_nesterov + `formula; var_t = var - lr_t * (m_t * beta1 + (1 - beta1) * grad) / (epsilon + sqrt(v_t))` + `formula; var_t = var - lr_t * (m_t * beta1 + (1 - beta1) * grad) / (epsilon + sqrt(v_t))` + """ + var = varparams.var + lr_t = varparams.lr_t + m_t = varparams.m_t + beta1_broad = varparams.beta1_broad + grad = varparams.grad + epsilon = varparams.epsilon + v_t = varparams.v_t + + input_dtype = var.dtype + + s_one = torch.ones((1), dtype=input_dtype) + + s_neg_one = torch.ones((1), dtype=input_dtype) * -1 + + # `formula; m_t * beta1` + v_muls_mt_beta1 = torch.mul(m_t, beta1_broad) + + # `formula; 1 -beta1` + v_neg_beta1 = torch.mul(beta1_broad, s_neg_one) + vsub_1_beta1 = torch.add(v_neg_beta1, s_one) + + # `formula; (1-beta1)* grad` + v_mul_grad = torch.mul(vsub_1_beta1, grad) + + # `formula; (m_t*beta1 + (1 - beta1)*grad)` + v_div_left = torch.add(v_muls_mt_beta1, v_mul_grad) + + # `formula; lr_t * (m_t*beta1 + (1 - beta1) * grad)` + # broadcast lr_t to vector + + lrt_broad = lr_t.expand_as(var) + v_mul_left = torch.mul(lrt_broad, v_div_left) + + # `formula; (epsilon + sqrt(v_t))` + v_add_sqrt_v = _inner_eps_add_sqrt_vt_compute(epsilon, v_t) + + # `formula; lr_t * (m_t*beta1 + (1-beta1)*grad / (epsilon + sqrt(v_t))` + v_div_res = torch.div(v_mul_left, v_add_sqrt_v) + + # `formula; var - lr_t * (m_t*beta1 + (1-beta1)*grad) / (epsilon + sqrt(v_t))` + v_t = torch.sub(var, v_div_res) + + return v_t + + +def _output_var_t_compute(var, lr_t, m_t, epsilon, v_t): + """ + _output_var_t_compute + `var_t = var - lr_t * m_t / (epsilon + sqrt(v_t))` + """ + # `formula; lr_t * m_t` + lr_t = lr_t.to(m_t.device) + v_mul_left = torch.mul(lr_t, m_t) + + # `formula; (epsilon + sqrt(v_t))` + v_add_sqrt_v = _inner_eps_add_sqrt_vt_compute(epsilon, v_t) + + # `formula; lr_t * m_t /(epsilon + sqrt(v_t))` + v_div_res = torch.div(v_mul_left, v_add_sqrt_v) + + # `formula; var - lr_t * m_t / (epsilon + sqrt(v_t))` + v_t = torch.sub(var, v_div_res) + + return v_t + + +def npu_apply_adam(beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad, use_locking, use_nesterov, out): + var, m, v = out + input_dtype = m.dtype + beta1_tensor = torch.tensor(beta1, dtype=input_dtype).to(m.device) + beta1_broad = beta1_tensor.expand_as(m) + m_t = _output_m_compute(m, beta1_broad, grad) + v_t = _output_v_compute(v, beta2, grad) + lr_t = _inner_lr_compute(lr, beta2_power, beta1_power, grad) + if use_nesterov: + var_params = VarParams(var, lr_t, m_t, beta1_broad, grad, epsilon, v_t) + var_t = _output_var_t_compute_use_nesterov(var_params) + else: + var_t = _output_var_t_compute(var, lr_t, m_t, epsilon, v_t) + return var_t, m_t, v_t diff --git a/debug/accuracy_tools/msprobe/pytorch/bench_functions/group_norm_silu.py b/debug/accuracy_tools/msprobe/pytorch/bench_functions/group_norm_silu.py new file mode 100644 index 0000000000000000000000000000000000000000..c8757083c56b78cabbb83ec5d2b7b80f0edd8421 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/bench_functions/group_norm_silu.py @@ -0,0 +1,27 @@ +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + + +def npu_group_norm_silu(x, gama, beta, group, eps): + if len(x.shape) != 4: + raise ValueError("x shape should be (N, C, H, W)") + res = torch.ops.aten.native_group_norm(x, gama, beta, x.shape[0], x.shape[1], x.shape[2] * x.shape[3], group, eps) + res = list(res) + if not res: + raise ValueError("run native_group_norm failed") + res[0] = torch.nn.functional.silu(res[0]) + return res diff --git a/debug/accuracy_tools/msprobe/pytorch/bench_functions/mish.py b/debug/accuracy_tools/msprobe/pytorch/bench_functions/mish.py new file mode 100644 index 0000000000000000000000000000000000000000..f395a30ee60db57ab9a298a637c8318ffce7aec4 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/bench_functions/mish.py @@ -0,0 +1,21 @@ +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + + +def npu_mish(x): + mish = torch.nn.Mish() + return mish(x) diff --git a/debug/accuracy_tools/msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py b/debug/accuracy_tools/msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py new file mode 100644 index 0000000000000000000000000000000000000000..be15935ce9c9f77bc0a8447902f7f4a7b536a7fb --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py @@ -0,0 +1,44 @@ +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import numpy as np + + +def softmax_func(x, axis=None): + x = x.float() + x_max = x.max(dim=axis, keepdims=True).values + x_sub = x - x_max + y = torch.exp(x_sub) + x_sum = y.sum(dim=axis, keepdims=True) + ans = 0 if (x_sum == 0).any() else y / x_sum + return ans + + +def npu_moe_gating_top_k_softmax(x, finished_optional, k): + input_dtype = x.dtype + num_expert = x.shape[-1] + softmax = softmax_func(x, -1) + softmax = softmax.to(input_dtype) + expert_idx = torch.argsort(-softmax, dim=-1, stable=True) + expert_idx = expert_idx[:, :k] + y = torch.gather(softmax, index=expert_idx, dim=-1) + if finished_optional is not None: + finished_optional = finished_optional.view(finished_optional.shape[0], 1) + finished_optional = finished_optional.expand(-1, k) + expert_idx = torch.where(finished_optional, num_expert, expert_idx) + row_idx = torch.arange(y.shape[0] * y.shape[1]).reshape(y.shape[1], y.shape[0]).t() + + return y, expert_idx, row_idx diff --git a/debug/accuracy_tools/msprobe/pytorch/bench_functions/npu_fusion_attention.py b/debug/accuracy_tools/msprobe/pytorch/bench_functions/npu_fusion_attention.py index 9e6f110c3da1dcf353bac34bc3cd51fe6981e1cf..58a585f5a05f4b2d533d150db3a9fbfd907f5a07 100644 --- a/debug/accuracy_tools/msprobe/pytorch/bench_functions/npu_fusion_attention.py +++ b/debug/accuracy_tools/msprobe/pytorch/bench_functions/npu_fusion_attention.py @@ -30,6 +30,7 @@ numels=0, prefix=None, sparse_mode=0, gen_mask_parallel=True, sync=False """ +from collections import namedtuple import torch import numpy as np from einops import rearrange @@ -54,6 +55,14 @@ GTYPE = torch.float64 # arm host必须选择float64,x86环境选择float32即 SOFTMAX_BUILD_MODE = "QKV" # "MAX_SUM" +FaForwardParams = namedtuple("FaForwardParams", + ["q", "k", "v", "drop_mask", "atten_mask", "pse", "scale", "keep_prob"]) +FaBackwardParams = namedtuple("FaBackwardParams", + ["dx", "q", "k", "v", "softmax_res", "drop_mask", "pse", "scale", "keep_prob"]) +RebuildSoftmaxParams = namedtuple("RebuildSoftmaxParams", + ["q", "k", "atten_mask", "pse", "scale", "softmax_max", "softmax_sum"]) + + def softmax_forward(x): x_max = torch.max(x, dim=-1, keepdims=True)[0] x_sub = x.sub(x_max) @@ -99,7 +108,15 @@ def calculate_qk(q, k, atten_mask, pse, scale): return qk -def fusion_attention_forward(q, k, v, drop_mask, atten_mask, pse, scale, keep_prob): +def fusion_attention_forward(forward_params): + q = forward_params.q + k = forward_params.k + v = forward_params.v + drop_mask = forward_params.drop_mask + atten_mask = forward_params.atten_mask + pse = forward_params.pse + scale = forward_params.scale + keep_prob = forward_params.keep_prob qk = calculate_qk(q, k, atten_mask, pse, scale) softmax_res, softmax_max, softmax_sum = softmax_forward(qk) if drop_mask is None or len(drop_mask.shape) == 0: @@ -110,7 +127,16 @@ def fusion_attention_forward(q, k, v, drop_mask, atten_mask, pse, scale, keep_pr return y, softmax_max, softmax_sum -def fusion_attention_backward(dx, q, k, v, softmax_res, drop_mask, pse, scale, keep_prob): +def fusion_attention_backward(backward_params): + dx = backward_params.dx + q = backward_params.q + k = backward_params.k + v = backward_params.v + softmax_res = backward_params.softmax_res + drop_mask = backward_params.drop_mask + pse = backward_params.pse + scale = backward_params.scale + keep_prob = backward_params.keep_prob dp = torch.matmul(dx, v.permute(0, 1, 3, 2)) if drop_mask is None or len(drop_mask.shape) == 0: drop_res = softmax_res.permute(0, 1, 3, 2) @@ -368,11 +394,18 @@ def rebuid_softmax_by_qkv(q, k, atten_mask, pse, scale): return softmax_res -def rebuild_softmax_by_max_sum(q, k, atten_mask, pse, scale, softmax_max, softmax_sum): +def rebuild_softmax_by_max_sum(softmax_params): """ attention = softmax(QK^T/sqrt(d))V softmax(x_i) = e^(x_i - x_max_i) / x_sum_i) """ + q = softmax_params.q + k = softmax_params.k + atten_mask = softmax_params.atten_mask + pse = softmax_params.pse + scale = softmax_params.scale + softmax_max = softmax_params.softmax_max + softmax_sum = softmax_params.softmax_sum logger.info("Using softmax_max and softmax_sum to rebuild original softmax") qk = calculate_qk(q, k, atten_mask, pse, scale) if softmax_max.shape[-1] == 0: @@ -502,10 +535,8 @@ def npu_fusion_attention(*args, **kwargs): key = convert_to_bnsd(key, n2, input_layout) value = convert_to_bnsd(value, n2, input_layout) k_new, v_new = generate_kv(key, value, n1, n2) - out_golden, softmax_max, softmax_sum = fusion_attention_forward(q=query, k=k_new, v=v_new, - drop_mask=None, atten_mask=atten_mask, - pse=pse, scale=scale, - keep_prob=keep_prob) + forward_params = FaForwardParams(query, k_new, v_new, None, atten_mask, pse, scale, keep_prob) + out_golden, softmax_max, softmax_sum = fusion_attention_forward(forward_params) if out_golden.dim() == 5: out_golden = out_golden.reshape(out_golden.size(0), out_golden.size(1) * out_golden.size(2), out_golden.size(3), out_golden.size(4)) @@ -546,9 +577,10 @@ def npu_fusion_attention_grad(*args, **kwargs): if SOFTMAX_BUILD_MODE == "QKV": softmax_res = rebuid_softmax_by_qkv(query, k_new, atten_mask, pse, scale_value) else: - softmax_res = rebuild_softmax_by_max_sum(query, k_new, atten_mask, pse, scale_value, softmax_max, softmax_sum) - - dq, dk, dv = fusion_attention_backward(dx, query, k_new, v_new, softmax_res, None, pse, scale_value, keep_prob) + softmax_params = RebuildSoftmaxParams(query, k_new, atten_mask, pse, scale_value, softmax_max, softmax_sum) + softmax_res = rebuild_softmax_by_max_sum(softmax_params) + backward_params = FaBackwardParams(dx, query, k_new, v_new, softmax_res, None, pse, scale_value, keep_prob) + dq, dk, dv = fusion_attention_backward(backward_params) # N不等长适配by cdy if not (n1 == n2): diff --git a/debug/accuracy_tools/msprobe/pytorch/bench_functions/sort_v2.py b/debug/accuracy_tools/msprobe/pytorch/bench_functions/sort_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..c5bd1c141f83158632cffd9ce6238f191fbfe826 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/bench_functions/sort_v2.py @@ -0,0 +1,21 @@ +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + + +def npu_sort_v2(x, dim=-1, descending=False, out=None): + y, _ = torch.sort(x, dim=dim, descending=descending) + return y diff --git a/debug/accuracy_tools/msprobe/pytorch/common/parse_json.py b/debug/accuracy_tools/msprobe/pytorch/common/parse_json.py index 70c145d0b6ab48932cf3102cc3d5742b0d113635..b46dbdac7c4620d2dccd31aff8217b80583391c3 100644 --- a/debug/accuracy_tools/msprobe/pytorch/common/parse_json.py +++ b/debug/accuracy_tools/msprobe/pytorch/common/parse_json.py @@ -24,7 +24,8 @@ def parse_json_info_forward_backward(json_path): real_data_path = dump_json.get("dump_data_dir") dump_data = dump_json.get("data") if dump_data is None: - raise ParseJsonException(ParseJsonException.InvalidDumpJson, "something wrong with dump, no data found in dump.json") + raise ParseJsonException(ParseJsonException.InvalidDumpJson, + "something wrong with dump, no data found in dump.json") if not dump_data: logger.warning("data field is empty, no overflow data found.") diff --git a/debug/accuracy_tools/msprobe/pytorch/common/utils.py b/debug/accuracy_tools/msprobe/pytorch/common/utils.py index f59d027ea0106f124a610753b1437dbc44e170a7..16067f6d2bee70645bcc337d1809a14f41ae5b96 100644 --- a/debug/accuracy_tools/msprobe/pytorch/common/utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/common/utils.py @@ -18,6 +18,7 @@ import os import pickle import random import stat +import inspect from functools import wraps import numpy as np @@ -404,6 +405,48 @@ def load_api_data(api_data_bytes): return buffer +def is_recomputation(): + """Check if the current operation is in the re-computation phase. + + This function inspects the current call stack to indicate whether the current operation is in the + re-computation phase. We use a blacklist mechanism, now supported megatron and mindspeed framework. + megatron: The 'backward' function is called by the 'torch/autograd/function.py' file. + mindspeed: The 'checkpoint_function_backward' function is called by the 'torch/autograd/function.py' + file or the custom module(use CheckpointWithoutOutput) with the 'recompute_fn' function is executed within the + 'torch/utils/checkpoint.py' file. + + Returns: + bool: True if in the re-computation phase, False otherwise. + """ + backward_function_indices = [] + call_stack = inspect.stack() + + # Identify the function 'backward' is being executed within the 'torch/_tensor.py' file. + for frame_info in call_stack: + if frame_info.function == "recompute_fn" and frame_info.filename.endswith('torch/utils/checkpoint.py'): + del call_stack + return True + + # Identify indices in the call stack where the specific function is being executed + for idx, frame_info in enumerate(call_stack): + if frame_info.function == Const.BACKWARD or frame_info.function == 'checkpoint_function_backward': + backward_function_indices.append(idx) + + # Check if the execution is within 'torch/autograd/function.py' file + for idx in backward_function_indices: + # The Megatron and MindSpeed L0&L1 scenes + if idx + 1 < len(call_stack) and call_stack[idx + 1].filename.endswith('torch/autograd/function.py'): + del call_stack + return True + # The latest MindSpeed L2 and ModelLink scenes + if idx + 2 < len(call_stack) and call_stack[idx + 2].filename.endswith('torch/autograd/function.py'): + del call_stack + return True + + del call_stack + return False + + def check_save_param(variable, name, save_backward): # try catch this api to skip invalid call if not isinstance(variable, (list, dict, torch.Tensor, int, float, str)): @@ -420,4 +463,13 @@ def check_save_param(variable, name, save_backward): logger.warning("PrecisionDebugger.save_backward name not valid, " "should be bool. " "Skip current save process.") - raise ValueError \ No newline at end of file + raise ValueError + + +def replace_last_occurrence(text, old, new): + if text is None: + return text + index = text.rfind(old) + if index != -1: + return text[:index] + text[index:].replace(old, new, 1) + return text diff --git a/debug/accuracy_tools/msprobe/pytorch/debugger/debugger_config.py b/debug/accuracy_tools/msprobe/pytorch/debugger/debugger_config.py index 1df261658577f95de0104d8901687a2734c41051..77e78bc38063602e64b533291d60b9b12fd2ae00 100644 --- a/debug/accuracy_tools/msprobe/pytorch/debugger/debugger_config.py +++ b/debug/accuracy_tools/msprobe/pytorch/debugger/debugger_config.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -26,7 +26,7 @@ class DebuggerConfig: self.task = task or common_config.task or Const.STATISTICS self.rank = common_config.rank if common_config.rank else [] self.step = common_config.step if common_config.step else [] - self.level = level or common_config.level or "L1" + self.level = level or common_config.level or Const.LEVEL_L1 self.enable_dataloader = common_config.enable_dataloader self.scope = task_config.scope if task_config.scope else [] self.list = task_config.list if task_config.list else [] @@ -34,10 +34,7 @@ class DebuggerConfig: self.summary_mode = task_config.summary_mode if task_config.summary_mode else Const.STATISTICS self.overflow_nums = task_config.overflow_nums if task_config.overflow_nums else 1 self.framework = Const.PT_FRAMEWORK - - if self.level == Const.LEVEL_L2: - self.is_backward_kernel_dump = False - self._check_and_adjust_config_with_l2() + self.async_dump = common_config.async_dump if common_config.async_dump else False if self.task == Const.FREE_BENCHMARK: self.fuzz_device = task_config.fuzz_device @@ -64,6 +61,10 @@ class DebuggerConfig: self.check() + if self.level == Const.LEVEL_L2: + self.is_backward_kernel_dump = False + self._check_and_adjust_config_with_l2() + def check_kwargs(self): if self.task and self.task not in Const.TASK_LIST: raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR, @@ -74,6 +75,19 @@ class DebuggerConfig: if not self.dump_path: raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR, f"The dump_path not found.") + if not isinstance(self.async_dump, bool): + raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR, + f"The parameters async_dump should be bool.") + if self.async_dump and self.task == Const.TENSOR and not self.list: + raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR, + f"The parameters async_dump is true in tensor task, the parameters list cannot be " + f"empty.") + if self.task == Const.STRUCTURE and self.level not in [Const.LEVEL_L0, Const.LEVEL_MIX]: + logger.warning_on_rank_0( + f"When the task is set to structure, the level should be one of {[Const.LEVEL_L0, Const.LEVEL_MIX]}. " + f"If not, the default level is {Const.LEVEL_MIX}." + ) + self.level = Const.LEVEL_MIX def check(self): self.check_kwargs() @@ -89,21 +103,25 @@ class DebuggerConfig: logger.error_on_rank_0( f"For level {self.level}, PrecisionDebugger or start interface must receive a 'model' parameter.") raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR, f"missing the parameter 'model'") - if instance.model and not isinstance(instance.model, list): - instance.model = [instance.model] - if start_model: - if not isinstance(start_model, list): - instance.model = [start_model] - else: - instance.model = start_model - for single_model in instance.model: - if not isinstance(single_model, torch.nn.Module): - logger.error_on_rank_0( - f"The 'model' parameter must be a torch.nn.Module or list[torch.nn.Module] type, " - f"currently there is a {type(single_model)} type." - ) - raise MsprobeException( - MsprobeException.INVALID_PARAM_ERROR, f"model must be a torch.nn.Module or list[torch.nn.Module]") + + instance.model = start_model if start_model is not None else instance.model + if isinstance(instance.model, torch.nn.Module): + return + + error_model = None + if isinstance(instance.model, (list, tuple)): + for model in instance.model: + if not isinstance(model, torch.nn.Module): + error_model = model + break + else: + error_model = instance.model + + if error_model is not None: + error_info = (f"The 'model' parameter must be a torch.nn.Module or list[torch.nn.Module] " + f"type, currently there is a {type(error_model)} type.") + raise MsprobeException( + MsprobeException.INVALID_PARAM_ERROR, error_info) def _check_and_adjust_config_with_l2(self): if self.scope: diff --git a/debug/accuracy_tools/msprobe/pytorch/debugger/precision_debugger.py b/debug/accuracy_tools/msprobe/pytorch/debugger/precision_debugger.py index e07616bc69676a508cc7bc836ea6402e8eceb4a3..5bb1d3a14e82d7b4bce9d7da8921a1d701e82222 100644 --- a/debug/accuracy_tools/msprobe/pytorch/debugger/precision_debugger.py +++ b/debug/accuracy_tools/msprobe/pytorch/debugger/precision_debugger.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -23,6 +23,7 @@ from msprobe.core.common.utils import get_real_step_or_rank from msprobe.pytorch.common.log import logger from msprobe.pytorch.common.utils import check_save_param from msprobe.pytorch.debugger.debugger_config import DebuggerConfig +from msprobe.pytorch.dump.module_dump.module_dump import ModuleDumper from msprobe.pytorch.grad_probe.grad_monitor import GradientMonitor from msprobe.pytorch.pt_config import parse_json_config from msprobe.pytorch.service import Service @@ -50,7 +51,7 @@ class PrecisionDebugger: dump_path=None, level=None, model=None, - step=None, + step=None ): if not hasattr(self, "initialized"): config_params = ConfigParameters(config_path, @@ -67,12 +68,13 @@ class PrecisionDebugger: if self.task == Const.GRAD_PROBE: self.gm = GradientMonitor(common_config, task_config) return - if step: + if step is not None: common_config.step = get_real_step_or_rank(step, Const.STEP) self.config = DebuggerConfig( common_config, task_config, task, dump_path, level ) self.service = Service(self.config) + self.module_dumper = ModuleDumper(self.service) self.enable_dataloader = self.config.enable_dataloader if self.enable_dataloader: logger.warning_on_rank_0("The enable_dataloader feature will be deprecated in the future.") @@ -171,6 +173,35 @@ class PrecisionDebugger: instance.service.save(variable, name, save_backward) +def module_dump(module, dump_name): + if not isinstance(module, torch.nn.Module): + raise MsprobeException( + MsprobeException.INVALID_PARAM_ERROR, + f"the module argument in module_dump must be a torch.nn.Module subclass" + ) + if not isinstance(dump_name, str): + raise MsprobeException( + MsprobeException.INVALID_PARAM_ERROR, + f"the dump_name argument in module_dump must be a str type" + ) + instance = PrecisionDebugger._instance + if not instance: + raise MsprobeException( + MsprobeException.INTERFACE_USAGE_ERROR, + f"PrecisionDebugger must be instantiated before using module_dump interface" + ) + instance.module_dumper.start_module_dump(module, dump_name) + + +def module_dump_end(): + instance = PrecisionDebugger._instance + if not instance: + raise MsprobeException( + MsprobeException.INTERFACE_USAGE_ERROR, + f"PrecisionDebugger must be instantiated before using module_dump_end interface" + ) + instance.module_dumper.stop_module_dump() + def iter_tracer(func): def func_wrapper(*args, **kwargs): diff --git a/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/__init__.py b/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/module_dump.py b/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/module_dump.py new file mode 100644 index 0000000000000000000000000000000000000000..4700de6f1f9f3b5ddfb9507decb6f8739b5eda9b --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/module_dump.py @@ -0,0 +1,86 @@ +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from msprobe.core.common.const import Const +from msprobe.core.data_dump.scope import BaseScope +from msprobe.pytorch.common.log import logger +from msprobe.pytorch.hook_module.api_registry import api_register + +torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0' + + +class ModuleDumper: + def __init__(self, service): + self.service = service + self.hook_handle_list = [] + + def start_module_dump(self, module, dump_name): + api_register.api_originality() + self.register_hook(module, dump_name) + + def stop_module_dump(self): + api_register.api_modularity() + for hook_handle in self.hook_handle_list: + if isinstance(hook_handle, torch.utils.hooks.RemovableHandle): + hook_handle.remove() + self.hook_handle_list.clear() + + def register_hook(self, module, dump_name): + prefix_name = ( + BaseScope.Module_Type_Module + Const.SEP + + dump_name + Const.SEP + + module.__class__.__name__ + Const.SEP + ) + module_processor = self.service.module_processor + _, forward_hook, backward_hook, forward_hook_torch_version_below_2 = self.service.build_hook( + BaseScope.Module_Type_Module, + prefix_name + ) + + if module_processor.has_register_backward_hook(module): + logger.warning( + f"The {dump_name} module has registered deprecated register_backward_hook," + f"which may cause abnormal data dump. The backward data dump for this module will be skipped." + ) + if torch_version_above_or_equal_2: + forward_hook_handle = module.register_forward_hook(forward_hook, with_kwargs=True) + else: + if not module_processor.has_register_backward_hook(module): + backward_hook_handle = module.register_full_backward_hook( + module_processor.node_hook(prefix_name + Const.BACKWARD, Const.STOP) + ) + self.hook_handle_list.append(backward_hook_handle) + forward_hook_handle = module.register_forward_hook(forward_hook_torch_version_below_2) + self.hook_handle_list.append(forward_hook_handle) + if not module_processor.has_register_backward_hook(module): + backward_hook_handle = module.register_full_backward_hook(backward_hook) + self.hook_handle_list.append(backward_hook_handle) + + forward_pre_hook_handle = module.register_forward_pre_hook( + module_processor.node_hook(prefix_name + Const.FORWARD, Const.START) + ) + forward_hook_handle = module.register_forward_hook( + module_processor.node_hook(prefix_name + Const.FORWARD, Const.STOP) + ) + self.hook_handle_list.extend([forward_pre_hook_handle, forward_hook_handle]) + if torch_version_above_or_equal_2 and not module_processor.has_register_backward_hook(module): + backward_pre_hook_handle = module.register_full_backward_pre_hook( + module_processor.node_hook(prefix_name + Const.BACKWARD, Const.START) + ) + backward_hook_handle = module.register_full_backward_hook( + module_processor.node_hook(prefix_name + Const.BACKWARD, Const.STOP) + ) + self.hook_handle_list.extend([backward_pre_hook_handle, backward_hook_handle]) diff --git a/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/module_processer.py b/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/module_processer.py index 34246148a9d54f1bc7c31b02fc3fe781566a9ef5..b5ca1da461fd4235a09172de4b9dcea34a624e58 100644 --- a/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/module_processer.py +++ b/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/module_processer.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -16,22 +16,26 @@ from functools import wraps import torch -from torch.utils.checkpoint import set_checkpoint_early_stop -from torch.utils.checkpoint import checkpoint as origin_checkpoint from msprobe.core.common.const import Const from msprobe.core.data_dump.scope import BaseScope, ModuleRangeScope, MixRangeScope from msprobe.pytorch.common.log import logger +from msprobe.pytorch.common.utils import replace_last_occurrence +from torch.utils.checkpoint import checkpoint as origin_checkpoint +from torch.utils.checkpoint import set_checkpoint_early_stop from torch.utils.hooks import BackwardHook torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0' + def checkpoint_without_early_stop(*args, **kwargs): with set_checkpoint_early_stop(False): return origin_checkpoint(*args, **kwargs) + def replace_checkpoint(): torch.utils.checkpoint.checkpoint = checkpoint_without_early_stop + class ModuleProcesser: module_count = {} module_stack = [] @@ -42,29 +46,8 @@ class ModuleProcesser: self.scope = scope if isinstance(scope, (ModuleRangeScope, MixRangeScope)) else None BackwardHook.setup_input_hook = ModuleProcesser.clone_return_value(BackwardHook.setup_input_hook) BackwardHook.setup_output_hook = ModuleProcesser.clone_return_value(BackwardHook.setup_output_hook) - BackwardHook.setup_output_hook = ModuleProcesser.filter_tensor_and_tuple(BackwardHook.setup_output_hook) replace_checkpoint() - @staticmethod - def filter_tensor_and_tuple(func): - @wraps(func) - def wrap_by_filter_tensor_and_tuple(*args, **kwargs): - # setup_output_hook传入非tensor数据,工具后续dump会报错,处理方式是解析非tensor数据的属性,对tensor属性挂hook - # setup_output_hook定义为setup_output_hook(self, args),因此处理第二个位置参数,即*args[1] - if not isinstance(args[1], (torch.Tensor, tuple)): - for item_str in dir(args[1]): - item = getattr(args[1], item_str) - # 处理tensor或者只包含tensor的元组 - if isinstance(item, torch.Tensor) or \ - (isinstance(item, tuple) and all(isinstance(x, torch.Tensor) for x in item)): - args_new = (args[0], item) - result = func(*args_new, **kwargs) - setattr(args[1], item_str, result) - return args[1] - return func(*args, **kwargs) - - return wrap_by_filter_tensor_and_tuple - @staticmethod def clone_return_value(func): @wraps(func) @@ -78,11 +61,11 @@ class ModuleProcesser: def clone_if_tensor(result): if isinstance(result, torch.Tensor): return result.clone() - elif isinstance(result, tuple): + elif type(result) is tuple: return tuple(ModuleProcesser.clone_if_tensor(x) for x in result) - elif isinstance(result, list): + elif type(result) is list: return list(ModuleProcesser.clone_if_tensor(x) for x in result) - elif isinstance(result, dict): + elif type(result) is dict: return {k: ModuleProcesser.clone_if_tensor(v) for k, v in result.items()} else: return result @@ -96,13 +79,20 @@ class ModuleProcesser: return ModuleProcesser.module_count[module_name] @staticmethod - def remove_deprecated_backward_hook_if_exist(module): - if hasattr(module, '_backward_hooks') and \ - len(module._backward_hooks) > 0 and \ - module._is_full_backward_hook is False: - module._backward_hooks.clear() - module._is_full_backward_hook = None - logger.warning("Found deprecated backward hooks. Removing them and switching to full backward hooks.") + def has_register_backward_hook(module): + return hasattr(module, '_backward_hooks') and \ + len(module._backward_hooks) > 0 and \ + module._is_full_backward_hook is False + + @staticmethod + def get_modules_and_names(models): + modules_and_names_with_index = {} + if isinstance(models, (list, tuple)): + for index, model in enumerate(models): + modules_and_names_with_index[str(index)] = model.named_modules() + else: + modules_and_names_with_index["-1"] = models.named_modules() + return modules_and_names_with_index @classmethod def reset_module_stats(cls): @@ -111,40 +101,41 @@ class ModuleProcesser: cls.api_parent_node = "" cls.module_node = {} - def hook_modules(self, models, build_hook): + def register_module_hook(self, models, build_hook): logger.info_on_rank_0("The init dump is enabled, and the module dump function will not be available.") - for model in models: - self.register_module_hook(model, build_hook) - - def register_module_hook(self, model, build_hook): - for name, module in model.named_modules(): - if module == model: - continue - - prefix_name = ( - BaseScope.Module_Type_Module + Const.SEP + - name + Const.SEP + - module.__class__.__name__ + Const.SEP - ) - pre_forward_hook, forward_hook, backward_hook, forward_hook_torch_version_below_2 = build_hook( - BaseScope.Module_Type_Module, - prefix_name - ) - if torch_version_above_or_equal_2: - module.register_forward_hook(forward_hook, with_kwargs=True) - else: - self.remove_deprecated_backward_hook_if_exist(module) - module.register_full_backward_hook(self.node_hook(prefix_name + Const.BACKWARD, Const.STOP)) - module.register_forward_hook(forward_hook_torch_version_below_2) - self.remove_deprecated_backward_hook_if_exist(module) - module.register_full_backward_hook(backward_hook) - - module.register_forward_pre_hook(self.node_hook(prefix_name + Const.FORWARD, Const.START)) - module.register_forward_hook(self.node_hook(prefix_name + Const.FORWARD, Const.STOP)) - if torch_version_above_or_equal_2: - module.register_full_backward_pre_hook(self.node_hook(prefix_name + Const.BACKWARD, Const.START)) - self.remove_deprecated_backward_hook_if_exist(module) - module.register_full_backward_hook(self.node_hook(prefix_name + Const.BACKWARD, Const.STOP)) + modules_and_names_with_index = self.get_modules_and_names(models) + for index, modules_and_names in modules_and_names_with_index.items(): + model = models if index == "-1" else models[int(index)] + for name, module in modules_and_names: + if module == model: + continue + module_index = (index + Const.SEP) if index != "-1" else "" + prefix_name = (BaseScope.Module_Type_Module + Const.SEP + module_index + + name + Const.SEP + module.__class__.__name__ + Const.SEP) + pre_forward_hook, forward_hook, backward_hook, forward_hook_torch_version_below_2 = build_hook( + BaseScope.Module_Type_Module, + prefix_name + ) + + if self.has_register_backward_hook(module): + logger.warning( + f"The {prefix_name[:-1]} has registered deprecated register_backward_hook," + f"which may cause abnormal data dump. The backward data dump for this module will be skipped." + ) + if torch_version_above_or_equal_2: + module.register_forward_hook(forward_hook, with_kwargs=True) + else: + if not self.has_register_backward_hook(module): + module.register_full_backward_hook(self.node_hook(prefix_name + Const.BACKWARD, Const.STOP)) + module.register_forward_hook(forward_hook_torch_version_below_2) + if not self.has_register_backward_hook(module): + module.register_full_backward_hook(backward_hook) + + module.register_forward_pre_hook(self.node_hook(prefix_name + Const.FORWARD, Const.START)) + module.register_forward_hook(self.node_hook(prefix_name + Const.FORWARD, Const.STOP)) + if torch_version_above_or_equal_2 and not self.has_register_backward_hook(module): + module.register_full_backward_pre_hook(self.node_hook(prefix_name + Const.BACKWARD, Const.START)) + module.register_full_backward_hook(self.node_hook(prefix_name + Const.BACKWARD, Const.STOP)) def node_hook(self, name_prefix, start_or_stop, **kwargs): @@ -192,9 +183,9 @@ class ModuleProcesser: if not hasattr(module, "mindstudio_reserved_name") or not module.mindstudio_reserved_name: module.mindstudio_reserved_name = [] module.mindstudio_reserved_name.append(full_name) - forward_full_name = full_name.replace(Const.BACKWARD, Const.FORWARD) - ModuleProcesser.module_node[full_name] = ModuleProcesser.module_node[forward_full_name].replace( - Const.FORWARD, Const.BACKWARD) if ModuleProcesser.module_node[forward_full_name] else None + forward_full_name = replace_last_occurrence(full_name, Const.BACKWARD, Const.FORWARD) + ModuleProcesser.module_node[full_name] = replace_last_occurrence( + ModuleProcesser.module_node.get(forward_full_name), Const.FORWARD, Const.BACKWARD) ModuleProcesser.api_parent_node = None if self.scope: self.scope.begin_module(full_name) diff --git a/debug/accuracy_tools/msprobe/pytorch/function_factory.py b/debug/accuracy_tools/msprobe/pytorch/function_factory.py index e3ac6947f69074b5041a44e23c5a720e62680941..247e2cd0ed5ea11047cc0d75954dbc1e92b889f4 100644 --- a/debug/accuracy_tools/msprobe/pytorch/function_factory.py +++ b/debug/accuracy_tools/msprobe/pytorch/function_factory.py @@ -27,6 +27,11 @@ from msprobe.pytorch.bench_functions.rotary_mul import npu_rotary_mul, npu_rotar from msprobe.pytorch.bench_functions.scaled_mask_softmax import npu_scaled_masked_softmax, \ npu_scaled_masked_softmax_backward from msprobe.pytorch.bench_functions.swiglu import npu_swiglu, npu_swiglu_backward +from msprobe.pytorch.bench_functions.apply_adam import npu_apply_adam +from msprobe.pytorch.bench_functions.group_norm_silu import npu_group_norm_silu +from msprobe.pytorch.bench_functions.mish import npu_mish +from msprobe.pytorch.bench_functions.moe_gating_top_k_softmax import npu_moe_gating_top_k_softmax +from msprobe.pytorch.bench_functions.sort_v2 import npu_sort_v2 from msprobe.pytorch.common.utils import logger @@ -79,7 +84,8 @@ class Register(dict): npu_custom_functions = Register() npu_custom_functions([ npu_apply_adam_w, npu_confusion_transpose, npu_fast_gelu, npu_layer_norm_eval, npu_linear, npu_fusion_attention, - npu_rms_norm, npu_rotary_mul, npu_scaled_masked_softmax, npu_swiglu, gpu_fusion_attention + npu_rms_norm, npu_rotary_mul, npu_scaled_masked_softmax, npu_swiglu, gpu_fusion_attention, npu_apply_adam, + npu_group_norm_silu, npu_mish, npu_moe_gating_top_k_softmax, npu_sort_v2 ]) # register for npu custom backward bench functions diff --git a/debug/accuracy_tools/msprobe/pytorch/functional/module_dump.py b/debug/accuracy_tools/msprobe/pytorch/functional/module_dump.py deleted file mode 100644 index e8fae0cd301e595848857b5ccf9d0bddc590ea73..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/msprobe/pytorch/functional/module_dump.py +++ /dev/null @@ -1,84 +0,0 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -import torch.nn as nn -from msprobe.core.common.const import Const -from msprobe.core.common.exceptions import MsprobeException -from msprobe.core.data_dump.scope import BaseScope -from msprobe.pytorch.common.log import logger -from msprobe.pytorch.debugger.precision_debugger import PrecisionDebugger -from msprobe.pytorch.hook_module.api_registry import api_register -from msprobe.pytorch.service import torch_version_above_or_equal_2 - -hook_handle_list = [] - - -def module_dump(module, dump_name): - if not isinstance(module, nn.Module): - logger.error("The parameter module in module_dump must be a Module subclass.") - raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR) - if not isinstance(dump_name, str): - logger.error("The parameter dump_name in module_dump must be a str type.") - raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR) - - api_register.api_originality() - register_hook(module, dump_name) - - -def module_dump_end(): - api_register.api_modularity() - remove_hook() - hook_handle_list.clear() - - -def register_hook(module, dump_name): - prefix = BaseScope.Module_Type_Module + Const.SEP + dump_name + Const.SEP + module.__class__.__name__ + Const.SEP - - pdg = PrecisionDebugger() - _, forward_hook, backward_hook, forward_hook_torch_version_below_2 = \ - pdg.service.build_hook(BaseScope.Module_Type_Module, prefix) - - if torch_version_above_or_equal_2: - forward_hook_handle = module.register_forward_hook(forward_hook, with_kwargs=True) - hook_handle_list.append(forward_hook_handle) - else: - pdg.service.module_processor.remove_deprecated_backward_hook_if_exist(module) - full_backward_hook_handle = module.register_full_backward_hook( - pdg.service.module_processor.node_hook(prefix + Const.BACKWARD, Const.STOP)) - forward_hook_handle = module.register_forward_hook(forward_hook_torch_version_below_2) - hook_handle_list.extend([full_backward_hook_handle, forward_hook_handle]) - pdg.service.module_processor.remove_deprecated_backward_hook_if_exist(module) - full_backward_hook_handle = module.register_full_backward_hook(backward_hook) - - forward_pre_hook_handle = module.register_forward_pre_hook( - pdg.service.module_processor.node_hook(prefix + Const.FORWARD, Const.START)) - forward_hook_handle = module.register_forward_hook( - pdg.service.module_processor.node_hook(prefix + Const.FORWARD, Const.STOP)) - hook_handle_list.extend([full_backward_hook_handle, forward_pre_hook_handle, forward_hook_handle]) - - if torch_version_above_or_equal_2: - backward_pre_hook_handle = module.register_full_backward_pre_hook( - pdg.service.module_processor.node_hook(prefix + Const.BACKWARD, Const.START)) - pdg.service.module_processor.remove_deprecated_backward_hook_if_exist(module) - full_backward_hook_handle = module.register_full_backward_hook( - pdg.service.module_processor.node_hook(prefix + Const.BACKWARD, Const.STOP)) - hook_handle_list.extend([backward_pre_hook_handle, full_backward_hook_handle]) - - -def remove_hook(): - for hook_handle in hook_handle_list: - if isinstance(hook_handle, torch.utils.hooks.RemovableHandle): - hook_handle.remove() diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/support_wrap_ops.yaml b/debug/accuracy_tools/msprobe/pytorch/hook_module/support_wrap_ops.yaml index c2a1927cbc2f84392c90d3c51937dbdfb80e1fd3..4bc22f51ceb5497f307fb4ac3226c8c590ea459a 100644 --- a/debug/accuracy_tools/msprobe/pytorch/hook_module/support_wrap_ops.yaml +++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/support_wrap_ops.yaml @@ -1911,4 +1911,5 @@ distributed: - all_to_all_single - all_to_all - all_gather_into_tensor - - reduce_scatter_tensor \ No newline at end of file + - reduce_scatter_tensor + - batch_isend_irecv \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_distributed.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_distributed.py index 987676c03338d11606b729ba8eab6791853fbb86..1cd11842c31bacdad7c1bb90f98ac81c3415a40e 100644 --- a/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_distributed.py +++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_distributed.py @@ -57,6 +57,10 @@ class DistributedOPTemplate(HOOKModule): if kwargs.get("async_op") or self.op_name_ in ["isend", "irecv"]: if handle and hasattr(handle, 'wait'): handle.wait() + if self.op_name_ == "batch_isend_irecv": + if isinstance(handle, list): + for req in handle: + req.wait() return handle diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/anomaly_detect.py b/debug/accuracy_tools/msprobe/pytorch/monitor/anomaly_detect.py index 128e71d253f311ff3f2bac8528d18a206f9abe00..63f20b1928c80e1e29d7cb8224f267c246fcaa8b 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/anomaly_detect.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/anomaly_detect.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,20 +14,20 @@ # limitations under the License. import itertools import os -import sys import statistics as st +import sys from abc import ABC +from collections import defaultdict from dataclasses import dataclass, field from typing import List -from collections import defaultdict import pandas as pd import torch from torch.utils.tensorboard import SummaryWriter -from msprobe.pytorch.common.log import logger -from msprobe.core.common.file_utils import change_mode, create_directory, write_df_to_csv from msprobe.core.common.const import FileCheckConst, MonitorConst +from msprobe.core.common.file_utils import change_mode, create_directory, write_df_to_csv +from msprobe.pytorch.common.log import logger class ScanRule(ABC): @@ -136,8 +136,8 @@ class AnomalyDataFactory(ABC): tag_name = tag[0] param_name = tag_name.split('/')[0] call_id = self.name2callid.get(tag_name, -1) - if MonitorConst.VPP_SEP in param_name: - vpp_stage = int(param_name.split(MonitorConst.VPP_SEP)[0]) + if MonitorConst.NAME_SEP in param_name: + vpp_stage = int(param_name.split(MonitorConst.NAME_SEP)[0]) else: vpp_stage = 0 @@ -161,16 +161,17 @@ class TrainStage: OPTIMIZER_STAGE = 2 -FORWARD_KEY = [MonitorConst.ACTV_IN, MonitorConst.ACTV_OUT] -BACKWARD_KEY = [MonitorConst.ACTVGRAD_IN, MonitorConst.ACTVGRAD_OUT, - MonitorConst.PRE_GRAD, MonitorConst.POST_GRAD, MonitorConst.ACC_GRAD] -OPTIMIZER_KEY = [MonitorConst.EXP_AVG, MonitorConst.EFXP_AVG_SQ] +FORWARD_KEY = [MonitorConst.ACTV] +BACKWARD_KEY = [MonitorConst.ACTVGRAD, MonitorConst.PRE_GRAD, + MonitorConst.POST_GRAD, MonitorConst.ACC_GRAD] +OPTIMIZER_KEY = [MonitorConst.EXP_AVG, MonitorConst.EXP_AVG_SQ] TRAIN_STAGE = { **{key_: TrainStage.FORWARD_STAGE for key_ in FORWARD_KEY}, **{key_: TrainStage.BACKWARD_STAGE for key_ in BACKWARD_KEY}, **{key_: TrainStage.OPTIMIZER_STAGE for key_ in OPTIMIZER_KEY} } + @dataclass(eq=True) class GradAnomalyData: rank: int = 0 @@ -220,7 +221,7 @@ class GradAnomalyData: @staticmethod def get_train_stage(tag_name): """ - :param tag_name: "0:fc2_0/rank0/input", "0:fc1.weight/rank0/post_grad", "0:fc2.weight/rank0/efxp_avg_sq" + :param tag_name: "0:fc2.input:0/rank0/actv", "0:fc1.weight/rank0/post_grad", "0:fc2.weight/rank0/exp_avg_sq" :return: int, if forward return 0; if backward return 1; if optimizer return 2 """ key_ = tag_name.split("/")[-1] @@ -288,10 +289,19 @@ class BaseWriterWithAD: tags = list(itertools.product(metric_value.keys(), ops)) for op2tensor in metric_value.values(): tensors.extend(op2tensor.values()) + if not tensors: + return + + n_slices = len(tensors) // MonitorConst.SLICE_SIZE with torch.no_grad(): - metric_list = torch.stack(tensors).cpu() if tensors else [] - for tag, metric in zip(tags, metric_list): - self.add_scalar(tag, metric, step) + for i in range(n_slices + 1): + begin = i * MonitorConst.SLICE_SIZE + end = (i+1) * MonitorConst.SLICE_SIZE + if begin == len(tensors): + continue + metric_list = torch.stack(tensors[begin:end]).cpu() + for tag, metric in zip(tags[begin:end], metric_list): + self.add_scalar(tag, metric, step) def _ad(self, scalar_value, history): return AnomalyScanner.scan(self.ad_rules, history, cur=scalar_value) @@ -351,11 +361,11 @@ class CSVWriterWithAD(BaseWriterWithAD): new_data = [] for name, metric_value in self.context_dict.items(): - if MonitorConst.VPP_SEP not in name: - new_data.append([name] + [step] + metric_value) - else: - new_data.append(name.split(MonitorConst.VPP_SEP) + [step] + metric_value) - new_data = pd.DataFrame(new_data).round(self.ndigits) + new_line = name.split(MonitorConst.NAME_SEP) + metric_value + new_line.insert(2, step) + new_data.append(new_line) + + new_data = pd.DataFrame(new_data).round(self.ndigits).fillna("nan") write_df_to_csv(new_data, filepath, mode='a+', header=False) self.context_dict = defaultdict(list) @@ -371,26 +381,11 @@ class CSVWriterWithAD(BaseWriterWithAD): def write_metrics(self, ops, metric_value, step, prefix=''): super().write_metrics(ops, metric_value, step, prefix='') - # generate csv headers - # set hashmap to reduce the number of headers generated. - # 前向的norm用input.ops_和output.ops_,反向的用input_grad.ops_和output_grad.ops_ - if prefix in {"actv", "actv_grad"}: - if prefix == "actv": - input_and_output = [MonitorConst.ACTV_IN, MonitorConst.ACTV_OUT] - else: - input_and_output = [MonitorConst.ACTVGRAD_IN, MonitorConst.ACTVGRAD_OUT] - ops_ = [MonitorConst.DOT.join(i) for i in itertools.product(input_and_output, ops)] - csv_header = ["module_name", "step", *ops_] + if prefix in [MonitorConst.ACTV, MonitorConst.ACTVGRAD]: + self.header = MonitorConst.CSV_HEADER_XY + ops else: - csv_header = ["param_name", "step", *ops] - - keys = list(metric_value.keys()) - if keys and MonitorConst.VPP_SEP in keys[0]: - csv_header.insert(0, "vpp_stage") - - self.header = csv_header + self.header = MonitorConst.CSV_HEADER + ops self.write_csv(prefix, step) - self.header = [] def close(self): pass diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/csv2tb.py b/debug/accuracy_tools/msprobe/pytorch/monitor/csv2tb.py index 3a332c4f696181c28cbc25a0f377e14d304282bc..6ffd1ffabe7b113ff4e61786d4d9f0709b8b605b 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/csv2tb.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/csv2tb.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -12,47 +12,45 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import datetime import os import re -import datetime from multiprocessing import Process import pytz from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm +from msprobe.core.common.const import MonitorConst from msprobe.core.common.file_utils import read_csv, create_directory, remove_path +from msprobe.core.common.utils import is_int from msprobe.pytorch.common.log import logger from msprobe.pytorch.monitor.utils import get_target_output_dir -from msprobe.core.common.const import MonitorConst -from msprobe.core.common.utils import is_int all_data_type_list = ["actv", "actv_grad", "exp_avg", "exp_avg_sq", "grad_unreduced", "grad_reduced", "param"] -csv_file_suffix = r"_\d+-\d+\.csv" +CSV_FILE_SUFFIX = r"_\d+-\d+\.csv" -def parse_step_line(data, line_id, name, ops): - vp_id = data["vpp_stage"][line_id] - module_name = data[name][line_id] - step = data["step"][line_id] +def parse_step_line(line, ops): + vp_id = line["vpp_stage"] + module_name = line[MonitorConst.HEADER_NAME] + step = line["step"] vpp_name = f"vp{vp_id}:{module_name}" + if 'micro_step' in line: + vpp_name = f'{vpp_name}{MonitorConst.NAME_SEP}micro{line["micro_step"]}' ops_result = {} for op in ops: - ops_result[op] = data[op][line_id] + ops_result[op] = line[op] return vpp_name, step, ops_result def parse_step_fn(filepath): data = read_csv(filepath) - - header = list(data.keys()) - name = header[MonitorConst.HEADER_NAME_INDEX] - ops = header[MonitorConst.OPS_START_INDEX:] - + ops = [k for k in data.keys() if k in MonitorConst.OP_LIST] parse_step_result = {} - for line_id in range(len(data)): - vpp_name, step, ops_result = parse_step_line(data, line_id, name, ops) + for _, line in data.iterrows(): + vpp_name, step, ops_result = parse_step_line(line, ops) if vpp_name not in parse_step_result: parse_step_result[vpp_name] = {} if step in parse_step_result[vpp_name]: @@ -65,7 +63,7 @@ def write_step(output_dirpath, parse_step_result, rank, data_type): tb_output_path = os.path.join(output_dirpath, f"rank{rank}", data_type) if os.path.exists(tb_output_path): remove_path(tb_output_path) - logger.warning(f"existing path {tb_output_path} will be recovered") + logger.warning(f"existing path {tb_output_path} will be recovered") writer = SummaryWriter(tb_output_path) for vpp_name, step_data_dict in parse_step_result.items(): step_data_list = [(step, ops) for step, ops in step_data_dict.items()] @@ -82,7 +80,10 @@ def update_dict(dict1, dict2): for key, value in dict2.items(): if key in dict1: if isinstance(dict1[key], dict) and isinstance(value, dict): - update_dict(dict1[key], value) + try: + update_dict(dict1[key], value) + except Exception as e: + raise Exception(f"Error updating nested dict failed at key '{key}': {e}") from e else: raise Exception(f"duplicate key: {key}") else: @@ -91,13 +92,13 @@ def update_dict(dict1, dict2): def csv2tb_by_step_work(target_output_dirs, output_dirpath, data_type_list): - for dir in tqdm(target_output_dirs): - dirpath = dir["path"] - rank = dir["rank"] + for directory in tqdm(target_output_dirs): + dirpath = directory["path"] + rank = directory["rank"] for data_type in data_type_list: all_step_result = {} for filename in os.listdir(dirpath): - if not re.match(f"{data_type}{csv_file_suffix}", filename): + if not re.match(f"{data_type}{CSV_FILE_SUFFIX}", filename): continue filepath = os.path.join(dirpath, filename) try: @@ -105,7 +106,7 @@ def csv2tb_by_step_work(target_output_dirs, output_dirpath, data_type_list): except Exception as e: logger.error(f"csv2tensorboard parse {filepath} failed \n {e}") break - + all_step_result = update_dict(all_step_result, parse_step_result) if all_step_result: write_step(output_dirpath, all_step_result, rank, data_type) @@ -127,7 +128,14 @@ def check_data_type_list(data_type_list): raise ValueError(f"data type({data_type}) is not supported, supported data type: {all_data_type_list}") -def csv2tensorboard_by_step(monitor_path, time_start=None, time_end=None, process_num=1, data_type_list=None, output_dirpath=None): +def csv2tensorboard_by_step( + monitor_path, + time_start=None, + time_end=None, + process_num=1, + data_type_list=None, + output_dirpath=None +): check_process_num(process_num) check_data_type_list(data_type_list) target_output_dirs = get_target_output_dir(monitor_path, time_start, time_end) diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/distributed/wrap_distributed.py b/debug/accuracy_tools/msprobe/pytorch/monitor/distributed/wrap_distributed.py index b9e00f95b9374ea673eb86d406b2b3c84793ce1a..b2fa26a58e702120fcabd5d82f8e1e0ed27f3bc4 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/distributed/wrap_distributed.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/distributed/wrap_distributed.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,16 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect import os import re -import inspect import torch -import torch.nn as nn import torch.distributed as dist +import torch.nn as nn -from msprobe.core.common.file_utils import load_yaml from msprobe.core.common.const import MonitorConst +from msprobe.core.common.file_utils import load_yaml from msprobe.pytorch.monitor.module_metric import get_metrics, get_summary_writer_tag_name try: @@ -57,10 +57,13 @@ class DistributedOPTemplate(nn.Module): super(DistributedOPTemplate, self).__init__() self.op_name_ = str(op_name) self.__name__ = self.op_name_ + self.cc_hooks = [] for pre_hook in pre_hooks: - self.register_forward_pre_hook(pre_hook, with_kwargs=True) + handle = self.register_forward_pre_hook(pre_hook, with_kwargs=True) + self.cc_hooks.append(handle) for hook in post_hooks: - self.register_forward_hook(hook, with_kwargs=True) + handle = self.register_forward_hook(hook, with_kwargs=True) + self.cc_hooks.append(handle) def forward(self, *args, **kwargs): return distributed_func.get(self.op_name_)(*args, **kwargs) @@ -120,8 +123,11 @@ class ApiRegistry: def initialize_hook(self, pre_hooks, post_hooks): self.store_ori_attr(dist, get_distributed_ops(), self.distributed_attr_origin) + cc_hooks = [] for op_name in get_distributed_ops(): self.distributed_attr_hooked[op_name] = DistributedOPTemplate(op_name, pre_hooks, post_hooks) + cc_hooks.extend(self.distributed_attr_hooked[op_name].cc_hooks) + return cc_hooks def get_process_group(process_group): @@ -243,12 +249,14 @@ def create_hooks(context, monitor): monitor.ops, args, MonitorConst.PREFIX_POST ) - elif isinstance(out, list): # batch_isend_irecv + elif isinstance(out, list): # batch_isend_irecv for out_element in out: - PENDING_ASYNC_CC_BY_HANDLE[out_element] = create_async_callback_func(context[module.op_name_], - module.op_name_, - monitor.ops, args, - MonitorConst.PREFIX_POST) + PENDING_ASYNC_CC_BY_HANDLE[out_element] = create_async_callback_func( + context[module.op_name_], + module.op_name_, + monitor.ops, args, + MonitorConst.PREFIX_POST + ) return out catch_data(context[module.op_name_], module.op_name_, monitor.ops, args, MonitorConst.PREFIX_POST) return out diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py b/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py index 5e1e848fc67cb7c017ce69248a6cd159c59b9dc5..62497514d3c1f0e6655f60c4538d00d5ac55ce3e 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -22,9 +22,12 @@ from functools import partial import pytz import torch import torch.distributed as dist -from msprobe.core.common.const import MonitorConst +from torch.utils.hooks import BackwardHook + +from msprobe.core.common.const import MonitorConst, Const from msprobe.core.common.file_utils import load_json, save_json from msprobe.pytorch.common.log import logger +from msprobe.pytorch.common.utils import is_recomputation from msprobe.pytorch.monitor.anomaly_analyse import AnomalyDataWriter from msprobe.pytorch.monitor.anomaly_detect import AnomalyScanner, SummaryWriterWithAD, AnomalyDataFactory, \ CSVWriterWithAD, BaseWriterWithAD, WriterInput @@ -34,18 +37,16 @@ from msprobe.pytorch.monitor.features import get_sign_matches from msprobe.pytorch.monitor.module_metric import get_metrics, get_summary_writer_tag_name, \ TensorMetrics, squash_param_name from msprobe.pytorch.monitor.module_spec_verifier import validate_config_spec -from msprobe.pytorch.monitor.optimizer_collect import OptimizerMonFactory, OptimizerMon -from msprobe.pytorch.monitor.utils import get_param_struct, validate_config, validate_ops, is_recomputation, \ +from msprobe.pytorch.monitor.optimizer_collect import OptimizerMonFactory +from msprobe.pytorch.monitor.utils import get_param_struct, validate_config, validate_ops, \ get_output_base_dir, get_target_output_dir from msprobe.pytorch.monitor.visualizer import HeatmapVisualizer -from torch.optim.optimizer import register_optimizer_step_pre_hook, register_optimizer_step_post_hook -from torch.utils.hooks import BackwardHook - torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0' if not torch_version_above_or_equal_2: raise ValueError("monitor require torch>=2.0") + FORMAT_MAPPING = { MonitorConst.TENSORBOARD: SummaryWriterWithAD, MonitorConst.CSV: CSVWriterWithAD, @@ -65,7 +66,6 @@ def param_is_data_parallel_duplicate(dp_group): class ModuleHookContext: def __init__(self, module_name) -> None: - self.step = 0 self.micro_step = 0 self.actv = defaultdict(dict) self.actvgrad = [] @@ -86,9 +86,6 @@ class ModuleHookContext: :param target_config: target obj in config json. :return: """ - valid_key = [MonitorConst.ACTV_IN, MonitorConst.ACTV_OUT, MonitorConst.ACTVGRAD_IN, MonitorConst.ACTVGRAD_OUT] - if key_name not in valid_key: - raise ValueError(f"key({key_name}) error, valid_key: {valid_key}") cared = target_config.get(self.module_name, self.struct) if key_name in cared: target_module_config = cared[key_name] @@ -101,9 +98,14 @@ class ModuleHookContext: else: logger.warning_on_rank_0(f"target module config error, result maybe empty." f"module_name: {self.module_name}, key_name: {key_name}") + self.format_by_arg[key_name] = None else: self.format_by_arg[key_name] = self.struct.get(key_name).get('config') + def reset(self): + self.actv.clear() + self.actvgrad.clear() + start_step = 0 @@ -111,7 +113,6 @@ start_step = 0 class OptimizerContext: def __init__(self) -> None: self.step = start_step - self.param_effective_rank = defaultdict(float) self.param_mg_direction = defaultdict(float) self.param_adam_update = defaultdict() self.param_adam_ratio = defaultdict() @@ -123,6 +124,18 @@ class OptimizerContext: self.metric_dict = {} self.param_metric = {} + def reset(self): + self.param_mg_direction.clear() + self.param_adam_update.clear() + self.param_adam_ratio.clear() + self.param_weight_grad.clear() + self.param_exp_avg.clear() + self.exp_avg_metric.clear() + self.param_exp_avg_sq.clear() + self.exp_avg_sq_metric.clear() + self.metric_dict.clear() + self.param_metric.clear() + class CommunicationContext: def __init__(self) -> None: @@ -163,132 +176,86 @@ class GradContext: class TrainerMon: tensor_metrics = TensorMetrics() - def __init__(self, config_file_path, process_group=None, params_have_main_grad=True, opt_ty=None) -> None: - """ - opt_ty: "Megatron_Float16OptimizerWithFloat16Params" or "Megatron_DistributedOptimizer" - """ - self.module_fwd_hook_context_by_module = defaultdict(ModuleHookContext) - self.module_bwd_hook_context_by_module = defaultdict(ModuleHookContext) - self.optimizer_context = defaultdict(OptimizerContext) - self.cc_context = defaultdict(CommunicationContext) - self.grad_context = GradContext() + def __init__(self, config_file_path, process_group=None, params_have_main_grad=True) -> None: + # TYPE1: 只在这里初始化的变量, 不会随着训练中途config配置改变而重置 + self.config_file_path = config_file_path self.process_group = get_process_group(process_group) self.params_have_main_grad = params_have_main_grad - self.opt_ty = opt_ty + self.update_heatmap_visualizer = defaultdict(HeatmapVisualizer) + self.ratio_heatmap_visualizer = defaultdict(HeatmapVisualizer) + self.origin_step_func = None + self.origin_start_grad_sync = None + self.config_timestamp = 0 # 后面有校验时间戳, 首次监控无需为了更新config文件时间戳而去改, 可通过dynamic_on开关直接打开 self.config = load_json(config_file_path) validate_config(self.config) - self.module_rank_list = self.config.get("module_ranks", []) - self.format = self.config.get('format', 'tensorboard') - self.eps = self.config.get('eps', 1e-8) - self.ops = self.config.get('ops', []) - self.ndigits = self.config.get('ndigits', 6) - self.all_xy = self.config.get('all_xy', False) - self.xy_distribution = self.config.get('xy_distribution', False) - self.forward_only = self.config.get('forward_only', False) - self.backward_only = self.config.get('backward_only', False) - self.ur_distribution = self.config.get('ur_distribution', False) - self.mv_distribution = self.config.get("mv_distribution", False) - self.wg_distribution = self.config.get("wg_distribution", False) - self.param_distribution = self.config.get("param_distribution", False) - self.mg_direction = self.config.get('mg_direction', False) - self.cc_distribution = self.config.get("cc_distribution", {}) - if not self.cc_distribution.get('enable', False): - self.cc_log_only = False - else: - self.cc_codeline = self.cc_distribution.get('cc_codeline', []) - self.cc_log_only = self.cc_distribution.get('cc_log_only', False) - self.cc_logged_stack = defaultdict(set) - self.cc_pre_hook = self.cc_distribution.get('cc_pre_hook', False) - api_register.initialize_hook(*create_hooks(context=self.cc_context, monitor=self)) - api_register.redirect_api() - - self.common_info() - - alert_setting = self.config.get('alert', {"rules": []}) - self.alert_rules = AnomalyScanner.load_rules(alert_setting["rules"]) - - # 设置时区,使用 'UTC' 作为示例 + self.squash_name = self.config.get('squash_name', True) # 不允许修改防止前后名字对不上 local_tz = pytz.timezone("Asia/Shanghai") # 根据需要调整为目标时区 - cur_time = datetime.now(local_tz).strftime('%b%d_%H-%M-%S') - unique_id = str(uuid.uuid4())[:8] - output_base_dir = get_output_base_dir() - + self.unique_id = str(uuid.uuid4())[:8] + self.output_base_dir = get_output_base_dir() time_tags = self.config.get("append_output", []) - if time_tags: - output_append_dirs = get_target_output_dir(output_base_dir, time_tags[0], time_tags[1]) if dist.is_initialized(): - rank = dist.get_rank() - if time_tags and str(rank) in output_append_dirs: - tensorboard_dir = output_append_dirs[str(rank)] - logger.info(f"append rank({rank}) result to {tensorboard_dir}") + self.rank = dist.get_rank() + if time_tags: + output_append_dirs = get_target_output_dir(self.output_base_dir, time_tags[0], time_tags[1]) + if str(self.rank) in output_append_dirs: + self.tensorboard_dir = output_append_dirs[str(self.rank)] + logger.info(f"append rank({self.rank}) result to {self.tensorboard_dir}") else: - tensorboard_dir = os.path.join(output_base_dir, f"{cur_time}-rank{rank}-{unique_id}") - pp_stage = dist.get_group_rank(self.process_group, rank) - group_mates = dist.get_process_group_ranks(self.process_group) + self.tensorboard_dir = os.path.join(self.output_base_dir, + f"{cur_time}-rank{self.rank}-{self.unique_id}") + self.pp_stage = dist.get_group_rank(self.process_group, self.rank) + self.group_mates = dist.get_process_group_ranks(self.process_group) else: - rank = 0 - tensorboard_dir = os.path.join(output_base_dir, f"{cur_time}-{unique_id}") - pp_stage = 0 - group_mates = [0] - self.rank = rank - - # 初始化AnomalyData工厂 - self.anomaly_data_factory = None - if alert_setting.get('dump', False): - self.anomaly_data_factory = AnomalyDataFactory(rank, pp_stage, group_mates) - - if self.format not in FORMAT_MAPPING: - raise ValueError(f"Unsupported format: {self.format}") - writer = FORMAT_MAPPING[self.format] - self.step_count_per_record = self.config.get('step_count_per_record', 1) - - if (rank in self.module_rank_list) or len(self.module_rank_list) == 0: - self.summary_writer = writer( - WriterInput( - tensorboard_dir, - self.alert_rules, - unique_id, - self.anomaly_data_factory, - self.ndigits, - self.step_count_per_record - ) - ) - # 初始化anomaly detected文件目录 - if self.anomaly_data_factory: - self.anomaly_data_writer = AnomalyDataWriter(os.path.join(output_base_dir, "anomaly_detected"), rank) - self.anomaly_data_writer.init_detected_json() - - # A HeatmapVisualizer instance is associated with an image - self.update_heatmap_visualizer = defaultdict(HeatmapVisualizer) - self.ratio_heatmap_visualizer = defaultdict(HeatmapVisualizer) - self.micro_batch_number = 1 + self.rank = 0 + self.tensorboard_dir = os.path.join(self.output_base_dir, f"{cur_time}-rank{self.rank}-{self.unique_id}") + self.pp_stage = 0 + self.group_mates = [0] + # TYPE2: 只会在set_monitor()主调中赋值的变量 self.model = None - self.weight_hooked = False - self.optimizer_hooked = False - self.param_registered = False self.vpp = False self.dp_group = None self.tp_group = None self.enable_megatron = False + self.micro_batch_number = 1 + self.optimizer_class = None + self.optimizer_mon = None + # TYPE3: 会随着训练中途config配置更新或监控状态改变而重置的变量 + self.module_fwd_hook_context_by_module = defaultdict(ModuleHookContext) + self.module_bwd_hook_context_by_module = defaultdict(ModuleHookContext) + self.optimizer_context = defaultdict(OptimizerContext) + self.cc_context = defaultdict(CommunicationContext) + self.grad_context = GradContext() + self.handles = defaultdict(list) self.param2name = defaultdict(str) self.name2index = defaultdict() self.name2indices = defaultdict() self.name2param = {} - self.param_name_call_id = {} self.duplicate_param = {} self.name2tag = {} + self.param_name_call_id = {} self.call_id = 0 + self.module_struct = defaultdict(dict) self.grad_accs = [] - self.handles = defaultdict(list) - - self.mix_precision_optimizer_mon = OptimizerMonFactory.create_optimizer_mon(opt_ty) - self.print_struct = self.config.get("print_struct", False) + self.weight_hooked = False + self.optimizer_hooked = False + self.param_registered = False self.struct_printed = False - self.module_struct = defaultdict(dict) + + # 动静态区分 + self.dynamic_enable = os.getenv("DYNAMIC_MONITOR", 'False').lower() == 'true' + if self.dynamic_enable: + logger.warning(f"DYNAMIC_MONITOR is set, " + f"please make sure you have 'dynamic_on' and 'collect_times' in {self.config_file_path}") + self.monitoring = False + else: + self.set_config() + # 静态且collect_times>0时在第0步self.monitoring就可以True, 动态默认在下一步开启 + if self.collect_times > 0: + self.monitoring = True def __del__(self): if hasattr(self, "summary_writer"): @@ -303,8 +270,16 @@ class TrainerMon: self._ops = validate_ops(value) @staticmethod - def set_wrapped_optimizer(_wrapped_optimizer): - OptimizerMon.set_wrapped_optimizer(_wrapped_optimizer) + def has_register_backward_hook(module_name, module): + if hasattr(module, '_backward_hooks') and \ + len(module._backward_hooks) > 0 and \ + module._is_full_backward_hook is False: + logger.warning( + f"The {module_name} has registered deprecated register_backward_hook," + f"which may cause abnormal data dump. The backward input/output for this module will be skipped." + ) + return True + return False @staticmethod def generate_cc_metrics(cc_name, cc_tensor): @@ -317,18 +292,76 @@ class TrainerMon: cc_tensor.reset() return metrics - def adhoc_check(self, target_tensor: torch.tensor, module_name: str, tensor_name: str, rank_list, ops_list): - rank = None - if dist.is_initialized(): - rank = dist.get_rank() - if (rank not in rank_list) and len(rank_list) != 0: - return - self.tensor_metrics.stat_insert(target_tensor, ops_list, module_name, tensor_name, rank) + def set_config(self): + logger.info(f"current config: {self.config}") + self.start_step = self.config.get("start_step", 0) + self.collect_times = self.config.get("collect_times", 100000000) # 默认大值, 目的是一直采集 + self.step_interval = self.config.get("step_interval", 1) + self.has_collect_times = 0 # 重设采集计数器 + self.print_struct = self.config.get("print_struct", False) + self.module_rank_list = self.config.get("module_ranks", []) + self.format = self.config.get('format', MonitorConst.CSV) + self.eps = self.config.get('eps', 1e-8) + self.ops = self.config.get('ops', []) + self.ndigits = self.config.get('ndigits', 6) + self.all_xy = self.config.get('all_xy', False) + self.xy_distribution = self.config.get('xy_distribution', False) + self.forward_only = self.config.get('forward_only', False) + self.backward_only = self.config.get('backward_only', False) + self.ur_distribution = self.config.get('ur_distribution', False) + self.mv_distribution = self.config.get("mv_distribution", False) + self.wg_distribution = self.config.get("wg_distribution", False) + self.param_distribution = self.config.get("param_distribution", False) + self.mg_direction = self.config.get('mg_direction', False) + self.cc_distribution = self.config.get("cc_distribution", {}) - def build_tbtag_tensor_map(self, module_name, tag, tensor): - key = get_summary_writer_tag_name(module_name, tag, self.rank) - self._register_param_call_id("_hook_module", key) - return {key: tensor} + if not self.cc_distribution.get('enable', False): + self.cc_log_only = False + else: + self.cc_codeline = self.cc_distribution.get('cc_codeline', []) + self.cc_log_only = self.cc_distribution.get('cc_log_only', False) + self.cc_logged_stack = defaultdict(set) + self.cc_pre_hook = self.cc_distribution.get('cc_pre_hook', False) + self.handles['cc'] = api_register.initialize_hook(*create_hooks(context=self.cc_context, monitor=self)) + api_register.redirect_api() + + self.common_info() + + # 初始化AnomalyData工厂 + alert_setting = self.config.get('alert', {"rules": []}) + self.alert_rules = AnomalyScanner.load_rules(alert_setting["rules"]) + self.anomaly_data_factory = None + if alert_setting.get('dump', False): + self.anomaly_data_factory = AnomalyDataFactory(self.rank, self.pp_stage, self.group_mates) + + # 初始化writer, 创建输出目录 + if self.format not in FORMAT_MAPPING: + logger.error(f"Unsupported format: {self.format}, use default format: {MonitorConst.CSV}") + self.format = MonitorConst.CSV + + if self.ur_distribution and self.format != 'tensorboard': + logger.error("can only set ur_distribution when format is 'tensorboard', cancel ur_distribution") + self.ur_distribution = False + + writer = FORMAT_MAPPING[self.format] + self.step_count_per_record = self.config.get('step_count_per_record', 1) + + if (self.rank in self.module_rank_list) or len(self.module_rank_list) == 0: + self.summary_writer = writer( + WriterInput( + self.tensorboard_dir, + self.alert_rules, + self.unique_id, + self.anomaly_data_factory, + self.ndigits, + self.step_count_per_record + ) + ) + # 初始化anomaly detected文件目录 + if self.anomaly_data_factory: + self.anomaly_data_writer = AnomalyDataWriter(os.path.join(self.output_base_dir, "anomaly_detected"), + self.rank) + self.anomaly_data_writer.init_detected_json() def common_info(self): if not self.xy_distribution: @@ -345,32 +378,20 @@ class TrainerMon: logger.info_on_rank_0('> grad and momentum direction will not be compared.') if not self.cc_distribution.get('enable', False): logger.info_on_rank_0("> cc operator is not monitored.") - if not self.opt_ty: - if self.ur_distribution: - raise Exception("ur_distribution cannot be enabled with unknown optimizer.") - if self.mv_distribution: - raise Exception("mv_distribution cannot be enabled with unknown optimizer.") - def hook_modules(self, model: torch.nn.Module, grad_acc_steps): + def hook_modules(self): if self.module_rank_list and (self.rank not in self.module_rank_list): return - if not isinstance(model, list): - model = [model] - self.model = model - self._register_param_name(model) - - self.micro_batch_number = grad_acc_steps - targets = self.config['targets'] - module_in_all_stage = [key for key in targets.keys() if MonitorConst.VPP_SEP not in key] + module_in_all_stage = [key for key in targets.keys() if MonitorConst.NAME_SEP not in key] for key in module_in_all_stage: struct = targets.pop(key) - targets.update({f'{vpp_stage}{MonitorConst.VPP_SEP}{key}': struct for vpp_stage in range(len(model))}) + targets.update({f'{vpp_stage}{MonitorConst.NAME_SEP}{key}': struct for vpp_stage in range(len(self.model))}) hooked_count = 0 - for vpp_stage, model_chunk in enumerate(model): - vpp_stage = f'{vpp_stage}{MonitorConst.VPP_SEP}' + for vpp_stage, model_chunk in enumerate(self.model): + vpp_stage = f'{vpp_stage}{MonitorConst.NAME_SEP}' targets = [x for x, _ in model_chunk.named_modules()] if self.print_struct else self.config[ 'targets'].keys() hooked_count += self._hook_module(targets, model_chunk, vpp_stage) @@ -394,12 +415,67 @@ class TrainerMon: return wrapped_setup + BackwardHook.setup_input_hook = wrap_hook_setup(BackwardHook.setup_input_hook) BackwardHook.setup_output_hook = wrap_hook_setup(BackwardHook.setup_output_hook) - - if not self.optimizer_hooked: - self.hook_optimizer() return + def set_monitor( + self, + model, + optimizer, + grad_acc_steps=1, + tp_group=None, + dp_group=None, + start_iteration=0 + ): + """External interface""" + global start_step + start_step = start_iteration + logger.info(f'grad acc steps {grad_acc_steps}') + self.micro_batch_number = grad_acc_steps + self.dp_group = dp_group + self.tp_group = tp_group + self.optimizer_mon, self.optimizer_class = OptimizerMonFactory.create_optimizer_mon(optimizer) + self.hook_step_final(optimizer) + if not isinstance(model, list): + model = [model] + self.model = model + if len(model) > 1: + self.vpp = True + self._smallest_rank_print('vpp enabled') + if not self.dynamic_enable: + self.register_hooks(optimizer) + + def register_hooks(self, optimizer): + self._register_param_name() + self.hook_optimizer(optimizer) + self._patch_grad_sync() + self.hook_modules() + self.monitoring = True + + def adhoc_check(self, target_tensor: torch.tensor, module_name: str, tensor_name: str, rank_list, ops_list): + rank = None + if dist.is_initialized(): + rank = dist.get_rank() + if (rank not in rank_list) and len(rank_list) != 0: + return + self.tensor_metrics.stat_insert(target_tensor, ops_list, module_name, tensor_name, rank) + + def build_tbtag_tensor_map(self, module_name, tag, tensor): + key = get_summary_writer_tag_name(module_name, tag, self.rank) + self._register_param_call_id("_hook_module", key) + return {key: tensor} + + def generate_param_map(self, tag, param_tensor): + metrics = {} + for name in self.param2name.values(): + key = get_summary_writer_tag_name(name, tag, self.rank) + self._register_param_call_id("optimizer_pre_step_hook", key) + if name not in param_tensor or param_tensor[name] is None: + continue + metrics[key] = param_tensor[name] + return metrics + def generate_param_metrics(self, opt_context): if not self.param_distribution: return @@ -410,8 +486,8 @@ class TrainerMon: return opt_context.exp_avg_metric = {} opt_context.exp_avg_sq_metric = {} - m_tag_tensor_map = self.generate_param_map('exp_avg', opt_context.param_exp_avg) - v_tag_tensor_map = self.generate_param_map('efxp_avg_sq', opt_context.param_exp_avg_sq) + m_tag_tensor_map = self.generate_param_map(MonitorConst.EXP_AVG, opt_context.param_exp_avg) + v_tag_tensor_map = self.generate_param_map(MonitorConst.EXP_AVG_SQ, opt_context.param_exp_avg_sq) get_metrics(self.ops, m_tag_tensor_map, self.eps, opt_context.exp_avg_metric) get_metrics(self.ops, v_tag_tensor_map, self.eps, opt_context.exp_avg_sq_metric) @@ -435,32 +511,8 @@ class TrainerMon: grad_dict[tag] = grad get_metrics(self.ops, grad_dict, self.eps, self.grad_context.post) - return self.grad_context.post, self.grad_context.pre - - def monitor_gnorm_with_ad(self, model, grad_acc_steps=1, optimizer=None, tp_group=None, dp_group=None, start_iteration=0): - """External interface""" - global start_step - start_step = start_iteration - logger.info(f'grad acc steps {grad_acc_steps}') - self.hook_optimizer(optimizer) - self.micro_batch_number = grad_acc_steps - - self.dp_group = dp_group - self.tp_group = tp_group - - self._register_param_name(model) - self._patch_grad_sync() - self.hook_modules(model, grad_acc_steps) - - def generate_param_map(self, tag, param_tensor): - metrics = {} - for name in self.param2name.values(): - key = get_summary_writer_tag_name(name, tag, self.rank) - self._register_param_call_id("optimizer_pre_step_hook", key) - if name not in param_tensor or param_tensor[name] is None: - continue - metrics[key] = param_tensor[name] - return metrics + unreduced_grad = self.grad_context.acc_metric if self.weight_hooked else self.grad_context.pre + return self.grad_context.post, unreduced_grad def generate_xy_metrics(self): actv = {} @@ -472,12 +524,14 @@ class TrainerMon: return actv, actv_grad def reload_xy(self, xy_distribution=False): + logger.warning("reload_xy() is deprecated and will be removed in a future version. " + "Use DYNAMIC_MONITOR instead.") self.xy_distribution = xy_distribution for handle in self.handles['xy']: handle.remove() self.handles['xy'].clear() - self.hook_modules(self.model, self.micro_batch_number) + self.hook_modules() for _, fwd_context in self.module_fwd_hook_context_by_module.items(): fwd_context.actv.clear() @@ -490,21 +544,23 @@ class TrainerMon: for _, fwd_context in self.module_fwd_hook_context_by_module.items(): if len(fwd_context.actv) == 0: continue - self.summary_writer.write_metrics(self.ops, fwd_context.actv, step, 'actv') + self.summary_writer.write_metrics(self.ops, fwd_context.actv, step, MonitorConst.ACTV) fwd_context.actv.clear() if self.grad_context.actv: - self.summary_writer.write_metrics(self.ops, self.grad_context.actv, step, 'actv_grad') + self.summary_writer.write_metrics(self.ops, self.grad_context.actv, step, MonitorConst.ACTVGRAD) def write_param_tb(self, opt_context): if not self.param_distribution: return - self.summary_writer.write_metrics(self.ops, opt_context.param_metric, opt_context.step, 'param') + self.summary_writer.write_metrics(self.ops, opt_context.param_metric, opt_context.step, MonitorConst.PARAM) def write_mv_tb(self, opt_context): if not self.mv_distribution: return - self.summary_writer.write_metrics(self.ops, opt_context.exp_avg_metric, opt_context.step, 'exp_avg') - self.summary_writer.write_metrics(self.ops, opt_context.exp_avg_sq_metric, opt_context.step, 'exp_avg_sq') + self.summary_writer.write_metrics(self.ops, opt_context.exp_avg_metric, + opt_context.step, MonitorConst.EXP_AVG) + self.summary_writer.write_metrics(self.ops, opt_context.exp_avg_sq_metric, + opt_context.step, MonitorConst.EXP_AVG_SQ) def write_grad_tb(self, step): if not self.wg_distribution: @@ -516,35 +572,37 @@ class TrainerMon: self.summary_writer.write_metrics(self.ops, self.grad_context.acc_metric, step, 'grad_unreduced') self.summary_writer.write_metrics(self.ops, self.grad_context.post, step, 'grad_reduced') - def hook_optimizer(self, optimizer=None): + def hook_optimizer(self, optimizer): # in DDP by default use params_have_main_grad def optimizer_pre_step_hook(optimizer, args, kwargs): context = self.optimizer_context[optimizer] - if self.opt_ty in MonitorConst.DEEPSPEED_OPT_TY: - if not self.name2indices: - self.name2indices = self.mix_precision_optimizer_mon.get_param_index(self.param2name, - self.name2index) - mv_result = self.mix_precision_optimizer_mon.fetch_mv(self, optimizer, self.param2name, - self.name2indices) - self.param2name = mv_result.grad - else: - mv_result = self.mix_precision_optimizer_mon.fetch_mv(self, optimizer, self.param2name) - context.param_exp_avg = mv_result.exp_avg - context.param_exp_avg_sq = mv_result.exp_avg_sq - context.param_adam_update = mv_result.update - context.param_adam_ratio = mv_result.ratio if (self.print_struct and not all(value == {} for value in self.module_struct.values()) and not self.struct_printed): self._save_module_struct() if not self.cc_log_only: - raise Exception("exit after first step when print model struct") + raise Exception("exit after first monitor step when print model struct") if self.cc_log_only and context.step > 0: self._smallest_rank_print("> Used communication ops and corresponding stack") self._smallest_rank_print( json.dumps({k: [i.split(';') for i in v] for k, v in self.cc_logged_stack.items()})) raise Exception("exit after first step when print cc stack") + # skip generate metrics + if context.step < self.start_step or (context.step - self.start_step) % self.step_interval != 0: + return + if MonitorConst.DEEPSPEED_ZERO_OPT_FILTER in self.optimizer_class: # use deepspeed with zero1/2/3 + if not self.name2indices: + self.name2indices = self.optimizer_mon.get_param_index(self.param2name, self.name2index, optimizer) + mv_result = self.optimizer_mon.fetch_mv(self, optimizer, self.param2name, self.name2indices) + self.param2name = mv_result.grad + else: + mv_result = self.optimizer_mon.fetch_mv(self, optimizer, self.param2name) + context.param_exp_avg = mv_result.exp_avg + context.param_exp_avg_sq = mv_result.exp_avg_sq + context.param_adam_update = mv_result.update + context.param_adam_ratio = mv_result.ratio + self.generate_wgrad_metrics() self.generate_mv_metrics(context) self.generate_param_metrics(context) @@ -575,58 +633,186 @@ class TrainerMon: context.metric_dict = metric_dict return - def optimizer_post_step_hook(optimizer, args, kwargs): + def patch_step(func, optimizer): + def wrapper(*args, **kwargs): + optimizer_pre_step_hook(optimizer, args, kwargs) + out = func(*args, **kwargs) + return out + return wrapper + + if self.optimizer_hooked: + return + + optimizer.__class__.step = patch_step(optimizer.__class__.step, optimizer) + + self.optimizer_hooked = True + return + + def dynamic_monitor(self, optimizer): + """ + If dynamic monitor enabled and config.json updated, + remove hooks and register new hooks according to new configuration. + """ + context = self.optimizer_context[optimizer] + if not self.dynamic_enable: + return + try: + # 如果文件时间戳没变, 可以不读取节省时间 + config_timestamp = os.path.getmtime(self.config_file_path) + if config_timestamp == self.config_timestamp: + return + # 更新config文件最新修改时间戳 + self.config_timestamp = config_timestamp + config = load_json(self.config_file_path) + except Exception as e: + logger.error(f"get config.json wrong because {e}, not updated, please check!!!") + return + + if config.get("dynamic_on", False): + try: + validate_config(config) + self.config = config + self.set_config() + logger.warning(f"config is updated at step{context.step - 1}, " + f"will start new hook at step{context.step}.") + except Exception as e: + logger.error(f"set config wrong because {e}, not updated, please check!!!") + return + + self._remove_all_hooks(optimizer) + self.register_hooks(optimizer) + + def hook_step_final(self, optimizer): + def step_final_hook(optimizer, args, kwargs): context = self.optimizer_context[optimizer] rank = dist.get_rank() if dist.is_initialized() else None + # 静态在第0步就可以保存, 动态在第0步不可以, 因为动态设计的就是重置后下一步开启, 第0步的self.monitoring还是False + if self.monitoring: + module_rank_valid = not self.module_rank_list or ( + dist.is_initialized() and dist.get_rank() in self.module_rank_list) + step_condition = (context.step >= self.start_step and ( + context.step - self.start_step) % self.step_interval == 0) + if module_rank_valid and step_condition: + self.has_collect_times += 1 + + if self.anomaly_data_factory: + self.anomaly_data_factory.set_call_id(self.param_name_call_id) + self.write_xy_tb(context.step) + self.write_grad_tb(context.step) + self.write_mv_tb(context) + self.write_param_tb(context) + self.write_adhoc_check(context.step) + + if self.ur_distribution: + for param_name, _ in context.param_adam_update.items(): + self.update_heatmap_visualizer[param_name].visualize( + get_summary_writer_tag_name(param_name, 'adam_update', rank), context.step, + self.summary_writer) + for param_name, _ in context.param_adam_ratio.items(): + self.ratio_heatmap_visualizer[param_name].visualize( + get_summary_writer_tag_name(param_name, 'adam_ratio', rank), context.step, + self.summary_writer) + + if context.metric_dict: + self.summary_writer.write_metrics(self.ops, context.metric_dict, context.step, 'other') + context.metric_dict.clear() + + if self.anomaly_data_factory: + self.anomaly_data_writer.write_detected_json(self.summary_writer.get_anomalies()) + self.summary_writer.clear_anomalies() + self.call_id = 0 + self.param_name_call_id.clear() + + if self.has_collect_times >= self.collect_times: + self._remove_all_hooks_final(optimizer) - if self.anomaly_data_factory: - self.anomaly_data_factory.set_call_id(self.param_name_call_id) - self.write_xy_tb(context.step) - self.write_grad_tb(context.step) - self.write_mv_tb(context) - self.write_param_tb(context) - self.write_adhoc_check(context.step) - - if self.ur_distribution: - for param_name, _ in context.param_adam_update.items(): - self.update_heatmap_visualizer[param_name].visualize( - get_summary_writer_tag_name(param_name, 'adam_update', rank), context.step, self.summary_writer) - for param_name, _ in context.param_adam_ratio.items(): - self.ratio_heatmap_visualizer[param_name].visualize( - get_summary_writer_tag_name(param_name, 'adam_ratio', rank), context.step, self.summary_writer) - - if context.metric_dict: - self.summary_writer.write_metrics(self.ops, context.metric_dict, context.step, 'other') - context.metric_dict.clear() context.step += 1 - if self.anomaly_data_factory: - self.anomaly_data_writer.write_detected_json(self.summary_writer.get_anomalies()) - self.summary_writer.clear_anomalies() - self.call_id = 0 - self.param_name_call_id.clear() - return + self.dynamic_monitor(optimizer) def patch_step(func, optimizer): def wrapper(*args, **kwargs): - optimizer_pre_step_hook(optimizer, args, kwargs) out = func(*args, **kwargs) - optimizer_post_step_hook(optimizer, args, kwargs) + step_final_hook(optimizer, args, kwargs) return out - return wrapper + optimizer.__class__.step = patch_step(optimizer.__class__.step, optimizer) + self.origin_step_func = optimizer.__class__.step + + return + + def _remove_all_hooks(self, optimizer): + # 清空hook handle + for handle in self.handles['xy']: + handle.remove() + self.handles['xy'].clear() + # 清空对应context缓存 + for _, fwd_context in self.module_fwd_hook_context_by_module.items(): + fwd_context.reset() + for _, bwd_context in self.module_bwd_hook_context_by_module.items(): + bwd_context.reset() + self.grad_context.reset() # 权重梯度和激活值梯度都在这 + + if self.origin_start_grad_sync: # megatron + try: + from megatron.core.distributed.param_and_grad_buffer import Bucket + Bucket.start_grad_sync = self.origin_start_grad_sync + logger.info("remove Bucket start_grad_sync") + except ImportError: + pass + try: + from megatron.core.distributed.param_and_grad_buffer import _ParamAndGradBucketGroup + _ParamAndGradBucketGroup.start_grad_sync = self.origin_start_grad_sync + logger.info("remove _ParamAndGradBucketGroup start_grad_sync") + except ImportError: + pass + else: # not megatron + for handle in self.handles['wgrads']: + handle.remove() + self.handles['wgrads'].clear() + self.weight_hooked = False + if self.optimizer_hooked: - return + optimizer.__class__.step = self.origin_step_func - if optimizer: - optimizer.__class__.step = patch_step(optimizer.__class__.step, optimizer) + for _, context in self.optimizer_context.items(): + context.reset() + self.optimizer_hooked = False - else: - if not self.module_rank_list or (dist.is_initialized() and dist.get_rank() in self.module_rank_list): - register_optimizer_step_pre_hook(optimizer_pre_step_hook) - register_optimizer_step_post_hook(optimizer_post_step_hook) - self.optimizer_hooked = True - return + for handle in self.handles['cc']: + handle.remove() + self.handles['cc'].clear() + for _, context in self.cc_context.items(): + context.reset() + + # 清空节点缓存 + self.param2name.clear() + self.name2index.clear() + self.name2indices.clear() + self.name2param.clear() + self.duplicate_param.clear() + self.name2tag.clear() + self.module_struct.clear() + self.grad_accs.clear() + + # 关闭采集状态 + self.monitoring = False + + def _remove_all_hooks_final(self, optimizer): + if self.dynamic_enable: + # 结束后自动重置dynamic_on为False等待用户手动开启 + try: + config = load_json(self.config_file_path) + config['dynamic_on'] = False + save_json(self.config_file_path, config, indent=2) + config_timestamp = os.path.getmtime(self.config_file_path) + self.config_timestamp = config_timestamp + logger.info( + "Finish monitor, set config'dynamic_on=False, will restart by set it to True and update config") + except Exception as e: + logger.warning(f"Finish monitor, set config'dynamic_on=False fail because {e}, please check!!!") + logger.info("Finish monitor") + self._remove_all_hooks(optimizer) def _smallest_rank_print(self, msg): if dist.is_initialized(): @@ -651,8 +837,8 @@ class TrainerMon: self.struct_printed = True def _is_target_param(self, param_name, param, prefix): - squash_name = prefix + squash_param_name(param_name) name = prefix + param_name + squash_name = prefix + squash_param_name(param_name, self.squash_name) for target in self.config['targets'].keys(): if param_name.startswith(target) or squash_name.startswith(target) or name.startswith(target): setattr(param, "zero_out_wgrad", True) @@ -666,7 +852,7 @@ class TrainerMon: if not param.requires_grad: continue if self._is_target_param(param_name, param, prefix): - name = prefix + squash_param_name(param_name) + name = prefix + squash_param_name(param_name, self.squash_name) if name in self.param2name.values(): name = prefix + param_name self.param2name[param] = name @@ -683,28 +869,16 @@ class TrainerMon: } index += 1 - def _register_param_name(self, model): - if self.param_registered: - return - - if not isinstance(model, list): - model = [model] - - if len(model) > 1: - self.vpp = True - self._smallest_rank_print('vpp enabled') - - for vpp_stage, model_chunk in enumerate(model): - prefix = f'{vpp_stage}{MonitorConst.VPP_SEP}' + def _register_param_name(self): + for vpp_stage, model_chunk in enumerate(self.model): + prefix = f'{vpp_stage}{MonitorConst.NAME_SEP}' self._register_chunk(model_chunk, prefix) - self.param_registered = True - def _is_target_module(self, module_name, targets, vpp_stage): if self.all_xy or self.print_struct: - return vpp_stage + squash_param_name(module_name) + return vpp_stage + squash_param_name(module_name, self.squash_name) for pattern in [ - vpp_stage + squash_param_name(module_name), + vpp_stage + squash_param_name(module_name, self.squash_name), vpp_stage + module_name, ]: if pattern in targets: @@ -726,76 +900,79 @@ class TrainerMon: context: ModuleHookContext = self.module_fwd_hook_context_by_module[module] if not context.struct: context.struct = { - MonitorConst.ACTV_IN: get_param_struct(module_input), - MonitorConst.ACTV_OUT: get_param_struct(module_output) + Const.INPUT: get_param_struct(module_input), + Const.OUTPUT: get_param_struct(module_output) } if self.print_struct: self.module_struct[context.module_name].update(context.struct) return if not context.format_by_arg: - context.set_format_by_arg(MonitorConst.ACTV_IN, self.config['targets']) - context.set_format_by_arg(MonitorConst.ACTV_OUT, self.config['targets']) + context.set_format_by_arg(Const.INPUT, self.config['targets']) + context.set_format_by_arg(Const.OUTPUT, self.config['targets']) if not context.format_by_arg: return if not context.verified: - context.focused_in_col = validate_config_spec(context.format_by_arg[MonitorConst.ACTV_IN], + context.focused_in_col = validate_config_spec(context.format_by_arg[Const.INPUT], module_input, context.module_name, - MonitorConst.ACTV_IN) - context.focused_out_col = validate_config_spec(context.format_by_arg[MonitorConst.ACTV_OUT], + Const.INPUT) + context.focused_out_col = validate_config_spec(context.format_by_arg[Const.OUTPUT], module_output, context.module_name, - MonitorConst.ACTV_OUT) + Const.OUTPUT) context.verified = True # expect output be tensor type tbtag_tensor_map = {} cared_input = module_input if context.focused_in_col is None else module_input[context.focused_in_col] tbtag_tensor_map.update( - self.build_tbtag_tensor_map(f'{context.module_name}_{context.micro_step}', MonitorConst.ACTV_IN, - cared_input)) + self.build_tbtag_tensor_map( + f'{context.module_name}.{Const.INPUT}{MonitorConst.NAME_SEP}{context.micro_step}', + MonitorConst.ACTV, cared_input)) cared_output = module_output if context.focused_out_col is None else module_output[context.focused_out_col] tbtag_tensor_map.update( - self.build_tbtag_tensor_map(f'{context.module_name}_{context.micro_step}', MonitorConst.ACTV_OUT, - cared_output)) + self.build_tbtag_tensor_map( + f'{context.module_name}.{Const.OUTPUT}{MonitorConst.NAME_SEP}{context.micro_step}', + MonitorConst.ACTV, cared_output)) get_metrics(self.ops, tbtag_tensor_map, self.eps, context.actv) context.micro_step += 1 if context.micro_step == self.micro_batch_number: context.micro_step = 0 - context.step += 1 return def bwd_hook_fun(module, input_grad, output_grad): context: ModuleHookContext = self.module_bwd_hook_context_by_module[module] if not context.struct: context.struct = { - MonitorConst.ACTVGRAD_IN: get_param_struct(input_grad), - MonitorConst.ACTVGRAD_OUT: get_param_struct(output_grad) + MonitorConst.INPUT_GRAD: get_param_struct(input_grad), + MonitorConst.OUTPUT_GRAD: get_param_struct(output_grad) } if self.print_struct: self.module_struct[context.module_name].update(context.struct) return if not context.format_by_arg: - context.set_format_by_arg(MonitorConst.ACTVGRAD_IN, self.config['targets']) - context.set_format_by_arg(MonitorConst.ACTVGRAD_OUT, self.config['targets']) + context.set_format_by_arg(MonitorConst.INPUT_GRAD, self.config['targets']) + context.set_format_by_arg(MonitorConst.OUTPUT_GRAD, self.config['targets']) if not context.format_by_arg: return if not context.verified: - context.focused_in_col = validate_config_spec(context.format_by_arg[MonitorConst.ACTVGRAD_IN], - input_grad, context.module_name, - MonitorConst.ACTVGRAD_IN) - context.focused_out_col = validate_config_spec(context.format_by_arg[MonitorConst.ACTVGRAD_OUT], - output_grad, context.module_name, - MonitorConst.ACTVGRAD_OUT) + context.focused_in_col = validate_config_spec( + context.format_by_arg[MonitorConst.INPUT_GRAD], + input_grad, context.module_name, MonitorConst.INPUT_GRAD) + context.focused_out_col = validate_config_spec( + context.format_by_arg[MonitorConst.OUTPUT_GRAD], + output_grad, context.module_name, MonitorConst.OUTPUT_GRAD) context.verified = True tbtag_tensor_map = {} cared_input_grad = input_grad if context.focused_in_col is None else input_grad[context.focused_in_col] tbtag_tensor_map.update( - self.build_tbtag_tensor_map(f'{context.module_name}_{context.micro_step}', MonitorConst.ACTVGRAD_IN, - cared_input_grad)) + self.build_tbtag_tensor_map( + f'{context.module_name}.{Const.INPUT}{MonitorConst.NAME_SEP}{context.micro_step}', + MonitorConst.ACTV, cared_input_grad)) cared_output_grad = output_grad if context.focused_out_col is None else output_grad[context.focused_out_col] tbtag_tensor_map.update( - self.build_tbtag_tensor_map(f'{context.module_name}_{context.micro_step}', MonitorConst.ACTVGRAD_OUT, - cared_output_grad)) + self.build_tbtag_tensor_map( + f'{context.module_name}.{Const.OUTPUT}{MonitorConst.NAME_SEP}{context.micro_step}', + MonitorConst.ACTV, cared_output_grad)) if context.micro_step == 0 and context.actvgrad: logger.warning(f"actvgrad context of {context.module_name} is not empty when first micro_step, " @@ -807,7 +984,6 @@ class TrainerMon: context.micro_step += 1 if context.micro_step == self.micro_batch_number: context.micro_step = 0 - context.step += 1 return if self.backward_only and self.forward_only: @@ -822,7 +998,7 @@ class TrainerMon: if not self.backward_only: handle = submodule.register_forward_hook(partial(fwd_hook_fun, name=name)) self.handles['xy'].append(handle) - if not self.forward_only: + if not self.forward_only and not self.has_register_backward_hook(name, submodule): handle = submodule.register_full_backward_hook(bwd_hook_fun) self.handles['xy'].append(handle) self.module_bwd_hook_context_by_module[submodule] = ModuleHookContext(name) @@ -834,7 +1010,10 @@ class TrainerMon: def patch_sync(sync_grad_func): def wrapper(bucket): grad_dict = {} - bucket_params_id_list = [id(params) for params in bucket.params_list] + # Megatron between core_r0.6.0 and core_r0.8.0, this bucket is Bucket. + # When megatron is core_r0.9.0, this bucket is _ParamAndGradBucketGroup. + # In megatron version core_r0.9.0, func start_grad_sync from Bucket moved to _ParamAndGradBucketGroup. + bucket_params_id_list = [id(params) for params in bucket.params] for param, name in self.param2name.items(): if id(param) not in bucket_params_id_list: continue @@ -853,18 +1032,28 @@ class TrainerMon: return wrapper + if not self.wg_distribution: + return + try: from megatron.core.distributed.param_and_grad_buffer import Bucket + self.origin_start_grad_sync = Bucket.start_grad_sync + Bucket.start_grad_sync = patch_sync(Bucket.start_grad_sync) self.enable_megatron = True + logger.info("megatron version is >= core_r0.6.0 <= core_r0.8.0") except ImportError: self.enable_megatron = False - if not self.wg_distribution: - return + try: + from megatron.core.distributed.param_and_grad_buffer import _ParamAndGradBucketGroup + self.origin_start_grad_sync = _ParamAndGradBucketGroup.start_grad_sync + _ParamAndGradBucketGroup.start_grad_sync = patch_sync(_ParamAndGradBucketGroup.start_grad_sync) + self.enable_megatron = True + logger.info("megatron version is > core_r0.8.0 <= core_r0.9.0") + except ImportError: + self.enable_megatron = False - if self.enable_megatron: - Bucket.start_grad_sync = patch_sync(Bucket.start_grad_sync) # differ in different megatron version - else: + if not self.enable_megatron: self._hook_weights() def _hook_weights(self): @@ -881,6 +1070,7 @@ class TrainerMon: else: context_dict[key] = param.grad.clone() + logger.info("hooking weights.") for param, name in self.param2name.items(): key = get_summary_writer_tag_name(name, 'acc_grad', self.rank) setattr(param, 'micro_step', 0) diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/module_metric.py b/debug/accuracy_tools/msprobe/pytorch/monitor/module_metric.py index 77473d32e4bc2d1bc537bb6d2ac2903f96014cbd..87963812006413a90fd33bc70d6172a7c73c3f10 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/module_metric.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/module_metric.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -12,17 +12,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import itertools -import math import re -import statistics import torch -from msprobe.core.common.const import MonitorConst -from msprobe.pytorch.monitor.features import square_sum, get_max, get_min, get_zeros, get_nans, get_norm, get_mean -from msprobe.pytorch.common.log import logger -from msprobe.pytorch.monitor.utils import NAN_TENSOR_ON_DEVICE +from msprobe.pytorch.monitor.features import get_max, get_min, get_zeros, get_nans, get_norm, get_mean +from msprobe.pytorch.monitor.utils import get_nan_tensor def get_summary_writer_tag_name(module_or_param_name: str, tag: str, rank): @@ -32,7 +27,9 @@ def get_summary_writer_tag_name(module_or_param_name: str, tag: str, rank): return f"{module_or_param_name}/rank{rank}/{tag}" -def squash_param_name(param_name): +def squash_param_name(param_name, enable=True): + if not enable: + return param_name name = '' for pattern in ['layers?\.(.*)', 'embeddings?\.(.*)', 'final.*', 'output.*', 'norm.*']: match = re.findall(pattern, param_name) @@ -64,7 +61,7 @@ class TensorMetrics: self.metrics = {} # tensor_tag --> [] self.cur_idx = {} - def stat_insert(self, tensor, stat_ops, module_name, tensor_name, rank, eps=1e-8): + def stat_insert(self, tensor, stat_ops, module_name, tensor_name, rank): """get stats and insert into metrics dictionary""" prefix = get_summary_writer_tag_name(module_name, tensor_name, rank) for stat_op in stat_ops: @@ -150,13 +147,13 @@ def get_metrics(ops, tag2tensor, eps, out_dict=None): """ :param ops: ["op1", "op2"] :param tag2tensor: { - '0:fc_0/input': torch.randn([3, 4]), - '0:fc_0/output': torch.randn([3, 3]) + '0:fc.input:0/actv': torch.randn([3, 4]), + '0:fc.output:0/actv': torch.randn([3, 3]) } :param eps: float 1e-8 :param out_dict:{ - '0:fc_0/input': {"op1": op1(torch.randn([3, 4])), "op2": op2(torch.randn([3, 4]))} - '0:fc_0/output': {"op1": op1(torch.randn([3, 3])), "op2": op2(torch.randn([3, 3]))} + '0:fc.input:0/actv': {"op1": op1(torch.randn([3, 4])), "op2": op2(torch.randn([3, 4]))} + '0:fc.output:0/actv': {"op1": op1(torch.randn([3, 3])), "op2": op2(torch.randn([3, 3]))} } :return: out_dict """ @@ -167,7 +164,7 @@ def get_metrics(ops, tag2tensor, eps, out_dict=None): out_dict[tag] = {} if not torch.is_tensor(tensor): # Non-tensor in/output filled with nan. - out_dict[tag].update({metric_name: NAN_TENSOR_ON_DEVICE for metric_name in ops}) + out_dict[tag].update({metric_name: get_nan_tensor() for metric_name in ops}) continue for metric_name in ops: fun_metric = config_metric_registry.get(metric_name) diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/module_spec_verifier.py b/debug/accuracy_tools/msprobe/pytorch/monitor/module_spec_verifier.py index 9574ca100e3ede82c689c679ce109aa16085f2a9..72c35c90bf9540a31cfa1176274a3d2c66bc8946 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/module_spec_verifier.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/module_spec_verifier.py @@ -79,6 +79,8 @@ class TupleValidator(ConfigValidator): def validate_config_spec(config_spec: str, actual_data, module_name: str, data_type: str): focused_col = None + if not config_spec or not isinstance(config_spec, str): + return focused_col for _, validator_cls in config_validator_registry.items(): config_validator = validator_cls() pattern_match = config_validator.check_pattern_match(config_spec) diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/optimizer_collect.py b/debug/accuracy_tools/msprobe/pytorch/monitor/optimizer_collect.py index 152bffd73c06a270906aea9dd9a0b4ac6d5ecfea..602514836d2531ad4a6be3a23f56bc3b942ba199 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/optimizer_collect.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/optimizer_collect.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from abc import ABC, abstractmethod from collections import defaultdict import torch @@ -24,16 +23,10 @@ from msprobe.pytorch.monitor.utils import MVResult, MVGradResult class OptimizerMon(object): - wrapped_optimizer = None - def __init__(self) -> None: self.fp16_to_fp32_param = {} self.is_stage3 = False - @classmethod - def set_wrapped_optimizer(cls, wrapped_optimizer): - cls.wrapped_optimizer = wrapped_optimizer - def fetch_mv(self, monitor, torch_opt, params2name): pass @@ -83,11 +76,10 @@ class OptimizerMon(object): ratio_dict = defaultdict() param2name = defaultdict() fp32_partitioned_groups_flat_grad = defaultdict() - mix_prec_opt = OptimizerMon.wrapped_optimizer partition_id = dist.get_rank() def get_flatten_grad(self, optimizer, group_idx): - if fp32_partitioned_groups_flat[group_idx].grad is None: + if fp32_partitioned_groups_flat[group_idx].grad is None: if partition_id == dist.get_world_size() - 1 and not self.is_stage3: fp32_partitioned_groups_flat_grad = optimizer.flatten_dense_tensors_aligned( optimizer.averaged_gradients[group_idx], @@ -102,7 +94,7 @@ class OptimizerMon(object): return fp32_partitioned_groups_flat[group_idx].grad for group_idx in range(len(fp32_partitioned_groups_flat)): - fp32_partitioned_groups_flat_grad[group_idx] = get_flatten_grad(self, mix_prec_opt, group_idx) + fp32_partitioned_groups_flat_grad[group_idx] = get_flatten_grad(self, torch_opt, group_idx) for name in params2name.values(): start_idx, end_idx, group_idx, group_with_rank = name2indices[name] @@ -111,9 +103,9 @@ class OptimizerMon(object): fp32_param = fp32_partitioned_groups_flat[group_idx][start_idx: end_idx] fp32_param.grad = fp32_partitioned_groups_flat_grad[group_idx][start_idx: end_idx] param2name[fp32_param] = name - if not mix_prec_opt.state: + if not torch_opt.state: continue - state_param = list(mix_prec_opt.state.values())[group_idx] + state_param = list(torch_opt.state.values())[group_idx] exp_avg = state_param.get("exp_avg", None) exp_avg_sq = state_param.get("exp_avg_sq", None) if exp_avg is None or exp_avg_sq is None: @@ -150,36 +142,34 @@ class MixPrecisionOptimizerMon(OptimizerMon): 混合精度优化器监控类。在混合精度训练中监控和管理优化器。 混合精度训练通过适当降低某些计算的精度来加速训练过程并减少内存消耗。 """ - def map_fp16_tp_fp32_param(self, mix_prec_opt): - for fp16_group, fp32_group in zip(mix_prec_opt.float16_groups, mix_prec_opt.fp32_from_float16_groups): - for fp16_param, fp32_param in zip(fp16_group, fp32_group): - self.fp16_to_fp32_param[fp16_param] = fp32_param + + def map_fp16_tp_fp32_param(self, torch_opt): + for fp16_group, fp32_group in zip(torch_opt.float16_groups, torch_opt.fp32_from_float16_groups): + for fp16_param, fp32_param in zip(fp16_group, fp32_group): + self.fp16_to_fp32_param[fp16_param] = fp32_param def fetch_mv(self, monitor, torch_opt, params2name): - mix_prec_opt = self.wrapped_optimizer + if not self.fp16_to_fp32_param and torch_opt is not None: + self.map_fp16_tp_fp32_param(torch_opt) - if not self.fp16_to_fp32_param and mix_prec_opt is not None: - self.map_fp16_tp_fp32_param(mix_prec_opt) - return self._fetch_mv_in_adam(monitor, torch_opt, params2name) class MegatronDistributedOptimizerMon(OptimizerMon): - def map_fp16_tp_fp32_param(self, mix_prec_opt): - if not (hasattr(mix_prec_opt, "model_float16_groups") and - hasattr(mix_prec_opt, "shard_fp32_from_float16_groups")): + def map_fp16_tp_fp32_param(self, torch_opt): + if not (hasattr(torch_opt, "model_float16_groups") and + hasattr(torch_opt, "shard_fp32_from_float16_groups")): raise Exception( "megatron distributed optimizer should have model_float16_groups and shard_fp32_from_float16_groups, " "if not, please check megatron-lm version") - for fp16_group, shard_fp32_group in zip(mix_prec_opt.model_float16_groups, - mix_prec_opt.shard_fp32_from_float16_groups): + for fp16_group, shard_fp32_group in zip(torch_opt.model_float16_groups, + torch_opt.shard_fp32_from_float16_groups): for fp16_param, shard_fp32_param in zip(fp16_group, shard_fp32_group): self.fp16_to_fp32_param[fp16_param] = shard_fp32_param def fetch_mv(self, monitor, torch_opt, params2name): - mix_prec_opt = self.wrapped_optimizer - if not self.fp16_to_fp32_param and mix_prec_opt is not None: - self.map_fp16_tp_fp32_param(mix_prec_opt) + if not self.fp16_to_fp32_param and torch_opt is not None: + self.map_fp16_tp_fp32_param(torch_opt) return self._fetch_mv_in_adam(monitor, torch_opt, params2name) @@ -191,33 +181,29 @@ class MegatronFP32OptimizerMon(OptimizerMon): class MegatronChainedDistributedOptimizerMon(MegatronDistributedOptimizerMon): def fetch_mv(self, monitor, torch_opt, params2name): - mix_prec_opt = self.wrapped_optimizer - - if not self.fp16_to_fp32_param and mix_prec_opt is not None: - for opt in mix_prec_opt.chained_optimizers: + if not self.fp16_to_fp32_param and torch_opt is not None: + for opt in torch_opt.chained_optimizers: self.map_fp16_tp_fp32_param(opt) if not isinstance(torch_opt, torch.optim.Optimizer): torch_opt.state = {} - for opt in mix_prec_opt.chained_optimizers: + for opt in torch_opt.chained_optimizers: torch_opt.state.update(opt.optimizer.state) return self._fetch_mv_in_adam(monitor, torch_opt, params2name) class MegatronChainedMixPrecisionOptimizerMon(MixPrecisionOptimizerMon): def fetch_mv(self, monitor, torch_opt, params2name): - mix_prec_opt = self.wrapped_optimizer - - if not self.fp16_to_fp32_param and mix_prec_opt is not None: - for opt in mix_prec_opt.chained_optimizers: + if not self.fp16_to_fp32_param and torch_opt is not None: + for opt in torch_opt.chained_optimizers: self.map_fp16_tp_fp32_param(opt) if not isinstance(torch_opt, torch.optim.Optimizer): torch_opt.state = {} - for opt in mix_prec_opt.chained_optimizers: + for opt in torch_opt.chained_optimizers: torch_opt.state.update(opt.optimizer.state) return self._fetch_mv_in_adam(monitor, torch_opt, params2name) - + class DeepSpeedZeroOptimizerStage0Mon(OptimizerMon): def fetch_mv(self, monitor, torch_opt, params2name): @@ -225,9 +211,8 @@ class DeepSpeedZeroOptimizerStage0Mon(OptimizerMon): class DeepSpeedZeroOptimizerStage3Mon(OptimizerMon): - def get_param_index(self, params2name, name2index): - mix_prec_opt = OptimizerMon.wrapped_optimizer - fp16_groups = mix_prec_opt.fp16_partitioned_groups + def get_param_index(self, params2name, name2index, torch_opt): + fp16_groups = torch_opt.fp16_partitioned_groups name2indices = defaultdict() index_length = defaultdict() index = 0 @@ -246,13 +231,11 @@ class DeepSpeedZeroOptimizerStage3Mon(OptimizerMon): def fetch_mv(self, monitor, torch_opt, params2name, name2indices=None): self.is_stage3 = True - mix_prec_opt = OptimizerMon.wrapped_optimizer - fp32_partitioned_groups_flat = mix_prec_opt.fp32_partitioned_groups_flat + fp32_partitioned_groups_flat = torch_opt.fp32_partitioned_groups_flat return self._fetch_mv_grad_in_adam(monitor, torch_opt, params2name, name2indices, fp32_partitioned_groups_flat) class DeepSpeedZeroOptimizerStage1or2Mon(OptimizerMon): - @staticmethod def get_group_index(fp32_length, world_size, index): for i in range(len(fp32_length) - 1): @@ -265,12 +248,11 @@ class DeepSpeedZeroOptimizerStage1or2Mon(OptimizerMon): return sub_interval_start, min(sub_index, world_size - 1) return fp32_length[-1], 0 - def get_param_index(self, params2name, name2index): - mix_prec_opt = OptimizerMon.wrapped_optimizer - padding = mix_prec_opt.groups_padding + def get_param_index(self, params2name, name2index, torch_opt): + padding = torch_opt.groups_padding world_size = dist.get_world_size() fp32_length = [0] - for fp32_group_index, single_partition_of_fp32_group in enumerate(mix_prec_opt.single_partition_of_fp32_groups): + for fp32_group_index, single_partition_of_fp32_group in enumerate(torch_opt.single_partition_of_fp32_groups): fp32_length.append(len(single_partition_of_fp32_group) * world_size + fp32_length[fp32_group_index]) bf16_groups = [] @@ -278,7 +260,7 @@ class DeepSpeedZeroOptimizerStage1or2Mon(OptimizerMon): index_length = defaultdict() index = 0 idx = 0 - for group_idx, bf16_group in enumerate(mix_prec_opt.bit16_groups): + for group_idx, bf16_group in enumerate(torch_opt.bit16_groups): bf16_groups.extend(bf16_group) for param in bf16_group: param_length = len(param.flatten()) @@ -286,7 +268,7 @@ class DeepSpeedZeroOptimizerStage1or2Mon(OptimizerMon): index_length[idx] = (index, index + param_length, group_idx, group_index, group_with_rank) index += param_length idx += 1 - group_length = len(bf16_groups) / len(mix_prec_opt.bit16_groups) + group_length = len(bf16_groups) / len(torch_opt.bit16_groups) for _, name in params2name.items(): name_index = name2index[name] start_idx, end_idx, group_idx, group_index, group_with_rank = index_length[name_index] @@ -300,8 +282,7 @@ class DeepSpeedZeroOptimizerStage1or2Mon(OptimizerMon): return name2indices def fetch_mv(self, monitor, torch_opt, params2name, name2indices=None): - mix_prec_opt = OptimizerMon.wrapped_optimizer - fp32_partitioned_groups_flat = mix_prec_opt.single_partition_of_fp32_groups + fp32_partitioned_groups_flat = torch_opt.single_partition_of_fp32_groups return self._fetch_mv_grad_in_adam(monitor, torch_opt, params2name, name2indices, fp32_partitioned_groups_flat) @@ -312,22 +293,23 @@ class DummyOptimizerMon(OptimizerMon): class OptimizerMonFactory: _optimizer_mon_map = { - "Megatron_Float16OptimizerWithFloat16Params": MixPrecisionOptimizerMon, - "Megatron_DistributedOptimizer": MegatronDistributedOptimizerMon, - "Megatron_ChainedDistributedOptimizer": MegatronChainedDistributedOptimizerMon, - "Megatron_ChainedFloat16OptimizerWithFloat16Params": MegatronChainedMixPrecisionOptimizerMon, - "Megatron_FP32Optimizer": MegatronFP32OptimizerMon, - "DeepSpeedZeroOptimizer_Stage0": DeepSpeedZeroOptimizerStage0Mon, - "DeepSpeedZeroOptimizer_Stage1_or_2": DeepSpeedZeroOptimizerStage1or2Mon, + "FP32Optimizer": MegatronFP32OptimizerMon, + "Float16OptimizerWithFloat16Params": MixPrecisionOptimizerMon, + "DistributedOptimizer": MegatronDistributedOptimizerMon, + "ChainedDistributedOptimizer": MegatronChainedDistributedOptimizerMon, + "ChainedFloat16OptimizerWithFloat16Params": MegatronChainedMixPrecisionOptimizerMon, + "BF16_Optimizer": DeepSpeedZeroOptimizerStage0Mon, + "DeepSpeedZeroOptimizer": DeepSpeedZeroOptimizerStage1or2Mon, "DeepSpeedZeroOptimizer_Stage3": DeepSpeedZeroOptimizerStage3Mon, - "unknown": DummyOptimizerMon + "Adam": DummyOptimizerMon } @staticmethod - def create_optimizer_mon(opt_ty: str): - if not opt_ty: - return DummyOptimizerMon() - optimizer_mon_class = OptimizerMonFactory._optimizer_mon_map.get(opt_ty) - if not optimizer_mon_class: - raise Exception("opt_ty should be one of: " + ", ".join(OptimizerMonFactory._optimizer_mon_map.keys())) - return optimizer_mon_class() + def create_optimizer_mon(optimizer): + # auto replace opt_ty + optimizer_class = optimizer.__class__.__name__ + if optimizer_class == "ChainedOptimizer": + optimizer_class = "Chained" + optimizer.chained_optimizers[0].__class__.__name__ + + optimizer_mon_class = OptimizerMonFactory._optimizer_mon_map.get(optimizer_class, DummyOptimizerMon) + return optimizer_mon_class(), optimizer_class diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/unittest/test_monitor.py b/debug/accuracy_tools/msprobe/pytorch/monitor/unittest/test_monitor.py index 78a1d3a263eb63282e27b2b633be32cfec9b945a..4d5c1a717d80ee30414f25b44a93ddc7257ef2c7 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/unittest/test_monitor.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/unittest/test_monitor.py @@ -1,6 +1,21 @@ +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse import os import re -import argparse from glob import glob import pandas as pd @@ -21,19 +36,19 @@ def parse_logfile(logfile): def parse_monitor_output(output_dir): reduced = {} unreduced = {} - for dir in glob(output_dir + '*'): - rank = int(re.findall('(?<=rank)[\d]*', dir)[0]) + for directory in glob(output_dir + '*'): + rank = int(re.findall('(?<=rank)[\d]*', directory)[0]) unreduced[rank] = [] reduced[rank] = [] - for file in os.listdir(dir): - df = pd.read_csv(os.path.join(dir, file)) + for file in os.listdir(directory): + df = pd.read_csv(os.path.join(directory, file)) if '_unreduced_' in file: unreduced[rank].append(df) pass elif '_reduced_' in file: reduced[rank].append(df) else: - logger.info(f'unexpected file {file} in {dir}') + logger.info(f'unexpected file {file} in {directory}') return reduced, unreduced @@ -41,7 +56,7 @@ def valid_reduce(reduced, unreduced, tp_size, dp_size, sequence_parallel): steps = len(reduced[0]) world_size = len(reduced) errors = [] - for index, row in unreduced[0][0].iterrows(): + for _, row in unreduced[0][0].iterrows(): param = row['param_name'] is_tp_duplicate = False for step in range(2): @@ -103,7 +118,7 @@ def valid_total_norm(total_norm, reduced, duplicate_embedding): if step == 0: logger.info(f'rank {rank} is duplicated in dp group') continue - for index, row in reduced[rank][step].iterrows(): + for _, row in reduced[rank][step].iterrows(): if duplicate_embedding and 'word_embedding' in row['param_name']: continue calculated_norm += row['norm'] ** 2 diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/utils.py b/debug/accuracy_tools/msprobe/pytorch/monitor/utils.py index 57f8baf7f58aa99bafebe675b8eb046a51983a93..94afe56ffcfe7571a189c5f6959b2eb9a2779d81 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/utils.py @@ -36,7 +36,7 @@ except ImportError: if torch.cuda.is_available(): device = "cuda" -NAN_TENSOR_ON_DEVICE = torch.tensor(torch.nan, device=device) +NAN_TENSOR_ON_DEVICE = None FILE_MAX_SIZE = 10 * 1024 * 1024 * 1024 FILE_NAME_MAX_LENGTH = 255 DIRECTORY_MAX_LENGTH = 4096 @@ -57,6 +57,13 @@ def get_output_base_dir(): return os.getenv(MonitorConst.MONITOR_OUTPUT_DIR, MonitorConst.DEFAULT_MONITOR_OUTPUT_DIR) +def get_nan_tensor(): + global NAN_TENSOR_ON_DEVICE + if not NAN_TENSOR_ON_DEVICE: + NAN_TENSOR_ON_DEVICE = torch.tensor(torch.nan, device=device) + return NAN_TENSOR_ON_DEVICE + + def filter_special_chars(func): @wraps(func) def func_level(msg): @@ -82,48 +89,6 @@ def get_param_struct(param): return res -def is_recomputation(): - """Check if the current operation is in the re-computation phase. - - This function inspects the current call stack to indicate whether the current operation is in the - re-computation phase. We use a blacklist mechanism, now supported megatron and mindspeed framework. - megatron: The 'backward' function is called by the 'torch/autograd/function.py' file. - mindspeed: The 'checkpoint_function_backward' function is called by the 'torch/autograd/function.py' - file or the custom module(use CheckpointWithoutOutput) with the 'backward' function is executed within the - 'torch/_tensor.py' file. - - Returns: - bool: True if in the re-computation phase, False otherwise. - """ - backward_function_indices = [] - call_stack = inspect.stack() - - # Identify the function 'backward' is being executed within the 'torch/_tensor.py' file. - for frame_info in call_stack: - if frame_info.function == Const.BACKWARD and frame_info.filename.endswith('torch/_tensor.py'): - del call_stack - return True - - # Identify indices in the call stack where the specific function is being executed - for idx, frame_info in enumerate(call_stack): - if frame_info.function == Const.BACKWARD or frame_info.function == 'checkpoint_function_backward': - backward_function_indices.append(idx) - - # Check if the execution is within 'torch/autograd/function.py' file - for idx in backward_function_indices: - # The Megatron and MindSpeed L0&L1 scenes - if idx + 1 < len(call_stack) and call_stack[idx + 1].filename.endswith('torch/autograd/function.py'): - del call_stack - return True - # The latest MindSpeed L2 and ModelLink scenes - if idx + 2 < len(call_stack) and call_stack[idx + 2].filename.endswith('torch/autograd/function.py'): - del call_stack - return True - - del call_stack - return False - - def validate_ops(ops): if not isinstance(ops, list): raise TypeError("ops should be a list") @@ -208,6 +173,11 @@ def validate_cc_distribution(cc_distribution): raise TypeError(f'{key} of cc_distribution is not supported.') +def validate_squash_name(squash_name): + if not isinstance(squash_name, bool): + raise TypeError('squash_name should be a bool') + + def validate_alert(alert): if not isinstance(alert, dict): raise TypeError('alert should be a dictionary') @@ -276,6 +246,9 @@ def validate_config(config): step_count_per_record = config.get('step_count_per_record', 1) validate_step_count_per_record(step_count_per_record) + squash_name = config.get('squash_name', True) + validate_squash_name(squash_name) + if not targets: if xy_distribution: config["all_xy"] = True diff --git a/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/compare.py b/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/compare.py index 4ccd8a79202c2758b5c34e6bcbe89ed1119db360..4a4632913cfbd92d274e9467929fdfcdf0e7ef0e 100644 --- a/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/compare.py +++ b/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/compare.py @@ -1,8 +1,7 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" -# Copyright (C) 2022-2024. Huawei Technologies Co., Ltd. All rights reserved. -# Licensed under the Apache License, Version 2.0 (the "License"); +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # @@ -13,16 +12,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" import os import time -import numpy as np from collections import namedtuple -from msprobe.pytorch.parse_tool.lib.utils import Util + +import numpy as np + +from msprobe.core.common.file_utils import create_directory, load_npy, save_npy_to_txt, write_csv, os_walk_for_files from msprobe.pytorch.parse_tool.lib.config import Const from msprobe.pytorch.parse_tool.lib.parse_exception import ParseException -from msprobe.core.common.file_utils import create_directory, load_npy, save_npy_to_txt, write_csv, os_walk_for_files +from msprobe.pytorch.parse_tool.lib.utils import Util class Compare: @@ -126,7 +126,7 @@ class Compare: all_close = np.allclose(data_left, data_right, atol=al, rtol=rl) np.seterr(divide='raise') cos_sim = np.dot(data_left, data_right) / ( - np.sqrt(np.dot(data_left, data_left)) * np.sqrt(np.dot(data_right, data_right))) + np.sqrt(np.dot(data_left, data_left)) * np.sqrt(np.dot(data_right, data_right))) err_cnt = 0 total_cnt = data_left.shape[0] diff_table_columns = ['Index', 'Left', 'Right', 'Diff'] diff --git a/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/config.py b/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/config.py index 53fb9ac4407b326d288e3249e204710a4ed30cfa..6dc70afe465bc81e57e22dbe5105d566979ebd03 100644 --- a/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/config.py +++ b/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/config.py @@ -1,8 +1,7 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" -# Copyright (C) 2022-2024. Huawei Technologies Co., Ltd. All rights reserved. -# Licensed under the Apache License, Version 2.0 (the "License"); +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # @@ -13,14 +12,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" import os + import numpy as np class Const: - MS_ACCU_CMP_PATH = '/usr/local/Ascend/ascend-toolkit/latest/tools/operator_cmp/compare/msaccucmp.py' MS_ACCU_CMP_FILE_NAME = 'msaccucmp.py' ROOT_DIR = "" diff --git a/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/file_desc.py b/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/file_desc.py index 14ba27277168bc110b38287afbba957b69f8cdff..c883547251c6fafbfb5884511817f0a632d91ef4 100644 --- a/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/file_desc.py +++ b/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/file_desc.py @@ -1,4 +1,18 @@ -# coding=utf-8 +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import os diff --git a/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/interactive_cli.py b/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/interactive_cli.py index 1ea7dd30153e458b758dc0a79779b54a25fe8289..ac6f3d234e3a6681a580f16e56d94204223102f1 100644 --- a/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/interactive_cli.py +++ b/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/interactive_cli.py @@ -1,8 +1,7 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" -# Copyright (C) 2022-2024. Huawei Technologies Co., Ltd. All rights reserved. -# Licensed under the Apache License, Version 2.0 (the "License"); +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # @@ -13,13 +12,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" -import cmd + import argparse -from msprobe.pytorch.parse_tool.lib.parse_tool import ParseTool -from msprobe.pytorch.parse_tool.lib.utils import Util +import cmd + from msprobe.pytorch.parse_tool.lib.config import Const from msprobe.pytorch.parse_tool.lib.parse_exception import catch_exception +from msprobe.pytorch.parse_tool.lib.parse_tool import ParseTool +from msprobe.pytorch.parse_tool.lib.utils import Util class InteractiveCli(cmd.Cmd): @@ -81,7 +81,7 @@ class InteractiveCli(cmd.Cmd): self.util.check_files_in_path(args.my_dump_path) self.util.check_files_in_path(args.golden_dump_path) if self.util.dir_contains_only(args.my_dump_path, ".npy") and \ - self.util.dir_contains_only(args.golden_dump_path, ".npy"): + self.util.dir_contains_only(args.golden_dump_path, ".npy"): self.parse_tool.do_compare_converted_dir(args) else: self.parse_tool.do_vector_compare(args) diff --git a/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/parse_exception.py b/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/parse_exception.py index 7525230cedc7ff11d4112a55998c6414e8f09217..d6ab6c708aa50a4d050a87b464d740d316065e1c 100644 --- a/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/parse_exception.py +++ b/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/parse_exception.py @@ -1,8 +1,7 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" -# Copyright (C) 2022-2024. Huawei Technologies Co., Ltd. All rights reserved. -# Licensed under the Apache License, Version 2.0 (the "License"); +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # @@ -13,13 +12,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" + import logging + from msprobe.core.common.exceptions import FileCheckException class ParseException(Exception): - PARSE_INVALID_PATH_ERROR = 0 PARSE_NO_FILE_ERROR = 1 PARSE_NO_MODULE_ERROR = 2 @@ -51,4 +50,5 @@ def catch_exception(func): except FileCheckException: log.error("Command execution failed") return result + return inner diff --git a/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/parse_tool.py b/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/parse_tool.py index 372ced0da1a0208dd15b1c571373b84f0c7e288b..ca508886f5324a436e357d5dd9598c8c4f0cd363 100644 --- a/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/parse_tool.py +++ b/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/parse_tool.py @@ -1,8 +1,7 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" -# Copyright (C) 2022-2024. Huawei Technologies Co., Ltd. All rights reserved. -# Licensed under the Apache License, Version 2.0 (the "License"); +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # @@ -13,17 +12,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" + import argparse import os from collections import namedtuple +from msprobe.core.common.file_utils import create_directory +from msprobe.pytorch.parse_tool.lib.compare import Compare from msprobe.pytorch.parse_tool.lib.config import Const +from msprobe.pytorch.parse_tool.lib.parse_exception import catch_exception, ParseException from msprobe.pytorch.parse_tool.lib.utils import Util -from msprobe.pytorch.parse_tool.lib.compare import Compare from msprobe.pytorch.parse_tool.lib.visualization import Visualization -from msprobe.pytorch.parse_tool.lib.parse_exception import catch_exception, ParseException -from msprobe.core.common.file_utils import create_directory + class ParseTool: def __init__(self): @@ -117,7 +117,8 @@ class ParseTool: self.util.check_path_valid(args.golden_dump_path) self.util.check_file_path_format(args.my_dump_path, Const.NPY_SUFFIX) self.util.check_file_path_format(args.golden_dump_path, Const.NPY_SUFFIX) - compare_data_args = namedtuple('compare_data_args', ['my_dump_path', 'golden_dump_path', 'save', 'rtol', 'atol', 'count']) + compare_data_args = namedtuple('compare_data_args', + ['my_dump_path', 'golden_dump_path', 'save', 'rtol', 'atol', 'count']) compare_data_args.__new__.__defaults__ = (False, 0.001, 0.001, 20) res = compare_data_args(args.my_dump_path, args.golden_dump_path, args.save, args.rtol, args.atol, args.count) self.compare.compare_data(res) diff --git a/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/utils.py b/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/utils.py index 3bdc419dd0426b6b9f4551dc176f0fd909cd741b..66229d36b8d0b532eea48f1aa5d96e178ed80cdc 100644 --- a/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/utils.py @@ -1,8 +1,7 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" -# Copyright (C) 2022-2024. Huawei Technologies Co., Ltd. All rights reserved. -# Licensed under the Apache License, Version 2.0 (the "License"); +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # @@ -13,24 +12,24 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" + +import hashlib import os import re -import sys import subprocess -import hashlib +import sys import time -import numpy as np from collections import namedtuple -from msprobe.pytorch.parse_tool.lib.config import Const -from msprobe.pytorch.parse_tool.lib.file_desc import DumpDecodeFileDesc, FileDesc -from msprobe.pytorch.parse_tool.lib.parse_exception import ParseException -from msprobe.core.common.file_utils import change_mode, check_other_user_writable,\ - check_path_executable, check_path_owner_consistent + +import numpy as np from msprobe.core.common.const import FileCheckConst +from msprobe.core.common.file_utils import change_mode, check_other_user_writable, \ + check_path_executable, check_path_owner_consistent from msprobe.core.common.file_utils import check_file_or_directory_path, remove_path, check_file_type, os_walk_for_files from msprobe.pytorch.common.log import logger - +from msprobe.pytorch.parse_tool.lib.config import Const +from msprobe.pytorch.parse_tool.lib.file_desc import DumpDecodeFileDesc, FileDesc +from msprobe.pytorch.parse_tool.lib.parse_exception import ParseException try: from rich.traceback import install @@ -135,7 +134,7 @@ class Util: zero_mask = (data == 0) data[zero_mask] += np.finfo(float).eps return data - + @staticmethod def dir_contains_only(path, endfix): files = os_walk_for_files(path, Const.MAX_TRAVERSAL_DEPTH) @@ -143,11 +142,11 @@ class Util: if not file['file'].endswith(endfix): return False return True - + @staticmethod def localtime_str(): return time.strftime("%Y%m%d%H%M%S", time.localtime()) - + @staticmethod def change_filemode_safe(path): change_mode(path, FileCheckConst.DATA_FILE_AUTHORITY) @@ -208,7 +207,7 @@ class Util: def list_numpy_files(self, path, extern_pattern=''): return self.list_file_with_pattern(path, Const.NUMPY_PATTERN, extern_pattern, - self._gen_numpy_file_info) + self._gen_numpy_file_info) def create_columns(self, content): if not Columns: diff --git a/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/visualization.py b/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/visualization.py index 8d8c151a48925e4d1ec44427698ba21e218b2b28..5b53831b1c6fb9280dbad5621ee222baa2712225 100644 --- a/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/visualization.py +++ b/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/visualization.py @@ -1,8 +1,7 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" -# Copyright (C) 2022-2024. Huawei Technologies Co., Ltd. All rights reserved. -# Licensed under the Apache License, Version 2.0 (the "License"); +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # @@ -13,14 +12,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" + import json -import numpy as np +import numpy as np +from msprobe.core.common.file_utils import FileOpen, load_npy, save_npy_to_txt from msprobe.pytorch.parse_tool.lib.config import Const -from msprobe.pytorch.parse_tool.lib.utils import Util from msprobe.pytorch.parse_tool.lib.parse_exception import ParseException -from msprobe.core.common.file_utils import FileOpen, load_npy, save_npy_to_txt +from msprobe.pytorch.parse_tool.lib.utils import Util class Visualization: @@ -77,7 +76,7 @@ class Visualization: self.util.log.info(" File \"{}\", line {}, in {}".format(item[0], item[1], item[2])) self.util.log.info(" {}".format(item[3])) continue - if len(msg) > 5 and len(msg[5]) >= 3: + if len(msg) > 5 and len(msg[5]) >= 3: summery_info = " [{}][dtype: {}][shape: {}][max: {}][min: {}][mean: {}]" \ .format(msg[0], msg[3], msg[4], msg[5][0], msg[5][1], msg[5][2]) if not title_printed: diff --git a/debug/accuracy_tools/msprobe/pytorch/pt_config.py b/debug/accuracy_tools/msprobe/pytorch/pt_config.py index 01cff973dfbcd007ab45d5845cbae568a3863b11..8293ac969490b103eef630081b6001234ca8bb07 100644 --- a/debug/accuracy_tools/msprobe/pytorch/pt_config.py +++ b/debug/accuracy_tools/msprobe/pytorch/pt_config.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -303,28 +303,25 @@ class GradToolConfig(BaseConfig): check_bounds(self.bounds) +class StructureConfig(BaseConfig): + def __init__(self, json_config): + super().__init__(json_config) + + +TaskDict = { + Const.TENSOR: TensorConfig, + Const.STATISTICS: StatisticsConfig, + Const.OVERFLOW_CHECK: OverflowCheckConfig, + Const.FREE_BENCHMARK: FreeBenchmarkCheckConfig, + Const.RUN_UT: RunUTConfig, + Const.GRAD_PROBE: GradToolConfig, + Const.STRUCTURE: StructureConfig +} + + def parse_task_config(task, json_config): - default_dic = {} - if task == Const.TENSOR: - config_dic = json_config.get(Const.TENSOR, default_dic) - return TensorConfig(config_dic) - elif task == Const.STATISTICS: - config_dic = json_config.get(Const.STATISTICS, default_dic) - return StatisticsConfig(config_dic) - elif task == Const.OVERFLOW_CHECK: - config_dic = json_config.get(Const.OVERFLOW_CHECK, default_dic) - return OverflowCheckConfig(config_dic) - elif task == Const.FREE_BENCHMARK: - config_dic = json_config.get(Const.FREE_BENCHMARK, default_dic) - return FreeBenchmarkCheckConfig(config_dic) - elif task == Const.RUN_UT: - config_dic = json_config.get(Const.RUN_UT, default_dic) - return RunUTConfig(config_dic) - elif task == Const.GRAD_PROBE: - config_dic = json_config.get(Const.GRAD_PROBE, default_dic) - return GradToolConfig(config_dic) - else: - return StatisticsConfig(default_dic) + task_map = json_config.get(task, dict()) + return TaskDict.get(task)(task_map) def parse_json_config(json_file_path, task): diff --git a/debug/accuracy_tools/msprobe/pytorch/service.py b/debug/accuracy_tools/msprobe/pytorch/service.py index 93012b91c0649b161d84186d1bab22c5d59471b9..59091549588761277520830c7949a22781e64490 100644 --- a/debug/accuracy_tools/msprobe/pytorch/service.py +++ b/debug/accuracy_tools/msprobe/pytorch/service.py @@ -28,7 +28,7 @@ from msprobe.core.data_dump.data_processor.base import ModuleForwardInputsOutput from msprobe.core.data_dump.scope import BaseScope from msprobe.pytorch.api_accuracy_checker.common.utils import ApiData from msprobe.pytorch.common.log import logger -from msprobe.pytorch.common.utils import get_rank_if_initialized +from msprobe.pytorch.common.utils import get_rank_if_initialized, is_recomputation from msprobe.pytorch.dump.kernel_dump.kernel_config import create_kernel_config_json from msprobe.pytorch.dump.module_dump.module_processer import ModuleProcesser from msprobe.pytorch.hook_module.api_registry import api_register @@ -57,6 +57,7 @@ class Service: self.should_stop_service = False self.attl = None self.params_grad_info = {} + self.hook_handle_dict = {} # 提前注册,确保注册尽可能多的API hook self.register_api_hook() self.init_for_debug_level() @@ -65,6 +66,7 @@ class Service: def pre_hook(api_or_module_name, module, args, kwargs): if not self.should_execute_hook(module_type, module, True): return args, kwargs + is_recompute = is_recomputation() self.inner_switch = True if module_type == BaseScope.Module_Type_Module: @@ -79,7 +81,13 @@ class Service: return None, None if self.data_collector: module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=None) - self.data_collector.forward_input_data_collect(api_or_module_name, module, pid, module_input_output) + self.data_collector.forward_input_data_collect( + api_or_module_name, + module, + pid, + module_input_output, + is_recompute + ) self.inner_switch = False return args, kwargs @@ -103,7 +111,12 @@ class Service: if not (Const.FORWARD in self.config.data_mode and Const.BACKWARD not in self.config.data_mode): for param_name, param in params_dict.items(): if param.requires_grad: - param.register_hook(grad_hook(module, ori_name, param_name)) + name = ori_name + Const.SEP + param_name + old_handle = self.hook_handle_dict.get(name) + if old_handle and hasattr(old_handle, "remove"): + old_handle.remove() + handle = param.register_hook(grad_hook(module, ori_name, param_name)) + self.hook_handle_dict[name] = handle def init_params_grad_info(module, params_dict): ''' @@ -127,6 +140,7 @@ class Service: def forward_hook(api_or_module_name, module, args, kwargs, output): if not self.should_execute_hook(module_type, module, True): return None + is_recompute = is_recomputation() self.inner_switch = True if self.config.online_run_ut: @@ -162,10 +176,15 @@ class Service: if module_type == BaseScope.Module_Type_Module: api_or_module_name = module.mindstudio_reserved_name[-1] self.data_collector.update_api_or_module_name(api_or_module_name) - params_dict = {key.split(Const.SEP)[-1]: value for key, value in module.named_parameters(recurse=False)} - setattr(module_input_output, Const.PARAMS, params_dict) + params_dict = {} + if self.config.task != Const.STRUCTURE: + params_dict = { + key.split(Const.SEP)[-1]: value + for key, value in module.named_parameters(recurse=False) + } + setattr(module_input_output, Const.PARAMS, params_dict) # 判断是否需要注册参数hook - if not hasattr(module, 'params_grad_name') and params_dict: + if params_dict: ori_name = api_or_module_name.rsplit(Const.SEP, 2)[0] grad_name = ori_name + Const.SEP + Const.PARAMS_GRAD # 首次执行前向hook时,添加params_grad_name属性,并注册参数hook @@ -175,7 +194,8 @@ class Service: api_or_module_name, module, pid, - module_input_output + module_input_output, + is_recompute ) init_params_grad_info(module, params_dict) else: @@ -184,7 +204,8 @@ class Service: api_or_module_name, module, pid, - module_input_output + module_input_output, + is_recompute ) if self.data_collector.if_return_forward_new_output(): @@ -200,6 +221,7 @@ class Service: def backward_hook(api_or_module_name, module, grad_input, grad_output): if not self.should_execute_hook(module_type, module, False): return + is_recompute = is_recomputation() self.inner_switch = True if module_type == BaseScope.Module_Type_Module: @@ -213,7 +235,13 @@ class Service: if self.data_collector: # 此处获取到的grad_input实际为反向过程的输出数据,grad_output为反向过程的输入数据,因此传入时调换顺序 module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_output, grad_output=grad_input) - self.data_collector.backward_data_collect(api_or_module_name, module, pid, module_input_output) + self.data_collector.backward_data_collect( + api_or_module_name, + module, + pid, + module_input_output, + is_recompute + ) self.inner_switch = False pid = os.getpid() @@ -248,6 +276,8 @@ class Service: if self.config.rank and self.current_rank not in self.config.rank: return self.register_module_hook() + if self.config.level == Const.LEVEL_MIX: + register_optimizer_hook(self.data_collector) self.first_start = False if self.config.online_run_ut and torch_version_above_or_equal_2: run_ut_dispatch(self.attl, True, self.config.online_run_ut_recompute) @@ -272,6 +302,10 @@ class Service: if self.config.online_run_ut and torch_version_above_or_equal_2: run_ut_dispatch(self.attl, False, self.config.online_run_ut_recompute) return + if self.config.async_dump: + self.data_collector.fill_stack_tensor_data() + if self.config.task == Const.TENSOR: + self.data_collector.data_processor.dump_async_data() self.data_collector.write_json() def step(self): @@ -279,6 +313,10 @@ class Service: return if self.should_stop_service: return + if self.config.async_dump: + self.data_collector.fill_stack_tensor_data() + if self.config.task == Const.TENSOR: + self.data_collector.data_processor.dump_async_data() self.data_collector.write_json() self.current_iter += 1 self.data_collector.update_iter(self.current_iter) @@ -341,7 +379,6 @@ class Service: dump_path_aggregation.dump_tensor_data_dir = dump_data_dir dump_path_aggregation.free_benchmark_file_path = os.path.join(dump_dir, "free_benchmark.csv") self.data_collector.update_dump_paths(dump_path_aggregation) - self.data_collector.initialize_json_file(framework=Const.PT_FRAMEWORK) def register_api_hook(self): @@ -353,13 +390,10 @@ class Service: ) api_register.api_modularity() - if self.config.level == Const.LEVEL_MIX: - register_optimizer_hook(self.data_collector) - def register_module_hook(self): if self.config.level in [Const.LEVEL_L0, Const.LEVEL_MIX]: logger.info_on_rank_0(f"The module {self.config.task} hook function is successfully mounted to the model.") - self.module_processor.hook_modules(self.model, self.build_hook) + self.module_processor.register_module_hook(self.model, self.build_hook) def attl_init(self): if self.config.online_run_ut: @@ -395,7 +429,7 @@ class Service: def reset_status(self): ModuleProcesser.reset_module_stats() HOOKModule.reset_module_stats() - self.data_collector.data_writer.reset_cache() + self.data_collector.reset_status() self.params_grad_info.clear() if self.config.level == Const.LEVEL_L2: @@ -436,8 +470,6 @@ class Service: def save(self, variable, name, save_backward): if self.config.level != Const.LEVEL_DEBUG: return - - count = self.debug_variable_counter[name] self.debug_variable_counter[name] += 1 @@ -449,4 +481,4 @@ class Service: # backward save if save_backward: - self.data_collector.debug_data_collect_backward(variable, grad_name_with_count) \ No newline at end of file + self.data_collector.debug_data_collect_backward(variable, grad_name_with_count) diff --git a/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare_utils.py b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare_utils.py index 86b17c5686983ea0fe7b9447071561673bd495a2..ab8703dcd353ff32dc0722fc314ade6042d6f567 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare_utils.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare_utils.py @@ -5,13 +5,16 @@ import os import shutil import unittest from unittest.mock import patch +import zlib + +import numpy as np from msprobe.core.common.const import CompareConst, Const from msprobe.core.common.utils import CompareException from msprobe.core.compare.utils import ApiItemInfo, _compare_parser, check_and_return_dir_contents, extract_json, \ count_struct, get_accuracy, append_stack_info, get_rela_diff_summary_mode, get_un_match_accuracy, merge_tensor, \ op_item_parse, read_op, rename_api, resolve_api_special_parameters, result_item_init, stack_column_process, \ - table_value_is_valid, get_name_and_state + table_value_is_valid, get_name_and_state, reorder_op_name_list, reorder_op_x_list, gen_op_item # test_read_op_1 op_data = { @@ -149,16 +152,18 @@ o_result = [ [16], [16], 0.0, 0.0, 0.0, 0.0, '0.0%', '0.0%', '0.0%', '0.0%', 0.19734230637550354, -0.18177609145641327, 0.007903944700956345, 1.0, 0.19734230637550354, -0.18177609145641327, 0.007903944700956345, 1.0, '', '', 'None'], - ['Functional.conv2d.0.forward.output.0', 'Functional.conv2d.0.forward.output.0', 'torch.float32', 'torch.float32', - [1, 16, 28, 28], [1, 16, 28, 28], 0.0, 0.0, 0.0, 0.0, '0.0%', '0.0%', '0.0%', '0.0%', - 2.1166646480560303, -2.190781354904175, -0.003579073818400502, 1.0, - 2.1166646480560303, -2.190781354904175, -0.003579073818400502, 1.0, '', '', 'None'], - ['Functional.conv2d.0.forward.parameters.weight', 'Functional.conv2d.0.forward.parameters.weight', 'torch.float32', 'torch.float32', + ['Functional.conv2d.0.forward.parameters.weight', 'Functional.conv2d.0.forward.parameters.weight', 'torch.float32', + 'torch.float32', [1, 16, 28, 28], [1, 16, 28, 28], 0.0, 0.0, 0.0, 0.0, '0.0%', '0.0%', '0.0%', '0.0%', 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, '', '', 'None'], - ['Functional.conv2d.0.forward.parameters.bias', 'Functional.conv2d.0.forward.parameters.bias', 'torch.float32', 'torch.float32', + ['Functional.conv2d.0.forward.parameters.bias', 'Functional.conv2d.0.forward.parameters.bias', 'torch.float32', + 'torch.float32', [1, 16, 28, 28], [1, 16, 28, 28], 0.0, 0.0, 0.0, 0.0, '0.0%', '0.0%', '0.0%', '0.0%', 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, '', '', 'None'], + ['Functional.conv2d.0.forward.output.0', 'Functional.conv2d.0.forward.output.0', 'torch.float32', 'torch.float32', + [1, 16, 28, 28], [1, 16, 28, 28], 0.0, 0.0, 0.0, 0.0, '0.0%', '0.0%', '0.0%', '0.0%', + 2.1166646480560303, -2.190781354904175, -0.003579073818400502, 1.0, + 2.1166646480560303, -2.190781354904175, -0.003579073818400502, 1.0, '', '', 'None'], ['Functional.conv2d.0.parameters_grad.weight', 'Functional.conv2d.0.parameters_grad.weight', 'torch.float32', 'torch.float32', [1, 16, 28, 28], [1, 16, 28, 28], 0.0, 0.0, 0.0, 0.0, '0.0%', '0.0%', '0.0%', '0.0%', 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, '', '', 'None'], @@ -174,11 +179,13 @@ o_result_unmatch_1 = [ ['Functional.conv2d.0.forward.input.1', 'N/A', 'torch.float32', 'N/A', [16, 1, 5, 5], 'N/A', 'N/A', 'N/A', 'N/A', 'None'], ['Functional.conv2d.0.forward.input.2', 'N/A', 'torch.float32', 'N/A', [16], 'N/A', 'N/A', 'N/A', 'N/A', 'None'], - ['Functional.conv2d.0.forward.output.0', 'N/A', 'torch.float32', 'N/A', [1, 16, 28, 28], 'N/A', 'N/A', 'N/A', 'N/A', + ['Functional.conv2d.0.forward.parameters.weight', 'N/A', 'torch.float32', 'N/A', [1, 16, 28, 28], 'N/A', 'N/A', + 'N/A', 'N/A', 'None'], - ['Functional.conv2d.0.forward.parameters.weight', 'N/A', 'torch.float32', 'N/A', [1, 16, 28, 28], 'N/A', 'N/A', 'N/A', 'N/A', + ['Functional.conv2d.0.forward.parameters.bias', 'N/A', 'torch.float32', 'N/A', [1, 16, 28, 28], 'N/A', 'N/A', 'N/A', + 'N/A', 'None'], - ['Functional.conv2d.0.forward.parameters.bias', 'N/A', 'torch.float32', 'N/A', [1, 16, 28, 28], 'N/A', 'N/A', 'N/A', 'N/A', + ['Functional.conv2d.0.forward.output.0', 'N/A', 'torch.float32', 'N/A', [1, 16, 28, 28], 'N/A', 'N/A', 'N/A', 'N/A', 'None'], ['Functional.conv2d.0.parameters_grad.weight', 'N/A', 'torch.float32', 'N/A', [1, 16, 28, 28], 'N/A', 'N/A', 'N/A', 'N/A', 'None'], @@ -195,15 +202,17 @@ o_result_unmatch_2 = [ ['Functional.conv2d.0.forward.input.2', 'N/A', 'torch.float32', 'N/A', [16], 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 0.19734230637550354, -0.18177609145641327, 0.007903944700956345, 1.0, 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'No bench data matched.', 'None'], - ['Functional.conv2d.0.forward.output.0', 'N/A', 'torch.float32', 'N/A', [1, 16, 28, 28], 'N/A', 'N/A', 'N/A', 'N/A', - 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 2.1166646480560303, -2.190781354904175, -0.003579073818400502, 1.0, 'N/A', 'N/A', - 'N/A', 'N/A', 'N/A', 'No bench data matched.', 'None'], - ['Functional.conv2d.0.forward.parameters.weight', 'N/A', 'torch.float32', 'N/A', [1, 16, 28, 28], 'N/A', 'N/A', 'N/A', 'N/A', + ['Functional.conv2d.0.forward.parameters.weight', 'N/A', 'torch.float32', 'N/A', [1, 16, 28, 28], 'N/A', 'N/A', + 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 1.0, 1.0, 1.0, 1.0, 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'No bench data matched.', 'None'], - ['Functional.conv2d.0.forward.parameters.bias', 'N/A', 'torch.float32', 'N/A', [1, 16, 28, 28], 'N/A', 'N/A', 'N/A', 'N/A', + ['Functional.conv2d.0.forward.parameters.bias', 'N/A', 'torch.float32', 'N/A', [1, 16, 28, 28], 'N/A', 'N/A', 'N/A', + 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 1.0, 1.0, 1.0, 1.0, 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'No bench data matched.', 'None'], + ['Functional.conv2d.0.forward.output.0', 'N/A', 'torch.float32', 'N/A', [1, 16, 28, 28], 'N/A', 'N/A', 'N/A', 'N/A', + 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 2.1166646480560303, -2.190781354904175, -0.003579073818400502, 1.0, 'N/A', 'N/A', + 'N/A', 'N/A', 'N/A', 'No bench data matched.', 'None'], ['Functional.conv2d.0.parameters_grad.weight', 'N/A', 'torch.float32', 'N/A', [1, 16, 28, 28], 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 1.0, 1.0, 1.0, 1.0, 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'No bench data matched.', 'None'], @@ -221,13 +230,15 @@ o_result_unmatch_3 = [ ['Functional.conv2d.0.forward.input.2', 'N/A', 'torch.float32', 'N/A', [16], 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 0.19734230637550354, -0.18177609145641327, 0.007903944700956345, 1.0, 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'No bench data matched.', 'None', '-1'], + ['Functional.conv2d.0.forward.parameters.weight', 'N/A', 'torch.float32', 'N/A', [1, 16, 28, 28], 'N/A', 'N/A', + 'N/A', 'N/A', + 'N/A', 'N/A', 1.0, 1.0, 1.0, 1.0, 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'No bench data matched.', 'None', '-1'], + ['Functional.conv2d.0.forward.parameters.bias', 'N/A', 'torch.float32', 'N/A', [1, 16, 28, 28], 'N/A', 'N/A', 'N/A', + 'N/A', + 'N/A', 'N/A', 1.0, 1.0, 1.0, 1.0, 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'No bench data matched.', 'None', '-1'], ['Functional.conv2d.0.forward.output.0', 'N/A', 'torch.float32', 'N/A', [1, 16, 28, 28], 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 2.1166646480560303, -2.190781354904175, -0.003579073818400502, 1.0, 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'No bench data matched.', 'None', '-1'], - ['Functional.conv2d.0.forward.parameters.weight', 'N/A', 'torch.float32', 'N/A', [1, 16, 28, 28], 'N/A', 'N/A', 'N/A', 'N/A', - 'N/A', 'N/A', 1.0, 1.0, 1.0, 1.0, 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'No bench data matched.', 'None', '-1'], - ['Functional.conv2d.0.forward.parameters.bias', 'N/A', 'torch.float32', 'N/A', [1, 16, 28, 28], 'N/A', 'N/A', 'N/A', 'N/A', - 'N/A', 'N/A', 1.0, 1.0, 1.0, 1.0, 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'No bench data matched.', 'None', '-1'], ['Functional.conv2d.0.parameters_grad.weight', 'N/A', 'torch.float32', 'N/A', [1, 16, 28, 28], 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 1.0, 1.0, 1.0, 1.0, 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'No bench data matched.', 'None', '-1'], ['Functional.conv2d.0.parameters_grad.bias', 'N/A', 'torch.float32', 'N/A', [1, 16, 28, 28], 'N/A', 'N/A', 'N/A', 'N/A', @@ -631,3 +642,209 @@ class TestGetNameAndState(unittest.TestCase): with self.assertRaises(CompareException) as context: get_name_and_state(name) self.assertIn('Invalid name string', str(context.exception.code)) + + +class TestReorderOpNameList(unittest.TestCase): + def test_reorder_op_name_list(self): + # 标准顺序 + op_name_list = ["op.forward.input.0.0", "op.forward.output.0", "op.forward.output.1", "op.forward.parameters.1", "op.forward.parameters.2", "op.parameters_grad.0"] + result = reorder_op_name_list(op_name_list) + expected = ["op.forward.input.0.0", "op.forward.parameters.1", "op.forward.parameters.2", "op.forward.output.0", "op.forward.output.1", "op.parameters_grad.0"] + self.assertEqual(result, expected) + + # 只有输入元素 + op_name_list = ["op.forward.input.0", "op.forward.input.1"] + result = reorder_op_name_list(op_name_list) + expected = ["op.forward.input.0", "op.forward.input.1"] + self.assertEqual(result, expected) + + # 输入为空 + op_name_list = [] + result = reorder_op_name_list(op_name_list) + expected = [] + self.assertEqual(result, expected) + + +class TestReorderOpXList(unittest.TestCase): + def test_reorder_op_x_list(self): + # 标准顺序 + op_name_list = ["op.forward.input.0", "op.forward.output.0", "op.forward.parameters.weight"] + summary_list = ["summary1", "summary2", "summary3"] + data_name_list = ["data1", "data2", "data3"] + result_op_name, result_summary, result_data_name = reorder_op_x_list(op_name_list, summary_list, data_name_list) + self.assertEqual(result_op_name, ["op.forward.input.0", "op.forward.parameters.weight", "op.forward.output.0"]) + self.assertEqual(result_summary, ["summary1", "summary3", "summary2"]) + self.assertEqual(result_data_name, ["data1", "data3", "data2"]) + + # 空 op_name_list 或 summary_list + op_name_list = [] + summary_list = [] + data_name_list = ["data1", "data2", "data3"] + result_op_name, result_summary, result_data_name = reorder_op_x_list(op_name_list, summary_list, data_name_list) + self.assertEqual(result_op_name, []) + self.assertEqual(result_summary, []) + self.assertEqual(result_data_name, ["data1", "data2", "data3"]) + + # 空 data_name_list + op_name_list = ["op.forward.input.0", "op.forward.output.0", "op.forward.parameters.weight"] + summary_list = ["summary1", "summary2", "summary3"] + data_name_list = [] + result_op_name, result_summary, result_data_name = reorder_op_x_list(op_name_list, summary_list, data_name_list) + self.assertEqual(result_op_name, ["op.forward.input.0", "op.forward.parameters.weight", "op.forward.output.0"]) + self.assertEqual(result_summary, ["summary1", "summary3", "summary2"]) + self.assertEqual(result_data_name, []) + + # data_name_list 为 None + op_name_list = ["op.forward.input.0", "op.forward.output.0", "op.forward.parameters.weight"] + summary_list = ["summary1", "summary2", "summary3"] + data_name_list = None + result_op_name, result_summary, result_data_name = reorder_op_x_list(op_name_list, summary_list, data_name_list) + self.assertEqual(result_op_name, ["op.forward.input.0", "op.forward.parameters.weight", "op.forward.output.0"]) + self.assertEqual(result_summary, ["summary1", "summary3", "summary2"]) + self.assertEqual(result_data_name, None) + + +class TestGenOpItem(unittest.TestCase): + def test_gen_op_item_with_data_name(self): + op_data = { + 'data_name': 'test_data', + 'type': 'torch.Tensor', + 'dtype': 'torch.int64', + 'shape': [3], + 'value': [1, 2, 3], + 'Max': 3, + 'Min': 1, + 'Mean': 2, + 'Norm': 2 + } + op_name = 'op_test' + + result = gen_op_item(op_data, op_name) + + self.assertEqual(result['data_name'], 'test_data') + self.assertEqual(result['full_op_name'], 'test_data') + self.assertEqual(result['dtype'], 'torch.int64') + self.assertEqual(result['shape'], [3]) + self.assertEqual(result['Max'], 3) + self.assertEqual(result['Min'], 1) + self.assertEqual(result['Mean'], 2) + self.assertEqual(result['Norm'], 2) + self.assertEqual(result['md5'], f"{zlib.crc32(str(op_data['value']).encode()):08x}") + + def test_gen_op_item_with_empty_data_name(self): + op_data = { + 'data_name': '', + 'type': 'torch.Tensor', + 'value': [1, 2, 3] + } + op_name = 'op_test' + + result = gen_op_item(op_data, op_name) + + # data_name为空时,应该被设置为'-1' + self.assertEqual(result['data_name'], '-1') + self.assertEqual(result['full_op_name'], op_name) + + def test_gen_op_item_with_none_data_name(self): + op_data = { + 'data_name': None, + 'type': 'torch.Tensor', + 'value': [1, 2, 3] + } + op_name = 'op_test' + + result = gen_op_item(op_data, op_name) + + # data_name为None时,应该被设置为'-1' + self.assertEqual(result['data_name'], '-1') + self.assertEqual(result['full_op_name'], op_name) + + def test_gen_op_item_with_type_torch_size(self): + op_data = { + 'data_name': 'test_data', + 'type': 'torch.Size', + 'value': [2, 3, 4] + } + op_name = 'op_test' + + result = gen_op_item(op_data, op_name) + + self.assertEqual(result['dtype'], 'torch.Size') + self.assertEqual(result['shape'], '[2, 3, 4]') + self.assertEqual(result['Max'], None) + self.assertEqual(result['Min'], None) + self.assertEqual(result['Mean'], None) + self.assertEqual(result['Norm'], None) + + def test_gen_op_item_with_type_slice(self): + op_data = { + 'data_name': 'test_data', + 'type': 'slice', + 'value': [1, 2, 3] + } + op_name = 'op_test' + + result = gen_op_item(op_data, op_name) + + self.assertEqual(result['dtype'], 'slice') + self.assertEqual(result['shape'], str(np.shape(np.array(op_data['value'])))) + + def test_gen_op_item_with_type_ellipsis(self): + op_data = { + 'data_name': 'test_data', + 'type': 'ellipsis', + 'value': '...' + } + op_name = 'op_test' + + result = gen_op_item(op_data, op_name) + + self.assertEqual(result['dtype'], 'ellipsis') + self.assertEqual(result['shape'], '[]') + self.assertEqual(result['Max'], '...') + self.assertEqual(result['Min'], '...') + self.assertEqual(result['Mean'], '...') + self.assertEqual(result['Norm'], '...') + + def test_gen_op_item_with_type_torch_process_group(self): + op_data = { + 'data_name': 'test_data', + 'type': 'torch.ProcessGroup', + 'group_ranks': [0, 1] + } + op_name = 'op_test' + + result = gen_op_item(op_data, op_name) + + self.assertEqual(result['dtype'], 'torch.ProcessGroup') + self.assertEqual(result['shape'], '[]') + self.assertEqual(result['Max'], '[0, 1]') + self.assertEqual(result['Min'], '[0, 1]') + self.assertEqual(result['Mean'], '[0, 1]') + self.assertEqual(result['Norm'], '[0, 1]') + + def test_gen_op_item_with_default_dtype(self): + op_data = { + 'data_name': 'test_data', + 'type': 'other_type', + 'value': [1, 2, 3] + } + op_name = 'op_test' + + result = gen_op_item(op_data, op_name) + + self.assertEqual(result['dtype'], str(type(op_data['value']))) + self.assertEqual(result['shape'], '[]') + + def test_gen_op_item_with_md5(self): + op_data = { + 'data_name': 'test_data', + 'type': 'torch.Tensor', + 'value': [1, 2, 3] + } + op_name = 'op_test' + + result = gen_op_item(op_data, op_name) + + expected_md5 = f"{zlib.crc32(str(op_data['value']).encode()):08x}" + self.assertEqual(result['md5'], expected_md5) diff --git a/debug/accuracy_tools/msprobe/test/core_ut/compare/test_cmp_highlight.py b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_cmp_highlight.py index f517e6cfceb1abd86b6b355250ecc7b057b620a6..f561a3e05ec84c3ee75dac50ed5aec2a2af7f7b5 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/compare/test_cmp_highlight.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_cmp_highlight.py @@ -226,8 +226,9 @@ class TestUtilsMethods(unittest.TestCase): self.assertEqual(api_batch._state, Const.INPUT) self.assertEqual(api_batch.input_len, 2) - self.assertEqual(api_batch.output_end_index, 4) self.assertEqual(api_batch.params_end_index, 4) + self.assertEqual(api_batch.output_end_index, 4) + self.assertEqual(api_batch.params_grad_end_index, 4) def test_ApiBatch_increment_output(self): api_name = "functional.conv2d" @@ -238,8 +239,63 @@ class TestUtilsMethods(unittest.TestCase): self.assertEqual(api_batch._state, Const.OUTPUT) self.assertEqual(api_batch.input_len, 1) + self.assertEqual(api_batch.params_end_index, 3) self.assertEqual(api_batch.output_end_index, 4) + self.assertEqual(api_batch.params_grad_end_index, 4) + + def test_ApiBatch_increment_kwargs(self): + api_name = "functional.conv2d" + start = 2 + api_batch = ApiBatch(api_name, start) + + api_batch.increment(Const.KWARGS) + + self.assertEqual(api_batch._state, Const.KWARGS) + self.assertEqual(api_batch.input_len, 2) self.assertEqual(api_batch.params_end_index, 4) + self.assertEqual(api_batch.output_end_index, 4) + self.assertEqual(api_batch.params_grad_end_index, 4) + + def test_ApiBatch_increment_params(self): + api_name = "functional.conv2d" + start = 2 + api_batch = ApiBatch(api_name, start) + + api_batch.increment(Const.PARAMS) + + self.assertEqual(api_batch._state, Const.PARAMS) + self.assertEqual(api_batch.input_len, 1) + self.assertEqual(api_batch.params_end_index, 4) + self.assertEqual(api_batch.output_end_index, 4) + self.assertEqual(api_batch.params_grad_end_index, 4) + + def test_ApiBatch_increment_multiple_input(self): + api_name = "functional.conv2d" + start = 2 + api_batch = ApiBatch(api_name, start) + + api_batch.increment(Const.INPUT) + api_batch.increment(Const.INPUT) + + self.assertEqual(api_batch._state, Const.INPUT) + self.assertEqual(api_batch.input_len, 3) + self.assertEqual(api_batch.params_end_index, 5) + self.assertEqual(api_batch.output_end_index, 5) + self.assertEqual(api_batch.params_grad_end_index, 5) + + def test_ApiBatch_increment_multiple_output(self): + api_name = "functional.conv2d" + start = 2 + api_batch = ApiBatch(api_name, start) + + api_batch.increment(Const.OUTPUT) + api_batch.increment(Const.OUTPUT) + + self.assertEqual(api_batch._state, Const.OUTPUT) + self.assertEqual(api_batch.input_len, 1) + self.assertEqual(api_batch.params_end_index, 3) + self.assertEqual(api_batch.output_end_index, 5) + self.assertEqual(api_batch.params_grad_end_index, 5) @patch("msprobe.core.compare.highlight.logger") def test_value_check(self, mock_logger): diff --git a/debug/accuracy_tools/msprobe/test/core_ut/compare/test_merge_result_utils.py b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_merge_result_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..69b0e7ff01ba7304e977dfd9608ce8482c131d5b --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_merge_result_utils.py @@ -0,0 +1,266 @@ +# coding=utf-8 +""" +# Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import unittest +from unittest.mock import patch + +from msprobe.core.common.const import CompareConst +from msprobe.core.common.utils import CompareException +from msprobe.core.compare.merge_result.utils import replace_compare_index_dict, check_config + + +class TestReplaceCompareIndexDict(unittest.TestCase): + + def setUp(self): + # 初始化测试数据 + self.compare_index_dict = { + 'Max diff': { + 'op_name_1': {0: 'N/A'}, + 'op_name_2': {0: 'N/A'} + }, + 'L2norm diff': { + 'op_name_1': {0: 'N/A'}, + 'op_name_2': {0: 'N/A'} + }, + 'MeanRelativeErr': { + 'op_name_1': {0: 'N/A'}, + 'op_name_2': {0: 'N/A'} + }, + CompareConst.NPU_MAX: { + 'op_name_1': {0: 'tp-0-1-2-3'}, + 'op_name_2': {0: 'tp-0-1-2-3'} + }, + CompareConst.BENCH_MAX: { + 'op_name_1': {0: 'tp-0-1-2-3'}, + 'op_name_2': {0: 'tp-0-1-2-3'} + } + } + self.compare_index_list = ['Max diff', 'L2norm diff', 'MeanRelativeErr', 'NPU max', 'Bench max'] + self.rank_num = 0 + + def test_process_compare_index_dict_na(self): + result = replace_compare_index_dict(self.compare_index_dict, self.compare_index_list, self.rank_num) + + # 检查是否替换了 N/A 值 + self.assertEqual(result['Max diff']['op_name_1'][self.rank_num], 'NPU:tp-0-1-2-3 Bench:tp-0-1-2-3') + self.assertEqual(result['Max diff']['op_name_2'][self.rank_num], 'NPU:tp-0-1-2-3 Bench:tp-0-1-2-3') + + self.assertEqual(result['L2norm diff']['op_name_1'][self.rank_num], 'NPU:tp-0-1-2-3 Bench:tp-0-1-2-3') + self.assertEqual(result['L2norm diff']['op_name_2'][self.rank_num], 'NPU:tp-0-1-2-3 Bench:tp-0-1-2-3') + + self.assertEqual(result['MeanRelativeErr']['op_name_1'][self.rank_num], 'NPU:tp-0-1-2-3 Bench:tp-0-1-2-3') + self.assertEqual(result['MeanRelativeErr']['op_name_2'][self.rank_num], 'NPU:tp-0-1-2-3 Bench:tp-0-1-2-3') + + def test_no_na_values(self): + # 修改测试数据,确保没有 N/A 值 + for index in self.compare_index_list[:-2]: # 排除 'NPU max' 和 'Bench max' + self.compare_index_dict[index] = { + 'op_name_1': {0: 'tp-0-1-2-3'}, + 'op_name_2': {0: 'tp-0-1-2-3'} + } + + result = replace_compare_index_dict(self.compare_index_dict, self.compare_index_list, self.rank_num) + + # 验证返回值没有变化 + self.assertEqual(result['Max diff']['op_name_1'][self.rank_num], 'tp-0-1-2-3') + self.assertEqual(result['Max diff']['op_name_2'][self.rank_num], 'tp-0-1-2-3') + + self.assertEqual(result['L2norm diff']['op_name_1'][self.rank_num], 'tp-0-1-2-3') + self.assertEqual(result['L2norm diff']['op_name_2'][self.rank_num], 'tp-0-1-2-3') + + self.assertEqual(result['MeanRelativeErr']['op_name_1'][self.rank_num], 'tp-0-1-2-3') + self.assertEqual(result['MeanRelativeErr']['op_name_2'][self.rank_num], 'tp-0-1-2-3') + + def test_non_string_npu_bench(self): + # 修改 NPU 和 Bench 统计量为非字符串类型 + self.compare_index_dict[CompareConst.NPU_MAX] = { + 'op_name_1': {0: 123}, + 'op_name_2': {0: 123} + } + self.compare_index_dict[CompareConst.BENCH_MAX] = { + 'op_name_1': {0: 123}, + 'op_name_2': {0: 123} + } + + result = replace_compare_index_dict(self.compare_index_dict, self.compare_index_list, self.rank_num) + + expected_value = 'NPU:123 Bench:123' + self.assertEqual(result['Max diff']['op_name_1'][self.rank_num], expected_value) + self.assertEqual(result['Max diff']['op_name_2'][self.rank_num], expected_value) + + self.assertEqual(result['L2norm diff']['op_name_1'][self.rank_num], expected_value) + self.assertEqual(result['L2norm diff']['op_name_2'][self.rank_num], expected_value) + + self.assertEqual(result['MeanRelativeErr']['op_name_1'][self.rank_num], expected_value) + self.assertEqual(result['MeanRelativeErr']['op_name_2'][self.rank_num], expected_value) + + def test_missing_npu_bench_max(self): + # 移除 NPU_MAX 和 BENCH_MAX 键 + del self.compare_index_dict[CompareConst.NPU_MAX] + del self.compare_index_dict[CompareConst.BENCH_MAX] + + result = replace_compare_index_dict(self.compare_index_dict, self.compare_index_list, self.rank_num) + + # 验证原始数据未改变 + self.assertEqual(result['Max diff']['op_name_1'][self.rank_num], 'N/A') + self.assertEqual(result['Max diff']['op_name_2'][self.rank_num], 'N/A') + + self.assertEqual(result['L2norm diff']['op_name_1'][self.rank_num], 'N/A') + self.assertEqual(result['L2norm diff']['op_name_2'][self.rank_num], 'N/A') + + self.assertEqual(result['MeanRelativeErr']['op_name_1'][self.rank_num], 'N/A') + self.assertEqual(result['MeanRelativeErr']['op_name_2'][self.rank_num], 'N/A') + + def test_unsupported_values(self): + # 'unsupported' + self.compare_index_dict['Max diff'] = { + 'op_name_1': {0: 'unsupported'}, + 'op_name_2': {0: 'unsupported'} + } + self.compare_index_dict['L2norm diff'] = { + 'op_name_1': {0: 'unsupported'}, + 'op_name_2': {0: 'unsupported'} + } + self.compare_index_dict['MeanRelativeErr'] = { + 'op_name_1': {0: 'unsupported'}, + 'op_name_2': {0: 'unsupported'} + } + + result = replace_compare_index_dict(self.compare_index_dict, self.compare_index_list, self.rank_num) + + # 检查是否替换了'unsupported' + expected_value = 'NPU:tp-0-1-2-3 Bench:tp-0-1-2-3' + + self.assertEqual(result['Max diff']['op_name_1'][self.rank_num], expected_value) + self.assertEqual(result['Max diff']['op_name_2'][self.rank_num], expected_value) + + self.assertEqual(result['L2norm diff']['op_name_1'][self.rank_num], expected_value) + self.assertEqual(result['L2norm diff']['op_name_2'][self.rank_num], expected_value) + + self.assertEqual(result['MeanRelativeErr']['op_name_1'][self.rank_num], expected_value) + self.assertEqual(result['MeanRelativeErr']['op_name_2'][self.rank_num], expected_value) + + def test_nan_values(self): + # 'Nan' + self.compare_index_dict['Max diff'] = { + 'op_name_1': {0: 'Nan'}, + 'op_name_2': {0: 'Nan'} + } + self.compare_index_dict['L2norm diff'] = { + 'op_name_1': {0: 'Nan'}, + 'op_name_2': {0: 'Nan'} + } + self.compare_index_dict['MeanRelativeErr'] = { + 'op_name_1': {0: 'Nan'}, + 'op_name_2': {0: 'Nan'} + } + + result = replace_compare_index_dict(self.compare_index_dict, self.compare_index_list, self.rank_num) + + # 检查是否替换了'Nan' + expected_value = 'NPU:tp-0-1-2-3 Bench:tp-0-1-2-3' + + self.assertEqual(result['Max diff']['op_name_1'][self.rank_num], expected_value) + self.assertEqual(result['Max diff']['op_name_2'][self.rank_num], expected_value) + + self.assertEqual(result['L2norm diff']['op_name_1'][self.rank_num], expected_value) + self.assertEqual(result['L2norm diff']['op_name_2'][self.rank_num], expected_value) + + self.assertEqual(result['MeanRelativeErr']['op_name_1'][self.rank_num], expected_value) + self.assertEqual(result['MeanRelativeErr']['op_name_2'][self.rank_num], expected_value) + + def test_empty_dict(self): + # 测试空字典的处理 + empty_dict = {} + result = replace_compare_index_dict(empty_dict, [], self.rank_num) + self.assertEqual(result, {}) + + def test_empty_compare_index_list(self): + # 测试空 compare_index_list 的情况 + result = replace_compare_index_dict(self.compare_index_dict, [], self.rank_num) + self.assertEqual(result, self.compare_index_dict) + + +class TestCheckConfig(unittest.TestCase): + + @patch('msprobe.core.common.file_utils.logger.error') + def test_check_config_empty(self, mock_logger_error): + config = None + + with self.assertRaises(CompareException): + check_config(config) + + mock_logger_error.assert_called_once_with('config.yaml is empty, please check.') + + @patch('msprobe.core.common.file_utils.logger.error') + def test_check_config_missing_api(self, mock_logger_error): + config = { + 'compare_index': ['index1', 'index2'] + } + + with self.assertRaises(CompareException): + check_config(config) + + mock_logger_error.assert_called_once_with('The APIs required to merge data were not found.') + + @patch('msprobe.core.common.file_utils.logger.error') + def test_check_config_api_is_not_list(self, mock_logger_error): + config = { + 'api': 'api1', + 'compare_index': ['index1', 'index2'] + } + + with self.assertRaises(CompareException): + check_config(config) + + mock_logger_error.assert_called_once_with("The config format of 'api' is incorrect, please check.") + + @patch('msprobe.core.common.file_utils.logger.error') + def test_check_config_compare_index_is_not_list(self, mock_logger_error): + config = { + 'api': ['api1', 'api2'], + 'compare_index': 'index1' + } + + with self.assertRaises(CompareException): + check_config(config) + + mock_logger_error.assert_called_once_with("The config format of 'compare_index' is incorrect, please check.") + + def test_check_config_compare_index_is_none(self): + config = { + 'api': ['api1', 'api2'], + 'compare_index': None + } + result_target = { + 'api': ['api1', 'api2'], + 'compare_index': [] + } + result = check_config(config) + + self.assertEqual(result, result_target) + + @patch('msprobe.core.common.file_utils.logger.error') + def test_check_config_success(self, mock_logger_error): + config = { + 'api': ['api1', 'api2'], + 'compare_index': ['index1', 'index2'] + } + + result = check_config(config) + + self.assertEqual(result, config) + mock_logger_error.assert_not_called() diff --git a/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_base.py b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_base.py index 19e9c9173752799c14c9066e946bd84e83661349..8ff89437646ee203aaa4a3fac5bbfea1538e9409 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_base.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_base.py @@ -113,13 +113,23 @@ class TestBaseDataProcessor(unittest.TestCase): expected = {'type': 'int', 'value': 1} self.assertEqual(result, expected) - def test_analyze_numpy(self): - result = BaseDataProcessor._analyze_numpy(5, 'int32') - self.assertEqual(result, {'type': 'int32', 'value': 5}) - def test_get_special_types(self): self.assertIn(int, BaseDataProcessor.get_special_types()) + def test_analyze_numpy(self): + ndarray = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32) + result = BaseDataProcessor._analyze_numpy(ndarray, 'numpy.ndarray') + expected_result = { + 'type': 'numpy.ndarray', + 'dtype': 'int32', + 'shape': (2, 3), + 'Max': 6, + 'Min': 1, + 'Mean': 3.5, + 'Norm':9.539392014169456 + } + self.assertEqual(result, expected_result) + def test_recursive_apply_transform(self): transform = lambda x, _: x * 2 Test = namedtuple("Test", ['a']) @@ -143,7 +153,7 @@ class TestBaseDataProcessor(unittest.TestCase): self.assertEqual(BaseDataProcessor.recursive_apply_transform((1, 2), transform), [2, 4]) self.assertEqual(BaseDataProcessor.recursive_apply_transform({'a': 1}, transform), {'a': 2}) - @patch.object(logger, 'warning') + @patch.object(logger, 'debug') def test_recursive_apply_transform_with_warning(self, mock_logger): transform = lambda x, _: x * 2 BaseDataProcessor.recursive_apply_transform({1, 2, 3}, transform) diff --git a/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_mindspore_processor.py b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_mindspore_processor.py index 5ecd508acf4f076f13325a71b3f8e8cdc446350c..b593d34c5d86c7fb3b4a0e8a3ff548c55555e09d 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_mindspore_processor.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_mindspore_processor.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- """ -# Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. +# Copyright (C) 2024-2025. Huawei Technologies Co., Ltd. All rights reserved. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -26,8 +26,10 @@ from msprobe.core.data_dump.data_processor.base import BaseDataProcessor from msprobe.core.data_dump.data_processor.mindspore_processor import ( MindsporeDataProcessor, TensorDataProcessor, - OverflowCheckDataProcessor + OverflowCheckDataProcessor, + KernelDumpDataProcessor, ) +from msprobe.mindspore.common.log import logger class TestMindsporeDataProcessor(unittest.TestCase): @@ -56,6 +58,7 @@ class TestMindsporeDataProcessor(unittest.TestCase): self.assertEqual(result, expected_result) def test_get_stat_info_float(self): + self.config.async_dump = False tensor = ms.Tensor([1.0, 2.0, 3.0]) result = self.processor.get_stat_info(tensor) self.assertEqual(result.max, 3.0) @@ -63,7 +66,17 @@ class TestMindsporeDataProcessor(unittest.TestCase): self.assertEqual(result.mean, 2.0) self.assertEqual(result.norm, ms.ops.norm(tensor).item()) + def test_get_stat_info_float_async(self): + self.config.async_dump = True + tensor = ms.tensor([1.0, 2.0, 3.0]) + result = self.processor.get_stat_info(tensor).stack_tensor_stat[1] + self.assertEqual(result[0].item(), 3.0) + self.assertEqual(result[1].item(), 1.0) + self.assertEqual(result[2].item(), 2.0) + self.assertEqual(result[3].item(), ms.ops.norm(tensor).item()) + def test_get_stat_info_int(self): + self.config.async_dump = False tensor = ms.Tensor([1, 2, 3], dtype=ms.int32) result = self.processor.get_stat_info(tensor) self.assertEqual(result.max, 3) @@ -71,7 +84,15 @@ class TestMindsporeDataProcessor(unittest.TestCase): self.assertEqual(result.mean, 2) self.assertEqual(result.norm, ms.ops.norm(tensor).item()) + def test_get_stat_info_int_async(self): + self.config.async_dump = True + tensor = ms.tensor([1, 2, 3]) + result = self.processor.get_stat_info(tensor).stack_tensor_stat[1] + self.assertEqual(result[0].item(), 3.0) + self.assertEqual(result[1].item(), 1.0) + def test_get_stat_info_bool(self): + self.config.async_dump = False tensor = ms.Tensor([True, False, True]) result = self.processor.get_stat_info(tensor) self.assertEqual(result.max, True) @@ -79,11 +100,19 @@ class TestMindsporeDataProcessor(unittest.TestCase): self.assertIsNone(result.mean) self.assertIsNone(result.norm) + def test_get_stat_info_bool_async(self): + self.config.async_dump = True + tensor = ms.Tensor([True, False, True]) + result = self.processor.get_stat_info(tensor).stack_tensor_stat[1] + self.assertEqual(result[0].item(), True) + self.assertEqual(result[1].item(), False) + @patch.object(MindsporeDataProcessor, 'get_md5_for_tensor') def test__analyze_tensor(self, get_md5_for_tensor): get_md5_for_tensor.return_value = "test_md5" tensor = ms.Tensor(np.array([1, 2, 3], dtype=np.int32)) self.config.summary_mode = 'md5' + self.config.async_dump = False suffix = "test_tensor" expected_result = { 'type': 'mindspore.Tensor', @@ -112,6 +141,7 @@ class TestTensorDataProcessor(unittest.TestCase): @patch('msprobe.core.data_dump.data_processor.mindspore_processor.save_tensor_as_npy') def test_analyze_tensor(self, mock_save): self.config.framework = "mindspore" + self.config.async_dump = False tensor = ms.Tensor([1.0, 2.0, 3.0]) suffix = 'suffix' result = self.processor._analyze_tensor(tensor, suffix) @@ -239,3 +269,68 @@ class TestOverflowCheckDataProcessor(unittest.TestCase): return_value=True): self.data_processor._analyze_tensor("tensor", "suffix") mock_warning.assert_called_with("The file path file_path length exceeds limit.") + +class TestKernelDumpDataProcessor(unittest.TestCase): + def setUp(self): + self.config = MagicMock() + self.data_writer = MagicMock() + self.processor = KernelDumpDataProcessor(self.config, self.data_writer) + + @patch.object(logger, 'warning') + def test_print_unsupported_log(self, mock_logger_warning): + self.processor._print_unsupported_log("test_api_name") + mock_logger_warning.assert_called_with("The kernel dump does not support the test_api_name API.") + + @patch('msprobe.core.data_dump.data_processor.mindspore_processor.KernelDumpDataProcessor.start_kernel_dump') + @patch('msprobe.core.data_dump.data_processor.mindspore_processor.has_adump', new=True) + def test_analyze_pre_forward_with_adump(self, mock_start_kernel_dump): + self.processor.analyze_forward_input("test_api_name", None, None) + mock_start_kernel_dump.assert_called_once() + self.assertTrue(self.processor.enable_kernel_dump) + + @patch('msprobe.core.data_dump.data_processor.mindspore_processor.has_adump', new=False) + @patch.object(logger, 'warning') + def test_analyze_pre_forward_without_adump(self, mock_logger_warning): + self.processor.enable_kernel_dump = True + self.processor.analyze_forward_input("test_api_name", None, None) + mock_logger_warning.assert_called_with("The current msprobe package does not compile adump, and kernel dump cannot be used.") + self.assertFalse(self.processor.enable_kernel_dump) + + @patch('msprobe.core.data_dump.data_processor.mindspore_processor.KernelDumpDataProcessor.stop_kernel_dump') + @patch.object(logger, 'info') + def test_analyze_forward_successfully(self, mock_logger_info, mock_stop_kernel_dump): + self.processor.enable_kernel_dump = True + self.processor.analyze_forward_output('test_api_name', None, None) + self.assertFalse(self.processor.enable_kernel_dump) + mock_stop_kernel_dump.assert_called_once() + mock_logger_info.assert_called_with("The kernel data of test_api_name is dumped successfully.") + + @patch('msprobe.core.data_dump.data_processor.mindspore_processor.has_adump', new=True) + @patch('msprobe.core.data_dump.data_processor.mindspore_processor.KernelDumpDataProcessor.start_kernel_dump') + def test_analyze_pre_backward_with_adump(self, mock_start_kernel_dump): + self.processor.enable_kernel_dump = True + self.processor.analyze_backward_input("test_api_name", None, None) + self.assertTrue(self.processor.enable_kernel_dump) + mock_start_kernel_dump.assert_called_once() + + @patch('msprobe.core.data_dump.data_processor.mindspore_processor.has_adump', new=False) + @patch.object(logger, 'warning') + def test_analyze_pre_backward_without_adump(self, mock_logger_warning): + self.processor.enable_kernel_dump = True + self.processor.analyze_backward_input("test_api_name", None, None) + self.assertFalse(self.processor.enable_kernel_dump) + mock_logger_warning.assert_called_with("The current msprobe package does not compile adump, and kernel dump cannot be used.") + + @patch('msprobe.core.data_dump.data_processor.mindspore_processor.KernelDumpDataProcessor.stop_kernel_dump') + @patch.object(logger, 'info') + def test_analyze_backward_successfully(self, mock_logger_info, mock_stop_kernel_dump): + self.processor.enable_kernel_dump = True + self.processor.analyze_backward('test_api_name', None, None) + self.assertFalse(self.processor.enable_kernel_dump) + mock_stop_kernel_dump.assert_called_once() + mock_logger_info.assert_called_with("The kernel data of test_api_name is dumped successfully.") + + def test_reset_status(self): + self.processor.enable_kernel_dump = False + self.processor.reset_status() + self.assertTrue(self.processor.enable_kernel_dump) diff --git a/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_pytorch_processor.py b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_pytorch_processor.py index b33bbef048f7fa0f23d8cc0a0871850f45490b34..34064e7cc2b9d0aa5c0c2e98806b8993137a589c 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_pytorch_processor.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_pytorch_processor.py @@ -70,6 +70,14 @@ class TestPytorchDataProcessor(unittest.TestCase): self.assertEqual(result.mean, 2.0) self.assertEqual(result.norm, torch.norm(tensor).item()) + def test_get_stat_info_float_async(self): + tensor = torch.tensor([1.0, 2.0, 3.0]) + result = self.processor.get_stat_info_async(tensor).stack_tensor_stat[1] + self.assertEqual(result[0].item(), 3.0) + self.assertEqual(result[1].item(), 1.0) + self.assertEqual(result[2].item(), 2.0) + self.assertEqual(result[3].item(), torch.norm(tensor).item()) + def test_get_stat_info_int(self): tensor = torch.tensor([1, 2, 3], dtype=torch.int32) result = self.processor.get_stat_info(tensor) @@ -78,6 +86,14 @@ class TestPytorchDataProcessor(unittest.TestCase): self.assertEqual(result.mean, 2) self.assertEqual(result.norm, torch.norm(tensor.float()).item()) + def test_get_stat_info_int_async(self): + tensor = torch.tensor([1, 2, 3]) + result = self.processor.get_stat_info_async(tensor).stack_tensor_stat[1] + self.assertEqual(result[0].item(), 3.0) + self.assertEqual(result[1].item(), 1.0) + self.assertEqual(result[2].item(), 2.0) + self.assertEqual(result[3].item(), torch.norm(tensor.float()).item()) + def test_get_stat_info_empty(self): tensor = torch.tensor([]) result = self.processor.get_stat_info(tensor) @@ -94,6 +110,12 @@ class TestPytorchDataProcessor(unittest.TestCase): self.assertIsNone(result.mean) self.assertIsNone(result.norm) + def test_get_stat_info_bool_async(self): + tensor = torch.tensor([True, False, True]) + result = self.processor.get_stat_info_async(tensor).stack_tensor_stat[1] + self.assertEqual(result[0].item(), True) + self.assertEqual(result[1].item(), False) + def test_get_stat_info_with_scalar_tensor(self): scalar_tensor = torch.tensor(42.0) result = PytorchDataProcessor.get_stat_info(scalar_tensor) @@ -204,6 +226,23 @@ class TestPytorchDataProcessor(unittest.TestCase): } self.assertEqual(result, expected) + def test_analyze_reduce_op_successful(self): + arg = dist.ReduceOp.SUM + result = self.processor._analyze_reduce_op(arg) + expected = {'type': 'torch.distributed.ReduceOp', 'value': 'RedOpType.SUM'} + self.assertEqual(result, expected) + + @patch.object(logger, 'warning') + def test_analyze_reduce_op_failed(self, mock_logger_warning): + class TestReduceOp: + def __str__(self): + raise Exception("failed to convert str type") + arg = TestReduceOp() + self.processor._analyze_reduce_op(arg) + mock_logger_warning.assert_called_with( + "Failed to get value of torch.distributed.ReduceOp with error info: failed to convert str type." + ) + def test_get_special_types(self): special_types = self.processor.get_special_types() self.assertIn(torch.Tensor, special_types) @@ -232,7 +271,7 @@ class TestPytorchDataProcessor(unittest.TestCase): numpy_element = np.int64(1) converted_numpy, numpy_type = self.processor._convert_numpy_to_builtin(numpy_element) result = self.processor.analyze_single_element(numpy_element, []) - expected_result = self.processor._analyze_numpy(converted_numpy, numpy_type) + expected_result = {"type": numpy_type, "value": converted_numpy} self.assertEqual(result, expected_result) def test_analyze_single_element_tensor(self): @@ -257,6 +296,7 @@ class TestPytorchDataProcessor(unittest.TestCase): get_md5_for_tensor.return_value = 'mocked_md5' tensor = torch.tensor([1.0, 2.0, 3.0]) self.config.summary_mode = 'md5' + self.config.async_dump = False result = self.processor._analyze_tensor(tensor, 'suffix') expected = { 'type': 'torch.Tensor', @@ -299,6 +339,7 @@ class TestTensorDataProcessor(unittest.TestCase): @patch('torch.save') def test_analyze_tensor(self, mock_save): self.config.framework = "pytorch" + self.config.async_dump = False tensor = torch.tensor([1.0, 2.0, 3.0]) suffix = 'suffix' result = self.processor._analyze_tensor(tensor, suffix) diff --git a/debug/accuracy_tools/msprobe/test/core_ut/data_dump/test_data_collector.py b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/test_data_collector.py index 3d7f64171c309f536a4d5dbe920d01cbc2c1b612..b9d2e7abef7244fc12dc71e3113c26af52529ce9 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/data_dump/test_data_collector.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/test_data_collector.py @@ -116,4 +116,4 @@ class TestDataCollector(unittest.TestCase): self.data_collector.debug_data_collect_backward("variable", "name_with_count") mock_update_debug.assert_called_with({"name_with_count": "all_none_data_info"}) mock_analyze_debug_backward.assert_called_with("variable", "name_with_count", self.data_collector.data_writer.cache_debug['data']) - self.data_collector.data_writer.cache_debug = None \ No newline at end of file + self.data_collector.data_writer.cache_debug = None diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/api_accuracy_checker/test_data_manager.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/api_accuracy_checker/test_data_manager.py index bb4c8b197ef8362921858839ca3790224715a39a..9cfad00d8ff13e91eb84fff5f46ab434f9ed1d4d 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/api_accuracy_checker/test_data_manager.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/api_accuracy_checker/test_data_manager.py @@ -2,7 +2,8 @@ import unittest from unittest.mock import patch, mock_open, MagicMock import os from msprobe.mindspore.api_accuracy_checker.api_accuracy_checker import DataManager -from msprobe.core.common.const import MsCompareConst, CompareConst +from msprobe.core.common.const import CompareConst +from msprobe.mindspore.common.const import MsCompareConst class TestDataManager(unittest.TestCase): diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/debugger/test_ms_precision_debugger.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/debugger/test_ms_precision_debugger.py index 6774b76e05a779620c124fe944b348a118177db1..066ff537ce6fba12f712ae3d4681115499be35a6 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/debugger/test_ms_precision_debugger.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/debugger/test_ms_precision_debugger.py @@ -43,7 +43,8 @@ class TestPrecisionDebugger(unittest.TestCase): "dump_path": "/absolute_path", "rank": [], "step": [], - "level": "L1" + "level": "L1", + "async_dump": False } common_config = CommonConfig(json_config) @@ -128,3 +129,41 @@ class TestPrecisionDebugger(unittest.TestCase): debugger.service = MagicMock() debugger.forward_backward_dump_end() debugger.service.stop.assert_called_once() + + def test_is_graph_dump_level_not_kernel(self): + config = MagicMock() + config.level = "NOT_KERNEL" + config.list = ["some_value"] + result = PrecisionDebugger._is_graph_dump(config) + self.assertFalse(result) + + def test_is_graph_dump_empty_list(self): + config = MagicMock() + config.level = MsConst.KERNEL + config.list = [] + result = PrecisionDebugger._is_graph_dump(config) + self.assertTrue(result) + + def test_is_graph_dump_multiple_items_in_list(self): + config = MagicMock() + config.level = MsConst.KERNEL + config.list = ["item1", "item2"] + result = PrecisionDebugger._is_graph_dump(config) + self.assertTrue(result) + + def test_is_graph_dump_single_item_with_slash_or_dash(self): + config = MagicMock() + config.level = MsConst.KERNEL + config.list = ["item/with/slash"] + result = PrecisionDebugger._is_graph_dump(config) + self.assertTrue(result) + config.list = ["item-with-dash"] + result = PrecisionDebugger._is_graph_dump(config) + self.assertTrue(result) + + def test_is_graph_dump_single_item_without_dash_or_slash(self): + config = MagicMock() + config.level = MsConst.KERNEL + config.list = ["Functional.relu.1.forward"] + result = PrecisionDebugger._is_graph_dump(config) + self.assertFalse(result) diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/dump/test_ms_kernel_config.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/dump/test_ms_kernel_config.py new file mode 100644 index 0000000000000000000000000000000000000000..54c59b6409cb546384dcb50f47c7c27975fa1cb7 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/dump/test_ms_kernel_config.py @@ -0,0 +1,53 @@ +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest.mock import patch + +from msprobe.mindspore.dump.kernel_dump.kernel_config import create_kernel_config_json + + +class TestPtKernelConfig(unittest.TestCase): + @patch("msprobe.mindspore.dump.kernel_dump.kernel_config.save_json") + def test_create_kernel_config_json_with_rank(self, mock_save_json): + dump_path = "./step0" + cur_rank = 0 + kernel_config_path = create_kernel_config_json(dump_path, cur_rank) + self.assertEqual(kernel_config_path, "./step0/kernel_config_0.json") + config_info = { + "dump": { + "dump_list": [], + "dump_path": dump_path, + "dump_mode": "all", + "dump_op_switch": "on" + } + } + mock_save_json.assert_called_once_with(kernel_config_path, config_info, indent=4) + + @patch("msprobe.mindspore.dump.kernel_dump.kernel_config.save_json") + def test_create_kernel_config_json_without_rank(self, mock_save_json): + dump_path = "./step0" + cur_rank = '' + kernel_config_path = create_kernel_config_json(dump_path, cur_rank) + self.assertEqual(kernel_config_path, "./step0/kernel_config.json") + config_info = { + "dump": { + "dump_list": [], + "dump_path": dump_path, + "dump_mode": "all", + "dump_op_switch": "on" + } + } + mock_save_json.assert_called_once_with(kernel_config_path, config_info, indent=4) diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/free_benchmark/test_ms_api_pynative_self_check.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/free_benchmark/test_ms_api_pynative_self_check.py index a6ddae1f59b111774eae3ca2003fc4362e373d9d..e589dd4d58715d74644047f8c7e7a6ce79ccf225 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/free_benchmark/test_ms_api_pynative_self_check.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/free_benchmark/test_ms_api_pynative_self_check.py @@ -90,7 +90,7 @@ class TestApiPyNativeSelfCheck(TestCase): mock_set_hook.assert_called_once() def test_build_hook(self): - _, forward_hook, backward_hook = self.checker.build_hook("Functional.add.") + _, forward_hook, backward_hook, _ = self.checker.build_hook("Functional.add.") cell = Cell() diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/grad_probe/test_grad_analyzer.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/grad_probe/test_grad_analyzer.py index 7a182869eedf07fdd360f5ba295c2100393e83a4..802769d9005916c8723d436349d13ca7f557a00a 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/grad_probe/test_grad_analyzer.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/grad_probe/test_grad_analyzer.py @@ -6,7 +6,7 @@ import mindspore as ms from unittest import TestCase, mock from unittest.mock import patch from mindspore import Tensor, Parameter -from msprobe.mindspore.grad_probe.grad_analyzer import CSVGenerator, grad_dump +from msprobe.mindspore.grad_probe.grad_analyzer import CSVGenerator, grad_dump, GradDumpConfig from msprobe.mindspore.grad_probe.global_context import grad_context from msprobe.core.grad_probe.constant import GradConst @@ -102,7 +102,9 @@ class TestGradAnalyzer(TestCase): # Run the grad_dump function try: - grad_dump(dump_dir, g_name, dump_step, grad, level, bounds) + conf = GradDumpConfig(dump_dir=dump_dir, g_name=g_name, dump_step=dump_step, grad=grad, level=level, + bounds=bounds) + grad_dump(conf) except RuntimeError as e: # If TensorDump fails due to environment, skip the file existence check self.skipTest(f"TensorDump operation failed: {e}") diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/hook_module/test_ms_wrap_distributed.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/hook_module/test_ms_wrap_distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..325996ae0ee422f6e111ad831d20fea1e8344736 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/hook_module/test_ms_wrap_distributed.py @@ -0,0 +1,112 @@ +import unittest +from unittest.mock import Mock, patch +import numpy as np +import mindspore +from mindspore import Tensor, ops + +from msprobe.mindspore.monitor.distributed.wrap_distributed import ( + catch_data, + DistributedOPTemplate, + ApiRegistry, + get_distributed_ops, + is_target_line, + op_aggregate, + update_data +) +from msprobe.core.common.const import MonitorConst + +class TestWrapDistributed(unittest.TestCase): + def setUp(self): + self.mock_ops = ['min', 'max', 'norm'] + self.mock_rank = '0' + + def hook(self): + def forward_pre_hook(nope, inputs): + return inputs + + def forward_hook(nope, inputs, output): + return 2 + + return [forward_pre_hook], [forward_hook] + + def test_catch_data(self): + # 准备测试数据 + cc_context = Mock() + cc_context.data = {} + cc_name = "all_reduce" + args = [Tensor(np.array([1.0, 2.0, 3.0]))] + + # 测试输入数据捕获 + catch_data(cc_context, cc_name, self.mock_ops, args, MonitorConst.PREFIX_PRE) + self.assertTrue('all_reduce/pre_0' in cc_context.data) + + # 测试输出数据捕获 + catch_data(cc_context, cc_name, self.mock_ops, args, MonitorConst.PREFIX_POST) + self.assertTrue('all_reduce/post_0' in cc_context.data) + + def test_distributed_op_template(self): + # 测试分布式算子模板 + pre_hooks, post_hooks = self.hook() + op = DistributedOPTemplate("all_reduce", pre_hooks, post_hooks) + + self.assertEqual(op.op_name_, "all_reduce") + self.assertEqual(len(op.cc_hooks), 2) + + def test_api_registry(self): + # 测试API注册器 + registry = ApiRegistry() + + # 测试API存储 + ori_api_group = Mock() + api_list = ["all_reduce", "all_gather"] + api_ori_attr = {} + + ApiRegistry.store_ori_attr(ori_api_group, api_list, api_ori_attr) + self.assertEqual(len(api_ori_attr), 2) + + def test_op_aggregate(self): + # 测试算子聚合 + tensor_list = [Tensor(1.0), Tensor(2.0), Tensor(3.0)] + + # 测试min操作 + result = op_aggregate('min', tensor_list) + self.assertEqual(result.asnumpy(), 1.0) + + # 测试max操作 + result = op_aggregate('max', tensor_list) + self.assertEqual(result.asnumpy(), 3.0) + + # 测试mean操作 + result = op_aggregate('mean', tensor_list) + self.assertEqual(result.asnumpy(), 2.0) + + def test_update_data(self): + # 测试数据更新 + old_data = {} + new_data = { + 'tag1': { + 'min': Tensor(1.0), + 'max': Tensor(2.0) + } + } + + result = update_data(old_data, new_data) + self.assertTrue('tag1' in result) + self.assertTrue('min' in result['tag1']) + self.assertTrue('max' in result['tag1']) + + def test_is_target_line(self): + # 测试目标行检查 + # 空代码行列表应该返回True + self.assertTrue(is_target_line([])) + + # 测试匹配模式 + codeline = ['test_pattern'] + with patch('msprobe.mindspore.monitor.distributed.wrap_distributed.get_callstack') as mock_callstack: + mock_callstack.return_value = ['test_pattern_line'] + self.assertTrue(is_target_line(codeline)) + + def test_get_distributed_ops(self): + # 测试获取分布式算子列表 + ops = get_distributed_ops() + self.assertIsInstance(ops, set) diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_service.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_service.py index ebb3c0985e3266e1a7701bbaae1f7ff20e9d82e2..912830ea1ab705aae63c69f5c240887d4b4ce5b7 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_service.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_service.py @@ -60,16 +60,16 @@ class TestService(unittest.TestCase): def test_check_model_valid_with_valid_cell(self): model = nn.Cell() + model_list = [model] self.assertEqual(self.service.check_model_valid(model), model) + self.assertEqual(self.service.check_model_valid(model_list), model_list) def test_check_model_valid_with_invalid_type(self): + model = nn.Cell() with self.assertRaises(MsprobeException): self.service.check_model_valid("not a cell") - - def test_check_level_valid_with_unsupported_level(self): - self.service.config.level = Const.LEVEL_L2 with self.assertRaises(MsprobeException): - self.service.check_level_valid() + self.service.check_model_valid(["not a cell", model]) def test_update_primitive_counters(self): self.service.primitive_counters = {} @@ -222,7 +222,7 @@ class TestService(unittest.TestCase): self.service.step() self.assertEqual(self.service.current_iter, 1) self.service.data_collector.update_iter.assert_called_once_with(1) - self.service.data_collector.data_writer.reset_cache.assert_called_once() + self.service.data_collector.reset_status.assert_called_once() self.assertEqual(JitDump.jit_count, defaultdict(int)) self.assertEqual((self.service.primitive_hook_service.primitive_counters), {}) @@ -240,7 +240,7 @@ class TestService(unittest.TestCase): mock_input = (MagicMock(),) mock_output = MagicMock() - _, forward_hook, backward_hook = self.service.build_hook(BaseScope.Module_Type_Module, "TestHook") + _, forward_hook, backward_hook, _ = self.service.build_hook(BaseScope.Module_Type_Module, "TestHook") forward_hook(mock_cell, mock_input, mock_output) self.service.data_collector.update_api_or_module_name.assert_called_with('TestCell') diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_primitive_dump.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_primitive_dump.py index a2e86e986eaee610262066e08606f559cd007d09..3cafd49f2c101c45dbb65a08803dd77c6bca485d 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_primitive_dump.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_primitive_dump.py @@ -76,10 +76,6 @@ class TestService(unittest.TestCase): with self.assertRaises(MsprobeException) as context: self.service.check_model_valid(model) - # For the purpose of the test, let's also verify the expected exception message - expected_message = f"{MsprobeException.err_strs.get(MsprobeException.INVALID_PARAM_ERROR)}model 参数必须是 mindspore.nn.Cell 类型。" - self.assertEqual(str(context.exception), expected_message) - def test_update_primitive_counters(self): primitive_name = "test_primitive" self.service.primitive_hook_service.update_primitive_counters(primitive_name) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/bench_functions/test_apply_adam.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/bench_functions/test_apply_adam.py new file mode 100644 index 0000000000000000000000000000000000000000..02631ec1fb487de28e9300934637b36791cf22ea --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/bench_functions/test_apply_adam.py @@ -0,0 +1,85 @@ +import unittest +import torch + +from msprobe.pytorch.bench_functions.apply_adam import npu_apply_adam + + +class TestNPUApplyAdam(unittest.TestCase): + def setUp(self): + # 初始化测试数据 + self.beta1_power = 0.9 + self.beta2_power = 0.999 + self.lr = 0.001 + self.beta1 = 0.9 + self.beta2 = 0.999 + self.epsilon = 1e-8 + self.grad = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32) + self.use_locking = False + self.var = torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32) + self.m = torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32) + self.v = torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32) + self.out = (self.var, self.m, self.v) + + def test_npu_apply_adam_without_nesterov(self): + # 测试不使用 Nesterov 动量的情况 + use_nesterov = False + var_t, m_t, v_t = npu_apply_adam( + self.beta1_power, self.beta2_power, self.lr, self.beta1, self.beta2, + self.epsilon, self.grad, self.use_locking, use_nesterov, self.out + ) + + # 验证 var_t 的结果 + expected_var_t = torch.tensor([-0.0010, -0.0010, -0.0010], dtype=torch.float32) + self.assertTrue(torch.allclose(var_t, expected_var_t, atol=1e-4)) + + # 验证 m_t 的结果 + expected_m_t = torch.tensor([0.1000, 0.2000, 0.3000], dtype=torch.float32) + self.assertTrue(torch.allclose(m_t, expected_m_t, atol=1e-4)) + + # 验证 v_t 的结果 + expected_v_t = torch.tensor([0.0010, 0.0040, 0.0090], dtype=torch.float32) + self.assertTrue(torch.allclose(v_t, expected_v_t, atol=1e-4)) + + def test_npu_apply_adam_with_nesterov(self): + # 测试使用 Nesterov 动量的情况 + use_nesterov = True + var_t, m_t, v_t = npu_apply_adam( + self.beta1_power, self.beta2_power, self.lr, self.beta1, self.beta2, + self.epsilon, self.grad, self.use_locking, use_nesterov, self.out + ) + + # 验证 var_t 的结果 + expected_var_t = torch.tensor([-0.0019, -0.0019, -0.0019], dtype=torch.float32) + self.assertTrue(torch.allclose(var_t, expected_var_t, atol=1e-4)) + + # 验证 m_t 的结果 + expected_m_t = torch.tensor([0.1000, 0.2000, 0.3000], dtype=torch.float32) + self.assertTrue(torch.allclose(m_t, expected_m_t, atol=1e-4)) + + # 验证 v_t 的结果 + expected_v_t = torch.tensor([0.0010, 0.0040, 0.0090], dtype=torch.float32) + self.assertTrue(torch.allclose(v_t, expected_v_t, atol=1e-4)) + + def test_npu_apply_adam_with_non_zero_initial_values(self): + # 测试非零初始值的情况 + self.m = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32) + self.v = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32) + self.out = (self.var, self.m, self.v) + + use_nesterov = False + var_t, m_t, v_t = npu_apply_adam( + self.beta1_power, self.beta2_power, self.lr, self.beta1, self.beta2, + self.epsilon, self.grad, self.use_locking, use_nesterov, self.out + ) + + # 验证 var_t 的结果 + expected_var_t = torch.tensor([-0.0003, -0.0004, -0.0005], dtype=torch.float32) + self.assertTrue(torch.allclose(var_t, expected_var_t, atol=1e-4)) + + # 验证 m_t 的结果 + expected_m_t = torch.tensor([1., 2., 3.], dtype=torch.float32) + self.assertTrue(torch.allclose(m_t, expected_m_t, atol=1e-4)) + + # 验证 v_t 的结果 + expected_v_t = torch.tensor([1.0000, 2.0020, 3.0060], dtype=torch.float32) + self.assertTrue(torch.allclose(v_t, expected_v_t, atol=1e-4)) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/bench_functions/test_group_norm_silu.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/bench_functions/test_group_norm_silu.py new file mode 100644 index 0000000000000000000000000000000000000000..7e4a447df1a10d57f47a427be672bbced2a5cffd --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/bench_functions/test_group_norm_silu.py @@ -0,0 +1,23 @@ +import unittest +import torch + +from msprobe.pytorch.bench_functions.group_norm_silu import npu_group_norm_silu + + +class TestNPUGroupNormSILU(unittest.TestCase): + def setUp(self): + self.input0 = torch.tensor([[[[1.0, 2.0], [3.0, 4.0]]]]) + self.gama = torch.tensor([1.0]) + self.beta = torch.tensor([0.0]) + self.group = 1 + self.eps = 1e-5 + + def test_npu_group_norm_silu_positive(self): + # 调用 npu_group_norm_silu 函数 + result = npu_group_norm_silu(self.input0, self.gama, self.beta, self.group, self.eps) + + # 预期的结果 + expected_result = torch.tensor([[[[-0.2780, -0.1744], [0.2728, 1.0636]]]]) + + # 使用 torch.allclose 进行近似比较 + self.assertTrue(torch.allclose(result[0], expected_result, atol=1e-4)) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/bench_functions/test_mish.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/bench_functions/test_mish.py new file mode 100644 index 0000000000000000000000000000000000000000..e1684859d8cb48a12584f0c41f37687cb57e7e13 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/bench_functions/test_mish.py @@ -0,0 +1,20 @@ +import unittest +import torch + +from msprobe.pytorch.bench_functions.mish import npu_mish + + +class TestNPUMish(unittest.TestCase): + def setUp(self): + self.input0 = torch.tensor([[[[1.0, 2.0], [3.0, 4.0]]]]) + self.eps = 1e-5 + + def test_npu_mish_positive(self): + # 调用 npu_mish 函数 + result = npu_mish(self.input0) + + # 预期的结果 + expected_result = torch.tensor([[[[0.8651, 1.9440], [2.9865, 3.9974]]]]) + + # 使用 torch.allclose 进行近似比较 + self.assertTrue(torch.allclose(result[0], expected_result, atol=1e-4)) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/bench_functions/test_npu_moe_gating_top_k_softmax.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/bench_functions/test_npu_moe_gating_top_k_softmax.py new file mode 100644 index 0000000000000000000000000000000000000000..e33915afb7455e647c7631e225e0f28668796e64 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/bench_functions/test_npu_moe_gating_top_k_softmax.py @@ -0,0 +1,37 @@ +import unittest +import torch + +from msprobe.pytorch.bench_functions.moe_gating_top_k_softmax import npu_moe_gating_top_k_softmax, softmax_func + + +class TestNPUMoEGatingTopKSoftmax(unittest.TestCase): + def setUp(self): + self.input0 = torch.tensor([[[[1.0, 2.0], [3.0, 4.0]]]]) + self.finished_optional = None + self.k = 2 + + def test_npu_moe_gating_top_k_softmax(self): + # 调用 npu_moe_gating_top_k_softmax 函数 + result = npu_moe_gating_top_k_softmax(self.input0, self.finished_optional, self.k) + + # 预期的结果 + expected_result = ( + torch.tensor([[[[0.7311, 0.2689], [0.7311, 0.2689]]]]), + torch.tensor([[[[1, 0], [1, 0]]]]), + torch.tensor([[0]]) + ) + + # 使用 torch.allclose 进行近似比较 + self.assertTrue(torch.allclose(result[0], expected_result[0], atol=1e-4)) + self.assertTrue(torch.allclose(result[1], expected_result[1], atol=1e-4)) + self.assertTrue(torch.allclose(result[2], expected_result[2], atol=1e-4)) + + def test_softmax_func(self): + # 调用 softmax_func 函数 + result = softmax_func(self.input0, -1) + + # 预期的结果 + expected_result = torch.tensor([[[[0.2689, 0.7311], [0.2689, 0.7311]]]]) + + # 使用 torch.allclose 进行近似比较 + self.assertTrue(torch.allclose(result, expected_result, atol=1e-4)) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/bench_functions/test_sort_v2.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/bench_functions/test_sort_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..9c008172dbcfa6e3fd7c8e6c340a679e2ac3e9c8 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/bench_functions/test_sort_v2.py @@ -0,0 +1,22 @@ +import unittest +import torch + +from msprobe.pytorch.bench_functions.sort_v2 import npu_sort_v2 + + +class TestSortV2(unittest.TestCase): + def setUp(self): + self.input0 = torch.tensor([[[[1.0, 2.0], [3.0, 4.0]]]]) + self.dim = -1 + self.descending = False + self.out = None + + def test_npu_sort_v2(self): + # 调用 npu_sort_v2 函数 + result = npu_sort_v2(self.input0, self.dim, self.descending, self.out) + + # 预期的结果 + expected_result = torch.tensor([[[[1.0, 2.0], [3.0, 4.0]]]]) + + # 使用 torch.allclose 进行近似比较 + self.assertTrue(torch.allclose(result, expected_result, atol=1e-4)) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/debugger/test_pt_debugger_config.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/debugger/test_pt_debugger_config.py index e9c55bc162632fd07ebf0d4e7d437acef236e4cb..4fc27c267ebe65ea46ecf0f17bc47ff702eb241d 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/debugger/test_pt_debugger_config.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/debugger/test_pt_debugger_config.py @@ -15,6 +15,7 @@ class TestDebuggerConfig(unittest.TestCase): self.common_config.task = Const.STATISTICS self.common_config.level = "L1" self.common_config.enable_dataloader = True + self.common_config.async_dump = False def test_default_init(self): debugger = DebuggerConfig(self.common_config, self.task_config, None, None, None) @@ -81,6 +82,7 @@ class TestDebuggerConfig(unittest.TestCase): def test_check_and_adjust_config_with_l2_list_empty(self): self.common_config.dump_path = "./dump_path" self.common_config.task = Const.TENSOR + self.common_config.async_dump = False self.task_config.scope = [] self.task_config.list = [] diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/functional/test_module_dump.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/dump/test_module_dump.py similarity index 38% rename from debug/accuracy_tools/msprobe/test/pytorch_ut/functional/test_module_dump.py rename to debug/accuracy_tools/msprobe/test/pytorch_ut/dump/test_module_dump.py index 8a0ff72dd266e056f8a549b698b56b6e8c6e1041..63d6abc3a2430bb6f092820c4b97a02cdf675612 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/functional/test_module_dump.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/dump/test_module_dump.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -18,16 +18,12 @@ from unittest.mock import patch, MagicMock import torch import torch.nn as nn -from msprobe.core.common.exceptions import MsprobeException -from msprobe.core.common.log import logger from msprobe.pytorch import PrecisionDebugger -from msprobe.pytorch.service import torch_version_above_or_equal_2 -from msprobe.pytorch.functional.module_dump import module_dump, module_dump_end, \ - hook_handle_list, remove_hook, register_hook from msprobe.pytorch.hook_module.api_registry import api_register +from msprobe.pytorch.service import torch_version_above_or_equal_2 -class TestModuleDump(unittest.TestCase): +class TestModuleDumper(unittest.TestCase): @classmethod def setUpClass(cls): PrecisionDebugger._instance = None @@ -40,44 +36,25 @@ class TestModuleDump(unittest.TestCase): def setUp(self): self.module = nn.Linear(8, 4) - - def tearDown(self): - hook_handle_list.clear() - - @patch.object(logger, 'error') - def test_module_dump(self, mock_error): - with self.assertRaises(MsprobeException) as context: - module_dump(1, "TestModule") - self.assertEqual(context.exception.code, MsprobeException.INVALID_PARAM_ERROR) - mock_error.assert_called_with("The parameter module in module_dump must be a Module subclass.") - - with self.assertRaises(MsprobeException) as context: - module_dump(self.module, 1) - self.assertEqual(context.exception.code, MsprobeException.INVALID_PARAM_ERROR) - mock_error.assert_called_with("The parameter dump_name in module_dump must be a str type.") - - with patch('msprobe.pytorch.functional.module_dump.register_hook') as mock_register_hook: - module_dump(self.module, "TestModule") - mock_register_hook.assert_called_with(self.module, "TestModule") - - def test_module_dump_end(self): - hook_handle_list.extend([1, 2, 3]) - with patch('msprobe.pytorch.functional.module_dump.remove_hook') as mock_remove_hook: - module_dump_end() - mock_remove_hook.assert_called_once() - self.assertEqual(hook_handle_list, []) + debugger = PrecisionDebugger(dump_path="./") + self.module_dumper = debugger.module_dumper + + def test_stop_module_dump(self): + self.module_dumper.hook_handle_list.extend([1, 2, 3]) + with patch('msprobe.pytorch.dump.module_dump.module_dump.api_register') as mock_api_register: + mock_handle1 = MagicMock(spec=torch.utils.hooks.RemovableHandle) + mock_handle2 = MagicMock(spec=torch.utils.hooks.RemovableHandle) + self.module_dumper.hook_handle_list.extend([mock_handle1, mock_handle2]) + + self.module_dumper.stop_module_dump() + mock_handle1.remove.assert_called_once() + mock_handle2.remove.assert_called_once() + self.assertEqual(self.module_dumper.hook_handle_list, []) + mock_api_register.api_modularity.assert_called_once() def test_register_hook(self): - PrecisionDebugger(dump_path="./") - register_hook(self.module, "TestModule") + self.module_dumper.register_hook(self.module, "TestModule") if torch_version_above_or_equal_2: - self.assertEqual(len(hook_handle_list), 6) + self.assertEqual(len(self.module_dumper.hook_handle_list), 6) else: - self.assertEqual(len(hook_handle_list), 5) - - def test_remove_hook(self): - mock_handle = MagicMock(spec=torch.utils.hooks.RemovableHandle) - hook_handle_list.append(mock_handle) - remove_hook() - - mock_handle.remove.assert_called_once() + self.assertEqual(len(self.module_dumper.hook_handle_list), 5) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/dump/test_module_processer.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/dump/test_module_processer.py index e7ed4926c868b2bca4d53dd4046d3e3a71904557..f8a561b61b6a758a525675bdc59957e5c923b261 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/dump/test_module_processer.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/dump/test_module_processer.py @@ -25,32 +25,6 @@ class TestModuleProcesser(unittest.TestCase): processor = ModuleProcesser(scope) self.assertIsNone(processor.scope) - def test_filter_tensor_and_tuple(self): - def func(nope, x): - return x * 2 - - result_1 = ModuleProcesser.filter_tensor_and_tuple(func)(None, torch.tensor([1])) - self.assertEqual(result_1, torch.tensor([2])) - result_2 = ModuleProcesser.filter_tensor_and_tuple(func)(None, "test") - self.assertEqual(result_2, "test") - - def test_filter_tensor_and_tuple_with_tensor(self): - class MockBackwardHook: - @staticmethod - def setup_output_hook(*args, **kwargs): - return args[1] - - mock_hook = MockBackwardHook.setup_output_hook - wrapped_hook = ModuleProcesser.filter_tensor_and_tuple(mock_hook) - - tensor = torch.tensor([1, 2, 3]) - mock_obj = type('MockObj', (object,), {'tensor_attr': tensor})() - wrapped_hook(None, mock_obj) - self.assertIs(mock_obj.tensor_attr, tensor) - non_tensor_obj = type('MockObj', (object,), {'non_tensor_attr': 'non_tensor_value'})() - wrapped_hook(None, non_tensor_obj) - self.assertEqual(non_tensor_obj.non_tensor_attr, 'non_tensor_value') - def test_clone_return_value_and_test_clone_if_tensor(self): def func(x): return x @@ -118,9 +92,13 @@ class TestModuleProcesser(unittest.TestCase): self.assertEqual(module.mindstudio_reserved_name, [expected_name]) self.assertIn(expected_name, ModuleProcesser.module_node) - def test_remove_deprecated_backward_hook_if_exist(self): + def test_has_register_backward_hook(self): module = MagicMock() module._backward_hooks = {0: lambda: None} module._is_full_backward_hook = False - self.processor.remove_deprecated_backward_hook_if_exist(module) - self.assertIsNone(module._is_full_backward_hook) + result = self.processor.has_register_backward_hook(module) + self.assertTrue(result) + + module._is_full_backward_hook = True + result = self.processor.has_register_backward_hook(module) + self.assertFalse(result) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/config/all_config.json b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/config/all_config.json new file mode 100644 index 0000000000000000000000000000000000000000..9c2eb5b43a278a6e2e8104e3d9f8bce912930a0d --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/config/all_config.json @@ -0,0 +1,13 @@ +{ + "targets": { + "": {} + }, + "param_distribution": true, + "xy_distribution": true, + "mv_distribution": true, + "wg_distribution": true, + "all_xy": true, + "format": "csv", + "ops": ["norm", "nans"], + "step_count_per_record": 3 +} \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/demo_model.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/demo_model.py index d25cf05a39c81af7e219c711b7941b886dcba377..f5de419440224cca261b62df2495e8ce28b8e2d4 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/demo_model.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/demo_model.py @@ -36,15 +36,13 @@ def monitor_demo(config: str = "./config/monitor_config.json"): hooker = TrainerMon( config, - params_have_main_grad=False, - opt_ty='Megatron_FP32Optimizer' + params_have_main_grad=False ) - hooker.monitor_gnorm_with_ad( + hooker.set_monitor( model=net, grad_acc_steps=1, optimizer=optimizer ) - hooker.set_wrapped_optimizer(optimizer) train_ds = ToyDataset() train_loader = torch.utils.data.DataLoader(train_ds, shuffle=True, batch_size=10) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_anomaly_detect.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_anomaly_detect.py index 186a3e491ad5a460cf79b4673055950f2bb6ed7c..fa0960e2cc1842a138b47fad3f86c1ed0d089db8 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_anomaly_detect.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_anomaly_detect.py @@ -122,13 +122,13 @@ class TestAnomalyDataFactory(TestCase): class TestGradAnomalyData(TestCase): def setUp(self) -> None: - tag_name = "0:1.self_attention.core_attention_flash_0/rank0/output" - message = "Rule AnomalyTurbulence reports anomaly signal in ('0:1.self_attention.core_attention_flash_0/rank0/output', 'min') at step 2." + tag_name = "0:1.self_attention.core_attention_flash.output:0/rank0/actv" + message = "Rule AnomalyTurbulence reports anomaly signal in ('0:1.self_attention.core_attention_flash.output:0/rank0/actv', 'min') at step 2." group_mates = [0] self.GradAnomalyData = GradAnomalyData(tag_name=tag_name, message=message, group_mates=group_mates) def test_get_train_stage(self): - tag_name_list = ["0:fc2_0/rank0/input", "0:fc1.weight/rank0/post_grad", "0:fc2.weight/rank0/efxp_avg_sq", ""] + tag_name_list = ["0:fc2.input:0/rank0/actv", "0:fc1.weight/rank0/post_grad", "0:fc2.weight/rank0/exp_avg_sq", ""] expected_train_stage_list = [0, 1, 2, -1] for tag_name, expected_train_stage in zip(tag_name_list, expected_train_stage_list): train_stage = GradAnomalyData.get_train_stage(tag_name) @@ -142,15 +142,15 @@ class TestGradAnomalyData(TestCase): 'pp_stage': 0, 'vpp_stage': 0, 'call_id': 0, - 'tag_name': "0:1.self_attention.core_attention_flash_0/rank0/output", - 'message': "Rule AnomalyTurbulence reports anomaly signal in ('0:1.self_attention.core_attention_flash_0/rank0/output', 'min') at step 2.", + 'tag_name': "0:1.self_attention.core_attention_flash.output:0/rank0/actv", + 'message': "Rule AnomalyTurbulence reports anomaly signal in ('0:1.self_attention.core_attention_flash.output:0/rank0/actv', 'min') at step 2.", 'group_mates': [0] } self.assertEqual(self.GradAnomalyData.to_dict(), expected) def test_get_key(self): - expected = "0:1.self_attention.core_attention_flash_0/rank0/output_step_0_call_0" + expected = "0:1.self_attention.core_attention_flash.output:0/rank0/actv_step_0_call_0" self.assertEqual(self.GradAnomalyData.get_key(), expected) @@ -168,8 +168,8 @@ class TestGradAnomalyData(TestCase): def test_lt_same_step_same_micro_step_different_vpp_stage(self): # same forward - data1 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=0, call_id=0, tag_name="xxx/input") - data2 = GradAnomalyData(step=1, micro_step=0, vpp_stage=1, pp_stage=0, call_id=0, tag_name="xxx/input") + data1 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=0, call_id=0, tag_name="xxx/actv") + data2 = GradAnomalyData(step=1, micro_step=0, vpp_stage=1, pp_stage=0, call_id=0, tag_name="xxx/actv") self.assertLess(data1, data2) self.assertGreater(data2, data1) @@ -180,15 +180,15 @@ class TestGradAnomalyData(TestCase): self.assertGreater(data1, data2) # diff train stage - data1 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=0, call_id=0, tag_name="xxx/input") + data1 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=0, call_id=0, tag_name="xxx/actv") data2 = GradAnomalyData(step=1, micro_step=0, vpp_stage=1, pp_stage=0, call_id=0, tag_name="xxx/post_grad") self.assertLess(data1, data2) self.assertGreater(data2, data1) def test_lt_same_step_same_micro_step_same_vpp_stage_different_pp_stage(self): # same forward - data1 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=0, call_id=0, tag_name="xxx/input") - data2 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=1, call_id=0, tag_name="xxx/input") + data1 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=0, call_id=0, tag_name="xxx/actv") + data2 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=1, call_id=0, tag_name="xxx/actv") self.assertLess(data1, data2) self.assertGreater(data2, data1) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_csv2tb.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_csv2tb.py new file mode 100644 index 0000000000000000000000000000000000000000..f2bc82ffafc2a1f10719d4a46669bc0050c12782 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_csv2tb.py @@ -0,0 +1,751 @@ +import os +import shutil +import random +import unittest +import pytest +import torch +import numpy as np +import torch.nn as nn +from tensorboard.backend.event_processing.event_accumulator import EventAccumulator + +from msprobe.pytorch import TrainerMon +from msprobe.core.common.const import MonitorConst +from msprobe.pytorch.monitor.csv2tb import parse_step_fn, csv2tensorboard_by_step + + +base_dir = os.path.dirname(os.path.realpath(__file__)) +config_json_path = os.path.join(base_dir, "config", "all_config.json") +monitor_output = os.path.join(base_dir, "./monitor_output_csv2tb") +os.environ[MonitorConst.MONITOR_OUTPUT_DIR] = monitor_output +timestamp_dirpath = None +csv2tb_dirpath = None + + +def seed_all(seed=1234, mode=False): + random.seed(seed) + os.environ['PYTHONHASHSEED'] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.use_deterministic_algorithms(mode) + +seed_all() + + +inputs = [torch.rand(10, 10) for _ in range(10)] +labels = [torch.randint(0, 5, (10,)) for _ in range(10)] + + +class MockModule(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(10, 5) + self.relu = nn.ReLU() + + def forward(self, x): + x1 = self.linear(x) + x2 = self.relu(x1) + return x2 + + +def data_collect(): + loss_fun = nn.CrossEntropyLoss() + test_module = MockModule() + nn.init.constant_(test_module.linear.weight, 1.0) + nn.init.constant_(test_module.linear.bias, 1.0) + optimizer = torch.optim.Adam(test_module.parameters()) + + monitor = TrainerMon(config_json_path, params_have_main_grad=False) + monitor.set_monitor(test_module, grad_acc_steps=1, optimizer=optimizer) + + for input_data, label in zip(inputs, labels): + output = test_module(input_data) + loss = loss_fun(output, label) + optimizer.zero_grad() + loss.backward() + optimizer.step() + + global timestamp_dirpath, csv2tb_dirpath + timestamp_dirpath = os.path.join(monitor_output, os.listdir(monitor_output)[0]) + csv2tensorboard_by_step(monitor_output) + for dirname in os.listdir(monitor_output): + if "csv2tensorboard" in dirname: + csv2tb_dirpath = os.path.join(monitor_output, dirname, "rank0") + + +def extract_scalars_from_tensorboard(log_dir): + # 初始化 EventAccumulator + event_acc = EventAccumulator(log_dir) + event_acc.Reload() # 加载事件数据 + + # 获取所有 scalar 标签 + scalar_tags = event_acc.Tags()['scalars'] + + # 构建字典,键为标签,值为对应的 (step, value) 列表 + scalars_dict = {} + for tag in scalar_tags: + scalar_events = event_acc.Scalars(tag) + scalars_dict[tag] = [(event.step, event.value) for event in scalar_events] + + return scalars_dict + + +def dict_equal(a, b): + if not isinstance(a, dict) or not isinstance(b, dict): + if np.isnan(a) and np.isnan(b): + return True + return a == b + + if set(a.keys()) != set(b.keys()): + return False + + for key in a: + if not dict_equal(a[key], b[key]): + return False + + return True + + +def compare_scalar_dicts(dict1, dict2): + if set(dict1.keys()) != set(dict2.keys()): + return False + + for key in dict1: + list1 = dict1[key] + list2 = dict2[key] + + if len(list1) != len(list2): + return False + + # 对比每对 (step, value) + for (step1, value1), (step2, value2) in zip(list1, list2): + if step1 != step2: + return False + + if not (value1 == value2 or (np.isnan(value1) and np.isnan(value2))): + return False + return True + + +@pytest.fixture(scope="session") +def setup_all(): + data_collect() + yield + shutil.rmtree(monitor_output) + +@pytest.mark.usefixtures("setup_all") +class TestGradMonitor(unittest.TestCase): + + def setUp(self): + self.maxDiff = None + + def test_actv(self): + data = parse_step_fn(os.path.join(timestamp_dirpath,"actv_0-2.csv")) + result = { + 'vp0:.input:micro0': { + 0: {'nans': 0.0,'norm': 5.550016}, + 1: {'nans': 0.0,'norm': 5.975112}, + 2: {'nans': 0.0,'norm': 5.789881} + }, + 'vp0:.output:micro0': { + 0: {'nans': 0.0,'norm': 41.842655}, + 1: {'nans': 0.0,'norm': 44.40981}, + 2: {'nans': 0.0,'norm': 43.578354} + }, + 'vp0:linear.input:micro0': { + 0: {'nans': 0.0,'norm': 5.550016}, + 1: {'nans': 0.0,'norm': 5.975112}, + 2: {'nans': 0.0,'norm': 5.789881} + }, + 'vp0:linear.output:micro0': { + 0: {'nans': 0.0,'norm': 41.842655}, + 1: {'nans': 0.0,'norm': 44.40981}, + 2: {'nans': 0.0,'norm': 43.578354} + }, + 'vp0:relu.input:micro0': { + 0: {'nans': 0.0,'norm': 41.842655}, + 1: {'nans': 0.0,'norm': 44.40981}, + 2: {'nans': 0.0,'norm': 43.578354} + }, + 'vp0:relu.output:micro0': { + 0: {'nans': 0.0,'norm': 41.842655}, + 1: {'nans': 0.0,'norm': 44.40981}, + 2: {'nans': 0.0,'norm': 43.578354} + } + } + self.assertEqual(dict_equal(data, result), True) + tb_data = extract_scalars_from_tensorboard(os.path.join(csv2tb_dirpath, "actv")) + print(tb_data) + tb_result = { + 'vp0:.input:micro0/nans': [(0, 0.0), + (1, 0.0), + (2, 0.0), + (3, 0.0), + (4, 0.0), + (5, 0.0), + (6, 0.0), + (7, 0.0), + (8, 0.0), + (9, 0.0)], + 'vp0:.input:micro0/norm': [(0, 5.550015926361084), + (1, 5.975111961364746), + (2, 5.789881229400635), + (3, 6.052319049835205), + (4, 5.573315143585205), + (5, 5.864360809326172), + (6, 5.292460918426514), + (7, 5.477899074554443), + (8, 5.884613990783691), + (9, 5.456457138061523)], + 'vp0:.output:micro0/nans': [(0, 0.0), + (1, 0.0), + (2, 0.0), + (3, 0.0), + (4, 0.0), + (5, 0.0), + (6, 0.0), + (7, 0.0), + (8, 0.0), + (9, 0.0)], + 'vp0:.output:micro0/norm': [(0, 41.842655181884766), + (1, 44.40980911254883), + (2, 43.57835388183594), + (3, 45.83631134033203), + (4, 42.0673828125), + (5, 43.46839141845703), + (6, 39.77947235107422), + (7, 40.200843811035156), + (8, 44.453147888183594), + (9, 40.841522216796875)], + 'vp0:linear.input:micro0/nans': [(0, 0.0), + (1, 0.0), + (2, 0.0), + (3, 0.0), + (4, 0.0), + (5, 0.0), + (6, 0.0), + (7, 0.0), + (8, 0.0), + (9, 0.0)], + 'vp0:linear.input:micro0/norm': [(0, 5.550015926361084), + (1, 5.975111961364746), + (2, 5.789881229400635), + (3, 6.052319049835205), + (4, 5.573315143585205), + (5, 5.864360809326172), + (6, 5.292460918426514), + (7, 5.477899074554443), + (8, 5.884613990783691), + (9, 5.456457138061523)], + 'vp0:linear.output:micro0/nans': [(0, 0.0), + (1, 0.0), + (2, 0.0), + (3, 0.0), + (4, 0.0), + (5, 0.0), + (6, 0.0), + (7, 0.0), + (8, 0.0), + (9, 0.0)], + 'vp0:linear.output:micro0/norm': [(0, 41.842655181884766), + (1, 44.40980911254883), + (2, 43.57835388183594), + (3, 45.83631134033203), + (4, 42.0673828125), + (5, 43.46839141845703), + (6, 39.77947235107422), + (7, 40.200843811035156), + (8, 44.453147888183594), + (9, 40.841522216796875)], + 'vp0:relu.input:micro0/nans': [(0, 0.0), + (1, 0.0), + (2, 0.0), + (3, 0.0), + (4, 0.0), + (5, 0.0), + (6, 0.0), + (7, 0.0), + (8, 0.0), + (9, 0.0)], + 'vp0:relu.input:micro0/norm': [(0, 41.842655181884766), + (1, 44.40980911254883), + (2, 43.57835388183594), + (3, 45.83631134033203), + (4, 42.0673828125), + (5, 43.46839141845703), + (6, 39.77947235107422), + (7, 40.200843811035156), + (8, 44.453147888183594), + (9, 40.841522216796875)], + 'vp0:relu.output:micro0/nans': [(0, 0.0), + (1, 0.0), + (2, 0.0), + (3, 0.0), + (4, 0.0), + (5, 0.0), + (6, 0.0), + (7, 0.0), + (8, 0.0), + (9, 0.0)], + 'vp0:relu.output:micro0/norm': [(0, 41.842655181884766), + (1, 44.40980911254883), + (2, 43.57835388183594), + (3, 45.83631134033203), + (4, 42.0673828125), + (5, 43.46839141845703), + (6, 39.77947235107422), + (7, 40.200843811035156), + (8, 44.453147888183594), + (9, 40.841522216796875)]} + self.assertEqual(compare_scalar_dicts(tb_data, tb_result), True) + + + def test_actv_grad(self): + data = parse_step_fn(os.path.join(timestamp_dirpath,"actv_grad_0-2.csv")) + nan = np.nan + result = { + 'vp0:.input:micro0': { + 0: {'norm': nan, 'nans': nan}, + 1: {'norm': nan, 'nans': nan}, + 2: {'norm': nan, 'nans': nan} + }, + 'vp0:.output:micro0': { + 0: {'norm': 0.282843, 'nans': 0.0}, + 1: {'norm': 0.282617, 'nans': 0.0}, + 2: {'norm': 0.282655, 'nans': 0.0} + }, + 'vp0:relu.input:micro0': { + 0: {'norm': 0.282843, 'nans': 0.0}, + 1: {'norm': 0.282617, 'nans': 0.0}, + 2: {'norm': 0.282655, 'nans': 0.0} + }, + 'vp0:relu.output:micro0': { + 0: {'norm': 0.282843, 'nans': 0.0}, + 1: {'norm': 0.282617, 'nans': 0.0}, + 2: {'norm': 0.282655, 'nans': 0.0} + }, + 'vp0:linear.input:micro0': { + 0: {'norm': nan, 'nans': nan}, + 1: {'norm': nan, 'nans': nan}, + 2: {'norm': nan, 'nans': nan} + }, + 'vp0:linear.output:micro0': { + 0: {'norm': 0.282843, 'nans': 0.0}, + 1: {'norm': 0.282617, 'nans': 0.0}, + 2: {'norm': 0.282655, 'nans': 0.0} + } + } + self.assertEqual(dict_equal(data, result), True) + + tb_data = extract_scalars_from_tensorboard(os.path.join(csv2tb_dirpath, "actv_grad")) + tb_result = { + 'vp0:.input:micro0/nans': [(0, nan), + (1, nan), + (2, nan), + (3, nan), + (4, nan), + (5, nan), + (6, nan), + (7, nan), + (8, nan), + (9, nan)], + 'vp0:.input:micro0/norm': [(0, nan), + (1, nan), + (2, nan), + (3, nan), + (4, nan), + (5, nan), + (6, nan), + (7, nan), + (8, nan), + (9, nan)], + 'vp0:.output:micro0/nans': [(0, 0.0), + (1, 0.0), + (2, 0.0), + (3, 0.0), + (4, 0.0), + (5, 0.0), + (6, 0.0), + (7, 0.0), + (8, 0.0), + (9, 0.0)], + 'vp0:.output:micro0/norm': [(0, 0.2828429937362671), + (1, 0.2826170027256012), + (2, 0.2826550006866455), + (3, 0.2828519940376282), + (4, 0.2822929918766022), + (5, 0.2826640009880066), + (6, 0.28316599130630493), + (7, 0.28274500370025635), + (8, 0.2833530008792877), + (9, 0.2825529873371124)], + 'vp0:linear.input:micro0/nans': [(0, nan), + (1, nan), + (2, nan), + (3, nan), + (4, nan), + (5, nan), + (6, nan), + (7, nan), + (8, nan), + (9, nan)], + 'vp0:linear.input:micro0/norm': [(0, nan), + (1, nan), + (2, nan), + (3, nan), + (4, nan), + (5, nan), + (6, nan), + (7, nan), + (8, nan), + (9, nan)], + 'vp0:linear.output:micro0/nans': [(0, 0.0), + (1, 0.0), + (2, 0.0), + (3, 0.0), + (4, 0.0), + (5, 0.0), + (6, 0.0), + (7, 0.0), + (8, 0.0), + (9, 0.0)], + 'vp0:linear.output:micro0/norm': [(0, 0.2828429937362671), + (1, 0.2826170027256012), + (2, 0.2826550006866455), + (3, 0.2828519940376282), + (4, 0.2822929918766022), + (5, 0.2826640009880066), + (6, 0.28316599130630493), + (7, 0.28274500370025635), + (8, 0.2833530008792877), + (9, 0.2825529873371124)], + 'vp0:relu.input:micro0/nans': [(0, 0.0), + (1, 0.0), + (2, 0.0), + (3, 0.0), + (4, 0.0), + (5, 0.0), + (6, 0.0), + (7, 0.0), + (8, 0.0), + (9, 0.0)], + 'vp0:relu.input:micro0/norm': [(0, 0.2828429937362671), + (1, 0.2826170027256012), + (2, 0.2826550006866455), + (3, 0.2828519940376282), + (4, 0.2822929918766022), + (5, 0.2826640009880066), + (6, 0.28316599130630493), + (7, 0.28274500370025635), + (8, 0.2833530008792877), + (9, 0.2825529873371124)], + 'vp0:relu.output:micro0/nans': [(0, 0.0), + (1, 0.0), + (2, 0.0), + (3, 0.0), + (4, 0.0), + (5, 0.0), + (6, 0.0), + (7, 0.0), + (8, 0.0), + (9, 0.0)], + 'vp0:relu.output:micro0/norm': [(0, 0.2828429937362671), + (1, 0.2826170027256012), + (2, 0.2826550006866455), + (3, 0.2828519940376282), + (4, 0.2822929918766022), + (5, 0.2826640009880066), + (6, 0.28316599130630493), + (7, 0.28274500370025635), + (8, 0.2833530008792877), + (9, 0.2825529873371124)]} + self.assertEqual(compare_scalar_dicts(tb_data, tb_result), True) + + + def test_param(self): + data = parse_step_fn(os.path.join(timestamp_dirpath,"param_0-2.csv")) + result = { + 'vp0:linear.bias': { + 0: {'nans': 0.0, 'norm': 2.236068}, + 1: {'nans': 0.0, 'norm': 2.236198}, + 2: {'nans': 0.0, 'norm': 2.235769} + }, + 'vp0:linear.weight': { + 0: {'nans': 0.0, 'norm': 7.071068}, + 1: {'nans': 0.0, 'norm': 7.068808}, + 2: {'nans': 0.0, 'norm': 7.06771} + } + } + self.assertEqual(dict_equal(data, result), True) + tb_data = extract_scalars_from_tensorboard(os.path.join(csv2tb_dirpath, "param")) + tb_result = { + 'vp0:linear.weight/norm': [ + (0, 7.071067810058594), + (1, 7.068808078765869), + (2, 7.067709922790527), + (3, 7.0673418045043945), + (4, 7.066926956176758), + (5, 7.066311836242676), + (6, 7.065629959106445), + (7, 7.065262794494629), + (8, 7.065001964569092), + (9, 7.064840793609619)], + 'vp0:linear.weight/nans': [ + (0, 0.0), + (1, 0.0), + (2, 0.0), + (3, 0.0), + (4, 0.0), + (5, 0.0), + (6, 0.0), + (7, 0.0), + (8, 0.0), + (9, 0.0)], + 'vp0:linear.bias/norm': [ + (0, 2.2360680103302), + (1, 2.2361979484558105), + (2, 2.235769033432007), + (3, 2.235903024673462), + (4, 2.2360129356384277), + (5, 2.2359039783477783), + (6, 2.2357990741729736), + (7, 2.2357349395751953), + (8, 2.2356700897216797), + (9, 2.235619068145752)], + 'vp0:linear.bias/nans': [ + (0, 0.0), + (1, 0.0), + (2, 0.0), + (3, 0.0), + (4, 0.0), + (5, 0.0), + (6, 0.0), + (7, 0.0), + (8, 0.0), + (9, 0.0)] + } + self.assertEqual(compare_scalar_dicts(tb_data, tb_result), True) + + def test_exp_avg(self): + data = parse_step_fn(os.path.join(timestamp_dirpath,"exp_avg_0-2.csv")) + result = { + 'vp0:linear.bias': { + 1: {'nans': 0.0, 'norm': 0.024495}, + 2: {'nans': 0.0, 'norm': 0.052203} + }, + 'vp0:linear.weight': { + 1: {'nans': 0.0, 'norm': 0.052394}, + 2: {'nans': 0.0, 'norm': 0.099221} + } + } + self.assertEqual(dict_equal(data, result), True) + tb_data = extract_scalars_from_tensorboard(os.path.join(csv2tb_dirpath, "exp_avg")) + tb_result = { + 'vp0:linear.bias/nans': [(1, 0.0), + (2, 0.0), + (3, 0.0), + (4, 0.0), + (5, 0.0), + (6, 0.0), + (7, 0.0), + (8, 0.0), + (9, 0.0)], + 'vp0:linear.bias/norm': [(1, 0.024495000019669533), + (2, 0.05220299959182739), + (3, 0.06452500075101852), + (4, 0.05751600116491318), + (5, 0.07189200073480606), + (6, 0.07151799649000168), + (7, 0.053112998604774475), + (8, 0.06187799945473671), + (9, 0.04195199906826019)], + 'vp0:linear.weight/nans': [(1, 0.0), + (2, 0.0), + (3, 0.0), + (4, 0.0), + (5, 0.0), + (6, 0.0), + (7, 0.0), + (8, 0.0), + (9, 0.0)], + 'vp0:linear.weight/norm': [(1, 0.05239399895071983), + (2, 0.09922099858522415), + (3, 0.12258800119161606), + (4, 0.11325100064277649), + (5, 0.14186500012874603), + (6, 0.14408400654792786), + (7, 0.11372199654579163), + (8, 0.12264800071716309), + (9, 0.09017200022935867)]} + self.assertEqual(compare_scalar_dicts(tb_data, tb_result), True) + + def test_exp_avg_sq(self): + data = parse_step_fn(os.path.join(timestamp_dirpath,"exp_avg_sq_0-2.csv")) + result = { + 'vp0:linear.bias': { + 1: {'nans': 0.0, 'norm': 4.2e-05}, + 2: {'nans': 0.0, 'norm': 9.6e-05} + }, + 'vp0:linear.weight': { + 1: {'nans': 0.0, 'norm': 6.7e-05}, + 2: {'nans': 0.0, 'norm': 0.000126} + } + } + self.assertEqual(dict_equal(data, result), True) + tb_data = extract_scalars_from_tensorboard(os.path.join(csv2tb_dirpath, "exp_avg_sq")) + tb_result = { + 'vp0:linear.bias/nans': [(1, 0.0), + (2, 0.0), + (3, 0.0), + (4, 0.0), + (5, 0.0), + (6, 0.0), + (7, 0.0), + (8, 0.0), + (9, 0.0)], + 'vp0:linear.bias/norm': [(1, 4.199999966658652e-05), + (2, 9.600000339560211e-05), + (3, 0.00013099999341648072), + (4, 0.00013099999341648072), + (5, 0.00016500000492669642), + (6, 0.0001900000061141327), + (7, 0.00020199999562464654), + (8, 0.00022899999748915434), + (9, 0.00024300000222865492)], + 'vp0:linear.weight/nans': [(1, 0.0), + (2, 0.0), + (3, 0.0), + (4, 0.0), + (5, 0.0), + (6, 0.0), + (7, 0.0), + (8, 0.0), + (9, 0.0)], + 'vp0:linear.weight/norm': [(1, 6.70000008540228e-05), + (2, 0.00012599999899975955), + (3, 0.00015799999528098851), + (4, 0.00016599999798927456), + (5, 0.00021399999968707561), + (6, 0.00024199999461416155), + (7, 0.00026000000070780516), + (8, 0.00028700000257231295), + (9, 0.0003060000017285347)]} + self.assertEqual(compare_scalar_dicts(tb_data, tb_result), True) + + def test_grad_reduced(self): + data = parse_step_fn(os.path.join(timestamp_dirpath,"grad_reduced_0-2.csv")) + result = { + 'vp0:linear.bias': { + 0: {'nans': 0.0, 'norm': 0.244949}, + 1: {'nans': 0.0, 'norm': 0.314345}, + 2: {'nans': 0.0, 'norm': 0.281475} + }, + 'vp0:linear.weight': { + 0: {'nans': 0.0, 'norm': 0.523935}, + 1: {'nans': 0.0, 'norm': 0.595672}, + 2: {'nans': 0.0, 'norm': 0.497603} + } + } + self.assertEqual(dict_equal(data, result), True) + tb_data = extract_scalars_from_tensorboard(os.path.join(csv2tb_dirpath, "grad_reduced")) + tb_result = { + 'vp0:linear.bias/nans': [(0, 0.0), + (1, 0.0), + (2, 0.0), + (3, 0.0), + (4, 0.0), + (5, 0.0), + (6, 0.0), + (7, 0.0), + (8, 0.0), + (9, 0.0)], + 'vp0:linear.bias/norm': [(0, 0.24494899809360504), + (1, 0.31434500217437744), + (2, 0.2814750075340271), + (3, 0.006068999879062176), + (4, 0.2398650050163269), + (5, 0.2817699909210205), + (6, 0.1456969976425171), + (7, 0.2817710041999817), + (8, 0.15226399898529053), + (9, 0.1355219930410385)], + 'vp0:linear.weight/nans': [(0, 0.0), + (1, 0.0), + (2, 0.0), + (3, 0.0), + (4, 0.0), + (5, 0.0), + (6, 0.0), + (7, 0.0), + (8, 0.0), + (9, 0.0)], + 'vp0:linear.weight/norm': [(0, 0.5239350199699402), + (1, 0.5956720113754272), + (2, 0.49760299921035767), + (3, 0.23948900401592255), + (4, 0.5050320029258728), + (5, 0.5136330127716064), + (6, 0.3642309904098511), + (7, 0.4831080138683319), + (8, 0.3234719932079315), + (9, 0.32385098934173584)]} + self.assertEqual(compare_scalar_dicts(tb_data, tb_result), True) + + def test_grad_unreduced(self): + data = parse_step_fn(os.path.join(timestamp_dirpath,"grad_unreduced_0-2.csv")) + result = { + 'vp0:linear.bias': { + 0: {'nans': 0.0, 'norm': 0.244949}, + 1: {'nans': 0.0, 'norm': 0.314345}, + 2: {'nans': 0.0, 'norm': 0.281475} + }, + 'vp0:linear.weight': { + 0: {'nans': 0.0, 'norm': 0.523935}, + 1: {'nans': 0.0, 'norm': 0.595672}, + 2: {'nans': 0.0, 'norm': 0.497603} + } + } + self.assertEqual(dict_equal(data, result), True) + + tb_data = extract_scalars_from_tensorboard(os.path.join(csv2tb_dirpath, "grad_unreduced")) + tb_result = { + 'vp0:linear.bias/nans': [(0, 0.0), + (1, 0.0), + (2, 0.0), + (3, 0.0), + (4, 0.0), + (5, 0.0), + (6, 0.0), + (7, 0.0), + (8, 0.0), + (9, 0.0)], + 'vp0:linear.bias/norm': [(0, 0.24494899809360504), + (1, 0.31434500217437744), + (2, 0.2814750075340271), + (3, 0.006068999879062176), + (4, 0.2398650050163269), + (5, 0.2817699909210205), + (6, 0.1456969976425171), + (7, 0.2817710041999817), + (8, 0.15226399898529053), + (9, 0.1355219930410385)], + 'vp0:linear.weight/nans': [(0, 0.0), + (1, 0.0), + (2, 0.0), + (3, 0.0), + (4, 0.0), + (5, 0.0), + (6, 0.0), + (7, 0.0), + (8, 0.0), + (9, 0.0)], + 'vp0:linear.weight/norm': [(0, 0.5239350199699402), + (1, 0.5956720113754272), + (2, 0.49760299921035767), + (3, 0.23948900401592255), + (4, 0.5050320029258728), + (5, 0.5136330127716064), + (6, 0.3642309904098511), + (7, 0.4831080138683319), + (8, 0.3234719932079315), + (9, 0.32385098934173584)]} + self.assertEqual(compare_scalar_dicts(tb_data, tb_result), True) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_module_hook.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_module_hook.py index 76015f50497c08136974ee2dd4c24e793e1c0b92..eefacb73c8e76636086554775b0e6f2e916ddf6e 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_module_hook.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_module_hook.py @@ -5,7 +5,7 @@ from unittest.mock import patch, MagicMock import pandas as pd import torch -from msprobe.core.common.const import MonitorConst +from msprobe.core.common.const import MonitorConst, Const from torch import distributed as dist from msprobe.pytorch.monitor.module_hook import CommunicationContext, GradContext, ModuleHookContext, \ @@ -39,8 +39,7 @@ class TestModuleHook(unittest.TestCase): xy_config = os.path.join(base_dir, "config/xy_config.json") hooker = TrainerMon( xy_config, - params_have_main_grad=False, - opt_ty='Megatron_FP32Optimizer' + params_have_main_grad=False ) self.get_dist_mock(True) @@ -56,7 +55,7 @@ class TestModuleHook(unittest.TestCase): with self.assertRaises(Exception) as context: monitor_demo(print_struct_config) - self.assertEqual(str(context.exception), "exit after first step when print model struct") + self.assertEqual(str(context.exception), "exit after first monitor step when print model struct") def test_xy_distribution(self): xy_monitor_output = "./test_xy_distribution" @@ -73,13 +72,13 @@ class TestModuleHook(unittest.TestCase): self.assertTrue(os.path.exists(actv_grad_0_csv)) # validate columns and lines actv_0 = pd.read_csv(actv_0_csv) - expect_columns = ['vpp_stage', 'module_name', 'step', 'input.norm', 'input.nans', 'output.norm', 'output.nans'] + expect_columns = ['vpp_stage', 'name', 'step', 'micro_step', 'norm', 'nans'] self.assertListEqual(list(actv_0.columns), expect_columns) - self.assertEqual(actv_0.shape, tuple([3, 7])) + self.assertEqual(actv_0.shape, tuple([6, 6])) actv_grad_0 = pd.read_csv(actv_grad_0_csv) - expect_columns = ['vpp_stage', 'module_name', 'step', 'input_grad.norm', 'input_grad.nans', 'output_grad.norm', 'output_grad.nans'] + expect_columns = ['vpp_stage', 'name', 'step', 'micro_step', 'norm', 'nans'] self.assertListEqual(list(actv_grad_0.columns), expect_columns) - self.assertEqual(actv_0.shape, tuple([3, 7])) + self.assertEqual(actv_0.shape, tuple([6, 6])) def test_wg_distribution(self): self.get_dist_mock(False) @@ -96,7 +95,7 @@ class TestModuleHook(unittest.TestCase): self.assertTrue(os.path.exists(grad_reduced_0_csv)) self.assertTrue(os.path.exists(grad_unreduced_0_csv)) # validate columns and lines - expect_columns = ["vpp_stage", "param_name", "step", "norm"] + expect_columns = ["vpp_stage", "name", "step", "norm"] grad_reduced_0 = pd.read_csv(grad_reduced_0_csv) self.assertListEqual(list(grad_reduced_0.columns), expect_columns) self.assertEqual(grad_reduced_0.shape, tuple([2, 4])) @@ -119,7 +118,7 @@ class TestModuleHook(unittest.TestCase): self.assertTrue(os.path.exists(exp_avg_1_csv)) self.assertTrue(os.path.exists(exp_avg_sq_1_csv)) # validate columns and lines - expect_columns = ["vpp_stage", "param_name", "step", "norm"] + expect_columns = ["vpp_stage", "name", "step", "norm"] exp_avg_1 = pd.read_csv(exp_avg_1_csv) self.assertListEqual(list(exp_avg_1.columns), expect_columns) self.assertEqual(exp_avg_1.shape, tuple([2, 4])) @@ -146,8 +145,7 @@ class TestModuleHook(unittest.TestCase): self.get_dist_mock(True) hooker = TrainerMon( cc_config, - params_have_main_grad=False, - opt_ty='Megatron_FP32Optimizer' + params_have_main_grad=False ) self.assertIsNotNone(hooker) @@ -160,7 +158,7 @@ class TestModuleHook(unittest.TestCase): rank_list = [1, 2] ops_list = ['max', 'min'] cc_config = os.path.join(base_dir, "config/cc_config.json") - hooker = TrainerMon(cc_config, params_have_main_grad=False, opt_ty='Megatron_FP32Optimizer') + hooker = TrainerMon(cc_config, params_have_main_grad=False) hooker.adhoc_check(target_tensor, module_name, tensor_name, rank_list, ops_list) def test_generate_cc_metrics(self): @@ -183,32 +181,11 @@ class TestModuleHook(unittest.TestCase): result = TrainerMon.generate_cc_metrics(cc_name, cc_tensor) self.assertDictEqual(result, expected_metrics) - def test_common_info_with_Exception(self): - xy_config = os.path.join(base_dir, "config/xy_config.json") - hooker = TrainerMon( - xy_config, - params_have_main_grad=False, - opt_ty=None - ) - hooker.forward_only = True - - hooker.ur_distribution = True - with self.assertRaises(Exception) as context: - hooker.common_info() - self.assertIn(str(context.exception), "ur_distribution cannot be enabled with unknown optimizer.") - - hooker.ur_distribution = False - hooker.mv_distribution = True - with self.assertRaises(Exception) as context: - hooker.common_info() - self.assertIn(str(context.exception), "mv_distribution cannot be enabled with unknown optimizer.") - def test_generate_xy_metrics(self): xy_config = os.path.join(base_dir, "config/xy_config.json") trainer_mon = TrainerMon( xy_config, - params_have_main_grad=False, - opt_ty='Megatron_FP32Optimizer' + params_have_main_grad=False ) fwd_context = ModuleHookContext("module1") @@ -224,8 +201,7 @@ class TestModuleHook(unittest.TestCase): xy_config = os.path.join(base_dir, "config/xy_config.json") trainer_mon = TrainerMon( xy_config, - params_have_main_grad=False, - opt_ty='Megatron_FP32Optimizer' + params_have_main_grad=False ) trainer_mon.rank = 0 trainer_mon.module_rank_list = [1, 2] @@ -272,59 +248,53 @@ class TestModuleHookContext(unittest.TestCase): self.module_name = "test_module" self.context = ModuleHookContext(self.module_name) self.context.struct = { - MonitorConst.ACTV_IN: { + Const.INPUT: { "config": "tuple[1]", "0": "size=(2, 784), dtype=torch.float32", }, - MonitorConst.ACTV_OUT: { + Const.OUTPUT: { "config": "tensor", "tensor": "size=(2, 10), dtype=torch.float32" }, - MonitorConst.ACTVGRAD_IN: { + MonitorConst.INPUT_GRAD: { "config": "tuple[1]", "0": "size=(2, 784), dtype=torch.float32" }, - MonitorConst.ACTVGRAD_OUT: { + MonitorConst.OUTPUT_GRAD: { "config": "tuple[1]", "0": "size=(2, 10), dtype=torch.float32" } } self.target_config = { self.module_name: { - MonitorConst.ACTV_IN: "tuple[1]:0", - MonitorConst.ACTV_OUT: "tensor", - MonitorConst.ACTVGRAD_IN: "tuple[1]:0" + Const.INPUT: "tuple[1]:0", + Const.OUTPUT: "tensor", + MonitorConst.INPUT_GRAD: "tuple[1]:0" } } - def test_set_format_by_arg_invalid_key(self): - with self.assertRaises(ValueError) as err: - self.context.set_format_by_arg('invalid_key', {}) - self.assertIn(str(err.exception), - "key(invalid_key) error, valid_key: ['input', 'output', 'input_grad', 'output_grad']") - def test_set_format_by_arg_module_name_in_target_config(self): - self.context.set_format_by_arg(MonitorConst.ACTV_IN, self.target_config) - self.assertEqual(self.context.format_by_arg[MonitorConst.ACTV_IN], "tuple[1]:0") - self.context.set_format_by_arg(MonitorConst.ACTV_OUT, self.target_config) - self.assertEqual(self.context.format_by_arg[MonitorConst.ACTV_OUT], "tensor") - self.context.set_format_by_arg(MonitorConst.ACTVGRAD_IN, self.target_config) - self.assertEqual(self.context.format_by_arg[MonitorConst.ACTVGRAD_IN], "tuple[1]:0") - self.context.set_format_by_arg(MonitorConst.ACTVGRAD_OUT, self.target_config) - self.assertEqual(self.context.format_by_arg[MonitorConst.ACTVGRAD_OUT], "tuple[1]") + self.context.set_format_by_arg(Const.INPUT, self.target_config) + self.assertEqual(self.context.format_by_arg[Const.INPUT], "tuple[1]:0") + self.context.set_format_by_arg(Const.OUTPUT, self.target_config) + self.assertEqual(self.context.format_by_arg[Const.OUTPUT], "tensor") + self.context.set_format_by_arg(MonitorConst.INPUT_GRAD, self.target_config) + self.assertEqual(self.context.format_by_arg[MonitorConst.INPUT_GRAD], "tuple[1]:0") + self.context.set_format_by_arg(MonitorConst.OUTPUT_GRAD, self.target_config) + self.assertEqual(self.context.format_by_arg[MonitorConst.OUTPUT_GRAD], "tuple[1]") def test_set_format_by_arg_module_name_not_in_target_config(self): target_config = {} - self.context.set_format_by_arg(MonitorConst.ACTV_IN, target_config) - self.assertEqual(self.context.format_by_arg[MonitorConst.ACTV_IN], "tuple[1]") - self.context.set_format_by_arg(MonitorConst.ACTV_OUT, target_config) - self.assertEqual(self.context.format_by_arg[MonitorConst.ACTV_OUT], "tensor") + self.context.set_format_by_arg(Const.INPUT, target_config) + self.assertEqual(self.context.format_by_arg[Const.INPUT], "tuple[1]") + self.context.set_format_by_arg(Const.OUTPUT, target_config) + self.assertEqual(self.context.format_by_arg[Const.OUTPUT], "tensor") @patch('msprobe.pytorch.monitor.module_hook.logger') def test_set_format_by_arg_target_module_config_error(self, mock_logger): - target_config = {self.module_name: {MonitorConst.ACTV_IN: 123}} - self.context.set_format_by_arg(MonitorConst.ACTV_IN, target_config) - self.assertNotIn(MonitorConst.ACTV_IN, self.context.format_by_arg) + target_config = {self.module_name: {Const.INPUT: 123}} + self.context.set_format_by_arg(Const.INPUT, target_config) + self.assertIsNone(self.context.format_by_arg.get(Const.INPUT)) mock_logger.warning_on_rank_0.assert_called_once() diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_monitor_utils.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_monitor_utils.py index 82a280357b5d387d425380d94dfeadc279674c69..0462ac3f39531119b40d3cc5051fad77f687b9b5 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_monitor_utils.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_monitor_utils.py @@ -8,7 +8,9 @@ from msprobe.core.common.const import MonitorConst from msprobe.pytorch.monitor.utils import filter_special_chars, MsgConst, get_param_struct, validate_ops, \ validate_ranks, validate_targets, validate_print_struct, validate_ur_distribution, validate_xy_distribution, \ validate_mg_distribution, validate_wg_distribution, validate_cc_distribution, validate_alert, validate_config, \ - is_recomputation, get_output_base_dir + get_output_base_dir +from msprobe.pytorch.common.utils import is_recomputation + class TestValidationFunctions(unittest.TestCase): diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_optimizer_collect.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_optimizer_collect.py index 99a26897b3356d6d448a6e4ac14761548ba48aae..793b086b02db03f8a04b159f35f1df55fc1a9d2c 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_optimizer_collect.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_optimizer_collect.py @@ -13,7 +13,6 @@ from msprobe.pytorch.monitor.utils import MVResult, MVGradResult class TestOptimizerMon(unittest.TestCase): - def setUp(self) -> None: # 初始化需要的monitor, torch_opt, params2name等对象 self.monitor = Mock() @@ -56,12 +55,11 @@ class TestOptimizerMon(unittest.TestCase): mock_dist.get_world_size.return_value = 1 # Mocking the wrapped_optimizer - self.optimizer_mon.wrapped_optimizer = MagicMock() - self.optimizer_mon.wrapped_optimizer.state = defaultdict(dict) - self.optimizer_mon.wrapped_optimizer.averaged_gradients = defaultdict(torch.Tensor) - self.optimizer_mon.wrapped_optimizer.partition_size = defaultdict(int) - self.optimizer_mon.wrapped_optimizer.flatten_dense_tensors_aligned = MagicMock() - self.optimizer_mon.wrapped_optimizer.flatten = MagicMock() + self.torch_opt.state = defaultdict(dict) + self.torch_opt.averaged_gradients = defaultdict(torch.Tensor) + self.torch_opt.partition_size = defaultdict(int) + self.torch_opt.flatten_dense_tensors_aligned = MagicMock() + self.torch_opt.flatten = MagicMock() # Mocking the torch_opt.param_groups self.torch_opt.param_groups = [{'step': 1, 'betas': (0.9, 0.999)}, @@ -83,7 +81,6 @@ class TestOptimizerMon(unittest.TestCase): class TestMixPrecisionOptimizerMon(unittest.TestCase): - def test_fetch_mv_with_fp16_to_fp32_param_and_mix_prec_opt(self): # init monitor, torch_opt ... self.monitor = MagicMock() @@ -93,7 +90,6 @@ class TestMixPrecisionOptimizerMon(unittest.TestCase): self.mix_prec_opt.float16_groups = [MagicMock()] self.mix_prec_opt.fp32_from_float16_groups = [MagicMock()] self.optimizer = MixPrecisionOptimizerMon() - self.optimizer.wrapped_optimizer = self.mix_prec_opt self.optimizer.fp16_to_fp32_param = {} # Mock _fetch_mv_in_adam method and set a fixed return value @@ -105,19 +101,17 @@ class TestMixPrecisionOptimizerMon(unittest.TestCase): self.mock_fetch_mv_in_adam.assert_called_once_with(self.monitor, self.torch_opt, self.params2name) self.assertIsInstance(res, MVResult) -class TestChainedMixPrecisionOptimizerMon(unittest.TestCase): +class TestChainedMixPrecisionOptimizerMon(unittest.TestCase): def test_fetch_mv_with_fp16_to_fp32_param_and_mix_prec_opt(self): # init monitor, torch_opt ... self.monitor = MagicMock() self.torch_opt = MagicMock() self.params2name = MagicMock() - self.mix_prec_opt = MagicMock() - self.mix_prec_opt.float16_groups = [MagicMock()] - self.mix_prec_opt.fp32_from_float16_groups = [MagicMock()] + self.torch_opt.float16_groups = [MagicMock()] + self.torch_opt.fp32_from_float16_groups = [MagicMock()] self.optimizer = MegatronChainedMixPrecisionOptimizerMon() self.optimizer.optimizer = [MagicMock(), MagicMock()] - self.optimizer.wrapped_optimizer = self.mix_prec_opt self.optimizer.fp16_to_fp32_param = {} # Mock _fetch_mv_in_adam method and set a fixed return value @@ -135,22 +129,22 @@ class TestMegatronChainedDistributedOptimizerMon(unittest.TestCase): self.monitor = MagicMock() self.torch_opt = MagicMock() self.params2name = MagicMock() - self.mock_wrapped_optimizer = MagicMock() mv_result = MVResult(exp_avg={}, exp_avg_sq={}, update={}, ratio={}) self.mock_fetch_mv_in_adam = MagicMock(return_value=mv_result) self.optimizer = MegatronChainedDistributedOptimizerMon() def test_fetch_mv_with_valid_optimizer(self): - self.mock_wrapped_optimizer.model_float16_groups = [MagicMock()] - self.mock_wrapped_optimizer.shard_fp32_from_float16_groups = [MagicMock()] - self.optimizer.wrapped_optimizer = self.mock_wrapped_optimizer + self.torch_opt.model_float16_groups = [MagicMock()] + self.torch_opt.shard_fp32_from_float16_groups = [MagicMock()] self.optimizer._fetch_mv_in_adam = self.mock_fetch_mv_in_adam res = self.optimizer.fetch_mv(self.monitor, self.torch_opt, self.params2name) self.assertIsInstance(res, MVResult) def test_fetch_mv_with_invalid_optimizer(self): - self.optimizer.wrapped_optimizer = Mock() + self.torch_opt = Mock() + self.torch_opt.model_float16_groups = None + self.torch_opt.shard_fp32_from_float16_groups = None self.optimizer._fetch_mv_in_adam = self.mock_fetch_mv_in_adam with self.assertRaises(Exception): @@ -162,22 +156,22 @@ class TestMegatronDistributedOptimizerMon(unittest.TestCase): self.monitor = MagicMock() self.torch_opt = MagicMock() self.params2name = MagicMock() - self.mock_wrapped_optimizer = MagicMock() mv_result = MVResult(exp_avg={}, exp_avg_sq={}, update={}, ratio={}) self.mock_fetch_mv_in_adam = MagicMock(return_value=mv_result) self.optimizer = MegatronDistributedOptimizerMon() def test_fetch_mv_with_valid_optimizer(self): - self.mock_wrapped_optimizer.model_float16_groups = [MagicMock()] - self.mock_wrapped_optimizer.shard_fp32_from_float16_groups = [MagicMock()] - self.optimizer.wrapped_optimizer = self.mock_wrapped_optimizer + self.torch_opt.model_float16_groups = [MagicMock()] + self.torch_opt.shard_fp32_from_float16_groups = [MagicMock()] self.optimizer._fetch_mv_in_adam = self.mock_fetch_mv_in_adam res = self.optimizer.fetch_mv(self.monitor, self.torch_opt, self.params2name) self.assertIsInstance(res, MVResult) def test_fetch_mv_with_invalid_optimizer(self): - self.optimizer.wrapped_optimizer = Mock() + self.torch_opt = Mock() + self.torch_opt.model_float16_groups = None + self.torch_opt.shard_fp32_from_float16_groups = None self.optimizer._fetch_mv_in_adam = self.mock_fetch_mv_in_adam with self.assertRaises(Exception): @@ -189,31 +183,27 @@ class TestCommonFetchMv(unittest.TestCase): self.monitor = MagicMock() self.torch_opt = MagicMock() self.params2name = MagicMock() - self.mock_wrapped_optimizer = MagicMock() def test_megatron_fp32_optimizer_mon(self): self.optimizer = MegatronFP32OptimizerMon() - self.optimizer.wrapped_optimizer = self.mock_wrapped_optimizer res = self.optimizer.fetch_mv(self.monitor, self.torch_opt, self.params2name) self.assertIsInstance(res, MVResult) def test_deepspeed_zero_optimizer_stage0_mon(self): self.optimizer = DeepSpeedZeroOptimizerStage0Mon() - self.optimizer.wrapped_optimizer = self.mock_wrapped_optimizer res = self.optimizer.fetch_mv(self.monitor, self.torch_opt, self.params2name) self.assertIsInstance(res, MVResult) def test_dummy_optimizer_mon(self): self.optimizer = DummyOptimizerMon() - self.optimizer.wrapped_optimizer = self.mock_wrapped_optimizer res = self.optimizer.fetch_mv(self.monitor, self.torch_opt, self.params2name) self.assertIsInstance(res, MVResult) class TestDeepSpeedZeroOptimizerStage3Mon(unittest.TestCase): def test_get_param_index(self): - OptimizerMon.wrapped_optimizer = Mock() - OptimizerMon.wrapped_optimizer.fp16_partitioned_groups = [ + self.torch_opt = Mock() + self.torch_opt.fp16_partitioned_groups = [ [Mock(flatten=lambda: [1, 2, 3]), Mock(flatten=lambda: [4, 5])], [Mock(flatten=lambda: [6, 7, 8, 9])] @@ -222,7 +212,7 @@ class TestDeepSpeedZeroOptimizerStage3Mon(unittest.TestCase): self.name2index = {'weight1': 0, 'weight2': 2} optimizer_stage3_mon = DeepSpeedZeroOptimizerStage3Mon() - name2indices = optimizer_stage3_mon.get_param_index(self.params2name, self.name2index) + name2indices = optimizer_stage3_mon.get_param_index(self.params2name, self.name2index, self.torch_opt) expected_name2indices = {'weight1': (0, 3, 0, None), 'weight2': (5, 9, 1, None)} self.assertDictEqual(dict(name2indices), expected_name2indices) @@ -231,9 +221,7 @@ class TestDeepSpeedZeroOptimizerStage3Mon(unittest.TestCase): self.monitor = MagicMock() self.torch_opt = MagicMock() self.params2name = MagicMock() - OptimizerMon.wrapped_optimizer = Mock() - OptimizerMon.wrapped_optimizer.fp16_partitioned_groups = MagicMock() - + self.torch_opt.fp16_partitioned_groups = MagicMock() self.optimizer = DeepSpeedZeroOptimizerStage3Mon() # mock _fetch_mv_grad_in_adam @@ -246,7 +234,6 @@ class TestDeepSpeedZeroOptimizerStage3Mon(unittest.TestCase): class TestDeepSpeedZeroOptimizerStage1or2Mon(unittest.TestCase): - def test_get_group_index(self): self.fp32_length = [10, 20, 30, 40] self.world_size = 4 @@ -266,15 +253,15 @@ class TestDeepSpeedZeroOptimizerStage1or2Mon(unittest.TestCase): self.optimizer_monitor = DeepSpeedZeroOptimizerStage1or2Mon() - OptimizerMon.wrapped_optimizer = MagicMock() - OptimizerMon.wrapped_optimizer.groups_padding = [1, 2, 3] - OptimizerMon.wrapped_optimizer.single_partition_of_fp32_groups = [torch.tensor([1, 2]), torch.tensor([3, 4, 5])] - OptimizerMon.wrapped_optimizer.bit16_groups = [ + self.torch_opt = MagicMock() + self.torch_opt.groups_padding = [1, 2, 3] + self.torch_opt.single_partition_of_fp32_groups = [torch.tensor([1, 2]), torch.tensor([3, 4, 5])] + self.torch_opt.bit16_groups = [ [torch.tensor([6, 7]), torch.tensor([8])], [torch.tensor([9, 10, 11])] ] - name2indices = self.optimizer_monitor.get_param_index(self.params2name, self.name2index) + name2indices = self.optimizer_monitor.get_param_index(self.params2name, self.name2index, self.torch_opt) for name, indices in name2indices.items(): self.assertIn(name, self.params2name.values()) self.assertIsInstance(indices, tuple) @@ -284,9 +271,7 @@ class TestDeepSpeedZeroOptimizerStage1or2Mon(unittest.TestCase): self.monitor = MagicMock() self.torch_opt = MagicMock() self.params2name = MagicMock() - OptimizerMon.wrapped_optimizer = Mock() - OptimizerMon.wrapped_optimizer.fp16_partitioned_groups = MagicMock() - + self.torch_opt.fp16_partitioned_groups = MagicMock() self.optimizer = DeepSpeedZeroOptimizerStage1or2Mon() # mock _fetch_mv_grad_in_adam @@ -302,32 +287,52 @@ class TestOptimizerMonFactory(unittest.TestCase): def test_create_optimizer_mon(self): # 测试已知的优化器类型 - self.assertIsInstance(OptimizerMonFactory.create_optimizer_mon("Megatron_Float16OptimizerWithFloat16Params"), + mix_optimizer = MagicMock() + mix_optimizer_class = MagicMock() + mix_optimizer_class.__name__ = "Float16OptimizerWithFloat16Params" + mix_optimizer.__class__ = mix_optimizer_class + self.assertIsInstance(OptimizerMonFactory.create_optimizer_mon(mix_optimizer)[0], MixPrecisionOptimizerMon) - self.assertIsInstance(OptimizerMonFactory.create_optimizer_mon("Megatron_DistributedOptimizer"), + dis_optimizer = MagicMock() + dis_optimizer_class = MagicMock() + dis_optimizer_class.__name__ = "DistributedOptimizer" + dis_optimizer.__class__ = dis_optimizer_class + self.assertIsInstance(OptimizerMonFactory.create_optimizer_mon(dis_optimizer)[0], MegatronDistributedOptimizerMon) - self.assertIsInstance(OptimizerMonFactory.create_optimizer_mon("Megatron_ChainedDistributedOptimizer"), - MegatronChainedDistributedOptimizerMon) - self.assertIsInstance(OptimizerMonFactory.create_optimizer_mon("Megatron_ChainedFloat16OptimizerWithFloat16Params"), - MegatronChainedMixPrecisionOptimizerMon) - self.assertIsInstance(OptimizerMonFactory.create_optimizer_mon("Megatron_FP32Optimizer"), + fp32_optimizer = MagicMock() + fp32_optimizer_class = MagicMock() + fp32_optimizer_class.__name__ = "FP32Optimizer" + fp32_optimizer.__class__ = fp32_optimizer_class + self.assertIsInstance(OptimizerMonFactory.create_optimizer_mon(fp32_optimizer)[0], MegatronFP32OptimizerMon) - self.assertIsInstance(OptimizerMonFactory.create_optimizer_mon("DeepSpeedZeroOptimizer_Stage0"), + chained_optimizer = MagicMock() + chained_optimizer_class = MagicMock() + chained_optimizer_class.__name__ = "ChainedOptimizer" + chained_optimizer.__class__ = chained_optimizer_class + chained_optimizer.chained_optimizers = [mix_optimizer, mix_optimizer] + self.assertIsInstance(OptimizerMonFactory.create_optimizer_mon(chained_optimizer)[0], + MegatronChainedMixPrecisionOptimizerMon) + chained_optimizer.chained_optimizers = [dis_optimizer, dis_optimizer] + self.assertIsInstance(OptimizerMonFactory.create_optimizer_mon(chained_optimizer)[0], + MegatronChainedDistributedOptimizerMon) + deepspeed_optimizer = MagicMock() + deepspeed_optimizer_class = MagicMock() + deepspeed_optimizer_class.__name__ = "BF16_Optimizer" + deepspeed_optimizer.__class__ = deepspeed_optimizer_class + self.assertIsInstance(OptimizerMonFactory.create_optimizer_mon(deepspeed_optimizer)[0], DeepSpeedZeroOptimizerStage0Mon) - self.assertIsInstance(OptimizerMonFactory.create_optimizer_mon("DeepSpeedZeroOptimizer_Stage1_or_2"), + deepspeed_optimizer_class.__name__ = "DeepSpeedZeroOptimizer" + self.assertIsInstance(OptimizerMonFactory.create_optimizer_mon(deepspeed_optimizer)[0], DeepSpeedZeroOptimizerStage1or2Mon) - self.assertIsInstance(OptimizerMonFactory.create_optimizer_mon("DeepSpeedZeroOptimizer_Stage3"), + deepspeed_optimizer_class.__name__ = "DeepSpeedZeroOptimizer_Stage3" + self.assertIsInstance(OptimizerMonFactory.create_optimizer_mon(deepspeed_optimizer)[0], DeepSpeedZeroOptimizerStage3Mon) - # 测试未知的优化器类型,应该返回DummyOptimizerMon - self.assertIsInstance(OptimizerMonFactory.create_optimizer_mon("unknown"), DummyOptimizerMon) - - # 测试空的优化器类型,应该返回DummyOptimizerMon - self.assertIsInstance(OptimizerMonFactory.create_optimizer_mon(""), DummyOptimizerMon) - - # 测试异常情况,如果输入的优化器类型不在已知类型列表中,应该抛出异常 - with self.assertRaises(Exception): - OptimizerMonFactory.create_optimizer_mon("nonexistent") + unknown_optimizer = MagicMock() + unknown_optimizer_class = MagicMock() + unknown_optimizer_class.__name__ = "unknown" + unknown_optimizer.__class__ = unknown_optimizer_class + self.assertIsInstance(OptimizerMonFactory.create_optimizer_mon(unknown_optimizer)[0], DummyOptimizerMon) if __name__ == '__main__': diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/test_service.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/test_service.py index 51b343fc2c6c429a8c3da73a53d4388aa35f55fa..6687f3111050ea53e14e62f3afd55ae1eff2b8c0 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/test_service.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/test_service.py @@ -78,10 +78,10 @@ class TestService(unittest.TestCase): self.service.build_hook = MagicMock() self.config.level = "L0" with patch("msprobe.pytorch.service.logger.info_on_rank_0") as mock_logger, \ - patch("msprobe.pytorch.service.ModuleProcesser.hook_modules") as mock_hook_modules: + patch("msprobe.pytorch.service.ModuleProcesser.register_module_hook") as mock_register_module_hook: self.service.register_module_hook() self.assertEqual(mock_logger.call_count, 1) - mock_hook_modules.assert_called_once() + mock_register_module_hook.assert_called_once() def test_register_api_hook_with_level1(self): self.service.build_hook = MagicMock() diff --git a/debug/accuracy_tools/msprobe/test/visualization_ut/builder/test_graph_builder.py b/debug/accuracy_tools/msprobe/test/visualization_ut/builder/test_graph_builder.py index e76331992ec19c0a0545b10d58b3013ca988e763..706dc8bf82e59f413c3fd559a39af89c6a70be47 100644 --- a/debug/accuracy_tools/msprobe/test/visualization_ut/builder/test_graph_builder.py +++ b/debug/accuracy_tools/msprobe/test/visualization_ut/builder/test_graph_builder.py @@ -66,6 +66,17 @@ class TestGraphBuilder(unittest.TestCase): self.assertEqual(node_id_b, 'Module.root.backward.0') node_id_c = GraphBuilder._handle_backward_upnode_missing(construct_dict, 'Module.module.c.backward.0', None) self.assertIsNone(node_id_c) + construct_dict = {'Module.module.a.forward': 'Module.root.forward', 'Module.module.a.backward': None, + 'Module.root.forward': None, 'Module.root.backward': None, + 'Module.module.b.forward': 'Module.root.forward', + 'Module.module.b.backward': 'Module.root.backward', 'Module.module.c.backward': None} + node_id_a = GraphBuilder._handle_backward_upnode_missing(construct_dict, 'Module.module.a.backward', None) + self.assertEqual(node_id_a, 'Module.root.backward') + node_id_b = GraphBuilder._handle_backward_upnode_missing(construct_dict, 'Module.module.b.backward', + 'Module.root.backward') + self.assertEqual(node_id_b, 'Module.root.backward') + node_id_c = GraphBuilder._handle_backward_upnode_missing(construct_dict, 'Module.module.c.backward', None) + self.assertIsNone(node_id_c) def test__collect_apis_between_modules_only_apis(self): graph = Graph('TestNet') diff --git a/debug/accuracy_tools/msprobe/test/visualization_ut/graph/test_node_op.py b/debug/accuracy_tools/msprobe/test/visualization_ut/graph/test_node_op.py index 4e0bc926b1ecb493deee5f30a680a086f220c739..8cc51126cd76b09ac2abcb20bfb3f2adb2d606cb 100644 --- a/debug/accuracy_tools/msprobe/test/visualization_ut/graph/test_node_op.py +++ b/debug/accuracy_tools/msprobe/test/visualization_ut/graph/test_node_op.py @@ -10,8 +10,7 @@ class TestNodeOp(unittest.TestCase): def test_get_node_op_invalid(self): node_name = "InvalidNodeName" - with self.assertRaises(Exception): - NodeOp.get_node_op(node_name) + self.assertEqual(NodeOp.get_node_op(node_name), NodeOp.module) def test_get_node_op_all(self): test_cases = [ diff --git a/debug/accuracy_tools/msprobe/visualization/builder/graph_builder.py b/debug/accuracy_tools/msprobe/visualization/builder/graph_builder.py index 860bbe4e7ac47a27f753214efcf054f8197c8e5e..814882e6b819e9e6b6b421aec5f8f0b89f03f7c6 100644 --- a/debug/accuracy_tools/msprobe/visualization/builder/graph_builder.py +++ b/debug/accuracy_tools/msprobe/visualization/builder/graph_builder.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,17 +14,19 @@ # limitations under the License. import re + +from msprobe.core.common.const import Const +from msprobe.core.common.file_utils import load_json +from msprobe.visualization.builder.msprobe_adapter import get_input_output +from msprobe.visualization.builder.msprobe_adapter import op_patterns from msprobe.visualization.graph.graph import Graph from msprobe.visualization.graph.node_op import NodeOp from msprobe.visualization.utils import save_json_file, GraphConst -from msprobe.visualization.builder.msprobe_adapter import get_input_output -from msprobe.core.common.file_utils import load_json -from msprobe.core.common.const import Const -from msprobe.visualization.builder.msprobe_adapter import op_patterns class GraphBuilder: backward_pattern = re.compile(r"(\.backward\.)(\d+)$") + forward_pattern = re.compile(r"(\.forward\.)(\d+)$") # 匹配以大写字母开头,后接任意字母,并以Template(结尾 template_pattern = re.compile(r'\b[A-Z][a-zA-Z]*Template\(') @@ -112,12 +114,17 @@ class GraphBuilder: 如果backward节点的父级节点是null,则尝试从同名的forward节点寻找父级节点 """ # 匹配以.backward.后跟一个或多个数字结尾的模式 - backward_pattern = r"(\.backward\.)(\d+)$" - forward_pattern = r"(\.forward\.)(\d+)$" - if re.search(backward_pattern, subnode_id) and not upnode_id: - forward_upnode_id = construct_dict.get(re.sub(backward_pattern, r".forward.\2", subnode_id)) + if GraphBuilder.backward_pattern.search(subnode_id) and not upnode_id: + forward_upnode_id = construct_dict.get(GraphBuilder.backward_pattern.sub(r".forward.\2", subnode_id)) + if forward_upnode_id: + new_upnode_id = GraphBuilder.forward_pattern.sub(r".backward.\2", forward_upnode_id) + if new_upnode_id in construct_dict: + return new_upnode_id + # 匹配以.backward结尾的节点 + if subnode_id.endswith(Const.SEP + Const.BACKWARD) and not upnode_id: + forward_upnode_id = construct_dict.get(subnode_id.replace(Const.BACKWARD, Const.FORWARD)) if forward_upnode_id: - new_upnode_id = re.sub(forward_pattern, r".backward.\2", forward_upnode_id) + new_upnode_id = forward_upnode_id.replace(Const.FORWARD, Const.BACKWARD) if new_upnode_id in construct_dict: return new_upnode_id return upnode_id @@ -147,6 +154,8 @@ class GraphBuilder: input_data, output_data = get_input_output(node_data, node.id) # 更新数据 node.set_input_output(input_data, output_data) + if GraphConst.BATCH_P2P in name: + GraphBuilder._extract_batch_p2p_info(node, node_data) # 反向节点使用对应前向节点的堆栈信息 # 模块命名举例:Module.module.module.GPTModel.backward.0; API命名举例:Tensor.permute.1.backward if (not node_stack_info and @@ -163,6 +172,24 @@ class GraphBuilder: node.add_upnode(upnode) return node + @staticmethod + def _is_valid_batch_p2p_output(param_list): + if not isinstance(param_list, list) or not param_list: + return False + if not isinstance(param_list[0], list) or not param_list[0]: + return False + return True + + @staticmethod + def _extract_batch_p2p_info(node, node_data): + param_list = node_data.get(Const.OUTPUT, []) + # 数据格式:"output": [[{param1}, {param2}, ...]] + if GraphBuilder._is_valid_batch_p2p_output(param_list): + for param in param_list[0]: + info = {GraphConst.OP: param.get(GraphConst.OP), GraphConst.PEER: param.get(GraphConst.PEER), + GraphConst.GROUP_ID: param.get(GraphConst.GROUP_ID)} + node.batch_p2p_info.append(info) + @staticmethod def _collect_apis_between_modules(graph): """ diff --git a/debug/accuracy_tools/msprobe/visualization/builder/msprobe_adapter.py b/debug/accuracy_tools/msprobe/visualization/builder/msprobe_adapter.py index da309c765980aeeacbf5326ce520987e74b473e3..ee5e3f519ed126b2aaa493e0d3a3b7fce33313e4 100644 --- a/debug/accuracy_tools/msprobe/visualization/builder/msprobe_adapter.py +++ b/debug/accuracy_tools/msprobe/visualization/builder/msprobe_adapter.py @@ -23,7 +23,7 @@ from msprobe.core.compare.acc_compare import ModeConfig # 用于将节点名字解析成对应的NodeOp的规则 op_patterns = [ # NodeOp.module - r'^(Module.|Cell.)', + r'^(Module.|Cell.|optimizer|clip_grad)', # NodeOp.function_api r'^(Tensor.|Torch.|Functional.|NPU.|VF.|Distributed.|Aten.|Mint.|Primitive.|Jit.|MintFunctional.)' ] @@ -57,8 +57,8 @@ def run_real_data(dump_path_param, csv_path, framework, is_cross_frame=False): from msprobe.pytorch.compare.pt_compare import PTComparator return PTComparator(mode_config).do_multi_process(dump_path_param, csv_path) else: - from msprobe.mindspore.compare.ms_compare import MSComparator - ms_comparator = MSComparator(mode_config) + from msprobe.mindspore.compare.ms_compare import MSComparator, MappingConfig + ms_comparator = MSComparator(mode_config, MappingConfig()) ms_comparator.cross_frame = is_cross_frame return ms_comparator.do_multi_process(dump_path_param, csv_path) @@ -120,11 +120,13 @@ def compare_data_fuzzy(data_dict_list1, data_dict_list2): return True -def format_node_data(data_dict): +def format_node_data(data_dict, node_id=None): """ - 批量进行节点数据的输出 + 删除节点数据中不需要展示的字段 """ del_list = ['requires_grad', 'full_op_name'] + if node_id and GraphConst.BATCH_P2P in node_id: + del_list.extend(['op', 'peer', 'tag', 'group_id']) for _, value in data_dict.items(): if not isinstance(value, dict): continue diff --git a/debug/accuracy_tools/msprobe/visualization/graph/base_node.py b/debug/accuracy_tools/msprobe/visualization/graph/base_node.py index 58ee17ba9d3d37497d04cb8114be09431396f105..2642ff1e97ebcc055212d4d776eb7c8a08866dc8 100644 --- a/debug/accuracy_tools/msprobe/visualization/graph/base_node.py +++ b/debug/accuracy_tools/msprobe/visualization/graph/base_node.py @@ -34,6 +34,7 @@ class BaseNode: self.micro_step_id = None self.overflow_level = None self.matched_distributed = {} + self.batch_p2p_info = [] def __str__(self): info = f'id:\t{self.id}' @@ -92,8 +93,8 @@ class BaseNode: result = { 'id': self.id, 'node_type': self.op.value, - 'output_data': format_node_data(self.output_data), - 'input_data': format_node_data(self.input_data), + 'output_data': format_node_data(self.output_data, self.id), + 'input_data': format_node_data(self.input_data, self.id), 'upnode': self.upnode.id if self.upnode else 'None', 'subnodes': [node.id for node in self.subnodes], 'matched_node_link': self.matched_node_link, diff --git a/debug/accuracy_tools/msprobe/visualization/graph/distributed_analyzer.py b/debug/accuracy_tools/msprobe/visualization/graph/distributed_analyzer.py index 482d936ea551473165a4f56693e583218e2ddf77..5e68d6b2528aea4d6645da2885fa76a7b9bb97b2 100644 --- a/debug/accuracy_tools/msprobe/visualization/graph/distributed_analyzer.py +++ b/debug/accuracy_tools/msprobe/visualization/graph/distributed_analyzer.py @@ -107,6 +107,15 @@ class DistributedAnalyzer: return None, None return group_ranks, group_id + @staticmethod + def _get_batch_group_info(node, rank): + for data in node.input_data.values(): + group_id = data.get('group_id') + if group_id is not None: + return group_id + logger.warning(f'The group_id of node {node.id} does not exist, {CANNOT_MATCH}{rank}') + return None + def distributed_match(self): for rank, graph in self.graphs.items(): nodes = graph.node_map @@ -115,7 +124,9 @@ class DistributedAnalyzer: if not node_id.startswith(Const.DISTRIBUTED) or node.matched_distributed: continue api_name, distributed_type = self._get_distributed_name_and_type(node_id) - if distributed_type == DistributedType.P2P: + if api_name == GraphConst.BATCH_P2P: + self._batch_p2p_match(node, rank) + elif distributed_type == DistributedType.P2P: self._p2p_match(node, rank, api_name) else: self._collective_match(node, rank, api_name) @@ -138,12 +149,16 @@ class DistributedAnalyzer: for rank, graph in self.graphs.items(): group_count = {} group_info = {} + batch_p2p_count = {} nodes = graph.node_map for node_id, node in nodes.items(): if not node_id.startswith(Const.DISTRIBUTED): continue api_name, distributed_type = self._get_distributed_name_and_type(node_id) - if distributed_type == DistributedType.P2P: + if api_name == GraphConst.BATCH_P2P: + self._make_batch_p2p_mapping(node, rank, batch_p2p_count) + continue + elif distributed_type == DistributedType.P2P: config_info = self.config.get(api_name) target_rank = self._get_target_rank(node, rank, config_info[1]) if target_rank is None: @@ -162,7 +177,32 @@ class DistributedAnalyzer: unique_group_id = group_id + Const.REPLACEMENT_CHARACTER + str(group_count.get(group_id)) group_info[unique_group_id] = node_id group_info[node_id] = unique_group_id - self.group_node_mapping[rank] = group_info + if rank not in self.group_node_mapping: + self.group_node_mapping[rank] = {} + self.group_node_mapping[rank].update(group_info) + + def _make_batch_p2p_mapping(self, node, rank, batch_p2p_count): + """ + 给batch_isend_irecv接口的每个p2p内容赋予唯一标识 + """ + if rank not in self.group_node_mapping: + self.group_node_mapping[rank] = {} + params = [] + for info_dict in node.batch_p2p_info: + op = info_dict.get(GraphConst.OP) + target_rank = info_dict.get(GraphConst.PEER) + if op is None or target_rank is None: + logger.warning('Cannot get param op or peer.') + continue + group_id = op + Const.REPLACEMENT_CHARACTER + Const.RANK + str(target_rank) + \ + Const.REPLACEMENT_CHARACTER + info_dict.get(GraphConst.GROUP_ID, '') + batch_p2p_count[group_id] = batch_p2p_count.get(group_id, 0) + 1 + # 例如: isend_rank0_5a4d31ad765260ba50eb190f1f9fd163_1 + unique_group_id = group_id + Const.REPLACEMENT_CHARACTER + str(batch_p2p_count.get(group_id)) + params.append(unique_group_id) + self.group_node_mapping.get(rank)[unique_group_id] = node.id + if params: + self.group_node_mapping.get(rank)[node.id] = params def _get_distributed_name_and_type(self, node_id): if Const.SEP not in node_id: @@ -316,3 +356,40 @@ class DistributedAnalyzer: if nodes_info: matched_distributed['nodes_info'] = nodes_info node.matched_distributed = matched_distributed + + def _batch_p2p_match(self, node, rank): + """ + 批量点对点匹配 + + 针对torch.distributed.batch_isend_irecv接口,其入参是一个包含点对点通信信息的集合,需要遍历集合对每个点对点通信信息进行匹配 + :param node: 当前集体通信节点 + :param rank: 当前节点所属rank + :return: + """ + unique_group_ids = self.group_node_mapping.get(rank, {}).get(node.id) + if not unique_group_ids: + return + matched_distributed = [] if len(unique_group_ids) > 1 else {} + for unique_group_id in unique_group_ids: + try: + id_info = unique_group_id.split(Const.REPLACEMENT_CHARACTER) + api_name = id_info[0] + target_api_name = self.config.get(api_name)[0] + target_rank = int(id_info[1].replace(Const.RANK, '')) + except Exception as e: + logger.warning(f'Failed to parsing batch p2p parameter with error info: {e}.') + continue + target_node = self._get_target_node(rank, unique_group_id, api_name, target_rank, target_api_name) + if not target_node: + continue + communications_type = self.config.get(api_name)[2] + index = target_node.data.get(GraphConst.OVERFLOW_LEVEL, CompareConst.NAN) if self.overflow_check \ + else target_node.data.get(GraphConst.JSON_INDEX_KEY, CompareConst.NAN) + matched_info = { + 'communications_type': communications_type, + 'nodes_info': {target_rank: [str(index), target_node.id]} + } + matched_distributed.append(matched_info) if isinstance(matched_distributed, list) \ + else matched_distributed.update(matched_info) + if matched_distributed: + node.matched_distributed = matched_distributed diff --git a/debug/accuracy_tools/msprobe/visualization/graph/node_op.py b/debug/accuracy_tools/msprobe/visualization/graph/node_op.py index 26839398ca3b15ab6b8cffa9137999befb6541b2..33bfa9cc2e34a0960c3ff236a1bd183a5753a0ab 100644 --- a/debug/accuracy_tools/msprobe/visualization/graph/node_op.py +++ b/debug/accuracy_tools/msprobe/visualization/graph/node_op.py @@ -16,6 +16,7 @@ from enum import Enum import re from msprobe.visualization.builder.msprobe_adapter import op_patterns +from msprobe.core.common.log import logger class NodeOp(Enum): @@ -32,8 +33,9 @@ class NodeOp(Enum): for op in NodeOp: index = op.value if index < 0 or index >= len(op_patterns): - raise Exception("NodeOp and op_patterns in MsprobeAdapter do not match") + continue pattern = op_patterns[index] if re.match(pattern, node_name): return op - raise Exception(f"Cannot parse node_name {node_name} into NodeOp") + logger.warning(f"Cannot parsing node_name {node_name} into NodeOp, default parsing as module.") + return NodeOp.module diff --git a/debug/accuracy_tools/msprobe/visualization/graph_service.py b/debug/accuracy_tools/msprobe/visualization/graph_service.py index f9bc3e97337ae0263e7c49941031153c57469538..75b0014c1c09abb8dfecf285fed5eed3063827a0 100644 --- a/debug/accuracy_tools/msprobe/visualization/graph_service.py +++ b/debug/accuracy_tools/msprobe/visualization/graph_service.py @@ -16,8 +16,8 @@ import os import time import json -from msprobe.core.common.file_utils import (FileOpen, check_file_type, create_directory, FileChecker, - check_file_or_directory_path) +from msprobe.core.common.file_utils import (check_file_type, create_directory, FileChecker, + check_file_or_directory_path, load_json) from msprobe.core.common.const import FileCheckConst, Const from msprobe.core.common.utils import CompareException from msprobe.core.overflow_check.checker import AnomalyDetector @@ -130,19 +130,21 @@ def _compare_graph_ranks(input_param, args, step=None): output_file_name = f'compare_{step}_{nr}_{current_time}.vis' if step else f'compare_{nr}_{current_time}.vis' result = _compare_graph(input_param, args) result.output_file_name = output_file_name - try: - result.rank = int(nr.replace(Const.RANK, "")) - except Exception as e: - logger.error('The folder name format is incorrect, expected rank+number.') - raise CompareException(CompareException.INVALID_PATH_ERROR) from e + if nr != Const.RANK: + try: + result.rank = int(nr.replace(Const.RANK, "")) + except Exception as e: + logger.error('The folder name format is incorrect, expected rank+number.') + raise CompareException(CompareException.INVALID_PATH_ERROR) from e # 暂存所有rank的graph,用于匹配rank间的分布式节点 compare_graph_results.append(result) # 匹配rank间的分布式节点 - DistributedAnalyzer({obj.rank: obj.graph_n for obj in compare_graph_results}, - args.overflow_check).distributed_match() - DistributedAnalyzer({obj.rank: obj.graph_b for obj in compare_graph_results}, - args.overflow_check).distributed_match() + if len(compare_graph_results) > 1: + DistributedAnalyzer({obj.rank: obj.graph_n for obj in compare_graph_results}, + args.overflow_check).distributed_match() + DistributedAnalyzer({obj.rank: obj.graph_b for obj in compare_graph_results}, + args.overflow_check).distributed_match() for result in compare_graph_results: _export_compare_graph_result(args, [result.graph_n, result.graph_b], result.graph_comparator, @@ -177,14 +179,17 @@ def _build_graph_ranks(dump_ranks_path, args, step=None): output_file_name = f'build_{step}_{rank}_{current_time}.vis' if step else f'build_{rank}_{current_time}.vis' result = _build_graph(dump_path, args) result.output_file_name = output_file_name - try: - result.rank = int(rank.replace(Const.RANK, "")) - except Exception as e: - logger.error('The folder name format is incorrect, expected rank+number.') - raise CompareException(CompareException.INVALID_PATH_ERROR) from e + if rank != Const.RANK: + try: + result.rank = int(rank.replace(Const.RANK, "")) + except Exception as e: + logger.error('The folder name format is incorrect, expected rank+number.') + raise CompareException(CompareException.INVALID_PATH_ERROR) from e build_graph_results.append(result) - DistributedAnalyzer({obj.rank: obj.graph for obj in build_graph_results}, args.overflow_check).distributed_match() + if len(build_graph_results) > 1: + DistributedAnalyzer({obj.rank: obj.graph for obj in build_graph_results}, + args.overflow_check).distributed_match() for result in build_graph_results: _export_build_graph_result(args.output_path, result.graph, result.micro_steps, args.overflow_check, @@ -215,8 +220,7 @@ def _graph_service_parser(parser): def _graph_service_command(args): - with FileOpen(args.input_path, "r") as file: - input_param = json.load(file) + input_param = load_json(args.input_path) npu_path = input_param.get("npu_path") bench_path = input_param.get("bench_path") check_file_or_directory_path(npu_path, isdir=True) diff --git a/debug/accuracy_tools/msprobe/visualization/utils.py b/debug/accuracy_tools/msprobe/visualization/utils.py index b808ddb1bca779b9daa4bd22767d7d6558b64b89..623bcd11c45f1ff8e9c283d30a982af239706ce4 100644 --- a/debug/accuracy_tools/msprobe/visualization/utils.py +++ b/debug/accuracy_tools/msprobe/visualization/utils.py @@ -97,6 +97,10 @@ def check_directory_content(input_path): if all(os.path.isfile(os.path.join(input_path, item)) for item in contents): return GraphConst.FILES + # 单卡只有一个rank文件夹 + if contents == [Const.RANK]: + return GraphConst.RANKS + rank_pattern = re.compile(r'^rank\d+$') step_pattern = re.compile(r'^step\d+$') @@ -151,6 +155,7 @@ class GraphConst: SUMMARY_COMPARE = 0 MD5_COMPARE = 1 REAL_DATA_COMPARE = 2 + STRUCTURE_COMPARE = 3 JSON_NPU_KEY = 'NPU' JSON_BENCH_KEY = 'Bench' JSON_TIP_KEY = 'ToolTip' @@ -196,13 +201,15 @@ class GraphConst: DUMP_MODE_TO_GRAPHCOMPARE_MODE_MAPPING = { Const.ALL: REAL_DATA_COMPARE, Const.SUMMARY: SUMMARY_COMPARE, - Const.MD5: MD5_COMPARE + Const.MD5: MD5_COMPARE, + Const.STRUCTURE: STRUCTURE_COMPARE } GRAPHCOMPARE_MODE_TO_DUMP_MODE_TO_MAPPING = { REAL_DATA_COMPARE: Const.ALL, SUMMARY_COMPARE: Const.SUMMARY, - MD5_COMPARE: Const.MD5 + MD5_COMPARE: Const.MD5, + STRUCTURE_COMPARE: Const.STRUCTURE } RANKS = 'ranks' @@ -211,3 +218,8 @@ class GraphConst: SRC = 'src' DST = 'dst' + + BATCH_P2P = 'batch_isend_irecv' + OP = 'op' + PEER = 'peer' + GROUP_ID = 'group_id' diff --git a/debug/accuracy_tools/setup.py b/debug/accuracy_tools/setup.py index 4b46216d7b10afd44797cb752ea015f5545eae90..2da7fcf667765a841b9db1bbf5628fad5b1cf8a9 100644 --- a/debug/accuracy_tools/setup.py +++ b/debug/accuracy_tools/setup.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,7 +14,7 @@ # limitations under the License. -__version__ = '1.2.0' +__version__ = '1.2.2' import subprocess import platform @@ -53,21 +53,20 @@ EXCLUDE_PKGS = [ if "--plat-name" in sys.argv or "--python-tag" in sys.argv: raise SystemError("Specifing platforms or python version is not supported.") -if (platform.system() != "Linux"): +if platform.system() != "Linux": raise SystemError("MsProbe is only supported on Linux platforms.") - -mod_list_range = {"adump",} +mod_list_range = {"adump", } mod_list = [] -for i in range(len(sys.argv)): - if sys.argv[i].startswith("--include-mod"): - if sys.argv[i].startswith("--include-mod="): - mod_list = sys.argv[i][len("--include-mod="):].split(',') - sys.argv.remove(sys.argv[i]) +for i, arg in enumerate(sys.argv): + if arg.startswith("--include-mod"): + if arg.startswith("--include-mod="): + mod_list = arg[len("--include-mod="):].split(',') + sys.argv.remove(arg) elif i + 1 < len(sys.argv) and not sys.argv[i + 1].startswith("--"): mod_list = sys.argv[i + 1].split(',') sys.argv.remove(sys.argv[i + 1]) - sys.argv.remove(sys.argv[i]) + sys.argv.remove(arg) mod_list = list(set(mod_list) & mod_list_range) break diff --git a/dynolog_npu/README.md b/dynolog_npu/README.md index 08bb22ba4d131c78eee821361e6a36ac96f4654c..bc7fe2dbf9d7ec9a25ced566e27c1260955a16cb 100644 --- a/dynolog_npu/README.md +++ b/dynolog_npu/README.md @@ -85,32 +85,67 @@ nputrace子命令支持的参数选项 | 子命令 | 参数类型 | 说明 | |-------|-------|-------| -| record_shapes | action | 是否采集算子的InputShapes和InputTypes,设置参数采集,默认不采集 | -| profile_memory | action | 是否采集算子内存信息,设置参数采集,默认不采集 | -| with_stack | action | 是否采集Python调用栈,设置参数采集,默认不采集 | -| with_flops | action | 是否采集算子flops,设置参数采集,默认不采集 | -| with_modules | action | 是否采集modules层级的Python调用栈,设置参数采集,默认不采集 | +| job-id | u64 | 采集任务的job id,默认值0,dynolog原生参数 | +| pids | String | 采集任务的pid列表,多个pid用逗号分隔,默认值0,dynolog原生参数 | +| process-limit | u64 | 最大采集进程的数量,默认值3,dynolog原生参数 | +| profile-start-time | u64 | 用于同步采集的Unix时间戳,单位毫秒,默认值0,dynolog原生参数 | +| duration-ms | u64 | 采集的周期,单位毫秒,默认值500,dynolog原生参数 | +| iterations | i64 | 采集总迭代数,默认值-1,dynolog原生参数 | +| log-file | String | 采集落盘的路径,必选值 | +| start-step | u64 | 开始采集的迭代数,默认值0 | +| record-shapes | action | 是否采集算子的InputShapes和InputTypes,设置参数采集,默认不采集 | +| profile-memory | action | 是否采集算子内存信息,设置参数采集,默认不采集 | +| with-stack | action | 是否采集Python调用栈,设置参数采集,默认不采集 | +| with-flops | action | 是否采集算子flops,设置参数采集,默认不采集 | +| with-modules | action | 是否采集modules层级的Python调用栈,设置参数采集,默认不采集 | | analyse | action | 采集后是否自动解析,设置参数解析,默认不解析 | -| l2_cache | action | 是否采集L2 Cache数据,设置参数采集,默认不采集 | -| op_attr | action | 是否采集算子属性信息,设置参数采集,默认不采集 | -| data_simplification | String | 解析完成后是否数据精简,可选值范围[`true`, `false`],默认值`true` | +| l2-cache | action | 是否采集L2 Cache数据,设置参数采集,默认不采集 | +| op-attr | action | 是否采集算子属性信息,设置参数采集,默认不采集 | +| msprof-tx | action | 是否使能MSTX,设置参数采集,默认使能 | +| data-simplification | String | 解析完成后是否数据精简,可选值范围[`true`, `false`],默认值`true` | | activities | String | 控制CPU、NPU事件采集范围,可选值范围[`CPU,NPU`, `NPU,CPU`, `CPU`, `NPU`],默认值`CPU,NPU` | -| profiler_level | String | 控制profiler的采集等级,可选值范围[`Level_none`, `Level0`, `Level1`, `Level2`],默认值`Level0`| -| aic_metrics | String | AI Core的性能指标采集项,可选值范围[`AiCoreNone`, `PipeUtilization`, `ArithmeticUtilization`, `Memory`, `MemoryL0`, `ResourceConflictRatio`, `MemoryUB`, `L2Cache`, `MemoryAccess`],默认值`AiCoreNone`| -| export_type | String | profiler解析导出数据的类型,可选值范围[`Text`, `Db`],默认值`Text`| -| gc_detect_threshold | Option | GC检测阈值,单位ms,只采集超过阈值的GC事件。该参数为可选参数,默认不设置时不开启GC检测 | +| profiler-level | String | 控制profiler的采集等级,可选值范围[`Level_none`, `Level0`, `Level1`, `Level2`],默认值`Level0`| +| aic-metrics | String | AI Core的性能指标采集项,可选值范围[`AiCoreNone`, `PipeUtilization`, `ArithmeticUtilization`, `Memory`, `MemoryL0`, `ResourceConflictRatio`, `MemoryUB`, `L2Cache`, `MemoryAccess`],默认值`AiCoreNone`| +| export-type | String | profiler解析导出数据的类型,可选值范围[`Text`, `Db`],默认值`Text`| +| gc-detect-threshold | Option | GC检测阈值,单位ms,只采集超过阈值的GC事件。该参数为可选参数,默认不设置时不开启GC检测 | -- nputrace示例命令 +- nputrace使用方法 + +Step1: 拉起dynolog daemon进程 +```bash +# 方法1:使用systemd拉起service +# 修改配置文件/etc/dynolog.gflags, 使能ipc_monitor +echo "--enable_ipc_monitor" | sudo tee -a /etc/dynolog.gflags +sudo systemctl start dynolog + +# 方法2:命令行执行 +dynolog --enable-ipc-monitor + +#dynolog daemon的日志路径为:/var/log/dynolog.log +``` + +Step 2:使能dynolog trace dump环境变量 +```bash +export KINETO_USE_DAEMON=1 +``` + +Step 3: 拉起训练任务 +```bash +# 训练任务中需要使用pytorch的优化器/继承原生优化器 +bash train.sh +``` + +Step 4:使用dyno CLI动态触发trace dump ```bash -# 示例1:采集框架、CANN和device数据,同时采集完后自动解析以及解析完成不做数据精简,落盘路径为/tmp/profile_data -dyno nputrace --activities CPU,NPU --analyse --data_simplification false --log-file /tmp/profile_data +# 示例1:从第10个step开始采集,采集2个step,采集框架、CANN和device数据,同时采集完后自动解析以及解析完成不做数据精简,落盘路径为/tmp/profile_data +dyno nputrace --start-step 10 --iterations 2 --activities CPU,NPU --analyse --data-simplification false --log-file /tmp/profile_data -# 示例2:只采集CANN和device数据,同时采集完后自动解析以及解析完成后开启数据精简,落盘路径为/tmp/profile_data -dyno nputrace --activities NPU --analyse --data_simplification true --log-file /tmp/profile_data +# 示例2:从第10个step开始采集,采集2个step,只采集CANN和device数据,同时采集完后自动解析以及解析完成后开启数据精简,落盘路径为/tmp/profile_data +dyno nputrace --start-step 10 --iterations 2 --activities NPU --analyse --data-simplification true --log-file /tmp/profile_data -# 示例3:只采集CANN和device数据,只采集不解析,落盘路径为/tmp/profile_data -dyno nputrace --activities NPU --log-file /tmp/profile_data +# 示例3:从第10个step开始采集,采集2个step,只采集CANN和device数据,只采集不解析,落盘路径为/tmp/profile_data +dyno nputrace --start-step 10 --iterations 2 --activities NPU --log-file /tmp/profile_data ``` ### NPU Monitor功能 @@ -129,20 +164,50 @@ dyno npu-monitor [SUBCOMMANDS] npu-monitor子命令支持的参数选项 | 子命令 | 参数类型 | 说明 | |-------|-------|-------| -| npu_monitor_start | action | 开启性能监控,设置参数开启,默认不采集 | -| npu_monitor_stop | action | 停止性能监控,设置参数开启,默认不采集 | -| report_interval_s | int | 性能监控数据上报周期,单位s,需要在启动时设置。默认值60 | -| mspti_activity_kind | String | 性能监控数据上报数据类型,可以设置单个或多个,多个类型以逗号分隔,需要在启动时设置。可选值范围[`Marker`, `Kernel`, `API`, `Hccl`, `Memory`, `MemSet`, `MemCpy`] , 默认值`Marker`| +| npu-monitor-start | action | 开启性能监控,设置参数开启,默认不采集 | +| npu-monitor-stop | action | 停止性能监控,设置参数开启,默认不采集 | +| report-interval-s | int | 性能监控数据上报周期,单位s,需要在启动时设置。默认值60 | +| mspti-activity-kind | String | 性能监控数据上报数据类型,可以设置单个或多个,多个类型以逗号分隔,需要在启动时设置。可选值范围[`Marker`, `Kernel`, `API`, `Hccl`, `Memory`, `MemSet`, `MemCpy`] , 默认值`Marker`| -- npu-monitor示例命令 +- npu-monitor使用方法 +Step1: 拉起dynolog daemon进程 +```bash +# 方法1:使用systemd拉起service +# 修改配置文件/etc/dynolog.gflags, 使能ipc_monitor +echo "--enable_ipc_monitor" | sudo tee -a /etc/dynolog.gflags +sudo systemctl start dynolog + +# 方法2:命令行执行 +dynolog --enable-ipc-monitor + +#dynolog daemon的日志路径为:/var/log/dynolog.log +``` + +Step 2:使能dynolog trace dump环境变量 +```bash +export KINETO_USE_DAEMON=1 +``` + +Step 3: 拉起训练任务 +```bash +# 训练任务中需要使用pytorch的优化器/继承原生优化器 +bash train.sh +``` + +Step 4:使用dyno CLI使能npu-monitor ```bash # 示例1:开启性能监控,使用默认配置 -dyno npu-monitor --npu_monitor_start +dyno npu-monitor --npu-monitor-start # 示例2:暂停性能监控 -dyno npu-monitor --npu_monitor_stop +dyno npu-monitor --npu-monitor-stop + +# 示例3:性能监控过程中修改配置 +# 上报周期30s, 上报数据类型Marker和Kernel +dyno npu-monitor --report-interval-s 30 --mspti-activity-kind Marker,Kernel -# 示例3:开启性能监控,上报周期30s, 上报数据类型Marker和Kernel -dyno npu-monitor --npu_monitor_start 30 --mspti_activity_kind Marker,Kernel +# 示例4:性能监控开启时修改配置 +# 上报周期30s, 上报数据类型Marker和Kernel +dyno npu-monitor --npu-monitor-start --report-interval-s 30 --mspti-activity-kind Marker,Kernel ``` \ No newline at end of file diff --git a/dynolog_npu/dynolog_npu/cli/src/commands/nputrace.rs b/dynolog_npu/dynolog_npu/cli/src/commands/nputrace.rs index 4bf7132de338d8eee0de556449269712617772e2..f70923bca4cc5ce29a8855a464c411b63a930ef0 100644 --- a/dynolog_npu/dynolog_npu/cli/src/commands/nputrace.rs +++ b/dynolog_npu/dynolog_npu/cli/src/commands/nputrace.rs @@ -55,6 +55,7 @@ pub struct NpuTraceOptions { pub aic_metrics: String, pub l2_cache: bool, pub op_attr: bool, + pub msprof_tx: bool, pub gc_detect_threshold: Option, pub data_simplification: String, pub export_type: String, @@ -75,6 +76,7 @@ PROFILE_PROFILER_LEVEL={} PROFILE_AIC_METRICS={} PROFILE_L2_CACHE={} PROFILE_OP_ATTR={} +PROFILE_MSPROF_TX={} PROFILE_GC_DETECT_THRESHOLD={} PROFILE_DATA_SIMPLIFICATION={} PROFILE_EXPORT_TYPE={}"#, @@ -89,6 +91,7 @@ PROFILE_EXPORT_TYPE={}"#, self.aic_metrics, self.l2_cache, self.op_attr, + self.msprof_tx, self.gc_detect_threshold.map_or("None".to_string(), |v| v.to_string()), self.data_simplification, self.export_type @@ -213,6 +216,7 @@ ACTIVITIES_ITERATIONS=1000"# aic_metrics: "AiCoreNone".to_string(), l2_cache: true, op_attr: true, + msprof_tx: true, gc_detect_threshold: 0.1, data_simplification: "true", export_type: "Text".to_string(), @@ -234,6 +238,7 @@ PROFILE_PROFILER_LEVEL=Level0 PROFILE_AIC_METRICS=AiCoreNone PROFILE_L2_CACHE=true PROFILE_OP_ATTR=true +PROFILE_MSPROF_TX=true PROFILE_GC_DETECT_THRESHOLD=0.1 PROFILE_DATA_SIMPLIFICATION=true PROFILE_EXPORT_TYPE=Text"# diff --git a/dynolog_npu/dynolog_npu/cli/src/main.rs b/dynolog_npu/dynolog_npu/cli/src/main.rs index 8bc4a2af0e2c19d6e783663924578e3c2ad7408a..9fdea3d1254467081356b2e0daeb8ed3ca05a16d 100644 --- a/dynolog_npu/dynolog_npu/cli/src/main.rs +++ b/dynolog_npu/dynolog_npu/cli/src/main.rs @@ -172,6 +172,9 @@ enum Command { /// Whether to collect op attributes. #[clap(long, action)] op_attr: bool, + /// Whether to enable MSTX. + #[clap(long, action)] + msprof_tx: bool, /// GC detect threshold. #[clap(long)] gc_detect_threshold: Option, @@ -290,6 +293,7 @@ fn main() -> Result<()> { aic_metrics, l2_cache, op_attr, + msprof_tx, gc_detect_threshold, data_simplification, export_type, @@ -318,6 +322,7 @@ fn main() -> Result<()> { aic_metrics, l2_cache, op_attr, + msprof_tx, gc_detect_threshold, data_simplification, export_type, diff --git a/dynolog_npu/plugin/ipc_monitor/DynoLogNpuMonitor.cpp b/dynolog_npu/plugin/ipc_monitor/DynoLogNpuMonitor.cpp index 4aea4b8aa1d7a43e652f4a2e02c5379f71856788..e52bfece61f35e2cd38346e374a9bb4d3b4b2004 100644 --- a/dynolog_npu/plugin/ipc_monitor/DynoLogNpuMonitor.cpp +++ b/dynolog_npu/plugin/ipc_monitor/DynoLogNpuMonitor.cpp @@ -1,73 +1,81 @@ #include "DynoLogNpuMonitor.h" - -#include - +#include +#include +#include #include "utils.h" -#include "log.h" namespace dynolog_npu { namespace ipc_monitor { bool DynoLogNpuMonitor::Init() { + if (isInitialized_) { - std::cout << "[WRARNING] DynoLog npu monitor already initialized" << std::endl; + LOG(WARNING) << "DynoLog npu monitor already initialized"; return true; } bool res = ipcClient_.RegisterInstance(npuId_); if (res) { isInitialized_ = true; - std::cout << "[INFO] DynoLog npu monitor initialized success !" << std::endl; + LOG(INFO) << "DynoLog npu monitor initialized successfully"; } return res; } -ErrCode DynoLogNpuMonitor::DealMonitorReq(const MsptiMonitorCfg& cmd) +ErrCode DynoLogNpuMonitor::DealMonitorReq(const MsptiMonitorCfg& cmd) { if (cmd.monitorStart && !msptiMonitor_.IsStarted()) { - PRINT_INFO("Start Mspti Monitor thread to collect, reportTimes: %s, enableActivity: %s", reportTimes, enableActivity); + LOG(INFO) << "Start mspti monitor thread successfully"; msptiMonitor_.Start(); } if (cmd.monitorStop && msptiMonitor_.IsStarted()) { - PRINT_INFO("End Mpsit Monitor thread"); + LOG(INFO) << "Stop mspti monitor thread successfully"; msptiMonitor_.Stop(); } - for (auto activity : cmd.enableActivities) { - if (activity > MSPTI_ACTIVITY_KIND_INVALID) { + if (!cmd.enableActivities.empty()) { + auto curActivities = msptiMonitor_.GetEnabledActivities(); + std::vector enableKinds, disableKinds; + std::set_difference(cmd.enableActivities.begin(), cmd.enableActivities.end(), curActivities.begin(), curActivities.end(), + std::back_inserter(enableKinds)); + std::set_difference(curActivities.begin(), curActivities.end(), cmd.enableActivities.begin(), cmd.enableActivities.end(), + std::back_inserter(disableKinds)); + for (auto activity : enableKinds) { msptiMonitor_.EnableActivity(activity); } + for (auto activity : disableKinds) { + msptiMonitor_.DisableActivity(activity); + } } msptiMonitor_.SetFlushInterval(cmd.reportIntervals); - return ErrCode.SUC; + return ErrCode::SUC; } std::string DynoLogNpuMonitor::Poll() { std::string res = ipcClient_.IpcClientNpuConfig(); if (res.size() == 4) { // res为4,表示dynolog注册进程成功 - PRINT_INFO("Regist dynolog success"); + LOG(INFO) << "Regist to dynolog daemon successfully"; return ""; } if (res.empty()) { - std::cout << "[INFO] Request for dynolog server is empty !" << std::endl; + LOG(INFO) << "Request from dynolog daemon is empty"; return ""; } - std::cout << "[INFO] Received NPU configuration successfully" << std::endl; + LOG(INFO) << "Received NPU configuration successfully"; return res; } -void DynoLogNpuMonitor::EnableMsptiMonitor(std::unordered_map& cfg_map) +void DynoLogNpuMonitor::EnableMsptiMonitor(std::unordered_map& cfg_map) { auto cmd = InputParser::GetInstance()->DynoLogGetOpts(cfg_map); if (cmd.isMonitor) { auto ans = DealMonitorReq(cmd); - if (ans != ErrCode.SUC) { - PRINT_ERROR("deal monitor request fail, because" + IPC_ERROR(ans)); + if (ans != ErrCode::SUC) { + LOG(ERROR) << "Deal monitor request failed, because" << IPC_ERROR(ans); } } } - } // namespace ipc_monitor -} // namespace dynolog_npu \ No newline at end of file +} // namespace dynolog_npu diff --git a/dynolog_npu/plugin/ipc_monitor/InputParser.cpp b/dynolog_npu/plugin/ipc_monitor/InputParser.cpp index 7a1f0d20b302e4917a02ee386f9925f6ec1d988b..6f77b2ba86a6b835cef07f564c3dc57f7e77d71e 100644 --- a/dynolog_npu/plugin/ipc_monitor/InputParser.cpp +++ b/dynolog_npu/plugin/ipc_monitor/InputParser.cpp @@ -1,11 +1,6 @@ #include "InputParser.h" - -#include -#include #include #include - -#include "log.h" #include "utils.h" namespace dynolog_npu { @@ -18,8 +13,8 @@ const std::string NPU_MONITOR_STOP_KEY = "NPU_MONITOR_STOP"; const std::unordered_set cfgMap { "MSPTI_ACTIVITY_KIND", - "REPORT_INTERVAL_S", - "NPU_MONITOR_START", + "REPORT_INTERVAL_S", + "NPU_MONITOR_START", "NPU_MONITOR_STOP", "REQUEST_TRACE_ID" }; @@ -33,18 +28,18 @@ const std::unordered_map kindStrMap { {"MemCpy", MSPTI_ACTIVITY_KIND_MEMCPY} }; -std::vector str2Kinds(const std::string& kindStrs) +std::set str2Kinds(const std::string& kindStrs) { - std::vector ans; + std::set res; auto kindStrList = split(kindStrs, ','); for (auto& kindStr : kindStrList) { auto kind = kindStrMap.find(kindStr); if (kind == kindStrMap.end()) { return {MSPTI_ACTIVITY_KIND_INVALID}; } - ans.push_back(kind); + res.insert(kind->second); } - return ans; + return res; } MsptiMonitorCfg InputParser::DynoLogGetOpts(std::unordered_map& cmd) @@ -52,7 +47,7 @@ MsptiMonitorCfg InputParser::DynoLogGetOpts(std::unordered_map activityKinds = str2Kinds(cmd[MSPTI_ACTIVITY_KIND_KEY]); + auto activityKinds = str2Kinds(cmd[MSPTI_ACTIVITY_KIND_KEY]); uint32_t reportTimes = 0; Str2Uint32(reportTimes, cmd[REPORT_INTERVAL_S_KEY]); bool startSwitch = false; diff --git a/dynolog_npu/plugin/ipc_monitor/InputParser.h b/dynolog_npu/plugin/ipc_monitor/InputParser.h index 4b498d76243dc11997e9c14c3edbd5320956efb4..e5f674e1605b3721a75372113ee5d7f012c5e506 100644 --- a/dynolog_npu/plugin/ipc_monitor/InputParser.h +++ b/dynolog_npu/plugin/ipc_monitor/InputParser.h @@ -3,7 +3,7 @@ #include #include -#include +#include #include namespace dynolog_npu { @@ -11,7 +11,7 @@ namespace ipc_monitor { struct MsptiMonitorCfg { - std::vector enableActivities; + std::set enableActivities; uint32_t reportIntervals; bool monitorStart; bool monitorStop; diff --git a/dynolog_npu/plugin/ipc_monitor/MonitorBase.h b/dynolog_npu/plugin/ipc_monitor/MonitorBase.h index 108023c7624b747e5987be9184d6c594decd360a..29be0b6be04083babb8d20e5386e93c053a41357 100644 --- a/dynolog_npu/plugin/ipc_monitor/MonitorBase.h +++ b/dynolog_npu/plugin/ipc_monitor/MonitorBase.h @@ -1,5 +1,6 @@ #ifndef MONITOR_BASE_H #define MONITOR_BASE_H + #include namespace dynolog_npu { @@ -14,5 +15,4 @@ public: } // namespace ipc_monitor } // namespace dynolog_npu - -#endif \ No newline at end of file +#endif // MONITOR_BASE_H diff --git a/dynolog_npu/plugin/ipc_monitor/MsptiMonitor.cpp b/dynolog_npu/plugin/ipc_monitor/MsptiMonitor.cpp index 8fd4f8292b966e1c08d23c739fd01dac55863d4e..693a91b2eeeb59beb5bd3b50db9fd5b48c5f6cb3 100644 --- a/dynolog_npu/plugin/ipc_monitor/MsptiMonitor.cpp +++ b/dynolog_npu/plugin/ipc_monitor/MsptiMonitor.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include "DynoLogNpuMonitor.h" namespace { @@ -163,21 +164,24 @@ void MsptiMonitor::Start() } SetThreadName("MsptiMonitor"); if (Thread::Start() != 0) { - std::cout << "MsptiMonitor::Start failed" << std::endl; + LOG(ERROR) << "MsptiMonitor start failed"; return; } start_.store(true); + LOG(INFO) << "MsptiMonitor start successfully"; } void MsptiMonitor::Stop() { if (!start_.load()) { + LOG(WARNING) << "MsptiMonitor is not running"; return; } Uninit(); if (msptiActivityFlushAll(1) != MSPTI_SUCCESS) { - std::cout << "MsptiMonitor::Stop msptiActivityFlushAll failed" << std::endl; + LOG(WARNING) << "MsptiMonitor stop msptiActivityFlushAll failed"; } + LOG(INFO) << "MsptiMonitor stop successfully"; } void MsptiMonitor::Uninit() @@ -197,7 +201,7 @@ void MsptiMonitor::EnableActivity(msptiActivityKind kind) if (msptiActivityEnable(kind) == MSPTI_SUCCESS) { enabledActivities_.insert(kind); } else { - std::cout << "MsptiMonitor::EnableActivity failed, kind: " << static_cast(kind) << std::endl; + LOG(ERROR) << "MsptiMonitor enableActivity failed, kind: " << static_cast(kind); } } } @@ -209,7 +213,7 @@ void MsptiMonitor::DisableActivity(msptiActivityKind kind) if (msptiActivityDisable(kind) == MSPTI_SUCCESS) { enabledActivities_.erase(kind); } else { - std::cout << "MsptiMonitor::DisableActivity failed, kind: " << static_cast(kind) << std::endl; + LOG(ERROR) << "MsptiMonitor disableActivity failed, kind: " << static_cast(kind); } } } @@ -228,14 +232,20 @@ bool MsptiMonitor::IsStarted() return start_.load(); } +std::set MsptiMonitor::GetEnabledActivities() +{ + std::lock_guard lock(activityMtx_); + return enabledActivities_; +} + void MsptiMonitor::Run() { if (msptiSubscribe(&subscriber_, nullptr, nullptr) != MSPTI_SUCCESS) { - std::cout << "MsptiMonitor::Run failed, msptiSubscribe failed" << std::endl; + LOG(ERROR) << "MsptiMonitor run failed, msptiSubscribe failed"; return; } if (msptiActivityRegisterCallbacks(BufferRequest, BufferComplete) != MSPTI_SUCCESS) { - std::cout << "MsptiMonitor::Run failed, msptiActivityRegisterCallbacks failed" << std::endl; + LOG(ERROR) << "MsptiMonitor run failed, msptiActivityRegisterCallbacks failed"; return; } while (true) @@ -255,12 +265,12 @@ void MsptiMonitor::Run() } if (flushInterval_.load() > 0) { if (msptiActivityFlushAll(1) != MSPTI_SUCCESS) { - std::cout << "MsptiMonitor::Run msptiActivityFlushAll failed" << std::endl; + LOG(ERROR) << "MsptiMonitor run msptiActivityFlushAll failed"; } } } if (msptiUnsubscribe(subscriber_) != MSPTI_SUCCESS) { - std::cout << "MsptiMonitor::Run failed, msptiUnsubscribe failed" << std::endl; + LOG(ERROR) << "MsptiMonitor run failed, msptiUnsubscribe failed"; } { std::lock_guard lock(activityMtx_); @@ -284,7 +294,7 @@ void MsptiMonitor::BufferRequest(uint8_t **buffer, size_t *size, size_t *maxNumR if (allocCnt.load() >= MAX_ALLOC_CNT) { *buffer = nullptr; *size = 0; - std::cout << "MsptiMonitor::BufferRequest failed, allocCnt: " << allocCnt.load() << std::endl; + LOG(ERROR) << "MsptiMonitor BufferRequest failed, allocCnt: " << allocCnt.load(); return; } uint8_t *pBuffer = ReinterpretConvert(MsptiMalloc(DEFAULT_BUFFER_SIZE, ALIGN_SIZE)); @@ -295,12 +305,14 @@ void MsptiMonitor::BufferRequest(uint8_t **buffer, size_t *size, size_t *maxNumR *buffer = pBuffer; *size = DEFAULT_BUFFER_SIZE; allocCnt++; + LOG(INFO) << "MsptiMonitor BufferRequest, size: " << *size; } } void MsptiMonitor::BufferComplete(uint8_t *buffer, size_t size, size_t validSize) { if (validSize > 0 && buffer != nullptr) { + LOG(INFO) << "MsptiMonitor BufferComplete, size: " << size << ", validSize: " << validSize; msptiActivity *record = nullptr; msptiResult status = MSPTI_SUCCESS; do { @@ -310,7 +322,7 @@ void MsptiMonitor::BufferComplete(uint8_t *buffer, size_t size, size_t validSize } else if (status == MSPTI_ERROR_MAX_LIMIT_REACHED) { break; } else { - std::cout << "MsptiMonitor::BufferComplete failed, status: " << static_cast(status) << std::endl; + LOG(ERROR) << "MsptiMonitor BufferComplete failed, status: " << static_cast(status); break; } } while (true); @@ -353,24 +365,23 @@ void MsptiMonitor::BufferConsume(msptiActivity *record) void MsptiMonitor::SendMessage(const std::string &message) { if (message.empty()) { - std::cout << "MsptiMonitor::SendMessage message is empty" << std::endl; + LOG(WARNING) << "MsptiMonitor SendMessage message is empty"; return; } static const std::string destName = DYNO_IPC_NAME + "_data"; static const int maxRetry = 5, retryWaitTimeUs = 1000; auto msg = Message::ConstructStrMessage(message, MSG_TYPE_DATA); if (!msg) { - std::cout << "MsptiMonitor::ConstructStrMessage failed, message: " << message << std::endl; + LOG(ERROR) << "MsptiMonitor ConstructStrMessage failed, message: " << message; return; } auto ipcClient = DynoLogNpuMonitor::GetInstance()->GetIpcClient(); if (!ipcClient) { - std::cout << "DynoLogNpuMonitor ipcClient is nullptr" << std::endl; + LOG(ERROR) << "DynoLogNpuMonitor ipcClient is null"; return; } if (!ipcClient->SyncSendMessage(*msg, destName, maxRetry, retryWaitTimeUs)) { - std::cout << "send mspti message failed: " << message << std::endl; - perror("send mspti message failed"); + PLOG(ERROR) << "send mspti message failed" << message; } } } // namespace ipc_monitor diff --git a/dynolog_npu/plugin/ipc_monitor/MsptiMonitor.h b/dynolog_npu/plugin/ipc_monitor/MsptiMonitor.h index 82ba15c8a2e9416ff11e0519e28fdc67dfa3109a..0a73356065374ac3a52dd4beb2e61b5655b0cbab 100644 --- a/dynolog_npu/plugin/ipc_monitor/MsptiMonitor.h +++ b/dynolog_npu/plugin/ipc_monitor/MsptiMonitor.h @@ -21,6 +21,7 @@ public: void DisableActivity(msptiActivityKind kind); void SetFlushInterval(uint32_t interval); bool IsStarted(); + std::set GetEnabledActivities(); private: static void BufferRequest(uint8_t **buffer, size_t *size, size_t *maxNumRecords); diff --git a/dynolog_npu/plugin/ipc_monitor/NpuIpcClient.cpp b/dynolog_npu/plugin/ipc_monitor/NpuIpcClient.cpp index 3c8eec276a32b882ff6e0247f83712ca7f1326bd..938393ffdaf5796456710f64a106efa91d21e064 100644 --- a/dynolog_npu/plugin/ipc_monitor/NpuIpcClient.cpp +++ b/dynolog_npu/plugin/ipc_monitor/NpuIpcClient.cpp @@ -1,6 +1,5 @@ #include "NpuIpcClient.h" - -#include +#include namespace dynolog_npu { namespace ipc_monitor { @@ -15,20 +14,21 @@ bool IpcClient::RegisterInstance(int32_t id) std::unique_ptr message = Message::ConstructMessage(context, MSG_TYPE_CONTEXT); try { if (!SyncSendMessage(*message, DYNO_IPC_NAME)) { - std::cout << "[WARNING]Failed to send register ctxt for pid " << context.pid << " with dyno" << std::endl; + LOG(WARNING) << "Failed to send register ctxt for pid " << context.pid << " with dyno"; return false; } } catch (const std::exception &e) { - std::cout << "[WARNING] Error when SyncSendMessage: " << e.what() << std::endl; + LOG(WARNING) << "Error when SyncSendMessage: " << e.what(); return false; } - std::cout << "[INFO] Resigter pid " << context.pid << " for dynolog success !" << std::endl; + LOG(INFO) << "Resigter pid " << context.pid << " for dynolog success!"; return true; } + std::string IpcClient::IpcClientNpuConfig() { auto size = pids_.size(); - auto *req = (NpuRequest *)malloc(sizeof(NpuRequest) + sizeof(int32_t) * size); + auto *req = ReinterpretConvert(malloc(sizeof(NpuRequest) + sizeof(int32_t) * size)); req->type = DYNO_IPC_TYPE; req->pidSize = size; req->jobId = JOB_ID; @@ -37,7 +37,7 @@ std::string IpcClient::IpcClientNpuConfig() } std::unique_ptr message = Message::ConstructMessage(*req, MSG_TYPE_REQUEST, size); if (!SyncSendMessage(*message, DYNO_IPC_NAME)) { - std::cout << "[WARNING] Failed to send config to dyno server fail !" << std::endl; + LOG(WARNING) << "Failed to send config to dyno server"; free(req); req = nullptr; return ""; @@ -45,13 +45,13 @@ std::string IpcClient::IpcClientNpuConfig() free(req); message = PollRecvMessage(MAX_IPC_RETRIES, MAX_SLEEP_US); if (!message) { - std::cout << "[WARNING] Failed to receive on-demand config !" << std::endl; + LOG(WARNING) << "Failed to receive on-demand config"; return ""; } std::string res = std::string(ReinterpretConvert(message->buf.get()), message->metadata.size); - return res; } + std::unique_ptr IpcClient::ReceiveMessage() { std::lock_guard wguard(dequeLock_); @@ -62,10 +62,11 @@ std::unique_ptr IpcClient::ReceiveMessage() msgDynoDeque_.pop_front(); return message; } + bool IpcClient::SyncSendMessage(const Message &message, const std::string &destName, int numRetry, int seepTimeUs) { if (destName.empty()) { - std::cout << "[WARNING] Can not send to empty socket name !" << std::endl; + LOG(WARNING) << "Can not send to empty socket name!"; return false; } int i = 0; @@ -79,11 +80,12 @@ bool IpcClient::SyncSendMessage(const Message &message, const std::string &destN seepTimeUs *= 2; // 2: double sleep time } } catch (const std::exception &e) { - std::cout << "[ERROR] Error when SyncSendMessage: " << e.what() << std::endl; + LOG(ERROR) << "Error when SyncSendMessage: " << e.what(); return false; } return i < numRetry; } + bool IpcClient::Recv() { try { @@ -94,7 +96,7 @@ bool IpcClient::Recv() try { successFlag = ep_.TryPeekMessage(*peekCtxt); } catch (std::exception &e) { - std::cout << "[ERROR] Error when TryPeekMessage: " << e.what() << std::endl; + LOG(ERROR) << "Error when TryPeekMessage: " << e.what(); return false; } if (successFlag) { @@ -108,7 +110,7 @@ bool IpcClient::Recv() try { successFlag = ep_.TryRcvMessage(*recvCtxt); } catch (std::exception &e) { - std::cout << "[ERROR] Error when TryRecvMsg: " << e.what() << std::endl; + LOG(ERROR) << "Error when TryRecvMsg: " << e.what(); return false; } if (successFlag) { @@ -118,11 +120,12 @@ bool IpcClient::Recv() } } } catch (std::exception &e) { - std::cout << "[ERROR] Error in Recv(): " << e.what() << std::endl; + LOG(ERROR) << "Error in Recv(): " << e.what(); return false; } return false; } + std::unique_ptr IpcClient::PollRecvMessage(int maxRetry, int sleeTimeUs) { for (int i = 0; i < maxRetry; i++) { @@ -133,6 +136,5 @@ std::unique_ptr IpcClient::PollRecvMessage(int maxRetry, int sleeTimeUs } return nullptr; } - } // namespace ipc_monitor -} // namespace dynolog_npu \ No newline at end of file +} // namespace dynolog_npu diff --git a/dynolog_npu/plugin/ipc_monitor/NpuIpcClient.h b/dynolog_npu/plugin/ipc_monitor/NpuIpcClient.h index 5054e4d635ca039b59ff55c5e11efd9dbe111818..90827777a91eeac49251e246ce75f4eec4f24942 100644 --- a/dynolog_npu/plugin/ipc_monitor/NpuIpcClient.h +++ b/dynolog_npu/plugin/ipc_monitor/NpuIpcClient.h @@ -1,14 +1,9 @@ #ifndef NPU_IPC_CLIENT_H #define NPU_IPC_CLIENT_H -#include -#include +#include #include #include -#include -#include -#include -#include #include "NpuIpcEndPoint.h" #include "utils.h" @@ -118,7 +113,6 @@ private: bool Recv(); std::unique_ptr PollRecvMessage(int maxRetry, int sleeTimeUs); }; - } // namespace ipc_monitor } // namespace dynolog_npu diff --git a/dynolog_npu/plugin/ipc_monitor/NpuIpcEndPoint.h b/dynolog_npu/plugin/ipc_monitor/NpuIpcEndPoint.h index a3186736acba7efd0777caf87114957c8a7365a0..1fedaeb62d096004ded64abf3f1627f85c66f6f7 100644 --- a/dynolog_npu/plugin/ipc_monitor/NpuIpcEndPoint.h +++ b/dynolog_npu/plugin/ipc_monitor/NpuIpcEndPoint.h @@ -1,15 +1,12 @@ #ifndef NPU_IPC_ENDPOINT_H #define NPU_IPC_ENDPOINT_H -#include + #include #include #include #include #include #include -#include -#include -#include #include "utils.h" namespace dynolog_npu { @@ -59,10 +56,12 @@ public: chmod(address.sun_path, SOCKET_FD_CHMOD); } } + ~NpuIpcEndPoint() { close(socketFd); } + [[nodiscard]] auto BuildSendNpuCtxt(const std::string &desAddrName, const std::vector &npuPayLoad, const std::vector &fileDes) { @@ -197,8 +196,7 @@ protected: return ctxt; } }; - } // namespace ipc_monitor } // namespace dynolog_npu -#endif +#endif // NPU_IPC_ENDPOINT_H diff --git a/dynolog_npu/plugin/ipc_monitor/PyDynamicMonitorProxy.h b/dynolog_npu/plugin/ipc_monitor/PyDynamicMonitorProxy.h index 690a76eb84a1de0eae1009003996861832fd3771..66deb94afe4cae140ca694bd1a5c85ab0b04c38e 100644 --- a/dynolog_npu/plugin/ipc_monitor/PyDynamicMonitorProxy.h +++ b/dynolog_npu/plugin/ipc_monitor/PyDynamicMonitorProxy.h @@ -1,8 +1,7 @@ #ifndef PYDYNAMIC_MONITOR_PROXY_H #define PYDYNAMIC_MONITOR_PROXY_H -#include -#include +#include #include "MonitorBase.h" #include "DynoLogNpuMonitor.h" @@ -14,23 +13,29 @@ public: PyDynamicMonitorProxy() = default; bool InitDyno(int npuId) { - try { - monitor_ = DynoLogNpuMonitor::GetInstance(); - monitor_->SetNpuId(npuId); - bool res = monitor_->Init(); - return res; - } catch (const std::exception &e) { - std::cout << "[ERROR] Error when init dyno " << e.what() << std::endl; - return false; - } + try { + if (!google::IsGoogleLoggingInitialized()) { + google::InitGoogleLogging("DynoLogNpuMonitor"); + google::SetLogDestination(google::GLOG_INFO, "/var/log/dynolog_npu_"); + google::SetLogFilenameExtension(".log"); + } + monitor_ = DynoLogNpuMonitor::GetInstance(); + monitor_->SetNpuId(npuId); + bool res = monitor_->Init(); + return res; + } catch (const std::exception &e) { + LOG(ERROR) << "Error when init dyno " << e.what(); + return false; + } } std::string PollDyno() { - return monitor_->Poll(); - }; + return monitor_->Poll(); + } - void EnableMsptiMonitor(std::unordered_map& config_map) { + void EnableMsptiMonitor(std::unordered_map& config_map) + { DynoLogNpuMonitor::GetInstance()->EnableMsptiMonitor(config_map); } @@ -40,5 +45,4 @@ private: } // namespace ipc_monitor } // namespace dynolog_npu - -#endif +#endif // PYDYNAMIC_MONITOR_PROXY_H diff --git a/dynolog_npu/plugin/ipc_monitor/log.cpp b/dynolog_npu/plugin/ipc_monitor/log.cpp deleted file mode 100644 index 208b0284de4f0efc8c4b29db94c52f6fc372271e..0000000000000000000000000000000000000000 --- a/dynolog_npu/plugin/ipc_monitor/log.cpp +++ /dev/null @@ -1,10 +0,0 @@ -#include "log.h" -#include -#include - -void Log::PrintMsg(const std::string& msg, const std::string& level) const -{ - std::ostringstream oss; - oss << msg << "\n"; - std::cout << oss.str(); -} \ No newline at end of file diff --git a/dynolog_npu/plugin/ipc_monitor/log.h b/dynolog_npu/plugin/ipc_monitor/log.h deleted file mode 100644 index 27e0e0f0a11e726f051d0fc3f53f4eca0fa53c30..0000000000000000000000000000000000000000 --- a/dynolog_npu/plugin/ipc_monitor/log.h +++ /dev/null @@ -1,23 +0,0 @@ -#pragma once -#include "singleton.h" - -#define PRINT_INFO(format, ...) \ - do { \ - Log::GetInstance()->PrintMsg(format, "[INFO]"); \ - } while(0) \ - -#define PRINT_WARNING(format, ...) \ - do { \ - Log::GetInstance()->PrintMsg(format, "[WARNING]"); \ - } while(0) \ - -#define PRINT_ERROR(format, ...) \ - do { \ - Log::GetInstance()->PrintMsg(format, "[ERROR]"); \ - } while(0) \ - -class Log : public dynolog_npu::ipc_monitor::Singleton -{ -public: - void PrintMsg(const std::string& msg, const std::string& level) const; -}; diff --git a/dynolog_npu/plugin/ipc_monitor/utils.cpp b/dynolog_npu/plugin/ipc_monitor/utils.cpp index 28bdd40f9568e3c97475c89cea808fe92bb39cb9..fce7103f4896b112e058961c0cce27901f70bab1 100644 --- a/dynolog_npu/plugin/ipc_monitor/utils.cpp +++ b/dynolog_npu/plugin/ipc_monitor/utils.cpp @@ -1,11 +1,14 @@ #include "utils.h" - +#include +#include #include #include -#include +#include +#include #include - -#include "log.h" +#include +#include +#include namespace dynolog_npu { namespace ipc_monitor { @@ -52,11 +55,9 @@ std::string formatErrorCode(SubModule submodule, ErrCode errorCode) oss << "ERR" << std::setw(2) << std::setfill('0') << static_cast(submodule); // 2: 字段宽度 oss << std::setw(3) << std::setfill('0') << static_cast(errorCode); // 3: 字段宽度 oss << " " << submoduleMap[submodule] << " " << errCodeMap[errorCode]; - return oss.str(); }; - int32_t GetProcessId() { return static_cast(getpid()); @@ -75,11 +76,10 @@ std::pair GetParentPidAndCommand(int32_t pid) if (std::getline(statFile, line)) { int ret = sscanf(line.c_str(), "%*d (%[^)]) %*c %d", command.data(), &parentPid); if (ret == 2) { // 2: 接收到2个字符 - std::cout << "[INFO] Success to get parent pid: " << parentPid << std::endl; return std::make_pair(parentPid, command); } } - std::cout << "[WARNING] Failed to parse /proc/" << pid << "/stat" << std::endl; + LOG(WARNING) << "Failed to parse /proc/" << pid << "/stat"; return std::make_pair(0, ""); } @@ -104,8 +104,10 @@ std::vector GetPids() for (const auto &pidPair : pids) { res.push_back(pidPair.first); } + LOG(INFO) << "Success to get parent pid: " << res; return res; } + std::string GenerateUuidV4() { static std::random_device randomDevice; @@ -138,27 +140,28 @@ std::string GenerateUuidV4() return stringStream.str(); } -bool Str2Uint32(uint32_t& dest, const std::string& str) +bool Str2Uint32(uint32_t& dest, const std::string& str) { if (str.empty()) { - PRINT_INFO("Str to uint32 fail, input string is null"); + LOG(ERROR) << "Str to uint32 failed, input string is null"; return false; } size_t pos = 0; try { dest = static_cast(std::stoul(str, &pos)); } catch(...) { - PRINT_INFO("Str to uint32 fail, input string is %s", numStr.c_str()); + LOG(ERROR) << "Str to uint32 failed, input string is " << str; return false; } if (pos != str.size()) { - PRINT_INFO("Str to uint32 fail, input string is %s", numStr.c_str()); + LOG(ERROR) << "Str to uint32 failed, input string is " << str; return false; } return true; } -bool Str2Bool(bool& dest, const std::string& str) { +bool Str2Bool(bool& dest, const std::string& str) +{ std::string lower_str = str; std::transform(lower_str.begin(), lower_str.end(), lower_str.begin(), ::tolower); @@ -171,7 +174,7 @@ bool Str2Bool(bool& dest, const std::string& str) { dest = false; return true; } - PRINT_ERROR("Invalid boolean string: %s", std.c_str()); + LOG(ERROR) << "Str to bool failed, input string is " << str; return false; } diff --git a/dynolog_npu/plugin/ipc_monitor/utils.h b/dynolog_npu/plugin/ipc_monitor/utils.h index 7a0ef7ec2522aaf4e340af9e3745c61d8ae7a9e1..728fcb2608fa835f649675fe9a325e956114bbdf 100644 --- a/dynolog_npu/plugin/ipc_monitor/utils.h +++ b/dynolog_npu/plugin/ipc_monitor/utils.h @@ -1,17 +1,10 @@ #ifndef IPC_MONITOR_UTILS_H #define IPC_MONITOR_UTILS_H -#include -#include + #include #include #include -#include -#include -#include -#include -#include -#include -#include +#include namespace dynolog_npu { @@ -48,7 +41,6 @@ enum class ErrCode { PERMISSION = 12, }; - std::string formatErrorCode(SubModule submodule, ErrCode errorCode); #define IPC_ERROR(error) formatErrorCode(SubModule::IPC, error) @@ -58,8 +50,6 @@ inline T ReinterpretConvert(V ptr) { return reinterpret_cast(ptr); } - } // namespace ipc_monitor } // namespace dynolog_npu - -#endif +#endif // IPC_MONITOR_UTILS_H diff --git a/dynolog_npu/plugin/setup.py b/dynolog_npu/plugin/setup.py index c6d42cf984dae1f6c79c0968e0f179cbf4df3080..28e558b2f20bfd4cf55fc088da762a1b8e62e79f 100644 --- a/dynolog_npu/plugin/setup.py +++ b/dynolog_npu/plugin/setup.py @@ -18,6 +18,10 @@ from setuptools import setup from pybind11.setup_helpers import Pybind11Extension BASE_DIR = os.path.dirname(os.path.realpath(__file__)) +DYNOLOG_PATH = os.path.join(os.path.dirname(BASE_DIR), "third_party", "dynolog") +JSON_INC_PATH = os.path.join(DYNOLOG_PATH, "third_party", "json", "single_include") +GLOG_INC_PATH = os.path.join(DYNOLOG_PATH, "third_party", "glog", "src") +GLOG_LIB_PATH = os.path.join(DYNOLOG_PATH, "build", "third_party", "glog") # Define the extension module ext_modules = [ @@ -25,12 +29,11 @@ ext_modules = [ "IPCMonitor", # Name of the Python module sources=["bindings.cpp"] + list(glob("ipc_monitor/*.cpp")), # Source files include_dirs=[os.path.join(BASE_DIR, "ipc_monitor"), # Include directories - os.path.join(os.path.dirname(BASE_DIR), - "third_party", "dynolog", "third_party", "json", "single_include")], + JSON_INC_PATH, GLOG_INC_PATH, GLOG_LIB_PATH], extra_compile_args=["-std=c++14", "-fPIC", "-fstack-protector-all", "-fno-strict-aliasing", "-fno-common", "-fvisibility=hidden", "-fvisibility-inlines-hidden", "-Wfloat-equal", "-Wextra", "-O2"], - library_dirs=[os.path.join(BASE_DIR, "stub")], - libraries=["mspti", "pthread"], + library_dirs=[os.path.join(BASE_DIR, "stub"), GLOG_LIB_PATH], + libraries=["mspti", "pthread", "glog"], language="c++", # Specify the language ), ] diff --git a/plugins/mindstudio-insight-plugins/Scalar/front/src/api/lossApi.ts b/plugins/mindstudio-insight-plugins/Scalar/front/src/api/lossApi.ts index 199b41517553b31a0d8f5b4fabe68e9af9cbb783..478a45b61478ec271ec3f376ba4fa5cd9bacc292 100644 --- a/plugins/mindstudio-insight-plugins/Scalar/front/src/api/lossApi.ts +++ b/plugins/mindstudio-insight-plugins/Scalar/front/src/api/lossApi.ts @@ -27,6 +27,16 @@ interface NewFileItem { dir: string; fileList: string[]; } +export interface tokenParam { + file: string; + globalBatchSize: number; + seqLength: number; +} +export interface fileMergeParam { + action: 'merge' | 'unset'; + name: string; + fileList: string[]; +} export interface NewFileResponseBody { data: NewFileItem[]; } @@ -41,18 +51,18 @@ export interface FileItem { fileList: Array<{ name: string; path: string, dirs: string[] }>; }; export interface NormalChart { - type: 'normal' + type: 'normal'; + enable: boolean; } export interface TokenChart { type: 'token'; - globalBatchSize: number; - seqLength: number; + enable: boolean; } export interface SmoothingChart { type: 'smoothing'; + enable: boolean; algorithm: string; weight: number; - offset: number; } export interface DateItem { step: number; @@ -63,10 +73,15 @@ export interface ChartsDataRequestParams { graphList: Array<{ tag: string; file: string; - offset: number; + start: number; + end: number; graphConfig: [NormalChart, TokenChart, SmoothingChart] }>; }; +export interface QueryParamItem { + start: number; + end: number; +} export interface ChartsDataResponseBody { data: ChartsDataItem[]; }; @@ -77,19 +92,35 @@ export interface LineItem { name: string; color: string; } +export interface TokenItem { + file: string; + tag: string[]; + globalBatchSize: number; + seqLength: number; +} +export interface TokenResponseBody { + data: TokenItem[] +} +export interface FileMergeResponseBody { + data: { + action: 'merge' | 'unset'; + tags: string[]; + file: string; + fileList: string[]; + } +} interface ChartsDataItem { tag: string; file: string; - normal: DataItem[]; - token: DataItem[]; - normalSmoothing: DataItem[]; - tokenSmoothing: DataItem[]; + normal: DataItem; + token: DataItem; + normalSmoothing: DataItem; + tokenSmoothing: DataItem; dateConfig: DateItem[] }; export interface DataItem { [key: string]: { value: number, date: string, wallTime: number } }; - let controller = new AbortController(); export const getChartsData = async (params: ChartsDataRequestParams): Promise> => { controller.abort(); @@ -127,5 +158,17 @@ export const getIncrementalTag = async (): Promise> => { + return await request({ + url: '/ScalarVisually/TokenParamSet', + method: 'post', + data: { params } + }); +}; +export const fileMergeOrUnset = async (params: fileMergeParam): Promise> => { + return await request({ + url: '/ScalarVisually/FileMerge', + method: 'post', + data: params + }); +}; diff --git a/plugins/mindstudio-insight-plugins/Scalar/front/src/components/LossShow/Chart.tsx b/plugins/mindstudio-insight-plugins/Scalar/front/src/components/LossShow/Chart.tsx index 59803a9eb724bc03c297ee6e3ed4c14e8edea6b2..3546012b73cd7e14ed8f10223ac2032d4b45892e 100644 --- a/plugins/mindstudio-insight-plugins/Scalar/front/src/components/LossShow/Chart.tsx +++ b/plugins/mindstudio-insight-plugins/Scalar/front/src/components/LossShow/Chart.tsx @@ -13,257 +13,21 @@ import { ColorPicker, Input, Tooltip, + message, type TablePaginationConfig, type TableProps, } from 'antd'; -import type { DataItem } from '@/api/lossApi'; +import { setTokenParam, fileMergeOrUnset, tokenParam, fileMergeParam } from '@/api/lossApi'; import Chart from '../Echarts'; -import type { LossShowInfo } from '@/entity/lossShow'; -import { downloadCsv, keepDecimals, notZero, type HeaderType } from '@/utils/common'; -import type { ECBasicOption } from 'echarts/types/dist/shared'; +import type { LossShowInfo, CheckListItem } from '@/entity/lossShow'; import { Resizer } from '../common/Resizer'; import { useTranslation } from 'react-i18next'; import { Smoothing } from './Smoothing'; import { FileTreeList } from './FileTreeList'; -import { TFunction } from 'i18next'; import eventBus from '@/eventBus'; -interface MaxMin { - max: number; - min: number; -} -interface ChartDataLenInfo { - [key: string]: MaxMin; - dataLen: MaxMin; - smoothingDataLen: MaxMin; -}; -interface ChartParam { - columns: string[]; - tableDataSource: Array<{ [key: string]: string | number | null }>; - tableDataSourceSmoothing: Array<{ [key: string]: string | number | null }>; - smoothingColumns: string[]; - dateConfig: { [key: string]: { coord: [string, string], value: string }[] }; - lineCallback: (name: string) => { fullName: string, abbName: string }; - colorCallback: (key: string, blur: string) => string; -} -const creatDataChart = (chartParam: ChartParam) => { - const { columns, tableDataSource, tableDataSourceSmoothing, smoothingColumns, dateConfig, lineCallback, colorCallback } = chartParam; +import { creatDataChart, creatComparisonChart, getTableData, getData, initTokenOption, exportFile, type ChartParam, ChartDataLenInfo } from './utils'; - const source = tableDataSource.map((item, index) => { - return { ...tableDataSourceSmoothing[index], ...item }; - }); - - const dimensions: string[] = []; - columns.forEach((item, index) => { - dimensions.push(item, smoothingColumns[index]); - }); - - return { - animation: false, - tooltip: { - trigger: 'axis', - className: 'chartTooltip', - axisPointer: { - type: 'cross', - }, - formatter: (seriesList: any[]) => chartTooltipFormat(seriesList, dimensions), - }, - xAxis: { - type: 'category', - }, - yAxis: { - }, - series: dimensions.map(_item => { - if (_item.endsWith('smoothing')) { - return { - symbolSize: 5, - type: 'line', - animation: false, - progressive: 0, - showSymbol: false, - lineStyle: { - color: colorCallback(_item, '1'), - }, - itemStyle: { - color: colorCallback(_item, '1'), - }, - }; - } else { - return { - symbolSize: 5, - type: 'line', - animation: false, - progressive: 0, - showSymbol: false, - lineStyle: { - color: colorCallback(_item, '0.4'), - }, - itemStyle: { - color: colorCallback(_item, '0.4'), - }, - markPoint: { - data: dateConfig[_item], - symbolSize: [95, 50], - label: { - position: 'inside', - show: true, - distance: 5, - fontWeight: 'lighter', - fontSize: 10, - borderType: 'solid', - heigth: 50 - } - } - }; - } - }), - dataset: { - dimensions: ['step', ...dimensions], - source, - }, - legend: { - show: true, - textStyle: { - color: 'rgb(141, 152, 170)', - }, - data: [...columns], - formatter: (name: string) => { - const { abbName } = lineCallback(name); - return abbName; - }, - ellipsis: true, - tooltip: { - show: true, - formatter: (params: { name: string }) => { - const { fullName } = lineCallback(params.name); - return fullName; - } - } - }, - }; -}; - -const chartTooltipFormat = (seriesList: any[], dimensions: string[]) => { - if (seriesList.length < 1) { - return; - } - const addNum = (seriesList.length === 2 && dimensions.length === 4) ? 1 : 2; //不展示smoothing的图例,只展示原始图例,最多两条曲线 - const div = document.createElement('div'); - div.className = 'tooltip'; - div.append(`Step: ${seriesList[0].data.step}`); - const tooltipItem = document.createElement('div'); - tooltipItem.className = 'tooltipItem'; - for (let i = 0; i < seriesList.length; i += addNum) { - const seriesItem = seriesList[i]; - const circle = document.createElement('div'); - circle.className = 'circle'; - circle.setAttribute('style', `background-color:${seriesList[seriesList.length % 2 ? i : (dimensions.length === 4 ? i : i + 1)].color};`); - const keyDom = document.createElement('div'); - const valueDom = document.createElement('div'); - keyDom.className = 'value'; - valueDom.className = 'smoothing'; - if (seriesList.length === 2 && dimensions.length === 4) { - keyDom.append(`value: ${seriesItem.data[dimensions[i === 0 ? i : i + 1]]}`); - valueDom.append(`smoothing: ${seriesItem.data[dimensions[i === 0 ? i + 1 : i + 2]]}`); - } else { - keyDom.append(`value: ${seriesItem.data[dimensions[i]]}`); - valueDom.append(`smoothing: ${seriesItem.data[dimensions[i + 1]]}`); - } - tooltipItem.appendChild(circle); - tooltipItem.appendChild(keyDom); - tooltipItem.appendChild(valueDom); - } - div.appendChild(tooltipItem); - - return div; -}; - -const creatComparisonChart = (source: Array<{ [key: string]: string | number | null }> = [], t: TFunction): ECBasicOption => { - - return { - animation: false, - tooltip: { - trigger: 'axis', - axisPointer: { - type: 'cross', - }, - }, - xAxis: { - type: 'category', - }, - yAxis: [ - { - type: 'value', - }, - { - type: 'value', - position: 'left', - axisLabel: { - formatter: '{value} %', - } - }, - ], - series: ['comparisonNormal', 'comparisonAbsolute', 'comparisonRelative'].map(name => ({ - name: t(name), - symbolSize: 5, - type: 'line', - animation: false, - progressive: 0, - yAxisIndex: 0, - showSymbol: false, - })), - dataset: { - dimensions: ['step', 'Comparison Normal', 'Comparison Absolute', 'Comparison Relative'], - source, - }, - legend: { - textStyle: { - color: 'rgb(141, 152, 170)', - }, - selectedMode: 'single', // 每次打开一个图例 - selected: {}, - }, - }; -}; - -const getTableData = (data: { [key: string]: DataItem }) => { - const tableDataSource: Array<{ key: number, step: string, [key: string]: number | string | null }> = []; - const keys = Object.keys(data); - - if (keys.length === 0) { - return { tableDataSource, len: { max: -1, min: -1 } }; - } - - let len: number = 0; - let steps: string[] = []; - const keys0 = Object.keys(data[keys[0]]); - let keys1 = []; - if (keys.length < 2) { - steps = keys0; - len = steps.length; - } else { - keys1 = Object.keys(data[keys[1]]); - steps = [...new Set([...keys0, ...keys1])]; - len = steps.length; - } - for (let i = 0; i < len; i++) { - tableDataSource.push({ step: String(steps[i]), key: i }); - } - for (let i = 0; i < tableDataSource.length; i++) { - for (let j = 0; j < keys.length; j++) { - const key = tableDataSource[i].step; - tableDataSource[i][keys[j]] = data[keys[j]][key]?.value || null; - } - } - return { - tableDataSource, len: { - max: len, - min: keys.length < 2 - ? -1 : Math.min(keys0.length, keys1.length), - }, - }; -}; - -const DataChart = observer(({ lossShowInfo, tag, isExpand }: +export const DataChart = observer(({ lossShowInfo, tag, isExpand }: { lossShowInfo: LossShowInfo; tag: string | string[]; isExpand: boolean }): JSX.Element => { const { t } = useTranslation('lossShow'); const [columns, setColumns] = useState['columns']>([]); @@ -277,25 +41,16 @@ const DataChart = observer(({ lossShowInfo, tag, isExpand }: const [enableToken, setToken] = useState(false); const [showModal, setShowModal] = useState(false); const [showLineModal, setShowLineModal] = useState(false); - const initTokenOption = () => { - const res: { [key: string]: { [key: string]: { globalBatchSize: number, seqLength: number } } } = {}; //记录每一个tag的modal框内的globalBatchSize和seqLength - const fileList = lossShowInfo.getFileListByModel('isChecked'); //所有文件列表 - fileList.forEach(item => { - res[item.tag] = {}; - item.file.forEach(file => { - const oneConfig = lossShowInfo.tokenOptions?.[item.tag]?.files?.[file.filePath]; - res[item.tag][file.filePath] = { globalBatchSize: oneConfig ? oneConfig.globalBatchSize : -1, seqLength: oneConfig ? oneConfig.seqLength : -1 }; - }); - }); - return res; - }; + const [showMergeModal, setShowMergeModal] = useState(false); + const [messageApi, contextHolder] = message.useMessage(); const [lineConfig, setLineConfig] = useState<{ name: string, color: string, key: string }[]>(lossShowInfo.getLineConfig(tag)); //名称和color的配置 - const [changeKey, setChangeKey] = useState(''); //要修改颜色和图例的文件的key值 - const [tokenOption, setTokenOption] = useState(initTokenOption()); //token配置 + const [changeKey, setChangeKey] = useState(''); //要修改颜色和图例或token的文件的key值 + const [tokenOption, setTokenOption] = useState(initTokenOption(lossShowInfo)); //token配置 const [batchTokenOption, setBatchTokenOption] = useState<{ globalBatchSize: number, seqLength: number }>({ globalBatchSize: -1, seqLength: -1 }); //初始化批量设置token的输入框的值 const [originToken, setOriginToken] = useState(tokenOption); //每次修改token前做储存,以便取消时恢复原样 const [checkedKeys, setCheckedKeys] = useState([]); //记录当前选择的所有文件 const [changeConfig, setChangeConfig] = useState(''); //修改的配置类型 + const [mergeName, setMergeName] = useState(''); //合并的文件名称 const pagination: TablePaginationConfig = { simple: true, align: 'center', @@ -319,7 +74,7 @@ const DataChart = observer(({ lossShowInfo, tag, isExpand }: */ const updateModalConfig = () => { setLineConfig(lossShowInfo.getLineConfig(tag)); - setTokenOption(initTokenOption()); + setTokenOption(initTokenOption(lossShowInfo)); }; const resizeChart = () => { setResizeChartFlag(Math.random()); @@ -396,18 +151,21 @@ const DataChart = observer(({ lossShowInfo, tag, isExpand }: } }; - const exportFile = (tag: string | string[]): void => { - const fileName = `${typeof tag === 'string' ? tag : 'tagAggregation'}_${Date.now()}.csv`; - downloadCsv(columns?.map(item => ({ key: item.key, title: item.title })) as HeaderType[], tableData || [], fileName); - }; - /** * @description: token的modal框确认事件 * @return {*} */ - const handleOk = () => { + const handleOk = async () => { + let tokenParams: tokenParam[] = []; //批量操作token时需要特殊处理 if (changeConfig === 'tokenBatch') { + if (judgeIsMerge()) { + messageApi.open({ + type: 'error', + content: t('batchTokenTip'), + }); + return; + } const { globalBatchSize, seqLength } = batchTokenOption; checkedKeys.forEach(fileKey => { if (typeof tag === 'string') { @@ -422,9 +180,26 @@ const DataChart = observer(({ lossShowInfo, tag, isExpand }: } }); } - setBatchTokenOption({ globalBatchSize: -1, seqLength: -1 }); - setOriginToken(JSON.parse(JSON.stringify(tokenOption))); lossShowInfo.modifyTokenOption(tag, tokenOption, enableToken); + if (changeConfig === 'tokenBatch') { + tokenParams = lossShowInfo.getTokenParams(tag, checkedKeys); + } else { + tokenParams = lossShowInfo.getTokenParams(tag, changeKey); + } + const fileTagMap = (await setTokenParam(tokenParams)).body.data || []; + //同一文件可能对应多个tag,将不同tag下的同一文件的token统一 + if (fileTagMap.length) { + fileTagMap.forEach(item => { + item.tag.forEach(tag => { + tokenOption[tag][item.file].globalBatchSize = item.globalBatchSize; + tokenOption[tag][item.file].seqLength = item.seqLength; + }); + lossShowInfo.modifyTokenOption(item.tag, tokenOption, enableToken); + }); + } + setOriginToken(JSON.parse(JSON.stringify(tokenOption))); + //将批量操作token弹框内的值初始化 + setBatchTokenOption({ globalBatchSize: -1, seqLength: -1 }); setShowModal(false); if (enableToken) { eventBus.emit('updataChartData'); @@ -508,7 +283,7 @@ const DataChart = observer(({ lossShowInfo, tag, isExpand }: {newTag} : <>} {changeConfig !== 'tokenBatch' ?
- FilePath: + {t('filePath')}: {file}
: <>}
@@ -598,6 +373,92 @@ const DataChart = observer(({ lossShowInfo, tag, isExpand }: setOriginToken(JSON.parse(JSON.stringify(tokenOption))); } }; + /** + * @description: 判断已选择文件中是否包含虚拟文件 + * @return {*} + */ + const judgeIsMerge = () => { + if (typeof tag === 'string') { + return checkedKeys.some(key => lossShowInfo.mergeFileList[tag].some(i => i.file === key)); + } else { + return tag.some(t => checkedKeys.some(key => lossShowInfo.mergeFileList[t].some(i => `${t}&&${i.file}:${i.mergeName}` === key))); + } + }; + /** + * @description: 处理聚合时传递给lossShowInfo.modifyCheck的参数 + * @param {string} keys 路径相关信息的字符串数组 + * @return {*} + */ + const handleAggCheckParam = (keys: string[]) => { + const param: CheckListItem[] = keys.map(item => { + const index = item.lastIndexOf(':'); + const tag = item.slice(0, index).split('&&')[0]; + const fileName = `${tag}:${item.slice(0, index).split('&&')[1]}`; + const filePath = `${tag}:${item.slice(index + 1)}`; + return { tag, filePath, fileName }; + }); + lossShowInfo.modifyCheck(param); + }; + /** + * @description: 合并或取消合并 + * @param {fileMergeParam} action merge|unset + * @param {string} filePath 虚拟文件路径 + * @return {*} + */ + const mergeOrUnset = async (action: fileMergeParam['action'], filePath?: string) => { + //虚拟文件名称不能为空 + if (action === 'merge' && !mergeName) { + messageApi.open({ + type: 'error', + content: t('mergeNameTip'), + }); + return; + } + //判断用户是否在合并时选择了虚拟文件 + if (action === 'merge' && judgeIsMerge()) { + messageApi.open({ + type: 'error', + content: t('mergeTip'), + }); + return; + } + const handleTagAggFile = (file: string) => { + const index = file.lastIndexOf(':'); + return file.slice(index + 1); + }; + const fileList: string[] = checkedKeys.map(item => handleTagAggFile(item)); + const param: fileMergeParam = { + action, + name: action === 'merge' ? mergeName : handleTagAggFile(filePath as string), + fileList: action === 'merge' ? fileList : [] + }; + const body = (await fileMergeOrUnset(param)).body; + lossShowInfo.handleMergeFile(mergeName, body); + if (action === 'merge') { + mergeHandleCancel(); + } else { + //处理勾选框 + if (!filePath) return; + const index = checkedKeys.findIndex(key => key === filePath); + if (index > -1) { + checkedKeys.splice(index, 1); + } + setCheckedKeys(checkedKeys); + if (typeof tag === 'string') { + lossShowInfo.modifyCheck({ tag, fileList: checkedKeys }); + } else { + handleAggCheckParam(checkedKeys); + } + } + }; + /** + * @description: merge的modal框取消事件 + * @return {*} + */ + const mergeHandleCancel = () => { + setMergeName(''); + setShowMergeModal(false); + }; /** * @description: 根据tag数量动态生成LineModal中的设置 * @param {string} key 文件名 @@ -616,17 +477,17 @@ const DataChart = observer(({ lossShowInfo, tag, isExpand }: {tag}
- FilePath: + {t('filePath')}: {filePath}
- Name: - { + {t('legendName')}: + { nameColorChange(line.key, 'name', e.target.value); }}>
- Color: + {t('legendColor')}: { nameColorChange(line.key, 'color', color); }}> @@ -647,10 +508,11 @@ const DataChart = observer(({ lossShowInfo, tag, isExpand }: //初始化lineConfig和tokenOption updateModalConfig(); }, []); + const updateConfig = { updateModalConfig, updateCheckedKeys: setCheckedKeys, showMergeModal: setShowMergeModal, mergeOrUnset }; return
- +
@@ -679,22 +541,33 @@ const DataChart = observer(({ lossShowInfo, tag, isExpand }: {generateLineModal(changeKey)}
+ { mergeOrUnset('merge'); }} width='800px' style={{ top: 280 }} className='tokenConfigModal' onCancel={mergeHandleCancel} maskClosable={false} key='mergeConfig' okText={t('ok')} cancelText={t('cancel')}> +
+
+ {t('mergeFileName')}: + { + setMergeName(e.target.value); + }}> +
+
+
{ showFileList.length > 0 && <>
- {(tableData?.length ?? 0 > 0) ? : <>} } + {contextHolder} ; }); -const ComparisonChart = observer(({ lossShowInfo, tag }: { lossShowInfo: LossShowInfo; tag: string | string[] }): JSX.Element => { +export const ComparisonChart = observer(({ lossShowInfo, tag }: { lossShowInfo: LossShowInfo; tag: string | string[] }): JSX.Element => { const { t } = useTranslation('lossShow'); const [chartOption, setChartOption] = useState({}); const [columns, setColumns] = useState['columns']>([]); @@ -740,45 +613,16 @@ const ComparisonChart = observer(({ lossShowInfo, tag }: { lossShowInfo: LossSho return res; }; - const getData = (fileList: string[], data: { [key: string]: DataItem }) => { - const dataSource: Array<{ [key: string]: string | number }> = []; - const keys = Object.keys(data); - if (keys.length < 2) { - return { dataSource, len: -1 }; - } - const keys0 = Object.keys(data[keys[0]]); - const keys1 = Object.keys(data[keys[1]]); - const steps: string[] = keys0.filter(item => keys1.includes(item)); - const len = steps.length; - for (let i = 0; i < len; i++) { - dataSource.push({ step: String(steps[i]), key: i }); - } - for (let i = 0; i < len; i++) { - const baseData = data[fileList[1]][steps[i]].value || 0; - const comparisonData = data[fileList[0]][steps[i]].value || 1; - if (typeof baseData === 'string' && typeof comparisonData === 'string') { - continue; - } - const dif = baseData - comparisonData; - dataSource[i]['Comparison Normal'] = dif; - dataSource[i]['Comparison Absolute'] = Math.abs(dif); - dataSource[i]['Comparison Relative'] = keepDecimals(Math.abs(dif) / notZero(comparisonData) * 100); - } - return { dataSource, len }; - }; - - const exportFile = (tag: string | string[]): void => { - const fileName = `${typeof tag === 'string' ? tag : 'tagAggregation'}_comparison_${Date.now()}.csv`; - downloadCsv(columns?.map(item => ({ key: item.key, title: item.title })) as HeaderType[], tableData as any[], fileName); - }; - const init = () => { const { data, columns } = lossShowInfo.getDataByTag(tag); const { dataSource } = getData(columns, data); - setTableData(dataSource); - setColumns(getColumns()); - setDataInfo({ baseLine: columns[0], comparison: columns[1] }); - setChartOption(creatComparisonChart(dataSource, t)); + //判断图表是否需要刷新 + if (JSON.stringify(dataSource) !== JSON.stringify(tableData)) { + setTableData(dataSource); + setColumns(getColumns()); + setDataInfo({ baseLine: columns[0], comparison: columns[1] }); + setChartOption(creatComparisonChart(dataSource, t)); + } }; const resizeChart = () => { @@ -793,17 +637,19 @@ const ComparisonChart = observer(({ lossShowInfo, tag }: { lossShowInfo: LossSho setColumns(getColumns()); setChartOption(creatComparisonChart(tableData, t)); }, [t]); - + useEffect(() => { + setTableData([]); //切换聚合状态时恢复默认值 + }, [lossShowInfo.isTagPolymerize]); return
{t('dataComparison')}
-
{t('baselineData')}:
+
{t('baselineData')}:
{dataInfo.baseLine}
-
{t('comparativeData')}:
+
{t('comparativeData')}:
{dataInfo.comparison}
@@ -816,7 +662,7 @@ const ComparisonChart = observer(({ lossShowInfo, tag }: { lossShowInfo: LossSho
- {(tableData?.length ?? 0 > 0) ? : <>} @@ -824,41 +670,4 @@ const ComparisonChart = observer(({ lossShowInfo, tag }: { lossShowInfo: LossSho ; }); -export const ChartsContainer = observer(({ lossShowInfo }: { lossShowInfo: LossShowInfo }) => { - const { t } = useTranslation('lossShow'); - const tags = lossShowInfo.showTagList; - const isShowComparisonChart = (tag: string): boolean => { - const { columns } = lossShowInfo.getDataByTag(tag); - return columns.length === 2; - }; - - return <> - { - tags.length > 0 ? tags.map(tag => (
-
{tag}
- - {isShowComparisonChart(tag) && } -
)) :
{t('noData')}
- } - ; -}); - -export const ChartsContainerTagAggregation = observer(({ lossShowInfo }: { lossShowInfo: LossShowInfo }) => { - const { t } = useTranslation('lossShow'); - - const tags = lossShowInfo.showTagList; - const isShowComparisonChart = (): boolean => { - const { columns } = lossShowInfo.getDataByTag(tags); - return columns.length === 2; - }; - return <> - { - tags.length > 0 ?
-
{t('tagAggregation')}
- - {isShowComparisonChart() && } -
:
{t('noData')}
- } - ; -}); diff --git a/plugins/mindstudio-insight-plugins/Scalar/front/src/components/LossShow/ChartsContainer.tsx b/plugins/mindstudio-insight-plugins/Scalar/front/src/components/LossShow/ChartsContainer.tsx new file mode 100644 index 0000000000000000000000000000000000000000..8a391a79185e87fff850b6972d08b8539b85db9c --- /dev/null +++ b/plugins/mindstudio-insight-plugins/Scalar/front/src/components/LossShow/ChartsContainer.tsx @@ -0,0 +1,45 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + */ +import { observer } from 'mobx-react'; +import type { LossShowInfo } from '@/entity/lossShow'; +import { useTranslation } from 'react-i18next'; +import { DataChart, ComparisonChart } from './Chart'; +export const ChartsContainer = observer(({ lossShowInfo }: { lossShowInfo: LossShowInfo }) => { + const { t } = useTranslation('lossShow'); + const tags = lossShowInfo.showTagList; + + const isShowComparisonChart = (tag: string): boolean => { + const { columns } = lossShowInfo.getDataByTag(tag); + return columns.length === 2; + }; + + return <> + { + tags.length > 0 ? tags.map(tag => (
+
{tag}
+ + {isShowComparisonChart(tag) && } +
)) :
{t('noData')}
+ } + ; +}); + +export const ChartsContainerTagAggregation = observer(({ lossShowInfo }: { lossShowInfo: LossShowInfo }) => { + const { t } = useTranslation('lossShow'); + + const tags = lossShowInfo.showTagList; + const isShowComparisonChart = (): boolean => { + const { columns } = lossShowInfo.getDataByTag(tags); + return columns.length === 2; + }; + return <> + { + tags.length > 0 ?
+
{t('tagAggregation')}
+ + {isShowComparisonChart() && } +
:
{t('noData')}
+ } + ; +}); \ No newline at end of file diff --git a/plugins/mindstudio-insight-plugins/Scalar/front/src/components/LossShow/FileTreeList.tsx b/plugins/mindstudio-insight-plugins/Scalar/front/src/components/LossShow/FileTreeList.tsx index 09800af1c404d76f30a060bab6a4ade8b10ba8c4..636eef9dd7e0d5901bcdb55621e9bd6b913ac698 100644 --- a/plugins/mindstudio-insight-plugins/Scalar/front/src/components/LossShow/FileTreeList.tsx +++ b/plugins/mindstudio-insight-plugins/Scalar/front/src/components/LossShow/FileTreeList.tsx @@ -1,7 +1,6 @@ /* - * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. */ - import React, { useEffect, useRef, useState } from 'react'; import styled from '@emotion/styled'; import { @@ -12,15 +11,17 @@ import type { CheckListItem, LossShowInfo } from '@/entity/lossShow'; import { useTranslation } from 'react-i18next'; import eventBus from '@/eventBus'; import { DataNode } from 'antd/es/tree'; +import { type fileMergeParam } from '@/api/lossApi'; interface Position { left: string; top: string; } interface TreeItem { - title: string; + title: string | JSX.Element; key: string; checkable?: boolean; selectable?: boolean; + isMerge?: boolean; children?: TreeItem[]; } @@ -32,8 +33,20 @@ interface MenuItemModel { visible: boolean; title?: string; } -export const FileTreeList = ({ lossShowInfo, tag, setConfigChange, updateModalConfig }: - { lossShowInfo: LossShowInfo; tag: string | string[], setConfigChange: (type: string, key: string, checkedKeys: React.Key[]) => void, updateModalConfig: () => void }): JSX.Element => { +interface UpdateConfig { + updateModalConfig: () => void; + updateCheckedKeys: React.Dispatch> + showMergeModal: React.Dispatch>; + mergeOrUnset: (action: fileMergeParam['action'], filePath?: string) => void; +} +interface FileTreeList { + lossShowInfo: LossShowInfo; + tag: string | string[]; + setConfigChange: (type: string, key: string, checkedKeys: React.Key[]) => void; + updateConfig: UpdateConfig; +} +export const FileTreeList = ({ lossShowInfo, tag, setConfigChange, updateConfig }: FileTreeList): JSX.Element => { + const { updateModalConfig, updateCheckedKeys, showMergeModal, mergeOrUnset } = updateConfig; const { t } = useTranslation('lossShow'); const [line, setLine] = useState(''); //记录文件key值 const MenuContainer = styled.div` @@ -77,25 +90,25 @@ export const FileTreeList = ({ lossShowInfo, tag, setConfigChange, updateModalCo } else { newCheckedKeys = checked; } + updateModalConfig();//及时更新lineConfig和tokenOption if (typeof tag === 'string') { - lossShowInfo.modifyCheck({ tag, fileList: newCheckedKeys.slice(-2) as string[] }); + lossShowInfo.modifyCheck({ tag, fileList: newCheckedKeys as string[] }); } else { - handleAggCheckParam(newCheckedKeys as string[], -2); + handleAggCheckParam(newCheckedKeys as string[]); } //新增勾选标签后,及时更新数据 if (checkedKeys.length < newCheckedKeys.length) { eventBus.emit('updataChartData'); } setCheckedKeys(newCheckedKeys as string[]); + updateCheckedKeys(newCheckedKeys as string[]); }; /** * @description: 处理聚合时传递给lossShowInfo.modifyCheck的参数 * @param {string} keys 路径相关信息的字符串数组 - * @param {number} start 起始索引 - * @param {number} end 结束索引 * @return {*} */ - const handleAggCheckParam = (keys: string[], start: number, end?: number) => { + const handleAggCheckParam = (keys: string[]) => { const param: CheckListItem[] = keys.map(item => { const index = item.lastIndexOf(':'); const tag = item.slice(0, index).split('&&')[0]; @@ -103,11 +116,7 @@ export const FileTreeList = ({ lossShowInfo, tag, setConfigChange, updateModalCo const filePath = `${tag}:${item.slice(index + 1)}`; return { tag, filePath, fileName }; }); - if (end) { - lossShowInfo.modifyCheck(param.slice(start, end)); - } else { - lossShowInfo.modifyCheck(param.slice(start)); - } + lossShowInfo.modifyCheck(param); }; /** * @description: 生成树状列表的数据源 @@ -122,15 +131,20 @@ export const FileTreeList = ({ lossShowInfo, tag, setConfigChange, updateModalCo const name = file.fileName; const path = file.filePath; const len = file.dirs.length; + const isMerge = file.isMerge; let cur = tree; file.dirs.forEach((item: string, index: number) => { const exist: false | TreeItem = cur.find(i => i.title === item) || false; if (!exist) { const obj: TreeItem = { title: '', selectable: false, key: '' }; if (index === len - 1) { - obj.title = name; + obj.title = isMerge ? ({name}) : name; obj.key = path; - cur.push(obj); + if (isMerge) { + cur.unshift(obj); + } else { + cur.push(obj); + } } else { obj.title = item; obj.checkable = false; @@ -181,15 +195,44 @@ export const FileTreeList = ({ lossShowInfo, tag, setConfigChange, updateModalCo setConfigChange(type, filePath, checkedKeys); } }; + /** + * @description: 处理文件合并或取消合并 + * @param {string} type + * @return {*} + */ + const handelMergeOrUnset = async (type: string) => { + setContextMenuVisible(false); + if (type === 'merge') { + showMergeModal(true); + } else { + mergeOrUnset('unset', line); + } + }; + /** + * @description: 判断是否是虚拟文件 + * @return {*} + */ + const judgeIsMerge = () => { + if (typeof tag === 'string') { + return lossShowInfo.mergeFileList[tag].some(i => i.file === line); + } else { + const index = line.lastIndexOf(':'); + const key = line.slice(index + 1); + return tag.some(t => lossShowInfo.mergeFileList[t].some(i => i.file === key)); + } + }; /** * @description: 获取右键菜单选项 * @return {*} */ const getMenuItems = (): JSX.Element => { + const isMerge = judgeIsMerge(); const menuItems: MenuItemModel[] = [ { name: t('setLine'), key: 'setLine', event: () => { handleConfig('line'); }, visible: true }, //更改图例和颜色 - { name: t('tokenConfigChange'), key: 'tokenConfigChange', event: () => { handleConfig('token'); }, visible: true }, //更改单个文件token配置 - { name: t('batchTokenConfigChange'), key: 'batchTokenConfigChange', event: () => { handleConfig('tokenBatch'); }, visible: checkedKeys.length >= 2, disabled: !checkedKeys.includes(line) } //批量修改token配置 + { name: t('tokenConfigChange'), key: 'tokenConfigChange', event: () => { handleConfig('token'); }, visible: true, disabled: isMerge }, //更改单个文件token配置 + { name: t('batchTokenConfigChange'), key: 'batchTokenConfigChange', event: () => { handleConfig('tokenBatch'); }, visible: checkedKeys.length >= 2, disabled: !checkedKeys.includes(line) || isMerge }, //批量修改token配置 + { name: t('fileMerge'), key: 'fileMerge', event: () => { handelMergeOrUnset('merge'); }, visible: checkedKeys.length >= 2, disabled: !checkedKeys.includes(line) || isMerge }, //文件合并 + { name: t('fileUnset'), key: 'fileUnset', event: () => { handelMergeOrUnset('unset'); }, visible: isMerge } //取消文件合并 ]; return <> {menuItems.filter(menuItem => menuItem.visible).map(item => ( @@ -286,15 +329,21 @@ export const FileTreeList = ({ lossShowInfo, tag, setConfigChange, updateModalCo } tag.forEach(oneTag => { fileList.find(item => item.tag === oneTag)?.file.forEach(file => { - tree.push({ title: `${oneTag}:${file.fileName}`, key: `${oneTag}&&${file.fileName}:${file.filePath}`, selectable: false, checkable: true }); + const isMerge = file.isMerge; + if (isMerge) { + tree.unshift({ title: ({`${oneTag}:${file.fileName}`}), key: `${oneTag}&&${file.fileName}:${file.filePath}`, selectable: false, checkable: true }); + } else { + tree.push({ title: `${oneTag}:${file.fileName}`, key: `${oneTag}&&${file.fileName}:${file.filePath}`, selectable: false, checkable: true }); + } if (file.value) { res.push(`${oneTag}&&${file.fileName}:${file.filePath}`); } }); }); - handleAggCheckParam(res, 0, 2); + handleAggCheckParam(res); } setCheckedKeys(res); + updateCheckedKeys(res); setTree(tree); }; useEffect(() => { diff --git a/plugins/mindstudio-insight-plugins/Scalar/front/src/components/LossShow/Smoothing.tsx b/plugins/mindstudio-insight-plugins/Scalar/front/src/components/LossShow/Smoothing.tsx index 8052b4e75673279d1f3da89b69a89f247e57b5f3..eca378ba68cc8d241c5f820f2fd14dcd4c0fc4eb 100644 --- a/plugins/mindstudio-insight-plugins/Scalar/front/src/components/LossShow/Smoothing.tsx +++ b/plugins/mindstudio-insight-plugins/Scalar/front/src/components/LossShow/Smoothing.tsx @@ -54,7 +54,7 @@ export const Smoothing = ({ lossShowInfo, tag }: { lossShowInfo: LossShowInfo; t return
-
{t('algorithm')}:
+
{t('algorithm')}:
-
{t('tagAggregation')}:
+
{t('tagAggregation')}:
lossShowInfo.modifyIsTagPolymerize(e.target.checked)} checked={lossShowInfo.isTagPolymerize} />
-
{t('updateFrequency')}:
+
{t('updateFrequency')}:
-
{t('parseState')}
+
{t('parseState')}:
diff --git a/plugins/mindstudio-insight-plugins/Scalar/front/src/components/LossShow/utils.tsx b/plugins/mindstudio-insight-plugins/Scalar/front/src/components/LossShow/utils.tsx new file mode 100644 index 0000000000000000000000000000000000000000..c0ad73e32c357450f26462bd82a4706eae69469c --- /dev/null +++ b/plugins/mindstudio-insight-plugins/Scalar/front/src/components/LossShow/utils.tsx @@ -0,0 +1,282 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + */ +import type { ECBasicOption } from 'echarts/types/dist/shared'; +import { TFunction } from 'i18next'; +import { type DataItem } from '@/api/lossApi'; +import type { LossShowInfo } from '@/entity/lossShow'; +import { type TableProps, } from 'antd'; +import { downloadCsv, keepDecimals, notZero, type HeaderType } from '@/utils/common'; +export interface ChartParam { + columns: string[]; + tableDataSource: Array<{ [key: string]: string | number | null }>; + tableDataSourceSmoothing: Array<{ [key: string]: string | number | null }>; + smoothingColumns: string[]; + dateConfig: { [key: string]: { coord: [string, string], value: string }[] }; + lineCallback: (name: string) => { fullName: string, abbName: string }; + colorCallback: (key: string, blur: string) => string; +} +export interface ChartDataLenInfo { + [key: string]: MaxMin; + dataLen: MaxMin; + smoothingDataLen: MaxMin; +}; +interface MaxMin { + max: number; + min: number; +} +const chartTooltipFormat = (seriesList: any[], dimensions: string[]) => { + if (seriesList.length < 1) { + return; + } + const addNum = (seriesList.length === 2 && dimensions.length === 4) ? 1 : 2; //不展示smoothing的图例,只展示原始图例,最多两条曲线 + const div = document.createElement('div'); + div.className = 'tooltip'; + div.append(`Step: ${seriesList[0].data.step}`); + const tooltipItem = document.createElement('div'); + tooltipItem.className = 'tooltipItem'; + for (let i = 0; i < seriesList.length; i += addNum) { + const seriesItem = seriesList[i]; + const circle = document.createElement('div'); + circle.className = 'circle'; + circle.setAttribute('style', `background-color:${seriesList[seriesList.length % 2 ? i : (dimensions.length === 4 ? i : i + 1)].color};`); + const keyDom = document.createElement('div'); + const valueDom = document.createElement('div'); + keyDom.className = 'value'; + valueDom.className = 'smoothing'; + if (seriesList.length === 2 && dimensions.length === 4) { + keyDom.append(`value: ${seriesItem.data[dimensions[i === 0 ? i : i + 1]]}`); + valueDom.append(`smoothing: ${seriesItem.data[dimensions[i === 0 ? i + 1 : i + 2]]}`); + } else { + keyDom.append(`value: ${seriesItem.data[dimensions[i]]}`); + valueDom.append(`smoothing: ${seriesItem.data[dimensions[i + 1]]}`); + } + tooltipItem.appendChild(circle); + tooltipItem.appendChild(keyDom); + tooltipItem.appendChild(valueDom); + } + div.appendChild(tooltipItem); + + return div; +}; +export const creatDataChart = (chartParam: ChartParam) => { + const { columns, tableDataSource, tableDataSourceSmoothing, smoothingColumns, dateConfig, lineCallback, colorCallback } = chartParam; + + const source = tableDataSource.map((item, index) => { + return { ...tableDataSourceSmoothing[index], ...item }; + }); + + const dimensions: string[] = []; + columns.forEach((item, index) => { + dimensions.push(item, smoothingColumns[index]); + }); + + return { + animation: false, + tooltip: { + trigger: 'axis', + className: 'chartTooltip', + axisPointer: { + type: 'cross', + }, + formatter: (seriesList: any[]) => chartTooltipFormat(seriesList, dimensions), + }, + xAxis: { + type: 'category', + }, + yAxis: { + }, + series: dimensions.map(_item => { + if (_item.endsWith('smoothing')) { + return { + symbolSize: 5, + type: 'line', + animation: false, + progressive: 0, + showSymbol: false, + lineStyle: { + color: colorCallback(_item, '1'), + }, + itemStyle: { + color: colorCallback(_item, '1'), + }, + }; + } else { + return { + symbolSize: 5, + type: 'line', + animation: false, + progressive: 0, + showSymbol: false, + lineStyle: { + color: colorCallback(_item, '0.4'), + }, + itemStyle: { + color: colorCallback(_item, '0.4'), + }, + markPoint: { + data: dateConfig[_item], + symbolSize: [95, 50], + label: { + position: 'inside', + show: true, + distance: 5, + fontWeight: 'lighter', + fontSize: 10, + borderType: 'solid', + heigth: 50 + } + } + }; + } + }), + dataset: { + dimensions: ['step', ...dimensions], + source, + }, + legend: { + show: true, + textStyle: { + color: 'rgb(141, 152, 170)', + }, + data: [...columns], + formatter: (name: string) => { + const { abbName } = lineCallback(name); + return abbName; + }, + ellipsis: true, + tooltip: { + show: true, + formatter: (params: { name: string }) => { + const { fullName } = lineCallback(params.name); + return fullName; + } + } + }, + }; +}; +export const creatComparisonChart = (source: Array<{ [key: string]: string | number | null }> = [], t: TFunction): ECBasicOption => { + return { + animation: false, + tooltip: { + trigger: 'axis', + axisPointer: { + type: 'cross', + }, + }, + xAxis: { + type: 'category', + }, + yAxis: [ + { + type: 'value', + }, + { + type: 'value', + position: 'left', + axisLabel: { + formatter: '{value} %', + } + }, + ], + series: ['comparisonNormal', 'comparisonAbsolute', 'comparisonRelative'].map(name => ({ + name: t(name), + symbolSize: 5, + type: 'line', + animation: false, + progressive: 0, + yAxisIndex: 0, + showSymbol: false, + })), + dataset: { + dimensions: ['step', 'Comparison Normal', 'Comparison Absolute', 'Comparison Relative'], + source, + }, + legend: { + textStyle: { + color: 'rgb(141, 152, 170)', + }, + selectedMode: 'single', // 每次打开一个图例 + selected: {}, + }, + }; +}; +export const getTableData = (data: { [key: string]: DataItem }) => { + const tableDataSource: Array<{ key: number, step: string, [key: string]: number | string | null }> = []; + const keys = Object.keys(data); + const keysLen = keys.length; + if (keysLen === 0) { + return { tableDataSource, len: { max: -1, min: -1 } }; + } + + let len: number = 0; + let steps: string[] = []; + const keys0 = keysLen < 2 ? Object.keys(data[keys[0]]) : Object.keys(data[keys[keysLen - 2]]); + let keys1 = []; + if (keysLen < 2) { + steps = keys0; + len = steps.length; + } else { + keys1 = Object.keys(data[keys[keysLen - 1]]); + steps = [...new Set([...keys0, ...keys1])]; + len = steps.length; + } + for (let i = 0; i < len; i++) { + tableDataSource.push({ step: String(steps[i]), key: i }); + } + for (let i = 0; i < tableDataSource.length; i++) { + for (let j = 0; j < keysLen; j++) { + const key = tableDataSource[i].step; + tableDataSource[i][keys[j]] = data[keys[j]][key]?.value || null; + } + } + return { + tableDataSource, len: { + max: len, + min: keysLen < 2 + ? -1 : Math.min(keys0.length, keys1.length), + }, + }; +}; +export const getData = (fileList: string[], data: { [key: string]: DataItem }) => { + const dataSource: Array<{ [key: string]: string | number }> = []; + const keys = Object.keys(data); + if (keys.length < 2) { + return { dataSource, len: -1 }; + } + const keys0 = Object.keys(data[keys[0]]); + const keys1 = Object.keys(data[keys[1]]); + const steps: string[] = keys0.filter(item => keys1.includes(item)); + const len = steps.length; + for (let i = 0; i < len; i++) { + dataSource.push({ step: String(steps[i]), key: i }); + } + for (let i = 0; i < len; i++) { + const baseData = data[fileList[1]][steps[i]]?.value || 0; + const comparisonData = data[fileList[0]][steps[i]]?.value || 1; + if (typeof baseData === 'string' && typeof comparisonData === 'string') { + continue; + } + const dif = baseData - comparisonData; + dataSource[i]['Comparison Normal'] = dif; + dataSource[i]['Comparison Absolute'] = Math.abs(dif); + dataSource[i]['Comparison Relative'] = keepDecimals(Math.abs(dif) / notZero(comparisonData) * 100); + } + return { dataSource, len }; +}; +export const initTokenOption = (lossShowInfo: LossShowInfo) => { + const res: { [key: string]: { [key: string]: { globalBatchSize: number, seqLength: number } } } = {}; //记录每一个tag的modal框内的globalBatchSize和seqLength + const fileList = lossShowInfo.getFileListByModel('isChecked'); //所有文件列表 + fileList.forEach(item => { + res[item.tag] = {}; + item.file.forEach(file => { + const oneConfig = lossShowInfo.tokenOptions?.[item.tag]?.files?.[file.filePath]; + res[item.tag][file.filePath] = { globalBatchSize: oneConfig ? oneConfig.globalBatchSize : -1, seqLength: oneConfig ? oneConfig.seqLength : -1 }; + }); + }); + return res; +}; +export const exportFile = (isAgg: boolean, tag: string | string[], columns: TableProps['columns'], tableData: Array<{ [key: string]: string | number | null }> | undefined): void => { + const fileName = isAgg ? `${typeof tag === 'string' ? tag : 'tagAggregation'}_comparison_${Date.now()}.csv` : `${typeof tag === 'string' ? tag : 'tagAggregation'}_${Date.now()}.csv`; + downloadCsv(columns?.map(item => ({ key: item.key, title: item.title })) as HeaderType[], tableData || [], fileName); +}; \ No newline at end of file diff --git a/plugins/mindstudio-insight-plugins/Scalar/front/src/entity/lossShow.ts b/plugins/mindstudio-insight-plugins/Scalar/front/src/entity/lossShow.ts index cdc1d90b108c489c59838b70185dbe6f61b37d00..4ef98eb9583594114c0a39f12339faed9a855465 100644 --- a/plugins/mindstudio-insight-plugins/Scalar/front/src/entity/lossShow.ts +++ b/plugins/mindstudio-insight-plugins/Scalar/front/src/entity/lossShow.ts @@ -10,25 +10,26 @@ import type { NewFileResponseBody, TokenChart, DateItem, - LineItem + LineItem, + tokenParam, + QueryParamItem, + FileMergeResponseBody } from '@/api/lossApi'; import { makeAutoObservable } from 'mobx'; import { getRandomRGBColor, getChartColor } from '@/utils/common'; interface FileInfoItem { - [key: string]: string | number | boolean | string[] | DataItem | DateItem[] | LineItem; + [key: string]: string | number | boolean | string[] | DataItem | DateItem[] | LineItem | QueryParamItem; name: string; data: DataItem; tokenData: DataItem; - offset: number; - tokenOffset: number; isChecked: boolean; smoothingData: DataItem; tokenSmoothingData: DataItem; - sampleOffset: number; - tokenSampleOffset: number; dirs: string[]; dateConfig: DateItem[]; lineConfig: LineItem; + isMerge: boolean; + queryParam: QueryParamItem }; interface TagItem { [key: string]: FileInfoItem; @@ -42,13 +43,16 @@ export interface FileListResult { filePath: string; fileName: string; dirs: string[]; - value: string | number | boolean | string[] | DataItem | DateItem[] | LineItem; + isMerge: boolean; + queryParam: QueryParamItem; + value: string | number | boolean | string[] | DataItem | DateItem[] | LineItem | QueryParamItem; }[]; }; export interface CheckListItem { tag: string; filePath: string; fileName: string; + isMerge?: boolean; dirs?: string[]; }; interface SmoothingConfig { @@ -68,10 +72,12 @@ interface TokenOptions { enable: boolean; } } +interface LineType { + line: ['normal', 'normalSmoothing', 'token', 'tokenSmoothing'] +} export class LossShowInfo { renderChart: boolean = false; showList: { [key: string]: string[]; } = {}; - private showListCopy: { [key: string]: string[]; } = {}; tagAggregationShowList: CheckListItem[] = []; fileListChange: number = 0; stepGetData: number = 5; @@ -83,6 +89,7 @@ export class LossShowInfo { showTagList: string[] = []; chartInit: boolean = false; tokenOptions: TokenOptions = {}; //token相关配置 + mergeFileList: { [key: string]: { file: string, mergeName: string }[] } = {}; //虚拟文件列表 constructor() { makeAutoObservable(this); }; @@ -91,7 +98,6 @@ export class LossShowInfo { if (!this.hasBaseData.includes(tag)) { this.hasBaseData.push(tag); this.showList[tag] = [file.path]; - this.showListCopy[tag] = [file.path]; if (this.showTagList.length < 3) { this.showTagList.push(tag); } @@ -107,6 +113,7 @@ export class LossShowInfo { this.showList[item.tag] = []; this.smoothingConfigList[item.tag] = { sampleAlgorithm: '', sampleWeight: 0 }; this.tokenOptions[item.tag] = { files: {}, enable: false }; + this.mergeFileList[item.tag] = []; } item.fileList.forEach((file: { name: string; path: string, dirs: string[] }, index: number) => { if (!Object.keys(this.fileList[item.tag]).includes(file.path)) { @@ -120,18 +127,17 @@ export class LossShowInfo { name: file.name, data: {}, tokenData: {}, - offset: 0, - tokenOffset: 0, isChecked: this.getIsChecked(item.tag, file), smoothingData: {}, tokenSmoothingData: {}, - sampleOffset: 0, - tokenSampleOffset: 0, dirs: file.dirs, dateConfig: [], - lineConfig: { name: `${item.tag}:${file.name}`, color } + lineConfig: { name: `${item.tag}:${file.name}`, color }, + isMerge: false, + queryParam: { start: 0, end: -1 } }; this.tokenOptions[item.tag].files[file.path] = { globalBatchSize: -1, seqLength: -1 }; + this.mergeFileList[item.tag] = []; } }); }); @@ -139,17 +145,18 @@ export class LossShowInfo { }; addData(dataList: ChartsDataResponseBody['data']) { + const dataMap: { [key: string]: string } = { 'normal': 'data', 'normalSmoothing': 'smoothingData', 'token': 'tokenData', 'tokenSmoothing': 'tokenSmoothingData' }; + const dataMapArr: LineType['line'] = ['normal', 'normalSmoothing', 'token', 'tokenSmoothing']; dataList.forEach(item => { const file = this.fileList[item.tag][item.file]; - file.data = Object.assign(file.data, item.normal); - file.offset = Object.keys(file.data).length; - file.smoothingData = Object.assign(file.smoothingData, item.normalSmoothing); - file.sampleOffset = Object.keys(file.smoothingData).length; + dataMapArr.forEach(type => { + file[dataMap[type]] = item[type] || {}; + if (item[type] && Object.keys(item[type]).length) { + const keys = Object.keys(item[type]).map(Number); + file.queryParam = { start: keys[0], end: keys[keys.length - 1] }; + } + }); file.dateConfig.push(...(item.dateConfig ?? [])); - file.tokenData = Object.assign(file.tokenData, item.token); - file.tokenOffset = Object.keys(file.tokenData).length; - file.tokenSmoothingData = Object.assign(file.tokenSmoothingData, item.tokenSmoothing); - file.tokenSampleOffset = Object.keys(file.tokenSmoothingData).length; }); this.renderChart = !this.renderChart; }; @@ -209,9 +216,28 @@ export class LossShowInfo { modifyCheck(value: { tag: string, fileList: string[] } | CheckListItem[]) { if (Array.isArray(value)) { - this.tagAggregationShowList = value; + this.tagAggregationShowList = value.slice(-2); const checkFilePathList = value.map(item => item.filePath); - this.modifyTagAggregationCheck(this.checkedListFormat(checkFilePathList)); + const files = this.checkedListFormat(checkFilePathList); + if (value.length) { + Object.keys(files).forEach(tag => { + Object.keys(this.fileList[tag]).forEach(filePath => { + const fileList = files[tag]; + if (fileList.includes(filePath)) { + this.fileList[tag][filePath].isChecked = true; + } else { + this.fileList[tag][filePath].isChecked = false; + } + }); + }); + } else { + //无选项选中时进行取消合并的操作 + Object.keys(this.fileList).forEach(tag => { + Object.keys(this.fileList[tag]).forEach(filePath => { + this.fileList[tag][filePath].isChecked = false; + }); + }); + } } else { this.modifyOneCheck(value); } @@ -224,19 +250,27 @@ export class LossShowInfo { } else { this.fileList[tag][filePath].isChecked = false; } - this.showList[tag] = fileList; + this.showList[tag] = fileList.slice(-2); }); }; - private modifyTagAggregationCheck(value: { [key: string]: string[] }) { + private modifyTagAggregationCheck(isAgg: boolean) { + //切换聚合状态时,默认选中第一个非虚拟文件 + let checkFlag = false; Object.keys(this.fileList).forEach(tag => { + if (!isAgg) { + checkFlag = false; + } Object.keys(this.fileList[tag]).forEach(filePath => { - this.fileList[tag][filePath].isChecked = false; + const isMerge = this.fileList[tag][filePath].isMerge; + if (isMerge || checkFlag) { + this.fileList[tag][filePath].isChecked = false; + } else { + checkFlag = true; + this.fileList[tag][filePath].isChecked = true; + this.showList[tag] = [filePath]; + } }); - this.showList[tag] = []; - }); - Object.keys(value).forEach(tag => { - this.modifyOneCheck({ tag, fileList: value[tag] }); }); }; @@ -249,14 +283,16 @@ export class LossShowInfo { }; modifyIsTagPolymerize(value: boolean) { - const copyMid = JSON.parse(JSON.stringify(this.showList)); - this.modifyTagAggregationCheck((this.showListCopy)); - this.showListCopy = copyMid; + this.modifyTagAggregationCheck(value); this.isTagPolymerize = value; - //切换聚合状态时统一使能token为false,界面展示normal的数据 - for (const tag in this.tokenOptions) { + Object.keys(this.fileList).forEach(tag => { + //切换聚合状态时统一使能token为false,界面展示normal的数据 this.tokenOptions[tag].enable = false; - } + Object.keys(this.fileList[tag]).forEach(filePath => { + //切换聚合状态时统一使用默认的查询参数 + this.fileList[tag][filePath].queryParam = { start: 0, end: -1 }; + }); + }); }; modifySmoothingConfig(tag: string | string[], value: { sampleAlgorithm: string, sampleWeight: number }) { @@ -268,10 +304,8 @@ export class LossShowInfo { const file = this.fileList[tag][filePath]; if (this.tokenOptions[tag].enable) { file.tokenSmoothingData = {}; - file.tokenSampleOffset = 0; } else { file.smoothingData = {}; - file.sampleOffset = 0; } }); }); @@ -279,15 +313,21 @@ export class LossShowInfo { modifyTokenOption = (tag: string | string[], option: { [key: string]: { [key: string]: { globalBatchSize: number, seqLength: number } } }, enable: boolean) => { if (typeof tag === 'string') { Object.keys(option[tag]).forEach(item => { + if (!this.tokenOptions[tag].files[item]) { + this.tokenOptions[tag].files[item] = { globalBatchSize: -1, seqLength: -1 }; + } this.tokenOptions[tag].files[item].globalBatchSize = option[tag][item].globalBatchSize; this.tokenOptions[tag].files[item].seqLength = option[tag][item].seqLength; //原来的数据要清除 - const file = this.fileList[tag][item]; - file.tokenData = {}; - file.tokenOffset = 0; - file.tokenSmoothingData = {}; - file.tokenSampleOffset = 0; - file.dateConfig = []; + const file = this.fileList[tag]?.[item]; + if (file) { + file.tokenData = {}; + file.tokenOffset = 0; + file.tokenSmoothingData = {}; + file.tokenSampleOffset = 0; + file.dateConfig = []; + file.queryParam = { start: 0, end: -1 }; + } }); this.tokenOptions[tag].enable = enable; } else { @@ -296,12 +336,15 @@ export class LossShowInfo { this.tokenOptions[t].files[item].globalBatchSize = option[t][item].globalBatchSize; this.tokenOptions[t].files[item].seqLength = option[t][item].seqLength; //原来的数据要清除 - const file = this.fileList[t][item]; - file.tokenData = {}; - file.tokenOffset = 0; - file.tokenSmoothingData = {}; - file.tokenSampleOffset = 0; - file.dateConfig = []; + const file = this.fileList[t]?.[item]; + if (file) { + file.tokenData = {}; + file.tokenOffset = 0; + file.tokenSmoothingData = {}; + file.tokenSampleOffset = 0; + file.dateConfig = []; + file.queryParam = { start: 0, end: -1 }; + } }); this.tokenOptions[t].enable = enable; }); @@ -318,7 +361,9 @@ export class LossShowInfo { value: this.fileList[tag][filePath][model], filePath, fileName: this.fileList[tag][filePath].name, - dirs: this.fileList[tag][filePath].dirs + dirs: this.fileList[tag][filePath].dirs, + isMerge: this.fileList[tag][filePath].isMerge, + queryParam: this.fileList[tag][filePath].queryParam })); res.push({ tag, @@ -390,20 +435,19 @@ export class LossShowInfo { if (this.showTagList.includes(tag)) { const { sampleAlgorithm: algorithm, sampleWeight: weight } = this.smoothingConfigList[tag]; Object.keys(this.fileList[tag]).forEach((file: string) => { - const tokenOption: TokenChart = { type: 'token', globalBatchSize: -1, seqLength: -1 }; + const tokenOption: TokenChart = { type: 'token', enable: false }; const oneFile = this.fileList[tag][file]; - let offset = oneFile.sampleOffset; + const tokenEnable = this.tokenOptions[tag].enable; + const start = oneFile.queryParam.start; + const end = oneFile.queryParam.end; if (oneFile.isChecked) { - if (this.tokenOptions[tag].enable) { - tokenOption.globalBatchSize = this.tokenOptions[tag].files[file].globalBatchSize; - tokenOption.seqLength = this.tokenOptions[tag].files[file].seqLength; - offset = oneFile.tokenSampleOffset; - } + tokenOption.enable = tokenEnable; graphList.push({ tag, file, - offset: this.tokenOptions[tag].enable ? oneFile.tokenOffset : oneFile.offset, - graphConfig: [{ type: 'normal' }, tokenOption, { type: 'smoothing', algorithm, weight, offset }] + start, + end, + graphConfig: [{ type: 'normal', enable: !tokenEnable }, tokenOption, { type: 'smoothing', enable: algorithm.length > 0, algorithm, weight }] }); } }); @@ -411,7 +455,35 @@ export class LossShowInfo { }); return graphList; }; - + getTokenParams(tag: string | string[], key: string | string[]) { + let params: tokenParam[] = []; + if (typeof tag === 'string') { + if (typeof key === 'string') { + const index = key.lastIndexOf(':'); + const file = key.slice(index + 1); + params.push({ file, globalBatchSize: this.tokenOptions[tag].files[file].globalBatchSize, seqLength: this.tokenOptions[tag].files[file].seqLength }); + } else { + params = key.map(file => ({ file, globalBatchSize: this.tokenOptions[tag].files[file].globalBatchSize, seqLength: this.tokenOptions[tag].files[file].seqLength })); + } + } else { + if (typeof key === 'string') { + const index = key.lastIndexOf(':'); + const tagName = key.slice(0, index).split('&&')[0]; + const file = key.slice(index + 1); + const { globalBatchSize, seqLength } = this.tokenOptions[tagName].files[file]; + params.push({ file, globalBatchSize, seqLength }); + } else { + params = key.map(fileKey => { + const index = fileKey.lastIndexOf(':'); + const tagName = fileKey.slice(0, index).split('&&')[0]; + const file = fileKey.slice(index + 1); + const { globalBatchSize, seqLength } = this.tokenOptions[tagName].files[file]; + return { file, globalBatchSize, seqLength }; + }); + } + } + return params; + }; getTagList() { return Object.keys(this.fileList); }; @@ -431,7 +503,36 @@ export class LossShowInfo { return showFileList; } }; - + handleMergeFile(mergeName: string, body: FileMergeResponseBody) { + const { action, tags, file } = body.data; + if (action === 'merge') { + tags.forEach(t => { + this.fileList[t][file] = { + name: mergeName, + data: {}, + tokenData: {}, + isChecked: this.getIsChecked(t, { name: mergeName, path: file }), + smoothingData: {}, + tokenSmoothingData: {}, + dirs: [file], + dateConfig: [], + lineConfig: { name: `${t}:${mergeName}`, color: getRandomRGBColor() }, + isMerge: true, + queryParam: { start: 0, end: -1 } + }; + this.mergeFileList[t].push({ file, mergeName }); + this.tokenOptions[t].files[file] = { globalBatchSize: -1, seqLength: -1 }; + }); + this.fileListChange += 1; + } else { + tags.forEach(t => { + delete this.fileList[t][file]; + const index = this.mergeFileList[t].findIndex(i => i.file === file); + this.mergeFileList[t].splice(index, 1); + }); + this.fileListChange -= 1; + } + }; clearFileList() { this.renderChart = false; this.fileList = {}; @@ -441,9 +542,10 @@ export class LossShowInfo { this.newFileList = []; this.fileListChange += 1; this.isTagPolymerize = false; - this.showListCopy = {}; this.tagAggregationShowList = []; this.smoothingConfigList = {}; this.showTagList = []; + this.tokenOptions = {}; + this.mergeFileList = {}; }; }; diff --git a/plugins/mindstudio-insight-plugins/Scalar/front/src/i18n/lossShow/en.json b/plugins/mindstudio-insight-plugins/Scalar/front/src/i18n/lossShow/en.json index 21b1dd365de5959c771b39baa0784d3d12d44fee..8fae21d92a53fe461e4330ed037f281e20c7c207 100644 --- a/plugins/mindstudio-insight-plugins/Scalar/front/src/i18n/lossShow/en.json +++ b/plugins/mindstudio-insight-plugins/Scalar/front/src/i18n/lossShow/en.json @@ -24,16 +24,24 @@ "zoomReset": "Zoom Reset", "restore": "Restore", "algorithm": "Smoothing Algorithm", - "parseState": "Data parsing status: ", + "parseState": "Data parsing status", "parseFailed": "Failed to parse data", "tokenEnable": "Enable token", "tokenConfigChange": "Modify Token Conifg", "setBatchLength": "GlobalBatch & SeqLength Setting", "setLine": "Custom polyline", "fileMerge": "File Merge", + "fileUnset": "Unset File Merge", "ok": "OK", "cancel": "Cancel", "tokenTip": "After the token is enabled, if the configuration is incorrect, no data may be displayed in the chart.", - "batchTokenConfigChange":"Batch modify Token Conifg" + "batchTokenConfigChange": "Batch modify Token Conifg", + "filePath": "FilePath", + "legendName": "Legend Name", + "legendColor": "Legend Color", + "mergeFileName": "Virtual File Name", + "mergeTip": "The virtual file cannot be merged again. Please deselect the virtual file.", + "mergeNameTip": "The virtual file name cannot be empty.", + "batchTokenTip": "The token parameter cannot be set for the virtual file. Please deselect the virtual file." } } \ No newline at end of file diff --git a/plugins/mindstudio-insight-plugins/Scalar/front/src/i18n/lossShow/zh.json b/plugins/mindstudio-insight-plugins/Scalar/front/src/i18n/lossShow/zh.json index cb44aab9b7e73e946d1e2fd1583594844c757b77..c5c8895b4f6b5c47ab893045fa020e0c6e5e8aa3 100644 --- a/plugins/mindstudio-insight-plugins/Scalar/front/src/i18n/lossShow/zh.json +++ b/plugins/mindstudio-insight-plugins/Scalar/front/src/i18n/lossShow/zh.json @@ -24,16 +24,24 @@ "zoomReset": "区域缩放还原", "restore": "还原", "algorithm": "平滑算法", - "parseState": "数据解析状态: ", + "parseState": "数据解析状态", "parseFailed": "数据解析失败", "tokenEnable": "开启Token开关", "tokenConfigChange": "修改Token配置", "setBatchLength": "设置GlobalBatch和SeqLength", "setLine": "自定义折线", "fileMerge": "文件合并", + "fileUnset": "取消文件合并", "ok": "确认", "cancel": "取消", "tokenTip": "开启Token后,如果配置不正确,图表可能无数据!", - "batchTokenConfigChange":"批量修改Token配置" + "batchTokenConfigChange": "批量修改Token配置", + "filePath": "文件路径", + "legendName": "图例名称", + "legendColor": "图例颜色", + "mergeFileName": "虚拟文件名称", + "mergeTip": "虚拟文件不能被再次合并,请取消选中虚拟文件", + "mergeNameTip": "虚拟文件名不能为空", + "batchTokenTip": "虚拟文件不能被设置Token参数,请取消选中虚拟文件" } } \ No newline at end of file diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/package.json b/plugins/tensorboard-plugins/tb_graph_ascend/fe/package.json index 5e6b358b41bb527446686323c4f364b3febe6c1a..dd175fe2dd63303c3cd43d26e9f9649a618df2d1 100644 --- a/plugins/tensorboard-plugins/tb_graph_ascend/fe/package.json +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/package.json @@ -63,9 +63,13 @@ "@polymer/paper-tooltip": "^3.0.1", "@polymer/polymer": "^3.5.1", "@types/lodash": "^4.17.1", + "@vaadin/button": "^23.5.11", + "@vaadin/combo-box": "^23.5.11", + "@vaadin/details": "^24.6.5", "@vaadin/icon": "^23.5.11", "@vaadin/icons": "^23.5.11", "@vaadin/notification": "^23.5.11", + "@vaadin/progress-bar": "^23.5.11", "@vaadin/tabs": "^23.5.11", "@vaadin/tabsheet": "^23.5.11", "@vaadin/tooltip": "^23.5.11", diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/index.css b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/index.css index a28ea1aebe630fe879bb4672047bc8bae03ec274..ac81d67c4f2d92e40d801031358fc6c14066cdaa 100644 --- a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/index.css +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/index.css @@ -4,4 +4,14 @@ graph-app { height: 100%; margin: 0; font-family: Roboto, sans-serif; +} + +vaadin-combo-box-scroller { + overflow: scroll; + font-size: 14px; +} + +vaadin-combo-box-item { + overflow: unset; + font-size: 14px; } \ No newline at end of file diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph/components/legend/index.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph/components/legend/index.ts new file mode 100644 index 0000000000000000000000000000000000000000..5b5db77050a470954dfc722ca2a60b26cd162e7c --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph/components/legend/index.ts @@ -0,0 +1,98 @@ +/* Copyright (c) 2025, Huawei Technologies. + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { PolymerElement, html } from '@polymer/polymer'; +import { customElement } from '@polymer/decorators'; +@customElement('scene-legend') +class legend extends PolymerElement { + static get template() { + return html` + +
+
+ + Module or Operators +
+
+ + Unexpanded Module or Operators +
+ +
+ Unexpandable Node: It can be an Api, operator or module. It cannot be expanded because it has no + subnodes +
+
+
+
+
+ + Api List +
+ +
Apis between modules
+
+
+
+
+ `; + } +} diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph/tf-graph.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph/tf-graph.ts index 91e2f9f8837ebc71fec169639288942d945b57c4..af08ab79b2a700fec671e4972352405e02d53af1 100644 --- a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph/tf-graph.ts +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph/tf-graph.ts @@ -27,10 +27,12 @@ import * as tf_graph_scene from '../tf_graph_common/scene'; import * as tf_graph_util from '../tf_graph_common/util'; import * as tf_graph_layout from '../tf_graph_common/layout'; import './tf-graph-scene'; +import './components/legend/index'; import { Selection } from '../tf_graph_controls/tf-graph-controls'; import { fetchPbTxt, parseGraphPbTxt } from '../tf_graph_common/parser'; import * as tf_hierarchy from '../tf_graph_common/hierarchy'; import * as tf_graph_parser from '../tf_graph_common/parser'; + import { BENCH_PREFIX } from '../tf_graph_common/common'; let _isRankJump = ''; @@ -47,7 +49,6 @@ class TfGraph extends LegacyElementMixin(PolymerElement) { width: 100%; height: 100%; background: white; - box-shadow: 0 1px 5px rgba(0, 0, 0, 0.2); display: flex; } @@ -77,6 +78,7 @@ class TfGraph extends LegacyElementMixin(PolymerElement) { text-transform: none; } +
-
+
-
- - -
-
- - -
-
- -
-
- - -
- - - -
-
- - - -
-
- - Module or Operators -
-
- - -
-
- - - - - - - - - - - - - -
- - - Unexpanded Module or Operators -
- -
- Unexpandable Node: It can be an Api, operator or module. It cannot be expanded because it has no - subnodes -
-
-
-
- - - Api List -
- -
Apis between modules
-
-
-
- -
+ + `; // 核心part @@ -955,84 +756,16 @@ class TfGraphControls extends LegacyElementMixin(DarkModeMixin(PolymerElement)) selectedSide: string = '0'; // 颜色数据 - @property({ type: Object, observer: '_updateColorItems' }) - colors: any; - - // 自定义颜色设置 - @property({ type: Array }) - standardColorList = ['#FFFCF3', '#FFEDBE', '#FFDC7F', '#FFC62E', '#FF9B3D', '#FF704D', '#FF4118']; - colorList = _.cloneDeep(this.standardColorList); - hiddenSelects = [{ key: 'NaN', values: [NaN, NaN] }]; - - // 精度筛选 - @property({ type: Array }) - selectColor: any = []; - selectedPrecisionNode: string = ''; - selectedOverflowNode: string = ''; - @property({ type: Object }) - precisionmenu: any = []; - - // 溢出筛选 - @property({ type: Array }) - overflowLevel: any = []; @property({ type: Object }) - overflowmenu: any = []; - - // 标杆侧未匹配节点 - @property({ type: String, notify: true }) - selectedBenchUnmatched: string = ''; + colors: any; // 颜色图例 @property({ type: Object }) colorset; - colorSetChanged; - - // 溢出图例默认数据 - @property({ type: Object }) - overFlowSet: any = [ - ['#B6C7FC', 'medium'], - ['#7E96F0', 'high'], - ['#4668B8', 'critical'], - ]; - - // 节点匹配 - @property({ type: Object }) - matchednodeset: any = []; - unMatchednodeset: any = []; // 节点匹配,未匹配部分节点 - @property({ type: String, notify: true }) - selectedNPUNode: string = ''; - selectedBenchNode: string = ''; @property({ type: Object }) unmatched: any = []; - NPU_unmatched: any = []; - Bench_unmatched: any = []; - @property({ type: Number }) - _selectedNpuMatchMenu: number = -1; - _selectedBenchMatchMenu: number = -1; - - // 节点匹配,已匹配部分节点 - @property({ type: String, notify: true }) - selectedMatchedNPUNode: string = ''; - selectedMatchedBenchNode: string = ''; - @property({ type: Object }) - NPU_matched: any = []; - Bench_matched: any = []; - matched: any = []; - @property({ type: Number }) - _selectedUnMatchMenu: number = -1; - - // 后端传递过来的用户之前已匹配过的节点列表 - matchedlist: any = []; - - // 控制图例展开 - @property({ type: Boolean }) - _expanded: boolean = true; - _legendOpened: boolean = true; // legend图例 - _colors: boolean = true; // 颜色图例 - _overFlowLevel: boolean = true; // 溢出筛选图例 - _colorSetting: boolean = true; // 颜色设置按钮 // 上传文件 @property({ type: Object, notify: true }) @@ -1051,11 +784,9 @@ class TfGraphControls extends LegacyElementMixin(DarkModeMixin(PolymerElement)) this._showTabContent('设置', 'nodes-content'); document.addEventListener('contextMenuTag-changed', this._getTagChanged.bind(this)); } - _getTagChanged(contextMenuTag) { this.set('_selectedTagIndex', contextMenuTag.detail); } - _showTabContent(buttonText, contentId) { // Remove 'active' class from all buttons this.shadowRoot?.querySelectorAll('.tab-button').forEach((button) => { @@ -1081,27 +812,13 @@ class TfGraphControls extends LegacyElementMixin(DarkModeMixin(PolymerElement)) selectedContent.classList.remove('hidden'); } } - // 使用示例 _showNodeControls() { this._showTabContent('设置', 'nodes-content'); } - - _showDirectoryStructure() { - this._showTabContent('目录', 'directory-content'); + _showMatch() { + this._showTabContent('匹配', 'match-content'); } - - _showSearchStructure() { - this._showTabContent('搜索', 'search-content'); - } - - _onGraphTypeChangedByUserGesture() { - tf_graph_util.notifyDebugEvent({ - actionId: tb_debug.GraphDebugEventId.GRAPH_TYPE_CHANGED, - eventLabel: this._selectedGraphType, - }); - } - _numTags(datasets: Dataset, _selectedRunIndex: number) { return this._getTags(datasets, _selectedRunIndex).length; } @@ -1114,41 +831,6 @@ class TfGraphControls extends LegacyElementMixin(DarkModeMixin(PolymerElement)) _fit() { this.fire('fit-tap'); } - @observe('colorset') - _observe() { - if (this.colorset.length !== 0) { - const colorsets = this.colorset; - for (const item of colorsets) { - if (item[1].value.length === 0) { - item[1].value.push('无匹配节点'); - } - } - this.colorSetChanged = colorsets; - } else { - return; - } - } - @observe('unmatched') - _observeUnmatchedNode() { - this.set('NPU_unmatched', this.unmatched[0]); - this.set('Bench_unmatched', this.unmatched[1]); - } - @observe('matchedlist', 'selection') - _observeMatchedList() { - this.set('NPU_matched', []); - this.set('Bench_matched', []); - this.set('matched', []); - if (this.matchedlist) { - for (const item of this.matchedlist) { - this.NPU_matched = [...this.NPU_matched, item[0]]; - this.Bench_matched = [...this.Bench_matched, item[1]]; - this.matched = [...this.matched, [item[0], item[1]]]; - } - this.set('NPU_matched', this.NPU_matched); - this.set('Bench_matched', this.Bench_matched); - this.set('matched', this.matched); - } - } _observeMenuNode() { let prefix = ''; const hasBNode = this.renderHierarchy.bench?.renderedOpNames.some((name: string) => name.startsWith(BENCH_PREFIX)); @@ -1161,440 +843,12 @@ class TfGraphControls extends LegacyElementMixin(DarkModeMixin(PolymerElement)) _observeMenuSideItem() { this.set('selectedMenuNode', ''); } - _observePrecsionNode() { - let prefix = ''; - const hasBNode = this.renderHierarchy.bench?.renderedOpNames.some((name: string) => name.startsWith(BENCH_PREFIX)); - if (hasBNode) { - prefix = NPU_PREFIX; - } - const node = prefix + this.selectedPrecisionNode; - this.set('selectedNode', node); - } - _observeNPUUnMatchedNode(event) { - let prefix = ''; - const hasBNode = this.renderHierarchy.bench?.renderedOpNames.some((name: string) => name.startsWith(BENCH_PREFIX)); - if (hasBNode) { - prefix = NPU_PREFIX; - } - const node = prefix + event.model.item; - this.set('selectedNPUNode', node); - this.set('selectedNode', node); - } - _observeBenchUnMatchedNode() { - let prefix = ''; - const hasBNode = this.renderHierarchy.bench?.renderedOpNames.some((name: string) => name.startsWith(BENCH_PREFIX)); - if (hasBNode) { - prefix = BENCH_PREFIX; - } - const node = prefix + this.selectedBenchUnmatched; - this.set('selectedBenchNode', node); - this.set('selectedNode', node); - } - _observeNPUMatchedNode(event) { - let prefix = ''; - const hasBNode = this.renderHierarchy.bench?.renderedOpNames.some((name: string) => name.startsWith(BENCH_PREFIX)); - if (hasBNode) { - prefix = NPU_PREFIX; - } - const node = prefix + event.model.item; - this.set('selectedMatchedNPUNode', node); - const matched_node = this.findNodeInMatched(node.slice(4), 0); - if (matched_node) { - this.set('selectedMatchedBenchNode', `${BENCH_PREFIX}${matched_node}`); - } - this.set('selectedNode', node); - } - _observeBenchMatchedNode(event) { - let prefix = ''; - const hasBNode = this.renderHierarchy.bench?.renderedOpNames.some((name: string) => name.startsWith(BENCH_PREFIX)); - if (hasBNode) { - prefix = BENCH_PREFIX; - } - const node = prefix + event.model.item; - this.set('selectedMatchedBenchNode', node); - const matched_node = this.findNodeInMatched(node.slice(4), 1); - if (matched_node) { - this.set('selectedMatchedNPUNode', `${NPU_PREFIX}${matched_node}`); - } - this.set('selectedNode', node); - } - findNodeInMatched(node, side) { - // 遍历 matched 数组 - for (let i = 0; i < this.matched.length; i++) { - // 获取当前子数组 - const pair = this.matched[i]; - // 确保子数组的长度为 2,防止越界 - if (pair.length >= 2 && pair[side] === node) { - if (side === 0) { - return pair[1]; // 返回第二项 - } else { - return pair[0]; // 返回第一项 - } - } - } - // 如果没找到 node - return null; // 返回 null 或其他指示未找到的值 - } @observe('menu', 'selectedSide') _getMenuItem() { if (this.menu) { this.set('menuItem', this.menu[Number(this.selectedSide)]); } } - showDynamicDialog(message) { - // 检查是否已经有显示的对话框,避免重复添加 - let existingDialog = this.shadowRoot?.querySelector('#dynamicDialog'); - if (existingDialog) { - existingDialog.remove(); // 删除旧的对话框 - } - // 创建新的对话框 - const dialog = document.createElement('paper-dialog'); - dialog.id = 'dynamicDialog'; - // 添加标题 - const title = document.createElement('h2'); - title.textContent = '提示'; - dialog.appendChild(title); - // 添加提示内容 - const content = document.createElement('div'); - content.textContent = message; - dialog.appendChild(content); - // 添加按钮 - const buttonContainer = document.createElement('div'); - buttonContainer.classList.add('buttons'); - const closeButton = document.createElement('paper-button'); - closeButton.setAttribute('dialog-dismiss', ''); - closeButton.textContent = '关闭'; - buttonContainer.appendChild(closeButton); - dialog.appendChild(buttonContainer); - // 添加到 shadow DOM - this.shadowRoot?.appendChild(dialog); - // 打开对话框 - dialog.open(); - } - - _handleNodeSearch(event, type: 'unmatched' | 'precision' | 'overflow') { - const action = event.target.getAttribute('data-action'); - const isUnmatched = type === 'unmatched'; - const menuFirstRow = isUnmatched ? this.menu[1] : this.menu[0]; - const selectedNode = this.selectedNode; - let nodeList; - let colorSet; - if (isUnmatched) { - nodeList = this.Bench_unmatched; - } else if (type === 'overflow') { - nodeList = this.overflowmenu; - colorSet = this.overflowLevel; - } else { - nodeList = this.precisionmenu; - if (type === 'precision') { - colorSet = this.selectColor; - } else { - colorSet = null; - } - } - const prefix = isUnmatched ? BENCH_PREFIX : NPU_PREFIX; - const hasBNode = this.renderHierarchy.bench?.renderedOpNames.some((name: string) => name.startsWith(BENCH_PREFIX)); - const showDialog = (message: string) => { - this.showDynamicDialog(message); - }; - - const setDefaultNode = () => { - const defaultNode = hasBNode && !isUnmatched ? `${prefix}${nodeList[0]}` : `${prefix}${nodeList[0]}`; - this.set('selectedNode', defaultNode); - }; - - // 校验逻辑 - if (isUnmatched && nodeList.length === 0) { - showDialog('标杆侧没有未匹配节点'); - return; - } - if (!isUnmatched && colorSet.length === 0) { - showDialog('请选择颜色'); - return; - } - if (!isUnmatched && nodeList.length === 0) { - showDialog('选择的颜色没有节点存在'); - return; - } - - // 如果用户未选中节点,设置默认节点 - if (!selectedNode) { - setDefaultNode(); - return; - } - - // 获取 selectedNode 在 menuFirstRow 中的索引 - const slicedNode = hasBNode ? selectedNode.slice(4) : selectedNode; - const startIndex = menuFirstRow.indexOf(slicedNode); - if (startIndex === -1) { - setDefaultNode(); - return; - } - - // 查找下一个节点 - const findNextNode = () => { - if (nodeList.includes(selectedNode)) { - const currentIndex = nodeList.indexOf(selectedNode); - if (currentIndex + 1 >= nodeList.length) { - showDialog(isUnmatched ? '没有下一个未匹配节点' : '没有下一个问题节点'); - return null; - } - return nodeList[currentIndex + 1]; - } - for (let i = startIndex + 1; i < menuFirstRow.length; i++) { - if (nodeList.includes(menuFirstRow[i])) { - return menuFirstRow[i]; - } - } - showDialog(isUnmatched ? '没有下一个未匹配节点' : '没有下一个问题节点'); - return null; - }; - - // 查找上一个节点 - const findPreviousNode = () => { - if (nodeList.includes(selectedNode)) { - const currentIndex = nodeList.indexOf(selectedNode); - if (currentIndex === 0) { - showDialog(isUnmatched ? '没有上一个未匹配节点' : '没有上一个问题节点'); - return null; - } - return nodeList[currentIndex - 1]; - } - for (let i = startIndex - 1; i >= 0; i--) { - if (nodeList.includes(menuFirstRow[i])) { - return menuFirstRow[i]; - } - } - showDialog(isUnmatched ? '没有上一个未匹配节点' : '没有上一个问题节点'); - return null; - }; - - // 执行查找 - const nextNode = action === 'next' ? findNextNode() : findPreviousNode(); - - if (nextNode) { - let selectedNode; - if (isUnmatched || hasBNode) { - selectedNode = `${prefix}${nextNode}`; - } else { - selectedNode = nextNode; - } - this.set('selectedNode', selectedNode); - } - } - - _handleUnmatchSearch(event) { - this._handleNodeSearch(event, 'unmatched'); - } - - _handlePrecisonSearch(event) { - this._handleNodeSearch(event, 'precision'); - } - - _handleOverflowSearch(event) { - this._handleNodeSearch(event, 'overflow'); - } - - async _handleMatchedNodesClick(this) { - // 打开弹窗 - if (this.selectedNPUNode === '' || this.selectedBenchNode === '') { - this.showDynamicDialog('节点不可匹配'); - return; - } - // 请求参数 - const params = new URLSearchParams(); - const run = this.datasets[this._selectedRunIndex].name; - const tag = this.datasets[this._selectedRunIndex].tags[this._selectedTagIndex].tag; - params.set('run', run); - if (tag) params.set('tag', tag); - params.set('NPU', this.selectedNPUNode); - params.set('Bench', this.selectedBenchNode); - // 接口请求 - const precisionPath = 'match?' + String(params); - const precisionStr = await tf_graph_parser.fetchPbTxt(precisionPath); // 获取异步的 ArrayBuffer - const decoder = new TextDecoder(); - const decodedStr = decoder.decode(precisionStr); // 解码 ArrayBuffer 到字符串 - // 接口返回 - const mactchResult: MactchResult = JSON.parse(decodedStr); - - if (mactchResult.success) { - const mactchData = mactchResult.data; - this.push('matchednodeset', mactchData); - if (this.unMatchednodeset.length !== 0) { - const has = this.unMatchednodeset.indexOf(this.selectedNPUNode); - if (has !== -1) { - this.unMatchednodeset = [...this.unMatchednodeset.slice(0, has), ...this.unMatchednodeset.slice(has + 1)]; - } - } - tf_graph_node.setMatched(this.matchednodeset); //节点上色 - const index_N = this.NPU_unmatched.indexOf(this.selectedNPUNode.slice(4)); - if (index_N !== -1) { - this.splice('NPU_unmatched', index_N, 1); - this.notifyPath('NPU_unmatched'); - } - const index_B = this.Bench_unmatched.indexOf(this.selectedBenchNode.slice(4)); - if (index_B !== -1) { - this.splice('Bench_unmatched', index_B, 1); - this.notifyPath('Bench_unmatched'); - } - this.set('_selectedNpuMatchMenu', -1); - this.set('_selectedBenchMatchMenu', -1); - this.NPU_matched = [...this.NPU_matched, this.selectedNPUNode.slice(4)]; - this.Bench_matched = [...this.Bench_matched, this.selectedBenchNode.slice(4)]; - this.matched = [...this.matched, [this.selectedNPUNode.slice(4), this.selectedBenchNode.slice(4)]]; - this.showDynamicDialog('节点匹配成功'); - this.set('selectedNode', ''); - this.set('selectedNode', this.selectedNPUNode); - this.selectedNPUNode = ''; - this.selectedBenchNode = ''; - } else { - this.showDynamicDialog(mactchResult.error); - } - } - - async _handleUnMatchedNodesClick(this) { - // 打开弹窗 - if (this.selectedMatchedNPUNode === '' || this.selectedMatchedBenchNode === '') { - this.showDynamicDialog('取消匹配失败,请核对选择节点'); - return; - } - const existsInMatch = this.matched.some( - ([NPU_matched, Bench_matched]) => - NPU_matched === this.selectedMatchedNPUNode.slice(4) && - Bench_matched === this.selectedMatchedBenchNode.slice(4), - ); - if (!existsInMatch) { - this.showDynamicDialog('取消匹配失败,请核对选择节点'); - return; - } - this.NPU_unmatched.push(this.selectedMatchedNPUNode.slice(4)); - this.NPU_unmatched = [...this.NPU_unmatched]; - this.notifyPath('NPU_unmatched'); - this.Bench_unmatched.push(this.selectedMatchedBenchNode.slice(4)); - this.Bench_unmatched = [...this.Bench_unmatched]; - this.notifyPath('Bench_unmatched'); - const index_N = this.NPU_matched.indexOf(this.selectedMatchedNPUNode.slice(4)); - if (index_N !== -1) { - this.NPU_matched = [...this.NPU_matched.slice(0, index_N), ...this.NPU_matched.slice(index_N + 1)]; - } - const index_B = this.Bench_matched.indexOf(this.selectedMatchedBenchNode.slice(4)); - if (index_B !== -1) { - this.Bench_matched = [...this.Bench_matched.slice(0, index_B), ...this.Bench_matched.slice(index_B + 1)]; - } - const index_M = this.matched.findIndex( - (item) => item[0] === this.selectedMatchedNPUNode.slice(4) && item[1] === this.selectedMatchedBenchNode.slice(4), - ); - if (index_M !== -1) { - this.matched = [...this.matched.slice(0, index_M), ...this.matched.slice(index_M + 1)]; - } - const index_U = this.matchednodeset.findIndex((item) => item[0] === this.selectedMatchedNPUNode); - if (index_U !== -1) { - this.matchednodeset = [...this.matchednodeset.slice(0, index_U), ...this.matchednodeset.slice(index_U + 1)]; - } else { - this.unMatchednodeset.push(this.selectedMatchedNPUNode); - } - this.set('_selectedUnMatchMenu', -1); - const params = new URLSearchParams(); - const run = this.datasets[this._selectedRunIndex].name; - const tag = this.datasets[this._selectedRunIndex].tags[this._selectedTagIndex].tag; - params.set('run', run); - if (tag) params.set('tag', tag); - params.set('NPU', this.selectedMatchedNPUNode); - params.set('Bench', this.selectedMatchedBenchNode); - const precisionPath = 'unmatch?' + String(params); - const precisionStr = await tf_graph_parser.fetchPbTxt(precisionPath); // 获取异步的 ArrayBuffer - const decoder = new TextDecoder(); - const decodedStr = decoder.decode(precisionStr); // 解码 ArrayBuffer 到字符串 - tf_graph_node.setMatched(this.matchednodeset); - tf_graph_node.setUnMatched(this.unMatchednodeset); - this.set('selectedNode', ''); - this.showDynamicDialog('已取消匹配'); - } - // 写一个如果切换数据清除所有checkbox和所有this.selectColor - @observe('selection') - _clearAllToggleCheckbox() { - this.set('selectedSide', '0'); - const allCheckboxes = this.shadowRoot?.querySelectorAll('paper-checkbox'); - if (allCheckboxes) { - allCheckboxes.forEach((checkbox) => { - checkbox.checked = false; // 清空每个 checkbox 的选中状态 - }); - } - this.selectColor = []; - this.precisionmenu = []; - this.overflowLevel = []; - this.set('selectedNode', ''); - } - - async _toggleCheckbox(this, event) { - const { batch, step } = this.selection; - const item = event.model.item; - let checkbox, overflowCheckbox; - if (item[1].value) { - checkbox = this.shadowRoot?.getElementById(`checkbox-${event.model.index}`) as HTMLInputElement; - } else { - overflowCheckbox = this.shadowRoot?.getElementById(`overflowCheckbox-${event.model.index}`) as HTMLInputElement; - } - const run = this.datasets[this._selectedRunIndex].tags[this._selectedTagIndex].run; - const tag = this.datasets[this._selectedRunIndex].tags[this._selectedTagIndex].tag; - const params = new URLSearchParams(); - if (run) params.set('run', run); - if (tag) params.set('tag', tag); - params.set('batch', String(batch === -1 ? -1 : batch - 1)); - params.set('step', String(step === -1 ? -1 : step - 1)); - // 更新 selectColor 数组 - if (checkbox) { - if (checkbox.checked) { - this.selectColor.push(item[1].value); // 添加选中的颜色 - } else { - const index = this.selectColor.findIndex( - (color) => color[0] === item[1].value[0] && color[1] === item[1].value[1], - ); - if (index !== -1) { - this.selectColor.splice(index, 1); // 取消选中的颜色 - } - } - if (this.selectColor.length === 0) { - this.precisionmenu = []; - return; - } - params.set('precision_index', this.selectColor.join(',')); - const screenPath = 'screen?' + String(params); - try { - const screenStr = tf_graph_parser.fetchPbTxt(screenPath); - const precisionmenu = JSON.parse(new TextDecoder().decode(await screenStr).replace(/'/g, '"')) as object; - this.set('precisionmenu', precisionmenu); - } catch (e) { - console.error('Parse tooltips failed, please check the format of tooltips in the input vis file'); - } - // 更新数据绑定 - this.notifyPath(`menu.${event.model.index}.checked`, checkbox.checked); - } else { - if (overflowCheckbox.checked) { - this.overflowLevel.push(item[1]); // 添加选中的颜色 - } else { - const index = this.overflowLevel.findIndex((overflow) => overflow === item[1]); - if (index !== -1) { - this.overflowLevel.splice(index, 1); // 取消选中的颜色 - } - } - if (this.overflowLevel.length === 0) { - this.overflowmenu = []; - return; - } - params.set('overflow_level', this.overflowLevel.join(',')); - const screenPath = 'screen?' + String(params); - try { - const screenStr = tf_graph_parser.fetchPbTxt(screenPath); - this.overflowmenu = JSON.parse(new TextDecoder().decode(await screenStr).replace(/'/g, '"')) as object; - } catch (e) { - console.error('Parse tooltips failed, please check the format of tooltips in the input vis file'); - } - // 更新数据绑定 - this.notifyPath(`menu.${event.model.index}.checked`, overflowCheckbox.checked); - } - } - download() { this.fire('download-image-requested', this._downloadFilename); } @@ -1688,274 +942,13 @@ class TfGraphControls extends LegacyElementMixin(DarkModeMixin(PolymerElement)) _statsNotNull(stats: tf_graph_proto.StepStats) { return stats !== null; } - _toggleLegendOpen(): void { - this.set('_legendOpened', !this._legendOpened); - } - _toggleColorsOpen(): void { - this.set('_colors', !this._colors); - } - - toggleVisibility(): void { - this.set('_colorSetting', !this._colorSetting); - } - - _clickSetting(event) { - event.stopPropagation(); - this.set('_colors', true); - this.toggleVisibility(); - } - - _cancelAction() { - this.toggleVisibility(); - } - - _confirmAction() { - const newColorsList = {}; - const len = this.hiddenSelects.length; - if (len === 0) { - this.showDynamicDialog('配置失败,请添加配置项。'); - return; - } - - // 遍历每一项,动态生成 newColorsList 对象 - for (let i = 0; i < len; i++) { - const color = this.hiddenSelects[i].key; - const leftValue = this.hiddenSelects[i].values[0]; - const rightValue = this.hiddenSelects[i].values[1]; - // 检查每个组中的所有输入框是否都有值 - if (isNaN(leftValue) || isNaN(rightValue) || color === 'NaN') { - this.showDynamicDialog('配置失败,存在未配置项。'); - return; - } - // 将每个 color 和其对应的 leftValue 和 rightValue 作为 value 数组,设置到 colors 对象中 - newColorsList[color] = { - value: [leftValue, rightValue], - description: - '此节点所有输入输出的统计量相对误差,值越大代表测量值与标杆值的偏差越大,相对误差计算方式:|(测量值-标杆值)/标杆值|', - }; - } - // 无匹配节点图例一定存在 - newColorsList[UNMATCHED_COLOR] = { - value: [], - description: '对比过程中节点未匹配上', - }; - // 更新颜色列表 - this.set('colors', newColorsList); - let newColorSetChanged: any[] = []; - Object.entries(newColorsList).forEach(([color, details]) => { - let detailsTyped = details as { value: string }; - if (color === UNMATCHED_COLOR) { - detailsTyped.value = '无匹配节点'; - } - const colorset: any[] = [color, detailsTyped]; - newColorSetChanged.push(colorset); - }); - this.set('colorSetChanged', newColorSetChanged); - const params = new URLSearchParams(); - params.set('colors', JSON.stringify(newColorsList)); - const colorsPath = 'setNewColors?' + String(params); - tf_graph_parser.fetchPbTxt(colorsPath); - // 根据颜色列表重绘 - let nodeDataSet = Object.entries(this.renderHierarchy.npu.getIndex()); - tf_graph_node.getColors(this.colors); - for (let [_key, value] of Object.entries(nodeDataSet)) { - const renderInfo = value[1]; - const getElementBySelectors = (element, selectors) => { - let currentElement = element; - for (const selector of selectors) { - currentElement = currentElement?.shadowRoot?.querySelector(selector); - if (!currentElement) return null; - } - return currentElement; - }; - const graph = document.querySelector('graph-app'); - const svgRoot = getElementBySelectors(graph, ['tf-graph-board', 'tf-graph', 'tf-graph-scene', 'svg']); - const sceneElement = getElementBySelectors(graph, ['tf-graph-board', 'tf-graph', 'tf-graph-scene']); - const nodeGroup = d3.select(svgRoot).select(`.node[data-name="${renderInfo.node.name}"]`); - tf_graph_node.stylize(nodeGroup, renderInfo, sceneElement); - } - this.toggleVisibility(); - } - - _changeColor(event) { - const selectedColor = event.target.value; - const index = event.model.index; - this.set(`hiddenSelects.${index}.key`, selectedColor); - this.notifyPath('hiddenSelects'); - this._setColorList(); - } - - // 不显示NaN 而显示空 - _formatValue(value) { - return isNaN(value) ? '' : value; - } - - _validateInputs(event: any) { - const index = event.model.index; - const { values } = this.hiddenSelects[index]; - - // 显式定义 leftInputSet 和 rightInputSet 的类型为 number[] - const [leftInputSet, rightInputSet] = this.hiddenSelects.reduce<[number[], number[]]>( - (acc, item) => { - acc[0].push(item.values[0]); - acc[1].push(item.values[1]); - return acc; - }, - [[], []], // 初始值为两个空数组 - ); - - let value = parseFloat(event.target.value); - // 输入值验证 NaN值防护 限制输入范围 - if (isNaN(value) || value < 0 || value > 1) { - this._clearInput(event, index); - return; - } - - const valueStr = value.toString(); - - // 检查是否存在小数点 - const parts = valueStr.split('.'); - const hasDecimal = parts.indexOf('.') !== -1; - - // 如果存在小数点且小数部分长度超过最大限制 - if (hasDecimal && parts[1].length > 5) { - // 使用 toFixed 保留最多5位小数 - value = parseFloat(value.toFixed(5)); - } - - const isLeftInput = event.target.id === 'input-left'; - const otherSide = isLeftInput ? values[1] : values[0]; - const [left, right] = isLeftInput ? [value, otherSide] : [otherSide, value]; - - // 检查输入值是否有效 - if ((isLeftInput && left > right) || (!isLeftInput && right < left)) { - this._clearInput(event, index); - return; - } - - // 检查输入值是否与其他区间冲突 - const isConflict = this.hiddenSelects.some((item, i) => { - // 排除当前输入框 - if (i === index) return false; - - const [leftInput, rightInput] = item.values; - return ( - (isLeftInput && left !== leftInput && left >= leftInput && left < rightInput) || - (!isLeftInput && right !== rightInput && right > leftInput && right <= rightInput) || - (isLeftInput && leftInputSet.includes(left)) || - (!isLeftInput && rightInputSet.includes(right)) - ); - }); - - if (isConflict) { - this._clearInput(event, index); - return; - } - - // 0!@#¥ 也可以被float转换为0,阻止这种情况发生 - event.target.value = value; - // 更新值 - this.set(`hiddenSelects.${index}.values.${isLeftInput ? 0 : 1}`, value); - } - - _clearInput(event: any, index: number) { - event.target.value = ''; // 清空输入框 - this.set(`hiddenSelects.${index}.values.${event.target.id === 'input-left' ? 0 : 1}`, NaN); // 更新 hiddenSelects - } - - _addOption() { - if (this.hiddenSelects.length < 5) { - const obj = { - key: 'NaN', - values: [NaN, NaN], - }; - this.push('hiddenSelects', obj); - } - // 确保它在当前同步操作this.push()之后才执行. - this.async(() => { - this._setColorList(); - }, 0); - } - - _removeOption(event) { - const index = event.model.index; - - // 删除项 - this.splice('hiddenSelects', index, 1); - - // 恢复其他输入框的值 - this.hiddenSelects.forEach((item, i) => { - if (i >= index) { - this.set(`hiddenSelects.${i}.values`, item.values); - } - }); - this._setColorList(); - } - - _setColorList() { - let colorSelectElements = this.shadowRoot?.querySelectorAll('[id^="color-select"]'); - let backgroundColors: string[] = []; - this.hiddenSelects.forEach((item) => { - // 获取计算后的背景色 - const backgroundColor = item.key; - backgroundColors.push(backgroundColor); - }); - let newColorList = this.standardColorList.filter((color) => !backgroundColors.includes(color)); - this.set('colorList', newColorList); - // 清除选中,否则再次选中不同列表的同一顺位的值的时候不会触发on-change - this.async(() => { - colorSelectElements?.forEach((element) => { - if (element instanceof HTMLSelectElement) { - element.selectedIndex = -1; - } - }); - }, 0); - } - - _toggleOverflowLevelOpen(): void { - this.set('_overFlowLevel', !this._overFlowLevel); - } - _getToggleLegendIcon(legendOpened: boolean): string { - // This seems counter-intuitive, but actually makes sense because the - // expand-more button points downwards, and the expand-less button points - // upwards. For most collapsibles, this works because the collapsibles - // expand in the downwards direction. This collapsible expands upwards - // though, so we reverse the icons. - return legendOpened ? 'expand-more' : 'expand-less'; - } - _getSelectionOpGraphDisabled(datasets: Dataset, _selectedRunIndex: number, _selectedTagIndex: number) { - return ( - !datasets[_selectedRunIndex] || - !datasets[_selectedRunIndex].tags[_selectedTagIndex] || - !datasets[_selectedRunIndex].tags[_selectedTagIndex].opGraph - ); - } - _getSelectionProfileDisabled(datasets: Dataset, _selectedRunIndex: number, _selectedTagIndex: number) { - return ( - !datasets[_selectedRunIndex] || - !datasets[_selectedRunIndex].tags[_selectedTagIndex] || - !datasets[_selectedRunIndex].tags[_selectedTagIndex].profile - ); - } - _getSelectionConceptualGraphDisabled(datasets: Dataset, _selectedRunIndex: number, _selectedTagIndex: number) { - return ( - !datasets[_selectedRunIndex] || - !datasets[_selectedRunIndex].tags[_selectedTagIndex] || - !datasets[_selectedRunIndex].tags[_selectedTagIndex].conceptualGraph - ); - } - _getToggleIcon(expanded) { - return expanded ? 'expand-less' : 'expand-more'; - } - _toggleExpanded() { - this._expanded = !this._expanded; - } triggerMenuExpandEvent(newName) { const detailsElement = this.shadowRoot?.getElementById(newName) as HTMLDetailsElement; if (detailsElement?.open) { const event = new CustomEvent('menu-expand-node-changed', { detail: { name: newName, open: 'unexpand' }, }); + document.dispatchEvent(event); } else { const event = new CustomEvent('menu-expand-node-changed', { @@ -1981,7 +974,6 @@ class TfGraphControls extends LegacyElementMixin(DarkModeMixin(PolymerElement)) const subnode_list = 'subgraph?' + String(params); fetchPbTxt(subnode_list).then((arrayBuffer: ArrayBuffer) => { parseGraphPbTxt(arrayBuffer).then((graphDef) => { - this.updateGraphData(graphDef, nodeName); return graphDef; }); }); @@ -1993,211 +985,9 @@ class TfGraphControls extends LegacyElementMixin(DarkModeMixin(PolymerElement)) const summary = document.createElement('summary'); const detail = document.createElement('details'); summary.id = 'root'; - summary.textContent = '目录'; + summary.textContent = 'Root'; summary.addEventListener('click', this._getdata.bind(this)); detail.appendChild(summary); menubox?.appendChild(detail); } - _updateColorItems() { - const coloritems = this.shadowRoot?.getElementById('coloritems'); - const tbody = coloritems?.querySelector('tbody'); - if (Object.entries(this.colors).length !== 0) { - if (tbody) { - tbody.innerHTML = ''; - Object.entries(this.colors).forEach(([color, details]) => { - let detailsArray: any[] = []; - detailsArray = [details]; - if (detailsArray) { - const tr = document.createElement('tr'); - const td = document.createElement('td'); - const div = document.createElement('div'); - div.className = 'rectangle'; - div.style.backgroundColor = color; - const td2 = document.createElement('td'); - const td3 = document.createElement('td'); - const divInTd3 = document.createElement('div'); - const paperTooltip = document.createElement('paper-tooltip'); - const divInPaperTooltip = document.createElement('div'); - divInTd3.className = 'legend-clarifier'; - paperTooltip.setAttribute('animation-delay', '0'); - paperTooltip.setAttribute('position', 'right'); - paperTooltip.setAttribute('offset', '0'); - divInPaperTooltip.className = 'custom-tooltip'; - divInPaperTooltip.textContent = detailsArray[0].description; - if (detailsArray[0].value === '无匹配节点') { - td2.textContent = '无匹配节点'; - } else { - td2.textContent = detailsArray[0].value[0] + '-' + detailsArray[0].value[1]; - } - tbody.appendChild(tr); - tr.appendChild(td); - td.appendChild(div); - tr.appendChild(td2); - tr.appendChild(td3); - td3.appendChild(divInTd3); - divInTd3.appendChild(paperTooltip); - paperTooltip.appendChild(divInPaperTooltip); - } - }); - } - } else { - if (tbody) { - tbody.innerHTML = ''; - const rows = [ - { color: '#FFFCF3', text: '1-0.8' }, - { color: '#FFEDBE', text: '0.8-0.6' }, - { color: '#FFDC7F', text: '0.6-0.4' }, - { color: '#FFC62E', text: '0.4-0.2' }, - { color: '#ff704d', text: '0.2-0' }, - { color: '#C7C7C7', text: 'Not Connected' }, - ]; - rows.forEach(({ color, text }) => { - const tr = document.createElement('tr'); - tr.innerHTML = `
${text}`; - tbody.appendChild(tr); - }); - } - this.colorset = [ - [ - '#FFFCF3', - { - description: - '此节点所有输入输出的统计量相对误差,值越大代表测量值与标杆值的偏差越大,相对误差计算方式:|(测量值-标杆值)/标杆值|', - value: [0.8, 1], - }, - ], - [ - '#FFEDBE', - { - description: - '此节点所有输入输出的统计量相对误差,值越大代表测量值与标杆值的偏差越大,相对误差计算方式:|(测量值-标杆值)/标杆值|', - value: [0.6, 0.8], - }, - ], - [ - '#FFDC7F', - { - description: - '此节点所有输入输出的统计量相对误差,值越大代表测量值与标杆值的偏差越大,相对误差计算方式:|(测量值-标杆值)/标杆值|', - value: [0.4, 0.6], - }, - ], - [ - '#FFC62E', - { - description: - '此节点所有输入输出的统计量相对误差,值越大代表测量值与标杆值的偏差越大,相对误差计算方式:|(测量值-标杆值)/标杆值|', - value: [0.2, 0.4], - }, - ], - [ - '#ff704d', - { - description: - '此节点所有输入输出的统计量相对误差,值越大代表测量值与标杆值的偏差越大,相对误差计算方式:|(测量值-标杆值)/标杆值|', - value: [0, 0.2], - }, - ], - [ - '#C7C7C7', - { - description: '对比过程中节点未匹配上', - value: [], - }, - ], - ]; - } - } - updateGraphData(graphDef, nodeName) { - this.graphDef = graphDef; - let menubox; - let detailsElement; - if (nodeName === 'root') { - menubox = this.shadowRoot?.getElementById('menubox'); - detailsElement = menubox?.querySelector('details'); - } else { - detailsElement = this.shadowRoot?.getElementById(nodeName); - } - this.graphDef.node.forEach((node) => { - const detail = document.createElement('details'); - const summary = document.createElement('summary'); - let nodeNameWithoutPrefix; - detail.id = node.name; - summary.id = node.name; - summary.title = node.name; - if (!this.shadowRoot?.getElementById(node.name)) { - if (nodeName === 'root' && node.name.substring(0, 4) === NPU_PREFIX) { - nodeNameWithoutPrefix = node.name.substring(4) + '(对比)'; - } else if (nodeName === 'root' && node.name.substring(0, 4) === BENCH_PREFIX) { - nodeNameWithoutPrefix = node.name.substring(4) + '(标杆)'; - } else { - if (node.name.substring(0, 4) === BENCH_PREFIX || node.name.substring(0, 4) === NPU_PREFIX) { - nodeNameWithoutPrefix = node.name.substring(4); - } else { - nodeNameWithoutPrefix = node.name; - } - } - - if (node.isLeaf) { - summary.classList.add('no-arrow'); - detail.style.paddingLeft = '22px'; - } else { - summary.classList.remove('no-arrow'); - } - summary.style.backgroundColor = 'white'; - if (Object.keys(this.colors).length === 0) { - this.colors = { - '#FFFCF3': { - value: [0.8, 1], - description: - '此节点所有输入输出的统计量相对误差,值越大代表测量值与标杆值的偏差越大,相对误差计算方式:|(测量值-标杆值)/标杆值|', - }, - '#FFEDBE': { - value: [0.6, 0.8], - description: - '此节点所有输入输出的统计量相对误差,值越大代表测量值与标杆值的偏差越大,相对误差计算方式:|(测量值-标杆值)/标杆值|', - }, - '#FFDC7F': { - value: [0.4, 0.6], - description: - '此节点所有输入输出的统计量相对误差,值越大代表测量值与标杆值的偏差越大,相对误差计算方式:|(测量值-标杆值)/标杆值|', - }, - '#FFC62E': { - value: [0.2, 0.4], - description: - '此节点所有输入输出的统计量相对误差,值越大代表测量值与标杆值的偏差越大,相对误差计算方式:|(测量值-标杆值)/标杆值|', - }, - '#ff704d': { - value: [0, 0.2], - description: - '此节点所有输入输出的统计量相对误差,值越大代表测量值与标杆值的偏差越大,相对误差计算方式:|(测量值-标杆值)/标杆值|', - }, - '#C7C7C7': { - value: [], - description: '对比过程中节点未匹配上', - }, - }; - } - for (const [color, details] of Object.entries(this.colors)) { - let detailsArray: any[] = []; - detailsArray = [details]; - const [start, end] = detailsArray[0].value; - if ( - (start === end && node.precision_index === start) || - (node.precision_index >= start && node.precision_index < end) || - (node.precision_index === end && end === 1) - ) { - summary.style.backgroundColor = color; - break; - } else { - summary.style.backgroundColor = UNMATCHED_COLOR; - } - } - summary.textContent = nodeNameWithoutPrefix; - summary.addEventListener('click', this._getdata.bind(this)); - detail.appendChild(summary); - detailsElement?.appendChild(detail); - } - }); - } } diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_controls/utils.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_controls/utils.ts new file mode 100644 index 0000000000000000000000000000000000000000..2f48ef9a3a32a1e7ca7026a8a3680eee84a1dc7d --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_controls/utils.ts @@ -0,0 +1,24 @@ +/* Copyright (c) 2025, Huawei Technologies. + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +export const getElementBySelectors = (selectors) => { + let currentElement = document.querySelector('graph-app'); + for (const selector of selectors) { + currentElement = currentElement?.shadowRoot?.querySelector(selector); + if (!currentElement) return null; + } + return currentElement; +}; diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_loader/tf-graph-dashboard-loader.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_loader/tf-graph-dashboard-loader.ts index 90a46b12c973b8a8719101b9d746e366c17af439..222d4596c84bbd5bc8842eb856d71b8d8874a3e0 100644 --- a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_loader/tf-graph-dashboard-loader.ts +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_loader/tf-graph-dashboard-loader.ts @@ -34,7 +34,7 @@ interface GraphRunTag { interface Components { Menu: object; - ToolTip: object; + ToolTip: ''; Colors: object; MicroSteps: number; StepList: []; @@ -167,18 +167,18 @@ class TfGraphDashboardLoader extends LegacyElementMixin(PolymerElement) { value: 0, msg: '', }); - + const tracker = tf_graph_util.getTracker(this); const dataTracker = tf_graph_util.getSubtaskTracker(tracker, 100, 'Data'); dataTracker.setMessage('Initialization in progress'); - + let timer = 0; let shouldBreak = false; // 标志位,用于控制循环退出 - + // 启动定时器任务 - const timerTask = async function() { - let previousProgress = 0; // 记录上一次更新的进度 - + const timerTask = async function () { + let previousProgress = 0; // 记录上一次更新的进度 + while (timer <= DATA_LOAD_TIME && !shouldBreak) { if (timer < DATA_NOTICE_TIME) { const progress = Math.log(timer + 1) / Math.log(DATA_NOTICE_TIME); @@ -188,33 +188,33 @@ class TfGraphDashboardLoader extends LegacyElementMixin(PolymerElement) { } else { dataTracker.setMessage('File data too large, still reading'); } - await new Promise(resolve => setTimeout(resolve, 100)); + await new Promise((resolve) => setTimeout(resolve, 100)); timer++; } }.bind(this); - - const fetchTask = async function() { + + const fetchTask = async function () { let componentsStr; try { componentsStr = await tf_graph_parser.fetchPbTxt(componentsPath); } catch (e) { - shouldBreak = true; // 捕获 fetchPbTxt 错误并停止定时器 - dataTracker.reportError('Fetch error, please check first file in file path', e as Error) + shouldBreak = true; // 捕获 fetchPbTxt 错误并停止定时器 + dataTracker.reportError('Fetch error, please check first file in file path', e as Error); return; } - - shouldBreak = true; // 正常流程也停止定时器 - + + shouldBreak = true; // 正常流程也停止定时器 + let components: Components = { Menu: [], - ToolTip: {}, + ToolTip: '', Colors: {}, MicroSteps: 0, StepList: [], UnMatchedNode: [], match: [], }; - + try { if (componentsStr) { components = { @@ -223,11 +223,14 @@ class TfGraphDashboardLoader extends LegacyElementMixin(PolymerElement) { }; } } catch (e) { - shouldBreak = true; // 解析错误时停止定时器 - dataTracker.reportError('Parse components failed, please check the format of config data in the input vis file', e as Error) + shouldBreak = true; // 解析错误时停止定时器 + dataTracker.reportError( + 'Parse components failed, please check the format of config data in the input vis file', + e as Error, + ); return; } - + // 后续处理逻辑... const entries = Object.entries(components.ToolTip); const toolTipObject = Object.fromEntries(entries); @@ -251,14 +254,13 @@ class TfGraphDashboardLoader extends LegacyElementMixin(PolymerElement) { const steplistCount = Number(components.MicroSteps); this._setSteplist(steplistCount ? components.StepList : []); resolve(); - }.bind(this); - + // 同时启动定时器和 fetch 任务 await Promise.all([timerTask(), fetchTask()]); }); } - + _load(selection: tf_graph_controls.Selection) { const { run, tag, type: selectionType, batch, step } = selection; switch (selectionType) { diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_node_info/components/tf_vaadin_table/index.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_node_info/components/tf_vaadin_table/index.ts index 07f8f6c5cdc214c24a4d02ae59f6619363ec3f12..4ec13f59daa46596f5a5935cf1cfe0ce5f283e93 100644 --- a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_node_info/components/tf_vaadin_table/index.ts +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_node_info/components/tf_vaadin_table/index.ts @@ -39,6 +39,10 @@ class TfVaadinTable extends PolymerElement { align-items: center; } + vaadin-grid::part(header-cell) { + border-bottom: 1px solid rgb(66, 66, 66); + } + .highlight-cell { border: 1px solid #005fdb; border-radius: 4px; @@ -67,6 +71,9 @@ class TfVaadinTable extends PolymerElement { margin-right: 10px; border-radius: 50%; } + .splitter { + border-bottom: 1px solid rgb(66, 66, 66); + } @@ -98,6 +108,9 @@ class TfVaadinTable extends PolymerElement { @property({ type: Boolean }) isSingleGraphNode = false; // 是否是单节点图 + @property({ type: Object }) + tooltips: any; + @property({ type: Object }) handleCellClick!: (e: MouseEvent, syncGrid: HTMLElement) => void; @@ -128,11 +141,11 @@ class TfVaadinTable extends PolymerElement { } const ignoreDataIndex = ['data_name', 'isBench', 'isMatched', 'value']; const headers = Array.from( - data.slice(0, 5).reduce((keys, item) => { + data.reduce((keys, item) => { // 只取前5个数据项,避免性能问题 Object.keys(item).forEach((key) => { if (!ignoreDataIndex.includes(key)) { - keys.add(key); + keys.add(key + ''); } }); return keys; @@ -148,15 +161,22 @@ class TfVaadinTable extends PolymerElement { _renderDefaultValue(root: HTMLElement, column: any, rowData: any) { const selectedColor = this._getCssVariable('--selected-color'); const matchedColor = this._getCssVariable('--matched-color'); - if (rowData.item['isBench']) root.style.backgroundColor = matchedColor; - else root.style.backgroundColor = selectedColor; + root.classList.remove('splitter'); + if (rowData.item['isBench']) { + root.style.backgroundColor = matchedColor; + if (rowData.item['isMatched']) root.classList.add('splitter'); + } else root.style.backgroundColor = selectedColor; if (column.path === 'name' && !this.isSingleGraphNode) { const className = rowData.item['isMatched'] ? 'avater-matched' : 'avater-unmatched'; root.innerHTML = `${rowData.item[column.path]}`; return; } - root.title = rowData.item[column.path] || '-'; - root.textContent = rowData.item[column.path] || '-'; + let tooltip = rowData.item[column.path] ?? '-'; + if (this.tooltips && this.tooltips[column.path]) { + tooltip = this.tooltips[column.path] + ':\n' + tooltip; + } + root.title = tooltip; + root.textContent = rowData.item[column.path] ?? '-'; } handleGridClick(e: MouseEvent) { diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_node_info/index.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_node_info/index.ts index 9f282e3f31ddd8cc579cb43c951508572cde64f8..682979a4a13c495d923685a03b7b0b1133e3d3ae 100644 --- a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_node_info/index.ts +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_node_info/index.ts @@ -130,10 +130,12 @@ class TfGraphNodeInfo extends PolymerElement {
+