diff --git a/.gitee/PULL_REQUEST_TEMPLATE.zh-CN.md b/.gitee/PULL_REQUEST_TEMPLATE.zh-CN.md index fc9e09f35030f71a8b23b5bc9fe86b120820b8bc..e9cc1deb82ff0498f1a8267cd288ecde798f308c 100644 --- a/.gitee/PULL_REQUEST_TEMPLATE.zh-CN.md +++ b/.gitee/PULL_REQUEST_TEMPLATE.zh-CN.md @@ -17,6 +17,11 @@ --- +## 3. 分支合并要求 +- [ ] **代码合并**(请确保将 master 分支的最新代码同步合并至 poc 分支及 pre-research 分支,同时保证 poc 分支的代码也已正确合并到 pre-research 分支。) + +--- + ## 3. 代码检视 - **要求:** - 合入代码超过 200 行,需三人以上会议检视。 diff --git a/OWNERS b/OWNERS index 415d737ed907c577bc61e71c2839a485395b899c..775490f4b708f61c57619326b579a52fa49a9b86 100644 --- a/OWNERS +++ b/OWNERS @@ -20,4 +20,6 @@ reviewers: - TAJh - czr9775 - kali20gakki -- wjchuee \ No newline at end of file +- wjchuee +- chenhao_1209 +- feng123www \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/README.md b/debug/accuracy_tools/msprobe/README.md index 0e68d1f8d9bdaba93a2f65220f85d08eb45f8586..6b7d483078a6a744ce935591ced0971dea2f5b2f 100644 --- a/debug/accuracy_tools/msprobe/README.md +++ b/debug/accuracy_tools/msprobe/README.md @@ -44,6 +44,7 @@ export MSPROBE_LOG_LEVEL={x} - msprobe支持AscendPyTorch 1.11.0或更高版本,支持的PyTorch和CANN以及PyTorch和python软件版本配套关系请参见《[Ascend Extension for PyTorch插件](https://gitee.com/ascend/pytorch)》。 - msprobe支持MindSpore 2.4.0或更高版本,支持的MindSpore和CANN以及MindSpore和python软件版本配套关系请参见《[MindSpore版本发布列表](https://www.mindspore.cn/versions)》。 +- msprobe支持MSAdapter 2.1.0。 - msprobe支持的固件驱动版本与配套CANN软件支持的固件驱动版本相同,开发者可通过“[昇腾社区-固件与驱动](https://gitee.com/link?target=https%3A%2F%2Fwww.hiascend.com%2Fhardware%2Ffirmware-drivers%2Fcommunity%3Fproduct%3D2%26model%3D28%26cann%3D8.0.RC3.alpha003%26driver%3D1.0.25.alpha)”页面根据产品型号与CANN软件版本获取配套的固件与驱动。 @@ -69,35 +70,37 @@ export MSPROBE_LOG_LEVEL={x} ### 1 数据采集 -msprobe 通过在训练脚本中添加 PrecisionDebugger 接口的方式对 API 执行精度数据 dump 操作,对应 config.json 中的 task 为 statistics 或 tensor。 +msprobe 通过在训练脚本中添加 PrecisionDebugger 接口的方式对 API 执行精度数据 dump 操作。对应 config.json 中的 "statistics" 或 "tensor" task。 [PyTorch 场景的数据采集](./docs/05.data_dump_PyTorch.md) [MindSpore 场景的数据采集](./docs/06.data_dump_MindSpore.md) +[MSAdapter 场景的数据采集](./docs/29.data_dump_MSAdapter.md) + ### 2 精度预检 -精度预检旨在昇腾 NPU 上扫描训练模型中的所有 API 进行 API 复现,给出精度情况的诊断和分析。对应 config.json 中的 task 为 run_ut。 +精度预检旨在昇腾 NPU 上扫描训练模型中的所有 API 进行 API 复现,给出精度情况的诊断和分析。对应 config.json 中的 "run_ut" task。 PyTorch 场景的[离线预检](./docs/07.accuracy_checker_PyTorch.md)和[在线预检](./docs/08.accuracy_checker_online_PyTorch.md) MindSpore 动态图场景的[离线预检](./docs/09.accuracy_checker_MindSpore.md) -### 3 精度比对 +### 3 分级可视化构图比对 -该功能进行 PyTorch 整网 API 粒度的数据 dump、精度比对,进而定位训练场景下的精度问题。 +该功能将msprobe工具dump的精度数据进行解析,还原模型图结构,实现模型各个层级的精度数据比对,方便用户理解模型结构、分析精度问题。 -[PyTorch 场景的精度比对](./docs/10.accuracy_compare_PyTorch.md) +[PyTorch 场景的分级可视化构图比对](./docs/21.visualization_PyTorch.md) -[MindSpore 场景的精度比对](./docs/11.accuracy_compare_MindSpore.md) +[MindSpore 场景的分级可视化构图比对](./docs/22.visualization_MindSpore.md) -### 4 溢出检测与解析 +### 4 精度比对 -溢出检测与解析是在执行精度数据 dump 时,判断是否存在输入正常但输出存在溢出的 API,从而判断是否为正常溢出。对应 config.json 中的 overflow_check。 +该功能进行 PyTorch 整网 API 粒度的数据 dump、精度比对,进而定位训练场景下的精度问题。 -[PyTorch 场景的溢出检测与解析](./docs/12.overflow_check_PyTorch.md) +[PyTorch 场景的精度比对](./docs/10.accuracy_compare_PyTorch.md) -[MindSpore 场景的溢出检测与解析](./docs/13.overflow_check_MindSpore.md) +[MindSpore 场景的精度比对](./docs/11.accuracy_compare_MindSpore.md) ### 5 数据解析 @@ -129,26 +132,28 @@ MindSpore 动态图场景的[离线预检](./docs/09.accuracy_checker_MindSpore. [兼容 PyTorch 和 MindSpore 框架的训练状态监控](./docs/19.monitor.md) -### 10 分级可视化构图比对 +### 10 单算子API自动生成脚本 -该功能将msprobe工具dump的精度数据进行解析,还原模型图结构,实现模型各个层级的精度数据比对,方便用户理解模型结构、分析精度问题。 +该功能将msprobe工具dump的精度数据进行解析,自动生成单API脚本,用于复现整网中出现的算子问题,降低用户复现问题的成本,供开发分析算子问题。 -[PyTorch 场景的分级可视化构图比对](./docs/21.visualization_PyTorch.md) +[PyTorch 单算子API自动生成脚本](./docs/23.generate_operator_PyTorch.md) -[MindSpore 场景的分级可视化构图比对](./docs/22.visualization_MindSpore.md) +### 11 数码关联 +该功能只支持 MindSpore 静态图场景,用于将IR图与dump数据进行关联,获取dump数据和代码调用栈的关联关系。 -### 11 单算子API自动生成脚本 +[MindSpore 场景的数码关联](./docs/24.code_mapping_Mindspore.md) -该功能将msprobe工具dump的精度数据进行解析,自动生成单API脚本,用于复现整网中出现的算子问题,降低用户复现问题的成本,供开发分析算子问题。 +### 12 溢出检测与解析 -[PyTorch 单算子API自动生成脚本](./docs/23.generate_operator_PyTorch.md) +溢出检测用于采集溢出 API 或 模块的精度数据,而溢出解析则是通过对溢出数据的分析,进一步判断是否为正常溢出。对应 config.json 中的 "overflow_check" task。 +推荐直接使用[数据采集](#1-数据采集)功能采集统计量信息,检测溢出问题。 -### 12 数码关联 +[PyTorch 场景的溢出检测与解析](./docs/12.overflow_check_PyTorch.md) -该功能只支持 MindSpore 静态图场景,用于将IR图与dump数据进行关联,获取dump数据和代码调用栈的关联关系。 +[MindSpore 场景的溢出检测](./docs/13.overflow_check_MindSpore.md) -[MindSpore 场景的数码关联](./docs/24.code_mapping_Mindspore.md) +[MSAdapter 场景的溢出检测](./docs/30.overflow_check_MSAdapter.md) ## 📑 补充材料 diff --git a/debug/accuracy_tools/msprobe/ccsrc/base/DebuggerConfig.hpp b/debug/accuracy_tools/msprobe/ccsrc/base/DebuggerConfig.hpp index 15ea9e6fda47c0380d9718f135a1baf0658788eb..d56191443f8e6a7819c2bfbf402a5937bacd92ff 100644 --- a/debug/accuracy_tools/msprobe/ccsrc/base/DebuggerConfig.hpp +++ b/debug/accuracy_tools/msprobe/ccsrc/base/DebuggerConfig.hpp @@ -199,7 +199,7 @@ public: OverflowCheckCfg() = default; ~OverflowCheckCfg() = default; - uint32_t overflowNums{1}; + int32_t overflowNums{1}; DebuggerOpCheckLevel checkMode{DebuggerOpCheckLevel::CHECK_LEVEL_ALL}; private: diff --git a/debug/accuracy_tools/msprobe/ccsrc/core/AclDumpDataProcessor.cpp b/debug/accuracy_tools/msprobe/ccsrc/core/AclDumpDataProcessor.cpp index 0fe3443fa1f9286fe77c710c955d543d94c4b3a4..3374aa0be3117702ccde546fee14d7276d2e2c22 100644 --- a/debug/accuracy_tools/msprobe/ccsrc/core/AclDumpDataProcessor.cpp +++ b/debug/accuracy_tools/msprobe/ccsrc/core/AclDumpDataProcessor.cpp @@ -56,23 +56,30 @@ constexpr const char* kStatsHeaderShape = "Shape"; constexpr const char* kStatsHeaderMax = "Max Value"; constexpr const char* kStatsHeaderMin = "Min Value"; constexpr const char* kStatsHeaderAvg = "Avg Value"; -constexpr const char* kStatsHeaderL2Norm = "L2 Norm Value"; +constexpr const char* kStatsHeaderL2Norm = "l2norm"; +constexpr const char* kStatsHeaderL2NormInCsv = "L2Norm Value"; constexpr const char* kStatsHeaderMD5 = "MD5 Value"; constexpr const char* kStatsHeaderNan = "Nan Count"; +constexpr const char* kStatsHeaderNanInCsv = "NaN Count"; constexpr const char* kStatsHeaderNegInf = "Negative Inf Count"; constexpr const char* kStatsHeaderPosInf = "Positive Inf Count"; constexpr const char* kRankId = "RANK_ID"; constexpr const char* kDigitalNumbers = "0123456789"; -static const std::map summaryOptionHeaderStrMap = { - {DebuggerSummaryOption::MAX, kStatsHeaderMax}, - {DebuggerSummaryOption::MIN, kStatsHeaderMin}, - {DebuggerSummaryOption::MEAN, kStatsHeaderAvg}, - {DebuggerSummaryOption::L2NORM, kStatsHeaderL2Norm}, - {DebuggerSummaryOption::NAN_CNT, kStatsHeaderNan}, - {DebuggerSummaryOption::NEG_INF_CNT, kStatsHeaderNegInf}, - {DebuggerSummaryOption::POS_INF_CNT, kStatsHeaderPosInf}, - {DebuggerSummaryOption::MD5, kStatsHeaderMD5}, +static const std::map> summaryOptionHeaderStrMap = { + {DebuggerSummaryOption::MAX, {kStatsHeaderMax, kStatsHeaderMax}}, + {DebuggerSummaryOption::MIN, {kStatsHeaderMin, kStatsHeaderMin}}, + {DebuggerSummaryOption::MEAN, {kStatsHeaderAvg, kStatsHeaderAvg}}, + {DebuggerSummaryOption::L2NORM, {kStatsHeaderL2Norm, kStatsHeaderL2NormInCsv}}, + {DebuggerSummaryOption::NAN_CNT, {kStatsHeaderNan, kStatsHeaderNanInCsv}}, + {DebuggerSummaryOption::NEG_INF_CNT, {kStatsHeaderNegInf, kStatsHeaderNegInf}}, + {DebuggerSummaryOption::POS_INF_CNT, {kStatsHeaderPosInf, kStatsHeaderPosInf}}, + {DebuggerSummaryOption::MD5, {kStatsHeaderMD5, kStatsHeaderMD5}}, +}; + +const static std::map kDtypeTransMap = { + {AclDtype::DT_BF16, AclDtype::DT_FLOAT}, + {AclDtype::DT_INT4, AclDtype::DT_INT8}, }; class AclTensorStats { @@ -170,7 +177,7 @@ static std::map ParseTensorSummaryHeaderOrder(c for (uint32_t pos = 0; pos < segs.size(); ++pos) { const std::string& opt = segs[pos]; for (auto it = summaryOptionHeaderStrMap.begin(); it != summaryOptionHeaderStrMap.end(); ++it) { - if (opt == it->second) { + if (opt == it->second.first) { ret[pos] = it->first; break; } @@ -233,7 +240,7 @@ std::string AclTensorStats::GetCsvHeader() const ret.append("Op Type,Op Name,Task ID,Stream ID,Timestamp,Input/Output,Slot,Data Size,Data Type,Format,Shape"); for (auto it = stats.begin(); it != stats.end(); it++) { ret.append(","); - ret.append(summaryOptionHeaderStrMap.at(it->first)); + ret.append(summaryOptionHeaderStrMap.at(it->first).second); } ret.append("\n"); @@ -603,7 +610,7 @@ static std::string GenDataPath(const std::string& path) { inline std::string GetTensorInfoSuffix(AclTensorInfo& tensor) { return "." + tensor.inout + "." + std::to_string(tensor.slot) + - "." + DataUtils::GetFormatString(tensor.hostFmt) + "." + DataUtils::GetDTypeString(tensor.dtype); + "." + DataUtils::GetFormatString(tensor.hostFmt) + "." + DataUtils::GetDTypeString(tensor.oriDtype); } static DebuggerErrno DumpOneAclTensorFmtBin(AclTensorInfo& tensor) @@ -640,10 +647,13 @@ static DebuggerErrno DumpOneAclTensorFmtNpy(AclTensorInfo& tensor) return DebuggerErrno::OK; } - if (tensor.dtype == AclDtype::DT_BF16) { - ret = AclTensor::TransDtype(tensor, AclDtype::DT_FLOAT); + auto it = kDtypeTransMap.find(tensor.dtype); + if (it != kDtypeTransMap.end()) { + AclDtype dstDtype = it->second; + ret = AclTensor::TransDtype(tensor, dstDtype); if (ret != DebuggerErrno::OK) { - LOG_ERROR(ret, tensor + ": Failed to transform dtype from bf16 to fp32."); + LOG_ERROR(ret, tensor + ": Failed to transform dtype from " + DataUtils::GetDTypeString(it->first) + " to " + + DataUtils::GetDTypeString(it->second)+ "."); return ret; } } @@ -736,7 +746,9 @@ static DebuggerErrno DumpOneAclTensor(AclTensorInfo& tensor, std::vector overflowNums; +} + +void AclDumper::CountOverflowNumbers(const acldumpChunk* chunk) +{ + if (IsOverflowCompleted() || !isOverflowDump || !chunk->isLastChunk) { + return; + } + const std::string fileName = chunk->fileName; + auto separator = fileName.rfind("/"); + auto fileBaseName = fileName.substr(separator + 1); + if (fileBaseName.rfind("Opdebug.Node_OpDebug.") == 0) { + // count according to the first file: Node_OpDebug + realOverflowNums++; + } + return; +} + std::string AclDumper::GetDumpPath(uint32_t curStep) const { if (!initialized || foreDumpPath.empty()) { @@ -357,6 +377,11 @@ DebuggerErrno AclDumper::Initialize() void AclDumper::OnAclDumpCallBack(const acldumpChunk* chunk, int32_t len) { DEBUG_FUNC_TRACE(); + CountOverflowNumbers(chunk); + if (IsOverflowCompleted()) { + return; + } + std::string dumpPath = FileUtils::GetAbsPath(chunk->fileName); auto it = dataProcessors.find(dumpPath); if (it == dataProcessors.end()) { @@ -424,6 +449,8 @@ void AclDumper::SetDump(uint32_t rank, uint32_t curStep, ExtArgs& args) ret = AclDumpGenStatJson(statisticsCfg, rank, curStep, kernels); } else if (overflowCheckCfg != nullptr) { ret = AclDumpGenOverflowJson(overflowCheckCfg, rank, curStep); + overflowNums = overflowCheckCfg->overflowNums; + isOverflowDump = true; } if (ret != DebuggerErrno::OK) { diff --git a/debug/accuracy_tools/msprobe/ccsrc/core/AclDumper.hpp b/debug/accuracy_tools/msprobe/ccsrc/core/AclDumper.hpp index dcfad5fafcabdf944e1d4b0b0a3cd77251ce047d..6985df65e166101c08501e5e206e003bda494b9a 100644 --- a/debug/accuracy_tools/msprobe/ccsrc/core/AclDumper.hpp +++ b/debug/accuracy_tools/msprobe/ccsrc/core/AclDumper.hpp @@ -58,11 +58,17 @@ private: uint32_t curStep, const char** kernels); DebuggerErrno AclDumpGenOverflowJson(std::shared_ptr overflowCfg, uint32_t rank, uint32_t curStep); + void CountOverflowNumbers(const acldumpChunk* chunk); + bool IsOverflowCompleted(); + bool initialized{false}; bool aclDumpHasSet{false}; std::string foreDumpPath; std::vector hostAnalysisOpt; std::map> dataProcessors; + bool isOverflowDump{false}; + int32_t overflowNums{1}; + int32_t realOverflowNums{0}; }; void KernelInitDump(); diff --git a/debug/accuracy_tools/msprobe/ccsrc/core/AclTensor.cpp b/debug/accuracy_tools/msprobe/ccsrc/core/AclTensor.cpp index 45adff4962156f87f52c17166bc3b381f07f2978..4a5ec4c555198015603d7cc1446be66fda05765d 100644 --- a/debug/accuracy_tools/msprobe/ccsrc/core/AclTensor.cpp +++ b/debug/accuracy_tools/msprobe/ccsrc/core/AclTensor.cpp @@ -291,7 +291,11 @@ static inline void AssertDim(const AclShape& shape, size_t dim) static inline void AssertConsis(const AclTensorInfo& tensor) { - if (EleNumOfTensor(tensor, false) * SizeOfAclDType(tensor) != tensor.dataSize) { + size_t tensor_size = EleNumOfTensor(tensor, false) * SizeOfAclDType(tensor); + // Processing dtype whose size < 1 + // The ele num of quantization type(qint4*2) in MindSpore must be even. + if (tensor.dtype == AclDtype::DT_INT4) tensor_size = EleNumOfTensor(tensor, false) / 2; + if (tensor_size != tensor.dataSize) { throw std::runtime_error(tensor + ": The internal data of Tensor is inconsistent."); } } @@ -343,7 +347,7 @@ AclTensorInfo ParseAttrsFromDumpData(const std::string& dumpPath, const uint8_t* } int32_t subFormat = tensor.sub_format(); - return AclTensorInfo{dumpPath, data, dtype, dFmt, hFmt, dShape, hShape, dataSize, subFormat, io, slot, dumpOriginData}; + return AclTensorInfo{dumpPath, data, dtype, dtype, dFmt, hFmt, dShape, hShape, dataSize, subFormat, io, slot, dumpOriginData}; } template AclTensorInfo ParseAttrsFromDumpData( @@ -763,34 +767,80 @@ static void TransBf16ToFp32(const uint8_t* input, size_t num, uint8_t* output, s } } -DebuggerErrno TransDtype(AclTensorInfo& tensor, AclDtype to) +static void TransInt4ToInt8(const uint8_t* input, size_t elemNums, uint8_t* output, size_t bufferSize) { + if (bufferSize < elemNums * sizeof(int8_t)) { + LOG_ERROR(DebuggerErrno::ERROR_BUFFER_OVERFLOW, "Insufficient space for converting data from int4 to int8."); + return; + } + const int8_t *srcData = reinterpret_cast(input); + int8_t *dstData = reinterpret_cast(output); + size_t inputLength = elemNums / 2; + int maxValue = 7; + int minValue = -8; + int signBitShift = 3; + int signBitMask = 0x08; + for (size_t i = 0; i < inputLength; ++i) { + int8_t s = *srcData; + int8_t t = s & 0xf; + // keep the sign bit not change + int8_t signBit = (t & signBitMask) >> signBitShift; + if (signBit == 1) { + t = t | 0xf0; + } else { + t = t & 0x0f; + } + if (t < minValue || t > maxValue) { + LOG_ERROR(DebuggerErrno::ERROR_INVALID_VALUE, "Invalid int4 value."); + } + *dstData = t; + ++dstData; + + int highByteShift = 4; + t = s >> highByteShift; + signBit = (t & signBitMask) >> signBitShift; + if (signBit == 1) { + t = t | 0xf0; + } else { + t = t & 0x0f; + } + if (t < minValue || t > maxValue) { + LOG_ERROR(DebuggerErrno::ERROR_INVALID_VALUE, "Invalid int4 value."); + } + *dstData = t; + ++dstData; + ++srcData; + } + return; +} - const static std::set> kSupportedDtypeTrans = { - {AclDtype::DT_BF16, AclDtype::DT_FLOAT}, - }; +DebuggerErrno TransDtype(AclTensorInfo& tensor, AclDtype to) +{ if (tensor.dtype == to) { return DebuggerErrno::OK; } - if (kSupportedDtypeTrans.find({tensor.dtype, to}) == kSupportedDtypeTrans.end()) { - return DebuggerErrno::ERROR_UNKNOWN_TRANS; - } - + tensor.oriDtype = tensor.dtype; std::vector buffer; AssertConsis(tensor); size_t bufferSize = EleNumOfTensor(tensor) * SizeOfAclDType(to); - buffer.reserve(bufferSize); + buffer.resize(bufferSize); const uint8_t* input = tensor.transBuf.empty() ? tensor.aclData : tensor.transBuf.data(); uint8_t* output = buffer.data(); - /* 目前仅支持bf16->fp32,若有通用转换需求再用更泛化的方式重写 */ if (tensor.dtype == AclDtype::DT_BF16 && to == AclDtype::DT_FLOAT) { TransBf16ToFp32(input, EleNumOfTensor(tensor), output, bufferSize); + } else if (tensor.dtype == AclDtype::DT_INT4 && to == AclDtype::DT_INT8) { + TransInt4ToInt8(input, EleNumOfTensor(tensor), output, bufferSize); + } else { + LOG_ERROR(DebuggerErrno::ERROR_UNKNOWN_TRANS, tensor + ": Trans " + DataUtils::GetDTypeString(tensor.dtype) + + " to " + DataUtils::GetDTypeString(to) + " is not supported."); + return DebuggerErrno::ERROR_UNKNOWN_TRANS; } tensor.transBuf = std::move(buffer); + tensor.dtype = to; return DebuggerErrno::OK; } diff --git a/debug/accuracy_tools/msprobe/ccsrc/core/AclTensor.hpp b/debug/accuracy_tools/msprobe/ccsrc/core/AclTensor.hpp index 8b5ba5b06d935d5aaa2dff35e921b9072db6aa1a..f2ac429a7f14370ea1721369c7f9089cb971bb6e 100644 --- a/debug/accuracy_tools/msprobe/ccsrc/core/AclTensor.hpp +++ b/debug/accuracy_tools/msprobe/ccsrc/core/AclTensor.hpp @@ -40,6 +40,7 @@ struct AclTensorInfo { std::string dumpPath; const uint8_t* aclData; AclDtype dtype; + AclDtype oriDtype; AclFormat deviceFmt; AclFormat hostFmt; AclShape deviceShape; @@ -52,7 +53,7 @@ struct AclTensorInfo { std::vector transBuf; std::string ToString() const { - return "AclTensor(path=" + dumpPath + ",dtype=" + std::to_string(dtype) + ",inout=" + inout + ")"; + return "AclTensor(path=" + dumpPath + ",dtype=" + DataUtils::GetDTypeString(dtype) + ",inout=" + inout + ")"; } }; @@ -71,6 +72,7 @@ AclTensorInfo ParseAttrsFromDumpData(const std::string &dumpPath, const uint8_t* const std::string& io, uint32_t slot); DebuggerErrno TransFormatD2H(AclTensorInfo& tensor); DebuggerErrno TransDtype(AclTensorInfo& tensor, AclDtype to); +bool IsDtypeSupportTrans(AclDtype dtype); } } diff --git a/debug/accuracy_tools/msprobe/core/common/const.py b/debug/accuracy_tools/msprobe/core/common/const.py index d9623b807121ea129484a535fe8a9e2293e662f3..f7ba9f90d0e27df9a40ee37b838fbd2919f0d25b 100644 --- a/debug/accuracy_tools/msprobe/core/common/const.py +++ b/debug/accuracy_tools/msprobe/core/common/const.py @@ -230,6 +230,92 @@ class Const: TENSOR_STAT_LEN = 2 + SUPPORT_API_FILE_NAME = "support_wrap_ops.yaml" + + PT_API_TYPE_FUNCTIONAL = "functional" + PT_API_TYPE_TENSOR = "tensor" + PT_API_TYPE_TORCH = "torch" + PT_API_TYPE_VF = "_VF" + PT_API_TYPE_NPU = "torch_npu" + PT_API_TYPE_ATEN = "aten" + PT_API_TYPE_DIST = "distributed" + PT_API_TYPE_NPU_DIST = "npu_distributed" + + MS_API_TYPE_OPS = "ops" + MS_API_TYPE_TENSOR = "tensor" + MS_API_TYPE_STUB_TENSOR = "stubtensor" + MS_API_TYPE_MINT = "mint.ops" + MS_API_TYPE_MINT_FUNC = "mint.nn.functional" + MS_API_TYPE_COM = "communication.comm_func" + + FUNCTIONAL_API_TYPE_PREFIX = "Functional" + TENSOR_API_TYPE_PREFIX = "Tensor" + DIST_API_TYPE_PREFIX = "Distributed" + + TORCH_API_TYPE_PREFIX = "Torch" + NPU_API_TYPE_PREFIX = "NPU" + ATEN_API_TYPE_PREFIX = "Aten" + VF_API_TYPE_PREFIX = "VF" + + MINT_API_TYPE_PREFIX = "Mint" + MINT_FUNC_API_TYPE_PREFIX = "MintFunctional" + + SUPPORT_API_DICT_KEY_MAP = { + PT_FRAMEWORK: { + PT_API_TYPE_FUNCTIONAL: PT_API_TYPE_FUNCTIONAL, + PT_API_TYPE_TENSOR: PT_API_TYPE_TENSOR, + PT_API_TYPE_TORCH: PT_API_TYPE_TORCH, + PT_API_TYPE_VF: PT_API_TYPE_VF, + PT_API_TYPE_NPU: PT_API_TYPE_NPU, + PT_API_TYPE_ATEN: PT_API_TYPE_ATEN, + PT_API_TYPE_DIST: PT_API_TYPE_DIST, + PT_API_TYPE_NPU_DIST: PT_API_TYPE_NPU_DIST + }, + MS_FRAMEWORK: { + MS_API_TYPE_OPS: MS_API_TYPE_OPS, + MS_API_TYPE_TENSOR: MS_API_TYPE_TENSOR, + MS_API_TYPE_STUB_TENSOR: MS_API_TYPE_TENSOR, + MS_API_TYPE_MINT: MS_API_TYPE_MINT, + MS_API_TYPE_MINT_FUNC: MS_API_TYPE_MINT_FUNC, + MS_API_TYPE_COM: MS_API_TYPE_COM + }, + MT_FRAMEWORK: { + PT_API_TYPE_FUNCTIONAL: PT_API_TYPE_FUNCTIONAL, + PT_API_TYPE_TENSOR: PT_API_TYPE_TENSOR, + PT_API_TYPE_TORCH: PT_API_TYPE_TORCH, + PT_API_TYPE_NPU: PT_API_TYPE_NPU, + PT_API_TYPE_DIST: PT_API_TYPE_DIST + } + } + + API_DATA_PREFIX = { + PT_FRAMEWORK: { + PT_API_TYPE_FUNCTIONAL: FUNCTIONAL_API_TYPE_PREFIX, + PT_API_TYPE_TENSOR: TENSOR_API_TYPE_PREFIX, + PT_API_TYPE_TORCH: TORCH_API_TYPE_PREFIX, + PT_API_TYPE_VF: VF_API_TYPE_PREFIX, + PT_API_TYPE_NPU: NPU_API_TYPE_PREFIX, + PT_API_TYPE_ATEN: ATEN_API_TYPE_PREFIX, + PT_API_TYPE_DIST: DIST_API_TYPE_PREFIX, + PT_API_TYPE_NPU_DIST: DIST_API_TYPE_PREFIX + }, + MS_FRAMEWORK: { + MS_API_TYPE_OPS: FUNCTIONAL_API_TYPE_PREFIX, + MS_API_TYPE_TENSOR: TENSOR_API_TYPE_PREFIX, + MS_API_TYPE_STUB_TENSOR: TENSOR_API_TYPE_PREFIX, + MS_API_TYPE_MINT: MINT_API_TYPE_PREFIX, + MS_API_TYPE_MINT_FUNC: MINT_FUNC_API_TYPE_PREFIX, + MS_API_TYPE_COM: DIST_API_TYPE_PREFIX + }, + MT_FRAMEWORK: { + PT_API_TYPE_FUNCTIONAL: FUNCTIONAL_API_TYPE_PREFIX, + PT_API_TYPE_TENSOR: TENSOR_API_TYPE_PREFIX, + PT_API_TYPE_TORCH: TORCH_API_TYPE_PREFIX, + PT_API_TYPE_NPU: NPU_API_TYPE_PREFIX, + PT_API_TYPE_DIST: DIST_API_TYPE_PREFIX + } + } + class CompareConst: """ @@ -256,6 +342,7 @@ class CompareConst: MEAN_DIFF = "Mean diff" NORM_DIFF = "L2norm diff" COSINE = "Cosine" + EUC_DIST = "EucDist" MAX_ABS_ERR = "MaxAbsErr" MAX_RELATIVE_ERR = "MaxRelativeErr" MIN_RELATIVE_ERR = "MinRelativeErr" @@ -330,8 +417,8 @@ class CompareConst: ULP_ERR_STATUS = "ulp_err_status" COMPARE_RESULT_HEADER = [ - NPU_NAME, BENCH_NAME, NPU_DTYPE, BENCH_DTYPE, NPU_SHAPE, BENCH_SHAPE, COSINE, MAX_ABS_ERR, MAX_RELATIVE_ERR, - ONE_THOUSANDTH_ERR_RATIO, FIVE_THOUSANDTHS_ERR_RATIO, + NPU_NAME, BENCH_NAME, NPU_DTYPE, BENCH_DTYPE, NPU_SHAPE, BENCH_SHAPE, COSINE, EUC_DIST, + MAX_ABS_ERR, MAX_RELATIVE_ERR, ONE_THOUSANDTH_ERR_RATIO, FIVE_THOUSANDTHS_ERR_RATIO, NPU_MAX, NPU_MIN, NPU_MEAN, NPU_NORM, BENCH_MAX, BENCH_MIN, BENCH_MEAN, BENCH_NORM, ACCURACY, ERROR_MESSAGE ] @@ -357,18 +444,16 @@ class CompareConst: Const.MD5: MD5_COMPARE_RESULT_HEADER } - ALL_COMPARE_INDEX = [COSINE, MAX_ABS_ERR, MAX_RELATIVE_ERR, ONE_THOUSANDTH_ERR_RATIO, FIVE_THOUSANDTHS_ERR_RATIO] + ALL_COMPARE_INDEX = [COSINE, EUC_DIST, MAX_ABS_ERR, MAX_RELATIVE_ERR, ONE_THOUSANDTH_ERR_RATIO, + FIVE_THOUSANDTHS_ERR_RATIO] SUMMARY_COMPARE_INDEX = [MAX_DIFF, MIN_DIFF, MEAN_DIFF, NORM_DIFF, MAX_RELATIVE_ERR, MIN_RELATIVE_ERR, MEAN_RELATIVE_ERR, NORM_RELATIVE_ERR] # dtype match - MS_TYPE = [ - [Const.FLOAT16, Const.FLOAT32], [Const.FLOAT32, Const.FLOAT16], - [Const.FLOAT16, Const.BFLOAT16], [Const.BFLOAT16, Const.FLOAT16] - ] - TORCH_TYPE = [ - [Const.TORCH_FLOAT16, Const.TORCH_FLOAT32], [Const.TORCH_FLOAT32, Const.TORCH_FLOAT16], - [Const.TORCH_FLOAT16, Const.TORCH_BFLOAT16], [Const.TORCH_BFLOAT16, Const.TORCH_FLOAT16] + + DTYPE_MATCH_GROUPS = [ + {Const.FLOAT16, Const.FLOAT32, Const.BFLOAT16}, + {Const.TORCH_FLOAT16, Const.TORCH_FLOAT32, Const.TORCH_BFLOAT16} ] # read_op @@ -467,7 +552,7 @@ class CompareConst: BENCH_MEAN: None, BENCH_NORM: None, ACCURACY: '', ERROR_MESSAGE: '' } MS_GRAPH_NPY = { - COSINE: None, MAX_ABS_ERR: None, MAX_RELATIVE_ERR: None, ONE_THOUSANDTH_ERR_RATIO: None, + COSINE: None, EUC_DIST: None, MAX_ABS_ERR: None, MAX_RELATIVE_ERR: None, ONE_THOUSANDTH_ERR_RATIO: None, FIVE_THOUSANDTHS_ERR_RATIO: None } MS_GRAPH_STATISTIC = { @@ -538,61 +623,6 @@ class OverflowConst: OVERFLOW_DEBUG_MODE = 1 -class MsCompareConst: - # api_info field - MINT = "Mint" - MINT_FUNCTIONAL = "MintFunctional" - TENSOR_API = "Tensor" - - API_NAME_STR_LENGTH = 4 - MAX_RECURSION_DEPTH = 20 - - # Mindtorch api_info field - MINDTORCH_TENSOR = "Tensor" - MINDTORCH = "Torch" - MINDTORCH_FUNC = "Functional" - MINDTORCH_NPU = "NPU" - MINDTORCH_DIST = "Distributed" - - - - MT_VALID_API_TYPES = [ - MINDTORCH, MINDTORCH_FUNC, MINDTORCH_TENSOR - ] - - TASK_FIELD = "task" - STATISTICS_TASK = "statistics" - FRAMEWORK = "framework" - TENSOR_TASK = "tensor" - DUMP_DATA_DIR_FIELD = "dump_data_dir" - DATA_FIELD = "data" - - # supported api yaml - SUPPORTED_API_LIST_FILE = "checker_support_api.yaml" - SUPPORTED_TENSOR_LIST_KEY = "tensor" - - # detail_csv - DETAIL_CSV_API_NAME = "API Name" - DETAIL_CSV_BENCH_DTYPE = "Bench Dtype" - DETAIL_CSV_TESTED_DTYPE = "Tested Dtype" - DETAIL_CSV_SHAPE = "Shape" - DETAIL_CSV_PASS_STATUS = "Status" - DETAIL_CSV_MESSAGE = "Message" - DETAIL_CSV_FILE_NAME = "accuracy_checking_details" - - # result_csv - RESULT_CSV_FORWARD_TEST_SUCCESS = "Forward Test Success" - RESULT_CSV_BACKWARD_TEST_SUCCESS = "Backward Test Success" - RESULT_CSV_FILE_NAME = "accuracy_checking_result" - - EPSILON = 1e-8 - - class ProcessStatus: - SUCCESS = "success" - API_NOT_FOUND = "api_not_found" - EXCEPTION_SKIP = "exception_skip" - - class MsgConst: """ Class for log messages const diff --git a/debug/accuracy_tools/msprobe/core/common/utils.py b/debug/accuracy_tools/msprobe/core/common/utils.py index c06b5b64927bf47da1573df3b1d4db34dfa24cb1..963ca82545650c71f451ecc5624416a9e7b316df 100644 --- a/debug/accuracy_tools/msprobe/core/common/utils.py +++ b/debug/accuracy_tools/msprobe/core/common/utils.py @@ -75,6 +75,7 @@ class MsprobeBaseException(Exception): MERGE_COMPARE_RESULT_ERROR = 33 NAMES_STRUCTS_MATCH_ERROR = 34 INVALID_STATE_ERROR = 35 + INVALID_API_NAME_ERROR = 36 def __init__(self, code, error_info: str = ""): super(MsprobeBaseException, self).__init__() @@ -247,6 +248,10 @@ def md5_find(data): def detect_framework_by_dump_json(file_path): + json_data = load_json(file_path) + framework = json_data.get("framework", None) + if framework in [Const.PT_FRAMEWORK, Const.MS_FRAMEWORK]: + return framework pattern_ms = r'"type":\s*"mindspore' pattern_pt = r'"type":\s*"torch' with FileOpen(file_path, 'r') as file: @@ -424,6 +429,15 @@ def get_real_step_or_rank(step_or_rank_input, obj): return real_step_or_rank +def check_init_step(step): + if not is_int(step): + raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR, + f"{step} must be an integer") + if not step >= 0: + raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR, + f"{step} must be greater than or equal to 0") + + def check_seed_all(seed, mode, rm_dropout): if is_int(seed): if seed < 0 or seed > Const.MAX_SEED_VALUE: @@ -472,13 +486,13 @@ recursion_depth = defaultdict(int) # 装饰一个函数,当函数递归调用超过限制时,抛出异常并打印函数信息。 -def recursion_depth_decorator(func_info): +def recursion_depth_decorator(func_info, max_depth=Const.MAX_DEPTH): def decorator(func): @wraps(func) def wrapper(*args, **kwargs): func_id = id(func) recursion_depth[func_id] += 1 - if recursion_depth[func_id] > Const.MAX_DEPTH: + if recursion_depth[func_id] > max_depth: msg = f"call {func_info} exceeds the recursion limit." logger.error_log_with_exp( msg, diff --git a/debug/accuracy_tools/msprobe/core/compare/acc_compare.py b/debug/accuracy_tools/msprobe/core/compare/acc_compare.py index 55229d72657c67428186bcb233371e3b9eee73e0..f2aa8c479ecd0e40f3585708f82ff48a0e74832c 100644 --- a/debug/accuracy_tools/msprobe/core/compare/acc_compare.py +++ b/debug/accuracy_tools/msprobe/core/compare/acc_compare.py @@ -311,9 +311,9 @@ class Comparator: ] if self.dump_mode == Const.SUMMARY: - result_item = base_result_item + [" "] * 8 + result_item = base_result_item + [" "] * 8 # 8个统计量数据情况的比对指标 else: - result_item = base_result_item + [" "] * 5 + result_item = base_result_item + [" "] * 6 # 6个真实数据情况的比对指标 npu_summary_data = npu_ops_all.get(ms_op_name).get("summary") result_item.extend(npu_summary_data) @@ -329,7 +329,9 @@ class Comparator: else: result_item.append(CompareConst.NONE) if self.dump_mode == Const.ALL: - result_item.append(npu_ops_all.get(ms_op_name).get("data_name", None)) + ms_data_name = npu_ops_all.get(ms_op_name).get("data_name", None) + pt_data_name = bench_ops_all.get(bench_op_name).get("data_name", None) + result_item.append([ms_data_name, pt_data_name]) result.append(result_item) elif ms_op_name not in npu_ops_all: logger.warning(f'Can not find npu op name : `{ms_op_name}` in npu dump json file.') @@ -349,47 +351,48 @@ class Comparator: result_df = self.make_result_table(result) return result_df - def compare_by_op(self, npu_op_name, bench_op_name, op_name_mapping_dict, input_param, bench_data): + def compare_by_op(self, npu_op_name, bench_op_name, op_name_mapping_dict, input_param): """ :param npu_op_name: excel中的NPU_Name,例如:MintFunctional.conv2d.0.forward.input.3.0 :param bench_op_name: excel中的Bench_Name,例如:Functional.conv2d.0.forward.input.3.0 :param op_name_mapping_dict: op_name和npy或pt文件的映射关系 :param input_param: npu_json_path/bench_json_path/stack_json_path等参数 - :param bench_data: bench的dump数据中"data"字段 :return: result_list,包含余弦相似度、最大绝对误差、最大相对误差、千分之一误差率、千分之五误差率和错误信息 - 用于读取excel中的NPU_Name和Bench_Name,根据映射关系找到npy或pt文件,然后读取文件中的数据进行比较,计算余弦相似度、 + 用于读取excel中的NPU_Name和Bench_Name,根据映射关系找到npy或pt文件,然后读取文件中的数据进行比较,计算余弦相似度、欧式距离 最大绝对误差、最大相对误差、千分之一误差率、千分之五误差率并生成错误信息 """ - npu_bench_name_list = op_name_mapping_dict[npu_op_name] - data_name = safe_get_value(npu_bench_name_list, 1, "npu_bench_name_list") error_file, relative_err, error_flag = None, None, False - bench_data_name = get_bench_data_name(bench_op_name, bench_data) - if data_name == '-1' or data_name == -1: # 没有真实数据路径 - n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE - error_flag = True - elif not bench_data_name: + + data_name_pair = op_name_mapping_dict.get(npu_op_name) + npu_data_name = data_name_pair[0] + bench_data_name = data_name_pair[1] + + if str(npu_data_name) == '-1': # 没有npu真实数据 + n_value, b_value, error_flag = CompareConst.READ_NONE, CompareConst.READ_NONE, True + elif str(bench_data_name) == '-1': # 没有bench真实数据 n_value, b_value, error_flag = CompareConst.READ_NONE, CompareConst.READ_NONE, True error_file = 'no_bench_data' else: + npu_dir = input_param.get("npu_dump_data_dir") + bench_dir = input_param.get("bench_dump_data_dir") try: - read_npy_data = getattr(self, "read_npy_data") frame_name = getattr(self, "frame_name") + read_npy_data = getattr(self, "read_npy_data") if frame_name == "MSComparator": - n_value = read_npy_data(input_param.get("npu_dump_data_dir"), npu_op_name + Const.NUMPY_SUFFIX) + n_value = read_npy_data(npu_dir, npu_data_name) if self.cross_frame: - b_value = read_npy_data(input_param.get("bench_dump_data_dir"), bench_data_name, - load_pt_file=True) + b_value = read_npy_data(bench_dir, bench_data_name, load_pt_file=True) else: - b_value = read_npy_data(input_param.get("bench_dump_data_dir"), bench_data_name) + b_value = read_npy_data(bench_dir, bench_data_name) else: - n_value = read_npy_data(input_param.get("npu_dump_data_dir"), npu_op_name + Const.PT_SUFFIX) - b_value = read_npy_data(input_param.get("bench_dump_data_dir"), bench_data_name) + n_value = read_npy_data(npu_dir, npu_data_name) + b_value = read_npy_data(bench_dir, bench_data_name) except IOError as error: error_file = error.filename n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE error_flag = True except (FileCheckException, CompareException): - error_file = data_name + error_file = npu_data_name n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE error_flag = True @@ -456,21 +459,23 @@ class Comparator: def compare_ops(self, idx, dump_path_dict, result_df, lock, input_param): cos_result = [] + euc_dist_result = [] max_err_result = [] max_relative_err_result = [] - err_mess = [] one_thousand_err_ratio_result = [] five_thousand_err_ratio_result = [] + err_mess = [] + is_print_compare_log = input_param.get("is_print_compare_log") - bench_data = load_json(input_param.get("bench_json_path")).get('data') + for i in range(len(result_df)): npu_op_name = result_df.iloc[i, 0] bench_op_name = result_df.iloc[i, 1] if is_print_compare_log: logger.info("start compare: {}".format(npu_op_name)) - cos_sim, max_abs_err, max_relative_err, one_thousand_err_ratio, five_thousand_err_ratio, err_msg = \ - self.compare_by_op(npu_op_name, bench_op_name, dump_path_dict, input_param, bench_data) + cos_sim, euc_dist, max_abs_err, max_relative_err, one_thousand_err_ratio, five_thousand_err_ratio, err_msg \ + = self.compare_by_op(npu_op_name, bench_op_name, dump_path_dict, input_param) if is_print_compare_log: logger.info( @@ -479,71 +484,30 @@ class Comparator: "five_thousand_err_ratio {}".format(npu_op_name, cos_sim, max_abs_err, max_relative_err, err_msg, one_thousand_err_ratio, five_thousand_err_ratio)) cos_result.append(cos_sim) + euc_dist_result.append(euc_dist) max_err_result.append(max_abs_err) max_relative_err_result.append(max_relative_err) - err_mess.append(err_msg) one_thousand_err_ratio_result.append(one_thousand_err_ratio) five_thousand_err_ratio_result.append(five_thousand_err_ratio) + err_mess.append(err_msg) cr = ComparisonResult( cos_result=cos_result, + euc_dist_result=euc_dist_result, max_err_result=max_err_result, max_relative_err_result=max_relative_err_result, - err_msgs=err_mess, one_thousand_err_ratio_result=one_thousand_err_ratio_result, - five_thousand_err_ratio_result=five_thousand_err_ratio_result + five_thousand_err_ratio_result=five_thousand_err_ratio_result, + err_msgs=err_mess ) return _save_cmp_result(idx, cr, result_df, lock) - def do_multi_process(self, input_parma, result_df): + def do_multi_process(self, input_param, result_df): try: - result_df = _handle_multi_process(self.compare_ops, input_parma, result_df, + result_df = _handle_multi_process(self.compare_ops, input_param, result_df, multiprocessing.Manager().RLock()) return result_df except ValueError as e: logger.error('result dataframe is not found.') raise CompareException(CompareException.INVALID_DATA_ERROR) from e - - -def get_bench_data_name(bench_op_name, bench_data): - bench_name_list = re.split(r'\.(input|output|kwargs|parameters|parameters_grad)\.', bench_op_name) - if len(bench_name_list) > 1 and bench_name_list[1] == Const.PARAMS_GRAD: - bench_data_bundle = bench_data.get(bench_name_list[0] + Const.SEP + bench_name_list[1], {}) - else: - bench_data_bundle = bench_data.get(bench_name_list[0], {}) - if not bench_data_bundle or len(bench_name_list) < 3: - return None - layers = bench_name_list[2].split(Const.SEP) - - def _get(key, container): - if isinstance(container, dict): - return container.get(key) - if isinstance(container, list): - try: - return container[int(key)] - except (ValueError, IndexError): - return None - return None - - def get_by_layer(container, params_grad=False): - data = container - # dump.json中parameters_grad的结构为key:[{}], 如果存在key,有且只有一个列表元素,而op_name中只命名到了key,因此加'0' - if params_grad: - layers.append('0') - for layer in layers: - data = _get(layer, data) - return _get(CompareConst.DATA_NAME.lower(), data) - - if Const.INPUT == bench_name_list[1]: - return get_by_layer(bench_data_bundle.get(Const.INPUT, bench_data_bundle.get(Const.INPUT_ARGS))) - elif Const.KWARGS == bench_name_list[1]: - return get_by_layer(bench_data_bundle.get(Const.INPUT_KWARGS)) - elif Const.OUTPUT == bench_name_list[1]: - return get_by_layer(bench_data_bundle.get(Const.OUTPUT)) - elif Const.PARAMS == bench_name_list[1]: - return get_by_layer(bench_data_bundle.get(Const.PARAMS)) - elif Const.PARAMS_GRAD == bench_name_list[1]: - return get_by_layer(bench_data_bundle, params_grad=True) - else: - return None diff --git a/debug/accuracy_tools/msprobe/core/compare/check.py b/debug/accuracy_tools/msprobe/core/compare/check.py index 653823e20b29b14b6e7ede929f3bd2865bffaa18..9429d7ffa1a3c1feffb0bc68f5cde777e5f8d460 100644 --- a/debug/accuracy_tools/msprobe/core/compare/check.py +++ b/debug/accuracy_tools/msprobe/core/compare/check.py @@ -82,12 +82,8 @@ def check_type_shape_match(npu_struct, bench_struct): f'should both be 2, please check!') raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from error shape_match = npu_shape == bench_shape - type_match = npu_type == bench_type - if not type_match: - if ([npu_type, bench_type] in CompareConst.MS_TYPE) or ([npu_type, bench_type] in CompareConst.TORCH_TYPE): - type_match = True - else: - type_match = False + type_match = ((npu_type == bench_type) or + any(npu_type in group and bench_type in group for group in CompareConst.DTYPE_MATCH_GROUPS)) struct_match = shape_match and type_match if not struct_match: return False diff --git a/debug/accuracy_tools/msprobe/core/compare/highlight.py b/debug/accuracy_tools/msprobe/core/compare/highlight.py index cf3e1c4c03e9553f5566870b7c5ebe2d890e9774..1983313249f34680a8f25c3a2466d8871fe0a693 100644 --- a/debug/accuracy_tools/msprobe/core/compare/highlight.py +++ b/debug/accuracy_tools/msprobe/core/compare/highlight.py @@ -146,11 +146,13 @@ class HighlightRules: } # 用于比较输入和输出的规则 + # 真实数据检查规则 compare_rules = { "check_order_magnitude": CheckOrderMagnitude(), "check_one_thousand_error": CheckOneThousandErrorRatio(), "check_cosine_similarity": CheckCosineSimilarity() } + # 统计量数据检查规则 summary_compare_rules = { "check_order_magnitude": CheckOrderMagnitude(), "check_max_relative_diff": CheckMaxRelativeDiff(), diff --git a/debug/accuracy_tools/msprobe/core/compare/multiprocessing_compute.py b/debug/accuracy_tools/msprobe/core/compare/multiprocessing_compute.py index c2c1461e452f9d2c7f4e0e2803dfe51be2a132c0..71b0f29d64f717adc87b74cf48e891652e9e753f 100644 --- a/debug/accuracy_tools/msprobe/core/compare/multiprocessing_compute.py +++ b/debug/accuracy_tools/msprobe/core/compare/multiprocessing_compute.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -15,14 +15,17 @@ import multiprocessing from dataclasses import dataclass +from functools import partial + import pandas as pd from tqdm import tqdm + from msprobe.core.common.log import logger from msprobe.core.common.utils import CompareException from msprobe.core.common.const import CompareConst -def _handle_multi_process(func, input_parma, result_df, lock): +def _handle_multi_process(func, input_param, result_df, lock): process_num = max(int((multiprocessing.cpu_count() + 1) // 4), 1) op_name_mapping_dict = read_dump_data(result_df) @@ -44,7 +47,7 @@ def _handle_multi_process(func, input_parma, result_df, lock): progress_bar = tqdm(total=len(result_df), desc="API/Module Item Compare Process", unit="row", ncols=100) - def update_progress(size, progress_lock): + def update_progress(size, progress_lock, extra_param=None): with progress_lock: progress_bar.update(size) @@ -52,10 +55,12 @@ def _handle_multi_process(func, input_parma, result_df, lock): idx = df_chunk_size * process_idx chunk_size = len(df_chunk) result = pool.apply_async(func, - args=(idx, op_name_mapping_dict, df_chunk, lock, input_parma), + args=(idx, op_name_mapping_dict, df_chunk, lock, input_param), error_callback=err_call, - callback=update_progress(chunk_size, lock)) + callback=partial(update_progress, chunk_size, lock) + ) results.append(result) + final_results = [r.get() for r in results] pool.close() pool.join() @@ -92,12 +97,12 @@ def _ms_graph_handle_multi_process(func, result_df, mode): def read_dump_data(result_df): try: npu_dump_name_list = result_df.iloc[0:, 0].tolist() - npu_dump_tensor_list = result_df.iloc[0:, -1].tolist() + dump_tensor_pair_list = result_df.iloc[0:, -1].tolist() op_name_mapping_dict = {} for index, _ in enumerate(npu_dump_name_list): npu_dump_name = npu_dump_name_list[index] - npu_dump_tensor = npu_dump_tensor_list[index] - op_name_mapping_dict[npu_dump_name] = [npu_dump_tensor, npu_dump_tensor] + dump_tensor_pair = dump_tensor_pair_list[index] + op_name_mapping_dict[npu_dump_name] = dump_tensor_pair return op_name_mapping_dict except ValueError as e: logger.error('result dataframe is not found.') @@ -110,11 +115,12 @@ def read_dump_data(result_df): @dataclass class ComparisonResult: cos_result: list + euc_dist_result: list max_err_result: list max_relative_err_result: list - err_msgs: list one_thousand_err_ratio_result: list five_thousand_err_ratio_result: list + err_msgs: list def _save_cmp_result(offset, result: ComparisonResult, result_df, lock): @@ -135,15 +141,16 @@ def _save_cmp_result(offset, result: ComparisonResult, result_df, lock): for i, _ in enumerate(result.cos_result): process_index = i + offset result_df.loc[process_index, CompareConst.COSINE] = result.cos_result[i] + result_df.loc[process_index, CompareConst.EUC_DIST] = result.euc_dist_result[i] result_df.loc[process_index, CompareConst.MAX_ABS_ERR] = result.max_err_result[i] result_df.loc[process_index, CompareConst.MAX_RELATIVE_ERR] = result.max_relative_err_result[i] - result_df.loc[process_index, CompareConst.ERROR_MESSAGE] = result.err_msgs[i] - result_df.loc[process_index, CompareConst.ACCURACY] = ( - check_accuracy(result.cos_result[i], result.max_err_result[i])) result_df.loc[process_index, CompareConst.ONE_THOUSANDTH_ERR_RATIO] = ( result.one_thousand_err_ratio_result)[i] result_df.loc[process_index, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO] = ( result.five_thousand_err_ratio_result)[i] + result_df.loc[process_index, CompareConst.ACCURACY] = ( + check_accuracy(result.cos_result[i], result.max_err_result[i])) + result_df.loc[process_index, CompareConst.ERROR_MESSAGE] = result.err_msgs[i] return result_df except ValueError as e: logger.error('result dataframe is not found.') diff --git a/debug/accuracy_tools/msprobe/core/compare/npy_compare.py b/debug/accuracy_tools/msprobe/core/compare/npy_compare.py index c551985780cb9b56e32573727f9bf88f274da24e..4103d361fec14284fc38f97e1418e5405e939cd9 100644 --- a/debug/accuracy_tools/msprobe/core/compare/npy_compare.py +++ b/debug/accuracy_tools/msprobe/core/compare/npy_compare.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -70,7 +70,7 @@ def get_error_flag_and_msg(n_value, b_value, error_flag=False, error_file=None): error_flag = True return CompareConst.NONE, CompareConst.NONE, error_flag, err_msg if not n_value.shape: # 判断数据是否为0维张量 - err_msg = (f"This is type of 0-d tensor, can not calculate '{CompareConst.COSINE}', " + err_msg = (f"This is type of 0-d tensor, can not calculate '{CompareConst.COSINE}', '{CompareConst.EUC_DIST}', " f"'{CompareConst.ONE_THOUSANDTH_ERR_RATIO}' and '{CompareConst.FIVE_THOUSANDTHS_ERR_RATIO}'. ") error_flag = False # 0-d tensor 最大绝对误差、最大相对误差仍然支持计算,因此error_flag设置为False,不做统一处理 return n_value, b_value, error_flag, err_msg @@ -168,8 +168,9 @@ def statistics_data_check(result_dict): class TensorComparisonBasic(abc.ABC): """NPU和bench中npy数据的比较模板""" + @abc.abstractmethod - def apply(self, n_value, b_value, relative_err): + def apply(self, n_value, b_value, relative_err, err_msg): raise NotImplementedError @@ -190,6 +191,7 @@ def get_relative_err(n_value, b_value): class GetCosineSimilarity(TensorComparisonBasic): """计算cosine相似度""" + @staticmethod def correct_data(result): if result == CompareConst.NAN: @@ -198,9 +200,9 @@ class GetCosineSimilarity(TensorComparisonBasic): return round(float(result), 6) return result - def apply(self, n_value, b_value, relative_err): - if not n_value.shape: - return CompareConst.UNSUPPORTED, "" + def apply(self, n_value, b_value, relative_err, err_msg): + if "This is type of 0-d tensor" in err_msg: + return CompareConst.UNSUPPORTED, err_msg with np.errstate(divide="ignore", invalid="ignore"): if len(n_value) == 1: @@ -224,9 +226,22 @@ class GetCosineSimilarity(TensorComparisonBasic): return result, "" +class GetEuclideanDistance(TensorComparisonBasic): + """计算欧式距离""" + + def apply(self, n_value, b_value, relative_err, err_msg): + if "This is type of 0-d tensor" in err_msg: + return CompareConst.UNSUPPORTED, err_msg + + distance = np.linalg.norm(n_value - b_value, ord=2) + + return distance, "" + + class GetMaxAbsErr(TensorComparisonBasic): """计算最大绝对误差""" - def apply(self, n_value, b_value, relative_err): + + def apply(self, n_value, b_value, relative_err, err_msg): temp_res = n_value - b_value max_value = np.max(np.abs(temp_res)) if np.isnan(max_value): @@ -237,7 +252,8 @@ class GetMaxAbsErr(TensorComparisonBasic): class GetMaxRelativeErr(TensorComparisonBasic): """计算最大相对误差""" - def apply(self, n_value, b_value, relative_err): + + def apply(self, n_value, b_value, relative_err, err_msg): max_relative_err = np.max(np.abs(relative_err)) if np.isnan(max_relative_err): msg = "Cannot compare by MaxRelativeError, the data contains nan/inf/-inf in dump data." @@ -247,12 +263,13 @@ class GetMaxRelativeErr(TensorComparisonBasic): class GetErrRatio(TensorComparisonBasic): """计算相对误差小于指定阈值(千分之一、千分之五)的比例""" + def __init__(self, threshold): self.threshold = threshold - def apply(self, n_value, b_value, relative_err): - if not n_value.shape: - return CompareConst.UNSUPPORTED, "" + def apply(self, n_value, b_value, relative_err, err_msg): + if "This is type of 0-d tensor" in err_msg: + return CompareConst.UNSUPPORTED, err_msg if not np.size(relative_err): return CompareConst.NAN, "" @@ -264,6 +281,7 @@ class GetErrRatio(TensorComparisonBasic): class CompareOps: compare_ops = { "cosine_similarity": GetCosineSimilarity(), + "euclidean_distance": GetEuclideanDistance(), "max_abs_error": GetMaxAbsErr(), "max_relative_error": GetMaxRelativeErr(), "one_thousand_err_ratio": GetErrRatio(CompareConst.THOUSAND_RATIO_THRESHOLD), @@ -295,7 +313,7 @@ def compare_ops_apply(n_value, b_value, error_flag, err_msg): n_value, b_value = reshape_value(n_value, b_value) for op in CompareOps.compare_ops.values(): - result, msg = op.apply(n_value, b_value, relative_err) + result, msg = op.apply(n_value, b_value, relative_err, err_msg) result_list.append(result) err_msg += msg return result_list, err_msg diff --git a/debug/accuracy_tools/msprobe/core/compare/utils.py b/debug/accuracy_tools/msprobe/core/compare/utils.py index a2edf57e5bb91400675fe01734ea7fbf0e1df893..66dc9ba94ee168f2ea7dbba15f4c76d5e6ef6f13 100644 --- a/debug/accuracy_tools/msprobe/core/compare/utils.py +++ b/debug/accuracy_tools/msprobe/core/compare/utils.py @@ -285,9 +285,9 @@ def result_item_init(n_info, b_info, dump_mode): md5_compare_result = CompareConst.PASS if n_info.struct[2] == b_info.struct[2] else CompareConst.DIFF result_item.extend([n_info.struct[2], b_info.struct[2], md5_compare_result]) elif dump_mode == Const.SUMMARY: - result_item.extend([" "] * 8) + result_item.extend([" "] * 8) # 8个统计量数据情况的比对指标 else: - result_item.extend([" "] * 5) + result_item.extend([" "] * 6) # 6个真实数据情况的比对指标 else: err_msg = "index out of bounds error will occur in result_item_init, please check!\n" \ f"npu_info_struct is {n_info.struct}\n" \ @@ -321,8 +321,8 @@ def get_accuracy(result, n_dict, b_dict, dump_mode): has_stack = npu_stack_info and bench_stack_info if dump_mode == Const.ALL: - npu_data_name = n_dict.get("data_name", None) - bench_data_name = b_dict.get("data_name", None) + npu_data_name_list = n_dict.get("data_name", None) + bench_data_name_list = b_dict.get("data_name", None) for index in range(min_len): n_name = safe_get_value(n_dict, n_start + index, "n_dict", key="op_name") @@ -353,7 +353,9 @@ def get_accuracy(result, n_dict, b_dict, dump_mode): result_item.append(err_msg) result_item = stack_column_process(result_item, has_stack, index, key, npu_stack_info) if dump_mode == Const.ALL: - result_item.append(safe_get_value(npu_data_name, n_start + index, "npu_data_name")) + npu_data_name = safe_get_value(npu_data_name_list, n_start + index, "npu_data_name_list") + bench_data_name = safe_get_value(bench_data_name_list, b_start + index, "bench_data_name_list") + result_item.append([npu_data_name, bench_data_name]) result.append(result_item) @@ -371,7 +373,7 @@ def get_accuracy(result, n_dict, b_dict, dump_mode): continue result_item = [ n_name, CompareConst.NAN, n_struct[0], CompareConst.NAN, n_struct[1], CompareConst.NAN, - " ", " ", " ", " ", " " + " ", " ", " ", " ", " ", " " ] summary_data = n_dict.get(CompareConst.SUMMARY)[n_start + index] result_item.extend(summary_data) @@ -388,7 +390,8 @@ def get_accuracy(result, n_dict, b_dict, dump_mode): result_item.append(err_msg) result_item = stack_column_process(result_item, has_stack, index, key, npu_stack_info) if dump_mode == Const.ALL: - result_item.append(safe_get_value(npu_data_name, n_start + index, "npu_data_name")) + npu_data_name = safe_get_value(npu_data_name_list, n_start + index, "npu_data_name_list") + result_item.append([npu_data_name, "-1"]) result.append(result_item) @@ -453,9 +456,9 @@ def get_un_match_accuracy(result, n_dict, dump_mode): result.append(result_item) continue if dump_mode == Const.SUMMARY: - result_item.extend([CompareConst.N_A] * 8) + result_item.extend([CompareConst.N_A] * 8) # 8个统计量数据情况的比对指标 if dump_mode == Const.ALL: - result_item.extend([CompareConst.N_A] * 5) + result_item.extend([CompareConst.N_A] * 6) # 6个真实数据情况的比对指标 npu_summary_data = safe_get_value(summary_reorder, index, "summary_reorder") bench_summary_data = [CompareConst.N_A] * 4 @@ -467,7 +470,7 @@ def get_un_match_accuracy(result, n_dict, dump_mode): result_item.append(err_msg) append_stack_info(result_item, npu_stack_info, index) if dump_mode == Const.ALL and result_item[1] == CompareConst.N_A: - result_item.extend(["-1"]) + result_item.extend([["-1", "-1"]]) result.append(result_item) @@ -542,10 +545,17 @@ def get_name_and_state(name): state type: input, output, kwargs, parameters, parameters_grad """ + if not isinstance(name, str): + logger.error(f'Invalid name: {name}, type should be string, please check.') + raise CompareException(CompareException.INVALID_API_NAME_ERROR) + if Const.PARAMS_GRAD in name.split(Const.SEP): return name.split(Const.PARAMS_GRAD)[0], Const.PARAMS_GRAD split = re.split(Const.REGEX_FORWARD_BACKWARD, name) + if len(split) < 3: + logger.error(f'Invalid name string: {name}, can not be split by forward/backward, please check.') + raise CompareException(CompareException.INVALID_API_NAME_ERROR) api = f'{split[0]}.{split[1]}.' state_str = split[2] match = re.match(r'^(\d+\.)?(input|output|kwargs|parameters)\..+$', state_str) diff --git a/debug/accuracy_tools/msprobe/core/data_dump/api_registry.py b/debug/accuracy_tools/msprobe/core/data_dump/api_registry.py new file mode 100644 index 0000000000000000000000000000000000000000..1bef962232e47bc1eed399093e6812baa8f18f9c --- /dev/null +++ b/debug/accuracy_tools/msprobe/core/data_dump/api_registry.py @@ -0,0 +1,176 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, Any, Optional, Callable, Union, List, Tuple + +from msprobe.core.common.const import Const +from msprobe.core.common.file_utils import load_yaml + + +def _get_attr(module, attr_name): + if Const.SEP in attr_name: + sub_module_name, sub_attr = attr_name.rsplit(Const.SEP, 1) + sub_module = getattr(module, sub_module_name, None) + attr = getattr(sub_module, sub_attr, None) + else: + attr = getattr(module, attr_name, None) + return attr + + +class ApiWrapper: + def __init__( + self, api_types: Dict[str, Dict[str, Any]], + api_list_paths: Union[str, List[str], Tuple[str]] + ): + self.api_types = api_types + if not isinstance(api_list_paths, (list, tuple)): + api_list_paths = [api_list_paths] * len(self.api_types) + elif len(api_list_paths) != len(self.api_types): + raise RuntimeError("The number of api_list_paths must be equal to the number of frameworks in 'api_types', " + "when api_list_paths is a list or tuple.") + self.api_list_paths = api_list_paths + self.api_names = self._get_api_names() + self.wrapped_api_functions = dict() + + def wrap_api( + self, api_templates, hook_build_func: Optional[Callable] + ): + api_types_num = sum([len(v) for v in self.api_types.values()]) + if not isinstance(api_templates, (list, tuple)): + api_templates = [api_templates] * api_types_num + elif len(api_templates) != api_types_num: + raise RuntimeError("The number of api_templates must be equal to the number of api_types, " + "when api_templates is a list or tuple.") + + self.wrapped_api_functions.clear() + index = 0 + for framework, api_types in self.api_types.items(): + wrapped_functions_in_framework = dict() + for api_type, api_modules in api_types.items(): + wrapped_functions = dict() + name_prefix = Const.API_DATA_PREFIX.get(framework, {}).get(api_type, "API") + api_template = api_templates[index] + index += 1 + for api_name in self.api_names.get(framework, {}).get(api_type, []): + ori_api = _get_attr(api_modules[0], api_name) + if callable(ori_api): + def wrap_api_func(api_name, api_func, prefix, hook_build_func, api_template): + def api_function(*args, **kwargs): + return api_template(api_name, api_func, prefix, hook_build_func)(*args, **kwargs) + api_function.__name__ = api_name + return api_function + wrapped_functions[api_name] = wrap_api_func(api_name, ori_api, name_prefix, + hook_build_func, api_template) + wrapped_functions_in_framework[api_type] = wrapped_functions + self.wrapped_api_functions[framework] = wrapped_functions_in_framework + return self.wrapped_api_functions + + def _get_api_names(self): + api_names = dict() + + for index, framework in enumerate(self.api_types.keys()): + api_list = load_yaml(self.api_list_paths[index]) + valid_names = dict() + for api_type, api_modules in self.api_types.get(framework, {}).items(): + api_from_file = api_list.get(Const.SUPPORT_API_DICT_KEY_MAP.get(framework, {}).get(api_type), []) + names = set() + for api_name in api_from_file: + target_attr = api_name + target_module = api_modules[0] + if Const.SEP in api_name: + sub_module_name, target_attr = api_name.rsplit(Const.SEP, 1) + target_module = getattr(api_modules[0], sub_module_name, None) + if target_module and target_attr in dir(target_module): + names.add(api_name) + valid_names[api_type] = names + api_names[framework] = valid_names + + return api_names + + +class ApiRegistry: + """ + Base class for api registry. + """ + + def __init__(self, api_types, inner_used_api, supported_api_list_path, api_templates): + self.ori_api_attr = dict() + self.wrapped_api_attr = dict() + self.inner_used_ori_attr = dict() + self.inner_used_wrapped_attr = dict() + self.api_types = api_types + self.inner_used_api = inner_used_api + self.supported_api_list_path = supported_api_list_path + self.api_templates = api_templates + + @staticmethod + def store_ori_attr(ori_api_group, api_list, api_ori_attr): + for api in api_list: + api_ori_attr[api] = _get_attr(ori_api_group, api) + + @staticmethod + def set_api_attr(api_group, attr_dict): + for api, api_attr in attr_dict.items(): + if Const.SEP in api: + sub_module_name, sub_op = api.rsplit(Const.SEP, 1) + sub_module = getattr(api_group, sub_module_name, None) + if sub_module is not None: + setattr(sub_module, sub_op, api_attr) + else: + setattr(api_group, api, api_attr) + + def register_all_api(self): + for framework, api_types in self.api_types.items(): + for api_type, api_modules in api_types.items(): + api_type_with_framework = framework + Const.SEP + api_type + for module in api_modules[1]: + self.set_api_attr(module, self.wrapped_api_attr.get(api_type_with_framework, {})) + + def register_inner_used_api(self): + for api_type in self.inner_used_api.keys(): + self.set_api_attr(self.inner_used_api.get(api_type)[0], self.inner_used_wrapped_attr.get(api_type, {})) + + def restore_all_api(self): + for framework, api_types in self.api_types.items(): + for api_type, api_modules in api_types.items(): + api_type_with_framework = framework + Const.SEP + api_type + for module in api_modules[1]: + self.set_api_attr(module, self.ori_api_attr.get(api_type_with_framework, {})) + + def restore_inner_used_api(self): + for api_type in self.inner_used_api.keys(): + self.set_api_attr(self.inner_used_api.get(api_type)[0], self.inner_used_ori_attr.get(api_type, {})) + + def initialize_hook(self, hook_build_func): + api_wrapper = ApiWrapper(self.api_types, self.supported_api_list_path) + wrapped_api_functions = api_wrapper.wrap_api(self.api_templates, hook_build_func) + + for framework, api_types in self.api_types.items(): + for api_type, api_modules in api_types.items(): + ori_attr = dict() + self.store_ori_attr(api_modules[0], api_wrapper.api_names.get(framework).get(api_type), ori_attr) + api_type_with_framework = framework + Const.SEP + api_type + self.ori_api_attr[api_type_with_framework] = ori_attr + self.wrapped_api_attr[api_type_with_framework] = wrapped_api_functions.get(framework).get(api_type) + + for inner_used_api_type, inner_used_api_list in self.inner_used_api.items(): + ori_attr = dict() + wrapped_attr = dict() + for api_name in inner_used_api_list[1:]: + if self.ori_api_attr.get(inner_used_api_type, {}).get(api_name): + ori_attr[api_name] = self.ori_api_attr.get(inner_used_api_type).get(api_name) + wrapped_attr[api_name] = self.wrapped_api_attr.get(inner_used_api_type).get(api_name) + self.inner_used_ori_attr[inner_used_api_type] = ori_attr + self.inner_used_wrapped_attr[inner_used_api_type] = wrapped_attr diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py index 8c4542a1917b76809aad21971e148ec17bd6045e..c6ab0293cf3edafab06a5bf03e1a429d86e92720 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py @@ -26,7 +26,7 @@ from msprobe.core.data_dump.data_processor.base import (BaseDataProcessor, Tenso from msprobe.core.common.file_utils import path_len_exceeds_limit, save_npy from msprobe.mindspore.common.utils import convert_bf16_to_fp32, save_tensor_as_npy from msprobe.mindspore.common.log import logger -from msprobe.mindspore.dump.hook_cell.api_registry import api_register +from msprobe.mindspore.dump.hook_cell.api_register import get_api_register has_adump = True try: @@ -44,6 +44,7 @@ class MindsporeDataProcessor(BaseDataProcessor): "dtype": self.analyze_dtype_in_kwargs } self._async_dump_cache = {} + self.api_register = get_api_register() @staticmethod def get_md5_for_tensor(x): @@ -74,46 +75,29 @@ class MindsporeDataProcessor(BaseDataProcessor): else: if not ops.is_floating_point(data) or data.dtype == ms.float64: data = data.to(ms.float32) - api_register.norm_inner_op_set_ori_func() - get_max_value = api_register.mint_ops_ori_attr.get("max", mint.max) - get_min_value = api_register.mint_ops_ori_attr.get("min", mint.min) - get_mean_value = api_register.mint_ops_ori_attr.get("mean", mint.mean) - if hasattr(mint, "norm"): - get_norm_value = api_register.mint_ops_ori_attr.get("norm", mint.norm) - else: - get_norm_value = api_register.functional_ori_attr.get("norm", ops.norm) - tensor_stat.max = get_max_value(data).item() - tensor_stat.min = get_min_value(data).item() - tensor_stat.mean = get_mean_value(data).item() + get_norm_value = mint.norm if hasattr(mint, "norm") else ops.norm + tensor_stat.max = mint.max(data).item() + tensor_stat.min = mint.min(data).item() + tensor_stat.mean = mint.mean(data).item() tensor_stat.norm = get_norm_value(data).item() - api_register.norm_inner_op_set_hook_func() return tensor_stat @staticmethod def get_stat_info_async(data): tensor_stat = TensorStatInfo() - stack_method = api_register.functional_ori_attr.get("stack", ms.ops.stack) if data.dtype == ms.complex64 or data.dtype == ms.complex128: logger.warning("Async dump do not support complex data!") return tensor_stat elif data.dtype == ms.bool_: - tensor_stat.stack_tensor_stat = (["Max", "Min"], stack_method([data.any(), data.all()])) + tensor_stat.stack_tensor_stat = (["Max", "Min"], ops.stack([data.any(), data.all()])) elif not data.shape: - tensor_stat.stack_tensor_stat = (["Max", "Min", "Mean", "Norm"], stack_method([data, data, data, data])) + tensor_stat.stack_tensor_stat = (["Max", "Min", "Mean", "Norm"], ops.stack([data, data, data, data])) else: if not ops.is_floating_point(data) or data.dtype == ms.float64: data = data.to(ms.float32) - api_register.norm_inner_op_set_ori_func() - get_max_value = api_register.mint_ops_ori_attr.get("max", mint.max) - get_min_value = api_register.mint_ops_ori_attr.get("min", mint.min) - get_mean_value = api_register.mint_ops_ori_attr.get("mean", mint.mean) - if hasattr(mint, "norm"): - get_norm_value = api_register.mint_ops_ori_attr.get("norm", mint.norm) - else: - get_norm_value = api_register.functional_ori_attr.get("norm", ops.norm) - tensor_stat.stack_tensor_stat = (["Max", "Min", "Mean", "Norm"], stack_method( - [get_max_value(data), get_min_value(data), get_mean_value(data), get_norm_value(data)])) - api_register.norm_inner_op_set_hook_func() + get_norm_value = mint.norm if hasattr(mint, "norm") else ops.norm + tensor_stat.stack_tensor_stat = (["Max", "Min", "Mean", "Norm"], ops.stack( + [mint.max(data), mint.min(data), mint.mean(data), get_norm_value(data)])) return tensor_stat @staticmethod @@ -125,14 +109,17 @@ class MindsporeDataProcessor(BaseDataProcessor): return super().get_special_types() + cls.mindspore_special_type def get_stat_info(self, data): + self.api_register.restore_inner_used_api() tensor_stat = TensorStatInfo() if data.numel() == 0: - return tensor_stat + stat_info = tensor_stat else: if self.config.async_dump: - return MindsporeDataProcessor.get_stat_info_async(data) + stat_info = MindsporeDataProcessor.get_stat_info_async(data) else: - return MindsporeDataProcessor.get_stat_info_sync(data) + stat_info = MindsporeDataProcessor.get_stat_info_sync(data) + self.api_register.register_inner_used_api() + return stat_info def analyze_single_element(self, element, suffix_stack): if suffix_stack and suffix_stack[-1] in self.mindspore_object_key: @@ -191,7 +178,7 @@ class TensorDataProcessor(MindsporeDataProcessor): else: save_tensor_as_npy(tensor, file_path) return single_arg - + def _analyze_numpy(self, ndarray, suffix): dump_data_name, file_path = self.get_save_file_path(suffix) save_npy(ndarray, file_path) @@ -244,7 +231,7 @@ class OverflowCheckDataProcessor(MindsporeDataProcessor): api_info_struct = super().analyze_backward(name, module, module_input_output) self.maybe_save_overflow_data() return api_info_struct if self.has_overflow else None - + def analyze_params(self, name, param_name, grad): self.has_overflow = False api_info_struct = super().analyze_params(name, param_name, grad) diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py index 64253aa4260cab608e5ca84a5d006b28b94a33ab..4c56419dcb17b78918e4d46a3aaa50b12ef32777 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py @@ -78,14 +78,16 @@ class PytorchDataProcessor(BaseDataProcessor): def analyze_device_in_kwargs(element): single_arg = {} single_arg.update({'type': "torch.device"}) - if not isinstance(element, str): + if isinstance(element, (int, str)): + single_arg.update({"value": element}) + elif isinstance(element, torch.device): if hasattr(element, "index"): device_value = element.type + ":" + str(element.index) else: device_value = element.type single_arg.update({"value": device_value}) else: - single_arg.update({"value": element}) + logger.debug(f"Device type {type(element)} is not supported.") return single_arg @staticmethod @@ -143,7 +145,7 @@ class PytorchDataProcessor(BaseDataProcessor): if data.is_meta: return tensor_stat data_clone = data.detach() - if data_clone.numel() == 0: + if not data_clone.numel() or not data_clone.data_ptr(): return tensor_stat else: if data_clone.device.type == Const.CPU_LOWERCASE or not async_dump: @@ -228,7 +230,7 @@ class PytorchDataProcessor(BaseDataProcessor): if isinstance(element, dist.ProcessGroup): return self._analyze_process_group(element) if isinstance(element, dist.P2POp): - return self._analyze_p2pop(element) + return self._analyze_p2pop(element, Const.SEP.join([str(suffix) for suffix in suffix_stack])) if isinstance(element, dist.ReduceOp): return self._analyze_reduce_op(element) converted_numpy, numpy_type = self._convert_numpy_to_builtin(element) @@ -247,10 +249,10 @@ class PytorchDataProcessor(BaseDataProcessor): module_input_output.update_output_with_args_and_kwargs() return super().analyze_forward_output(name, module, module_input_output) - def _analyze_p2pop(self, arg): + def _analyze_p2pop(self, arg, suffix): p2pop_info = {"class_type": "torch.distributed.P2POp"} try: - tensor_info = self._analyze_tensor(arg.tensor, []) + tensor_info = self._analyze_tensor(arg.tensor, suffix) p2pop_info.update({"tensor": tensor_info}) p2pop_info.update({"op": arg.op.__name__}) p2pop_info.update({"peer": arg.peer}) @@ -311,7 +313,7 @@ class TensorDataProcessor(PytorchDataProcessor): saved_tensor = tensor.clone().contiguous().detach() save_pt(saved_tensor, file_path) return single_arg - + def _analyze_numpy(self, ndarray, suffix): dump_data_name, file_path = self.get_save_file_path(suffix) save_pt(torch.tensor(ndarray), file_path) diff --git a/debug/accuracy_tools/msprobe/docs/01.installation.md b/debug/accuracy_tools/msprobe/docs/01.installation.md index 1ab5f6419ba07ec749bad139f874fbc7301fd8b3..530783e87d0bdadd51856cb1ae08160cb081da80 100644 --- a/debug/accuracy_tools/msprobe/docs/01.installation.md +++ b/debug/accuracy_tools/msprobe/docs/01.installation.md @@ -16,7 +16,7 @@ pip install mindstudio-probe |版本|发布日期|支持 PyTorch 版本|支持 MindSpore 版本|下载链接|校验码| |:--:|:--:|:--:|:--:|:--:|:--:| -|1.2.2|2025.2.26|1.11/2.0/2.1/2.2|2.4.0|[mindstudio_probe-1.2.2-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/msprobe/1.2/mindstudio_probe-1.2.2-py3-none-any.whl)|1db0cf4572bc0305c68705b74775f652c6cb2c2bedb6c6e57f43e31ab273b288| +|1.2.2|2025.3.03|1.11/2.0/2.1/2.2|2.4.0|[mindstudio_probe-1.2.2-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/msprobe/1.2/mindstudio_probe-1.2.2-py3-none-any.whl)|961411bb460d327ea51d6ca4d0c8e8c5565f07c0852d7b8592b781ca35b87212| |1.2.1|2025.2.07|1.11/2.0/2.1/2.2|2.4.0|[mindstudio_probe-1.2.1-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/msprobe/1.2/mindstudio_probe-1.2.1-py3-none-any.whl)|b64b342118558e0339b39237f88a49b93fd24551b0cb202c872fbfef4260c86b| |1.2.0|2025.1.13|1.11/2.0/2.1/2.2|2.4.0|[mindstudio_probe-1.2.0-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/msprobe/1.2/mindstudio_probe-1.2.0-py3-none-any.whl)|1e3aeea1706112f6ee52fd1165037936bb209138f0b9ec42ea21e2c1c8942cdc| |1.1.1|2024.12.09|1.11/2.0/2.1/2.2|2.4.0|[mindstudio_probe-1.1.1-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/msprobe/1.1/mindstudio_probe-1.1.1-py3-none-any.whl)|577b597555dc155b76ba1a62d575c3546004644e140a456c3ba0824d46283735| diff --git a/debug/accuracy_tools/msprobe/docs/02.config_introduction.md b/debug/accuracy_tools/msprobe/docs/02.config_introduction.md index f134bd4536294d209e7b3e6e73fd80b9be61041d..a5f17637dae817de6eba091e9eef602ca95f091a 100644 --- a/debug/accuracy_tools/msprobe/docs/02.config_introduction.md +++ b/debug/accuracy_tools/msprobe/docs/02.config_introduction.md @@ -12,23 +12,23 @@ | 参数 | 解释 | 是否必选 | | ----------------- |------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -------- | -| task | dump 的任务类型,str 类型。可选参数:
"statistics":仅采集统计信息,默认值;
"tensor":采集统计信息和完全复刻整网的真实数据;
"run_ut":精度预检,仅 PyTorch 场景支持,采集数据时勿选;
"overflow_check":溢出检测;
"free_benchmark":无标杆比对;
"grad_probe":梯度监控;
"structure":仅采集模型结构以及调用栈信息,不采集具体数据。
根据 task 参数取值的不同,可以配置不同场景参数,详见:
[1.2 task 配置为 statistics](#12-task-配置为-statistics),
[1.3 task 配置为 tensor](#13-task-配置为-tensor),
[1.4 task 配置为 run_ut](#14-task-配置为-run_ut),
[1.5 task 配置为 overflow_check](#15-task-配置为-overflow_check),
[1.6 task 配置为 free_benchmark](#16-task-配置为-free_benchmark),
[1.7 task 配置为 grad_probe](#17-task-配置为-grad_probe)。
**配置示例**:"task": "tensor"。 | 否 | +| task | dump 的任务类型,str 类型。可选参数:
"statistics":仅采集统计信息,默认值;
"tensor":采集统计信息和完全复刻整网的真实数据;
"run_ut":精度预检,仅 PyTorch 场景支持,采集数据时勿选;
"overflow_check":溢出检测;
"free_benchmark":无标杆比对,不支持 MSAdapter 场景;
"grad_probe":梯度监控, 不支持 MSAdapter 场景;
"structure":仅采集模型结构以及调用栈信息,不采集具体数据。
根据 task 参数取值的不同,可以配置不同场景参数,详见:
[1.2 task 配置为 statistics](#12-task-配置为-statistics),
[1.3 task 配置为 tensor](#13-task-配置为-tensor),
[1.4 task 配置为 run_ut](#14-task-配置为-run_ut),
[1.5 task 配置为 overflow_check](#15-task-配置为-overflow_check),
[1.6 task 配置为 free_benchmark](#16-task-配置为-free_benchmark),
[1.7 task 配置为 grad_probe](#17-task-配置为-grad_probe)。
**配置示例**:"task": "tensor"。 | 否 | | dump_path | 设置 dump 数据目录路径,str 类型。
**配置示例**:"dump_path": "./dump_path"。 | 是 | | rank | 指定对某张卡上的数据进行采集,list[Union[int, str]] 类型,默认未配置(表示采集所有卡的数据),应配置元素为 ≥0 的整数或类似"4-6"的字符串,且须配置实际可用的 Rank ID。
PyTorch 场景: Rank ID 从 0 开始计数,最大取值为所有节点可用卡总数-1,若所配置的值大于实际训练所运行的卡的 Rank ID,则 dump 数据为空,比如当前环境 Rank ID 为 0 到 7,实际训练运行 0 到 3 卡,此时若配置 Rank ID 为 4 或不存在的 10 等其他值,dump 数据为空。
MindSpore 场景:所有节点的 Rank ID 均从 0 开始计数,最大取值为每个节点可用卡总数-1,config.json 配置一次 rank 参数对所有节点同时生效。
注意,单卡训练时,rank必须为[],即空列表,不能指定rank。
**配置示例**:"rank": [1, "4-6"]。 | 否 | | step | 指定采集某个 step 的数据,list[Union[int, str]] 类型。默认未配置,表示采集所有 step 数据。采集特定 step 时,须指定为训练脚本中存在的 step,可逐个配置,也可以指定范围。
**配置示例**:"step": [0, 1 , 2, "4-6"]。 | 否 | -| level | dump 级别,str 类型,根据不同级别采集不同数据。可选参数:
"L0":dump 模块级精度数据,仅 PyTorch 与 MindSpore 动态图场景支持,使用背景详见 [1.1.1 模块级精度数据 dump 说明](#111-模块级精度数据-dump-说明);
"L1":dump API 级精度数据,默认值,仅 PyTorch 与 MindSpore 动态图场景支持;
"L2":dump kernel 级精度数据,PyTorch场景详细介绍见 [PyTorch 场景的 kernel dump 说明](./04.kernel_dump_PyTorch.md);MindSpore场景详细介绍见 [MindSpore 场景的 kernel dump 说明](./28.kernel_dump_MindSpore.md);
"mix":dump module 模块级和 API 级精度数据,即"L0"+"L1",仅 PyTorch 与 MindSpore 动态图场景支持。
"debug":单点保存功能,细节详见[单点保存工具 README](./28.debugger_save_instruction.md)
**配置示例**:"level": "L1"。 | 否 | +| level | dump 级别,str 类型,根据不同级别采集不同数据。可选参数:
"L0":dump 模块级精度数据,仅 PyTorch、MSAdapter 以及 MindSpore 动态图场景支持,使用背景详见 [1.1.1 模块级精度数据 dump 说明](#111-模块级精度数据-dump-说明);
"L1":dump API 级精度数据,默认值,仅 PyTorch、MSAdapter 以及 MindSpore 动态图场景支持;
"L2":dump kernel 级精度数据,PyTorch 场景详细介绍见 [PyTorch 场景的 kernel dump 说明](./04.kernel_dump_PyTorch.md);MindSpore 动态图场景详细介绍见 [MindSpore 动态图场景的 kernel dump 说明](./28.kernel_dump_MindSpore.md);MindSpore 静态图场景详细介绍见《MindSpore 场景的数据采集》中的 ["**8.1 静态图场景**"](./06.data_dump_MindSpore.md#81-静态图场景)小节;
"mix":dump module 模块级和 API 级精度数据,即"L0"+"L1",仅 PyTorch、MSAdapter 以及 MindSpore 动态图场景支持。
"debug":单点保存功能,细节详见[单点保存工具 README](./28.debugger_save_instruction.md)
**配置示例**:"level": "L1"。 | 否 | | enable_dataloader | 自动控制开关,bool 类型,仅 PyTorch 场景支持。可选参数 true(开启)或 false(关闭),默认为 false。配置为 true 后自动识别 step 参数指定的迭代,并在该迭代执行完成后退出训练,此时 start、stop 和 step 函数可不配置,开启该开关要求训练脚本是通过 torch.utils.data.dataloader 方式加载数据。仅支持 PyTorch 单卡训练使用,分布式训练场景下存在数据 dump 不全问题。 **这个特性下个版本将被废弃** | 否 | | async_dump | 异步 dump 开关,bool 类型。可选参数 true(开启)或 false(关闭),默认为 false。配置为 true 后开启异步 dump,即采集的精度数据会在当前 step 训练结束后统一落盘,训练过程中工具不触发同步操作。由于使用该模式有**显存溢出**的风险,当 task 配置为 tensor 时,即真实数据的异步dump模式,必须配置 [list](#13-task-配置为-tensor) 参数,指定需要 dump 的 tensor 。该模式暂不支持复数类型 tensor
的统计量计算。 | 否 | #### 1.1.1 模块级精度数据 dump 说明 -仅 PyTorch 与 MindSpore 动态图场景支持。 +仅 PyTorch、MSAdapter以及 MindSpore 动态图场景支持。 大模型场景下,通常不是简单的利用自动迁移能力实现从 GPU 到 NPU 的训练脚本迁移,而是会对 NPU 网络进行一系列针对性的适配,因此,常常会造成迁移后的 NPU 模型存在部分子结构不能与 GPU 原始模型完全对应。模型结构不一致导致 API 调用类型及数量不一致,若直接按照 API 粒度进行精度数据 dump 和比对,则无法完全比对所有的 API。 本小节介绍的功能是对模型中的大粒度模块进行数据 dump,使其比对时,对于无法以 API 粒度比对的模块可以直接以模块粒度进行比对。 -模块指的是继承 nn.Module 类(PyTorch场景)或 nn.Cell 类(MindSpore场景)的子类,通常情况下这类模块就是一个小模型,可以被视为一个整体,dump 数据时以模块为粒度进行 dump。 +模块指的是继承 nn.Module 类(PyTorch 与 MSAdapter 场景)或 nn.Cell 类(MindSpore 场景)的子类,通常情况下这类模块就是一个小模型,可以被视为一个整体,dump 数据时以模块为粒度进行 dump。 @@ -36,21 +36,23 @@ - - - - + + - + - + + +
参数解释是否必选
scopePyTorch 和 MindSpore 动态图场景 dump 范围,list[str] 类型,默认未配置(list 也未配置时表示 dump 所有 API 的数据)。该参数可以在 [ ] 内配置两个模块名或 API 名,要求列表长度必须为2,需要配置按照工具命名格式的完整模块名或API名称,用于锁定区间,dump 该范围内的数据。
配置示例: +
scopePyTorch、MSAdapter 以及 MindSpore 动态图场景 dump 范围,list[str] 类型,默认未配置(list 也未配置时表示 dump 所有 API 的数据)。该参数可以在 [ ] 内配置两个模块名或 API 名,要求列表长度必须为2,需要配置按照工具命名格式的完整模块名或API名称,用于锁定区间,dump 该范围内的数据。
配置示例: "scope": ["Module.conv1.Conv2d.forward.0", "Module.fc2.Linear.forward.0"], 或 "scope": ["Cell.conv1.Conv2d.forward.0", "Cell.fc2.Dense.backward.0"], 或"scope": ["Tensor.add.0.forward", "Functional.square.2.forward"]。与 level 参数取值相关,level 为 L0 级别时,可配置模块名;level 为 L1 级别时,可配置 API 名, level为 mix 级别时,可配置为模块名或API名。
list自定义采集的算子列表,list[str] 类型,默认未配置(scope 也未配置时表示 dump 所有 API 的数据),包含以下配置方法:
PyTorch 和 MindSpore 动态图场景配置具体的 API 全称,dump 该 API 数据。在 PyTorch 场景,如果 level 配置成 L2,该配置为必填项。
配置示例:"list": ["Tensor.permute.1.forward", "Tensor.transpose.2.forward", "Torch.relu.3.backward"]。
PyTorch 和 MindSpore 动态图场景在level为 mix 级别时可以配置模块名称,dump该模块展开数据 (dump该模块从执行开始到执行结束期间的所有数据)。 +
PyTorch、MSAdapter 以及 MindSpore 动态图场景配置具体的 API 全称,dump 该 API 数据。在 PyTorch 场景,如果 level 配置成 L2,该配置为必填项。
配置示例:"list": ["Tensor.permute.1.forward", "Tensor.transpose.2.forward", "Torch.relu.3.backward"]。
PyTorch 和 MindSpore 动态图场景在level为 mix 级别时可以配置模块名称,dump该模块展开数据 (dump该模块从执行开始到执行结束期间的所有数据)。
配置示例:"list": ["Module.module.language_model.encoder.layers.0.mlp.ParallelMlp.forward.0"], 或 "list": ["Cell.network_with_loss.language_model.encoder.layers.0.mlp.ParallelMlp.forward.0"]
PyTorch 和 MindSpore 动态图场景指定某一类 API,dump 某一类的 API 级别输入输出数据。
配置示例:"list": ["relu"]。
PyTorch 和 MindSpore 动态图场景在level为 mix 级别时, 会dump名称中包含list中配置的字符串的API数据,还会将名称中包含list中配置的字符串的模块进行展开dump (dump该模块从执行开始到执行结束期间的所有数据)。
MindSpore 静态图场景配置 kernel_name,可以是算子的名称列表,也可以指定算子类型("level": "L2"时不支持),还可以配置算子名称的正则表达式(当字符串符合“name-regex(xxx)”格式时,后台则会将其作为正则表达式。
配置示例:list: ["name-regex(Default/.+)"]
可匹配算子名称以“Default/”开头的所有算子。
PyTorch、MSAdapter 以及 MindSpore 动态图场景指定某一类 API,dump 某一类的 API 级别输入输出数据。
配置示例:"list": ["relu"]。
PyTorch、MSAdapter 以及 MindSpore 动态图场景在level为 mix 级别时, 会dump名称中包含list中配置的字符串的API数据,还会将名称中包含list中配置的字符串的模块进行展开dump (dump该模块从执行开始到执行结束期间的所有数据)。
MindSpore 静态图场景配置 kernel_name,可以是算子的名称列表,也可以指定算子类型(jit_level=O2 时不支持),还可以配置算子名称的正则表达式(当字符串符合“name-regex(xxx)”格式时,后台则会将其作为正则表达式。
配置示例:list: ["name-regex(Default/.+)"]
可匹配算子名称以“Default/”开头的所有算子。
data_modedump 数据过滤,str 类型。
PyTorch 与 MindSpore 动态图场景:支持"all"、"forward"、"backward"、"input"和"output",除"all"外,其余参数可以自由组合。默认为["all"],即保存所有 dump 的数据。
配置示例:"data_mode": ["backward"] (仅保存反向数据)或 "data_mode": ["forward", "input"](仅保存前向的输入数据)。
PyTorch、MSAdapter 以及 MindSpore 动态图场景:支持"all"、"forward"、"backward"、"input"和"output",除"all"外,其余参数可以自由组合。默认为["all"],即保存所有 dump 的数据。
配置示例:"data_mode": ["backward"] (仅保存反向数据)或 "data_mode": ["forward", "input"](仅保存前向的输入数据)。
MindSpore 静态图场景:仅支持"all"、"input"和"output"参数,且各参数只能单独配置,不支持自由组合。
配置示例:"data_mode": ["all"]。
summary_mode控制 dump 文件输出的模式,str 类型,仅 PyTorch 与 MindSpore 动态图场景支持,可选参数:
md5:dump 输出包含 CRC-32 值以及 API 统计信息的 dump.json 文件,用于验证数据的完整性;
statistics:dump 仅输出包含 API 统计信息的 dump.json 文件,默认值。
配置示例:"summary_mode": "md5"。
MindSpore静态图jit_level=O2场景L2级dump,支持上述配置的同时额外支持配置统计项列表,可选统计项为max、min、mean、l2norm,可从中任意选取组合搭配。其中mean、l2norm的结果为float数据格式。
配置示例:"summary_mode": ["max", "min"]。
summary_mode控制 dump 文件输出的模式,str 类型,支持 PyTorch、MSAdapter、MindSpore 动态图以及 MindSpore 静态图 jit_level=O2 场景。
PyTorch、MSAdapter 以及 MindSpore 动态图场景:可选参数为
md5:dump 输出包含 CRC-32 值以及 API 统计信息的 dump.json 文件,用于验证数据的完整性;
statistics:dump 仅输出包含 API 统计信息的 dump.json 文件,默认值。
配置示例:"summary_mode": "md5"。
MindSpore 静态图 jit_level=O2 场景:支持上述配置的同时额外支持配置统计项列表,可选统计项为max、min、mean、l2norm,可从中任意选取组合搭配。其中mean、l2norm的结果为float数据格式。
配置示例:"summary_mode": ["max", "min"]。
-**说明**:"summary_mode"配置为"md5"时,所使用的校验算法为CRC-32算法。 +**说明**:"summary_mode" 配置为 "md5" 时,所使用的校验算法为 CRC-32 算法。 ### 1.3 task 配置为 tensor @@ -86,16 +88,16 @@ ### 1.5 task 配置为 overflow_check -PyTorch 与 MindSpore 动态图场景下,"level"须为"L0"或"L1";MindSpore 静态图场景下,"level"须为"L2",且模型编译优化等级(jit_level)须为"O2"。 +PyTorch、MSAdapter 以及 MindSpore 动态图场景下,"level"须为"L0"或"L1";MindSpore 静态图场景下,"level"须为"L2",且模型编译优化等级(jit_level)须为"O2"。 | 参数 | 解释 | 是否必选 | | ------------- | ---------------------- | -------- | | overflow_nums | 最大溢出次数,int 类型,默认为 1,仅 PyTorch 与 MindSpore 动态图场景支持。表示第 N 次溢出后,不再进行溢出检测。过程中检测到溢出 API 对应的 输入输出 数据均 dump。
**配置示例**:"overflow_nums": 3。配置为 -1 时,表示持续检测溢出直到训练结束。 | 否 | -| check_mode | 溢出类型,str 类型,仅 MindSpore 场景支持,可选参数:
"aicore":开启 AI Core 的溢出检测,不支持 MindSpore v2.3.0 以上版本;
"atomic":开启 Atomic 的溢出检测,不支持 MindSpore v2.3.0 以上版本;
"all":开启算子的溢出检测,默认值。
**配置示例**:"check_mode": "all"。 | 否 | +| check_mode | 溢出类型,str 类型,仅 MindSpore v2.3.0 以下版本的静态图场景支持,可选参数:
"aicore":开启 AI Core 的溢出检测;
"atomic":开启 Atomic 的溢出检测;
"all":开启算子的溢出检测,默认值。
**配置示例**:"check_mode": "all"。 | 否 | ### 1.6 task 配置为 free_benchmark -仅 PyTorch 场景与 MindSpore 动态图场景支持,且"level"为"L1"。 +仅 PyTorch 与 MindSpore 动态图场景支持,且"level"为"L1"。 - task 配置为 free_benchmark 时,开启**无标杆比对**,在 NPU 环境下通过对当前模型 API 的输入添加扰动因子,二次执行,将得到的输出与未添加扰动因子前的输出进行比对,从而**得出该模型中可能存在因迁移等变化导致精度降低的 API**。 diff --git a/debug/accuracy_tools/msprobe/docs/05.data_dump_PyTorch.md b/debug/accuracy_tools/msprobe/docs/05.data_dump_PyTorch.md index db9a989c9d1c731fd9099d311f3ab3b95e5c7d5d..be9386df1d906330f76b536f19ef1a13d1553754 100644 --- a/debug/accuracy_tools/msprobe/docs/05.data_dump_PyTorch.md +++ b/debug/accuracy_tools/msprobe/docs/05.data_dump_PyTorch.md @@ -183,10 +183,25 @@ save(variable, name, save_backward=True) **参数说明**: | 参数名称 | 参数含义 | 支持数据类型 | 是否必选| | ---------- | ------------------| ------------------- | ------------------- | -| variable | 需要保存的变量 |dict, list, torch.tensor, int, float, str | 是 | +| variable | 需要保存的变量 |dict, list, tuple, torch.tensor, int, float, str | 是 | | name | 指定的名称 | str | 是 | | save_backward | 是否保存反向数据 | boolean | 否 | +### 1.10 set_init_step + +**功能说明**:设置起始step数,step数默认从0开始计数,使用该接口后step从指定值开始计数。该函数需在 **start** 函数调用前使用,建议写在训练迭代的循环开始前。 + +**原型**: + +```Python +debugger.set_init_step(step) +``` + +**参数说明**: + +1.step: 指定的起始step数。 + + ## 2 示例代码 ### 2.1 快速上手 @@ -355,7 +370,7 @@ if __name__ == "__main__": ``` * `rank`:设备 ID,每张卡的数据保存在对应的 `rank{ID}` 目录下。非分布式场景下没有 rank ID,目录名称为 rank。 * `dump_tensor_data`:保存采集到的张量数据。 -* `dump.json`: 保存API或Module前反向数据的统计量信息。包含dump数据的API名称或Module名称,各数据的dtype、 shape、max、min、mean、L2norm(L2范数,平方根)统计信息以及当配置summary_mode="md5"时的CRC-32数据。具体介绍可参考[dump.json文件说明](./27.dump_json_instruction.md#1-dumpjson文件介绍pytorch)。 +* `dump.json`: 保存API或Module前反向数据的统计量信息。包含dump数据的API名称或Module名称,各数据的dtype、 shape、max、min、mean、L2norm(L2范数,平方根)统计信息以及当配置summary_mode="md5"时的CRC-32数据。具体介绍可参考[dump.json文件说明](./27.dump_json_instruction.md#1-PyTorch场景下的dump.json文件)。 * `stack.json`:API/Module的调用栈信息。 * `construct.json`:分层分级结构,level为L1时,construct.json内容为空。 diff --git a/debug/accuracy_tools/msprobe/docs/06.data_dump_MindSpore.md b/debug/accuracy_tools/msprobe/docs/06.data_dump_MindSpore.md index f7507facd2a92f3acbefdc92fa6cd808a155d6e3..d99f61ba40e2978ec950b375382bb58d7379f37c 100644 --- a/debug/accuracy_tools/msprobe/docs/06.data_dump_MindSpore.md +++ b/debug/accuracy_tools/msprobe/docs/06.data_dump_MindSpore.md @@ -144,13 +144,27 @@ save(variable, name, save_backward=True) **参数说明**: | 参数名称 | 参数含义 | 支持数据类型 | 是否必选| | ---------- | ------------------| ------------------- | ------------------- | -| variable | 需要保存的变量 |dict, list, torch.tensor, int, float, str | 是 | +| variable | 需要保存的变量 |dict, list, tuple, torch.tensor, int, float, str | 是 | | name | 指定的名称 | str | 是 | | save_backward | 是否保存反向数据 | boolean | 否 | +#### 6.1.6 set_init_step -### 6.2 msprobe.mindspore.common.utils.MsprobeStep +**功能说明**:设置起始step数,step数默认从0开始计数,使用该接口后step从指定值开始计数。该函数需在 **start** 函数调用前使用,建议写在训练迭代的循环开始前。 + +**原型**: + +```Python +set_init_step(step) +``` + +**参数说明**: + +1.step: 指定的起始step数。 + + +### 6.2 msprobe.mindspore.MsprobeStep **功能说明**:MindSpore Callback类,自动在每个step开始时调用start()接口,在每个step结束时调用stop()、step()接口。实现使用 Model 高阶 API 的动态图场景下 L0、L1、mix 级别的精度数据采集控制,控制粒度为单个 **Step** ,而 PrecisionDebugger.start, PrecisionDebugger.stop 接口的控制粒度任意训练代码段。 @@ -164,7 +178,17 @@ MsprobeStep(debugger) 1. debugger:PrecisionDebugger对象。 -### 6.3 msprobe.mindspore.seed_all +### 6.3 msprobe.mindspore.MsprobeInitStep + +**功能说明**:MindSpore Callback 类,自动获取并设置初始 step 值。仅适用于静态图 O0/O1 模式的断点续训场景。 + +**原型**: + +```Python +MsprobeInitStep() +``` + +### 6.4 msprobe.mindspore.seed_all **功能说明**:用于固定网络中的随机性和开启确定性计算。 @@ -181,9 +205,6 @@ seed_all(seed=1234, mode=False, rm_dropout=True) 3. rm_dropout:控制dropout失效的开关。可配置 True 或 False,默认值:True,非必选。参数示例:rm_dropout=True。该参数设置为 True 后,将会使mindspore.ops.Dropout,mindspore.ops.Dropout2D,mindspore.ops.Dropout3D,mindspore.mint.nn.Dropout和mindspore.mint.nn.functional.dropout失效,以避免因随机dropout造成的网络随机性。建议在采集mindspore数据前开启。注意:通过rm_dropout控制dropout失效或生效需要在初始化Dropout实例前调用才能生效。 - - - ## 7. 示例代码 ### 7.1 静态图场景 @@ -372,7 +393,7 @@ dump 结果目录结构示例如下: * `rank`:设备 ID,每张卡的数据保存在对应的 `rank{ID}` 目录下。非分布式场景下没有 rank ID,目录名称为 rank。 * `dump_tensor_data`:保存采集到的张量数据。 -* `dump.json`: 保存API或Cell前反向数据的统计量信息。包含dump数据的API名称或Cell名称,各数据的dtype、 shape、max、min、mean、L2norm(L2范数,平方根)统计信息以及当配置summary_mode="md5"时的CRC-32数据。具体介绍可参考[dump.json文件说明](./27.dump_json_instruction.md#2-dumpjson文件示例mindspore)。 +* `dump.json`: 保存API或Cell前反向数据的统计量信息。包含dump数据的API名称或Cell名称,各数据的dtype、 shape、max、min、mean、L2norm(L2范数,平方根)统计信息以及当配置summary_mode="md5"时的CRC-32数据。具体介绍可参考[dump.json文件说明](./27.dump_json_instruction.md#2-MindSpore场景下的dump.json文件)。 * `stack.json`:API/Cell的调用栈信息。 * `construct.json`:分层分级结构,level为L1时,construct.json内容为空。 diff --git a/debug/accuracy_tools/msprobe/docs/09.accuracy_checker_MindSpore.md b/debug/accuracy_tools/msprobe/docs/09.accuracy_checker_MindSpore.md index 8e5ab781ce0652ea572e0a0e5fb053655c5f48ec..3bf65032edae2b8e35c5818d5c030c9ce4c79e95 100644 --- a/debug/accuracy_tools/msprobe/docs/09.accuracy_checker_MindSpore.md +++ b/debug/accuracy_tools/msprobe/docs/09.accuracy_checker_MindSpore.md @@ -2,7 +2,7 @@ ## 1 简介 -**MindSpore 动态图精度预检**a通过扫描昇腾 NPU 上用户训练 MindSpore 模型中的所有 Mint API,输出精度情况的诊断和分析。工具以模型中所有 Mint API 前反向的 dump 结果为输入,构造相应的 API 单元测试,将 NPU 输出与标杆(CPU 高精度)比对,计算对应的精度指标,从而找出 NPU 中存在精度问题的 Mint API。本工具支持**随机生成模式和真实数据模式**b。 +**MindSpore 动态图精度预检**a通过扫描昇腾 NPU 上用户训练 MindSpore 模型中的所有 Mint API 以及 Msadapter场景下迁移的 Mindspore API,输出精度情况的诊断和分析。工具以模型中所有 API 前反向的 dump 结果为输入,构造相应的 API 单元测试,将 NPU 输出与标杆(CPU 高精度)比对,计算对应的精度指标,从而找出 NPU 中存在精度问题的 API。本工具支持**随机生成模式和真实数据模式**b。 a. 支持 Mindspore 版本:2.4/2.5; diff --git a/debug/accuracy_tools/msprobe/docs/10.accuracy_compare_PyTorch.md b/debug/accuracy_tools/msprobe/docs/10.accuracy_compare_PyTorch.md index b4525d738d849a17ca5049bd2214784c6f788d21..6f886215b0a389582bc3cc4c31943f76e6a414a3 100644 --- a/debug/accuracy_tools/msprobe/docs/10.accuracy_compare_PyTorch.md +++ b/debug/accuracy_tools/msprobe/docs/10.accuracy_compare_PyTorch.md @@ -257,11 +257,11 @@ PyTorch 精度比对是以 CPU 或 GPU 的计算结果为标杆,通过计算 统计量有 4 种:最大值(max)、最小值(min)、平均值(mean)和 L2-范数(L2 norm)。 -|dump 数据模式|Cosine (tensor 余弦相似度)|MaxAbsErr (tensor 最大绝对误差)|MaxRelativeErr (tensor 最大相对误差)|One Thousandth Err Ratio (tensor 相对误差小于千分之一的比例)|Five Thousandth Err Ratio (tensor 相对误差小于千分之五的比例)|NPU 和 bench 的统计量绝对误差 (max, min, mean, L2 norm) diff| NPU 和 bench 的统计量相对误差 (max, min, mean, L2 norm) RelativeErr |NPU 和 bench 的统计量 (max, min, mean, L2 norm)|NPU MD5 (NPU 数据 CRC-32 值)|BENCH MD5 (bench 数据 CRC-32 值)|Result (比对结果)|Accuracy Reached or Not (计算精度是否达标)|Err_message (错误信息提示)|NPU_Stack_Info (堆栈信息)|Data_Name (NPU 真实数据名)| -|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:| -|真实数据模式|√|√|√|√|√|||√||||√|√|√|√| -|统计数据模式||||||√|√|√|||√||√|√|| -|MD5 模式|||||||||√|√|√|||√|| +|dump 数据模式|Cosine (tensor 余弦相似度)|EucDist (tensor 欧式距离)|MaxAbsErr (tensor 最大绝对误差)|MaxRelativeErr (tensor 最大相对误差)|One Thousandth Err Ratio (tensor 相对误差小于千分之一的比例)|Five Thousandth Err Ratio (tensor 相对误差小于千分之五的比例)|NPU 和 bench 的统计量绝对误差 (max, min, mean, L2 norm) diff| NPU 和 bench 的统计量相对误差 (max, min, mean, L2 norm) RelativeErr |NPU 和 bench 的统计量 (max, min, mean, L2 norm)|NPU MD5 (NPU 数据 CRC-32 值)|BENCH MD5 (bench 数据 CRC-32 值)|Result (比对结果)|Accuracy Reached or Not (计算精度是否达标)|Err_message (错误信息提示)|NPU_Stack_Info (堆栈信息)| Data_Name ([NPU真实数据名,Bench真实数据名]) | +|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---------------------------------:| +|真实数据模式|√|√|√|√|√|√|||√||||√|√|√| √ | +|统计数据模式|||||||√|√|√|||√||√|√| | +|MD5 模式||||||||||√|√|√|||√| | 上表中NPU_Stack_Info字段需要配置-s参数生成。 @@ -320,7 +320,7 @@ MD5 模式: 5. "This is empty data, can not compare.":读取到的数据为空(真实数据模式); 6. "Shape of NPU and bench Tensor do not match. Skipped.":NPU 和 Bench 的数据结构不一致(真实数据模式); 7. "The Position of inf or nan in NPU and bench Tensor do not match.":NPU 和 Bench 的数据有 nan/inf(真实数据模式); -8. "This is type of 0-d tensor, can not calculate 'Cosine', 'One Thousandth Err Ratio' and 'Five Thousandths Err Ratio'.":NPU 为0维张量(真实数据模式); +8. "This is type of 0-d tensor, can not calculate 'Cosine', 'EucDist', 'One Thousandth Err Ratio' and 'Five Thousandths Err Ratio'.":NPU 为0维张量(真实数据模式); 9. "Dtype of NPU and bench Tensor do not match.":NPU 和 Bench 数据的数据类型不同(真实数据模式); 10. "":除以上情况的其余情况(真实数据模式、统计数据模式)。 @@ -330,13 +330,15 @@ MD5 模式: 1. Cosine:通过计算两个向量的余弦值来判断其相似度,数值越接近于 1 说明计算出的两个张量越相似,实际可接受阈值为大于 0.99。在计算中可能会存在 nan,主要由于可能会出现其中一个向量为 0。 -2. MaxAbsErr:当最大绝对误差越接近 0 表示其计算的误差越小,实际可接受阈值为小于 0.001。 +2. EucDist:通过计算两个向量的欧式距离来判断其相似度,定义为多维空间中两个点之间的绝对距离。数值越接近0,张量越相似,数值越大,差异越大。 -3. MaxRelativeErr:当最大相对误差越接近 0 表示其计算的误差越小。 +3. MaxAbsErr:当最大绝对误差越接近 0 表示其计算的误差越小,实际可接受阈值为小于 0.001。 + +4. MaxRelativeErr:当最大相对误差越接近 0 表示其计算的误差越小。 当 dump 数据中存在 0 或 Nan 时,比对结果中最大相对误差则出现 inf 或 Nan 的情况,属于正常现象。 -4. One Thousandth Err Ratio(相对误差小于千分之一的元素比例)、Five Thousandths Err Ratio(相对误差小于千分之五的元素比例)精度指标:是指 NPU 的 Tensor 中的元素逐个与对应的标杆数据对比,相对误差小于千分之一、千分之五的比例占总元素个数的比例。该数据仅作为精度下降趋势的参考,并不参与计算精度是否通过的判定。 +5. One Thousandth Err Ratio(相对误差小于千分之一的元素比例)、Five Thousandths Err Ratio(相对误差小于千分之五的元素比例)精度指标:是指 NPU 的 Tensor 中的元素逐个与对应的标杆数据对比,相对误差小于千分之一、千分之五的比例占总元素个数的比例。该数据仅作为精度下降趋势的参考,并不参与计算精度是否通过的判定。 ## 4 多卡比对结果提取汇总通信算子数据 diff --git a/debug/accuracy_tools/msprobe/docs/12.overflow_check_PyTorch.md b/debug/accuracy_tools/msprobe/docs/12.overflow_check_PyTorch.md index 97b049000c6aca9a69aeca66e1a27a4260b3d142..983477554e138f3e547f2d3efcf14fdfc4a991a0 100644 --- a/debug/accuracy_tools/msprobe/docs/12.overflow_check_PyTorch.md +++ b/debug/accuracy_tools/msprobe/docs/12.overflow_check_PyTorch.md @@ -28,7 +28,7 @@ msprobe 工具在 PyTorch 场景下提供溢出数据采集功能和溢出数据 溢出数据采集功能在昇腾 NPU 上支持饱和模式(仅支持 Atlas 训练系列产品)和 INF/NAN 模式。 -INF/NAN 模式遵循 IEEE 754 标准,根据定义输出 INF/NAN 的计算结果。与之对应的饱和模式在计算出现溢出时,饱和为浮点数极值(+-MAX)。对于 CANN 侧配置,Atlas 训练系列产品,默认为饱和模式,且不建议使用 INF/NAN 模式;Atlas A2 训练系列产品,默认为 INF/NAN 模式,且不建议使用饱和模式。 +INF/NAN 模式遵循 IEEE 754 标准,根据定义输出 INF/NAN 的计算结果。与之对应的饱和模式在计算出现溢出时,饱和为浮点数极值(+-MAX)。对于 CANN 侧配置,Atlas 训练系列产品,默认为饱和模式,且不支持使用 INF/NAN 模式;Atlas A2 训练系列产品,默认为 INF/NAN 模式,且不建议使用饱和模式。 INF/NAN 模式的使能方式如下: diff --git a/debug/accuracy_tools/msprobe/docs/13.overflow_check_MindSpore.md b/debug/accuracy_tools/msprobe/docs/13.overflow_check_MindSpore.md index 33ff4a0259aef02d122022402966c65358e8efff..ef83aa17237d1cc56b8a67bf4b3ec9f57647fb9c 100644 --- a/debug/accuracy_tools/msprobe/docs/13.overflow_check_MindSpore.md +++ b/debug/accuracy_tools/msprobe/docs/13.overflow_check_MindSpore.md @@ -11,7 +11,7 @@ export INF_NAN_MODE_ENABLE=1 export MS_ASCEND_CHECK_OVERFLOW_MODE="INFNAN_MODE" ``` -**a**:在处理浮点数计算溢出问题时,NPU 当前支持两种溢出模式:INF/NAN 模式与饱和模式。INF/NAN 模式遵循 IEEE 754 标准,根据定义输出 INF/NAN 的计算结果。与之对应的饱和模式在计算出现溢出时,饱和为浮点数极值(+-MAX)。对于 CANN 侧配置,Atlas 训练系列产品,默认为饱和模式,且不建议使用 INF/NAN 模式;Atlas A2训练系列产品,默认为 INF/NAN 模式,且不建议使用饱和模式。对于 MindSpore 框架侧配置,仅支持对 Atlas A2 训练系列产品进行设置,默认为 INF/NAN 模式。CANN 侧 与 MindSpore 框架侧配置须一致。 +**a**:在处理浮点数计算溢出问题时,NPU 当前支持两种溢出模式:INF/NAN 模式与饱和模式。INF/NAN 模式遵循 IEEE 754 标准,根据定义输出 INF/NAN 的计算结果。与之对应的饱和模式在计算出现溢出时,饱和为浮点数极值(+-MAX)。对于 CANN 侧配置,Atlas 训练系列产品,默认为饱和模式,且不支持使用 INF/NAN 模式;Atlas A2训练系列产品,默认为 INF/NAN 模式,且不建议使用饱和模式。对于 MindSpore 框架侧配置,仅支持对 Atlas A2 训练系列产品进行设置,默认为 INF/NAN 模式。CANN 侧 与 MindSpore 框架侧配置须一致。 溢出检测任务的配置示例见[MindSpore 静态图场景下 task 配置为 overflow_check](https://gitee.com/ascend/mstt/blob/master/debug/accuracy_tools/msprobe/docs/03.config_examples.md#23-task-%E9%85%8D%E7%BD%AE%E4%B8%BA-overflow_check)、[MindSpore 动态图场景下 task 配置为 overflow_check](https://gitee.com/ascend/mstt/blob/master/debug/accuracy_tools/msprobe/docs/03.config_examples.md#33-task-%E9%85%8D%E7%BD%AE%E4%B8%BA-overflow_check)。 diff --git a/debug/accuracy_tools/msprobe/docs/27.dump_json_instruction.md b/debug/accuracy_tools/msprobe/docs/27.dump_json_instruction.md index f994dc2301bcae6b23dc7a7503297aa4fe5b3724..bf5998bce0b4cd174b9713d9417d1afb674c2b56 100644 --- a/debug/accuracy_tools/msprobe/docs/27.dump_json_instruction.md +++ b/debug/accuracy_tools/msprobe/docs/27.dump_json_instruction.md @@ -1,8 +1,8 @@ # dump.json文件说明及示例 -## 1. dump.json文件示例(PyTorch) +## 1. PyTorch 场景下的 dump.json 文件 -### 1.1 L0级别 +### 1.1 L0 级别 L0级别的dump.json文件包括模块的前反向的输入输出,以及模块的参数和参数梯度。以PyTorch的Conv2d模块为例,网络中模块调用代码为: `output = self.conv2(input) # self.conv2 = torch.nn.Conv2d(64, 128, 5, padding=2, bias=True)` @@ -168,7 +168,7 @@ dump.json文件中包含以下数据名称: } ``` -### 1.2 L1级别 +### 1.2 L1 级别 L1级别的dump.json文件包括API的前反向的输入输出。以PyTorch的relu函数为例,网络中API调用代码为: `output = torch.nn.functional.relu(input)` @@ -264,13 +264,13 @@ dump.json文件中包含以下数据名称: } ``` -### 1.3 mix级别 +### 1.3 mix 级别 mix级别的dump.json文件同时包括L0和L1级别的dump数据,文件格式与上述示例相同。 -## 2. dump.json文件示例(MindSpore) +## 2. MindSpore 场景下的 dump.json 文件 -### 2.1 L0级别 +### 2.1 L0 级别 L0级别的dump.json文件包括模块的前反向的输入输出,以及模块的参数和参数梯度。 以MindSpore的Conv2d模块为例,dump.json文件中使用的模块调用代码为: @@ -429,7 +429,7 @@ dump.json文件中包含以下数据名称: } ``` -### 2.2 L1级别 +### 2.2 L1 级别 L1级别的dump.json文件包括API的前反向的输入输出,以MindSpore的relu函数为例,网络中API调用代码为: `output = mindspore.ops.relu(input)` @@ -521,5 +521,275 @@ L1级别的dump.json文件包括API的前反向的输入输出,以MindSpore的 } ``` -### 2.3 mix级别 +### 2.3 mix 级别 + mix级别的dump.json文件同时包括L0和L1级别的dump数据,文件格式与上述示例相同。 + +## 3. MSAdapter 场景下的 dump.json 文件 + +### 3.1 L0 级别 + +L0 级别的 dump.json 文件包括模块的前反向的输入输出,以及模块的参数和参数梯度。以 Conv2d 模块为例,网络中模块调用代码为: +`output = self.conv2(input) # self.conv2 = torch.nn.Conv2d(64, 128, 5, padding=2, bias=True)` + +dump.json文件中包含以下数据名称: + +- `Module.conv2.Conv2d.forward.0`:模块的前向数据,其中input_args为模块的输入数据(位置参数),input_kwargs为模块的输入数据(关键字参数),output为模块的输出数据,parameters为模块的参数数据,包括权重(weight)和偏置(bias)。 +- `Module.conv2.Conv2d.parameters_grad`:模块的参数梯度数据,包括权重(weight)和偏置(bias)的梯度。 +- `Module.conv2.Conv2d.backward.0`:模块的反向数据,其中input为模块反向的输入梯度(对应前向输出的梯度),output为模块的反向输出梯度(对应前向输入的梯度)。 + +**说明**:当dump时传入的model参数为List[torch.nn.Module]或Tuple[torch.nn.Module]时,模块级数据的命名中包含该模块在列表中的索引index,命名格式为`{Module}.{index}.*`,*表示以上三种模块级数据的命名格式,例如:`Module.0.conv1.Conv2d.forward.0`。 + +```json +{ + "task": "tensor", + "level": "L0", + "framework": "mindtorch", + "dump_data_dir": "/dump/path", + "data": { + "Module.conv2.Conv2d.forward.0": { + "input_args": [ + { + "type": "mindspore.Tensor", + "dtype": "Float32", + "shape": [ + 8, + 16, + 14, + 14 + ], + "Max": 1.638758659362793, + "Min": 0.0, + "Mean": 0.2544615864753723, + "Norm": 70.50277709960938, + "requires_grad": true, + "data_name": "Module.conv2.Conv2d.forward.0.input.0.npy" + } + ], + "input_kwargs": {}, + "output": [ + { + "type": "mindspore.Tensor", + "dtype": "Float32", + "shape": [ + 8, + 32, + 10, + 10 + ], + "Max": 1.6815717220306396, + "Min": -1.5120246410369873, + "Mean": -0.025344856083393097, + "Norm": 149.65576171875, + "requires_grad": true, + "data_name": "Module.conv2.Conv2d.forward.0.output.0.npy" + } + ], + "parameters": { + "weight": { + "type": "mindspore.Tensor", + "dtype": "Float32", + "shape": [ + 32, + 16, + 5, + 5 + ], + "Max": 0.05992485210299492, + "Min": -0.05999220535159111, + "Mean": -0.0006165213999338448, + "Norm": 3.421217441558838, + "requires_grad": true, + "data_name": "Module.conv2.Conv2d.forward.0.parameters.weight.npy" + }, + "bias": { + "type": "mindspore.Tensor", + "dtype": "Float32", + "shape": [ + 32 + ], + "Max": 0.05744686722755432, + "Min": -0.04894155263900757, + "Mean": 0.006410328671336174, + "Norm": 0.17263513803482056, + "requires_grad": true, + "data_name": "Module.conv2.Conv2d.forward.0.parameters.bias.npy" + } + } + }, + "Module.conv2.Conv2d.parameters_grad": { + "weight": [ + { + "type": "mindspore.Tensor", + "dtype": "Float32", + "shape": [ + 32, + 16, + 5, + 5 + ], + "Max": 0.018550323322415352, + "Min": -0.008627401664853096, + "Mean": 0.0006675920449197292, + "Norm": 0.26084786653518677, + "requires_grad": false, + "data_name": "Module.conv2.Conv2d.parameters_grad.weight.npy" + } + ], + "bias": [ + { + "type": "mindspore.Tensor", + "dtype": "Float32", + "shape": [ + 32 + ], + "Max": 0.014914230443537235, + "Min": -0.006656786892563105, + "Mean": 0.002657240955159068, + "Norm": 0.029451673850417137, + "requires_grad": false, + "data_name": "Module.conv2.Conv2d.parameters_grad.bias.npy" + } + ] + }, + "Module.conv2.Conv2d.backward.0": { + "input": [ + { + "type": "mindspore.Tensor", + "dtype": "Float32", + "shape": [ + 8, + 32, + 10, + 10 + ], + "Max": 0.0015069986693561077, + "Min": -0.001139344065450132, + "Mean": 3.3215508210560074e-06, + "Norm": 0.020567523315548897, + "requires_grad": false, + "data_name": "Module.conv2.Conv2d.backward.0.input.0.npy" + } + ], + "output": [ + { + "type": "mindspore.Tensor", + "dtype": "Float32", + "shape": [ + 8, + 16, + 14, + 14 + ], + "Max": 0.0007466732058674097, + "Min": -0.00044813455315306783, + "Mean": 6.814070275140693e-06, + "Norm": 0.01474067009985447, + "requires_grad": false, + "data_name": "Module.conv2.Conv2d.backward.0.output.0.npy" + } + ] + } + } +} +``` + +### 3.2 L1 级别 +L1级别的dump.json文件包括API的前反向的输入输出。以 relu API 为例,网络中 API 调用代码为: +`output = torch.nn.functional.relu(input)` + +dump.json文件中包含以下数据名称: +- `Functional.relu.0.forward`:API的前向数据,其中input_args为API的输入数据(位置参数),input_kwargs为API的输入数据(关键字参数),output为API的输出数据。 +- `Functional.relu.0.backward`:API的反向数据,其中input为API的反向输入梯度(对应前向输出的梯度),output为API的反向输出梯度(对应前向输入的梯度)。 + +```json +{ + "task": "tensor", + "level": "L1", + "framework": "mindtorch", + "dump_data_dir":"/dump/path", + "data": { + "Functional.relu.0.forward": { + "input_args": [ + { + "type": "mindspore.Tensor", + "dtype": "Float32", + "shape": [ + 32, + 16, + 28, + 28 + ], + "Max": 1.3864083290100098, + "Min": -1.3364859819412231, + "Mean": 0.03711778670549393, + "Norm": 236.20692443847656, + "requires_grad": true, + "data_name": "Functional.relu.0.forward.input.0.npy" + } + ], + "input_kwargs": {}, + "output": [ + { + "type": "mindspore.Tensor", + "dtype": "Float32", + "shape": [ + 32, + 16, + 28, + 28 + ], + "Max": 1.3864083290100098, + "Min": 0.0, + "Mean": 0.16849493980407715, + "Norm": 175.23345947265625, + "requires_grad": true, + "data_name": "Functional.relu.0.forward.output.0.npy" + } + ] + }, + "Functional.relu.0.backward": { + "input": [ + { + "type": "mindspore.Tensor", + "dtype": "Float32", + "shape": [ + 32, + 16, + 28, + 28 + ], + "Max": 0.0001815402356442064, + "Min": -0.00013352684618439525, + "Mean": 0.00011915402356442064, + "Norm": 0.007598237134516239, + "requires_grad": false, + "data_name": "Functional.relu.0.backward.input.0.npy" + } + ], + "output": [ + { + "type": "mindspore.Tensor", + "dtype": "Float32", + "shape": [ + 32, + 16, + 28, + 28 + ], + "Max": 0.0001815402356442064, + "Min": -0.00012117840378778055, + "Mean": 2.0098118724831693e-08, + "Norm": 0.006532244384288788, + "requires_grad": false, + "data_name": "Functional.relu.0.backward.output.0.npy" + } + ] + } + } +} +``` + +### 3.3 mix 级别 + +mix级别的dump.json文件同时包括L0和L1级别的dump数据,文件格式与上述示例相同。 \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/docs/28.kernel_dump_MindSpore.md b/debug/accuracy_tools/msprobe/docs/28.kernel_dump_MindSpore.md index 6b8cc558aa22526158033cfb35f31203d8b04278..4988586c0568b391739f7c14f1a9452461f1a6f1 100644 --- a/debug/accuracy_tools/msprobe/docs/28.kernel_dump_MindSpore.md +++ b/debug/accuracy_tools/msprobe/docs/28.kernel_dump_MindSpore.md @@ -1,4 +1,4 @@ -# MindSpore 场景的 kernel dump 说明 +# MindSpore 动态图场景的 kernel dump 说明 当使用 msprobe 数据采集功能时,level 配置为 "L2" 表示采集 kernel 层级的算子数据,仅支持昇腾 NPU 平台。 diff --git a/debug/accuracy_tools/msprobe/docs/29.data_dump_MSAdapter.md b/debug/accuracy_tools/msprobe/docs/29.data_dump_MSAdapter.md new file mode 100644 index 0000000000000000000000000000000000000000..cefcabafbcbbdbb33a3d9d63c17a30396c9e4c52 --- /dev/null +++ b/debug/accuracy_tools/msprobe/docs/29.data_dump_MSAdapter.md @@ -0,0 +1,229 @@ +# MSAdapter 场景的精度数据采集 + +MSAdapter 是一款 MindSpore 生态适配工具,可以将 PyTorch 训练脚本高效迁移至 MindSpore 框架执行,以实现在不改变原有 PyTorch 用户开发习惯的情况下,使得 PyTorch 代码能在昇腾上获得高效性能。 + +msprobe 工具主要通过在训练脚本内添加 dump 接口、启动训练的方式采集精度数据。 + +本工具提供固定的 API 支持列表,若需要删除或增加 dump 的 API,可以在 msprobe/pytorch/hook_module/support_wrap_ops.yaml 文件内手动修改,如下示例: + +```yaml +functional: # functional为算子类别,找到对应的类别,在该类别下按照下列格式删除或添加API + - conv1d + - conv2d + - conv3d +``` + +删除 API 的场景:部分模型代码逻辑会存在 API 原生类型校验,工具执行dump操作时,对封装后的模型 API 可能与模型的原生 API 类型不一致,此时可能引发校验失败,详见《[FAQ](FAQ.md)》中“异常情况”的第10和11条。 + +## 1. 工具安装 + +请参见[《msprobe 工具安装指南》](./01.installation.md)。 + +## 2 接口介绍 + +### 2.1 msprobe.mindspore.PrecisionDebugger + +**功能说明**:通过加载 dump 配置文件的方式来确定 dump 操作的详细配置。 + +**原型**: + +```Python +PrecisionDebugger(config_path=None, task=None, dump_path=None, level=None, step=None) +``` + +**参数说明**: + +1. config_path:指定 dump 配置文件路径,string 类型。参数示例:"./config.json"。未配置该路径时,默认使用 [config.json](../config.json) 文件的默认配置,配置选项含义可见 [config.json 介绍](./02.config_introduction.md)。 + +2. 其他参数与 [config.json](../config.json) 文件中的同名配置字段含义相同,具体可见 [config.json 介绍](./02.config_introduction.md)。当参数值非None时,优先级高于 [config.json](../config.json) 文件中的同名配置。 + +#### 2.1.1 start + +**功能说明**:启动精度数据采集。需要与 [**stop**](#212-stop) 接口一起添加在训练迭代的 for 循环内。 + +**原型**: + +```Python +start(model=None) +``` + +**参数说明**: + +1. model:指定需要采集 Module 级数据的模型,支持传入 torch.nn.Module、list[torch.nn.Module]或Tuple[torch.nn.Module] 类型,默认未配置。level 配置为 "L0" 或 "mix" 时,必须在该接口中配置该参数。API级别("L1" level)dump 时,传入 model 可以采集 model 内包含 primitive op 对象在内的所有 API 数据,若不传入 model 参数,则只采集非 primitive op 的 API 数据。 + +#### 2.1.2 stop + +**功能说明**:停止精度数据采集。在 **start** 接口调用之后的任意位置添加。若 **stop** 接口添加在反向计算代码之后,则会采集 **start** 和该接口之间的前反向数据。 +若 **stop** 接口添加在反向计算代码之前,则需要将 [**step**](#213-step) 接口添加到反向计算代码之后,才能采集 **start** 和该接口之间的前反向数据。 + +**注意**:**stop** 接口必须调用,否则可能导致精度数据落盘不全。 + +**原型**: + +```Python +stop() +``` + +#### 2.1.3 step + +**功能说明**:进行训练 step 数的自增,完成当前 step 所有数据的落盘并更新 dump 参数。在一个 step 训练结束的位置添加,且必须在 **stop** 接口之后的位置调用。该接口需要配合 **start** 和 **stop** 函数使用,尽量添加在反向计算代码之后,否则可能会导致反向数据丢失。 + +**原型**: + +```Python +step() +``` + +#### 2.1.4 forward_backward_dump_end + +**功能说明**:停止精度数据采集。与 **stop** 接口功能相同,该函数在将来会被移除,建议使用 **stop** 接口。 + +**原型**: + +```Python +forward_backward_dump_end() +``` + +#### 2.1.5 save + +**功能说明**:单点保存网络执行过程中正反向数值,并以统计值/张量文件落盘。 + +**原型**: +```python +save(variable, name, save_backward=True) +``` + +**参数说明**: +| 参数名称 | 参数含义 | 支持数据类型 | 是否必选| +| ---------- | ------------------| ------------------- | ------------------- | +| variable | 需要保存的变量 |dict, list, tuple, torch.tensor, int, float, str | 是 | +| name | 指定的名称 | str | 是 | +| save_backward | 是否保存反向数据 | boolean | 否 | + +### 2.2 msprobe.mindspore.seed_all + +**功能说明**:用于固定网络中的随机性和开启确定性计算。 + +**原型**: +```python +seed_all(seed=1234, mode=False, rm_dropout=True) +``` + +**参数说明**: + +1. seed: 随机性种子,默认值:1234,非必选。参数示例: seed=1000。该参数用于 random、numpy.random, mindspore.common.Initializer、mindspore.nn.probability.distribution的随机数生成以及 Python 中 str、bytes、datetime 对象的 hash 算法。 + +2. mode:确定性计算使能,可配置 True 或 False,默认值:False,非必选。参数示例:mode=True。该参数设置为 True 后,将会开启算子确定性运行模式与归约类通信算子(AllReduce、ReduceScatter、Reduce)的确定性计算。注意:确定性计算会导致 API 执行性能降低,建议在发现模型多次执行结果不同的情况下开启。 + +3. rm_dropout:控制 dropout 失效的开关。可配置 True 或 False,默认值:True,非必选。参数示例:rm_dropout=True。该参数设置为 True 后,将会使 mindspore.ops.Dropout,mindspore.ops.Dropout2D,mindspore.ops.Dropout3D,mindspore.mint.nn.Dropout和mindspore.mint.nn.functional.dropout 失效,以避免因随机 dropout 造成的网络随机性。建议在采集数据前调用。 + +**注意**:通过 rm_dropout 控制 dropout 失效或生效需要在初始化 Dropout 实例前调用才能生效。 + +## 3 示例代码 + +以下为添加了 msprobe 工具 dump 接口的示例训练脚本。 + +```python +import mindspore as ms +import torch +import torch.nn as nn +import torch.nn.functional as F + +# 导入工具的数据采集接口 +from msprobe.pytorch import PrecisionDebugger + +# 在模型训练开始前实例化PrecisionDebugger +debugger = PrecisionDebugger(config_path='./config.json') + + +# 定义网络 +class Net(nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear1 = nn.Linear(in_features=8, out_features=4) + self.linear2 = nn.Linear(in_features=4, out_features=2) + + def forward(self, x): + x1 = self.linear1(x) + x2 = self.linear2(x1) + logits = F.relu(x2) + return logits + + +net = Net() + + +def train_step(inputs): + return net(inputs) + + +if __name__ == "__main__": + data = (torch.randn(10, 8), torch.randn(10, 8), torch.randn(10, 8)) + grad_fn = ms.value_and_grad(train_step, grad_position=0) + + for inputs in data: + # 开启数据 dump + debugger.start(model=net) + + out, grad = grad_fn(inputs) + + # 停止数据 dump + debugger.stop() + # 更新 step 信息 + debugger.step() +``` + +## 4 dump 结果文件介绍 + +训练结束后,工具将 dump 的数据保存在 dump_path 参数指定的目录下。目录结构示例如下: + +```lua +├── dump_path +│ ├── step0 +│ | ├── rank0 +│ | │ ├── dump_tensor_data +| | | | ├── Tensor.permute.1.forward.npy +| | | | ├── Functional.linear.5.backward.output.npy # 命名格式为{api_type}.{api_name}.{API调用次数}.{forward/backward}.{input/output}.{参数序号}, 其中,“参数序号”表示该API的第n个输入或输出,例如1,则为第一个参数,若该参数为list格式,则根据list继续排序,例如1.1,表示该API的第1个参数的第1个元素。 +| | | | ... +| | | | ├── Module.conv1.Conv2d.forward.0.input.0.npy # 命名格式为{Module}.{module_name}.{class_name}.{forward/backward}.{调用次数}.{input/output}.{参数序号}, 其中,“参数序号”表示该Module的第n个参数,例如1,则为第一个参数,若该参数为list格式,则根据list继续排序,例如1.1,表示该Module的第1个参数的第1个元素。 +| | | | ├── Module.conv1.Conv2D.forward.0.parameters.bias.npy # 模块参数数据:命名格式为{Module}.{module_name}.{class_name}.forward.{调用次数}.parameters.{parameter_name}。 +| | | | └── Module.conv1.Conv2D.parameters_grad.weight.npy # 模块参数梯度数据:命名格式为{Module}.{module_name}.{class_name}.parameters_grad.{parameter_name}。因为同一模块的参数使用同一梯度进行更新,所以参数梯度文件名不包含调用次数。 +| | | | # 当dump时传入的model参数为List[torch.nn.Module]或Tuple[torch.nn.Module]时,模块级数据的命名中包含该模块在列表中的索引index,命名格式为{Module}.{index}.*,*表示以上三种模块级数据的命名格式,例如:Module.0.conv1.Conv2d.forward.0.input.0.npy。 +│ | | ├── dump.json +│ | | ├── stack.json +│ | | └── construct.json +│ | ├── rank1 +| | | ├── dump_tensor_data +| | | | └── ... +│ | | ├── dump.json +│ | | ├── stack.json +| | | └── construct.json +│ | ├── ... +│ | | +| | └── rank7 +│ ├── step1 +│ | ├── ... +│ ├── step2 +``` +* `rank`:设备 ID,每张卡的数据保存在对应的 `rank{ID}` 目录下。非分布式场景下没有 rank ID,目录名称为 rank。 +* `dump_tensor_data`:保存采集到的张量数据。 +* `dump.json`: 保存 API 或 Module 前反向数据的统计量信息。包含 dump 数据的 API 名称或 Module 名称,各数据的 dtype、 shape、max、min、mean、L2norm(L2范数,平方根)统计信息以及当配置 summary_mode="md5" 时的 CRC-32 数据。具体介绍可参考[dump.json文件说明](./27.dump_json_instruction.md#3-MSAdapter场景下的dump.json文件)。 +* `stack.json`:API/Module 的调用栈信息。 +* `construct.json`:分层分级结构,level 为 L1 时,construct.json 内容为空。 + + +当 task 为 tensor 时,dump 过程中,npy 文件在对应算子或者模块被执行后就会落盘,而 json 文件则需要在正常执行 PrecisionDebugger.stop() 后才会写入完整数据。因此如果程序异常终止,终止前被执行算子的相关 npy 文件得以保存,但 json 文件中的数据可能丢失。 + +其中 rank 为设备上各卡的 ID,每张卡上 dump 的数据会生成对应 dump 目录。非分布式场景下没有 rank ID,目录名称为 rank。 + +npy 文件名的前缀含义如下: + +| 前缀 | 含义 | +| ----------- | ---------------------------- | +| Tensor | torch.Tensor API数据 | +| Torch | torch API数据 | +| Functional | torch.nn.functional API数据 | +| NPU | NPU 亲和API数据 | +| Distributed | torch.distributed API数据 | +| Jit | 被 "jit" 装饰的模块或函数数据 | +| Module | torch.nn.Module 类(模块)数据 | \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/docs/30.overflow_check_MSAdapter.md b/debug/accuracy_tools/msprobe/docs/30.overflow_check_MSAdapter.md new file mode 100644 index 0000000000000000000000000000000000000000..01d64c808d40a1e5c4ea2190c028a7c389ffbdc4 --- /dev/null +++ b/debug/accuracy_tools/msprobe/docs/30.overflow_check_MSAdapter.md @@ -0,0 +1,31 @@ +# MSAdapter 场景的溢出检测 + +msprobe 工具提供 MSAdapter 场景下的溢出检测功能。其检测对象为 **API** 级别(除 Primitive 和 Jit 类 API)或**模块**级别,分别对应 config.json 配置中的 **"L1"** 、**"L0"** level。 + +需要注意,本工具仅支持在 INF/NAN 模式a下进行溢出检测。INF/NAN 模式的使能方式如下: + +```Shell +# 使能 CANN 侧 INF/NAN 模式 +export INF_NAN_MODE_ENABLE=1 +# 使能 MindSpore 框架侧 INF/NAN 模式 +export MS_ASCEND_CHECK_OVERFLOW_MODE="INFNAN_MODE" +``` + +**a**:在处理浮点数计算溢出问题时,NPU 当前支持两种溢出模式:INF/NAN 模式与饱和模式。INF/NAN 模式遵循 IEEE 754 标准,根据定义输出 INF/NAN 的计算结果。与之对应的饱和模式在计算出现溢出时,饱和为浮点数极值(+-MAX)。对于 CANN 侧配置,Atlas 训练系列产品,默认为饱和模式,且不建议使用 INF/NAN 模式;Atlas A2训练系列产品,默认为 INF/NAN 模式,且不建议使用饱和模式。对于 MindSpore 框架侧配置,仅支持对 Atlas A2 训练系列产品进行设置,默认为 INF/NAN 模式。CANN 侧 与 MindSpore 框架侧配置须一致。 + +溢出检测任务的配置示例见["**MindSpore 动态图场景 task 配置为 overflow_check**"](./03.config_examples.md#33-task配置为overflow_check)小节。 + + +## 1 接口介绍 + +溢出检测功能提供的接口与数据采集任务一致,详见 MSAdapter 场景的精度数据采集中的["**2 接口介绍**"](./29.data_dump_MSAdapter.md#2-接口介绍)小节。 + +需要注意,目前暂不支持 "L1" level 下 primitive op 的溢出检测。 + +## 2 示例代码 + +溢出检测功能使用方式与数据采集任务一致,详见 MSAdapter 场景的精度数据采集中的["**3 示例代码**"](./29.data_dump_MSAdapter.md#3-示例代码)小节。 + +## 3 溢出检测结果文件介绍 + +溢出检测结果文件目录结构与含义与数据采集任务一致,但仅保存溢出 API 或 模块 的真实数据或统计信息。详见 MSAdapter 场景的精度数据采集中的["**4 dump 结果文件介绍**"](./29.data_dump_MSAdapter.md#4-dump-结果文件介绍)小节。 \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/docs/img/compare_result.png b/debug/accuracy_tools/msprobe/docs/img/compare_result.png index 07cdb51707fe43d07723ed976275d99f55b50571..b6d7ec6dfcbc44b4b7056e1297a481f495ceb86e 100644 Binary files a/debug/accuracy_tools/msprobe/docs/img/compare_result.png and b/debug/accuracy_tools/msprobe/docs/img/compare_result.png differ diff --git a/debug/accuracy_tools/msprobe/mindspore/__init__.py b/debug/accuracy_tools/msprobe/mindspore/__init__.py index 089c29eb098ad4305edcca1306462f8924dd9291..8471e29a0d000aa9e61e01f1eb1bc578247fdda8 100644 --- a/debug/accuracy_tools/msprobe/mindspore/__init__.py +++ b/debug/accuracy_tools/msprobe/mindspore/__init__.py @@ -17,12 +17,12 @@ import os try: from msprobe.lib import _msprobe_c - os.environ["MS_HOOK_ENABLE"] = "on" - os.environ["HOOK_TOOL_PATH"] = _msprobe_c.__file__ except ImportError: from .common.log import logger logger.info("Module _msprobe_c has not been installed. L2-Dump may not work normally.") from msprobe.mindspore.debugger.precision_debugger import PrecisionDebugger -from msprobe.mindspore.common.utils import seed_all -from msprobe.mindspore.monitor.module_hook import TrainerMon \ No newline at end of file +from msprobe.mindspore.common.utils import seed_all, MsprobeStep, MsprobeInitStep +from msprobe.mindspore.monitor.module_hook import TrainerMon + +os.environ["MS_HOOK_ENABLE"] = "on" diff --git a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py index 98c6b4b98530ec447c2e239c11b5d4d7b927d874..557d731e042913da3a622035219ec8dea0409ab4 100644 --- a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +++ b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py @@ -16,7 +16,7 @@ import os from tqdm import tqdm -from msprobe.core.common.const import Const, CompareConst, MsCompareConst +from msprobe.core.common.const import Const, CompareConst from msprobe.core.common.file_utils import FileOpen, create_directory, write_csv, load_json, load_yaml from msprobe.core.common.utils import add_time_as_suffix from msprobe.mindspore.api_accuracy_checker.api_info import ApiInfo @@ -25,6 +25,7 @@ from msprobe.mindspore.api_accuracy_checker.base_compare_algorithm import compar from msprobe.mindspore.api_accuracy_checker.data_manager import DataManager from msprobe.mindspore.api_accuracy_checker.utils import (check_and_get_from_json_dict, global_context, trim_output_compute_element_list) +from msprobe.mindspore.common.const import MsCompareConst from msprobe.mindspore.common.log import logger from msprobe.mindspore.api_accuracy_checker import torch_mindtorch_importer @@ -156,6 +157,7 @@ class ApiAccuracyChecker: real_api_str = Const.SEP.join(api_name_str_list[1:-2]) api_list = load_yaml(yaml_path) supported_tensor_api_list = api_list.get(MsCompareConst.SUPPORTED_TENSOR_LIST_KEY) + supported_fusion_api_list = MsCompareConst.SUPPORTED_FUSION_LIST if api_type_str in (MsCompareConst.MINT, MsCompareConst.MINT_FUNCTIONAL) \ and global_context.get_framework() == Const.MS_FRAMEWORK: return True @@ -165,6 +167,9 @@ class ApiAccuracyChecker: if api_type_str == MsCompareConst.TENSOR_API and real_api_str in supported_tensor_api_list \ and global_context.get_framework() == Const.MS_FRAMEWORK: return True + if api_type_str == MsCompareConst.FUNCTIONAL_API and real_api_str in supported_fusion_api_list \ + and global_context.get_framework() == Const.MS_FRAMEWORK: + return True return False def parse(self, api_info_path): diff --git a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/api_runner.py b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/api_runner.py index f42702be0b114e40e5e31dc4326bd9ca21f82202..36e506f67737cdea4452ba27f4fad0524d4c2884 100644 --- a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/api_runner.py +++ b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/api_runner.py @@ -15,11 +15,13 @@ import mindspore from mindspore import ops -from msprobe.core.common.const import Const, MsCompareConst +from msprobe.core.common.const import Const from msprobe.core.common.exceptions import ApiAccuracyCheckerException from msprobe.mindspore.api_accuracy_checker.compute_element import ComputeElement from msprobe.mindspore.api_accuracy_checker.type_mapping import float_dtype_str_list, torch_dtype_to_dtype_str from msprobe.mindspore.api_accuracy_checker.utils import convert_to_tuple +from msprobe.mindspore.api_accuracy_checker.bench_functions.fusion_operator import fusion +from msprobe.mindspore.common.const import MsCompareConst from msprobe.mindspore.common.log import logger @@ -64,7 +66,9 @@ api_parent_module_mapping = { (MsCompareConst.MINDTORCH_FUNC, Const.MT_FRAMEWORK): mindtorch_func, (MsCompareConst.MINDTORCH_FUNC, Const.PT_FRAMEWORK): torch.nn.functional, (MsCompareConst.MINDTORCH_DIST, Const.MT_FRAMEWORK): mindtorch_dist, - (MsCompareConst.MINDTORCH_DIST, Const.PT_FRAMEWORK): torch.distributed + (MsCompareConst.MINDTORCH_DIST, Const.PT_FRAMEWORK): torch.distributed, + (MsCompareConst.FUNCTIONAL_API, Const.MS_FRAMEWORK): mindspore.ops, + (MsCompareConst.FUSION_API, Const.PT_FRAMEWORK): fusion } @@ -83,7 +87,9 @@ api_parent_module_str_mapping = { (MsCompareConst.MINDTORCH_FUNC, Const.MT_FRAMEWORK): "mindtorch_func", (MsCompareConst.MINDTORCH_FUNC, Const.PT_FRAMEWORK): "torch.nn.functional", (MsCompareConst.MINDTORCH_DIST, Const.MT_FRAMEWORK): "mindtorch_dist", - (MsCompareConst.MINDTORCH_DIST, Const.PT_FRAMEWORK): "torch.distributed" + (MsCompareConst.MINDTORCH_DIST, Const.PT_FRAMEWORK): "torch.distributed", + (MsCompareConst.FUNCTIONAL_API, Const.MS_FRAMEWORK): "mindspore.ops", + (MsCompareConst.FUSION_API, Const.PT_FRAMEWORK): "fusion" } @@ -125,7 +131,8 @@ class ApiRunner: err_msg = f"ApiRunner.get_info_from_name failed: api_name_str: {api_name_str} is not in defined format" logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue)) api_type_str, api_sub_name = api_name_list[0], api_name_list[1] - if api_type_str not in [MsCompareConst.MINT, MsCompareConst.MINT_FUNCTIONAL, MsCompareConst.TENSOR_API] \ + if api_type_str not in [MsCompareConst.MINT, MsCompareConst.MINT_FUNCTIONAL, MsCompareConst.TENSOR_API, + MsCompareConst.FUNCTIONAL_API] \ and api_platform == Const.MS_FRAMEWORK: err_msg = f"ApiRunner.get_info_from_name failed: not mint, mint.nn.functional or Tensor api" logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue)) @@ -139,9 +146,9 @@ class ApiRunner: def get_api_instance(api_type_str, api_sub_name, api_platform): """ Args: - api_type_str: str, Union["MintFunctional", "Mint", "Tensor"] + api_type_str: str, Union["MintFunctional", "Mint", "Tensor", "Functional"] api_sub_name: str, e.g. "relu" - api_platform: str: Union["mindpore", "torch"] + api_platform: str: Union["mindpore", "pytorch"] Return: api_instance: function object @@ -151,9 +158,12 @@ class ApiRunner: mindspore.mint.{api_sub_name} <--> torch.{api_sub_name} mindspore.mint.nn.functional.{api_sub_name} <--> torch.nn.functional.{api_sub_name} """ - - api_parent_module = api_parent_module_mapping.get((api_type_str, api_platform)) - api_parent_module_str = api_parent_module_str_mapping.get((api_type_str, api_platform)) + if api_sub_name in MsCompareConst.SUPPORTED_FUSION_LIST and api_platform == "pytorch": + api_parent_module = api_parent_module_mapping.get((MsCompareConst.FUSION_API, api_platform)) + api_parent_module_str = api_parent_module_str_mapping.get((MsCompareConst.FUSION_API, api_platform)) + else: + api_parent_module = api_parent_module_mapping.get((api_type_str, api_platform)) + api_parent_module_str = api_parent_module_str_mapping.get((api_type_str, api_platform)) full_api_name = api_parent_module_str + Const.SEP + api_sub_name if not hasattr(api_parent_module, api_sub_name): diff --git a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py index ead03d25ea5c2e6bb0422486f1939c5b31ee589b..da2f8ad612fcf3a42083894ff1b8e56db757f919 100644 --- a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +++ b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py @@ -18,9 +18,10 @@ from abc import ABC, abstractmethod import mindspore import numpy as np import torch -from msprobe.core.common.const import CompareConst, MsCompareConst +from msprobe.core.common.const import CompareConst from msprobe.core.common.exceptions import ApiAccuracyCheckerException from msprobe.mindspore.common.log import logger +from msprobe.mindspore.common.const import MsCompareConst class CompareResult: diff --git a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py new file mode 100644 index 0000000000000000000000000000000000000000..c5adc9842414fd616816f30f3c0b66c22b0e86b7 --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py @@ -0,0 +1,541 @@ +# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import namedtuple +import torch +import torch.nn as nn +import numpy as np + +from einops import rearrange + + +from msprobe.pytorch.common.utils import logger + +GTYPE = torch.float64 # arm host必须选择float64,x86环境选择float32即可,64也行。arm计算很慢,s=8k的场景建议使用x86 +SOFTMAX_BUILD_MODE = "QKV" # "MAX_SUM" + +FaForwardParams = namedtuple("FaForwardParams", + ["q", "k", "v", "drop_mask", "attn_mask", "pse", "scalar_value", "keep_prob"]) +FaBackwardParams = namedtuple("FaBackwardParams", + ["dx", "q", "k", "v", "softmax_res", "drop_mask", "pse", "scalar_value", "keep_prob"]) +RebuildSoftmaxParams = namedtuple("RebuildSoftmaxParams", + ["q", "k", "attn_mask", "pse", "scalar_value", "softmax_max", "softmax_sum"]) + + +def softmax_forward(x): + x_max = torch.max(x, dim=-1, keepdims=True)[0] + x_sub = x.sub(x_max) + y = torch.exp(x_sub) + x_sum = y.sum(dim=-1, keepdims=True) + res = y.div(x_sum) + return res, x_max, x_sum + + +def softmax_grad(dp, softmax_res): + muls = dp * softmax_res + muls_r = muls.sum(dim=-1, keepdims=True) + sub_r = dp - muls_r + res = sub_r * softmax_res + return res + + +def broadcast_kv(num_heads, num_kv_heads, kv_tensor, dtype): + if num_kv_heads == 0 or num_kv_heads > num_heads: + raise ValueError(f"num_kv_heads must be non-zero and bigger than num_heads.") + + factor = num_heads // num_kv_heads + kv_shape = kv_tensor.shape + b = kv_shape[0] + s = kv_shape[2] + d = kv_shape[3] + kv_res = torch.zeros([b, num_heads, s, d]).to(dtype) + for i in range(num_heads): + j = i // factor + kv_res[:, i:i + 1, :, :] = kv_tensor[:, j:j + 1, :, :] + return kv_res + + +def calculate_qk(q, k, attn_mask, pse, scalar_value): + if k.dim() != 4: + raise ValueError(f"k tensor dimension must be 4, but got {k.dim()} dimensions (shape: {k.shape})") + + if k.dim() == 3: + k = k.unsqueeze(1) # 在head维度扩展 + + if pse is None or len(pse.shape) == 0: + qk = torch.matmul(q, k.permute(0, 1, 3, 2)).mul(scalar_value) + else: + qk = (torch.matmul(q, k.permute(0, 1, 3, 2)) + pse).mul(scalar_value) + if attn_mask is None or len(attn_mask.shape) == 0: + return qk + else: + qk = qk + attn_mask.bool() * (-40000.0) # -10000 + return qk + + +def fusion_attention_forward(forward_params): + q = forward_params.q + k = forward_params.k + v = forward_params.v + drop_mask = forward_params.drop_mask + attn_mask = forward_params.attn_mask + pse = forward_params.pse + scalar_value = forward_params.scalar_value + keep_prob = forward_params.keep_prob + + qk = calculate_qk(q, k, attn_mask, pse, scalar_value) + softmax_res, softmax_max, softmax_sum = softmax_forward(qk) + if drop_mask is None or len(drop_mask.shape) == 0: + drop_res = softmax_res + else: + drop_res = softmax_res * drop_mask * (1.0 / keep_prob) + y = torch.matmul(drop_res, v) + return y, softmax_max, softmax_sum + + +def fusion_attention_backward(backward_params): + dx = backward_params.dx + q = backward_params.q + k = backward_params.k + v = backward_params.v + softmax_res = backward_params.softmax_res + drop_mask = backward_params.drop_mask + pse = backward_params.pse + scalar_value = backward_params.scalar_value + keep_prob = backward_params.keep_prob + dp = torch.matmul(dx, v.permute(0, 1, 3, 2)) + if drop_mask is None or len(drop_mask.shape) == 0: + drop_res = softmax_res.permute(0, 1, 3, 2) + dp_drop = dp + else: + drop_res = softmax_res.mul(drop_mask).mul(1.0 / keep_prob).permute(0, 1, 3, 2) + dp_drop = dp * drop_mask * (1.0 / keep_prob) + dv = torch.matmul(drop_res, dx) + softmax_grad_res = (softmax_grad(dp_drop, softmax_res) * scalar_value) + dq = torch.matmul(softmax_grad_res, k) + dk = torch.matmul(softmax_grad_res.permute(0, 1, 3, 2), q) + return dq, dk, dv + + +def parse_bsnd_args(query, key, head_num, input_layout): + supported_input_layout = ["BSH", "SBH", "BSND", "BNSD", "TND"] + b, s1, s2, n1, n2, d, h1, h2 = None, None, None, head_num, None, None, None, None + + if not isinstance(input_layout, str) or input_layout not in supported_input_layout: + raise ValueError(f"Invalid input_layout arg which must be one of {supported_input_layout}.") + + if input_layout == "TND": + raise ValueError(f"input_layout {input_layout} does not supported for now.") + try: + if input_layout == "BSH": + b, s1, h1 = query.shape + _, s2, h2 = key.shape + d = h1 // n1 + n2 = h2 // d + elif input_layout == "SBH": + s1, b, h1 = query.shape + s2, _, h2 = key.shape + d = h1 // n1 + n2 = h2 // d + elif input_layout == "BSND": + b, s1, n1, d = query.shape + _, s2, n2, _ = key.shape + h1 = n1 * d + h2 = n2 * d + elif input_layout == "BNSD": + b, n1, s1, d = query.shape + _, n2, s2, _ = key.shape + h1 = n1 * d + h2 = n2 * d + except Exception as e: + raise ValueError(f"query.shape: {query.shape}, key.shape: {key.shape}, parse_bsnd_args error: {e}") from e + + if d == 0: + raise ValueError(f"Value d must be non-zero.") + _dtype = query.dtype + ret = (b, s1, s2, n1, n2, d, h1, h2, _dtype) + return ret + + +def convert_from_bnsd(_input, input_layout): + """ + transform qkv from bnsd to input_layout. + B: batch_size + S: sequence_length + N: num_heads + D: head_dim + Args: + _input (torch.Tensor): tensor of shape (B,N,S,D) + input_layout (str): "BSH" or "SBH" or "BSND" or "BNSD" or "TND" + Returns: + tensor of shape (B,N,S,D) or (B,S,N,D) or (S,B,H) or (B,S,H) + """ + if input_layout == "BSH": + # (B,N,S,D)=>(B,S,N*D) + out = rearrange(_input, 'b n s d -> b s (n d)').contiguous() + elif input_layout == "SBH": + # (B,N,S,D)=>(S,B,N*D) + out = rearrange(_input, 'b n s d -> s b (n d)').contiguous() + elif input_layout == "BSND": + # (B,N,S,D)=>(B,S,N,D) + out = rearrange(_input, 'b n s d -> b s n d').contiguous() + elif input_layout == "TND": + raise ValueError(f"input_layout {input_layout} does not supported for now.") + else: + out = _input + return out + + +def convert_to_bnsd(_input, n, input_layout): + """ + transform qkv from input_layout to bnsd. + B: batch_size + S: sequence_length + N: num_heads + D: head_dim + Args: + _input (torch.Tensor): tensor of shape (B,N,S,D) or (B,S,N,D) or (S,B,H) or (B,S,H) + n (int): num_heads + input_layout (str):"BSH" or "SBH" or "BSND" or "BNSD" or "TND" + Returns: + tensor of shape (B,N,S,D) + """ + if input_layout == "BSH": + # (B,S,N*D)=>(B,N,S,D) + out = rearrange(_input, 'b s (n d) -> b n s d', n=n) + elif input_layout == "SBH": + # (S,B,N*D)=>(B,N,S,D) + out = rearrange(_input, 's b (n d) -> b n s d', n=n) + elif input_layout == "BSND": + # (B,S,N,D)=>(B,N,S,D) + out = rearrange(_input, 'b s n d -> b n s d', n=n) + elif input_layout == "TND": + raise ValueError(f"input_layout {input_layout} does not supported for now.") + else: + out = _input + if out.dim() != 4: + raise ValueError(f"convert qkv format failed with input_layout {input_layout}.") + return out.to(GTYPE) + + +def generate_attn_mask(*args): + """ + # 当sparse_mode=2、3、4时小算子到融合算子会走这个优化,反过来看就要拆解回原来的基本实现 + ===> attn_mask = torch.from_numpy(np.triu(np.ones([2048, 2048]), k=1)).to(dtype) + """ + + sparse_mode, attn_mask, b, n1, s1, s2, pre_tocken, next_tocken, dtype = args + shape = [s1, s2] + + if attn_mask is not None: + # 当FA的输入已经包含attn_mask时,可以认为已经是转换之后的mask矩阵了,有三种特殊场景,即稀疏矩阵场景,需要进行逆向还原 + if sparse_mode == 2 or sparse_mode == 3 or sparse_mode == 4: + logger.info(f"s1: {s1}, s2:{s2}, attn_mask.shape:{attn_mask.shape}, attn_mask.dtype:{attn_mask.dtype}") + + if attn_mask.dim() == 2 and attn_mask.shape[0] == 2048 and attn_mask.shape[1] == 2048: + if attn_mask.equal(torch.from_numpy(np.triu(np.ones([2048, 2048]), k=1)).to(attn_mask.dtype)): + if sparse_mode == 2: + attn_mask = torch.from_numpy(np.triu(np.ones(shape), k=1)) + elif sparse_mode == 3: + attn_mask = torch.from_numpy(np.triu(np.ones(shape), k=s2 - s1 + 1)) + elif sparse_mode == 4: + attn_mask_u = torch.from_numpy(np.triu(np.ones(shape), k=next_tocken + 1)) + attn_mask_l = torch.from_numpy(np.tril(np.ones(shape), k=-pre_tocken - 1)) + attn_mask = attn_mask_u + attn_mask_l + logger.debug(f"反向转换attn_mask {attn_mask.shape}") + return attn_mask.to(dtype) + + return attn_mask.to(dtype) + + if attn_mask is not None: + if attn_mask.dim() == 2: + if attn_mask.shape[0] != s1 or attn_mask.shape[1] != s2: + raise ValueError(f"Invalid attn_mask shape `SS` {attn_mask.shape}") + shape = [s1, s2] + elif attn_mask.dim() == 4: + if attn_mask.shape[1] == 1: + shape = [b, 1, s1, s2] if b != 1 else [1, 1, s1, s2] + else: + shape = [b, n1, s1, s2] if b != 1 else [1, n1, s1, s2] + + if sparse_mode == 0: + attn_mask_u = torch.from_numpy(np.triu(np.ones(shape), k=next_tocken + 1)) + attn_mask_l = torch.from_numpy(np.tril(np.ones(shape), k=-pre_tocken - 1)) + attn_mask = attn_mask_u + attn_mask_l + elif sparse_mode == 1: # no sparse + attn_mask = torch.from_numpy(np.zeros(shape)) + elif sparse_mode == 2: + attn_mask = torch.from_numpy(np.triu(np.ones(shape), k=1)) + elif sparse_mode == 3: + attn_mask = torch.from_numpy(np.triu(np.ones(shape), k=s2 - s1 + 1)) + elif sparse_mode == 4: + attn_mask_u = torch.from_numpy(np.triu(np.ones(shape), k=next_tocken + 1)) + attn_mask_l = torch.from_numpy(np.tril(np.ones(shape), k=-pre_tocken - 1)) + attn_mask = attn_mask_u + attn_mask_l + # 注:不会出现sparse_mode=5的情况,该情况要求必须要传入attn_mask,且attn_mask矩阵数据格式须为BNSS或B1SS, + # 因此可以认为FA的输入已经是正确的attn_mask了 + return attn_mask.to(dtype) + + +def generate_kv(key, value, n1, n2): + # N不等长适配by cdy + if not (n1 == n2): + k_new = broadcast_kv(n1, n2, key, key.dtype) + v_new = broadcast_kv(n1, n2, value, value.dtype) + else: + k_new = key + v_new = value + return k_new, v_new + + +def rebuid_softmax_by_qkv(q, k, attn_mask, pse, scalar_value): + """ + attention = softmax(QK^T/sqrt(d))V + softmax(x_i) = e^(x_i - x_max) / sum(e^(x_i - x_max)) + """ + logger.info("Using QKV to rebuild original softmax") + qk = calculate_qk(q, k, attn_mask, pse, scalar_value) + softmax_res, _, _ = softmax_forward(qk) + return softmax_res + + +def rebuild_softmax_by_max_sum(softmax_params): + """ + attention = softmax(QK^T/sqrt(d))V + softmax(x_i) = e^(x_i - x_max_i) / x_sum_i) + """ + q = softmax_params.q + k = softmax_params.k + attn_mask = softmax_params.attn_mask + pse = softmax_params.pse + scalar_value = softmax_params.scalar_value + softmax_max = softmax_params.softmax_max + softmax_sum = softmax_params.softmax_sum + logger.info("Using softmax_max and softmax_sum to rebuild original softmax") + + qk = calculate_qk(q, k, attn_mask, pse, scalar_value) + if softmax_max.shape[-1] == 0: + raise ValueError(f"softmax_max.shape[-1] must be non-zero, softmax_max.shape: {softmax_max.shape}") + repeat_dim = qk.shape[-1] // softmax_max.shape[-1] + softmax_res = torch.exp(qk.sub(softmax_max.repeat(1, 1, 1, repeat_dim))).div( + softmax_sum.repeat(1, 1, 1, repeat_dim)) + return softmax_res + + +def get_head_num(*args, **kwargs): + if kwargs.get("head_num", None): + head_num = kwargs.get("head_num") + elif len(args) >= 4: + head_num = args[3] + else: + raise ValueError(f"Unsupported npu_fusion_attention args {args}.") + return head_num + + +def get_input_layout(*args, **kwargs): + if kwargs.get("input_layout", None): + input_layout = kwargs.get("input_layout") + elif len(args) >= 5: + input_layout = args[4] + else: + raise ValueError(f"Unsupported npu_fusion_attention args {args}.") + return input_layout + + +def npu_fusion_attention_forward_patch(*args, **kwargs): + if len(args) < 2: + raise RuntimeError("npu_fusion_attention_forward_patch: length of args should greater than or equal to 2.") + + # query, key, value, head_num, input_layout + head_num = get_head_num(*args, **kwargs) + input_layout = get_input_layout(*args, **kwargs) + + b, s1, s2, n1, n2, d, h1, h2, dtype = parse_bsnd_args(args[0], args[1], head_num, input_layout) + if n1 == n2 and s1 == s2: + logger.debug(f"running case : BNSD = {b}_{n1}_{s1}_{d}, sparse = {kwargs.get('sparse_mode', 0)}") + else: + logger.debug(f"running case: BNSD = {b}_{n1}({n2})_{s1}({s2})_{d}, sparse = {kwargs.get('sparse_mode', 0)}") + if not (n1 % n2 == 0 and n1 >= n2): + raise ValueError(f"N1与N2不匹配,请检查: n1 = {n1}, n2 = {n2}.") + + dims_kwargs = { + "b": b, "s1": s1, "s2": s2, "n1": n1, "n2": n2, + "d": d, "h1": h1, "h2": h2, "dtype": dtype + } + new_kwargs = { + "keep_prob": 1, + "scalar_value": kwargs.get("scalar_value", 1 / (d ** 0.5)), + "sparse_mode": kwargs.get("sparse_mode", 0), + "prefix": kwargs.get("prefix"), + "pre_tockens": kwargs.get("pre_tockens", 2147483647), + "next_tockens": kwargs.get("next_tockens", 2147483647), + "pse": kwargs.get("pse"), + "padding_mask": kwargs.get("padding_mask"), + "attn_mask": kwargs.get("attn_mask") + } + + return args, dims_kwargs, new_kwargs + + +def npu_fusion_attention_backward_patch(*args, **kwargs): + if len(args) != 6: + raise ValueError(f"Unsupported npu_fusion_attention_grad args {args}.") + + b, s1, s2, n1, n2, d, h1, h2, dtype = parse_bsnd_args(args[0], args[1], args[4], args[5]) + if n1 == n2 and s1 == s2: + logger.info(f"running case : bnsd = {b}_{n1}_{s1}_{d}, sparse = {kwargs.get('sparse_mode', 0)}") + else: + logger.info(f"running case: bnsd = {b}_{n1}({n2})_{s1}({s2})_{d}, sparse = {kwargs.get('sparse_mode', 0)}") + if not (n1 % n2 == 0 and n1 >= n2): + raise ValueError(f"N1与N2不匹配,请检查: n1 = {n1}, n2 = {n2}.") + + dims_kwargs = { + "b": b, "s1": s1, "s2": s2, "n1": n1, "n2": n2, + "d": d, "h1": h1, "h2": h2, "dtype": dtype + } + + new_kwargs = { + "keep_prob": 1, + "scalar_value_value": kwargs.get("scalar_value_value", 1 / (d ** 0.5)), + "sparse_mode": kwargs.get("sparse_mode", 0), + "prefix": kwargs.get("prefix"), + "pre_tockens": kwargs.get("pre_tockens", 2147483647), + "next_tockens": kwargs.get("next_tockens", 2147483647), + "pse": kwargs.get("pse"), + "padding_mask": kwargs.get("padding_mask"), + "softmax_max": kwargs.get("softmax_max"), + "softmax_sum": kwargs.get("softmax_sum"), + "softmax_in": kwargs.get("softmax_in"), + "attention_in": kwargs.get("attention_in"), + "seed": kwargs.get("seed", 0), + "offset": kwargs.get("offset", 0), + "numels": kwargs.get("numels", 0), + "attn_mask": kwargs.get("attn_mask") + } + + return args, dims_kwargs, new_kwargs + + +class FlashAttentionScore(nn.Module): + def __init__(self): + super(FlashAttentionScore, self).__init__() + # You can initialize any parameters here if necessary + + def forward(self, *inputs, **kwargs): + # Extract the inputs for the attention calculation + new_args, dims_kwargs, new_kwargs = npu_fusion_attention_forward_patch(*inputs, **kwargs) + query, key, value = new_args[0], new_args[1], new_args[2] + + input_layout = get_input_layout(*inputs, **kwargs) + + n1 = dims_kwargs.get("n1") + n2 = dims_kwargs.get("n2") + s1 = dims_kwargs.get("s1") + s2 = dims_kwargs.get("s2") + b = dims_kwargs.get("b") + dtype = dims_kwargs.get("dtype") + attn_mask = new_kwargs.get("attn_mask") + keep_prob = new_kwargs.get("keep_prob") + sparse_mode = new_kwargs.get("sparse_mode") + pre_tockens = new_kwargs.get("pre_tockens") + next_tockens = new_kwargs.get("next_tokens") + pse = new_kwargs.get("real_shift") + scalar_value = new_kwargs.get("scalar_value") + + args_temp = [sparse_mode, attn_mask, b, n1, s1, s2, pre_tockens, next_tockens, dtype] + + attn_mask = generate_attn_mask(*args_temp) + query = convert_to_bnsd(query, n1, input_layout) + key = convert_to_bnsd(key, n2, input_layout) + value = convert_to_bnsd(value, n2, input_layout) + + forward_params = FaForwardParams( + q=query, + k=key, + v=value, + drop_mask=None, + attn_mask=attn_mask, + pse=pse, + scalar_value=scalar_value, + keep_prob=keep_prob + ) + + out_golden, softmax_max, softmax_sum = fusion_attention_forward(forward_params) + + # If output dimension is 5, reshape accordingly + if out_golden.dim() == 5: + out_golden = out_golden.reshape(out_golden.size(0), + out_golden.size(1) * out_golden.size(2), + out_golden.size(3), out_golden.size(4)) + + out_golden = convert_from_bnsd(out_golden, input_layout) + + # Ensure the output matches the desired layout + out_golden = out_golden.cpu(), softmax_max.repeat(1, 1, 1, 8).cpu(), softmax_sum.repeat(1, 1, 1, 8).cpu() + + return out_golden + + def backward(self, *inputs, **kwargs): + # The backward pass will be similar to what was described for the gradient computation + new_args, dims_kwargs, new_kwargs = npu_fusion_attention_backward_patch(*inputs, **kwargs) + query, key, value, dx, input_layout = new_args[0], new_args[1], new_args[2], new_args[3], new_args[5] + n1 = dims_kwargs.get("n1") + n2 = dims_kwargs.get("n2") + s1 = dims_kwargs.get("s1") + s2 = dims_kwargs.get("s2") + b = dims_kwargs.get("b") + dtype = dims_kwargs.get("dtype") + attn_mask = new_kwargs.get("attn_mask") + keep_prob = new_kwargs.get("keep_prob") + sparse_mode = new_kwargs.get("sparse_mode") + pre_tockens = new_kwargs.get("pre_tockens") + next_tockens = new_kwargs.get("next_tockens") + pse = new_kwargs.get("pse") + softmax_max = new_kwargs.get("softmax_max") + softmax_sum = new_kwargs.get("softmax_sum") + scalar_value = new_kwargs.get("scalar_value") + + args_temp = [sparse_mode, attn_mask, b, n1, s1, s2, pre_tockens, next_tockens, dtype] + attn_mask = generate_attn_mask(*args_temp) + + query = convert_to_bnsd(query, n1, input_layout) + dx = convert_to_bnsd(dx, n1, input_layout) + key = convert_to_bnsd(key, n2, input_layout) + value = convert_to_bnsd(value, n2, input_layout) + + k_new, v_new = generate_kv(key, value, n1, n2) + + if SOFTMAX_BUILD_MODE == "QKV": + softmax_res = rebuid_softmax_by_qkv(query, k_new, attn_mask, pse, scalar_value) + else: + softmax_params = RebuildSoftmaxParams(query, k_new, attn_mask, pse, scalar_value, softmax_max, softmax_sum) + softmax_res = rebuild_softmax_by_max_sum(softmax_params) + + backward_params = FaBackwardParams(dx, query, k_new, v_new, softmax_res, None, pse, scalar_value, keep_prob) + dq, dk, dv = fusion_attention_backward(backward_params) + + # Reshape as needed + if dq.dim() == 5: + dq = dq.reshape(dq.size(0), dq.size(1) * dq.size(2), dq.size(3), dq.size(4)) + if dk.dim() == 5: + dk = dk.reshape(dk.size(0), dk.size(1) * dk.size(2), dk.size(3), dk.size(4)) + if dv.dim() == 5: + dv = dv.reshape(dv.size(0), dv.size(1) * dv.size(2), dv.size(3), dv.size(4)) + + dq = convert_from_bnsd(dq, input_layout) + dk = convert_from_bnsd(dk, input_layout) + dv = convert_from_bnsd(dv, input_layout) + + return dq.cpu(), dk.cpu(), dv.cpu() diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_vf.py b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py similarity index 34% rename from debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_vf.py rename to debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py index 05ee3bc92257be9882c20cf825ebb7561f41ddb1..e1344541e89c4dafd9d49d63e3fdea117366bdd9 100644 --- a/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_vf.py +++ b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py @@ -13,48 +13,29 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os -import torch +from msprobe.mindspore.api_accuracy_checker.bench_functions.flash_attention_score import FlashAttentionScore -from msprobe.core.common.const import Const -from msprobe.core.common.file_utils import load_yaml -from msprobe.pytorch.hook_module.hook_module import HOOKModule -from msprobe.pytorch.common.utils import torch_device_guard +class FusionOperator: + """ + 所有融合算子的父类,定义了通用的接口和属性。 + """ -cur_path = os.path.dirname(os.path.realpath(__file__)) -yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml") + # 初始化操作符字典 + def __init__(self): + self.flash_attention_score = None # 用于存放 FlashAttentionScore 操作符 + self._register_operators() + def __getattr__(self, name): + """ 动态获取算子类 """ + if hasattr(self, name): + return getattr(self, name) + else: + raise AttributeError(f"'FusionOperator' object has no attribute '{name}'") -def get_vf_ops(): - yaml_data = load_yaml(yaml_path) - wrap_vf_ops = yaml_data.get('_VF') - return wrap_vf_ops + def _register_operators(self): + """ 注册操作符到父类,以便通过 fusion.xxx 调用 """ + self.flash_attention_score = FlashAttentionScore() -class HOOKVfOP(object): - pass - - -class VfOPTemplate(HOOKModule): - def __init__(self, op_name, hook): - self.op_name_ = op_name - self.prefix_op_name_ = "VF" + Const.SEP + str(op_name) + Const.SEP - super().__init__(hook) - - @torch_device_guard - def forward(self, *args, **kwargs): - return getattr(torch._C._VariableFunctionsClass, str(self.op_name_))(*args, **kwargs) - - -def wrap_vf_op(op_name, hook): - def vf_op_template(*args, **kwargs): - return VfOPTemplate(op_name, hook)(*args, **kwargs) - - return vf_op_template - - -def wrap_vf_ops_and_bind(hook): - _vf_ops = get_vf_ops() - for op_name in _vf_ops: - setattr(HOOKVfOP, "wrap_" + op_name, wrap_vf_op(op_name, hook)) +fusion = FusionOperator() diff --git a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/data_manager.py b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/data_manager.py index 748adf7d02cafe3983fe1990b40b1e77e993698b..fc2680d68a5697dae165c70a276b21038f87fbe0 100644 --- a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/data_manager.py +++ b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/data_manager.py @@ -16,12 +16,13 @@ import os import csv -from msprobe.core.common.const import Const, CompareConst, MsCompareConst +from msprobe.core.common.const import Const, CompareConst from msprobe.core.common.file_utils import FileOpen, create_directory, write_csv, read_csv from msprobe.core.common.utils import add_time_as_suffix, MsprobeBaseException from msprobe.mindspore.api_accuracy_checker.base_compare_algorithm import compare_algorithms from msprobe.core.common.file_utils import check_file_or_directory_path from msprobe.mindspore.common.log import logger +from msprobe.mindspore.common.const import MsCompareConst class ResultCsvEntry: diff --git a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py index e764140badf4c107ea83044353aba19a1c412fe0..1913675ad162bf690fc0aed5fc84c245ae4f73ca 100644 --- a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +++ b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py @@ -27,10 +27,11 @@ import numpy as np from tqdm import tqdm # 本地应用/库特定导入 -from msprobe.core.common.const import Const, CompareConst, MsCompareConst +from msprobe.core.common.const import Const, CompareConst from msprobe.mindspore.api_accuracy_checker.api_accuracy_checker import ApiAccuracyChecker, BasicInfoAndStatus from msprobe.mindspore.api_accuracy_checker.multi_data_manager import MultiDataManager from msprobe.mindspore.common.log import logger +from msprobe.mindspore.common.const import MsCompareConst class MultiApiAccuracyChecker(ApiAccuracyChecker): diff --git a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py index 84f2706cc55fa3d0a1fba13d54ba8310371f1a43..7b319382eb4eba4abac3bd6894cc3b0262032d88 100644 --- a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +++ b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py @@ -19,7 +19,8 @@ import sys from pathlib import Path import mindspore from msprobe.mindspore.common.log import logger -from msprobe.core.common.const import Const, CompareConst, MsCompareConst +from msprobe.core.common.const import Const, CompareConst +from msprobe.mindspore.common.const import MsCompareConst import torch as mindtorch from torch import Tensor as mindtorch_tensor import torch.nn.functional as mindtorch_func diff --git a/debug/accuracy_tools/msprobe/mindspore/code_mapping/graph_parser.py b/debug/accuracy_tools/msprobe/mindspore/code_mapping/graph_parser.py index ee35750fb35c100e2025b0dcbdd9e20ef998b2ee..262e3b6fe0703f9ea1bc757ae3873e4fc3c99fac 100644 --- a/debug/accuracy_tools/msprobe/mindspore/code_mapping/graph_parser.py +++ b/debug/accuracy_tools/msprobe/mindspore/code_mapping/graph_parser.py @@ -34,19 +34,6 @@ class Parser: if isinstance(subgraph_node.attrs, list): subgraph_node.attrs.extend(attrs) - @staticmethod - def parse_graph_attributes(text: str, graph_node: GraphNode) -> None: - attr_pattern = re.compile(r'# Attrs:\s*(.*)', re.DOTALL) - match = attr_pattern.search(text, graph_node.pos) - if match: - attrs = match.group(1).strip().split('\n') - for attr in attrs: - if not attr: - break - key, value = attr.split(':') - if isinstance(graph_node.attrs, dict): - graph_node.attrs[key.strip()] = value.strip() - @staticmethod def parse_code_info(text: str, start_pos: int, end_pos: int) -> List[str]: code_info = [] @@ -203,11 +190,6 @@ class Parser: subgraph_info.end = end_pos logging.info('Parsed subgraph: %s', subgraph_name) - def count_nodes(self) -> Tuple[int, int]: - total_nodes = len(self.nodes) - total_cnodes = sum(1 for node in self.nodes.values() if node.name.startswith('CNode')) - return total_nodes, total_cnodes - def create_backward_map(self): for node in self.nodes.values(): if node.scope and node.scope.startswith("Gradients"): diff --git a/debug/accuracy_tools/msprobe/mindspore/common/const.py b/debug/accuracy_tools/msprobe/mindspore/common/const.py index 9e8c79e51284b8e9696dde150481609f7da8b488..067e783842f13899feaba00476777ded707e9eb7 100644 --- a/debug/accuracy_tools/msprobe/mindspore/common/const.py +++ b/debug/accuracy_tools/msprobe/mindspore/common/const.py @@ -70,6 +70,67 @@ class Const: } +class MsCompareConst: + # api_info field + MINT = "Mint" + MINT_FUNCTIONAL = "MintFunctional" + TENSOR_API = "Tensor" + FUNCTIONAL_API = "Functional" + FUSION_API = "FUSION" + + API_NAME_STR_LENGTH = 4 + MAX_RECURSION_DEPTH = 20 + + # Mindtorch api_info field + MINDTORCH_TENSOR = "Tensor" + MINDTORCH = "Torch" + MINDTORCH_FUNC = "Functional" + MINDTORCH_NPU = "NPU" + MINDTORCH_DIST = "Distributed" + + + + MT_VALID_API_TYPES = [ + MINDTORCH, MINDTORCH_FUNC, MINDTORCH_TENSOR + ] + SUPPORTED_FUSION_LIST = ["flash_attention_score"] + + + TASK_FIELD = "task" + STATISTICS_TASK = "statistics" + FRAMEWORK = "framework" + TENSOR_TASK = "tensor" + DUMP_DATA_DIR_FIELD = "dump_data_dir" + DATA_FIELD = "data" + + # supported api yaml + SUPPORTED_API_LIST_FILE = "checker_support_api.yaml" + SUPPORTED_TENSOR_LIST_KEY = "tensor" + + # detail_csv + DETAIL_CSV_API_NAME = "API Name" + DETAIL_CSV_BENCH_DTYPE = "Bench Dtype" + DETAIL_CSV_TESTED_DTYPE = "Tested Dtype" + DETAIL_CSV_SHAPE = "Shape" + DETAIL_CSV_PASS_STATUS = "Status" + DETAIL_CSV_MESSAGE = "Message" + DETAIL_CSV_FILE_NAME = "accuracy_checking_details" + + # result_csv + RESULT_CSV_FORWARD_TEST_SUCCESS = "Forward Test Success" + RESULT_CSV_BACKWARD_TEST_SUCCESS = "Backward Test Success" + RESULT_CSV_FILE_NAME = "accuracy_checking_result" + + EPSILON = 1e-8 + + class ProcessStatus: + SUCCESS = "success" + API_NOT_FOUND = "api_not_found" + EXCEPTION_SKIP = "exception_skip" + + + + class FreeBenchmarkConst: ADD_NOISE = "add_noise" BIT_NOISE = "bit_noise" diff --git a/debug/accuracy_tools/msprobe/mindspore/common/utils.py b/debug/accuracy_tools/msprobe/mindspore/common/utils.py index ded3faaa22b565ef35c17a7596782976ddf9125d..625842da589a3090cddc75c50175ac577f1777b6 100644 --- a/debug/accuracy_tools/msprobe/mindspore/common/utils.py +++ b/debug/accuracy_tools/msprobe/mindspore/common/utils.py @@ -28,6 +28,30 @@ from msprobe.core.common.const import Const from msprobe.core.common.utils import CompareException, check_seed_all +class MsprobeStep(ms.train.Callback): + def __init__(self, debugger): + super(MsprobeStep, self).__init__() + self.debugger = debugger + + def on_train_step_begin(self, run_context): + self.debugger.start() + + def on_train_step_end(self, run_context): + self.debugger.stop() + self.debugger.step() + + +class MsprobeInitStep(ms.train.Callback): + def on_train_begin(self, run_context): + try: + from ms._c_expression import _set_init_iter + except ImportError: + logger.warning('MsprobeInitStep does not work on this version of MindSpore.') + return + cb_params = run_context.original_args() + _set_init_iter(cb_params.cur_step_num) + + def get_rank_if_initialized(): if ms.communication.GlobalComm.INITED: return ms.communication.get_rank() @@ -93,20 +117,6 @@ def seed_all(seed=1234, mode=False, rm_dropout=True): remove_dropout() -class MsprobeStep(ms.train.Callback): - - def __init__(self, debugger): - super(MsprobeStep, self).__init__() - self.debugger = debugger - - def on_train_step_begin(self, run_context): - self.debugger.start() - - def on_train_step_end(self, run_context): - self.debugger.stop() - self.debugger.step() - - class Dropout(ops.Dropout): def __init__(self, keep_prob=0.5, seed0=0, seed1=1): super().__init__(1., seed0, seed1) @@ -169,7 +179,7 @@ def set_register_backward_hook_functions(): from msprobe.mindspore.mindtorch import (_call_impl, register_full_backward_pre_hook, register_full_backward_hook) - if not hasattr(torch, "register_full_backward_hook"): + if not hasattr(torch.nn.Module, "register_full_backward_hook"): setattr(torch.nn.Module, "_call_impl", _call_impl) setattr(torch.nn.Module, "register_full_backward_pre_hook", register_full_backward_pre_hook) setattr(torch.nn.Module, "register_full_backward_hook", register_full_backward_hook) @@ -182,9 +192,9 @@ def set_register_backward_hook_functions(): def check_save_param(variable, name, save_backward): # try catch this api to skip invalid call - if not isinstance(variable, (list, dict, ms.Tensor, int, float, str)): + if not isinstance(variable, (list, dict, tuple, ms.Tensor, int, float, str)): logger.warning("PrecisionDebugger.save variable type not valid, " - "should be one of list, dict, ms.Tensor, int, float or string. " + "should be one of list, dict, tuple, ms.Tensor, int, float or string. " "Skip current save process.") raise ValueError if not isinstance(name, str): @@ -196,4 +206,4 @@ def check_save_param(variable, name, save_backward): logger.warning("PrecisionDebugger.save_backward name not valid, " "should be bool. " "Skip current save process.") - raise ValueError \ No newline at end of file + raise ValueError diff --git a/debug/accuracy_tools/msprobe/mindspore/compare/ms_compare.py b/debug/accuracy_tools/msprobe/mindspore/compare/ms_compare.py index 8509a7f38add0c2e8d3f3638f4c247895e07bd6d..4f158512bb4671d6cb46f706d8bf32772a4c971e 100644 --- a/debug/accuracy_tools/msprobe/mindspore/compare/ms_compare.py +++ b/debug/accuracy_tools/msprobe/mindspore/compare/ms_compare.py @@ -22,10 +22,10 @@ import pandas as pd from msprobe.core.common.const import CompareConst, Const from msprobe.core.common.exceptions import FileCheckException -from msprobe.core.common.file_utils import FileOpen, create_directory, load_json, load_npy, load_yaml +from msprobe.core.common.file_utils import create_directory, load_json, load_npy, load_yaml from msprobe.core.common.log import logger from msprobe.core.common.utils import CompareException, check_compare_param, check_configuration_param, \ - check_op_str_pattern_valid, get_dump_mode, set_dump_path + check_op_str_pattern_valid, get_dump_mode, set_dump_path, detect_framework_by_dump_json from msprobe.core.compare.acc_compare import Comparator, ModeConfig from msprobe.core.compare.check import dtype_mapping from msprobe.core.compare.layer_mapping import generate_data_mapping_by_layer_mapping @@ -78,6 +78,11 @@ class MSComparator(Comparator): raise TypeError(f"The type of parameter `data_mapping` must be dict, str or None, but got " f"{type(self.data_mapping)}") + @staticmethod + def process_data_name(result): + result['data_name_x'] = result.apply(lambda row: [row['data_name_x'], row['data_name_y']], axis=1) + return result + def calc_accuracy(self, result_df, header): condition_no_bench = result_df[CompareConst.BENCH_NAME] == CompareConst.N_A result_df[condition_no_bench] = result_df[condition_no_bench].fillna(CompareConst.N_A) @@ -125,7 +130,8 @@ class MSComparator(Comparator): result_df.loc[warning_flag, CompareConst.RESULT] = CompareConst.WARNING result_df.loc[warning_flag, CompareConst.ERROR_MESSAGE] = 'Need double check api accuracy.' else: - fill_cols = [CompareConst.COSINE, CompareConst.MAX_ABS_ERR, CompareConst.MAX_RELATIVE_ERR, + fill_cols = [CompareConst.COSINE, CompareConst.EUC_DIST, + CompareConst.MAX_ABS_ERR, CompareConst.MAX_RELATIVE_ERR, CompareConst.ONE_THOUSANDTH_ERR_RATIO, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO, CompareConst.ERROR_MESSAGE] result_df.loc[~condition_no_bench, fill_cols] = '' @@ -139,6 +145,8 @@ class MSComparator(Comparator): header.append(CompareConst.STACK) if self.dump_mode == Const.ALL: header.append(CompareConst.DATA_NAME) + result = self.process_data_name(result) + result.rename(columns={'op_name_x': CompareConst.NPU_NAME, 'op_name_y': CompareConst.BENCH_NAME, 'dtype_x': CompareConst.NPU_DTYPE, @@ -169,6 +177,7 @@ class MSComparator(Comparator): result[npu_summary] = result['summary_x'].apply(set_summary).tolist() result[bench_summary] = result['summary_y'].apply(set_summary).tolist() + result_df = pd.DataFrame(columns=header) for h in header: if h in result.columns: @@ -269,15 +278,15 @@ class MSComparator(Comparator): bench_dtype = match_result['dtype_y'] if self.cross_frame: npu_dtype = npu_dtype.map(dtype_mapping).fillna(npu_dtype) - return ((npu_dtype == bench_dtype) | - ((npu_dtype == Const.FLOAT16) & (bench_dtype == Const.FLOAT32)) | - ((npu_dtype == Const.FLOAT32) & (bench_dtype == Const.FLOAT16)) | - ((npu_dtype == Const.FLOAT16) & (bench_dtype == Const.BFLOAT16)) | - ((npu_dtype == Const.BFLOAT16) & (bench_dtype == Const.FLOAT16)) | - ((npu_dtype == Const.TORCH_FLOAT16) & (bench_dtype == Const.TORCH_FLOAT32)) | - ((npu_dtype == Const.TORCH_FLOAT32) & (bench_dtype == Const.TORCH_FLOAT16)) | - ((npu_dtype == Const.TORCH_FLOAT16) & (bench_dtype == Const.TORCH_BFLOAT16)) | - ((npu_dtype == Const.TORCH_BFLOAT16) & (bench_dtype == Const.TORCH_FLOAT16))) + + equal_condition = npu_dtype == bench_dtype + match_condition = ( + (npu_dtype.isin(CompareConst.DTYPE_MATCH_GROUPS[0]) & bench_dtype.isin( + CompareConst.DTYPE_MATCH_GROUPS[0])) | + (npu_dtype.isin(CompareConst.DTYPE_MATCH_GROUPS[1]) & bench_dtype.isin( + CompareConst.DTYPE_MATCH_GROUPS[1])) + ) + return equal_condition | match_condition match_result.loc[~gen_dtype_condition(), [i + '_y' for i in bench_df.columns]] = CompareConst.N_A return self.make_result_df(match_result) @@ -382,12 +391,11 @@ class MSComparator(Comparator): def check_cross_framework(bench_json_path): - pattern = r'"data_name":\s*"[^"]+\.pt"' - with FileOpen(bench_json_path, 'r') as file: - for line in file: - if re.search(pattern, line): - return True - return False + framework = detect_framework_by_dump_json(bench_json_path) + if framework == Const.PT_FRAMEWORK: + return True + else: + return False def ms_compare(input_param, output_path, **kwargs): diff --git a/debug/accuracy_tools/msprobe/mindspore/compare/ms_graph_compare.py b/debug/accuracy_tools/msprobe/mindspore/compare/ms_graph_compare.py index 701988ba483de4e13d85892dbb42d62c7cc805b8..153f4fd655212b24904c33d29dad694ee1dd2c1f 100644 --- a/debug/accuracy_tools/msprobe/mindspore/compare/ms_graph_compare.py +++ b/debug/accuracy_tools/msprobe/mindspore/compare/ms_graph_compare.py @@ -195,11 +195,12 @@ class GraphMSComparator: if not error_flag: result_list, err_msg = compare_ops_apply(n_value, b_value, False, "") result_dict[CompareConst.COSINE] = result_list[0] - result_dict[CompareConst.MAX_ABS_ERR] = result_list[1] - result_dict[CompareConst.MAX_RELATIVE_ERR] = result_list[2] - result_dict[CompareConst.ONE_THOUSANDTH_ERR_RATIO] = result_list[3] - result_dict[CompareConst.FIVE_THOUSANDTHS_ERR_RATIO] = result_list[4] - result_dict[CompareConst.ACCURACY] = check_accuracy(result_list[0], result_list[1]) + result_dict[CompareConst.EUC_DIST] = result_list[1] + result_dict[CompareConst.MAX_ABS_ERR] = result_list[2] + result_dict[CompareConst.MAX_RELATIVE_ERR] = result_list[3] + result_dict[CompareConst.ONE_THOUSANDTH_ERR_RATIO] = result_list[4] + result_dict[CompareConst.FIVE_THOUSANDTHS_ERR_RATIO] = result_list[5] + result_dict[CompareConst.ACCURACY] = check_accuracy(result_list[0], result_list[2]) result_dict[CompareConst.ERROR_MESSAGE] = err_msg return pd.Series(result_dict) diff --git a/debug/accuracy_tools/msprobe/mindspore/debugger/precision_debugger.py b/debug/accuracy_tools/msprobe/mindspore/debugger/precision_debugger.py index 7694d71dd98ae1c7c4611f9435a274ac018e5df6..4f2109504fbc2c54bc73daa42ecd96567ac90502 100644 --- a/debug/accuracy_tools/msprobe/mindspore/debugger/precision_debugger.py +++ b/debug/accuracy_tools/msprobe/mindspore/debugger/precision_debugger.py @@ -22,12 +22,12 @@ from mindspore._c_expression import MSContext from msprobe.core.common.const import Const, FileCheckConst, MsgConst from msprobe.core.common.exceptions import MsprobeException from msprobe.core.common.file_utils import FileChecker -from msprobe.core.common.utils import get_real_step_or_rank +from msprobe.core.common.utils import get_real_step_or_rank, check_init_step from msprobe.mindspore.cell_processor import CellProcessor from msprobe.mindspore.common.const import Const as MsConst from msprobe.mindspore.common.utils import set_register_backward_hook_functions, check_save_param from msprobe.mindspore.debugger.debugger_config import DebuggerConfig -from msprobe.mindspore.dump.hook_cell.api_registry import api_register +from msprobe.mindspore.dump.hook_cell.api_register import get_api_register from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell from msprobe.mindspore.grad_probe.grad_monitor import GradientMonitor from msprobe.mindspore.ms_config import parse_json_config @@ -84,7 +84,7 @@ class PrecisionDebugger: common_config.dump_path = dump_path if dump_path else common_config.dump_path self.config = DebuggerConfig(common_config, task_config) - if _msprobe_c: + if self._need_msprobe_c() and _msprobe_c: _msprobe_c._PrecisionDebugger(framework="MindSpore", config_path=config_path) self.config.execution_mode = self._get_execution_mode() @@ -151,7 +151,7 @@ class PrecisionDebugger: instance = cls._instance if not instance: raise Exception(MsgConst.NOT_CREATED_INSTANCE) - if _msprobe_c: + if cls._need_msprobe_c() and _msprobe_c: _msprobe_c._PrecisionDebugger().start() if instance.task in PrecisionDebugger.task_not_need_service: return @@ -163,7 +163,7 @@ class PrecisionDebugger: instance.service.start(model) else: if not instance.first_start: - api_register.api_set_ori_func() + get_api_register().restore_all_api() handler = TaskHandlerFactory.create(instance.config) handler.handle() @@ -180,7 +180,7 @@ class PrecisionDebugger: instance = cls._instance if not instance: raise Exception(MsgConst.NOT_CREATED_INSTANCE) - if _msprobe_c: + if cls._need_msprobe_c() and _msprobe_c: _msprobe_c._PrecisionDebugger().stop() if instance.task == Const.GRAD_PROBE: instance.gm.stop() @@ -195,7 +195,7 @@ class PrecisionDebugger: instance = cls._instance if not instance: raise Exception(MsgConst.NOT_CREATED_INSTANCE) - if _msprobe_c: + if cls._need_msprobe_c() and _msprobe_c: _msprobe_c._PrecisionDebugger().step() if instance.task in PrecisionDebugger.task_not_need_service: return @@ -233,6 +233,14 @@ class PrecisionDebugger: instance.service = Service(instance.config) instance.service.save(variable, name, save_backward) + @classmethod + def set_init_step(cls, step): + instance = cls._instance + if not instance: + raise Exception(MsgConst.NOT_CREATED_INSTANCE) + check_init_step(step) + instance.service.init_step = step + @classmethod def _need_service(cls): instance = cls._instance @@ -241,4 +249,11 @@ class PrecisionDebugger: if instance.config.execution_mode != MsConst.PYNATIVE_MODE: return False else: - return instance.config.task != Const.FREE_BENCHMARK and not instance._is_graph_dump(instance.config) \ No newline at end of file + return instance.config.task != Const.FREE_BENCHMARK and not instance._is_graph_dump(instance.config) + + @classmethod + def _need_msprobe_c(cls): + instance = cls._instance + if not instance: + raise Exception(MsgConst.NOT_CREATED_INSTANCE) + return instance.config.level_ori == Const.LEVEL_L2 diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_register.py b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_register.py new file mode 100644 index 0000000000000000000000000000000000000000..53271ff07bea418db66b3d3724a84eda5b52c296 --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_register.py @@ -0,0 +1,133 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from mindspore import Tensor, ops, mint +from mindspore.mint.nn import functional +from mindspore.common._stub_tensor import StubTensor +from mindspore.communication import comm_func + +from msprobe.core.common.file_utils import load_yaml +from msprobe.core.common.utils import Const +from msprobe.core.data_dump.api_registry import ApiRegistry +from msprobe.mindspore.common.const import Const as MsConst +from msprobe.mindspore.common.utils import is_mindtorch +from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell + + +cur_path = os.path.dirname(os.path.realpath(__file__)) +if not is_mindtorch(): + _api_types = { + Const.MS_FRAMEWORK: { + Const.MS_API_TYPE_OPS: (ops, (ops,)), + Const.MS_API_TYPE_TENSOR: (Tensor, (Tensor,)), + Const.MS_API_TYPE_STUB_TENSOR: (StubTensor, (StubTensor,)), + Const.MS_API_TYPE_MINT: (mint, (mint,)), + Const.MS_API_TYPE_MINT_FUNC: (functional, (functional,)), + Const.MS_API_TYPE_COM: (comm_func, (comm_func,)) + } + } + _supported_api_list_path = (os.path.join(cur_path, MsConst.SUPPORTED_API_LIST_FILE),) +else: + import torch + import torch_npu + _api_types = { + Const.MT_FRAMEWORK: { + Const.PT_API_TYPE_FUNCTIONAL: (torch.nn.functional, (torch.nn.functional,)), + Const.PT_API_TYPE_TENSOR: (torch.Tensor, (torch.Tensor,)), + Const.PT_API_TYPE_TORCH: (torch, (torch,)), + Const.PT_API_TYPE_NPU: (torch_npu, (torch_npu,)), + Const.PT_API_TYPE_DIST: (torch.distributed, (torch.distributed, torch.distributed.distributed_c10d)) + } + } + _supported_api_list_path = (os.path.join(cur_path, '../../../pytorch/hook_module', + MsConst.SUPPORTED_API_LIST_FILE),) + +_inner_used_api = { + Const.MS_FRAMEWORK + Const.SEP + Const.MS_API_TYPE_OPS: ( + ops, "norm", "square", "sqrt", "is_complex", "stack", "is_floating_point" + ), + Const.MS_FRAMEWORK + Const.SEP + Const.MS_API_TYPE_TENSOR: ( + Tensor, "to", "numel" + ), + Const.MS_FRAMEWORK + Const.SEP + Const.MS_API_TYPE_MINT: ( + mint, "max", "min", "mean", "norm" + ) +} + + +class ApiTemplate(HOOKCell): + def __init__(self, api_name, api_func, prefix, hook_build_func): + self.api_name = api_name + self.api_func = api_func + self.prefix_api_name = prefix + Const.SEP + str(api_name.split(Const.SEP)[-1]) + Const.SEP + super().__init__(hook_build_func) + + @staticmethod + def async_to_sync(output): + # Fake handle, used to return after the CommHandle executes the wait method + fake_handle = type("FakeHandle", (), {"wait": lambda self: None})() + if isinstance(output, tuple) and len(output) == 2 and hasattr(output[1], "wait"): + output[1].wait() + output = (output[0], fake_handle) + elif hasattr(output, "wait"): + output.wait() + output = fake_handle + return output + + def construct(self, *args, **kwargs): + if self.api_name.startswith(MsConst.DROPOUT_API_NAME_PREFIX): + return args[0] if args else kwargs.get(Const.INPUT) + + output = self.api_func(*args, **kwargs) + + if self.prefix_api_name.startswith(MsConst.DISTRIBUTED_DATA_PREFIX): + if kwargs.get("async_op") or self.api_name in ["isend", "irecv"]: + output = self.async_to_sync(output) + return output + + def forward(self, *args, **kwargs): + if self.api_name.startswith(MsConst.DROPOUT_API_NAME_PREFIX): + return args[0] if args else kwargs.get(Const.INPUT) + return self.api_func(*args, **kwargs) + + +api_register = None +stub_tensor_set = False + + +def get_api_register(return_new=False): + global stub_tensor_set + + def stub_method(method): + def wrapped_method(*args, **kwargs): + return method(*args, **kwargs) + return wrapped_method + if not is_mindtorch() and not stub_tensor_set: + for attr_name in dir(StubTensor): + attr = getattr(StubTensor, attr_name) + api_names = load_yaml(_supported_api_list_path[0]).get(Const.MS_API_TYPE_TENSOR, []) + if attr_name in api_names and callable(attr): + setattr(StubTensor, attr_name, stub_method(attr)) + stub_tensor_set = True + + if return_new: + return ApiRegistry(_api_types, _inner_used_api, _supported_api_list_path, ApiTemplate) + + global api_register + if api_register is None: + api_register = ApiRegistry(_api_types, _inner_used_api, _supported_api_list_path, ApiTemplate) + return api_register diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py deleted file mode 100644 index 7aee1deccd9689985c7a2e270648bd0877cd7cf3..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py +++ /dev/null @@ -1,207 +0,0 @@ -# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from mindspore import Tensor, ops, mint -from mindspore.mint.nn import functional -from mindspore.common._stub_tensor import StubTensor -from mindspore.communication import comm_func - -from msprobe.mindspore.dump.hook_cell.wrap_api import (HOOKTensor, HOOKStubTensor, HOOKFunctionalOP, - HOOKMintOP, HOOKMintNNFunctionalOP, HOOKDistributedOP, - HOOKTorchOP, HOOKTorchTensor, HOOKTorchFunctionalOP, - HOOKTorchDistributedOP, HOOKTorchNpuOP, - get_wrap_api_list, get_wrap_torch_api_list, setup_hooks) -from msprobe.core.common.utils import Const -from msprobe.mindspore.common.utils import is_mindtorch - -if is_mindtorch(): - import torch - import torch_npu - - -def stub_method(method): - def wrapped_method(*args, **kwargs): - return method(*args, **kwargs) - return wrapped_method - - -class ApiRegistry: - def __init__(self): - self.tensor_ori_attr = {} - self.stub_tensor_ori_attr = {} - self.functional_ori_attr = {} - self.mint_ops_ori_attr = {} - self.mint_func_ops_ori_attr = {} - self.distributed_ori_attr = {} - self.norm_inner_ops_ori_attr = {} - - self.torch_ori_attr = {} - self.torch_tensor_ori_attr = {} - self.torch_functional_ori_attr = {} - self.torch_distributed_ori_attr = {} - self.torch_npu_ori_attr = {} - - self.tensor_hook_attr = {} - self.stub_tensor_hook_attr = {} - self.functional_hook_attr = {} - self.mint_ops_hook_attr = {} - self.mint_func_ops_hook_attr = {} - self.distibuted_hook_attr = {} - self.norm_inner_ops_hook_attr = {} - - self.torch_hook_attr = {} - self.torch_tensor_hook_attr = {} - self.torch_functional_hook_attr = {} - self.torch_distributed_hook_attr = {} - self.torch_npu_hook_attr = {} - - self.norm_inner_ops = ["norm", "square", "sqrt", "is_complex"] - - @staticmethod - def store_ori_attr(ori_api_group, api_list, api_ori_attr): - for api in api_list: - if Const.SEP in api: - sub_module_name, sub_op = api.rsplit(Const.SEP, 1) - sub_module = getattr(ori_api_group, sub_module_name) - ori_api_func = getattr(sub_module, sub_op) - else: - ori_api_func = getattr(ori_api_group, api) - if ori_api_group == StubTensor: - api_ori_attr[api] = stub_method(ori_api_func) - continue - api_ori_attr[api] = ori_api_func - - @staticmethod - def set_api_attr(api_group, attr_dict): - for api, api_attr in attr_dict.items(): - if Const.SEP in api: - sub_module_name, sub_op = api.rsplit(Const.SEP, 1) - sub_module = getattr(api_group, sub_module_name, None) - if sub_module is not None: - setattr(sub_module, sub_op, api_attr) - else: - setattr(api_group, api, api_attr) - - def norm_inner_op_set_hook_func(self): - self.set_api_attr(ops, self.norm_inner_ops_hook_attr) - - def norm_inner_op_set_ori_func(self): - self.set_api_attr(ops, self.norm_inner_ops_ori_attr) - - def api_set_hook_func(self): - if is_mindtorch(): - self.set_api_attr(torch, self.torch_hook_attr) - self.set_api_attr(torch.Tensor, self.torch_tensor_hook_attr) - self.set_api_attr(torch.nn.functional, self.torch_functional_hook_attr) - self.set_api_attr(torch.distributed, self.torch_distributed_hook_attr) - self.set_api_attr(torch.distributed.distributed_c10d, self.torch_distributed_hook_attr) - self.set_api_attr(torch_npu, self.torch_npu_hook_attr) - else: - self.set_api_attr(Tensor, self.tensor_hook_attr) - self.set_api_attr(StubTensor, self.stub_tensor_hook_attr) - self.set_api_attr(ops, self.functional_hook_attr) - self.set_api_attr(mint, self.mint_ops_hook_attr) - self.set_api_attr(functional, self.mint_func_ops_hook_attr) - self.set_api_attr(comm_func, self.distibuted_hook_attr) - - def api_set_ori_func(self): - if is_mindtorch(): - self.set_api_attr(torch, self.torch_ori_attr) - self.set_api_attr(torch.Tensor, self.torch_tensor_ori_attr) - self.set_api_attr(torch.nn.functional, self.torch_functional_ori_attr) - self.set_api_attr(torch.distributed, self.torch_distributed_ori_attr) - self.set_api_attr(torch.distributed.distributed_c10d, self.torch_distributed_ori_attr) - self.set_api_attr(torch_npu, self.torch_npu_ori_attr) - else: - self.set_api_attr(Tensor, self.tensor_ori_attr) - self.set_api_attr(StubTensor, self.stub_tensor_ori_attr) - self.set_api_attr(ops, self.functional_ori_attr) - self.set_api_attr(mint, self.mint_ops_ori_attr) - self.set_api_attr(functional, self.mint_func_ops_ori_attr) - self.set_api_attr(comm_func, self.distributed_ori_attr) - - def initialize_hook(self, hook): - setup_hooks(hook) - if is_mindtorch(): - wrap_torch_api_name = get_wrap_torch_api_list() - self.store_ori_attr(torch, - wrap_torch_api_name.torch_api_names, self.torch_ori_attr) - self.store_ori_attr(torch.Tensor, - wrap_torch_api_name.tensor_api_names, self.torch_tensor_ori_attr) - self.store_ori_attr(torch.nn.functional, - wrap_torch_api_name.functional_api_names, self.torch_functional_ori_attr) - self.store_ori_attr(torch.distributed, - wrap_torch_api_name.distributed_api_names, self.torch_distributed_ori_attr) - self.store_ori_attr(torch_npu, - wrap_torch_api_name.npu_api_names, self.torch_npu_ori_attr) - for attr_name in dir(HOOKTorchOP): - if attr_name.startswith(Const.ATTR_NAME_PREFIX): - api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:] - self.torch_hook_attr[api_name] = getattr(HOOKTorchOP, attr_name) - for attr_name in dir(HOOKTorchTensor): - if attr_name.startswith(Const.ATTR_NAME_PREFIX): - api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:] - self.torch_tensor_hook_attr[api_name] = getattr(HOOKTorchTensor, attr_name) - for attr_name in dir(HOOKTorchFunctionalOP): - if attr_name.startswith(Const.ATTR_NAME_PREFIX): - api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:] - self.torch_functional_hook_attr[api_name] = getattr(HOOKTorchFunctionalOP, attr_name) - for attr_name in dir(HOOKTorchDistributedOP): - if attr_name.startswith(Const.ATTR_NAME_PREFIX): - api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:] - self.torch_distributed_hook_attr[api_name] = getattr(HOOKTorchDistributedOP, attr_name) - for attr_name in dir(HOOKTorchNpuOP): - if attr_name.startswith(Const.ATTR_NAME_PREFIX): - api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:] - self.torch_npu_hook_attr[api_name] = getattr(HOOKTorchNpuOP, attr_name) - return - - wrap_api_name = get_wrap_api_list() - self.store_ori_attr(Tensor, wrap_api_name.tensor_api_names, self.tensor_ori_attr) - self.store_ori_attr(StubTensor, wrap_api_name.stub_tensor_api_names, self.stub_tensor_ori_attr) - self.store_ori_attr(ops, wrap_api_name.ops_api_names, self.functional_ori_attr) - self.store_ori_attr(mint, wrap_api_name.mint_api_names, self.mint_ops_ori_attr) - self.store_ori_attr(functional, wrap_api_name.mint_nn_func_api_names, self.mint_func_ops_ori_attr) - self.store_ori_attr(comm_func, wrap_api_name.distributed_api_names, self.distributed_ori_attr) - self.store_ori_attr(ops, self.norm_inner_ops, self.norm_inner_ops_ori_attr) - for attr_name in dir(HOOKTensor): - if attr_name.startswith(Const.ATTR_NAME_PREFIX): - api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:] - self.tensor_hook_attr[api_name] = getattr(HOOKTensor, attr_name) - for attr_name in dir(HOOKStubTensor): - if attr_name.startswith(Const.ATTR_NAME_PREFIX): - api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:] - self.stub_tensor_hook_attr[api_name] = getattr(HOOKStubTensor, attr_name) - for attr_name in dir(HOOKFunctionalOP): - if attr_name.startswith(Const.ATTR_NAME_PREFIX): - api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:] - self.functional_hook_attr[api_name] = getattr(HOOKFunctionalOP, attr_name) - if api_name in self.norm_inner_ops: - self.norm_inner_ops_hook_attr[api_name] = getattr(HOOKFunctionalOP, attr_name) - for attr_name in dir(HOOKMintOP): - if attr_name.startswith(Const.ATTR_NAME_PREFIX): - api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:] - self.mint_ops_hook_attr[api_name] = getattr(HOOKMintOP, attr_name) - for attr_name in dir(HOOKMintNNFunctionalOP): - if attr_name.startswith(Const.ATTR_NAME_PREFIX): - api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:] - self.mint_func_ops_hook_attr[api_name] = getattr(HOOKMintNNFunctionalOP, attr_name) - for attr_name in dir(HOOKDistributedOP): - if attr_name.startswith(Const.ATTR_NAME_PREFIX): - api_name = attr_name[Const.ATTR_NAME_PREFIX_LEN:] - self.distibuted_hook_attr[api_name] = getattr(HOOKDistributedOP, attr_name) - - -api_register = ApiRegistry() diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/hook_cell.py b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/hook_cell.py index b68a7d995a56497a219281c5a43d692c46cfac4d..7007992ca4540a06b1ebc85a068179e88ec589cc 100644 --- a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/hook_cell.py +++ b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/hook_cell.py @@ -28,23 +28,22 @@ def get_cell_count(name): return HOOKCell.cell_count[name] -def __init__(self, build_hook) -> None: +def __init__(self, hook_build_func) -> None: super(HOOKCell, self).__init__() self.changed_status = False self.input_kwargs = {} - self.prefix = "" if not HOOKCell.g_stop_hook: HOOKCell.g_stop_hook = True self.changed_status = True - if hasattr(self, "prefix_api_name"): - self.prefix = self.prefix_api_name - self.forward_data_collected = False - forward_pre_hook, forward_hook, backward_hook, backward_pre_hook = build_hook(self.prefix) - self.register_forward_pre_hook(forward_pre_hook) - self.register_forward_hook(forward_hook) - register_backward_hook_functions["full"](self, backward_hook) - register_backward_hook_functions["pre"](self, backward_pre_hook) + + prefix = self.prefix_api_name if hasattr(self, "prefix_api_name") else "" + if callable(hook_build_func): + forward_pre_hook, forward_hook, backward_hook, backward_pre_hook = hook_build_func(prefix) + self.register_forward_pre_hook(forward_pre_hook) + self.register_forward_hook(forward_hook) + register_backward_hook_functions["full"](self, backward_hook) + register_backward_hook_functions["pre"](self, backward_pre_hook) # 重载call,加全局标志。 diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml index 723b0cbc93f78d50f703838eb488de6733008906..364062b46478b63369269c2470ea526eec59a3d3 100644 --- a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +++ b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml @@ -564,15 +564,15 @@ tensor: - all - amax - amin + - angle - any - arccos - arccosh - - argmax - - angle - arcsin - arcsinh - arctan - arctanh + - argmax - argmin - argsort - asin @@ -582,19 +582,23 @@ tensor: - atanh - baddbmm - bernoulli + - bfloat16 - bincount - bitwise_and - bitwise_or - bitwise_xor - bmm - bool + - bool astype - broadcast_to + - byte - ceil - - cholesky_solve - cholesky + - cholesky_solve - clamp - clip - conj + - copy - copysign - cos - cosh @@ -606,11 +610,13 @@ tensor: - deg2rad - diag - diagflat + - diagonal - diff - digamma - div - div_ - divide + - double - equal - erf - erfc @@ -618,13 +624,16 @@ tensor: - exp - expand_as - expm1 + - flatten - flip - fliplr - flipud + - float - float_power - floor - fmod - frac + - from_numpy - gather_elements - ge - geqrf @@ -648,12 +657,12 @@ tensor: - inner - int - inverse + - is_complex + - is_signed - isclose - isfinite - isinf - isnan - - is_complex - - is_signed - isneginf - isposinf - isreal @@ -704,28 +713,27 @@ tensor: - new_ones - new_zeros - nextafter - - norm - nonzero + - norm - not_equal - ormqr - permute - pow - prod - qr + - rad2deg - ravel - real - reciprocal - remainder - renorm - - rad2deg - - tile - repeat_interleave - reshape - reshape - - round + - resize - rot90 + - round - rsqrt - - sum_to_size - scatter - sgn - short @@ -745,7 +753,8 @@ tensor: - sub - sub_ - subtract - - subtract + - sum + - sum_to_size - svd - swapaxes - swapdims @@ -753,13 +762,13 @@ tensor: - take - tan - tanh - - trace - - swapaxes + - tensor_split - tile + - to - topk - - tril - - tensor_split + - trace - transpose + - tril - true_divide - trunc - unbind @@ -769,17 +778,6 @@ tensor: - view - where - xlogy - - from_numpy - - std - - take - - var - - all - - any - - copy - - diagonal - - flatten - - resize - - sum mint.ops: - abs diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/wrap_api.py b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/wrap_api.py deleted file mode 100644 index 0e97929ecd7f8444b19fd531efc49883d0df58de..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/wrap_api.py +++ /dev/null @@ -1,212 +0,0 @@ -# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os - -from mindspore import Tensor, mint, ops -from mindspore.common._stub_tensor import StubTensor -from mindspore.communication import comm_func -from mindspore.mint.nn import functional - -from msprobe.core.common.const import Const -from msprobe.core.common.file_utils import load_yaml -from msprobe.mindspore.common.const import Const as MsConst -from msprobe.mindspore.common.utils import is_mindtorch -from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell - -if is_mindtorch(): - import torch - import torch_npu - -cur_path = os.path.dirname(os.path.realpath(__file__)) -yaml_path = os.path.join(cur_path, MsConst.SUPPORTED_API_LIST_FILE) -torch_yaml_path = os.path.join(cur_path, "../../../pytorch/hook_module", MsConst.SUPPORTED_API_LIST_FILE) - - -class HOOKTensor(object): - pass - - -class HOOKStubTensor(object): - pass - - -class HOOKFunctionalOP(object): - pass - - -class HOOKMintOP(object): - pass - - -class HOOKMintNNFunctionalOP(object): - pass - - -class HOOKDistributedOP(object): - pass - - -class HOOKTorchOP(object): - pass - - -class HOOKTorchTensor(object): - pass - - -class HOOKTorchFunctionalOP(object): - pass - - -class HOOKTorchDistributedOP(object): - pass - - -class HOOKTorchNpuOP(object): - pass - - -class ApiTemplate(HOOKCell): - def __init__(self, api_name, api_dict, prefix, hook): - self.api_name = api_name - self.api_func = api_dict[api_name] - self.prefix_api_name = prefix + str(api_name.split(Const.SEP)[-1]) + Const.SEP - super().__init__(hook) - - @staticmethod - def async_to_sync(output): - # Fake handle, used to return after the CommHandle executes the wait method - fake_handle = type("FakeHandle", (), {"wait": lambda self: None})() - if isinstance(output, tuple) and len(output) == 2 and hasattr(output[1], "wait"): - output[1].wait() - output = (output[0], fake_handle) - elif hasattr(output, "wait"): - output.wait() - output = fake_handle - return output - - def construct(self, *args, **kwargs): - if self.api_name.startswith(MsConst.DROPOUT_API_NAME_PREFIX): - return args[0] if args else kwargs.get(Const.INPUT) - - output = self.api_func(*args, **kwargs) - - if self.prefix_api_name.startswith(MsConst.DISTRIBUTED_DATA_PREFIX): - if kwargs.get("async_op") or self.api_name in ["isend", "irecv"]: - output = self.async_to_sync(output) - return output - - def forward(self, *args, **kwargs): - if self.api_name.startswith(MsConst.DROPOUT_API_NAME_PREFIX): - return args[0] if args else kwargs.get(Const.INPUT) - return self.api_func(*args, **kwargs) - - -class WrapApiName: - def __init__(self, tensor_api_names, stub_tensor_api_names, ops_api_names, mint_api_names, mint_nn_func_api_names, - distributed_api_names): - self.tensor_api_names = tensor_api_names - self.stub_tensor_api_names = stub_tensor_api_names - self.ops_api_names = ops_api_names - self.mint_api_names = mint_api_names - self.mint_nn_func_api_names = mint_nn_func_api_names - self.distributed_api_names = distributed_api_names - - -class WrapTorchApiName: - def __init__(self, torch_api_names, tensor_api_names, functional_api_names, distributed_api_names, npu_api_names): - self.torch_api_names = torch_api_names - self.tensor_api_names = tensor_api_names - self.functional_api_names = functional_api_names - self.distributed_api_names = distributed_api_names - self.npu_api_names = npu_api_names - - -def get_wrap_api_list(): - api_list = load_yaml(yaml_path) - tensor_api = api_list.get(MsConst.SUPPORTED_TENSOR_LIST_KEY) - ops_api = api_list.get(MsConst.SUPPORTED_OPS_LIST_KEY) - mint_api = api_list.get(MsConst.SUPPORTED_MINT_LIST_KEY) - mint_nn_func_api = api_list.get(MsConst.SUPPORTED__MINT_NN_FUNC_LIST_KEY) - distributed_api = api_list.get(MsConst.SUPPORTED_COMM_LIST_KEY) - wrap_api_name = WrapApiName(set(tensor_api) & set(dir(Tensor)), - set(tensor_api) & set(dir(StubTensor)), - set(ops_api) & set(dir(ops)), - set(mint_api) & set(dir(mint)), - set(mint_nn_func_api) & set(dir(functional)), - set(distributed_api) & set(dir(comm_func))) - return wrap_api_name - - -def get_wrap_torch_api_list(): - api_list = load_yaml(torch_yaml_path) - torch_api = api_list.get("torch") - tensor_api = api_list.get("tensor") - functional_api = api_list.get("functional") - distributed_api = api_list.get("distributed") - npu_api = api_list.get("torch_npu") - wrap_api_name = WrapTorchApiName(set(torch_api) & set(dir(torch)), - set(tensor_api) & set(dir(torch.Tensor)), - set(functional_api) & set(dir(torch.nn.functional)), - set(distributed_api) & set(dir(torch.distributed)), - set(npu_api) & set(dir(torch_npu))) - return wrap_api_name - - -def wrap_api_func(api_name, api_dict, prefix, hook): - def api_function(*args, **kwargs): - return ApiTemplate(api_name, api_dict, prefix, hook)(*args, **kwargs) - return api_function - - -def wrap_api_func_and_bind(api_list, api_dict, prefix, hook, hook_class): - for api_name in api_list: - if callable(api_dict[api_name]): - setattr(hook_class, Const.ATTR_NAME_PREFIX + api_name, wrap_api_func(api_name, api_dict, prefix, hook)) - - -def setup_hooks(hook): - if is_mindtorch(): - torch_wrap_api_name = get_wrap_torch_api_list() - wrap_api_func_and_bind(torch_wrap_api_name.torch_api_names, - {f: getattr(torch, f) for f in dir(torch)}, - MsConst.TORCH_DATA_PREFIX, hook, HOOKTorchOP) - wrap_api_func_and_bind(torch_wrap_api_name.tensor_api_names, - {f: getattr(torch.Tensor, f) for f in dir(torch.Tensor)}, - MsConst.TENSOR_DATA_PREFIX, hook, HOOKTorchTensor) - wrap_api_func_and_bind(torch_wrap_api_name.functional_api_names, - {f: getattr(torch.nn.functional, f) for f in dir(torch.nn.functional)}, - MsConst.OPS_DATA_PREFIX, hook, HOOKTorchFunctionalOP) - wrap_api_func_and_bind(torch_wrap_api_name.distributed_api_names, - {f: getattr(torch.distributed, f) for f in dir(torch.distributed)}, - MsConst.DISTRIBUTED_DATA_PREFIX, hook, HOOKTorchDistributedOP) - wrap_api_func_and_bind(torch_wrap_api_name.npu_api_names, {f: getattr(torch_npu, f) for f in dir(torch_npu)}, - MsConst.TORCH_NPU_DATA_PREFIX, hook, HOOKTorchNpuOP) - return - - wrap_api_name = get_wrap_api_list() - wrap_api_func_and_bind(wrap_api_name.tensor_api_names, {f: getattr(Tensor, f) for f in dir(Tensor)}, - MsConst.TENSOR_DATA_PREFIX, hook, HOOKTensor) - wrap_api_func_and_bind(wrap_api_name.stub_tensor_api_names, {f: getattr(StubTensor, f) for f in dir(StubTensor)}, - MsConst.STUB_TENSOR_DATA_PREFIX, hook, HOOKStubTensor) - wrap_api_func_and_bind(wrap_api_name.ops_api_names, {f: getattr(ops, f) for f in dir(ops)}, - MsConst.OPS_DATA_PREFIX, hook, HOOKFunctionalOP) - wrap_api_func_and_bind(wrap_api_name.mint_api_names, {f: getattr(mint, f) for f in dir(mint)}, - MsConst.MINT_DATA_PREFIX, hook, HOOKMintOP) - wrap_api_func_and_bind(wrap_api_name.mint_nn_func_api_names, {f: getattr(functional, f) for f in dir(functional)}, - MsConst.MINT_NN_FUNC_DATA_PREFIX, hook, HOOKMintNNFunctionalOP) - wrap_api_func_and_bind(wrap_api_name.distributed_api_names, {f: getattr(comm_func, f) for f in dir(comm_func)}, - MsConst.DISTRIBUTED_DATA_PREFIX, hook, HOOKDistributedOP) diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/jit_dump.py b/debug/accuracy_tools/msprobe/mindspore/dump/jit_dump.py index 0a32200639a1f3805f815c37caaef5d3bb64c82f..634b15767528da447adadbe324aa4163adc14838 100644 --- a/debug/accuracy_tools/msprobe/mindspore/dump/jit_dump.py +++ b/debug/accuracy_tools/msprobe/mindspore/dump/jit_dump.py @@ -16,6 +16,7 @@ import os from collections import defaultdict +import mindspore from mindspore._c_expression import PyNativeExecutor_ try: from mindspore.common.api import _MindsporeFunctionExecutor @@ -25,7 +26,10 @@ except ImportError: from msprobe.core.common.log import logger from msprobe.core.common.const import Const from msprobe.core.data_dump.data_processor.base import ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs -from msprobe.mindspore.dump.hook_cell.api_registry import api_register +from msprobe.mindspore.dump.hook_cell.api_register import get_api_register + + +_api_register = get_api_register() def dump_jit(name, in_feat, out_feat, is_forward): @@ -69,7 +73,7 @@ class JitDump(_MindsporeFunctionExecutor): def __call__(self, *args, **kwargs): if JitDump.jit_dump_switch: - api_register.api_set_ori_func() + _api_register.restore_all_api() out = super().__call__(*args, **kwargs) if JitDump.jit_dump_switch and len(args) > 0: if self.name and self.name != "construct": @@ -80,7 +84,7 @@ class JitDump(_MindsporeFunctionExecutor): elif len(args) == 0: logger.warning(f"The jit function {self.name} has no input arguments, nothing will be dumped.") if JitDump.jit_dump_switch: - api_register.api_set_hook_func() + _api_register.register_all_api() return out @classmethod @@ -101,9 +105,12 @@ class JitDump(_MindsporeFunctionExecutor): def grad(self, obj, grad, weights, grad_position, *args, **kwargs): if JitDump.jit_dump_switch and JitDump.jit_enable: - api_register.api_set_ori_func() - output = self._executor.grad(grad, obj, weights, grad_position, *args, *(kwargs.values())) + _api_register.restore_all_api() + if mindspore.__version__ >= "2.5": + output = self._executor.grad(grad, obj, weights, grad_position, False, *args, *(kwargs.values())) + else: + output = self._executor.grad(grad, obj, weights, grad_position, *args, *(kwargs.values())) if JitDump.jit_dump_switch and JitDump.jit_enable: dump_jit(obj, args, None, False) - api_register.api_set_hook_func() + _api_register.register_all_api() return output diff --git a/debug/accuracy_tools/msprobe/mindspore/dym_loader/hook_dynamic_loader.cc b/debug/accuracy_tools/msprobe/mindspore/dym_loader/hook_dynamic_loader.cc index b72d68741da491fc450c2d697a3ebfec895a3447..9ef4eec3ad4855f91e5b93322f3009afe2b42a41 100644 --- a/debug/accuracy_tools/msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +++ b/debug/accuracy_tools/msprobe/mindspore/dym_loader/hook_dynamic_loader.cc @@ -18,37 +18,10 @@ #include #include #include +#include #include "utils/log_adapter.h" -namespace { - -// Utility function to check if a file path is valid -bool IsValidPath(const std::string &path) { - struct stat fileStat; - if (stat(path.c_str(), &fileStat) != 0) { - MS_LOG(ERROR) << "File does not exist or cannot be accessed: " << path; - return false; - } - - if (S_ISLNK(fileStat.st_mode)) { - MS_LOG(ERROR) << "File is a symbolic link, which is not allowed: " << path; - return false; - } - - if (!S_ISREG(fileStat.st_mode)) { - MS_LOG(ERROR) << "File is not a regular file: " << path; - return false; - } - - if (path.substr(path.find_last_of(".")) != ".so") { - MS_LOG(ERROR) << "File is not a .so file: " << path; - return false; - } - - return true; -} - -} // namespace +namespace py = pybind11; HookDynamicLoader &HookDynamicLoader::GetInstance() { static HookDynamicLoader instance; @@ -65,38 +38,31 @@ bool HookDynamicLoader::loadFunction(void *handle, const std::string &functionNa return true; } -bool HookDynamicLoader::validateLibraryPath(const std::string &libPath) { - char *realPath = realpath(libPath.c_str(), nullptr); - if (!realPath) { - MS_LOG(WARNING) << "Failed to resolve realpath for the library: " << libPath; - return false; - } - - bool isValid = IsValidPath(realPath); - free(realPath); // Free memory allocated by realpath - return isValid; -} - bool HookDynamicLoader::LoadLibrary() { - const char *libPath = std::getenv("HOOK_TOOL_PATH"); - if (!libPath) { - MS_LOG(WARNING) << "HOOK_TOOL_PATH is not set!"; - return false; - } - - std::string resolvedLibPath(libPath); - if (!validateLibraryPath(resolvedLibPath)) { - MS_LOG(WARNING) << "Library path validation failed."; - return false; - } - + std::string msprobePath = ""; + // 获取gil锁 + py::gil_scoped_acquire acquire; + try { + py::module msprobeMod = py::module::import("msprobe.lib._msprobe_c"); + if (!py::hasattr(msprobeMod, "__file__")) { + MS_LOG(WARNING) << "Adump mod not found"; + return false; + } + msprobePath = msprobeMod.attr("__file__").cast(); + } catch (const std::exception& e) { + MS_LOG(WARNING) << "Adump mod path unable to get: " << e.what(); + return false; + } std::lock_guard lock(mutex_); if (handle_) { MS_LOG(WARNING) << "Hook library already loaded!"; return false; } - - handle_ = dlopen(resolvedLibPath.c_str(), RTLD_LAZY | RTLD_LOCAL); + if (msprobePath == "") { + MS_LOG(WARNING) << "Adump path not loaded"; + return false; + } + handle_ = dlopen(msprobePath.c_str(), RTLD_LAZY | RTLD_LOCAL); if (!handle_) { MS_LOG(WARNING) << "Failed to load Hook library: " << dlerror(); return false; @@ -104,7 +70,7 @@ bool HookDynamicLoader::LoadLibrary() { for (const auto &functionName : functionList_) { if (!loadFunction(handle_, functionName)) { - MS_LOG(WARNING) << "Failed to load function: " << functionName; + MS_LOG(WARNING) << "Failed to load adump function"; dlclose(handle_); handle_ = nullptr; return false; diff --git a/debug/accuracy_tools/msprobe/mindspore/dym_loader/hook_dynamic_loader.h b/debug/accuracy_tools/msprobe/mindspore/dym_loader/hook_dynamic_loader.h index 6309e60b662a03d7f77cb450986ded5329fd8960..3e604558aee825e69dcffc00b817f7980e17e3e1 100644 --- a/debug/accuracy_tools/msprobe/mindspore/dym_loader/hook_dynamic_loader.h +++ b/debug/accuracy_tools/msprobe/mindspore/dym_loader/hook_dynamic_loader.h @@ -40,7 +40,6 @@ class HookDynamicLoader { private: // Helper functions bool loadFunction(void *handle, const std::string &functionName); - bool validateLibraryPath(const std::string &libPath); HookDynamicLoader() = default; diff --git a/debug/accuracy_tools/msprobe/mindspore/free_benchmark/api_pynative_self_check.py b/debug/accuracy_tools/msprobe/mindspore/free_benchmark/api_pynative_self_check.py index 57b7de4fa567d73a19178256d79f5e4cbeb38864..da4821b3ac45a689fab5ba5c63515f88bd6e17c3 100644 --- a/debug/accuracy_tools/msprobe/mindspore/free_benchmark/api_pynative_self_check.py +++ b/debug/accuracy_tools/msprobe/mindspore/free_benchmark/api_pynative_self_check.py @@ -19,6 +19,7 @@ import os import traceback import mindspore as ms + from msprobe.core.common.const import Const from msprobe.core.common.exceptions import DistributedNotInitializedError from msprobe.core.common.file_utils import check_path_length, load_yaml @@ -27,7 +28,7 @@ from msprobe.mindspore.common.const import FreeBenchmarkConst from msprobe.mindspore.common.log import logger from msprobe.mindspore.common.utils import get_rank_if_initialized from msprobe.mindspore.debugger.debugger_config import DebuggerConfig -from msprobe.mindspore.dump.hook_cell.api_registry import api_register +from msprobe.mindspore.dump.hook_cell.api_register import get_api_register from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell from msprobe.mindspore.free_benchmark.common.config import Config from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams @@ -37,6 +38,9 @@ from msprobe.mindspore.free_benchmark.perturbation.perturbation_factory import P from msprobe.mindspore.runtime import Runtime +_api_register = get_api_register() + + class ApiPyNativeSelfCheck: def __init__(self, config: DebuggerConfig): Config.is_enable = True @@ -60,8 +64,8 @@ class ApiPyNativeSelfCheck: self.store_original_func() def handle(self): - api_register.initialize_hook(self.build_hook) - api_register.api_set_hook_func() + _api_register.initialize_hook(self.build_hook) + _api_register.register_all_api() def build_hook(self, api_name): def pre_hook(cell, input_data): @@ -166,13 +170,13 @@ def check_self(api_name_with_id, output, ori_func, *args, **kwargs): return ret logger.info(f"[{api_name_with_id}] is {Config.handler_type}ing.") - api_register.api_set_ori_func() + _api_register.restore_all_api() try: perturbation = PerturbationFactory.create(api_name_with_id) params.fuzzed_result = perturbation.handle(params) if params.fuzzed_result is False: - api_register.api_set_hook_func() + _api_register.register_all_api() return ret if Config.stage == Const.BACKWARD: params.original_result = Tools.get_grad(params.original_func, *params.args, **params.kwargs) @@ -183,7 +187,7 @@ def check_self(api_name_with_id, output, ori_func, *args, **kwargs): logger.error(f"[{api_name_with_id}] Error: {str(e)}") logger.error(f"[{api_name_with_id}] Error detail: {traceback.format_exc()}") - api_register.api_set_hook_func() + _api_register.register_all_api() return ret diff --git a/debug/accuracy_tools/msprobe/mindspore/monitor/utils.py b/debug/accuracy_tools/msprobe/mindspore/monitor/utils.py index a27172f19ead537f276c5ce0820b405d7abb6e25..c85e66a65ba26fdbc1d10a8e55c8273236409b36 100644 --- a/debug/accuracy_tools/msprobe/mindspore/monitor/utils.py +++ b/debug/accuracy_tools/msprobe/mindspore/monitor/utils.py @@ -258,7 +258,7 @@ def validate_config(config): step_interval = config.get('step_interval', 1) validate_step_interval(step_interval) - collect_times = config.get('collect_times', 1e8) + collect_times = config.get('collect_times', int(1e8)) validate_collect_times(collect_times) if not targets: diff --git a/debug/accuracy_tools/msprobe/mindspore/service.py b/debug/accuracy_tools/msprobe/mindspore/service.py index 5afbd046be4caf29c4b247a0f8fdd655c5208fd0..11d6db7a981a950120731df26edabee2e751b720 100644 --- a/debug/accuracy_tools/msprobe/mindspore/service.py +++ b/debug/accuracy_tools/msprobe/mindspore/service.py @@ -41,7 +41,7 @@ from msprobe.mindspore.cell_processor import CellProcessor from msprobe.mindspore.common.log import logger from msprobe.mindspore.common.utils import (get_rank_if_initialized, clean_input_kwargs, is_mindtorch, register_backward_hook_functions) -from msprobe.mindspore.dump.hook_cell.api_registry import api_register +from msprobe.mindspore.dump.hook_cell.api_register import get_api_register from msprobe.mindspore.dump.hook_cell.primitive_hooks import PrimitiveHookService from msprobe.mindspore.dump.jit_dump import JitDump from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell @@ -63,6 +63,8 @@ class Service: self.inner_switch = False self.primitive_switch = False self.current_iter = 0 + self.loop = 0 + self.init_step = 0 self.first_start = True self.current_rank = None self.dump_iter_dir = None @@ -71,6 +73,7 @@ class Service: self.params_grad_info = {} self.hook_handle_dict = {} # 提前注册,确保注册尽可能多的API hook + self.api_register = get_api_register() self.register_api_hook() self.init_for_debug_level() @@ -276,11 +279,12 @@ class Service: if self.config.task == Const.TENSOR: self.data_collector.data_processor.dump_async_data() self.data_collector.write_json() - self.current_iter += 1 - self.data_collector.update_iter(self.current_iter) + self.loop += 1 self.reset_status() def start(self, model=None): + self.current_iter = self.loop + self.init_step + self.data_collector.update_iter(self.current_iter) if self.config.level == Const.LEVEL_DEBUG: return self.start_call = True @@ -321,7 +325,7 @@ class Service: PIJitCaptureContext.__exit__ = self.empty self.first_start = False - api_register.api_set_hook_func() + self.api_register.register_all_api() self.switch = True self.primitive_switch = True logger.info(f"Dump switch is turned on at step {self.current_iter}. ") @@ -410,8 +414,8 @@ class Service: def register_api_hook(self): if self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L1, Const.LEVEL_L2]: logger.info(f"The api {self.config.task} hook function is successfully mounted to the model.") - api_register.initialize_hook(functools.partial(self.build_hook, BaseScope.Module_Type_API)) - api_register.api_set_hook_func() + self.api_register.initialize_hook(functools.partial(self.build_hook, BaseScope.Module_Type_API)) + self.api_register.register_all_api() def get_cells_and_names(self): cells_and_names_with_index = {} diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py index 797210f09c3b55a64002a4aa84a3d39770ae803c..641eada030353ec67f6ce7b59bd3d14909a56e51 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py @@ -183,6 +183,7 @@ class APIExtractor: self.update_data_name(v, dump_data_dir) return value + @recursion_depth_decorator("OpGenerator: APIExtractor.update_data_name") def update_data_name(self, data, dump_data_dir): if isinstance(data, list): for item in data: diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py index 498102b475f564564d6039a81e305fba3bceec17..3eb7fc0df96ce4bc52c7e45fdd3d07fba986feec 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py @@ -50,6 +50,9 @@ def split_json_file(input_file, num_splits, filter_api): backward_data[f"{data_name}.backward"] = backward_data.pop(data_name) input_data = load_json(input_file) + if "dump_data_dir" not in input_data.keys(): + logger.error("Invalid input file, 'dump_data_dir' field is missing") + raise CompareException("Invalid input file, 'dump_data_dir' field is missing") if input_data.get("data") is None: logger.error("Invalid input file, 'data' field is missing") raise CompareException("Invalid input file, 'data' field is missing") diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py index 905687c1bfc932883396481410c333a7566fd342..51fe32de810d4580313c503ec8cca69e9e819f92 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py @@ -65,6 +65,7 @@ DETAILS_FILE_NAME = "accuracy_checking_details_" + current_time + ".csv" not_backward_list = ['repeat_interleave'] unsupported_backward_list = ['masked_select'] +unsupported_api_list = ["to"] tqdm_params = { @@ -218,6 +219,7 @@ def blacklist_and_whitelist_filter(api_name, black_list, white_list): If api is both in black_list and black_list, black_list first. return: False for exec api, True for not exec """ + black_list.extend(unsupported_api_list) if black_list and api_name in black_list: return True if white_list and api_name not in white_list: @@ -317,7 +319,8 @@ def run_torch_api_online(api_full_name, api_data, backward_content): if kwargs.get("device"): del kwargs["device"] - device_out = exec_api(api_type, api_name, Const.CUDA_LOWERCASE, args, kwargs) + device_exec_params = ExecParams(api_type, api_name, current_device, args, kwargs, False, None) + device_out = exec_api(device_exec_params) device_out = move2device_exec(device_out, "cpu") return UtDataInfo(None, None, out, device_out, None, in_fwd_data_list, None, rank=api_data.rank) diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py index dc0174212e3f8f8cf70fa1701aadc664138dbcdf..e62315a1a163a1dbba605de31d3b01703314ee5b 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py @@ -1,9 +1,7 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # -# Licensed under the Apache License, Version 2.0 (the "License"); +# Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # @@ -18,8 +16,8 @@ import os from collections import namedtuple import re -import torch +import torch try: import torch_npu except ImportError: @@ -33,11 +31,9 @@ from msprobe.core.common.const import FileCheckConst, Const, CompareConst from msprobe.core.common.file_utils import FileChecker from msprobe.core.common.log import logger from msprobe.core.common.utils import CompareException +from msprobe.pytorch.hook_module.api_register import ApiTemplate, get_api_register from msprobe.pytorch.hook_module.wrap_aten import AtenOPTemplate -from msprobe.pytorch.hook_module.wrap_functional import FunctionalOPTemplate -from msprobe.pytorch.hook_module.wrap_npu_custom import NpuOPTemplate -from msprobe.pytorch.hook_module.wrap_tensor import TensorOPTemplate -from msprobe.pytorch.hook_module.wrap_torch import TorchOPTemplate + hf_32_standard_api = ["conv1d", "conv2d"] not_detach_set = {'resize_', 'resize_as_', 'set_', 'transpose_', 't_', 'squeeze_', 'unsqueeze_'} @@ -108,17 +104,30 @@ def exec_api(exec_params): kwargs = exec_params.kwargs is_autocast = exec_params.is_autocast autocast_dtype = exec_params.autocast_dtype - - if api_type == "Functional": - torch_api = FunctionalOPTemplate(api_name, str, False) - if api_type == "Tensor": - torch_api = TensorOPTemplate(api_name, str, False) - if api_type == "Torch": - torch_api = TorchOPTemplate(api_name, str, False) - if api_type == "Aten": + out = None + + prefix_map = Const.API_DATA_PREFIX.get(Const.PT_FRAMEWORK, {}) + if not prefix_map or api_type not in prefix_map.values() or \ + api_type not in ( + Const.FUNCTIONAL_API_TYPE_PREFIX, + Const.TENSOR_API_TYPE_PREFIX, + Const.TORCH_API_TYPE_PREFIX, + Const.ATEN_API_TYPE_PREFIX, + Const.NPU_API_TYPE_PREFIX + ): + return out + + if api_type == Const.ATEN_API_TYPE_PREFIX: torch_api = AtenOPTemplate(api_name, None, False) - if api_type == "NPU": - torch_api = NpuOPTemplate(api_name, None, False, device) + else: + api_register = get_api_register() + api_register.initialize_hook(None) + api_func_type = list(prefix_map.keys())[list(prefix_map.values()).index(api_type)] + api_func = api_register.ori_api_attr.get(Const.PT_FRAMEWORK + Const.SEP + api_func_type, {}).get(api_name) + if api_func is None: + return out + + torch_api = ApiTemplate(api_name, api_func, api_type, None, need_hook=False, device=device) if is_autocast: with autocast(dtype=autocast_dtype): out = torch_api.forward(*args, **kwargs) diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py index f31c29c6bb6fa8a863b83bf09d15aba09645436f..f858067b6616ff34ee95e1c4394e63fe4385b397 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py @@ -27,6 +27,8 @@ from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.client import T from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.server import TCPServer from msprobe.core.common.file_utils import remove_path from msprobe.pytorch.common.utils import logger, save_api_data, load_api_data, save_pkl, load_pkl +from msprobe.core.common.utils import recursion_depth_decorator + BufferType = Union[ApiData, Dict[str, Any], str] # Union[Tensor, Tuple[Optional[Tensor]]] @@ -168,11 +170,12 @@ class ATTL: return buffer +@recursion_depth_decorator("move2device_exec") def move2device_exec(obj, device): if isinstance(obj, (tuple, list)): data_list = [move2device_exec(val, device) for val in obj] return data_list if isinstance(obj, list) else tuple(data_list) - if isinstance(obj, dict): + if isinstance(obj, dict): return {key: move2device_exec(val, device) for key, val in obj.items()} elif isinstance(obj, torch.Tensor): obj = obj.detach() diff --git a/debug/accuracy_tools/msprobe/pytorch/common/utils.py b/debug/accuracy_tools/msprobe/pytorch/common/utils.py index 16067f6d2bee70645bcc337d1809a14f41ae5b96..40360ef44c3d482b9fc3de2ee7f538b0b161c617 100644 --- a/debug/accuracy_tools/msprobe/pytorch/common/utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/common/utils.py @@ -57,7 +57,7 @@ def parameter_adapter(func): @wraps(func) def inner(self, *args, **kwargs): - if self.op_name_ == "__getitem__" and len(args) > 1 and isinstance(args[1], torch.Tensor): + if self.api_name == "__getitem__" and len(args) > 1 and isinstance(args[1], torch.Tensor): input_tensor = args[0] indices = args[1] if indices.dtype == torch.uint8: @@ -77,7 +77,7 @@ def parameter_adapter(func): else: res = [input_tensor[tensor_index] for tensor_index in indices] return getattr(torch._C._VariableFunctionsClass, "stack")(res, 0) - if self.op_name_ == "__eq__" and len(args) > 1 and args[1] is None: + if self.api_name == "__eq__" and len(args) > 1 and args[1] is None: return False return func(self, *args, **kwargs) @@ -449,9 +449,9 @@ def is_recomputation(): def check_save_param(variable, name, save_backward): # try catch this api to skip invalid call - if not isinstance(variable, (list, dict, torch.Tensor, int, float, str)): + if not isinstance(variable, (list, dict, tuple, torch.Tensor, int, float, str)): logger.warning("PrecisionDebugger.save variable type not valid, " - "should be one of list, dict, torch.Tensor, int, float or string. " + "should be one of list, dict, tuple, torch.Tensor, int, float or string. " "Skip current save process.") raise ValueError if not isinstance(name, str): diff --git a/debug/accuracy_tools/msprobe/pytorch/debugger/precision_debugger.py b/debug/accuracy_tools/msprobe/pytorch/debugger/precision_debugger.py index 5bb1d3a14e82d7b4bce9d7da8921a1d701e82222..e6b014e528463ccae248ba01c3cca7be782aaaaf 100644 --- a/debug/accuracy_tools/msprobe/pytorch/debugger/precision_debugger.py +++ b/debug/accuracy_tools/msprobe/pytorch/debugger/precision_debugger.py @@ -19,7 +19,7 @@ import torch from msprobe.core.common.const import Const, FileCheckConst, MsgConst from msprobe.core.common.exceptions import MsprobeException from msprobe.core.common.file_utils import FileChecker -from msprobe.core.common.utils import get_real_step_or_rank +from msprobe.core.common.utils import get_real_step_or_rank, check_init_step from msprobe.pytorch.common.log import logger from msprobe.pytorch.common.utils import check_save_param from msprobe.pytorch.debugger.debugger_config import DebuggerConfig @@ -172,6 +172,14 @@ class PrecisionDebugger: return instance.service.save(variable, name, save_backward) + @classmethod + def set_init_step(cls, step): + instance = cls._instance + if not instance: + raise Exception(MsgConst.NOT_CREATED_INSTANCE) + check_init_step(step) + instance.service.init_step = step + def module_dump(module, dump_name): if not isinstance(module, torch.nn.Module): diff --git a/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/module_dump.py b/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/module_dump.py index 4700de6f1f9f3b5ddfb9507decb6f8739b5eda9b..cc78962f401a9e4f46d5794d7ca074f2e37f45e0 100644 --- a/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/module_dump.py +++ b/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/module_dump.py @@ -17,7 +17,7 @@ import torch from msprobe.core.common.const import Const from msprobe.core.data_dump.scope import BaseScope from msprobe.pytorch.common.log import logger -from msprobe.pytorch.hook_module.api_registry import api_register +from msprobe.pytorch.hook_module.api_register import get_api_register torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0' @@ -26,13 +26,14 @@ class ModuleDumper: def __init__(self, service): self.service = service self.hook_handle_list = [] + self.api_register = get_api_register() def start_module_dump(self, module, dump_name): - api_register.api_originality() + self.api_register.restore_all_api() self.register_hook(module, dump_name) def stop_module_dump(self): - api_register.api_modularity() + self.api_register.register_all_api() for hook_handle in self.hook_handle_list: if isinstance(hook_handle, torch.utils.hooks.RemovableHandle): hook_handle.remove() diff --git a/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/module_processer.py b/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/module_processer.py index b5ca1da461fd4235a09172de4b9dcea34a624e58..c1c433276fba441177e45892809883ea5e946dcf 100644 --- a/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/module_processer.py +++ b/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/module_processer.py @@ -17,6 +17,7 @@ from functools import wraps import torch from msprobe.core.common.const import Const +from msprobe.core.common.utils import recursion_depth_decorator from msprobe.core.data_dump.scope import BaseScope, ModuleRangeScope, MixRangeScope from msprobe.pytorch.common.log import logger from msprobe.pytorch.common.utils import replace_last_occurrence @@ -58,6 +59,7 @@ class ModuleProcesser: return clone_return_value_func @staticmethod + @recursion_depth_decorator("ModuleDump: ModuleProcesser.clone_if_tensor") def clone_if_tensor(result): if isinstance(result, torch.Tensor): return result.clone() @@ -109,6 +111,8 @@ class ModuleProcesser: for name, module in modules_and_names: if module == model: continue + if module.__class__.__name__ == "FullyShardedDataParallel": + continue module_index = (index + Const.SEP) if index != "-1" else "" prefix_name = (BaseScope.Module_Type_Module + Const.SEP + module_index + name + Const.SEP + module.__class__.__name__ + Const.SEP) diff --git a/debug/accuracy_tools/msprobe/pytorch/function_factory.py b/debug/accuracy_tools/msprobe/pytorch/function_factory.py index 247e2cd0ed5ea11047cc0d75954dbc1e92b889f4..f515b5d4783c0e20a2303579f6954d42a7b9deac 100644 --- a/debug/accuracy_tools/msprobe/pytorch/function_factory.py +++ b/debug/accuracy_tools/msprobe/pytorch/function_factory.py @@ -70,7 +70,7 @@ class Register(dict): def add_register_item(key, value): if key in self._dict: - logger.warning(f"{value.__name__} has been registered before, so we will overriden it.") + logger.warning(f"{value.__name__} has been registered before, so we will override it.") self[key] = value return value diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/api_register.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/api_register.py new file mode 100644 index 0000000000000000000000000000000000000000..30a45a84d877b5b1dd4b4d0231faece010c59f61 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/api_register.py @@ -0,0 +1,131 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +import os + +import torch +import torch.distributed as dist + +from msprobe.core.common.const import Const +from msprobe.core.data_dump.api_registry import ApiRegistry +from msprobe.pytorch.common.utils import ( + torch_without_guard_version, is_gpu, torch_device_guard, parameter_adapter +) +from msprobe.pytorch.function_factory import npu_custom_functions +from msprobe.pytorch.hook_module.hook_module import HOOKModule + + +torch_version_above_2 = torch.__version__.split('+')[0] > '2.0' + +_api_types = { + Const.PT_FRAMEWORK: { + Const.PT_API_TYPE_FUNCTIONAL: (torch.nn.functional, (torch.nn.functional,)), + Const.PT_API_TYPE_TENSOR: (torch.Tensor, (torch.Tensor,)), + Const.PT_API_TYPE_TORCH: (torch, (torch,)), + Const.PT_API_TYPE_VF: (torch._C._VariableFunctionsClass, (torch._VF,)), + Const.PT_API_TYPE_DIST: (dist, (dist, dist.distributed_c10d)) + } +} +if not is_gpu: + import torch_npu + if torch_without_guard_version: + _api_types.get(Const.PT_FRAMEWORK).update( + { + Const.PT_API_TYPE_NPU: (torch.ops.npu, (torch_npu, torch.ops.npu)) + } + ) + else: + _api_types.get(Const.PT_FRAMEWORK).update( + {Const.PT_API_TYPE_NPU: (torch_npu._C._VariableFunctionsClass, (torch_npu,))} + ) + _api_types.get(Const.PT_FRAMEWORK).update( + { + Const.PT_API_TYPE_NPU_DIST: (torch_npu.distributed, (torch_npu.distributed, + torch_npu.distributed.distributed_c10d)) + } + ) + +_inner_used_api = {} +_supported_api_list_path = (os.path.join(os.path.dirname(os.path.realpath(__file__)), Const.SUPPORT_API_FILE_NAME),) +_cuda_func_mapping = {"npu_fusion_attention": "gpu_fusion_attention"} + + +@parameter_adapter +def tensor_module_forward(module, *args, **kwargs): + return module.api_func(*args, **kwargs) + + +def dist_module_forward(module, *args, **kwargs): + handle = module.api_func(*args, **kwargs) + if kwargs.get("async_op") or module.api_name in ["isend", "irecv"]: + if handle and hasattr(handle, 'wait'): + handle.wait() + if module.api_name == "batch_isend_irecv": + if isinstance(handle, list): + for req in handle: + req.wait() + return handle + + +def npu_module_forward(module, *args, **kwargs): + if not module.need_hook: + if module.api_name not in npu_custom_functions: + raise Exception(f'There is not bench function {module.api_name}') + if module.device == Const.CUDA_LOWERCASE: + module.api_name = _cuda_func_mapping.get(module.api_name, module.api_name) + if module.device in [Const.CUDA_LOWERCASE, Const.CPU_LOWERCASE]: + return npu_custom_functions[module.api_name](*args, **kwargs) + return module.api_func(*args, **kwargs) + + +forward_methods = { + "Tensor": tensor_module_forward, + "Distributed": dist_module_forward, + "NPU": npu_module_forward +} + + +class ApiTemplate(HOOKModule): + def __init__(self, api_name, api_func, prefix, hook_build_func, need_hook=True, device=Const.CPU_LOWERCASE): + self.api_name = api_name + self.api_func = api_func + self.prefix = prefix + self.prefix_api_name = prefix + Const.SEP + str(api_name.split(Const.SEP)[-1]) + Const.SEP + self.need_hook = need_hook + self.device = device + if self.need_hook: + super().__init__(hook_build_func) + if prefix == Const.DIST_API_TYPE_PREFIX: + self.op_is_distributed = True + + @torch_device_guard + def forward(self, *args, **kwargs): + exec_func = forward_methods.get(self.prefix) + exec_func = functools.partial(exec_func, self) if exec_func else self.api_func + return exec_func(*args, **kwargs) + + +api_register = None + + +def get_api_register(return_new=False): + if return_new: + return ApiRegistry(_api_types, _inner_used_api, _supported_api_list_path, ApiTemplate) + + global api_register + if api_register is None: + api_register = ApiRegistry(_api_types, _inner_used_api, _supported_api_list_path, ApiTemplate) + return api_register diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/api_registry.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/api_registry.py deleted file mode 100644 index 1aad89bd6e89ae839513001b1d51572b50d8280b..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/msprobe/pytorch/hook_module/api_registry.py +++ /dev/null @@ -1,166 +0,0 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -import torch.distributed as dist - -from msprobe.pytorch.hook_module import wrap_torch, wrap_functional, wrap_tensor, wrap_vf, wrap_distributed, wrap_aten -from msprobe.pytorch.hook_module.wrap_aten import get_aten_ops -from msprobe.pytorch.hook_module.wrap_distributed import get_distributed_ops -from msprobe.pytorch.hook_module.wrap_functional import get_functional_ops -from msprobe.pytorch.hook_module.wrap_tensor import get_tensor_ops -from msprobe.pytorch.hook_module.wrap_torch import get_torch_ops -from msprobe.pytorch.hook_module.wrap_vf import get_vf_ops -from msprobe.pytorch.common.utils import torch_without_guard_version, npu_distributed_api, is_gpu -from msprobe.core.common.const import Const - -torch_version_above_2 = torch.__version__.split('+')[0] > '2.0' - -if not is_gpu: - import torch_npu - from . import wrap_npu_custom - from .wrap_npu_custom import get_npu_ops - - -class ApiRegistry: - def __init__(self): - self.tensor_ori_attr = {} - self.torch_ori_attr = {} - self.functional_ori_attr = {} - self.distributed_ori_attr = {} - self.npu_distributed_ori_attr = {} - self.vf_ori_attr = {} - self.aten_ori_attr = {} - self.torch_npu_ori_attr = {} - - self.tensor_hook_attr = {} - self.torch_hook_attr = {} - self.functional_hook_attr = {} - self.distributed_hook_attr = {} - self.npu_distributed_hook_attr = {} - self.vf_hook_attr = {} - self.aten_hook_attr = {} - self.torch_npu_hook_attr = {} - - @staticmethod - def store_ori_attr(ori_api_group, api_list, api_ori_attr): - for api in api_list: - if '.' in api: - sub_module_name, sub_op = api.rsplit('.', 1) - sub_module = getattr(ori_api_group, sub_module_name) - api_ori_attr[api] = getattr(sub_module, sub_op) - else: - api_ori_attr[api] = getattr(ori_api_group, api) - - @staticmethod - def set_api_attr(api_group, attr_dict): - for api, api_attr in attr_dict.items(): - if '.' in api: - sub_module_name, sub_op = api.rsplit('.', 1) - sub_module = getattr(api_group, sub_module_name, None) - if sub_module is not None: - setattr(sub_module, sub_op, api_attr) - else: - setattr(api_group, api, api_attr) - - def api_modularity(self): - self.set_api_attr(torch.Tensor, self.tensor_hook_attr) - self.set_api_attr(torch, self.torch_hook_attr) - self.set_api_attr(torch.nn.functional, self.functional_hook_attr) - self.set_api_attr(dist, self.distributed_hook_attr) - self.set_api_attr(dist.distributed_c10d, self.distributed_hook_attr) - if not is_gpu and not torch_without_guard_version: - self.set_api_attr(torch_npu.distributed, self.npu_distributed_hook_attr) - self.set_api_attr(torch_npu.distributed.distributed_c10d, self.npu_distributed_hook_attr) - if torch_version_above_2: - self.set_api_attr(torch.ops.aten, self.aten_hook_attr) - self.set_api_attr(torch._VF, self.vf_hook_attr) - if not is_gpu: - self.set_api_attr(torch_npu, self.torch_npu_hook_attr) - - def api_originality(self): - self.set_api_attr(torch.Tensor, self.tensor_ori_attr) - self.set_api_attr(torch, self.torch_ori_attr) - self.set_api_attr(torch.nn.functional, self.functional_ori_attr) - self.set_api_attr(dist, self.distributed_ori_attr) - self.set_api_attr(dist.distributed_c10d, self.distributed_ori_attr) - if not is_gpu and not torch_without_guard_version: - self.set_api_attr(torch_npu.distributed, self.npu_distributed_ori_attr) - self.set_api_attr(torch_npu.distributed.distributed_c10d, self.npu_distributed_ori_attr) - if torch_version_above_2: - self.set_api_attr(torch.ops.aten, self.aten_ori_attr) - self.set_api_attr(torch._VF, self.vf_ori_attr) - if not is_gpu: - self.set_api_attr(torch_npu, self.torch_npu_ori_attr) - - def initialize_hook(self, hook, online_run_ut=False): - """ - initialize_hook - Args: - hook (_type_): initialize_hook - online_run_ut (bool): default False, whether online run_ut or not. - If online_run_ut is True, the hook will not wrap the aten ops. - """ - self.store_ori_attr(torch.Tensor, get_tensor_ops(), self.tensor_ori_attr) - wrap_tensor.wrap_tensor_ops_and_bind(hook) - for attr_name in dir(wrap_tensor.HOOKTensor): - if attr_name.startswith(Const.ATTR_NAME_PREFIX): - self.tensor_hook_attr[attr_name[5:]] = getattr(wrap_tensor.HOOKTensor, attr_name) - - self.store_ori_attr(torch, get_torch_ops(), self.torch_ori_attr) - wrap_torch.wrap_torch_ops_and_bind(hook) - for attr_name in dir(wrap_torch.HOOKTorchOP): - if attr_name.startswith(Const.ATTR_NAME_PREFIX): - self.torch_hook_attr[attr_name[5:]] = getattr(wrap_torch.HOOKTorchOP, attr_name) - - self.store_ori_attr(torch.nn.functional, get_functional_ops(), self.functional_ori_attr) - wrap_functional.wrap_functional_ops_and_bind(hook) - for attr_name in dir(wrap_functional.HOOKFunctionalOP): - if attr_name.startswith(Const.ATTR_NAME_PREFIX): - self.functional_hook_attr[attr_name[5:]] = getattr(wrap_functional.HOOKFunctionalOP, attr_name) - - self.store_ori_attr(dist, get_distributed_ops(), self.distributed_ori_attr) - wrap_distributed.wrap_distributed_ops_and_bind(hook) - if not is_gpu and not torch_without_guard_version: - self.store_ori_attr(torch_npu.distributed, npu_distributed_api, self.npu_distributed_ori_attr) - for attr_name in dir(wrap_distributed.HOOKDistributedOP): - if attr_name.startswith(Const.ATTR_NAME_PREFIX): - self.distributed_hook_attr[attr_name[5:]] = getattr(wrap_distributed.HOOKDistributedOP, attr_name) - if not is_gpu and not torch_without_guard_version and attr_name[5:] in npu_distributed_api: - self.npu_distributed_hook_attr[attr_name[5:]] = getattr(wrap_distributed.HOOKDistributedOP, - attr_name) - - if torch_version_above_2 and not online_run_ut: - self.store_ori_attr(torch.ops.aten, get_aten_ops(), self.aten_ori_attr) - wrap_aten.wrap_aten_ops_and_bind(hook) - for attr_name in dir(wrap_aten.HOOKAtenOP): - if attr_name.startswith(Const.ATTR_NAME_PREFIX): - self.aten_hook_attr[attr_name[5:]] = getattr(wrap_aten.HOOKAtenOP, attr_name) - - self.store_ori_attr(torch._VF, get_vf_ops(), self.vf_ori_attr) - wrap_vf.wrap_vf_ops_and_bind(hook) - for attr_name in dir(wrap_vf.HOOKVfOP): - if attr_name.startswith(Const.ATTR_NAME_PREFIX): - self.vf_hook_attr[attr_name[5:]] = getattr(wrap_vf.HOOKVfOP, attr_name) - - if not is_gpu: - self.store_ori_attr(torch_npu, get_npu_ops(), self.torch_npu_ori_attr) - wrap_npu_custom.wrap_npu_ops_and_bind(hook) - for attr_name in dir(wrap_npu_custom.HOOKNpuOP): - if attr_name.startswith(Const.ATTR_NAME_PREFIX): - self.torch_npu_hook_attr[attr_name[5:]] = getattr(wrap_npu_custom.HOOKNpuOP, attr_name) - - -api_register = ApiRegistry() diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/hook_module.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/hook_module.py index b59d4be82f2b55326c2a1d6a8a9e127a8470bff6..1eba9897b08dfa6bc49d14c9036ffe86510bb563 100644 --- a/debug/accuracy_tools/msprobe/pytorch/hook_module/hook_module.py +++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/hook_module.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -28,28 +28,27 @@ class HOOKModule(nn.Module): module_count = defaultdict(int) inner_stop_hook = {} - def __init__(self, build_hook) -> None: + def __init__(self, hook_build_func) -> None: super(HOOKModule, self).__init__() self.has_overflow = False - self.prefix = "" self.current_thread = threading.current_thread().ident if self.current_thread not in HOOKModule.inner_stop_hook: HOOKModule.inner_stop_hook[self.current_thread] = False self.stop_hook = HOOKModule.inner_stop_hook.get(self.current_thread, False) if not self.stop_hook: - if hasattr(self, "prefix_op_name_"): - self.prefix = self.prefix_op_name_ - self.forward_data_collected = False - forward_pre_hook, forward_hook, backward_hook, _ = build_hook(self.prefix) - if torch_version_above_or_equal_2: - self.register_forward_pre_hook(forward_pre_hook, with_kwargs=True) - self.register_forward_hook(forward_hook, with_kwargs=True) - else: - self.register_forward_pre_hook(forward_pre_hook) - self.register_forward_hook(forward_hook) - self.register_backward_hook(backward_hook) + + prefix = self.prefix_api_name if hasattr(self, "prefix_api_name") else "" + if callable(hook_build_func): + forward_pre_hook, forward_hook, backward_hook, _ = hook_build_func(prefix) + if torch_version_above_or_equal_2: + self.register_forward_pre_hook(forward_pre_hook, with_kwargs=True) + self.register_forward_hook(forward_hook, with_kwargs=True) + else: + self.register_forward_pre_hook(forward_pre_hook) + self.register_forward_hook(forward_hook) + self.register_backward_hook(backward_hook) def __call__(self, *args, **kwargs): changed = False @@ -111,6 +110,10 @@ class HOOKModule(nn.Module): return result else: return result + + if not (var.requires_grad and torch.is_grad_enabled()): + return result + grad_fn = var.grad_fn if grad_fn is not None: for hook in non_full_backward_hooks: diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/support_wrap_ops.yaml b/debug/accuracy_tools/msprobe/pytorch/hook_module/support_wrap_ops.yaml index 4bc22f51ceb5497f307fb4ac3226c8c590ea459a..f7f4955485435b3c2a5c13da62139d2bb519d9a1 100644 --- a/debug/accuracy_tools/msprobe/pytorch/hook_module/support_wrap_ops.yaml +++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/support_wrap_ops.yaml @@ -149,9 +149,10 @@ tensor: - __bool__ - __div__ - __eq__ + - __floordiv__ - __ge__ - - __gt__ - __getitem__ + - __gt__ - __iadd__ - __iand__ - __idiv__ @@ -160,23 +161,33 @@ tensor: - __imod__ - __imul__ - __ior__ + - __ipow__ - __irshift__ - __isub__ - __ixor__ + - __le__ - __lshift__ + - __lt__ - __matmul__ - __mod__ - __mul__ + - __ne__ - __nonzero__ - __or__ + - __pow__ - __radd__ + - __rdiv__ + - __rmod__ - __rmul__ + - __ror__ + - __rpow__ - __rshift__ + - __rsub__ + - __rxor__ - __setitem__ - __sub__ - __truediv__ - __xor__ - - __pow__ - abs - abs_ - absolute @@ -199,12 +210,14 @@ tensor: - addmv_ - addr - addr_ + - adjoint - align_as - align_to - all - allclose - amax - amin + - aminmax - angle - any - arccos @@ -216,12 +229,15 @@ tensor: - arcsinh - arcsinh_ - arctan + - arctan2 + - arctan2_ - arctan_ - arctanh - arctanh_ - argmax - argmin - argsort + - argwhere - asin - asin_ - asinh @@ -236,39 +252,51 @@ tensor: - baddbmm_ - bernoulli - bernoulli_ + - bfloat16 - bincount - bitwise_and - bitwise_and_ + - bitwise_left_shift + - bitwise_left_shift_ - bitwise_not - bitwise_not_ - bitwise_or - bitwise_or_ + - bitwise_right_shift + - bitwise_right_shift_ - bitwise_xor - bitwise_xor_ - bmm + - bool - broadcast_to + - byte - cauchy_ - ceil - ceil_ + - cfloat + - char - cholesky + - cholesky_inverse + - cholesky_solve - chunk - clamp - - cholesky_solve - - cholesky_inverse - clamp_ - clamp_max - clamp_max_ - - clip - clamp_min - clamp_min_ + - clip - clip_ + - conj_physical - copysign - copysign_ + - corrcoef - cos - cos_ - cosh - cosh_ - count_nonzero + - cov - cummax - cummin - cumprod @@ -282,20 +310,23 @@ tensor: - diag_embed - diagflat - diagonal + - diagonal_scatter - diff - - dist - digamma - digamma_ + - dist - div - div_ - divide - divide_ - dot + - double + - dsplit - eig - eq - eq_ - - erf - equal + - erf - erf_ - erfc - erfc_ @@ -304,18 +335,21 @@ tensor: - exp - exp2 - exp2_ - - expm1 - exp_ + - expand + - expand_as + - expm1 - expm1_ - exponential_ - fill_ - - fix - fill_diagonal_ + - fix - fix_ + - flatten - flip - fliplr - - flatten - flipud + - float - float_power - float_power_ - floor @@ -328,6 +362,7 @@ tensor: - fmod_ - frac - frac_ + - frexp - gather - gcd - gcd_ @@ -338,31 +373,37 @@ tensor: - ger - greater - greater_ - - gt - - gt_ - greater_equal - greater_equal_ + - gt + - gt_ + - half - hardshrink - heaviside - heaviside_ - histc + - histogram + - hsplit - hypot - hypot_ + - i0 + - i0_ - igamma - igamma_ - igammac - igammac_ - index_add - index_add_ - - inverse - index_copy - index_copy_ - index_fill - index_fill_ - index_put - index_put_ - - inner - index_select + - inner + - int + - inverse - isclose - isfinite - isinf @@ -380,7 +421,6 @@ tensor: - le_ - lerp - lerp_ - - where - less - less_ - less_equal @@ -397,43 +437,47 @@ tensor: - log_ - log_normal_ - log_softmax - - logcumsumexp - - logdet - logaddexp - logaddexp2 + - logcumsumexp + - logdet - logical_and - logical_and_ - logical_not - - logit - logical_not_ - logical_or - logical_or_ - logical_xor - logical_xor_ + - logit - logit_ - logsumexp + - long - lstsq - lt - lt_ + - lu - lu_solve - map2_ - map_ - masked_fill - - matmul - masked_fill_ - masked_scatter - masked_scatter_ - masked_select + - matmul - matrix_exp + - matrix_power - max - maximum - mean - - matrix_power - median - min - minimum - mm - mode + - moveaxis + - movedim - msort - mul - mul_ @@ -443,6 +487,11 @@ tensor: - mv - mvlgamma - mvlgamma_ + - nan_to_num + - nan_to_num_ + - nanmean + - nanmedian + - nanquantile - nansum - narrow - narrow_copy @@ -452,20 +501,29 @@ tensor: - neg_ - negative - negative_ + - nextafter + - nextafter_ - nonzero - norm - normal_ - not_equal - not_equal_ + - numpy + - orgqr + - ormqr + - outer - permute - pinverse - polygamma + - polygamma_ - pow - pow_ - - polygamma_ - prelu - prod - put_ + - q_zero_point + - qr + - quantile - rad2deg - rad2deg_ - ravel @@ -474,15 +532,16 @@ tensor: - relu - relu_ - remainder - - repeat_interleave - - reshape - remainder_ - renorm - renorm_ - repeat + - repeat_interleave + - reshape - reshape_as - resize_ - resize_as_ + - resolve_neg - roll - rot90 - round @@ -496,6 +555,7 @@ tensor: - select - sgn - sgn_ + - short - sigmoid - sigmoid_ - sign @@ -507,11 +567,13 @@ tensor: - sinc_ - sinh - sinh_ + - slice_scatter - slogdet - smm - softmax - solve - sort + - split - split_with_sizes - sqrt - sqrt_ @@ -521,21 +583,29 @@ tensor: - squeeze_ - sspaddmm - std + - stft + - stride - sub - sub_ + - subtract - sum - sum_to_size - svd + - swapaxes + - swapdims + - swapdims_ - symeig - t - t_ - take + - take_along_dim - tan - tan_ - tanh - tanh_ - tensor_split - tile + - to - topk - transpose - transpose_ @@ -543,8 +613,8 @@ tensor: - tril - tril_ - triu - - true_divide - triu_ + - true_divide - true_divide_ - trunc - trunc_ @@ -552,37 +622,20 @@ tensor: - unbind - unflatten - unfold + - unique + - unique_consecutive - unsafe_chunk - - unsqueeze - unsafe_split - unsafe_split_with_sizes + - unsqueeze + - unsqueeze_ - var - vdot - - unsqueeze_ - view_as + - vsplit + - where - xlogy - xlogy_ - - split - - stft - - nan_to_num - - dsplit - - orgqr - - bitwise_left_shift_ - - arctan2 - - histogram - - q_zero_point - - adjoint - - ormqr - - bitwise_right_shift_ - - nanquantile - - lu - - quantile - - arctan2_ - - qr - - diagonal_scatter - - corrcoef - - vsplit - - aminmax torch: - linalg.norm @@ -642,13 +695,14 @@ torch: - addmv - addmv_ - addr - - amax - affine_grid_generator - align_tensors - all - alpha_dropout - - amin - alpha_dropout_ + - amax + - amin + - aminmax - angle - any - arange @@ -661,12 +715,14 @@ torch: - arcsinh - arcsinh_ - arctan + - arctan2 - arctan_ - arctanh - arctanh_ - argmax - argmin - argsort + - argwhere - asin - asin_ - asinh @@ -687,13 +743,13 @@ torch: - batch_norm_elemt - batch_norm_gather_stats - batch_norm_gather_stats_with_counts - - bernoulli - batch_norm_stats - batch_norm_update_stats + - bernoulli - bilinear + - binary_cross_entropy_with_logits - bincount - binomial - - binary_cross_entropy_with_logits - bitwise_and - bitwise_not - bitwise_or @@ -739,9 +795,9 @@ torch: - conv_transpose1d - conv_transpose2d - conv_transpose3d - - cos - convolution - copysign + - cos - cos_ - cosh - cosh_ @@ -755,14 +811,16 @@ torch: - cummin - cumprod - cumsum + - cumulative_trapezoid - deg2rad - deg2rad_ - det - diag - diag_embed - - diff - diagflat - diagonal + - diagonal_scatter + - diff - digamma - dist - div @@ -771,12 +829,15 @@ torch: - dropout - dropout_ - dsmm + - dsplit - dstack - eig - einsum - embedding - embedding_bag - embedding_renorm_ + - empty + - empty_like - eq - equal - erf @@ -791,12 +852,12 @@ torch: - expm1 - expm1_ - eye - - feature_dropout - feature_alpha_dropout - feature_alpha_dropout_ + - feature_dropout - feature_dropout_ - - fix - fill_ + - fix - fix_ - flatten - flip @@ -811,8 +872,9 @@ torch: - fmod - frac - frac_ - - full + - frexp - frobenius_norm + - full - full_like - gather - gcd @@ -824,8 +886,8 @@ torch: - greater_equal - grid_sampler - grid_sampler_2d - - group_norm - grid_sampler_3d + - group_norm - gru - gru_cell - gt @@ -835,23 +897,29 @@ torch: - heaviside - hinge_embedding_loss - histc + - histogram + - histogramdd - hsmm + - hsplit - hspmm - hstack - hypot + - i0 + - i0_ - igamma - igammac - index_add - index_copy - - inner - index_fill - index_put - index_put_ - index_select + - inner - instance_norm - inverse - isclose - isfinite + - isin - isinf - isnan - isneginf @@ -879,8 +947,8 @@ torch: - log1p_ - log2 - log2_ - - log_softmax - log_ + - log_softmax - logaddexp - logaddexp2 - logcumsumexp @@ -899,18 +967,18 @@ torch: - lt - lu_solve - lu_unpack - - masked_fill - margin_ranking_loss + - masked_fill - masked_scatter - masked_select - - matrix_exp - matmul + - matrix_exp - matrix_power - matrix_rank - max - max_pool1d - - max_pool2d - max_pool1d_with_indices + - max_pool2d - max_pool3d - maximum - mean @@ -929,18 +997,20 @@ torch: - mvlgamma - nan_to_num - nan_to_num_ + - nanmean - nanmedian + - nanquantile - nansum - narrow + - narrow_copy - native_batch_norm - native_group_norm - - narrow_copy - native_layer_norm - native_norm - ne - neg - - negative - neg_ + - negative - negative_ - nextafter - nonzero @@ -972,30 +1042,31 @@ torch: - ravel - real - reciprocal - - relu - reciprocal_ + - relu - relu_ - remainder - renorm - repeat_interleave - reshape - resize_as_ + - resolve_neg - roll - rot90 - round - round_ + - row_stack - rrelu - rrelu_ - rsqrt - - row_stack - rsqrt_ - rsub - saddmm - scalar_tensor - scatter - - select - scatter_add - searchsorted + - select - selu - selu_ - sgn @@ -1015,12 +1086,12 @@ torch: - solve - sort - sparse_coo_tensor - - square - split - split_with_sizes - spmm - sqrt - sqrt_ + - square - square_ - squeeze - sspaddmm @@ -1042,8 +1113,8 @@ torch: - tan_ - tanh - tanh_ - - tensordot - tensor_split + - tensordot - threshold - threshold_ - tile @@ -1059,19 +1130,21 @@ torch: - true_divide - trunc - trunc_ - - unique_consecutive - - xlogy - unbind + - unflatten + - unique_consecutive - unsafe_chunk - unsafe_split - - vander - - var - - vdot - unsafe_split_with_sizes - unsqueeze + - vander + - var - var_mean + - vdot + - vsplit - vstack - where + - xlogy - xlogy_ _VF: @@ -1165,6 +1238,27 @@ torch_npu: - npu_moe_finalize_routing - npu_moe_gating_top_k_softmax - npu_trans_quant_param + - npu_gelu + - npu_ffn + - npu_quant_matmul + - npu_format_cast_ + - npu_dynamic_quant + - npu_moe_compute_expert_tokens + - npu_weight_quant_batchmatmul + - npu_dynamic_quant_asymmetric + - npu_grouped_matmul + - npu_quant_scatter_ + - npu_group_quant + - npu_fused_infer_attention_score + - npu_quantize + - npu_fast_gelu + - npu_weight_quant_batchmatmul + - scatter_update + - scatter_update_ + - npu_moe_init_routing + - npu_scatter_nd_update_ + - npu_scatter_nd_update + - npu_prefetch aten: - signbit @@ -1912,4 +2006,8 @@ distributed: - all_to_all - all_gather_into_tensor - reduce_scatter_tensor - - batch_isend_irecv \ No newline at end of file + - batch_isend_irecv + +npu_distributed: + - isend + - irecv \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_distributed.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_distributed.py deleted file mode 100644 index 1cd11842c31bacdad7c1bb90f98ac81c3415a40e..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_distributed.py +++ /dev/null @@ -1,79 +0,0 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -from functools import wraps -import torch.distributed as dist - -from msprobe.pytorch.hook_module.hook_module import HOOKModule -from msprobe.pytorch.common.utils import torch_device_guard -from msprobe.core.common.const import Const -from msprobe.core.common.file_utils import load_yaml - - -cur_path = os.path.dirname(os.path.realpath(__file__)) -yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml") - - -distributed_func = {} -for f in dir(dist): - distributed_func[f] = getattr(dist, f) - - -def get_distributed_ops(): - _all_distributed_ops = dir(dist) - yaml_data = load_yaml(yaml_path) - wrap_distributed_ops = yaml_data.get('distributed') - return set(wrap_distributed_ops) & set(_all_distributed_ops) - - -class HOOKDistributedOP(object): - pass - - -class DistributedOPTemplate(HOOKModule): - def __init__(self, op_name, build_hook): - self.op_name_ = op_name - self.prefix_op_name_ = "Distributed" + Const.SEP + str(op_name) + Const.SEP - super().__init__(build_hook) - if not self.stop_hook: - self.op_is_distributed = True - - @torch_device_guard - def forward(self, *args, **kwargs): - handle = distributed_func.get(self.op_name_)(*args, **kwargs) - if kwargs.get("async_op") or self.op_name_ in ["isend", "irecv"]: - if handle and hasattr(handle, 'wait'): - handle.wait() - if self.op_name_ == "batch_isend_irecv": - if isinstance(handle, list): - for req in handle: - req.wait() - return handle - - -def wrap_distributed_op(op_name, hook): - @wraps(DistributedOPTemplate) - def distributed_op_template(*args, **kwargs): - return DistributedOPTemplate(op_name, hook)(*args, **kwargs) - - distributed_op_template.__name__ = op_name - return distributed_op_template - - -def wrap_distributed_ops_and_bind(hook): - _distributed_ops = get_distributed_ops() - for op_name in _distributed_ops: - setattr(HOOKDistributedOP, "wrap_" + str(op_name), wrap_distributed_op(op_name, hook)) diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_functional.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_functional.py deleted file mode 100644 index 6164169476dab66ac2bdb8d0cbc41a04ddce6713..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_functional.py +++ /dev/null @@ -1,66 +0,0 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import torch - -from msprobe.pytorch.hook_module.hook_module import HOOKModule -from msprobe.pytorch.common.utils import torch_device_guard -from msprobe.core.common.const import Const -from msprobe.pytorch.common.log import logger -from msprobe.core.common.file_utils import load_yaml - - -cur_path = os.path.dirname(os.path.realpath(__file__)) -yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml") - - -def get_functional_ops(): - yaml_data = load_yaml(yaml_path) - wrap_functional_ops = yaml_data.get('functional') - _all_functional_ops = dir(torch.nn.functional) - return set(wrap_functional_ops) & set(_all_functional_ops) - - -TorchFunctions = {func: getattr(torch.nn.functional, func) for func in get_functional_ops()} - - -class HOOKFunctionalOP(object): - pass - - -class FunctionalOPTemplate(HOOKModule): - def __init__(self, op_name, hook, need_hook=True): - self.op_name_ = op_name - self.prefix_op_name_ = "Functional" + Const.SEP + str(op_name) + Const.SEP - if need_hook: - super().__init__(hook) - - @torch_device_guard - def forward(self, *args, **kwargs): - return TorchFunctions[str(self.op_name_)](*args, **kwargs) - - -def wrap_functional_op(op_name, hook): - def functional_op_template(*args, **kwargs): - return FunctionalOPTemplate(op_name, hook)(*args, **kwargs) - - return functional_op_template - - -def wrap_functional_ops_and_bind(hook): - _functional_ops = get_functional_ops() - for op_name in _functional_ops: - setattr(HOOKFunctionalOP, "wrap_" + op_name, wrap_functional_op(op_name, hook)) diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_npu_custom.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_npu_custom.py deleted file mode 100644 index 1c0afc59f50c069fbcd7e9a546c5b57c467400a9..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_npu_custom.py +++ /dev/null @@ -1,85 +0,0 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import torch - -from msprobe.pytorch.hook_module.hook_module import HOOKModule -from msprobe.pytorch.common.utils import torch_device_guard, torch_without_guard_version -from msprobe.core.common.const import Const -from msprobe.core.common.log import logger -from msprobe.core.common.file_utils import load_yaml -from msprobe.pytorch.function_factory import npu_custom_functions - -try: - import torch_npu -except ImportError: - logger.info("Failing to import torch_npu.") - - -cur_path = os.path.dirname(os.path.realpath(__file__)) -yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml") -cuda_func_mapping = {"npu_fusion_attention" : "gpu_fusion_attention"} - - -def get_npu_ops(): - if torch_without_guard_version: - _npu_ops = dir(torch.ops.npu) - else: - _npu_ops = dir(torch_npu._C._VariableFunctionsClass) - yaml_data = load_yaml(yaml_path) - wrap_npu_ops = yaml_data.get('torch_npu') - return set(wrap_npu_ops) & set(_npu_ops) - - -class HOOKNpuOP(object): - pass - - -class NpuOPTemplate(HOOKModule): - - def __init__(self, op_name, hook, need_hook=True, device=Const.CPU_LOWERCASE): - self.op_name_ = op_name - self.prefix_op_name_ = "NPU" + Const.SEP + str(op_name) + Const.SEP - self.need_hook = need_hook - self.device = device - if need_hook: - super().__init__(hook) - - @torch_device_guard - def forward(self, *args, **kwargs): - if not self.need_hook: - if self.op_name_ not in npu_custom_functions: - raise Exception(f'There is not bench function {self.op_name_}') - if self.device == Const.CUDA_LOWERCASE: - self.op_name_ = cuda_func_mapping.get(self.op_name_, self.op_name_) - if self.device in [Const.CUDA_LOWERCASE, Const.CPU_LOWERCASE]: - return npu_custom_functions[self.op_name_](*args, **kwargs) - if torch_without_guard_version: - return getattr(torch.ops.npu, str(self.op_name_))(*args, **kwargs) - else: - return getattr(torch_npu._C._VariableFunctionsClass, str(self.op_name_))(*args, **kwargs) - - -def wrap_npu_op(op_name, hook): - def npu_op_template(*args, **kwargs): - return NpuOPTemplate(op_name, hook)(*args, **kwargs) - return npu_op_template - - -def wrap_npu_ops_and_bind(hook): - _npu_ops = get_npu_ops() - for op_name in _npu_ops: - setattr(HOOKNpuOP, "wrap_" + str(op_name), wrap_npu_op(op_name, hook)) diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_tensor.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_tensor.py deleted file mode 100644 index f93c09a12415f22d96306ebc9de919520c025236..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_tensor.py +++ /dev/null @@ -1,69 +0,0 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os - -import torch - -from msprobe.pytorch.hook_module.hook_module import HOOKModule -from msprobe.pytorch.common.utils import torch_device_guard, parameter_adapter -from msprobe.core.common.const import Const -from msprobe.core.common.file_utils import load_yaml - - -cur_path = os.path.dirname(os.path.realpath(__file__)) -yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml") - - -def get_tensor_ops(): - _tensor_ops = dir(torch.Tensor) - yaml_data = load_yaml(yaml_path) - wrap_tensor_ops = yaml_data.get('tensor') - return set(wrap_tensor_ops) & set(_tensor_ops) - - -TensorOps = {op: getattr(torch.Tensor, op) for op in get_tensor_ops()} - - -class HOOKTensor(object): - pass - - -class TensorOPTemplate(HOOKModule): - - def __init__(self, op_name, hook, need_hook=True): - self.op_name_ = op_name - self.prefix_op_name_ = "Tensor" + Const.SEP + str(op_name) + Const.SEP - if need_hook: - super().__init__(hook) - - @torch_device_guard - @parameter_adapter - def forward(self, *args, **kwargs): - return TensorOps[str(self.op_name_)](*args, **kwargs) - - -def wrap_tensor_op(op_name, hook): - - def tensor_op_template(*args, **kwargs): - return TensorOPTemplate(op_name, hook)(*args, **kwargs) - - return tensor_op_template - - -def wrap_tensor_ops_and_bind(hook): - _tensor_ops = get_tensor_ops() - for op_name in _tensor_ops: - setattr(HOOKTensor, "wrap_" + str(op_name), wrap_tensor_op(op_name, hook)) diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_torch.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_torch.py deleted file mode 100644 index fc9d61c206bcfaeda7fefb5cb8b90fda2d67cb16..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_torch.py +++ /dev/null @@ -1,84 +0,0 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import torch - -from msprobe.pytorch.hook_module.hook_module import HOOKModule -from msprobe.pytorch.common.utils import torch_device_guard -from msprobe.core.common.const import Const -from msprobe.core.common.file_utils import load_yaml - - -cur_path = os.path.dirname(os.path.realpath(__file__)) -yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml") - - -def get_torch_ops(): - _torch_ops = [] - yaml_data = load_yaml(yaml_path) - wrap_torch_ops = yaml_data.get('torch') - for operation in wrap_torch_ops: - if '.' in operation: - operation_sub_module_name, operation_sub_op = operation.rsplit('.', 1) - operation_sub_module = getattr(torch, operation_sub_module_name) - if operation_sub_op in dir(operation_sub_module): - _torch_ops.append(operation) - else: - if hasattr(torch, operation): - _torch_ops.append(operation) - return set(_torch_ops) - - -TorchOps = {} -for op in get_torch_ops(): - if '.' in op: - sub_module_name, sub_op = op.rsplit('.', 1) - sub_module = getattr(torch, sub_module_name) - TorchOps[op] = getattr(sub_module, sub_op) - else: - TorchOps[op] = getattr(torch, op) - - - -class HOOKTorchOP(object): - pass - - -class TorchOPTemplate(HOOKModule): - - def __init__(self, op_name, hook, need_hook=True): - self.op_name_ = op_name - self.prefix_op_name_ = "Torch" + Const.SEP + str(op_name) + Const.SEP - if need_hook: - super().__init__(hook) - - @torch_device_guard - def forward(self, *args, **kwargs): - return TorchOps[str(self.op_name_)](*args, **kwargs) - - -def wrap_torch_op(op_name, hook): - - def torch_op_template(*args, **kwargs): - return TorchOPTemplate(op_name, hook)(*args, **kwargs) - - return torch_op_template - - -def wrap_torch_ops_and_bind(hook): - _torch_ops = get_torch_ops() - for op_name in _torch_ops: - setattr(HOOKTorchOP, "wrap_" + op_name, wrap_torch_op(op_name, hook)) diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/distributed/wrap_distributed.py b/debug/accuracy_tools/msprobe/pytorch/monitor/distributed/wrap_distributed.py index b2fa26a58e702120fcabd5d82f8e1e0ed27f3bc4..20ef3757d4ad45cc2bb90769f44eef1cebe82560 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/distributed/wrap_distributed.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/distributed/wrap_distributed.py @@ -24,6 +24,7 @@ import torch.nn as nn from msprobe.core.common.const import MonitorConst from msprobe.core.common.file_utils import load_yaml from msprobe.pytorch.monitor.module_metric import get_metrics, get_summary_writer_tag_name +from msprobe.pytorch.common.log import logger try: import torch_npu @@ -37,6 +38,7 @@ WrapDistributedOps = load_yaml(OpsPath).get("distributed", []) StackBlackListPath = os.path.join(os.path.dirname(__file__), "stack_blacklist.yaml") StackBlackList = load_yaml(StackBlackListPath).get("stack", []) +MAX_STRING_LENGTH = 1000 distributed_func = {} for f in dir(dist): @@ -139,6 +141,8 @@ def get_process_group(process_group): def stack_filter(stack): + if len(stack) > MAX_STRING_LENGTH: + logger.warning(f'The character strin contains more than {MAX_STRING_LENGTH}. re match is skipped.') for pattern in StackBlackList: if re.search(pattern, stack): return False @@ -188,10 +192,12 @@ def update_data(old, new): def is_target_line(codeline): - stack = get_callstack() - whole_stack = ';'.join(stack) if codeline == []: return True + stack = get_callstack() + whole_stack = ';'.join(stack) + if len(whole_stack) > MAX_STRING_LENGTH: + logger.warning(f'The character strin contains more than {MAX_STRING_LENGTH}. re match is skipped.') for pattern in codeline: if re.search(pattern, whole_stack): return True diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py b/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py index d0285564d3cb5c00b69933db3259b7c3339c443d..2db2a9712566e41406ff9e82d1e4e8cff4b32ef9 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py @@ -26,6 +26,7 @@ from torch.utils.hooks import BackwardHook from msprobe.core.common.const import MonitorConst, Const from msprobe.core.common.file_utils import load_json, save_json +from msprobe.core.common.utils import recursion_depth_decorator from msprobe.pytorch.common.log import logger from msprobe.pytorch.common.utils import is_recomputation from msprobe.pytorch.monitor.anomaly_analyse import AnomalyDataWriter @@ -735,6 +736,7 @@ class TrainerMon: logger.info_on_rank_0(f"> {hooked_count} modules are monitored.") + @recursion_depth_decorator('msprobe.pytorch.monitor.clone_if_tensor') def clone_if_tensor(args): if isinstance(args, tuple): return tuple([clone_if_tensor(arg) for arg in args]) @@ -1066,7 +1068,7 @@ class TrainerMon: self.enable_megatron = True logger.info("megatron version is > core_r0.8.0 <= core_r0.9.0") except ImportError: - self.enable_megatron = False + self.enable_megatron = False | self.enable_megatron if not self.enable_megatron: self._hook_weights() diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/optimizer_collect.py b/debug/accuracy_tools/msprobe/pytorch/monitor/optimizer_collect.py index 602514836d2531ad4a6be3a23f56bc3b942ba199..131f3ecab47e3935dfd8512b1b3e8ad479b3ba50 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/optimizer_collect.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/optimizer_collect.py @@ -185,7 +185,7 @@ class MegatronChainedDistributedOptimizerMon(MegatronDistributedOptimizerMon): for opt in torch_opt.chained_optimizers: self.map_fp16_tp_fp32_param(opt) - if not isinstance(torch_opt, torch.optim.Optimizer): + if not isinstance(torch_opt, torch.optim.Optimizer) and not hasattr(torch_opt, 'state'): torch_opt.state = {} for opt in torch_opt.chained_optimizers: torch_opt.state.update(opt.optimizer.state) @@ -198,7 +198,7 @@ class MegatronChainedMixPrecisionOptimizerMon(MixPrecisionOptimizerMon): for opt in torch_opt.chained_optimizers: self.map_fp16_tp_fp32_param(opt) - if not isinstance(torch_opt, torch.optim.Optimizer): + if not isinstance(torch_opt, torch.optim.Optimizer) and not hasattr(torch_opt, 'state'): torch_opt.state = {} for opt in torch_opt.chained_optimizers: torch_opt.state.update(opt.optimizer.state) @@ -206,9 +206,60 @@ class MegatronChainedMixPrecisionOptimizerMon(MixPrecisionOptimizerMon): class DeepSpeedZeroOptimizerStage0Mon(OptimizerMon): - def fetch_mv(self, monitor, torch_opt, params2name): - return self._fetch_mv_in_adam(monitor, torch_opt, params2name) + def get_group_index(self, torch_opt): + bit16_groups = torch_opt.bf16_groups + param2group = defaultdict() + for group_idx, bit16_group in enumerate(bit16_groups): + for param in bit16_group: + param2group[param] = group_idx + return param2group + + def fetch_mv(self, monitor, torch_opt, params2name, name2indices=None): + param2group = self.get_group_index(torch_opt) + exp_avg_dict = defaultdict(float) + exp_avg_sq_dict = defaultdict(float) + update_dict = defaultdict() + ratio_dict = defaultdict() + + param_slice_mappings = torch_opt.state_dict()['param_slice_mappings'] + for param, name in params2name.items(): + group_idx = param2group[param] + state = torch_opt.state[torch_opt.fp32_groups_flat_partition[group_idx]] + if state.get('exp_avg', None) is None: + logger.warning(f"optimizer state is None. Something is wrong if this is not the first step") + break + param_slice_mapping = param_slice_mappings[group_idx] + hp_address = param_slice_mapping.get(torch_opt.param_names[param]) + if hp_address is None: + continue + start = hp_address.start + numel = hp_address.numel + if monitor.mv_distribution: + exp_avg_dict[name] = state['exp_avg'].narrow(0, start, numel) + exp_avg_sq_dict[name] = state['exp_avg_sq'].narrow(0, start, numel) + if monitor.mg_direction: + exp_avg_dict[name] = state['exp'].narrow(0, start, numel) + if monitor.ur_distribution: + if len(torch_opt.param_groups) > 1: + logger.info(f"the length of torch_opt.param_groups is {len(torch_opt.param_groups)}.") + if 'step' in state: + step = state['step'] # Optimizer from pytorch or FusedAdam from apex(used by megatron) + elif 'step' in torch_opt.param_groups[0]: + step = torch_opt.param_groups[0]['step'] # AdamW from mindspeed + else: + logger.warning(f"step of {name} is None, maybe something wrong happened.") + continue + exp_avg = state['exp_avg'].narrow(0, start, numel) + exp_avg_sq = state['exp_avg_sq'].narrow(0, start, numel) + exp_avg_hat = exp_avg / (1 - torch_opt.defaults['betas'][0] ** step) + exp_avg_sq_hat = exp_avg_sq / (1 - torch_opt.defaults['betas'][1] ** step) + update_dict[name] = exp_avg_hat / (torch.sqrt(exp_avg_sq_hat) + torch_opt.defaults['eps']) + ratio_dict[name] = exp_avg_hat / torch.sqrt(exp_avg_sq_hat) + monitor.update_heatmap_visualizer[name].pre_cal(update_dict[name]) + monitor.ratio_heatmap_visualizer[name].pre_cal(ratio_dict[name]) + return MVResult(exp_avg=exp_avg_dict, exp_avg_sq=exp_avg_sq_dict, update=update_dict, ratio=ratio_dict) + class DeepSpeedZeroOptimizerStage3Mon(OptimizerMon): def get_param_index(self, params2name, name2index, torch_opt): diff --git a/debug/accuracy_tools/msprobe/pytorch/online_dispatch/dump_compare.py b/debug/accuracy_tools/msprobe/pytorch/online_dispatch/dump_compare.py index b185bc1110d4062d8a31b9cc94dc946d8fb8456c..a154064755ed116eeb2f2ea97b50160bf4b7beb9 100644 --- a/debug/accuracy_tools/msprobe/pytorch/online_dispatch/dump_compare.py +++ b/debug/accuracy_tools/msprobe/pytorch/online_dispatch/dump_compare.py @@ -19,6 +19,8 @@ import os from datetime import datetime, timezone import torch +from msprobe.core.common.const import Const +from msprobe.core.common.utils import recursion_depth_decorator from msprobe.core.common.file_utils import FileOpen, save_npy, save_json from msprobe.pytorch.common.log import logger @@ -91,6 +93,7 @@ def support_basic_type(data): return False +@recursion_depth_decorator("dump_data") def dump_data(data, prefix, dump_path): if isinstance(data, (tuple, list)) and data: for i, item in enumerate(data): diff --git a/debug/accuracy_tools/msprobe/pytorch/online_dispatch/utils.py b/debug/accuracy_tools/msprobe/pytorch/online_dispatch/utils.py index ae8b9435a34ced607d4e70fab615b2b017083fe9..2116186cc046865388c40d33142384301f84acd2 100644 --- a/debug/accuracy_tools/msprobe/pytorch/online_dispatch/utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/online_dispatch/utils.py @@ -27,8 +27,10 @@ else: pta_cpu_device = torch.device("cpu") from msprobe.core.common.const import CompareConst +from msprobe.core.common.utils import recursion_depth_decorator from msprobe.pytorch.common.log import logger + cpu_device = torch._C.device("cpu") COLOR_RED = '\033[31m' COLOR_GREEN = '\033[32m' @@ -85,6 +87,7 @@ def get_callstack(): return callstack +@recursion_depth_decorator("data_to_cpu") def data_to_cpu(data, deep, data_cpu): global cpu_device list_cpu = [] diff --git a/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/utils.py b/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/utils.py index 66229d36b8d0b532eea48f1aa5d96e178ed80cdc..db731b338244ca78bfd460633a7442b1e2ef2d5d 100644 --- a/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/utils.py @@ -264,7 +264,7 @@ class Util: match = re_pattern.match(name) if not match: continue - if extern_pattern != '' and re_pattern.match(extern_pattern) and not re.match(extern_pattern, name): + if extern_pattern != '' and re_pattern.match(extern_pattern) and not name.startswith(extern_pattern): continue file_list[name] = gen_info_func(name, match, file["root"]) return file_list diff --git a/debug/accuracy_tools/msprobe/pytorch/pt_config.py b/debug/accuracy_tools/msprobe/pytorch/pt_config.py index 8293ac969490b103eef630081b6001234ca8bb07..fedc62a32d6f4dd6c894b264a22f46a55ac291d2 100644 --- a/debug/accuracy_tools/msprobe/pytorch/pt_config.py +++ b/debug/accuracy_tools/msprobe/pytorch/pt_config.py @@ -148,7 +148,7 @@ class FreeBenchmarkCheckConfig(BaseConfig): self.pert_mode in PytorchFreeBenchmarkConst.CPU_MODE_LIST ): msg = ( - f"You neet to and can only set fuzz_device as {DeviceType.CPU} " + f"You need to and can only set fuzz_device as {DeviceType.CPU} " f"when pert_mode in {PytorchFreeBenchmarkConst.CPU_MODE_LIST}" ) logger.error_log_with_exp( diff --git a/debug/accuracy_tools/msprobe/pytorch/service.py b/debug/accuracy_tools/msprobe/pytorch/service.py index fd81a7f1cf064506a4fb91481429828c97113509..27235292e1ed740a33e8d1778b06c7993d4ae94e 100644 --- a/debug/accuracy_tools/msprobe/pytorch/service.py +++ b/debug/accuracy_tools/msprobe/pytorch/service.py @@ -30,7 +30,7 @@ from msprobe.pytorch.common.log import logger from msprobe.pytorch.common.utils import get_rank_if_initialized, is_recomputation from msprobe.pytorch.dump.kernel_dump.kernel_config import create_kernel_config_json from msprobe.pytorch.dump.module_dump.module_processer import ModuleProcesser -from msprobe.pytorch.hook_module.api_registry import api_register +from msprobe.pytorch.hook_module.api_register import get_api_register from msprobe.pytorch.hook_module.hook_module import HOOKModule from msprobe.pytorch.hook_module.register_optimizer_hook import register_optimizer_hook @@ -50,6 +50,8 @@ class Service: self.switch = False self.inner_switch = False self.current_iter = 0 + self.loop = 0 + self.init_step = 0 self.first_start = True self.current_rank = None self.dump_iter_dir = None @@ -58,6 +60,7 @@ class Service: self.params_grad_info = {} self.hook_handle_dict = {} # 提前注册,确保注册尽可能多的API hook + self.api_register = get_api_register() self.register_api_hook() self.init_for_debug_level() @@ -246,6 +249,8 @@ class Service: return HookFn(pre_forward_hook_fn, forward_hook_fn, backward_hook_fn, forward_hook_torch_version_below_2_fn) def start(self, model): + self.current_iter = self.loop + self.init_step + self.data_collector.update_iter(self.current_iter) if self.config.level == Const.LEVEL_DEBUG: return if self.need_stop_service(): @@ -304,8 +309,7 @@ class Service: if self.config.task == Const.TENSOR: self.data_collector.data_processor.dump_async_data() self.data_collector.write_json() - self.current_iter += 1 - self.data_collector.update_iter(self.current_iter) + self.loop += 1 self.reset_status() def need_stop_service(self): @@ -370,11 +374,10 @@ class Service: def register_api_hook(self): if self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L1, Const.LEVEL_L2]: logger.info_on_rank_0(f"The api {self.config.task} hook function is successfully mounted to the model.") - api_register.initialize_hook( - functools.partial(self.build_hook, BaseScope.Module_Type_API), - self.config.online_run_ut + self.api_register.initialize_hook( + functools.partial(self.build_hook, BaseScope.Module_Type_API) ) - api_register.api_modularity() + self.api_register.register_all_api() def register_module_hook(self): if self.config.level in [Const.LEVEL_L0, Const.LEVEL_MIX]: @@ -409,7 +412,7 @@ class Service: if self.config.nfs_path: self.attl.upload("end") elif self.attl.socket_manager is not None: - logger.info(f"pid: {os.getpid()} finished, start send STOP signal.") + logger.info(f"pid: {os.getpid()} finished, start sends STOP signal.") self.attl.socket_manager.send_stop_signal() def reset_status(self): diff --git a/debug/accuracy_tools/msprobe/test/core_ut/common/test_dump_file/dump_no_pt_no_ms.json b/debug/accuracy_tools/msprobe/test/core_ut/common/test_dump_file/dump_no_pt_no_ms.json new file mode 100644 index 0000000000000000000000000000000000000000..63a062d8ffa264a0254fc2bab0208dcf951ae094 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/core_ut/common/test_dump_file/dump_no_pt_no_ms.json @@ -0,0 +1,3 @@ +{ + "task": "tensor" +} \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/test/core_ut/common/test_dump_file/ms_dump_no_framework.json b/debug/accuracy_tools/msprobe/test/core_ut/common/test_dump_file/ms_dump_no_framework.json new file mode 100644 index 0000000000000000000000000000000000000000..b223c74b2315af1b9454e5f1e70c29502d449c56 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/core_ut/common/test_dump_file/ms_dump_no_framework.json @@ -0,0 +1,4 @@ +{ + "task": "tensor", + "type": "mindspore.float16" +} \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/test/core_ut/common/test_dump_file/pt_dump_no_framework.json b/debug/accuracy_tools/msprobe/test/core_ut/common/test_dump_file/pt_dump_no_framework.json new file mode 100644 index 0000000000000000000000000000000000000000..2444ae1fd4096b083a9e8a0e51c9166bb990f51f --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/core_ut/common/test_dump_file/pt_dump_no_framework.json @@ -0,0 +1,4 @@ +{ + "task": "tensor", + "type": "torch.float16" +} \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/test/core_ut/common/test_utils.py b/debug/accuracy_tools/msprobe/test/core_ut/common/test_utils.py index 3472ca9018e189ffb48e4d26cfeb79e1ba1ff16d..61766ed27c0a58f4fff81fb2f45618de60bb5b48 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/common/test_utils.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/common/test_utils.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- """ -# Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. +# Copyright (C) 2024-2025. Huawei Technologies Co., Ltd. All rights reserved. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -18,11 +18,13 @@ import json import os import tempfile from datetime import datetime, timezone +import unittest from unittest import TestCase from unittest.mock import MagicMock, mock_open, patch import OpenSSL import numpy as np +from pathlib import Path from msprobe.core.common.const import Const from msprobe.core.common.file_utils import ( @@ -53,7 +55,8 @@ from msprobe.core.common.utils import (CompareException, recursion_depth_decorator, MsprobeBaseException, check_str_param, - is_json_file) + is_json_file, + detect_framework_by_dump_json) class TestUtils(TestCase): @@ -488,3 +491,42 @@ class TestCheckCrtValid(TestCase): with self.assertRaises(RuntimeError) as context: check_crt_valid(self.cert_file_path) self.assertIn('The SSL certificate is invalid', str(context.exception)) + + +class TestDetectFrameworkByDumpJson(unittest.TestCase): + + @patch('msprobe.core.common.utils.load_json') + def test_valid_pytorch_framework(self, mock_load_json): + mock_load_json.return_value = {"framework": Const.PT_FRAMEWORK} + + result = detect_framework_by_dump_json("dummy_path") + + self.assertEqual(result, Const.PT_FRAMEWORK) + + @patch('msprobe.core.common.utils.load_json') + def test_valid_mindspore_framework(self, mock_load_json): + mock_load_json.return_value = {"framework": Const.MS_FRAMEWORK} + + result = detect_framework_by_dump_json("dummy_path") + + self.assertEqual(result, Const.MS_FRAMEWORK) + + def test_detect_framework_in_file(self): + self.current_dir = Path(__file__).parent + file_path = self.current_dir / "test_dump_file/pt_dump_no_framework.json" + result = detect_framework_by_dump_json(file_path) + self.assertEqual(result, Const.PT_FRAMEWORK) + + self.current_dir = Path(__file__).parent + file_path = self.current_dir / "test_dump_file/ms_dump_no_framework.json" + result = detect_framework_by_dump_json(file_path) + self.assertEqual(result, Const.MS_FRAMEWORK) + + @patch("msprobe.core.common.utils.logger") + def test_detect_framework_exception(self, mock_logger): + self.current_dir = Path(__file__).parent + file_path = self.current_dir / "test_dump_file/dump_no_pt_no_ms.json" + with self.assertRaises(CompareException) as context: + result = detect_framework_by_dump_json(file_path) + self.assertEqual(context.exception.code, CompareException.INVALID_PARAM_ERROR) + mock_logger.error.assert_called_once_with(f"{file_path} must be based on the MindSpore or PyTorch framework.") diff --git a/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare.py b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare.py index b4566fcfe6f48d9040feb4dc22f3a96cd08719a7..94244be326e9954c700339abec2db16a2ab31b07 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare.py @@ -11,7 +11,7 @@ import torch from msprobe.core.common.const import CompareConst, Const from msprobe.core.common.utils import CompareException -from msprobe.core.compare.acc_compare import Comparator, ModeConfig, get_bench_data_name +from msprobe.core.compare.acc_compare import Comparator, ModeConfig from msprobe.core.compare.highlight import find_error_rows, find_compare_result_error_rows, ApiBatch from msprobe.core.compare.utils import get_accuracy from msprobe.pytorch.compare.pt_compare import PTComparator @@ -159,16 +159,16 @@ aten_result = [ -10.640625, -0.008758544921875, 5.397906303405762, -5.796811580657959, 2.5283952709287405e-10, 'Warning', 'Need double check api accuracy.', 'None'], ['Aten__native_batch_norm_legit_functional.default_0_forward.output.1', 'Nan', 'torch.float32', 'Nan', [256], 'Nan', - ' ', ' ', ' ', ' ', ' ', 0.30550330877304077, -0.24485322833061218, -0.010361209511756897, 'Nan', 'Nan', 'Nan', + ' ', ' ', ' ', ' ', ' ', ' ', 0.30550330877304077, -0.24485322833061218, -0.010361209511756897, 'Nan', 'Nan', 'Nan', 'Yes', '', 'None'], ['Aten__native_batch_norm_legit_functional.default_0_forward.output.2', 'Nan', 'torch.float32', 'Nan', [256], 'Nan', - ' ', ' ', ' ', ' ', ' ', 623.9192504882812, 432.96826171875, 520.2276611328125, 'Nan', 'Nan', 'Nan', + ' ', ' ', ' ', ' ', ' ', ' ', 623.9192504882812, 432.96826171875, 520.2276611328125, 'Nan', 'Nan', 'Nan', 'Yes', '', 'None'], ['Aten__native_batch_norm_legit_functional.default_0_forward.output.3', 'Nan', 'torch.float32', 'Nan', [256], 'Nan', - ' ', ' ', ' ', ' ', ' ', 2.4797861576080322, -3.055997371673584, -0.04795549064874649, 'Nan', 'Nan', 'Nan', + ' ', ' ', ' ', ' ', ' ', ' ', 2.4797861576080322, -3.055997371673584, -0.04795549064874649, 'Nan', 'Nan', 'Nan', 'Yes', '', 'None'], ['Aten__native_batch_norm_legit_functional.default_0_forward.output.4', 'Nan', 'torch.float32', 'Nan', [256], 'Nan', - ' ', ' ', ' ', ' ', ' ', 61.7945556640625, 42.59713363647461, 52.03831481933594, 'Nan', 'Nan', 'Nan', + ' ', ' ', ' ', ' ', ' ', ' ', 61.7945556640625, 42.59713363647461, 52.03831481933594, 'Nan', 'Nan', 'Nan', 'Yes', '', 'None']] highlight_dict = {'red_rows': [], 'yellow_rows': []} @@ -191,17 +191,21 @@ summary_line_3 = ['Functional_batch_norm_0_forward.output.2', 'Functional_batch_ 'torch.float32', [256, 256, 14, 14], [256, 256, 14, 14], 0, 0, 0, 0, 2, 0, 1, 1, 1, 1, 1, 1, 'Warning', ''] line_input = ['Functional.batch.norm.0.forward.input.0', 'Functional.batch.norm.0.forward.input.0', 'torch.float16', - 'torch.float32', [256, 256, 14, 14], [256, 256, 14, 14], 1, 1, 1, 0.95, 1, 1, 1, 1, 1, 1.01, 1, 1, 1, + 'torch.float32', [256, 256, 14, 14], [256, 256, 14, 14], 1, 0.5, 1, 1, 0.95, 1, + 1, 1, 1, 1, 1.01, 1, 1, 1, 'Yes', ''] line_1 = ['Functional.batch.norm.0.forward.output.0', 'Functional.batch.norm.0.forward.output.0', 'torch.float16', - 'torch.float32', [256, 256, 14, 14], [256, 256, 14, 14], 0.8, 1, 1, 0.59, 1, 'nan', 0, 1, 1, 19, 1, 1, 1, - 'Warning', ''] + 'torch.float32', [256, 256, 14, 14], [256, 256, 14, 14], 0.8, 0.5, 1, 1, 0.59, 1, + 'nan', 0, 1, 1, 19, 1, 1, 1, + 'Yes', ''] line_2 = ['Functional.batch.norm.0.forward.output.1', 'Functional.batch.norm.0.forward.output.1', 'torch.float16', - 'torch.float32', [256, 256, 14, 14], [256, 256, 14, 14], 0.9, 1, 1, 0.8, 1, 0, 0.12, 0, 1, 1, 0.1, 1, 1, 1, - 'Warning', ''] + 'torch.float32', [256, 256, 14, 14], [256, 256, 14, 14], 0.9, 0.5, 1, 1, 0.8, 1, + 0, 0.12, 0, 1, 1, 0.1, 1, 1, + 'Yes', ''] line_3 = ['Functional.batch.norm.0.forward.output.2', 'Functional.batch.norm.0.forward.output.2', 'torch.float16', - 'torch.float32', [256, 256, 14, 14], [256, 256, 14, 14], 0.8, 1.1e+10, 1, 0.85, 1, 9, 0.12, 0, 1, 1, 0.1, 1, - 1, 1, 'Warning', ''] + 'torch.float32', [256, 256, 14, 14], [256, 256, 14, 14], 0.8, 0.5, 1.1e+10, 1, 0.85, 1, + 9, 0.12, 0, 1, 1, 0.1, 1, 1, + 'Yes', ''] op_data = { 'input_args': [{'type': 'torch.Tensor', 'dtype': 'torch.float32', 'shape': [16, 1, 3, 3], @@ -363,7 +367,7 @@ class TestUtilsMethods(unittest.TestCase): 'torch.float32', 'torch.float32', [2, 2], [2, 2], '', '', '', '', '', '', '', '', 1, 1, 1, 1, 1, 1, 1, 1, 'Yes', '', 'File']] result_all = [['Functional.linear.0.forward.input.0', 'Functional.linear.0.forward.input.0', - 'torch.float32', 'torch.float32', [2, 2], [2, 2], '', '', '', '', '', + 'torch.float32', 'torch.float32', [2, 2], [2, 2], '', '', '', '', '', '', 1, 1, 1, 1, 1, 1, 1, 1, 'Yes', '', 'File', '-1']] columns_md5_stack_mode_true = CompareConst.MD5_COMPARE_RESULT_HEADER + ['NPU_Stack_Info'] result_table_md5_true = pd.DataFrame(result_md5, columns=columns_md5_stack_mode_true, dtype=object) @@ -403,10 +407,10 @@ class TestUtilsMethods(unittest.TestCase): 'torch.float32', 'torch.float32', [2, 2], [2, 2], '', '', '', '', '', '', '', '', 1, 1, 1, 1, 1, 1, 1, 1, 'Yes', '']] result_all_test = [['Functional.linear.0.forward.input.0', 'Functional.linear.0.forward.input.0', - 'torch.float32', 'torch.float32', [2, 2], [2, 2], '', '', '', '', '', + 'torch.float32', 'torch.float32', [2, 2], [2, 2], '', '', '', '', '', '', 1, 1, 1, 1, 1, 1, 1, 1, 'Yes', '', '', '-1']] result_all = [['Functional.linear.0.forward.input.0', 'Functional.linear.0.forward.input.0', - 'torch.float32', 'torch.float32', [2, 2], [2, 2], '', '', '', '', '', + 'torch.float32', 'torch.float32', [2, 2], [2, 2], '', '', '', '', '', '', 1, 1, 1, 1, 1, 1, 1, 1, 'Yes', '', '-1']] columns_md5_stack_mode_true = CompareConst.MD5_COMPARE_RESULT_HEADER result_table_md5_true = pd.DataFrame(result_md5, columns=columns_md5_stack_mode_true, dtype='object') @@ -632,11 +636,11 @@ class TestUtilsMethods(unittest.TestCase): def test_do_multi_process(self): data = [['Functional.linear.0.forward.input.0', 'Functional.linear.0.forward.input.0', 'torch.float32', 'torch.float32', [2, 2], [2, 2], - '', '', '', '', '', 1, 1, 1, 1, 1, 1, 1, 1, 'Yes', '', '-1']] + '', '', '', '', '', '', 1, 1, 1, 1, 1, 1, 1, 1, 'Yes', '', ['-1', '-1']]] o_data = [['Functional.linear.0.forward.input.0', 'Functional.linear.0.forward.input.0', - 'torch.float32', 'torch.float32', [2, 2], [2, 2], 'unsupported', 'unsupported', 'unsupported', - 'unsupported', 'unsupported', - 1, 1, 1, 1, 1, 1, 1, 1, 'None', 'No bench data matched.', '-1']] + 'torch.float32', 'torch.float32', [2, 2], [2, 2], + 'unsupported', 'unsupported', 'unsupported', 'unsupported', 'unsupported', 'unsupported', + 1, 1, 1, 1, 1, 1, 1, 1, 'None', 'No bench data matched.', ['-1', '-1']]] columns = CompareConst.COMPARE_RESULT_HEADER + ['Data_name'] result_df = pd.DataFrame(data, columns=columns) o_result = pd.DataFrame(o_data, columns=columns) @@ -666,10 +670,10 @@ class TestUtilsMethods(unittest.TestCase): mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) pt_comparator = PTComparator(mode_config) - result = pt_comparator.compare_by_op(npu_op_name, bench_op_name, op_name_mapping_dict, input_param, {}) + result = pt_comparator.compare_by_op(npu_op_name, bench_op_name, op_name_mapping_dict, input_param) self.assertEqual(result, ['unsupported', 'unsupported', 'unsupported', 'unsupported', 'unsupported', - 'No bench data matched.']) + 'unsupported', 'No bench data matched.']) def test_compare_by_op_2(self): npu_op_name = 'Functional.linear.0.forward.input.0' @@ -684,42 +688,22 @@ class TestUtilsMethods(unittest.TestCase): pt_comparator = PTComparator(mode_config) pt_name = '-1' - pt_path = os.path.join(base_dir, pt_name) - op_name_mapping_dict = {'Functional.linear.0.forward.input.0': [pt_path, pt_path]} + op_name_mapping_dict = {'Functional.linear.0.forward.input.0': [pt_name, pt_name]} input_param = {'npu_dump_data_dir': base_dir, 'bench_dump_data_dir': base_dir} - result = pt_comparator.compare_by_op(npu_op_name, bench_op_name, op_name_mapping_dict, input_param, - {'Functional.linear.0.forward': {'input_args': [ - {'data_name': 'Functional.linear.0.forward.input.0.pt'}]}}) + result = pt_comparator.compare_by_op(npu_op_name, bench_op_name, op_name_mapping_dict, input_param) self.assertEqual(result, ['unsupported', 'unsupported', 'unsupported', 'unsupported', 'unsupported', - f'Dump file: {pt_path} not found.']) + 'unsupported', 'No bench data matched.']) pt_name = 'Functional.linear.0.forward.input.0.pt' - pt_path = os.path.join(base_dir, pt_name) - op_name_mapping_dict = {'Functional.linear.0.forward.input.0': [pt_path, pt_path]} + op_name_mapping_dict = {'Functional.linear.0.forward.input.0': [pt_name, pt_name]} input_param = {'npu_dump_data_dir': base_dir, 'bench_dump_data_dir': base_dir} - result = pt_comparator.compare_by_op(npu_op_name, bench_op_name, op_name_mapping_dict, input_param, {}) + result = pt_comparator.compare_by_op(npu_op_name, bench_op_name, op_name_mapping_dict, input_param) self.assertEqual(result, ['unsupported', 'unsupported', 'unsupported', 'unsupported', 'unsupported', - 'Bench does not have data file.']) + 'unsupported', 'Dump file: Functional.linear.0.forward.input.0.pt not found.']) generate_pt(base_dir) - result = pt_comparator.compare_by_op(npu_op_name, bench_op_name, op_name_mapping_dict, input_param, - {'Functional.linear.0.forward': {'input_args': [ - {'data_name': 'Functional.linear.0.forward.input.0.pt'}]}}) - self.assertEqual(result, [1.0, 0.0, 0.0, 1.0, 1.0, '']) - - def test_get_bench_data_name_input(self): - bench_op_name = "Functional.linear.0.forward.input.0" - bench_data = {"Functional.linear.0.forward": {"input_args": [{"data_name": "Functional.linear.0.forward.input.0.pt"}], "input_kwargs": {}, "output": []}} - result = get_bench_data_name(bench_op_name, bench_data) - - self.assertEqual(result, "Functional.linear.0.forward.input.0.pt") - - def test_get_bench_data_name_output(self): - bench_op_name = "Functional.linear.0.forward.output.0" - bench_data = {"Functional.linear.0.forward": {"input_args": [], "input_kwargs": {}, "output": [{"data_name": "Functional.linear.0.forward.output.0.pt"}]}} - result = get_bench_data_name(bench_op_name, bench_data) - - self.assertEqual(result, "Functional.linear.0.forward.output.0.pt") + result = pt_comparator.compare_by_op(npu_op_name, bench_op_name, op_name_mapping_dict, input_param) + self.assertEqual(result, [1.0, 0.0, 0.0, 0.0, 1.0, 1.0, '']) class TestComparator(unittest.TestCase): diff --git a/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare_npy_compare.py b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare_npy_compare.py index aec6cdc51173ae817f32dd76455bec645659b45c..da315b657c8c1fc691136a1dbc56574d69c92076 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare_npy_compare.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare_npy_compare.py @@ -20,7 +20,7 @@ from unittest.mock import patch from msprobe.core.common.const import CompareConst from msprobe.core.compare.npy_compare import handle_inf_nan, reshape_value, get_error_flag_and_msg, \ npy_data_check, statistics_data_check, get_relative_err, GetCosineSimilarity, GetMaxAbsErr, GetMaxRelativeErr, \ - GetErrRatio, error_value_process, compare_ops_apply + GetErrRatio, error_value_process, compare_ops_apply, GetEuclideanDistance op_name = 'Functional.conv2d.0.backward.input.0' @@ -113,7 +113,7 @@ class TestUtilsMethods(unittest.TestCase): n_value, b_value, error_flag, err_msg = get_error_flag_and_msg(n_value, b_value, error_flag=error_flag) self.assertFalse(error_flag) - self.assertEqual(err_msg, "This is type of 0-d tensor, can not calculate 'Cosine', " + self.assertEqual(err_msg, "This is type of 0-d tensor, can not calculate 'Cosine', 'EucDist', " "'One Thousandth Err Ratio' and 'Five Thousandths Err Ratio'. ") def test_get_error_flag_and_msg_shape_unmatch(self): @@ -239,15 +239,17 @@ class TestUtilsMethods(unittest.TestCase): b_value_1 = np.array(1) relative_err = get_relative_err(n_value_1, b_value_1) n_value_1, b_value_1 = reshape_value(n_value_1, b_value_1) - result, err_msg = op.apply(n_value_1, b_value_1, relative_err) + err_msg = "This is type of 0-d tensor, can not calculate 'Cosine', 'EucDist', 'One Thousandth Err Ratio' and 'Five Thousandths Err Ratio'. " + result, err_msg = op.apply(n_value_1, b_value_1, relative_err, err_msg) self.assertEqual(result, CompareConst.UNSUPPORTED) - self.assertEqual(err_msg, "") + self.assertEqual(err_msg, "This is type of 0-d tensor, can not calculate 'Cosine', 'EucDist', 'One Thousandth Err Ratio' and 'Five Thousandths Err Ratio'. ") n_value_2 = np.array([1, 2]) b_value_2 = np.array([1, 2]) relative_err = get_relative_err(n_value_2, b_value_2) n_value_2, b_value_2 = reshape_value(n_value_2, b_value_2) - result, err_msg = op.apply(n_value_2, b_value_2, relative_err) + err_msg = "" + result, err_msg = op.apply(n_value_2, b_value_2, relative_err, err_msg) self.assertEqual(result, 1.0) self.assertEqual(err_msg, "") @@ -255,7 +257,8 @@ class TestUtilsMethods(unittest.TestCase): b_value_3 = np.array([0, 0]) relative_err = get_relative_err(n_value_3, b_value_3) n_value_3, b_value_3 = reshape_value(n_value_3, b_value_3) - result, err_msg = op.apply(n_value_3, b_value_3, relative_err) + err_msg = "" + result, err_msg = op.apply(n_value_3, b_value_3, relative_err, err_msg) self.assertEqual(result, 1.0) self.assertEqual(err_msg, "") @@ -263,7 +266,8 @@ class TestUtilsMethods(unittest.TestCase): b_value_4 = np.array([1, 2]) relative_err = get_relative_err(n_value_4, b_value_4) n_value_4, b_value_4 = reshape_value(n_value_4, b_value_4) - result, err_msg = op.apply(n_value_4, b_value_4, relative_err) + err_msg = "" + result, err_msg = op.apply(n_value_4, b_value_4, relative_err, err_msg) self.assertEqual(result, CompareConst.NAN) self.assertEqual(err_msg, 'Cannot compare by Cosine Similarity, All the data is Zero in npu dump data.') @@ -271,7 +275,8 @@ class TestUtilsMethods(unittest.TestCase): b_value_5 = np.array([0, 0]) relative_err = get_relative_err(n_value_5, b_value_5) n_value_5, b_value_5 = reshape_value(n_value_5, b_value_5) - result, err_msg = op.apply(n_value_5, b_value_5, relative_err) + err_msg = "" + result, err_msg = op.apply(n_value_5, b_value_5, relative_err, err_msg) self.assertEqual(result, CompareConst.NAN) self.assertEqual(err_msg, 'Cannot compare by Cosine Similarity, All the data is Zero in Bench dump data.') @@ -282,7 +287,9 @@ class TestUtilsMethods(unittest.TestCase): b_value_1 = np.array([1]) relative_err = get_relative_err(n_value_1, b_value_1) n_value_1, b_value_1 = reshape_value(n_value_1, b_value_1) - result, err_msg = op.apply(n_value_1, b_value_1, relative_err) + err_msg = "" + + result, err_msg = op.apply(n_value_1, b_value_1, relative_err, err_msg) self.assertEqual(result, CompareConst.UNSUPPORTED) self.assertEqual(err_msg, "This is a 1-d tensor of length 1.") @@ -294,8 +301,9 @@ class TestUtilsMethods(unittest.TestCase): b_value = np.array([1, 1]) relative_err = get_relative_err(n_value, b_value) n_value, b_value = reshape_value(n_value, b_value) + err_msg = "" - result, err_msg = op.apply(n_value, b_value, relative_err) + result, err_msg = op.apply(n_value, b_value, relative_err, err_msg) self.assertEqual(result, CompareConst.NAN) self.assertEqual(err_msg, "Cannot compare by Cosine Similarity, the dump data has NaN.") @@ -319,8 +327,9 @@ class TestUtilsMethods(unittest.TestCase): b_value = np.array([0, 0]) relative_err = get_relative_err(n_value, b_value) n_value, b_value = reshape_value(n_value, b_value) + err_msg = "" - result, err_msg = op.apply(n_value, b_value, relative_err) + result, err_msg = op.apply(n_value, b_value, relative_err, err_msg) self.assertEqual(result, 2.0) self.assertEqual(err_msg, "") @@ -333,8 +342,9 @@ class TestUtilsMethods(unittest.TestCase): b_value = np.array([1, 1]) relative_err = get_relative_err(n_value, b_value) n_value, b_value = reshape_value(n_value, b_value) + err_msg = "" - result, err_msg = op.apply(n_value, b_value, relative_err) + result, err_msg = op.apply(n_value, b_value, relative_err, err_msg) self.assertEqual(result, CompareConst.NAN) self.assertEqual(err_msg, "Cannot compare by MaxAbsError, the data contains nan/inf/-inf in dump data.") @@ -347,8 +357,9 @@ class TestUtilsMethods(unittest.TestCase): b_value = np.array([1, 1]) relative_err = get_relative_err(n_value, b_value) n_value, b_value = reshape_value(n_value, b_value) + err_msg = "" - result, err_msg = op.apply(n_value, b_value, relative_err) + result, err_msg = op.apply(n_value, b_value, relative_err, err_msg) self.assertEqual(result, 1.0) self.assertEqual(err_msg, "") @@ -361,8 +372,9 @@ class TestUtilsMethods(unittest.TestCase): b_value = np.array([1, 1]) relative_err = get_relative_err(n_value, b_value) n_value, b_value = reshape_value(n_value, b_value) + err_msg = "" - result, err_msg = op.apply(n_value, b_value, relative_err) + result, err_msg = op.apply(n_value, b_value, relative_err, err_msg) self.assertEqual(result, CompareConst.NAN) self.assertEqual(err_msg, "Cannot compare by MaxRelativeError, the data contains nan/inf/-inf in dump data.") @@ -375,8 +387,9 @@ class TestUtilsMethods(unittest.TestCase): b_value = np.array([1, 1]) relative_err = get_relative_err(n_value, b_value) n_value, b_value = reshape_value(n_value, b_value) + err_msg = "" - result, err_msg = op.apply(n_value, b_value, relative_err) + result, err_msg = op.apply(n_value, b_value, relative_err, err_msg) self.assertEqual(result, 0.5) self.assertEqual(err_msg, "") @@ -387,11 +400,12 @@ class TestUtilsMethods(unittest.TestCase): n_value = np.array(1) # 标量 b_value = np.array(1) relative_err = np.array(0) + err_msg = "This is type of 0-d tensor, can not calculate 'Cosine', 'EucDist', 'One Thousandth Err Ratio' and 'Five Thousandths Err Ratio'. " - result, err_msg = op.apply(n_value, b_value, relative_err) + result, err_msg = op.apply(n_value, b_value, relative_err, err_msg) self.assertEqual(result, CompareConst.UNSUPPORTED) - self.assertEqual(err_msg, "") + self.assertEqual(err_msg, "This is type of 0-d tensor, can not calculate 'Cosine', 'EucDist', 'One Thousandth Err Ratio' and 'Five Thousandths Err Ratio'. ") def test_GetThousandErrRatio_not_size(self): op = GetErrRatio(CompareConst.THOUSAND_RATIO_THRESHOLD) @@ -399,8 +413,9 @@ class TestUtilsMethods(unittest.TestCase): n_value = np.array([1, 2]) b_value = np.array([1, 2]) relative_err = np.array([]) # 空数组 + err_msg = "" - result, err_msg = op.apply(n_value, b_value, relative_err) + result, err_msg = op.apply(n_value, b_value, relative_err, err_msg) self.assertEqual(result, CompareConst.NAN) self.assertEqual(err_msg, "") @@ -412,8 +427,9 @@ class TestUtilsMethods(unittest.TestCase): b_value = np.array([1, 1]) relative_err = get_relative_err(n_value, b_value) n_value, b_value = reshape_value(n_value, b_value) + err_msg = "" - result, err_msg = op.apply(n_value, b_value, relative_err) + result, err_msg = op.apply(n_value, b_value, relative_err, err_msg) self.assertEqual(result, 0.5) self.assertEqual(err_msg, "") @@ -471,5 +487,34 @@ class TestUtilsMethods(unittest.TestCase): error_flag = False err_msg = '' a, b = compare_ops_apply(n_value, b_value, error_flag, err_msg) - self.assertEqual(a, [1.0, 0.0, 0.0, 1.0, 1.0]) + self.assertEqual(a, [1.0, 0.0, 0.0, 0.0, 1.0, 1.0]) self.assertEqual(b, '') + + +class TestGetEuclideanDistance(unittest.TestCase): + + def setUp(self): + self.euc_distance = GetEuclideanDistance() + + def test_euclidean_distance_normal(self): + # 测试计算两个张量之间的欧式距离 + n_value = np.array([1, 2, 3]) + b_value = np.array([4, 5, 6]) + relative_err = None + err_msg = "" + + result, msg = self.euc_distance.apply(n_value, b_value, relative_err, err_msg) + expected_distance = np.linalg.norm(n_value - b_value) + self.assertEqual(result, expected_distance) + self.assertEqual(msg, '') + + def test_euclidean_distance_0d_tensor(self): + # 测试计算两个张量之间的欧式距离 + n_value = np.array(1) + b_value = np.array(1) + relative_err = None + err_msg = "This is type of 0-d tensor, can not calculate 'Cosine', 'EucDist', 'One Thousandth Err Ratio' and 'Five Thousandths Err Ratio'. " + + result, msg = self.euc_distance.apply(n_value, b_value, relative_err, err_msg) + self.assertEqual(result, CompareConst.UNSUPPORTED) + self.assertEqual(msg, "This is type of 0-d tensor, can not calculate 'Cosine', 'EucDist', 'One Thousandth Err Ratio' and 'Five Thousandths Err Ratio'. ") diff --git a/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare_utils.py b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare_utils.py index ab8703dcd353ff32dc0722fc314ade6042d6f567..bf23f4de1dac73a44a2497e1a927ba30e5440715 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare_utils.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare_utils.py @@ -221,28 +221,34 @@ o_result_unmatch_2 = [ 'N/A', 'N/A', 'N/A', 'N/A', 'No bench data matched.', 'None'] ] o_result_unmatch_3 = [ - ['Functional.conv2d.0.forward.input.0', 'N/A', 'torch.float32', 'N/A', [1, 1, 28, 28], 'N/A', 'N/A', 'N/A', 'N/A', - 'N/A', 'N/A', 3.029174327850342, -2.926689624786377, -0.06619918346405029, 1.0, 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', - 'No bench data matched.', 'None', '-1'], - ['Functional.conv2d.0.forward.input.1', 'N/A', 'torch.float32', 'N/A', [16, 1, 5, 5], 'N/A', 'N/A', 'N/A', 'N/A', - 'N/A', 'N/A', 0.19919930398464203, -0.19974489510059357, 0.006269412115216255, 1.0, 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', - 'No bench data matched.', 'None', '-1'], - ['Functional.conv2d.0.forward.input.2', 'N/A', 'torch.float32', 'N/A', [16], 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', - 'N/A', 0.19734230637550354, -0.18177609145641327, 0.007903944700956345, 1.0, 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', - 'No bench data matched.', 'None', '-1'], - ['Functional.conv2d.0.forward.parameters.weight', 'N/A', 'torch.float32', 'N/A', [1, 16, 28, 28], 'N/A', 'N/A', - 'N/A', 'N/A', - 'N/A', 'N/A', 1.0, 1.0, 1.0, 1.0, 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'No bench data matched.', 'None', '-1'], - ['Functional.conv2d.0.forward.parameters.bias', 'N/A', 'torch.float32', 'N/A', [1, 16, 28, 28], 'N/A', 'N/A', 'N/A', - 'N/A', - 'N/A', 'N/A', 1.0, 1.0, 1.0, 1.0, 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'No bench data matched.', 'None', '-1'], - ['Functional.conv2d.0.forward.output.0', 'N/A', 'torch.float32', 'N/A', [1, 16, 28, 28], 'N/A', 'N/A', 'N/A', 'N/A', - 'N/A', 'N/A', 2.1166646480560303, -2.190781354904175, -0.003579073818400502, 1.0, 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', - 'No bench data matched.', 'None', '-1'], - ['Functional.conv2d.0.parameters_grad.weight', 'N/A', 'torch.float32', 'N/A', [1, 16, 28, 28], 'N/A', 'N/A', 'N/A', 'N/A', - 'N/A', 'N/A', 1.0, 1.0, 1.0, 1.0, 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'No bench data matched.', 'None', '-1'], - ['Functional.conv2d.0.parameters_grad.bias', 'N/A', 'torch.float32', 'N/A', [1, 16, 28, 28], 'N/A', 'N/A', 'N/A', 'N/A', - 'N/A', 'N/A', 1.0, 1.0, 1.0, 1.0, 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'No bench data matched.', 'None', '-1'] + ['Functional.conv2d.0.forward.input.0', 'N/A', 'torch.float32', 'N/A', [1, 1, 28, 28], 'N/A', + 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', + 3.029174327850342, -2.926689624786377, -0.06619918346405029, 1.0, 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', + 'No bench data matched.', 'None', ['-1', '-1']], + ['Functional.conv2d.0.forward.input.1', 'N/A', 'torch.float32', 'N/A', [16, 1, 5, 5], 'N/A', + 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', + 0.19919930398464203, -0.19974489510059357, 0.006269412115216255, 1.0, 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', + 'No bench data matched.', 'None', ['-1', '-1']], + ['Functional.conv2d.0.forward.input.2', 'N/A', 'torch.float32', 'N/A', [16], 'N/A', + 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', + 0.19734230637550354, -0.18177609145641327, 0.007903944700956345, 1.0, 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', + 'No bench data matched.', 'None', ['-1', '-1']], + ['Functional.conv2d.0.forward.parameters.weight', 'N/A', 'torch.float32', 'N/A', [1, 16, 28, 28], 'N/A', + 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', + 1.0, 1.0, 1.0, 1.0, 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'No bench data matched.', 'None', ['-1', '-1']], + ['Functional.conv2d.0.forward.parameters.bias', 'N/A', 'torch.float32', 'N/A', [1, 16, 28, 28], 'N/A', + 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', + 1.0, 1.0, 1.0, 1.0, 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'No bench data matched.', 'None', ['-1', '-1']], + ['Functional.conv2d.0.forward.output.0', 'N/A', 'torch.float32', 'N/A', [1, 16, 28, 28], 'N/A', + 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', + 2.1166646480560303, -2.190781354904175, -0.003579073818400502, 1.0, 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', + 'No bench data matched.', 'None', ['-1', '-1']], + ['Functional.conv2d.0.parameters_grad.weight', 'N/A', 'torch.float32', 'N/A', [1, 16, 28, 28], 'N/A', + 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', + 1.0, 1.0, 1.0, 1.0, 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'No bench data matched.', 'None', ['-1', '-1']], + ['Functional.conv2d.0.parameters_grad.bias', 'N/A', 'torch.float32', 'N/A', [1, 16, 28, 28], 'N/A', + 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', + 1.0, 1.0, 1.0, 1.0, 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'No bench data matched.', 'None', ['-1', '-1']] ] # test_merge_tensor @@ -558,7 +564,7 @@ class TestUtilsMethods(unittest.TestCase): dump_mode = Const.ALL result_item = result_item_init(n_info, b_info, dump_mode) self.assertEqual(result_item, ['Tensor.add.0.forward.input.0', 'Tensor.add.0.forward.input.0', - 'torch.float32', 'torch.float32', [96], [96], ' ', ' ', ' ', ' ', ' ']) + 'torch.float32', 'torch.float32', [96], [96], ' ', ' ', ' ', ' ', ' ', ' ']) dump_mode = Const.SUMMARY result_item = result_item_init(n_info, b_info, dump_mode) diff --git a/debug/accuracy_tools/msprobe/test/core_ut/compare/test_cmp_highlight.py b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_cmp_highlight.py index f561a3e05ec84c3ee75dac50ed5aec2a2af7f7b5..3261bce5d6d0a15d8e46c7d9fc22df0cf64c9e4d 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/compare/test_cmp_highlight.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_cmp_highlight.py @@ -26,7 +26,7 @@ def generate_result_xlsx(base_dir): data_path = os.path.join(base_dir, 'target_result.xlsx') data = [['Functional.linear.0.forward.input.0', 'Functional.linear.0.forward.input.0', 'torch.float32', 'torch.float32', [2, 2], [2, 2], - '', '', '', '', '', 1, 1, 1, 1, 1, 1, 1, 1, 'Yes', '', '-1'] + '', '', '', '', '', '', 1, 1, 1, 1, 1, 1, 1, 1, 'Yes', '', '-1'] ] columns = CompareConst.COMPARE_RESULT_HEADER + ['Data_name'] result_df = pd.DataFrame(data, columns=columns) @@ -101,8 +101,8 @@ class TestUtilsMethods(unittest.TestCase): self.assertEqual(result, None) def test_CheckOneThousandErrorRatio_str(self): - api_in = [1, 1, 1, 1, 1, 1, 1, 1, 1, "unsupported"] - api_out = [1, 1, 1, 1, 1, 1, 1, 1, 1, "unsupported"] + api_in = [1, 1, 1, 1, 1, 1, 0.9, 0.5, 1, 1, "unsupported"] + api_out = [1, 1, 1, 1, 1, 1, 0.9, 0.5, 1, 1, "unsupported"] info = (api_in, api_out, 1) color_columns = () dump_mode = Const.ALL @@ -113,8 +113,8 @@ class TestUtilsMethods(unittest.TestCase): @patch("msprobe.core.compare.highlight.add_highlight_row_info") def test_CheckOneThousandErrorRatio_red(self, mock_add_highlight_row_info): - api_in = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1] - api_out = [1, 1, 1, 1, 1, 1, 1, 1, 1, 0.5] + api_in = [1, 1, 1, 1, 1, 1, 0.9, 0.5, 1, 1, 1] + api_out = [1, 1, 1, 1, 1, 1, 0.9, 0.5, 1, 1, 0.5] info = (api_in, api_out, 1) ColorColumns = namedtuple('ColorColumns', ['red', 'yellow']) color_columns = ColorColumns(red=[], yellow=[]) @@ -315,7 +315,7 @@ class TestUtilsMethods(unittest.TestCase): columns = CompareConst.COMPARE_RESULT_HEADER data = [['Functional.linear.0.forward.input.0', 'Functional.linear.0.forward.input.0', 'torch.float32', 'torch.float32', [2, 2], [2, 2], - '', '', '', '', '', 1, 1, 1, 1, 1, 1, 1, 1, 'Yes', ''] + '', '', '', '', '', '', 1, 1, 1, 1, 1, 1, 1, 1, 'Yes', ''] ] result_df = pd.DataFrame(data, columns=columns) @@ -329,7 +329,7 @@ class TestUtilsMethods(unittest.TestCase): def test_highlight_rows_xlsx_red(self): data = [['Functional.linear.0.forward.input.0', 'Functional.linear.0.forward.input.0', 'torch.float32', 'torch.float32', [2, 2], [2, 2], - '', '', '', '', '', 1, 1, 1, 1, 1, 1, 1, 1, 'Yes', '', '-1'] + '', '', '', '', '', '', 1, 1, 1, 1, 1, 1, 1, 1, 'Yes', '', '-1'] ] columns = CompareConst.COMPARE_RESULT_HEADER + ['Data_name'] result_df = pd.DataFrame(data, columns=columns) @@ -342,7 +342,7 @@ class TestUtilsMethods(unittest.TestCase): def test_highlight_rows_xlsx_yellow(self): data = [['Functional.linear.0.forward.input.0', 'Functional.linear.0.forward.input.0', 'torch.float32', 'torch.float32', [2, 2], [2, 2], - '', '', '', '', '', 1, 1, 1, 1, 1, 1, 1, 1, 'Yes', '', '-1'] + '', '', '', '', '', '', 1, 1, 1, 1, 1, 1, 1, 1, 'Yes', '', '-1'] ] columns = CompareConst.COMPARE_RESULT_HEADER + ['Data_name'] result_df = pd.DataFrame(data, columns=columns) @@ -356,7 +356,7 @@ class TestUtilsMethods(unittest.TestCase): def test_highlight_rows_xlsx_malicious_columns(self, mock_save_book): data = [['Functional.linear.0.forward.input.0', 'Functional.linear.0.forward.input.0', 'torch.float32', 'torch.float32', [2, 2], [2, 2], - '', '', '', '', '', 1, 1, 1, 1, 1, 1, 1, 1, 'Yes', '', '-1'] + '', '', '', '', '', '', 1, 1, 1, 1, 1, 1, 1, 1, 'Yes', '', '-1'] ] columns = CompareConst.COMPARE_RESULT_HEADER + ['=Data_name'] result_df = pd.DataFrame(data, columns=columns) @@ -378,10 +378,10 @@ class TestUtilsMethods(unittest.TestCase): def test_highlight_rows_xlsx_malicious_type(self, mock_save_book): data = [['Functional.linear.0.forward.input.0', 'Functional.linear.0.forward.input.0', '=torch.float32', 'torch.float32', [2, 2], [2, 2], - '', '', '', '', '', 1, 1, 1, 1, 1, 1, 1, 1, 'Yes', '', '-1'], + '', '', '', '', '', '', 1, 1, 1, 1, 1, 1, 1, 1, 'Yes', '', '-1'], ['Functional.linear.0.forward.input.0', 'Functional.linear.0.forward.input.0', '=torch.float32', 'torch.float32', [2, 2], [2, 2], - '', '', '', '', '', 1, 1, 1, 1, 1, 1, 1, 1, 'Yes', '', '-1'] + '', '', '', '', '', '', 1, 1, 1, 1, 1, 1, 1, 1, 'Yes', '', '-1'] ] columns = CompareConst.COMPARE_RESULT_HEADER + ['Data_name'] result_df = pd.DataFrame(data, columns=columns) @@ -416,10 +416,10 @@ class TestUtilsMethods(unittest.TestCase): def test_update_highlight_err_msg(self): data = [['Functional.linear.0.forward.input.0', 'Functional.linear.0.forward.input.0', 'torch.float32', 'torch.float32', [2, 2], [2, 2], - '', '', '', '', '', 1, 1, 1, 1, 1, 1, 1, 1, 'Yes', '', '-1'], + '', '', '', '', '', '', 1, 1, 1, 1, 1, 1, 1, 1, 'Yes', '', '-1'], ['Functional.linear.0.forward.input.0', 'Functional.linear.0.forward.input.0', 'torch.float32', 'torch.float32', [2, 2], [2, 2], - '', '', '', '', '', 1, 1, 1, 1, 1, 1, 1, 1, 'Yes', '', '-1'] + '', '', '', '', '', '', 1, 1, 1, 1, 1, 1, 1, 1, 'Yes', '', '-1'] ] columns = CompareConst.COMPARE_RESULT_HEADER + ['Data_name'] result_df = pd.DataFrame(data, columns=columns) @@ -433,10 +433,10 @@ class TestUtilsMethods(unittest.TestCase): t_data = [['Functional.linear.0.forward.input.0', 'Functional.linear.0.forward.input.0', 'torch.float32', 'torch.float32', [2, 2], [2, 2], - '', '', '', '', '', 1, 1, 1, 1, 1, 1, 1, 1, 'Yes', 'a\nb', '-1'], + '', '', '', '', '', '', 1, 1, 1, 1, 1, 1, 1, 1, 'Yes', 'a\nb', '-1'], ['Functional.linear.0.forward.input.0', 'Functional.linear.0.forward.input.0', 'torch.float32', 'torch.float32', [2, 2], [2, 2], - '', '', '', '', '', 1, 1, 1, 1, 1, 1, 1, 1, 'Yes', 'd', '-1'] + '', '', '', '', '', '', 1, 1, 1, 1, 1, 1, 1, 1, 'Yes', 'd', '-1'] ] target_result_df = pd.DataFrame(t_data, columns=columns) self.assertTrue(result_df.equals(target_result_df)) diff --git a/debug/accuracy_tools/msprobe/test/core_ut/compare/test_cmp_multiprocessing_compute.py b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_cmp_multiprocessing_compute.py index 9c2dea835fea13af7902bf796d9ab06c9eb6a61b..49f084ce07c8e90afb2aa1c3340bb4c3965c8fa7 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/compare/test_cmp_multiprocessing_compute.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_cmp_multiprocessing_compute.py @@ -16,14 +16,14 @@ from test_acc_compare import generate_dump_json data = [['Functional.linear.0.forward.input.0', 'Functional.linear.0.forward.input.0', 'torch.float32', 'torch.float32', [2, 2], [2, 2], - '', '', '', '', '', + '', '', '', '', '', '', 1, 1, 1, 1, 1, 1, 1, 1, - 'Yes', '', '-1']] + 'Yes', '', ['-1', '-1']]] o_data = [['Functional.linear.0.forward.input.0', 'Functional.linear.0.forward.input.0', 'torch.float32', 'torch.float32', [2, 2], [2, 2], - 'unsupported', 'unsupported', 'unsupported', 'unsupported', 'unsupported', + 'unsupported', 'unsupported', 'unsupported', 'unsupported', 'unsupported', 'unsupported', 1, 1, 1, 1, 1, 1, 1, 1, - 'None', 'No bench data matched.', '-1']] + 'None', 'No bench data matched.', ['-1', '-1']]] columns = CompareConst.COMPARE_RESULT_HEADER + ['Data_name'] result_df = pd.DataFrame(data, columns=columns) o_result = pd.DataFrame(o_data, columns=columns) @@ -34,9 +34,9 @@ class TestUtilsMethods(unittest.TestCase): def setUp(self): self.result_df = pd.DataFrame(columns=[ - CompareConst.COSINE, CompareConst.MAX_ABS_ERR, CompareConst.MAX_RELATIVE_ERR, - CompareConst.ERROR_MESSAGE, CompareConst.ACCURACY, - CompareConst.ONE_THOUSANDTH_ERR_RATIO, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO + CompareConst.COSINE, CompareConst.EUC_DIST, CompareConst.MAX_ABS_ERR, CompareConst.MAX_RELATIVE_ERR, + CompareConst.ONE_THOUSANDTH_ERR_RATIO, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO, + CompareConst.ACCURACY, CompareConst.ERROR_MESSAGE ]) os.makedirs(base_dir, mode=0o750, exist_ok=True) self.lock = threading.Lock() @@ -54,9 +54,9 @@ class TestUtilsMethods(unittest.TestCase): func = Comparator(mode_config).compare_ops generate_dump_json(base_dir) - input_parma = {'bench_json_path': os.path.join(base_dir, 'dump.json')} + input_param = {'bench_json_path': os.path.join(base_dir, 'dump.json')} lock = multiprocessing.Manager().RLock() - result = _handle_multi_process(func, input_parma, result_df, lock) + result = _handle_multi_process(func, input_param, result_df, lock) self.assertTrue(result.equals(o_result)) def test_read_dump_data(self): @@ -72,9 +72,10 @@ class TestUtilsMethods(unittest.TestCase): cos_result=[0.99, 0.98], max_err_result=[0.01, 0.02], max_relative_err_result=[0.001, 0.002], - err_msgs=['', 'Error in comparison'], + euc_dist_result=[0.5, 0.49], one_thousand_err_ratio_result=[0.1, 0.2], - five_thousand_err_ratio_result=[0.05, 0.1] + five_thousand_err_ratio_result=[0.05, 0.1], + err_msgs=['', 'Error in comparison'] ) offset = 0 updated_df = _save_cmp_result(offset, comparison_result, self.result_df, self.lock) @@ -88,9 +89,10 @@ class TestUtilsMethods(unittest.TestCase): cos_result=[0.99], max_err_result=[], max_relative_err_result=[0.001], - err_msgs=[''], + euc_dist_result=[0.5], one_thousand_err_ratio_result=[0.1], - five_thousand_err_ratio_result=[0.05] + five_thousand_err_ratio_result=[0.05], + err_msgs=[''] ) with self.assertRaises(CompareException) as context: _save_cmp_result(0, comparison_result, self.result_df, self.lock) diff --git a/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_pytorch_processor.py b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_pytorch_processor.py index 34064e7cc2b9d0aa5c0c2e98806b8993137a589c..3d31a1bb51679c28d2cc25ecced891e31ce4dcfd 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_pytorch_processor.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_pytorch_processor.py @@ -19,6 +19,7 @@ from msprobe.core.data_dump.data_processor.pytorch_processor import ( KernelDumpDataProcessor ) from torch import distributed as dist +from torch._subclasses import FakeTensorMode class TestPytorchDataProcessor(unittest.TestCase): @@ -62,6 +63,15 @@ class TestPytorchDataProcessor(unittest.TestCase): result = PytorchDataProcessor.get_stat_info(mock_data) self.assertIsInstance(result, TensorStatInfo) + def test_get_stat_info_with_fake_tensor(self): + with FakeTensorMode() as fake_tensor_mode: + fake_tensor = fake_tensor_mode.from_tensor(torch.randn(1, 2, 3)) + result = PytorchDataProcessor.get_stat_info(fake_tensor) + self.assertIsNone(result.max) + self.assertIsNone(result.min) + self.assertIsNone(result.mean) + self.assertIsNone(result.norm) + def test_get_stat_info_float(self): tensor = torch.tensor([1.0, 2.0, 3.0]) result = self.processor.get_stat_info(tensor) diff --git a/debug/accuracy_tools/msprobe/test/core_ut/data_dump/test_api_registry.py b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/test_api_registry.py new file mode 100644 index 0000000000000000000000000000000000000000..c67c5d8ee9efd201cdcf09bc82471cac1f6607c3 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/test_api_registry.py @@ -0,0 +1,73 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from unittest import TestCase +from unittest.mock import patch + +import torch + +from msprobe.core.common.const import Const +from msprobe.core.data_dump.api_registry import _get_attr, ApiWrapper + + +class TestFunctions(TestCase): + def test__get_attr(self): + module = torch + + attr_name = 'linalg.norm' + target_value = torch.linalg.norm + actual_value = _get_attr(module, attr_name) + self.assertEqual(target_value, actual_value) + + attr_name = 'norm' + target_value = torch.norm + actual_value = _get_attr(module, attr_name) + self.assertEqual(target_value, actual_value) + + +class TestApiWrapper(TestCase): + api_types = { + Const.PT_FRAMEWORK: { + Const.PT_API_TYPE_TORCH: (torch, torch), + } + } + supported_api_list_path = (Const.SUPPORT_API_FILE_NAME,) + yaml_value = {'torch': ['linalg.norm', 'norm']} + api_names = {Const.PT_FRAMEWORK: {'torch': {'linalg.norm', 'norm'}}} + + def test___init__(self): + with patch('msprobe.core.data_dump.api_registry.load_yaml', return_value=self.yaml_value): + api_wrapper = ApiWrapper(self.api_types, self.supported_api_list_path) + self.assertEqual(api_wrapper.api_types, self.api_types) + self.assertEqual(api_wrapper.api_list_paths, self.supported_api_list_path) + self.assertEqual(api_wrapper.api_names, self.api_names) + self.assertEqual(api_wrapper.wrapped_api_functions, {}) + + api_wrapper = ApiWrapper(self.api_types, Const.SUPPORT_API_FILE_NAME) + self.assertEqual(api_wrapper.api_list_paths, list(self.supported_api_list_path)) + + with self.assertRaises(Exception) as context: + api_wrapper = ApiWrapper(self.api_types, (Const.SUPPORT_API_FILE_NAME, Const.SUPPORT_API_FILE_NAME)) + self.assertEqual(str(context.exception), + "The number of api_list_paths must be equal to the number of frameworks in 'api_types', " + "when api_list_paths is a list or tuple.") + + def test__get_api_names(self): + target_value = self.api_names + with patch('msprobe.core.data_dump.api_registry.load_yaml', return_value=self.yaml_value): + api_wrapper = ApiWrapper(self.api_types, self.supported_api_list_path) + actual_value = api_wrapper._get_api_names() + self.assertEqual(target_value, actual_value) diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/api_accuracy_checker/test_data_manager.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/api_accuracy_checker/test_data_manager.py index bb4c8b197ef8362921858839ca3790224715a39a..9cfad00d8ff13e91eb84fff5f46ab434f9ed1d4d 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/api_accuracy_checker/test_data_manager.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/api_accuracy_checker/test_data_manager.py @@ -2,7 +2,8 @@ import unittest from unittest.mock import patch, mock_open, MagicMock import os from msprobe.mindspore.api_accuracy_checker.api_accuracy_checker import DataManager -from msprobe.core.common.const import MsCompareConst, CompareConst +from msprobe.core.common.const import CompareConst +from msprobe.mindspore.common.const import MsCompareConst class TestDataManager(unittest.TestCase): diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/compare/dump_file/mindspore_data/dump.json b/debug/accuracy_tools/msprobe/test/mindspore_ut/compare/dump_file/mindspore_data/dump.json index 5b954f6d6443c92e6321e5f55e373e99f428653d..48800c0455c6651b146600e61e636d4dc25fac31 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/compare/dump_file/mindspore_data/dump.json +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/compare/dump_file/mindspore_data/dump.json @@ -1,6 +1,7 @@ { "task": "statistics", "level": "mix", + "framework": "mindspore", "dump_data_dir": null, "data": { "Tensor.__add__.0.forward": { diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/compare/dump_file/pytorch_data/dump.json b/debug/accuracy_tools/msprobe/test/mindspore_ut/compare/dump_file/pytorch_data/dump.json index 150cbd43b169573e48542aa0c46c26e7df69843e..b2704185ff19b961b43453f81247236d77677d83 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/compare/dump_file/pytorch_data/dump.json +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/compare/dump_file/pytorch_data/dump.json @@ -1,6 +1,7 @@ { "task": "statistics", "level": "mix", + "framework": "pytorch", "dump_data_dir": null, "data": { "Tensor.__add__.0.forward": { diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/compare/test_ms_compare.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/compare/test_ms_compare.py index b5cbff9784a837ea4d64ac9eccdf30175564f712..6f7377894002e60add41dc7b2d3c1d3d68391e0b 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/compare/test_ms_compare.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/compare/test_ms_compare.py @@ -5,8 +5,10 @@ import random import shutil import tempfile import unittest +from unittest.mock import patch import numpy as np +import pandas as pd import torch import yaml @@ -350,21 +352,21 @@ class TestUtilsMethods(unittest.TestCase): finally: shutil.rmtree(data_path) - def test_check_cross_framework(self): - ms_data = { - "data_name": "Cell.model.language_model.encoder.layers.5.input_norm.FusedRMSNorm.forward.0.input.0.npy", - } - pt_data = { - "data_name": "Module.module.module.language_model.encoder.layers.0.input_norm.RMSNorm.forward.0.input.0.pt", - } + @patch('msprobe.mindspore.compare.ms_compare.detect_framework_by_dump_json') + def test_check_cross_framework_valid_pytorch(self, mock_detect_framework): + mock_detect_framework.return_value = Const.PT_FRAMEWORK + + result = check_cross_framework("dummy_path") + + self.assertTrue(result) - def check_data(data): - with tempfile.NamedTemporaryFile(mode='w+', suffix='.json', encoding='utf-8', delete=True) as temp_file: - json.dump(data, temp_file, ensure_ascii=False, indent=4) - temp_file.flush() - return check_cross_framework(temp_file.name) - self.assertFalse(check_data(ms_data)) - self.assertTrue(check_data(pt_data)) + @patch('msprobe.mindspore.compare.ms_compare.detect_framework_by_dump_json') + def test_check_cross_framework_invalid_framework(self, mock_detect_framework): + mock_detect_framework.return_value = Const.MS_FRAMEWORK + + result = check_cross_framework("dummy_path") + + self.assertFalse(result) def test_comapre_process(self): data_path = tempfile.mkdtemp(prefix='dump_data', dir='/tmp') @@ -533,4 +535,28 @@ class TestUtilsMethods(unittest.TestCase): api_list = ["Mint"] with self.assertRaises(CompareException): - ms_comparator.get_api_name(api_list) \ No newline at end of file + ms_comparator.get_api_name(api_list) + + def test_process_data_name(self): + stack_mode = True + auto_analyze = True + fuzzy_match = False + dump_mode = Const.ALL + + mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode) + mapping_config = MappingConfig() + ms_comparator = MSComparator(mode_config, mapping_config) + + data = pd.DataFrame({ + 'data_name_x': ['A', 'B', 'C'], + 'data_name_y': ['X', 'Y', 'Z'] + }) + + result = ms_comparator.process_data_name(data.copy()) + + expected = pd.DataFrame({ + 'data_name_x': [['A', 'X'], ['B', 'Y'], ['C', 'Z']], + 'data_name_y': ['X', 'Y', 'Z'] + }) + + pd.testing.assert_frame_equal(result, expected) diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/compare/test_ms_graph_compare.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/compare/test_ms_graph_compare.py index e3fd9348efe7dd4df0a6db2cd52a45f4757dae01..c2e7c9368c3f049511657469ebb16388015b621a 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/compare/test_ms_graph_compare.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/compare/test_ms_graph_compare.py @@ -78,7 +78,7 @@ class TestMsGraphCompare(unittest.TestCase): result_correct = ( f"[['{npu_file_path}', '{bench_file_path}', dtype('float16'), dtype('float16'), (10, 10), (10, 10), " - f"44.0, 44.0, 44.0, inf, 44.0, 44.0, 44.0, inf, 'Yes', '', 1.0, 0.0, 0.0, 1.0, 1.0]]") + f"44.0, 44.0, 44.0, inf, 44.0, 44.0, 44.0, inf, 'Yes', '', 1.0, 0.0, 0.0, 0.0, 1.0, 1.0]]") self.assertNotEqual(len(files), 0) self.assertEqual(result, result_correct) diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/debugger/test_ms_precision_debugger.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/debugger/test_ms_precision_debugger.py index 066ff537ce6fba12f712ae3d4681115499be35a6..86fdcc08385fc0fe68da541b473a02e130d28478 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/debugger/test_ms_precision_debugger.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/debugger/test_ms_precision_debugger.py @@ -98,6 +98,9 @@ class TestPrecisionDebugger(unittest.TestCase): def __init__(self): self.task = Const.TENSOR self.service = None + self.config = MagicMock() + self.config.level_ori = MagicMock() + self.config.level_ori.return_value = Const.LEVEL_L1 PrecisionDebugger._instance = None with self.assertRaises(Exception) as context: PrecisionDebugger.stop() diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/free_benchmark/test_ms_api_pynative_self_check.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/free_benchmark/test_ms_api_pynative_self_check.py index e589dd4d58715d74644047f8c7e7a6ce79ccf225..c4482a22f042723f12431b56c730a8d69957b63c 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/free_benchmark/test_ms_api_pynative_self_check.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/free_benchmark/test_ms_api_pynative_self_check.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -23,12 +23,18 @@ from mindspore import Tensor, mint, ops from msprobe.core.common.const import Const from msprobe.mindspore.common.const import FreeBenchmarkConst from msprobe.mindspore.common.log import logger -from msprobe.mindspore.dump.hook_cell.api_registry import api_register -from msprobe.mindspore.free_benchmark.api_pynative_self_check import (ApiPyNativeSelfCheck, check_all_tensor, - check_self, data_pre_deal, - deal_fuzzed_and_original_result, - get_module, get_supported_ops, - get_target_arg_index, need_wrapper_func) +from msprobe.mindspore.free_benchmark.api_pynative_self_check import ( + ApiPyNativeSelfCheck, + check_all_tensor, + check_self, + data_pre_deal, + deal_fuzzed_and_original_result, + get_module, + get_supported_ops, + get_target_arg_index, + need_wrapper_func, + _api_register +) from msprobe.mindspore.free_benchmark.common.config import Config from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams from msprobe.mindspore.free_benchmark.common.utils import Tools @@ -83,8 +89,8 @@ class TestApiPyNativeSelfCheck(TestCase): self.assertEqual(self_checker.ori_func, target_ori_func) def test_handle(self): - with patch.object(api_register, "initialize_hook") as mock_init_hook, \ - patch.object(api_register, "api_set_hook_func") as mock_set_hook: + with patch.object(_api_register, "initialize_hook") as mock_init_hook, \ + patch.object(_api_register, "register_all_api") as mock_set_hook: self.checker.handle() mock_init_hook.assert_called_with(self.checker.build_hook) mock_set_hook.assert_called_once() @@ -156,8 +162,8 @@ class TestApiPyNativeSelfCheck(TestCase): mock_warning.reset_mock() Config.stage = Const.FORWARD with patch.object(logger, "info") as mock_info, \ - patch.object(api_register, "api_set_ori_func") as mock_set_ori, \ - patch.object(api_register, "api_set_hook_func") as mock_set_hook, \ + patch.object(_api_register, "restore_all_api") as mock_set_ori, \ + patch.object(_api_register, "register_all_api") as mock_set_hook, \ patch("msprobe.mindspore.free_benchmark.api_pynative_self_check.deal_fuzzed_and_original_result", return_value="ret"): args = (1.0, 1.0) diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_service.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_service.py index 912830ea1ab705aae63c69f5c240887d4b4ce5b7..96d76c17d3ce50cba87feb5a5f392f179758aad9 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_service.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_service.py @@ -21,12 +21,13 @@ from unittest.mock import MagicMock, patch from mindspore import nn, ops from msprobe.core.common.exceptions import MsprobeException -from msprobe.core.common.utils import Const, DumpPathAggregation +from msprobe.core.common.utils import Const +from msprobe.core.data_dump.api_registry import ApiRegistry from msprobe.core.data_dump.scope import BaseScope from msprobe.mindspore.cell_processor import CellProcessor from msprobe.mindspore.common.log import logger from msprobe.mindspore.common.utils import register_backward_hook_functions -from msprobe.mindspore.dump.hook_cell.api_registry import ApiRegistry, api_register +from msprobe.mindspore.dump.hook_cell.api_register import get_api_register from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell from msprobe.mindspore.dump.jit_dump import JitDump from msprobe.mindspore.service import Service @@ -49,7 +50,7 @@ class TestService(unittest.TestCase): self.service.primitive_hook_service = MagicMock() def tearDown(self) -> None: - api_register.api_set_ori_func() + get_api_register().restore_all_api() def test_init(self): self.assertEqual(self.service.config.level, "L0") @@ -197,7 +198,7 @@ class TestService(unittest.TestCase): @patch.object(Service, 'need_end_service', return_value=False) @patch.object(JitDump, 'set_config') @patch.object(JitDump, 'set_data_collector') - @patch.object(ApiRegistry, 'api_set_hook_func') + @patch.object(ApiRegistry, 'register_all_api') def test_start_with_jit_dump_enabled(self, mock_api_set_hook_func, mock_set_data_collector, mock_set_config, mock_need_end_service, mock_register_cell_hook, mock_register_primitive_hook): @@ -218,10 +219,9 @@ class TestService(unittest.TestCase): HOOKCell.cell_count = {"test_api": 1} JitDump.jit_count = {"test_api": 1} self.service.primitive_hook_service.primitive_counters = {"test_api": 1} - self.service.current_iter = 0 + self.service.loop = 0 self.service.step() - self.assertEqual(self.service.current_iter, 1) - self.service.data_collector.update_iter.assert_called_once_with(1) + self.assertEqual(self.service.loop, 1) self.service.data_collector.reset_status.assert_called_once() self.assertEqual(JitDump.jit_count, defaultdict(int)) self.assertEqual((self.service.primitive_hook_service.primitive_counters), {}) @@ -269,7 +269,7 @@ class TestService(unittest.TestCase): primitive_combined_name) @patch.object(ApiRegistry, 'initialize_hook') - @patch.object(ApiRegistry, 'api_set_hook_func') + @patch.object(ApiRegistry, 'register_all_api') @patch("msprobe.mindspore.service.logger.info") def test_register_hook_new_with_level_mix(self, mock_logger, mock_api_set_hook_func, mock_initialize_hook): self.service.config.level = Const.LEVEL_MIX diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_primitive_dump.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_primitive_dump.py index 3cafd49f2c101c45dbb65a08803dd77c6bca485d..79deeee08e13273f08f32be26a375d1d26f5d2f1 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_primitive_dump.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_primitive_dump.py @@ -84,9 +84,9 @@ class TestService(unittest.TestCase): self.assertEqual(self.service.primitive_hook_service.primitive_counters[primitive_name], 1) def test_step_updates_iteration(self): - initial_iter = self.service.current_iter + initial_iter = self.service.loop self.service.step() - self.assertEqual(self.service.current_iter, initial_iter + 1) + self.assertEqual(self.service.loop, initial_iter + 1) @patch.object(HOOKCell, 'cell_count', new_callable=lambda: defaultdict(int)) def test_step_resets_counters(self, _): @@ -96,12 +96,13 @@ class TestService(unittest.TestCase): self.assertEqual(self.service.primitive_hook_service.primitive_counters, {}) self.assertEqual(HOOKCell.cell_count, defaultdict(int)) - def test_step_calls_update_iter(self): - # 检查是否在调用 step 时调用了 update_iter + def test_start_calls_update_iter(self): + # 检查是否在调用 start 时调用了 update_iter with patch.object(self.service.data_collector, 'update_iter') as mock_update_iter: - initial_iter = self.service.current_iter - self.service.step() - mock_update_iter.assert_called_once_with(initial_iter + 1) + initial_iter = self.service.loop + init_step = self.service.init_step + self.service.start() + mock_update_iter.assert_called_once_with(initial_iter + init_step) class TestPrimitiveHookService(unittest.TestCase): diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py index 1ad191a0d4e85715e6199367d1d305c10a728630..8eb8fde4fdca88c97a4165f541f6dd6e7133303f 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py @@ -136,7 +136,7 @@ class TestMultiRunUT(unittest.TestCase): def setUp(self): self.test_json_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), "dump.json") - self.test_data = {'data': {'key1': 'TRUE', 'key2': 'TRUE', 'key3': 'TRUE'}} + self.test_data = {'dump_data_dir': '/test', 'data': {'key1': 'TRUE', 'key2': 'TRUE', 'key3': 'TRUE'}} self.test_json_content = json.dumps(self.test_data) self.forward_split_files_content = [ {'key1': 'TRUE', 'key2': 'TRUE'}, diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut_utils.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut_utils.py index 0cf30461aec70b85577c38ebed011bf9f818874d..8cead7b0093ce68dca8a12b0ea6dbcde78a70c0b 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut_utils.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut_utils.py @@ -1,13 +1,27 @@ -# coding=utf-8 +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import unittest -from unittest.mock import patch, MagicMock + import torch + from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import * from msprobe.core.common.file_utils import create_directory, write_csv class TestRunUtUtils(unittest.TestCase): - def setUp(self): save_path = "temp_save_path" create_directory(save_path) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/common/test_pt_utils.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/common/test_pt_utils.py index cdc922cc98d59b59ec0be85833d2000cd38913c8..42035932e56131b3eabd0977f72e9895b7accedc 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/common/test_pt_utils.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/common/test_pt_utils.py @@ -1,3 +1,18 @@ +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import os import io import unittest @@ -19,7 +34,7 @@ class TestParameterAdapter(unittest.TestCase): def setUp(self): self.func_mock = MagicMock() self.decorated_func = parameter_adapter(self.func_mock) - self.op_name_ = "__getitem__" + self.api_name = "__getitem__" def test_handle_masked_select_bfloat16(self): input_tensor = torch.tensor([1.0, 2.0], dtype=torch.bfloat16) @@ -45,7 +60,7 @@ class TestParameterAdapter(unittest.TestCase): self.assertTrue(torch.equal(result, torch.tensor([20.0, 30.0]))) def test_op_name_eq_with_none(self): - self.op_name_ = "__eq__" + self.api_name = "__eq__" args = (torch.tensor([1]), None) result = self.decorated_func(self, *args) self.assertFalse(result) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/dump/test_module_dump.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/dump/test_module_dump.py index 63d6abc3a2430bb6f092820c4b97a02cdf675612..5aaf0820a78339ff4f1cc5d28aff8762bae31a39 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/dump/test_module_dump.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/dump/test_module_dump.py @@ -18,8 +18,10 @@ from unittest.mock import patch, MagicMock import torch import torch.nn as nn + +from msprobe.core.data_dump.api_registry import ApiRegistry from msprobe.pytorch import PrecisionDebugger -from msprobe.pytorch.hook_module.api_registry import api_register +from msprobe.pytorch.hook_module.api_register import get_api_register from msprobe.pytorch.service import torch_version_above_or_equal_2 @@ -27,12 +29,12 @@ class TestModuleDumper(unittest.TestCase): @classmethod def setUpClass(cls): PrecisionDebugger._instance = None - api_register.api_originality() + get_api_register().restore_all_api() @classmethod def tearDownClass(cls): PrecisionDebugger._instance = None - api_register.api_originality() + get_api_register().restore_all_api() def setUp(self): self.module = nn.Linear(8, 4) @@ -41,7 +43,7 @@ class TestModuleDumper(unittest.TestCase): def test_stop_module_dump(self): self.module_dumper.hook_handle_list.extend([1, 2, 3]) - with patch('msprobe.pytorch.dump.module_dump.module_dump.api_register') as mock_api_register: + with patch.object(ApiRegistry, 'register_all_api') as mock_api_register: mock_handle1 = MagicMock(spec=torch.utils.hooks.RemovableHandle) mock_handle2 = MagicMock(spec=torch.utils.hooks.RemovableHandle) self.module_dumper.hook_handle_list.extend([mock_handle1, mock_handle2]) @@ -50,7 +52,7 @@ class TestModuleDumper(unittest.TestCase): mock_handle1.remove.assert_called_once() mock_handle2.remove.assert_called_once() self.assertEqual(self.module_dumper.hook_handle_list, []) - mock_api_register.api_modularity.assert_called_once() + mock_api_register.assert_called_once() def test_register_hook(self): self.module_dumper.register_hook(self.module, "TestModule") diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_api_registry.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_api_registry.py deleted file mode 100644 index 837ad23df76be2a012a7408dab4879847937f229..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_api_registry.py +++ /dev/null @@ -1,130 +0,0 @@ -import unittest -from msprobe.pytorch.hook_module.api_registry import ApiRegistry, torch_version_above_2, is_gpu - - -class TestApiRegistry(unittest.TestCase): - - def test_store_ori_attr(self): - class A(): - a1 = 1 - class B(): - a = A() - b1 = 1 - b2 = 2 - - api_list = ["a.a1", "b1", "b2"] - expect_output = {"a.a1":1, "b1":1, "b2":2} - actual_output = dict() - ApiRegistry.store_ori_attr(B, api_list, actual_output) - self.assertEqual(actual_output, expect_output) - - - def test_set_api_attr(self): - class A(): - a1 = 1 - class B(): - a = A().__class__ - b1 = 1 - - attr_dict = {"a.a2":2, "b2":2, "b3":3} - ApiRegistry.set_api_attr(B, attr_dict) - - for k, v in attr_dict.items(): - if '.' in k: - sub_module_name, sub_op = k.rsplit('.', 1) - sub_module = getattr(B, sub_module_name, None) - - self.assertEqual(getattr(sub_module, sub_op), v) - else: - self.assertEqual(getattr(B, k), v) - - def test_api_modularity(self): - - import torch - import torch.distributed as dist - #import torch_npu #门禁没有安装torch_npu - from msprobe.pytorch.hook_module.api_registry import torch_without_guard_version, npu_distributed_api, is_gpu, torch_version_above_2 - - - - reg = ApiRegistry() - attr_dict = {"b2":2, "b3":3} - reg.tensor_hook_attr = attr_dict - reg.torch_hook_attr = attr_dict - reg.functional_hook_attr = attr_dict - reg.distributed_hook_attr = attr_dict - reg.npu_distributed_hook_attr = attr_dict - reg.aten_hook_attr = attr_dict - reg.vf_hook_attr = attr_dict - reg.torch_npu_hook_attr = attr_dict - - reg.api_modularity() - self.assertEqual(torch.Tensor.b2, 2) - - self.assertEqual(torch.b2, 2) - self.assertEqual(torch.nn.functional.b2, 2) - self.assertEqual(dist.b2, 2) - self.assertEqual(dist.distributed_c10d.b2, 2) - #if not is_gpu and not torch_without_guard_version: - #self.assertEqual(torch_npu.distributed.b2, 2) - #self.assertEqual(torch_npu.distributed.distributed_c10d.b2, 2) - if torch_version_above_2: - self.assertEqual(torch.ops.aten.b2, 2) - self.assertEqual(torch._VF.b2, 2) - #if not is_gpu: - #self.assertEqual(torch_npu.b2, 2) - - - def test_api_originality(self): - import torch - import torch.distributed as dist - #import torch_npu #门禁没有安装torch_npu - from msprobe.pytorch.hook_module.api_registry import torch_without_guard_version, npu_distributed_api, is_gpu, torch_version_above_2 - - - - reg = ApiRegistry() - attr_dict = {"b2":2, "b3":3} - reg.tensor_hook_attr = attr_dict - reg.torch_hook_attr = attr_dict - reg.functional_hook_attr = attr_dict - reg.distributed_hook_attr = attr_dict - reg.npu_distributed_hook_attr = attr_dict - reg.aten_hook_attr = attr_dict - reg.vf_hook_attr = attr_dict - reg.torch_npu_hook_attr = attr_dict - - reg.api_originality() - self.assertEqual(torch.Tensor.b2, 2) - - self.assertEqual(torch.b2, 2) - self.assertEqual(torch.nn.functional.b2, 2) - self.assertEqual(dist.b2, 2) - self.assertEqual(dist.distributed_c10d.b2, 2) - #if not is_gpu and not torch_without_guard_version: - #self.assertEqual(torch_npu.distributed.b2, 2) - #self.assertEqual(torch_npu.distributed.distributed_c10d.b2, 2) - if torch_version_above_2: - self.assertEqual(torch.ops.aten.b2, 2) - self.assertEqual(torch._VF.b2, 2) - #if not is_gpu: - #self.assertEqual(torch_npu.b2, 2) - - def test_initialize_hook(self): - def hook_test(): - pass - - reg = ApiRegistry() - reg.initialize_hook(hook_test) - empty_list = [] - self.assertFalse(empty_list==reg.tensor_hook_attr) - self.assertFalse(empty_list==reg.torch_hook_attr) - self.assertFalse(empty_list==reg.functional_hook_attr) - self.assertFalse(empty_list==reg.distributed_hook_attr) - self.assertFalse(empty_list==reg.npu_distributed_hook_attr) - if torch_version_above_2: - #print(True) - self.assertFalse(empty_list==reg.aten_hook_attr) - if not is_gpu: - #print(True) - self.assertFalse(empty_list==reg.torch_npu_hook_attr) \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py deleted file mode 100644 index 246feb56becf9942de9214f5b24b8471e9b4024a..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +++ /dev/null @@ -1,41 +0,0 @@ -import unittest -import torch.distributed as dist -from msprobe.pytorch.hook_module.wrap_distributed import * - -class TestWrapDistributed(unittest.TestCase): - def hook(name, prefix): - def forward_pre_hook(nope, input, kwargs): - return input, kwargs - - def forward_hook(nope, input, kwargs, result): - return 2 - - def backward_hook(): - pass - - def forward_hook_torch_version_below_2(): - pass - - return forward_pre_hook, forward_hook, backward_hook, forward_hook_torch_version_below_2 - - def test_get_distributed_ops(self): - ops = get_distributed_ops() - self.assertIsInstance(ops, set) - - def test_DistributedOPTemplate(self): - self.setUp() - op_name = 'all_reduce' - if op_name in get_distributed_ops(): - op = DistributedOPTemplate(op_name, self.hook) - self.assertEqual(op.op_name_, op_name) - - def test_wrap_distributed_op(self): - op_name = 'all_reduce' - if op_name in get_distributed_ops(): - wrapped_op = wrap_distributed_op(op_name, self.hook) - self.assertTrue(callable(wrapped_op)) - - def test_wrap_distributed_ops_and_bind(self): - wrap_distributed_ops_and_bind(self.hook) - for op_name in get_distributed_ops(): - self.assertTrue(hasattr(HOOKDistributedOP, "wrap_" + str(op_name))) \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py deleted file mode 100644 index 282551e3cefdb2ae63efda284f5e7ae7482ae81c..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +++ /dev/null @@ -1,73 +0,0 @@ -import unittest -import torch -import torch.nn.functional as F -from msprobe.pytorch.hook_module.wrap_functional import get_functional_ops, \ - wrap_functional_ops_and_bind, HOOKFunctionalOP -from msprobe.pytorch.common.utils import remove_dropout - - -class TestDropoutFunctions(unittest.TestCase): - - def setUp(self): - self.input_tensor = torch.ones(10, 10) - remove_dropout() - - def test_function_dropout_no_dropout(self): - output = F.dropout(self.input_tensor, p = 0., training = True) - self.assertTrue(torch.equal(self.input_tensor, output)) - - def test_function_dropout_train_vs_eval(self): - output_train = F.dropout(self.input_tensor, p = 0., training = True) - output_eval = F.dropout(self.input_tensor, p = 0., training = False) - self.assertTrue(torch.equal(output_train, output_eval)) - - def test_function_dropout_invalid_probability(self): - with self.assertRaises(ValueError): - F.dropout(self.input_tensor, p = -0.1) - with self.assertRaises(ValueError): - F.dropout(self.input_tensor, p = 1.1) - - def test_function_dropout2d_no_dropout(self): - output = F.dropout2d(self.input_tensor, p = 0., training = True) - self.assertTrue(torch.equal(self.input_tensor, output)) - - def test_function_dropout2d_train_vs_eval(self): - output_train = F.dropout2d(self.input_tensor, p = 0., training = True) - output_eval = F.dropout2d(self.input_tensor, p = 0., training = False) - self.assertTrue(torch.equal(output_train, output_eval)) - - def test_function_dropout2d_invalid_probability(self): - with self.assertRaises(ValueError): - F.dropout2d(self.input_tensor, p = -0.1) - with self.assertRaises(ValueError): - F.dropout2d(self.input_tensor, p = 1.1) - - def test_function_dropout3d_no_dropout(self): - input_tensor_3d = self.input_tensor.unsqueeze(0) - output = F.dropout3d(input_tensor_3d, p = 0., training = True) - self.assertTrue(torch.equal(input_tensor_3d, output)) - - def test_function_dropout3d_train_vs_eval(self): - input_tensor_3d = self.input_tensor.unsqueeze(0) - output_train = F.dropout3d(input_tensor_3d, p = 0., training = True) - output_eval = F.dropout3d(input_tensor_3d, p = 0., training = False) - self.assertTrue(torch.equal(output_train, output_eval)) - - def test_function_dropout3d_invalid_probability(self): - input_tensor_3d = self.input_tensor.unsqueeze(0) - with self.assertRaises(ValueError): - F.dropout3d(input_tensor_3d, p = -0.1) - with self.assertRaises(ValueError): - F.dropout3d(input_tensor_3d, p = 1.1) - - -class TestWrapFunctional(unittest.TestCase): - - def test_get_functional_ops(self): - expected_ops = {'relu', 'sigmoid', 'softmax'} - actual_ops = get_functional_ops() - self.assertTrue(expected_ops.issubset(actual_ops)) - - def test_wrap_functional_ops_and_bind(self): - wrap_functional_ops_and_bind(None) - self.assertTrue(hasattr(HOOKFunctionalOP, 'wrap_relu')) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_npu_custom.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_npu_custom.py deleted file mode 100644 index 573d6d000f37f429619b89507cecd1258fbe4c8b..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_npu_custom.py +++ /dev/null @@ -1,43 +0,0 @@ -import unittest -from unittest.mock import MagicMock, patch - -from msprobe.core.common.const import Const -from msprobe.core.common.log import logger -from msprobe.pytorch.function_factory import npu_custom_functions -from msprobe.pytorch.hook_module.wrap_npu_custom import NpuOPTemplate - -try: - import torch_npu -except ImportError: - logger.info("Failing to import torch_npu.") - - -class TestNpuOPTemplate(unittest.TestCase): - - def setUp(self): - self.mock_hook = MagicMock(return_value=(MagicMock(), MagicMock(), MagicMock(), None)) - self.template = NpuOPTemplate("sum", self.mock_hook) - - def test_init(self): - self.assertEqual(self.template.op_name_, "sum") - self.assertEqual(self.template.prefix_op_name_, f"NPU{Const.SEP}sum{Const.SEP}") - self.assertTrue(self.template.need_hook) - self.assertEqual(self.template.device, Const.CPU_LOWERCASE) - - @patch('torch.ops.npu.sum') - def test_forward_without_hook(self, mock_npu_sum): - self.template.need_hook = False - npu_custom_functions["sum"] = MagicMock(return_value="output_from_custom") - - result = self.template.forward(1, 2, key='value') - self.assertEqual(result, "output_from_custom") - mock_npu_sum.assert_not_called() - - @patch('torch.ops.npu.sum') - def test_forward_with_hook(self, mock_npu_sum): - self.template.need_hook = True - mock_npu_sum.return_value = "output_from_npu" - - result = self.template.forward(1, 2, key='value') - self.assertEqual(result, "output_from_npu") - mock_npu_sum.assert_called_once_with(1, 2, key='value') diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py deleted file mode 100644 index 6868c5bda7a88c84702d15e995c7f60af2b4e4c5..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +++ /dev/null @@ -1,40 +0,0 @@ -import unittest -import torch -from msprobe.pytorch.hook_module.wrap_tensor import get_tensor_ops, HOOKTensor, TensorOPTemplate, wrap_tensor_op, wrap_tensor_ops_and_bind - -class TestWrapTensor(unittest.TestCase): - - def hook(name, prefix): - def forward_pre_hook(nope, input, kwargs): - return input, kwargs - - def forward_hook(nope, input, kwargs, result): - return 2 - - def backward_hook(): - pass - - def forward_hook_torch_version_below_2(): - pass - - return forward_pre_hook, forward_hook, backward_hook, forward_hook_torch_version_below_2 - - def test_get_tensor_ops(self): - result = get_tensor_ops() - self.assertIsInstance(result, set) - - def test_HOOKTensor(self): - hook_tensor = HOOKTensor() - self.assertIsInstance(hook_tensor, HOOKTensor) - - def test_TensorOPTemplate(self): - tensor_op_template = TensorOPTemplate('add', self.hook) - self.assertTrue(tensor_op_template.op_name_, 'add') - - def test_wrap_tensor_op(self): - wrapped_op = wrap_tensor_op('add', self.hook) - self.assertTrue(callable(wrapped_op)) - - def test_wrap_tensor_ops_and_bind(self): - wrap_tensor_ops_and_bind(self.hook) - self.assertTrue(hasattr(HOOKTensor, 'wrap_add')) \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py deleted file mode 100644 index e0e4d000c0bd83be4facbbb406357427faf875ec..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +++ /dev/null @@ -1,48 +0,0 @@ -import unittest -import torch -from msprobe.pytorch.hook_module.wrap_torch import * - -class TestWrapTorch(unittest.TestCase): - - def hook(name, prefix): - def forward_pre_hook(nope, input, kwargs): - return input, kwargs - - def forward_hook(nope, input, kwargs, result): - return 2 - - def backward_hook(): - pass - - def forward_hook_torch_version_below_2(): - pass - - return forward_pre_hook, forward_hook, backward_hook, forward_hook_torch_version_below_2 - - def setUp(self): - - self.op_name = 'add' - self.torch_op = wrap_torch_op(self.op_name, self.hook) - - def test_get_torch_ops(self): - self.setUp() - ops = get_torch_ops() - self.assertIsInstance(ops, set) - self.assertIn(self.op_name, ops) - - def test_TorchOPTemplate(self): - self.setUp() - template = TorchOPTemplate(self.op_name, self.hook) - self.assertEqual(template.op_name_, self.op_name) - self.assertEqual(template.prefix_op_name_, "Torch." + str(self.op_name) + ".") - - def test_forward(self): - self.setUp() - template = TorchOPTemplate(self.op_name, self.hook) - result = template.forward(torch.tensor([1, 2, 3]), torch.tensor([4, 5, 6])) - torch.testing.assert_close(result, torch.tensor([5, 7, 9])) - - def test_wrap_torch_ops_and_bind(self): - self.setUp() - wrap_torch_ops_and_bind(self.hook) - self.assertTrue(hasattr(HOOKTorchOP, "wrap_" + self.op_name)) \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py deleted file mode 100644 index 98efb4bc5b8a30284fe820124e48af7f487d1c54..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +++ /dev/null @@ -1,11 +0,0 @@ -import unittest -import torch -from msprobe.pytorch.hook_module import wrap_vf - -class TestWrapVF(unittest.TestCase): - def setUp(self): - self.hook = lambda x: x - - def test_get_vf_ops(self): - ops = wrap_vf.get_vf_ops() - self.assertIsInstance(ops, list) \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/demo_model.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/demo_model.py index f5de419440224cca261b62df2495e8ce28b8e2d4..820b1f7476d3d92288069bc00ac798c44bf14da6 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/demo_model.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/demo_model.py @@ -1,7 +1,25 @@ +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import torch import torch.nn.functional as F from msprobe.pytorch import TrainerMon from msprobe.pytorch.common import seed_all +from msprobe.pytorch.hook_module.api_register import get_api_register + +get_api_register().restore_all_api() device = torch.device('cpu') dtype_float32 = torch.float32 diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_csv2tb.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_csv2tb.py index f2bc82ffafc2a1f10719d4a46669bc0050c12782..4178e2ef8fbfb2c2bafa90b32fa92d622b95e3cd 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_csv2tb.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_csv2tb.py @@ -1,3 +1,18 @@ +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import os import shutil import random @@ -11,6 +26,9 @@ from tensorboard.backend.event_processing.event_accumulator import EventAccumula from msprobe.pytorch import TrainerMon from msprobe.core.common.const import MonitorConst from msprobe.pytorch.monitor.csv2tb import parse_step_fn, csv2tensorboard_by_step +from msprobe.pytorch.hook_module.api_register import get_api_register + +get_api_register().restore_all_api() base_dir = os.path.dirname(os.path.realpath(__file__)) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_module_hook.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_module_hook.py index eefacb73c8e76636086554775b0e6f2e916ddf6e..66d016f9487a4e7f7fc747dfb021b1f887c51f4a 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_module_hook.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_module_hook.py @@ -1,3 +1,18 @@ +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import os.path import shutil import unittest @@ -8,10 +23,13 @@ import torch from msprobe.core.common.const import MonitorConst, Const from torch import distributed as dist +from msprobe.pytorch import TrainerMon +from msprobe.pytorch.hook_module.api_register import get_api_register from msprobe.pytorch.monitor.module_hook import CommunicationContext, GradContext, ModuleHookContext, \ param_is_not_tensor_parallel_duplicate, param_is_data_parallel_duplicate from msprobe.test.pytorch_ut.monitor.demo_model import monitor_demo -from msprobe.pytorch import TrainerMon + +get_api_register().restore_all_api() base_dir = os.path.dirname(os.path.realpath(__file__)) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/test_pt_config.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/test_pt_config.py index c1b8bac47fda100636b55fbc5ad452c2843e8aaa..d5c7fdea28abe18a47594eaa1bea750f67334a1f 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/test_pt_config.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/test_pt_config.py @@ -268,7 +268,7 @@ class TestFreeBenchmarkCheckConfig(unittest.TestCase): invalid_config["fuzz_device"] = "cpu" invalid_config["pert_mode"] = "INVALID_CPU_MODE" config = FreeBenchmarkCheckConfig(invalid_config) - self.assertIn("You neet to and can only set fuzz_device as ", str(mock_error.call_args)) + self.assertIn("You need to and can only set fuzz_device as ", str(mock_error.call_args)) @patch('msprobe.core.common.log.logger.error_log_with_exp') def test_check_handler_type_invalid(self, mock_error): diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/test_service.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/test_service.py index 6687f3111050ea53e14e62f3afd55ae1eff2b8c0..f0da7c467f63cd95e3e8585816aba374912d4347 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/test_service.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/test_service.py @@ -1,7 +1,23 @@ +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import unittest from unittest.mock import patch, mock_open, MagicMock from msprobe.core.common.utils import Const +from msprobe.core.data_dump.api_registry import ApiRegistry from msprobe.pytorch.debugger.debugger_config import DebuggerConfig from msprobe.pytorch.pt_config import parse_json_config from msprobe.pytorch.service import Service @@ -67,7 +83,7 @@ class TestService(unittest.TestCase): def test_step_success(self): self.service.step() - self.assertEqual(self.service.current_iter, 1) + self.assertEqual(self.service.loop, 1) def test_step_fail(self): self.service.should_stop_service = True @@ -87,8 +103,8 @@ class TestService(unittest.TestCase): self.service.build_hook = MagicMock() self.config.level = "L1" with patch("msprobe.pytorch.service.logger.info_on_rank_0") as mock_logger, \ - patch("msprobe.pytorch.service.api_register.initialize_hook") as mock_init_hook, \ - patch("msprobe.pytorch.service.api_register.api_modularity") as mock_api_modularity: + patch.object(ApiRegistry, "initialize_hook") as mock_init_hook, \ + patch.object(ApiRegistry, 'register_all_api') as mock_api_modularity: self.service.register_api_hook() self.assertEqual(mock_logger.call_count, 1) mock_init_hook.assert_called_once() diff --git a/debug/accuracy_tools/msprobe/test/resources/layer_mapping/mindspore/dump.json b/debug/accuracy_tools/msprobe/test/resources/layer_mapping/mindspore/dump.json index b55f9e0699fe6329ceeb09a51fe20118c65545e7..153d84e7d117b5be89dfdb522edc39dc066929cb 100644 --- a/debug/accuracy_tools/msprobe/test/resources/layer_mapping/mindspore/dump.json +++ b/debug/accuracy_tools/msprobe/test/resources/layer_mapping/mindspore/dump.json @@ -1,6 +1,7 @@ { "task": "statistics", "level": "mix", + "framework": "mindspore", "dump_data_dir": null, "data": { "Cell.network_with_loss.module.language_model.embedding.word_embeddings.VocabParallelEmbedding.forward.0": { diff --git a/debug/accuracy_tools/msprobe/test/resources/layer_mapping/pytorch/dump.json b/debug/accuracy_tools/msprobe/test/resources/layer_mapping/pytorch/dump.json index d7dd1c0c38e2d24c8b0d19c346a50eb33437d232..02239176a9d690c4ce70c06cc6ab117a3c122811 100644 --- a/debug/accuracy_tools/msprobe/test/resources/layer_mapping/pytorch/dump.json +++ b/debug/accuracy_tools/msprobe/test/resources/layer_mapping/pytorch/dump.json @@ -1,6 +1,7 @@ { "task": "statistics", "level": "mix", + "framework": "pytorch", "dump_data_dir": null, "data": { "Module.module.module.language_model.embedding.word_embeddings.VocabParallelEmbedding.forward.0": { diff --git a/debug/accuracy_tools/msprobe/test/visualization_ut/builder/test_graph_builder.py b/debug/accuracy_tools/msprobe/test/visualization_ut/builder/test_graph_builder.py index 706dc8bf82e59f413c3fd559a39af89c6a70be47..9b69e8bc2a740937eb4f6a93072486997f51d8c6 100644 --- a/debug/accuracy_tools/msprobe/test/visualization_ut/builder/test_graph_builder.py +++ b/debug/accuracy_tools/msprobe/test/visualization_ut/builder/test_graph_builder.py @@ -111,3 +111,23 @@ class TestGraphBuilder(unittest.TestCase): self.assertEqual(graph.root.subnodes[2].op, NodeOp.module) self.assertEqual(len(graph.root.subnodes[0].subnodes), 0) self.assertEqual(graph.root.subnodes[0].id, 'Module.a.0') + + def test_add_parameters_grad(self): + graph = Graph('TestNet') + graph.add_node(NodeOp.module, 'Module.a.backward.0', graph.root) + graph.add_node(NodeOp.module, 'Module.b.backward.0', graph.root) + graph.add_node(NodeOp.module, 'Module.a.backward.1', graph.root) + graph.add_node(NodeOp.module, 'Module.aa.backward.0', graph.get_node('Module.a.backward.0')) + graph.add_node(NodeOp.module, 'Module.aaa.backward.0', graph.get_node('Module.a.backward.0')) + graph.add_node(NodeOp.module, 'Module.aa.backward.1', graph.get_node('Module.a.backward.1')) + graph.add_node(NodeOp.module, 'Module.aaa.backward.1', graph.get_node('Module.a.backward.1')) + + data_dict = {'Module.a.parameters_grad': {}, 'Module.aaa.parameters_grad': {}} + GraphBuilder._add_parameters_grad(graph, data_dict) + root_nodes_id = [node.id for node in graph.get_node('TestNet').subnodes] + sub_nodes_id0 = [node.id for node in graph.get_node('Module.a.backward.0').subnodes] + sub_nodes_id1 = [node.id for node in graph.get_node('Module.a.backward.1').subnodes] + + self.assertEqual(root_nodes_id[-1], 'Module.a.backward.1') + self.assertEqual(sub_nodes_id0[-1], 'Module.aaa.backward.0') + self.assertEqual(sub_nodes_id1[-1], 'Module.a.parameters_grad') diff --git a/debug/accuracy_tools/msprobe/test/visualization_ut/compare/test_mode_adapter.py b/debug/accuracy_tools/msprobe/test/visualization_ut/compare/test_mode_adapter.py index 87d1f9ee5f01c7c9b2f264f3e6ec16b5155c1f8e..4c38e4e6200219e59663e7f14b9cbc521dda0977 100644 --- a/debug/accuracy_tools/msprobe/test/visualization_ut/compare/test_mode_adapter.py +++ b/debug/accuracy_tools/msprobe/test/visualization_ut/compare/test_mode_adapter.py @@ -2,7 +2,8 @@ import json import unittest from unittest.mock import patch, MagicMock from msprobe.visualization.compare.mode_adapter import ModeAdapter -from msprobe.visualization.graph.base_node import BaseNode, NodeOp +from msprobe.visualization.graph.base_node import BaseNode +from msprobe.visualization.graph.node_op import NodeOp from msprobe.visualization.utils import GraphConst, ToolTip from msprobe.core.common.const import CompareConst diff --git a/debug/accuracy_tools/msprobe/test/visualization_ut/graph/test_base_node.py b/debug/accuracy_tools/msprobe/test/visualization_ut/graph/test_base_node.py index 480b95620e6a81577d825b7af55b45fc0a04c34c..64b7101c6b036113e018faec649974753acdaec3 100644 --- a/debug/accuracy_tools/msprobe/test/visualization_ut/graph/test_base_node.py +++ b/debug/accuracy_tools/msprobe/test/visualization_ut/graph/test_base_node.py @@ -1,6 +1,6 @@ import unittest -from msprobe.visualization.graph.base_node import BaseNode, NodeOp -from msprobe.visualization.utils import GraphConst +from msprobe.visualization.graph.base_node import BaseNode +from msprobe.visualization.graph.node_op import NodeOp class TestBaseNode(unittest.TestCase): diff --git a/debug/accuracy_tools/msprobe/test/visualization_ut/graph/test_graph.py b/debug/accuracy_tools/msprobe/test/visualization_ut/graph/test_graph.py index 81f9fdca5277de6e1670da409bcf93e56ece3206..f1c4ee95567e9021cc30e2e2f55702751e27bb94 100644 --- a/debug/accuracy_tools/msprobe/test/visualization_ut/graph/test_graph.py +++ b/debug/accuracy_tools/msprobe/test/visualization_ut/graph/test_graph.py @@ -54,17 +54,6 @@ class TestGraph(unittest.TestCase): matched_node, ancestors = Graph.match(graph_a, graph_a.get_node("node_id_a_1"), graph_b) self.assertIsNotNone(matched_node) self.assertEqual(ancestors, ['node_id_a']) - - def test_dfs(self): - graph = Graph("model_name") - graph.add_node(NodeOp.module, "node_a") - graph.add_node(NodeOp.module, "node_b") - node_a = BaseNode(self.node_op, self.node_id) - result = {} - graph.dfs(node_a, result) - self.assertEqual(result, {'node_id': {'id': 'node_id', 'node_type': 0, 'data': {}, - 'output_data': {}, 'input_data': {}, 'upnode': 'None', 'subnodes': [], - 'matched_node_link': [], 'suggestions': {}, 'stack_info': []}}) def test_split_nodes_by_micro_step(self): nodes = [BaseNode(NodeOp.module, 'a.forward.0'), BaseNode(NodeOp.module, 'a.backward.0'), diff --git a/debug/accuracy_tools/msprobe/visualization/builder/graph_builder.py b/debug/accuracy_tools/msprobe/visualization/builder/graph_builder.py index 814882e6b819e9e6b6b421aec5f8f0b89f03f7c6..0b3305e8c6a3bc7fba955e4231eb8b2258191248 100644 --- a/debug/accuracy_tools/msprobe/visualization/builder/graph_builder.py +++ b/debug/accuracy_tools/msprobe/visualization/builder/graph_builder.py @@ -27,8 +27,8 @@ from msprobe.visualization.utils import save_json_file, GraphConst class GraphBuilder: backward_pattern = re.compile(r"(\.backward\.)(\d+)$") forward_pattern = re.compile(r"(\.forward\.)(\d+)$") - # 匹配以大写字母开头,后接任意字母,并以Template(结尾 - template_pattern = re.compile(r'\b[A-Z][a-zA-Z]*Template\(') + # 匹配以大写字母开头,后接任意字母,并以Template(结尾,或包含api_template(的字符串 + template_pattern = re.compile(r'\b([A-Z][a-zA-Z]*Template|api_template)\(') @staticmethod def build(construct_path, data_path, stack_path, model_name='DefaultModel', complete_stack=False): @@ -51,6 +51,7 @@ class GraphBuilder: graph = Graph(model_name, data_path=dump_dict.get('dump_data_dir', ''), dump_data=data_dict) GraphBuilder._init_nodes(graph, construct_dict, data_dict, stack_dict) GraphBuilder._collect_apis_between_modules(graph) + GraphBuilder._add_parameters_grad(graph, data_dict) return graph @staticmethod @@ -235,6 +236,44 @@ class GraphBuilder: graph.root.subnodes = output + @staticmethod + def _add_parameters_grad(graph, data_dict): + """ + 将parameters_grad信息添加到graph中, + 对应模块的parameters_grad节点添加到对应模块的最后一次backward节点(backward计数最大)内作为子节点 + + 例如,graph有节点Module.a.backward.0, Module.a.backward.1, Module.a.backward.2 + 则Module.a.parameters_grad添加在Module.a.backward.2内作为子节点 + """ + prefixes = [] + suffix = Const.SEP + Const.PARAMS_GRAD + for node_id in data_dict.keys(): + if node_id not in graph.node_map and node_id.endswith(suffix): + prefixes.append(node_id.replace(suffix, '')) + + max_info = {prefix: 0 for prefix in prefixes} + + for key in graph.node_map.keys(): + for prefix in prefixes: + # 构建正则表达式,匹配以 "backward.数字" 结尾的键 + pattern = re.compile(r'^' + re.escape(prefix) + r'\.backward\.(\d+)$') + match = pattern.match(key) + if match: + num = int(match.group(1)) + if num > max_info[prefix]: + max_info[prefix] = num + + for prefix, num in max_info.items(): + node_id = prefix + Const.SEP + Const.BACKWARD + Const.SEP + str(num) + node = graph.get_node(node_id) + if node: + parameters_grad_node_id = graph.add_node(NodeOp.module, prefix + suffix, up_node=node) + # 添加输入输出数据 + node_data = data_dict.get(parameters_grad_node_id, {}) + input_data, output_data = get_input_output(node_data, parameters_grad_node_id) + # 更新数据 + graph.get_node(parameters_grad_node_id).set_input_output(input_data, output_data) + class GraphExportConfig: def __init__(self, graph_n, graph_b=None, tool_tip=None, node_colors=None, micro_steps=None, task='', diff --git a/debug/accuracy_tools/msprobe/visualization/compare/graph_comparator.py b/debug/accuracy_tools/msprobe/visualization/compare/graph_comparator.py index 902d721a8d1047b687b878eb45a802a1df4154bd..3f695d23483c8980c958995a36025b0514877cf9 100644 --- a/debug/accuracy_tools/msprobe/visualization/compare/graph_comparator.py +++ b/debug/accuracy_tools/msprobe/visualization/compare/graph_comparator.py @@ -17,12 +17,14 @@ import re from msprobe.visualization.builder.msprobe_adapter import compare_node, get_compare_mode, run_real_data from msprobe.visualization.utils import GraphConst, load_json_file, load_data_json_file, get_csv_df from msprobe.visualization.graph.graph import Graph, NodeOp -from msprobe.visualization.graph.node_colors import NodeColors from msprobe.visualization.compare.mode_adapter import ModeAdapter from msprobe.core.common.const import Const +from msprobe.core.common.utils import recursion_depth_decorator class GraphComparator: + MAX_DEPTH = 1000 + def __init__(self, graphs, dump_path_param, args, mapping_dict=None): self.graph_n = graphs[0] self.graph_b = graphs[1] @@ -41,7 +43,7 @@ class GraphComparator: else: self._compare_nodes(self.graph_n.root) self._postcompare() - + def add_compare_result_to_node(self, node, compare_result_list): """ 将比对结果添加到节点的输入输出数据中 @@ -66,43 +68,8 @@ class GraphComparator: self.ma.parse_result(node, [compare_in_dict, compare_out_dict])) node.data[GraphConst.JSON_INDEX_KEY] = precision_index node.data.update(other_dict) - - def _parse_param(self, dump_path_param, output_path): - self.dump_path_param = dump_path_param - self.output_path = output_path - compare_mode = get_compare_mode(self.dump_path_param) - self.ma = ModeAdapter(compare_mode) - self.data_n_dict = load_data_json_file(dump_path_param.get('npu_json_path')) - self.data_b_dict = load_data_json_file(dump_path_param.get('bench_json_path')) - self.stack_json_data = load_json_file(dump_path_param.get('stack_json_path')) - - def _postcompare(self): - self._handle_api_collection_index() - if not self.ma.compare_mode == GraphConst.REAL_DATA_COMPARE: - return - df = get_csv_df(True, self.ma.csv_data, self.ma.compare_mode) - df = run_real_data(self.dump_path_param, df, self.framework, True if self.mapping_dict else False) - compare_data_dict = {row[0]: row.tolist() for _, row in df.iterrows()} - for node in self.ma.compare_nodes: - precision_index, _ = self.ma.parse_result(node, [compare_data_dict]) - node.data[GraphConst.JSON_INDEX_KEY] = precision_index - - def _handle_api_collection_index(self): - """ - api集合的指标, md5模式使用集合中所有api最小的指标,statistics和tensor模式使用集合中所有api最大的指标 - md5模式下指标为0代表最差,statistics和tensor模式下指标为1代表最差 - """ - for node in self.graph_n.root.subnodes: - if node.op == NodeOp.api_collection: - precision_index = GraphConst.MAX_INDEX_KEY if self.ma.compare_mode == GraphConst.MD5_COMPARE \ - else GraphConst.MIN_INDEX_KEY - for api in node.subnodes: - precision_index = min(precision_index, - api.data.get(GraphConst.JSON_INDEX_KEY, GraphConst.MAX_INDEX_KEY)) \ - if self.ma.compare_mode == GraphConst.MD5_COMPARE \ - else max(precision_index, api.data.get(GraphConst.JSON_INDEX_KEY, GraphConst.MIN_INDEX_KEY)) - node.data[GraphConst.JSON_INDEX_KEY] = precision_index + @recursion_depth_decorator('GraphComparator._compare_nodes', max_depth=MAX_DEPTH) def _compare_nodes(self, node_n): """ 递归遍历NPU树中的节点,如果在Bench中找到具有相同名称的节点,检查他们的祖先和参数信息,检查一致则及逆行精度数据对比 @@ -126,6 +93,7 @@ class GraphComparator: for subnode in node_n.subnodes: self._compare_nodes(subnode) + @recursion_depth_decorator('GraphComparator._compare_nodes_fuzzy', max_depth=MAX_DEPTH) def _compare_nodes_fuzzy(self, node_n): if node_n.op != NodeOp.function_api: # 模块经过模糊匹配 @@ -146,6 +114,42 @@ class GraphComparator: for sub_node in node_n.subnodes: self._compare_nodes_fuzzy(sub_node) + def _parse_param(self, dump_path_param, output_path): + self.dump_path_param = dump_path_param + self.output_path = output_path + compare_mode = get_compare_mode(self.dump_path_param) + self.ma = ModeAdapter(compare_mode) + self.data_n_dict = load_data_json_file(dump_path_param.get('npu_json_path')) + self.data_b_dict = load_data_json_file(dump_path_param.get('bench_json_path')) + self.stack_json_data = load_json_file(dump_path_param.get('stack_json_path')) + + def _postcompare(self): + self._handle_api_collection_index() + if not self.ma.compare_mode == GraphConst.REAL_DATA_COMPARE: + return + df = get_csv_df(True, self.ma.csv_data, self.ma.compare_mode) + df = run_real_data(self.dump_path_param, df, self.framework, True if self.mapping_dict else False) + compare_data_dict = {row[0]: row.tolist() for _, row in df.iterrows()} + for node in self.ma.compare_nodes: + precision_index, _ = self.ma.parse_result(node, [compare_data_dict]) + node.data[GraphConst.JSON_INDEX_KEY] = precision_index + + def _handle_api_collection_index(self): + """ + api集合的指标, md5模式使用集合中所有api最小的指标,statistics和tensor模式使用集合中所有api最大的指标 + md5模式下指标为0代表最差,statistics和tensor模式下指标为1代表最差 + """ + for node in self.graph_n.root.subnodes: + if node.op == NodeOp.api_collection: + precision_index = GraphConst.MAX_INDEX_KEY if self.ma.compare_mode == GraphConst.MD5_COMPARE \ + else GraphConst.MIN_INDEX_KEY + for api in node.subnodes: + precision_index = min(precision_index, + api.data.get(GraphConst.JSON_INDEX_KEY, GraphConst.MAX_INDEX_KEY)) \ + if self.ma.compare_mode == GraphConst.MD5_COMPARE \ + else max(precision_index, api.data.get(GraphConst.JSON_INDEX_KEY, GraphConst.MIN_INDEX_KEY)) + node.data[GraphConst.JSON_INDEX_KEY] = precision_index + def _get_and_add_result(self, node_n, node_b): compare_result_list = compare_node([node_n.id, node_b.id], [self.data_n_dict, self.data_b_dict], diff --git a/debug/accuracy_tools/msprobe/visualization/graph/base_node.py b/debug/accuracy_tools/msprobe/visualization/graph/base_node.py index 2642ff1e97ebcc055212d4d776eb7c8a08866dc8..fd1541b87bf5e7ba54a95089646683c41f546ca6 100644 --- a/debug/accuracy_tools/msprobe/visualization/graph/base_node.py +++ b/debug/accuracy_tools/msprobe/visualization/graph/base_node.py @@ -12,10 +12,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from msprobe.core.overflow_check.level import OverflowLevel -from msprobe.visualization.graph.node_op import NodeOp from msprobe.visualization.utils import GraphConst from msprobe.visualization.builder.msprobe_adapter import format_node_data, compare_data, compare_data_fuzzy +from msprobe.core.common.log import logger class BaseNode: @@ -114,7 +115,13 @@ class BaseNode: """ ancestors = [] current_node = self.upnode + seen_nodes = set() while current_node: + if current_node.id in seen_nodes: + logger.warning(f'Detected a cycle in the node structure and cannot get node ancestors, ' + f'current node is {current_node.id}.') + return [] + seen_nodes.add(current_node.id) ancestors.append(current_node.id) current_node = current_node.upnode return list(reversed(ancestors)) diff --git a/debug/accuracy_tools/msprobe/visualization/graph/graph.py b/debug/accuracy_tools/msprobe/visualization/graph/graph.py index 5ce12d1cadb9aec2cc7c65954bb861b85032212d..569d8ea21b55ee24410c2b2fd1824d9b991966cd 100644 --- a/debug/accuracy_tools/msprobe/visualization/graph/graph.py +++ b/debug/accuracy_tools/msprobe/visualization/graph/graph.py @@ -19,7 +19,6 @@ from msprobe.visualization.utils import GraphConst from msprobe.core.common.log import logger from msprobe.core.common.const import Const - MAX_RECUR_LEVEL = 100 @@ -67,7 +66,6 @@ class Graph: ancestors_b = node_b.get_ancestors() return node_b, ancestors_n, ancestors_b - @staticmethod def fuzzy_match(node_n, node_b): if not node_n or not node_b or not node_n.fuzzy_eq(node_b): @@ -76,13 +74,6 @@ class Graph: ancestors_b = node_b.get_ancestors() return node_b, ancestors_n, ancestors_b - @staticmethod - def dfs(node, result): - info = node.to_dict() - result[node.id] = info - for subnode in node.subnodes: - Graph.dfs(subnode, result) - @staticmethod def split_nodes_by_micro_step(nodes): """ diff --git a/debug/accuracy_tools/msprobe/visualization/utils.py b/debug/accuracy_tools/msprobe/visualization/utils.py index 623bcd11c45f1ff8e9c283d30a982af239706ce4..f6e8258bb67cd850b5fc93b2b541b263fb48578e 100644 --- a/debug/accuracy_tools/msprobe/visualization/utils.py +++ b/debug/accuracy_tools/msprobe/visualization/utils.py @@ -182,11 +182,8 @@ class GraphConst: STR_MAX_LEN = 50 SMALL_VALUE = 1e-3 MD5_INDEX_LIST = [CompareConst.RESULT] - REAL_DATA_INDEX_LIST = [CompareConst.COSINE, CompareConst.MAX_ABS_ERR, CompareConst.MAX_RELATIVE_ERR, - CompareConst.ONE_THOUSANDTH_ERR_RATIO, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO] - SUMMARY_INDEX_LIST = [CompareConst.MAX_DIFF, CompareConst.MIN_DIFF, CompareConst.MEAN_DIFF, - CompareConst.NORM_DIFF, CompareConst.MAX_RELATIVE_ERR, CompareConst.MIN_RELATIVE_ERR, - CompareConst.MEAN_RELATIVE_ERR, CompareConst.NORM_RELATIVE_ERR] + REAL_DATA_INDEX_LIST = CompareConst.ALL_COMPARE_INDEX + SUMMARY_INDEX_LIST = CompareConst.SUMMARY_COMPARE_INDEX VALUE_INDEX_LIST = [Const.MAX, Const.MIN, Const.MEAN, Const.NORM] APIS_BETWEEN_MODULES = 'Apis_Between_Modules' NULL = 'null' diff --git a/dynolog_npu/README.md b/dynolog_npu/README.md index 9cc015e66c656c65fa48ad73a8246487a2016bef..86a23b7f82925079c26623b070936538768d9b8c 100644 --- a/dynolog_npu/README.md +++ b/dynolog_npu/README.md @@ -51,6 +51,8 @@ sudo yum install -y cmake ninja ### 3. 编译 +- dynolog编译 + 默认编译生成dyno和dynolog二进制文件, -t参数可以支持将二进制文件打包成deb包或rpm包. ```bash @@ -64,6 +66,10 @@ bash scripts/build.sh -t deb bash scripts/build.sh -t rpm ``` +- dynolog_npu_plugin wheel包编译 + +dynolog_npu_plugin wheel包提供IPCMonitor,MsptiMonitor等公共能力,使用nputrace和npu-monitor功能前必须安装该wheel包,具体编译安装指导可参考dynolog_npu\plugin\README.md。 + ## 使用方式 ### Profiler trace dump功能 @@ -85,32 +91,69 @@ nputrace子命令支持的参数选项 | 子命令 | 参数类型 | 说明 | |-------|-------|-------| -| record_shapes | action | 是否采集算子的InputShapes和InputTypes,设置参数采集,默认不采集 | -| profile_memory | action | 是否采集算子内存信息,设置参数采集,默认不采集 | -| with_stack | action | 是否采集Python调用栈,设置参数采集,默认不采集 | -| with_flops | action | 是否采集算子flops,设置参数采集,默认不采集 | -| with_modules | action | 是否采集modules层级的Python调用栈,设置参数采集,默认不采集 | +| job-id | u64 | 采集任务的job id,默认值0,dynolog原生参数 | +| pids | String | 采集任务的pid列表,多个pid用逗号分隔,默认值0,dynolog原生参数 | +| process-limit | u64 | 最大采集进程的数量,默认值3,dynolog原生参数 | +| profile-start-time | u64 | 用于同步采集的Unix时间戳,单位毫秒,默认值0,dynolog原生参数 | +| duration-ms | u64 | 采集的周期,单位毫秒,默认值500,dynolog原生参数 | +| iterations | i64 | 采集总迭代数,默认值-1,dynolog原生参数 | +| log-file | String | 采集落盘的路径,必选值 | +| start-step | u64 | 开始采集的迭代数,默认值0 | +| record-shapes | action | 是否采集算子的InputShapes和InputTypes,设置参数采集,默认不采集 | +| profile-memory | action | 是否采集算子内存信息,设置参数采集,默认不采集 | +| with-stack | action | 是否采集Python调用栈,设置参数采集,默认不采集 | +| with-flops | action | 是否采集算子flops,设置参数采集,默认不采集 | +| with-modules | action | 是否采集modules层级的Python调用栈,设置参数采集,默认不采集 | | analyse | action | 采集后是否自动解析,设置参数解析,默认不解析 | -| l2_cache | action | 是否采集L2 Cache数据,设置参数采集,默认不采集 | -| op_attr | action | 是否采集算子属性信息,设置参数采集,默认不采集 | -| data_simplification | String | 解析完成后是否数据精简,可选值范围[`true`, `false`],默认值`true` | +| l2-cache | action | 是否采集L2 Cache数据,设置参数采集,默认不采集 | +| op-attr | action | 是否采集算子属性信息,设置参数采集,默认不采集 | +| msprof-tx | action | 是否使能MSTX,设置参数采集,默认使能 | +| data-simplification | String | 解析完成后是否数据精简,可选值范围[`true`, `false`],默认值`true` | | activities | String | 控制CPU、NPU事件采集范围,可选值范围[`CPU,NPU`, `NPU,CPU`, `CPU`, `NPU`],默认值`CPU,NPU` | -| profiler_level | String | 控制profiler的采集等级,可选值范围[`Level_none`, `Level0`, `Level1`, `Level2`],默认值`Level0`| -| aic_metrics | String | AI Core的性能指标采集项,可选值范围[`AiCoreNone`, `PipeUtilization`, `ArithmeticUtilization`, `Memory`, `MemoryL0`, `ResourceConflictRatio`, `MemoryUB`, `L2Cache`, `MemoryAccess`],默认值`AiCoreNone`| -| export_type | String | profiler解析导出数据的类型,可选值范围[`Text`, `Db`],默认值`Text`| -| gc_detect_threshold | Option | GC检测阈值,单位ms,只采集超过阈值的GC事件。该参数为可选参数,默认不设置时不开启GC检测 | +| profiler-level | String | 控制profiler的采集等级,可选值范围[`Level_none`, `Level0`, `Level1`, `Level2`],默认值`Level0`| +| aic-metrics | String | AI Core的性能指标采集项,可选值范围[`AiCoreNone`, `PipeUtilization`, `ArithmeticUtilization`, `Memory`, `MemoryL0`, `ResourceConflictRatio`, `MemoryUB`, `L2Cache`, `MemoryAccess`],默认值`AiCoreNone`| +| export-type | String | profiler解析导出数据的类型,可选值范围[`Text`, `Db`],默认值`Text`| +| gc-detect-threshold | Option | GC检测阈值,单位ms,只采集超过阈值的GC事件。该参数为可选参数,默认不设置时不开启GC检测 | + + +- nputrace使用方法 + +Step0: 参考`3.编译`章节完成dynolog的编译,以及dynolog_npu_plugin wheel包的编译和安装。 + +Step1:拉起dynolog daemon进程 +```bash +# 方法1:使用systemd拉起service +# 修改配置文件/etc/dynolog.gflags, 使能ipc_monitor +echo "--enable_ipc_monitor" | sudo tee -a /etc/dynolog.gflags +sudo systemctl start dynolog + +# 方法2:命令行执行 +dynolog --enable-ipc-monitor + +#dynolog daemon的日志路径为:/var/log/dynolog.log +``` + +Step 2:使能dynolog trace dump环境变量 +```bash +export KINETO_USE_DAEMON=1 +``` -- nputrace示例命令 +Step 3: 拉起训练任务 +```bash +# 训练任务中需要使用pytorch的优化器/继承原生优化器 +bash train.sh +``` +Step 4:使用dyno CLI动态触发trace dump ```bash -# 示例1:采集框架、CANN和device数据,同时采集完后自动解析以及解析完成不做数据精简,落盘路径为/tmp/profile_data -dyno nputrace --activities CPU,NPU --analyse --data_simplification false --log-file /tmp/profile_data +# 示例1:从第10个step开始采集,采集2个step,采集框架、CANN和device数据,同时采集完后自动解析以及解析完成不做数据精简,落盘路径为/tmp/profile_data +dyno nputrace --start-step 10 --iterations 2 --activities CPU,NPU --analyse --data-simplification false --log-file /tmp/profile_data -# 示例2:只采集CANN和device数据,同时采集完后自动解析以及解析完成后开启数据精简,落盘路径为/tmp/profile_data -dyno nputrace --activities NPU --analyse --data_simplification true --log-file /tmp/profile_data +# 示例2:从第10个step开始采集,采集2个step,只采集CANN和device数据,同时采集完后自动解析以及解析完成后开启数据精简,落盘路径为/tmp/profile_data +dyno nputrace --start-step 10 --iterations 2 --activities NPU --analyse --data-simplification true --log-file /tmp/profile_data -# 示例3:只采集CANN和device数据,只采集不解析,落盘路径为/tmp/profile_data -dyno nputrace --activities NPU --log-file /tmp/profile_data +# 示例3:从第10个step开始采集,采集2个step,只采集CANN和device数据,只采集不解析,落盘路径为/tmp/profile_data +dyno nputrace --start-step 10 --iterations 2 --activities NPU --log-file /tmp/profile_data ``` ### NPU Monitor功能 @@ -129,20 +172,50 @@ dyno npu-monitor [SUBCOMMANDS] npu-monitor子命令支持的参数选项 | 子命令 | 参数类型 | 说明 | |-------|-------|-------| -| npu_monitor_start | action | 开启性能监控,设置参数开启,默认不采集 | -| npu_monitor_stop | action | 停止性能监控,设置参数开启,默认不采集 | -| report_interval_s | int | 性能监控数据上报周期,单位s,需要在启动时设置。默认值60 | -| mspti_activity_kind | String | 性能监控数据上报数据类型,可以设置单个或多个,多个类型以逗号分隔,需要在启动时设置。可选值范围[`Marker`, `Kernel`, `API`, `Hccl`, `Memory`, `MemSet`, `MemCpy`] , 默认值`Marker`| +| npu-monitor-start | action | 开启性能监控,设置参数开启,默认不采集 | +| npu-monitor-stop | action | 停止性能监控,设置参数开启,默认不采集 | +| report-interval-s | int | 性能监控数据上报周期,单位s,需要在启动时设置。默认值60 | +| mspti-activity-kind | String | 性能监控数据上报数据类型,可以设置单个或多个,多个类型以逗号分隔,需要在启动时设置。可选值范围[`Marker`, `Kernel`, `API`, `Hccl`, `Memory`, `MemSet`, `MemCpy`] , 默认值`Marker`| + +- npu-monitor使用方法 + +Step1: 拉起dynolog daemon进程 +```bash +# 方法1:使用systemd拉起service +# 修改配置文件/etc/dynolog.gflags, 使能ipc_monitor +echo "--enable_ipc_monitor" | sudo tee -a /etc/dynolog.gflags +sudo systemctl start dynolog + +# 方法2:命令行执行 +dynolog --enable-ipc-monitor + +#dynolog daemon的日志路径为:/var/log/dynolog.log +``` -- npu-monitor示例命令 +Step 2:使能dynolog trace dump环境变量 +```bash +export KINETO_USE_DAEMON=1 +``` +Step 3: 拉起训练任务 +```bash +# 训练任务中需要使用pytorch的优化器/继承原生优化器 +bash train.sh +``` + +Step 4:使用dyno CLI使能npu-monitor ```bash # 示例1:开启性能监控,使用默认配置 -dyno npu-monitor --npu_monitor_start +dyno npu-monitor --npu-monitor-start # 示例2:暂停性能监控 -dyno npu-monitor --npu_monitor_stop +dyno npu-monitor --npu-monitor-stop + +# 示例3:性能监控过程中修改配置 +# 上报周期30s, 上报数据类型Marker和Kernel +dyno npu-monitor --report-interval-s 30 --mspti-activity-kind Marker,Kernel -# 示例3:开启性能监控,上报周期30s, 上报数据类型Marker和Kernel -dyno npu-monitor --npu_monitor_start 30 --mspti_activity_kind Marker,Kernel +# 示例4:性能监控开启时修改配置 +# 上报周期30s, 上报数据类型Marker和Kernel +dyno npu-monitor --npu-monitor-start --report-interval-s 30 --mspti-activity-kind Marker,Kernel ``` \ No newline at end of file diff --git a/dynolog_npu/dynolog_npu/cli/src/commands/nputrace.rs b/dynolog_npu/dynolog_npu/cli/src/commands/nputrace.rs index 4bf7132de338d8eee0de556449269712617772e2..f70923bca4cc5ce29a8855a464c411b63a930ef0 100644 --- a/dynolog_npu/dynolog_npu/cli/src/commands/nputrace.rs +++ b/dynolog_npu/dynolog_npu/cli/src/commands/nputrace.rs @@ -55,6 +55,7 @@ pub struct NpuTraceOptions { pub aic_metrics: String, pub l2_cache: bool, pub op_attr: bool, + pub msprof_tx: bool, pub gc_detect_threshold: Option, pub data_simplification: String, pub export_type: String, @@ -75,6 +76,7 @@ PROFILE_PROFILER_LEVEL={} PROFILE_AIC_METRICS={} PROFILE_L2_CACHE={} PROFILE_OP_ATTR={} +PROFILE_MSPROF_TX={} PROFILE_GC_DETECT_THRESHOLD={} PROFILE_DATA_SIMPLIFICATION={} PROFILE_EXPORT_TYPE={}"#, @@ -89,6 +91,7 @@ PROFILE_EXPORT_TYPE={}"#, self.aic_metrics, self.l2_cache, self.op_attr, + self.msprof_tx, self.gc_detect_threshold.map_or("None".to_string(), |v| v.to_string()), self.data_simplification, self.export_type @@ -213,6 +216,7 @@ ACTIVITIES_ITERATIONS=1000"# aic_metrics: "AiCoreNone".to_string(), l2_cache: true, op_attr: true, + msprof_tx: true, gc_detect_threshold: 0.1, data_simplification: "true", export_type: "Text".to_string(), @@ -234,6 +238,7 @@ PROFILE_PROFILER_LEVEL=Level0 PROFILE_AIC_METRICS=AiCoreNone PROFILE_L2_CACHE=true PROFILE_OP_ATTR=true +PROFILE_MSPROF_TX=true PROFILE_GC_DETECT_THRESHOLD=0.1 PROFILE_DATA_SIMPLIFICATION=true PROFILE_EXPORT_TYPE=Text"# diff --git a/dynolog_npu/dynolog_npu/cli/src/main.rs b/dynolog_npu/dynolog_npu/cli/src/main.rs index 8bc4a2af0e2c19d6e783663924578e3c2ad7408a..9fdea3d1254467081356b2e0daeb8ed3ca05a16d 100644 --- a/dynolog_npu/dynolog_npu/cli/src/main.rs +++ b/dynolog_npu/dynolog_npu/cli/src/main.rs @@ -172,6 +172,9 @@ enum Command { /// Whether to collect op attributes. #[clap(long, action)] op_attr: bool, + /// Whether to enable MSTX. + #[clap(long, action)] + msprof_tx: bool, /// GC detect threshold. #[clap(long)] gc_detect_threshold: Option, @@ -290,6 +293,7 @@ fn main() -> Result<()> { aic_metrics, l2_cache, op_attr, + msprof_tx, gc_detect_threshold, data_simplification, export_type, @@ -318,6 +322,7 @@ fn main() -> Result<()> { aic_metrics, l2_cache, op_attr, + msprof_tx, gc_detect_threshold, data_simplification, export_type, diff --git a/dynolog_npu/plugin/README.md b/dynolog_npu/plugin/README.md new file mode 100644 index 0000000000000000000000000000000000000000..0cd51633bb15504416da85a49845a0ffb53d4452 --- /dev/null +++ b/dynolog_npu/plugin/README.md @@ -0,0 +1,42 @@ + + +# Plugins for Dynolog NPU +## 模块说明 +### IPCMonitor +提供IPC(Inter-Process Communication)通信接口,用于实现 +1. IPC控制通道: profiler backend向dynolog daemon获取profiler配置 + + +__PyDynamicMonitorProxy__: +* `init_dyno` 向dynolog daemon发送注册请求 + * input: npuId(int) + * return: None +* `poll_dyno` 向dynolog daemon获取Profiler控制参数 + * input: None + * return: str, 返回控制参数 + +## 安装方式 +### 1. 通过shell脚本一键安装 +``` +chmod +x build.sh +./build.sh +``` +### 2. 手动安装 +* 安装依赖 +``` +pip install wheel +pip install pybind11 +``` +* 编译whl包 +``` +python3 setup.py bdist_wheel +``` +以上命令执行完成后在plugn/dist目录下生成dynolog_npu插件whl安装包dynolog-npu-plugin-{version}.whl +* 安装 +``` +pip install dist/{dynolog-npu-plugin-{version}.wheel} +``` +* 卸载 +``` +pip uninstall dynolog-npu-plugin +``` \ No newline at end of file diff --git a/dynolog_npu/plugin/Readme.md b/dynolog_npu/plugin/Readme.md deleted file mode 100644 index c59bfffad5aaac5383b407e3ff3d23ed126131f5..0000000000000000000000000000000000000000 --- a/dynolog_npu/plugin/Readme.md +++ /dev/null @@ -1,17 +0,0 @@ - - -# Build and Install npu-dynolog-plugin -``` -# install pybind11 -pip install pybind11 - -# build dynolog_npu_plugin wheel -python3 setup.py bdist_wheel -# install -pip install dist/{dynolog-npu-plugin-xxx.wheel} - -# example -import IPCMonitor -dyno_worker = IPCMonitor.PyDynamicMonitorProxy() -dyno_worker.init_dyno(0) -``` diff --git a/dynolog_npu/plugin/ipc_monitor/DynoLogNpuMonitor.cpp b/dynolog_npu/plugin/ipc_monitor/DynoLogNpuMonitor.cpp index 940f5aae167f088361057fe2a7a389a76f5bb2b4..bba66d7297af1eec929a0149b0b2d1df35eaf843 100644 --- a/dynolog_npu/plugin/ipc_monitor/DynoLogNpuMonitor.cpp +++ b/dynolog_npu/plugin/ipc_monitor/DynoLogNpuMonitor.cpp @@ -1,7 +1,4 @@ #include "DynoLogNpuMonitor.h" - -#include - #include "utils.h" namespace dynolog_npu { @@ -10,13 +7,13 @@ namespace ipc_monitor { bool DynoLogNpuMonitor::Init() { if (isInitialized_) { - std::cout << "[WRARNING] DynoLog npu monitor already initialized" << std::endl; + LOG(ERROR) << "DynoLog npu monitor already initialized"; return true; } bool res = ipcClient_.RegisterInstance(npuId_); if (res) { isInitialized_ = true; - std::cout << "[INFO] DynoLog npu monitor initialized success !" << std::endl; + LOG(INFO) << "DynoLog npu monitor initialized success!"; } return res; } @@ -24,11 +21,6 @@ bool DynoLogNpuMonitor::Init() std::string DynoLogNpuMonitor::Poll() { std::string res = ipcClient_.IpcClientNpuConfig(); - if (res.empty()) { - std::cout << "[INFO] Request for dynolog server is empty !" << std::endl; - return ""; - } - std::cout << "[INFO] Received NPU configuration successfully" << std::endl; return res; } diff --git a/dynolog_npu/plugin/ipc_monitor/NpuIpcClient.cpp b/dynolog_npu/plugin/ipc_monitor/NpuIpcClient.cpp index 97966e8eeacc7276426feb237aa122eb8dee046f..ca2429f1e368ad996b8a8a954810ed7439c78bea 100644 --- a/dynolog_npu/plugin/ipc_monitor/NpuIpcClient.cpp +++ b/dynolog_npu/plugin/ipc_monitor/NpuIpcClient.cpp @@ -1,6 +1,5 @@ #include "NpuIpcClient.h" -#include namespace dynolog_npu { namespace ipc_monitor { @@ -15,14 +14,14 @@ bool IpcClient::RegisterInstance(int32_t id) std::unique_ptr message = Message::ConstructMessage(context, "ctxt"); try { if (!SyncSendMessage(*message, std::string(DYNO_IPC_NAME))) { - std::cout << "[WARNING]Failed to send register ctxt for pid " << context.pid << " with dyno" << std::endl; + LOG(ERROR) << "Failed to send register ctxt for pid " << context.pid << " with dyno"; return false; } } catch (const std::exception &e) { - std::cout << "[WARNING] Error when SyncSendMessage: " << e.what() << std::endl; + LOG(ERROR) << " Error when SyncSendMessage: " << e.what(); return false; } - std::cout << "[INFO] Resigter pid " << context.pid << " for dynolog success !" << std::endl; + LOG(INFO) << "Resigter pid " << context.pid << " for dynolog success !"; return true; } std::string IpcClient::IpcClientNpuConfig() @@ -37,7 +36,7 @@ std::string IpcClient::IpcClientNpuConfig() } std::unique_ptr message = Message::ConstructMessage(*req, "req", size); if (!SyncSendMessage(*message, std::string(DYNO_IPC_NAME))) { - std::cout << "[WARNING] Failed to send config to dyno server fail !" << std::endl; + LOG(ERROR) << " Failed to send config to dyno server fail !"; free(req); req = nullptr; return ""; @@ -45,7 +44,7 @@ std::string IpcClient::IpcClientNpuConfig() free(req); message = PollRecvMessage(MAX_IPC_RETRIES, MAX_SLEEP_US); if (!message) { - std::cout << "[WARNING] Failed to receive on-demand config !" << std::endl; + LOG(ERROR) << " Failed to receive on-demand config !"; return ""; } std::string res = std::string(ReinterpretConvert(message->buf.get()), message->metadata.size); @@ -65,7 +64,7 @@ std::unique_ptr IpcClient::ReceiveMessage() bool IpcClient::SyncSendMessage(const Message &message, const std::string &destName, int numRetry, int seepTimeUs) { if (destName.empty()) { - std::cout << "[WARNING] Can not send to empty socket name !" << std::endl; + LOG(ERROR) << " Can not send to empty socket name !"; return false; } int i = 0; @@ -79,7 +78,7 @@ bool IpcClient::SyncSendMessage(const Message &message, const std::string &destN seepTimeUs *= 2; // 2: double sleep time } } catch (const std::exception &e) { - std::cout << "[ERROR] Error when SyncSendMessage: " << e.what() << std::endl; + LOG(ERROR) << " Error when SyncSendMessage: " << e.what(); return false; } return i < numRetry; @@ -94,7 +93,7 @@ bool IpcClient::Recv() try { successFlag = ep_.TryPeekMessage(*peekCtxt); } catch (std::exception &e) { - std::cout << "[ERROR] Error when TryPeekMessage: " << e.what() << std::endl; + LOG(ERROR) << " Error when TryPeekMessage: " << e.what(); return false; } if (successFlag) { @@ -108,7 +107,7 @@ bool IpcClient::Recv() try { successFlag = ep_.TryRcvMessage(*recvCtxt); } catch (std::exception &e) { - std::cout << "[ERROR] Error when TryRecvMsg: " << e.what() << std::endl; + LOG(ERROR) << " Error when TryRecvMsg: " << e.what(); return false; } if (successFlag) { @@ -118,7 +117,7 @@ bool IpcClient::Recv() } } } catch (std::exception &e) { - std::cout << "[ERROR] Error in Recv(): " << e.what() << std::endl; + LOG(ERROR) << " Error in Recv(): " << e.what(); return false; } return false; diff --git a/dynolog_npu/plugin/ipc_monitor/PyDynamicMonitorProxy.h b/dynolog_npu/plugin/ipc_monitor/PyDynamicMonitorProxy.h index 8b5f88abf9d2cf589bec685cd3a520729afe8dd5..0471a70a3419eeeee2986d1d18710ee112c70313 100644 --- a/dynolog_npu/plugin/ipc_monitor/PyDynamicMonitorProxy.h +++ b/dynolog_npu/plugin/ipc_monitor/PyDynamicMonitorProxy.h @@ -1,7 +1,7 @@ #ifndef PYDYNAMIC_MONITOR_PROXY_H #define PYDYNAMIC_MONITOR_PROXY_H -#include +#include #include #include "MonitorBase.h" #include "DynoLogNpuMonitor.h" @@ -14,15 +14,21 @@ public: PyDynamicMonitorProxy() = default; bool InitDyno(int npuId) { - try { - monitor_ = DynoLogNpuMonitor::GetInstance(); - monitor_->SetNpuId(npuId); - bool res = monitor_->Init(); - return res; - } catch (const std::exception &e) { - std::cout << "[ERROR] Error when init dyno " << e.what() << std::endl; - return false; - } + try { + if (!google::IsGoogleLoggingInitialized()) { + google::InitGoogleLogging("DynoLogNpuMonitor"); + google::SetLogDestination(google::GLOG_INFO, "/var/log/dynolog_npu_"); + google::SetLogFilenameExtension(".log"); + } + monitor_ = DynoLogNpuMonitor::GetInstance(); + monitor_->SetNpuId(npuId); + bool res = monitor_->Init(); + LOG(ERROR) << res; + return res; + } catch (const std::exception &e) { + LOG(ERROR) << "Error when init dyno " << e.what(); + return false; + } } std::string PollDyno() diff --git a/dynolog_npu/plugin/ipc_monitor/utils.cpp b/dynolog_npu/plugin/ipc_monitor/utils.cpp index 936821fd34bc34bc9db9e09515132e8af39ba57a..b57942082e0fd52426ddce47bfc70620bf19019f 100644 --- a/dynolog_npu/plugin/ipc_monitor/utils.cpp +++ b/dynolog_npu/plugin/ipc_monitor/utils.cpp @@ -68,11 +68,11 @@ std::pair GetParentPidAndCommand(int32_t pid) if (std::getline(statFile, line)) { int ret = sscanf(line.c_str(), "%*d (%[^)]) %*c %d", command.data(), &parentPid); if (ret == 2) { // 2: 接收到2个字符 - std::cout << "[INFO] Success to get parent pid: " << parentPid << std::endl; + LOG(INFO) << "Success to get parent pid: " << parentPid; return std::make_pair(parentPid, command); } } - std::cout << "[WARNING] Failed to parse /proc/" << pid << "/stat" << std::endl; + LOG(ERROR) << " Failed to parse /proc/" << pid << "/stat"; return std::make_pair(0, ""); } diff --git a/dynolog_npu/plugin/ipc_monitor/utils.h b/dynolog_npu/plugin/ipc_monitor/utils.h index 0d8ceb8cfd0bf81b6d8b807c6ac1b505276ddf83..2374a27d417f91bc23108a892c6eb25cbb5039d8 100644 --- a/dynolog_npu/plugin/ipc_monitor/utils.h +++ b/dynolog_npu/plugin/ipc_monitor/utils.h @@ -10,7 +10,7 @@ #include #include #include -#include +#include #include diff --git a/dynolog_npu/plugin/setup.py b/dynolog_npu/plugin/setup.py index 151b9b3fb3fa1a42e147685f632163c8b3f5a564..55e924c6b6950c2a9f8f466159ea56184f77e1a6 100644 --- a/dynolog_npu/plugin/setup.py +++ b/dynolog_npu/plugin/setup.py @@ -13,25 +13,28 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +from glob import glob from setuptools import setup from pybind11.setup_helpers import Pybind11Extension BASE_DIR = os.path.dirname(os.path.realpath(__file__)) +DYNOLOG_PATH = os.path.join(os.path.dirname(BASE_DIR), "third_party", "dynolog") +GLOG_INC_PATH = os.path.join(DYNOLOG_PATH, "third_party", "glog", "src") +GLOG_LIB_PATH = os.path.join(DYNOLOG_PATH, "build", "third_party", "glog") # Define the extension module ext_modules = [ Pybind11Extension( "IPCMonitor", # Name of the Python module - sources=["bindings.cpp", - "ipc_monitor/utils.cpp", - "ipc_monitor/DynoLogNpuMonitor.cpp", - "ipc_monitor/NpuIpcClient.cpp", - ], # Source files - include_dirs=[os.path.join(BASE_DIR, "ipc_monitor")], # Include Pybind11 headers + sources=["bindings.cpp"] + list(glob("ipc_monitor/*.cpp")), # Source files + include_dirs=[os.path.join(BASE_DIR, "ipc_monitor"), GLOG_INC_PATH, GLOG_LIB_PATH], # Include Pybind11 headers + library_dirs=[GLOG_LIB_PATH], + libraries=["glog"], language="c++", # Specify the language ), ] + # Set up the package setup( name="dynolog_npu_plugin", diff --git a/flight_recoder/analysis_flight.py b/flight_recoder/analysis_flight.py deleted file mode 100644 index f81f771ab1c81ad79cb93401e200b600a4b17af3..0000000000000000000000000000000000000000 --- a/flight_recoder/analysis_flight.py +++ /dev/null @@ -1,164 +0,0 @@ -# Copyright (c) 2025, Huawei Technologies Co., Ltd. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Copyright Huawei Technologies Co., Ltd. 2024-2025. All rights reserved. - -import os -import pickle -import sys -import logging -from collections import defaultdict - -from check_path import get_valid_read_path - - -logging.basicConfig( - level=logging.INFO, # 设置日志级别为 INFO - format="%(asctime)s - %(levelname)s - %(message)s", # 设置日志格式 - handlers=[logging.StreamHandler()], # 输出到控制台 -) - - -SAFE_CLASSES = { - # 内置安全类型 - "builtins": {"str", "int", "float", "list", "dict", "tuple"}, -} - - -class SafeUnpickler(pickle.Unpickler): - def find_class(self, module, name): - # 检查模块和类是否在白名单中 - if module in SAFE_CLASSES and name in SAFE_CLASSES[module]: - return super().find_class(module, name) - raise pickle.UnpicklingError(f"Forbidden class: {module}.{name}") - - -def load_recorder_data(path, world_size): - """加载所有 rank 的 recorder 数据""" - recorder_dict = {} - for rank in range(world_size): - file_path = os.path.join(path, str(rank)) if not path.endswith("/") else path + str(rank) - file_path = get_valid_read_path(file_path) - try: - with open(file_path, "rb") as f: - res = SafeUnpickler(f).load() - recorder_dict[str(rank)] = res - except Exception as e: - logging.error(f"Failed to load data from {file_path}: {e}") - return recorder_dict - - -def extract_hccl_info(recorder_dict): - """从 recorder 数据中提取 HCCL 相关信息""" - hccl_dict = {} - for rank, recorder in recorder_dict.items(): - entries = recorder.get("entries", []) - if not entries: - continue - last_entry = entries[-1] - hccl_dict[rank] = { - "state": last_entry.get("state", None), - "record_id": last_entry.get("record_id", None), - "pg_id": last_entry.get("pg_id", None), - "time_discovered_completed_ns": last_entry.get("time_discovered_completed_ns", None), - "name": last_entry.get("frames", [{}])[0].get("name", None), - } - return hccl_dict - - -def analyze_pg_groups(hccl_dict): - """分析 HCCL 数据,按 pg_id 分组并检查问题""" - pg_groups = defaultdict(list) - for _, op in hccl_dict.items(): - pg_groups[op["pg_id"]].append(op) - - for pg_id, group in pg_groups.items(): - scheduled_ops = [op for op in group if op["state"] == "scheduled"] - completed_ops = [op for op in group if op["state"] == "completed"] - - # 情况 1: 所有卡都是 scheduled,且 record_id 和 name 相同 - if len(scheduled_ops) == len(group): - record_id = scheduled_ops[0]["record_id"] - name = scheduled_ops[0]["name"] - all_same = all(op["record_id"] == record_id and op["name"] == name for op in scheduled_ops) - if all_same: - logging.info( - f"The pg_id {pg_id}'s Communication Operator {name}" - " executed too slowly, causing the HCCL to time out." - ) - - # 情况 2: 存在 completed 算子且 该算子的record_id 比其他 scheduled 算子少 1 - elif completed_ops and scheduled_ops: - completed_op = completed_ops[0] - scheduled_record_id = scheduled_ops[0]["record_id"] - if completed_op["record_id"] == scheduled_record_id - 1: - logging.info( - f"The pg_id {pg_id}'s rank {completed_op['pg_id']}'s " - "Computational task took too long, causing the other ranks' " - "HCCL task to time out." - ) - - # 情况 3: 所有算子均为 completed - elif not scheduled_ops and completed_ops: - latest_op = max(completed_ops, key=lambda x: x["time_discovered_completed_ns"] or 0) - logging.info( - f"The computational task of the pg_id {pg_id} " - f"after the communication operator {latest_op['name']} " - "took too long." - ) - - else: - logging.info(f"The situation cannot be recognized!") - - -def get_int_arg(args, idx, default): - if len(args) > idx: - try: - return int(args[idx]) - except ValueError: - logging.warning(f"Invalid input {args[idx]}, using default: {default}") - return default - - -def main(): - # 设置默认值 - default_path = os.getenv("TORCH_HCCL_DEBUG_INFO_TEMP_FILE") - default_world_size = 8 - - # 获取命令行参数,如果未提供则使用默认值 - path = sys.argv[1] if len(sys.argv) > 1 else default_path - world_size = get_int_arg(sys.argv, 2, default_world_size) - - if not path: - raise ValueError("Path is required and cannot be empty.") - - logging.info(f"Path: {path}") - logging.info(f"World Size: {world_size}") - - # 加载数据 - recorder_dict = load_recorder_data(path, world_size) - if not recorder_dict: - logging.error("No valid recorder data found.") - return - - # 提取 HCCL 信息 - hccl_dict = extract_hccl_info(recorder_dict) - - # 分析 HCCL 数据 - analyze_pg_groups(hccl_dict) - - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/flight_recoder/check_path.py b/flight_recoder/check_path.py deleted file mode 100644 index b34e4dcdb68b28b44f387cb14919ad127658ca8f..0000000000000000000000000000000000000000 --- a/flight_recoder/check_path.py +++ /dev/null @@ -1,81 +0,0 @@ -# Copyright (c) 2025, Huawei Technologies Co., Ltd. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import re -import os -import sys -import stat - - -PATH_WHITE_LIST_REGEX = re.compile(r"[^_A-Za-z0-9/.-]") -MAX_READ_FILE_SIZE_4G = 4294967296 # 4G, 4 * 1024 * 1024 * 1024 -MAX_READ_FILE_SIZE_32G = 34359738368 # 32G, 32 * 1024 * 1024 * 1024 -MAX_READ_FILE_SIZE_512G = 549755813888 # 512G, 512 * 1024 * 1024 * 1024 - -# group not writable, others no permission, max stat is 750 -WRITE_FILE_NOT_PERMITTED_STAT = stat.S_IWGRP | stat.S_IWOTH | stat.S_IROTH | stat.S_IXOTH -# group not writable, others not writable, max stat is 755 -READ_FILE_NOT_PERMITTED_STAT = stat.S_IWGRP | stat.S_IWOTH - - -def type_to_str(value_type): - return ' or '.join([ii.__name__ for ii in value_type]) if isinstance(value_type, tuple) else value_type.__name__ - - -def check_type(value, value_type, param_name="value"): - if not isinstance(value, value_type): - raise TypeError('{} must be {}, not {}.'.format(param_name, type_to_str(value_type), type(value).__name__)) - - -def get_valid_path(path): - check_type(path, str, "path") - if not path or len(path) == 0: - raise ValueError("The value of the path cannot be empty.") - if PATH_WHITE_LIST_REGEX.search(path): # Check special char - raise ValueError("Input path contains invalid characters.") # Not printing out the path value for invalid char - path = os.path.expanduser(path) # Consider paths starting with "~" - if os.path.islink(os.path.abspath(path)): # when checking link, get rid of the "/" at the path tail if any - raise ValueError("The value of the path cannot be soft link: {}.".format(path)) - - real_path = os.path.realpath(path) - - if len(real_path) > 4096: - raise ValueError("The length of file path should be less than 4096.") - - if real_path != path and PATH_WHITE_LIST_REGEX.search(real_path): # Check special char again - raise ValueError("Input path contains invalid characters.") # Not printing out the path value for invalid char - - return real_path - - -def is_belong_to_user_or_group(file_stat): - return file_stat.st_uid == os.getuid() or file_stat.st_gid in os.getgroups() - - -def get_valid_read_path(path, size_max=MAX_READ_FILE_SIZE_4G, check_user_stat=True, is_dir=False): - real_path = get_valid_path(path) - if not os.path.isfile(real_path): - raise ValueError("The path {} doesn't exists or not a file.".format(path)) - - file_stat = os.stat(real_path) - if check_user_stat and not sys.platform.startswith("win") and not is_belong_to_user_or_group(file_stat): - raise ValueError("The file {} doesn't belong to the current user or group.".format(path)) - if check_user_stat and os.stat(path).st_mode & READ_FILE_NOT_PERMITTED_STAT > 0: - raise ValueError("The file {} is group writable, or is others writable.".format(path)) - if not os.access(real_path, os.R_OK) or file_stat.st_mode & stat.S_IRUSR == 0: # At least been 400 - raise ValueError("Current user doesn't have read permission to the file {}.".format(path)) - if not is_dir and size_max > 0 and file_stat.st_size > size_max: - raise ValueError("The file {} exceeds size limitation of {}.".format(path, size_max)) - return real_path \ No newline at end of file diff --git a/flight_recoder/flight_recoder.md b/flight_recoder/flight_recoder.md deleted file mode 100644 index 8b398a6730bae0823b04c20a22258a81392922c9..0000000000000000000000000000000000000000 --- a/flight_recoder/flight_recoder.md +++ /dev/null @@ -1,49 +0,0 @@ -# 飞行记录器超时类问题分析 - -训练任务卡住是阻塞AI大规模分布式集群训练任务的主要和关键问题,当前需要等待集合通信超时才能感知,影响集群可用性。框架需要支持检测训练任务卡住问题,做到提前识别并保存必要的诊断信息,提高问题定位效率和集群设备可用性。当HeartbeatMonitor长时间未检测到心跳时,即可认为训练任务已经卡住,需要触发诊断信息保存。 - -本工具提供torch npu上飞行记录器flight recorder记录日志的读取解析能力,并根据解析后的日志提供超时类问题的初步分析能力,主要支持以下三种情况的超时类问题的识别和分析 - -|问题| 具体内容 | -| --- | --- | -|类型一 | 同通信域内的某张卡计算超时,导致其他卡等待触发飞行记录器和hccl time out | -|类型二 | 同通信域内的通信算子之后的非通信任务耗时过长| -|类型三 | 同通信域内的某个通信算子进行通信时执行超时 | - -## 使用方法 - -### 1 飞行记录器开启方法 - -按照如下方法设置环境变量开启飞行记录器 - -``` -export TORCH_HCCL_ENABLE_MONITORING=1 #用于检测是否开启卡住问题检测 -export TORCH_HCCL_DUMP_ON_TIMEOUT=1 # 用于控制是否保存诊断信息 -export TORCH_HCCL_TRACE_BUFFER_SIZE=1 # 用于控制保存的集合通信状态数量 -export TORCH_HCCL_HEARTBEAT_TIMEOUT_SEC=20 # 用于控制心跳超时时间,即训练业务多久未下发集合通信算子时需要判定为卡住,默认10分钟,单位s。(需要小于HCCL_EXEC_TIMEOUT,避免集合通信先报超时错误) -export TORCH_HCCL_DEBUG_INFO_TEMP_FILE=/tmp/ #保存诊断信息的文件路径 -``` - -### 2 工具使用方法 - -``` -python analysis_flight.py path world_size -``` - -脚本从命令行参数获取 `path` 和 `world_size` 的值,并记录日志。如果未提供命令行参数,则使用默认值。 - -* `path`:从命令行第一个参数获取,如果未提供则使用 `default_path`, default_path从TORCH_HCCL_DEBUG_INFO_TEMP_FILE获取。 -* `world_size`:从命令行第二个参数获取,如果未提供则使用 `default_world_size`,默认为8。 - -| 参数名| 含义 | 使用限制 | -| --- | --- | --- | -| path | 飞行记录器的日志 | 可选。数据类型:string 默认为环境变量中的TORCH_HCCL_DEBUG_INFO_TEMP_FILE,若设置日志格式指定有前缀,则需要在路径中加入前缀 | -| world_size | 同一个通信域中的卡数 | 可选。数据类型:int 默认为8 | - -### 3 输出示例 - -``` -2025-02-19 08:10:07,160 - INFO - Path: /tmp/ -2025-02-19 08:10:07,160 - INFO - World Size: 8 -2025-02-19 08:10:07,162 - INFO - The pg_id 0's rank 0's Computational task took too long, causing the other ranks' HCCL task to time out. -``` diff --git a/profiler/msprof_analyze/README.md b/profiler/msprof_analyze/README.md index 7e2267a55596bac342b0e2ada564ea31c5625a84..d39aea89a521eaae504d177fac8ac2c9c5982afb 100644 --- a/profiler/msprof_analyze/README.md +++ b/profiler/msprof_analyze/README.md @@ -117,6 +117,7 @@ Successfully installed msprof-analyze-{version} | profiler版本 | 发布日期 | 下载链接 | 校验码 | |------------|------------|-------------------------------------------------------------------------------------------------------------------------------------------| ------------------------------------------------------------ | +| 2.0.1 | 2025-02-28 | [msprof_analyze-2.0.1-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/profiler/package/2.0.1/msprof_analyze-2.0.1-py3-none-any.whl) | 82dfe2c779dbab9015f61d36ea0c32d832b6d182454b3f7db68e6c0ed49c0423 | | 2.0.0 | 2025-02-08 | [msprof_analyze-2.0.0-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/profiler/package/2.0.0/msprof_analyze-2.0.0-py3-none-any.whl) | 8e44e5f3e7681c377bb2657a600ad9841d3bed11061ddd7844c30e8a97242101 | | 1.3.4 | 2025-01-20 | [msprof_analyze-1.3.4-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/profiler/package/1.3.4/msprof_analyze-1.3.4-py3-none-any.whl) | 8de92188d1a97105fb14cadcb0875ccd5f66629ee3bb25f37178da1906f4cce2 | | 1.3.3 | 2024-12-26 | [msprof_analyze-1.3.3-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/profiler/package/1.3.3/msprof_analyze-1.3.3-py3-none-any.whl) | 27676f2eee636bd0c65243f81e292c7f9d30d7f985c772ac9cbaf10b54d3584e | diff --git a/profiler/msprof_analyze/advisor/analyzer/computation/ai_core_performance/ai_core_performance_checker.py b/profiler/msprof_analyze/advisor/analyzer/computation/ai_core_performance/ai_core_performance_checker.py index fa62cd6f8958e28320d19e09d8ef1dae5609d03f..6d42e317599a83661615e727d8abcbd9997a3a9b 100644 --- a/profiler/msprof_analyze/advisor/analyzer/computation/ai_core_performance/ai_core_performance_checker.py +++ b/profiler/msprof_analyze/advisor/analyzer/computation/ai_core_performance/ai_core_performance_checker.py @@ -306,7 +306,7 @@ class AICorePerformanceChecker: aic_fixpipe_ratio = self.safe_divide(aic_fixpipe_ratio, length) aic_mte2_ratio = self.safe_divide(aic_mte2_ratio, length) if aic_mte2_ratio is None or aic_fixpipe_ratio is None: - return None, None, None + return None, None, None, None aic_fixpipe_ratio_rule, aic_mte2_ratio_rule = None, None for rule in self._operator_rules["fa_operators"]: if rule["target"] == "aic_fixpipe_ratio": @@ -341,7 +341,7 @@ class AICorePerformanceChecker: aiv_vec_ratio = self.safe_divide(aiv_vec_ratio, length) aic_mte2_ratio = self.safe_divide(aic_mte2_ratio, length) if aiv_vec_ratio is None or aic_mte2_ratio is None: - return None, None, None + return None, None, None, None aiv_vec_ratio_rule, aic_mte2_ratio_rule = None, None for rule in self._operator_rules["fa_operators"]: if rule["target"] == "aiv_vec_ratio": @@ -440,7 +440,7 @@ class AICorePerformanceChecker: result.add(OptimizeRecord(optimization_item)) headers = [ "Type", - "Description and Suggestion", + "Description", ] result.add_detail(problem_map[op_type], headers=headers) for opti_issue in self.result[op_type][0]: diff --git a/profiler/msprof_analyze/advisor/display/html/templates/ai_core_performance.html b/profiler/msprof_analyze/advisor/display/html/templates/ai_core_performance.html index 77e5e0cb55200efdf5b854e03ac2844ddc631a8f..9ee1ae9cb308c4e813ceb8fb86db26801fd1efdc 100644 --- a/profiler/msprof_analyze/advisor/display/html/templates/ai_core_performance.html +++ b/profiler/msprof_analyze/advisor/display/html/templates/ai_core_performance.html @@ -1,155 +1,200 @@ {% if format_result|length > 0 %} +

AI CORE Performance Analysis

{% if language == "cn" %} {% set title_ns = namespace(type='类别', desc='描述及建议', opti_set='性能优化算子集合', bound_set='bound算子集合', affinity_set='不亲和算子集合', - opti_refer=' 参考性能优化空间: ', bound_refer=' bound类型为: ', affinity_refer=' 不亲和类型为: ', title_desc='算子相关分析,参考如下: ') %} + opti_refer=' 参考性能优化空间', bound_refer=' bound类型为', affinity_refer=' 不亲和类型为', title_desc='算子相关分析,参考如下: ') %} {% else %} {% set title_ns = namespace(type='Type', desc='Description and Suggestion', opti_set='set of performance optimization operators', - bound_set='set of bound operators', affinity_set='set of unaffine operators', opti_refer=' refer to Performance Optimization Space: ', - bound_refer=' bound type: ', affinity_refer=' type of disaffinity: ', title_desc=' Operator related analysis, referenced below: ') %} + bound_set='set of bound operators', affinity_set='set of unaffine operators', opti_refer=' refer to Performance Optimization Space', + bound_refer=' bound type', affinity_refer=' type of disaffinity', title_desc=' Operator related analysis, referenced below: ') %} {% endif %} {% if format_result.cube[0]|length + format_result.cube[1]|length + format_result.cube[2]|length > 0 %} - MatMul{{ title_ns.title_desc }} + Cube{{ title_ns.title_desc }}
- + {% set opti_ns = namespace(total_opti='') %} {% for opti in format_result.cube[0] %} {% if not loop.first %} - {% set opti_ns.total_opti = opti_ns.total_opti ~ "
" ~ opti.op_name ~ " operator shape: " ~ opti.shape ~ " dtype: " ~ opti.dtype ~ title_ns.opti_refer ~ opti.optimization ~ "%" %} + {% set opti_ns.total_opti = opti_ns.total_opti ~ "" %} {% else %} - {% set opti_ns.total_opti = opti.op_name ~ " operator shape: " ~ opti.shape ~ " dtype: " ~ opti.dtype ~ title_ns.opti_refer ~ opti.optimization ~ "%" %} + {% set opti_ns.total_opti = "" %} {% endif %} {% endfor %} {% if opti_ns.total_opti|length > 0 %} - + {% endif %} {% set bound_ns = namespace(total_bound='') %} {% for bound in format_result.cube[1] %} {% if not loop.first %} - {% set bound_ns.total_bound = bound_ns.total_bound ~ "
" ~ bound.op_name ~ " operator shape: " ~ bound.shape ~ " dtype: " ~ bound.dtype ~ title_ns.bound_refer ~ bound.bound %} + {% set bound_ns.total_bound = bound_ns.total_bound ~ "" %} {% else %} - {% set bound_ns.total_bound = bound.op_name ~ " operator shape: " ~ bound.shape ~ " dtype: " ~ bound.dtype ~ title_ns.bound_refer ~ bound.bound %} + {% set bound_ns.total_bound = "" %} {% endif %} {% endfor %} {% if bound_ns.total_bound|length > 0 %} - + {% endif %} {% set affinity_ns = namespace(total_affinity='') %} {% for affinity in format_result.cube[2] %} {% if not loop.first %} - {% set affinity_ns.total_affinity = affinity_ns.total_affinity ~ "
" ~ affinity.op_name ~ " operator shape: " ~ affinity.shape ~ " dtype: " ~ affinity.dtype ~ title_ns.affinity_refer ~ affinity.suggestion %} + {% set affinity_ns.total_affinity = affinity_ns.total_affinity ~ "" %} {% else %} - {% set affinity_ns.total_affinity = affinity.op_name ~ " operator shape: " ~ affinity.shape ~ " dtype: " ~ affinity.dtype ~ title_ns.affinity_refer ~ affinity.suggestion %} + {% set affinity_ns.total_affinity = "" %} {% endif %} {% endfor %} {% if affinity_ns.total_affinity|length > 0 %} - + {% endif %}
{{ title_ns.type }}{{ title_ns.type }} {{ title_ns.desc }}
" ~ opti.op_name ~ "" ~ opti.shape ~ "" ~ opti.dtype ~ "" ~ opti.optimization ~ "%
" ~ opti.op_name ~ "" ~ opti.shape ~ "" ~ opti.dtype ~ "" ~ opti.optimization ~ "%
{{ title_ns.opti_set }}{{ opti_ns.total_opti | safe }} + + + {{ opti_ns.total_opti | safe }} +
nameshapedtype{{ title_ns.opti_refer }}
+
" ~ bound.op_name ~ "" ~ bound.shape ~ "" ~ bound.dtype ~ "" ~ bound.bound ~ "
" ~ bound.op_name ~ "" ~ bound.shape ~ "" ~ bound.dtype ~ "" ~ bound.bound ~ "
{{ title_ns.bound_set }}{{ bound_ns.total_bound | safe }} + + + {{ bound_ns.total_bound | safe }} +
nameshapedtype{{ title_ns.bound_refer }}
+
" ~ affinity.op_name ~ "" ~ affinity.shape ~ "" ~ affinity.dtype ~ "" ~ affinity.suggestion ~ "
" ~ affinity.op_name ~ "" ~ affinity.shape ~ "" ~ affinity.dtype ~ "" ~ affinity.suggestion ~ "
{{ title_ns.affinity_set }}{{ affinity_ns.total_affinity | safe }} + + + {{ affinity_ns.total_affinity | safe }} +
nameshapedtype{{ title_ns.affinity_refer }}
+
{% endif %} {% if format_result.fa[0]|length + format_result.fa[1]|length + format_result.fa[2]|length > 0 %} - FA{{ title_ns.title_desc }} + FA{{ title_ns.title_desc }}
- + {% set opti_ns = namespace(total_opti='') %} {% for opti in format_result.fa[0] %} {% if not loop.first %} - {% set opti_ns.total_opti = opti_ns.total_opti ~ "
" ~ opti.op_name ~ " operator shape: " ~ opti.shape ~ " dtype: " ~ opti.dtype ~ title_ns.opti_refer ~ opti.optimization ~ "%" %} + {% set opti_ns.total_opti = opti_ns.total_opti ~ "" %} {% else %} - {% set opti_ns.total_opti = opti.op_name ~ " operator shape: " ~ opti.shape ~ " dtype: " ~ opti.dtype ~ title_ns.opti_refer ~ opti.optimization ~ "%" %} + {% set opti_ns.total_opti = "" %} {% endif %} {% endfor %} {% if opti_ns.total_opti|length > 0 %} - + {% endif %} {% set bound_ns = namespace(total_bound='') %} {% for bound in format_result.fa[1] %} {% if not loop.first %} - {% set bound_ns.total_bound = bound_ns.total_bound ~ "
" ~ bound.op_name ~ " operator shape: " ~ bound.shape ~ " dtype: " ~ bound.dtype ~ title_ns.bound_refer ~ bound.bound %} + {% set bound_ns.total_bound = bound_ns.total_bound ~ "" %} {% else %} - {% set bound_ns.total_bound = bound.op_name ~ " operator shape: " ~ bound.shape ~ " dtype: " ~ bound.dtype ~ title_ns.bound_refer ~ bound.bound %} + {% set bound_ns.total_bound = "" %} {% endif %} {% endfor %} {% if bound_ns.total_bound|length > 0 %} - + {% endif %} {% set affinity_ns = namespace(total_affinity='') %} {% for affinity in format_result.fa[2] %} {% if not loop.first %} - {% set affinity_ns.total_affinity = affinity_ns.total_affinity ~ "
" ~ affinity.op_name ~ " operator shape: " ~ affinity.shape ~ " dtype: " ~ affinity.dtype ~ title_ns.affinity_refer ~ affinity.suggestion %} + {% set affinity_ns.total_affinity = affinity_ns.total_affinity ~ "" %} {% else %} - {% set affinity_ns.total_affinity = affinity.op_name ~ " operator shape: " ~ affinity.shape ~ " dtype: " ~ affinity.dtype ~ title_ns.affinity_refer ~ affinity.suggestion %} + {% set affinity_ns.total_affinity = "" %} {% endif %} {% endfor %} {% if affinity_ns.total_affinity|length > 0 %} - + {% endif %}
{{ title_ns.type }}{{ title_ns.type }} {{ title_ns.desc }}
" ~ opti.op_name ~ "" ~ opti.shape ~ "" ~ opti.dtype ~ "" ~ opti.optimization ~ "%
" ~ opti.op_name ~ "" ~ opti.shape ~ "" ~ opti.dtype ~ "" ~ opti.optimization ~ "%
{{ title_ns.opti_set }}{{ opti_ns.total_opti | safe }} + + + {{ opti_ns.total_opti | safe }} +
nameshapedtype{{ title_ns.opti_refer }}
+
" ~ bound.op_name ~ "" ~ bound.shape ~ "" ~ bound.dtype ~ "" ~ bound.bound ~ "
" ~ bound.op_name ~ "" ~ bound.shape ~ "" ~ bound.dtype ~ "" ~ bound.bound ~ "
{{ title_ns.bound_set }}{{ bound_ns.total_bound | safe }} + + + {{ bound_ns.total_bound | safe }} +
nameshapedtype{{ title_ns.bound_refer }}
+
" ~ affinity.op_name ~ "" ~ affinity.shape ~ "" ~ affinity.dtype ~ "" ~ affinity.suggestion ~ "
" ~ affinity.op_name ~ "" ~ affinity.shape ~ "" ~ affinity.dtype ~ "" ~ affinity.suggestion ~ "
{{ title_ns.affinity_set }}{{ affinity_ns.total_affinity | safe }} + + + {{ affinity_ns.total_affinity | safe }} +
nameshapedtype{{ title_ns.affinity_refer }}
+
{% endif %} {% if format_result.vector[0]|length + format_result.vector[1]|length > 0 %} - Vector{{ title_ns.title_desc }} + Vector{{ title_ns.title_desc }}
- + {% set opti_ns = namespace(total_opti='') %} {% for opti in format_result.vector[0] %} {% if not loop.first %} - {% set opti_ns.total_opti = opti_ns.total_opti ~ "
" ~ opti.op_name ~ " operator shape: " ~ opti.shape ~ " dtype: " ~ opti.dtype ~ title_ns.opti_refer ~ opti.optimization ~ "%" %} + {% set opti_ns.total_opti = opti_ns.total_opti ~ "" %} {% else %} - {% set opti_ns.total_opti = opti.op_name ~ " operator shape: " ~ opti.shape ~ " dtype: " ~ opti.dtype ~ title_ns.opti_refer ~ opti.optimization ~ "%" %} + {% set opti_ns.total_opti = "" %} {% endif %} {% endfor %} {% if opti_ns.total_opti|length > 0 %} - + {% endif %} {% set bound_ns = namespace(total_bound='') %} {% for bound in format_result.vector[1] %} {% if not loop.first %} - {% set bound_ns.total_bound = bound_ns.total_bound ~ "
" ~ bound.op_name ~ " operator shape: " ~ bound.shape ~ " dtype: " ~ bound.dtype ~ title_ns.bound_refer ~ bound.bound %} + {% set bound_ns.total_bound = bound_ns.total_bound ~ "" %} {% else %} - {% set bound_ns.total_bound = bound.op_name ~ " operator shape: " ~ bound.shape ~ " dtype: " ~ bound.dtype ~ title_ns.bound_refer ~ bound.bound %} + {% set bound_ns.total_bound = "" %} {% endif %} {% endfor %} {% if bound_ns.total_bound|length > 0 %} - + {% endif %}
{{ title_ns.type }}{{ title_ns.type }} {{ title_ns.desc }}
" ~ opti.op_name ~ "" ~ opti.shape ~ "" ~ opti.dtype ~ "" ~ opti.optimization ~ "%
" ~ opti.op_name ~ "" ~ opti.shape ~ "" ~ opti.dtype ~ "" ~ opti.optimization ~ "%
{{ title_ns.opti_set }}{{ opti_ns.total_opti | safe }} + + + {{ opti_ns.total_opti | safe }} +
nameshapedtype{{ title_ns.opti_refer }}
+
" ~ bound.op_name ~ "" ~ bound.shape ~ "" ~ bound.dtype ~ "" ~ bound.bound ~ "
" ~ bound.op_name ~ "" ~ bound.shape ~ "" ~ bound.dtype ~ "" ~ bound.bound ~ "
{{ title_ns.bound_set }}{{ bound_ns.total_bound | safe }} + + + {{ bound_ns.total_bound | safe }} +
nameshapedtype{{ title_ns.bound_refer }}
+
diff --git a/profiler/msprof_analyze/advisor/rules/en/aicore_performance.yaml b/profiler/msprof_analyze/advisor/rules/en/aicore_performance.yaml index 68ab59f16937880c7428330811005297d1551b0d..5f9f63869054a850aabfe41a2a4dd088afa70cfa 100644 --- a/profiler/msprof_analyze/advisor/rules/en/aicore_performance.yaml +++ b/profiler/msprof_analyze/advisor/rules/en/aicore_performance.yaml @@ -1,7 +1,7 @@ -cube_problem: "Cube operator performance analysis" -fa_problem: "FA operator performance analysis" -vector_problem: "Vector operator performance analysis" -description: "Provide some reference bottlenecks for the AICORE operator" +cube_problem: "Cube Operator Perf Analysis" +fa_problem: "FA Operator Perf Analysis" +vector_problem: "Vector Operator Perf Analysis" +description: "Provide some reference bottlenecks for the AICORE operators" bound_description: "set of bound operators" optimization_description: "set of performance optimization operators" affinity_description: "set of unaffine operators" diff --git a/profiler/msprof_analyze/cluster_analyse/README.md b/profiler/msprof_analyze/cluster_analyse/README.md index 325a0984793297dfac28673f04a582ea7b4316b9..2911f6ef23649bdb7b0df5fdd22ad70e1ec20a9b 100644 --- a/profiler/msprof_analyze/cluster_analyse/README.md +++ b/profiler/msprof_analyze/cluster_analyse/README.md @@ -2,11 +2,11 @@ cluster_analyse(集群分析工具)是在集群场景下,通过此工具来进行集群数据的分析,当前主要对基于通信域的迭代内耗时分析、通信时间分析以及通信矩阵分析为主, 从而定位慢卡、慢节点以及慢链路问题。 ## 性能数据采集 -当前集群调优工具主要支持PyTorch场景的Ascend PyTorch Profiler采集方式和MindSpore场景的MindSpore Profiler采集方式下的集群数据。 +当前集群调优工具主要支持PyTorch场景的Ascend PyTorch Profiler采集方式和MindSpore场景的MindSpore Profiler采集方式以及msprof命令行工具采集方式下的集群数据。 此工具只需要NPU的性能数据作为输入。 -Ascend PyTorch Profiler采集方法请参见《[NPU性能数据采集](https://gitee.com/ascend/mstt/tree/master/profiler/msprof_analyze)》,MindSpore Profiler采集方法请参见《[性能调试](https://www.mindspore.cn/mindinsight/docs/zh-CN/r2.3/performance_profiling_ascend.html)》。 +Ascend PyTorch Profiler采集方法请参见《[NPU性能数据采集](https://gitee.com/ascend/mstt/tree/master/profiler/msprof_analyze)》,MindSpore Profiler采集方法请参见《[性能调试](https://www.mindspore.cn/mindinsight/docs/zh-CN/r2.3/performance_profiling_ascend.html)》,msprof命令行采集方法请参见《[msprof命令行工具](https://www.hiascend.com/document/detail/zh/canncommercial/800/devaids/devtools/profiling/atlasprofiling_16_0010.html)》。 我们要求至少是L1级别的数据。 ```python @@ -16,19 +16,26 @@ experimental_config = torch_npu.profiler._ExperimentalConfig( ``` ### 确认数据是否可用 -打开采集到的某张卡数据(\*ascend_pt、\*ascend_ms结尾的文件夹),可用的数据应该具备: +通过上述三种方式获得性能数据,打开采集到的某张卡数据,可用的数据应该具备: -- ./profiler_info_x.json, -- ./ASCEND_PROFILER_OUTPUT/step_trace_time.csv, -- ./ASCEND_PROFILER_OUTPUT/trace_view.json, -- ./ASCEND_PROFILER_OUTPUT/kernel_details.csv, -- ./ASCEND_PROFILER_OUTPUT/communication.json, -- ./ASCEND_PROFILER_OUTPUT/communication_matrix.json +- Ascend PyTorch Profiler采集的\*ascend_pt目录或MindSpore Profiler采集的\*ascend_ms目录: -或者具备: + - ./profiler_info_x.json, + - ./ASCEND_PROFILER_OUTPUT/step_trace_time.csv, + - ./ASCEND_PROFILER_OUTPUT/trace_view.json, + - ./ASCEND_PROFILER_OUTPUT/kernel_details.csv, + - ./ASCEND_PROFILER_OUTPUT/communication.json, + - ./ASCEND_PROFILER_OUTPUT/communication_matrix.json -- analysis.db -- ascend_pytorch_profiler_{rank_id}.db + 或者具备: + + - analysis.db + - ascend_pytorch_profiler_{rank_id}.db + +- msprof命令行采集的PROF_XXX目录: + + - --type=db、--export=on情况下解析的:msprof_{timestamp}.db + - --type=db、--analyze=on情况下解析的:analyze/communication_analyzer.db 以上csv、json文件与db文件只能存在一类,否则集群分析工具解析异常。MindSpore场景暂不支持以上db文件。 @@ -79,6 +86,9 @@ experimental_config = torch_npu.profiler._ExperimentalConfig( | compute_op_sum | 集群场景性能数据的device运行算子信息汇总分析,输入性能数据需要基于ascend_pytorch_profiler_{rank_id}.db文件。--export_type为db时,输出交付件cluster_analysis.db;--export_type为notebook时,在cluster_analysis_output/ComputeOpSum目录下输出交付件stats.ipynb;可根据实际情况决定是否是否打开--exclude_op_name。 | 否 | | hccl_sum | 集合通信算子耗时分析,输入性能数据需要基于ascend_pytorch_profiler_{rank_id}.db文件。--export_type为db时,输出交付件cluster_analysis.db;--export_type为notebook时,在cluster_analysis_output/HcclSum目录下输出交付件stats.ipynb。 | 否 | | mstx_sum | 集群场景mstx打点信息汇总分析,输入性能数据需要基于ascend_pytorch_profiler_{rank_id}.db文件。--export_type为db时,输出交付件cluster_analysis.db;--export_type为notebook时,在cluster_analysis_output/MstxSum目录下输出交付件stats.ipynb。 | 否 | + | freq_analysis | 集群场景aicore frequency信息汇总分析,输入性能数据需要基于ascend_pytorch_profiler_{rank_id}.db文件。打屏输出是否存在aicore存在空闲(频率为800MHz)、异常(频率不为1800MHz或800MHz)的现象。如果有,则在输出交付件cluster_analysis.db增加对应的卡和频率信息。 | 否 | + | ep_load_balance | 集群场景moe负载信息汇总分析,输入性能数据需要基于ascend_pytorch_profiler_{rank_id}.db文件。输出交付件cluster_analysis.db增加EPTokensSummary, TopEPTokensInfo分析表格。 | 否 | + | slow_rank | 集群场景通信算子快慢卡汇总分析,输入性能数据需要基于ascend_pytorch_profiler_{rank_id}.db文件。输出交付件cluster_analysis.db中展示各个rank按照当前的快慢卡统计算法得出的快慢卡影响次数。 | 自定义分析参数 | 与cann_api_sum、compute_op_sum、hccl_sum等参数功能类似,用户可自定义一套性能数据的分析规则,需要详细了解性能分析的开发人员,具体开发指导请参见“[自定义分析规则开发指导](#自定义分析规则开发指导)”。 | 否 | --parallel_mode参数示例如下: diff --git a/profiler/msprof_analyze/cluster_analyse/analysis/analysis_facade.py b/profiler/msprof_analyze/cluster_analyse/analysis/analysis_facade.py index aa9658f0c6229264e40eeacf6cfcc9fcd3a18538..4f9b68efe673961172708f08e9be10f5eb5b088a 100644 --- a/profiler/msprof_analyze/cluster_analyse/analysis/analysis_facade.py +++ b/profiler/msprof_analyze/cluster_analyse/analysis/analysis_facade.py @@ -16,17 +16,18 @@ from multiprocessing import Process, Value, Lock from tqdm import tqdm from msprof_analyze.cluster_analyse.analysis.communication_analysis import CommunicationAnalysis -from msprof_analyze.cluster_analyse.analysis.communication_analysis import CommunicationAnalysisOptimized from msprof_analyze.cluster_analyse.analysis.comm_matrix_analysis import CommMatrixAnalysis -from msprof_analyze.cluster_analyse.analysis.comm_matrix_analysis import CommMatrixAnalysisOptimized from msprof_analyze.cluster_analyse.analysis.step_trace_time_analysis import StepTraceTimeAnalysis from msprof_analyze.cluster_analyse.analysis.host_info_analysis import HostInfoAnalysis from msprof_analyze.cluster_analyse.analysis.cluster_base_info_analysis import ClusterBaseInfoAnalysis from msprof_analyze.cluster_analyse.common_func.context import Context - from msprof_analyze.cluster_analyse.common_func.analysis_loader import get_class_from_name from msprof_analyze.prof_common.constant import Constant from msprof_analyze.prof_common.logger import get_logger +from msprof_analyze.cluster_analyse.recipes.comm_group_map.comm_group_map import CommGroupMap +from msprof_analyze.cluster_analyse.recipes.communication_time_sum.communication_time_sum import \ + CommunicationTimeSumRecipe +from msprof_analyze.cluster_analyse.recipes.communication_matrix_sum.communication_matrix_sum import CommMatrixSum logger = get_logger() @@ -34,10 +35,7 @@ logger = get_logger() class AnalysisFacade: default_module = {CommunicationAnalysis, StepTraceTimeAnalysis, CommMatrixAnalysis, HostInfoAnalysis, ClusterBaseInfoAnalysis} - simplified_module = { - CommunicationAnalysisOptimized, StepTraceTimeAnalysis, CommMatrixAnalysisOptimized, HostInfoAnalysis, - ClusterBaseInfoAnalysis - } + simplified_module = {StepTraceTimeAnalysis, ClusterBaseInfoAnalysis, HostInfoAnalysis} def __init__(self, params: dict): self.params = params @@ -47,6 +45,7 @@ class AnalysisFacade: process_list = [] if self.params.get(Constant.DATA_SIMPLIFICATION) and self.params.get(Constant.DATA_TYPE) == Constant.DB: analysis_module = self.simplified_module + self.cluster_analyze_with_recipe() else: analysis_module = self.default_module @@ -92,3 +91,12 @@ class AnalysisFacade: recipe_class = get_class_from_name(self.params.get(Constant.ANALYSIS_MODE)) if recipe_class: self.do_recipe(recipe_class) + + def cluster_analyze_with_recipe(self): + recipes = [["CommGroupMap", CommGroupMap]] + if self.params.get(Constant.ANALYSIS_MODE) in (Constant.ALL, Constant.COMMUNICATION_TIME): + recipes.append(["CommunicationTimeSumRecipe", CommunicationTimeSumRecipe]) + if self.params.get(Constant.ANALYSIS_MODE) in (Constant.ALL, Constant.COMMUNICATION_MATRIX): + recipes.append(["CommMatrixSum", CommMatrixSum]) + for recipe_class in recipes: + self.do_recipe(recipe_class) diff --git a/profiler/msprof_analyze/cluster_analyse/analysis/base_analysis.py b/profiler/msprof_analyze/cluster_analyse/analysis/base_analysis.py index 59b824c1defd3ea93655a71d9727962871d07537..0d14af7693abbf433af14431ff4958e5a24a3cde 100644 --- a/profiler/msprof_analyze/cluster_analyse/analysis/base_analysis.py +++ b/profiler/msprof_analyze/cluster_analyse/analysis/base_analysis.py @@ -54,7 +54,7 @@ class BaseAnalysis: if stat_name in op_name: if stat_name != total: return False - return True + return True @abstractmethod def run(self): diff --git a/profiler/msprof_analyze/cluster_analyse/analysis/cluster_base_info_analysis.py b/profiler/msprof_analyze/cluster_analyse/analysis/cluster_base_info_analysis.py index cb280978c41a639a2f5d17e2a0ff08ed3a9962d6..664ef3ba1a424eec7b673a0fe765ca390ec372b4 100644 --- a/profiler/msprof_analyze/cluster_analyse/analysis/cluster_base_info_analysis.py +++ b/profiler/msprof_analyze/cluster_analyse/analysis/cluster_base_info_analysis.py @@ -28,7 +28,6 @@ logger = get_logger() class ClusterBaseInfoAnalysis(BaseAnalysis): - KEY_DISTRIBUTED_ARGS = "distributed_args" def __init__(self, param: dict): super().__init__(param) @@ -56,7 +55,7 @@ class ClusterBaseInfoAnalysis(BaseAnalysis): result_db = os.path.join(output_path, Constant.DB_CLUSTER_COMMUNICATION_ANALYZER) conn, curs = DBManager.create_connect_db(result_db) DBManager.create_tables(result_db, Constant.TABLE_CLUSTER_BASE_INFO) - save_distributed_args = [[json.dumps(self.distributed_args)]] + save_distributed_args = [[Constant.DISTRIBUTED_ARGS, json.dumps(self.distributed_args)]] sql = "insert into {} values ({value})".format(Constant.TABLE_CLUSTER_BASE_INFO, value="?," * (len(save_distributed_args[0]) - 1) + "?") DBManager.executemany_sql(conn, sql, save_distributed_args) @@ -72,9 +71,9 @@ class ClusterBaseInfoAnalysis(BaseAnalysis): except RuntimeError as e: logger.error("Read json failed. %s", str(e)) continue - if not meta_data.get(self.KEY_DISTRIBUTED_ARGS): + if not meta_data.get(Constant.DISTRIBUTED_ARGS): continue - for key, value in meta_data[self.KEY_DISTRIBUTED_ARGS].items(): + for key, value in meta_data[Constant.DISTRIBUTED_ARGS].items(): if key == "rank": continue self.distributed_args.setdefault(key, value) diff --git a/profiler/msprof_analyze/cluster_analyse/analysis/comm_matrix_analysis.py b/profiler/msprof_analyze/cluster_analyse/analysis/comm_matrix_analysis.py index a87803438aef3a733c41588413ff2281b85ae418..d4df5466c3845f7a2db7f3b5b439059155bc4d47 100644 --- a/profiler/msprof_analyze/cluster_analyse/analysis/comm_matrix_analysis.py +++ b/profiler/msprof_analyze/cluster_analyse/analysis/comm_matrix_analysis.py @@ -22,6 +22,8 @@ from msprof_analyze.prof_common.db_manager import DBManager from msprof_analyze.cluster_analyse.common_func.utils import increase_shared_value from msprof_analyze.prof_common.constant import Constant from msprof_analyze.prof_common.logger import get_logger +from msprof_analyze.cluster_analyse.common_func.utils import double_hash +from msprof_analyze.prof_common.file_manager import FileManager logger = get_logger() @@ -70,48 +72,66 @@ class CommMatrixAnalysis(BaseAnalysis): self.combine_link_info(step_dict) def merge_same_links(self, step_dict: dict): - def process_link_key(rank_id, rank_dict): + def update_rank_map(step_dict): + for op_name, op_dict in step_dict.items(): + group_name = op_name.split("@")[-1] + for rank_id, rank_dict in op_dict.items(): + for link_key in rank_dict: + if '-' not in link_key: + logger.warning("%s has an invalid link key %s!", str(op_name), str(link_key)) + break + src_rank = link_key.split('-')[0] + dst_rank = link_key.split('-')[1] + if src_rank == dst_rank: + if src_rank not in project_local_global_rank_map.get(group_name, {}): + project_local_global_rank_map.setdefault(group_name, {})[src_rank] = rank_id + elif project_local_global_rank_map.get(group_name, {}).get(src_rank) != rank_id: + logger.warning(f"In the same communication group {group_name}, global rank {rank_id} " + f"and {project_local_global_rank_map.get(group_name, {}).get(src_rank)} " + f"get the same local rank {src_rank}!") + + def process_link_key(rank_dict): for link_key in rank_dict: if '-' not in link_key: logger.warning("%s has an invalid link key %s!", str(op_name), str(link_key)) break - src_rank = link_key.split('-')[0] - dst_rank = link_key.split('-')[1] - if src_rank == dst_rank: - if src_rank not in project_local_global_rank_map: - project_local_global_rank_map[src_rank] = rank_id - elif project_local_global_rank_map.get(src_rank) != rank_id: - logger.warning("In the same communication group, local ranks projecting to global ranks " - "repeat!") self.combine_link(link_info[link_key], rank_dict[link_key]) - def convert_local_to_global_rank(): + def convert_local_to_global_rank(rank_map): tmp_link = {} for link_key, link_dict in link_info.items(): src_rank = link_key.split('-')[0] dst_rank = link_key.split('-')[1] - src_rank = project_local_global_rank_map[src_rank] \ - if src_rank in project_local_global_rank_map else src_rank - dst_rank = project_local_global_rank_map[dst_rank] \ - if dst_rank in project_local_global_rank_map else dst_rank + if src_rank not in rank_map: + logger.warning(f"The src local rank {src_rank} of the operator {op_name} " + f"cannot be mapped to the global rank.") + continue + if dst_rank not in rank_map: + logger.warning(f"The dst local rank {dst_rank} of the operator {op_name} " + f"cannot be mapped to the global rank.") + continue + src_rank = rank_map[src_rank] + dst_rank = rank_map[dst_rank] link_dict[Constant.BANDWIDTH_GB_S] = \ self.compute_ratio(link_dict.get(Constant.TRANSIT_SIZE_MB, 0), link_dict.get(Constant.TRANSIT_TIME_MS, 0)) tmp_link[f"{src_rank}-{dst_rank}"] = link_dict return tmp_link - project_local_global_rank_map = dict() default_value = { Constant.TRANSPORT_TYPE: '', Constant.TRANSIT_TIME_MS: 0, Constant.TRANSIT_SIZE_MB: 0, Constant.OP_NAME: '' } + project_local_global_rank_map = self.get_parallel_group_info() + update_rank_map(step_dict) for op_name, op_dict in step_dict.items(): link_info = defaultdict(lambda: copy.deepcopy(default_value)) - for rank_id, rank_dict in op_dict.items(): - process_link_key(rank_id, rank_dict) - step_dict[op_name] = convert_local_to_global_rank() + group_name = op_name.split("@")[-1] + for rank_dict in op_dict.values(): + process_link_key(rank_dict) + step_dict[op_name] = convert_local_to_global_rank(project_local_global_rank_map.get(group_name, {})) def combine_link_info(self, step_dict: dict): default_value = { @@ -131,23 +151,15 @@ class CommMatrixAnalysis(BaseAnalysis): link_dict.get(Constant.TRANSIT_TIME_MS, 0)) step_dict[Constant.TOTAL_OP_INFO] = total_op_info - -class CommMatrixAnalysisOptimized(CommMatrixAnalysis): - SAVED_JSON = "cluster_communication_matrix.json" - COMMUNICATION_MATRIX_TABLE = "ClusterCommunicationMatrix" - - def __init__(self, param: dict): - super().__init__(param) - - def dump_db(self): - res_comm_matrix = self.adapter.transfer_matrix_from_json_to_db(self.comm_ops_struct) - output_path = os.path.join(self.cluster_analysis_output_path, Constant.CLUSTER_ANALYSIS_OUTPUT) - result_db = os.path.join(output_path, Constant.DB_CLUSTER_COMMUNICATION_ANALYZER) - DBManager.create_tables(result_db, self.COMMUNICATION_MATRIX_TABLE) - conn, cursor = DBManager.create_connect_db(result_db) - if res_comm_matrix: - res_matrix_value = [list(data.values())[1:] for data in res_comm_matrix] - sql = "insert into {} values ({value})".format(self.COMMUNICATION_MATRIX_TABLE, - value="?," * (len(res_matrix_value[0]) - 1) + "?") - DBManager.executemany_sql(conn, sql, res_matrix_value) - DBManager.destroy_db_connect(conn, cursor) + def get_parallel_group_info(self): + parallel_group_info = {} + for profiler_path in self.data_map.values(): + meta_json = os.path.join(profiler_path, "profiler_metadata.json") + if os.path.exists(meta_json): + meta_data = FileManager.read_json_file(meta_json) + for group_name, group_info in meta_data.get("parallel_group_info", {}).items(): + global_ranks = group_info.get("global_ranks") + if isinstance(global_ranks, list) and global_ranks: + global_ranks.sort() + parallel_group_info[double_hash(group_name)] = dict(enumerate(global_ranks)) + return parallel_group_info diff --git a/profiler/msprof_analyze/cluster_analyse/analysis/communication_analysis.py b/profiler/msprof_analyze/cluster_analyse/analysis/communication_analysis.py index 61daa5b943d4a718d90f80203bac3fc948202199..e8ca793f525b0279053bc9848f99f21016ea6295 100644 --- a/profiler/msprof_analyze/cluster_analyse/analysis/communication_analysis.py +++ b/profiler/msprof_analyze/cluster_analyse/analysis/communication_analysis.py @@ -137,147 +137,3 @@ class CommunicationBandwidthParams: self.step_id = step_id self.transport_type = transport_type self.package_size = package_size - - -class CommunicationAnalysisOptimized(BaseAnalysis): - COMMUNICATION_BANDWIDTH_TABLE = "ClusterCommunicationBandwidth" - COMMUNICATION_TIME_TABLE = "ClusterCommunicationTime" - - def __init__(self, param: dict): - super().__init__(param) - self._communication_ops = param.get(Constant.COMM_DATA_DICT, {}).get(Constant.COMMUNICATION_OPS) - self._communication_group = param.get(Constant.COMM_DATA_DICT, {}).get(Constant.COMMUNICATION_GROUP) - self._aggregate_time = {} - self._aggregate_bandwidth = {} - self._output_time = [] - self._output_bandwidth = [] - - @staticmethod - def _execute(conn, res_data, table_name): - if res_data: - sql = "insert into {} values ({value})".format(table_name, value="?," * (len(res_data[0]) - 1) + "?") - DBManager.executemany_sql(conn, sql, res_data) - - @staticmethod - def _format_time_data(communication_data): - data_dict = {} - for single_op in communication_data: - formatted_data = CommunicationTimeBean(single_op) - data_dict.setdefault(formatted_data.step_id, {}). \ - setdefault(formatted_data.rank_id, {}). \ - setdefault(formatted_data.group_name, []).extend([formatted_data]) - return data_dict - - def run(self, completed_processes, lock): - if not self._communication_ops[0] or not self._communication_ops[1]: - increase_shared_value(completed_processes, lock) - logger.info("CommunicationAnalysisOptimized completed") - return - self._aggregate_time = self._format_time_data(self._communication_ops[0]) - self._aggregate_bandwidth = self._format_bandwidth_data(self._communication_ops[1]) - self._compute_total_info() - self._dump_data() - increase_shared_value(completed_processes, lock) - logger.info("CommunicationAnalysisOptimized completed") - - def _format_bandwidth_data(self, communication_data: dict): - data_dict = {} - for single_op in communication_data: - formatted_data = CommunicationBandwidthBean(single_op) - rank_set = str(self.collective_group_dict.get(formatted_data.group_name, formatted_data.group_name)) - data_dict.setdefault(rank_set, {}).setdefault(formatted_data.step_id, {}). \ - setdefault(formatted_data.rank_id, {}). \ - setdefault(formatted_data.transport_type, {}). \ - setdefault(formatted_data.package_size, []).extend([formatted_data]) - return data_dict - - def _dump_data(self): - output_path = os.path.join(self.cluster_analysis_output_path, Constant.CLUSTER_ANALYSIS_OUTPUT) - result_db = os.path.join(output_path, Constant.DB_CLUSTER_COMMUNICATION_ANALYZER) - DBManager.create_tables(result_db, self.COMMUNICATION_TIME_TABLE) - DBManager.create_tables(result_db, self.COMMUNICATION_BANDWIDTH_TABLE) - conn, cursor = DBManager.create_connect_db(result_db) - self._execute(conn, self._output_time, self.COMMUNICATION_TIME_TABLE) - self._execute(conn, self._output_bandwidth, self.COMMUNICATION_BANDWIDTH_TABLE) - DBManager.destroy_db_connect(conn, cursor) - - def _compute_time_info(self): - for step_id, rank_dict in self._aggregate_time.items(): - for rank_id, communication_op_info in rank_dict.items(): - rank_set_dict = {} - for group_name, single_group_op_info in communication_op_info.items(): - total_dict = { - TableConstant.RANK_ID: rank_id, - TableConstant.STEP: step_id, - TableConstant.GROUP_NAME: group_name, - TableConstant.HCCL_OP_NAME: Constant.TOTAL_OP_INFO - } - total_time_info = CommunicationTimeBean(total_dict) - for com_info_dict in single_group_op_info: - total_time_info += com_info_dict - self._output_time.append(com_info_dict.convert_output()) - rank_set = str(self.collective_group_dict.get(group_name)) - if not rank_set: - logger.warning("failed to find rank set with group name: %s.", str(group_name)) - continue - if rank_set_dict.get(rank_set): - rank_set_dict[rank_set] += total_time_info - else: - rank_set_dict[rank_set] = total_time_info - for _, total_time_info in rank_set_dict.items(): - total_time_info.compute_ratio() - self._output_time.append(total_time_info.convert_output()) - - def _process_package_info(self, package_info, total_transit_size, total_transit_time, op_group_set, - communication_bandwidth_params): - total_bw_info = CommunicationBandwidthBean({ - TableConstant.RANK_ID: communication_bandwidth_params.rank_id, - TableConstant.STEP: communication_bandwidth_params.step_id, - TableConstant.GROUP_NAME: '', - TableConstant.HCCL_OP_NAME: Constant.TOTAL_OP_INFO, - TableConstant.TRANSPORT_TYPE: communication_bandwidth_params.transport_type, - TableConstant.TRANSIT_SIZE: 0.0, - TableConstant.TRANSIT_TIME: 0.0, - TableConstant.BANDWIDTH: 0.0, - TableConstant.PACKAGE_SIZE: communication_bandwidth_params.package_size - }) - for bandwidth_package_info in package_info: - total_bw_info += bandwidth_package_info - if not total_bw_info.group_name: - total_bw_info.set_group_name(bandwidth_package_info.group_name) - self._output_bandwidth.append(bandwidth_package_info.convert_output()) - op_group = bandwidth_package_info.hccl_op_name + "@" + bandwidth_package_info.group_name - if op_group not in op_group_set: - op_group_set.add(op_group) - total_transit_size += bandwidth_package_info.transit_size - total_transit_time += bandwidth_package_info.transit_time - return total_bw_info, total_transit_size, total_transit_time - - def _compute_bandwidth_info(self): - for _, step_dict in self._aggregate_bandwidth.items(): - for step_id, rank_dict in step_dict.items(): - for rank_id, communication_op_info in rank_dict.items(): - for transport_type, bandwidth_info in communication_op_info.items(): - total_transit_size = 0.0 - total_transit_time = 0.0 - total_info = [] - op_group_set = set() - for package_size, package_info in bandwidth_info.items(): - total_bandwidth_info, total_transit_size, total_transit_time = self._process_package_info( - package_info, total_transit_size, total_transit_time, op_group_set, - CommunicationBandwidthParams(rank_id, step_id, transport_type, package_size) - ) - total_info.append(total_bandwidth_info) - total_bandwidth = total_transit_size / total_transit_time if total_transit_time else 0.0 - for single_total_info in total_info: - single_total_info.set_transit_size(total_transit_size) - single_total_info.set_transit_time(total_transit_time) - single_total_info.set_bandwidth(total_bandwidth) - self._output_bandwidth.append(single_total_info.convert_output()) - - def _compute_total_info(self): - if not self._aggregate_time or not self._aggregate_bandwidth: - logger.error("communication data is null.") - return - self._compute_time_info() - self._compute_bandwidth_info() diff --git a/profiler/msprof_analyze/cluster_analyse/analysis/host_info_analysis.py b/profiler/msprof_analyze/cluster_analyse/analysis/host_info_analysis.py index e8821b3c2971d1b9707a8347b164f031a44d2922..05cd3c9a2ab7332fb885993148505b3128f128fc 100644 --- a/profiler/msprof_analyze/cluster_analyse/analysis/host_info_analysis.py +++ b/profiler/msprof_analyze/cluster_analyse/analysis/host_info_analysis.py @@ -21,6 +21,8 @@ from msprof_analyze.cluster_analyse.common_func.utils import increase_shared_val from msprof_analyze.prof_common.path_manager import PathManager from msprof_analyze.prof_common.constant import Constant from msprof_analyze.prof_common.logger import get_logger +from msprof_analyze.cluster_analyse.cluster_data_preprocess.msprof_data_preprocessor import MsprofDataPreprocessor +from msprof_analyze.cluster_analyse.cluster_data_preprocess.mindspore_data_preprocessor import MindsporeDataPreprocessor logger = get_logger() @@ -33,6 +35,8 @@ class HostInfoAnalysis(BaseAnalysis): super().__init__(param) self.all_rank_host_info = {} self.all_rank_device_info = [] + self.is_msprof = param.get(Constant.IS_MSPROF) + self.is_mindspore = param.get(Constant.IS_MINDSPORE) def run(self, completed_processes=None, lock=None): if self.data_type != Constant.DB: @@ -83,7 +87,7 @@ class HostInfoAnalysis(BaseAnalysis): for rank_id, profiling_dir in self.data_map.items(): host_info = [] rank_device_info = [] - db_path = os.path.join(profiling_dir, Constant.SINGLE_OUTPUT, f"ascend_pytorch_profiler_{rank_id}.db") + db_path = self._get_db_path(rank_id, profiling_dir) if (os.path.exists(db_path) and DBManager.check_tables_in_db(db_path, self.TABLE_HOST_INFO)): conn, curs = DBManager.create_connect_db(db_path) sql = "select * from {0}".format(self.TABLE_HOST_INFO) @@ -98,6 +102,13 @@ class HostInfoAnalysis(BaseAnalysis): sql = "select * from {0}".format(self.TABLE_RANK_DEVICE_MAP) rank_device_info = DBManager.fetch_all_data(curs, sql, is_dict=False) DBManager.destroy_db_connect(conn, curs) + if self.is_msprof: + device_id = MsprofDataPreprocessor.get_device_id(profiling_dir) + rank_device_info = [[rank_id, device_id]] + if self.is_mindspore: + prof_dir = MindsporeDataPreprocessor.get_msprof_dir(profiling_dir) + device_id = MsprofDataPreprocessor.get_device_id(prof_dir) + rank_device_info = [[rank_id, device_id]] if not (rank_device_info and rank_device_info[0]): if not print_empty_host_info: print_empty_host_info = f"No {self.TABLE_RANK_DEVICE_MAP} data in {self.data_type} file." @@ -109,3 +120,10 @@ class HostInfoAnalysis(BaseAnalysis): self.all_rank_device_info.extend(rank_device_info) if print_empty_host_info: logger.warning(print_empty_host_info) + + def _get_db_path(self, rank_id, profiling_dir): + if self.is_msprof: + return MsprofDataPreprocessor.get_msprof_profiler_db_path(profiling_dir) + if self.is_mindspore: + return os.path.join(profiling_dir, Constant.SINGLE_OUTPUT, f"ascend_mindspore_profiler_{rank_id}.db") + return os.path.join(profiling_dir, Constant.SINGLE_OUTPUT, f"ascend_pytorch_profiler_{rank_id}.db") diff --git a/profiler/msprof_analyze/cluster_analyse/analysis/msprof_step_trace_time_adapter.py b/profiler/msprof_analyze/cluster_analyse/analysis/msprof_step_trace_time_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..799fd86a477b5ea3e0b59e1831fdc5a7ff398a6b --- /dev/null +++ b/profiler/msprof_analyze/cluster_analyse/analysis/msprof_step_trace_time_adapter.py @@ -0,0 +1,149 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from msprof_analyze.cluster_analyse.prof_bean.step_trace_time_bean import StepTraceTimeBean +from msprof_analyze.prof_common.utils import convert_to_float +from msprof_analyze.prof_common.file_manager import FileManager +from msprof_analyze.cluster_analyse.common_func.time_range_calculator import RangeCaculator +from msprof_analyze.prof_common.db_manager import DBManager +from msprof_analyze.prof_common.logger import get_logger +from msprof_analyze.cluster_analyse.common_func.table_constant import TableConstant +from msprof_analyze.cluster_analyse.common_func.time_range_calculator import CommunicationTimeRange +from msprof_analyze.prof_common.constant import Constant + +logger = get_logger() + + +class MsprofStepTraceTimeAdapter: + COMPUTE = "Computing" + COMM_NOT_OVERLAP = "Communication(Not Overlapped)" + OVERLAPPED = "Overlapped" + COMMUNICATION = "Communication" + FREE = "Free" + STAGE = "Stage" + BUBBLE = "Bubble" + COMM_NOT_OVERLAP_EXCLUDE_RECEIVE = "Communication(Not Overlapped and Exclude Receive)" + PREPARE = "Preparing" + STEP = "Step" + + def __init__(self, file_path): + self.file_path = file_path + self._data = {self.STEP: None, self.COMPUTE: 0, self.COMM_NOT_OVERLAP: 0, self.OVERLAPPED: 0, + self.COMMUNICATION: 0, self.FREE: 0, self.STAGE: 0, self.BUBBLE: 0, + self.COMM_NOT_OVERLAP_EXCLUDE_RECEIVE: 0, self.PREPARE: 0} + + def generate_step_trace_time_data(self): + json_str = [] + for file_path in self.file_path: + json_str.extend(FileManager.read_json_file(file_path)) + receive_comm = [] + analysis_data = {} + for data in json_str: + event_name = data.get("name", "") + if event_name in {self.COMMUNICATION, self.COMPUTE, self.FREE, self.COMM_NOT_OVERLAP}: + analysis_data.setdefault(event_name, []).append(data) + elif event_name.startswith('hcom_receive'): + receive_comm.append(data) + for event_type, event_list in analysis_data.items(): + self._data[event_type] = sum((convert_to_float(event.get("dur", 0)) for event in event_list)) + self._data[self.BUBBLE] = sum((convert_to_float(event.get("dur", 0)) for event in receive_comm)) + self._data[self.COMM_NOT_OVERLAP_EXCLUDE_RECEIVE] = self._data[self.COMM_NOT_OVERLAP] - self._data[self.BUBBLE] + self._data[self.OVERLAPPED] = self._data[self.COMMUNICATION] - self._data[self.COMM_NOT_OVERLAP] + e2e_time = self._data[self.FREE] + self._data[self.COMPUTE] + self._data[self.COMM_NOT_OVERLAP] + self._data[self.STAGE] = e2e_time - self._data[self.BUBBLE] + return [StepTraceTimeBean(self._data)] + + +class MsprofStepTraceTimeDBAdapter(MsprofStepTraceTimeAdapter): + OP_NAME = 0 + START_NS = 1 + END_NS = 2 + + def __init__(self, file_path): + super().__init__(file_path) + self.task_db_con = None + self.task_db_curs = None + self.string_id_map = None + self.compute_task_info = None + self.communication_op_info = None + + def generate_step_trace_time_data(self): + try: + self._init_task_info_from_db() + except Exception as err: + logger.error(err) + DBManager.destroy_db_connect(self.task_db_con, self.task_db_curs) + return [] + origin_compute_data = self._get_compute_data() + origin_communication_data, bubble_data = self._get_communication_data() + compute_data = RangeCaculator.merge_continuous_intervals(origin_compute_data) + self._data[self.COMPUTE] = sum(data.end_ts - data.start_ts for data in compute_data) + communication_data = RangeCaculator.merge_continuous_intervals(origin_communication_data) + self._data[self.COMMUNICATION] = sum(data.end_ts - data.start_ts for data in communication_data) + pure_communication_data, free_data = RangeCaculator.compute_pipeline_overlap(communication_data, compute_data) + self._data[self.COMM_NOT_OVERLAP] = sum(data.end_ts - data.start_ts for data in pure_communication_data) + self._data[self.FREE] = sum(data.end_ts - data.start_ts for data in free_data) + self._data[self.BUBBLE] = sum(data.end_ts - data.start_ts for data in bubble_data) + self._data[self.COMM_NOT_OVERLAP_EXCLUDE_RECEIVE] = self._data[self.COMM_NOT_OVERLAP] - self._data[self.BUBBLE] + self._data[self.OVERLAPPED] = self._data[self.COMMUNICATION] - self._data[self.COMM_NOT_OVERLAP] + e2e_time = self._data[self.FREE] + self._data[self.COMPUTE] + self._data[self.COMM_NOT_OVERLAP] + self._data[self.STAGE] = e2e_time - self._data[self.BUBBLE] + return [[self._data[self.STEP], self._data[self.COMPUTE] / Constant.NS_TO_US, + self._data[self.COMM_NOT_OVERLAP] / Constant.NS_TO_US, self._data[self.OVERLAPPED] / Constant.NS_TO_US, + self._data[self.COMMUNICATION] / Constant.NS_TO_US, self._data[self.FREE] / Constant.NS_TO_US, + self._data[self.STAGE] / Constant.NS_TO_US, self._data[self.BUBBLE] / Constant.NS_TO_US, + self._data[self.COMM_NOT_OVERLAP_EXCLUDE_RECEIVE] / Constant.NS_TO_US, + self._data[self.PREPARE] / Constant.NS_TO_US]] + + def _init_task_info_from_db(self): + db_path = self.file_path.get(Constant.PROFILER_DB_PATH) + conn, curs = DBManager.create_connect_db(db_path) + if not (conn and curs): + logger.warning(f"Failed to connect to db file: {db_path}") + return + self.task_db_con = conn + self.task_db_curs = curs + if DBManager.judge_table_exists(curs, TableConstant.TABLE_STRING_IDS): + sql = "select id, value from {}".format(TableConstant.TABLE_STRING_IDS) + string_id_data = DBManager.fetch_all_data(curs, sql, is_dict=False) + self.string_id_map = {data[0]: data[1] for data in string_id_data} + if DBManager.judge_table_exists(curs, TableConstant.TABLE_COMPUTE_TASK_INFO): + sql = f"select TASK.startNs, TASK.endNs from {TableConstant.TABLE_COMPUTE_TASK_INFO} LEFT JOIN " \ + f"{TableConstant.TABLE_TASK} on {TableConstant.TABLE_TASK}.globalTaskId == " \ + f"{TableConstant.TABLE_COMPUTE_TASK_INFO}.globalTaskId" + self.compute_task_info = DBManager.fetch_all_data(curs, sql, is_dict=False) + if DBManager.judge_table_exists(curs, TableConstant.TABLE_COMMUNICATION_OP): + sql = "select opName, startNs, endNs from {}".format(TableConstant.TABLE_COMMUNICATION_OP) + self.communication_op_info = DBManager.fetch_all_data(curs, sql, is_dict=False) + DBManager.destroy_db_connect(conn, curs) + + def _get_communication_data(self): + communication_data = [] + bubble_data = [] + for op_info in self.communication_op_info: + op_start_time = op_info[self.START_NS] + time_range = RangeCaculator.generate_time_range( + op_start_time, op_info[self.END_NS], class_range=CommunicationTimeRange) + communication_data.append(time_range) + op_name = self.string_id_map.get(op_info[self.OP_NAME], '') + if op_name.startswith('hcom_receive'): + bubble_data.append(time_range) + return communication_data, bubble_data + + def _get_compute_data(self): + compute_data = [] + for compute_task in self.compute_task_info: + compute_data.append(RangeCaculator.generate_time_range(compute_task[0], compute_task[1])) + return compute_data diff --git a/profiler/msprof_analyze/cluster_analyse/analysis/step_trace_time_analysis.py b/profiler/msprof_analyze/cluster_analyse/analysis/step_trace_time_analysis.py index 5168f63aef5113221766252ac6ef35f8d9504147..0ee79b0f2d06d1acb03b5d4f1635cbabd5341206 100644 --- a/profiler/msprof_analyze/cluster_analyse/analysis/step_trace_time_analysis.py +++ b/profiler/msprof_analyze/cluster_analyse/analysis/step_trace_time_analysis.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +import re from msprof_analyze.prof_common.db_manager import DBManager from msprof_analyze.cluster_analyse.common_func.utils import increase_shared_value @@ -21,6 +22,9 @@ from msprof_analyze.cluster_analyse.prof_bean.step_trace_time_bean import StepTr from msprof_analyze.prof_common.constant import Constant from msprof_analyze.prof_common.file_manager import FileManager from msprof_analyze.prof_common.logger import get_logger +from msprof_analyze.cluster_analyse.analysis.msprof_step_trace_time_adapter import MsprofStepTraceTimeAdapter +from msprof_analyze.cluster_analyse.cluster_data_preprocess.msprof_data_preprocessor import MsprofDataPreprocessor +from msprof_analyze.cluster_analyse.analysis.msprof_step_trace_time_adapter import MsprofStepTraceTimeDBAdapter logger = get_logger() @@ -35,11 +39,13 @@ class StepTraceTimeAnalysis: self.collection_path = param.get(Constant.COLLECTION_PATH) self.cluster_analysis_output_path = param.get(Constant.CLUSTER_ANALYSIS_OUTPUT_PATH) self.data_map = param.get(Constant.DATA_MAP) - self.communication_group = param.get(Constant.COMM_DATA_DICT, {}).get(Constant.COMMUNICATION_GROUP) + self.communication_group = param.get(Constant.COMM_DATA_DICT, {}).get(Constant.COMMUNICATION_GROUP, {}) self.step_time_dict = {} self.step_data_list = [] self.data_type = param.get(Constant.DATA_TYPE) self.distributed_args = None + self.is_msprof = param.get(Constant.IS_MSPROF) + self.is_mindspore = param.get(Constant.IS_MINDSPORE) @staticmethod def get_max_data_row(data_group_list: list): @@ -50,6 +56,26 @@ class StepTraceTimeAnalysis: ret.append(max(item)) return ret + @staticmethod + def find_msprof_json(path): + msprof_pattern = r'^msprof_\d{14}\.json$' + msprof_slice_pattern = r'^msprof_slice_\d{1}_\d{14}\.json$' + msprof_dict, msprof_slice_dict = {}, {} + for file_name in os.listdir(path): + if re.match(msprof_pattern, file_name): + timestamp = re.search(r"\d{14}", file_name).group() + msprof_dict.setdefault(timestamp, []).append(os.path.join(path, file_name)) + elif re.match(msprof_slice_pattern, file_name): + timestamp = re.search(r"\d{14}", file_name).group() + msprof_slice_dict.setdefault(timestamp, []).append(os.path.join(path, file_name)) + if msprof_dict: + max_timestamp = max(msprof_dict.keys()) + return msprof_dict.get(max_timestamp) + if msprof_slice_dict: + max_timestamp = max(msprof_slice_dict.keys()) + return msprof_slice_dict.get(max_timestamp) + return [] + def run(self, completed_processes, lock): self.load_step_trace_time_data() self.analyze_step_time() @@ -132,19 +158,31 @@ class StepTraceTimeAnalysis: metadata = FileManager.read_json_file(metadata_path) self.distributed_args = metadata.get(Constant.DISTRIBUTED_ARGS, None) if metadata else None if self.data_type == Constant.TEXT: - step_time_file = os.path.join(profiling_dir_path, Constant.SINGLE_OUTPUT, Constant.STEP_TIME_CSV) - if os.path.exists(step_time_file): - self.step_time_dict[rank_id] = FileManager.read_csv_file(step_time_file, StepTraceTimeBean) + if self.is_msprof: + msprof_json = self.find_msprof_json(os.path.join(profiling_dir_path, "mindstudio_profiler_output")) + self.step_time_dict[rank_id] = MsprofStepTraceTimeAdapter( + msprof_json).generate_step_trace_time_data() + else: + step_time_file = os.path.join(profiling_dir_path, Constant.SINGLE_OUTPUT, Constant.STEP_TIME_CSV) + if os.path.exists(step_time_file): + self.step_time_dict[rank_id] = FileManager.read_csv_file(step_time_file, StepTraceTimeBean) else: - step_time_file = os.path.join(profiling_dir_path, Constant.SINGLE_OUTPUT, - Constant.DB_COMMUNICATION_ANALYZER) - if (os.path.exists(step_time_file) and - DBManager.check_tables_in_db(step_time_file, Constant.TABLE_STEP_TRACE)): - conn, cursor = DBManager.create_connect_db(step_time_file) - sql = "select * from {0}".format(Constant.TABLE_STEP_TRACE) - data = DBManager.fetch_all_data(cursor, sql, is_dict=False) - self.step_time_dict[rank_id] = data - DBManager.destroy_db_connect(conn, cursor) + if self.is_msprof or self.is_mindspore: + profiler_db = MsprofDataPreprocessor.get_msprof_profiler_db_path(profiling_dir_path) if \ + self.is_msprof else os.path.join(profiling_dir_path, Constant.SINGLE_OUTPUT, + f"ascend_mindspore_profiler_{rank_id}.db") + self.step_time_dict[rank_id] = MsprofStepTraceTimeDBAdapter( + {Constant.PROFILER_DB_PATH: profiler_db}).generate_step_trace_time_data() + else: + step_time_file = os.path.join(profiling_dir_path, Constant.SINGLE_OUTPUT, + Constant.DB_COMMUNICATION_ANALYZER) + if (os.path.exists(step_time_file) and + DBManager.check_tables_in_db(step_time_file, Constant.TABLE_STEP_TRACE)): + conn, cursor = DBManager.create_connect_db(step_time_file) + sql = "select * from {0}".format(Constant.TABLE_STEP_TRACE) + data = DBManager.fetch_all_data(cursor, sql, is_dict=False) + self.step_time_dict[rank_id] = data + DBManager.destroy_db_connect(conn, cursor) if not self.step_time_dict.get(rank_id): logger.warning("Rank %s does not have a valid step_trace_time data in %s file.", str(rank_id), str(self.data_type)) diff --git a/profiler/msprof_analyze/cluster_analyse/cluster_analysis.py b/profiler/msprof_analyze/cluster_analyse/cluster_analysis.py index d7d71908506256eca3b8bd884a593188546189ce..ce085d015aad65a2fbf8acde6a6247b006d2f746 100644 --- a/profiler/msprof_analyze/cluster_analyse/cluster_analysis.py +++ b/profiler/msprof_analyze/cluster_analyse/cluster_analysis.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import argparse +import copy import os import sys @@ -21,6 +22,7 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath( from msprof_analyze.cluster_analyse.analysis.analysis_facade import AnalysisFacade from msprof_analyze.cluster_analyse.cluster_data_preprocess.pytorch_data_preprocessor import PytorchDataPreprocessor from msprof_analyze.cluster_analyse.cluster_data_preprocess.mindspore_data_preprocessor import MindsporeDataPreprocessor +from msprof_analyze.cluster_analyse.cluster_data_preprocess.msprof_data_preprocessor import MsprofDataPreprocessor from msprof_analyze.cluster_analyse.communication_group.communication_group_generator import CommunicationGroupGenerator from msprof_analyze.prof_common.additional_args_manager import AdditionalArgsManager from msprof_analyze.prof_common.constant import Constant @@ -47,6 +49,7 @@ ALL_FEATURE_LIST = COMM_FEATURE_LIST + get_all_recipes() class Interface: ASCEND_PT = "ascend_pt" ASCEND_MS = "ascend_ms" + PROF = "PROF_" def __init__(self, params: dict): self.collection_path = PathManager.get_realpath(params.get(Constant.PROFILING_PATH)) @@ -58,7 +61,6 @@ class Interface: self.matrix_ops = [] self.origin_params = params self.cluster_analysis_output_path = self.get_cluster_analysis_output_path(params) - self.force = params.get(Constant.FORCE, False) AdditionalArgsManager().init(params) def get_cluster_analysis_output_path(self, params): @@ -70,27 +72,41 @@ class Interface: def allocate_prof_data(self): ascend_pt_dirs = [] ascend_ms_dirs = [] + prof_dirs = [] for root, dirs, _ in os.walk(self.collection_path): for dir_name in dirs: if dir_name.endswith(self.ASCEND_PT): ascend_pt_dirs.append(os.path.join(root, dir_name)) if dir_name.endswith(self.ASCEND_MS): ascend_ms_dirs.append(os.path.join(root, dir_name)) + if dir_name.startswith(self.PROF): + prof_dirs.append(os.path.join(root, dir_name)) pytorch_processor = PytorchDataPreprocessor(ascend_pt_dirs) pt_data_map = pytorch_processor.get_data_map() - data_type = pytorch_processor.get_data_type() - ms_data_map = MindsporeDataPreprocessor(ascend_ms_dirs).get_data_map() + pt_data_type = pytorch_processor.get_data_type() + ms_processor = MindsporeDataPreprocessor(ascend_ms_dirs) + ms_data_map = ms_processor.get_data_map() + ms_data_type = ms_processor.get_data_type() if pt_data_map and ms_data_map: logger.error("Can not analyze pytorch and mindspore meantime.") - return [] - return (pt_data_map, data_type) if pt_data_map else (ms_data_map, Constant.TEXT) + return {} + if pt_data_map: + return {Constant.DATA_MAP: pt_data_map, Constant.DATA_TYPE: pt_data_type, Constant.IS_MSPROF: False} + if ms_data_map: + return {Constant.DATA_MAP: ms_data_map, Constant.DATA_TYPE: ms_data_type, Constant.IS_MSPROF: False, + Constant.IS_MINDSPORE: True} + msprof_processor = MsprofDataPreprocessor(prof_dirs) + prof_data_map = msprof_processor.get_data_map() + prof_data_type = msprof_processor.get_data_type() + return {Constant.DATA_MAP: prof_data_map, Constant.DATA_TYPE: prof_data_type, Constant.IS_MSPROF: True} def run(self): PathManager.check_input_directory_path(self.collection_path) PathManager.check_input_directory_path(self.cluster_analysis_output_path) PathManager.check_path_owner_consistent([self.collection_path, self.cluster_analysis_output_path]) - data_map, data_type = self.allocate_prof_data() + data_dict = self.allocate_prof_data() + data_map, data_type = data_dict.get(Constant.DATA_MAP), data_dict.get(Constant.DATA_TYPE) if not data_map: logger.warning("Can not get rank info or profiling data.") return @@ -98,34 +114,32 @@ class Interface: logger.error("The current folder contains both DB and other files. Please check.") return - params = { + params = copy.deepcopy(self.origin_params) + params.update({ Constant.COLLECTION_PATH: self.collection_path, + Constant.ANALYSIS_MODE: self.analysis_mode, Constant.DATA_MAP: data_map, Constant.DATA_TYPE: data_type, - Constant.ANALYSIS_MODE: self.analysis_mode, - Constant.CLUSTER_ANALYSIS_OUTPUT_PATH: self.cluster_analysis_output_path, - Constant.DATA_SIMPLIFICATION: self.origin_params.get(Constant.DATA_SIMPLIFICATION, False), - Constant.FORCE: self.force - } - + Constant.IS_MSPROF: data_dict.get(Constant.IS_MSPROF, False), + Constant.IS_MINDSPORE: data_dict.get(Constant.IS_MINDSPORE, False), + Constant.CLUSTER_ANALYSIS_OUTPUT_PATH: self.cluster_analysis_output_path + }) if self.analysis_mode in COMM_FEATURE_LIST: FileManager.create_output_dir(self.cluster_analysis_output_path) PathManager.check_path_writeable(self.cluster_analysis_output_path) logger.info("Begin generate communication data.") - comm_data_dict = CommunicationGroupGenerator(params).generate() - logger.info("Communication data read completed.") - params[Constant.COMM_DATA_DICT] = comm_data_dict + if data_type == Constant.TEXT or not params.get(Constant.DATA_SIMPLIFICATION): + comm_data_dict = CommunicationGroupGenerator(params).generate() + logger.info("Communication data read completed.") + params[Constant.COMM_DATA_DICT] = comm_data_dict AnalysisFacade(params).cluster_analyze() logger.info("The cluster analysis result file has been generated: %s", self.cluster_analysis_output_path) - return - - if data_type != Constant.DB: + elif data_type == Constant.TEXT: logger.error("The current analysis node only supports DB as input data. Please check.") - return - FileManager.create_output_dir(self.cluster_analysis_output_path, is_overwrite=True) - self.origin_params.update(params) - AnalysisFacade(self.origin_params).recipe_analyze() + else: + FileManager.create_output_dir(self.cluster_analysis_output_path, is_overwrite=True) + AnalysisFacade(params).recipe_analyze() def cluster_analysis_main(): @@ -146,7 +160,12 @@ def cluster_analysis_main(): args, extra_args = parser.parse_known_args() parameter = vars(args) - parameter[Constant.EXTRA_ARGS] = extra_args + if extra_args: + if parameter.get(Constant.MODE) in COMM_FEATURE_LIST: + unknown_args = " ".join(extra_args) + logger.warning(f"Invalid parameters: {unknown_args}. It will not have any effect.") + else: + parameter[Constant.EXTRA_ARGS] = extra_args Interface(parameter).run() diff --git a/profiler/msprof_analyze/cluster_analyse/cluster_data_preprocess/mindspore_data_preprocessor.py b/profiler/msprof_analyze/cluster_analyse/cluster_data_preprocess/mindspore_data_preprocessor.py index eaa14fb71f9583b65b0ed18c6ad9727d913eb2fe..c22ecb1ad907f8d20bfe2b380db798928d0343b9 100644 --- a/profiler/msprof_analyze/cluster_analyse/cluster_data_preprocess/mindspore_data_preprocessor.py +++ b/profiler/msprof_analyze/cluster_analyse/cluster_data_preprocess/mindspore_data_preprocessor.py @@ -12,11 +12,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os +import re from collections import defaultdict from msprof_analyze.cluster_analyse.cluster_data_preprocess.data_preprocessor import DataPreprocessor - from msprof_analyze.prof_common.logger import get_logger +from msprof_analyze.prof_common.constant import Constant +from msprof_analyze.prof_common.file_manager import FileManager logger = get_logger() @@ -25,6 +28,15 @@ class MindsporeDataPreprocessor(DataPreprocessor): def __init__(self, path_list: list): super().__init__(path_list) + self.data_type = set() + + @classmethod + def get_msprof_dir(cls, profiling_path): + prof_pattren = r"^PROF_\d+_\d+_[0-9a-zA-Z]+" + for file_name in os.listdir(profiling_path): + if re.match(prof_pattren, file_name): + return os.path.join(profiling_path, file_name) + return "" def get_data_map(self) -> dict: rank_id_map = defaultdict(list) @@ -33,6 +45,15 @@ class MindsporeDataPreprocessor(DataPreprocessor): if rank_id < 0: logger.error("fail to get rankid or rankid invalid.") continue + for file_name in os.listdir(dir_name): + if file_name.startswith(self.PROFILER_INFO_HEAD) and file_name.endswith(self.PROFILER_INFO_EXTENSION): + file_path = os.path.join(dir_name, file_name) + config = FileManager.read_json_file(file_path) + export_type = (config.get(Constant.PROFILER_PARAMETER, {}).get(Constant.EXPORT_TYPE, Constant.TEXT)) + if isinstance(export_type, list): + self.data_type.add(Constant.DB if Constant.DB in export_type else Constant.TEXT) + else: + self.data_type.add(export_type) rank_id_map[rank_id].append(dir_name) try: @@ -42,3 +63,8 @@ class MindsporeDataPreprocessor(DataPreprocessor): except Exception as e: raise RuntimeError("Found invalid directory name!") from e return self.data_map + + def get_data_type(self): + if len(self.data_type) == 1: + return self.data_type.pop() + return Constant.INVALID diff --git a/profiler/msprof_analyze/cluster_analyse/cluster_data_preprocess/msprof_data_preprocessor.py b/profiler/msprof_analyze/cluster_analyse/cluster_data_preprocess/msprof_data_preprocessor.py new file mode 100644 index 0000000000000000000000000000000000000000..f751de56fe3d622e705c481220cf4a6760b163d0 --- /dev/null +++ b/profiler/msprof_analyze/cluster_analyse/cluster_data_preprocess/msprof_data_preprocessor.py @@ -0,0 +1,120 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import re +from collections import defaultdict + +from msprof_analyze.cluster_analyse.cluster_data_preprocess.data_preprocessor import DataPreprocessor +from msprof_analyze.prof_common.constant import Constant +from msprof_analyze.prof_common.logger import get_logger +from msprof_analyze.prof_common.file_manager import FileManager + +logger = get_logger() + + +class MsprofDataPreprocessor(DataPreprocessor): + DEVICE_PATTERN = r"device_\d{1,2}$" + INFO_JSON_PATTERN = r"^info\.json\.\d{1,2}$" + DB_PATTERN = r"^msprof_\d{1,20}\.db$" + + def __init__(self, path_list: list): + super().__init__(path_list) + self.data_type = set() + + @classmethod + def get_msprof_profiler_db_path(cls, data_path): + msprof_db_pattern = r"^msprof_\d{14}\.db$" + msprof_db_list = [] + for file_name in os.listdir(data_path): + if re.match(msprof_db_pattern, file_name): + msprof_db_list.append(file_name) + if msprof_db_list: + msprof_db_list.sort(key=lambda x: x.split(".")[0].split("_")[-1]) + return os.path.join(data_path, msprof_db_list[-1]) + return "" + + @classmethod + def get_device_id(cls, data_path): + for file_name in os.listdir(data_path): + if re.match(cls.DEVICE_PATTERN, file_name): + return int(file_name.split("_")[-1]) + return None + + def get_data_map(self) -> dict: + prof_data_uid = defaultdict(list) + prof_data_rank = defaultdict(list) + for dir_name in self.path_list: + info_json_file = self._find_info_json_file(dir_name) + if not info_json_file: + logger.error(f"Profiling data in not completed, please check the info.json file in the path {dir_name}") + continue + + if self._check_db_type(dir_name): + self.data_type.add(Constant.DB) + elif os.path.exists(os.path.join(dir_name, "mindstudio_profiler_output")): + if os.path.exists(os.path.join(dir_name, "analyze")): + self.data_type.add(Constant.TEXT) + else: + logger.error(f"The profiling data has not been fully parsed. You can parse it by executing " + f"the following command: msprof --analyze=on --output={dir_name}") + continue + else: + logger.error(f"The profiling data has not been fully parsed. You can parse it by executing " + f"the following command: msprof --export=on --output={dir_name}; " + f"msprof --analyze=on --output={dir_name}") + continue + info_json = FileManager.read_json_file(info_json_file) + rank_id = info_json.get("rank_id") + if rank_id != Constant.INVALID_RETURN: + prof_data_rank[rank_id].append(dir_name) + continue + host_id = info_json.get("hostUid") + device_id = int(os.path.basename(info_json_file).split(".")[-1]) + prof_data_uid[(host_id, device_id)].append(dir_name) + + if prof_data_rank: + for rank_id, dir_list in prof_data_rank.items(): + dir_list.sort(key=lambda x: x.split('_')[-2]) + self.data_map[rank_id] = dir_list[0] + else: + ordered_keys = sorted(prof_data_uid.keys(), key=lambda x: (x[0], x[1])) + rank_id = 0 + for key in ordered_keys: + dir_list = prof_data_uid[key] + dir_list.sort(key=lambda x: x.split('_')[-2]) + self.data_map[rank_id] = dir_list[0] + rank_id += 1 + return self.data_map + + def get_data_type(self): + if len(self.data_type) == 1: + return self.data_type.pop() + return Constant.INVALID + + def _find_info_json_file(self, dir_name): + for file_name in os.listdir(dir_name): + file_path = os.path.join(dir_name, file_name) + if not os.path.isdir(file_path): + continue + for device_file in os.listdir(file_path): + if re.match(self.INFO_JSON_PATTERN, device_file): + return os.path.join(dir_name, file_name, device_file) + return None + + def _check_db_type(self, dir_name): + for file_name in os.listdir(dir_name): + if re.match(self.DB_PATTERN, file_name): + return True + return False diff --git a/profiler/msprof_analyze/cluster_analyse/common_func/context.py b/profiler/msprof_analyze/cluster_analyse/common_func/context.py index b41972c0d21ac73fb5b9f0291cddec8d9a06b94a..cde351508c0e814867d224b7ce5c454e0813a71b 100644 --- a/profiler/msprof_analyze/cluster_analyse/common_func/context.py +++ b/profiler/msprof_analyze/cluster_analyse/common_func/context.py @@ -84,7 +84,12 @@ class ConcurrentContext(Context): def map(self, func, *iterables, **kwargs): partial_func = partial(func, **kwargs) - return list(self._executor.map(partial_func, *iterables)) + try: + res = list(self._executor.map(partial_func, *iterables)) + except Exception as err: + logger.error(err) + return [] + return res def wait(self, waitable): return waitable diff --git a/profiler/msprof_analyze/cluster_analyse/common_func/table_constant.py b/profiler/msprof_analyze/cluster_analyse/common_func/table_constant.py index 27daae78cb9004a2f713b9ebf9bc6ab916dd9325..eae250b0f7a7b5ab138aeae0eed9aba863e1cbf1 100644 --- a/profiler/msprof_analyze/cluster_analyse/common_func/table_constant.py +++ b/profiler/msprof_analyze/cluster_analyse/common_func/table_constant.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. class TableConstant: - RANK_SET = "rank_set" STEP = "step" RANK_ID = "rank_id" @@ -39,3 +38,15 @@ class TableConstant: DST_RANK = "dst_rank" TRANSPORT_TYPE = "transport_type" OPNAME = "op_name" + GROUP_ID = "group_id" + PG_NAME = "pg_name" + NAME = "name" + VALUE = "value" + + # table name + TABLE_STRING_IDS = "STRING_IDS" + TABLE_COMPUTE_TASK_INFO = "COMPUTE_TASK_INFO" + TABLE_COMMUNICATION_OP = "COMMUNICATION_OP" + TABLE_TASK = "TASK" + TABLE_META_DATA = "META_DATA" + TABLE_COMM_ANALYZER_MATRIX = "CommAnalyzerMatrix" diff --git a/profiler/msprof_analyze/cluster_analyse/common_func/tables_config.py b/profiler/msprof_analyze/cluster_analyse/common_func/tables_config.py index 42c509694cfd1a896f60ec6b282de040f22204b6..7a40d9977d89883f8a33b743e5e80d73f2bfa328 100644 --- a/profiler/msprof_analyze/cluster_analyse/common_func/tables_config.py +++ b/profiler/msprof_analyze/cluster_analyse/common_func/tables_config.py @@ -139,6 +139,7 @@ class TablesConfig: ("pg_name", "TEXT, null") ], "ClusterBaseInfoMap": [ - ("distributed_args", "TEXT, null") + ("key", "TEXT, null"), + ("value", "TEXT, null") ] } diff --git a/profiler/msprof_analyze/cluster_analyse/common_func/time_range_calculator.py b/profiler/msprof_analyze/cluster_analyse/common_func/time_range_calculator.py new file mode 100644 index 0000000000000000000000000000000000000000..36ee94067ac37aa57f04cc1f64855be19b9807e0 --- /dev/null +++ b/profiler/msprof_analyze/cluster_analyse/common_func/time_range_calculator.py @@ -0,0 +1,99 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + +DEFAULT_INT_VALUE = -1 + + +@dataclass +class TimeRange: + start_ts: int = DEFAULT_INT_VALUE + end_ts: int = DEFAULT_INT_VALUE + + +class CommunicationTimeRange(TimeRange): + + def __init__(self): + super().__init__() + + +class RangeCaculator: + + @staticmethod + def generate_time_range(start, end, class_range=TimeRange): + time_range = class_range() + time_range.start_ts, time_range.end_ts = start, end + return time_range + + @staticmethod + def merge_continuous_intervals(time_range_list: list): + result = [] + if not time_range_list: + return result + time_range_list.sort(key=lambda x: x.start_ts) + current_range = time_range_list[0] + for time_range in time_range_list: + if time_range.start_ts <= current_range.end_ts: + current_range.end_ts = max(current_range.end_ts, time_range.end_ts) + else: + result.append(current_range) + current_range = time_range + result.append(current_range) + return result + + @staticmethod + def compute_pipeline_overlap(communication_range, compute_range): + free_time_range = [] + pure_communication_range = [] + time_range_list = sorted(communication_range + compute_range, key=lambda x: x.start_ts) + if not time_range_list: + return pure_communication_range, free_time_range + + min_range = time_range_list.pop(0) + for time_range in time_range_list: + if min_range.end_ts - time_range.start_ts < 0: + free_time_range.append( + RangeCaculator.generate_time_range(min_range.end_ts, time_range.start_ts) + ) + if isinstance(min_range, CommunicationTimeRange): + pure_communication_range.append( + RangeCaculator.generate_time_range(min_range.start_ts, min_range.end_ts) + ) + min_range = time_range + continue + if min_range.end_ts - time_range.end_ts < 0: + if isinstance(min_range, CommunicationTimeRange): + pure_communication_range.append( + RangeCaculator.generate_time_range(min_range.start_ts, time_range.start_ts) + ) + min_range = RangeCaculator.generate_time_range(min_range.end_ts, time_range.end_ts) + if isinstance(time_range, CommunicationTimeRange): + min_range = RangeCaculator.generate_time_range( + min_range.end_ts, time_range.end_ts, class_range=CommunicationTimeRange + ) + else: + if isinstance(min_range, CommunicationTimeRange): + pure_communication_range.append( + RangeCaculator.generate_time_range(min_range.start_ts, time_range.start_ts) + ) + min_range = RangeCaculator.generate_time_range( + time_range.end_ts, min_range.end_ts, class_range=CommunicationTimeRange + ) + if isinstance(time_range, CommunicationTimeRange): + min_range = RangeCaculator.generate_time_range(time_range.end_ts, min_range.end_ts) + if isinstance(min_range, CommunicationTimeRange): + pure_communication_range.append(min_range) + return pure_communication_range, free_time_range diff --git a/profiler/msprof_analyze/cluster_analyse/communication_group/base_communication_group.py b/profiler/msprof_analyze/cluster_analyse/communication_group/base_communication_group.py index 2c02bfdbf1bdd22dec838b56b7eb0c9c9872cac2..bdb6371b1e587026d7805d5e4c7573ddb1aa5143 100644 --- a/profiler/msprof_analyze/cluster_analyse/communication_group/base_communication_group.py +++ b/profiler/msprof_analyze/cluster_analyse/communication_group/base_communication_group.py @@ -39,10 +39,10 @@ class BaseCommunicationGroup: self.data_map = params.get(Constant.DATA_MAP) self.data_type = params.get(Constant.DATA_TYPE) self.analysis_mode = params.get(Constant.ANALYSIS_MODE) + self.is_msprof = params.get(Constant.IS_MSPROF) self.rank_comm_dir_dict = {} - self.p2p_link = [] self.collective_group_dict = defaultdict(set) - self.p2p_comm_group = [] + self.p2p_group_dict = defaultdict(set) self.communication_group = {} self.parallel_group_info = {} self.communication_ops = [] @@ -54,8 +54,9 @@ class BaseCommunicationGroup: comm_op_dirs = [] for rank_id, profiling_dir_path in self.data_map.items(): if self.data_type == Constant.TEXT: - comm_dir = os.path.join(profiling_dir_path, Constant.SINGLE_OUTPUT, Constant.COMM_JSON) - matrix_dir = os.path.join(profiling_dir_path, Constant.SINGLE_OUTPUT, Constant.COMM_MATRIX_JSON) + output_dir = "analyze" if self.is_msprof else Constant.SINGLE_OUTPUT + comm_dir = os.path.join(profiling_dir_path, output_dir, Constant.COMM_JSON) + matrix_dir = os.path.join(profiling_dir_path, output_dir, Constant.COMM_MATRIX_JSON) else: comm_dir = os.path.join(profiling_dir_path, Constant.SINGLE_OUTPUT, Constant.DB_COMMUNICATION_ANALYZER) matrix_dir = comm_dir @@ -69,52 +70,11 @@ class BaseCommunicationGroup: with Pool(processes=max_processes) as p: self.rank_comm_dir_dict = p.map(self.read_communication_func, comm_op_dirs) - def set_p2p_groups(self): - self.p2p_link = sorted(self.p2p_link, key=lambda x: min(x)) - while self.p2p_link: - union_set = deepcopy(self.p2p_link[0]) - rm_list = [self.p2p_link[0]] - for _, link_rank_set_x in enumerate(self.p2p_link[1:]): - if UnionFind.is_connected(link_rank_set_x, union_set): - union_set = union_set.union(link_rank_set_x) - rm_list.append(link_rank_set_x) - self.p2p_comm_group.append(union_set) - self.p2p_link = [element for element in self.p2p_link if element not in rm_list] - - def generate_collective_communication_group(self): + def generate_communication_group(self): self.communication_group[Constant.COLLECTIVE] = \ [list(group) for _, group in self.collective_group_dict.items()] - - def generate_p2p_communication_group(self): - stage_group = {} - for _, rank_set in self.collective_group_dict.items(): - if not self.whether_valid_comm_group(rank_set): - continue - unioned_set = set() - remove_key = [] - for first_rank, stage in stage_group.items(): - if UnionFind.is_connected(rank_set, stage): - unioned_set = UnionFind.union(rank_set, stage, unioned_set) - remove_key.append(first_rank) - if unioned_set: - for key in remove_key: - del stage_group[key] - stage_group[min(unioned_set)] = unioned_set - else: - stage_group[min(rank_set)] = rank_set - first_rank_sort_list = sorted([first_rank for first_rank in stage_group]) self.communication_group[Constant.P2P] = \ - [list(stage_group.get(first_rank, {})) for first_rank in first_rank_sort_list] - - def whether_valid_comm_group(self, rank_set: set): - """ - while distinguish which communication group should be used to infer stage info, these group should be ignored: - 1. group can not include more than 1 rank in every single p2p group - """ - for p2p_rank_set in self.p2p_comm_group: - if len(rank_set.intersection(p2p_rank_set)) > 1: - return False - return True + [list(group) for _, group in self.p2p_group_dict.items()] @abstractmethod def read_communication_func(self, params: tuple): @@ -138,7 +98,8 @@ class BaseCommunicationGroup: if not isinstance(step_id_dict, dict): logger.warning("rank%s's communication.json has a wrong data struct.", rank_id) continue - self.get_collective_ops_name(rank_id, step_id_dict.get(Constant.COLLECTIVE)) + self.add_collective_group_rank_map(rank_id, step_id_dict.get(Constant.COLLECTIVE, {})) + self.add_p2p_group_rank_map(rank_id, step_id_dict.get(Constant.P2P, {})) for comm_op_type, comm_op_dict in step_id_dict.items(): self.add_communication_ops(rank_id, step_id, comm_op_type, comm_op_dict) @@ -146,8 +107,10 @@ class BaseCommunicationGroup: if not isinstance(step_id_dict, dict): logger.warning("rank%s's communication_matrix.json has a wrong data struct.", rank_id) continue - self.set_p2p_link(rank_id, step_id, rank_id_matrix_dict) - self.get_collective_ops_name(rank_id, step_id_dict.get(Constant.COLLECTIVE)) + self.add_matrix_ops(rank_id, step_id, step_id_dict) + self.add_collective_group_rank_map(rank_id, step_id_dict.get(Constant.COLLECTIVE, {})) + self.add_p2p_group_rank_map(rank_id, step_id_dict.get(Constant.P2P, {})) + @abstractmethod def dump_data(self): @@ -166,45 +129,25 @@ class BaseCommunicationGroup: self.load_communication_data() self.analyze_communication_data() self.read_parallel_group_info() - self.set_p2p_groups() - self.generate_collective_communication_group() - self.generate_p2p_communication_group() + self.generate_communication_group() self.analyze_parallel_group_info() self.dump_data() return self.collect_comm_data() - def set_p2p_link(self, rank_id: int, step_id: str, rank_id_matrix_dict: dict): - ops = rank_id_matrix_dict.get(step_id, {}) - self.add_matrix_ops(rank_id, step_id, ops) - if not ops: - logger.warning( - "rank%s %s do not have communication matrix ops data.", rank_id, step_id - ) - return - p2p_ops = ops.get(Constant.P2P, {}) - for op_name, link_dict in p2p_ops.items(): - self.append_p2p_link(op_name, link_dict) - - def append_p2p_link(self, op_name, link_dict): - for link in link_dict: - if '-' not in link: - logger.warning("%s has an invalid link key %s!", op_name, link) - break - src_rank = int(link.split('-')[0]) - dst_rank = int(link.split('-')[1]) - if src_rank != dst_rank: - rank_set = {src_rank, dst_rank} - if rank_set in self.p2p_link: - continue - self.p2p_link.append(rank_set) - - def get_collective_ops_name(self, rank_id: int, comm_op_dict: dict): + def add_collective_group_rank_map(self, rank_id: int, comm_op_dict: dict): for comm_op in comm_op_dict: if comm_op.startswith('Total'): continue group_name = comm_op.split('@')[-1] self.collective_group_dict[group_name].add(rank_id) + def add_p2p_group_rank_map(self, rank_id: int, comm_op_dict: dict): + for comm_op in comm_op_dict: + if comm_op.startswith('Total'): + continue + group_name = comm_op.split('@')[-1] + self.p2p_group_dict[group_name].add(rank_id) + def add_communication_ops(self, rank_id: str, step_id: str, comm_op_type: str, comm_op_dict: dict): for comm_op in comm_op_dict: if comm_op.startswith('Total'): @@ -243,6 +186,8 @@ class BaseCommunicationGroup: comm_group_df = pd.DataFrame(columns=comm_group_cols) for group_name, rank_set in self.collective_group_dict.items(): comm_group_df.loc[comm_group_df.shape[0]] = [Constant.COLLECTIVE, list(rank_set), group_name] + for group_name, rank_set in self.p2p_group_dict.items(): + comm_group_df.loc[comm_group_df.shape[0]] = [Constant.P2P, list(rank_set), group_name] # create parallel group dataframe parallel_group_cols = ["group_name", "group_id", "pg_name"] @@ -256,9 +201,6 @@ class BaseCommunicationGroup: # merge by group_name df = pd.merge(comm_group_df, parallel_group_df, on='group_name', how='left') - # add p2p group - for rank_set in self.communication_group[Constant.P2P]: - df.loc[df.shape[0]] = [Constant.P2P, list(rank_set), None, None, None] df.fillna("", inplace=True) self.comm_group_parallel_info_df = df diff --git a/profiler/msprof_analyze/cluster_analyse/communication_group/communication_db_group.py b/profiler/msprof_analyze/cluster_analyse/communication_group/communication_db_group.py index 99b55fb9956fff23ba36d9f4b80ba05caa33562c..f570ce204d95662569392bdbf416200ca66418c3 100644 --- a/profiler/msprof_analyze/cluster_analyse/communication_group/communication_db_group.py +++ b/profiler/msprof_analyze/cluster_analyse/communication_group/communication_db_group.py @@ -100,13 +100,16 @@ class CommunicationDBGroupOptimized(BaseCommunicationGroup): comm_time_data = (time_data, bandwidth_data) return rank_id, comm_time_data, comm_matrix_data - def set_collective_group(self, rank_id: int, time_data: list): + def set_group_rank_map(self, rank_id: int, time_data: list): for single_time_data in time_data: - if single_time_data.get('type') == Constant.P2P: - continue + group_type = single_time_data.get(Constant.TYPE) group_name = single_time_data.get(Constant.GROUP_NAME) - if group_name: + if not group_name: + return + if group_type == Constant.COLLECTIVE: self.collective_group_dict[group_name].add(rank_id) + elif group_type == Constant.P2P: + self.p2p_group_dict[group_name].add(rank_id) def analyze_communication_data(self): for rank_id, comm_time_data, comm_matrix_data in self.rank_comm_dir_dict: @@ -115,7 +118,7 @@ class CommunicationDBGroupOptimized(BaseCommunicationGroup): if not time_data: logger.warning("[WARNING] rank %s has error format in time data.", rank_id) continue - self.set_collective_group(rank_id, time_data) + self.set_group_rank_map(rank_id, time_data) self.communication_ops.extend(self._merge_data_with_rank(rank_id, time_data)) self.bandwidth_data.extend(self._merge_data_with_rank(rank_id, bandwidth_data)) if self.analysis_mode in [Constant.ALL, Constant.COMMUNICATION_MATRIX]: @@ -126,8 +129,8 @@ class CommunicationDBGroupOptimized(BaseCommunicationGroup): if not isinstance(step_id_dict, dict): logger.warning("[WARNING] rank %s has error format in matrix data.", rank_id) continue - self.set_p2p_link(rank_id, step_id, comm_matrix_data) - self.get_collective_ops_name(rank_id, step_id_dict.get(Constant.COLLECTIVE)) + self.add_matrix_ops(rank_id, step_id, step_id_dict) + self.set_group_rank_map(rank_id, time_data) def generate_collective_communication_group(self): collective_group = [] diff --git a/profiler/msprof_analyze/cluster_analyse/communication_group/communication_json_group.py b/profiler/msprof_analyze/cluster_analyse/communication_group/communication_json_group.py index 2975050da0706870136ad4d8e84f28c56ded4718..e6fd3b41eeaddd0f01673e8b6c60a042dfb808f2 100644 --- a/profiler/msprof_analyze/cluster_analyse/communication_group/communication_json_group.py +++ b/profiler/msprof_analyze/cluster_analyse/communication_group/communication_json_group.py @@ -15,9 +15,13 @@ import os from copy import deepcopy - + from msprof_analyze.cluster_analyse.communication_group.base_communication_group import BaseCommunicationGroup from msprof_analyze.prof_common.file_manager import FileManager +from msprof_analyze.cluster_analyse.communication_group.msprof_communication_matrix_adapter import \ + MsprofCommunicationMatrixAdapter +from msprof_analyze.cluster_analyse.communication_group.msprof_communication_time_adapter import \ + MsprofCommunicationTimeAdapter class CommunicationJsonGroup(BaseCommunicationGroup): @@ -42,7 +46,11 @@ class CommunicationJsonGroup(BaseCommunicationGroup): comm_data = {} matrix_data = {} if os.path.exists(comm_json_path) and self.analysis_mode in ["all", "communication_time"]: - comm_data = FileManager.read_json_file(comm_json_path) + comm_data = MsprofCommunicationTimeAdapter( + comm_json_path).generate_comm_time_data() if self.is_msprof else FileManager.read_json_file( + comm_json_path) if os.path.exists(matrix_json_path) and self.analysis_mode in ["all", "communication_matrix"]: - matrix_data = FileManager.read_json_file(matrix_json_path) + matrix_data = MsprofCommunicationMatrixAdapter( + matrix_json_path).generate_comm_matrix_data() if self.is_msprof else FileManager.read_json_file( + matrix_json_path) return rank_id, comm_data, matrix_data diff --git a/profiler/msprof_analyze/cluster_analyse/communication_group/msprof_communication_matrix_adapter.py b/profiler/msprof_analyze/cluster_analyse/communication_group/msprof_communication_matrix_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..7f1aef80b96cbd45dc362e6c507303e4ad8d0424 --- /dev/null +++ b/profiler/msprof_analyze/cluster_analyse/communication_group/msprof_communication_matrix_adapter.py @@ -0,0 +1,102 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import re +from collections import defaultdict + +from msprof_analyze.prof_common.file_manager import FileManager +from msprof_analyze.prof_common.constant import Constant +from msprof_analyze.prof_common.logger import get_logger + +from msprof_analyze.prof_common.utils import compute_ratio + +logger = get_logger() + + +class MsprofCommunicationMatrixAdapter: + P2P_HCOM = ["hcom_send", "hcom_receive", "hcom_batchsendrecv"] + HCCL_PATTERN = r"send|reduce|invalid|broadcast|allreduce|" \ + r"receive|allgather|reducescatter|scatter|alltoall|alltoallv|alltoallvc|batchsendrecv" + BANDWIDTH_GB_S = "Bandwidth(GB/s)" + TRANSPORT_TYPE = "Transport Type" + TRANSIT_SIZE_MB = "Transit Size(MB)" + TRANSIT_TIME_MS = "Transit Time(ms)" + + def __init__(self, file_path): + self.file_path = file_path + + def generate_comm_matrix_data(self): + output_comm_matrix = {"step": {Constant.P2P: {}, Constant.COLLECTIVE: {}}} + comm_matrix_data = FileManager.read_json_file(self.file_path) + split_comm_dict = {Constant.P2P: {}, Constant.COLLECTIVE: {}} + for communication_op, comm_matrix_info in comm_matrix_data.items(): + lower_op_name = communication_op.lower() + if any(lower_op_name.startswith(start_str) for start_str in self.P2P_HCOM): + split_comm_dict[Constant.P2P][communication_op] = comm_matrix_info + elif lower_op_name.startswith(Constant.TOTAL): + continue + else: + split_comm_dict[Constant.COLLECTIVE][communication_op] = comm_matrix_info + output_comm_matrix["step"][Constant.P2P] = self.integrate_matrix_data( + self.get_comm_type(split_comm_dict[Constant.P2P])) + output_comm_matrix["step"][Constant.COLLECTIVE] = self.integrate_matrix_data( + self.get_comm_type(split_comm_dict[Constant.COLLECTIVE])) + return output_comm_matrix + + def get_comm_type(self, op_data: dict) -> dict: + new_comm_op_dict = defaultdict(list) + for communication_op, communication_info in op_data.items(): + match_obj = re.compile(self.HCCL_PATTERN).search((communication_op.lower())) + if match_obj: + comm_op_type = match_obj.group() + else: + comm_op_type = communication_op.split("__")[0] + logger.warning(f"Unknown communication op type: {comm_op_type}") + for link, data in communication_info.items(): + new_comm_op_name = (comm_op_type, communication_op.split("@")[-1], link) + data['Op Name'] = communication_op.split("@")[0] + new_comm_op_dict[new_comm_op_name].append(data) + return new_comm_op_dict + + def integrate_matrix_data(self, new_comm_op_dict: dict): + """integrate the matrix data""" + comm_op_dict = defaultdict(dict) + for new_comm_op_name, data in new_comm_op_dict.items(): + data.sort(key=lambda x: x[self.BANDWIDTH_GB_S], reverse=True) + t_type = data[0].get(self.TRANSPORT_TYPE, '') + t_size = sum(x.get(self.TRANSIT_SIZE_MB, 0) for x in data) + t_time = sum(x.get(self.TRANSIT_TIME_MS, 0) for x in data) + bandwidth = compute_ratio(t_size, t_time) + + link = new_comm_op_name[2] + new_comm_op_name_top1 = f'{new_comm_op_name[0]}-top1@{new_comm_op_name[1]}' + new_comm_op_name_middle = f'{new_comm_op_name[0]}-middle@{new_comm_op_name[1]}' + new_comm_op_name_bottom1 = f'{new_comm_op_name[0]}-bottom1@{new_comm_op_name[1]}' + new_comm_op_name_bottom2 = f'{new_comm_op_name[0]}-bottom2@{new_comm_op_name[1]}' + new_comm_op_name_bottom3 = f'{new_comm_op_name[0]}-bottom3@{new_comm_op_name[1]}' + new_comm_op_name_total = f'{new_comm_op_name[0]}-total@{new_comm_op_name[1]}' + comm_op_dict[new_comm_op_name_top1].update({link: data[0]}) + comm_op_dict[new_comm_op_name_middle].update({link: data[len(data) // 2]}) + comm_op_dict[new_comm_op_name_bottom1].update({link: data[-1]}) + comm_op_dict[new_comm_op_name_total].update({link: { + self.TRANSPORT_TYPE: t_type, + self.TRANSIT_SIZE_MB: t_size, + self.TRANSIT_TIME_MS: t_time, + self.BANDWIDTH_GB_S: bandwidth + }}) + if len(data) >= 2: + comm_op_dict[new_comm_op_name_bottom2].update({link: data[-2]}) + if len(data) >= 3: + comm_op_dict[new_comm_op_name_bottom3].update({link: data[-3]}) + return comm_op_dict diff --git a/profiler/msprof_analyze/cluster_analyse/communication_group/msprof_communication_time_adapter.py b/profiler/msprof_analyze/cluster_analyse/communication_group/msprof_communication_time_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..7b63b700f5cb001eb85da3bf0a75d7bb8ae36ae7 --- /dev/null +++ b/profiler/msprof_analyze/cluster_analyse/communication_group/msprof_communication_time_adapter.py @@ -0,0 +1,38 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from msprof_analyze.prof_common.file_manager import FileManager +from msprof_analyze.prof_common.constant import Constant + + +class MsprofCommunicationTimeAdapter: + P2P_HCOM = ["hcom_send", "hcom_receive", "hcom_batchsendrecv"] + TOTAL = "total" + + def __init__(self, file_path): + self.file_path = file_path + + def generate_comm_time_data(self): + output_communication = {"step": {Constant.P2P: {}, Constant.COLLECTIVE: {}}} + communication_data = FileManager.read_json_file(self.file_path) + for communication_op, communication_info in communication_data.items(): + lower_op_name = communication_op.lower() + if any(lower_op_name.startswith(start_str) for start_str in self.P2P_HCOM): + output_communication["step"][Constant.P2P][communication_op] = communication_info + elif lower_op_name.startswith(self.TOTAL): + continue + else: + output_communication["step"][Constant.COLLECTIVE][communication_op] = communication_info + + return output_communication diff --git a/profiler/msprof_analyze/cluster_analyse/recipes/base_recipe_analysis.py b/profiler/msprof_analyze/cluster_analyse/recipes/base_recipe_analysis.py index a8b503592536e529b4a9043058284f0094b08038..4b6ddaf9855206cbc339f5d8563954a158f072ff 100644 --- a/profiler/msprof_analyze/cluster_analyse/recipes/base_recipe_analysis.py +++ b/profiler/msprof_analyze/cluster_analyse/recipes/base_recipe_analysis.py @@ -26,6 +26,8 @@ from msprof_analyze.cluster_analyse.common_func.utils import convert_unit from msprof_analyze.prof_common.constant import Constant from msprof_analyze.prof_common.logger import get_logger from msprof_analyze.prof_common.path_manager import PathManager +from msprof_analyze.cluster_analyse.cluster_data_preprocess.msprof_data_preprocessor import MsprofDataPreprocessor +from msprof_analyze.prof_common.file_manager import FileManager logger = get_logger() @@ -42,6 +44,8 @@ class BaseRecipeAnalysis(ABC): self._recipe_name = params.get(Constant.RECIPE_NAME, "") self._parallel_mode = params.get(Constant.PARALLEL_MODE, "") self._export_type = params.get(Constant.EXPORT_TYPE, "") + self._is_msprof = params.get(Constant.IS_MSPROF) + self._is_mindspore = params.get(Constant.IS_MINDSPORE) self._cluster_analysis_output_path = os.path.join( params.get(Constant.CLUSTER_ANALYSIS_OUTPUT_PATH, self._collection_dir), Constant.CLUSTER_ANALYSIS_OUTPUT) self._output_path = self._cluster_analysis_output_path if self._export_type == "db" else os.path.join( @@ -50,7 +54,7 @@ class BaseRecipeAnalysis(ABC): self._rank_list = rank_list if rank_list == "all" else [int(rank) for rank in rank_list.split(",") if rank.isdigit()] self._step_id = params.get(Constant.STEP_ID, Constant.VOID_STEP) - self._extra_args = self.get_extra_argument(params.get(Constant.EXTRA_ARGS)) + self._extra_args = self.get_extra_argument(params.get(Constant.EXTRA_ARGS, [])) PathManager.make_dir_safety(self._output_path) def __enter__(self): @@ -86,7 +90,10 @@ class BaseRecipeAnalysis(ABC): def get_extra_argument(cls, args_list) -> dict: parser = argparse.ArgumentParser() cls.add_parser_argument(parser) - args, _ = parser.parse_known_args(args_list) + args, unknown_args = parser.parse_known_args(args_list) + if unknown_args: + unknown_args = " ".join(unknown_args) + logger.warning(f"Invalid parameters: {unknown_args}. It will not have any effect.") return vars(args) @abstractmethod @@ -107,7 +114,7 @@ class BaseRecipeAnalysis(ABC): result_db = custom_db_path if custom_db_path else os.path.join(self.output_path, file_name) conn, cursor = DBManager.create_connect_db(result_db) if isinstance(data, pd.DataFrame): - data.to_sql(table_name, conn, if_exists='replace', index=True) + data.to_sql(table_name, conn, if_exists='replace', index=index) else: logger.error(f"Unknown dump data type: {type(data)}") DBManager.destroy_db_connect(conn, cursor) @@ -129,12 +136,10 @@ class BaseRecipeAnalysis(ABC): if replace_dict is None: shutil.copy(template_file, output_file_path) else: - with open(template_file, 'r') as f: - template_content = f.read() - for key, value in replace_dict.items(): - template_content = template_content.replace(str(key), str(value)) - with open(output_file_path, 'w') as f: - f.write(template_content) + template_content = FileManager.read_common_file(template_file) + for key, value in replace_dict.items(): + template_content = template_content.replace(str(key), str(value)) + FileManager.create_common_file(output_file_path, template_content) logger.info(f"Notebook export path is: {output_file_path}") def add_helper_file(self, helper_file): @@ -158,16 +163,41 @@ class BaseRecipeAnalysis(ABC): db_paths = [] for rank_id in rank_ids: rank_path = self._data_map[rank_id] - db_path = os.path.join(rank_path, Constant.SINGLE_OUTPUT, f"ascend_pytorch_profiler_{rank_id}.db") - if os.path.exists(db_path): - db_paths.append({Constant.RANK_ID: rank_id, Constant.PROFILER_DB_PATH: db_path, - Constant.STEP_RANGE: self._get_step_range(db_path)}) + db_path_dict = {Constant.RANK_ID: rank_id, Constant.PROFILER_DB_PATH: "", Constant.ANALYSIS_DB_PATH: "", + Constant.STEP_RANGE: {}} + profiler_db_path = self._get_profiler_db_path(rank_id, rank_path) + analysis_db_path = self._get_analysis_db_path(rank_path) + if os.path.exists(profiler_db_path): + db_path_dict[Constant.PROFILER_DB_PATH] = profiler_db_path + db_path_dict[Constant.STEP_RANGE] = self._get_step_range(profiler_db_path) else: - logger.warning(f"DB file not found, rank id: {rank_id}, db path: {db_path}.") + logger.warning(f"Profiler DB file not found, rank id: {rank_id}, db path: {profiler_db_path}.") + + if os.path.exists(analysis_db_path): + db_path_dict[Constant.ANALYSIS_DB_PATH] = analysis_db_path + else: + logger.warning(f"Analysis DB file not found, rank id: {rank_id}, db path: {analysis_db_path}.") + if db_path_dict.get(Constant.PROFILER_DB_PATH): + db_paths.append(db_path_dict) if invalid_rank_id: logger.warning(f"Invalid Rank id: [{','.join(invalid_rank_id)}].") return db_paths + def _get_profiler_db_path(self, rank_id, data_path): + if self._is_msprof: + db_path = MsprofDataPreprocessor.get_msprof_profiler_db_path(data_path) + return db_path if db_path else os.path.join(data_path, "msprof_xx.db") + if self._is_mindspore: + return os.path.join(data_path, Constant.SINGLE_OUTPUT, f"ascend_mindspore_profiler_{rank_id}.db") + return os.path.join(data_path, Constant.SINGLE_OUTPUT, f"ascend_pytorch_profiler_{rank_id}.db") + + def _get_analysis_db_path(self, data_path): + if self._is_msprof: + return os.path.join(data_path, Constant.ANALYZE_DIR, "communication_analyzer.db") + if self._is_mindspore: + return os.path.join(data_path, Constant.SINGLE_OUTPUT, "communication_analyzer.db") + return os.path.join(data_path, Constant.SINGLE_OUTPUT, "analysis.db") + def _get_step_range(self, db_path): step_range = {} if self._step_id == Constant.VOID_STEP: @@ -204,9 +234,14 @@ class BaseRecipeAnalysis(ABC): Extract the profiling data required for cluster analysis from each device, and then aggregate the results from each device to be processed by a reduce function. Params: - data_map: eg. {"RANK_ID": 1, - "profiler_db_path": "xxxx/ascend_pytorch_profiler_1.db", - "step_range": {"id": 2, "startNs": 12345, "endNs": 12443]} + data_map: eg1. {"RANK_ID": 1, + "profiler_db_path": "xxx/ASCEND_PROFILER_OUTPUT/ascend_pytorch_profiler_1.db", + "analysis_db_path": "xxx/ASCEND_PROFILER_OUTPUT/analysis.db", + "step_range": {"id": 2, "startNs": 12345, "endNs": 12443]} + eg2. {"RANK_ID": 1, + "profiler_db_path": "xxx/msprof_20250227145123.db", + "analysis_db_path": "xxx/analyze/communication_analyzer.db", + "step_range": {"id": 2, "startNs": 12345, "endNs": 12443]} analysis_class: hccl_sum, compute_op_sum, cann_api_sum, mstx_sum…… """ pass diff --git a/profiler/msprof_analyze/cluster_analyse/recipes/comm_group_map/__init__.py b/profiler/msprof_analyze/cluster_analyse/recipes/comm_group_map/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/profiler/msprof_analyze/cluster_analyse/recipes/comm_group_map/comm_group_map.py b/profiler/msprof_analyze/cluster_analyse/recipes/comm_group_map/comm_group_map.py new file mode 100644 index 0000000000000000000000000000000000000000..8a0b5d8fb20daed6ec6b8ee1459321b1bdb4ca07 --- /dev/null +++ b/profiler/msprof_analyze/cluster_analyse/recipes/comm_group_map/comm_group_map.py @@ -0,0 +1,122 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import os +import pandas as pd + +from msprof_analyze.cluster_analyse.common_func.utils import double_hash +from msprof_analyze.cluster_analyse.common_func.table_constant import TableConstant +from msprof_analyze.cluster_analyse.recipes.base_recipe_analysis import BaseRecipeAnalysis +from msprof_analyze.prof_common.constant import Constant +from msprof_analyze.prof_common.logger import get_logger +from msprof_analyze.prof_common.database_service import DatabaseService + +logger = get_logger() + + +class CommGroupMap(BaseRecipeAnalysis): + COMMUNICATION_GROUP_MAPPING_TABLE = "CommunicationGroupMapping" + + def __init__(self, params): + super().__init__(params) + logger.info("CommGroupMap init.") + self.group_df = None + + @property + def base_dir(self): + return os.path.basename(os.path.dirname(__file__)) + + @staticmethod + def get_comm_type_from_op_name(op_name: str): + op_name_lower = op_name.lower() + return Constant.P2P if ("send" in op_name_lower or "receive" in op_name_lower or "recv" in op_name_lower) \ + else Constant.COLLECTIVE + + def run(self, context): + mapper_res = self.mapper_func(context) + self.reducer_func(mapper_res) + if self._export_type == Constant.DB: + self.save_db() + else: + logger.error(f"CommGroupMap: {self._export_type} is not supported for export type.") + + def reducer_func(self, mapper_res): + # concat and process all comm group + comm_group_df_list = [df for df, _ in mapper_res] + comm_group_combined_df = pd.concat(comm_group_df_list).drop_duplicates() + comm_group_combined_df = (comm_group_combined_df.groupby([TableConstant.TYPE, TableConstant.GROUP_NAME]) + [TableConstant.RANK_ID].apply(lambda x: sorted(set(x))).reset_index()) + comm_group_combined_df[TableConstant.RANK_SET] = (comm_group_combined_df[TableConstant.RANK_ID]. + apply(lambda x: "(" + ",".join(str(i) for i in x) + ")")) + + comm_group_combined_df = comm_group_combined_df.drop(columns=[TableConstant.RANK_ID]) + # concat all parallel group info + parallel_info_df_list = [df for _, df in mapper_res] + parallel_info_combined_df = pd.concat(parallel_info_df_list).drop_duplicates() + # merge by group_name + group_df = pd.merge(comm_group_combined_df, parallel_info_combined_df, on=TableConstant.GROUP_NAME, how="left") + group_df.fillna("", inplace=True) + # column order + column_order = [TableConstant.TYPE, TableConstant.RANK_SET, TableConstant.GROUP_NAME, + TableConstant.GROUP_ID, TableConstant.PG_NAME] + self.group_df = group_df[column_order] + + def save_db(self): + self.dump_data(self.group_df, Constant.DB_CLUSTER_COMMUNICATION_ANALYZER, + self.COMMUNICATION_GROUP_MAPPING_TABLE, index=False) + + def _mapper_func(self, data_map, analysis_class): + rank_id = data_map.get(Constant.RANK_ID) + # read CommAnalyzerTime table + analysis_db_path = data_map.get(Constant.ANALYSIS_DB_PATH) + analysis_data_service = DatabaseService(analysis_db_path, {}) + analysis_data_service.add_table_for_query(Constant.TABLE_COMM_ANALYZER_TIME, + [TableConstant.HCCL_OP_NAME, TableConstant.GROUP_NAME]) + comm_time_res = analysis_data_service.query_data() + # process comm_time_df: group_name, type, rank_id + comm_time_df = comm_time_res.get(Constant.TABLE_COMM_ANALYZER_TIME) + comm_time_df[TableConstant.RANK_ID] = rank_id + comm_time_df[TableConstant.TYPE] = (comm_time_df[TableConstant.HCCL_OP_NAME]. + apply(lambda x: self.get_comm_type_from_op_name(x))) + comm_time_df = comm_time_df.drop(columns=[TableConstant.HCCL_OP_NAME]) + comm_time_df = comm_time_df.drop_duplicates() + + # read META_DATA table + profiler_db_path = data_map.get(Constant.PROFILER_DB_PATH) + profiler_data_service = DatabaseService(profiler_db_path, {}) + profiler_data_service.add_table_for_query(Constant.TABLE_META_DATA, + [TableConstant.NAME, TableConstant.VALUE]) + meta_data_res = profiler_data_service.query_data() + meta_data_df = meta_data_res.get(Constant.TABLE_META_DATA) + # process parallel_info_df + parallel_info_df = pd.DataFrame(columns=[TableConstant.GROUP_NAME, + TableConstant.GROUP_ID, TableConstant.PG_NAME]) + if Constant.PARALLEL_GROUP_INFO not in meta_data_df[TableConstant.NAME].values: + return comm_time_df, parallel_info_df + info_str = meta_data_df.loc[meta_data_df[TableConstant.NAME] == Constant.PARALLEL_GROUP_INFO, + TableConstant.VALUE].values[0] + info_dict = json.loads(info_str) + for group_id, parallel_info in info_dict.items(): + group_name = str(double_hash(group_id)) # group_name is hashed group_id + pg_name = parallel_info.get(TableConstant.GROUP_NAME, "") + if not pg_name: + continue + parallel_info_df.loc[parallel_info_df.shape[0]] = [group_name, group_id, pg_name] + + return comm_time_df, parallel_info_df + + + + diff --git a/profiler/msprof_analyze/cluster_analyse/recipes/communication_matrix_sum/__init__.py b/profiler/msprof_analyze/cluster_analyse/recipes/communication_matrix_sum/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/profiler/msprof_analyze/cluster_analyse/recipes/communication_matrix_sum/communication_matrix_sum.py b/profiler/msprof_analyze/cluster_analyse/recipes/communication_matrix_sum/communication_matrix_sum.py new file mode 100644 index 0000000000000000000000000000000000000000..8b91626fe50dfa9ebfa62321545c92bb4dde39d0 --- /dev/null +++ b/profiler/msprof_analyze/cluster_analyse/recipes/communication_matrix_sum/communication_matrix_sum.py @@ -0,0 +1,207 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import ast +import os + +import pandas as pd +from msprof_analyze.cluster_analyse.recipes.base_recipe_analysis import BaseRecipeAnalysis +from msprof_analyze.prof_common.logger import get_logger +from msprof_analyze.prof_common.constant import Constant +from msprof_analyze.prof_common.database_service import DatabaseService +from msprof_analyze.cluster_analyse.common_func.utils import double_hash + +from msprof_analyze.cluster_analyse.common_func.table_constant import TableConstant + +logger = get_logger() + + +class CommMatrixSum(BaseRecipeAnalysis): + TABLE_CLUSTER_COMM_MATRIX = "ClusterCommunicationMatrix" + RANK_MAP = "rank_map" + MATRIX_DATA = "matrix_data" + RANK_SET = "rank_set" + P2P_HCOM = ["hcom_send", "hcom_receive", "hcom_batchsendrecv"] + + def __init__(self, params): + super().__init__(params) + self.cluster_matrix_df = None + logger.info("CommMatrixSum init.") + + @property + def base_dir(self): + return os.path.basename(os.path.dirname(__file__)) + + @classmethod + def _get_parallel_group_info(cls, profiler_db_path): + rank_map = {} + data_service = DatabaseService(profiler_db_path, {}) + data_service.add_table_for_query(TableConstant.TABLE_META_DATA) + meta_df = data_service.query_data().get(TableConstant.TABLE_META_DATA, None) + if meta_df is None or meta_df.empty: + return rank_map + filtered_df = meta_df[meta_df['name'] == "parallel_group_info"] + if filtered_df.shape[0] == 1 and filtered_df.shape[1] == 2: + parallel_group_info = ast.literal_eval(filtered_df['value'].tolist()[0]) + for group_name, group_info in parallel_group_info.items(): + global_ranks = group_info.get("global_ranks") + if isinstance(global_ranks, list) and global_ranks: + global_ranks.sort() + rank_map[double_hash(group_name)] = dict(enumerate(global_ranks)) + return rank_map + + @classmethod + def _trans_msprof_matrix_data(cls, matrix_data): + matrix_data["step"] = "step" + matrix_data["type"] = Constant.COLLECTIVE + for index, row in matrix_data.iterrows(): + lower_op_name = row["hccl_op_name"].lower() + if any(lower_op_name.startswith(start_str) for start_str in cls.P2P_HCOM): + matrix_data.at[index, "type"] = Constant.P2P + matrix_data = matrix_data.rename(columns={'hccl_op_name': 'op_name'}) + matrix_data["hccl_op_name"] = matrix_data["op_name"].str.split("__").str[0] + + # 按多字段分组 + grouped_df = matrix_data.groupby(['type', 'step', 'group_name', 'hccl_op_name', 'src_rank', 'dst_rank']) + + # 定义一个函数,用于提取特定的记录 + def get_specific_rows(group): + # 按带宽排序 + sorted_group = group.sort_values(by='bandwidth') + bottom1 = sorted_group.iloc[-1] + bottom2 = sorted_group.iloc[-2] if len(group) > 1 else pd.Series() + bottom3 = sorted_group.iloc[-3] if len(group) > 2 else pd.Series() + top1 = sorted_group.iloc[0] + mid_index = len(group) // 2 + middle = sorted_group.iloc[mid_index] + return pd.DataFrame([top1, bottom1, bottom2, bottom3, middle], + index=['top1', 'bottom1', 'bottom2', 'bottom3', 'middle']).reset_index() + + example_df = grouped_df.apply(get_specific_rows).reset_index(drop=True) + example_df = example_df.dropna().reset_index(drop=True) + example_df["hccl_op_name"] = example_df["hccl_op_name"].astype(str) + "-" + example_df["index"].astype(str) + example_df = example_df.drop(columns="index") + + # total + total_df = matrix_data.groupby(['type', 'step', 'group_name', 'hccl_op_name', 'src_rank', 'dst_rank']).agg( + {'transport_type': 'first', "transit_size": "sum", "transit_time": "sum"}) + total_df = total_df.reset_index() + total_df["op_name"] = None + total_df["hccl_op_name"] = total_df["hccl_op_name"].astype(str) + "-total" + total_df['bandwidth'] = total_df['transit_size'] / total_df['transit_time'].where(total_df['transit_time'] != 0, + other=0) + return pd.concat([example_df, total_df], ignore_index=True) + + def run(self, context): + mapper_res = self.mapper_func(context) + self.reducer_func(mapper_res) + + if self._export_type == "db": + self.save_db() + else: + logger.error("communication_matrix_sum is not supported for notebook export type.") + + def reducer_func(self, mapper_res): + rank_map = self._generate_rank_map(mapper_res) + concat_df = pd.DataFrame() + for rank_data in mapper_res: + matrix_df = rank_data.get(self.MATRIX_DATA) + concat_df = pd.concat([concat_df, matrix_df], ignore_index=True) + if concat_df.empty: + logger.error("Communication matrix data is None.") + return + concat_df[self.RANK_SET] = "" + for index, row in concat_df.iterrows(): + if row["type"] == Constant.P2P: + concat_df.at[index, self.RANK_SET] = Constant.P2P + continue + rank_list = sorted(rank_map.get(row["group_name"], {}).values()) + concat_df.at[index, self.RANK_SET] = ",".join([str(rank) for rank in rank_list]) + grouped_df = concat_df.groupby( + [self.RANK_SET, 'step', "hccl_op_name", "group_name", "src_rank", "dst_rank"]).agg( + {'transport_type': 'first', 'op_name': 'first', "transit_size": "sum", "transit_time": "sum"}) + grouped_df = grouped_df.reset_index() + grouped_df["is_mapped"] = False + grouped_df["bandwidth"] = None + for index, row in grouped_df.iterrows(): + src_rank = row["src_rank"] + dst_rank = row["dst_rank"] + group_name = row["group_name"] + group_rank_map = rank_map.get(group_name, {}) + if src_rank not in group_rank_map: + logger.warning(f"The src local rank {src_rank} of the group_name {group_name} " + f"cannot be mapped to the global rank.") + continue + if dst_rank not in group_rank_map: + logger.warning(f"The dst local rank {dst_rank} of the group_name {group_name} " + f"cannot be mapped to the global rank.") + continue + grouped_df.at[index, 'src_rank'] = group_rank_map[src_rank] + grouped_df.at[index, 'dst_rank'] = group_rank_map[dst_rank] + grouped_df.at[index, 'is_mapped'] = True + grouped_df.at[index, 'bandwidth'] = row["transit_size"] / row["transit_time"] if row["transit_time"] else 0 + filtered_df = grouped_df[grouped_df["is_mapped"]].drop(columns="is_mapped") + total_op_info = filtered_df[filtered_df['hccl_op_name'].str.contains('total', na=False)].groupby( + [self.RANK_SET, 'step', "src_rank", "dst_rank"]).agg( + {"group_name": "first", 'transport_type': 'first', 'op_name': 'first', "transit_size": "sum", + "transit_time": "sum"} + ) + total_op_info = total_op_info.reset_index() + total_op_info["hccl_op_name"] = Constant.TOTAL_OP_INFO + total_op_info['bandwidth'] = total_op_info['transit_size'] / total_op_info['transit_time'].where( + total_op_info['transit_time'] != 0, other=0) + self.cluster_matrix_df = pd.concat([filtered_df, total_op_info], ignore_index=True).drop(columns=self.RANK_SET) + + def save_db(self): + self.dump_data(self.cluster_matrix_df, Constant.DB_CLUSTER_COMMUNICATION_ANALYZER, + self.TABLE_CLUSTER_COMM_MATRIX, index=False) + + def _generate_rank_map(self, mapper_res): + rank_map = {} + rank_map_df = pd.DataFrame({"group_name": [], "src_rank": [], Constant.RANK_ID: []}) + for rank_data in mapper_res: + rank_map.update(rank_data.get(self.RANK_MAP)) + matrix_df = rank_data.get(self.MATRIX_DATA) + filter_matrix_df = matrix_df[matrix_df["src_rank"] == matrix_df["dst_rank"]] + grouped_matrix_df = filter_matrix_df[['group_name', 'src_rank']].drop_duplicates() + grouped_matrix_df[Constant.RANK_ID] = rank_data.get(Constant.RANK_ID) + rank_map_df = pd.concat([grouped_matrix_df, rank_map_df], ignore_index=True) + rank_map_df = rank_map_df.drop_duplicates() + for _, row in rank_map_df.iterrows(): + group_name = row["group_name"] + local_rank = row["src_rank"] + global_rank = row[Constant.RANK_ID] + if group_name not in rank_map: + rank_map[group_name] = {local_rank: global_rank} + continue + if local_rank not in rank_map[group_name]: + rank_map[group_name][local_rank] = global_rank + continue + if rank_map[group_name][local_rank] != global_rank: + logger.warning(f"In the same communication group {group_name}, global rank {global_rank} " + f"and {rank_map[group_name][local_rank]} get the same local rank {local_rank}!") + return rank_map + + def _mapper_func(self, data_map, analysis_class): + result_data = {Constant.RANK_ID: data_map.get(Constant.RANK_ID)} + profiler_db_path = data_map.get(Constant.PROFILER_DB_PATH) + result_data[self.RANK_MAP] = self._get_parallel_group_info(profiler_db_path) + analysis_db_path = data_map.get(Constant.ANALYSIS_DB_PATH) + data_service = DatabaseService(analysis_db_path, {}) + data_service.add_table_for_query(TableConstant.TABLE_COMM_ANALYZER_MATRIX) + matrix_data = data_service.query_data().get(TableConstant.TABLE_COMM_ANALYZER_MATRIX) + if self._is_msprof or self._is_mindspore: + matrix_data = self._trans_msprof_matrix_data(matrix_data) + result_data[self.MATRIX_DATA] = matrix_data + return result_data diff --git a/profiler/msprof_analyze/cluster_analyse/recipes/communication_time_sum/__init__.py b/profiler/msprof_analyze/cluster_analyse/recipes/communication_time_sum/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a355e5a7f08206fc39dda4646817224c067f29f7 --- /dev/null +++ b/profiler/msprof_analyze/cluster_analyse/recipes/communication_time_sum/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/profiler/msprof_analyze/cluster_analyse/recipes/communication_time_sum/communication_time_sum.py b/profiler/msprof_analyze/cluster_analyse/recipes/communication_time_sum/communication_time_sum.py new file mode 100644 index 0000000000000000000000000000000000000000..2291be3405157fd7cdc282b062a1aadb1bb88d44 --- /dev/null +++ b/profiler/msprof_analyze/cluster_analyse/recipes/communication_time_sum/communication_time_sum.py @@ -0,0 +1,207 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os + +import numpy as np +import pandas as pd + +from msprof_analyze.cluster_analyse.common_func.table_constant import TableConstant +from msprof_analyze.cluster_analyse.recipes.base_recipe_analysis import BaseRecipeAnalysis +from msprof_analyze.prof_common.constant import Constant +from msprof_analyze.prof_common.database_service import DatabaseService +from msprof_analyze.prof_common.logger import get_logger + +logger = get_logger() + + +class CommunicationTimeSumRecipe(BaseRecipeAnalysis): + TABLE_CLUSTER_COMM_TIME = "ClusterCommunicationTime" + TABLE_CLUSTER_COMM_BANDWIDTH = "ClusterCommunicationBandwidth" + + def __init__(self, params): + super().__init__(params) + logger.info("CommunicationSum init.") + self.communication_time = None + self.communication_bandwidth = None + + @property + def base_dir(self): + return os.path.basename(os.path.dirname(__file__)) + + def run(self, context): + mapper_res = self.mapper_func(context) + self.reducer_func(mapper_res) + if self._export_type == Constant.DB: + self.save_db() + else: + logger.error("Unknown export type.") + + def reducer_func(self, mapper_res): + mapper_res_time = list(item[0] for item in mapper_res if item[0] is not None) + mapper_res_bw = list(item[1] for item in mapper_res if item[1] is not None) + if not mapper_res_time and not mapper_res_bw: + logger.error("Mapper data is None.") + return + cluster_db_path = os.path.join(self.output_path, Constant.DB_CLUSTER_COMMUNICATION_ANALYZER) + data_service = DatabaseService(cluster_db_path, None) + data_service.add_table_for_query("CommunicationGroupMapping", + [TableConstant.RANK_SET, TableConstant.GROUP_NAME]) + df_dict = data_service.query_data() + rank_set_df = df_dict.get("CommunicationGroupMapping", None) + if rank_set_df is None or rank_set_df.empty: + logger.error(f"There is no CommunicationGroupMapping data in {cluster_db_path}.") + return + communication_time = pd.concat(mapper_res_time) + communication_bandwidth = pd.concat(mapper_res_bw) + self._compute_time_info(communication_time, rank_set_df) + self._compute_bandwidth_info(communication_bandwidth, rank_set_df) + + def save_db(self): + self.dump_data(self.communication_time, Constant.DB_CLUSTER_COMMUNICATION_ANALYZER, + self.TABLE_CLUSTER_COMM_TIME, index=False) + self.dump_data(self.communication_bandwidth, Constant.DB_CLUSTER_COMMUNICATION_ANALYZER, + self.TABLE_CLUSTER_COMM_BANDWIDTH, index=False) + + def _compute_time_info(self, communication_time, rank_set_df): + """ + communication_time: ['hccl_op_name', 'group_name', 'start_timestamp', 'elapse_time', + 'transit_time', 'wait_time', 'synchronization_time', 'idle_time', + 'step', 'type', 'rank_id'] + rank_set_df: ['rank_set', 'group_name'] + output: ['step', 'rank_id', 'hccl_op_name', 'group_name', 'start_timestamp', 'elapse_time', 'transit_time', + 'wait_time', 'synchronization_time', 'idle_time', 'synchronization_time_ratio', 'wait_time_ratio'] + + 按"step", "rank_id", "rank_set"字段进行分组,汇总"elapse_time", "transit_time", "wait_time", + "synchronization_time", "idle_time"等时间数据,新增汇总行插入communication_time + """ + merged_df = pd.merge(communication_time, rank_set_df, on=TableConstant.GROUP_NAME, how='left') + summed_df = merged_df.groupby([TableConstant.STEP, TableConstant.RANK_ID, TableConstant.RANK_SET]).agg({ + TableConstant.GROUP_NAME: "first", + TableConstant.ELAPSED_TIME: "sum", + TableConstant.TRANSIT_TIME: "sum", + TableConstant.WAIT_TIME: "sum", + TableConstant.SYNCHRONIZATION_TIME: "sum", + TableConstant.IDLE_TIME: "sum" + }).reset_index() + summed_df[TableConstant.HCCL_OP_NAME] = Constant.TOTAL_OP_INFO + summed_df[TableConstant.START_TIMESTAMP] = 0 + # 计算 synchronization_time_ratio,wait_time_ratio + summed_df[TableConstant.SYNCHRONIZATION_TIME_RATIO] = ( + summed_df[TableConstant.SYNCHRONIZATION_TIME] / + (summed_df[TableConstant.TRANSIT_TIME] + summed_df[TableConstant.SYNCHRONIZATION_TIME]).replace(0, + np.nan) + ).fillna(0).round(4) + summed_df[TableConstant.WAIT_TIME_RATIO] = ( + summed_df[TableConstant.WAIT_TIME] / + (summed_df[TableConstant.TRANSIT_TIME] + summed_df[TableConstant.WAIT_TIME]).replace(0, np.nan) + ).fillna(0).round(4) + + communication_time[TableConstant.SYNCHRONIZATION_TIME_RATIO] = 0 + communication_time[TableConstant.WAIT_TIME_RATIO] = 0 + desired_order = [TableConstant.STEP, TableConstant.RANK_ID, TableConstant.HCCL_OP_NAME, + TableConstant.GROUP_NAME, TableConstant.START_TIMESTAMP, TableConstant.ELAPSED_TIME, + TableConstant.TRANSIT_TIME, TableConstant.WAIT_TIME, TableConstant.SYNCHRONIZATION_TIME, + TableConstant.IDLE_TIME, TableConstant.SYNCHRONIZATION_TIME_RATIO, + TableConstant.WAIT_TIME_RATIO] + # 合并汇总数据DataFrame + final_df = pd.concat([communication_time, summed_df], axis=0).reindex(columns=desired_order) + final_df.rename(columns={'elapse_time': 'elapsed_time'}, inplace=True) + self.communication_time = final_df + + def _compute_bandwidth_info(self, communication_bandwidth, rank_set_df): + """ + communication_bandwidth: ['hccl_op_name', 'group_name', 'transport_type', 'transit_size', + 'transit_time', 'bandwidth', 'large_packet_ratio', 'package_size', + 'count', 'total_duration', 'step', 'type', 'rank_id'] + output: ['step', 'rank_id', 'hccl_op_name', 'group_name', 'band_type', 'transit_size', 'transit_time', + 'bandwidth', 'large_packet_ratio', 'package_size', 'count', 'total_duration'] + rank_set_df: ['rank_set', 'group_name'] + 按'rank_set', 'step', 'rank_id', 'transport_type', 'package_size'进行分组,对'count', 'total_duration'进行求和; + 对于同一'rank_set', 'step', 'rank_id', 'transport_type'下的数据,对'transit_size', 'transit_time'求和, + 其中如果'hccl_op_name'+'group_name'相同,求和时只累加一次 + """ + merged_df = pd.merge(communication_bandwidth, rank_set_df, on=TableConstant.GROUP_NAME, how='left') + # 计算每个rank_set/step/rank_id/transport_type分组下去重后的transit_size和transit_time总和 + sum_transit_size = 'sum_transit_size' + sum_transit_time = 'sum_transit_time' + sum_transit = merged_df.groupby( + [TableConstant.RANK_SET, TableConstant.STEP, TableConstant.RANK_ID, TableConstant.TRANSPORT_TYPE]).apply( + self._get_sum_distinct_op).reset_index().rename(columns={ + TableConstant.TRANSIT_SIZE: sum_transit_size, + TableConstant.TRANSIT_TIME: sum_transit_time + }) + joined_df = pd.merge(merged_df, sum_transit, + on=[TableConstant.RANK_SET, TableConstant.STEP, TableConstant.RANK_ID, + TableConstant.TRANSPORT_TYPE]) + # 按'rank_set', 'step', 'rank_id', 'transport_type', 'package_size'进行聚合 + agg_result = joined_df.groupby( + [TableConstant.RANK_SET, TableConstant.STEP, TableConstant.RANK_ID, TableConstant.TRANSPORT_TYPE, + TableConstant.PACKAGE_SIZE] + ).agg({ + TableConstant.COUNT: 'sum', + TableConstant.TOTAL_DURATION: 'sum', + TableConstant.HCCL_OP_NAME: 'first', + TableConstant.GROUP_NAME: 'first', + sum_transit_size: 'first', + sum_transit_time: 'first' + }).reset_index() + agg_result[TableConstant.LARGE_PACKET_RATIO] = 0 + agg_result[TableConstant.HCCL_OP_NAME] = Constant.TOTAL_OP_INFO + # 计算聚合数据带宽 + agg_result[TableConstant.BANDWIDTH] = ( + agg_result[sum_transit_size] / agg_result[sum_transit_time].replace(0, np.nan) + ).fillna(0).round(4) + agg_result = agg_result.rename(columns={ + sum_transit_size: TableConstant.TRANSIT_SIZE, + sum_transit_time: TableConstant.TRANSIT_TIME + }) + desired_order = [TableConstant.STEP, TableConstant.RANK_ID, TableConstant.HCCL_OP_NAME, + TableConstant.GROUP_NAME, TableConstant.TRANSPORT_TYPE, TableConstant.TRANSIT_SIZE, + TableConstant.TRANSIT_TIME, TableConstant.BANDWIDTH, TableConstant.LARGE_PACKET_RATIO, + TableConstant.PACKAGE_SIZE, TableConstant.COUNT, TableConstant.TOTAL_DURATION] + final_df = pd.concat([communication_bandwidth, agg_result], axis=0).reindex(columns=desired_order) + final_df.rename(columns={TableConstant.TRANSPORT_TYPE: TableConstant.BAND_TYPE}, inplace=True) + self.communication_bandwidth = final_df + + def _get_sum_distinct_op(self, op_df): + return op_df.drop_duplicates(subset=[TableConstant.HCCL_OP_NAME, TableConstant.GROUP_NAME])[ + [TableConstant.TRANSIT_SIZE, TableConstant.TRANSIT_TIME]].sum() + + def _mapper_func(self, data_map, analysis_class): + analysis_db_path = data_map.get(Constant.ANALYSIS_DB_PATH) + rank_id = data_map.get(Constant.RANK_ID) + step_range = data_map.get(Constant.STEP_RANGE) + date_service = DatabaseService(analysis_db_path, step_range) + date_service.add_table_for_query(Constant.TABLE_COMM_ANALYZER_TIME) + date_service.add_table_for_query(Constant.TABLE_COMM_ANALYZER_BANDWIDTH) + df_dict = date_service.query_data() + time_df = df_dict.get(Constant.TABLE_COMM_ANALYZER_TIME) + bandwidth_df = df_dict.get(Constant.TABLE_COMM_ANALYZER_BANDWIDTH) + + is_time_df_empty = time_df is None or time_df.empty + is_bandwidth_df_empty = bandwidth_df is None or bandwidth_df.empty + if is_time_df_empty or is_bandwidth_df_empty: + logger.warning(f"There is no stats data in {analysis_db_path}.") + return None, None + # 补充step、rank_id字段 + time_df[TableConstant.RANK_ID] = rank_id + bandwidth_df[TableConstant.RANK_ID] = rank_id + if TableConstant.STEP not in time_df.columns: + time_df[TableConstant.STEP] = TableConstant.STEP + if TableConstant.STEP not in bandwidth_df.columns: + bandwidth_df[TableConstant.STEP] = TableConstant.STEP + return time_df, bandwidth_df diff --git a/profiler/msprof_analyze/cluster_analyse/recipes/ep_load_balance/__init__.py b/profiler/msprof_analyze/cluster_analyse/recipes/ep_load_balance/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b14094e3f9a77a0970342980ed8de1017f58ce19 --- /dev/null +++ b/profiler/msprof_analyze/cluster_analyse/recipes/ep_load_balance/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. \ No newline at end of file diff --git a/profiler/msprof_analyze/cluster_analyse/recipes/ep_load_balance/ep_load_balance.py b/profiler/msprof_analyze/cluster_analyse/recipes/ep_load_balance/ep_load_balance.py new file mode 100644 index 0000000000000000000000000000000000000000..74caf90341cc701bfc1c2756043d5cec9d6d546f --- /dev/null +++ b/profiler/msprof_analyze/cluster_analyse/recipes/ep_load_balance/ep_load_balance.py @@ -0,0 +1,129 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import json + +import pandas as pd + +from msprof_analyze.cluster_analyse.recipes.base_recipe_analysis import BaseRecipeAnalysis +from msprof_analyze.prof_common.constant import Constant +from msprof_analyze.prof_common.logger import get_logger +from msprof_analyze.prof_exports.ep_load_balance_ecport import InputShapeExport +from msprof_analyze.prof_common.database_service import DatabaseService + +logger = get_logger() + + +class EPLoadBalance(BaseRecipeAnalysis): + + EPTOKENSSUMMARY = "EPTokensSummary" + TOPEPTOKENSINFO = "TopEPTokensInfo" + META_DATA = "META_DATA" + Top_Num = 20 + GROUPEP = "exp" + + def __init__(self, params): + super().__init__(params) + logger.info("EPLoadBalance init.") + self.ep_tokens_summary = None + self.top_ep_tokens_map = None + + @property + def base_dir(self): + return os.path.basename(os.path.dirname(__file__)) + + def process_input_shapes(self, df): + def calculate_seqlength(shape_str): + shape_str = shape_str.strip('"') + parts = shape_str.split(";") + non_empty_parts = [part for part in parts if part] + # 取前 n-2 个有值的部分 + if len(non_empty_parts) > 1: + non_empty_parts = non_empty_parts[: len(non_empty_parts) - 2] + else: + return None + seqlength = 0 + for part in non_empty_parts: + part = part.strip() + try: + first_dim = int(part.split(",")[0]) + except (IndexError, ValueError) as e: + return None + seqlength += first_dim + return seqlength + + df["InputShapes"] = df["InputShapes"].apply(calculate_seqlength) + return df + + def reducer_func(self, mapper_res): + mapper_res = list(filter(lambda df: df is not None, mapper_res)) + if not mapper_res: + logger.error("Mapper data is None.") + return + for i, df in enumerate(mapper_res): + mapper_res[i] = self.process_input_shapes(df) + mapper_res = [df.dropna() for df in mapper_res] + for df in mapper_res: + df["epRanks"] = df["epRanks"].apply(lambda x: ",".join(map(str, x))) + combined_df = pd.concat(mapper_res) + self.ep_tokens_summary = combined_df.groupby(["Rank", "epRanks"]).agg({"InputShapes": "sum"}).reset_index() + self.ep_tokens_summary.columns = ["rank", "epRanks", "inputShapesSummary"] + self.top_ep_tokens_map = ( + self.ep_tokens_summary.groupby("epRanks")["inputShapesSummary"] + .agg(tokensDiff=lambda x: x.max() - x.min()) + .reset_index() + ) + self.top_ep_tokens_map = self.top_ep_tokens_map.sort_values(by="tokensDiff", ascending=False).head(self.Top_Num) + + def run(self, context): + mapper_res = self.mapper_func(context) + self.reducer_func(mapper_res) + + if self._export_type == "db": + self.save_db() + else: + logger.error("ep_load_balance is only supported for db export type.") + + def save_db(self): + self.dump_data(self.ep_tokens_summary, Constant.DB_CLUSTER_COMMUNICATION_ANALYZER, self.EPTOKENSSUMMARY) + self.dump_data(self.top_ep_tokens_map, Constant.DB_CLUSTER_COMMUNICATION_ANALYZER, self.TOPEPTOKENSINFO) + + def _mapper_func(self, data_map, analysis_class): + profiler_db_path = data_map.get(Constant.PROFILER_DB_PATH) + rank_id = data_map.get(Constant.RANK_ID) + step_range = data_map.get(Constant.STEP_RANGE) + analysis_data_service = DatabaseService(profiler_db_path, {}) + analysis_data_service.add_table_for_query(self.META_DATA) + meta_map = analysis_data_service.query_data()[self.META_DATA] + parallel_group_info = meta_map.loc[meta_map['name'] == 'parallel_group_info', 'value'].iloc[0] + try: + data_dict = json.loads(parallel_group_info) + except json.JSONDecodeError as e: + logger.error(f"{profiler_db_path}'s parallel_group_info is illegal") + return None + if not isinstance(data_dict, dict): + raise TypeError('{} must be dict, not {}.'.format(data_dict, type(data_dict).__name__)) + for _, value in data_dict.items(): + if value["group_name"] == self.GROUPEP: + global_ranks = value["global_ranks"] + break + df = InputShapeExport(profiler_db_path, analysis_class, step_range).read_export_db() + if df is None or df.empty: + logger.warning(f"There is no stats data in {profiler_db_path}.") + return None + df["Rank"] = rank_id + df["epRanks"] = [global_ranks] * len(df) + return df \ No newline at end of file diff --git a/profiler/msprof_analyze/cluster_analyse/recipes/freq_analysis/__init__.py b/profiler/msprof_analyze/cluster_analyse/recipes/freq_analysis/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/profiler/msprof_analyze/cluster_analyse/recipes/freq_analysis/freq_analysis.py b/profiler/msprof_analyze/cluster_analyse/recipes/freq_analysis/freq_analysis.py new file mode 100644 index 0000000000000000000000000000000000000000..985288ca40cb0cab5b6c7677b163d202e71ed23e --- /dev/null +++ b/profiler/msprof_analyze/cluster_analyse/recipes/freq_analysis/freq_analysis.py @@ -0,0 +1,112 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from collections import defaultdict +import pandas as pd + +from msprof_analyze.cluster_analyse.recipes.base_recipe_analysis import BaseRecipeAnalysis +from msprof_analyze.prof_common.constant import Constant +from msprof_analyze.prof_common.logger import get_logger +from msprof_analyze.prof_common.database_service import DatabaseService + +logger = get_logger() + + +class FreqAnalysis(BaseRecipeAnalysis): + COMMON_FREQ = 1800 + FREE_FREQ = 800 + + def __init__(self, params): + super().__init__(params) + self.free_freq_ranks = [] + self.abnormal_freq_ranks = [] + self.abnormal_freq_ranks_map = {} + + @property + def base_dir(self): + return os.path.basename(os.path.dirname(__file__)) + + def reducer_func(self, mapper_res): + if self._is_msprof: + logger.warning("Freq analysis do not support msprof db now.") + return + mapper_res = list(filter(lambda res: res[0] is not None, mapper_res)) + if not mapper_res: + logger.error("Mapper data is None, load profiling data failed.") + return + for freqs, rank_id in mapper_res: + if freqs == [self.COMMON_FREQ]: + continue + elif set(freqs) == {self.COMMON_FREQ, self.FREE_FREQ}: + self.free_freq_ranks.append(rank_id) + else: + self.abnormal_freq_ranks.append(rank_id) + self.abnormal_freq_ranks_map[rank_id] = str(freqs) + self.free_freq_ranks.sort() + self.abnormal_freq_ranks.sort() + + def save_db(self): + if len(self.free_freq_ranks) > 0: + logger.info(f"Found {len(self.free_freq_ranks)} ranks with free time, " + f"aicore frequency in {[self.FREE_FREQ, self.COMMON_FREQ]}.") + free_ranks_df = pd.DataFrame() + free_ranks_df["rankId"] = self.free_freq_ranks + free_ranks_df["aicoreFrequency"] = str([self.FREE_FREQ, self.COMMON_FREQ]) + free_ranks_df.set_index(["rankId"], inplace=True) + self.dump_data(free_ranks_df, Constant.DB_CLUSTER_COMMUNICATION_ANALYZER, "FreeFrequencyRanks") + else: + logger.info("No rank found with free time.") + if len(self.abnormal_freq_ranks) > 0: + logger.info(f"Found {len(self.abnormal_freq_ranks)} ranks with abnormal aicore frequency.") + + abnormal_ranks_df = pd.DataFrame.from_dict(self.abnormal_freq_ranks_map, + orient="index", columns=["aicoreFrequency"]) + abnormal_ranks_df = abnormal_ranks_df.reset_index().rename(columns={"index": "rankId"}) + abnormal_ranks_df.set_index(["rankId"], inplace=True) + self.dump_data(abnormal_ranks_df, Constant.DB_CLUSTER_COMMUNICATION_ANALYZER, "AbnormalFrequencyRanks") + else: + logger.info("No rank found with abnormal aicore frequency.") + if len(self.free_freq_ranks) > 0 or len(self.abnormal_freq_ranks) > 0: + logger.info("Please verify result in output file.") + + def run(self, context): + mapper_res = self.mapper_func(context) + self.reducer_func(mapper_res) + + if self._export_type == Constant.DB: + self.save_db() + else: + logger.error("Frequence analysis is not supported for notebook export type.") + + def _mapper_func(self, data_map, analysis_class): + profiler_db_path = data_map.get(Constant.PROFILER_DB_PATH) + service = DatabaseService(profiler_db_path, None) + service.add_table_for_query("AICORE_FREQ", ["deviceId", "freq"]) + service.add_table_for_query("RANK_DEVICE_MAP", ["rankId"]) + service_res = service.query_data() + aic_freq = service_res.get("AICORE_FREQ", None) + rank_id = service_res.get("RANK_DEVICE_MAP", None) + if aic_freq is None or aic_freq.empty: + logger.error(f"No aic freq data found in {profiler_db_path}.") + return None, None + if rank_id is None or rank_id.empty: + logger.error(f"No rank_id data found in {profiler_db_path}.") + return None, None + rank_id = rank_id["rankId"].values[0] + freq_arr = aic_freq["freq"].values + freqs = list(set(freq_arr)) + freqs.sort() + return freqs, rank_id diff --git a/profiler/msprof_analyze/cluster_analyse/recipes/mstx_sum/mstx_sum.py b/profiler/msprof_analyze/cluster_analyse/recipes/mstx_sum/mstx_sum.py index bfbcc6ffb49c6457cd54a9413e8bf7a145ec365b..db6aae0de869853ccc5debac729492bfc9853695 100644 --- a/profiler/msprof_analyze/cluster_analyse/recipes/mstx_sum/mstx_sum.py +++ b/profiler/msprof_analyze/cluster_analyse/recipes/mstx_sum/mstx_sum.py @@ -21,7 +21,7 @@ from msprof_analyze.cluster_analyse.common_func.utils import describe_duration from msprof_analyze.cluster_analyse.recipes.base_recipe_analysis import BaseRecipeAnalysis from msprof_analyze.prof_common.constant import Constant from msprof_analyze.prof_common.logger import get_logger -from msprof_analyze.prof_exports.mstx_mark_export import MstxMarkExport +from msprof_analyze.prof_exports.mstx_event_export import MstxMarkExport, MstxRangeExport from msprof_analyze.prof_exports.mstx_step_export import MstxStepExport logger = get_logger() @@ -43,16 +43,28 @@ def format_mark_info(df: pd.DataFrame, start_idx, stop_idx, name) -> MarkInfo: ) -def rename_mark_msg_name(mark_stats_df: pd.DataFrame): +def format_range_info(df: pd.DataFrame, idx, name) -> MarkInfo: + range_series = df.iloc[idx] + return MarkInfo( + name=name, + framework_duration=float(0), + cann_duration=float(range_series["cann_end_ts"] - range_series["cann_start_ts"]), + device_duration=float(range_series["device_end_ts"] - range_series["device_start_ts"]), + tid=range_series["tid"], + start_ns=range_series["cann_start_ts"] + ) + + +def rename_mark_msg_name(mstx_stats_df: pd.DataFrame): msg_idx_counter = {} - for idx, mark_info in enumerate(mark_stats_df.itertuples(index=False)): + for idx, mark_info in enumerate(mstx_stats_df.itertuples(index=False)): msg_idx_counter.setdefault(mark_info.step_id, {}).setdefault(mark_info.name, []).append(idx) for msg_dict in msg_idx_counter.values(): for msg, idx_list in msg_dict.items(): if len(idx_list) <= 1: continue for i, idx in enumerate(idx_list): - mark_stats_df.loc[idx, 'name'] = f"{msg}_{i}" + mstx_stats_df.loc[idx, 'name'] = f"{msg}_{i}" def compute_step_id(mark_stat, step_stats_df: pd.DataFrame): @@ -80,6 +92,45 @@ def format_columns(df: pd.DataFrame): return formatted_df[cols] +def handle_mark_data(mark_df: pd.DataFrame, rank_id: int) -> list: + res = [] + mark_df["framework_ts"] = mark_df["framework_ts"].astype("int64") + mark_info = {} + mismatch_msg = [] + for idx, row in enumerate(mark_df.itertuples(index=False)): + if row.msg.endswith(MstxSum.START_SUFFIX): + msg = row.msg[:-len(MstxSum.START_SUFFIX)] + mark_info.setdefault(row.tid, {}).setdefault(msg, []).append(idx) + elif row.msg.endswith(MstxSum.STOP_SUFFIX): + msg = row.msg[:-len(MstxSum.STOP_SUFFIX)] + idx_list = mark_info.get(row.tid, {}).get(msg, []) + if not idx_list: + mismatch_msg.append((row.msg, idx)) + continue + start_idx = idx_list.pop() + res.append(format_mark_info(mark_df, start_idx, idx, msg)) + + # 统计未匹配上的mark信息 + for msg_info in mark_info.values(): + for msg, idx_list in msg_info.items(): + if not idx_list: + continue + mismatch_msg.extend((msg + MstxSum.START_SUFFIX, idx) for idx in idx_list) + if mismatch_msg: + mismatch_msg.sort(key=lambda msg: msg[1]) + logger.warning(f"The following mark messages do not match anyone in " + f"rank {rank_id}: {','.join(msg[0] for msg in mismatch_msg)}.") + + return res + + +def handle_range_data(range_df: pd.DataFrame) -> list: + res = [] + for idx, row in enumerate(range_df.itertuples(index=False)): + res.append(format_range_info(range_df, idx, row.msg)) + return res + + class MstxSum(BaseRecipeAnalysis): TABLE_FRAMEWORK_STATS = "MSTXAllFrameworkStats" TABLE_CANN_STATS = "MSTXAllCannStats" @@ -159,40 +210,18 @@ class MstxSum(BaseRecipeAnalysis): if step_df is None or step_df.empty: step_df = pd.DataFrame({"start_ns": [0], "end_ns": [float("inf")], "step_id": [0]}) mark_df = MstxMarkExport(profiler_db_path, analysis_class, step_range).read_export_db() - if mark_df is None or mark_df.empty: - logger.warning(f"There is no mark data in {profiler_db_path}.") + range_df = MstxRangeExport(profiler_db_path, analysis_class, step_range).read_export_db() + mstx_res = [] + if not mark_df.empty: + mstx_res += handle_mark_data(mark_df, rank_id) + if not range_df.empty: + mstx_res += handle_range_data(range_df) + if not mstx_res: + logger.warning(f"There is no mstx data in {profiler_db_path}.") return None - mark_df["framework_ts"] = mark_df["framework_ts"].astype("int64") - - mark_info = {} - mark_res = [] - mismatch_msg = [] - for idx, row in enumerate(mark_df.itertuples(index=False)): - if row.msg.endswith(MstxSum.START_SUFFIX): - msg = row.msg[:-len(MstxSum.START_SUFFIX)] - mark_info.setdefault(row.tid, {}).setdefault(msg, []).append(idx) - elif row.msg.endswith(MstxSum.STOP_SUFFIX): - msg = row.msg[:-len(MstxSum.STOP_SUFFIX)] - idx_list = mark_info.get(row.tid, {}).get(msg, []) - if not idx_list: - mismatch_msg.append((row.msg, idx)) - continue - start_idx = idx_list.pop() - mark_res.append(format_mark_info(mark_df, start_idx, idx, msg)) - - # 统计未匹配上的mark信息 - for msg_info in mark_info.values(): - for msg, idx_list in msg_info.items(): - if not idx_list: - continue - mismatch_msg.extend((msg + MstxSum.START_SUFFIX, idx) for idx in idx_list) - if mismatch_msg: - mismatch_msg.sort(key=lambda msg: msg[1]) - logger.warning(f"The following mark messages do not match anyone in " - f"rank {rank_id}: {','.join(msg[0] for msg in mismatch_msg)}.") - - mark_stats_df = pd.DataFrame(mark_res).assign(Rank=rank_id) - mark_stats_df["step_id"] = mark_stats_df.apply(compute_step_id, axis=1, step_stats_df=step_df) - rename_mark_msg_name(mark_stats_df) - mark_stats_df = format_columns(mark_stats_df).set_index("Name", drop=True) - return mark_stats_df + + mstx_stats_df = pd.DataFrame(mstx_res).assign(Rank=rank_id) + mstx_stats_df["step_id"] = mstx_stats_df.apply(compute_step_id, axis=1, step_stats_df=step_df) + rename_mark_msg_name(mstx_stats_df) + mstx_stats_df = format_columns(mstx_stats_df).set_index("Name", drop=True) + return mstx_stats_df diff --git a/profiler/msprof_analyze/cluster_analyse/recipes/slow_rank/__init__.py b/profiler/msprof_analyze/cluster_analyse/recipes/slow_rank/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a355e5a7f08206fc39dda4646817224c067f29f7 --- /dev/null +++ b/profiler/msprof_analyze/cluster_analyse/recipes/slow_rank/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/profiler/msprof_analyze/cluster_analyse/recipes/slow_rank/dixon_table.py b/profiler/msprof_analyze/cluster_analyse/recipes/slow_rank/dixon_table.py new file mode 100644 index 0000000000000000000000000000000000000000..7bf7e2c80621f6756b2d9ad82051eefac263d141 --- /dev/null +++ b/profiler/msprof_analyze/cluster_analyse/recipes/slow_rank/dixon_table.py @@ -0,0 +1,117 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# 单边狄克逊检验表,995置信度 +DIXON_TABLE_995 = { + 3: 0.994, + 4: 0.920, + 5: 0.823, + 6: 0.744, + 7: 0.680, + 8: 0.723, + 9: 0.676, + 10: 0.638, + 11: 0.707, + 12: 0.675, + 13: 0.649, + 14: 0.672, + 15: 0.649, + 16: 0.629, + 17: 0.611, + 18: 0.595, + 19: 0.580, + 20: 0.568, + 21: 0.556, + 22: 0.545, + 23: 0.536, + 24: 0.526, + 25: 0.519, + 26: 0.510, + 27: 0.503, + 28: 0.496, + 29: 0.489, + 30: 0.484, + 31: 0.478, + 32: 0.473, + 33: 0.468, + 34: 0.463, + 35: 0.458, + 36: 0.454, + 37: 0.450, + 38: 0.446, + 39: 0.442, + 40: 0.439, + 41: 0.435, + 42: 0.432, + 43: 0.429, + 44: 0.425, + 45: 0.423, + 46: 0.420, + 47: 0.417, + 48: 0.414, + 49: 0.412, + 50: 0.409, + 51: 0.407, + 52: 0.405, + 53: 0.402, + 54: 0.400, + 55: 0.398, + 56: 0.396, + 57: 0.394, + 58: 0.392, + 59: 0.391, + 60: 0.388, + 61: 0.387, + 62: 0.385, + 63: 0.383, + 64: 0.382, + 65: 0.380, + 66: 0.379, + 67: 0.377, + 68: 0.376, + 69: 0.374, + 70: 0.372, + 71: 0.371, + 72: 0.370, + 73: 0.368, + 74: 0.368, + 75: 0.366, + 76: 0.365, + 77: 0.364, + 78: 0.363, + 79: 0.361, + 80: 0.360, + 81: 0.359, + 82: 0.358, + 83: 0.356, + 84: 0.356, + 85: 0.355, + 86: 0.353, + 87: 0.352, + 88: 0.352, + 89: 0.351, + 90: 0.350, + 91: 0.349, + 92: 0.348, + 93: 0.347, + 94: 0.346, + 95: 0.345, + 96: 0.344, + 97: 0.344, + 98: 0.343, + 99: 0.341, + 100: 0.341, +} \ No newline at end of file diff --git a/profiler/msprof_analyze/cluster_analyse/recipes/slow_rank/slow_rank.py b/profiler/msprof_analyze/cluster_analyse/recipes/slow_rank/slow_rank.py new file mode 100644 index 0000000000000000000000000000000000000000..f2c5b89c8b875dd28a1b3a02744239041833d345 --- /dev/null +++ b/profiler/msprof_analyze/cluster_analyse/recipes/slow_rank/slow_rank.py @@ -0,0 +1,179 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from collections import defaultdict + +import pandas as pd +import numpy as np + +from msprof_analyze.cluster_analyse.recipes.base_recipe_analysis import BaseRecipeAnalysis +from msprof_analyze.prof_common.constant import Constant +from msprof_analyze.prof_common.logger import get_logger +from msprof_analyze.prof_exports.cluster_time_summary_export import CommunicationTimeExport +from msprof_analyze.cluster_analyse.recipes.slow_rank.dixon_table import DIXON_TABLE_995 + +logger = get_logger() + + +def judge_norm(time_list, threshold=3): + t_max = max(time_list) + t_min = min(time_list) + t_mean = np.mean(time_list) + t_std = np.std(time_list) + threshold_high = t_mean + threshold * t_std + threshold_low = t_mean - threshold * t_std + + # 耗时低于下阈值的卡认为是慢卡 + outliers_idx = [i for i, time in enumerate(time_list) if time < threshold_low] + + # 如果存在高于上阈值的卡,则将耗时最短的卡加到慢卡的list中 + if t_max > threshold_high: + if time_list.index(t_min) not in outliers_idx: + outliers_idx.append(time_list.index(t_min)) + return outliers_idx + + +def judge_dixon(time_list): + n = len(time_list) + if n in [0, 1, 2]: + return [] + sorted_list = sorted(time_list) + + # 判断计算检验指标时分母是否可能为0 + if len(set(sorted_list)) <= 3: + return [] + + # 计算狄克逊检验的检验指标,次小值和最小值差,比上最大值和最小值的差。根据数据数量改变次小值和最大值的选取 + if n <= Constant.MAX_DIXON_NUM: + if n <= Constant.DIXON_THRESHOLD_1: + flag = (sorted_list[1] - sorted_list[0]) / (sorted_list[-1] - sorted_list[0]) + elif n <= Constant.DIXON_THRESHOLD_2: + flag = (sorted_list[1] - sorted_list[0]) / (sorted_list[-2] - sorted_list[0]) + elif n <= Constant.DIXON_THRESHOLD_3: + flag = (sorted_list[2] - sorted_list[0]) / (sorted_list[-2] - sorted_list[0]) + else: + flag = (sorted_list[2] - sorted_list[0]) / (sorted_list[-3] - sorted_list[0]) + + # 根据数据数量查表,若计算的检验指标较大,则认为有异常值,耗时最短的卡是慢卡 + if flag > DIXON_TABLE_995[n]: + return [time_list.index(sorted_list[0])] + return [] + + +def judge_slow_rank(time_list): + """根据time list长度 选择狄克逊检验或三倍标准差""" + if len(time_list) <= Constant.MAX_DIXON_NUM: + return judge_dixon(time_list) + else: + return judge_norm(time_list) + + +class SlowRankAnalysis(BaseRecipeAnalysis): + def __init__(self, params): + super().__init__(params) + logger.info("Slow Rank Analysis init.") + + @property + def base_dir(self): + return os.path.basename(os.path.dirname(__file__)) + + def reducer_func(self, mapper_res): + mapper_res = list(filter(lambda df: df is not None, mapper_res)) + if not mapper_res: + logger.error("Mapper data is None.") + return None + concated_df = pd.concat(mapper_res) + return concated_df + + def run(self, context): + if self._is_msprof: + logger.warning("Slow rank analysis do not support msprof db now.") + return + + mapper_res = self.mapper_func(context) + comm_ops_df = self.reducer_func(mapper_res) + if comm_ops_df is None: + return + + analyzer = SlowRankVoteAnalysis(comm_ops_df) + perpector_df = analyzer.run() + + if self._export_type == Constant.DB: + self.save_db(perpector_df) + else: + logger.error("SlowRank analysis is not supported for notebook export type.") + + def save_db(self, perpector_df): + self.dump_data(perpector_df, Constant.DB_CLUSTER_COMMUNICATION_ANALYZER, "SlowRank") + + def _mapper_func(self, data_map, analysis_class): + profiler_db_path = data_map.get(Constant.PROFILER_DB_PATH) + step_range = data_map.get(Constant.STEP_RANGE) + df = CommunicationTimeExport(profiler_db_path, analysis_class, step_range).read_export_db() + return df + + +class SlowRankVoteAnalysis: + def __init__(self, comm_ops): + self.comm_ops = comm_ops + + def grouping_ops(self): + """按照通信域、算子名称对通信算子进行分组""" + grouped_ops_dict = defaultdict(lambda: defaultdict(list)) + self.comm_ops = self.comm_ops[~self.comm_ops["opName"].str.contains("send")] + self.comm_ops = self.comm_ops[~self.comm_ops["opName"].str.contains("receive")] + grouped_df = self.comm_ops.groupby("groupName") + exclude_groups = [] + for group_name in grouped_df.groups.keys(): + ops_groupby_group_name = grouped_df.get_group(group_name) + ops_num = ops_groupby_group_name.groupby("opName").size().values + if len(set(ops_num)) > 1: + exclude_groups.append(group_name) + for exclude_group in exclude_groups: + self.comm_ops.drop(self.comm_ops[self.comm_ops["groupName"] == exclude_group].index, inplace=True) + self.comm_ops.reset_index(drop=True, inplace=True) + n = len(self.comm_ops) + group_name_arr = self.comm_ops["groupName"].values + op_name_arr = self.comm_ops["opName"].values + for idx in range(n): + group_name = group_name_arr[idx] + op_name = op_name_arr[idx] + grouped_ops_dict[group_name][op_name].append(idx) + return grouped_ops_dict + + def run(self): + grouped_ops_dict = self.grouping_ops() + perpector_dict = self.analysis(grouped_ops_dict) + return perpector_dict + + def analysis(self, grouped_ops_dict): + rank_id_arr = self.comm_ops["rankId"].values + comm_time_arr = self.comm_ops["communication_time"].values + perpector_dict = defaultdict(lambda: 0) + for _, ops_same_group in grouped_ops_dict.items(): + for _, ops_list in ops_same_group.items(): + time_list = [comm_time_arr[op_idx] for op_idx in ops_list] + perpector_rank_idx = judge_slow_rank(time_list) + if perpector_rank_idx: + for rank_idx in perpector_rank_idx: + slow_rank = rank_id_arr[ops_list[rank_idx]] + perpector_dict[slow_rank] += 1 + + perpector_df = pd.DataFrame(columns=["rankId", "slowAffectCount"]) + for rank, perpector_times in perpector_dict.items(): + perpector_df.loc[len(perpector_df)] = [rank, perpector_times] + perpector_df.set_index(["rankId"], inplace=True) + return perpector_df diff --git a/profiler/msprof_analyze/compare_tools/compare_backend/compare_bean/overall_metrics_bean.py b/profiler/msprof_analyze/compare_tools/compare_backend/compare_bean/overall_metrics_bean.py index 059416ec15e54b7732eb40cbf8b0e6ef1227bf90..4c040612033c6a92905ce09bd02e5c4a181ec988 100644 --- a/profiler/msprof_analyze/compare_tools/compare_backend/compare_bean/overall_metrics_bean.py +++ b/profiler/msprof_analyze/compare_tools/compare_backend/compare_bean/overall_metrics_bean.py @@ -68,11 +68,39 @@ class OverallMetricsBean: base_group_data = self._base_data.get("group", {}) comparison_group_data = self._comparison_data.get("group", {}) + base_pg_name_dict = self._base_data.get("pg_name_dict", {}) + comparison_pg_name_dict = self._comparison_data.get("pg_name_dict", {}) default_value = [0, 0, "/"] + # deal base and comparsion data which can match with pg_name + for base_pg_name, base_group_name_list in base_pg_name_dict.items(): + if len(base_group_name_list) != 1 or base_pg_name == Constant.UNKNOWN: + continue + comparison_group_name_list = comparison_pg_name_dict.get(base_pg_name, []) + if len(comparison_group_name_list) != 1: + continue + + base_data = base_group_data.pop(base_group_name_list[0], {}) + comparison_data = comparison_group_data.pop(comparison_group_name_list[0], {}) + description = f"\t{base_pg_name}: Communication" + ExcelConfig.ROW_STYLE_MAP[description] = CellFormatType.LIGHT_BLUE_NORMAL + self._append_data(rows_data, + self._get_row_data(description, + base_data.get(ExcelConfig.COMMUNICATION_TIME, default_value), + comparison_data.get(ExcelConfig.COMMUNICATION_TIME, default_value))) + self._append_data(rows_data, + self._get_row_data(ExcelConfig.WAIT, base_data.get(ExcelConfig.WAIT, default_value), + comparison_data.get(ExcelConfig.WAIT, default_value))) + self._append_data(rows_data, + self._get_row_data(ExcelConfig.TRANSMIT, + base_data.get(ExcelConfig.TRANSMIT, default_value), + comparison_data.get(ExcelConfig.TRANSMIT, default_value))) + for group_name, base_data in base_group_data.items(): comparison_data = comparison_group_data.pop(group_name, {}) - self._append_data(rows_data, self._get_row_data(group_name, base_data.get("group", default_value), - comparison_data.get("group", default_value))) + self._append_data(rows_data, + self._get_row_data(base_data.get("description", group_name), + base_data.get(ExcelConfig.COMMUNICATION_TIME, default_value), + comparison_data.get(ExcelConfig.COMMUNICATION_TIME, default_value))) self._append_data(rows_data, self._get_row_data(ExcelConfig.WAIT, base_data.get(ExcelConfig.WAIT, default_value), comparison_data.get(ExcelConfig.WAIT, default_value))) @@ -81,8 +109,10 @@ class OverallMetricsBean: base_data.get(ExcelConfig.TRANSMIT, default_value), comparison_data.get(ExcelConfig.TRANSMIT, default_value))) for group_name, comparison_data in comparison_group_data.items(): - self._append_data(rows_data, self._get_row_data(group_name, default_value, - comparison_data.get("group", default_value))) + self._append_data(rows_data, + self._get_row_data(comparison_data.get("description", group_name), + default_value, + comparison_data.get(ExcelConfig.COMMUNICATION_TIME, default_value))) self._append_data(rows_data, self._get_row_data(ExcelConfig.WAIT, default_value, comparison_data.get(ExcelConfig.WAIT, default_value))) self._append_data(rows_data, self._get_row_data(ExcelConfig.TRANSMIT, default_value, @@ -373,13 +403,17 @@ class OverallMetricsInfo: } if self._comm_group_list: for group_name in self._comm_group_list: - group_name_index = f"\t{group_name}" - ExcelConfig.ROW_STYLE_MAP[group_name_index] = CellFormatType.LIGHT_BLUE_NORMAL - overall_metrics_data.setdefault("group", {})[group_name_index] = { - "group": self.communication_data_by_group(group_name), + pg_name = self._profiling_info.get_pg_name_by_group(group_name) + description = " ".join([pg_name + ":" if pg_name != Constant.UNKNOWN else "", group_name]).strip() + ExcelConfig.ROW_STYLE_MAP[f"\t{description}"] = CellFormatType.LIGHT_BLUE_NORMAL + overall_metrics_data.setdefault("group", {})[group_name] = { + "description": f"\t{description}", + ExcelConfig.COMMUNICATION_TIME: self.communication_data_by_group(group_name), ExcelConfig.WAIT: self.wait_data_by_group(group_name), ExcelConfig.TRANSMIT: self.transmit_data_by_group(group_name) } + overall_metrics_data.setdefault("pg_name_dict", {}).setdefault(pg_name, []).append(group_name) + for kernel_name in self._profiling_info.mc2_time_dict.keys(): mc2_name_index = f"\t{kernel_name}" ExcelConfig.ROW_STYLE_MAP[mc2_name_index] = CellFormatType.LIGHT_BLUE_NORMAL diff --git a/profiler/msprof_analyze/compare_tools/compare_backend/compare_bean/profiling_info.py b/profiler/msprof_analyze/compare_tools/compare_backend/compare_bean/profiling_info.py index bcbce59c016338195d49044508614307f2db1896..36f8a0c7186f8e3c79123c4ee7f8abbe8f623720 100644 --- a/profiler/msprof_analyze/compare_tools/compare_backend/compare_bean/profiling_info.py +++ b/profiler/msprof_analyze/compare_tools/compare_backend/compare_bean/profiling_info.py @@ -27,7 +27,7 @@ class ProfilingInfo: 'page_attention_time', 'page_attention_num', 'vector_time_trans', 'vector_num_trans', 'vector_time_notrans', 'vector_num_notrans', 'sdma_time_tensor_move', 'sdma_num_tensor_move', 'sdma_time_stream', 'sdma_num_stream', 'other_cube_time', 'other_cube_num', 'rdma_bandwidth', - 'sdma_bandwidth', 'communication_group_time', 'mc2_time_dict'] + 'sdma_bandwidth', 'communication_group_time', 'mc2_time_dict', 'pg_name_dict'] TABLE_NAME = Constant.PERFORMANCE_TABLE HEADERS = [] OVERHEAD = [] @@ -93,6 +93,9 @@ class ProfilingInfo: # 按group展示通信的卡间等待和传输耗时 self.communication_group_time = {} + # communication_group与pg_name的对应关系 + self.pg_name_dict = {} + @property def e2e_time_ms(self): @@ -334,6 +337,9 @@ class ProfilingInfo: for time in time_dict.values(): self.wait_time += time.get(Constant.WAIT_TIME, 0) + def update_communication_group_pg_name(self, pg_name_dict: dict): + self.pg_name_dict = pg_name_dict + def set_memory_used(self, memory: float): self.memory_used = memory @@ -401,3 +407,6 @@ class ProfilingInfo: def get_mc2_number_by_name(self, kernel_name: str): return self.mc2_time_dict.get(kernel_name, {}).get(Constant.MC2_NUMBER, 0) + + def get_pg_name_by_group(self, group: str): + return self.pg_name_dict.get(group, Constant.UNKNOWN) \ No newline at end of file diff --git a/profiler/msprof_analyze/compare_tools/compare_backend/profiling_parser/base_profiling_parser.py b/profiler/msprof_analyze/compare_tools/compare_backend/profiling_parser/base_profiling_parser.py index b3d9a29944fe7f722b2fc54ac0381f8c23c1af14..844528f3029d053cdf8330f8fe7009bae1a366e6 100644 --- a/profiler/msprof_analyze/compare_tools/compare_backend/profiling_parser/base_profiling_parser.py +++ b/profiler/msprof_analyze/compare_tools/compare_backend/profiling_parser/base_profiling_parser.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os from abc import abstractmethod, ABC from decimal import Decimal @@ -156,6 +157,7 @@ class BaseProfilingParser(ABC): self._dispatch_events() self._update_kernel_dict() self._update_communication_dict() + self._update_pg_name_map() if self._enable_memory_compare: self._update_memory_list() if self._enable_profiling_compare: @@ -369,3 +371,17 @@ class BaseProfilingParser(ABC): with open(self._json_path, 'r') as file: for event in ijson.items(file, item): yield TraceEventBean(event) + + def _update_pg_name_map(self): + meta_file = os.path.join(self._profiling_path, Constant.PROFILER_METADATA) + if not os.path.exists(meta_file): + return + meta_data = FileManager.read_json_file(meta_file) + if Constant.PARALLEL_GROUP_INFO not in meta_data: + return + pg_name_map = {} + for group_id, group_info in meta_data[Constant.PARALLEL_GROUP_INFO].items(): + if group_id not in pg_name_map: + format_group_id = " ".join(["Group", group_id, "Communication"]) + pg_name_map[format_group_id] = group_info.get('group_name', "") + self._result_data.overall_metrics.update_communication_group_pg_name(pg_name_map) diff --git a/profiler/msprof_analyze/compare_tools/compare_backend/utils/common_func.py b/profiler/msprof_analyze/compare_tools/compare_backend/utils/common_func.py index ac9b4726eadb2cb5998ae10b4eaa995b6d7f0252..b26bdb826250b1b233fc4fa6dcbb8655758a1a8b 100644 --- a/profiler/msprof_analyze/compare_tools/compare_backend/utils/common_func.py +++ b/profiler/msprof_analyze/compare_tools/compare_backend/utils/common_func.py @@ -39,7 +39,7 @@ def convert_to_float(data: any) -> float: try: float_value = float(data) except Exception: - logger.error('Invalid profiling data which failed to convert data to float.') + logger.warning('Invalid profiling data which failed to convert data to float.') return 0.0 return float_value @@ -48,8 +48,8 @@ def convert_to_decimal(data: any) -> Decimal: try: decimal_value = Decimal(data) except Exception: - logger.error('Invalid profiling data which failed to convert data to decimal.') - return 0.0 + logger.warning('Invalid profiling data which failed to convert data to decimal.') + return Decimal(0) return decimal_value diff --git a/profiler/msprof_analyze/prof_common/constant.py b/profiler/msprof_analyze/prof_common/constant.py index 5353fc6d40f25cee9f1a10e9b734dc95573b3b2c..1e514b7c3468d8539257a7c631807eb26ed0d5ac 100644 --- a/profiler/msprof_analyze/prof_common/constant.py +++ b/profiler/msprof_analyze/prof_common/constant.py @@ -43,6 +43,7 @@ class Constant(object): FRAMEWORK_DIR = "FRAMEWORK" CLUSTER_ANALYSIS_OUTPUT = "cluster_analysis_output" SINGLE_OUTPUT = "ASCEND_PROFILER_OUTPUT" + ANALYZE_DIR = "analyze" COMM_JSON = "communication.json" COMM_MATRIX_JSON = "communication_matrix.json" STEP_TIME_CSV = "step_trace_time.csv" @@ -61,6 +62,7 @@ class Constant(object): # communication P2P = "p2p" COLLECTIVE = "collective" + TOTAL = "total" STEP_ID = "step_id" RANK_ID = "rank_id" GROUP_NAME = "group_name" @@ -86,6 +88,7 @@ class Constant(object): ELAPSE_TIME_MS = "Elapse Time(ms)" IDLE_TIME_MS = "Idle Time(ms)" LARGE_PACKET_RATIO = "Large Packet Ratio" + TYPE = "type" # params DATA_MAP = "data_map" @@ -97,6 +100,8 @@ class Constant(object): TRANSPORT_TYPE = "Transport Type" COMM_DATA_DICT = "comm_data_dict" DATA_TYPE = "data_type" + IS_MSPROF = "is_prof" + IS_MINDSPORE = "is_mindspore" # step time RANK = "rank" @@ -126,14 +131,17 @@ class Constant(object): TABLE_HOST_INFO = "HostInfo" TABLE_RANK_DEVICE_MAP = "RankDeviceMap" TABLE_CLUSTER_BASE_INFO = "ClusterBaseInfo" + TABLE_META_DATA = "META_DATA" # data config key CONFIG = "config" EXPER_CONFIG = "experimental_config" EXPER_EXPORT_TYPE = "_export_type" + PROFILER_PARAMETER = "profiler_parameters" # metadata key DISTRIBUTED_ARGS = "distributed_args" + PARALLEL_GROUP_INFO = "parallel_group_info" # mode ALL = "all" @@ -161,6 +169,7 @@ class Constant(object): BLUE_COLOR = "00BFFF" LIGHT_BLUE_COLOR = "87CEFA" US_TO_MS = 1000 + NS_TO_US = 1000 KB_TO_MB = 1024 INVALID_VALUE = -1 MILLISECONDS_TO_SECONDS = 10 ** 3 @@ -421,6 +430,7 @@ class Constant(object): CONCURRENT_MODE = "concurrent" PROFILER_DB_PATH = "profiler_db_path" + ANALYSIS_DB_PATH = "analysis_db_path" RANK_LIST = "rank_list" EXPORT_TYPE = "export_type" EXTRA_ARGS = "args" @@ -430,4 +440,14 @@ class Constant(object): # hccl_sum UINT32_BITS = 32 - UINT32_MASK = 0xffffffff \ No newline at end of file + UINT32_MASK = 0xffffffff + + # slow rank + MAX_DIXON_NUM = 100 + DIXON_THRESHOLD_1 = 7 + DIXON_THRESHOLD_2 = 10 + DIXON_THRESHOLD_3 = 13 + + UNKNOWN = "unknown" + + SQL_PLACEHOLDER_PATTERN = r"\?|\%s" diff --git a/profiler/msprof_analyze/prof_common/database_service.py b/profiler/msprof_analyze/prof_common/database_service.py index 6b776d4d957a9491aeb5690cf456038c114c3590..8cd4cdd2a1f414ba6d7945a22dea8fc6c312f85e 100644 --- a/profiler/msprof_analyze/prof_common/database_service.py +++ b/profiler/msprof_analyze/prof_common/database_service.py @@ -12,6 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import re + import pandas as pd from msprof_analyze.prof_common.db_manager import DBManager @@ -48,6 +50,8 @@ class DatabaseService: self._db_path = db_path self._step_range = step_range self._table_info = {} + self._param = (self._step_range.get(Constant.START_NS), + self._step_range.get(Constant.END_NS)) if self._step_range else None def add_table_for_query(self, table_name: str, columns=None): if not isinstance(table_name, str): @@ -71,15 +75,26 @@ class DatabaseService: if not DBManager.judge_table_exists(cursor, table_name): logger.warning(f"This table {table_name} does not exist in this database {self._db_path}.") continue - columns_str = "*" if not columns else ",".join(columns) + table_columns = DBManager.get_table_columns_name(cursor, table_name) + if not columns: + columns_str = ",".join(table_columns) + else: + columns = [column for column in columns if column in table_columns] + columns_str = ",".join(columns) + if not columns_str: + logger.error(f"The fields to be queried in Table {table_name} are invalid.") + return result_data if table_name in self.TABLE_TS_DICT and self._step_range: - where_str = f"where {self.TABLE_TS_DICT.get(table_name)} >= {self._step_range.get(Constant.START_NS)}" \ - f" and {self.TABLE_TS_DICT.get(table_name)} <= {self._step_range.get(Constant.END_NS)}" + where_str = f"where {self.TABLE_TS_DICT.get(table_name)} >= ? " \ + f"and {self.TABLE_TS_DICT.get(table_name)} <= ?" else: where_str = "" query_sql = f"select {columns_str} from {table_name} {where_str}" try: - data = pd.read_sql(query_sql, conn) + if self._param is not None and re.search(Constant.SQL_PLACEHOLDER_PATTERN, query_sql): + data = pd.read_sql(query_sql, conn, params=self._param) + else: + data = pd.read_sql(query_sql, conn) result_data[table_name] = data except Exception as err: logger.error(err) diff --git a/profiler/msprof_analyze/prof_common/db_manager.py b/profiler/msprof_analyze/prof_common/db_manager.py index ac24ec8144f7a67c1796906d7e75ab25a7a7f71c..151cb19ea5359fde07b52c8534f06d7697ac5bff 100644 --- a/profiler/msprof_analyze/prof_common/db_manager.py +++ b/profiler/msprof_analyze/prof_common/db_manager.py @@ -189,6 +189,17 @@ class DBManager: cls.destroy_db_connect(conn, curs) return res + @classmethod + def get_table_columns_name(cls, curs: any, table: any) -> int: + sql = f"PRAGMA table_info({table})" + try: + curs.execute(sql) + columns = curs.fetchall() + except sqlite3.Error as err: + logger.error(err) + return [] + return [column[1] for column in columns] + @classmethod def fetch_all_data(cls: any, curs: any, sql: str, param: tuple = None, is_dict: bool = True) -> list: """ diff --git a/profiler/msprof_analyze/prof_common/file_manager.py b/profiler/msprof_analyze/prof_common/file_manager.py index 7329d1d9f3cd11588bf63300e581260205b400cb..c8d3f0827a277cff984c720fe17e6bd8ff2d85fa 100644 --- a/profiler/msprof_analyze/prof_common/file_manager.py +++ b/profiler/msprof_analyze/prof_common/file_manager.py @@ -114,6 +114,18 @@ class FileManager: raise RuntimeError(f"Failed to read the file: {base_name}, reason is {str(e)}") from e return content + @classmethod + def create_common_file(cls, file_path: str, content: str) -> None: + base_name = os.path.basename(file_path) + PathManager.check_path_writeable(os.path.dirname(file_path)) + try: + with os.fdopen( + os.open(file_path, os.O_WRONLY | os.O_CREAT, Constant.FILE_AUTHORITY), + 'w') as file: + file.write(content) + except Exception as e: + raise RuntimeError(f"Can't create file: {base_name}") from e + @classmethod def create_csv_file(cls, profiler_path: str, data: list, file_name: str, headers: list = None) -> None: if not data: diff --git a/profiler/msprof_analyze/prof_common/utils.py b/profiler/msprof_analyze/prof_common/utils.py index 005d8505c9ccd750d4856518961c62b4407eea1e..5c0832566330bc3ba219e2de654d4b3d7af46e52 100644 --- a/profiler/msprof_analyze/prof_common/utils.py +++ b/profiler/msprof_analyze/prof_common/utils.py @@ -91,3 +91,10 @@ def convert_to_int(num): except (ValueError, NameError): logger.error(f"Can not convert %s to int", num) return 0 + + +def compute_ratio(dividend: float, divisor: float): + if abs(divisor) < 1e-15: + return 0 + else: + return round(dividend / divisor, 4) diff --git a/profiler/msprof_analyze/prof_exports/base_stats_export.py b/profiler/msprof_analyze/prof_exports/base_stats_export.py index 65ccd69ecde0acb296308e0c37bec3323468ae34..2d17c41cb511c6e8c75c6ad89c478d20eab26600 100644 --- a/profiler/msprof_analyze/prof_exports/base_stats_export.py +++ b/profiler/msprof_analyze/prof_exports/base_stats_export.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import re import pandas as pd @@ -29,18 +30,26 @@ class BaseStatsExport: self._analysis_class = analysis_class self._step_range = step_range self._query = None + self._param = (self._step_range.get(Constant.START_NS), + self._step_range.get(Constant.END_NS)) if self._step_range else None def get_query(self): return self._query def read_export_db(self): try: + if not self._db_path: + logger.error("db path is None.") + return None query = self.get_query() if query is None: logger.error("query is None.") return None conn, cursor = DBManager.create_connect_db(self._db_path, Constant.ANALYSIS) - data = pd.read_sql(query, conn) + if self._param is not None and re.search(Constant.SQL_PLACEHOLDER_PATTERN, query): + data = pd.read_sql(query, conn, params=self._param) + else: + data = pd.read_sql(query, conn) DBManager.destroy_db_connect(conn, cursor) return data except Exception as e: diff --git a/profiler/msprof_analyze/prof_exports/cann_api_sum_export.py b/profiler/msprof_analyze/prof_exports/cann_api_sum_export.py index 0d3da94a001609cdbaed7d3f4646dc908d2b8c23..456aac95f07fec30f9f87a7e401d14a7edc4ea05 100644 --- a/profiler/msprof_analyze/prof_exports/cann_api_sum_export.py +++ b/profiler/msprof_analyze/prof_exports/cann_api_sum_export.py @@ -64,12 +64,5 @@ class CannApiSumExport(BaseStatsExport): def __init__(self, db_path, recipe_name, step_range): super().__init__(db_path, recipe_name, step_range) - self._query = self.get_query_statement() - - def get_query_statement(self): - if self._step_range: - filter_statement = f"WHERE CANN_API.startNs >= {self._step_range.get(Constant.START_NS)} " \ - f"and CANN_API.startNs <= {self._step_range.get(Constant.END_NS)}" - else: - filter_statement = "" - return QUERY.format(filter_statement) + filter_statement = "WHERE CANN_API.startNs >= ? and CANN_API.startNs <= ?" if step_range else "" + self._query = QUERY.format(filter_statement) diff --git a/profiler/msprof_analyze/prof_exports/cluster_time_summary_export.py b/profiler/msprof_analyze/prof_exports/cluster_time_summary_export.py new file mode 100644 index 0000000000000000000000000000000000000000..36b86301dc73ffb8e6a206f9b61bac2b159fc747 --- /dev/null +++ b/profiler/msprof_analyze/prof_exports/cluster_time_summary_export.py @@ -0,0 +1,38 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from msprof_analyze.prof_exports.base_stats_export import BaseStatsExport +from msprof_analyze.prof_common.constant import Constant + + +class CommunicationTimeExport(BaseStatsExport): + QUERY = """ + SELECT + RANK_DEVICE_MAP.rankId, + si_group.value AS groupName, + si_op.value AS opName, + (COMMUNICATION_OP.endNs - COMMUNICATION_OP.startNs) / 1000.0 AS communication_time + FROM COMMUNICATION_OP + CROSS JOIN RANK_DEVICE_MAP + JOIN STRING_IDS si_group ON COMMUNICATION_OP.groupName = si_group.id + JOIN STRING_IDS si_op ON COMMUNICATION_OP.opName = si_op.id + JOIN CANN_API ON CANN_API.connectionId = COMMUNICATION_OP.connectionId + {} + """ + + def __init__(self, db_path, recipe_name, step_range): + super().__init__(db_path, recipe_name, step_range) + filter_statement = "WHERE CANN_API.startNs >= ? and CANN_API.startNs <= ?" if step_range else "" + self._query = self.QUERY.format(filter_statement) diff --git a/profiler/msprof_analyze/prof_exports/compute_op_sum_export.py b/profiler/msprof_analyze/prof_exports/compute_op_sum_export.py index f337925dc36ff8e26c782ab1ea1c00618ebf271c..24a2cc2d990976b5fa6e4c26a8b4dd745a88a767 100644 --- a/profiler/msprof_analyze/prof_exports/compute_op_sum_export.py +++ b/profiler/msprof_analyze/prof_exports/compute_op_sum_export.py @@ -69,27 +69,13 @@ class ComputeOpSumExport(BaseStatsExport): def __init__(self, db_path, recipe_name, step_range): super().__init__(db_path, recipe_name, step_range) - self._query = self.get_query_statement() - - def get_query_statement(self): - if self._step_range: - filter_statement = f"WHERE TASK.startNs >= {self._step_range.get(Constant.START_NS)} " \ - f"and TASK.startNs <= {self._step_range.get(Constant.END_NS)}" - else: - filter_statement = "" - return QUERY.format(filter_statement) + filter_statement = "WHERE TASK.startNs >= ? and TASK.startNs <= ?" if step_range else "" + self._query = QUERY.format(filter_statement) class ComputeOpSumExportExcludeOpName(BaseStatsExport): def __init__(self, db_path, recipe_name, step_range): super().__init__(db_path, recipe_name, step_range) - self._query = self.get_query_statement() - - def get_query_statement(self): - if self._step_range: - filter_statement = f"WHERE TASK.startNs >= {self._step_range.get(Constant.START_NS)} " \ - f"and TASK.startNs <= {self._step_range.get(Constant.END_NS)}" - else: - filter_statement = "" - return QUERY_EXCLUDE_OPNAME.format(filter_statement) + filter_statement = "WHERE TASK.startNs >= ? and TASK.startNs <= ?" if step_range else "" + self._query = QUERY_EXCLUDE_OPNAME.format(filter_statement) diff --git a/profiler/msprof_analyze/prof_exports/ep_load_balance_ecport.py b/profiler/msprof_analyze/prof_exports/ep_load_balance_ecport.py new file mode 100644 index 0000000000000000000000000000000000000000..59acd6bdde7d23ad6d645d420c3f6f9fcf7cd2b6 --- /dev/null +++ b/profiler/msprof_analyze/prof_exports/ep_load_balance_ecport.py @@ -0,0 +1,41 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from msprof_analyze.prof_exports.base_stats_export import BaseStatsExport +from msprof_analyze.prof_common.constant import Constant + +GROUPED_MATMUL_QUERY = """ +SELECT + InputShapes_IDS.value AS "InputShapes" +FROM COMPUTE_TASK_INFO +JOIN TASK + ON COMPUTE_TASK_INFO.globalTaskId = TASK.globalTaskId +LEFT JOIN STRING_IDS AS InputShapes_IDS + ON InputShapes_IDS.id = COMPUTE_TASK_INFO.inputShapes +WHERE COMPUTE_TASK_INFO.opType = ( + SELECT id + FROM STRING_IDS + WHERE value = 'GroupedMatmul' +) +{} + """ + + +class InputShapeExport(BaseStatsExport): + + def __init__(self, db_path, recipe_name, step_range): + super().__init__(db_path, recipe_name, step_range) + filter_statement = "And TASK.startNs >= ? And TASK.endNs <= ?" if step_range else "" + self._query = GROUPED_MATMUL_QUERY.format(filter_statement) diff --git a/profiler/msprof_analyze/prof_exports/hccl_sum_export.py b/profiler/msprof_analyze/prof_exports/hccl_sum_export.py index c577d40c0f5ae1289d196bdd6d7cd306ebcbf01e..80750ef88ddfeceb54b35f3083ce6f6a0ae318eb 100644 --- a/profiler/msprof_analyze/prof_exports/hccl_sum_export.py +++ b/profiler/msprof_analyze/prof_exports/hccl_sum_export.py @@ -41,12 +41,5 @@ class HcclSumExport(BaseStatsExport): def __init__(self, db_path, recipe_name, step_range): super().__init__(db_path, recipe_name, step_range) - self._query = self.get_query_statement() - - def get_query_statement(self): - if self._step_range: - filter_statement = f"WHERE COMMUNICATION_OP.startNs >= {self._step_range.get(Constant.START_NS)} " \ - f"and COMMUNICATION_OP.startNs <= {self._step_range.get(Constant.END_NS)}" - else: - filter_statement = "" - return QUERY.format(filter_statement) + filter_stat = "WHERE COMMUNICATION_OP.startNs >= ? and COMMUNICATION_OP.startNs <= ?" if step_range else "" + self._query = QUERY.format(filter_stat) diff --git a/profiler/msprof_analyze/prof_exports/mstx_mark_export.py b/profiler/msprof_analyze/prof_exports/mstx_event_export.py similarity index 56% rename from profiler/msprof_analyze/prof_exports/mstx_mark_export.py rename to profiler/msprof_analyze/prof_exports/mstx_event_export.py index 6a7f8d0c6d2f1b4cbbceb9323157215421b58464..76bf4672b616b94d639f75c89782b07f3a175dcc 100644 --- a/profiler/msprof_analyze/prof_exports/mstx_mark_export.py +++ b/profiler/msprof_analyze/prof_exports/mstx_event_export.py @@ -16,7 +16,7 @@ from msprof_analyze.prof_exports.base_stats_export import BaseStatsExport from msprof_analyze.prof_common.constant import Constant -QUERY = """ +MARK_QUERY = """ WITH FRAMEWORK_API AS ( SELECT @@ -46,7 +46,8 @@ LEFT JOIN LEFT JOIN STRING_IDS AS MSG_IDS ON MSTX_EVENTS.message == MSG_IDS.id -{} +WHERE + MSTX_EVENTS.eventType == 3 {} ORDER BY MSTX_EVENTS.startNs """ @@ -57,13 +58,47 @@ class MstxMarkExport(BaseStatsExport): def __init__(self, db_path, recipe_name, step_range): super().__init__(db_path, recipe_name, step_range) self._query = self.get_query_statement() + self._param = (step_range.get(Constant.START_NS), step_range.get(Constant.END_NS), + step_range.get(Constant.START_NS), + step_range.get(Constant.END_NS)) if step_range else None def get_query_statement(self): if self._step_range: - filter_statement_1 = f"WHERE PYTORCH_API.startNs >= {self._step_range.get(Constant.START_NS)} " \ - f"and PYTORCH_API.startNs <= {self._step_range.get(Constant.END_NS)}" - filter_statement_2 = f"WHERE MSTX_EVENTS.startNs >= {self._step_range.get(Constant.START_NS)} " \ - f"and MSTX_EVENTS.startNs <= {self._step_range.get(Constant.END_NS)}" + filter_statement_1 = "WHERE PYTORCH_API.startNs >= ? AND PYTORCH_API.startNs <= ?" + filter_statement_2 = "AND MSTX_EVENTS.startNs >= ? AND MSTX_EVENTS.startNs <= ?" else: filter_statement_1, filter_statement_2 = "", "" - return QUERY.format(filter_statement_1, filter_statement_2) + return MARK_QUERY.format(filter_statement_1, filter_statement_2) + + +RANGE_QUERY = ''' +SELECT + MSG_IDS.value AS "msg", + MSTX_EVENTS.startNs AS "cann_start_ts", + MSTX_EVENTS.endNs AS "cann_end_ts", + TASK.startNs AS "device_start_ts", + TASK.endNs AS "device_end_ts", + MSTX_EVENTS.globalTid AS "tid" +FROM + MSTX_EVENTS +LEFT JOIN + TASK + ON MSTX_EVENTS.connectionId == TASK.connectionId +LEFT JOIN + STRING_IDS AS MSG_IDS + ON MSTX_EVENTS.message == MSG_IDS.id +WHERE + MSTX_EVENTS.eventType == 2 {} +AND + MSTX_EVENTS.connectionId != 4294967295 +ORDER BY + MSTX_EVENTS.startNs + ''' + + +class MstxRangeExport(BaseStatsExport): + + def __init__(self, db_path, recipe_name, step_range): + super().__init__(db_path, recipe_name, step_range) + filter_statement = "AND MSTX_EVENTS.startNs >= ? AND MSTX_EVENTS.startNs <= ?" if step_range else "" + self._query = RANGE_QUERY.format(filter_statement) diff --git a/profiler/msprof_analyze/test/st/cluster_analyse/test_cluster_analyse_pytorch_db.py b/profiler/msprof_analyze/test/st/cluster_analyse/test_cluster_analyse_pytorch_db.py index bbc07adebfb5fd96da6f61cdd9a24e107b3653c8..61557153494dee90eec95d67efa87f3ce03eb60f 100644 --- a/profiler/msprof_analyze/test/st/cluster_analyse/test_cluster_analyse_pytorch_db.py +++ b/profiler/msprof_analyze/test/st/cluster_analyse/test_cluster_analyse_pytorch_db.py @@ -73,7 +73,7 @@ class TestClusterAnalysePytorchDb(TestCase): text_cluster_step_trace_time = ClusterStepTraceTimeDb(*df.iloc[0]) self.assertEqual(text_cluster_step_trace_time.type, db_cluster_step_trace_time.type, "Cluster step trace time db vs text 'type' property wrong.") - self.assertEqual(text_cluster_step_trace_time.index, db_cluster_step_trace_time.index, + self.assertEqual(str(text_cluster_step_trace_time.index), str(db_cluster_step_trace_time.index), "Cluster step trace time db vs text 'index' property wrong.") self.assertEqual(round(text_cluster_step_trace_time.computing), round(db_cluster_step_trace_time.computing), "Cluster step trace time db vs text 'computing' property wrong.") diff --git a/profiler/msprof_analyze/test/st/cluster_analyse/test_cluster_analyse_pytorch_text.py b/profiler/msprof_analyze/test/st/cluster_analyse/test_cluster_analyse_pytorch_text.py index d6f8d470109bf21a33bfde683bafb0f41660dcb0..be88b8f1e77b7855889cb5f3dcec7fbe8fef532a 100644 --- a/profiler/msprof_analyze/test/st/cluster_analyse/test_cluster_analyse_pytorch_text.py +++ b/profiler/msprof_analyze/test/st/cluster_analyse/test_cluster_analyse_pytorch_text.py @@ -139,9 +139,13 @@ class TestClusterAnalyseCmdPytorchText(TestCase): "Communication(Not Overlapped and Exclude Receive)", "Preparing"] self.assertEqual(headers, df.columns.tolist(), "PyTorch text result columns wrong.") - data_base = ["rank", "7", 14945901.573999925, 50289541.49199608, 14462809.01400388, 64752350.50599996, + data_base = ["rank", 7, 14945901.573999925, 50289541.49199608, 14462809.01400388, 64752350.50599996, 377397.078000026, 65612840.25, 0.0, 50289541.49199608, 1726054679554437.8] - self.assertEqual(data_base, df.iloc[0].loc["Type":"Preparing"].tolist(), "PyTorch text result data wrong.") + + rank_7_df = df[df["Index"] == 7] + self.assertEqual(len(rank_7_df), 1, "PyTorch text result data wrong.") + data_compare = rank_7_df.iloc[0].loc["Type":"Preparing"].values.tolist() + self.assertEqual(data_base, data_compare, "PyTorch text result data wrong.") def communication_matrix_compare(self): """ @@ -153,10 +157,11 @@ class TestClusterAnalyseCmdPytorchText(TestCase): for header in headers: result_data = result_data.get(header, {}) compare_data = [] - for data in list(result_data.values())[:12]: + result_data = {k: result_data[k] for k in sorted(result_data.keys(), reverse=True)} + for data in list(result_data.values()): compare_data.append(data.get("Bandwidth(GB/s)", -1)) - data_base = [25.0568, 641.8677, 23.4726, 23.2394, 626.9544, 24.9039, - 22.7738, 23.0614, 640.6486, 25.7812, 23.1025, 23.2896] + data_base = [641.8677, 23.4726, 23.2394, 25.0568, 22.7738, 626.9544, 23.0614, 24.9039, + 23.2896, 23.1025, 640.6486, 25.7812, 23.1077, 22.9017, 23.2811, 629.2938] self.assertEqual(data_base, compare_data, "PyTorch text result data wrong.") def communication_compare(self): @@ -170,6 +175,7 @@ class TestClusterAnalyseCmdPytorchText(TestCase): for header in headers: result_data = result_data.get(header, {}) board_datas = [] + result_data = {k: result_data[k] for k in sorted(result_data.keys(), reverse=True)} for data in list(result_data.values())[:2]: board_datas.append(data.get("Communication Time Info", {})) compare_data = [] diff --git a/profiler/msprof_analyze/test/ut/cluster_analyse/communication_group/test_communication_group_generator.py b/profiler/msprof_analyze/test/ut/cluster_analyse/communication_group/test_communication_group_generator.py deleted file mode 100644 index 517327b81117f100a1bf0f71edbe9ebd45ef605e..0000000000000000000000000000000000000000 --- a/profiler/msprof_analyze/test/ut/cluster_analyse/communication_group/test_communication_group_generator.py +++ /dev/null @@ -1,113 +0,0 @@ -# Copyright (c) 2025, Huawei Technologies Co., Ltd. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest -from unittest import mock - -from msprof_analyze.cluster_analyse.communication_group.communication_group_generator import CommunicationGroupGenerator -from msprof_analyze.prof_common.constant import Constant - - -class TestCommunicationGroupGenerator(unittest.TestCase): - DIR_PATH = '' - PARAMS = { - Constant.DATA_SIMPLIFICATION: "ORIGINAL", - Constant.DATA_TYPE: Constant.TEXT - } - - def test_generate_p2p_communication_when_given_group_1p_return_1p2p(self): - check = CommunicationGroupGenerator(self.PARAMS).processor - check.collective_group_dict = { - 'group1': {0} - } - with mock.patch("msprof_analyze.prof_common.file_manager.FileManager.read_json_file", - return_value=True): - check.generate_p2p_communication_group() - ret = {0} - self.assertEqual(ret, set(check.communication_group[Constant.P2P][0])) - - def test_generate_p2p_communication_when_given_group_8p_return_correct_value(self): - check = CommunicationGroupGenerator(self.PARAMS).processor - check.collective_group_dict = { - 'group1': {1, 2, 3, 4}, - 'group2': {5, 6, 7, 8}, - } - with mock.patch("msprof_analyze.prof_common.file_manager.FileManager.read_json_file", - return_value=True): - check.generate_p2p_communication_group() - ret_a = {1, 2, 3, 4} - ret_b = {5, 6, 7, 8} - self.assertEqual(ret_a, set(check.communication_group[Constant.P2P][0])) - self.assertEqual(ret_b, set(check.communication_group[Constant.P2P][1])) - - def test_generate_p2p_communication_when_given_group_16p_expect_4_group(self): - check = CommunicationGroupGenerator(self.PARAMS).processor - check.collective_group_dict = { - 'group1': {0, 1}, - 'group2': {0, 2}, - 'group3': {2, 3}, - 'group4': {3, 1}, - 'group5': {4, 5}, - 'group6': {4, 6}, - 'group7': {5, 7}, - 'group8': {6, 7}, - 'group9': {8, 9}, - 'group10': {8, 10}, - 'group11': {11, 10}, - 'group12': {11, 9}, - 'group13': {12, 13}, - 'group14': {12, 14}, - 'group15': {15, 13}, - 'group16': {15, 14} - } - with mock.patch("msprof_analyze.prof_common.file_manager.FileManager.read_json_file", - return_value=True): - check.generate_p2p_communication_group() - ret_a = {0, 1, 2, 3} - ret_b = {4, 5, 6, 7} - ret_c = {8, 9, 10, 11} - ret_d = {12, 13, 14, 15} - self.assertEqual(ret_a, set(check.communication_group[Constant.P2P][0])) - self.assertEqual(ret_b, set(check.communication_group[Constant.P2P][1])) - self.assertEqual(ret_c, set(check.communication_group[Constant.P2P][2])) - self.assertEqual(ret_d, set(check.communication_group[Constant.P2P][3])) - - def test_generate_p2p_communication_group_when_given_repeat_group_expect_2_group(self): - check = CommunicationGroupGenerator(self.PARAMS).processor - check.collective_group_dict = { - 'group1': {0, 1, 2, 3}, - 'group2': {0, 1, 2, 3}, - 'group3': {0, 1, 2, 3}, - 'group4': {0, 1, 2, 3}, - 'group5': {3, 2, 4, 5}, - 'group6': {4, 5, 6, 7}, - 'group7': {4, 5, 6, 7}, - 'group8': {4, 5, 6, 7}, - 'group9': {8, 9, 11, 10}, - 'group10': {8, 9, 11, 10}, - 'group11': {11, 10, 12, 13}, - 'group12': {11, 10, 12, 13}, - 'group13': {11, 10, 12, 13}, - 'group14': {12, 13, 14, 15}, - 'group15': {12, 13, 14, 15}, - 'group16': {12, 13, 14, 15} - } - with mock.patch("msprof_analyze.prof_common.file_manager.FileManager.read_json_file", - return_value=True): - check.generate_p2p_communication_group() - ret_a = {0, 1, 2, 3, 4, 5, 6, 7} - ret_b = {8, 9, 10, 11, 12, 13, 14, 15} - self.assertEqual(ret_a, set(check.communication_group[Constant.P2P][0])) - self.assertEqual(ret_b, set(check.communication_group[Constant.P2P][1])) diff --git a/profiler/msprof_analyze/test/ut/cluster_analyse/recipes/test_ep_load_balance.py b/profiler/msprof_analyze/test/ut/cluster_analyse/recipes/test_ep_load_balance.py new file mode 100644 index 0000000000000000000000000000000000000000..577df7bb84193be6d6dc815145e3c2d33d34c455 --- /dev/null +++ b/profiler/msprof_analyze/test/ut/cluster_analyse/recipes/test_ep_load_balance.py @@ -0,0 +1,90 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest.mock import patch, MagicMock +import pandas as pd + +from msprof_analyze.prof_common.constant import Constant +from msprof_analyze.cluster_analyse.recipes.ep_load_balance.ep_load_balance import EPLoadBalance + + +class TestEPLoadBalance(unittest.TestCase): + + def setUp(self): + self.params = {} + self.ep_load_balance = EPLoadBalance(self.params) + self.mock_db_path = "mock_db_path" + self.mock_rank_id = 0 + self.mock_step_range = {Constant.START_NS: 0, Constant.END_NS: 1000} + self.mock_global_ranks = [0, 1] + + @patch("msprof_analyze.cluster_analyse.recipes.ep_load_balance.ep_load_balance.DatabaseService") + def test_mapper_func_given_valid_data_map_when_called_then_pass(self, mock_db_service): + """ + Test _mapper_func method to ensure it returns a DataFrame with correct Rank and epRanks columns + when provided with a valid data map. + """ + # Mock the DatabaseService and its methods + mock_db_instance = mock_db_service.return_value + mock_db_instance.query_data.return_value = { + "META_DATA": pd.DataFrame( + { + "name": ["parallel_group_info"], + "value": ['{"group1": {"group_name": "exp", "global_ranks": [0, 1]}}'], + } + ) + } + + # Mock the InputShapeExport + + mock_input_shape_export = MagicMock() + mock_input_shape_export.read_export_db.return_value = pd.DataFrame( + {"InputShapes": ["1,3;4,6;;;;;4", "1,3;4,6;;;;;4"]} + ) + + with patch( + "msprof_analyze.cluster_analyse.recipes.ep_load_balance.ep_load_balance.InputShapeExport", + return_value=mock_input_shape_export, + ): + data_map = { + Constant.PROFILER_DB_PATH: self.mock_db_path, + Constant.RANK_ID: self.mock_rank_id, + Constant.STEP_RANGE: self.mock_step_range, + } + result = self.ep_load_balance._mapper_func(data_map, "mock_analysis_class") + + self.assertIsNotNone(result) + self.assertEqual(result["Rank"].tolist(), [self.mock_rank_id] * 2) + self.assertEqual(result["epRanks"].tolist(), [self.mock_global_ranks] * 2) + + def test_reducer_func_given_dataframes_when_called_then_pass(self): + """ + Test reducer_func method to ensure it processes multiple DataFrames and generates + ep_tokens_summary and top_ep_tokens_map correctly. + """ + mock_mapper_res = [ + pd.DataFrame( + {"Rank": [0, 1], "epRanks": [[0, 1], [0, 1]], "InputShapes": ["1,3;4,6;;;;;4", "7,8;10,12;;;;4"]} + ), + pd.DataFrame( + {"Rank": [2, 3], "epRanks": [[0, 1], [0, 1]], "InputShapes": ["1,3;4,6;;;;;4", "1,3;4,6;;;;;4"]} + ), + ] + + self.ep_load_balance.reducer_func(mock_mapper_res) + + self.assertIsNotNone(self.ep_load_balance.ep_tokens_summary) + self.assertIsNotNone(self.ep_load_balance.top_ep_tokens_map) \ No newline at end of file diff --git a/profiler/msprof_analyze/test/ut/cluster_analyse/recipes/test_freq_analysis.py b/profiler/msprof_analyze/test/ut/cluster_analyse/recipes/test_freq_analysis.py new file mode 100644 index 0000000000000000000000000000000000000000..0a559b79178d03879df6901703c66c7cfcd03663 --- /dev/null +++ b/profiler/msprof_analyze/test/ut/cluster_analyse/recipes/test_freq_analysis.py @@ -0,0 +1,83 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import random +import unittest + +import pandas as pd + +from msprof_analyze.cluster_analyse.recipes.freq_analysis.freq_analysis import FreqAnalysis + + +class TestFreqAnalysis(unittest.TestCase): + + freq = [1800] + free_freq = [800, 1800] + abnormal_freq = [1200, 1300, 1800] + + def test_no_error_freq(self): + params = {} + recipe = FreqAnalysis(params) + mapper_res = [(self.freq, 0)] * 10 + recipe.reducer_func(mapper_res) + self.assertEqual(recipe.free_freq_ranks, []) + self.assertEqual(recipe.abnormal_freq_ranks, []) + self.assertEqual(recipe.abnormal_freq_ranks_map, {}) + + + def test_free_rank_map(self): + params = {} + recipe = FreqAnalysis(params) + mapper_res = [ + (self.freq, 0), + (self.free_freq, 1), + (self.free_freq, 2), + (self.freq, 3) + ] + recipe.reducer_func(mapper_res) + self.assertEqual(recipe.free_freq_ranks, [1, 2]) + self.assertEqual(recipe.abnormal_freq_ranks, []) + self.assertEqual(recipe.abnormal_freq_ranks_map, {}) + + def test_abnormal_rank_map(self): + params = {} + recipe = FreqAnalysis(params) + mapper_res = [ + (self.freq, 0), + (self.abnormal_freq, 1), + (self.abnormal_freq, 2), + (self.freq, 3) + ] + + recipe.reducer_func(mapper_res) + self.assertEqual(recipe.free_freq_ranks, []) + self.assertEqual(recipe.abnormal_freq_ranks, [1, 2]) + + def test_mix_freq_case(self): + params = {} + recipe = FreqAnalysis(params) + mapper_res = [] + rank_case = [[], [], []] + random_freq = {0: self.freq, 1: self.free_freq, 2: self.abnormal_freq} + + for i in range(1000): + random_num = random.choice([0, 1, 2]) + mapper_res.append((random_freq.get(random_num, self.freq), i)) + rank_case[random_num].append(i) + + recipe.reducer_func(mapper_res) + self.assertEqual(recipe.free_freq_ranks, rank_case[1]) + self.assertEqual(recipe.abnormal_freq_ranks, rank_case[2]) diff --git a/profiler/msprof_analyze/test/ut/cluster_analyse/recipes/test_slow_rank.py b/profiler/msprof_analyze/test/ut/cluster_analyse/recipes/test_slow_rank.py new file mode 100644 index 0000000000000000000000000000000000000000..9fecb4f2a06eb5fcda71836fd6b8ad0f55a424ae --- /dev/null +++ b/profiler/msprof_analyze/test/ut/cluster_analyse/recipes/test_slow_rank.py @@ -0,0 +1,101 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +import pandas as pd + +from msprof_analyze.cluster_analyse.recipes.slow_rank.slow_rank import judge_norm, judge_dixon, SlowRankVoteAnalysis + + +class TestJudgeNorm(unittest.TestCase): + def test_no_outlier(self): + data_list = [10] * 120 + res = judge_norm(data_list) + self.assertEqual(res, []) + + def test_with_outlier(self): + data_with_outlier = [10] * 120 + data_with_outlier.append(0) + res = judge_norm(data_with_outlier) + self.assertEqual(res, [120]) + + +class TestJudgeDixon(unittest.TestCase): + def test_no_outlier(self): + for i in [6, 8, 12, 30]: + data_list = [100 + j for j in range(i)] + res = judge_dixon(data_list) + self.assertEqual(res, []) + + def test_with_outlier(self): + for i in [6, 8, 12, 30]: + data_with_outlier = [100 + j for j in range(i)] + data_with_outlier.append(0) + res = judge_dixon(data_with_outlier) + self.assertEqual(res, [i]) + + +class TestVoteAnalysis(unittest.TestCase): + + @staticmethod + def init_cmm_ops_df(group_0_op_0_num, group_0_op_1_num, group_1_op_0_num): + comm_ops_df = pd.DataFrame(columns=["rankId", "groupName", "opName", "communication_times"]) + for i in range(group_0_op_0_num): + comm_ops_df.loc[len(comm_ops_df)] = [i, "group_0", "op_0", 0] + for i in range(group_0_op_1_num): + comm_ops_df.loc[len(comm_ops_df)] = [i, "group_0", "op_1", 0] + for i in range(group_1_op_0_num): + comm_ops_df.loc[len(comm_ops_df)] = [i, "group_1", "op_0", 0] + return comm_ops_df + + def test_grouping_ops(self): + group_0_op_0_num = 10 + group_0_op_1_num = 10 + group_1_op_0_num = 5 + comm_ops_df = self.init_cmm_ops_df(group_0_op_0_num, group_0_op_1_num, group_1_op_0_num) + analyzer = SlowRankVoteAnalysis(comm_ops_df) + res = analyzer.grouping_ops() + res = dict(res) + for key in res.keys(): + res[key] = dict(res[key]) + golden_res = { + "group_0": { + "op_0": [i for i in range(group_0_op_0_num)], + "op_1": [i + group_0_op_0_num for i in range(group_0_op_1_num)] + }, + "group_1": { + "op_0": [i + group_0_op_0_num + group_0_op_1_num for i in range(group_1_op_0_num)] + } + } + self.assertEqual(res, golden_res) + + def test_grouping_ops_with_exclude(self): + group_0_op_0_num = 10 + group_0_op_1_num = 12 + group_1_op_0_num = 5 + comm_ops_df = self.init_cmm_ops_df(group_0_op_0_num, group_0_op_1_num, group_1_op_0_num) + analyzer = SlowRankVoteAnalysis(comm_ops_df) + res = analyzer.grouping_ops() + res = dict(res) + for key in res.keys(): + res[key] = dict(res[key]) + golden_res = { + "group_1": { + "op_0": [i for i in range(group_1_op_0_num)] + } + } + self.assertEqual(res, golden_res) diff --git a/profiler/msprof_analyze/version.txt b/profiler/msprof_analyze/version.txt index 359a5b952d49f3592571e2af081510656029298e..10bf840ed530af123660f5edb1544264d8f2def4 100644 --- a/profiler/msprof_analyze/version.txt +++ b/profiler/msprof_analyze/version.txt @@ -1 +1 @@ -2.0.0 \ No newline at end of file +2.0.1 \ No newline at end of file