From e5e6edb3724f9bc41d641f45bcdfcea4a6361b00 Mon Sep 17 00:00:00 2001 From: Vector Date: Tue, 12 Aug 2025 11:29:00 +0800 Subject: [PATCH] Add NZ and BF16 construction methods for demo framework --- example/op_demo/demo_util.h | 91 +++++++++++++++++++++++++++++++++++-- 1 file changed, 88 insertions(+), 3 deletions(-) mode change 100644 => 100755 example/op_demo/demo_util.h diff --git a/example/op_demo/demo_util.h b/example/op_demo/demo_util.h old mode 100644 new mode 100755 index 3edbec07..abbdb7d5 --- a/example/op_demo/demo_util.h +++ b/example/op_demo/demo_util.h @@ -98,6 +98,78 @@ atb::Status CastOp(atb::Context *contextPtr, aclrtStream stream, const atb::Tens return atb::ErrorType::NO_ERROR; } +/** + * @brief 根据输入inShape,判断是NZ格式的Shape还是ND格式的Shape,并得到两种格式下的shape + * @param tensorType tensor的数据类型 + * @param inShape tensor的shape + * @param ndShape ND格式tensor的shape + * @param nzShape NZ格式tensor的shape + * @return atb::Status atb错误码 + */ +atb::Status getShape(const aclDataType tensorType, std::vector inShape, std::vector &ndShape, + std::vector &nzShape) +{ + int64_t n0 = 16; + if (tensorType == ACL_INT8) { + n0 = 32; + } + if (inShape.size() == 4) { // inShape是NZ格式tensor的shape + nzShape.assign(inShape.begin(), inShape.end()); + ndShape = {inShape[0], inShape[2], inShape[1] * inShape[3]}; + } else { // inShape是ND格式tensor的shape + ndShape.assign(inShape.begin(), inShape.end()); + if (inShape.size() == 3) { // 该shape包含batch参数 + nzShape = {inShape[0], inShape[2] / n0, inShape[1], n0}; + } else { + nzShape = {1, inShape[1] / n0, inShape[0], n0}; + } + } + + return atb::ErrorType::NO_ERROR; +} + +/** + * @brief 进行ND到NZ的数据格式转换,调用transdata Op + * @param contextPtr context指针 + * @param stream stream + * @param inTensor 输入tensor + * @param outTensorType tensor的数据类型 + * @param outTensor 输出tensor + * @param shape tensor的shape + * @return atb::Status atb错误码 + */ +atb::Status TransdataOp(atb::Context *contextPtr, aclrtStream stream, const atb::Tensor inTensor, + const aclDataType tensorType, atb::Tensor &outTensor, std::vector shape) +{ + uint64_t workspaceSize = 0; + void *workspace = nullptr; + + atb::infer::TransdataParam opParam; + opParam.transdataType = atb::infer::TransdataParam::TransdataType::ND_TO_FRACTAL_NZ; + + atb::Operation *transdataOp = nullptr; + CHECK_STATUS(atb::CreateOperation(opParam, &transdataOp)); + atb::Tensor tensor; + CHECK_STATUS(CreateTensor(tensorType, aclFormat::ACL_FORMAT_FRACTAL_NZ, shape, tensor)); + atb::VariantPack transdataVariantPack; + transdataVariantPack.inTensors = {inTensor}; + transdataVariantPack.outTensors = {tensor}; + // 在Setup接口调用时对输入tensor和输出tensor进行校验。 + CHECK_STATUS(transdataOp->Setup(transdataVariantPack, workspaceSize, contextPtr)); + uint8_t *workspacePtr = nullptr; + if (workspaceSize > 0) { + CHECK_STATUS(aclrtMalloc(&workspace, workspaceSize, aclrtMemMallocPolicy::ACL_MEM_MALLOC_HUGE_FIRST)); + } + // Transdata ND to NZ执行 + CHECK_STATUS(transdataOp->Execute(transdataVariantPack, (uint8_t *)workspace, workspaceSize, contextPtr)); + CHECK_STATUS(aclrtSynchronizeStream(stream)); // 流同步,等待device侧任务计算完成 + if (workspaceSize > 0) { + CHECK_STATUS(aclrtFree(workspace)); // 清理工作空间 + } + outTensor = tensor; + return atb::ErrorType::NO_ERROR; +} + /** * @brief 简单封装,拷贝vector data中数据以创建tensor * @details 用于创建outTensorType类型的tensor @@ -128,16 +200,29 @@ atb::Status CreateTensorFromVector(atb::Context *contextPtr, aclrtStream stream, if (inTensorType == outTensorType && inTensorType != ACL_DT_UNDEFINED) { intermediateType = outTensorType; } - CHECK_STATUS(CreateTensor(intermediateType, format, shape, tensor)); + aclFormat tensorFormat = format; + std::vector ndShape, nzShape; + if (intermediateType != outTensorType && format == aclFormat::ACL_FORMAT_FRACTAL_NZ) { + tensorFormat = aclFormat::ACL_FORMAT_ND; + CHECK_STATUS(getShape(outTensorType, shape, ndShape, nzShape)); + } else { + ndShape.assign(shape.begin(), shape.end()); + } + CHECK_STATUS(CreateTensor(intermediateType, tensorFormat, shape, tensor)); CHECK_STATUS(aclrtMemcpy(tensor.deviceData, tensor.dataSize, data.data(), sizeof(T) * data.size(), ACL_MEMCPY_HOST_TO_DEVICE)); - CHECK_STATUS(CreateTensor(ACL_FLOAT16, aclFormat::ACL_FORMAT_ND, shape, outTensor)); + CHECK_STATUS(CreateTensor(outTensorType, aclFormat::ACL_FORMAT_ND, ndShape, outTensor)); if (intermediateType == outTensorType) { // 原始创建的tensor类型,不需要转换 outTensor = tensor; return atb::ErrorType::NO_ERROR; } - return CastOp(contextPtr, stream, tensor, outTensorType, outTensor); + CHECK_STATUS(CastOp(contextPtr, stream, tensor, outTensorType, outTensor)); + if (format == aclFormat::ACL_FORMAT_FRACTAL_NZ) { + CHECK_STATUS(TransdataOp(contextPtr, stream, outTensor, outTensorType, outTensor, nzShape)); + } + + return atb::ErrorType::NO_ERROR; } // 判断soc型号是否为Atlas A2/A3 -- Gitee