diff --git a/debug/accuracy_tools/msprobe/ccsrc/core/AclDumpDataProcessor.cpp b/debug/accuracy_tools/msprobe/ccsrc/core/AclDumpDataProcessor.cpp index 0fe3443fa1f9286fe77c710c955d543d94c4b3a4..3f66094c624da9389b151bc7ef0912751b28831e 100644 --- a/debug/accuracy_tools/msprobe/ccsrc/core/AclDumpDataProcessor.cpp +++ b/debug/accuracy_tools/msprobe/ccsrc/core/AclDumpDataProcessor.cpp @@ -75,6 +75,11 @@ static const std::map summaryOptionHeaderStr {DebuggerSummaryOption::MD5, kStatsHeaderMD5}, }; +const static std::map kDtypeTransMap = { + {AclDtype::DT_BF16, AclDtype::DT_FLOAT}, + {AclDtype::DT_INT4, AclDtype::DT_INT8}, +}; + class AclTensorStats { public: AclTensorStats() = default; @@ -603,7 +608,7 @@ static std::string GenDataPath(const std::string& path) { inline std::string GetTensorInfoSuffix(AclTensorInfo& tensor) { return "." + tensor.inout + "." + std::to_string(tensor.slot) + - "." + DataUtils::GetFormatString(tensor.hostFmt) + "." + DataUtils::GetDTypeString(tensor.dtype); + "." + DataUtils::GetFormatString(tensor.hostFmt) + "." + DataUtils::GetDTypeString(tensor.oriDtype); } static DebuggerErrno DumpOneAclTensorFmtBin(AclTensorInfo& tensor) @@ -640,10 +645,13 @@ static DebuggerErrno DumpOneAclTensorFmtNpy(AclTensorInfo& tensor) return DebuggerErrno::OK; } - if (tensor.dtype == AclDtype::DT_BF16) { - ret = AclTensor::TransDtype(tensor, AclDtype::DT_FLOAT); + auto it = kDtypeTransMap.find(tensor.dtype); + if (it != kDtypeTransMap.end()) { + AclDtype dstDtype = it->second; + ret = AclTensor::TransDtype(tensor, dstDtype); if (ret != DebuggerErrno::OK) { - LOG_ERROR(ret, tensor + ": Failed to transform dtype from bf16 to fp32."); + LOG_ERROR(ret, tensor + ": Failed to transform dtype from " + DataUtils::GetDTypeString(it->first) + " to " + + DataUtils::GetDTypeString(it->second)+ "."); return ret; } } @@ -736,7 +744,9 @@ static DebuggerErrno DumpOneAclTensor(AclTensorInfo& tensor, std::vector( @@ -763,34 +767,80 @@ static void TransBf16ToFp32(const uint8_t* input, size_t num, uint8_t* output, s } } -DebuggerErrno TransDtype(AclTensorInfo& tensor, AclDtype to) +static void TransInt4ToInt8(const uint8_t* input, size_t elemNums, uint8_t* output, size_t bufferSize) { + if (bufferSize < elemNums * sizeof(int8_t)) { + LOG_ERROR(DebuggerErrno::ERROR_BUFFER_OVERFLOW, "Insufficient space for converting data from int4 to int8."); + return; + } + const int8_t *srcData = reinterpret_cast(input); + int8_t *dstData = reinterpret_cast(output); + size_t inputLength = elemNums / 2; + int maxValue = 7; + int minValue = -8; + int signBitShift = 3; + int signBitMask = 0x08; + for (size_t i = 0; i < inputLength; ++i) { + int8_t s = *srcData; + int8_t t = s & 0xf; + // keep the sign bit not change + int8_t signBit = (t & signBitMask) >> signBitShift; + if (signBit == 1) { + t = t | 0xf0; + } else { + t = t & 0x0f; + } + if (t < minValue || t > maxValue) { + LOG_ERROR(DebuggerErrno::ERROR_INVALID_VALUE, "Invalid int4 value."); + } + *dstData = t; + ++dstData; + + int highByteShift = 4; + t = s >> highByteShift; + signBit = (t & signBitMask) >> signBitShift; + if (signBit == 1) { + t = t | 0xf0; + } else { + t = t & 0x0f; + } + if (t < minValue || t > maxValue) { + LOG_ERROR(DebuggerErrno::ERROR_INVALID_VALUE, "Invalid int4 value."); + } + *dstData = t; + ++dstData; + ++srcData; + } + return; +} - const static std::set> kSupportedDtypeTrans = { - {AclDtype::DT_BF16, AclDtype::DT_FLOAT}, - }; +DebuggerErrno TransDtype(AclTensorInfo& tensor, AclDtype to) +{ if (tensor.dtype == to) { return DebuggerErrno::OK; } - if (kSupportedDtypeTrans.find({tensor.dtype, to}) == kSupportedDtypeTrans.end()) { - return DebuggerErrno::ERROR_UNKNOWN_TRANS; - } - + tensor.oriDtype = tensor.dtype; std::vector buffer; AssertConsis(tensor); size_t bufferSize = EleNumOfTensor(tensor) * SizeOfAclDType(to); - buffer.reserve(bufferSize); + buffer.resize(bufferSize); const uint8_t* input = tensor.transBuf.empty() ? tensor.aclData : tensor.transBuf.data(); uint8_t* output = buffer.data(); - /* 目前仅支持bf16->fp32,若有通用转换需求再用更泛化的方式重写 */ if (tensor.dtype == AclDtype::DT_BF16 && to == AclDtype::DT_FLOAT) { TransBf16ToFp32(input, EleNumOfTensor(tensor), output, bufferSize); + } else if (tensor.dtype == AclDtype::DT_INT4 && to == AclDtype::DT_INT8) { + TransInt4ToInt8(input, EleNumOfTensor(tensor), output, bufferSize); + } else { + LOG_ERROR(DebuggerErrno::ERROR_UNKNOWN_TRANS, tensor + ": Trans " + DataUtils::GetDTypeString(tensor.dtype) + + " to " + DataUtils::GetDTypeString(to) + " is not supported."); + return DebuggerErrno::ERROR_UNKNOWN_TRANS; } tensor.transBuf = std::move(buffer); + tensor.dtype = to; return DebuggerErrno::OK; } diff --git a/debug/accuracy_tools/msprobe/ccsrc/core/AclTensor.hpp b/debug/accuracy_tools/msprobe/ccsrc/core/AclTensor.hpp index 8b5ba5b06d935d5aaa2dff35e921b9072db6aa1a..f2ac429a7f14370ea1721369c7f9089cb971bb6e 100644 --- a/debug/accuracy_tools/msprobe/ccsrc/core/AclTensor.hpp +++ b/debug/accuracy_tools/msprobe/ccsrc/core/AclTensor.hpp @@ -40,6 +40,7 @@ struct AclTensorInfo { std::string dumpPath; const uint8_t* aclData; AclDtype dtype; + AclDtype oriDtype; AclFormat deviceFmt; AclFormat hostFmt; AclShape deviceShape; @@ -52,7 +53,7 @@ struct AclTensorInfo { std::vector transBuf; std::string ToString() const { - return "AclTensor(path=" + dumpPath + ",dtype=" + std::to_string(dtype) + ",inout=" + inout + ")"; + return "AclTensor(path=" + dumpPath + ",dtype=" + DataUtils::GetDTypeString(dtype) + ",inout=" + inout + ")"; } }; @@ -71,6 +72,7 @@ AclTensorInfo ParseAttrsFromDumpData(const std::string &dumpPath, const uint8_t* const std::string& io, uint32_t slot); DebuggerErrno TransFormatD2H(AclTensorInfo& tensor); DebuggerErrno TransDtype(AclTensorInfo& tensor, AclDtype to); +bool IsDtypeSupportTrans(AclDtype dtype); } }