From 7aa2d5091c5790835ed84a33d8f27a99be48ad67 Mon Sep 17 00:00:00 2001 From: caobingjie Date: Tue, 12 Aug 2025 17:07:17 +0800 Subject: [PATCH 1/3] Add NZ and BF16 construction methods for demo framework --- example/op_demo/demo_util.h | 89 +++++++++++++++++++++++++++++++++++-- 1 file changed, 86 insertions(+), 3 deletions(-) diff --git a/example/op_demo/demo_util.h b/example/op_demo/demo_util.h index 3edbec07..01b42d5f 100644 --- a/example/op_demo/demo_util.h +++ b/example/op_demo/demo_util.h @@ -98,6 +98,79 @@ atb::Status CastOp(atb::Context *contextPtr, aclrtStream stream, const atb::Tens return atb::ErrorType::NO_ERROR; } + +/** + * @brief 根据输入inShape,判断是NZ格式的Shape还是ND格式的Shape,并得到两种格式下的shape + * @param tensorType tensor的数据类型 + * @param inShape tensor的shape + * @param ndShape ND格式tensor的shape + * @param nzShape NZ格式tensor的shape + * @return atb::Status atb错误码 + */ +atb::Status getShape(const aclDataType tensorType, std::vector inShape, std::vector &ndShape, + std::vector &nzShape) +{ + int64_t n0 = 16; + if (tensorType == ACL_INT8) { + n0 = 32; + } + if (inShape.size() == 4) { // inShape是NZ格式tensor的shape + nzShape.assign(inShape.begin(), inShape.end()); + ndShape = {inShape[0], inShape[2], inShape[1] * inShape[3]}; + } else { // inShape是ND格式tensor的shape + ndShape.assign(inShape.begin(), inShape.end()); + if (inShape.size() == 3) { // 该shape包含batch参数 + nzShape = {inShape[0], inShape[2] / n0, inShape[1], n0}; + } else if(inShape.size() == 2){ + nzShape = {1, inShape[1] / n0, inShape[0], n0}; + } + } + + return atb::ErrorType::NO_ERROR; +} + +/** + * @brief 进行ND到NZ的数据格式转换,调用transdata Op + * @param contextPtr context指针 + * @param stream stream + * @param inTensor 输入tensor + * @param outTensorType tensor的数据类型 + * @param outTensor 输出tensor + * @param shape tensor的shape + * @return atb::Status atb错误码 + */ +atb::Status TransdataOp(atb::Context *contextPtr, aclrtStream stream, const atb::Tensor inTensor, + const aclDataType tensorType, atb::Tensor &outTensor, std::vector shape) +{ + uint64_t workspaceSize = 0; + void *workspace = nullptr; + + atb::infer::TransdataParam opParam; + opParam.transdataType = atb::infer::TransdataParam::TransdataType::ND_TO_FRACTAL_NZ; + + atb::Operation *transdataOp = nullptr; + CHECK_STATUS(atb::CreateOperation(opParam, &transdataOp)); + atb::Tensor tensor; + CHECK_STATUS(CreateTensor(tensorType, aclFormat::ACL_FORMAT_FRACTAL_NZ, shape, tensor)); + atb::VariantPack transdataVariantPack; + transdataVariantPack.inTensors = {inTensor}; + transdataVariantPack.outTensors = {tensor}; + // 在Setup接口调用时对输入tensor和输出tensor进行校验。 + CHECK_STATUS(transdataOp->Setup(transdataVariantPack, workspaceSize, contextPtr)); + uint8_t *workspacePtr = nullptr; + if (workspaceSize > 0) { + CHECK_STATUS(aclrtMalloc(&workspace, workspaceSize, aclrtMemMallocPolicy::ACL_MEM_MALLOC_HUGE_FIRST)); + } + // Transdata ND to NZ执行 + CHECK_STATUS(transdataOp->Execute(transdataVariantPack, (uint8_t *)workspace, workspaceSize, contextPtr)); + CHECK_STATUS(aclrtSynchronizeStream(stream)); // 流同步,等待device侧任务计算完成 + if (workspaceSize > 0) { + CHECK_STATUS(aclrtFree(workspace)); // 清理工作空间 + } + outTensor = tensor; + return atb::ErrorType::NO_ERROR; +} + /** * @brief 简单封装,拷贝vector data中数据以创建tensor * @details 用于创建outTensorType类型的tensor @@ -128,16 +201,26 @@ atb::Status CreateTensorFromVector(atb::Context *contextPtr, aclrtStream stream, if (inTensorType == outTensorType && inTensorType != ACL_DT_UNDEFINED) { intermediateType = outTensorType; } - CHECK_STATUS(CreateTensor(intermediateType, format, shape, tensor)); + aclFormat tensorFormat = format; + if (intermediateType != outTensorType && format == aclFormat::ACL_FORMAT_FRACTAL_NZ) { + tensorFormat = aclFormat::ACL_FORMAT_ND; + } + CHECK_STATUS(CreateTensor(intermediateType, tensorFormat, shape, tensor)); CHECK_STATUS(aclrtMemcpy(tensor.deviceData, tensor.dataSize, data.data(), sizeof(T) * data.size(), ACL_MEMCPY_HOST_TO_DEVICE)); - CHECK_STATUS(CreateTensor(ACL_FLOAT16, aclFormat::ACL_FORMAT_ND, shape, outTensor)); + CHECK_STATUS(CreateTensor(outTensorType, aclFormat::ACL_FORMAT_ND, shape, outTensor)); if (intermediateType == outTensorType) { // 原始创建的tensor类型,不需要转换 outTensor = tensor; return atb::ErrorType::NO_ERROR; } - return CastOp(contextPtr, stream, tensor, outTensorType, outTensor); + CHECK_STATUS(CastOp(contextPtr, stream, tensor, outTensorType, outTensor)); + if(outTensor.desc.format != format) + { + //直接赋值将tensor转成需要的数据格式,或者使用提供的TransdataOp函数进行数据格式转换 + outTensor.desc.format = format; + } + return atb::ErrorType::NO_ERROR; } // 判断soc型号是否为Atlas A2/A3 -- Gitee From ada28cc182bbda5839dd3f6c843daa3ba173fb38 Mon Sep 17 00:00:00 2001 From: Vector Date: Tue, 12 Aug 2025 11:54:53 +0000 Subject: [PATCH 2/3] update example/op_demo/demo_util.h. Signed-off-by: Vector --- example/op_demo/demo_util.h | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/example/op_demo/demo_util.h b/example/op_demo/demo_util.h index 01b42d5f..360bfe9d 100644 --- a/example/op_demo/demo_util.h +++ b/example/op_demo/demo_util.h @@ -121,7 +121,7 @@ atb::Status getShape(const aclDataType tensorType, std::vector inShape, ndShape.assign(inShape.begin(), inShape.end()); if (inShape.size() == 3) { // 该shape包含batch参数 nzShape = {inShape[0], inShape[2] / n0, inShape[1], n0}; - } else if(inShape.size() == 2){ + } else if (inShape.size() == 2){ nzShape = {1, inShape[1] / n0, inShape[0], n0}; } } @@ -215,8 +215,7 @@ atb::Status CreateTensorFromVector(atb::Context *contextPtr, aclrtStream stream, return atb::ErrorType::NO_ERROR; } CHECK_STATUS(CastOp(contextPtr, stream, tensor, outTensorType, outTensor)); - if(outTensor.desc.format != format) - { + if (outTensor.desc.format != format) { //直接赋值将tensor转成需要的数据格式,或者使用提供的TransdataOp函数进行数据格式转换 outTensor.desc.format = format; } -- Gitee From b8d27649698acfebcaffd74d9db6466982c49d53 Mon Sep 17 00:00:00 2001 From: Vector Date: Tue, 12 Aug 2025 11:56:21 +0000 Subject: [PATCH 3/3] update example/op_demo/demo_util.h. Signed-off-by: Vector --- example/op_demo/demo_util.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/example/op_demo/demo_util.h b/example/op_demo/demo_util.h index 360bfe9d..9bee2694 100644 --- a/example/op_demo/demo_util.h +++ b/example/op_demo/demo_util.h @@ -121,7 +121,7 @@ atb::Status getShape(const aclDataType tensorType, std::vector inShape, ndShape.assign(inShape.begin(), inShape.end()); if (inShape.size() == 3) { // 该shape包含batch参数 nzShape = {inShape[0], inShape[2] / n0, inShape[1], n0}; - } else if (inShape.size() == 2){ + } else if (inShape.size() == 2) { nzShape = {1, inShape[1] / n0, inShape[0], n0}; } } -- Gitee