diff --git a/operator/ascendc/0_introduction/6_addtemplate_frameworklaunch/AddTemplateCustom/op_host/add_custom.cpp b/operator/ascendc/0_introduction/6_addtemplate_frameworklaunch/AddTemplateCustom/op_host/add_custom.cpp index faf223f02b3835a42d94c7ac45f7a126cf8f0b52..fdd5050ee39ac109d47a47399548e16da64ff226 100644 --- a/operator/ascendc/0_introduction/6_addtemplate_frameworklaunch/AddTemplateCustom/op_host/add_custom.cpp +++ b/operator/ascendc/0_introduction/6_addtemplate_frameworklaunch/AddTemplateCustom/op_host/add_custom.cpp @@ -19,23 +19,23 @@ static ge::graphStatus TilingFunc(gert::TilingContext *context) { TilingData tiling; uint32_t totalLength = context->GetInputShape(0)->GetOriginShape().GetShapeSize(); - ge::DataType dtype_x = context->GetInputDesc(0)->GetDataType(); - ge::DataType dtype_y = context->GetInputDesc(1)->GetDataType(); - ge::DataType dtype_z = context->GetOutputDesc(0)->GetDataType(); + ge::DataType dataTypeX = context->GetInputDesc(0)->GetDataType(); + ge::DataType dataTypeY = context->GetInputDesc(1)->GetDataType(); + ge::DataType dataTypeZ = context->GetOutputDesc(0)->GetDataType(); uint32_t D_T_X = ADD_TPL_FP32, D_T_Y=ADD_TPL_FP32, D_T_Z=ADD_TPL_FP32, TILE_NUM=1, IS_SPLIT=0; - if(dtype_x == ge::DataType::DT_FLOAT){ + if(dataTypeX == ge::DataType::DT_FLOAT){ D_T_X = ADD_TPL_FP32; - }else if(dtype_x == ge::DataType::DT_FLOAT16){ + }else if(dataTypeX == ge::DataType::DT_FLOAT16){ D_T_X = ADD_TPL_FP16; } - if(dtype_y == ge::DataType::DT_FLOAT){ + if(dataTypeY == ge::DataType::DT_FLOAT){ D_T_Y = ADD_TPL_FP32; - }else if(dtype_y == ge::DataType::DT_FLOAT16){ + }else if(dataTypeY == ge::DataType::DT_FLOAT16){ D_T_Y = ADD_TPL_FP16; } - if(dtype_z == ge::DataType::DT_FLOAT){ + if(dataTypeZ == ge::DataType::DT_FLOAT){ D_T_Z = ADD_TPL_FP32; - }else if(dtype_z == ge::DataType::DT_FLOAT16){ + }else if(dataTypeZ == ge::DataType::DT_FLOAT16){ D_T_Z = ADD_TPL_FP16; } if(totalLength< MIN_LENGTH_FOR_SPLIT){ @@ -45,10 +45,14 @@ static ge::graphStatus TilingFunc(gert::TilingContext *context) IS_SPLIT = 1; TILE_NUM = DEFAULT_TILE_NUM; } + if(D_T_X == ADD_TPL_FP32 && D_T_Y == ADD_TPL_FP32 && D_T_Z == ADD_TPL_FP32){ + TilingDataFp *tiling = context->GetTilingData(); + tiling->totalLength = totalLength; + }else if(D_T_X == ADD_TPL_FP16 && D_T_Y == ADD_TPL_FP16 && D_T_Z == ADD_TPL_FP16){ + TilingDataFp16 *tiling = context->GetTilingData(); + tiling->totalLength = totalLength; + } context->SetBlockDim(BLOCK_DIM); - tiling.set_totalLength(totalLength); - tiling.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity()); - context->GetRawTilingData()->SetDataSize(tiling.GetDataSize()); const uint64_t tilingKey = GET_TPL_TILING_KEY(D_T_X, D_T_Y, D_T_Z, TILE_NUM, IS_SPLIT); // 模板参数tilingkey配置 context->SetTilingKey(tilingKey); size_t *currentWorkspace = context->GetWorkspaceSizes(1); diff --git a/operator/ascendc/0_introduction/6_addtemplate_frameworklaunch/AddTemplateCustom/op_kernel/add_custom.cpp b/operator/ascendc/0_introduction/6_addtemplate_frameworklaunch/AddTemplateCustom/op_kernel/add_custom.cpp index 36bffd6638687e2e0a075de66e29ccd472ef819d..48aa818ae58def2d7b22eef2910fcb2df16d3a4b 100644 --- a/operator/ascendc/0_introduction/6_addtemplate_frameworklaunch/AddTemplateCustom/op_kernel/add_custom.cpp +++ b/operator/ascendc/0_introduction/6_addtemplate_frameworklaunch/AddTemplateCustom/op_kernel/add_custom.cpp @@ -9,6 +9,7 @@ */ #include "kernel_operator.h" #include "tiling_key_add_custom.h" +#include "add_custom_tiling.h" constexpr int32_t BUFFER_NUM = 2; // tensor num for each queue template @@ -90,12 +91,22 @@ private: template __global__ __aicore__ void add_custom(GM_ADDR x, GM_ADDR y, GM_ADDR z, GM_ADDR workspace, GM_ADDR tiling) { - GET_TILING_DATA(tiling_data, tiling); + //注册默认tiling结构体 + REGISTER_TILING_DEFAULT(optiling::TilingData); + //注册数据类型为FP32的tilingData结构体,此处必须和模板参数中定义保持一致,否则会有oom问题 + REGISTER_TILING_FOR_TILINGKEY( + "D_T_X == ADD_TPL_FP32 && D_T_Y == ADD_TPL_FP32 && D_T_Z == ADD_TPL_FP32", optiling::TilingDataFp); + //注册数据类型为FP16的tilingData结构体,此处必须和模板参数中定义保持一致,否则会有oom问题 + REGISTER_TILING_FOR_TILINGKEY( + "D_T_X == ADD_TPL_FP16 && D_T_Y == ADD_TPL_FP16 && D_T_Z == ADD_TPL_FP16", optiling::TilingDataFp16); + if(D_T_X == ADD_TPL_FP32 && D_T_Y == ADD_TPL_FP32 && D_T_Z == ADD_TPL_FP32){ + GET_TILING_DATA_WITH_STRUCT(optiling::TilingDataFp, tiling_data, tiling); KernelAdd op; op.Init(x, y, z, tiling_data.totalLength, TILE_NUM); op.Process1(); }else if(D_T_X == ADD_TPL_FP16 && D_T_Y == ADD_TPL_FP16 && D_T_Z == ADD_TPL_FP16){ + GET_TILING_DATA_WITH_STRUCT(optiling::TilingDataFp16, tiling_data, tiling); KernelAdd op; if(IS_SPLIT == 0){ op.Init(x, y, z, tiling_data.totalLength, TILE_NUM); diff --git a/operator/ascendc/0_introduction/6_addtemplate_frameworklaunch/AddTemplateCustom/op_host/add_custom_tiling.h b/operator/ascendc/0_introduction/6_addtemplate_frameworklaunch/AddTemplateCustom/op_kernel/add_custom_tiling.h similarity index 69% rename from operator/ascendc/0_introduction/6_addtemplate_frameworklaunch/AddTemplateCustom/op_host/add_custom_tiling.h rename to operator/ascendc/0_introduction/6_addtemplate_frameworklaunch/AddTemplateCustom/op_kernel/add_custom_tiling.h index 7e9e79d1d2b9b7da9fb2bec8d0914013dce3a59f..5e53d9d4fab48b86e379978909c0a1d9f8c63cf3 100644 --- a/operator/ascendc/0_introduction/6_addtemplate_frameworklaunch/AddTemplateCustom/op_host/add_custom_tiling.h +++ b/operator/ascendc/0_introduction/6_addtemplate_frameworklaunch/AddTemplateCustom/op_kernel/add_custom_tiling.h @@ -9,13 +9,22 @@ */ #ifndef ADD_CUSTOM_TILING_H #define ADD_CUSTOM_TILING_H -#include "register/tilingdata_base.h" +#include namespace optiling { -BEGIN_TILING_DATA_DEF(TilingData) -TILING_DATA_FIELD_DEF(uint32_t, totalLength); -END_TILING_DATA_DEF; +class TilingData{ +public: + uint32_t totalLength; +}; -REGISTER_TILING_DATA_CLASS(AddCustom, TilingData) +class TilingDataFp{ +public: + uint32_t totalLength; +}; + +class TilingDataFp16{ +public: + uint32_t totalLength; +}; } // namespace optiling #endif // ADD_CUSTOM_TILING_H diff --git a/operator/ascendc/0_introduction/6_addtemplate_frameworklaunch/AddTemplateCustom/op_kernel/tiling_key_add_custom.h b/operator/ascendc/0_introduction/6_addtemplate_frameworklaunch/AddTemplateCustom/op_kernel/tiling_key_add_custom.h index eae21744440e1706cddde0186f5f964d0ab77c4f..b47203895a809daf46a6bd568e7543102b83d52b 100644 --- a/operator/ascendc/0_introduction/6_addtemplate_frameworklaunch/AddTemplateCustom/op_kernel/tiling_key_add_custom.h +++ b/operator/ascendc/0_introduction/6_addtemplate_frameworklaunch/AddTemplateCustom/op_kernel/tiling_key_add_custom.h @@ -11,11 +11,11 @@ #define TILING_KEY_ADD_CUSTOM_H #include "ascendc/host_api/tiling/template_argument.h" -#define ADD_TPL_FP16 10 -#define ADD_TPL_FP32 20 +#define ADD_TPL_FP16 1 //数据类型定义 +#define ADD_TPL_FP32 0 -#define ADD_TPL_ND 15 -#define ADD_TPL_NZ 25 +#define ADD_TPL_ND 2 //数据格式定义 +#define ADD_TPL_NZ 29 /** ASCENDC_TPL_ARGS_DECL(args0, ...):算子的模板参数定义, args0表示算子唯一标识, 建议与opType保持一致,后续为若干个DTYPE、FORMAT、UINT、BOOL的模板参数定义 ASCENDC_TPL_DTYPE_DECL(args0, ...): DTYPE类型的模板参数定义,args0表示参数名,后续若干个参数为穷举的DTYPE枚举值 @@ -49,6 +49,7 @@ ASCENDC_TPL_SEL(...):算子的模板参数整体组合,可设置多个模板参 ASCENDC_TPL_FORMAT_SEL(args0, ...): FORMAT类型的模板参数组合,args0表示参数名,后续若干个参数为对应的ASCENDC_TPL_FORMAT_DECL定义的参数范围子集 ASCENDC_TPL_UINT_SEL(args0, args1, args2, ...): UINT类型的模板参数定义,args0表示参数名,args1是参数的表示类型,支持的表示类型为ASCENDC_TPL_UI_RANGE,ASCENDC_TPL_UI_LIST,ASCENDC_TPL_UI_MIX,后续的数值定义参考ASCENDC_TPL_UINT_DECL的规则 ASCENDC_TPL_BOOL_SEL(args0, ...): bool类型的模板参数定义,args0表示参数名,后续若干个参数为对应的ASCENDC_TPL_BOOL_DECL定义的参数范围子集 + ASCENDC_TPL_TILING_STRUCT_SEL(args0): 此模板参数组合对应的自定义tiling结构体,此处需要和kernel侧的判断逻辑保持一致,否则会有oom问题,args0表示tiling结构体名 */ ASCENDC_TPL_SEL( ASCENDC_TPL_ARGS_SEL( @@ -56,14 +57,16 @@ ASCENDC_TPL_SEL( ASCENDC_TPL_DTYPE_SEL(D_T_Y, ADD_TPL_FP16), ASCENDC_TPL_DTYPE_SEL(D_T_Z, ADD_TPL_FP16), ASCENDC_TPL_UINT_SEL(TILE_NUM, ASCENDC_TPL_UI_LIST, 1, 8), - ASCENDC_TPL_BOOL_SEL(IS_SPLIT, 0, 1) + ASCENDC_TPL_BOOL_SEL(IS_SPLIT, 0, 1), + ASCENDC_TPL_TILING_STRUCT_SEL(optiling::TilingDataFp16) ), ASCENDC_TPL_ARGS_SEL( ASCENDC_TPL_DTYPE_SEL(D_T_X, ADD_TPL_FP32), ASCENDC_TPL_DTYPE_SEL(D_T_Y, ADD_TPL_FP32), ASCENDC_TPL_DTYPE_SEL(D_T_Z, ADD_TPL_FP32), ASCENDC_TPL_UINT_SEL(TILE_NUM, ASCENDC_TPL_UI_LIST, 1, 8), - ASCENDC_TPL_BOOL_SEL(IS_SPLIT, 0, 1) + ASCENDC_TPL_BOOL_SEL(IS_SPLIT, 0, 1), + ASCENDC_TPL_TILING_STRUCT_SEL(optiling::TilingDataFp) ) ); #endif // TILING_KEY_ADD_CUSTOM_H \ No newline at end of file diff --git a/operator/ascendc/0_introduction/6_addtemplate_frameworklaunch/README.md b/operator/ascendc/0_introduction/6_addtemplate_frameworklaunch/README.md index 975426ef12794f417f26b844a8f4f7b024c2d98f..bd153e5dc078dc7de053faf2973cfe85b60c50df 100644 --- a/operator/ascendc/0_introduction/6_addtemplate_frameworklaunch/README.md +++ b/operator/ascendc/0_introduction/6_addtemplate_frameworklaunch/README.md @@ -139,7 +139,8 @@ CANN软件包中提供了工程创建工具msOpGen,AddTemplateCustom算子工 ### 4. 调用执行算子工程 - [aclnn调用AddTemplateCustom算子工程](./AclNNInvocation/README.md) ## 更新说明 -| 时间 | 更新事项 | -| ---------- |----------| -| 2024/10/25 | 新增模板参数算子样例 | -| 2024/11/18 | 样例目录调整 | +| 时间 | 更新事项 | +|------------|--------------| +| 2024/10/25 | 新增模板参数算子样例 | +| 2024/11/18 | 样例目录调整 | +| 2025/11/7 | 新增自定义结构体调用示例 | diff --git a/operator/ascendc/0_introduction/6_addtemplate_frameworklaunch/install.sh b/operator/ascendc/0_introduction/6_addtemplate_frameworklaunch/install.sh index 4b74830f0f3fcf2b8f3a956bf9122c41006e67d7..b64510b32f538ee1e15185cace3b99cee9d1b7f8 100755 --- a/operator/ascendc/0_introduction/6_addtemplate_frameworklaunch/install.sh +++ b/operator/ascendc/0_introduction/6_addtemplate_frameworklaunch/install.sh @@ -49,6 +49,8 @@ OP_NAME=AddTemplateCustom rm -rf CustomOp # Generate the op framework msopgen gen -i $OP_NAME.json -c ai_core-${SOC_VERSION} -lan cpp -out CustomOp +# Delete gen tiling.h +rm -rf CustomOp/op_host/add_custom_tiling.h # Copy op implementation files to CustomOp cp -rf $OP_NAME/* CustomOp # Build CustomOp project