From 0a0a6fb938677784ef5cd787b200f1c25895bb86 Mon Sep 17 00:00:00 2001 From: ruoshuisixue Date: Tue, 10 Dec 2024 19:47:23 +0800 Subject: [PATCH] =?UTF-8?q?=E6=A8=A1=E6=9D=BF=E5=8F=82=E6=95=B0=E7=B1=BB?= =?UTF-8?q?=E7=AE=97=E5=AD=90=E6=96=B0=E5=A2=9E=E8=87=AA=E5=AE=9A=E4=B9=89?= =?UTF-8?q?=E7=BB=93=E6=9E=84=E4=BD=93?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../AddTemplateCustom/op_host/add_custom.cpp | 28 +++++++++++-------- .../op_kernel/add_custom.cpp | 13 ++++++++- .../add_custom_tiling.h | 19 +++++++++---- .../op_kernel/tiling_key_add_custom.h | 11 +++++--- .../6_addtemplate_frameworklaunch/README.md | 9 +++--- .../6_addtemplate_frameworklaunch/install.sh | 2 ++ 6 files changed, 56 insertions(+), 26 deletions(-) rename operator/ascendc/0_introduction/6_addtemplate_frameworklaunch/AddTemplateCustom/{op_host => op_kernel}/add_custom_tiling.h (69%) diff --git a/operator/ascendc/0_introduction/6_addtemplate_frameworklaunch/AddTemplateCustom/op_host/add_custom.cpp b/operator/ascendc/0_introduction/6_addtemplate_frameworklaunch/AddTemplateCustom/op_host/add_custom.cpp index faf223f02..fdd5050ee 100644 --- a/operator/ascendc/0_introduction/6_addtemplate_frameworklaunch/AddTemplateCustom/op_host/add_custom.cpp +++ b/operator/ascendc/0_introduction/6_addtemplate_frameworklaunch/AddTemplateCustom/op_host/add_custom.cpp @@ -19,23 +19,23 @@ static ge::graphStatus TilingFunc(gert::TilingContext *context) { TilingData tiling; uint32_t totalLength = context->GetInputShape(0)->GetOriginShape().GetShapeSize(); - ge::DataType dtype_x = context->GetInputDesc(0)->GetDataType(); - ge::DataType dtype_y = context->GetInputDesc(1)->GetDataType(); - ge::DataType dtype_z = context->GetOutputDesc(0)->GetDataType(); + ge::DataType dataTypeX = context->GetInputDesc(0)->GetDataType(); + ge::DataType dataTypeY = context->GetInputDesc(1)->GetDataType(); + ge::DataType dataTypeZ = context->GetOutputDesc(0)->GetDataType(); uint32_t D_T_X = ADD_TPL_FP32, D_T_Y=ADD_TPL_FP32, D_T_Z=ADD_TPL_FP32, TILE_NUM=1, IS_SPLIT=0; - if(dtype_x == ge::DataType::DT_FLOAT){ + if(dataTypeX == ge::DataType::DT_FLOAT){ D_T_X = ADD_TPL_FP32; - }else if(dtype_x == ge::DataType::DT_FLOAT16){ + }else if(dataTypeX == ge::DataType::DT_FLOAT16){ D_T_X = ADD_TPL_FP16; } - if(dtype_y == ge::DataType::DT_FLOAT){ + if(dataTypeY == ge::DataType::DT_FLOAT){ D_T_Y = ADD_TPL_FP32; - }else if(dtype_y == ge::DataType::DT_FLOAT16){ + }else if(dataTypeY == ge::DataType::DT_FLOAT16){ D_T_Y = ADD_TPL_FP16; } - if(dtype_z == ge::DataType::DT_FLOAT){ + if(dataTypeZ == ge::DataType::DT_FLOAT){ D_T_Z = ADD_TPL_FP32; - }else if(dtype_z == ge::DataType::DT_FLOAT16){ + }else if(dataTypeZ == ge::DataType::DT_FLOAT16){ D_T_Z = ADD_TPL_FP16; } if(totalLength< MIN_LENGTH_FOR_SPLIT){ @@ -45,10 +45,14 @@ static ge::graphStatus TilingFunc(gert::TilingContext *context) IS_SPLIT = 1; TILE_NUM = DEFAULT_TILE_NUM; } + if(D_T_X == ADD_TPL_FP32 && D_T_Y == ADD_TPL_FP32 && D_T_Z == ADD_TPL_FP32){ + TilingDataFp *tiling = context->GetTilingData(); + tiling->totalLength = totalLength; + }else if(D_T_X == ADD_TPL_FP16 && D_T_Y == ADD_TPL_FP16 && D_T_Z == ADD_TPL_FP16){ + TilingDataFp16 *tiling = context->GetTilingData(); + tiling->totalLength = totalLength; + } context->SetBlockDim(BLOCK_DIM); - tiling.set_totalLength(totalLength); - tiling.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity()); - context->GetRawTilingData()->SetDataSize(tiling.GetDataSize()); const uint64_t tilingKey = GET_TPL_TILING_KEY(D_T_X, D_T_Y, D_T_Z, TILE_NUM, IS_SPLIT); // 模板参数tilingkey配置 context->SetTilingKey(tilingKey); size_t *currentWorkspace = context->GetWorkspaceSizes(1); diff --git a/operator/ascendc/0_introduction/6_addtemplate_frameworklaunch/AddTemplateCustom/op_kernel/add_custom.cpp b/operator/ascendc/0_introduction/6_addtemplate_frameworklaunch/AddTemplateCustom/op_kernel/add_custom.cpp index e67483876..4e00578a0 100644 --- a/operator/ascendc/0_introduction/6_addtemplate_frameworklaunch/AddTemplateCustom/op_kernel/add_custom.cpp +++ b/operator/ascendc/0_introduction/6_addtemplate_frameworklaunch/AddTemplateCustom/op_kernel/add_custom.cpp @@ -9,6 +9,7 @@ */ #include "kernel_operator.h" #include "tiling_key_add_custom.h" +#include "add_custom_tiling.h" constexpr int32_t BUFFER_NUM = 2; // tensor num for each queue template @@ -90,12 +91,22 @@ private: template __global__ __aicore__ void add_custom(GM_ADDR x, GM_ADDR y, GM_ADDR z, GM_ADDR workspace, GM_ADDR tiling) { - GET_TILING_DATA(tiling_data, tiling); + //注册默认tiling结构体 + REGISTER_TILING_DEFAULT(optiling::TilingData); + //注册数据类型为FP32的tilingData结构体,此处必须和模板参数中定义保持一致,否则会有oom问题 + REGISTER_TILING_FOR_TILINGKEY( + "D_T_X == ADD_TPL_FP32 && D_T_Y == ADD_TPL_FP32 && D_T_Z == ADD_TPL_FP32", optiling::TilingDataFp); + //注册数据类型为FP16的tilingData结构体,此处必须和模板参数中定义保持一致,否则会有oom问题 + REGISTER_TILING_FOR_TILINGKEY( + "D_T_X == ADD_TPL_FP16 && D_T_Y == ADD_TPL_FP16 && D_T_Z == ADD_TPL_FP16", optiling::TilingDataFp16); + if(D_T_X == ADD_TPL_FP32 && D_T_Y == ADD_TPL_FP32 && D_T_Z == ADD_TPL_FP32){ + GET_TILING_DATA_WITH_STRUCT(optiling::TilingDataFp, tiling_data, tiling); KernelAdd op; op.Init(x, y, z, tiling_data.totalLength, TILE_NUM); op.Process1(); }else if(D_T_X == ADD_TPL_FP16 && D_T_Y == ADD_TPL_FP16 && D_T_Z == ADD_TPL_FP16){ + GET_TILING_DATA_WITH_STRUCT(optiling::TilingDataFp16, tiling_data, tiling); KernelAdd op; if(IS_SPLIT == 0){ op.Init(x, y, z, tiling_data.totalLength, TILE_NUM); diff --git a/operator/ascendc/0_introduction/6_addtemplate_frameworklaunch/AddTemplateCustom/op_host/add_custom_tiling.h b/operator/ascendc/0_introduction/6_addtemplate_frameworklaunch/AddTemplateCustom/op_kernel/add_custom_tiling.h similarity index 69% rename from operator/ascendc/0_introduction/6_addtemplate_frameworklaunch/AddTemplateCustom/op_host/add_custom_tiling.h rename to operator/ascendc/0_introduction/6_addtemplate_frameworklaunch/AddTemplateCustom/op_kernel/add_custom_tiling.h index 7e9e79d1d..5e53d9d4f 100644 --- a/operator/ascendc/0_introduction/6_addtemplate_frameworklaunch/AddTemplateCustom/op_host/add_custom_tiling.h +++ b/operator/ascendc/0_introduction/6_addtemplate_frameworklaunch/AddTemplateCustom/op_kernel/add_custom_tiling.h @@ -9,13 +9,22 @@ */ #ifndef ADD_CUSTOM_TILING_H #define ADD_CUSTOM_TILING_H -#include "register/tilingdata_base.h" +#include namespace optiling { -BEGIN_TILING_DATA_DEF(TilingData) -TILING_DATA_FIELD_DEF(uint32_t, totalLength); -END_TILING_DATA_DEF; +class TilingData{ +public: + uint32_t totalLength; +}; -REGISTER_TILING_DATA_CLASS(AddCustom, TilingData) +class TilingDataFp{ +public: + uint32_t totalLength; +}; + +class TilingDataFp16{ +public: + uint32_t totalLength; +}; } // namespace optiling #endif // ADD_CUSTOM_TILING_H diff --git a/operator/ascendc/0_introduction/6_addtemplate_frameworklaunch/AddTemplateCustom/op_kernel/tiling_key_add_custom.h b/operator/ascendc/0_introduction/6_addtemplate_frameworklaunch/AddTemplateCustom/op_kernel/tiling_key_add_custom.h index 61dcb08ce..63a30d0df 100644 --- a/operator/ascendc/0_introduction/6_addtemplate_frameworklaunch/AddTemplateCustom/op_kernel/tiling_key_add_custom.h +++ b/operator/ascendc/0_introduction/6_addtemplate_frameworklaunch/AddTemplateCustom/op_kernel/tiling_key_add_custom.h @@ -11,11 +11,11 @@ #define TILING_KEY_ADD_H #include "ascendc/host_api/tiling/template_argument.h" -#define ADD_TPL_FP16 10 -#define ADD_TPL_FP32 20 +#define ADD_TPL_FP16 1 //数据类型定义 +#define ADD_TPL_FP32 0 -#define ADD_TPL_ND 15 -#define ADD_TPL_NZ 25 +#define ADD_TPL_ND 2 //数据格式定义 +#define ADD_TPL_NZ 29 /** ASCENDC_TPL_ARGS_DECL(args0, ...):算子的模板参数定义, args0表示算子唯一标识, 建议与opType保持一致,后续为若干个DTYPE、FORMAT、UINT、BOOL的模板参数定义 ASCENDC_TPL_DTYPE_DECL(args0, ...): DTYPE类型的模板参数定义,args0表示参数名,后续若干个参数为穷举的DTYPE枚举值 @@ -49,6 +49,7 @@ ASCENDC_TPL_SEL(...):算子的模板参数整体组合,可设置多个模板参 ASCENDC_TPL_FORMAT_SEL(args0, ...): FORMAT类型的模板参数组合,args0表示参数名,后续若干个参数为对应的ASCENDC_TPL_FORMAT_DECL定义的参数范围子集 ASCENDC_TPL_UINT_SEL(args0, args1, args2, ...): UINT类型的模板参数定义,args0表示参数名,args1是参数的表示类型,支持的表示类型为ASCENDC_TPL_UI_RANGE,ASCENDC_TPL_UI_LIST,ASCENDC_TPL_UI_MIX,后续的数值定义参考ASCENDC_TPL_UINT_DECL的规则 ASCENDC_TPL_BOOL_SEL(args0, ...): bool类型的模板参数定义,args0表示参数名,后续若干个参数为对应的ASCENDC_TPL_BOOL_DECL定义的参数范围子集 + ASCENDC_TPL_TILING_STRUCT_SEL(args0): 此模板参数组合对应的自定义tiling结构体,此处需要和kernel侧的判断逻辑保持一致,否则会有oom问题,args0表示tiling结构体名 */ ASCENDC_TPL_SEL( ASCENDC_TPL_ARGS_SEL( @@ -57,6 +58,7 @@ ASCENDC_TPL_SEL( ASCENDC_TPL_DTYPE_SEL(D_T_Z, ADD_TPL_FP16), ASCENDC_TPL_UINT_SEL(TILE_NUM, ASCENDC_TPL_UI_LIST, 1, 8), ASCENDC_TPL_BOOL_SEL(IS_SPLIT, 0, 1), + ASCENDC_TPL_TILING_STRUCT_SEL(optiling::TilingDataFp16) ), ASCENDC_TPL_ARGS_SEL( ASCENDC_TPL_DTYPE_SEL(D_T_X, ADD_TPL_FP32), @@ -64,6 +66,7 @@ ASCENDC_TPL_SEL( ASCENDC_TPL_DTYPE_SEL(D_T_Z, ADD_TPL_FP32), ASCENDC_TPL_UINT_SEL(TILE_NUM, ASCENDC_TPL_UI_LIST, 1, 8), ASCENDC_TPL_BOOL_SEL(IS_SPLIT, 0, 1), + ASCENDC_TPL_TILING_STRUCT_SEL(optiling::TilingDataFp) ), ); #endif \ No newline at end of file diff --git a/operator/ascendc/0_introduction/6_addtemplate_frameworklaunch/README.md b/operator/ascendc/0_introduction/6_addtemplate_frameworklaunch/README.md index ed6800185..1b7e346bf 100644 --- a/operator/ascendc/0_introduction/6_addtemplate_frameworklaunch/README.md +++ b/operator/ascendc/0_introduction/6_addtemplate_frameworklaunch/README.md @@ -139,7 +139,8 @@ CANN软件包中提供了工程创建工具msopgen,AddTemplateCustom算子工 ### 4. 调用执行算子工程 - [aclnn调用AddTemplateCustom算子工程](./AclNNInvocation/README.md) ## 更新说明 -| 时间 | 更新事项 | -| ---------- |----------| -| 2024/10/25 | 新增模板参数算子样例 | -| 2024/11/18 | 样例目录调整 | +| 时间 | 更新事项 | +|------------|--------------| +| 2024/10/25 | 新增模板参数算子样例 | +| 2024/11/18 | 样例目录调整 | +| 2024/12/10 | 新增自定义结构体调用示例 | diff --git a/operator/ascendc/0_introduction/6_addtemplate_frameworklaunch/install.sh b/operator/ascendc/0_introduction/6_addtemplate_frameworklaunch/install.sh index 41f6be73c..1dcb3f853 100755 --- a/operator/ascendc/0_introduction/6_addtemplate_frameworklaunch/install.sh +++ b/operator/ascendc/0_introduction/6_addtemplate_frameworklaunch/install.sh @@ -49,6 +49,8 @@ OP_NAME=AddTemplateCustom rm -rf CustomOp # Generate the op framework msopgen gen -i $OP_NAME.json -c ai_core-${SOC_VERSION} -lan cpp -out CustomOp +# Delete gen tiling.h +rm -rf CustomOp/op_host/add_custom_tiling.h # Copy op implementation files to CustomOp cp -rf $OP_NAME/* CustomOp # Build CustomOp project -- Gitee