diff --git a/accuracy_tools/msprobe/csrc/atb_probe/core/SaveExtra.cpp b/accuracy_tools/msprobe/csrc/atb_probe/core/SaveExtra.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c75c7889ca77c5593bba3ccc1a0c9aeda8aafdb6 --- /dev/null +++ b/accuracy_tools/msprobe/csrc/atb_probe/core/SaveExtra.cpp @@ -0,0 +1,181 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "atb_probe/include/SaveExtra.h" + +#include +#include +#include + +#include "atb_probe/Override.h" +#include "atb_probe/include/Helper.h" +#include "common/Toolkit.h" +#include "utils/Constant.h" +#include "utils/DataType.h" +#include "utils/Log.h" +#include "utils/Path.h" +#include "utils/Str.h" + +namespace atb { + namespace ExtraSpace { + const std::vector TENSOR_INFO_HEADER = { + "CaseNum", "CaseName", "OpName", "OpParam", "InNum", "InDType", "InFormat", + "InShape", "OutNum", "OutDType", "OutFormat", "OutShape", "DataGenType", "DataGenRange", + "InTensorFile", "OutTensorFile", "TestType", "TestLevel", "FromModel", "SocVersion", "ExpectedError", + }; + + static void AddTensorAttributes(const std::vector &tensors, + Types::TensorInfo &info, + std::string *dataGenTypes = nullptr) { + bool first = true; + for (const auto &tensor : tensors) { + if (!first) { + info.dtype += ";"; + info.format += ";"; + info.dims += ";"; + info.filePath += ";"; + if (dataGenTypes) { + *dataGenTypes += ";"; + } + } + info.dtype += tensor.dtype; + info.format += tensor.format; + info.dims += tensor.shape; + info.filePath += Kit::GetNewPathForSave(tensor.filePath, Cst::SUBDIRNAME_DUMP_TENSOR).c_str(); + first = false; + if (dataGenTypes) { + *dataGenTypes += "customize"; + } + } + } + + static std::vector GetIOInfo(const uint16_t &caseNum, const atb::Probe::OpInfo &opInfo) { + const std::string caseName = opInfo.opName + std::to_string(caseNum); + Types::TensorInfo inInfo; + std::string dataGenType; + Types::TensorInfo outInfo; + AddTensorAttributes(opInfo.inTensors, inInfo, &dataGenType); + AddTensorAttributes(opInfo.outTensors, outInfo); + + const std::vector inputStringParts = { + caseName, + opInfo.opName, + opInfo.opParam, + inInfo.dtype, + inInfo.format, + inInfo.dims, + inInfo.filePath, + outInfo.dtype, + outInfo.format, + outInfo.dims, + outInfo.filePath, + dataGenType, + }; + for (const std::string &value : inputStringParts) { + if (Kit::IsSpecialCharInjected(value)) { + return {}; + } + } + + std::vector fields = { + std::to_string(caseNum), + caseName, + opInfo.opName, + opInfo.opParam, + std::to_string(opInfo.inTensors.size()), + inInfo.dtype, + inInfo.format, + inInfo.dims, + std::to_string(opInfo.outTensors.size()), + outInfo.dtype, + outInfo.format, + outInfo.dims, + dataGenType, + "", // DataGenRange + inInfo.filePath, + outInfo.filePath, + "", // TestType + "", // TestLevel + "", // FromModel + "", // SocVersion + "NO_ERROR", + }; + return fields; + } + + bool IsSaveTiling() { + const std::string tilingFlag = StrSpace::GetEnvVar(Cst::LINK_SAVE_TILING); + if (tilingFlag.empty()) { + return false; + } + return StrSpace::Str2Int(tilingFlag.c_str(), 0, Cst::LINK_SAVE_TILING) != 0; + } + + bool IsValidParam(const uint8_t *data, const uint64_t &dataSize, const std::string &filePath) { + bool dataFlag = data != nullptr; + bool dataSizeFlag = (dataSize <= Utility::SafePath::SIZE_10G) && (dataSize > 0); + bool pathFlag = Kit::IsPathInGoal(filePath); + return dataFlag && dataSizeFlag && pathFlag; + } + + void SaveTiling(const uint8_t *data, const uint64_t &dataSize, const std::string &filePath) { + Utility::fs::path tilingFilePath = Kit::GetNewPathForSave(filePath, Cst::SUBDIRNAME_TILING); + Utility::SafePath::MakeParentDir(tilingFilePath); + Utility::SaveBytes(data, tilingFilePath, dataSize, std::ios::out | std::ios::binary); + } + + void SaveCpuProf(const uint64_t &executeCount, const std::string &opName, const std::string &st) { + Utility::fs::path cpuFilePath = Utility::GetMsprobeDir() / ("step" + std::to_string(executeCount)) / + ("rank" + Kit::PidTieRank::Get(getpid())) / "cpu_profiling_info.txt"; + Utility::SafePath::MakeParentDir(cpuFilePath); + + std::unordered_map nameNumMap = StrSpace::Str2Map(st, ", ", ":"); + std::vector header = Kit::GetColumns(nameNumMap); + std::vector content = Kit::GetElement(nameNumMap, header); + header.insert(header.begin(), "opName"); + content.insert(content.begin(), opName); + if (!Utility::fs::exists(cpuFilePath)) { + const std::string newHeader = StrSpace::Join(header, "\t"); + Utility::SaveTxt(newHeader, cpuFilePath, std::ios_base::app); + } + const std::string newContent = StrSpace::Join(content, "\t"); + Utility::SaveTxt(newContent, cpuFilePath, std::ios_base::app); + } + + void SaveInfo(const uint64_t &executeCount, const atb::Probe::OpInfo &opInfo, const std::string &fileName) { + Utility::fs::path opPath = Utility::GetMsprobeDir() / ("step" + std::to_string(executeCount)) / + ("rank" + Kit::PidTieRank::Get(getpid())) / fileName; + Utility::SafePath::MakeParentDir(opPath); + uint16_t caseNum = Utility::fs::exists(opPath) ? Kit::GetLineNum(opPath.c_str()) : 1; + if (!Utility::fs::exists(opPath)) { + const std::string newHeader = StrSpace::Join(TENSOR_INFO_HEADER, "\t"); + Utility::SaveTxt(newHeader, opPath, std::ios_base::app); + } + std::vector content = GetIOInfo(caseNum, opInfo); + std::string newContent = StrSpace::Join(content, "\t"); + Utility::SaveTxt(newContent, opPath, std::ios_base::app); + } + + bool IsSaveParam() { + const std::string paramFlag = StrSpace::GetEnvVar(Cst::LINK_SAVE_PARAM); + if (paramFlag.empty()) { + return false; + } + return StrSpace::Str2Int(paramFlag.c_str(), 0, Cst::LINK_SAVE_PARAM) != 0; + } + + } // namespace ExtraSpace +} // namespace atb diff --git a/accuracy_tools/msprobe/csrc/atb_probe/core/SaveGraph.cpp b/accuracy_tools/msprobe/csrc/atb_probe/core/SaveGraph.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1b3cc61f916f14d460d8d71d6575d36013175a61 --- /dev/null +++ b/accuracy_tools/msprobe/csrc/atb_probe/core/SaveGraph.cpp @@ -0,0 +1,174 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "atb_probe/include/SaveGraph.h" + +#include + +#include "common/Toolkit.h" +#include "utils/DataType.h" +#include "utils/Exception.h" +#include "utils/IO.h" +#include "utils/Log.h" + +namespace GraphSpace { + uint32_t g_maxDeep = 1000; + + static void AddTensors(Utility::ordered_json &newJson, + std::vector &rootNames, + const std::string &opName, + const std::string &tensorType, + const std::string &jsonKey, + size_t count) { + for (size_t i = 0; i < count; ++i) { + std::string tensorName = opName + "_" + tensorType + "_" + std::to_string(i); + newJson[jsonKey].emplace_back(tensorName); + rootNames.emplace_back(tensorName); + } + } + + static void AddByFields(const std::vector &name, + const Utility::ordered_json &originJson, + Utility::ordered_json &newJson) { + for (const auto &fieldName : name) { + if (originJson.contains(fieldName)) { + newJson[fieldName] = originJson[fieldName]; + } else { + LOG_ERROR << fieldName << " not found in graph."; + return; + } + } + } + + static void ModifyRootNodes(Utility::ordered_json &newJson, + std::vector &rootNames, + const Utility::ordered_json &originJson) { + std::string opName = originJson["opName"].get(); + uint32_t inTensorNum = originJson["inTensorNum"].get(); + uint32_t outTensorNum = originJson["outTensorNum"].get(); + uint32_t internalTensorNum = originJson.value("internalTensorNum", 0); + + AddTensors(newJson, rootNames, opName, "input", "inTensors", inTensorNum); + LOG_DEBUG << "Last inTensors: " << (rootNames.empty() ? "" : rootNames.back()); + AddTensors(newJson, rootNames, opName, "output", "outTensors", outTensorNum); + LOG_DEBUG << "Last outTensors: " << (rootNames.empty() ? "" : rootNames.back()); + AddTensors(newJson, rootNames, opName, "internal", "internalTensors", internalTensorNum); + LOG_DEBUG << "Last internalTensors: " << (rootNames.empty() ? "" : rootNames.back()); + } + + static void ProcessTensorIds(const Utility::ordered_json &tensorIds, + const std::vector &sourceTensorList, + std::vector &tensorNameList, + Utility::ordered_json &tensorContainer, + const std::string &opName, + const std::string &field) { + for (const auto &item : tensorIds) { + uint32_t index = item.get(); + if (index >= sourceTensorList.size()) { + LOG_ERROR << field << " index out of range for op: " << opName; + return; + } + const auto &tensorName = sourceTensorList[index]; + tensorContainer[field].emplace_back(tensorName); + tensorNameList.emplace_back(tensorName); + } + } + + static void DfsToModifyGraphTensors(Utility::ordered_json &curNodeToSave, + const std::vector &fatherNodeTensorNameList, + const Utility::ordered_json &curNodeInput, + uint32_t curDeep = 0) { + if (curDeep >= g_maxDeep) { + LOG_ERROR << "Function " + std::string(__func__) + " has been terminated due to max depth."; + throw Utility::MsprobeException("Exceeded maximum recursion depth of " + std::to_string(g_maxDeep)); + } + + const std::string opName = curNodeInput["opName"].get(); + curNodeToSave["opName"] = opName; + curNodeToSave["opType"] = curNodeInput["opType"]; + curNodeToSave["param"] = curNodeInput["param"]; + + std::vector curNodeTensorNameList; + // Process the input tensor. + ProcessTensorIds(curNodeInput["inTensorIds"], + fatherNodeTensorNameList, + curNodeTensorNameList, + curNodeToSave, + opName, + "inTensors"); + // Process the output tensor. + ProcessTensorIds(curNodeInput["outTensorIds"], + fatherNodeTensorNameList, + curNodeTensorNameList, + curNodeToSave, + opName, + "outTensors"); + + uint32_t internalTensorNum = curNodeInput.value("internalTensorNum", 0U); + for (uint32_t i = 0; i < internalTensorNum; ++i) { + std::string tensorName = opName + "_internal_" + std::to_string(i); + curNodeToSave["internalTensors"].emplace_back(tensorName); + curNodeTensorNameList.emplace_back(tensorName); + } + + if (curNodeInput.contains("nodes")) { + for (const auto &childNodeInput : curNodeInput["nodes"]) { + Utility::ordered_json childNodeToSave; + DfsToModifyGraphTensors(childNodeToSave, curNodeTensorNameList, childNodeInput, curDeep + 1); + curNodeToSave["nodes"].emplace_back(childNodeToSave); + } + } + } + + static void ProcessChildNodes(Utility::ordered_json &newJson, + const std::vector &tensorNameList, + const Utility::ordered_json &graphJson) { + if (graphJson.find("nodes") != graphJson.end()) { + for (const auto &childNodeInput : graphJson["nodes"]) { + Utility::ordered_json childNodeToSave; + try { + DfsToModifyGraphTensors(childNodeToSave, tensorNameList, childNodeInput); + } catch (const std::exception &e) { + LOG_ERROR << "An unexpected error occurred: " << e.what(); + return; + } + newJson["nodes"].emplace_back(childNodeToSave); + } + } + } + + void CheckInputValid(const std::string &opName, const Utility::ordered_json &graphJson) { + bool flag = Kit::IsValidKeys({"opName", "opType", "inTensorNum", "outTensorNum"}, graphJson); + if (!flag) { + return; + } + std::string opNameInJson = graphJson["opName"].get(); + if (opNameInJson != opName) { + LOG_ERROR << "opName is not in json. Currently: " << opName << " vs " << opNameInJson; + return; + } + } + + Utility::ordered_json Build(const Utility::ordered_json &graphJson) { + Utility::ordered_json graphNodeJsonToSave; + AddByFields({"opName", "opType", "param"}, graphJson, graphNodeJsonToSave); + std::vector tensorNameList; + ModifyRootNodes(graphNodeJsonToSave, tensorNameList, graphJson); + ProcessChildNodes(graphNodeJsonToSave, tensorNameList, graphJson); + return graphNodeJsonToSave; + } + +} // namespace GraphSpace diff --git a/accuracy_tools/msprobe/csrc/atb_probe/core/SaveTensor.cpp b/accuracy_tools/msprobe/csrc/atb_probe/core/SaveTensor.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3fc11ce587092560fb01f9b155af35cabb3e324e --- /dev/null +++ b/accuracy_tools/msprobe/csrc/atb_probe/core/SaveTensor.cpp @@ -0,0 +1,171 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "atb_probe/include/SaveTensor.h" + +#include +#include +#include + +#include "atb_probe/include/Helper.h" +#include "utils/Constant.h" +#include "utils/DataType.h" +#include "utils/Exception.h" +#include "utils/IO.h" +#include "utils/Log.h" +#include "utils/Str.h" + +namespace TensorSpace { + const std::unordered_map levelMap = { + {"L0", "Module"}, + {"L1", "Operation"}, + {"L2", "Kernel"}, + }; + + const std::unordered_map> levelKeywords = { + {"Module", {"Operation/before", "Operation/after", "Kernel/before", "Kernel/after"}}, // exclusion item + {"Operation", {"Operation/before", "Operation/after"}}, + {"Kernel", {"Kernel/before", "Kernel/after"}}, + }; + + static std::string BuildQueryFromIds(const std::vector &ids) { + std::string query = std::to_string(ids[0]); + for (size_t i = 1; i < ids.size(); ++i) { + query += "_" + std::to_string(ids[i]); + } + return query; + } + + static bool IsPathInModeGoal(const std::string &filePath, const std::vector &keywords) { + for (const auto &keyword : keywords) { + if (filePath.find(keyword) == std::string::npos) { + return false; + } + } + return true; + } + + static bool IsPathInLevelGoal(const std::string &filePath, const std::vector &levels) { + for (const auto &level : levels) { + auto it = levelKeywords.find(level); + const auto &keywords = it->second; + if (level == "Module") { + // No excluded keywords in the path. + bool hasExcluded = false; + for (const auto &kw : keywords) { + if (filePath.find(kw) != std::string::npos) { + LOG_DEBUG << "IsPathInLevelGoal of " << level << ": true"; + hasExcluded = true; + break; + } + } + if (!hasExcluded) { + return true; + } + } else { + // Any keyword in the path is acceptable. + for (const auto &kw : keywords) { + if (filePath.find(kw) != std::string::npos) { + LOG_DEBUG << "IsPathInLevelGoal of " << level << ": true"; + return true; + } + } + } + } + return false; + } + + static bool IsDumpMode(const std::string &label) { + return StrSpace::IsValueInGoal(Cst::LINK_DATA_MODE, label) || + StrSpace::IsValueInGoal(Cst::LINK_DATA_MODE, Cst::MODE_ALL); + } + + static std::vector GetDumpLevel() { + const std::string levelStr = StrSpace::GetEnvVar(Cst::LINK_DUMP_LEVEL); // "L1" + if (levelStr.empty()) { + return {}; + } + const std::vector levelVec = StrSpace::Split(levelStr, ","); + if (levelVec.empty()) { + return {}; + } + std::vector res; + for (const auto &level : levelVec) { + auto it = levelMap.find(level); + if (it != levelMap.end()) { + res.push_back(it->second); + } else { + LOG_ERROR << "Unsupported level: " << level; + } + } + return res; + } + + bool IsOpidMatch(const std::vector &ids, const std::string &opidStr) { + std::vector opidVec = StrSpace::Split(opidStr, ","); // 1_1_1,1_2_3 + std::string query = BuildQueryFromIds(ids); + for (const auto &indice : opidVec) { + LOG_DEBUG << "op_id from model: " << query << ", op_id from user: " << indice; + bool isChildMatch = indice.find('_') != std::string::npos; + if (isChildMatch) { + if (StrSpace::IsPrefix(query, indice) && + (query == indice || (query.length() > indice.length() && query[indice.length()] == '_'))) { + return true; + } + } else { + if (indice == query) { + return true; + } + } + } + return false; + } + + bool IsOpNameMatch(const std::string &optype, const std::string &opNameStr) { + std::vector opNameVec = StrSpace::Split(opNameStr, ","); + std::string lowerCaseOptype = StrSpace::ToLower(optype); + for (const auto &indice : opNameVec) { + LOG_DEBUG << "op_name from model: " << lowerCaseOptype << ", op_name from user: " << indice; + if (StrSpace::IsPrefix(lowerCaseOptype, indice)) { + return true; + } + } + return false; + } + + bool IsExpectedTensor(const Types::TensorInfo &inputFile) { + LOG_DEBUG << "Check filePath: " << inputFile.filePath; + bool modeFlag = + (IsDumpMode(Cst::MODE_INPUT) && IsPathInModeGoal(inputFile.filePath, {Cst::BEFORE, Cst::INTENSOR})) || + (IsDumpMode(Cst::MODE_OUTPUT) && IsPathInModeGoal(inputFile.filePath, {Cst::AFTER, Cst::OUTTENSOR})); + bool levelFlag = IsPathInLevelGoal(inputFile.filePath, GetDumpLevel()); + bool rankFlag = Kit::IsValidRank(inputFile.filePath); + bool dataSizeFlag = (inputFile.dataSize <= Utility::SafePath::SIZE_10G) && (inputFile.dataSize > 0); + bool hostDataFlag = (inputFile.hostData != nullptr); + bool headLengthFlag = StrSpace::IsStringLengthSafety({inputFile.format, inputFile.dtype, inputFile.dims}); + bool flag = modeFlag && levelFlag && rankFlag && dataSizeFlag && hostDataFlag && headLengthFlag; + LOG_DEBUG << "IsExpectedTensor: " << flag; + return flag; + } + + void Save(const Types::TensorInfo &inputFile) { + Utility::fs::path newPath = + Kit::GetNewPathForSave(inputFile.filePath, Cst::SUBDIRNAME_DUMP_TENSOR).replace_extension(".npy"); + Utility::SafePath::MakeParentDir(newPath); + std::vector shapeVec = StrSpace::SplitToInt(inputFile.dims, ","); + Utility::SaveNpy(inputFile.dtype, shapeVec, inputFile.hostData, inputFile.dataSize, newPath); + } +} // namespace TensorSpace diff --git a/accuracy_tools/msprobe/csrc/atb_probe/core/Stat.cpp b/accuracy_tools/msprobe/csrc/atb_probe/core/Stat.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0b167a4e47b0381af15bd176635bb74a6f9876ac --- /dev/null +++ b/accuracy_tools/msprobe/csrc/atb_probe/core/Stat.cpp @@ -0,0 +1,229 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "atb_probe/include/Stat.h" + +#include +#include + +#include + +#include "atb_probe/include/Helper.h" +#include "utils/Constant.h" +#include "utils/DataType.h" +#include "utils/Exception.h" +#include "utils/IO.h" +#include "utils/Log.h" +#include "utils/Str.h" + +namespace Stat { + inline constexpr int K_CRC_HEX_WIDTH = 8; + inline constexpr int K_FLOAT16_EXP_BITS = 5; + inline constexpr int K_FLOAT16_MANT_BITS = 10; + inline constexpr int K_BFLOAT16_EXP_BITS = 8; + inline constexpr int K_BFLOAT16_MANT_BITS = 7; + + template std::string ComputeCRC32(const T *data, size_t count) { + const unsigned char *bytePtr = static_cast(static_cast(data)); + size_t byteSize = count * sizeof(T); + + uLong crc = crc32(0L, Z_NULL, 0); + crc = crc32(crc, bytePtr, byteSize); + + std::stringstream ss; + ss << std::hex << std::setw(K_CRC_HEX_WIDTH) << std::setfill('0') << crc; + return ss.str(); + } + + static uint64_t GetCountFromDType(const std::string &dtypeKey, const uint64_t &dataSize) { + auto it = Utility::dtypeMap.find(dtypeKey); + if (it != Utility::dtypeMap.end()) { + uint64_t elemSize = it->second.elemSize; + uint64_t count = dataSize / elemSize; + return count; + } else { + throw Utility::MsprobeException("Invalid dtype: " + dtypeKey); + } + } + + static std::string GetDtype(const std::string &dtypeKey) { + auto it = Utility::dtypeMap.find(dtypeKey); + if (it != Utility::dtypeMap.end()) { + return it->second.dtypeName; + } else { + LOG_WARNING << "Unsupported dtype: " << dtypeKey; + return "Unknown"; + } + } + + static float ConvertToFloat32(const uint16_t value, const size_t exponentBits, const size_t mantissaBits) { + // Determine the bias of the semi-precision type + int32_t exponentBias = (1 << (exponentBits - 1)) - 1; + // Obtain the mask + uint16_t exponentMask = ((1 << exponentBits) - 1) << mantissaBits; + uint16_t mantissaMask = (1 << mantissaBits) - 1; + uint16_t signMask = 1 << (exponentBits + mantissaBits); + // Extract symbol bits + int sign = (value & signMask) ? -1 : 1; + // Extract index and mantissa + int32_t rawExponent = (value & exponentMask) >> mantissaBits; + uint32_t mantissa = value & mantissaMask; + // Handle special values + if (rawExponent == (1 << exponentBits) - 1) { // All 1s represent NaN or infinity + if (mantissa != 0) { + return std::numeric_limits::quiet_NaN(); // NaN + } else { + return sign * std::numeric_limits::infinity(); // Infinity + } + } else if (rawExponent == 0) { // Exponents with all zeros indicate non-normalized numbers or zeros + if (mantissa == 0) { + return sign * 0.0f; // Zero + } else { + // Unnormalized number + float result = sign * std::ldexp(static_cast(mantissa), + 1 - static_cast(exponentBias) - static_cast(mantissaBits)); + return result; + } + } + // Normalized number + float normalizedMantissa = 1.0f + static_cast(mantissa) / (1 << mantissaBits); + float result = sign * std::ldexp(normalizedMantissa, rawExponent - exponentBias); + return result; + } + + template Types::TensorStats ComputeStats(Types::TensorInfo info) { + const T *ptr = static_cast(info.hostData); + uint64_t count = GetCountFromDType(info.dtype, info.dataSize); + + double mean = 0.0; + double m2 = 0.0; // sum of squares of differences from the current mean + double minVal = static_cast(ptr[0]); + double maxVal = static_cast(ptr[0]); + + for (uint64_t i = 0; i < count; ++i) { + double val = static_cast(ptr[i]); + // Update min/max + if (val > maxVal) { + maxVal = val; + } + if (val < minVal) { + minVal = val; + } + // Welford's online update + double delta = val - mean; + mean += delta / (i + 1); + m2 += delta * (val - mean); // (val - new_mean) + } + + Types::TensorStats stats; + stats.type = "Tensor"; + stats.dtype = GetDtype(info.dtype); + stats.shape = StrSpace::SplitToInt(info.dims, ","); + stats.mean = mean; + stats.min = minVal; + stats.max = maxVal; + stats.norm = std::sqrt(m2 + count * mean * mean); // Equivalent to sqrt(sum(x^2)) + std::string summary_mode = StrSpace::GetEnvVar(Cst::LINK_SUMMARY_MODE); + stats.crc32 = (summary_mode == "md5") ? ComputeCRC32(ptr, count * sizeof(T)) : ""; + return stats; + } + + Types::TensorStats ComputeStatsFloat16(const Types::TensorInfo &info) { + const uint16_t *ptr = static_cast(info.hostData); + uint64_t count = GetCountFromDType(info.dtype, info.dataSize); + + size_t exponentBits, mantissaBits; + if (info.dtype == "1") { // float16 + exponentBits = K_FLOAT16_EXP_BITS; + mantissaBits = K_FLOAT16_MANT_BITS; + } else if (info.dtype == "27") { // bfloat16 + exponentBits = K_BFLOAT16_EXP_BITS; + mantissaBits = K_BFLOAT16_MANT_BITS; + } else { + throw Utility::MsprobeException("Unsupported dtype in float16-compatible handler."); + } + + // Initialize Welford variables + double mean = 0.0; + double m2 = 0.0; + float firstVal = ConvertToFloat32(ptr[0], exponentBits, mantissaBits); + double minVal = firstVal; + double maxVal = firstVal; + + for (uint64_t i = 0; i < count; ++i) { + float val = ConvertToFloat32(ptr[i], exponentBits, mantissaBits); + + if (val > maxVal) { + maxVal = val; + } + if (val < minVal) { + minVal = val; + } + // Welford update + double delta = val - mean; + mean += delta / (i + 1); + m2 += delta * (val - mean); + } + + Types::TensorStats stats; + stats.type = "Tensor"; + stats.dtype = GetDtype(info.dtype); + stats.shape = StrSpace::SplitToInt(info.dims, ","); + stats.mean = mean; + stats.min = minVal; + stats.max = maxVal; + stats.norm = std::sqrt(m2 + count * mean * mean); // Equivalent to sqrt(sum(val^2)) + + std::string summary_mode = StrSpace::GetEnvVar(Cst::LINK_SUMMARY_MODE); + stats.crc32 = (summary_mode == "md5") ? ComputeCRC32(ptr, count * sizeof(uint16_t)) : ""; + + return stats; + } + + Types::TensorStats Compute(const Types::TensorInfo &info) { + const std::string &dtype = info.dtype; + if (dtype == "1" || dtype == "27") { + // float16 or bfloat16 + return ComputeStatsFloat16(info); + } + if (dtype == "0") { + return ComputeStats(info); + } else if (dtype == "2") { + return ComputeStats(info); + } else if (dtype == "3") { + return ComputeStats(info); + } else if (dtype == "4") { + return ComputeStats(info); + } else if (dtype == "6") { + return ComputeStats(info); + } else if (dtype == "7") { + return ComputeStats(info); + } else if (dtype == "8") { + return ComputeStats(info); + } else if (dtype == "9") { + return ComputeStats(info); + } else if (dtype == "10") { + return ComputeStats(info); + } else if (dtype == "11") { + return ComputeStats(info); + } else if (dtype == "12") { + return ComputeStats(info); + } else { + throw Utility::MsprobeException("Unsupported dtype in ComputeByDType: " + dtype); + } + } + +} // namespace Stat