diff --git a/tutorials/source_zh_cn/advanced_use/custom_operator.md b/tutorials/source_zh_cn/advanced_use/custom_operator.md index 9607027cebb2e79f58f359518e4d50de4595e577..ece258ad5f7f1a4575ca2b8a568fea993c8b8f05 100644 --- a/tutorials/source_zh_cn/advanced_use/custom_operator.md +++ b/tutorials/source_zh_cn/advanced_use/custom_operator.md @@ -83,13 +83,19 @@ class CusSquare(PrimitiveWithInfer): 算子信息是指导后端选择算子实现的关键信息,同时也指导后端为算子插入合适的类型和格式转换。它通过`TBERegOp`接口定义,通过`op_info_register`装饰器将算子信息与算子实现入口函数绑定。当算子实现py文件被导入时,`op_info_register`装饰器会将算子信息注册到后端的算子信息库中。更多关于算子信息的使用方法请参考`TBERegOp`的成员方法的注释说明。 -> 算子信息中定义输入输出信息的个数和顺序、算子实现入口函数的参数中的输入输出信息的个数和顺序、算子原语中输入输出名称列表的个数和顺序,三者要完全一致。 - -> 算子如果带属性,在算子信息中需要用`attr()`描述属性信息,属性的名称与算子原语定义中的属性名称要一致。 +> - 算子信息中定义输入输出信息的个数和顺序、算子实现入口函数的参数中的输入输出信息的个数和顺序、算子原语中输入输出名称列表的个数和顺序,三者要完全一致。 +> - 算子如果带属性,在算子信息中需要用`attr()`描述属性信息,属性的名称与算子原语定义中的属性名称要一致。 ### 示例 -下面以`Square`算子的TBE实现`square_impl.py`为例进行介绍。`square_compute`是算子实现的计算函数,通过调到`te.lang.cce`提供的API描述了`x * x`的计算逻辑。`cus_square_op_info `是算子信息,通过`TBERegOp`来定义。`TBERegOp`中的`dtype_format`是用来描述算子支持的数据类型,下面示例中注册了两项说明该算子支持两种数据类型,而每一项需按照输入和输出的顺序依次描述支持的格式。第一个`dtype_format`说明支持的第一种数据类型是input0为F32_Default格式,output0为F32_Default格式。第二个`dtype_format`说明支持的第二种数据类型是input0为F16_Default格式,output0为F16_Default格式。 +下面以`Square`算子的TBE实现`square_impl.py`为例进行介绍。`square_compute`是算子实现的计算函数,通过调到`te.lang.cce`提供的API描述了`x * x`的计算逻辑。`cus_square_op_info `是算子信息,通过`TBERegOp`来定义。 + +在`TBERegOp`中: + +- `TBERegOp("CusSquare")`中算子注册名称`CusSquare`需要与算子名称一致。 +- `fusion_type("OPAQUE")`中`OPAQUE`是说明自定义算子采取不融合策略。 +- `kernel_name("CusSquareImpl")`中"CusSquareImpl"需要与算子入口函数名称一致。 +- `dtype_format`是用来描述算子支持的数据类型,下面示例中注册了两项说明该算子支持两种数据类型,而每一项需按照输入和输出的顺序依次描述支持的格式。第一个`dtype_format`说明支持的第一种数据类型是input0为F32_Default格式,output0为F32_Default格式。第二个`dtype_format`说明支持的第二种数据类型是input0为F16_Default格式,output0为F16_Default格式。 ```python from __future__ import absolute_import @@ -107,13 +113,13 @@ def square_compute(input_x, output_y): return res # Define the kernel info of CusSquare. -cus_square_op_info = TBERegOp("CusSquare") \ # The registered op name should be same with primitive name. - .fusion_type("OPAQUE") \ # Setting kernel fusion strategy. The default is not infusible. +cus_square_op_info = TBERegOp("CusSquare") \ + .fusion_type("OPAQUE") \ .partial_flag(True) \ .async_flag(False) \ .binfile_name("square.so") \ .compute_cost(10) \ - .kernel_name("CusSquareImpl") \ # The kernel name should be same with the name of the entry function. + .kernel_name("CusSquareImpl") \ .input(0, "x", False, "required", "all") \ .output(0, "y", False, "required", "all") \ .dtype_format(DataType.F32_Default, DataType.F32_Default) \