diff --git a/CMakeLists.txt b/CMakeLists.txt index f1606b572de7f1f0338db54d0e89cbd87af32f66..fcccc64a7fbb96e29c50713a4fe787487fe1de08 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -88,7 +88,6 @@ include_directories( ${PROJECT_SOURCE_DIR}/src/kernels/include/lcal/tiling ${PROJECT_SOURCE_DIR}/3rdparty/mki/include ${PROJECT_SOURCE_DIR}/3rdparty/nlohmannJson/include - $ENV{ASCEND_HOME_PATH}/include $ENV{ASCEND_HOME_PATH}/include/aclnn $ENV{PYTHON_INCLUDE_PATH} $ENV{PYTORCH_INSTALL_PATH}/include @@ -102,7 +101,8 @@ link_directories( $ENV{ASCEND_HOME_PATH}/lib64 $ENV{PYTHON_LIB_PATH} $ENV{PYTORCH_INSTALL_PATH}/lib - $ENV{PYTORCH_NPU_INSTALL_PATH}/lib) + $ENV{PYTORCH_NPU_INSTALL_PATH}/lib + $ENV{ATB_HOME_PATH}/lib) if(BUILD_TEST_FRAMEWORK OR USE_UNIT_TEST OR USE_PYTHON_TEST OR USE_FUZZ_TEST OR USE_CSV_OPS_TEST OR USE_INFRA_TEST OR USE_ALL_TEST) if(USE_FUZZ_TEST OR USE_ALL_TEST) diff --git a/configs/build_config.json b/configs/build_config.json new file mode 100644 index 0000000000000000000000000000000000000000..8f8b66dc18827e4fc5191ed5067535c715af4980 --- /dev/null +++ b/configs/build_config.json @@ -0,0 +1,8 @@ +{ + "targets": { + "ascend310b": true, + "ascend310p": true, + "ascend910b": true, + "ascend910": true + } +} diff --git a/configs/mixops/tbe_tactic_info.ini b/configs/mixops/tbe_tactic_info.ini new file mode 100644 index 0000000000000000000000000000000000000000..2917d03e7a81e6a59744bfbe408d804e8de34452 --- /dev/null +++ b/configs/mixops/tbe_tactic_info.ini @@ -0,0 +1,19 @@ +[RopeKernel] +ops=rotary_pos_emb_infer +operationName=RopeOperation +inputCount=5 +outputCount=2 +dtypeIn=float16,float16,float16,float16,int32 +dtypeOut=float16,float16 +attrs=None,None +mode=high_performance + +[ToppsampleKernel] +ops=top_p_sample +operationName=ToppsampleOperation +inputCount=2 +outputCount=1 +dtypeIn=float16,float16 +dtypeOut=int32 +mode=high_performance +socSupport=ascend910b,ascend310p \ No newline at end of file diff --git a/configs/ops/tbe_tactic_info.ini b/configs/ops/tbe_tactic_info.ini new file mode 100644 index 0000000000000000000000000000000000000000..de54195d1df17fdb88c956716da75f00c662c6ee --- /dev/null +++ b/configs/ops/tbe_tactic_info.ini @@ -0,0 +1,1879 @@ +[GeluF32Kernel] +ops=gelu_v2 +operationName=ActivationOperation +inputCount=1 +outputCount=1 +dtypeIn=float32 +dtypeOut=float32 +mode=high_precision +attrs=none + +[GeluF16Kernel] +ops=gelu_v2 +operationName=ActivationOperation +inputCount=1 +outputCount=1 +dtypeIn=float16 +dtypeOut=float16 +mode=high_precision +attrs=none + +[GeluBF16Kernel] +ops=gelu_v2 +operationName=ActivationOperation +inputCount=1 +outputCount=1 +dtypeIn=bfloat16 +dtypeOut=bfloat16 +attrs=none +socSupport=ascend910b + +[GeluApproxF32Kernel] +ops=gelu +operationName=ActivationOperation +inputCount=1 +outputCount=1 +dtypeIn=float32 +dtypeOut=float32 +mode=high_precision +socSupport=ascend310p,ascend910b,ascend910 + +[GeluApproxF16Kernel] +ops=gelu +operationName=ActivationOperation +inputCount=1 +outputCount=1 +dtypeIn=float16 +dtypeOut=float16 +mode=high_precision +socSupport=ascend310p,ascend910b,ascend910 + +[GeluApproxBF16Kernel] +ops=gelu +operationName=ActivationOperation +inputCount=1 +outputCount=1 +dtypeIn=bfloat16 +dtypeOut=bfloat16 +socSupport=ascend910b + +[ReluF32Kernel] +ops=relu +operationName=ActivationOperation +inputCount=1 +outputCount=1 +dtypeIn=float32 +dtypeOut=float32 + +[ReluBF16Kernel] +ops=relu +operationName=ActivationOperation +inputCount=1 +outputCount=1 +dtypeIn=bfloat16 +dtypeOut=bfloat16 +socSupport=ascend910b + +[AsStridedF16Int64Kernel] +ops=as_strided +operationName=AsStridedOperation +inputCount=4 +outputCount=1 +dtypeIn=float16,int64,int64,int64 +dtypeOut=float16 + +[AsStridedF32Int64Kernel] +ops=as_strided +operationName=AsStridedOperation +inputCount=4 +outputCount=1 +dtypeIn=float32,int64,int64,int64 +dtypeOut=float32 +socSupport=ascend910b + +[AsStridedInt64Int64Kernel] +ops=as_strided +operationName=AsStridedOperation +inputCount=4 +outputCount=1 +dtypeIn=int64,int64,int64,int64 +dtypeOut=int64 + +[AddI32Kernel] +ops=add +operationName=ElewiseOperation +inputCount=2 +outputCount=1 +dtypeIn=int32,int32 +dtypeOut=int32 + +[AddI64Kernel] +ops=add +operationName=ElewiseOperation +inputCount=2 +outputCount=1 +dtypeIn=int64,int64 +dtypeOut=int64 + +[AddF16Kernel] +ops=add +operationName=ElewiseOperation +inputCount=2 +outputCount=1 +dtypeIn=float16,float16 +dtypeOut=float16 + +[AddF32Kernel] +ops=add +operationName=ElewiseOperation +inputCount=2 +outputCount=1 +dtypeIn=float32,float32 +dtypeOut=float32 + +[AddBF16Kernel] +ops=add +operationName=ElewiseOperation +inputCount=2 +outputCount=1 +dtypeIn=bfloat16,bfloat16 +dtypeOut=bfloat16 +socSupport=ascend910b + +[SubF16Kernel] +ops=sub +operationName=ElewiseOperation +inputCount=2 +outputCount=1 +dtypeIn=float16,float16 +dtypeOut=float16 + +[SubBF16Kernel] +ops=sub +operationName=ElewiseOperation +inputCount=2 +outputCount=1 +dtypeIn=bfloat16,bfloat16 +dtypeOut=bfloat16 +socSupport=ascend910b + +[SubInt64Kernel] +ops=sub +operationName=ElewiseOperation +inputCount=2 +outputCount=1 +dtypeIn=int64,int64 +dtypeOut=int64 + +[MaskedFillF16Kernel] +ops=masked_fill +operationName=FillOperation +inputCount=3 +outputCount=1 +dtypeIn=float16,int8/bool,float16 +dtypeOut=float16 + +[MaskedFillBF16Kernel] +ops=masked_fill +operationName=FillOperation +inputCount=3 +outputCount=1 +dtypeIn=bfloat16,int8/bool,bfloat16 +dtypeOut=bfloat16 +socSupport=ascend910b + +[MaskedFillInt32Kernel] +ops=masked_fill +operationName=FillOperation +inputCount=3 +outputCount=1 +dtypeIn=int32,int8/bool,int32 +dtypeOut=int32 + +[FillF16Kernel] +ops=fill +operationName=FillOperation +inputCount=2 +outputCount=1 +dtypeIn=int64,float16 +dtypeOut=float16 + +[FillBF16Kernel] +ops=fill +operationName=FillOperation +inputCount=2 +outputCount=1 +dtypeIn=int64,bfloat16 +dtypeOut=bfloat16 +socSupport=ascend910b + +[MulF16Kernel] +ops=mul +operationName=ElewiseOperation +inputCount=2 +outputCount=1 +dtypeIn=float16,float16 +dtypeOut=float16 + +[MulF32Kernel] +ops=mul +operationName=ElewiseOperation +inputCount=2 +outputCount=1 +dtypeIn=float32,float32 +dtypeOut=float32 + +[OnehotInt32Kernel] +ops=one_hot +operationName=OnheHotOperation +inputCount=4 +outputCount=1 +dtypeIn=int32,int64,int32,int32 +dtypeOut=int32 +attrs=None +socSupport=ascend310p,ascend910b,ascend910 + +[OnehotInt64Kernel] +ops=one_hot +operationName=OnheHotOperation +inputCount=4 +outputCount=1 +dtypeIn=int64,int64,int64,int64 +dtypeOut=int64 +attrs=None +socSupport=ascend310p,ascend910b,ascend910 + +[MulBF16Kernel] +ops=mul +operationName=ElewiseOperation +inputCount=2 +outputCount=1 +dtypeIn=bfloat16,bfloat16 +dtypeOut=bfloat16 +socSupport=ascend910b + +[RealDivF32Kernel] +ops=real_div +operationName=ElewiseOperation +inputCount=2 +outputCount=1 +dtypeIn=float32,float32 +dtypeOut=float32 + +[RealDivF16Kernel] +ops=real_div +operationName=ElewiseOperation +inputCount=2 +outputCount=1 +dtypeIn=float16,float16 +dtypeOut=float16 + +[RealDivBF16Kernel] +ops=real_div +operationName=ElewiseOperation +inputCount=2 +outputCount=1 +dtypeIn=bfloat16,bfloat16 +dtypeOut=bfloat16 +socSupport=ascend910b + +[ExpandF16Kernel] +ops=expand +operationName=ExpandOperation +inputCount=2 +outputCount=1 +dtypeIn=float16,int64 +dtypeOut=float16 + +[ExpandBF16Kernel] +ops=expand +operationName=ExpandOperation +inputCount=2 +outputCount=1 +dtypeIn=bfloat16,int64 +dtypeOut=bfloat16 +socSupport=ascend910b + +[ConcatF16Input2Kernel] +ops=concat_d +operationName=ConcatOperation +inputCount=2 +outputCount=1 +dtypeIn=float16,float16 +dtypeOut=float16 + +[ConcatF32Input2Kernel] +ops=concat_d +operationName=ConcatOperation +inputCount=2 +outputCount=1 +dtypeIn=float32,float32 +dtypeOut=float32 +socSupport=ascend910b + +[CastF16F32Kernel] +ops=cast +operationName=ElewiseOperation +inputCount=1 +outputCount=1 +dtypeIn=float16 +dtypeOut=float32 + +[CastF32F16Kernel] +ops=cast +operationName=ElewiseOperation +inputCount=1 +outputCount=1 +dtypeIn=float32 +dtypeOut=float16 + +[CastI32I64Kernel] +ops=cast +operationName=ElewiseOperation +inputCount=1 +outputCount=1 +dtypeIn=int32 +dtypeOut=int64 + +[CastI64I32Kernel] +ops=cast +operationName=ElewiseOperation +inputCount=1 +outputCount=1 +dtypeIn=int64 +dtypeOut=int32 + +[CastF16I32Kernel] +ops=cast +operationName=ElewiseOperation +inputCount=1 +outputCount=1 +dtypeIn=float16 +dtypeOut=int32 + +[CastI32F16Kernel] +ops=cast +operationName=ElewiseOperation +inputCount=1 +outputCount=1 +dtypeIn=int32 +dtypeOut=float16 + +[CastBF16F32Kernel] +ops=cast +operationName=ElewiseOperation +inputCount=1 +outputCount=1 +dtypeIn=bfloat16 +dtypeOut=float32 +socSupport=ascend910b + +[CastF32BF16Kernel] +ops=cast +operationName=ElewiseOperation +inputCount=1 +outputCount=1 +dtypeIn=float32 +dtypeOut=bfloat16 +socSupport=ascend910b + +[CastI32F32Kernel] +ops=cast +operationName=ElewiseOperation +inputCount=1 +outputCount=1 +dtypeIn=int32 +dtypeOut=float32 + +[CastF32I32Kernel] +ops=cast +operationName=ElewiseOperation +inputCount=1 +outputCount=1 +dtypeIn=float32 +dtypeOut=int32 + +[CastI8F16Kernel] +ops=cast +operationName=ElewiseOperation +inputCount=1 +outputCount=1 +dtypeIn=int8 +dtypeOut=float16 + +[CastF16I8Kernel] +ops=cast +operationName=ElewiseOperation +inputCount=1 +outputCount=1 +dtypeIn=float16 +dtypeOut=int8 + +[CosF16Kernel] +ops=cos +operationName=ElewiseOperation +inputCount=1 +outputCount=1 +dtypeIn=float16 +dtypeOut=float16 +mode=high_performance + +[CosBF16Kernel] +ops=cos +operationName=ElewiseOperation +inputCount=1 +outputCount=1 +dtypeIn=bfloat16 +dtypeOut=bfloat16 +mode=high_performance +socSupport=ascend910b + +[CosF32Kernel] +ops=cos +operationName=ElewiseOperation +inputCount=1 +outputCount=1 +dtypeIn=float32 +dtypeOut=float32 +mode=high_performance + +[EqualF16Kernel] +ops=equal +operationName=ElewiseOperation +inputCount=2 +outputCount=1 +dtypeIn=float16,float16 +dtypeOut=int8/bool + +[EqualBF16Kernel] +ops=equal +operationName=ElewiseOperation +inputCount=2 +outputCount=1 +dtypeIn=bfloat16,bfloat16 +dtypeOut=int8/bool +socSupport=ascend910b + +[EqualF32Kernel] +ops=equal +operationName=ElewiseOperation +inputCount=2 +outputCount=1 +dtypeIn=float32,float32 +dtypeOut=int8/bool + +[FastGeluF16Kernel] +ops=fast_gelu +operationName=ElewiseOperation +inputCount=1 +outputCount=1 +dtypeIn=float16 +dtypeOut=float16 +mode=high_precision +socSupport=ascend310p,ascend910b,ascend910 + +[FastGeluBF16Kernel] +ops=fast_gelu +operationName=ElewiseOperation +inputCount=1 +outputCount=1 +dtypeIn=bfloat16 +dtypeOut=bfloat16 +mode=high_precision +socSupport=ascend910b + +[GreaterF16Kernel] +ops=greater +operationName=ElewiseOperation +inputCount=2 +outputCount=1 +dtypeIn=float16,float16 +dtypeOut=int8/bool + +[GreaterBF16Kernel] +ops=greater +operationName=ElewiseOperation +inputCount=2 +outputCount=1 +dtypeIn=bfloat16,bfloat16 +dtypeOut=int8/bool +socSupport=ascend910b + +[GreaterF32Kernel] +ops=greater +operationName=ElewiseOperation +inputCount=2 +outputCount=1 +dtypeIn=float32,float32 +dtypeOut=int8/bool + +[GreaterInt64Kernel] +ops=greater +operationName=ElewiseOperation +inputCount=2 +outputCount=1 +dtypeIn=int64,int64 +dtypeOut=int8/bool + +[LessF16Kernel] +ops=less +operationName=ElewiseOperation +inputCount=2 +outputCount=1 +dtypeIn=float16,float16 +dtypeOut=int8/bool + +[LessBF16Kernel] +ops=less +operationName=ElewiseOperation +inputCount=2 +outputCount=1 +dtypeIn=bfloat16,bfloat16 +dtypeOut=int8/bool +socSupport=ascend910b + +[LessF32Kernel] +ops=less +operationName=ElewiseOperation +inputCount=2 +outputCount=1 +dtypeIn=float32,float32 +dtypeOut=int8/bool + +[LessInt64Kernel] +ops=less +operationName=ElewiseOperation +inputCount=2 +outputCount=1 +dtypeIn=int64,int64 +dtypeOut=int8/bool + +[LogicalAndInt8Kernel] +ops=logical_and +operationName=ElewiseOperation +inputCount=2 +outputCount=1 +dtypeIn=int8/bool,int8/bool +dtypeOut=int8/bool + +[LogicalNotKernel] +ops=logical_not +operationName=ElewiseOperation +inputCount=1 +outputCount=1 +dtypeIn=int8/bool +dtypeOut=int8/bool + +[LogicalOrInt8Kernel] +ops=logical_or +operationName=ElewiseOperation +inputCount=2 +outputCount=1 +dtypeIn=int8/bool,int8/bool +dtypeOut=int8/bool + +[MulsF16Kernel] +ops=muls +operationName=ElewiseOperation +inputCount=1 +outputCount=1 +dtypeIn=float16 +dtypeOut=float16 + +[MulsBF16Kernel] +ops=muls +operationName=ElewiseOperation +inputCount=1 +outputCount=1 +dtypeIn=bfloat16 +dtypeOut=bfloat16 +socSupport=ascend910b + +[MulsF32Kernel] +ops=muls +operationName=ElewiseOperation +inputCount=1 +outputCount=1 +dtypeIn=float32 +dtypeOut=float32 + +[NegF16Kernel] +ops=neg +operationName=ElewiseOperation +inputCount=1 +outputCount=1 +dtypeIn=float16 +dtypeOut=float16 + +[NegBF16Kernel] +ops=neg +operationName=ElewiseOperation +inputCount=1 +outputCount=1 +dtypeIn=bfloat16 +dtypeOut=bfloat16 +socSupport=ascend910b + +[SinF16Kernel] +ops=sin +operationName=ElewiseOperation +inputCount=1 +outputCount=1 +dtypeIn=float16 +dtypeOut=float16 +mode=high_performance + +[SinBF16Kernel] +ops=sin +operationName=ElewiseOperation +inputCount=1 +outputCount=1 +dtypeIn=bfloat16 +dtypeOut=bfloat16 +mode=high_performance +socSupport=ascend910b + +[SinF32Kernel] +ops=sin +operationName=ElewiseOperation +inputCount=1 +outputCount=1 +dtypeIn=float32 +dtypeOut=float32 +mode=high_performance + +[SwishF16Kernel] +ops=swish +operationName=ElewiseOperation +inputCount=1 +outputCount=1 +dtypeIn=float16 +dtypeOut=float16 + + +[SwishBF16Kernel] +ops=swish +operationName=ElewiseOperation +inputCount=1 +outputCount=1 +dtypeIn=bfloat16 +dtypeOut=bfloat16 +socSupport=ascend910b + +[TanhF16Kernel] +ops=tanh +operationName=ElewiseOperation +inputCount=1 +outputCount=1 +dtypeIn=float16 +dtypeOut=float16 + +[TanhBF16Kernel] +ops=tanh +operationName=ElewiseOperation +inputCount=1 +outputCount=1 +dtypeIn=bfloat16 +dtypeOut=bfloat16 +socSupport=ascend910b + +[Gather16I64Kernel] +ops=gather_v2 +operationName=GatherOperation +inputCount=3 +outputCount=1 +dtypeIn=float16,int64,int64 +dtypeOut=float16 +attrs=None,False + +[Gather16I32Kernel] +ops=gather_v2 +operationName=GatherOperation +inputCount=3 +outputCount=1 +dtypeIn=float16,int32,int64 +dtypeOut=float16 +attrs=None,False + +[Gather32I64Kernel] +ops=gather_v2 +operationName=GatherOperation +inputCount=3 +outputCount=1 +dtypeIn=float32,int64,int64 +dtypeOut=float32 +attrs=None,False + +[Gather32I32Kernel] +ops=gather_v2 +operationName=GatherOperation +inputCount=3 +outputCount=1 +dtypeIn=float32,int32,int64 +dtypeOut=float32 +attrs=None,False + +[Gather64I64Kernel] +ops=gather_v2 +operationName=GatherOperation +inputCount=3 +outputCount=1 +dtypeIn=int64,int64,int64 +dtypeOut=int64 +attrs=None,False + +[Gather64I32Kernel] +ops=gather_v2 +operationName=GatherOperation +inputCount=3 +outputCount=1 +dtypeIn=int64,int32,int64 +dtypeOut=int64 +attrs=None,False + +[IndexAddF16Kernel] +ops=inplace_index_add +operationName=IndexOperation +inputCount=4 +outputCount=1 +dtypeIn=float16,int32,float16,float16 +dtypeOut=float16 +attrs=None +deterministic=False +socSupport=ascend310p,ascend910b,ascend910 + +[IndexAddBF16Kernel] +ops=inplace_index_add +operationName=IndexOperation +inputCount=4 +outputCount=1 +dtypeIn=bfloat16,int32,bfloat16,bfloat16 +dtypeOut=bfloat16 +attrs=None +deterministic=False +socSupport=ascend910b + +[MatMulNdF16Kernel] +ops=mat_mul +operationName=MatMulOperation +inputCount=4 +outputCount=1 +dtypeIn=float16,float16,, +dtypeOut=float16 +attrs=False,False +mode=high_performance + +[MatMulNdF32Kernel] +ops=mat_mul +operationName=MatMulOperation +inputCount=4 +outputCount=1 +dtypeIn=float32,float32,, +dtypeOut=float32 +attrs=False,False +mode=high_performance +socSupport=ascend910b +deterministic=False + +[MatMulNdF16TbKernel] +ops=mat_mul +operationName=MatMulOperation +inputCount=4 +outputCount=1 +dtypeIn=float16,float16,, +dtypeOut=float16 +attrs=False,True +mode=high_performance + +[BatchMatMulNdF16Kernel] +ops=batch_matmul_v2 +operationName=MatMulOperation +inputCount=4 +outputCount=1 +dtypeIn=float16,float16,, +dtypeOut=float16 +attrs=False,False +mode=high_performance + +[BatchMatMulNdF32Kernel] +ops=batch_matmul_v2 +operationName=MatMulOperation +inputCount=4 +outputCount=1 +dtypeIn=float32,float32,, +dtypeOut=float32 +attrs=False,False +mode=high_performance +socSupport=ascend910b +deterministic=False + +[BatchMatMulNdF16TbKernel] +ops=batch_matmul_v2 +operationName=MatMulOperation +inputCount=4 +outputCount=1 +dtypeIn=float16,float16,, +dtypeOut=float16 +attrs=False,True +mode=high_performance + +[MatMulNzF16Kernel] +ops=mat_mul +operationName=MatMulOperation +inputCount=4 +outputCount=1 +dtypeIn=float16,float16,, +dtypeOut=float16 +formatIn=FRACTAL_NZ,FRACTAL_NZ +formatOut=FRACTAL_NZ +attrs=False,False +mode=high_performance + +[MatMulNzF16TAKernel] +ops=mat_mul +operationName=MatMulOperation +inputCount=4 +outputCount=1 +dtypeIn=float16,float16,, +dtypeOut=float16 +formatIn=FRACTAL_NZ,FRACTAL_NZ +formatOut=FRACTAL_NZ +attrs=True,False +mode=high_performance + +[MatMulNzF16TBKernel] +ops=mat_mul +operationName=MatMulOperation +inputCount=4 +outputCount=1 +dtypeIn=float16,float16,, +dtypeOut=float16 +formatIn=FRACTAL_NZ,FRACTAL_NZ +formatOut=FRACTAL_NZ +attrs=False,True +mode=high_performance + +[MatMulNzF16TATBKernel] +ops=mat_mul +operationName=MatMulOperation +inputCount=4 +outputCount=1 +dtypeIn=float16,float16,, +dtypeOut=float16 +formatIn=FRACTAL_NZ,FRACTAL_NZ +formatOut=FRACTAL_NZ +attrs=True,True +mode=high_performance + +[BatchMatMulNzF16Kernel] +ops=batch_matmul_v2 +operationName=MatMulOperation +inputCount=4 +outputCount=1 +dtypeIn=float16,float16,, +dtypeOut=float16 +formatIn=FRACTAL_NZ,FRACTAL_NZ +formatOut=FRACTAL_NZ +attrs=False,False +mode=high_performance + +[BatchMatMulNzF16TAKernel] +ops=batch_matmul_v2 +operationName=MatMulOperation +inputCount=4 +outputCount=1 +dtypeIn=float16,float16,, +dtypeOut=float16 +formatIn=FRACTAL_NZ,FRACTAL_NZ +formatOut=FRACTAL_NZ +attrs=True,False +mode=high_performance + +[BatchMatMulNzF16TBKernel] +ops=batch_matmul_v2 +operationName=MatMulOperation +inputCount=4 +outputCount=1 +dtypeIn=float16,float16,, +dtypeOut=float16 +formatIn=FRACTAL_NZ,FRACTAL_NZ +formatOut=FRACTAL_NZ +attrs=False,True +mode=high_performance + +[BatchMatMulNzF16TATBKernel] +ops=batch_matmul_v2 +operationName=MatMulOperation +inputCount=4 +outputCount=1 +dtypeIn=float16,float16,, +dtypeOut=float16 +formatIn=FRACTAL_NZ,FRACTAL_NZ +formatOut=FRACTAL_NZ +attrs=True,True +mode=high_performance + +[PpMatMulF16NDF16NDF16NDKernel] +ops=pp_mat_mul +operationName=MatMulOperation +inputCount=4 +outputCount=1 +dtypeIn=float16,float16,float32,float32 +dtypeOut=float16 +formatIn=ND,ND,ND,ND +formatOut=ND +attrs=False,False,False,0 +mode=high_performance +socSupport=ascend910b + +[PpMatMulBF16NDBF16NDBF16NDKernel] +ops=pp_mat_mul +operationName=MatMulOperation +inputCount=4 +outputCount=1 +dtypeIn=bfloat16,bfloat16,float32,float32 +dtypeOut=bfloat16 +formatIn=ND,ND,ND,ND +formatOut=ND +attrs=False,False,False,0 +mode=high_performance +socSupport=ascend910b + +[PpMatMulF16NDF16NDF32NDKernel] +ops=pp_mat_mul +operationName=MatMulOperation +inputCount=4 +outputCount=1 +dtypeIn=float16,float16,float32,float32 +dtypeOut=float32 +formatIn=ND,ND,ND,ND +formatOut=ND +attrs=False,False,False,0 +mode=high_performance +socSupport=ascend910b + +[PpMatMulBF16NDBF16NDF32NDKernel] +ops=pp_mat_mul +operationName=MatMulOperation +inputCount=4 +outputCount=1 +dtypeIn=bfloat16,bfloat16,float32,float32 +dtypeOut=float32 +formatIn=ND,ND,ND,ND +formatOut=ND +attrs=False,False,False,0 +mode=high_performance +socSupport=ascend910b + +[PpMatMulF16NDF16NZF16NDKernel] +ops=pp_mat_mul +operationName=MatMulOperation +inputCount=4 +outputCount=1 +dtypeIn=float16,float16,float32,float32 +dtypeOut=float16 +formatIn=ND,FRACTAL_NZ,ND,ND +formatOut=ND +attrs=False,False,False,3 +mode=high_performance +socSupport=ascend910b + +[PpMatMulBF16NDBF16NZBF16NDKernel] +ops=pp_mat_mul +operationName=MatMulOperation +inputCount=4 +outputCount=1 +dtypeIn=bfloat16,bfloat16,float32,float32 +dtypeOut=bfloat16 +formatIn=ND,FRACTAL_NZ,ND,ND +formatOut=ND +attrs=False,False,False,3 +mode=high_performance +socSupport=ascend910b + +[PpMatMulF16NZF16NZF16NZKernel] +ops=pp_mat_mul +operationName=MatMulOperation +inputCount=4 +outputCount=1 +dtypeIn=float16,float16,float32,float32 +dtypeOut=float16 +formatIn=FRACTAL_NZ,FRACTAL_NZ,ND,ND +formatOut=FRACTAL_NZ +attrs=False,False,False,0 +mode=high_performance +socSupport=ascend310p + +[PpMatmulW8A8Kernel] +ops=pp_matmul_w8a8 +operationName=MatMulOperation +inputCount=5 +outputCount=1 +dtypeIn=int8,int8,int32,uint64,float32 +dtypeOut=float16 +formatIn=ND,ND,ND,ND,ND +formatOut=ND +attrs=False,False +mode=high_performance +socSupport=ascend910b + +[PpMatmulW8A8WeightNzKernel] +ops=pp_matmul_w8a8 +operationName=MatMulOperation +inputCount=5 +outputCount=1 +dtypeIn=int8,int8,int32,uint64,float32 +dtypeOut=float16 +formatIn=ND,FRACTAL_NZ,ND,ND,ND +formatOut=ND +attrs=False,False +mode=high_performance +socSupport=ascend910b + +[PpMatmulW8A8Bf16NDNDKernel] +ops=pp_matmul_w8a8 +operationName=MatMulOperation +inputCount=5 +outputCount=1 +dtypeIn=int8,int8,int32,float32,float32 +dtypeOut=bfloat16 +formatIn=ND,ND,ND,ND,ND +formatOut=ND +attrs=False,False +mode=high_performance +socSupport=ascend910b + +[PpMatmulW8A8Bf16NDNZKernel] +ops=pp_matmul_w8a8 +operationName=MatMulOperation +inputCount=5 +outputCount=1 +dtypeIn=int8,int8,int32,float32,float32 +dtypeOut=bfloat16 +formatIn=ND,FRACTAL_NZ,ND,ND,ND +formatOut=ND +attrs=False,False +mode=high_performance +socSupport=ascend910b + +[PpMatmulW8A8PertokenFP16Kernel] +ops=pp_matmul_w8a8 +operationName=MatMulOperation +inputCount=5 +outputCount=1 +dtypeIn=int8,int8,int32,float32,float32 +dtypeOut=float16 +formatIn=ND,ND,ND,ND,ND +formatOut=ND +attrs=False,False +mode=high_performance +socSupport=ascend910b + +[PpMatmulW8A8NzKernel] +ops=pp_matmul_w8a8 +operationName=MatMulOperation +inputCount=5 +outputCount=1 +dtypeIn=int8,int8,int32,uint64,float32 +dtypeOut=float16 +formatIn=FRACTAL_NZ,FRACTAL_NZ,ND,ND,ND +formatOut=FRACTAL_NZ +attrs=False,True +mode=high_performance +socSupport=ascend310p + +[PpMatmulW8A8NzCompressKernel] +ops=pp_matmul_w8a8_compress +operationName=MatMulOperation +inputCount=5 +outputCount=1 +dtypeIn=int8,int8,int32,uint64,int8 +dtypeOut=float16 +formatIn=FRACTAL_NZ,FRACTAL_NZ,ND,ND,ND +formatOut=FRACTAL_NZ +attrs=False,True,0,0 +mode=high_performance +socSupport=ascend310p + +[LayernormF16Kernel] +ops=layer_norm_v3 +operationName=NormOperation +inputCount=3 +outputCount=3 +dtypeIn=float16,float16,float16 +dtypeOut=float16,float16,float16 + +[LayernormBF16Kernel] +ops=layer_norm_v3 +operationName=NormOperation +inputCount=3 +outputCount=3 +dtypeIn=bfloat16,bfloat16,bfloat16 +dtypeOut=bfloat16,bfloat16,bfloat16 +socSupport=ascend910b + +[LayernormF32Kernel] +ops=layer_norm_v3 +operationName=NormOperation +inputCount=3 +outputCount=3 +dtypeIn=float32,float32,float32 +dtypeOut=float32,float32,float32 + +[RmsNormForwardKernelCanndev] +ops=rms_norm +operationName=NormOperation +inputCount=2 +outputCount=2 +dtypeIn=float16,float16 +dtypeOut=float16,float32 +mode=high_performance +socSupport=ascend910b + +[ReverseF16Kernel] +ops=reverse_v2 +operationName=ReverseOperation +inputCount=2 +outputCount=1 +dtypeIn=float16,int32 +dtypeOut=float16 +socSupport=ascend310p,ascend910b,ascend910 + +[ReverseF32Kernel] +ops=reverse_v2 +operationName=ReverseOperation +inputCount=2 +outputCount=1 +dtypeIn=float32,int32 +dtypeOut=float32 +socSupport=ascend910b + +[ReverseBF16Kernel] +ops=reverse_v2 +operationName=ReverseOperation +inputCount=2 +outputCount=1 +dtypeIn=bfloat16,int32 +dtypeOut=bfloat16 +socSupport=ascend910b + +[SoftmaxF16Kernel] +ops=softmax_v2 +operationName=SoftmaxOperation +inputCount=1 +outputCount=1 +dtypeIn=float16 +dtypeOut=float16 + +[SoftmaxBF16Kernel] +ops=softmax_v2 +operationName=SoftmaxOperation +inputCount=1 +outputCount=1 +dtypeIn=bfloat16 +dtypeOut=bfloat16 +socSupport=ascend910b + +[SoftmaxF32Kernel] +ops=softmax_v2 +operationName=SoftmaxOperation +inputCount=1 +outputCount=1 +dtypeIn=float32 +dtypeOut=float32 + +[SplitF16Output2Kernel] +ops=split_d +operationName=SplitOperation +inputCount=1 +outputCount=2 +dtypeIn=float16 +dtypeOut=float16,float16 + +[SplitF16Output3Kernel] +ops=split_d +operationName=SplitOperation +inputCount=1 +outputCount=3 +dtypeIn=float16 +dtypeOut=float16,float16,float16 + +[SplitInt64Output2Kernel] +ops=split_d +operationName=SplitOperation +inputCount=1 +outputCount=2 +dtypeIn=int64 +dtypeOut=int64,int64 + +[SplitVF16Output2Kernel] +ops=split_v +operationName=SplitOperation +inputCount=3 +outputCount=2 +dtypeIn=float16,int32,int32 +dtypeOut=float16,float16 + +[SplitVF16Output3Kernel] +ops=split_v +operationName=SplitOperation +inputCount=3 +outputCount=3 +dtypeIn=float16,int32,int32 +dtypeOut=float16,float16,float16 + +[SplitVInt64Output2Kernel] +ops=split_v +operationName=SplitOperation +inputCount=3 +outputCount=2 +dtypeIn=int64,int32,int32 +dtypeOut=int64,int64 + +[TopKDescF16Kernel] +ops=top_k_v2 +operationName=SortOperation +inputCount=2 +outputCount=2 +dtypeIn=float16,int32 +dtypeOut=float16,int32 +attrs=True,None,True + +[TopKDescBF16Kernel] +ops=top_k_v2 +operationName=SortOperation +inputCount=2 +outputCount=2 +dtypeIn=bfloat16,int32 +dtypeOut=bfloat16,int32 +attrs=True,None,True +socSupport=ascend910b + +[TopKDescF32Kernel] +ops=top_k_v2 +operationName=SortOperation +inputCount=2 +outputCount=2 +dtypeIn=float32,int32 +dtypeOut=float32,int32 +attrs=True,None,True +socSupport=ascend910b + +[TransdataNzToNdKernel] +ops=trans_data +operationName=TransdataOperation +inputCount=1 +outputCount=1 +dtypeIn=float16 +dtypeOut=float16 +formatIn=FRACTAL_NZ +formatOut=ND + +[TransdataNdToNzKernel] +ops=trans_data +operationName=TransdataOperation +inputCount=1 +outputCount=1 +dtypeIn=float16 +dtypeOut=float16 +formatIn=ND +formatOut=FRACTAL_NZ + +[TransdataNdToNzInt8Kernel] +ops=trans_data +operationName=TransdataOperation +inputCount=1 +outputCount=1 +dtypeIn=int8 +dtypeOut=int8 +formatIn=ND +formatOut=FRACTAL_NZ + +[Transpose8Kernel] +ops=transpose +operationName=TransposeOperation +inputCount=2 +outputCount=1 +dtypeIn=int8,int32 +dtypeOut=int8 +socSupport=ascend910b,ascend310p + +[Transpose16Kernel] +ops=transpose +operationName=TransposeOperation +inputCount=2 +outputCount=1 +dtypeIn=float16,int32 +dtypeOut=float16 + +[Transpose32Kernel] +ops=transpose +operationName=TransposeOperation +inputCount=2 +outputCount=1 +dtypeIn=float32,int32 +dtypeOut=float32 +socSupport=ascend910b,ascend310p + +[Transpose64Kernel] +ops=transpose +operationName=TransposeOperation +inputCount=2 +outputCount=1 +dtypeIn=int64,int32 +dtypeOut=int64 + +[SliceF16Int64Kernel] +ops=slice +operationName=SliceOperation +inputCount=3 +outputCount=1 +dtypeIn=float16,int64,int64 +dtypeOut=float16 + +[SliceInt8Int64Kernel] +ops=slice +operationName=SliceOperation +inputCount=3 +outputCount=1 +dtypeIn=int8,int64,int64 +dtypeOut=int8 + +[SliceInt32Int64Kernel] +ops=slice +operationName=SliceOperation +inputCount=3 +outputCount=1 +dtypeIn=int32,int64,int64 +dtypeOut=int32 + +[ViewCopyInt64Kernel] +ops=view_copy +operationName=CopyOperation +inputCount=8 +outputCount=1 +dtypeIn=int64,int64,int64,int64,int64,int64,int64,int64 +dtypeOut=int64 +mode=high_performance + +[ViewCopyInt32Kernel] +ops=view_copy +operationName=CopyOperation +inputCount=8 +outputCount=1 +dtypeIn=int32,int64,int64,int64,int32,int64,int64,int64 +dtypeOut=int32 +mode=high_performance + +[ViewCopyF32Kernel] +ops=view_copy +operationName=CopyOperation +inputCount=8 +outputCount=1 +dtypeIn=float32,int64,int64,int64,float32,int64,int64,int64 +dtypeOut=float32 +mode=high_performance + +[ViewCopyF16Kernel] +ops=view_copy +operationName=CopyOperation +inputCount=8 +outputCount=1 +dtypeIn=float16,int64,int64,int64,float16,int64,int64,int64 +dtypeOut=float16 +mode=high_performance + +[ViewCopyBF16Kernel] +ops=view_copy +operationName=CopyOperation +inputCount=8 +outputCount=1 +dtypeIn=bfloat16,int64,int64,int64,bfloat16,int64,int64,int64 +dtypeOut=bfloat16 +mode=high_performance +socSupport=ascend910b + +[LogF32Kernel] +ops=log +operationName=ElewiseOperation +inputCount=1 +outputCount=1 +dtypeIn=float32 +dtypeOut=float32 +attrs=-1.0,1.0,0.0 + +[LogF16Kernel] +ops=log +operationName=ElewiseOperation +inputCount=1 +outputCount=1 +dtypeIn=float16 +dtypeOut=float16 +attrs=-1.0,1.0,0.0 + +[LogBF16Kernel] +ops=log +operationName=ElewiseOperation +inputCount=1 +outputCount=1 +dtypeIn=bfloat16 +dtypeOut=bfloat16 +attrs=-1.0,1.0,0.0 +socSupport=ascend910b + +[QuantPerChannelKernel] +ops=quant_per_channel +operationName=ElewiseOperation +inputCount=3 +outputCount=1 +dtypeIn=float16,float16,int8 +dtypeOut=int8 +mode=high_performance +socSupport=ascend910b,ascend310p + +[DequantPerChannelKernel] +ops=dequant_per_channel +operationName=ElewiseOperation +inputCount=3 +outputCount=1 +dtypeIn=int8,float16,int8 +dtypeOut=float16 +mode=high_performance +socSupport=ascend910b,ascend310p + +[ReduceMaxInt32Kernel] +ops=reduce_max +operationName=ReduceOperation +inputCount=2 +outputCount=1 +dtypeIn=int32,int64 +dtypeOut=int32 +attrs=False + +[ReduceMaxBF16Kernel] +ops=reduce_max +operationName=ReduceOperation +inputCount=2 +outputCount=1 +dtypeIn=bfloat16,int64 +dtypeOut=bfloat16 +attrs=False +socSupport=ascend910b + +[ReduceMinInt32Kernel] +ops=reduce_min +operationName=ReduceOperation +inputCount=2 +outputCount=1 +dtypeIn=int32,int64 +dtypeOut=int32 +attrs=False + +[ReduceMinBF16Kernel] +ops=reduce_min +operationName=ReduceOperation +inputCount=2 +outputCount=1 +dtypeIn=bfloat16,int64 +dtypeOut=bfloat16 +attrs=False +socSupport=ascend910b + +[ReduceSumF16Kernel] +ops=reduce_sum +operationName=ReduceOperation +inputCount=2 +outputCount=1 +dtypeIn=float16,int64 +dtypeOut=float16 + +[ReduceSumBF16Kernel] +ops=reduce_sum +operationName=ReduceOperation +inputCount=2 +outputCount=1 +dtypeIn=bfloat16,int64 +dtypeOut=bfloat16 +socSupport=ascend910b + +[CumsumF16Kernel] +ops=cumsum +operationName=CumsumOperation +inputCount=2 +outputCount=1 +dtypeIn=float16,int64 +dtypeOut=float16 +attrs=False,False +deterministic=False + +[CumsumBF16Kernel] +ops=cumsum +operationName=CumsumOperation +inputCount=2 +outputCount=1 +dtypeIn=bfloat16,int64 +dtypeOut=bfloat16 +attrs=False,False +deterministic=False +socSupport=ascend910b + +[CumsumF16DtmKernel] +ops=cumsum +operationName=CumsumOperation +inputCount=2 +outputCount=1 +dtypeIn=float16,int64 +dtypeOut=float16 +attrs=False,False +deterministic=True + +[CumsumBF16DtmKernel] +ops=cumsum +operationName=CumsumOperation +inputCount=2 +outputCount=1 +dtypeIn=bfloat16,int64 +dtypeOut=bfloat16 +attrs=False,False +deterministic=True +socSupport=ascend910b + +[SigmoidF16Kernel] +ops=sigmoid +operationName=ActivationOperation +inputCount=1 +outputCount=1 +dtypeIn=float16 +dtypeOut=float16 +mode=high_precision + +[SigmoidBF16Kernel] +ops=sigmoid +operationName=ActivationOperation +inputCount=1 +outputCount=1 +dtypeIn=bfloat16 +dtypeOut=bfloat16 +mode=high_precision +socSupport=ascend910b + +[ZerosLikeF32Kernel] +ops=zeros_like +operationName=ZerosLikeOperation +inputCount=1 +outputCount=1 +dtypeIn=float32 +dtypeOut=float32 +socSupport=ascend910b + +[ScatterElementsV2Int32Int32NoneKernel] +ops=scatter_elements_v2 +operationName=ScatterElementsV2Operation +inputCount=3 +outputCount=1 +formatIn=ND,ND,ND +formatOut=ND +mode=high_performance +dtypeIn=int32,int32,int32 +dtypeOut=int32 +attrs=None,none +socSupport=ascend910b + +[ScatterElementsV2Int32Int64AddKernel] +ops=scatter_elements_v2 +operationName=ScatterElementsV2Operation +inputCount=3 +outputCount=1 +formatIn=ND,ND,ND +formatOut=ND +mode=high_performance +dtypeIn=int32,int64,int32 +dtypeOut=int32 +attrs=None,add +socSupport=ascend910b + +[ScatterElementsV2Bfloat16Int32AddKernel] +ops=scatter_elements_v2 +operationName=ScatterElementsV2Operation +inputCount=3 +outputCount=1 +formatIn=ND,ND,ND +formatOut=ND +mode=high_performance +dtypeIn=bfloat16,int32,bfloat16 +dtypeOut=bfloat16 +attrs=None,add +socSupport=ascend910b + +[ScatterElementsV2Uint8Int64NoneKernel] +ops=scatter_elements_v2 +operationName=ScatterElementsV2Operation +inputCount=3 +outputCount=1 +formatIn=ND,ND,ND +formatOut=ND +mode=high_performance +dtypeIn=uint8,int64,uint8 +dtypeOut=uint8 +attrs=None,none +socSupport=ascend910b + +[ScatterElementsV2Uint8Int32AddKernel] +ops=scatter_elements_v2 +operationName=ScatterElementsV2Operation +inputCount=3 +outputCount=1 +formatIn=ND,ND,ND +formatOut=ND +mode=high_performance +dtypeIn=uint8,int32,uint8 +dtypeOut=uint8 +attrs=None,add +socSupport=ascend910b + +[ScatterElementsV2Int8Int32AddKernel] +ops=scatter_elements_v2 +operationName=ScatterElementsV2Operation +inputCount=3 +outputCount=1 +formatIn=ND,ND,ND +formatOut=ND +mode=high_performance +dtypeIn=int8,int32,int8 +dtypeOut=int8 +attrs=None,add +socSupport=ascend910b + +[ScatterElementsV2Int32Int64NoneKernel] +ops=scatter_elements_v2 +operationName=ScatterElementsV2Operation +inputCount=3 +outputCount=1 +formatIn=ND,ND,ND +formatOut=ND +mode=high_performance +dtypeIn=int32,int64,int32 +dtypeOut=int32 +attrs=None,none +socSupport=ascend910b + +[ScatterElementsV2Float16Int32NoneKernel] +ops=scatter_elements_v2 +operationName=ScatterElementsV2Operation +inputCount=3 +outputCount=1 +formatIn=ND,ND,ND +formatOut=ND +mode=high_performance +dtypeIn=float16,int32,float16 +dtypeOut=float16 +attrs=None,none +socSupport=ascend910b + +[ScatterElementsV2Float32Int32AddKernel] +ops=scatter_elements_v2 +operationName=ScatterElementsV2Operation +inputCount=3 +outputCount=1 +formatIn=ND,ND,ND +formatOut=ND +mode=high_performance +dtypeIn=float32,int32,float32 +dtypeOut=float32 +attrs=None,add +socSupport=ascend910b + +[ScatterElementsV2Bfloat16Int32NoneKernel] +ops=scatter_elements_v2 +operationName=ScatterElementsV2Operation +inputCount=3 +outputCount=1 +formatIn=ND,ND,ND +formatOut=ND +mode=high_performance +dtypeIn=bfloat16,int32,bfloat16 +dtypeOut=bfloat16 +attrs=None,none +socSupport=ascend910b + +[ScatterElementsV2Bfloat16Int64AddKernel] +ops=scatter_elements_v2 +operationName=ScatterElementsV2Operation +inputCount=3 +outputCount=1 +formatIn=ND,ND,ND +formatOut=ND +mode=high_performance +dtypeIn=bfloat16,int64,bfloat16 +dtypeOut=bfloat16 +attrs=None,add +socSupport=ascend910b + +[ScatterElementsV2Bfloat16Int64NoneKernel] +ops=scatter_elements_v2 +operationName=ScatterElementsV2Operation +inputCount=3 +outputCount=1 +formatIn=ND,ND,ND +formatOut=ND +mode=high_performance +dtypeIn=bfloat16,int64,bfloat16 +dtypeOut=bfloat16 +attrs=None,none +socSupport=ascend910b + +[ScatterElementsV2Float32Int64NoneKernel] +ops=scatter_elements_v2 +operationName=ScatterElementsV2Operation +inputCount=3 +outputCount=1 +formatIn=ND,ND,ND +formatOut=ND +mode=high_performance +dtypeIn=float32,int64,float32 +dtypeOut=float32 +attrs=None,none +socSupport=ascend910b + +[ScatterElementsV2Float32Int64AddKernel] +ops=scatter_elements_v2 +operationName=ScatterElementsV2Operation +inputCount=3 +outputCount=1 +formatIn=ND,ND,ND +formatOut=ND +mode=high_performance +dtypeIn=float32,int64,float32 +dtypeOut=float32 +attrs=None,add +socSupport=ascend910b + +[ScatterElementsV2Float16Int64AddKernel] +ops=scatter_elements_v2 +operationName=ScatterElementsV2Operation +inputCount=3 +outputCount=1 +formatIn=ND,ND,ND +formatOut=ND +mode=high_performance +dtypeIn=float16,int64,float16 +dtypeOut=float16 +attrs=None,add +socSupport=ascend910b + +[ScatterElementsV2Float16Int32AddKernel] +ops=scatter_elements_v2 +operationName=ScatterElementsV2Operation +inputCount=3 +outputCount=1 +formatIn=ND,ND,ND +formatOut=ND +mode=high_performance +dtypeIn=float16,int32,float16 +dtypeOut=float16 +attrs=None,add +socSupport=ascend910b + +[ScatterElementsV2Int8Int32NoneKernel] +ops=scatter_elements_v2 +operationName=ScatterElementsV2Operation +inputCount=3 +outputCount=1 +formatIn=ND,ND,ND +formatOut=ND +mode=high_performance +dtypeIn=int8,int32,int8 +dtypeOut=int8 +attrs=None,none +socSupport=ascend910b + +[ScatterElementsV2Uint8Int64AddKernel] +ops=scatter_elements_v2 +operationName=ScatterElementsV2Operation +inputCount=3 +outputCount=1 +formatIn=ND,ND,ND +formatOut=ND +mode=high_performance +dtypeIn=uint8,int64,uint8 +dtypeOut=uint8 +attrs=None,add +socSupport=ascend910b + +[ScatterElementsV2Int32Int32AddKernel] +ops=scatter_elements_v2 +operationName=ScatterElementsV2Operation +inputCount=3 +outputCount=1 +formatIn=ND,ND,ND +formatOut=ND +mode=high_performance +dtypeIn=int32,int32,int32 +dtypeOut=int32 +attrs=None,add +socSupport=ascend910b + +[ScatterElementsV2Int8Int64AddKernel] +ops=scatter_elements_v2 +operationName=ScatterElementsV2Operation +inputCount=3 +outputCount=1 +formatIn=ND,ND,ND +formatOut=ND +mode=high_performance +dtypeIn=int8,int64,int8 +dtypeOut=int8 +attrs=None,add +socSupport=ascend910b + +[ScatterElementsV2Float16Int64NoneKernel] +ops=scatter_elements_v2 +operationName=ScatterElementsV2Operation +inputCount=3 +outputCount=1 +formatIn=ND,ND,ND +formatOut=ND +mode=high_performance +dtypeIn=float16,int64,float16 +dtypeOut=float16 +attrs=None,none +socSupport=ascend910b + +[ScatterElementsV2Int8Int64NoneKernel] +ops=scatter_elements_v2 +operationName=ScatterElementsV2Operation +inputCount=3 +outputCount=1 +formatIn=ND,ND,ND +formatOut=ND +mode=high_performance +dtypeIn=int8,int64,int8 +dtypeOut=int8 +attrs=None,none +socSupport=ascend910b + +[ScatterElementsV2Uint8Int32NoneKernel] +ops=scatter_elements_v2 +operationName=ScatterElementsV2Operation +inputCount=3 +outputCount=1 +formatIn=ND,ND,ND +formatOut=ND +mode=high_performance +dtypeIn=uint8,int32,uint8 +dtypeOut=uint8 +attrs=None,none +socSupport=ascend910b + +[ScatterElementsV2Float32Int32NoneKernel] +ops=scatter_elements_v2 +operationName=ScatterElementsV2Operation +inputCount=3 +outputCount=1 +formatIn=ND,ND,ND +formatOut=ND +mode=high_performance +dtypeIn=float32,int32,float32 +dtypeOut=float32 +attrs=None,none +socSupport=ascend910b diff --git a/scripts/build.sh b/scripts/build.sh index 8f67a76d798d810c7d606c7977c58804bcff9d13..cfb24e93f4f5368120add067cf83f90124e1705f 100644 --- a/scripts/build.sh +++ b/scripts/build.sh @@ -12,8 +12,8 @@ set -e SCRIPT_DIR=$(cd $(dirname $0); pwd) cd $SCRIPT_DIR/.. -CODE_ROOT=$(pwd) -CACHE_DIR=$CODE_ROOT/build +export CODE_ROOT=$(pwd) +export CACHE_DIR=$CODE_ROOT/build OUTPUT_DIR=$CODE_ROOT/output THIRD_PARTY_DIR=$CODE_ROOT/3rdparty @@ -26,6 +26,10 @@ USE_ASAN=OFF USE_MSSANITIZER=OFF USE_MSDEBUG=OFF SKIP_BUILD=OFF +BUILD_TBE_ADAPTER=OFF +DEPENDENCY_DIR=2025-06-18 +ASDOPS_SOURCE_DIR=/tmp/asdops_dependency/$DEPENDENCY_DIR +CMC_URL=https://cmc-szver-artifactory.cmc.tools.huawei.com/artifactory/cmc-software-release/Baize%20C/AscendTransformerBoost/1.0.0/asdops_dependency/$DEPENDENCY_DIR CSVOPSTEST_OPTIONS="" BUILD_PYBIND=ON SRC_ONLY=OFF @@ -161,6 +165,7 @@ function fn_build_asdops() rm -f $THIRD_PARTY_DIR/asdops/lib/libatb_mixops.so 2> /dev/null rm -f $THIRD_PARTY_DIR/asdops/lib/libatb_mixops_static.a 2> /dev/null rm -f $THIRD_PARTY_DIR/asdops/lib/libmki.so 2> /dev/null + rm -f $THIRD_PARTY_DIR/asdops/lib/libtbe_adapter.so 2> /dev/null return 0 fi cd $THIRD_PARTY_DIR @@ -304,6 +309,119 @@ function fn_build_3rdparty_for_test() fn_build_stub } +function fn_get_cxx_abi_string() +{ + if [ "${USE_CXX11_ABI}" == "ON" ]; then + echo "cxx_abi_1" + else + echo "cxx_abi_0" + fi +} + +function fn_copy_tbe_adapter() +{ + if [ ! -f $ATB_BUILD_DEPENDENCY_PATH/lib/libtbe_adapter.so ]; then + echo "error:$ATB_BUILD_DEPENDENCY_PATH/lib/libtbe_adapter.so dose not exist, please source set_env.sh." + return 0 + fi + + LOCAL_ABI=$(fn_get_cxx_abi_string) + + LOCAL_TBE_ADAPTER_PATH=$CODE_ROOT/output/atb/$LOCAL_ABI + TARGET_ABI=$(basename "$ATB_BUILD_DEPENDENCY_PATH") + if [ "${TARGET_ABI}" != "${LOCAL_ABI}" ];then + echo "$ATB_BUILD_DEPENDENCY_PATH use $TARGET_ABI, but $LOCAL_TBE_ADAPTER_PATH use $LOCAL_ABI, abi error." + return 0 + fi + + if [ "${ATB_BUILD_DEPENDENCY_PATH}" != "${LOCAL_TBE_ADAPTER_PATH}" ]; then + if [ -d $LOCAL_TBE_ADAPTER_PATH ]; then + cp ${ATB_BUILD_DEPENDENCY_PATH}/lib/libtbe_adapter.so $LOCAL_TBE_ADAPTER_PATH/lib + fi + fi +} + +function fn_check_dependency_cache() +{ + [[ ! -d "$ASDOPS_SOURCE_DIR" ]] && mkdir -p $ASDOPS_SOURCE_DIR + cd $ASDOPS_SOURCE_DIR + if [ ! -d "$ASDOPS_SOURCE_DIR/opp_kernel" ]; then + [[ ! -f "$ASDOPS_SOURCE_DIR/asdops_opp_kernel.tar.gz" ]] && wget --no-check-certificate $CMC_URL/asdops_opp_kernel.tar.gz + tar xf asdops_opp_kernel.tar.gz + rm asdops_opp_kernel.tar.gz + fi + echo "$ASDOPS_SOURCE_DIR!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!" + if [ ! -d "$ASDOPS_SOURCE_DIR/canndev" ]; then + [[ ! -f "$ASDOPS_SOURCE_DIR/canndev.tar.gz" ]] && wget --no-check-certificate $CMC_URL/canndev.tar.gz + tar xf canndev.tar.gz + rm canndev.tar.gz + fi + if [ ! -d "$ASDOPS_SOURCE_DIR/metadef" ]; then + [[ ! -f "$ASDOPS_SOURCE_DIR/metadef.tar.gz" ]] && wget --no-check-certificate $CMC_URL/metadef.tar.gz + tar xf metadef.tar.gz + rm metadef.tar.gz + fi + if [ ! -d "$ASDOPS_SOURCE_DIR/cann-ops-adv" ]; then + [[ ! -f "$ASDOPS_SOURCE_DIR/cann-ops-adv.tar.gz" ]] && wget --no-check-certificate $CMC_URL/cann-ops-adv.tar.gz + tar xf cann-ops-adv.tar.gz + rm cann-ops-adv.tar.gz + fi + if [ ! -d "$ASDOPS_SOURCE_DIR/api" ]; then + [[ ! -f "$ASDOPS_SOURCE_DIR/api.tar.gz" ]] && wget --no-check-certificate $CMC_URL/api.tar.gz + tar xf api.tar.gz + rm api.tar.gz + fi + echo "dependency_cache is ready" +} + +function fn_build_tbe_adapter_dependency() +{ + CCEC_COMPILER_DIR=$THIRD_PARTY_DIR/compiler/ccec_compiler + TIKCPP_DIR=$THIRD_PARTY_DIR/compiler/tikcpp + + CANNDEV_DIR=$THIRD_PARTY_DIR/canndev + METADEF_DIR=$THIRD_PARTY_DIR/metadef + API_DIR=$THIRD_PARTY_DIR/api + CANN_OPS_DIR=$THIRD_PARTY_DIR/cann-ops-adv + TBE_ADAPTER_DIR=$CODE_ROOT/src/kernels/tbe_adapter + + # dev + fn_check_dependency_cache + export ASCEND_KERNEL_PATH=$ASDOPS_SOURCE_DIR/opp_kernel + + #tbe_adapter dependency + SRC_FILE_LINE_NUM=$(wc -l < "$TBE_ADAPTER_DIR/stubs/include/canndev/ops/built-in/op_tiling/op_tiling.h") + DST_FILE_LINE_NUM=$(wc -l < "$THIRD_PARTY_DIR/canndev/ops/built-in/op_tiling/op_tiling.h") + [[ ! -d "$CANNDEV_DIR" ]] && cp -r $ASDOPS_SOURCE_DIR/canndev $CANNDEV_DIR + [[ ! -d "$API_DIR" ]] && cp -r $ASDOPS_SOURCE_DIR/api $API_DIR + [[ ! -d "$CANN_OPS_DIR" ]] && cp -r $ASDOPS_SOURCE_DIR/cann-ops-adv $CANN_OPS_DIR + #determine whether these two files are identical + [[ "$SRC_FILE_LINE_NUM" != "$DST_FILE_LINE_NUM" ]] && cp -r $TBE_ADAPTER_DIR/stubs/include/canndev $THIRD_PARTY_DIR + [[ "$SRC_FILE_LINE_NUM" != "$DST_FILE_LINE_NUM" ]] && cp -r $TBE_ADAPTER_DIR/stubs/include/api $THIRD_PARTY_DIR + if [ ! -d "$METADEF_DIR" ];then + cp -r $ASDOPS_SOURCE_DIR/metadef $METADEF_DIR + fi +} + +function fn_build_tbe_dependency() +{ + LOCAL_ABI=$(fn_get_cxx_abi_string) + if [ -f $OUTPUT_DIR/atb/$LOCAL_ABI/lib/libtbe_adapter.so ];then + echo "libtbe_adapter.so is already exist, skip build process." + BUILD_TBE_ADAPTER=OFF + return 0 + fi + + if [ -n "$ATB_BUILD_DEPENDENCY_PATH" ];then + #copy from nnal + fn_copy_tbe_adapter + else + #build by source code + BUILD_TBE_ADAPTER=ON + fn_build_tbe_adapter_dependency + fi +} + function fn_build_3rdparty_for_compile() { fn_build_nlohmann_json @@ -311,9 +429,11 @@ function fn_build_3rdparty_for_compile() fn_build_catlass fn_build_asdops fn_build_cann_dependency + fn_build_tbe_dependency if [ "$BUILD_PYBIND" == "ON" -a "$USE_CXX11_ABI" != "ON" ]; then fn_build_pybind11 fi + COMPILE_OPTIONS="${COMPILE_OPTIONS} -DBUILD_TBE_ADAPTER=$BUILD_TBE_ADAPTER" } function fn_build_3rdparty_for_doc() @@ -765,7 +885,8 @@ function fn_main() fn_init_env COMPILE_OPTIONS="${COMPILE_OPTIONS} -DCMAKE_BUILD_TYPE=$CMAKE_BUILD_TYPE \ - -DUSE_CXX11_ABI=$USE_CXX11_ABI -DUSE_ASAN=$USE_ASAN -DBUILD_PYBIND=$BUILD_PYBIND -DUSE_MSSANITIZER=$USE_MSSANITIZER" + -DUSE_CXX11_ABI=$USE_CXX11_ABI -DUSE_ASAN=$USE_ASAN -DBUILD_PYBIND=$BUILD_PYBIND -DUSE_MSSANITIZER=$USE_MSSANITIZER \ + -DPACKAGE_COMPILE=OFF" case "${arg1}" in "default") MKI_BUILD_MODE=Dev diff --git a/scripts/build_util.py b/scripts/build_util.py new file mode 100644 index 0000000000000000000000000000000000000000..424038f5ae2b6f733b3489ea4e2674cc6e3b7149 --- /dev/null +++ b/scripts/build_util.py @@ -0,0 +1,236 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2024 Huawei Technologies Co., Ltd. +# This file is a part of the CANN Open Software. +# Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# + +import os +import configparser +import json +import logging +import shutil +import stat +import re + + +# sycl-target --show-targets +def get_build_target_list(): + usr_config_file_path = os.getenv("BUILD_CONFIG_FILE", '') + if usr_config_file_path == '': + script_file_path = os.path.realpath(__file__) + build_config_json_file_path = os.path.join(os.path.dirname( + script_file_path), "../configs/build_config.json") + else: + build_config_json_file_path = usr_config_file_path + device_list = [] + try: + with open(build_config_json_file_path) as conf_file: + conf = json.load(conf_file) + target_option = conf['targets'] + for target, switch in target_option.items(): + if switch is True: + device_list.append(target) + except FileNotFoundError: + logging.error("file %s is not found!", build_config_json_file_path) + exit(1) + except json.decoder.JSONDecodeError: + logging.error("file %s is not json file!", build_config_json_file_path) + exit(1) + except KeyError: + logging.error("key 'targets' is not found in %s!", build_config_json_file_path) + exit(1) + + if len(device_list) == 0: + logging.error("no target device is set") + exit(1) + + device_list = list(set(device_list)) + return device_list + + +def get_info_from_file(file_path): + result = True + tactic_info = dict() + magic_dict = {"RT_DEV_BINARY_MAGIC_ELF": str(0x43554245), + "RT_DEV_BINARY_MAGIC_ELF_AIVEC": str(0x41415246), + "RT_DEV_BINARY_MAGIC_ELF_AICUBE": str(0x41494343)} + try: + with open(file_path) as f: + text = json.load(f) + tactic_info["binFileName"] = text["binFileName"] + tactic_info["compileInfo"] = text["compileInfo"] + tactic_info["opParaSize"] = text["opParaSize"] + tactic_info["coreType"] = text["coreType"] if text["coreType"] else "" + magic = text["magic"] + if magic not in magic_dict: + logging.error("magic %s is invalid", magic) + result = False + else: + tactic_info["magic"] = magic_dict[magic] + if "kernelList" in text: + tactic_info["kernelList"] = ','.join( + item["kernelName"] for item in text["kernelList"] + ) + else: + tactic_info["kernelList"] = text["kernelName"] + except FileNotFoundError: + logging.error("file %s is not found!", file_path) + result = False + except json.decoder.JSONDecodeError: + logging.error("file %s is not json file!", file_path) + result = False + except KeyError: + logging.error("keyerror in file %s!", file_path) + result = False + return tactic_info, result + + +def write_meta(meta_info, output_path, target_version): + meta_path = os.path.join(output_path, 'meta.ini') + with os.fdopen(os.open(meta_path, os.O_WRONLY | os.O_CREAT, stat.S_IWUSR | stat.S_IRUSR), 'w+') as fmeta: + fmeta.write("$Version=1.0\n") + fmeta.write(f"DeviceKernelVersion={target_version}\n") + fmeta.write(f"$Object.Count={len(meta_info)}\n") + for key, value in meta_info.items(): + fmeta.write(f"{key}.Object={value[0]}\n") + fmeta.write(f"{key}.OpName={value[1]}\n") + fmeta.write(f"{key}.KernelList={value[2]}\n") + fmeta.write(f"{key}.CompileInfo={value[3]}\n") + fmeta.write(f"{key}.TilingSize={value[4]}\n") + fmeta.write(f"{key}.CoreType={value[5]}\n") + fmeta.write(f"{key}.Magic={value[6]}\n") + fmeta.write("$End=1\n") + + +# 目前只支持一个tactic文件夹下一个.o和.json文件 +def copy_ascendc_code(meta_info, env_cache_dir, target_version, output_path): + op_kernels_version_dir = os.path.join( + env_cache_dir, "asdops_kernels", target_version) + if not os.path.exists(op_kernels_version_dir): + return 0 + code_file_count = 0 + for operation in os.listdir(op_kernels_version_dir): + operation_dir = os.path.join(op_kernels_version_dir, operation) + output_operation_dir = os.path.join(output_path, operation) + if not os.path.exists(output_operation_dir): + os.makedirs(output_operation_dir) + for tactic in os.listdir(operation_dir): + tactic_dir = os.path.join(operation_dir, tactic) + for file in os.listdir(tactic_dir): + if not file.endswith('.json'): + continue + code_file = os.path.join(tactic_dir, "".join([file[:-4], 'o'])) + if not os.path.exists(code_file): + logging.error("file %s has no object file.", file) + exit(1) + json_file = os.path.join(tactic_dir, file) + tactic_info, result = get_info_from_file(json_file) + if not result: + logging.error("failed to parse file %s.", json_file) + exit(1) + relative_to_path = os.path.join(operation, file) + to_path = os.path.join(output_operation_dir, file) + shutil.copyfile(code_file, to_path) + try: + compile_info_str = json.dumps(tactic_info["compileInfo"]) + meta_info[tactic] = ( + relative_to_path, tactic, tactic_info["kernelList"], compile_info_str, + tactic_info["opParaSize"], tactic_info["coreType"], tactic_info["magic"]) + except KeyError: + logging.error("%s get compile or meta info error", tactic) + exit(1) + code_file_count += 1 + return code_file_count + + +def copy_tbe_code_all_version(input_paras): + tbe_sections = input_paras["tbe_ini"].sections() + for target_version in input_paras["target_version_list"]: + output_path = os.path.join( + input_paras["env_cache_dir"], "mix_obj", target_version) + if not os.path.exists(output_path): + os.makedirs(output_path) + meta_info = {} + target_version_path = os.path.join( + input_paras["tbe_kernel_path"], target_version) + + for op_name in tbe_sections: + op_dir_path = os.path.join(output_path, op_name) + if not os.path.exists(op_dir_path): + os.mkdir(op_dir_path) + items = dict(input_paras["tbe_ini"].items(op_name)) + for op_key, relative_op_path in items.items(): + if '.' in op_key: + op_key, version_op_key = op_key.split('.') + if version_op_key != target_version: + continue + + tactic_info, ret = get_info_from_file(os.path.join( + target_version_path, relative_op_path)) + if not ret: + logging.error("failed to parse json file %s", relative_op_path) + exit(1) + + from_path = os.path.join( + target_version_path, "".join([relative_op_path[:-4], 'o'])) + object_name = os.path.basename(from_path) + to_path = os.path.join(op_dir_path, object_name) + relative_to_path = os.path.join(op_name, object_name) + shutil.copyfile(from_path, to_path) + if op_key not in meta_info: + try: + compile_info_str = json.dumps(tactic_info["compileInfo"]) + meta_info[op_key] = ( + relative_to_path, op_name, tactic_info["kernelList"], compile_info_str, + tactic_info["opParaSize"], tactic_info["coreType"], tactic_info["magic"]) + except KeyError: + logging.error("%s get compile or meta info error", op_name) + exit(1) + + ascendc_file_count = copy_ascendc_code( + meta_info, input_paras["env_cache_dir"], target_version, output_path) + logging.info( + f"{target_version} has {ascendc_file_count} AscendC tactics.") + + write_meta(meta_info, output_path, target_version) + + +def copy_tbe_device_code(): + env_code_root = os.getenv("CODE_ROOT") + env_cache_dir = os.getenv("CACHE_DIR") + tbe_kernel_path = os.getenv("ASCEND_KERNEL_PATH") + if not (env_code_root and env_cache_dir and tbe_kernel_path): + logging.error( + "env CODE_ROOT | OUTPUT_DIR | ASDOPS_SOURCE_DIR not exist!") + exit(1) + logging.info(f"tbe_kernel_path: {tbe_kernel_path}") + input_path = os.path.join(env_code_root, "configs/tbe_tactic_json.ini") + if not os.path.exists(input_path): + logging.error("ini file: %s not exist!", input_path) + exit(1) + tbe_ini = configparser.RawConfigParser() + tbe_ini.optionxform = lambda option: option + try: + tbe_ini.read(input_path) + except configparser.MissingSectionHeaderError: + logging.error("ini file: %s format error!", input_path) + exit(1) + except configparser.ParsingError: + logging.error("ini file: %s format error!", input_path) + exit(1) + + target_version_list = get_build_target_list() + copy_tbe_code_all_version({"env_code_root": env_code_root, + "target_version_list": target_version_list, + "env_cache_dir": env_cache_dir, + "tbe_kernel_path": tbe_kernel_path, + "tbe_ini": tbe_ini}) + os.remove(input_path) + + +def get_build_target_list_for_shell(): + return "\n".join(get_build_target_list()) diff --git a/scripts/release.sh b/scripts/release.sh index 5fbf126099723132b9d0130554e2c1d7dbaebe7d..8fd9abee8ec78f58e98b08039d16dc278fe5acf5 100644 --- a/scripts/release.sh +++ b/scripts/release.sh @@ -261,6 +261,37 @@ EOF echo "Ascend-cann-atb_${VERSION}_linux-${ARCH}.run is successfully generated in $OUTPUT_DIR" } +function fn_build_tbe_dependency() +{ + CANNDEV_DIR=$THIRD_PARTY_DIR/canndev + METADEF_DIR=$THIRD_PARTY_DIR/metadef + API_DIR=$THIRD_PARTY_DIR/api + CANN_OPS_DIR=$THIRD_PARTY_DIR/cann-ops-adv + export ASCEND_KERNEL_PATH=$ASCEND_HOME_PATH/opp/built-in/op_impl/ai_core/tbe/kernel + COMPILE_OPTIONS="${COMPILE_OPTIONS} -DBUILD_TBE_ADAPTER=ON" + + # release + if [ ! -d "$CANNDEV_DIR" ];then + echo "Failed to find canndev" + exit 1 + fi + if [ ! -d "$API_DIR" ];then + echo "Failed to find api" + exit 1 + fi + if [ ! -d "$CANN_OPS_DIR" ];then + echo "Failed to find cann-ops-adv" + exit 1 + fi + cp -r $CODE_ROOT/src/kernels/tbe_adapter/stubs/include/canndev $THIRD_PARTY_DIR + cp -r $CODE_ROOT/src/kernels/tbe_adapter/stubs/include/api $THIRD_PARTY_DIR + if [ ! -d "$METADEF_DIR" ];then + echo "Failed to find metadef" + exit 1 + fi + return +} + function fn_main() { if [[ "$1" == "pack" ]]; then @@ -291,6 +322,8 @@ function fn_main() "--build_customize_ops") COMPILE_OPTIONS="${COMPILE_OPTIONS} -DBUILD_CUSTOMIZE_OPS=ON" ;; + "--local_release_compile") + LOCAL_RELEASE_COMPILE=ON esac shift } @@ -309,9 +342,11 @@ function fn_main() fn_build_asdops fn_build_catlass fn_build_cann_dependency + fn_build_tbe_dependency [[ "$USE_CXX11_ABI" == "ON" ]] && COMPILE_OPTIONS="${COMPILE_OPTIONS} -DUSE_CXX11_ABI=ON" [[ "$USE_CXX11_ABI" == "OFF" ]] && COMPILE_OPTIONS="${COMPILE_OPTIONS} -DUSE_CXX11_ABI=OFF" - COMPILE_OPTIONS="${COMPILE_OPTIONS} -DCMAKE_BUILD_TYPE=Release" + COMPILE_OPTIONS="${COMPILE_OPTIONS} -DCMAKE_BUILD_TYPE=Release \ + -DLOCAL_RELEASE_COMPILE=$LOCAL_RELEASE_COMPILE -DPACKAGE_COMPILE=ON" config_atb_version fn_build_nlohmann_json fn_build_pybind11 @@ -325,7 +360,7 @@ fn_init_env SCRIPT_DIR=$(cd $(dirname $0); pwd) cd $SCRIPT_DIR cd .. -CODE_ROOT=$(pwd) +export CODE_ROOT=$(pwd) export CACHE_DIR=$CODE_ROOT/build OUTPUT_DIR=$CODE_ROOT/output THIRD_PARTY_DIR=$CODE_ROOT/3rdparty @@ -333,6 +368,7 @@ LOG_PATH="/var/log/cann_atb_log/" LOG_NAME="cann_atb_install.log" ATB_DIR=$CODE_ROOT RELEASE_DIR=$CODE_ROOT/ci/release +LOCAL_RELEASE_COMPILE=OFF cann_default_install_path="/usr/local/Ascend/ascend-toolkit" diff --git a/scripts/update_tbe_tactic_json.py b/scripts/update_tbe_tactic_json.py new file mode 100644 index 0000000000000000000000000000000000000000..965e447634c0ec27d9d85d98ff4e523b540ffaba --- /dev/null +++ b/scripts/update_tbe_tactic_json.py @@ -0,0 +1,366 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2024 Huawei Technologies Co., Ltd. +# This file is a part of the CANN Open Software. +# Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# +import argparse +import configparser +import json +import logging +import os +import stat +from collections import namedtuple + +from build_util import get_build_target_list + +logging.basicConfig(level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s') + +JsonSpecification = namedtuple( + "JsonSpecification", ["mode", "inputs", "outputs", "attrs", "dir", "deterministic"]) + +TacticDef = namedtuple("TacticDef", [ + "ops_name", "operation", "input_num", "output_num", "dtypes_in", + "dtypes_out", "formats_in", "formats_out", "mode", "attrs", "soc_support", "deterministic"]) + + + +def get_code_root(): + current_dir = os.path.dirname(os.path.abspath(__file__)) + return os.path.dirname(current_dir) + + +def get_tbe_kernel_path(): + result = True + tbe_kernel_path = os.getenv("ASCEND_KERNEL_PATH") + if not os.path.exists(tbe_kernel_path): + result = False + return tbe_kernel_path, result + + +def get_build_cache_path(): + result = True + build_cache_dir = os.getenv("CACHE_DIR") + if not os.path.exists(build_cache_dir): + result = False + return build_cache_dir, result + + +def read_tbe_config_file(input_args): + tbe_config_ini, result = None, True + if not os.path.exists(input_args.src_ini_path): + result = False + logging.error("ini file: %s not exist!", input_args.src_ini_path) + return tbe_config_ini, result + tbe_config_ini = configparser.RawConfigParser() + tbe_config_ini.optionxform = lambda option: option + try: + tbe_config_ini.read(input_args.src_ini_path) + except configparser.MissingSectionHeaderError: + result = False + logging.error("ini file: %s format error!", input_args.src_ini_path) + except configparser.ParsingError: + result = False + logging.error("ini file: %s format error!", input_args.src_ini_path) + return tbe_config_ini, result + + +def read_tbe_json_file(json_file_path): + result = True + ops_specification_list = [] + try: + json_list = os.listdir(json_file_path) + for file_name in json_list: + if not file_name.endswith(".json") or file_name.endswith("failed.json"): + continue + json_file = os.path.join(json_file_path, file_name) + with open(json_file) as f: + text = json.load(f) + item = text["supportInfo"] + inputs = item["inputs"] + outputs = item["outputs"] + mode = item["implMode"] if "implMode" in item else None + attrs = item["attrs"] if "attrs" in item else None + deterministic = item["deterministic"] if "deterministic" in item else None + json_info = JsonSpecification( + mode=mode, inputs=inputs, outputs=outputs, attrs=attrs, dir=file_name, deterministic=deterministic) + ops_specification_list.append(json_info) + except FileNotFoundError: + logging.error("file %s is not found!", json_file) + result = False + except json.decoder.JSONDecodeError: + logging.error("file %s is not json file!", json_file) + result = False + except KeyError: + logging.error("keyerror in file %s!", json_file) + result = False + return ops_specification_list, result + + +def impl_mode_matched_or_not(json_mode, tactic_mode): + if not json_mode: + return True + if not tactic_mode and "high_precision" not in json_mode: + return False + if tactic_mode and tactic_mode not in json_mode: + return False + return True + + +def check_dtype_matched_or_not(data_defs, data_dtypes): + try: + for i, dtype in enumerate(data_dtypes): + if dtype == "" and data_defs[i] is None: + continue + if data_defs[i] is None: + return False + dtypes = dtype.split("/") + tensor_dtype_not_match = isinstance(data_defs[i], dict) and data_defs[i]["dtype"] not in dtypes + tensorlist_dtype_not_match = isinstance(data_defs[i], list) and data_defs[i][0]["dtype"] not in dtypes + if tensor_dtype_not_match or tensorlist_dtype_not_match: + return False + except IndexError: + return False + except TypeError: + return False + return True + + +def check_format_matched_or_not(data_defs, data_formats): + try: + if not data_formats: + for data_def in data_defs: + if data_def is None: + continue + tensor_format_not_nd = isinstance(data_def, dict) and data_def["format"] != "ND" + tensorlist_format_not_nd = isinstance(data_def, list) and data_def[0]["format"] != "ND" + if tensor_format_not_nd or tensorlist_format_not_nd: + return False + else: + for i, dformat in enumerate(data_formats): + if data_defs[i] is None: + return False + tensor_format_not_match = isinstance(data_defs[i], dict) and data_defs[i]["format"] != dformat + tensorlist_format_not_match = isinstance(data_defs[i], list) and data_defs[i][0]["format"] != dformat + if tensor_format_not_match or tensorlist_format_not_match: + return False + except IndexError: + return False + except TypeError: + return False + return True + + +def inputs_outputs_matched_or_not(data_defs, data_num, data_dtypes, data_formats): + if len(data_defs) == 0 and data_num == 0: + return True + if len(data_defs) == 1 and "name" not in data_defs[0]: + # input paramType is dynamic + data_defs = data_defs[0] + + if len(data_defs) != data_num: + return False + + return check_dtype_matched_or_not(data_defs, data_dtypes) and check_format_matched_or_not(data_defs, data_formats) + + +def attrs_matched_or_not(json_attrs, tactic_attrs): + if not tactic_attrs: + return True + for i, attr in enumerate(tactic_attrs): + try: + if attr != str(json_attrs[i]["value"]): + return False + except IndexError: + return False + except TypeError: + return False + return True + + +def deterministic_matched_or_not(json_deterministic, tactic_deterministic): + if json_deterministic == "ignore": + return True + if json_deterministic.lower() == tactic_deterministic.lower(): + return True + else: + return False + + +def get_match_json(json_info_dir, tactic_info): + result = False + match_json_dir = "" + ops_specification_list, ret = read_tbe_json_file(json_info_dir) + if not ret: + return match_json_dir, result + count_check = 0 + for json_info in ops_specification_list: + matched = impl_mode_matched_or_not(json_info.mode, tactic_info.mode) \ + and inputs_outputs_matched_or_not( + json_info.inputs, tactic_info.input_num, tactic_info.dtypes_in, tactic_info.formats_in) \ + and inputs_outputs_matched_or_not( + json_info.outputs, tactic_info.output_num, tactic_info.dtypes_out, tactic_info.formats_out) \ + and attrs_matched_or_not(json_info.attrs, tactic_info.attrs) \ + and deterministic_matched_or_not(json_info.deterministic, tactic_info.deterministic) + + if matched: + match_json_dir, result = json_info.dir, True + count_check += 1 + + if count_check != 1: + logging.error( + f"{json_info_dir}: matched json file number is {count_check}, which should be 1") + result = False + return match_json_dir, result + + +def get_tbe_tactic_json(input_args, tbe_config_ini): + result = True + json_paths_info = configparser.ConfigParser() + json_paths_info.optionxform = lambda option: option + + try: + json_paths_info.read(input_args.dst_ini_path) + except configparser.MissingSectionHeaderError: + result = False + logging.error("ini file: %s format error!", input_args.dst_ini_path) + except configparser.ParsingError: + result = False + logging.error("ini file: %s format error!", input_args.dst_ini_path) + + tbe_kernel_path, ret = get_tbe_kernel_path() + if not ret: + result = False + logging.error("get tbe kernel path failed") + return json_paths_info, result + build_cache_dir, ret = get_build_cache_path() + if not ret: + result = False + logging.error("get build cache dir failed") + return json_paths_info, result + build_cache_obj_dir = os.path.join(build_cache_dir, "obj") + target_version_list = get_build_target_list() + logging.info("target version list: %s", target_version_list) + for target_version in target_version_list: + try: + for tactic_name in tbe_config_ini.sections(): + try: + ops = tbe_config_ini.get(tactic_name, "ops") + operation_name = tbe_config_ini.get( + tactic_name, "operationName") + input_num = int(tbe_config_ini.get( + tactic_name, "inputCount")) + output_num = int(tbe_config_ini.get( + tactic_name, "outputCount")) + input_dtypes = tbe_config_ini.get(tactic_name, "dtypeIn") + output_dtypes = tbe_config_ini.get(tactic_name, "dtypeOut") + input_formats = tbe_config_ini.get( + tactic_name, "formatIn", fallback=None) + output_formats = tbe_config_ini.get( + tactic_name, "formatOut", fallback=None) + mode = tbe_config_ini.get( + tactic_name, "mode", fallback=None) + attrs = tbe_config_ini.get( + tactic_name, "attrs", fallback=None) + soc_support = tbe_config_ini.get(tactic_name, "socSupport", fallback=None) + deterministic = tbe_config_ini.get( + tactic_name, "deterministic", fallback='ignore') + except configparser.NoOptionError: + logging.error("configparser option is not found: %s", tactic_name) + continue + except ValueError: + logging.error("string-to-int failed!") + continue + except configparser.InterpolationError: + result = False + logging.error("invalid interpolation syntax!") + break + except configparser.Error as e: + result = False + logging.error("Error: %s", e) + break + + input_dtype_arr = input_dtypes.split(",") + output_dtype_arr = output_dtypes.split(",") + input_format_arr = input_formats.split( + ",") if input_formats else None + output_format_arr = output_formats.split( + ",") if output_formats else None + attr_arr = attrs.split(',') if attrs else None + + tactic_info = TacticDef(ops_name=ops, operation=operation_name, + input_num=input_num, output_num=output_num, + dtypes_in=input_dtype_arr, dtypes_out=output_dtype_arr, + formats_in=input_format_arr, formats_out=output_format_arr, + mode=mode, attrs=attr_arr, soc_support=soc_support, + deterministic=deterministic) + if tactic_info.soc_support and target_version not in tactic_info.soc_support.split(","): + continue + json_info_dir = os.path.join( + tbe_kernel_path, target_version, ops) + match_json_dir, ret = get_match_json( + json_info_dir, tactic_info) + if not ret: + logging.error( + f"[{target_version}] get tactic failed: {tactic_name}") + exit(1) + if not json_paths_info.has_section(operation_name): + json_paths_info.add_section(operation_name) + json_paths_info.set( + operation_name, tactic_name + "." + target_version, + os.path.join(ops, match_json_dir)) + print(os.path.join(build_cache_obj_dir, target_version, operation_name, + match_json_dir)[:-5] + '_' + tactic_name.lower() + '.cpp') + + except configparser.NoSectionError: + result = False + logging.error("configparser section is not found") + except configparser.Error as e: + result = False + logging.error("Error: %s", e) + return json_paths_info, result + + +def write_tbe_tactic_json(input_args, json_paths_info): + fd = os.open(input_args.dst_ini_path, os.O_WRONLY | os.O_CREAT | + os.O_TRUNC, stat.S_IWUSR | stat.S_IRUSR) + with os.fdopen(fd, 'w+') as f: + try: + json_paths_info.write(f, space_around_delimiters=False) + except configparser.Error as e: + logging.error("Error: %s", e) + return + logging.info(f"write {input_args.dst_ini_path} success") + + +def main(): + code_root_dir = get_code_root() + tactic_info_path = os.path.join(code_root_dir, "configs/ops/tbe_tactic_info.ini") + + build_cache_dir, _ = get_build_cache_path() + tactic_json_path = os.path.join(build_cache_dir, "tbe_tactic_json.ini") + + parser = argparse.ArgumentParser() + parser.add_argument('--src_ini_path', type=str, required=False, + default=tactic_info_path) + parser.add_argument('--dst_ini_path', type=str, required=False, + default=tactic_json_path) + input_args = parser.parse_args() + + tbe_config_ini, ret = read_tbe_config_file(input_args) + if not ret: + logging.error("get tbe tactic info failed!") + exit(1) + json_paths_info, ret = get_tbe_tactic_json(input_args, tbe_config_ini) + if not ret: + logging.error("get tbe tactic json failed!") + exit(1) + write_tbe_tactic_json(input_args, json_paths_info) + + +if __name__ == "__main__": + main() diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 4730ce693fb70f9b6d9ff35cac05562de4f98e3c..208da858e4daadc035afe6e159db53b62ba2d2af 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -14,6 +14,7 @@ set(ops_common_directory ${CMAKE_CURRENT_LIST_DIR}/ops_common) set(atb_directory ${CMAKE_CURRENT_LIST_DIR}/atb) set(c_interface_directory ${CMAKE_CURRENT_LIST_DIR}/cinterface) set(MSTX_PATH $ENV{ASCEND_HOME_PATH}/tools/mstx/include) +set(ATB_INCLUDE_DIR $ENV{ASCEND_HOME_PATH}/include) add_compile_options(-Wfloat-equal -fno-common) @@ -39,7 +40,9 @@ if(USE_ASAN) endif() target_link_libraries(atb PUBLIC dl mki asdops atb_mixops ascendcl profapi lcal hccl pthread acl_op_compiler nnopbase) target_link_libraries(atb_train PUBLIC atb) -target_include_directories(atb PUBLIC ${MSTX_PATH}) -target_include_directories(atb_static PUBLIC ${MSTX_PATH}) +target_include_directories(atb PUBLIC ${MSTX_PATH} ${ATB_INCLUDE_DIR}) +target_include_directories(atb_static PUBLIC ${MSTX_PATH} ${ATB_INCLUDE_DIR}) +target_include_directories(atb_train PUBLIC ${MSTX_PATH} ${ATB_INCLUDE_DIR}) +target_include_directories(atb_train_static PUBLIC ${MSTX_PATH} ${ATB_INCLUDE_DIR}) install(TARGETS atb atb_static atb_train atb_train_static DESTINATION lib) \ No newline at end of file diff --git a/src/kernels/CMakeLists.txt b/src/kernels/CMakeLists.txt index 77bc97237a4f626b32f8aa30ef52790e3ef7956f..be97ac93b1a2a0d9d9fc76d4bc82f82ccae90d59 100644 --- a/src/kernels/CMakeLists.txt +++ b/src/kernels/CMakeLists.txt @@ -23,4 +23,7 @@ include_directories( ) add_subdirectory(mixkernels) -add_subdirectory(kernels) \ No newline at end of file +add_subdirectory(kernels) +if (BUILD_TBE_ADAPTER) + add_subdirectory(tbe_adapter) +endif() \ No newline at end of file diff --git a/src/kernels/tbe_adapter/CMakeLists.txt b/src/kernels/tbe_adapter/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..c5c8aaf3be12464e582c6a08815a7f9718a540e5 --- /dev/null +++ b/src/kernels/tbe_adapter/CMakeLists.txt @@ -0,0 +1,576 @@ +include(${MKI_PACKAGE_DIR}/cmake/host_config.cmake) +remove_definitions(-DOpSpace=Mki) +# add_definitions(-w) +file(GLOB_RECURSE SOURCE_FILES + ${CMAKE_CURRENT_LIST_DIR}/stubs/*.cpp + ${CMAKE_CURRENT_LIST_DIR}/platform/*.cpp + ${CMAKE_CURRENT_LIST_DIR}/tiling_runner/*.cpp +) + +set(ASCEND_PATH $ENV{ASCEND_HOME_PATH}) +set(METADEF_DIR ${PROJECT_SOURCE_DIR}/3rdparty/metadef) +set(CANN_OPS_DIR ${PROJECT_SOURCE_DIR}/3rdparty/cann-ops-adv/src/transformer) +set(TBE_TILING_DIR ${PROJECT_SOURCE_DIR}/3rdparty/canndev/ops/built-in/op_tiling) + +if("${PACKAGE_COMPILE}" STREQUAL "ON" AND NOT "${LOCAL_RELEASE_COMPILE}" STREQUAL "ON") + set(ASDOPS_METADEF_FILES + ${METADEF_DIR}/base/any_value.cc + ${METADEF_DIR}/base/type/ascend_string_impl.cc + ${METADEF_DIR}/base/utils/aligned_ptr.cc + ${METADEF_DIR}/base/utils/type_utils_impl.cc + ${METADEF_DIR}/base/runtime/compute_node_info.cc + ${METADEF_DIR}/base/runtime/runtime_attrs.cc + ${METADEF_DIR}/base/runtime/tiling_data.cc + ${METADEF_DIR}/base/registry/op_impl_registry.cc + ${METADEF_DIR}/exe_graph/lowering/getcdim.cc + ${METADEF_DIR}/exe_graph/lowering/shape_utils.cc + ${METADEF_DIR}/graph/attr/ge_attr_define.cc + ${METADEF_DIR}/graph/normal_graph/tensor.cc + ${METADEF_DIR}/graph/type/types.cc + ${METADEF_DIR}/register/op_tiling/op_tiling_attr_utils.cc + ${METADEF_DIR}/register/op_tiling/op_tiling_info.cc + ${METADEF_DIR}/register/op_binary_resource_manager.cc + ${METADEF_DIR}/register/tuning_bank_key_registry.cc + ${METADEF_DIR}/register/tuning_tiling_registry.cc + ${METADEF_DIR}/register/ascendc/tilingdata_base.cc + ${METADEF_DIR}/third_party/transformer/src/expand_dimension.cc + ${METADEF_DIR}/third_party/transformer/src/transfer_shape_according_to_format.cc + ${METADEF_DIR}/third_party/transformer/src/transfer_shape_utils.cc + ${METADEF_DIR}/third_party/transformer/src/axis_util.cc + ) + + # Add tbe_tiling.cc here + set(ASDOPS_TILING_FILES + ${TBE_TILING_DIR}/auto_tiling_context.cc + ${TBE_TILING_DIR}/auto_tiling_rt2.cc + ${TBE_TILING_DIR}/broadcast_v3.cc + ${TBE_TILING_DIR}/cache_tiling.cc + ${TBE_TILING_DIR}/compress_dequant_cache_tiling.cc + ${TBE_TILING_DIR}/concat_dsl.cc + ${TBE_TILING_DIR}/cube_tiling_runtime.cc + ${TBE_TILING_DIR}/elewise_v3.cc + ${TBE_TILING_DIR}/fusion.cc + ${TBE_TILING_DIR}/gather_dsl.cc + ${TBE_TILING_DIR}/gemm_ub_cache_tiling.cc + ${TBE_TILING_DIR}/gemm.cc + ${TBE_TILING_DIR}/lock.cc + ${TBE_TILING_DIR}/norm.cc + ${TBE_TILING_DIR}/reduce_tiling_v3.cc + ${TBE_TILING_DIR}/reduce_tiling_v3_compile_info.cc + ${TBE_TILING_DIR}/slice_dsl.cc + ${TBE_TILING_DIR}/sort_dsl.cc + ${TBE_TILING_DIR}/split_dsl.cc + ${TBE_TILING_DIR}/trans_data_fz2fzg.cc + ${TBE_TILING_DIR}/trans_data_fzg_to_fz.cc + ${TBE_TILING_DIR}/transdata_dsl.cc + ${TBE_TILING_DIR}/transdata_dsl_c04.cc + ${TBE_TILING_DIR}/transdata_dsl_c04_backward.cc + ${TBE_TILING_DIR}/transdata_dsl_entrance.cc + ${TBE_TILING_DIR}/transdata_dsl_borrow.cc + ${TBE_TILING_DIR}/transdata_dsl_general.cc + ${TBE_TILING_DIR}/transdata_dsl_util.cc + ${TBE_TILING_DIR}/vector_op_info.cc + ${TBE_TILING_DIR}/vector_tiling_key.cc + ${TBE_TILING_DIR}/vector_tiling_rt2.cc + ${TBE_TILING_DIR}/vector_tiling_util.cc + ${TBE_TILING_DIR}/aoe/runtime_kb/runtime_bank_manager.cc + ${TBE_TILING_DIR}/aoe/runtime_kb/op_hash.cc + ${TBE_TILING_DIR}/aoe/runtime_kb/op_runtime_bank.cc + ${TBE_TILING_DIR}/aoe/runtime_kb/common/utils/configuration.cc + ${TBE_TILING_DIR}/aoe/runtime_kb/common/utils/file_utils.cc + ${TBE_TILING_DIR}/aoe/runtime_kb/common/utils/system_utils.cc + ${TBE_TILING_DIR}/cube/algorithm/cache_tiling_impl.cc + ${TBE_TILING_DIR}/cube/algorithm/calculator/calculator.cc + ${TBE_TILING_DIR}/cube/algorithm/entity/shape.cc + ${TBE_TILING_DIR}/cube/algorithm/entity/status.cc + ${TBE_TILING_DIR}/cube/algorithm/hash/cache_runinfo.cc + ${TBE_TILING_DIR}/cube/algorithm/hash/hash.cc + ${TBE_TILING_DIR}/cube/algorithm/hash/tiling_cache.cc + ${TBE_TILING_DIR}/cube/impl/cache_tiling.cc + ${TBE_TILING_DIR}/cube/impl/cube_run_info.cc + ${TBE_TILING_DIR}/cube/impl/cube_tiling_param.cc + ${TBE_TILING_DIR}/cube/impl/cube_tiling.cc + ${TBE_TILING_DIR}/cube/platform/instruction_param.cc + ${TBE_TILING_DIR}/cube/platform/platform_info.cc + ${TBE_TILING_DIR}/cube/util/cube_util.cc + ${TBE_TILING_DIR}/cube/util/math_util.cc + ${TBE_TILING_DIR}/cube/util/timer.cc + ${TBE_TILING_DIR}/gemm/cache_tiling_basic_block.cc + ${TBE_TILING_DIR}/gemm/cache_tiling_basic_block_calc.cc + ${TBE_TILING_DIR}/gemm/common/cache_tiling_align_count.cc + ${TBE_TILING_DIR}/gemm/common/cache_tiling_request_bytes.cc + ${TBE_TILING_DIR}/gemm/estimate/cache_tiling_cycle_model.cc + ${TBE_TILING_DIR}/gemm/estimate/cache_tiling_basic_block_est.cc + ${TBE_TILING_DIR}/gemm/estimate/cache_tiling_est.cc + ${TBE_TILING_DIR}/runtime/as_strided.cc + ${TBE_TILING_DIR}/runtime/as_strided_tiling.cc + ${TBE_TILING_DIR}/runtime/cumsum.cc + ${TBE_TILING_DIR}/runtime/cumsum_tiling.cc + ${TBE_TILING_DIR}/runtime/cumsum_tiling_ascendc.cc + ${TBE_TILING_DIR}/runtime/cumsum_tiling_ascendc_int.cc + ${TBE_TILING_DIR}/runtime/concat_d.cc + ${TBE_TILING_DIR}/runtime/concat_tiling.cc + ${TBE_TILING_DIR}/runtime/expand.cc + ${TBE_TILING_DIR}/runtime/elewise_tiling.cc + ${TBE_TILING_DIR}/runtime/elementwise_template/fill_tiling.cc + ${TBE_TILING_DIR}/runtime/gatherv2.cc + ${TBE_TILING_DIR}/runtime/gather_v2_tiling.cpp + ${TBE_TILING_DIR}/runtime/inplace_index_add.cc + ${TBE_TILING_DIR}/runtime/inplace_index_add_tiling.cc + ${TBE_TILING_DIR}/runtime/layer_norm_v3.cc + ${TBE_TILING_DIR}/runtime/layer_norm_v3/layer_norm_v3_tiling.cc + ${TBE_TILING_DIR}/runtime/layer_norm_v3/layer_norm_v3_tiling_base.cc + ${OPS_THIRD_PARTY_DIR}/canndev/ops/norm/layer_norm_v4/op_host/layer_norm_v4_tiling.cpp + ${OPS_THIRD_PARTY_DIR}/canndev/ops/norm/layer_norm_v4/op_host/layer_norm_v4_tiling_base.cpp + ${TBE_TILING_DIR}/runtime/one_hot/one_hot.cc + ${TBE_TILING_DIR}/runtime/one_hot/one_hot_tiling.cc + ${TBE_TILING_DIR}/runtime/pack_tiling.cc + ${TBE_TILING_DIR}/runtime/reverse.cc + ${TBE_TILING_DIR}/runtime/reverse_v2_tiling.cc + ${TBE_TILING_DIR}/runtime/tensor_move_tiling.cc + ${TBE_TILING_DIR}/runtime/runtime2_util.cc + ${TBE_TILING_DIR}/runtime/slice.cc + ${TBE_TILING_DIR}/runtime/slice_tiling.cc + ${TBE_TILING_DIR}/runtime/strided_slice_tiling.cc + ${TBE_TILING_DIR}/runtime/top_k.cc + ${TBE_TILING_DIR}/runtime/top_k_v2_tiling.cc + ${TBE_TILING_DIR}/runtime/trans_data.cc + ${TBE_TILING_DIR}/runtime/trans_data_negative_target_ntc.cc + ${TBE_TILING_DIR}/runtime/trans_data_negative_target_tc_201.cc + ${TBE_TILING_DIR}/runtime/trans_data_positive_source_ntc_100.cc + ${TBE_TILING_DIR}/runtime/trans_data_positive_source_tc_1010.cc + ${TBE_TILING_DIR}/runtime/trans_data_positive_source_tc_1011.cc + ${TBE_TILING_DIR}/runtime/transdata_tiling.cc + ${TBE_TILING_DIR}/runtime/transpose.cc + ${TBE_TILING_DIR}/runtime/transpose_tiling.cc + ${TBE_TILING_DIR}/runtime/depth_to_space_tiling.cc + ${TBE_TILING_DIR}/runtime/space_to_depth_tiling.cc + ${TBE_TILING_DIR}/runtime/view_copy.cc + ${TBE_TILING_DIR}/runtime/view_copy_tiling.cc + ${OPS_THIRD_PARTY_DIR}/canndev/ops/index/scatter_elements_v2/op_host/scatter_elements_v2_tiling.cpp + ${TBE_TILING_DIR}/runtime/runtime2_util.cc + ${TBE_TILING_DIR}/../fusion_pass/common/fp16_t.cc + ${PROJECT_SOURCE_DIR}/3rdparty/canndev/ops/common/src/op_util.cc + ${PROJECT_SOURCE_DIR}/3rdparty/canndev/ops/common/src/getcdim.cc + ) +else() + set(ASDOPS_METADEF_FILES + ${METADEF_DIR}/exe_graph/lowering/getcdim.cc + ${METADEF_DIR}/exe_graph/lowering/shape_utils.cc + ${METADEF_DIR}/exe_graph/runtime/compute_node_info.cc + ${METADEF_DIR}/exe_graph/runtime/runtime_attrs.cc + ${METADEF_DIR}/exe_graph/runtime/tiling_data.cc + ${METADEF_DIR}/graph/attr/ge_attr_define.cc + ${METADEF_DIR}/graph/normal_graph/any_value.cc + ${METADEF_DIR}/graph/normal_graph/tensor.cc + ${METADEF_DIR}/graph/type/ascend_string.cc + ${METADEF_DIR}/graph/type/types.cc + ${METADEF_DIR}/graph/utils/aligned_ptr.cc + ${METADEF_DIR}/graph/utils/type_utils.cc + ${METADEF_DIR}/graph/utils/type_utils_ex.cc + ${METADEF_DIR}/register/op_impl_registry.cc + ${METADEF_DIR}/register/op_tiling/op_tiling_attr_utils.cc + ${METADEF_DIR}/register/op_tiling/op_tiling_info.cc + ${METADEF_DIR}/register/op_binary_resource_manager.cc + ${METADEF_DIR}/register/tuning_bank_key_registry.cc + ${METADEF_DIR}/register/tuning_tiling_registry.cc + ${METADEF_DIR}/register/ascendc/tilingdata_base.cc + ${METADEF_DIR}/third_party/transformer/src/expand_dimension.cc + ${METADEF_DIR}/third_party/transformer/src/transfer_shape_according_to_format.cc + ${METADEF_DIR}/third_party/transformer/src/transfer_shape_utils.cc + ${METADEF_DIR}/third_party/transformer/src/axis_util.cc + ) + + # Add tbe_tiling.cc here + set(ASDOPS_TILING_FILES + ${TBE_TILING_DIR}/auto_tiling_context.cc + ${TBE_TILING_DIR}/auto_tiling_rt2.cc + ${TBE_TILING_DIR}/broadcast_v3.cc + ${TBE_TILING_DIR}/cache_tiling.cc + ${TBE_TILING_DIR}/compress_dequant_cache_tiling.cc + ${TBE_TILING_DIR}/concat_dsl.cc + ${TBE_TILING_DIR}/cube_tiling_runtime.cc + ${TBE_TILING_DIR}/elewise_v3.cc + ${TBE_TILING_DIR}/fusion.cc + ${TBE_TILING_DIR}/gather_dsl.cc + ${TBE_TILING_DIR}/gemm_ub_cache_tiling.cc + ${TBE_TILING_DIR}/gemm.cc + ${TBE_TILING_DIR}/lock.cc + ${TBE_TILING_DIR}/norm.cc + ${TBE_TILING_DIR}/reduce_tiling_v3.cc + ${TBE_TILING_DIR}/reduce_tiling_v3_compile_info.cc + ${TBE_TILING_DIR}/slice_dsl.cc + ${TBE_TILING_DIR}/sort_dsl.cc + ${TBE_TILING_DIR}/split_dsl.cc + ${TBE_TILING_DIR}/trans_data_fz2fzg.cc + ${TBE_TILING_DIR}/trans_data_fzg_to_fz.cc + ${TBE_TILING_DIR}/transdata_dsl.cc + ${TBE_TILING_DIR}/transdata_dsl_c04.cc + ${TBE_TILING_DIR}/transdata_dsl_c04_backward.cc + ${TBE_TILING_DIR}/transdata_dsl_entrance.cc + ${TBE_TILING_DIR}/transdata_dsl_borrow.cc + ${TBE_TILING_DIR}/transdata_dsl_general.cc + ${TBE_TILING_DIR}/transdata_dsl_util.cc + ${TBE_TILING_DIR}/vector_op_info.cc + ${TBE_TILING_DIR}/vector_tiling_key.cc + ${TBE_TILING_DIR}/vector_tiling_rt2.cc + ${TBE_TILING_DIR}/vector_tiling_util.cc + ${TBE_TILING_DIR}/aoe/runtime_kb/runtime_bank_manager.cc + ${TBE_TILING_DIR}/aoe/runtime_kb/op_hash.cc + ${TBE_TILING_DIR}/aoe/runtime_kb/op_runtime_bank.cc + ${TBE_TILING_DIR}/aoe/runtime_kb/common/utils/configuration.cc + ${TBE_TILING_DIR}/aoe/runtime_kb/common/utils/file_utils.cc + ${TBE_TILING_DIR}/aoe/runtime_kb/common/utils/system_utils.cc + ${TBE_TILING_DIR}/cube/algorithm/cache_tiling_impl.cc + ${TBE_TILING_DIR}/cube/algorithm/calculator/calculator.cc + ${TBE_TILING_DIR}/cube/algorithm/entity/shape.cc + ${TBE_TILING_DIR}/cube/algorithm/entity/status.cc + ${TBE_TILING_DIR}/cube/algorithm/hash/cache_runinfo.cc + ${TBE_TILING_DIR}/cube/algorithm/hash/hash.cc + ${TBE_TILING_DIR}/cube/algorithm/hash/tiling_cache.cc + ${TBE_TILING_DIR}/cube/impl/cache_tiling.cc + ${TBE_TILING_DIR}/cube/impl/cube_run_info.cc + ${TBE_TILING_DIR}/cube/impl/cube_tiling_param.cc + ${TBE_TILING_DIR}/cube/impl/cube_tiling.cc + ${TBE_TILING_DIR}/cube/platform/instruction_param.cc + ${TBE_TILING_DIR}/cube/platform/platform_info.cc + ${TBE_TILING_DIR}/cube/util/cube_util.cc + ${TBE_TILING_DIR}/cube/util/math_util.cc + ${TBE_TILING_DIR}/cube/util/timer.cc + ${TBE_TILING_DIR}/gemm/cache_tiling_basic_block.cc + ${TBE_TILING_DIR}/gemm/cache_tiling_basic_block_calc.cc + ${TBE_TILING_DIR}/gemm/common/cache_tiling_align_count.cc + ${TBE_TILING_DIR}/gemm/common/cache_tiling_request_bytes.cc + ${TBE_TILING_DIR}/gemm/estimate/cache_tiling_cycle_model.cc + ${TBE_TILING_DIR}/gemm/estimate/cache_tiling_basic_block_est.cc + ${TBE_TILING_DIR}/gemm/estimate/cache_tiling_est.cc + ${TBE_TILING_DIR}/runtime/as_strided.cc + ${TBE_TILING_DIR}/runtime/as_strided_tiling.cc + ${TBE_TILING_DIR}/runtime/cumsum.cc + ${TBE_TILING_DIR}/runtime/cumsum_tiling.cc + ${TBE_TILING_DIR}/runtime/cumsum_tiling_ascendc.cc + ${TBE_TILING_DIR}/runtime/cumsum_tiling_ascendc_int.cc + ${TBE_TILING_DIR}/runtime/concat_d.cc + ${TBE_TILING_DIR}/runtime/concat_tiling.cc + ${TBE_TILING_DIR}/runtime/expand.cc + ${TBE_TILING_DIR}/runtime/elewise_tiling.cc + ${TBE_TILING_DIR}/runtime/elementwise_template/fill_tiling.cc + ${TBE_TILING_DIR}/runtime/gatherv2.cc + ${TBE_TILING_DIR}/runtime/gather_v2_tiling.cpp + ${TBE_TILING_DIR}/runtime/inplace_index_add.cc + ${TBE_TILING_DIR}/runtime/inplace_index_add_tiling.cc + ${TBE_TILING_DIR}/runtime/layer_norm_v3.cc + ${TBE_TILING_DIR}/runtime/one_hot/one_hot.cc + ${TBE_TILING_DIR}/runtime/one_hot/one_hot_tiling.cc + ${TBE_TILING_DIR}/runtime/pack_tiling.cc + ${TBE_TILING_DIR}/runtime/reverse.cc + ${TBE_TILING_DIR}/runtime/reverse_v2_tiling.cc + ${TBE_TILING_DIR}/runtime/tensor_move_tiling.cc + ${TBE_TILING_DIR}/runtime/runtime2_util.cc + ${TBE_TILING_DIR}/runtime/slice.cc + ${TBE_TILING_DIR}/runtime/slice_tiling.cc + ${TBE_TILING_DIR}/runtime/strided_slice_tiling.cc + ${TBE_TILING_DIR}/runtime/top_k.cc + ${TBE_TILING_DIR}/runtime/top_k_v2_tiling.cc + ${TBE_TILING_DIR}/runtime/trans_data.cc + ${TBE_TILING_DIR}/runtime/trans_data_negative_target_ntc.cc + ${TBE_TILING_DIR}/runtime/trans_data_negative_target_tc_201.cc + ${TBE_TILING_DIR}/runtime/trans_data_positive_source_ntc_100.cc + ${TBE_TILING_DIR}/runtime/trans_data_positive_source_tc_1010.cc + ${TBE_TILING_DIR}/runtime/trans_data_positive_source_tc_1011.cc + ${TBE_TILING_DIR}/runtime/transdata_tiling.cc + ${TBE_TILING_DIR}/runtime/transpose.cc + ${TBE_TILING_DIR}/runtime/transpose_tiling.cc + ${TBE_TILING_DIR}/runtime/depth_to_space_tiling.cc + ${TBE_TILING_DIR}/runtime/space_to_depth_tiling.cc + ${TBE_TILING_DIR}/runtime/view_copy.cc + ${TBE_TILING_DIR}/runtime/view_copy_tiling.cc + ${TBE_TILING_DIR}/runtime/scatter_elements_v2_tiling.cc + ${TBE_TILING_DIR}/runtime/runtime2_util.cc + ${TBE_TILING_DIR}/../fusion_pass/common/fp16_t.cc + ${PROJECT_SOURCE_DIR}/3rdparty/canndev/ops/common/src/op_util.cc + ${PROJECT_SOURCE_DIR}/3rdparty/canndev/ops/common/src/getcdim.cc + ) +endif() + +# Sink kernels +file(GLOB_RECURSE SINK_TILING_SRCS + ${TBE_TILING_DIR}/runtime/rotary_pos_emb_infer/* +) +list(APPEND ASDOPS_TILING_FILES ${SINK_TILING_SRCS}) + +# Tiling Api +set(TILING_API_BASE_DIR ${PROJECT_SOURCE_DIR}/3rdparty/api) +set(TILING_API_SRC_DIR ${PROJECT_SOURCE_DIR}/3rdparty/api/impl) + +# Generate kernel_tiling.h +set(GEN_KERENL_TILING_DATA_SCRIPT ${TILING_API_BASE_DIR}/cmake/scripts/gen_kernel_tiling_data_def.py) +set(TILING_DATA_DEF_DIR ${TILING_API_BASE_DIR}/lib) +set(KERNEL_TILING_HEAD ${CMAKE_BINARY_DIR}/generated_include/kernel_tiling/kernel_tiling.h) + +add_custom_command(OUTPUT ${KERNEL_TILING_HEAD} + COMMAND python3 ${GEN_KERENL_TILING_DATA_SCRIPT} ${TILING_DATA_DEF_DIR} ${KERNEL_TILING_HEAD} + DEPENDS ${GEN_KERENL_TILING_DATA_SCRIPT}) + +add_custom_target(GEN_KERNEL_TILING ALL + DEPENDS ${KERNEL_TILING_HEAD}) + +set(TILING_API_SRCS + ${TILING_API_SRC_DIR}/quantization/dequant/ascend_dequant_tiling_impl.cpp + ${TILING_API_SRC_DIR}/quantization/quant/ascend_quant_tiling_impl.cpp + ${TILING_API_SRC_DIR}/quantization/antiquant/ascend_antiquant_tiling_impl.cpp + ${TILING_API_SRC_DIR}/filter/dropout/dropout_tiling_impl.cpp + ${TILING_API_SRC_DIR}/activation/gelu/gelu_tiling_impl.cpp + ${TILING_API_SRC_DIR}/matmul/tiling/bmm_tiling.cpp + ${TILING_API_SRC_DIR}/matmul/tiling/matmul_tiling.cpp + ${TILING_API_SRC_DIR}/matmul/tiling/matmul_tiling_base.cpp + ${TILING_API_SRC_DIR}/matmul/tiling/matmul_tiling_algorithm.cpp + ${TILING_API_SRC_DIR}/matmul/tiling/math_util.cpp + ${TILING_API_SRC_DIR}/hccl/hccl_tiling.cpp + ${TILING_API_SRC_DIR}/math/clamp/clamp_tiling_impl.cpp + ${TILING_API_SRC_DIR}/math/acos/acos_tiling_impl.cpp + ${TILING_API_SRC_DIR}/math/acosh/acosh_tiling.cpp + ${TILING_API_SRC_DIR}/math/asin/asin_tiling_impl.cpp + ${TILING_API_SRC_DIR}/math/asinh/asinh_tiling.cpp + ${TILING_API_SRC_DIR}/math/atan/atan_tiling_impl.cpp + ${TILING_API_SRC_DIR}/math/atanh/atanh_tiling.cpp + ${TILING_API_SRC_DIR}/math/cos/cos_tiling_impl.cpp + ${TILING_API_SRC_DIR}/math/cosh/cosh_tiling_impl.cpp + ${TILING_API_SRC_DIR}/math/erf/erf_tiling_impl.cpp + ${TILING_API_SRC_DIR}/math/erfc/erfc_tiling_impl.cpp + ${TILING_API_SRC_DIR}/math/exp/exp_tiling_impl.cpp + ${TILING_API_SRC_DIR}/math/frac/frac_tiling_impl.cpp + ${TILING_API_SRC_DIR}/activation/geglu/geglu_tiling.cpp + ${TILING_API_SRC_DIR}/math/lgamma/lgamma_tiling.cpp + ${TILING_API_SRC_DIR}/math/digamma/digamma_tiling_impl.cpp + ${TILING_API_SRC_DIR}/math/log/log_tiling.cpp + ${TILING_API_SRC_DIR}/math/sin/sin_tiling_impl.cpp + ${TILING_API_SRC_DIR}/math/sinh/sinh_tiling_impl.cpp + ${TILING_API_SRC_DIR}/math/power/power_tiling_impl.cpp + ${TILING_API_SRC_DIR}/activation/sigmoid/sigmoid_tiling_impl.cpp + ${TILING_API_SRC_DIR}/math/round/round_tiling_impl.cpp + ${TILING_API_SRC_DIR}/math/tan/tan_tiling_impl.cpp + ${TILING_API_SRC_DIR}/math/tanh/tanh_tiling_impl.cpp + ${TILING_API_SRC_DIR}/math/trunc/trunc_tiling_impl.cpp + ${TILING_API_SRC_DIR}/math/axpy/axpy_tiling_impl.cpp + ${TILING_API_SRC_DIR}/math/hypot/hypot_tiling_impl.cpp + ${TILING_API_SRC_DIR}/activation/swiglu/swiglu_tiling.cpp + ${TILING_API_SRC_DIR}/math/ceil/ceil_tiling_impl.cpp + ${TILING_API_SRC_DIR}/math/floor/floor_tiling_impl.cpp + ${TILING_API_SRC_DIR}/activation/softmax/softmax_tiling.cpp + ${TILING_API_SRC_DIR}/activation/softmax/logsoftmax_tiling.cpp + ${TILING_API_SRC_DIR}/normalization/rmsnorm/rmsnorm_tiling_impl.cpp + ${TILING_API_SRC_DIR}/normalization/batchnorm/batchnorm_tiling_impl.cpp + ${TILING_API_SRC_DIR}/sort/sort/sort_tiling_impl.cpp + ${TILING_API_SRC_DIR}/sort/topk/topk_tiling_impl.cpp + ${TILING_API_SRC_DIR}/normalization/deepnorm/deepnorm_tiling_impl.cpp + ${TILING_API_SRC_DIR}/select/selectwithbytesmask/selectwithbytesmask_tiling_impl.cpp + ${TILING_API_SRC_DIR}/normalization/layernorm/layernorm_tiling_impl.cpp + ${TILING_API_SRC_DIR}/normalization/normalize/normalize_tiling_impl.cpp + ${TILING_API_SRC_DIR}/normalization/layernormgrad/layernorm_grad_tiling_impl.cpp + ${TILING_API_SRC_DIR}/normalization/layernormgrad/layernorm_grad_beta_tiling_impl.cpp + ${TILING_API_SRC_DIR}/normalization/groupnorm/groupnorm_tiling_impl.cpp + ${TILING_API_SRC_DIR}/normalization/welfordfinalize/welfordfinalize_tiling_impl.cpp + ${TILING_API_SRC_DIR}/pad/pad/pad_tiling_impl.cpp + ${TILING_API_SRC_DIR}/transpose/confusion_transpose/confusion_transpose_tiling_impl.cpp + ${TILING_API_SRC_DIR}/pad/broadcast/broadcast_tiling.cpp + ${TILING_API_SRC_DIR}/pad/broadcast/broadcast_tiling.cpp + ${TILING_API_SRC_DIR}/math/xor/xor_tiling.cpp + ${TILING_API_SRC_DIR}/math/cumsum/cumsum_tiling.cpp + ${TILING_API_SRC_DIR}/reduce/mean/mean_tiling.cpp + ${TILING_API_SRC_DIR}/math/sign/sign_tiling.cpp + ${TILING_API_SRC_DIR}/activation/reglu/reglu_tiling_impl.cpp + ${TILING_API_SRC_DIR}/reduce/reduce_xor_sum/reduce_xor_sum_tiling.cpp + ${TILING_API_SRC_DIR}/reduce/sum/sum_tiling.cpp + ${TILING_API_SRC_DIR}/reduce/reduce_tiling.cpp + ${TILING_API_SRC_DIR}/index/arithprogression/arithprogression_tiling_impl.cpp + ${TILING_API_SRC_DIR}/math/fmod/fmod_tiling_impl.cpp +) + +if("${PACKAGE_COMPILE}" STREQUAL "ON" AND NOT "${LOCAL_RELEASE_COMPILE}" STREQUAL "ON") + set(ASDOPS_INC_DIRS + ${CMAKE_CURRENT_LIST_DIR} + ${CMAKE_CURRENT_LIST_DIR}/stubs/include + ${CMAKE_CURRENT_LIST_DIR}/stubs/include/metadef + ${CMAKE_CURRENT_LIST_DIR}/stubs/include/metadef/inc + ${CMAKE_CURRENT_LIST_DIR}/stubs/include/metadef/inc/graph + ${CMAKE_CURRENT_LIST_DIR}/platform + ${CMAKE_CURRENT_LIST_DIR}/platform/tiling + ${KERNEL_TILING_DIR} + ${TILING_API_BASE_DIR} + ${TILING_API_BASE_DIR}/tiling + ${PROJECT_SOURCE_DIR}/3rdparty/canndev/ops + ${PROJECT_SOURCE_DIR}/3rdparty/canndev/ops/conversion/fill/op_kernel + ${METADEF_DIR} + ${METADEF_DIR}/base/runtime + ${METADEF_DIR}/exe_graph/lowering + ${METADEF_DIR}/exe_graph/runtime + ${METADEF_DIR}/graph + ${METADEF_DIR}/inc + ${METADEF_DIR}/inc/common + ${METADEF_DIR}/inc/exe_graph + ${METADEF_DIR}/inc/external + ${METADEF_DIR}/inc/external/exe_graph + ${METADEF_DIR}/inc/external/exe_graph/lowering + ${METADEF_DIR}/inc/external/exe_graph/runtime + ${METADEF_DIR}/inc/external/graph + ${METADEF_DIR}/register + ${METADEF_DIR}/third_party/transformer/inc + ${METADEF_DIR}/third_party/transformer/src + ${METADEF_DIR}/inc/common/ge_common/debug + ${PROJECT_SOURCE_DIR}/3rdparty/canndev/ops/utils/inc + ${PROJECT_SOURCE_DIR}/3rdparty/canndev/ops/built-in + ${PROJECT_SOURCE_DIR}/3rdparty/canndev/ops/built-in/op_tiling + ${PROJECT_SOURCE_DIR}/3rdparty/canndev/ops/built-in/op_tiling/runtime + ${PROJECT_SOURCE_DIR}/3rdparty/canndev/ops/common/ + ${PROJECT_SOURCE_DIR}/3rdparty/canndev/ops/common/inc + ${PROJECT_SOURCE_DIR}/3rdparty/canndev/third_party/fwkacllib/inc + ${METADEF_DIR}/inc/graph + $ENV{ASCEND_HOME_PATH}/include + $ENV{ASCEND_HOME_PATH}/include/tiling + ) +else() + set(ASDOPS_INC_DIRS + ${CMAKE_CURRENT_LIST_DIR} + ${CMAKE_CURRENT_LIST_DIR}/stubs/include + ${CMAKE_CURRENT_LIST_DIR}/stubs/include/metadef + ${CMAKE_CURRENT_LIST_DIR}/stubs/include/metadef/inc + ${CMAKE_CURRENT_LIST_DIR}/stubs/include/metadef/inc/graph + ${CMAKE_CURRENT_LIST_DIR}/platform + ${CMAKE_CURRENT_LIST_DIR}/platform/tiling + ${KERNEL_TILING_DIR} + ${TILING_API_BASE_DIR} + ${TILING_API_BASE_DIR}/tiling + ${PROJECT_SOURCE_DIR}/3rdparty/canndev/ops/conversion/fill/op_kernel + ${METADEF_DIR} + ${METADEF_DIR}/exe_graph/lowering + ${METADEF_DIR}/exe_graph/runtime + ${METADEF_DIR}/graph + ${METADEF_DIR}/inc + ${METADEF_DIR}/inc/common + ${METADEF_DIR}/inc/exe_graph + ${METADEF_DIR}/inc/external + ${METADEF_DIR}/inc/external/exe_graph + ${METADEF_DIR}/inc/external/exe_graph/lowering + ${METADEF_DIR}/inc/external/exe_graph/runtime + ${METADEF_DIR}/inc/external/graph + ${METADEF_DIR}/register + ${METADEF_DIR}/third_party/transformer/inc + ${METADEF_DIR}/inc/common/ge_common/debug + ${PROJECT_SOURCE_DIR}/3rdparty/canndev/ops/utils/inc + ${PROJECT_SOURCE_DIR}/3rdparty/canndev/ops/built-in + ${PROJECT_SOURCE_DIR}/3rdparty/canndev/ops/built-in/op_tiling + ${PROJECT_SOURCE_DIR}/3rdparty/canndev/ops/built-in/op_tiling/runtime + ${PROJECT_SOURCE_DIR}/3rdparty/canndev/ops/common/ + ${PROJECT_SOURCE_DIR}/3rdparty/canndev/ops/common/inc + ${PROJECT_SOURCE_DIR}/3rdparty/canndev/third_party/fwkacllib/inc + ${METADEF_DIR}/inc/graph + $ENV{ASCEND_HOME_PATH}/include + $ENV{ASCEND_HOME_PATH}/include/tiling + ) +endif() + +# Ignore warnings from 3rdparty srcs +# DO NOT USE THIS FOR OTHER SRCS +set_source_files_properties( + ${TILING_API_SRCS} + ${ASDOPS_TILING_FILES} + ${ASDOPS_METADEF_FILES} + PROPERTIES + COMPILE_FLAGS "-w" +) + +# MIX OPS +file(MAKE_DIRECTORY ${CMAKE_BINARY_DIR}/mixops/) +execute_process(COMMAND python3 ${PROJECT_SOURCE_DIR}/scripts/update_tbe_tactic_json.py + --src_ini_path ${PROJECT_SOURCE_DIR}/configs/mixops/tbe_tactic_info.ini + --dst_ini_path ${CMAKE_BINARY_DIR}/mixops/tbe_tactic_json.ini + OUTPUT_VARIABLE MIX_PYTHON_OUTPUT + ERROR_VARIABLE RESULT_INFO + RESULT_VARIABLE RESULT) +if(NOT RESULT EQUAL 0) + message(FATAL_ERROR "tbe info update failed, error code: ${RESULT}, error info:\n${RESULT_INFO}") +endif() + +string(REPLACE "\n" ";" MIX_REUSE_BINARY_LIST "${MIX_PYTHON_OUTPUT}") +list(POP_BACK MIX_REUSE_BINARY_LIST) +set_source_files_properties(${MIX_REUSE_BINARY_LIST} PROPERTIES GENERATED TRUE) + +add_custom_command( + OUTPUT ${MIX_REUSE_BINARY_LIST} ${CMAKE_BINARY_DIR}/mix_wait_flag.cpp + DEPENDS ${PROJECT_SOURCE_DIR}/configs/mixops/tbe_tactic_info.ini + COMMAND python3 ${MKI_PACKAGE_DIR}/scripts/build_util.py --binary_dir ${CMAKE_BINARY_DIR} --op_type tbe + --tbe_ini_path ${CMAKE_BINARY_DIR}/mixops/tbe_tactic_json.ini + COMMAND cmake -E sleep 10 + COMMAND cmake -E echo "wait 10 sec done" + COMMAND cmake -E touch ${CMAKE_BINARY_DIR}/mix_wait_flag.cpp +) + +add_custom_target(MIX_REUSE_SRC_TARGET ALL + DEPENDS ${MIX_REUSE_BINARY_LIST} ${CMAKE_BINARY_DIR}/mix_wait_flag.cpp +) +add_library(mix_reuse_kernels OBJECT ${MIX_REUSE_BINARY_LIST} ${CMAKE_BINARY_DIR}/mix_wait_flag.cpp) +add_dependencies(mix_reuse_kernels MIX_REUSE_SRC_TARGET) +target_compile_definitions(mix_reuse_kernels PRIVATE OpSpace=AtbOps) + +# Ops +execute_process(COMMAND python3 ${PROJECT_SOURCE_DIR}/scripts/update_tbe_tactic_json.py + OUTPUT_VARIABLE OPS_PYTHON_OUTPUT + ERROR_VARIABLE RESULT_INFO + RESULT_VARIABLE RESULT) +if(NOT RESULT EQUAL 0) + message(FATAL_ERROR "tbe info update failed, error code: ${RESULT}, error info:\n${RESULT_INFO}") +endif() + +string(REPLACE "\n" ";" OPS_REUSE_BINARY_LIST "${OPS_PYTHON_OUTPUT}") +list(POP_BACK OPS_REUSE_BINARY_LIST) +set_source_files_properties(${OPS_REUSE_BINARY_LIST} PROPERTIES GENERATED TRUE) + +add_custom_command( + OUTPUT ${OPS_REUSE_BINARY_LIST} ${CMAKE_BINARY_DIR}/ops_wait_flag.cpp + DEPENDS ${PROJECT_SOURCE_DIR}/configs/ops/tbe_tactic_info.ini + COMMAND python3 ${MKI_PACKAGE_DIR}/scripts/build_util.py --binary_dir ${CMAKE_BINARY_DIR} --op_type tbe + --tbe_ini_path ${CMAKE_BINARY_DIR}/tbe_tactic_json.ini + COMMAND cmake -E sleep 10 + COMMAND cmake -E echo "wait 10 sec done" + COMMAND cmake -E touch ${CMAKE_BINARY_DIR}/ops_wait_flag.cpp +) + +add_custom_target(OPS_REUSE_SRC_TARGET ALL + DEPENDS ${OPS_REUSE_BINARY_LIST} ${CMAKE_BINARY_DIR}/ops_wait_flag.cpp +) +add_library(ops_reuse_kernels OBJECT ${OPS_REUSE_BINARY_LIST} ${CMAKE_BINARY_DIR}/ops_wait_flag.cpp) +add_dependencies(ops_reuse_kernels OPS_REUSE_SRC_TARGET) +target_compile_definitions(ops_reuse_kernels PRIVATE OpSpace=AsdOps) + +# Target +add_library(tbe_adapter SHARED ${SOURCE_FILES} ${ASDOPS_METADEF_FILES} ${ASDOPS_TILING_FILES} ${TILING_API_SRCS}) +add_dependencies(tbe_adapter GEN_KERNEL_TILING) +target_include_directories(tbe_adapter PRIVATE ${ASDOPS_INC_DIRS}) +target_link_libraries(tbe_adapter PRIVATE mmpa c_sec ascendalog mix_reuse_kernels ops_reuse_kernels) + +add_definitions(-DLOG_CPP) + +target_compile_definitions(tbe_adapter PUBLIC + OPS_UTILS_LOG_SUB_MOD_NAME="ASDOPS_TILING" +) + +target_compile_definitions(tbe_adapter PUBLIC + fe=AsdOpsFe + ge=AsdOpsGe + gert=AsdOpsGeRt + optiling=AsdOpsTiling + error_message=AsdOpsErrorMessage + ErrorManager=AsdOpsErrorManager + td_dsl=AsdOpsTdDsl + tuningtiling=AsdOpsTuningTiling + RuntimeKb=AsdOpsRuntimeKb + gemm_cache_tiling=AsdOpsGemmCacheTiling + transformer=AsdOpsTransformer + ops=AsdOpsOps + AscendC=AsdOpsAscendC + matmul_tiling=AsdOpsMatmulTiling + platform_ascendc=AsdOpsPlatformAscendC +) + +install(TARGETS tbe_adapter DESTINATION lib) diff --git a/src/kernels/tbe_adapter/platform/platform_ascendc.cpp b/src/kernels/tbe_adapter/platform/platform_ascendc.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e1d924f2cdce43ece07b9efa32d71561ed0807ab --- /dev/null +++ b/src/kernels/tbe_adapter/platform/platform_ascendc.cpp @@ -0,0 +1,227 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include +#include +#include +#include +#include +#include +#include "securec.h" +#include "platform/platform_info.h" +#include "platform/platform_infos_def.h" + +namespace platform_ascendc { +const static uint64_t LOCAL_RESERV_SIZE = 256; +const static uint32_t WORKSPACE_SIZE_910B = 16 * 1024 * 1024; +const static uint32_t WORKSPACE_SIZE = 2 * 1024 * 1024; +const static uint32_t MIX_AIC_AIV_RATION_910B1 = 2; +const static uint32_t CUBE_GROUP_WORKSPACE_SIZE_910B = 1 * 1024 * 1024; +const static uint32_t GROUP_BARRIER_WORKSPACE_SIZE_910B = 1 * 1024 * 1024; +const static std::string STR_VERSION = "version"; +const static std::string STR_SOC_INFO = "SoCInfo"; +const static std::string SHORT_SOC_VERSION = "Short_SoC_version"; +const static std::string STR_SPLIT_KEY = "core_type_list"; +const static std::string STR_SPLIT_VAL = "CubeCore,VectorCore"; +const static std::string STR_CORE_CNT_CUB = "cube_core_cnt"; +const static std::string STR_CORE_CNT_VEC = "vector_core_cnt"; +const static std::string STR_CORE_CNT_AICORE = "ai_core_cnt"; +const static std::map CONVERT_MAP = { + {"Ascend310P", SocVersion::ASCEND310P}, + {"Ascend310B", SocVersion::ASCEND310B}, + {"Ascend910", SocVersion::ASCEND910}, + {"Ascend910B", SocVersion::ASCEND910B}, + {"Ascend910_93", SocVersion::ASCEND910B}, +}; + +static inline uint32_t GetCoreNumByType(fe::PlatFormInfos *platformInfo, bool isAiv) +{ + std::string key; + std::string val; + bool ret = platformInfo->GetPlatformResWithLock(STR_SOC_INFO, STR_SPLIT_KEY, val); + MKI_LOG_IF(!ret, ERROR) << "get platform failed, val is " << val; + + if (STR_SPLIT_VAL.compare(val) != 0) { + key = STR_CORE_CNT_AICORE; + } else if (isAiv) { + key = STR_CORE_CNT_VEC; + } else { + key = STR_CORE_CNT_CUB; + } + ret = platformInfo->GetPlatformResWithLock(STR_SOC_INFO, key, val); + MKI_LOG_IF(!ret, ERROR) << "get platform failed, key is " << key << ", val is" << val; + return val.empty() ? 0 : static_cast(std::atoi(val.c_str())); +} + +uint32_t PlatformAscendC::GetCoreNumVector(void) const +{ + if (GetSocVersion() == SocVersion::ASCEND310P) { + std::string val; + bool ret = GetPlatFormInfo()->GetPlatformResWithLock(STR_SOC_INFO, STR_CORE_CNT_VEC, val); + MKI_LOG_IF(!ret, ERROR) << "get platform vector num failed, val is " << val; + return val.empty() ? 0 : std::atoi(val.c_str()); + } + return 0; +} + +uint32_t PlatformAscendC::GetCoreNumAic(void) const +{ + return GetCoreNumByType(GetPlatFormInfo(), false); +} + +uint32_t PlatformAscendC::GetCoreNumAiv(void) const +{ + return GetCoreNumByType(GetPlatFormInfo(), true); +} + +uint32_t PlatformAscendC::GetCoreNum(void) const +{ + return GetPlatFormInfo()->GetCoreNum(); +} + +void PlatformAscendC::GetCoreMemSize(const CoreMemType &memType, uint64_t &size) const +{ + const fe::LocalMemType localType = static_cast(memType); + GetPlatFormInfo()->GetLocalMemSize(localType, size); + // only ascend910B need UB/L1 local reserved buf for kfc + if ((memType == CoreMemType::UB || memType == CoreMemType::L1) + && GetSocVersion() == SocVersion::ASCEND910B) { + size -= LOCAL_RESERV_SIZE; + } +} + +SocVersion PlatformAscendC::GetSocVersion(void) const +{ + std::string socVersionStr; + const auto ret = GetPlatFormInfo()->GetPlatformResWithLock(STR_VERSION, SHORT_SOC_VERSION, socVersionStr); + MKI_CHECK(ret, "get platform failed, socVersionStr is " << socVersionStr, + return SocVersion::RESERVED_VERSION); + + const auto &it = CONVERT_MAP.find(socVersionStr); + if (it != CONVERT_MAP.cend()) { + return it->second; + } + MKI_LOG(ERROR) << "get platform failed, convertMap do not find soc " << socVersionStr << " version"; + return SocVersion::RESERVED_VERSION; +} +void PlatformAscendC::GetCoreMemBw(const CoreMemType &memType, uint64_t &bwSize) const +{ + const fe::LocalMemType localType = static_cast(memType); + GetPlatFormInfo()->GetLocalMemBw(localType, bwSize); +} + +fe::PlatFormInfos* PlatformAscendC::GetPlatFormInfo(void) const +{ + MKI_CHECK(platformInfo_, "PlatformInfo cannot be initialized to nulltpr!!", raise(SIGABRT)); + return platformInfo_; +} + +uint32_t PlatformAscendC::CalcTschBlockDim(uint32_t sliceNum, uint32_t aicCoreNum, uint32_t aivCoreNum) const +{ + if (aicCoreNum == 0 || aivCoreNum == 0 || aicCoreNum > aivCoreNum) { + return sliceNum; + } + uint32_t ration = aivCoreNum / aicCoreNum; + uint32_t blockDim = (sliceNum + (ration - 1)) / ration; + // in mix case: 910B1(ration = 2), blockDim should not be greater than physical aic core num + if ((ration == MIX_AIC_AIV_RATION_910B1) && (blockDim > aicCoreNum)) { + MKI_LOG(ERROR) << "CalcTschBlockDim failed, calc blockDim " << blockDim + << " should not be greater than aicCoreNum " << aicCoreNum; + return 0; + } + return blockDim; +} + +uint32_t PlatformAscendC::GetLibApiWorkSpaceSize(void) const +{ + auto socVersion = GetSocVersion(); + if (socVersion == SocVersion::RESERVED_VERSION) { + MKI_LOG(ERROR) << "get platform failed, socVersionStr is " << static_cast(socVersion); + return -1; + } else if (socVersion == SocVersion::ASCEND910B) { + return WORKSPACE_SIZE_910B; + } + return WORKSPACE_SIZE; +} + +uint32_t PlatformAscendC::GetResCubeGroupWorkSpaceSize(void) const +{ + auto socVersion = GetSocVersion(); + if (socVersion == SocVersion::ASCEND910B) { + return CUBE_GROUP_WORKSPACE_SIZE_910B; + } else { + MKI_LOG(ERROR) << "get platform failed, socVersionStr is " << static_cast(socVersion); + return -1; + } +} + +uint32_t PlatformAscendC::GetResGroupBarrierWorkSpaceSize(void) const +{ + auto socVersion = GetSocVersion(); + if (socVersion == SocVersion::ASCEND910B) { + return GROUP_BARRIER_WORKSPACE_SIZE_910B; + } else { + MKI_LOG(ERROR) << "get platform failed, socVersionStr is " << static_cast(socVersion); + return -1; + } +} + +PlatformAscendC* PlatformAscendCManager::platformInfo = nullptr; +std::mutex PlatformAscendCManager::platformInitMtx; +SocVersion PlatformAscendCManager::SocVersionMap(const char *socVersionStr) +{ + const auto &iter = CONVERT_MAP.find(socVersionStr); + if (iter != CONVERT_MAP.cend()) { + return iter->second; + } + MKI_LOG(ERROR) << "get platform failed, convertMap do not find soc " << socVersionStr << " version"; + return SocVersion::RESERVED_VERSION; +} +fe::PlatFormInfos* PlatformAscendCManager::PlatformAscendCInit(const char *customSocVersion) +{ + static fe::PlatFormInfos gPlatformInfo; + std::string socVersion; + + if (customSocVersion == nullptr) { + const uint32_t maxLen = 50; + MKI_CHECK(Mki::MkiRtDeviceGetSocVersion(socVersion, maxLen) == MKIRT_SUCCESS, + "failed to get soc version", return nullptr); + } else { + socVersion = customSocVersion; + } + + fe::PlatformInfoManager::GeInstance().InitRuntimePlatformInfos(socVersion); + fe::OptionalInfos optionalInfos; + fe::PlatformInfoManager::GeInstance().GetPlatformInfos(socVersion, + gPlatformInfo, optionalInfos); + std::string socVersionStr; + const auto ret = gPlatformInfo.GetPlatformResWithLock(STR_VERSION, SHORT_SOC_VERSION, socVersionStr); + MKI_LOG_IF(!ret, ERROR) << "get platform short version failed, socVersion is " << socVersion; + SocVersion version = SocVersionMap(socVersionStr.c_str()); + if (version == SocVersion::RESERVED_VERSION) { + MKI_LOG(ERROR) << "Invalid SocVersion."; + return nullptr; + } else if ((version == SocVersion::ASCEND310P) || (version == SocVersion::ASCEND910)) { + gPlatformInfo.SetCoreNumByCoreType("AiCore"); + } else { + gPlatformInfo.SetCoreNumByCoreType("VectorCore"); + } + return &gPlatformInfo; +} +PlatformAscendC* PlatformAscendCManager::PlatformAscendCManagerInit(const char *customSocVersion) +{ + static fe::PlatFormInfos* gPlatformAscendCInfo = PlatformAscendCInit(customSocVersion); + MKI_CHECK(gPlatformAscendCInfo, "failed to get platformInfo", return nullptr); + + static PlatformAscendC tmp(gPlatformAscendCInfo); + platformInfo = &tmp; + return platformInfo; +} +} diff --git a/src/kernels/tbe_adapter/platform/platform_info.cpp b/src/kernels/tbe_adapter/platform/platform_info.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3cb19c3792e872fd2229d100d67f2f84efb35b3d --- /dev/null +++ b/src/kernels/tbe_adapter/platform/platform_info.cpp @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include "platform/platform_info.h" + +namespace fe { +PlatformInfoManager& PlatformInfoManager::Instance() +{ + static PlatformInfoManager manager; + return manager; +} + +PlatformInfoManager& PlatformInfoManager::GeInstance() +{ + return Instance(); +} + +uint32_t PlatformInfoManager::InitializePlatformInfo() +{ + return 1; +} + +uint32_t PlatformInfoManager::GetPlatformInfoWithOutSocVersion( + PlatformInfo &platformInfo, OptionalInfo &optiCompilationInfo) +{ + (void)platformInfo; + (void)optiCompilationInfo; + return 1; +} + +uint32_t PlatformInfoManager::GetPlatformInfoWithOutSocVersion( + PlatFormInfos &platformInfo, OptionalInfos &optiCompilationInfo) +{ + (void)platformInfo; + (void)optiCompilationInfo; + return 1; +} + +uint32_t PlatformInfoManager::InitRuntimePlatformInfos(const std::string &socVersion) +{ + (void)socVersion; + return 1; +} + +uint32_t PlatformInfoManager::GetPlatformInfos( + const string SoCVersion, PlatFormInfos &platform_info, OptionalInfos &opti_compilation_info) +{ + (void)SoCVersion; + (void)platform_info; + (void)opti_compilation_info; + return 1; +} + +PlatformInfoManager::PlatformInfoManager() {} +PlatformInfoManager::~PlatformInfoManager() {} +} \ No newline at end of file diff --git a/src/kernels/tbe_adapter/platform/platform_infos_def.cpp b/src/kernels/tbe_adapter/platform/platform_infos_def.cpp new file mode 100644 index 0000000000000000000000000000000000000000..268de985efe061f8d8d55dc10b8b8423e20eecd5 --- /dev/null +++ b/src/kernels/tbe_adapter/platform/platform_infos_def.cpp @@ -0,0 +1,254 @@ +/* + * Copyright (c) 2024-2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include "platform/platform_infos_def.h" +#include +#include +#include "platform_infos_impl.h" + +namespace fe { +constexpr uint32_t MAX_CORE_NUM = 128; +std::mutex g_asdopsFePlatMutex; + +bool PlatFormInfos::Init() +{ + platform_infos_impl_ = std::make_shared(); + if (platform_infos_impl_ == nullptr) { + return false; + } + return true; +} + +bool PlatFormInfos::GetPlatformResWithLock(const std::string &label, const std::string &key, std::string &val) +{ + std::lock_guard lockGuard(g_asdopsFePlatMutex); + if (platform_infos_impl_ == nullptr) { + return false; + } + return platform_infos_impl_->GetPlatformRes(label, key, val); +} + +bool PlatFormInfos::GetPlatformResWithLock(const std::string &label, std::map &res) +{ + std::lock_guard lockGuard(g_asdopsFePlatMutex); + if (platform_infos_impl_ == nullptr) { + return false; + } + return platform_infos_impl_->GetPlatformRes(label, res); +} + +std::map> PlatFormInfos::GetAICoreIntrinsicDtype() +{ + if (platform_infos_impl_ == nullptr) { + return {}; + } + return platform_infos_impl_->GetAICoreIntrinsicDtype(); +} + +std::map> PlatFormInfos::GetVectorCoreIntrinsicDtype() +{ + if (platform_infos_impl_ == nullptr) { + return {}; + } + return platform_infos_impl_->GetVectorCoreIntrinsicDtype(); +} + +bool PlatFormInfos::GetPlatformRes(const std::string &label, const std::string &key, std::string &val) +{ + if (platform_infos_impl_ == nullptr) { + return false; + } + return platform_infos_impl_->GetPlatformRes(label, key, val); +} + +bool PlatFormInfos::GetPlatformRes(const std::string &label, std::map &res) +{ + if (platform_infos_impl_ == nullptr) { + return false; + } + return platform_infos_impl_->GetPlatformRes(label, res); +} + +void PlatFormInfos::SetAICoreIntrinsicDtype(std::map> &intrinsicDtypes) +{ + if (platform_infos_impl_ == nullptr) { + return; + } + platform_infos_impl_->SetAICoreIntrinsicDtype(intrinsicDtypes); +} + +void PlatFormInfos::SetVectorCoreIntrinsicDtype(std::map> &intrinsicDtypes) +{ + if (platform_infos_impl_ == nullptr) { + return; + } + platform_infos_impl_->SetVectorCoreIntrinsicDtype(intrinsicDtypes); +} + +void PlatFormInfos::SetFixPipeDtypeMap(const std::map> &fixpipeDtypeMap) +{ + if (platform_infos_impl_ == nullptr) { + return; + } + platform_infos_impl_->SetFixPipeDtypeMap(fixpipeDtypeMap); +} + +void PlatFormInfos::SetCoreNumByCoreType(const std::string &core_type) +{ + std::string coreNumStr; + std::string coreTypeStr; + if (core_type == "VectorCore") { + coreTypeStr = "vector_core_cnt"; + } else { + coreTypeStr = "ai_core_cnt"; + } + std::lock_guard lockGuard(g_asdopsFePlatMutex); + (void)GetPlatformRes("SoCInfo", coreTypeStr, coreNumStr); + MKI_LOG(DEBUG) << "Set PlatFormInfos::core_num_ to " << coreTypeStr << ": " << coreNumStr; + if (coreNumStr.empty()) { + core_num_ = 1; + MKI_LOG(ERROR) << "CoreNumStr is empty!"; + } else { + core_num_ = std::strtoul(coreNumStr.c_str(), nullptr, 10); // 10 进制 + if (core_num_ > MAX_CORE_NUM) { + core_num_ = 1; + MKI_LOG(ERROR) << "core_num is out of range : " << core_num_; + } + } +} + +uint32_t PlatFormInfos::GetCoreNumByType(const std::string &core_type) +{ + std::string coreNumStr; + std::string coreTypeStr = core_type == "VectorCore" ? "vector_core_cnt" : "ai_core_cnt"; + std::lock_guard lockGuard(g_asdopsFePlatMutex); + (void)GetPlatformRes("SoCInfo", coreTypeStr, coreNumStr); + MKI_LOG(DEBUG) << "Get PlatFormInfos::core_num_ to " << coreTypeStr << ": " << coreNumStr; + if (coreNumStr.empty()) { + MKI_LOG(ERROR) << "CoreNumStr is empty!"; + return 1; + } else { + uint32_t coreTypeNum = std::strtoul(coreNumStr.c_str(), nullptr, 10); // 10 进制 + if (coreTypeNum > MAX_CORE_NUM) { + MKI_LOG(ERROR) << "core_num is out of range : " << coreTypeNum; + return 1; + } + return coreTypeNum; + } +} + +void PlatFormInfos::SetCoreNum(const uint32_t &coreNum) +{ + MKI_LOG(DEBUG) << "Set PlatFormInfos::core_num_: " << coreNum; + core_num_ = coreNum; +} + +uint32_t PlatFormInfos::GetCoreNum() const +{ + MKI_LOG(DEBUG) << "Get PlatFormInfos::core_num_: " << core_num_; + return core_num_; +} + +void PlatFormInfos::GetLocalMemSize(const LocalMemType &memType, uint64_t &size) +{ + std::string sizeStr; + switch (memType) { + case LocalMemType::L0_A: { + (void)GetPlatformRes("AICoreSpec", "l0_a_size", sizeStr); + break; + } + case LocalMemType::L0_B: { + (void)GetPlatformRes("AICoreSpec", "l0_b_size", sizeStr); + break; + } + case LocalMemType::L0_C: { + (void)GetPlatformRes("AICoreSpec", "l0_c_size", sizeStr); + break; + } + case LocalMemType::L1: { + (void)GetPlatformRes("AICoreSpec", "l1_size", sizeStr); + break; + } + case LocalMemType::L2: { + (void)GetPlatformRes("SoCInfo", "l2_size", sizeStr); + break; + } + case LocalMemType::UB: { + (void)GetPlatformRes("AICoreSpec", "ub_size", sizeStr); + break; + } + case LocalMemType::HBM: { + (void)GetPlatformRes("SoCInfo", "memory_size", sizeStr); + break; + } + default: { + break; + } + } + + if (sizeStr.empty()) { + size = 0; + } else { + try { + size = static_cast(std::stoll(sizeStr.c_str())); + } catch (const std::invalid_argument &e) { + size = 0; + } catch (const std::out_of_range &e) { + size = 0; + } + } +} + +void PlatFormInfos::GetLocalMemBw(const LocalMemType &memType, uint64_t &bwSize) +{ + std::string bwSizeStr; + switch (memType) { + case LocalMemType::L2: { + (void)GetPlatformRes("AICoreMemoryRates", "l2_rate", bwSizeStr); + break; + } + case LocalMemType::HBM: { + (void)GetPlatformRes("AICoreMemoryRates", "ddr_rate", bwSizeStr); + break; + } + default: { + break; + } + } + + if (bwSizeStr.empty()) { + bwSize = 0; + } else { + try { + bwSize = static_cast(std::stoll(bwSizeStr.c_str())); + } catch (const std::invalid_argument &e) { + bwSize = 0; + } catch (const std::out_of_range &e) { + bwSize = 0; + } + } +} + +std::map> PlatFormInfos::GetFixPipeDtypeMap() +{ + if (platform_infos_impl_ == nullptr) { + return {}; + } + return platform_infos_impl_->GetFixPipeDtypeMap(); +} + +void PlatFormInfos::SetPlatformRes(const std::string &label, std::map &res) +{ + if (platform_infos_impl_ == nullptr) { + return; + } + platform_infos_impl_->SetPlatformRes(label, res); +} +} // namespace fe diff --git a/src/kernels/tbe_adapter/platform/platform_infos_impl.cpp b/src/kernels/tbe_adapter/platform/platform_infos_impl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ebb6cd63fddc38b2cdd52a5364a84726e1873eda --- /dev/null +++ b/src/kernels/tbe_adapter/platform/platform_infos_impl.cpp @@ -0,0 +1,72 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include "platform_infos_impl.h" + +namespace fe { +std::map> PlatFormInfosImpl::GetAICoreIntrinsicDtype() +{ + return aiCoreIntrinsicDtypeMap_; +} + +std::map> PlatFormInfosImpl::GetVectorCoreIntrinsicDtype() +{ + return vectorCoreIntrinsicDtypeMap_; +} + +bool PlatFormInfosImpl::GetPlatformRes(const std::string &label, const std::string &key, std::string &value) +{ + const auto itLabel = platformResMap_.find(label); + if (itLabel == platformResMap_.cend()) { + return false; + } + + auto itKey = itLabel->second.find(key); + if (itKey == itLabel->second.end()) { + return false; + } + + value = itKey->second; + return true; +} + +bool PlatFormInfosImpl::GetPlatformRes(const std::string &label, std::map &result) +{ + auto itLabel = platformResMap_.find(label); + if (itLabel == platformResMap_.end()) { + return false; + } + + result = itLabel->second; + return true; +} + +void PlatFormInfosImpl::SetAICoreIntrinsicDtype(std::map> &dtypes) +{ + aiCoreIntrinsicDtypeMap_ = dtypes; +} + +void PlatFormInfosImpl::SetVectorCoreIntrinsicDtype(std::map> &dtypes) +{ + vectorCoreIntrinsicDtypeMap_ = dtypes; +} + +void PlatFormInfosImpl::SetPlatformRes(const std::string &label, std::map &result) +{ + platformResMap_[label] = result; +} + +void PlatFormInfosImpl::SetFixPipeDtypeMap(const std::map> &dtypeMap) +{ + fixpipeDtypeMap_ = dtypeMap; +} + +std::map> PlatFormInfosImpl::GetFixPipeDtypeMap() { return fixpipeDtypeMap_; } +} // namespace fe diff --git a/src/kernels/tbe_adapter/platform/platform_infos_impl.h b/src/kernels/tbe_adapter/platform/platform_infos_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..2785950412a2a4a7ff95d3ce269e16a40a807d34 --- /dev/null +++ b/src/kernels/tbe_adapter/platform/platform_infos_impl.h @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef UTILS_PLATFORM_PLATFORM_INFOS_IMPL_H +#define UTILS_PLATFORM_PLATFORM_INFOS_IMPL_H + +#include +#include +#include +#include +#include "platform/platform_infos_def.h" + +namespace fe { +class PlatFormInfosImpl { +using PlatInfoMapType = std::map>; + +public: + PlatInfoMapType GetAICoreIntrinsicDtype(); + PlatInfoMapType GetVectorCoreIntrinsicDtype(); + PlatInfoMapType GetFixPipeDtypeMap(); + + void SetPlatformRes(const std::string &label, std::map &result); + bool GetPlatformRes(const std::string &label, const std::string &key, std::string &value); + bool GetPlatformRes(const std::string &label, std::map &result); + + void SetFixPipeDtypeMap(const PlatInfoMapType &dtypeMap); + void SetAICoreIntrinsicDtype(PlatInfoMapType &dtypes); + void SetVectorCoreIntrinsicDtype(PlatInfoMapType &dtypes); + +private: + PlatInfoMapType aiCoreIntrinsicDtypeMap_; + PlatInfoMapType vectorCoreIntrinsicDtypeMap_; + std::map> platformResMap_; + PlatInfoMapType fixpipeDtypeMap_; +}; +} // namespace fe + +#endif diff --git a/src/kernels/tbe_adapter/platform/tiling/platform/platform_ascendc.h b/src/kernels/tbe_adapter/platform/tiling/platform/platform_ascendc.h new file mode 100644 index 0000000000000000000000000000000000000000..5a01251e7d5a1514ae8d5541c5a0885243af6b30 --- /dev/null +++ b/src/kernels/tbe_adapter/platform/tiling/platform/platform_ascendc.h @@ -0,0 +1,162 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! + * \file platform_ascendc.h + * \brief + */ + +#ifndef PLATFORM_ASCENDC_H +#define PLATFORM_ASCENDC_H + +#include +#include + +#define ASCENDC_ASSERT(cond, behavior) \ + do { \ + if (!(cond)) { \ + behavior; \ + raise(SIGABRT); \ + } \ + } while (0) +namespace fe { +class PlatFormInfos; +} + +namespace platform_ascendc { +enum class CoreMemType { + L0_A = 0, + L0_B = 1, + L0_C = 2, + L1 = 3, + L2 = 4, + UB = 5, + HBM = 6, + RESERVED +}; + +enum class SocVersion { + ASCEND910 = 0, // Ascend910A, Ascend910B + ASCEND910B, // Ascend910B1~4, Ascend910B2C, Ascend910_93 Serials + ASCEND310P, // Ascend310P1, Ascend310P3 + ASCEND310B, // Ascend310B1, Ascend310B2, Ascend310B3, Ascend310B4 + ASCEND910_95, // ASCEND910_95, __DAV_C310__ + ASCEND910_55, // Ascend910_55, __DAV_310R6__ + AS31XM1, + ASCEND031, + ASCEND035, + ASCEND310, + ASCEND610, + ASCEND610Lite, + ASCEND910_93, + BS9SX1A, + BS9SX2A, + HI3796CV300CS, + HI3796CV300ES, + MC61AM21A, + SD3403, + KIRIN9010, + RESERVED_VERSION = 99999 +}; + +class PlatformAscendC { +public: + PlatformAscendC() = delete; + ~PlatformAscendC() = default; + explicit PlatformAscendC(fe::PlatFormInfos *platformInfo): platformInfo_(platformInfo) {} + /** + * Get Core Number + * On Ascend910B MIX model, return AICore number + * @return core number by core type + */ + uint32_t GetCoreNum(void) const; + /** + * Get Core Number AiCore + * @return ai_core_num + */ + uint32_t GetCoreNumAic(void) const; + /** + * Get Core Number VectorCore + * @return vector_core_num + */ + uint32_t GetCoreNumAiv(void) const; + /** + * Get Core Number VectorCore for m200 + * @return vector_core_num if m200, otherwise 0 + */ + uint32_t GetCoreNumVector(void) const; + /** + * Calc task schedule block dim + * @sliceNum number slice of data division + * @aicCoreNum value of GetCoreNumAic() if used cube API, otherwise 0 + * @aivCoreNum value of GetCoreNumAiv() if used vector API, otherwise 0 + * @return task schedule block dim + */ + uint32_t CalcTschBlockDim(uint32_t sliceNum, uint32_t aicCoreNum, uint32_t aivCoreNum) const; + /** + * Get Work Space Size + * @return work sapce size by chip type + */ + uint32_t GetLibApiWorkSpaceSize(void) const; + uint32_t GetResCubeGroupWorkSpaceSize(void) const; + uint32_t GetResGroupBarrierWorkSpaceSize(void) const; + void GetCoreMemSize(const CoreMemType &memType, uint64_t &size) const; + void GetCoreMemBw(const CoreMemType &memType, uint64_t &bwSize) const; + /** + * Get Soc Version Enum + * @return Enum SocVersion + */ + SocVersion GetSocVersion(void) const; + +private: + fe::PlatFormInfos *platformInfo_; + fe::PlatFormInfos* GetPlatFormInfo(void) const; +}; +class PlatformAscendCManager { +public: + static PlatformAscendC* GetInstance() + { + const std::lock_guard lock(platformInitMtx); + if (platformInfo == nullptr) { + PlatformAscendCManagerInit(nullptr); + if (platformInfo == nullptr) { + return nullptr; + } + } + return platformInfo; + } + static PlatformAscendC* GetInstance(const char *customSocVersion) + { + const std::lock_guard lock(platformInitMtx); + if (platformInfo == nullptr) { + PlatformAscendCManagerInit(customSocVersion); + if (platformInfo == nullptr) { + return nullptr; + } + } + return platformInfo; + } +private: + static PlatformAscendC *platformInfo; + static std::mutex platformInitMtx; + static PlatformAscendC* PlatformAscendCManagerInit(const char *customSocVersion); + static SocVersion SocVersionMap(const char *socVersionStr); + static fe::PlatFormInfos* PlatformAscendCInit(const char *customSocVersion); + PlatformAscendCManager(); + ~PlatformAscendCManager() = default; +}; +} +#endif diff --git a/src/kernels/tbe_adapter/stubs/canndev/ops/built-in/op_tiling/auto_tiling_register.cpp b/src/kernels/tbe_adapter/stubs/canndev/ops/built-in/op_tiling/auto_tiling_register.cpp new file mode 100644 index 0000000000000000000000000000000000000000..31f3e2a412fc1aed1e85967d78468e70ba1514ec --- /dev/null +++ b/src/kernels/tbe_adapter/stubs/canndev/ops/built-in/op_tiling/auto_tiling_register.cpp @@ -0,0 +1,22 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include "auto_tiling_register.h" + +std::array &AutoTilingRegister::RegisterParser() +{ + static std::array autoTilingParsers; + return autoTilingParsers; +} + +std::array &AutoTilingRegister::RegisterTiling() +{ + static std::array autoTilingFuncs; + return autoTilingFuncs; +} \ No newline at end of file diff --git a/src/kernels/tbe_adapter/stubs/canndev/ops/built-in/op_tiling/error_util.cpp b/src/kernels/tbe_adapter/stubs/canndev/ops/built-in/op_tiling/error_util.cpp new file mode 100644 index 0000000000000000000000000000000000000000..36f7e1f0c5b4dd2cf75718a589283d7b78640704 --- /dev/null +++ b/src/kernels/tbe_adapter/stubs/canndev/ops/built-in/op_tiling/error_util.cpp @@ -0,0 +1,19 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include "error_util.h" + +std::string ge::GetAttrValueErrMsg(const std::string &attr_name, const std::string &wrong_val, + const std::string &correct_val) +{ + std::string msg = + ConcatString("attr[", attr_name, "], has wrong value[", wrong_val, "], it should be ", correct_val); + return msg; +} \ No newline at end of file diff --git a/src/kernels/tbe_adapter/stubs/canndev/ops/built-in/op_tiling/op_tiling_util.cpp b/src/kernels/tbe_adapter/stubs/canndev/ops/built-in/op_tiling/op_tiling_util.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ef06e942cf97b51e5459087f6e8ee23a792236c6 --- /dev/null +++ b/src/kernels/tbe_adapter/stubs/canndev/ops/built-in/op_tiling/op_tiling_util.cpp @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include "op_tiling_util.h" + +namespace optiling { +const std::map STR_TO_DATATYPE = {{"float", DT_FLOAT}, + {"float32", DT_FLOAT}, + {"float16", DT_FLOAT16}, + {"int8", DT_INT8}, + {"int16", DT_INT16}, + {"int32", DT_INT32}, + {"int64", DT_INT64}, + {"uint8", DT_UINT8}, + {"uint16", DT_UINT16}, + {"uint32", DT_UINT32}, + {"uint64", DT_UINT64}, + {"bool", DT_BOOL}, + {"double", DT_DOUBLE}, + {"dual", DT_DUAL}, + {"dual_sub_int8", DT_DUAL_SUB_INT8}, + {"dual_sub_uint8", DT_DUAL_SUB_UINT8}, + {"int4", DT_INT4}, + {"bfloat16", DT_BF16}}; + +ge::DataType GetGeTypeFromStr(const std::string &dtype_str) +{ + auto it = STR_TO_DATATYPE.find(dtype_str); + if (it != STR_TO_DATATYPE.end()) { + return it->second; + } + OP_LOGW("GetGeTypeFromStr", "con not get the dtype[%s] in ge::DataType list. will return DT_MAX", + dtype_str.c_str()); + return DT_MAX; +} +} // namespace optiling \ No newline at end of file diff --git a/src/kernels/tbe_adapter/stubs/include/api/impl/host_log.h b/src/kernels/tbe_adapter/stubs/include/api/impl/host_log.h new file mode 100644 index 0000000000000000000000000000000000000000..c45191c762110d5eb33bf4841b3b74f3b31b107d --- /dev/null +++ b/src/kernels/tbe_adapter/stubs/include/api/impl/host_log.h @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2024-2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file host_log.h + * \brief + */ + +#ifndef IMPL_HOST_LOG_H +#define IMPL_HOST_LOG_H +#include +#include +#include +#include + +#define ASCENDC_HOST_ASSERT(cond, ret, format, ...) \ + do { \ + if (!(cond)) { \ + MKI_FLOG_ERROR("[%s] " #format, __FUNCTION__, ##__VA_ARGS__); \ + ret; \ + } \ + } while (0) + +// 0 debug, 1 info, 2 warning, 3 error +#define TILING_LOG_ERROR(format, ...) MKI_FLOG_ERROR("[%s] " #format, __FUNCTION__, ##__VA_ARGS__) + +#define TILING_LOG_INFO(format, ...) MKI_FLOG_INFO("[%s] " #format, __FUNCTION__, ##__VA_ARGS__) + +#define TILING_LOG_WARNING(format, ...) MKI_FLOG_WARN("[%s] " #format, __FUNCTION__, ##__VA_ARGS__) + +#define TILING_LOG_DEBUG(format, ...) MKI_FLOG_DEBUG("[%s] " #format, __FUNCTION__, ##__VA_ARGS__) +#endif // IMPL_HOST_LOG_H diff --git a/src/kernels/tbe_adapter/stubs/include/canndev/ops/built-in/op_tiling/aoe/runtime_kb/common/kb_log.h b/src/kernels/tbe_adapter/stubs/include/canndev/ops/built-in/op_tiling/aoe/runtime_kb/common/kb_log.h new file mode 100644 index 0000000000000000000000000000000000000000..57e91fa6126d0977d69393773b74e4a1ea68fa7d --- /dev/null +++ b/src/kernels/tbe_adapter/stubs/include/canndev/ops/built-in/op_tiling/aoe/runtime_kb/common/kb_log.h @@ -0,0 +1,23 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef ASCEND_OPS_STUB_KB_LOG_H +#define ASCEND_OPS_STUB_KB_LOG_H + +#include +#include +#include + +#define CANNKB_LOGD(format, ...) +#define CANNKB_LOGI(format, ...) +#define CANNKB_LOGW(format, ...) +#define CANNKB_LOGE(format, ...) +#define CANNKB_LOGEVENT(format, ...) + +#endif diff --git a/src/kernels/tbe_adapter/stubs/include/canndev/ops/built-in/op_tiling/auto_tiling_register.h b/src/kernels/tbe_adapter/stubs/include/canndev/ops/built-in/op_tiling/auto_tiling_register.h new file mode 100644 index 0000000000000000000000000000000000000000..14c424555b8f9d60b2ff2c1fced0b59aaff17ff1 --- /dev/null +++ b/src/kernels/tbe_adapter/stubs/include/canndev/ops/built-in/op_tiling/auto_tiling_register.h @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file auto_tiling_register.h + * \brief + */ +#ifndef ASCEND_OPS_STUB_AUTO_TILING_REGISTER_H +#define ASCEND_OPS_STUB_AUTO_TILING_REGISTER_H + +#include "vector_tiling_rt2.h" + +#include + +using AutoTilingFunc = bool (*)(gert::TilingContext *, const optiling::OpInfoImpl *); +using AutoTilingParseFunc = optiling::AutoTilingCompileInfo *(*)(const char *op_type, + const nlohmann::json &json_compile_info); + +#define REGISTER_AUTO_TILING(pattern, tilingfunc, parsefunc) \ + static AutoTilingRegister g_auto_tiling_register_##__COUNTER__(pattern, tilingfunc, parsefunc) + +constexpr size_t PATTERN_BASE = 0x10; +constexpr size_t PATTERN_SIZE = static_cast(optiling::SchPattern::DEFAULT) - PATTERN_BASE; + +inline size_t PatternIndex(optiling::SchPattern _pattern) { return static_cast(_pattern) - PATTERN_BASE; } + +class AutoTilingRegister { +public: + AutoTilingRegister(optiling::SchPattern _pattern, AutoTilingFunc _tiling_func, AutoTilingParseFunc _parser) + { + size_t index = PatternIndex(_pattern); + auto ®ister_parser = RegisterParser(); + register_parser[index] = _parser; + auto ®ister_tiling = RegisterTiling(); + register_tiling[index] = _tiling_func; + }; + ~AutoTilingRegister() = default; + static std::array &RegisterParser(); + static std::array &RegisterTiling(); +}; + +#endif // ASCEND_OPS_STUB_AUTO_TILING_REGISTER_H diff --git a/src/kernels/tbe_adapter/stubs/include/canndev/ops/built-in/op_tiling/error_log.h b/src/kernels/tbe_adapter/stubs/include/canndev/ops/built-in/op_tiling/error_log.h new file mode 100644 index 0000000000000000000000000000000000000000..4408043752552365fee1cb84aa01c717b4ee873c --- /dev/null +++ b/src/kernels/tbe_adapter/stubs/include/canndev/ops/built-in/op_tiling/error_log.h @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file error_log.h + * \brief + */ +#ifndef ASCEND_OPS_STUB_ERROR_LOG_H +#define ASCEND_OPS_STUB_ERROR_LOG_H + +#include +#include "op_log.h" + +using namespace Mki; +using namespace std; + +namespace optiling { +// ADD FORMAT LOG +#define VECTOR_INNER_ERR_REPORT_TILIING(op_name, err_msg, ...) \ + do { \ + OP_LOGE_WITHOUT_REPORT(op_name, err_msg, ##__VA_ARGS__); \ + REPORT_INNER_ERROR("E89999", "op[%s], " err_msg, get_cstr(get_op_info(op_name)), ##__VA_ARGS__); \ + } while (0) + +#define OP_TILING_CHECK(cond, log_func, expr) \ + do { \ + if (cond) { \ + log_func; \ + expr; \ + } \ + } while (0) +} // namespace optiling + +#endif // ASCEND_OPS_STUB_ERROR_LOG_H diff --git a/src/kernels/tbe_adapter/stubs/include/canndev/ops/built-in/op_tiling/op_tiling.h b/src/kernels/tbe_adapter/stubs/include/canndev/ops/built-in/op_tiling/op_tiling.h new file mode 100644 index 0000000000000000000000000000000000000000..07f31e594397240e815603c120fb1ee05ca55753 --- /dev/null +++ b/src/kernels/tbe_adapter/stubs/include/canndev/ops/built-in/op_tiling/op_tiling.h @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef ASCEND_OPS_STUB_OP_TILING_H +#define ASCEND_OPS_STUB_OP_TILING_H + +#include +#include + +#include "register/op_tiling_info.h" +#include "op_log.h" +#include + +using namespace std; + +namespace optiling { +const static bool prof_switch = false; +} + +#endif \ No newline at end of file diff --git a/src/kernels/tbe_adapter/stubs/include/canndev/ops/built-in/op_tiling/op_tiling_util.h b/src/kernels/tbe_adapter/stubs/include/canndev/ops/built-in/op_tiling/op_tiling_util.h new file mode 100644 index 0000000000000000000000000000000000000000..6974328457b5bf11351f6c41fb5050b41aa6cd35 --- /dev/null +++ b/src/kernels/tbe_adapter/stubs/include/canndev/ops/built-in/op_tiling/op_tiling_util.h @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef ASCEND_OPS_STUB_OP_TILING_UTIL_H +#define ASCEND_OPS_STUB_OP_TILING_UTIL_H + +#include +#include "op_attr.h" +#include "op_util.h" +#include "op_tiling.h" +#include "error_util.h" + +using namespace ge; + +namespace optiling { +ge::DataType GetGeTypeFromStr(const std::string &dtype_str); + +template +bool GetCompileValue(const nlohmann::json& all_vars, const std::string& name, T& value) +{ + if (all_vars.empty()) { + return false; + } + + if (all_vars.count(name) == 0) { + return false; + } + + value = all_vars[name].get(); + return true; +} + +template +bool GetCompileValue(const nlohmann::json& all_vars, const std::string& name, T1& value, const T2 default_value) +{ + if (!GetCompileValue(all_vars, name, value)) { + value = static_cast(default_value); + } + return true; +} +} + +#endif \ No newline at end of file diff --git a/src/kernels/tbe_adapter/stubs/include/canndev/ops/common/inc/error_util.h b/src/kernels/tbe_adapter/stubs/include/canndev/ops/common/inc/error_util.h new file mode 100644 index 0000000000000000000000000000000000000000..a5608736e56bf0f48b9305a7d889bc37a3f3f433 --- /dev/null +++ b/src/kernels/tbe_adapter/stubs/include/canndev/ops/common/inc/error_util.h @@ -0,0 +1,144 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef ASCEND_OPS_STUB_ERROR_UTIL_H +#define ASCEND_OPS_STUB_ERROR_UTIL_H + +#include +#include + +#include "error_code.h" +#include "graph/operator.h" +#include "op_log.h" + +#define VECTOR_INFER_SHAPE_INNER_ERR_REPORT(op_name, err_msg) \ + do { \ + OP_LOGE(op_name, "%s", get_cstr(err_msg)); \ + } while (0) + +#define CUBE_INNER_ERR_REPORT(op_name, err_msg, ...) \ + do { \ + OP_LOGE_WITHOUT_REPORT(op_name, err_msg, ##__VA_ARGS__); \ + REPORT_INNER_ERROR("E69999", "op[%s], " err_msg, get_cstr(get_op_info(op_name)), ##__VA_ARGS__); \ + } while (0) + +namespace ge { +template std::string DebugString(const std::vector &v) +{ + std::ostringstream oss; + oss << "["; + if (v.size() > 0) { + for (size_t i = 0; i < v.size() - 1; ++i) { + oss << v[i] << ", "; + } + oss << v[v.size() - 1]; + } + oss << "]"; + return oss.str(); +} + +template std::string DebugString(const std::vector> &v) +{ + std::ostringstream oss; + oss << "["; + for (size_t i = 0; i < v.size(); ++i) { + if (i != 0) { + oss << ", "; + } + oss << "(" << v[i].first << ", " << v[i].second << ")"; + } + oss << "]"; + return oss.str(); +} + +inline std::ostream &operator<<(std::ostream &os, const ge::Operator &op) { return os << get_op_info(op); } + +/* + * str cat util function + * param[in] params need concat to string + * return concatted string + */ +template std::string ConcatString(const T &arg) +{ + std::ostringstream oss; + oss << arg; + return oss.str(); +} + +template std::string ConcatString(const T &arg, const Ts &...arg_left) +{ + std::ostringstream oss; + oss << arg; + oss << ConcatString(arg_left...); + return oss.str(); +} + +template std::string Shape2String(const T &shape) +{ + std::ostringstream oss; + oss << "["; + if (shape.GetDimNum() > 0) { + for (size_t i = 0; i < shape.GetDimNum() - 1; ++i) { + oss << shape.GetDim(i) << ", "; + } + oss << shape.GetDim(shape.GetDimNum() - 1); + } + oss << "]"; + return oss.str(); +} + +std::string GetViewErrorCodeStr(ge::ViewErrorCode errCode); + +std::string GetShapeErrMsg(uint32_t index, const std::string &wrong_shape, const std::string &correct_shape); + +std::string GetAttrValueErrMsg(const std::string &attr_name, const std::string &wrong_val, + const std::string &correct_val); + +std::string GetAttrSizeErrMsg(const std::string &attr_name, const std::string &wrong_size, + const std::string &correct_size); + +std::string GetInputInvalidErrMsg(const std::string ¶m_name); +std::string GetShapeSizeErrMsg(uint32_t index, const std::string &wrong_shape_size, + const std::string &correct_shape_size); + +std::string GetInputFormatNotSupportErrMsg(const std::string ¶m_name, const std::string &expected_format_list, + const std::string &data_format); + +std::string GetInputDtypeNotSupportErrMsg(const std::string ¶m_name, const std::string &expected_dtype_list, + const std::string &data_dtype); + +std::string GetInputDTypeErrMsg(const std::string ¶m_name, const std::string &expected_dtype, + const std::string &data_dtype); + +std::string GetInputFormatErrMsg(const std::string ¶m_name, const std::string &expected_format, + const std::string &data_format); + +std::string SetAttrErrMsg(const std::string ¶m_name); +std::string UpdateParamErrMsg(const std::string ¶m_name); + +template +std::string GetParamOutRangeErrMsg(const std::string ¶m_name, const T &real_value, const T &min, const T &max); + +std::string OtherErrMsg(const std::string &error_detail); + +void TbeInputDataTypeErrReport(const std::string &op_name, const std::string ¶m_name, + const std::string &expected_dtype_list, const std::string &dtype); + +void GeInfershapeErrReport(const std::string &op_name, const std::string &op_type, const std::string &value, + const std::string &reason); +/* + * log common runtime error + * param[in] opname op name + * param[in] error description + * return void + */ +void CommonRuntimeErrLog(const std::string &opname, const std::string &description); +} // namespace ge + +#endif \ No newline at end of file diff --git a/src/kernels/tbe_adapter/stubs/include/canndev/ops/common/inc/op_const.h b/src/kernels/tbe_adapter/stubs/include/canndev/ops/common/inc/op_const.h new file mode 100644 index 0000000000000000000000000000000000000000..052cbe13a54b4af73af7e402c7d9b3ba2f15c7db --- /dev/null +++ b/src/kernels/tbe_adapter/stubs/include/canndev/ops/common/inc/op_const.h @@ -0,0 +1,215 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef ASCEND_OPS_STUB_OP_CONST_H +#define ASCEND_OPS_STUB_OP_CONST_H + +#include +#include "external/graph/operator.h" +#include "graph/utils/op_desc_utils.h" +#include "runtime/tiling_context.h" +#include "runtime/infer_shape_context.h" +#include "op_util.h" +#include "context_util.h" + +namespace ops { +using namespace ge; + +template +static void GetDataToVector(const uint8_t *const_data, size_t data_size, std::vector &result) +{ + if (const_data == nullptr || data_size == 0) { + return; + } + + size_t size = data_size / sizeof(T2); + result.resize(size); + const T2 *data = reinterpret_cast(const_data); + for (size_t i = 0; i < size; i++) { + result[i] = *(data + i); + } +} + +/* + * @brief: read constvalue from paras store into values + * @param [in] paras: ge::Operator + * @param [in] const_input_idx: constvalue axes index + * @param [out] values: vector to store return values. + * @return bool: flag of success or not + */ +template +bool GetConstIntData(const ge::Operator ¶s, const int64_t const_input_idx, std::vector &values) +{ + return true; +} + +/* + * @brief: read constvalue from paras store into value + * @param [in] paras: ge::Operator + * @param [in] const_input_idx: constvalue axes index + * @param [out] value: integer to store return value. + * @return bool: flag of success or not + */ +template bool GetConstInt(const ge::Operator ¶s, const int64_t const_input_idx, T &value) +{ + return true; +} + +/* + * @brief: read constvalue from paras store into value + * @param [in] context: gert::InferShapeContext + * @param [in] const_input_idx: constvalue axes index + * @param [out] value: integer to store return value. + * @return bool: flag of success or not + */ +template bool GetConstInt(gert::InferShapeContext *context, const int64_t const_input_idx, T &value) +{ + return true; +} + +/* + * @brief: read constvalue from paras store into value + * @param [in] context: gert::TilingContext + * @param [in] const_input_idx: constvalue axes index + * @param [out] value: integer to store return value. + * @return bool: flag of success or not + */ +template bool GetConstInt(gert::TilingContext *context, const int64_t const_input_idx, T &value) +{ + if (context == nullptr) { + return false; + } + + const gert::Tensor* const_tensor = context->GetInputTensor(const_input_idx); + OPS_CHECK_NULL_WITH_CONTEXT_RET(context, const_tensor, false); + if (!IsConstTensor(const_tensor)) { + OP_LOGW(context->GetNodeName(), "the input[%ld] is not const tensor, will return failed.", const_input_idx); + return false; + } + + ge::DataType dtype = const_tensor->GetDataType(); + switch (dtype) { + case ge::DT_UINT64: + value = static_cast(const_tensor->GetData()[0]); + break; + case ge::DT_INT64: + value = static_cast(const_tensor->GetData()[0]); + break; + case ge::DT_UINT32: + value = static_cast(const_tensor->GetData()[0]); + break; + case ge::DT_INT32: + value = static_cast(const_tensor->GetData()[0]); + break; + default: { + OP_LOGW(context->GetNodeName(), "GetConstInt only support [int32, int64, uint64, uint32]. but is %s", + ops::ToString(dtype).c_str()); + return false; + } + } + OP_LOGD("GetConstInt", "GetConstInt of value is %d", static_cast(value)); + return true; +} + +template static void GetConstValueToShape(const gert::Tensor *tensor, size_t size, gert::Shape *shape) +{ + if (tensor == nullptr) { + return; + } + + const T *value = tensor->GetData(); + if (value == nullptr) { + return; + } + + shape->SetDimNum(size); + for (size_t i = 0; i < size; i++) { + shape->SetDim(i, value[i]); + } +} + +template void GetValueToShape(const gert::Tensor *const_tensor, gert::Shape *const_shape) +{ + if (const_tensor == nullptr || const_shape == nullptr) { + return; + } + + const T *const_value = const_tensor->GetData(); + if (const_value == nullptr) { + return; + } + + const size_t const_num = const_tensor->GetShapeSize(); + const_shape->SetDimNum(0); + for (size_t i = 0; i < const_num; ++i) { + const_shape->AppendDim(const_value[i]); + } +} + +template void GetValueToShape(const gert::Tensor *const_tensor, gert::Shape &const_shape) +{ + if (const_tensor == nullptr) { + return; + } + + const T *const_value = const_tensor->GetData(); + if (const_value == nullptr) { + return; + } + + const size_t const_num = const_tensor->GetShapeSize(); + const_shape.SetDimNum(0); + for (size_t i = 0; i < const_num; ++i) { + const_shape.AppendDim(const_value[i]); + } +} + +template bool GetConstIntToShape(T *context, const int64_t const_idx, gert::Shape &const_shape) +{ + if (context == nullptr) { + return false; + } + + const gert::Tensor *const_tensor = context->GetInputTensor(const_idx); + OPS_CHECK_NULL_WITH_CONTEXT_RET(context, const_tensor, false); + if (!IsConstTensor(const_tensor)) { + OP_LOGW(context->GetNodeName(), "the input[%lld] is not const tensor, will return failed.", const_idx); + return false; + } + + ge::DataType const_dtype = const_tensor->GetDataType(); + + switch (const_dtype) { + case ge::DT_INT32: { + GetValueToShape(const_tensor, const_shape); + break; + } + case ge::DT_INT64: { + GetValueToShape(const_tensor, const_shape); + break; + } + case ge::DT_UINT64: { + GetValueToShape(const_tensor, const_shape); + break; + } + case ge::DT_UINT32: { + GetValueToShape(const_tensor, const_shape); + break; + } + default: + OP_LOGW(context->GetNodeName(), "GetConstIntToShape only support [int32, int64, uint64, uint32]. but is %s", + ops::ToString(const_dtype).c_str()); + return false; + } + + OP_LOGI(context->GetNodeName(), "GetConstIntToShape: output shape is %s", ToString(const_shape).c_str()); + return true; +} +} // namespace ops +#endif // CANN_OPS_BUILT_IN_OPS_CONST_H_ diff --git a/src/kernels/tbe_adapter/stubs/include/canndev/ops/common/inc/op_log.h b/src/kernels/tbe_adapter/stubs/include/canndev/ops/common/inc/op_log.h new file mode 100644 index 0000000000000000000000000000000000000000..e4b618d931cfad5f6e7ae6599d5e7edb3397f25f --- /dev/null +++ b/src/kernels/tbe_adapter/stubs/include/canndev/ops/common/inc/op_log.h @@ -0,0 +1,224 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef ASCEND_OPS_TBE_STUB_OP_LOG_H +#define ASCEND_OPS_TBE_STUB_OP_LOG_H + +#include +#include + +#include + +#include "ascend_string.h" +#include "common/util/error_manager/error_manager.h" +#include "graph/operator.h" +#include "graph/node.h" +#include "toolchain/slog.h" +#include "base/err_msg.h" + + +#define OPPROTO_SUBMOD_NAME "OP_PROTO" + +inline const char *get_cstr(const std::string &str) { return str.c_str(); } + +inline const char *get_cstr(const char *str) { return str; } + +inline const std::string &get_op_info(const std::string &str) { return str; } + +inline const char *get_op_info(const char *str) { return str; } + +inline std::string get_op_info(const ge::NodePtr &node) +{ + return node != nullptr ? node->GetType() + ":" + node->GetName() : "nil"; +} + +inline std::string get_op_info(const ge::OpDescPtr &node) +{ + return node != nullptr ? node->GetType() + ":" + node->GetName() : "nil"; +} + +template constexpr bool is_ge_operator_type() +{ + return std::is_base_of::type>::value; +} + +template typename std::enable_if(), std::string>::type get_op_info(const T &op) +{ + ge::AscendString name; + ge::AscendString type; + auto get_name_ret = op.GetName(name); + auto get_type_ret = op.GetOpType(type); + std::string op_info = get_type_ret == ge::GRAPH_SUCCESS ? type.GetString() : "nil"; + op_info += ":"; + op_info += get_name_ret == ge::GRAPH_SUCCESS ? name.GetString() : "nil"; + return op_info; +} + +template constexpr bool is_context_type() +{ + return !std::is_base_of::type>::value && + !std::is_same::type>::value && + !std::is_same::type>::value && + !std::is_same::type>::value && + !std::is_same::type>::value && + !std::is_same::type>::value; +} + +template typename std::enable_if(), std::string>::type get_op_info(T context) +{ + if (context == nullptr) { + return "nil:nil"; + } + std::string op_info = context->GetNodeType() != nullptr ? context->GetNodeType() : "nil"; + op_info += ":"; + op_info += context->GetNodeName() != nullptr ? context->GetNodeName() : "nil"; + return op_info; +} + +template std::string TbeGetName(const T &op) +{ + ge::AscendString op_ascend_name; + ge::graphStatus ret = op.GetName(op_ascend_name); + if (ret != ge::GRAPH_SUCCESS) { + std::string op_name = "None"; + return op_name; + } + return op_ascend_name.GetString(); +} + +template std::string TbeGetOpType(const T &op) +{ + ge::AscendString op_ascend_name; + ge::graphStatus ret = op.GetOpType(op_ascend_name); + if (ret != ge::GRAPH_SUCCESS) { + std::string op_name = "None"; + return op_name; + } + return op_ascend_name.GetString(); +} + +#define CHECK_DIVISOR_ZERO(divisor) \ + if ((divisor) == 0) { \ + return; \ + } + +#define CHECK_DIVISOR_ZERO_RET(divisor, ret) \ + if ((divisor) == 0) { \ + return ret; \ + } + +#define OP_CHECK(cond, log_func, return_expr) \ + if (cond) { \ + log_func; \ + return_expr; \ + } + +#define OP_LOGI(opname, ...) D_OP_LOGI(get_op_info(opname), __VA_ARGS__) +#define OP_LOGW(opname, ...) D_OP_LOGW(get_op_info(opname), __VA_ARGS__) + +#define OP_LOGE_WITHOUT_REPORT(opname, ...) D_OP_LOGE(get_op_info(opname), __VA_ARGS__) +#define OP_LOGE(op_name, ...) \ + do { \ + OP_LOGE_WITHOUT_REPORT(op_name, ##__VA_ARGS__); \ + REPORT_INNER_ERROR("EZ9999", ##__VA_ARGS__); \ + } while (0) + +#define OP_LOGD(opname, ...) D_OP_LOGD(get_op_info(opname), __VA_ARGS__) +#define OP_EVENT(opname, ...) D_OP_EVENT(get_op_info(opname), __VA_ARGS__) + +#define OP_LOG_SUB_DEBUG(op_info, fmt, ...) \ + MKI_FLOG_DEBUG("[%s] OpName:[%s] " #fmt, __FUNCTION__, get_cstr(op_info), ##__VA_ARGS__) +#define OP_LOG_SUB_INFO(op_info, fmt, ...) \ + MKI_FLOG_INFO("[%s] OpName:[%s] " #fmt, __FUNCTION__, get_cstr(op_info), ##__VA_ARGS__) +#define OP_LOG_SUB_WARN(op_info, fmt, ...) \ + MKI_FLOG_WARN("[%s] OpName:[%s] " #fmt, __FUNCTION__, get_cstr(op_info), ##__VA_ARGS__) +#define OP_LOG_SUB_ERROR(op_info, fmt, ...) \ + MKI_FLOG_ERROR("[%s] OpName:[%s] " #fmt, __FUNCTION__, get_cstr(op_info), ##__VA_ARGS__) + +#define D_OP_LOGI(opname, fmt, ...) OP_LOG_SUB_INFO(opname, fmt, ##__VA_ARGS__) +#define D_OP_LOGW(opname, fmt, ...) OP_LOG_SUB_WARN(opname, fmt, ##__VA_ARGS__) +#define D_OP_LOGE(opname, fmt, ...) OP_LOG_SUB_ERROR(opname, fmt, ##__VA_ARGS__) +#define D_OP_LOGD(opname, fmt, ...) OP_LOG_SUB_DEBUG(opname, fmt, ##__VA_ARGS__) +#define D_OP_EVENT(opname, fmt, ...) OP_LOG_SUB_INFO(opname, fmt, ##__VA_ARGS__) + +#define UNLIKELY(x) __builtin_expect((x), 0) +#define LIKELY(x) __builtin_expect((x), 1) + +#define OP_LOGE_IF(condition, return_value, op_name, fmt, ...) \ + static_assert(std::is_same::type>::value, "condition should be bool"); \ + do { \ + if (UNLIKELY(condition)) { \ + OP_LOGE(op_name, fmt, ##__VA_ARGS__); \ + return return_value; \ + } \ + } while (0) + +#define OP_LOGW_IF(condition, op_name, fmt, ...) \ + static_assert(std::is_same::type>::value, "condition should be bool"); \ + do { \ + if (UNLIKELY(condition)) { \ + OP_LOGW(op_name, fmt, ##__VA_ARGS__); \ + } \ + } while (0) + +#define OP_LOGI_IF_RETURN(condition, return_value, op_name, fmt, ...) \ + static_assert(std::is_same::type>::value, "condition should be bool"); \ + do { \ + if (UNLIKELY(condition)) { \ + OP_LOGI(op_name, fmt, ##__VA_ARGS__); \ + return return_value; \ + } \ + } while (0) + +constexpr const int OP_MAX_LOG_SIZE = 16000; +constexpr const int OP_MSG_HEADER_LEN = 200; +// print very long log. long line will be split to multipile lines +#define OP_LOG_FULL(opname, format, ...) \ + do { \ + if (Mki::LogLevel::DEBUG < Mki::LogCore::Instance().GetLogLevel()) { \ + break; \ + } \ + char msgbufxyz[OP_MAX_LOG_SIZE]; \ + size_t msgmaxlen = (MSG_LENGTH - OP_MSG_HEADER_LEN); \ + int rettmp = snprintf_s(msgbufxyz, sizeof(msgbufxyz), sizeof(msgbufxyz) - 1, format, ##__VA_ARGS__); \ + if (rettmp == -1) { \ + msgbufxyz[sizeof(msgbufxyz) - 1] = '\0'; \ + } \ + size_t msglength = std::strlen(msgbufxyz); \ + if (msglength < msgmaxlen) { \ + OP_LOG_SUB_DEBUG(opname, "%s", msgbufxyz); \ + break; \ + } \ + char *msgchunkbegin = msgbufxyz; \ + char *msgchunkend = nullptr; \ + while (msgchunkbegin < msgbufxyz + msglength) { \ + if (msgchunkbegin[0] == '\n') { \ + OP_LOG_SUB_DEBUG(opname, ""); \ + msgchunkbegin += 1; \ + continue; \ + } \ + msgchunkend = std::strchr(msgchunkbegin, '\n'); \ + if (msgchunkend == nullptr) { \ + msgchunkend = msgchunkbegin + std::strlen(msgchunkbegin); \ + } \ + while (msgchunkend > msgchunkbegin) { \ + std::string msgchunk(msgchunkbegin, \ + std::min(msgmaxlen, static_cast(msgchunkend - msgchunkbegin))); \ + OP_LOG_SUB_DEBUG(opname, "%s", msgchunk.c_str()); \ + msgchunkbegin += msgchunk.size(); \ + } \ + msgchunkbegin += 1; \ + } \ + } while (0) + +#define OP_LOGD_FULL(opname, ...) OP_LOG_FULL(get_op_info(opname), __VA_ARGS__) + +int CheckLogLevel(int moduleId, int logLevel); + +#endif \ No newline at end of file diff --git a/src/kernels/tbe_adapter/stubs/include/metadef/inc/common/ge_common/debug/ge_log.h b/src/kernels/tbe_adapter/stubs/include/metadef/inc/common/ge_common/debug/ge_log.h new file mode 100644 index 0000000000000000000000000000000000000000..7338ec965a52945be199a53b2fdc80b298c5f173 --- /dev/null +++ b/src/kernels/tbe_adapter/stubs/include/metadef/inc/common/ge_common/debug/ge_log.h @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef ASCEND_OPS_STUB_GE_LOG_H +#define ASCEND_OPS_STUB_GE_LOG_H + +#include +#include "ge_error_codes.h" +#include "common/util/error_manager/error_manager.h" +#include "external/ge_common/ge_api_error_codes.h" + +using string = std::string; + +#define INTERNAL_ERROR 4 + +#define GELOGE(ERROR_CODE, fmt, ...) \ + MKI_FLOG_ERROR("[%s] error code: %d, " #fmt, __FUNCTION__, ERROR_CODE, ##__VA_ARGS__) + +#define GELOGW(fmt, ...) MKI_FLOG_WARN("[%s] " #fmt, __FUNCTION__, ##__VA_ARGS__) +#define GELOGI(fmt, ...) MKI_FLOG_INFO("[%s] " #fmt, __FUNCTION__, ##__VA_ARGS__) +#define GELOGD(fmt, ...) MKI_FLOG_DEBUG("[%s] " #fmt, __FUNCTION__, ##__VA_ARGS__) + +#define GE_LOG_ERROR(...) +#define GE_CHECK_NOTNULL_JUST_RETURN(...) +#define GE_CHECK_NOTNULL(...) + +#define REPORT_CALL_ERROR REPORT_INNER_ERROR + +#define GE_IF_BOOL_EXEC(...) + +// If expr is not true, print the log and return the specified status +#define GE_CHK_BOOL_RET_STATUS(expr, _status, ...) \ + do { \ + const bool b = (expr); \ + if (!b) { \ + REPORT_INNER_ERROR("E19999", __VA_ARGS__); \ + GELOGE((_status), __VA_ARGS__); \ + return (_status); \ + } \ + } while (false) + +using Status = uint32_t; + +#endif // ASCEND_OPS_STUB_GE_LOG_H diff --git a/src/kernels/tbe_adapter/stubs/include/metadef/inc/common/ge_common/debug/log.h b/src/kernels/tbe_adapter/stubs/include/metadef/inc/common/ge_common/debug/log.h new file mode 100644 index 0000000000000000000000000000000000000000..4791a0e1dc3db64d8fedb322bba4167ccecabce4 --- /dev/null +++ b/src/kernels/tbe_adapter/stubs/include/metadef/inc/common/ge_common/debug/log.h @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef ASCEND_OPS_STUB_FRAMEWORK_COMMON_DEBUG_LOG_H +#define ASCEND_OPS_STUB_FRAMEWORK_COMMON_DEBUG_LOG_H + +#include + +#include "ge_log.h" + +namespace ge { +using Status = uint32_t; + +#define GE_CHK_BOOL_EXEC(expr, exec_expr, ...) \ + do { \ + const bool b = (expr); \ + if (!b) { \ + GELOGE(ge::FAILED, __VA_ARGS__); \ + exec_expr; \ + } \ + } while (0) +} // namespace ge +#endif // ASCEND_OPS_STUB_FRAMEWORK_COMMON_DEBUG_LOG_H diff --git a/src/kernels/tbe_adapter/stubs/include/metadef/inc/common/ge_common/util.h b/src/kernels/tbe_adapter/stubs/include/metadef/inc/common/ge_common/util.h new file mode 100644 index 0000000000000000000000000000000000000000..a91ae4d166e6adf8cdb328946dd6a6cd0b777b16 --- /dev/null +++ b/src/kernels/tbe_adapter/stubs/include/metadef/inc/common/ge_common/util.h @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef ASCEND_OPS_STUB_FRAMEWORK_COMMON_UTIL_H +#define ASCEND_OPS_STUB_FRAMEWORK_COMMON_UTIL_H + +#include + +#include "debug/ge_log.h" +#include "debug/log.h" +#include "external/ge_common/ge_api_error_codes.h" + +namespace ge { +} // namespace ge + +#define GE_RETURN_IF_ERROR(expr) \ + do { \ + const ge::Status _chk_status = (expr); \ + if (_chk_status != ge::SUCCESS) { \ + return _chk_status; \ + } \ + } while(false) + +#endif // ASCEND_OPS_STUB_FRAMEWORK_COMMON_UTIL_H diff --git a/src/kernels/tbe_adapter/stubs/include/metadef/inc/external/ge_common/ge_api_error_codes.h b/src/kernels/tbe_adapter/stubs/include/metadef/inc/external/ge_common/ge_api_error_codes.h new file mode 100644 index 0000000000000000000000000000000000000000..b46f10f58cfa08735d2d715aa771f77377176d6b --- /dev/null +++ b/src/kernels/tbe_adapter/stubs/include/metadef/inc/external/ge_common/ge_api_error_codes.h @@ -0,0 +1,23 @@ +/* + * Copyright (c) 2024-2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef ASCEND_OPS_STUB_GE_API_ERROR_CODES_H +#define ASCEND_OPS_STUB_GE_API_ERROR_CODES_H + +#include +#include "graph/ascend_string.h" + +namespace ge { +constexpr uint32_t SUCCESS = 0; +constexpr uint32_t FAILED = 1; +constexpr uint32_t PARAM_INVALID = 2; +} // namespace ge + +#endif \ No newline at end of file diff --git a/src/kernels/tbe_adapter/stubs/include/metadef/inc/external/platform/platform_info.h b/src/kernels/tbe_adapter/stubs/include/metadef/inc/external/platform/platform_info.h new file mode 100644 index 0000000000000000000000000000000000000000..234de704ffc94109a7218bf6201d7411c581de43 --- /dev/null +++ b/src/kernels/tbe_adapter/stubs/include/metadef/inc/external/platform/platform_info.h @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2024-2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef PLATFORM_INFO_H +#define PLATFORM_INFO_H + +#include +#include +#include + +using namespace std; +#define UNUSED_VALUE(x) (void)(x) + +namespace fe { +class PlatformInfo {}; +class PlatFormInfos { +public: + bool GetPlatformResWithLock(const std::string &label, const std::string &key, std::string &val); + + bool GetPlatformResWithLock(const std::string &label, std::map &res); +}; +class OptionalInfo {}; +class OptionalInfos {}; +class PlatformInfoManager { +public: + PlatformInfoManager(const PlatformInfoManager &) = delete; + PlatformInfoManager &operator=(const PlatformInfoManager &) = delete; + + static PlatformInfoManager &Instance(); + static PlatformInfoManager &GeInstance(); + uint32_t InitializePlatformInfo(); + + uint32_t GetPlatformInfoWithOutSocVersion(PlatformInfo &platformInfo, OptionalInfo &optiCompilationInfo); + + uint32_t GetPlatformInfoWithOutSocVersion(PlatFormInfos &platformInfo, OptionalInfos &optiCompilationInfo); + + uint32_t InitRuntimePlatformInfos(const std::string &socVersion); + +private: + PlatformInfoManager(); +}; +} // namespace fe + +#endif diff --git a/src/kernels/tbe_adapter/stubs/include/metadef/inc/graph/attribute_group/attr_group_serialize.h b/src/kernels/tbe_adapter/stubs/include/metadef/inc/graph/attribute_group/attr_group_serialize.h new file mode 100644 index 0000000000000000000000000000000000000000..4ccf7cc4ef2d96f9536a7e53d6ac7440338845b2 --- /dev/null +++ b/src/kernels/tbe_adapter/stubs/include/metadef/inc/graph/attribute_group/attr_group_serialize.h @@ -0,0 +1,32 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef INC_GRAPH_ATTR_GROUP_SERIALIZE_H +#define INC_GRAPH_ATTR_GROUP_SERIALIZE_H + +#include "graph/attr_store.h" +#include "graph/ge_error_codes.h" + +namespace ge { +namespace proto { +class AttrGroups; +} +class AttrGroupSerialize { +public: + static graphStatus SerializeAllAttr(proto::AttrGroups &attrGroups, const AttrStore &attrStore); + static graphStatus DeserializeAllAttr(const proto::AttrGroups &attrGroup, AttrStore &attrStore); + +private: + static graphStatus OtherGroupSerialize(proto::AttrGroups &attrGroups, const AttrStore &attrStore); + static graphStatus OtherGroupDeserialize(const proto::AttrGroups &attrGroups, AttrStore &attrStore); +}; +} // namespace ge + +#endif // INC_GRAPH_ATTR_GROUP_SERIALIZE_H diff --git a/src/kernels/tbe_adapter/stubs/include/opdev/internal/op_binary_resource_manager_impl.h b/src/kernels/tbe_adapter/stubs/include/opdev/internal/op_binary_resource_manager_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..1885c37a0b6ca4892561ae1855aec08dcc61f754 --- /dev/null +++ b/src/kernels/tbe_adapter/stubs/include/opdev/internal/op_binary_resource_manager_impl.h @@ -0,0 +1,16 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef ASCEND_OPS_NNOPBASE_OP_DEV_OP_BINARY_RESOURCE_MANAGER_IMPL +#define ASCEND_OPS_NNOPBASE_OP_DEV_OP_BINARY_RESOURCE_MANAGER_IMPL + +namespace nnopbase { +} // nnopbase +#endif // ASCEND_OPS_NNOPBASE_OP_DEV_OP_BINARY_RESOURCE_MANAGER_IMPL \ No newline at end of file diff --git a/src/kernels/tbe_adapter/stubs/metadef/error_manager/error_manager.cpp b/src/kernels/tbe_adapter/stubs/metadef/error_manager/error_manager.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a5c86444b1f4ceda03b7fea1d05562dddbf57601 --- /dev/null +++ b/src/kernels/tbe_adapter/stubs/metadef/error_manager/error_manager.cpp @@ -0,0 +1,105 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include "common/util/error_manager/error_manager.h" + +#define UNUSED_VALUE(x) (void)(x) +using char_t = char; + +namespace error_message { +int32_t FormatErrorMessage(char_t *str_dst, size_t dst_max, const char_t *format, ...) +{ + UNUSED_VALUE(str_dst); + UNUSED_VALUE(dst_max); + UNUSED_VALUE(format); + return 0; +} + +void ReportInnerError(const char_t *file_name, const char_t *func, uint32_t line, const std::string error_code, + const char_t *format, ...) +{ + UNUSED_VALUE(file_name); + UNUSED_VALUE(func); + UNUSED_VALUE(line); + UNUSED_VALUE(error_code); + UNUSED_VALUE(format); +} +} + +ErrorManager &ErrorManager::GetInstance() +{ + static ErrorManager errorManager; + return errorManager; +} + +int32_t ErrorManager::ReportInterErrMessage(const std::string error_code, const std::string &error_msg) +{ + UNUSED_VALUE(error_code); + UNUSED_VALUE(error_msg); + return 0; +} + +int32_t ErrorManager::ReportErrMessage(const std::string error_code, + const std::map &args_map) +{ + UNUSED_VALUE(error_code); + UNUSED_VALUE(args_map); + return 0; +} + +std::string ErrorManager::GetErrorMessage() +{ + return ""; +} + +std::string ErrorManager::GetWarningMessage() +{ + return ""; +} + +int32_t ErrorManager::OutputErrMessage(int32_t handle) +{ + UNUSED_VALUE(handle); + return 0; +} + +int32_t ErrorManager::OutputMessage(int32_t handle) +{ + UNUSED_VALUE(handle); + return 0; +} + +namespace ge { +int32_t ReportInnerErrMsg(const char *file_name, const char *func, uint32_t line, const char *error_code, + const char *format, ...) +{ + UNUSED_VALUE(file_name); + UNUSED_VALUE(func); + UNUSED_VALUE(line); + UNUSED_VALUE(error_code); + UNUSED_VALUE(format); + return 0; +} + +int32_t ReportUserDefinedErrMsg(const char *error_code, const char *format, ...) +{ + UNUSED_VALUE(error_code); + UNUSED_VALUE(format); + return 0; +} + +int32_t ReportPredefinedErrMsg(const char *error_code, const std::vector &key, + const std::vector &value) +{ + UNUSED_VALUE(error_code); + UNUSED_VALUE(key); + UNUSED_VALUE(value); + return 0; +} +} \ No newline at end of file diff --git a/src/kernels/tbe_adapter/stubs/metadef/graph/attr/attr_store.cpp b/src/kernels/tbe_adapter/stubs/metadef/graph/attr/attr_store.cpp new file mode 100644 index 0000000000000000000000000000000000000000..88e9378f04b44c157fac8d440b38682eaa6d1dc1 --- /dev/null +++ b/src/kernels/tbe_adapter/stubs/metadef/graph/attr/attr_store.cpp @@ -0,0 +1,425 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include "graph/attr_store.h" +#include "attribute_group/attr_group_base.h" +#include "attribute_group/attr_group_serialize.h" +#include "checker.h" +#include "common/ge_common/debug/ge_log.h" + +namespace ge { +const AttrId CONST_INVALID_ATTR_ID = GetAttrId(0xffffffffU, 0U); + +AnyValue *AttrStore::GetOrCreateAnyValue(const AttrId attr_id) const +{ + return const_cast(GetAnyValue(attr_id)); +} +AnyValue *AttrStore::MutableAnyValue(const AttrId attr_id) const noexcept +{ + return const_cast(GetAnyValue(attr_id)); +} +const AnyValue *AttrStore::GetAnyValue(const AttrId attr_id) const noexcept +{ + const auto attrType = GetAttrType(attr_id); + if (attrType == static_cast(AttrType::kAttrPredefinedInIr)) { + return pre_defined_attrs_.GetAnyValue(GetSubAttrId(attr_id)); + } else if (attrType == static_cast(AttrType::kAttrGeneral)) { + return nullptr; // general不支持 + } else { + // empty + } + return nullptr; +} +AttrStore AttrStore::Create(const size_t pre_defined_attr_count) +{ + AttrStore as; + as.pre_defined_attrs_.Resize(pre_defined_attr_count); + return as; +} + +AttrStore::AttrStore(AttrStore &&other) +{ + names_to_id_.swap(other.names_to_id_); + Swap(other); + other_attrs_.Swap(other.other_attrs_); + + attrs_groups_ptr_ = std::move(other.attrs_groups_ptr_); + other.attrs_groups_ptr_.clear(); +} + +AttrStore &AttrStore::operator=(AttrStore &&other) +{ + if (this == &other) { + return *this; + } + names_to_id_.swap(other.names_to_id_); + Swap(other); + other_attrs_.Swap(other.other_attrs_); + + for (auto &iter : attrs_groups_ptr_) { + if (iter.second != nullptr) { + iter.second.reset(); + } + } + attrs_groups_ptr_.clear(); + attrs_groups_ptr_ = std::move(other.attrs_groups_ptr_); + other.attrs_groups_ptr_.clear(); + return *this; +} + +void AttrStore::CopyAttrStoreAllMembers(const AttrStore &other) +{ + names_to_id_ = other.names_to_id_; + pre_defined_attrs_ = other.pre_defined_attrs_; + general_attrs_ = other.general_attrs_; + other_attrs_ = other.other_attrs_; + + for (auto &otherAttrsPtr : other.attrs_groups_ptr_) { + if (otherAttrsPtr.second != nullptr) { + auto attrsGroupPtr = otherAttrsPtr.second->Clone(); + if (attrsGroupPtr == nullptr) { + GELOGE(GRAPH_FAILED, "Failed to alloc memory for attribute group."); + } + attrs_groups_ptr_.emplace(otherAttrsPtr.first, std::move(attrsGroupPtr)); + } + } +} + +AttrStore::AttrStore(const AttrStore &other) { CopyAttrStoreAllMembers(other); } + +AttrStore &AttrStore::operator=(const AttrStore &other) +{ + if (this == &other) { + return *this; + } + for (auto &attrsGroupPtr : attrs_groups_ptr_) { + if (attrsGroupPtr.second != nullptr) { + attrsGroupPtr.second.reset(); + } + } + attrs_groups_ptr_.clear(); + CopyAttrStoreAllMembers(other); + return *this; +} + +const AnyValue *AttrStore::GetAnyValue(const std::string &name) const noexcept +{ + const auto id = GetIdByName(name); + if (id != CONST_INVALID_ATTR_ID) { + return pre_defined_attrs_.GetAnyValue(GetSubAttrId(id)); + } + + const AnyValue *const av = general_attrs_.GetAnyValue(name); + if (av != nullptr) { + return av; + } + + return nullptr; +} +AnyValue *AttrStore::MutableAnyValue(const std::string &name) const noexcept +{ + return const_cast(GetAnyValue(name)); +} +AnyValue *AttrStore::GetOrCreateAnyValue(const std::string &name) +{ + const auto id = GetIdByName(name); + if (id != CONST_INVALID_ATTR_ID) { + return pre_defined_attrs_.GetOrCreateAnyValue(GetSubAttrId(id)); + } + return general_attrs_.GetOrCreateAnyValue(name); +} +AttrId AttrStore::GetIdByName(const std::string &name) const noexcept +{ + const auto iter = names_to_id_.find(name); + if (iter == names_to_id_.end()) { + return CONST_INVALID_ATTR_ID; + } + return iter->second; +} +void AttrStore::SetNameAndId(std::string name, const AttrId id) { names_to_id_[std::move(name)] = id; } +bool AttrStore::Exists(const AttrId attr_id) const noexcept { return GetAnyValue(attr_id) != nullptr; } +bool AttrStore::Exists(const std::string &name) const noexcept { return GetAnyValue(name) != nullptr; } +bool AttrStore::Delete(const std::string &name) +{ + const auto iter = names_to_id_.find(name); + if (iter != names_to_id_.end()) { + const auto subId = GetSubAttrId(iter->second); + (void)names_to_id_.erase(iter); + return pre_defined_attrs_.Delete(subId); + } + return general_attrs_.Delete(name); +} +std::set AttrStore::GetAllAttrNames() const +{ + std::set names; + for (const auto &iter : names_to_id_) { + (void)names.insert(iter.first); + } + general_attrs_.GetAllNames(names); + return names; +} +std::map AttrStore::GetAllAttrs() const { return GetAllAttrsWithFilter(nullptr); } +std::map AttrStore::GetAllAttrsWithFilter(const AttrNameFilter &attr_filter) const +{ + std::map attrs; + for (const auto &iter : names_to_id_) { + const auto av = pre_defined_attrs_.GetAnyValue(GetSubAttrId(iter.second)); + if (av == nullptr) { + // error + continue; + } + if (av->IsEmpty()) { + continue; + } + if ((attr_filter != nullptr) && (!attr_filter(iter.first))) { + continue; + } + attrs[iter.first] = *av; + } + general_attrs_.GetAllAttrsWithFilter(attrs, attr_filter); + return attrs; +} +void AttrStore::Swap(AttrStore &other) +{ + pre_defined_attrs_.Swap(other.pre_defined_attrs_); + general_attrs_.Swap(other.general_attrs_); +} + +void AttrStore::PreDefinedAttrStore::Resize(const size_t s) { attrs_.resize(s); } +bool AttrStore::PreDefinedAttrStore::Exists(const AttrSubId index) const noexcept +{ + if (index >= attrs_.size()) { + return false; + } + return !attrs_[static_cast(index)].IsEmpty(); +} +bool AttrStore::PreDefinedAttrStore::Delete(const AttrSubId index) +{ + if (!Exists(index)) { + return false; + } + attrs_[static_cast(index)].Clear(); + return true; +} +AnyValue *AttrStore::PreDefinedAttrStore::GetOrCreateAnyValue(const AttrSubId index) const +{ + return const_cast(GetAnyValue(index)); +} +AnyValue *AttrStore::PreDefinedAttrStore::MutableAnyValue(const AttrSubId index) const noexcept +{ + return const_cast(GetAnyValue(index)); +} +const AnyValue *AttrStore::PreDefinedAttrStore::GetAnyValue(const AttrSubId index) const noexcept +{ + if (index >= attrs_.size()) { + return nullptr; + } + return &attrs_[static_cast(index)]; +} +void AttrStore::PreDefinedAttrStore::Swap(AttrStore::PreDefinedAttrStore &other) { attrs_.swap(other.attrs_); } +bool AttrStore::CustomDefinedAttrStore::Exists(const std::string &name) const noexcept +{ + return attrs_.count(name) > 0UL; +} +bool AttrStore::CustomDefinedAttrStore::Delete(const std::string &name) { return attrs_.erase(name) == 1UL; } +AnyValue *AttrStore::CustomDefinedAttrStore::GetOrCreateAnyValue(const std::string &name) { return &attrs_[name]; } +AnyValue *AttrStore::CustomDefinedAttrStore::MutableAnyValue(const std::string &name) const noexcept +{ + return const_cast(GetAnyValue(name)); +} +const AnyValue *AttrStore::CustomDefinedAttrStore::GetAnyValue(const std::string &name) const noexcept +{ + const auto iter = attrs_.find(name); + if (iter != attrs_.end()) { + return &iter->second; + } else { + return nullptr; + } +} +void AttrStore::CustomDefinedAttrStore::GetAllNames(std::set &names) const +{ + for (const auto &iter : attrs_) { + (void)names.insert(iter.first); + } +} +void AttrStore::CustomDefinedAttrStore::GetAllAttrs(std::map &names_to_attr) const +{ + for (const auto &iter : attrs_) { + names_to_attr[iter.first] = iter.second; + } +} +void AttrStore::CustomDefinedAttrStore::GetAllAttrsWithFilter(std::map &names_to_attr, + const AttrNameFilter &attr_filter) const +{ + for (const auto &iter : attrs_) { + if ((attr_filter != nullptr) && (!attr_filter(iter.first))) { + continue; + } + names_to_attr[iter.first] = iter.second; + } +} +void AttrStore::CustomDefinedAttrStore::Swap(AttrStore::CustomDefinedAttrStore &other) { attrs_.swap(other.attrs_); } +bool AttrStore::SetAnyValueByName(const std::string &name, const AnyValue &value) +{ + const auto av = GetOrCreateAnyValue(name); + if (av == nullptr) { + return false; + } + *av = value; + return true; +} +void AttrStore::Clear() +{ + pre_defined_attrs_.Clear(); + general_attrs_.Clear(); +} +void AttrStore::PreDefinedAttrStore::Clear() { attrs_.clear(); } +void AttrStore::CustomDefinedAttrStore::Clear() { attrs_.clear(); } + +graphStatus AttrStore::SetAttrToOtherGroup(const std::string &attr, const AnyValue &value) +{ + return other_attrs_.SetAttr(attr, value); +} + +graphStatus AttrStore::GetAttrFromOtherGroup(const std::string &attr, AnyValue &value) const +{ + return other_attrs_.GetAttr(attr, value); +} + +const std::unordered_map &AttrStore::FastGetAllAttrsFromOtherGroup() const +{ + return other_attrs_.FastGetAllAttr(); +} + +std::unordered_map AttrStore::GetAllAttrsFromOtherGroup() const +{ + return other_attrs_.GetAllAttr(); +} + +bool AttrStore::CheckAttrIsExistInOtherGroup(const std::string &attr) const +{ + return other_attrs_.CheckAttrIsExist(attr); +} + +bool AttrStore::DeleteSingleAttrsInOtherGroup(const std::string &attr) { return other_attrs_.DeleteSingleAttr(attr); } + +void AttrStore::DeleteAllAttrsInOtherGroup() { other_attrs_.DeleteAllAttrs(); } + +const AttrGroupsMap &AttrStore::GetAttrsGroupPtr() const { return attrs_groups_ptr_; } + +void AttrStore::ClearAllAttrs() +{ + other_attrs_.DeleteAllAttrs(); + for (auto &attrsGroupPtr : attrs_groups_ptr_) { + if (attrsGroupPtr.second != nullptr) { + attrsGroupPtr.second.reset(); + } + } + attrs_groups_ptr_.clear(); +} + +void AttrStore::ClearAllAttrsInOtherAttrs() { return other_attrs_.DeleteAllAttrs(); } + +bool AttrStore::ClearAttrInOtherAttrs(const std::string &attr_name) { return other_attrs_.DeleteSingleAttr(attr_name); } + +const std::unordered_set OtherAttrs::valid_attrs_ = {"Max memory"}; + +bool OtherAttrs::CheckAttrIsValid(const std::string &attr) const +{ + if (valid_attrs_.find(attr) != valid_attrs_.end()) { + return true; + } + return false; +} + +graphStatus OtherAttrs::SetAttr(const std::string &attr, const AnyValue &value) +{ + if (CheckAttrIsValid(attr)) { + keys_to_attrs_[attr] = value; + return GRAPH_SUCCESS; + } + + REPORT_INNER_ERROR("E18888", "Failed to set the %s.", attr.c_str()); + GELOGE(ge::GRAPH_FAILED, "Failed to set the %s.", attr.c_str()); + return GRAPH_FAILED; +} + +graphStatus OtherAttrs::GetAttr(const std::string &attr, AnyValue &value) const +{ + auto iter = keys_to_attrs_.find(attr); + if (iter != keys_to_attrs_.end()) { + value = iter->second; + return GRAPH_SUCCESS; + } + + REPORT_INNER_ERROR("E18888", "Failed to find the %s.", attr.c_str()); + GELOGE(ge::GRAPH_FAILED, "Failed to find the %s.", attr.c_str()); + return GRAPH_FAILED; +} + +const std::unordered_map &OtherAttrs::FastGetAllAttr() const { return keys_to_attrs_; } + +std::unordered_map OtherAttrs::GetAllAttr() const { return keys_to_attrs_; } + +bool OtherAttrs::CheckAttrIsExist(const std::string &attr) const +{ + auto iter = keys_to_attrs_.find(attr); + if (iter != keys_to_attrs_.end()) { + return true; + } + + return false; +} + +bool OtherAttrs::DeleteSingleAttr(const std::string &attr) +{ + auto iter = keys_to_attrs_.find(attr); + if (iter != keys_to_attrs_.end()) { + keys_to_attrs_.erase(iter); + return true; + } + + return false; +} + +void OtherAttrs::DeleteAllAttrs() { keys_to_attrs_.clear(); } + +void OtherAttrs::Swap(OtherAttrs &other) { keys_to_attrs_.swap(other.keys_to_attrs_); } + +graphStatus AttrGroupSerialize::SerializeAllAttr(proto::AttrGroups &attrGroups, const AttrStore &attrStore) +{ + (void)attrStore; + (void)attrGroups; + + return GRAPH_SUCCESS; +} + +graphStatus AttrGroupSerialize::DeserializeAllAttr(const proto::AttrGroups &attrGroup, AttrStore &attrStore) +{ + (void)attrStore; + (void)attrGroup; + + return GRAPH_SUCCESS; +} + +graphStatus AttrGroupSerialize::OtherGroupSerialize(proto::AttrGroups &attrGroups, const AttrStore &attrStore) +{ + (void)attrStore; + (void)attrGroups; + return GRAPH_SUCCESS; +} + +graphStatus AttrGroupSerialize::OtherGroupDeserialize(const proto::AttrGroups &attrGroups, AttrStore &attrStore) +{ + (void)attrStore; + (void)attrGroups; + return GRAPH_SUCCESS; +} +} // namespace ge diff --git a/src/kernels/tbe_adapter/stubs/metadef/graph/detail/attributes_holder.cpp b/src/kernels/tbe_adapter/stubs/metadef/graph/detail/attributes_holder.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c586095e384bebdb48fa8b16251bd77678de0003 --- /dev/null +++ b/src/kernels/tbe_adapter/stubs/metadef/graph/detail/attributes_holder.cpp @@ -0,0 +1,88 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include "graph/detail/attributes_holder.h" + +#define UNUSED_VALUE(x) (void)(x) +namespace ge { +void AttrHolder::CopyAttrsFrom(const AttrHolder &holder) +{ + UNUSED_VALUE(holder); +} +void AttrHolder::CopyFrom(const AttrHolder &holder) +{ + UNUSED_VALUE(holder); +} + +graphStatus AttrHolder::SetAttr(const std::string &name, const AnyValue &value) +{ + UNUSED_VALUE(name); + UNUSED_VALUE(value); + return GRAPH_SUCCESS; +} +graphStatus AttrHolder::TrySetAttr(const std::string &name, const AnyValue &value) +{ + UNUSED_VALUE(name); + UNUSED_VALUE(value); + return GRAPH_SUCCESS; +} +graphStatus AttrHolder::AddRequiredAttr(const std::string &name) +{ + UNUSED_VALUE(name); + return GRAPH_SUCCESS; +} + +graphStatus AttrHolder::GetAttr(const std::string &name, AnyValue &value) const +{ + UNUSED_VALUE(name); + UNUSED_VALUE(value); + return GRAPH_SUCCESS; +} + +bool AttrHolder::HasAttr(const std::string &name) const +{ + UNUSED_VALUE(name); + return true; +} + +graphStatus AttrHolder::DelAttr(const std::string &name) +{ + UNUSED_VALUE(name); + return GRAPH_SUCCESS; +} + +const std::map AttrHolder::GetAllAttrs() const +{ + static std::map all; + return all; +} + +const std::set AttrHolder::GetAllAttrNames() const +{ + static std::set all; + return all; +} + +template <> void GeIrProtoHelper::InitDefault() {} + +template <> void GeIrProtoHelper::InitDefault() {} + +template <> void GeIrProtoHelper::InitDefault() {} + +template <> void GeIrProtoHelper::InitDefault() {} + +template <> void GeIrProtoHelper::InitDefault() {} + +template <> void GeIrProtoHelper::InitDefault() {} + +template <> void GeIrProtoHelper::InitDefault() {} + +template <> void GeIrProtoHelper::InitDefault() {} +} // namespace ge diff --git a/src/kernels/tbe_adapter/stubs/metadef/graph/ge_attr_value.cpp b/src/kernels/tbe_adapter/stubs/metadef/graph/ge_attr_value.cpp new file mode 100644 index 0000000000000000000000000000000000000000..aa3c7ec508bc2a4db693e5e22df41c0bc6382884 --- /dev/null +++ b/src/kernels/tbe_adapter/stubs/metadef/graph/ge_attr_value.cpp @@ -0,0 +1,701 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include "graph/ge_attr_value.h" + +#include "graph/utils/graph_utils.h" + +#define UNUSED_VALUE(x) (void)(x) +namespace ge { +void NamedAttrs::SetName(const std::string &name) { name_ = name; } + +std::string NamedAttrs::GetName() const { return name_; } + +AnyValue NamedAttrs::GetItem(const std::string &key) const +{ + AnyValue value; + (void)GetAttr(key, value); + return value; +} + +ProtoAttrMap &NamedAttrs::MutableAttrMap() { return attrs_; } + +ConstProtoAttrMap &NamedAttrs::GetAttrMap() const { return attrs_; } + +bool AttrUtils::HasAttr(ConstAttrHolderAdapter &&obj, const std::string &name) +{ + if (!obj) { + return false; + } + return obj->HasAttr(name); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetInt(ConstAttrHolderAdapter &&obj, + const std::string &name, + int32_t &value) +{ + UNUSED_VALUE(obj); + UNUSED_VALUE(name); + UNUSED_VALUE(value); + return true; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetInt(ConstAttrHolderAdapter &&obj, + const std::string &name, + uint32_t &value) +{ + UNUSED_VALUE(obj); + UNUSED_VALUE(name); + UNUSED_VALUE(value); + return true; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr AttrUtils::CloneOpDesc(const ConstOpDescPtr &org_op_desc) +{ + UNUSED_VALUE(org_op_desc); + return nullptr; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr AttrUtils::CopyOpDesc(const ConstOpDescPtr &org_op_desc) +{ + UNUSED_VALUE(org_op_desc); + return nullptr; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::SetListInt(AttrHolderAdapter &&obj, + const std::string &name, + const std::vector &value) +{ + UNUSED_VALUE(obj); + UNUSED_VALUE(name); + UNUSED_VALUE(value); + return true; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetListInt(ConstAttrHolderAdapter &&obj, + const std::string &name, + std::vector &value) +{ + UNUSED_VALUE(obj); + UNUSED_VALUE(name); + UNUSED_VALUE(value); + return true; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::SetInt(AttrHolderAdapter &&obj, + const std::string &name, + const int64_t &value) +{ + UNUSED_VALUE(obj); + UNUSED_VALUE(name); + UNUSED_VALUE(value); + return true; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetInt(ConstAttrHolderAdapter &&obj, + const std::string &name, + int64_t &value) +{ + UNUSED_VALUE(obj); + UNUSED_VALUE(name); + UNUSED_VALUE(value); + return true; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetInt(ConstAttrHolderAdapter &&obj, + const std::string &name, + uint64_t &value) +{ + UNUSED_VALUE(obj); + UNUSED_VALUE(name); + UNUSED_VALUE(value); + return true; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::SetFloat(AttrHolderAdapter &&obj, + const std::string &name, + const float32_t &value) +{ + UNUSED_VALUE(obj); + UNUSED_VALUE(name); + UNUSED_VALUE(value); + return true; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetFloat(ConstAttrHolderAdapter &&obj, + const std::string &name, + float32_t &value) +{ + UNUSED_VALUE(obj); + UNUSED_VALUE(name); + UNUSED_VALUE(value); + return true; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::SetListFloat(AttrHolderAdapter &&obj, + const std::string &name, + const std::vector &value) +{ + UNUSED_VALUE(obj); + UNUSED_VALUE(name); + UNUSED_VALUE(value); + return true; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetListFloat(ConstAttrHolderAdapter &&obj, + const std::string &name, + std::vector &value) +{ + UNUSED_VALUE(obj); + UNUSED_VALUE(name); + UNUSED_VALUE(value); + return true; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::SetBool(AttrHolderAdapter &&obj, + const std::string &name, + const bool &value) +{ + UNUSED_VALUE(obj); + UNUSED_VALUE(name); + UNUSED_VALUE(value); + return true; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetBool(ConstAttrHolderAdapter &&obj, + const std::string &name, + bool &value) +{ + UNUSED_VALUE(obj); + UNUSED_VALUE(name); + UNUSED_VALUE(value); + return true; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::SetListBool(AttrHolderAdapter &&obj, + const std::string &name, + const std::vector &value) +{ + UNUSED_VALUE(obj); + UNUSED_VALUE(name); + UNUSED_VALUE(value); + return true; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetListBool(ConstAttrHolderAdapter &&obj, + const std::string &name, + std::vector &value) +{ + UNUSED_VALUE(obj); + UNUSED_VALUE(name); + UNUSED_VALUE(value); + return true; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::SetStr(AttrHolderAdapter &&obj, + const std::string &name, + const std::string &value) +{ + UNUSED_VALUE(obj); + UNUSED_VALUE(name); + UNUSED_VALUE(value); + return true; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetStr(ConstAttrHolderAdapter &&obj, + const std::string &name, + std::string &value) +{ + UNUSED_VALUE(obj); + UNUSED_VALUE(name); + UNUSED_VALUE(value); + return true; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const std::string *AttrUtils::GetStr(ConstAttrHolderAdapter &&obj, + const std::string &name) +{ + UNUSED_VALUE(obj); + UNUSED_VALUE(name); + static std::string str; + return &str; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::SetListStr(AttrHolderAdapter &&obj, + const std::string &name, + const std::vector &value) +{ + UNUSED_VALUE(obj); + UNUSED_VALUE(name); + UNUSED_VALUE(value); + return true; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetListStr(ConstAttrHolderAdapter &&obj, + const std::string &name, + std::vector &value) +{ + UNUSED_VALUE(obj); + UNUSED_VALUE(name); + UNUSED_VALUE(value); + return true; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::SetTensorDesc(AttrHolderAdapter &&obj, + const std::string &name, + const GeTensorDesc &value) +{ + UNUSED_VALUE(obj); + UNUSED_VALUE(name); + UNUSED_VALUE(value); + return true; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetTensorDesc(ConstAttrHolderAdapter &&obj, + const std::string &name, + GeTensorDesc &value) +{ + UNUSED_VALUE(obj); + UNUSED_VALUE(name); + UNUSED_VALUE(value); + return true; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::SetListTensorDesc(AttrHolderAdapter &&obj, + const std::string &name, + const std::vector &value) +{ + UNUSED_VALUE(obj); + UNUSED_VALUE(name); + UNUSED_VALUE(value); + return true; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetListTensorDesc(ConstAttrHolderAdapter &&obj, + const std::string &name, + std::vector &value) +{ + UNUSED_VALUE(obj); + UNUSED_VALUE(name); + UNUSED_VALUE(value); + return true; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::SetNamedAttrs(AttrHolderAdapter &&obj, + const std::string &name, + const NamedAttrs &value) +{ + UNUSED_VALUE(obj); + UNUSED_VALUE(name); + UNUSED_VALUE(value); + return true; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetNamedAttrs(ConstAttrHolderAdapter &&obj, + const std::string &name, + NamedAttrs &value) +{ + UNUSED_VALUE(obj); + UNUSED_VALUE(name); + UNUSED_VALUE(value); + return true; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::SetListNamedAttrs(AttrHolderAdapter &&obj, + const std::string &name, + const std::vector &value) +{ + UNUSED_VALUE(obj); + UNUSED_VALUE(name); + UNUSED_VALUE(value); + return true; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetListNamedAttrs(ConstAttrHolderAdapter &&obj, + const std::string &name, + std::vector &value) +{ + UNUSED_VALUE(obj); + UNUSED_VALUE(name); + UNUSED_VALUE(value); + return true; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::SetDataType(AttrHolderAdapter &&obj, + const std::string &name, + const DataType &value) +{ + UNUSED_VALUE(obj); + UNUSED_VALUE(name); + UNUSED_VALUE(value); + return true; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetDataType(ConstAttrHolderAdapter &&obj, + const std::string &name, + DataType &value) +{ + UNUSED_VALUE(obj); + UNUSED_VALUE(name); + UNUSED_VALUE(value); + return true; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::SetListDataType(AttrHolderAdapter &&obj, + const std::string &name, + const std::vector &value) +{ + UNUSED_VALUE(obj); + UNUSED_VALUE(name); + UNUSED_VALUE(value); + return true; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetListDataType(ConstAttrHolderAdapter &&obj, + const std::string &name, + std::vector &value) +{ + UNUSED_VALUE(obj); + UNUSED_VALUE(name); + UNUSED_VALUE(value); + return true; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::SetListListInt(AttrHolderAdapter &&obj, + const std::string &name, + const std::vector> &value) +{ + UNUSED_VALUE(obj); + UNUSED_VALUE(name); + UNUSED_VALUE(value); + return true; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetListListInt(ConstAttrHolderAdapter &&obj, + const std::string &name, + std::vector> &value) +{ + UNUSED_VALUE(obj); + UNUSED_VALUE(name); + UNUSED_VALUE(value); + return true; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::SetListListFloat(AttrHolderAdapter &&obj, + const std::string &name, + const std::vector> &value) +{ + UNUSED_VALUE(obj); + UNUSED_VALUE(name); + UNUSED_VALUE(value); + return true; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetListListFloat(ConstAttrHolderAdapter &&obj, + const std::string &name, + std::vector> &value) +{ + UNUSED_VALUE(obj); + UNUSED_VALUE(name); + UNUSED_VALUE(value); + return true; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::SetListInt(AttrHolderAdapter &&obj, + const std::string &name, + const std::vector &value) +{ + UNUSED_VALUE(obj); + UNUSED_VALUE(name); + UNUSED_VALUE(value); + return true; +} +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::SetListInt(AttrUtils::AttrHolderAdapter &&obj, + const std::string &name, + const std::vector &value) +{ + UNUSED_VALUE(obj); + UNUSED_VALUE(name); + UNUSED_VALUE(value); + return true; +} +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::SetListInt(AttrHolderAdapter &&obj, + const std::string &name, + std::initializer_list &&value) +{ + UNUSED_VALUE(obj); + UNUSED_VALUE(name); + UNUSED_VALUE(value); + return true; +} +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetListInt(ConstAttrHolderAdapter &&obj, + const std::string &name, + std::vector &value) +{ + UNUSED_VALUE(obj); + UNUSED_VALUE(name); + UNUSED_VALUE(value); + return true; +} +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetListInt(ConstAttrHolderAdapter &&obj, + const std::string &name, + std::vector &value) +{ + UNUSED_VALUE(obj); + UNUSED_VALUE(name); + UNUSED_VALUE(value); + return true; +} +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::SetTensor(AttrUtils::AttrHolderAdapter &&obj, + const std::string &name, + const GeTensor &value) +{ + UNUSED_VALUE(obj); + UNUSED_VALUE(name); + UNUSED_VALUE(value); + return true; +} +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::SetTensor(AttrHolderAdapter &&obj, + const std::string &name, + const GeTensorPtr &value) +{ + UNUSED_VALUE(obj); + UNUSED_VALUE(name); + UNUSED_VALUE(value); + return true; +} +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::SetTensor(AttrHolderAdapter &&obj, + const std::string &name, + const ConstGeTensorPtr &value) +{ + UNUSED_VALUE(obj); + UNUSED_VALUE(name); + UNUSED_VALUE(value); + return true; +} +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::SetListTensor(AttrUtils::AttrHolderAdapter &&obj, + const std::string &name, + const std::vector &value) +{ + UNUSED_VALUE(obj); + UNUSED_VALUE(name); + UNUSED_VALUE(value); + return true; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::SetListTensor(AttrHolderAdapter &&obj, + const std::string &name, + const std::vector &value) +{ + UNUSED_VALUE(obj); + UNUSED_VALUE(name); + UNUSED_VALUE(value); + return true; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::SetListTensor(AttrHolderAdapter &&obj, + const std::string &name, + const std::vector &value) +{ + UNUSED_VALUE(obj); + UNUSED_VALUE(name); + UNUSED_VALUE(value); + return true; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::SetListTensor(AttrHolderAdapter &&obj, + const std::string &name, + std::initializer_list &&value) +{ + UNUSED_VALUE(obj); + UNUSED_VALUE(name); + UNUSED_VALUE(value); + return true; +} + +// 所有权UT测试,不能把属性上的GeTensor给错误释放了 +// 而且这里的行为与老版本是不一样的,老版本中,即使属性的owner生命周期结束析构了,通过本接口获取的value仍然是可用的 +// 但是新接口中,owner没有转移,owner析构后,value指向的内存就被释放了,这里需要排查 +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::MutableTensor(AttrHolderAdapter &&obj, + const std::string &name, + GeTensorPtr &value) +{ + UNUSED_VALUE(obj); + UNUSED_VALUE(name); + UNUSED_VALUE(value); + return true; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetTensor(ConstAttrHolderAdapter &&obj, + const std::string &name, + ConstGeTensorPtr &value) +{ + UNUSED_VALUE(obj); + UNUSED_VALUE(name); + UNUSED_VALUE(value); + return true; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetListTensor(ConstAttrHolderAdapter &&obj, + const std::string &name, + std::vector &value) +{ + UNUSED_VALUE(obj); + UNUSED_VALUE(name); + UNUSED_VALUE(value); + return true; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::MutableListTensor(AttrHolderAdapter &&obj, + const std::string &name, + std::vector &value) +{ + UNUSED_VALUE(obj); + UNUSED_VALUE(name); + UNUSED_VALUE(value); + return true; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::SetGraph(AttrUtils::AttrHolderAdapter &&obj, + const std::string &name, + const ComputeGraphPtr &value) +{ + UNUSED_VALUE(obj); + UNUSED_VALUE(name); + UNUSED_VALUE(value); + return true; +} +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::SetListGraph(AttrUtils::AttrHolderAdapter &&obj, + const std::string &name, + const std::vector &value) +{ + UNUSED_VALUE(obj); + UNUSED_VALUE(name); + UNUSED_VALUE(value); + return true; +} +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetGraph(AttrUtils::ConstAttrHolderAdapter &&obj, + const std::string &name, ComputeGraphPtr &value) +{ + UNUSED_VALUE(obj); + UNUSED_VALUE(name); + UNUSED_VALUE(value); + return true; +} +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetListGraph(AttrUtils::ConstAttrHolderAdapter &&obj, + const std::string &name, + std::vector &value) +{ + UNUSED_VALUE(obj); + UNUSED_VALUE(name); + UNUSED_VALUE(value); + return true; +} +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::SetBytes(AttrUtils::AttrHolderAdapter &&obj, + const std::string &name, const Buffer &value) +{ + UNUSED_VALUE(obj); + UNUSED_VALUE(name); + UNUSED_VALUE(value); + return true; +} +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetBytes(ConstAttrHolderAdapter &&obj, + const std::string &name, Buffer &value) +{ + UNUSED_VALUE(obj); + UNUSED_VALUE(name); + UNUSED_VALUE(value); + return true; +} +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::SetListBytes(AttrUtils::AttrHolderAdapter &&obj, + const std::string &name, + const std::vector &value) +{ + UNUSED_VALUE(obj); + UNUSED_VALUE(name); + UNUSED_VALUE(value); + return true; +} +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetListBytes(AttrUtils::ConstAttrHolderAdapter &&obj, + const std::string &name, + std::vector &value) +{ + UNUSED_VALUE(obj); + UNUSED_VALUE(name); + UNUSED_VALUE(value); + return true; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::SetZeroCopyBytes(AttrHolderAdapter &&obj, + const std::string &name, + Buffer &&buffer) +{ + UNUSED_VALUE(obj); + UNUSED_VALUE(name); + UNUSED_VALUE(buffer); + return true; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetZeroCopyBytes(ConstAttrHolderAdapter &&obj, + const std::string &name, + Buffer &buffer) +{ + UNUSED_VALUE(obj); + UNUSED_VALUE(name); + UNUSED_VALUE(buffer); + return true; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::SetZeroCopyListBytes(AttrHolderAdapter &&obj, + const std::string &name, + std::vector &list_buffer) +{ + UNUSED_VALUE(obj); + UNUSED_VALUE(name); + UNUSED_VALUE(list_buffer); + return true; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetZeroCopyListBytes(ConstAttrHolderAdapter &&obj, + const std::string &name, + std::vector &list_buffer) +{ + UNUSED_VALUE(obj); + UNUSED_VALUE(name); + UNUSED_VALUE(list_buffer); + return true; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::map +AttrUtils::GetAllAttrs(ConstAttrHolderAdapter &&obj) +{ + UNUSED_VALUE(obj); + const std::map empty; + return empty; +} + +std::string AttrUtils::GetAttrsStrAfterRid(ConstAttrHolderAdapter &&obj, const std::set &un_compute_attrs) +{ + UNUSED_VALUE(obj); + UNUSED_VALUE(un_compute_attrs); + return ""; +} +std::string AttrUtils::GetAllAttrsStr(ConstAttrHolderAdapter &&obj) +{ + UNUSED_VALUE(obj); + return ""; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::ClearAllAttrs(AttrHolderAdapter &&obj) +{ + UNUSED_VALUE(obj); + return true; +} +} // namespace ge diff --git a/src/kernels/tbe_adapter/stubs/metadef/graph/ge_tensor.cpp b/src/kernels/tbe_adapter/stubs/metadef/graph/ge_tensor.cpp new file mode 100644 index 0000000000000000000000000000000000000000..facb6f16afcc8507c2bb6478904abd7b1c49f57a --- /dev/null +++ b/src/kernels/tbe_adapter/stubs/metadef/graph/ge_tensor.cpp @@ -0,0 +1,1382 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include "graph/ge_tensor.h" + +#include +#include +#include +#include +#include "common/util/mem_utils.h" +#include "debug/ge_util.h" +#include "graph/normal_graph/ge_tensor_impl.h" +#include "graph/utils/tensor_utils.h" + +#define UNUSED_VALUE(x) (void)(x) +namespace ge { +namespace { +const std::map kDeviceToStrMap = { + {NPU, "NPU"}, + {CPU, "CPU"}, +}; + +const std::map kStrToDeviceMap = {{"NPU", NPU}, {"CPU", CPU}}; + +} // namespace + +void GeTensorSerializeUtils::GeShapeAsProto(const GeShape &shape, proto::ShapeDef *proto) +{ + UNUSED_VALUE(shape); + UNUSED_VALUE(proto); + MKI_LOG(ERROR) << "fail in GeShapeAsProto"; +} +void GeTensorSerializeUtils::GeTensorDescAsProto(const GeTensorDescImpl &desc, proto::TensorDescriptor *proto) +{ + UNUSED_VALUE(desc); + UNUSED_VALUE(proto); + MKI_LOG(ERROR) << "fail in GeTensorDescAsProto"; +} +void GeTensorSerializeUtils::GeTensorDescAsProto(const GeTensorDesc &desc, proto::TensorDescriptor *proto) +{ + UNUSED_VALUE(desc); + UNUSED_VALUE(proto); + MKI_LOG(ERROR) << "fail in GeTensorDescAsProto"; +} +void GeTensorSerializeUtils::GeTensorAsProto(const GeTensorImpl &tensor, proto::TensorDef *proto) +{ + UNUSED_VALUE(tensor); + UNUSED_VALUE(proto); + MKI_LOG(ERROR) << "fail in GeTensorAsProto"; +} +void GeTensorSerializeUtils::GeTensorAsProto(const GeTensor &tensor, proto::TensorDef *proto) +{ + UNUSED_VALUE(tensor); + UNUSED_VALUE(proto); + MKI_LOG(ERROR) << "fail in GeTensorAsProto"; +} + +void GeTensorSerializeUtils::AssembleGeShapeFromProto(const proto::ShapeDef *proto, GeShape &shape) +{ + UNUSED_VALUE(proto); + UNUSED_VALUE(shape); + MKI_LOG(ERROR) << "fail in AssembleGeShapeFromProto"; +} +void GeTensorSerializeUtils::AssembleGeTensorDescFromProto(const proto::TensorDescriptor *const proto, + GeTensorDesc &desc) +{ + UNUSED_VALUE(proto); + UNUSED_VALUE(desc); + MKI_LOG(ERROR) << "fail in AssembleGeTensorDescFromProto"; +} +void GeTensorSerializeUtils::AssembleGeTensorFromProto(const proto::TensorDef *proto, GeTensor &tensor) +{ + UNUSED_VALUE(proto); + UNUSED_VALUE(tensor); + MKI_LOG(ERROR) << "fail in AssembleGeTensorFromProto"; +} + +void GeTensorSerializeUtils::NormalizeGeTensorDescProto(proto::TensorDescriptor *proto) +{ + UNUSED_VALUE(proto); + MKI_LOG(ERROR) << "fail in NormalizeGeTensorDescProto"; +} + +void GeTensorSerializeUtils::GetShapeFromDescProto(const proto::TensorDescriptor *const proto, GeShape &shape) +{ + UNUSED_VALUE(proto); + UNUSED_VALUE(shape); + MKI_LOG(ERROR) << "fail in GetShapeFromDescProto"; +} + +void GeTensorSerializeUtils::GetOriginShapeFromDescProto(const proto::TensorDescriptor *const proto, GeShape &shape) +{ + UNUSED_VALUE(proto); + UNUSED_VALUE(shape); + MKI_LOG(ERROR) << "fail in GetOriginShapeFromDescProto"; +} + +void GeTensorSerializeUtils::GetDtypeFromDescProto(const proto::TensorDescriptor *const proto, DataType &dtype) +{ + UNUSED_VALUE(proto); + UNUSED_VALUE(dtype); + MKI_LOG(ERROR) << "fail in GetDtypeFromDescProto"; +} + +void GeTensorSerializeUtils::GetOriginDtypeFromDescProto(const proto::TensorDescriptor *const proto, DataType &dtype) +{ + UNUSED_VALUE(proto); + UNUSED_VALUE(dtype); + MKI_LOG(ERROR) << "fail in GetOriginDtypeFromDescProto"; +} + +void GeTensorSerializeUtils::GetFormatFromDescProto(const proto::TensorDescriptor *const proto, Format &format) +{ + UNUSED_VALUE(proto); + UNUSED_VALUE(format); + MKI_LOG(ERROR) << "fail in GetFormatFromDescProto"; +} + +void GeTensorSerializeUtils::GetOriginFormatFromDescProto(const proto::TensorDescriptor *const proto, Format &format) +{ + UNUSED_VALUE(proto); + UNUSED_VALUE(format); + MKI_LOG(ERROR) << "fail in GetOriginFormatFromDescProto"; +} + +class GeShapeImpl { + using DimsType = SmallVector; + +public: + ~GeShapeImpl() = default; + GeShapeImpl() = default; + explicit GeShapeImpl(proto::ShapeDef *const proto_msg); + explicit GeShapeImpl(const std::vector &dims); + void AppendDim(const int64_t dim_size); + void SetDimNum(const size_t dim_num); + void SetIsUnknownDimNum(); + bool IsUnknownDimNum() const; + int64_t GetDim(const size_t idx) const; + size_t GetDimNum() const; + std::vector ShapeImplGetDims() const; + graphStatus SetDim(const size_t idx, const int64_t value); + std::string ShapeImplToString() const; + const DimsType &ShapeImplGetMutableDims() const; + int64_t GetShapeSize() const; + bool IsScalar() const; + bool operator==(const GeShapeImpl &other) const; + bool IsUnknownShape() const; + +private: + DimsType dims_; + friend class GeTensorDesc; +}; + +// Default +GeShapeImpl::GeShapeImpl(const std::vector &dims) +{ + UNUSED_VALUE(dims); + MKI_LOG(ERROR) << "fail in GeShapeImpl"; +} + +void GeShapeImpl::SetDimNum(const size_t dim_num) +{ + MKI_LOG(ERROR) << "fail in SetDimNum"; + dims_.resize(dim_num, UNKNOWN_DIM); +} + +void GeShapeImpl::AppendDim(const int64_t dim_size) +{ + MKI_LOG(ERROR) << "fail in AppendDim"; + dims_.push_back(dim_size); +} + +bool GeShapeImpl::IsUnknownDimNum() const +{ + MKI_LOG(ERROR) << "fail in IsUnknownDimNum"; + return (dims_.size() == 1UL) && (dims_[0UL] == UNKNOWN_DIM_NUM); +} + +void GeShapeImpl::SetIsUnknownDimNum() +{ + MKI_LOG(ERROR) << "fail in SetIsUnknownDimNum"; +} + +size_t GeShapeImpl::GetDimNum() const +{ + MKI_LOG(ERROR) << "fail in GetDimNum"; + return static_cast(-1); +} + +int64_t GeShapeImpl::GetDim(const size_t idx) const +{ + UNUSED_VALUE(idx); + MKI_LOG(ERROR) << "fail in GetDim"; + return 0; +} + +graphStatus GeShapeImpl::SetDim(const size_t idx, const int64_t value) +{ + UNUSED_VALUE(idx); + UNUSED_VALUE(value); + MKI_LOG(ERROR) << "fail in SetDim"; + return GRAPH_FAILED; +} + +std::vector GeShapeImpl::ShapeImplGetDims() const +{ + MKI_LOG(ERROR) << "fail in ShapeImplGetDims"; + return std::vector(); +} + +const SmallVector &GeShapeImpl::ShapeImplGetMutableDims() const +{ + MKI_LOG(ERROR) << "fail in ShapeImplGetMutableDims"; + return dims_; +} + +std::string GeShapeImpl::ShapeImplToString() const +{ + MKI_LOG(ERROR) << "fail in ShapeImplToString"; + return std::string(); +} + +int64_t GeShapeImpl::GetShapeSize() const +{ + MKI_LOG(ERROR) << "fail in GetShapeSize"; + return 0; +} + +bool GeShapeImpl::IsUnknownShape() const +{ + MKI_LOG(ERROR) << "fail in IsUnknownShape"; + return std::any_of(dims_.begin(), dims_.end(), [](const int64_t &dim) { + return (dim == UNKNOWN_DIM) || (dim == UNKNOWN_DIM_NUM) || (dim < 0); + }); +} + +bool GeShapeImpl::IsScalar() const +{ + MKI_LOG(ERROR) << "fail in IsScalar"; + return dims_.empty(); +} + +GeShapeImpl::GeShapeImpl(proto::ShapeDef *const proto_msg) +{ + MKI_LOG(ERROR) << "fail in GeShapeImpl"; + UNUSED_VALUE(proto_msg); +} + +bool GeShapeImpl::operator==(const GeShapeImpl &other) const +{ + MKI_LOG(ERROR) << "fail in operator"; + return this->ShapeImplGetDims() == other.ShapeImplGetDims(); +} + +GeShape::GeShape(std::vector s) : impl_(MakeShared(std::move(s))) {} +GeShape::GeShape() : impl_(MakeShared()) {} + +GeShape::GeShape(GeShape &&other) : impl_(MakeShared(std::move(*(other.impl_)))) {} +GeShape::GeShape(const GeShape &other) : impl_(MakeShared(*(other.impl_))) {} + +GeShape::GeShape(const ProtoMsgOwner &proto_owner, proto::ShapeDef *const proto_msg) + : impl_(MakeShared(proto_msg)) +{ + MKI_LOG(ERROR) << "fail in MutableAttrMap"; + + UNUSED_VALUE(proto_owner); +} +GeShape::~GeShape() = default; + +size_t GeShape::GetDimNum() const +{ + MKI_LOG(ERROR) << "fail in GetDimNum"; + return impl_->GetDimNum(); +} + +void GeShape::SetDimNum(const size_t dim_num) +{ + MKI_LOG(ERROR) << "fail in SetDimNum"; + impl_->SetDimNum(dim_num); +} + +void GeShape::AppendDim(const int64_t dim_size) +{ + MKI_LOG(ERROR) << "fail in AppendDim"; + impl_->AppendDim(dim_size); +} + +bool GeShape::IsUnknownDimNum() const +{ + MKI_LOG(ERROR) << "fail in IsUnknownDimNum"; + return impl_->IsUnknownDimNum(); +} + +void GeShape::SetIsUnknownDimNum() +{ + MKI_LOG(ERROR) << "fail in SetIsUnknownDimNum"; + impl_->SetIsUnknownDimNum(); +} + +int64_t GeShape::GetDim(const size_t idx) const +{ + MKI_LOG(ERROR) << "fail in GetDim"; + return impl_->GetDim(idx); +} + +graphStatus GeShape::SetDim(const size_t idx, const int64_t value) +{ + MKI_LOG(ERROR) << "fail in SetDim"; + return impl_->SetDim(idx, value); +} + +std::vector GeShape::GetDims() const +{ + MKI_LOG(ERROR) << "fail in GetDims"; + return impl_->ShapeImplGetDims(); +} + +const SmallVector &GeShape::GetMutableDims() const +{ + MKI_LOG(ERROR) << "fail in GetMutableDims"; + return impl_->ShapeImplGetMutableDims(); +} + +std::string GeShape::ToString() const +{ + MKI_LOG(ERROR) << "fail in ToString"; + return impl_->ShapeImplToString(); +} + +int64_t GeShape::GetShapeSize() const +{ + MKI_LOG(ERROR) << "fail in GetShapeSize"; + return impl_->GetShapeSize(); +} + +bool GeShape::IsUnknownShape() const +{ + MKI_LOG(ERROR) << "fail in IsUnknownShape"; + return impl_->IsUnknownShape(); +} + +bool GeShape::IsScalar() const +{ + MKI_LOG(ERROR) << "fail in IsScalar"; + return impl_->IsScalar(); +} + +GeShape &GeShape::operator=(const GeShape &other) +{ + UNUSED_VALUE(other); + MKI_LOG(ERROR) << "fail in operator"; + return *this; +} + +GeShape &GeShape::operator=(GeShape &&other) +{ + UNUSED_VALUE(other); + MKI_LOG(ERROR) << "fail in operator"; + return *this; +} + +bool GeShape::operator==(const GeShape &other) const { return (*impl_) == (*(other.impl_)); } + +// GeTensorDesc +GeTensorDescImpl::GeTensorDescImpl(const GeShape &shape, const Format format, const DataType dt) : GeTensorDescImpl() +{ + SetFormat(format); + SetDataType(dt); + shape_ = shape; +} + +GeTensorDescImpl::GeTensorDescImpl(proto::TensorDescriptor *const proto_msg) : GeTensorDescImpl() +{ + UNUSED_VALUE(proto_msg); + MKI_LOG(ERROR) << "fail in GeTensorDescImpl"; +} + +void GeTensorDescImpl::SetDataType(const DataType dtype) +{ + MKI_LOG(DEBUG) << "stub SetDataType"; + dtype_ = dtype; +} + +void GeTensorDescImpl::SetOriginDataType(const DataType dtype) +{ + MKI_LOG(ERROR) << "fail in SetOriginDataType"; + origin_dtype_ = dtype; +} + +DataType GeTensorDescImpl::GetOriginDataType() const +{ + MKI_LOG(ERROR) << "fail in GetOriginDataType"; + return origin_dtype_; +} + +void GeTensorDescImpl::SetFormat(const Format format) +{ + MKI_LOG(DEBUG) << "stub SetFormat"; + format_ = format; +} + +void GeTensorDescImpl::SetOriginFormat(const Format format) +{ + MKI_LOG(DEBUG) << "stub SetOriginFormat"; + origin_format_ = format; +} + +Format GeTensorDescImpl::GetOriginFormat() const +{ + MKI_LOG(ERROR) << "fail in GetOriginFormat"; + return origin_format_; +} + +GeShape &GeTensorDescImpl::ShapeReference() const +{ + MKI_LOG(ERROR) << "fail in ShapeReference"; + return shape_; +} + +GeShape &GeTensorDescImpl::OriginShapeReference() const +{ + MKI_LOG(ERROR) << "fail in OriginShapeReference"; + return origin_shape_; +} + +bool GeTensorDescImpl::GeTensorDescAttrsAreEqual(const GeTensorDescImpl &other) const +{ + // The definition of attribute equality remains unchanged + MKI_LOG(ERROR) << "fail in GeTensorDescAttrsAreEqual"; + return ((shape_ == other.shape_) && (dtype_ == other.dtype_) && (format_ == other.format_) && + (ext_meta_ == other.ext_meta_)); +} + +bool GeTensorDescImpl::operator==(const GeTensorDescImpl &other) const +{ + // The definition of attribute equality remains unchanged + MKI_LOG(ERROR) << "fail in operator"; + return (origin_shape_ == other.origin_shape_) && (origin_format_ == other.origin_format_) && + (origin_dtype_ == other.origin_dtype_) && (GeTensorDescAttrsAreEqual(other)); +} + +ProtoAttrMap &GeTensorDescImpl::MutableAttrMap() +{ + MKI_LOG(ERROR) << "fail in MutableAttrMap"; + return attrs_; +} + +ConstProtoAttrMap &GeTensorDescImpl::GetAttrMap() const +{ + MKI_LOG(ERROR) << "fail in GetAttrMap"; + return attrs_; +} + +void GeTensorDescImpl::SetShape(GeShape &shape) const +{ + MKI_LOG(ERROR) << "fail in SetShape"; + ShapeReference() = std::move(shape); +} + +Format GeTensorDescImpl::GetFormat() const +{ + return format_; +} + +void GeTensorDescImpl::SetName(const std::string &name) +{ + MKI_LOG(ERROR) << "fail in SetName"; + ext_meta_.SetName(name); +} + +const std::string GeTensorDescImpl::GetName() const +{ + MKI_LOG(ERROR) << "fail in GetName"; + return ext_meta_.GetName(); +} + +DataType GeTensorDescImpl::GetDataType() const +{ + return dtype_; +} + +std::string GeTensorDescImpl::ExtMeta::GetDeviceTypeStr() const +{ + MKI_LOG(ERROR) << "fail in GetDeviceTypeStr"; + return std::string(); +} + +GeTensorDesc::GeTensorDesc() : AttrHolder(), impl_(ComGraphMakeShared()) {} + +// Default +GeTensorDesc::GeTensorDesc(const GeShape &shape, const Format format, const DataType dt) + : AttrHolder(), impl_(ComGraphMakeShared(shape, format, dt)) +{ +} + +// Default +GeTensorDesc::GeTensorDesc(const GeTensorDesc &desc) + : AttrHolder(desc), impl_(ComGraphMakeShared(*(desc.impl_))) +{ +} + +// Default +GeTensorDesc::GeTensorDesc(GeTensorDesc &&desc) : AttrHolder(desc), impl_(desc.impl_) {} + +GeTensorDesc::~GeTensorDesc() = default; + +GeTensorDesc::GeTensorDesc(proto::TensorDescriptor *const proto_msg) + : AttrHolder(), impl_(ComGraphMakeShared(proto_msg)) +{ +} + +bool GeTensorDesc::GeTensorDescAttrsAreEqual(const GeTensorDesc &r_ge_tensor_desc) const +{ + MKI_LOG(ERROR)<< "fail in GeTensorDescAttrsAreEqual"; + return impl_->GeTensorDescAttrsAreEqual(*(r_ge_tensor_desc.impl_)); +} + +bool GeTensorDesc::operator==(const GeTensorDesc &r_ge_tensor_desc) const +{ + MKI_LOG(ERROR)<< "fail in operator"; + return (*impl_) == (*(r_ge_tensor_desc.impl_)); +} + +GeShape &GeTensorDesc::ShapeReference() const +{ + MKI_LOG(ERROR)<< "fail in ShapeReference"; + return impl_->ShapeReference(); +} + +ProtoAttrMap &GeTensorDesc::MutableAttrMap() +{ + MKI_LOG(ERROR)<< "fail in MutableAttrMap"; + return impl_->MutableAttrMap(); +} + +ConstProtoAttrMap &GeTensorDesc::GetAttrMap() const +{ + MKI_LOG(ERROR)<< "fail in GetAttrMap"; + return impl_->GetAttrMap(); +} + +void GeTensorDesc::Update(const GeShape &shape, const Format format, const DataType dt) +{ + MKI_LOG(ERROR)<< "fail in Update"; + UNUSED_VALUE(shape); + UNUSED_VALUE(format); + UNUSED_VALUE(dt); +} +const GeShape &GeTensorDesc::GetShape() const +{ + MKI_LOG(ERROR)<< "fail in GetShape"; + return ShapeReference(); +} + +GeShape &GeTensorDesc::MutableShape() +{ + MKI_LOG(ERROR)<< "fail in MutableShape"; + return ShapeReference(); +} + +void GeTensorDesc::SetShape(const GeShape &shape) +{ + MKI_LOG(ERROR)<< "fail in SetShape"; + UNUSED_VALUE(shape); +} + +void GeTensorDesc::SetShape(GeShape &&shape) +{ + MKI_LOG(ERROR)<< "fail in SetShape"; + UNUSED_VALUE(shape); +} + +// set shape with -2, it stand for unknown shape +void GeTensorDesc::SetUnknownDimNumShape() +{ + MKI_LOG(ERROR)<< "fail in SetUnknownDimNumShape"; +} + +// for unknown shape +graphStatus GeTensorDesc::SetValueRange(const std::vector> &range) +{ + UNUSED_VALUE(range); + MKI_LOG(ERROR) << "fail in SetValueRange"; + return GRAPH_FAILED; +} + +graphStatus GeTensorDesc::GetValueRange(std::vector> &range) const +{ + UNUSED_VALUE(range); + std::vector> value_range; + MKI_LOG(ERROR) << "fail in GetValueRange"; + return GRAPH_FAILED; +} + +graphStatus GeTensorDesc::SetShapeRange(const std::vector> &range) +{ + UNUSED_VALUE(range); + MKI_LOG(ERROR) << "fail in SetShapeRange"; + return GRAPH_FAILED; +} + +graphStatus GeTensorDesc::SetOriginShapeRange(const std::vector> &range) +{ + UNUSED_VALUE(range); + std::vector> origin_shape_range; + MKI_LOG(ERROR) << "fail in SetOriginShapeRange"; + return GRAPH_FAILED; +} + +graphStatus GeTensorDesc::GetShapeRange(std::vector> &range) const +{ + UNUSED_VALUE(range); + std::vector> shape_range; + MKI_LOG(ERROR) << "fail in GetShapeRange"; + return GRAPH_FAILED; +} + +graphStatus GeTensorDesc::GetOriginShapeRange(std::vector> &range) const +{ + UNUSED_VALUE(range); + MKI_LOG(ERROR) << "fail in GetOriginShape"; + std::vector> origin_shape_range; + return GRAPH_FAILED; +} + +const GeShape &GeTensorDesc::GetOriginShape() const +{ + MKI_LOG(ERROR) << "fail in GetOriginShape"; + return impl_->OriginShapeReference(); +} + +GeShape &GeTensorDesc::MutableOriginShape() const { return impl_->OriginShapeReference(); } + +void GeTensorDesc::SetOriginShape(const GeShape &origin_shape) +{ + MKI_LOG(ERROR) << "fail in SetOriginShape"; + impl_->OriginShapeReference() = origin_shape; + impl_->SetOriginShapeInited(true); +} + +bool GeTensorDesc::IsOriginShapeInitialized() const +{ + MKI_LOG(ERROR) << "fail in IsOriginShapeInitialized"; + return impl_->IsOriginShapeInited(); +} + +Format GeTensorDesc::GetFormat() const +{ + MKI_LOG(DEBUG) << "stub GetFormat"; + return impl_->GetFormat(); +} + +void GeTensorDesc::SetFormat(const Format format) +{ + MKI_LOG(DEBUG) << "stub SetFormat"; + return impl_->SetFormat(format); +} + +void GeTensorDesc::SetName(const std::string &name) +{ + MKI_LOG(ERROR) << "fail in SetName"; + return impl_->SetName(name); +} + +const std::string GeTensorDesc::GetName() const +{ + MKI_LOG(ERROR) << "fail in GetName"; + return impl_->GetName(); +} + +Format GeTensorDesc::GetOriginFormat() const +{ + MKI_LOG(ERROR) << "fail in GetOriginFormat"; + return impl_->GetOriginFormat(); +} + +void GeTensorDesc::SetOriginFormat(const Format origin_format) +{ + MKI_LOG(DEBUG) << "stub SetOriginFormat"; + impl_->SetOriginFormat(origin_format); +} + +void GeTensorDesc::SetDataType(const DataType data_type) +{ + MKI_LOG(DEBUG) << "stub SetDataType"; + return impl_->SetDataType(data_type); +} + +DataType GeTensorDesc::GetDataType() const +{ + MKI_LOG(DEBUG) << "stub GetDataType"; + return impl_->GetDataType(); +} + +void GeTensorDesc::SetOriginDataType(const DataType origin_data_type) +{ + MKI_LOG(ERROR) << "fail in SetOriginDataType"; + impl_->SetOriginDataType(origin_data_type); +} + +DataType GeTensorDesc::GetOriginDataType() const +{ + MKI_LOG(ERROR) << "fail in GetOriginDataType"; + return impl_->GetOriginDataType(); +} + +std::vector GeTensorDesc::GetRefPortIndex() const +{ + MKI_LOG(ERROR) << "fail in GetRefPortIndex"; + std::vector ref_port_index; + return ref_port_index; +} + +void GeTensorDesc::SetRefPortByIndex(const std::vector &index) +{ + MKI_LOG(ERROR) << "fail in SetRefPortByIndex"; + UNUSED_VALUE(index); +} + +Placement GeTensorDesc::GetPlacement() const +{ + MKI_LOG(ERROR) << "fail in GetPlacement"; + int64_t placement = 0; + return static_cast(placement); +} + +void GeTensorDesc::SetPlacement(const Placement placement) +{ + MKI_LOG(ERROR) << "fail in SetPlacement"; + UNUSED_VALUE(placement); +} + +graphStatus GeTensorDesc::IsValid() const +{ + if ((this->GetDataType() != DT_UNDEFINED) || (this->GetFormat() != FORMAT_RESERVED)) { + return GRAPH_SUCCESS; + } + MKI_LOG(ERROR) << "fail in IsValid"; + return GRAPH_PARAM_INVALID; +} + +GeTensorDesc GeTensorDesc::Clone() const +{ + MKI_LOG(ERROR) << "fail in Clone"; + return *this; +} + +GeTensorDesc &GeTensorDesc::operator=(const GeTensorDesc &desc) +{ + UNUSED_VALUE(desc); + MKI_LOG(ERROR) << "fail in operator"; + return *this; +} + +GeTensorDesc &GeTensorDesc::operator=(GeTensorDesc &&desc) +{ + UNUSED_VALUE(desc); + MKI_LOG(ERROR) << "fail in operator"; + return *this; +} + +const std::string GeTensorDesc::GetExpandDimsRule() const { return std::string(); } +void GeTensorDesc::SetExpandDimsRule(const std::string &expand_dims_rule) +{ + UNUSED_VALUE(expand_dims_rule); + MKI_LOG(ERROR) << "fail in SetExpandDimsRule"; +} + +TensorData::TensorData() +{ + MKI_LOG(ERROR) << "fail in TensorData"; +} + +TensorData::TensorData(const TensorData &other) +{ + UNUSED_VALUE(other); + MKI_LOG(ERROR) << "fail in TensorData"; +} + +TensorData::~TensorData() = default; + +TensorData &TensorData::operator=(const TensorData &other) +{ + UNUSED_VALUE(other); + MKI_LOG(ERROR) << "fail in operator"; + return *this; +} + +graphStatus TensorData::SetData(std::vector &&data) +{ + UNUSED_VALUE(data); + MKI_LOG(ERROR) << "fail in SetData"; + return GRAPH_SUCCESS; +} +graphStatus TensorData::SetData(const std::vector &data) +{ + UNUSED_VALUE(data); + MKI_LOG(ERROR) << "fail in SetData"; + return GRAPH_SUCCESS; +} +graphStatus TensorData::SetData(const Buffer &data) +{ + UNUSED_VALUE(data); + MKI_LOG(ERROR) << "fail in SetData"; + return GRAPH_SUCCESS; +} +graphStatus TensorData::SetData(const TensorData &data) +{ + UNUSED_VALUE(data); + MKI_LOG(ERROR) << "fail in SetData"; + return GRAPH_SUCCESS; +} + +graphStatus TensorData::SetData(const uint8_t *const data, const size_t size) +{ + UNUSED_VALUE(data); + UNUSED_VALUE(size); + MKI_LOG(ERROR) << "fail in SetData"; + return GRAPH_SUCCESS; +} + +graphStatus TensorData::SetData(uint8_t *const data, const size_t size, const AlignedPtr::Deleter &delete_fuc) +{ + UNUSED_VALUE(data); + UNUSED_VALUE(size); + UNUSED_VALUE(delete_fuc); + MKI_LOG(ERROR) << "fail in SetData"; + return GRAPH_SUCCESS; +} + +void TensorData::SetData(std::shared_ptr aligned_ptr, const size_t size) +{ + UNUSED_VALUE(aligned_ptr); + MKI_LOG(ERROR) << "fail in SetData"; + UNUSED_VALUE(size); +} + +const uint8_t *TensorData::MallocAlignedPtr(const size_t size) +{ + UNUSED_VALUE(size); + MKI_LOG(ERROR) << "fail in MallocAlignedPtr"; + return nullptr; +} + +size_t TensorData::GetSize() const +{ + MKI_LOG(ERROR) << "fail in GetSize"; + return 0; +} + +const uint8_t *TensorData::GetData() const +{ + MKI_LOG(ERROR) << "fail in GetData"; + return nullptr; +} + +uint8_t *TensorData::GetData() +{ + MKI_LOG(ERROR) << "fail in GetData"; + return nullptr; +} + +const std::uint8_t *TensorData::data() const +{ + MKI_LOG(ERROR) << "fail in data"; + return GetData(); +} +std::uint8_t *TensorData::data() +{ + MKI_LOG(ERROR) << "fail in data"; + return GetData(); +} +std::size_t TensorData::size() const +{ + MKI_LOG(ERROR) << "fail in size"; + return GetSize(); +} +void TensorData::clear() +{ + MKI_LOG(ERROR) << "fail in clear"; +} + +uint8_t TensorData::operator[](const size_t index) const +{ + UNUSED_VALUE(index); + MKI_LOG(ERROR) << "fail in operator"; + return 0; +} + +const std::shared_ptr &TensorData::GetAlignedPtr() +{ + static std::shared_ptr ptr = nullptr; + MKI_LOG(ERROR) << "fail in GetAlignedPtr"; + return ptr; +} + +GeTensor::GeTensor() {} + +GeTensor::GeTensor(GeTensor &&other) noexcept +{ + UNUSED_VALUE(other); + MKI_LOG(ERROR) << "fail in GeTensor"; +} + +GeTensor::GeTensor(GeTensorImplPtr impl) +{ + UNUSED_VALUE(impl); + MKI_LOG(ERROR) << "fail in GeTensor"; +} + +GeTensor::GeTensor(const GeTensorDesc &tensor_desc) +{ + UNUSED_VALUE(tensor_desc); + MKI_LOG(ERROR) << "fail in GeTensor"; +} + +GeTensor::GeTensor(const GeTensorDesc &tensor_desc, const std::vector &data) +{ + UNUSED_VALUE(tensor_desc); + UNUSED_VALUE(data); + MKI_LOG(ERROR) << "fail in GeTensor"; +} + +GeTensor::GeTensor(const GeTensorDesc &tensor_desc, const uint8_t *const data, const size_t size) +{ + UNUSED_VALUE(tensor_desc); + UNUSED_VALUE(size); + UNUSED_VALUE(data); + MKI_LOG(ERROR) << "fail in GeTensor"; +} + +GeTensor::GeTensor(GeTensorDesc &&tensor_desc, std::vector &&data) +{ + UNUSED_VALUE(tensor_desc); + UNUSED_VALUE(data); + MKI_LOG(ERROR) << "fail in GeTensor"; +} + +GeTensor::GeTensor(const GeTensorDesc &tensor_desc, const Buffer &data) +{ + UNUSED_VALUE(tensor_desc); + UNUSED_VALUE(data); + MKI_LOG(ERROR) << "fail in GeTensor"; +} + +GeTensor::GeTensor(const GeTensorDesc &tensor_desc, std::shared_ptr aligned_ptr, const size_t size) +{ + UNUSED_VALUE(tensor_desc); + UNUSED_VALUE(aligned_ptr); + UNUSED_VALUE(size); + MKI_LOG(ERROR) << "fail in GeTensor"; +} + +GeTensor::GeTensor(const GeTensorDesc &tensor_desc, const size_t size) +{ + UNUSED_VALUE(tensor_desc); + UNUSED_VALUE(size); + MKI_LOG(ERROR) << "fail in GeTensor"; +} + +GeTensor::GeTensor(const ProtoMsgOwner &proto_owner, proto::TensorDef *proto_msg) +{ + UNUSED_VALUE(proto_owner); + UNUSED_VALUE(proto_msg); + MKI_LOG(ERROR) << "fail in GeTensor"; +} + +GeTensor::~GeTensor() = default; + +void GeTensor::BuildAlignerPtrWithProtoData() +{ + MKI_LOG(ERROR) << "fail in BuildAlignerPtrWithProtoData"; +} + +const GeTensorDesc &GeTensor::GetTensorDesc() const +{ + MKI_LOG(ERROR) << "fail in GetTensorDesc"; + return DescReference(); +} + +GeTensorDesc &GeTensor::MutableTensorDesc() +{ + MKI_LOG(ERROR) << "fail in MutableTensorDesc"; + return DescReference(); +} + +GeTensorDesc &GeTensor::DescReference() const +{ + MKI_LOG(ERROR) << "fail in DescReference"; + return DescReference(); +} + +void GeTensor::SetTensorDesc(const GeTensorDesc &tensor_desc) +{ + MKI_LOG(ERROR) << "fail in SetTensorDesc"; + UNUSED_VALUE(tensor_desc); +} + +graphStatus GeTensor::SetData(std::vector &&data) +{ + UNUSED_VALUE(data); + MKI_LOG(ERROR) << "fail in SetData"; + return GRAPH_SUCCESS; +} + +graphStatus GeTensor::SetData(const std::vector &data) +{ + UNUSED_VALUE(data); + MKI_LOG(ERROR) << "fail in SetData"; + return GRAPH_SUCCESS; +} + +graphStatus GeTensor::SetData(const uint8_t *const data, const size_t size) +{ + UNUSED_VALUE(data); + UNUSED_VALUE(size); + MKI_LOG(ERROR) << "fail in SetData"; + return GRAPH_SUCCESS; +} + +graphStatus GeTensor::SetData(const Buffer &data) +{ + UNUSED_VALUE(data); + MKI_LOG(ERROR) << "fail in SetData"; + return GRAPH_SUCCESS; +} + +graphStatus GeTensor::SetData(const TensorData &data) +{ + UNUSED_VALUE(data); + MKI_LOG(ERROR) << "fail in SetData"; + return GRAPH_SUCCESS; +} + +graphStatus GeTensor::SetData(uint8_t *const data, const size_t size, const AlignedPtr::Deleter &delete_fuc) +{ + UNUSED_VALUE(data); + UNUSED_VALUE(size); + UNUSED_VALUE(delete_fuc); + MKI_LOG(ERROR) << "fail in SetData"; + return GRAPH_SUCCESS; +} + +graphStatus GeTensor::ResetData(uint8_t *const data, const size_t size, const AlignedPtr::Deleter &delete_fuc) +{ + UNUSED_VALUE(data); + UNUSED_VALUE(size); + UNUSED_VALUE(delete_fuc); + return GRAPH_SUCCESS; +} + +void GeTensor::ClearData() +{ + MKI_LOG(ERROR) << "fail in ClearData"; +} + +GeTensor GeTensor::Clone() const +{ + MKI_LOG(ERROR) << "fail in Clone"; + return GeTensor(); +} + +GeTensor::GeTensor(const GeTensor &other) +{ + UNUSED_VALUE(other); + MKI_LOG(ERROR) << "fail in GeTensor"; +} + +GeTensor &GeTensor::operator=(const GeTensor &other) +{ + UNUSED_VALUE(other); + MKI_LOG(ERROR) << "fail in operator"; + return *this; +} + +GeTensor &GeTensor::operator=(GeTensor &&other) +{ + UNUSED_VALUE(other); + MKI_LOG(ERROR) << "fail in operator"; + return *this; +} + +std::shared_ptr GeTensor::GetAlignedPtr() +{ + MKI_LOG(ERROR) << "fail in GetAlignedPtr"; + return nullptr; +} + +const TensorData &GeTensor::GetData() const +{ + static TensorData tensorData; + MKI_LOG(ERROR) << "fail in GetData"; + return tensorData; +} +TensorData &GeTensor::MutableData() +{ + static TensorData tensorData; + MKI_LOG(ERROR) << "fail in MutableData"; + return tensorData; +} +// zero copy SetData +void GeTensor::SetData(std::shared_ptr aligned_ptr, const size_t size) +{ + UNUSED_VALUE(aligned_ptr); + UNUSED_VALUE(size); + MKI_LOG(ERROR) << "fail in SetData"; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::GetSize(const GeTensorDesc &tensor_desc, + int64_t &size) +{ + UNUSED_VALUE(tensor_desc); + UNUSED_VALUE(size); + MKI_LOG(ERROR) << "fail in GetSize"; + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetSize(GeTensorDesc &tensor_desc, const int64_t size) +{ + UNUSED_VALUE(tensor_desc); + UNUSED_VALUE(size); + MKI_LOG(ERROR) << "fail in SetSize"; +} + +int64_t TensorUtils::GetWeightSize(const GeTensorDesc &tensor_desc) +{ + UNUSED_VALUE(tensor_desc); + MKI_LOG(ERROR) << "fail in GetWeightSize"; + return 0; +} + +int64_t TensorUtils::GetWeightSize(const GeTensor &tensor) +{ + UNUSED_VALUE(tensor); + MKI_LOG(ERROR) << "fail in GetWeightSize"; + return 0; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY int64_t TensorUtils::GetWeightSize(const ConstGeTensorPtr &tensor_ptr) +{ + UNUSED_VALUE(tensor_ptr); + MKI_LOG(ERROR) << "fail in GetWeightSize"; + return 0; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY uint8_t *TensorUtils::GetWeightAddr(const ConstGeTensorPtr &tensor_ptr, + const uint8_t *const base) +{ + UNUSED_VALUE(tensor_ptr); + UNUSED_VALUE(base); + MKI_LOG(ERROR) << "fail in GetWeightAddr"; + return nullptr; +} + +uint8_t *TensorUtils::GetWeightAddr(const GeTensor &tensor, const uint8_t *const base) +{ + UNUSED_VALUE(tensor); + UNUSED_VALUE(base); + MKI_LOG(ERROR) << "fail in GetWeightAddr"; + return nullptr; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetWeightSize(GeTensorDesc &tensor_desc, + const int64_t size) +{ + UNUSED_VALUE(tensor_desc); + UNUSED_VALUE(size); + MKI_LOG(ERROR) << "fail in SetWeightSize"; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::GetReuseInput(const GeTensorDesc &tensor_desc, + bool &flag) +{ + UNUSED_VALUE(tensor_desc); + UNUSED_VALUE(flag); + MKI_LOG(ERROR) << "fail in GetReuseInput"; + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetReuseInput(GeTensorDesc &tensor_desc, + const bool flag) +{ + UNUSED_VALUE(tensor_desc); + UNUSED_VALUE(flag); + MKI_LOG(ERROR) << "fail in SetReuseInput"; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::GetOutputTensor(const GeTensorDesc &tensor_desc, + bool &flag) +{ + UNUSED_VALUE(tensor_desc); + UNUSED_VALUE(flag); + MKI_LOG(ERROR) << "fail in GetOutputTensor"; + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetOutputTensor(GeTensorDesc &tensor_desc, + const bool flag) +{ + UNUSED_VALUE(tensor_desc); + UNUSED_VALUE(flag); + MKI_LOG(ERROR) << "fail in SetOutputTensor"; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::GetDeviceType(const GeTensorDesc &tensor_desc, + DeviceType &type) +{ + UNUSED_VALUE(tensor_desc); + UNUSED_VALUE(type); + MKI_LOG(ERROR) << "fail in GetDeviceType"; + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetDeviceType(GeTensorDesc &tensor_desc, + const DeviceType type) +{ + UNUSED_VALUE(tensor_desc); + UNUSED_VALUE(type); + MKI_LOG(ERROR) << "fail in SetDeviceType"; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::GetInputTensor(const GeTensorDesc &tensor_desc, + bool &flag) +{ + UNUSED_VALUE(tensor_desc); + UNUSED_VALUE(flag); + MKI_LOG(ERROR) << "fail in GetInputTensor"; + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetInputTensor(GeTensorDesc &tensor_desc, + const bool flag) +{ + UNUSED_VALUE(tensor_desc); + UNUSED_VALUE(flag); + MKI_LOG(ERROR) << "fail in SetInputTensor"; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::GetRealDimCnt(const GeTensorDesc &tensor_desc, + uint32_t &cnt) +{ + UNUSED_VALUE(tensor_desc); + UNUSED_VALUE(cnt); + MKI_LOG(ERROR) << "fail in GetRealDimCnt"; + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetRealDimCnt(GeTensorDesc &tensor_desc, + const uint32_t cnt) +{ + UNUSED_VALUE(tensor_desc); + UNUSED_VALUE(cnt); + MKI_LOG(ERROR) << "fail in SetRealDimCnt"; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus +TensorUtils::GetReuseInputIndex(const GeTensorDesc &tensor_desc, uint32_t &idx) +{ + UNUSED_VALUE(tensor_desc); + UNUSED_VALUE(idx); + MKI_LOG(ERROR) << "fail in GetReuseInputIndex"; + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetReuseInputIndex(GeTensorDesc &tensor_desc, + const uint32_t idx) +{ + UNUSED_VALUE(tensor_desc); + UNUSED_VALUE(idx); + MKI_LOG(ERROR) << "fail in SetReuseInputIndex"; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::GetDataOffset(const GeTensorDesc &tensor_desc, + int64_t &offset) +{ + UNUSED_VALUE(tensor_desc); + UNUSED_VALUE(offset); + MKI_LOG(ERROR) << "fail in GetDataOffset"; + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetDataOffset(GeTensorDesc &tensor_desc, + const int64_t offset) +{ + UNUSED_VALUE(tensor_desc); + UNUSED_VALUE(offset); + MKI_LOG(ERROR) << "fail in SetDataOffset"; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::GetRC(const GeTensorDesc &tensor_desc, + uint32_t &rc) +{ + UNUSED_VALUE(tensor_desc); + UNUSED_VALUE(rc); + MKI_LOG(ERROR) << "fail in GetRC"; + return 0; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetRC(GeTensorDesc &tensor_desc, const uint32_t rc) +{ + UNUSED_VALUE(tensor_desc); + MKI_LOG(ERROR) << "fail in SetRC"; + UNUSED_VALUE(rc); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool TensorUtils::IsOriginShapeInited(const GeTensorDesc &tensor_desc) +{ + UNUSED_VALUE(tensor_desc); + MKI_LOG(ERROR) << "fail in IsOriginShapeInited"; + return true; +} + +GeTensor TensorUtils::CreateShareTensor(const GeTensor &other) +{ + GeTensor tensor; + ShareTensor(other, tensor); + MKI_LOG(ERROR) << "fail in CreateShareTensor"; + return tensor; +} + +GeTensor TensorUtils::CreateShareTensor(const GeTensorDesc &tensor_desc, std::shared_ptr aligned_ptr, + const size_t size) +{ + UNUSED_VALUE(tensor_desc); + UNUSED_VALUE(aligned_ptr); + UNUSED_VALUE(size); + MKI_LOG(ERROR) << "fail in CreateShareTensor"; + return GeTensor(); +} + +void TensorUtils::ShareTensor(const GeTensor &from, GeTensor &to) +{ + UNUSED_VALUE(from); + UNUSED_VALUE(to); + MKI_LOG(ERROR) << "fail in ShareTensor"; +} + +void TensorUtils::ShareTensorData(const TensorData &from, TensorData &to) +{ + UNUSED_VALUE(from); + UNUSED_VALUE(to); + MKI_LOG(ERROR) << "fail in ShareTensorData"; +} + +TensorData TensorUtils::CreateShareTensorData(const TensorData &other) +{ + UNUSED_VALUE(other); + MKI_LOG(ERROR) << "fail in CreateShareTensorData"; + return TensorData(); +} + +void TensorUtils::ShareAlignedPtr(std::shared_ptr ptr, const size_t size, TensorData &to) +{ + UNUSED_VALUE(ptr); + UNUSED_VALUE(size); + UNUSED_VALUE(to); + MKI_LOG(ERROR) << "fail in ShareAlignedPtr"; +} + +void TensorUtils::ShareAlignedPtr(std::shared_ptr ptr, const size_t size, GeTensor &to) +{ + UNUSED_VALUE(ptr); + UNUSED_VALUE(size); + UNUSED_VALUE(to); + MKI_LOG(ERROR) << "fail in ShareAlignedPtr"; +} + +// UT +void TensorUtils::CopyTensor(const GeTensor &from, GeTensor &to) +{ + UNUSED_VALUE(from); + UNUSED_VALUE(to); + MKI_LOG(ERROR) << "fail in CopyTensor"; +} +} // namespace ge diff --git a/src/kernels/tbe_adapter/stubs/metadef/graph/op_desc.cpp b/src/kernels/tbe_adapter/stubs/metadef/graph/op_desc.cpp new file mode 100644 index 0000000000000000000000000000000000000000..49f8c5ac72cb21cc417232f7d3cef272cc243143 --- /dev/null +++ b/src/kernels/tbe_adapter/stubs/metadef/graph/op_desc.cpp @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include "graph/op_desc.h" + +#include "graph/ge_tensor.h" + +#define UNUSED_VALUE(x) (void)(x) +namespace ge { +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY size_t OpDesc::GetInputsSize() const { return 0; } + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY size_t OpDesc::GetOutputsSize() const { return 0; } + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDescPtr OpDesc::MutableInputDesc(const uint32_t index) const +{ + UNUSED_VALUE(index); + return nullptr; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDescPtr OpDesc::MutableOutputDesc(const uint32_t index) const +{ + UNUSED_VALUE(index); + return nullptr; +} + +int32_t OpDesc::GetInputIndexByName(const std::string &name) const +{ + UNUSED_VALUE(name); + return 0; +} +} // namespace ge diff --git a/src/kernels/tbe_adapter/stubs/metadef/graph/operator.cpp b/src/kernels/tbe_adapter/stubs/metadef/graph/operator.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d61031103c1f08f191f386b152ab6760c890fea4 --- /dev/null +++ b/src/kernels/tbe_adapter/stubs/metadef/graph/operator.cpp @@ -0,0 +1,1367 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include "external/graph/operator.h" + +#include +#include +#include +#include +#include + +#include "debug/ge_util.h" +#include "graph/utils/op_desc_utils.h" +#include "graph/utils/graph_utils_ex.h" +#include "graph/utils/node_utils_ex.h" +#include "graph/utils/op_desc_utils_ex.h" + +#define UNUSED_VALUE(x) (void)(x) +namespace ge { +Graph::Graph(const char_t *name) +{ + UNUSED_VALUE(name); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator OpDescUtils::CreateOperatorFromNode(ge::ConstNodePtr node_ptr) +{ + UNUSED_VALUE(node_ptr); + return Operator("default"); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDescUtils::CopyOperators( + const ComputeGraphPtr &dst_compute_graph, const std::map &node_old_2_new, + const std::map &op_desc_old_2_new, + const std::map &src_op_list, std::map &dst_op_list) +{ + UNUSED_VALUE(dst_compute_graph); + UNUSED_VALUE(node_old_2_new); + UNUSED_VALUE(op_desc_old_2_new); + UNUSED_VALUE(src_op_list); + UNUSED_VALUE(dst_op_list); + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDescUtils::CopyOperatorLinks( + const std::map &src_op_list, std::map &dst_op_list) +{ + UNUSED_VALUE(src_op_list); + UNUSED_VALUE(dst_op_list); + return GRAPH_SUCCESS; +} + +Operator::Operator(const std::string &type) +{ + UNUSED_VALUE(type); +} + +Operator::Operator(const char_t *type) +{ + UNUSED_VALUE(type); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator OpDescUtils::CreateOperatorFromOpDesc(OpDescPtr op_desc) +{ + UNUSED_VALUE(op_desc); + return Operator("default"); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr OpDescUtils::GetOpDescFromOperator(const Operator &oprt) +{ + UNUSED_VALUE(oprt); + return nullptr; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ConstNodePtr NodeUtilsEx::GetNodeFromOperator(const Operator &op) +{ + UNUSED_VALUE(op); + return nullptr; +} + +GE_FUNC_HOST_VISIBILITY Operator::Operator(const std::string &name, const std::string &type) +{ + UNUSED_VALUE(name); + UNUSED_VALUE(type); +} + +GE_FUNC_HOST_VISIBILITY Operator::Operator(const AscendString &name, const AscendString &type) +{ + UNUSED_VALUE(name); + UNUSED_VALUE(type); +} + +GE_FUNC_HOST_VISIBILITY Operator::Operator(const char_t *name, const char_t *type) +{ + UNUSED_VALUE(name); + UNUSED_VALUE(type); +} + +Operator::Operator(ge::OperatorImplPtr &&op_impl) +{ + UNUSED_VALUE(op_impl); +} + +bool Operator::IsEmpty() const { return true; } + +std::string Operator::GetName() const { return ""; } + +graphStatus Operator::GetName(AscendString &name) const +{ + UNUSED_VALUE(name); + return GRAPH_SUCCESS; +} + +GE_FUNC_HOST_VISIBILITY Operator &Operator::SetInput(const std::string &dst_name, const ge::Operator &src_oprt) +{ + UNUSED_VALUE(dst_name); + UNUSED_VALUE(src_oprt); + return *this; +} + +GE_FUNC_HOST_VISIBILITY Operator &Operator::SetInput(const char_t *dst_name, const ge::Operator &src_oprt) +{ + UNUSED_VALUE(dst_name); + UNUSED_VALUE(src_oprt); + return *this; +} + +Operator &Operator::SetInput(const std::string &dst_name, const ge::OutHandler &out_handler) +{ + UNUSED_VALUE(dst_name); + UNUSED_VALUE(out_handler); + return *this; +} + +Operator &Operator::SetInput(const char_t *dst_name, const ge::OutHandler &out_handler) +{ + UNUSED_VALUE(dst_name); + UNUSED_VALUE(out_handler); + return *this; +} + +Operator &Operator::SetInput(const std::string &dst_name, const ge::Operator &src_oprt, const std::string &name) +{ + UNUSED_VALUE(dst_name); + UNUSED_VALUE(src_oprt); + UNUSED_VALUE(name); + return *this; +} + +Operator &Operator::SetInput(const char_t *dst_name, const ge::Operator &src_oprt, const char_t *name) +{ + UNUSED_VALUE(dst_name); + UNUSED_VALUE(src_oprt); + UNUSED_VALUE(name); + return *this; +} + +Operator &Operator::SetInput(const std::string &dst_name, const ge::Operator &src_oprt, uint32_t index) +{ + UNUSED_VALUE(dst_name); + UNUSED_VALUE(src_oprt); + UNUSED_VALUE(index); + return *this; +} + +Operator &Operator::SetInput(const char_t *dst_name, const ge::Operator &src_oprt, uint32_t index) +{ + UNUSED_VALUE(dst_name); + UNUSED_VALUE(src_oprt); + UNUSED_VALUE(index); + return *this; +} + +Operator &Operator::SetInput(uint32_t dst_index, const Operator &src_oprt, uint32_t src_index) +{ + UNUSED_VALUE(dst_index); + UNUSED_VALUE(src_oprt); + UNUSED_VALUE(src_index); + return *this; +} + +Operator &Operator::AddControlInput(const Operator &src_oprt) +{ + UNUSED_VALUE(src_oprt); + return *this; +} + +graphStatus Operator::GetInputConstData(const std::string &dst_name, Tensor &data) const +{ + UNUSED_VALUE(dst_name); + UNUSED_VALUE(data); + return GRAPH_SUCCESS; +} + +graphStatus Operator::GetInputConstData(const char_t *dst_name, Tensor &data) const +{ + UNUSED_VALUE(dst_name); + UNUSED_VALUE(data); + return GRAPH_SUCCESS; +} + +graphStatus Operator::GetInputConstDataOut(const std::string &dst_name, Tensor &data) const +{ + UNUSED_VALUE(dst_name); + UNUSED_VALUE(data); + return GRAPH_SUCCESS; +} + +graphStatus Operator::GetInputConstDataOut(const char_t *dst_name, Tensor &data) const +{ + UNUSED_VALUE(dst_name); + UNUSED_VALUE(data); + return GRAPH_SUCCESS; +} + +std::shared_ptr Operator::GetNode() const { return nullptr; } + +TensorDesc Operator::GetInputDesc(const std::string &name) const +{ + UNUSED_VALUE(name); + return TensorDesc(); +} + +TensorDesc Operator::GetInputDescByName(const char_t *name) const +{ + UNUSED_VALUE(name); + return TensorDesc(); +} + +void Operator::SetInferenceContext(const InferenceContextPtr &inference_context) +{ + UNUSED_VALUE(inference_context); +} + +InferenceContextPtr Operator::GetInferenceContext() const { return nullptr; } + +TensorDesc Operator::GetInputDesc(uint32_t index) const +{ + UNUSED_VALUE(index); + return TensorDesc(); +} + +graphStatus Operator::TryGetInputDesc(const std::string &name, TensorDesc &tensor_desc) const +{ + UNUSED_VALUE(name); + UNUSED_VALUE(tensor_desc); + return GRAPH_SUCCESS; +} + +graphStatus Operator::TryGetInputDesc(const char_t *name, TensorDesc &tensor_desc) const +{ + UNUSED_VALUE(name); + UNUSED_VALUE(tensor_desc); + return GRAPH_SUCCESS; +} + +graphStatus Operator::UpdateInputDesc(const std::string &name, const ge::TensorDesc &tensor_desc) +{ + UNUSED_VALUE(name); + UNUSED_VALUE(tensor_desc); + return GRAPH_SUCCESS; +} + +graphStatus Operator::UpdateInputDesc(const char_t *name, const ge::TensorDesc &tensor_desc) +{ + UNUSED_VALUE(name); + UNUSED_VALUE(tensor_desc); + return GRAPH_SUCCESS; +} + +OutHandler Operator::GetOutput(const std::string &name) const +{ + return GetOutput(name.c_str()); +} + +OutHandler Operator::GetOutput(const char_t *name) const +{ + UNUSED_VALUE(name); + return nullptr; +} + +OutHandler Operator::GetOutput(uint32_t index) const +{ + UNUSED_VALUE(index); + return nullptr; +} + +TensorDesc Operator::GetOutputDesc(const std::string &name) const +{ + UNUSED_VALUE(name); + return TensorDesc(); +} + +TensorDesc Operator::GetOutputDescByName(const char_t *name) const +{ + UNUSED_VALUE(name); + return TensorDesc(); +} + +TensorDesc Operator::GetOutputDesc(uint32_t index) const +{ + UNUSED_VALUE(index); + return TensorDesc(); +} + +graphStatus Operator::UpdateOutputDesc(const std::string &name, const ge::TensorDesc &tensor_desc) +{ + UNUSED_VALUE(name); + UNUSED_VALUE(tensor_desc); + return GRAPH_SUCCESS; +} + +graphStatus Operator::UpdateOutputDesc(const char_t *name, const ge::TensorDesc &tensor_desc) +{ + UNUSED_VALUE(name); + UNUSED_VALUE(tensor_desc); + return GRAPH_SUCCESS; +} + +TensorDesc Operator::GetDynamicInputDesc(const std::string &name, uint32_t index) const +{ + UNUSED_VALUE(name); + UNUSED_VALUE(index); + return TensorDesc(); +} + +TensorDesc Operator::GetDynamicInputDesc(const char_t *name, uint32_t index) const +{ + UNUSED_VALUE(name); + UNUSED_VALUE(index); + return TensorDesc(); +} + +graphStatus Operator::UpdateDynamicInputDesc(const std::string &name, uint32_t index, const TensorDesc &tensor_desc) +{ + UNUSED_VALUE(name); + UNUSED_VALUE(index); + UNUSED_VALUE(tensor_desc); + return GRAPH_SUCCESS; +} + +graphStatus Operator::UpdateDynamicInputDesc(const char_t *name, uint32_t index, const TensorDesc &tensor_desc) +{ + UNUSED_VALUE(name); + UNUSED_VALUE(index); + UNUSED_VALUE(tensor_desc); + return GRAPH_SUCCESS; +} + +TensorDesc Operator::GetDynamicOutputDesc(const std::string &name, uint32_t index) const +{ + UNUSED_VALUE(name); + UNUSED_VALUE(index); + return TensorDesc(); +} + +TensorDesc Operator::GetDynamicOutputDesc(const char_t *name, uint32_t index) const +{ + UNUSED_VALUE(name); + UNUSED_VALUE(index); + return TensorDesc(); +} + +graphStatus Operator::UpdateDynamicOutputDesc(const std::string &name, uint32_t index, const TensorDesc &tensor_desc) +{ + UNUSED_VALUE(name); + UNUSED_VALUE(index); + UNUSED_VALUE(tensor_desc); + return GRAPH_SUCCESS; +} + +graphStatus Operator::UpdateDynamicOutputDesc(const char_t *name, uint32_t index, const TensorDesc &tensor_desc) +{ + UNUSED_VALUE(name); + UNUSED_VALUE(index); + UNUSED_VALUE(tensor_desc); + return GRAPH_SUCCESS; +} + +graphStatus Operator::InferShapeAndType() { return GRAPH_SUCCESS; } + +graphStatus Operator::VerifyAllAttr(bool disable_common_verifier) +{ + UNUSED_VALUE(disable_common_verifier); + return GRAPH_SUCCESS; +} + +GE_FUNC_HOST_VISIBILITY size_t Operator::GetInputsSize() const { return 0UL; } + +GE_FUNC_HOST_VISIBILITY size_t Operator::GetOutputsSize() const { return 0UL; } + +const std::map Operator::GetAllAttrNamesAndTypes() const +{ + static std::map attr_types; + return attr_types; +} + +graphStatus Operator::GetAllAttrNamesAndTypes(std::map &attr_name_types) const +{ + UNUSED_VALUE(attr_name_types); + return GRAPH_SUCCESS; +} + +void Operator::InputRegister(const std::string &name) +{ + UNUSED_VALUE(name); +} + +void Operator::InputRegister(const char_t *name) +{ + UNUSED_VALUE(name); +} + +void Operator::InputRegister(const char_t *name, const char_t *datatype_symbol) +{ + UNUSED_VALUE(name); + UNUSED_VALUE(datatype_symbol); +} + +void Operator::OptionalInputRegister(const std::string &name) +{ + UNUSED_VALUE(name); +} + +void Operator::OptionalInputRegister(const char_t *name) +{ + UNUSED_VALUE(name); +} + +void Operator::OptionalInputRegister(const char_t *name, const char_t *datatype_symbol) +{ + UNUSED_VALUE(name); + UNUSED_VALUE(datatype_symbol); +} + +void Operator::InferFuncRegister(const std::function &func) +{ + UNUSED_VALUE(func); +} + +void Operator::InferFormatFuncRegister(const std::function &func) +{ + UNUSED_VALUE(func); +} + +void Operator::VerifierFuncRegister(const std::function &func) +{ + UNUSED_VALUE(func); +} + +void Operator::OutputRegister(const std::string &name) +{ + UNUSED_VALUE(name); +} + +void Operator::OutputRegister(const char_t *name) +{ + UNUSED_VALUE(name); +} + +void Operator::OutputRegister(const char_t *name, const char_t *datatype_symbol) +{ + UNUSED_VALUE(name); + UNUSED_VALUE(datatype_symbol); +} + +void Operator::DynamicInputRegister(const std::string &name, const uint32_t num, bool is_push_back) +{ + UNUSED_VALUE(name); + UNUSED_VALUE(num); + UNUSED_VALUE(is_push_back); +} + +void Operator::DynamicInputRegister(const char_t *name, const uint32_t num, bool is_push_back) +{ + UNUSED_VALUE(name); + UNUSED_VALUE(num); + UNUSED_VALUE(is_push_back); +} + +void Operator::DynamicInputRegisterByIndex(const std::string &name, const uint32_t num, size_t index) +{ + UNUSED_VALUE(name); + UNUSED_VALUE(num); + UNUSED_VALUE(index); +} + +void Operator::DynamicInputRegisterByIndex(const char_t *name, const uint32_t num, size_t index) +{ + UNUSED_VALUE(name); + UNUSED_VALUE(num); + UNUSED_VALUE(index); +} + +int32_t Operator::GetDynamicInputNum(const std::string &name) const +{ + UNUSED_VALUE(name); + return 0; +} + +int32_t Operator::GetDynamicInputNum(const char_t *name) const +{ + UNUSED_VALUE(name); + return 0; +} + +void Operator::DynamicOutputRegister(const std::string &name, const uint32_t num, bool is_push_back) +{ + UNUSED_VALUE(name); + UNUSED_VALUE(num); + UNUSED_VALUE(is_push_back); +} + +void Operator::DynamicOutputRegister(const char_t *name, const uint32_t num, bool is_push_back) +{ + UNUSED_VALUE(name); + UNUSED_VALUE(num); + UNUSED_VALUE(is_push_back); +} + +int32_t Operator::GetDynamicOutputNum(const std::string &name) const +{ + UNUSED_VALUE(name); + return 0; +} + +int32_t Operator::GetDynamicOutputNum(const char_t *name) const +{ + UNUSED_VALUE(name); + return 0; +} + +void Operator::RequiredAttrRegister(const std::string &name) +{ + UNUSED_VALUE(name); +} + +void Operator::RequiredAttrRegister(const char_t *name) +{ + UNUSED_VALUE(name); +} + +void Operator::DataTypeRegister(const char_t *datatype_symbol, const TensorType &type_range) +{ + UNUSED_VALUE(datatype_symbol); + UNUSED_VALUE(type_range); +} + +void Operator::DataTypeRegister(const char_t *datatype_symbol, const ListTensorType &list_type_range) +{ + UNUSED_VALUE(datatype_symbol); + UNUSED_VALUE(list_type_range); +} + +graphStatus Operator::VerifyAll() { return GRAPH_SUCCESS; } + +std::string Operator::GetOpType() const { return ""; } + +graphStatus Operator::GetOpType(AscendString &type) const +{ + UNUSED_VALUE(type); + return GRAPH_SUCCESS; +} + +Operator &Operator::SetInput(const std::string &dst_name, uint32_t dst_index, const ge::Operator &src_oprt) +{ + UNUSED_VALUE(dst_name); + UNUSED_VALUE(dst_index); + UNUSED_VALUE(src_oprt); + return *this; +} + +Operator &Operator::SetInput(const char_t *dst_name, uint32_t dst_index, const ge::Operator &src_oprt) +{ + UNUSED_VALUE(dst_name); + UNUSED_VALUE(dst_index); + UNUSED_VALUE(src_oprt); + return *this; +} + +Operator &Operator::SetInput(const std::string &dst_name, uint32_t dst_index, const ge::Operator &src_oprt, + const std::string &name) +{ + UNUSED_VALUE(dst_name); + UNUSED_VALUE(dst_index); + UNUSED_VALUE(src_oprt); + UNUSED_VALUE(name); + return *this; +} + +Operator &Operator::SetInput(const char_t *dst_name, uint32_t dst_index, const ge::Operator &src_oprt, + const char_t *name) +{ + UNUSED_VALUE(dst_name); + UNUSED_VALUE(dst_index); + UNUSED_VALUE(src_oprt); + UNUSED_VALUE(name); + return *this; +} + +OperatorImplPtr Operator::GetOperatorImplPtr() const { return nullptr; } + +void Operator::BreakConnect() const {} + +void Operator::AttrRegister(const std::string &name, const std::string &attr_value) +{ + UNUSED_VALUE(name); + UNUSED_VALUE(attr_value); +} + +void Operator::AttrRegister(const std::string &name, const std::vector &attr_value) +{ + UNUSED_VALUE(name); + UNUSED_VALUE(attr_value); +} + +void Operator::AttrRegister(const std::string &name, int64_t attr_value) +{ + UNUSED_VALUE(name); + UNUSED_VALUE(attr_value); +} + +void Operator::AttrRegister(const std::string &name, const vector &attr_value) +{ + UNUSED_VALUE(name); + UNUSED_VALUE(attr_value); +} + +void Operator::AttrRegister(const std::string &name, float32_t attr_value) +{ + UNUSED_VALUE(name); + UNUSED_VALUE(attr_value); +} + +void Operator::AttrRegister(const std::string &name, const vector &attr_value) +{ + UNUSED_VALUE(name); + UNUSED_VALUE(attr_value); +} + +void Operator::AttrRegister(const char_t *name, const char_t *attr_value) +{ + UNUSED_VALUE(name); + UNUSED_VALUE(attr_value); +} + +void Operator::AttrRegister(const std::string &name, bool attr_value) +{ + UNUSED_VALUE(name); + UNUSED_VALUE(attr_value); +} + +void Operator::AttrRegister(const std::string &name, const vector &attr_value) +{ + UNUSED_VALUE(name); + UNUSED_VALUE(attr_value); +} + +void Operator::AttrRegister(const std::string &name, const vector> &attr_value) +{ + UNUSED_VALUE(name); + UNUSED_VALUE(attr_value); +} + +void Operator::AttrRegister(const std::string &name, const NamedAttrs &attr_value) +{ + UNUSED_VALUE(name); + UNUSED_VALUE(attr_value); +} + +void Operator::AttrRegister(const std::string &name, const vector &attr_value) +{ + UNUSED_VALUE(name); + UNUSED_VALUE(attr_value); +} + +void Operator::AttrRegister(const std::string &name, const AscendString &attr_value) +{ + UNUSED_VALUE(name); + UNUSED_VALUE(attr_value); +} + +void Operator::AttrRegister(const char_t *name, const AscendString &attr_value) +{ + UNUSED_VALUE(name); + UNUSED_VALUE(attr_value); +} + +void Operator::AttrRegister(const std::string &name, const std::vector &attr_value) +{ + UNUSED_VALUE(name); + UNUSED_VALUE(attr_value); +} + +void Operator::AttrRegister(const char_t *name, const std::vector &attr_value) +{ + UNUSED_VALUE(name); + UNUSED_VALUE(attr_value); +} + +Operator &Operator::SetAttr(const std::string &name, const std::string &attr_value) +{ + UNUSED_VALUE(name); + UNUSED_VALUE(attr_value); + return *this; +} + +graphStatus Operator::GetAttr(const std::string &name, std::string &attr_value) const +{ + UNUSED_VALUE(name); + UNUSED_VALUE(attr_value); + return GRAPH_SUCCESS; +} + +graphStatus Operator::GetAttr(const char_t *name, int32_t &attr_value) const +{ + UNUSED_VALUE(name); + UNUSED_VALUE(attr_value); + return GRAPH_SUCCESS; +} + +graphStatus Operator::GetAttr(const char_t *name, bool &attr_value) const +{ + UNUSED_VALUE(name); + UNUSED_VALUE(attr_value); + return GRAPH_SUCCESS; +} + +graphStatus Operator::GetAttr(const char_t *name, int64_t &attr_value) const +{ + UNUSED_VALUE(name); + UNUSED_VALUE(attr_value); + return GRAPH_SUCCESS; +} + +graphStatus Operator::GetAttr(const char_t *name, float32_t &attr_value) const +{ + UNUSED_VALUE(name); + UNUSED_VALUE(attr_value); + return GRAPH_SUCCESS; +} + +graphStatus Operator::GetAttr(const char_t *name, std::vector &attr_value) const +{ + UNUSED_VALUE(name); + UNUSED_VALUE(attr_value); + return GRAPH_SUCCESS; +} + +graphStatus Operator::GetAttr(const char_t *name, std::vector &attr_value) const +{ + UNUSED_VALUE(name); + UNUSED_VALUE(attr_value); + return GRAPH_SUCCESS; +} + +graphStatus Operator::GetAttr(const char_t *name, std::vector &attr_value) const +{ + UNUSED_VALUE(name); + UNUSED_VALUE(attr_value); + return GRAPH_SUCCESS; +} + +Operator &Operator::SetAttr(const std::string &name, const std::vector &attr_value) +{ + UNUSED_VALUE(name); + UNUSED_VALUE(attr_value); + return *this; +} + +graphStatus Operator::GetAttr(const std::string &name, std::vector &attr_value) const +{ + UNUSED_VALUE(name); + UNUSED_VALUE(attr_value); + return GRAPH_SUCCESS; +} + +Operator &Operator::SetAttr(const char_t *name, const char_t *attr_value) +{ + UNUSED_VALUE(name); + UNUSED_VALUE(attr_value); + return *this; +} + +Operator &Operator::SetInputAttr(const int32_t index, const char_t *name, const char_t *attr_value) +{ + UNUSED_VALUE(index); + UNUSED_VALUE(name); + UNUSED_VALUE(attr_value); + return *this; +} + +Operator &Operator::SetInputAttr(const char_t *dst_name, const char_t *name, const char_t *attr_value) +{ + UNUSED_VALUE(dst_name); + UNUSED_VALUE(name); + UNUSED_VALUE(attr_value); + return *this; +} + +Operator &Operator::SetInputAttr(const int32_t index, const char_t *name, const std::vector &attr_value) +{ + UNUSED_VALUE(index); + UNUSED_VALUE(name); + UNUSED_VALUE(attr_value); + return *this; +} + +Operator &Operator::SetOutputAttr(const int32_t index, const char_t *name, const std::vector &attr_value) +{ + UNUSED_VALUE(index); + UNUSED_VALUE(name); + UNUSED_VALUE(attr_value); + return *this; +} + +Operator &Operator::SetInputAttr(const char_t *dst_name, const char_t *name, + const std::vector &attr_value) +{ + UNUSED_VALUE(dst_name); + UNUSED_VALUE(name); + UNUSED_VALUE(attr_value); + return *this; +} + +Operator &Operator::SetOutputAttr(const char_t *dst_name, const char_t *name, + const std::vector &attr_value) +{ + UNUSED_VALUE(dst_name); + UNUSED_VALUE(name); + UNUSED_VALUE(attr_value); + return *this; +} + +graphStatus Operator::GetInputAttr(const char_t *dst_name, const char_t *name, + std::vector &attr_value) const +{ + UNUSED_VALUE(dst_name); + UNUSED_VALUE(name); + UNUSED_VALUE(attr_value); + return GRAPH_SUCCESS; +} + +graphStatus Operator::GetOutputAttr(const char_t *dst_name, const char_t *name, + std::vector &attr_value) const +{ + UNUSED_VALUE(dst_name); + UNUSED_VALUE(name); + UNUSED_VALUE(attr_value); + return GRAPH_SUCCESS; +} + +graphStatus Operator::GetInputAttr(const int32_t index, const char_t *name, std::vector &attr_value) const +{ + UNUSED_VALUE(index); + UNUSED_VALUE(name); + UNUSED_VALUE(attr_value); + return GRAPH_SUCCESS; +} + +graphStatus Operator::GetOutputAttr(const int32_t index, const char_t *name, + std::vector &attr_value) const +{ + UNUSED_VALUE(index); + UNUSED_VALUE(name); + UNUSED_VALUE(attr_value); + return GRAPH_SUCCESS; +} + +Operator &Operator::SetOutputAttr(const int32_t index, const char_t *name, const char_t *attr_value) +{ + UNUSED_VALUE(index); + UNUSED_VALUE(name); + UNUSED_VALUE(attr_value); + return *this; +} + +Operator &Operator::SetOutputAttr(const char_t *dst_name, const char_t *name, const char_t *attr_value) +{ + UNUSED_VALUE(dst_name); + UNUSED_VALUE(name); + UNUSED_VALUE(attr_value); + return *this; +} + +Operator &Operator::SetInputAttr(const int32_t index, const char_t *name, const AscendString &attr_value) +{ + UNUSED_VALUE(index); + UNUSED_VALUE(name); + UNUSED_VALUE(attr_value); + return *this; +} + +Operator &Operator::SetInputAttr(const char_t *dst_name, const char_t *name, const AscendString &attr_value) +{ + UNUSED_VALUE(dst_name); + UNUSED_VALUE(name); + UNUSED_VALUE(attr_value); + return *this; +} + +Operator &Operator::SetOutputAttr(const int32_t index, const char_t *name, const AscendString &attr_value) +{ + UNUSED_VALUE(index); + UNUSED_VALUE(name); + UNUSED_VALUE(attr_value); + return *this; +} + +Operator &Operator::SetOutputAttr(const char_t *dst_name, const char_t *name, const AscendString &attr_value) +{ + UNUSED_VALUE(dst_name); + UNUSED_VALUE(name); + UNUSED_VALUE(attr_value); + return *this; +} + +graphStatus Operator::GetOutputAttr(const char_t *dst_name, const char_t *name, AscendString &attr_value) const +{ + UNUSED_VALUE(dst_name); + UNUSED_VALUE(name); + UNUSED_VALUE(attr_value); + return GRAPH_SUCCESS; +} + +graphStatus Operator::GetInputAttr(const char_t *dst_name, const char_t *name, AscendString &attr_value) const +{ + UNUSED_VALUE(dst_name); + UNUSED_VALUE(name); + UNUSED_VALUE(attr_value); + return GRAPH_SUCCESS; +} + +graphStatus Operator::GetInputAttr(const int32_t index, const char_t *name, AscendString &attr_value) const +{ + UNUSED_VALUE(index); + UNUSED_VALUE(name); + UNUSED_VALUE(attr_value); + return GRAPH_SUCCESS; +} + +graphStatus Operator::GetOutputAttr(const int32_t index, const char_t *name, AscendString &attr_value) const +{ + UNUSED_VALUE(index); + UNUSED_VALUE(name); + UNUSED_VALUE(attr_value); + return GRAPH_SUCCESS; +} + +Operator &Operator::SetAttr(const char_t *name, const AscendString &attr_value) +{ + UNUSED_VALUE(name); + UNUSED_VALUE(attr_value); + return *this; +} + +graphStatus Operator::GetAttr(const char_t *name, AscendString &attr_value) const +{ + UNUSED_VALUE(name); + UNUSED_VALUE(attr_value); + return GRAPH_SUCCESS; +} + +Operator &Operator::SetAttr(const char_t *name, const std::vector &attr_values) +{ + UNUSED_VALUE(name); + UNUSED_VALUE(attr_values); + return *this; +} + +graphStatus Operator::GetAttr(const char_t *name, std::vector &attr_values) const +{ + UNUSED_VALUE(name); + UNUSED_VALUE(attr_values); + return GRAPH_SUCCESS; +} + +Operator &Operator::SetAttr(const std::string &name, const Tensor &attr_value) +{ + UNUSED_VALUE(name); + UNUSED_VALUE(attr_value); + return *this; +} + +Operator &Operator::SetAttr(const char_t *name, const Tensor &attr_value) +{ + UNUSED_VALUE(name); + UNUSED_VALUE(attr_value); + return *this; +} + +Operator &Operator::SetAttr(const std::string &name, const std::vector &attr_value) +{ + UNUSED_VALUE(name); + UNUSED_VALUE(attr_value); + return *this; +} + +Operator &Operator::SetAttr(const char_t *name, const std::vector &attr_value) +{ + UNUSED_VALUE(name); + UNUSED_VALUE(attr_value); + return *this; +} + +graphStatus Operator::GetAttr(const std::string &name, Tensor &attr_value) const +{ + UNUSED_VALUE(name); + UNUSED_VALUE(attr_value); + return GRAPH_SUCCESS; +} + +graphStatus Operator::GetAttr(const char_t *name, Tensor &attr_value) const +{ + UNUSED_VALUE(name); + UNUSED_VALUE(attr_value); + return GRAPH_SUCCESS; +} + +graphStatus Operator::GetAttr(const std::string &name, std::vector &attr_value) const +{ + UNUSED_VALUE(name); + UNUSED_VALUE(attr_value); + return GRAPH_SUCCESS; +} + +graphStatus Operator::GetAttr(const char_t *name, std::vector &attr_value) const +{ + UNUSED_VALUE(name); + UNUSED_VALUE(attr_value); + return GRAPH_SUCCESS; +} + +Operator &Operator::SetAttr(const std::string &name, const OpBytes &attr_value) +{ + UNUSED_VALUE(name); + UNUSED_VALUE(attr_value); + return *this; +} + +Operator &Operator::SetAttr(const char_t *name, const OpBytes &attr_value) +{ + UNUSED_VALUE(name); + UNUSED_VALUE(attr_value); + return *this; +} + +graphStatus Operator::GetAttr(const std::string &name, OpBytes &attr_value) const +{ + UNUSED_VALUE(name); + UNUSED_VALUE(attr_value); + return GRAPH_SUCCESS; +} + +graphStatus Operator::GetAttr(const char_t *name, OpBytes &attr_value) const +{ + UNUSED_VALUE(name); + UNUSED_VALUE(attr_value); + return GRAPH_SUCCESS; +} + +Operator &Operator::SetAttr(const std::string &name, ge::AttrValue &&attr_value) +{ + UNUSED_VALUE(name); + UNUSED_VALUE(attr_value); + return *this; +} + +Operator &Operator::SetAttr(const char_t *name, ge::AttrValue &&attr_value) +{ + UNUSED_VALUE(name); + UNUSED_VALUE(attr_value); + return *this; +} + +graphStatus Operator::GetAttr(const std::string &name, ge::AttrValue &attr_value) const +{ + UNUSED_VALUE(name); + UNUSED_VALUE(attr_value); + return GRAPH_SUCCESS; +} + +graphStatus Operator::GetAttr(const char_t *name, ge::AttrValue &attr_value) const +{ + UNUSED_VALUE(name); + UNUSED_VALUE(attr_value); + return GRAPH_SUCCESS; +} + +Operator &Operator::SetAttr(const std::string &name, const std::vector &attr_value) +{ + UNUSED_VALUE(name); + UNUSED_VALUE(attr_value); + return *this; +} + +Operator &Operator::SetAttr(const char_t *name, const std::vector &attr_value) +{ + UNUSED_VALUE(name); + UNUSED_VALUE(attr_value); + return *this; +} + +graphStatus Operator::GetAttr(const std::string &name, std::vector &attr_value) const +{ + UNUSED_VALUE(name); + UNUSED_VALUE(attr_value); + return GRAPH_SUCCESS; +} + +graphStatus Operator::GetAttr(const char_t *name, std::vector &attr_value) const +{ + UNUSED_VALUE(name); + UNUSED_VALUE(attr_value); + return GRAPH_SUCCESS; +} + +Operator &Operator::SetAttr(const std::string &name, const ge::DataType &attr_value) +{ + UNUSED_VALUE(name); + UNUSED_VALUE(attr_value); + return *this; +} + +Operator &Operator::SetAttr(const char_t *name, const ge::DataType &attr_value) +{ + UNUSED_VALUE(name); + UNUSED_VALUE(attr_value); + return *this; +} + +graphStatus Operator::GetAttr(const std::string &name, ge::DataType &attr_value) const +{ + UNUSED_VALUE(name); + UNUSED_VALUE(attr_value); + return GRAPH_SUCCESS; +} + +graphStatus Operator::GetAttr(const char_t *name, ge::DataType &attr_value) const +{ + UNUSED_VALUE(name); + UNUSED_VALUE(attr_value); + return GRAPH_SUCCESS; +} + +void Operator::AttrRegister(const std::string &name, const std::vector &attr_value) +{ + UNUSED_VALUE(name); + UNUSED_VALUE(attr_value); +} + +void Operator::AttrRegister(const char_t *name, const std::vector &attr_value) +{ + UNUSED_VALUE(name); + UNUSED_VALUE(attr_value); +} + +void Operator::AttrRegister(const std::string &name, const ge::DataType &attr_value) +{ + UNUSED_VALUE(name); + UNUSED_VALUE(attr_value); +} + +void Operator::AttrRegister(const char_t *name, const ge::DataType &attr_value) +{ + UNUSED_VALUE(name); + UNUSED_VALUE(attr_value); +} + +void Operator::AttrRegister(const std::string &name, const Tensor &attr_value) +{ + UNUSED_VALUE(name); + UNUSED_VALUE(attr_value); +} + +void Operator::AttrRegister(const char_t *name, const Tensor &attr_value) +{ + UNUSED_VALUE(name); + UNUSED_VALUE(attr_value); +} + +void Operator::AttrRegister(const std::string &name, const std::vector &attr_value) +{ + UNUSED_VALUE(name); + UNUSED_VALUE(attr_value); +} + +void Operator::AttrRegister(const char_t *name, const vector &attr_value) +{ + UNUSED_VALUE(name); + UNUSED_VALUE(attr_value); +} + +void Operator::AttrRegister(const std::string &name, const OpBytes &attr_value) +{ + UNUSED_VALUE(name); + UNUSED_VALUE(attr_value); +} + +void Operator::AttrRegister(const char_t *name, const OpBytes &attr_value) +{ + UNUSED_VALUE(name); + UNUSED_VALUE(attr_value); +} + +void Operator::SubgraphRegister(const std::string &ir_name, bool dynamic) +{ + UNUSED_VALUE(ir_name); + UNUSED_VALUE(dynamic); +} + +void Operator::SubgraphRegister(const char_t *ir_name, bool dynamic) +{ + UNUSED_VALUE(ir_name); + UNUSED_VALUE(dynamic); +} + +void Operator::SubgraphCountRegister(const std::string &ir_name, uint32_t count) +{ + UNUSED_VALUE(ir_name); + UNUSED_VALUE(count); +} + +void Operator::SubgraphCountRegister(const char_t *ir_name, uint32_t count) +{ + UNUSED_VALUE(ir_name); + UNUSED_VALUE(count); +} + +void Operator::SetSubgraphBuilder(const std::string &ir_name, uint32_t index, const SubgraphBuilder &builder) +{ + UNUSED_VALUE(ir_name); + UNUSED_VALUE(index); + UNUSED_VALUE(builder); +} + +void Operator::SetSubgraphBuilder(const char_t *ir_name, uint32_t index, const SubgraphBuilder &builder) +{ + UNUSED_VALUE(ir_name); + UNUSED_VALUE(index); + UNUSED_VALUE(builder); +} + +std::vector Operator::GetSubgraphNames() const +{ + static std::vector names; + return names; +} + +graphStatus Operator::GetSubgraphNames(std::vector &names) const +{ + UNUSED_VALUE(names); + return GRAPH_SUCCESS; +} + +SubgraphBuilder Operator::GetDynamicSubgraphBuilder(const std::string &name, uint32_t index) const +{ + UNUSED_VALUE(name); + UNUSED_VALUE(index); + return nullptr; +} + +SubgraphBuilder Operator::GetDynamicSubgraphBuilder(const char_t *name, uint32_t index) const +{ + UNUSED_VALUE(name); + UNUSED_VALUE(index); + return nullptr; +} + +SubgraphBuilder Operator::GetSubgraphBuilder(const std::string &name) const +{ + UNUSED_VALUE(name); + return nullptr; +} + +SubgraphBuilder Operator::GetSubgraphBuilder(const char_t *name) const +{ + UNUSED_VALUE(name); + return nullptr; +} + +Graph Operator::GetSubgraphImpl(const std::string &name) const +{ + UNUSED_VALUE(name); + return Graph(); +} + +Graph Operator::GetSubgraphImpl(const char_t *name) const +{ + UNUSED_VALUE(name); + return Graph(); +} + +Graph Operator::GetSubgraph(const std::string &name) const +{ + UNUSED_VALUE(name); + return Graph(); +} + +Graph Operator::GetSubgraph(const char_t *name) const +{ + UNUSED_VALUE(name); + return Graph(); +} + +Graph Operator::GetDynamicSubgraph(const std::string &name, uint32_t index) const +{ + UNUSED_VALUE(name); + UNUSED_VALUE(index); + return Graph(); +} + +Graph Operator::GetDynamicSubgraph(const char_t *name, uint32_t index) const +{ + UNUSED_VALUE(name); + UNUSED_VALUE(index); + return Graph(); +} + +size_t Operator::GetSubgraphNamesCount() const +{ + return 0UL; +} + +void Operator::DynamicInputRegister(const char_t *name, const uint32_t num, const char_t *datatype_symbol, + bool is_push_back) +{ + UNUSED_VALUE(name); + UNUSED_VALUE(num); + UNUSED_VALUE(datatype_symbol); + UNUSED_VALUE(is_push_back); +} + +void Operator::DynamicOutputRegister(const char_t *name, const uint32_t num, const char_t *datatype_symbol, + bool is_push_back) +{ + UNUSED_VALUE(name); + UNUSED_VALUE(num); + UNUSED_VALUE(datatype_symbol); + UNUSED_VALUE(is_push_back); +} + +static inline bool HasSameNameNode(const ComputeGraphPtr &compute_graph) +{ + UNUSED_VALUE(compute_graph); + return true; +} + +ComputeGraphPtr GraphUtilsEx::CreateGraphFromOperator(const std::string &name, const std::vector &inputs) +{ + UNUSED_VALUE(name); + UNUSED_VALUE(inputs); + return nullptr; +} + +void GraphUtilsEx::BreakConnect(const std::map &all_nodes_infos) +{ + UNUSED_VALUE(all_nodes_infos); +} +} // namespace ge + diff --git a/src/kernels/tbe_adapter/tiling_runner/tbe_tiling_runner.cpp b/src/kernels/tbe_adapter/tiling_runner/tbe_tiling_runner.cpp new file mode 100644 index 0000000000000000000000000000000000000000..dbb88b896aa7332850563d9d9567ad5490592c92 --- /dev/null +++ b/src/kernels/tbe_adapter/tiling_runner/tbe_tiling_runner.cpp @@ -0,0 +1,700 @@ +/* + * Copyright (c) 2024-2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include "tbe_tiling_runner.h" + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "asdops/params/params.h" +#include "compute_node_info.h" +#include "continuous_vector.h" +#include "graph/any_value.h" +#include "graph/ge_tensor.h" +#include "graph/debug/ge_util.h" +#include "kernel_run_context.h" +#include "kernel_run_context_builder.h" +#include "platform/platform_infos_def.h" +#include "register/op_impl_registry_base.h" +#include "runtime_attrs_def.h" + +namespace { +const std::vector CORE_TYPE_VEC { + "AiCore", + "AiCore", + "VectorCore", + "AiCore", + "MIX", +}; +} + +namespace AsdOpsGeRt { +static std::unordered_map g_tilingParseCache; +std::mutex g_mutexTilingParseCache; + +class AsdOpsFePlatformInfosManager { +public: + static std::pair GetPlatFormInfos() + { + static std::once_flag initedFlag; + static fe::PlatFormInfos platformInfo; + static bool isSuccess; + + std::call_once(initedFlag, [&]() { isSuccess = InitPlatformInfo(platformInfo); }); + return std::make_pair(isSuccess, platformInfo); + } + + static bool InitPlatformInfo(fe::PlatFormInfos &platformInfo) + { + const uint32_t maxLen = 100; + std::string version; + MKI_CHECK(Mki::MkiRtDeviceGetSocVersion(version, maxLen) == MKIRT_SUCCESS, + "failed to get soc version", return false); + std::string socVersion(version); + MKI_LOG(DEBUG) << "tiling runner get soc version: " << socVersion; + Mki::PlatformManager &platformManager = Mki::PlatformManager::Instance(); + MKI_CHECK(platformManager.InitializePlatformManager() == Mki::PLATFORM_SUCCESS, + "failed to initialize platform manager", return false); + Mki::PlatformConfigs platformConfigs; + MKI_CHECK(platformManager.GetPlatformConfigs(socVersion, platformConfigs) == Mki::PLATFORM_SUCCESS, + "failed to get platform information", return false); + MKI_CHECK(platformInfo.Init(), "failed to initialize platformInfos", return false); + AdaptPlatformInfos(platformInfo, platformConfigs); + return true; + } + + static void AdaptPlatformInfos(fe::PlatFormInfos &platformInfo, Mki::PlatformConfigs &platformConfigs) + { + std::map> platformResMap = platformConfigs.GetPlatformSpecMap(); + + platformInfo.SetFixPipeDtypeMap(platformConfigs.GetFixPipeDtypeMap()); + platformInfo.SetAICoreIntrinsicDtype(platformConfigs.GetAICoreIntrinsicDtype()); + platformInfo.SetVectorCoreIntrinsicDtype(platformConfigs.GetVectorCoreIntrinsicDtype()); + for (auto &[label, res]: platformResMap) { + platformInfo.SetPlatformRes(label, res); + } + } +}; + +class TbeTilingRunnerImpl { +public: + TbeTilingRunnerImpl() = default; + ~TbeTilingRunnerImpl() = default; + + void SetName(const char *opType) { opType_ = opType; } + + void SetKernelName(const std::string kernelName) { kernelName_ = kernelName; } + + void AddInput(Mki::TensorDType dtype, Mki::TensorFormat format, const Mki::SVector &dims) + { + AddTensor(dtype, format, dims, inputs_); + } + + void AddConstInput(Mki::TensorDType dtype, Mki::TensorFormat format, + const Mki::SVector &dims, const void *data, size_t size) + { + Shape shape; + for (const auto dim : dims) { + (void)shape.AppendDim(dim); + } + size_t totalSize = 0UL; + auto tensorHolder = Tensor::CreateFollowing(GeDataType(dtype), size, totalSize); + MKI_CHECK(tensorHolder != nullptr, "tensorHolder is nullptr", return); + auto tensor = reinterpret_cast(tensorHolder.get()); + if (memcpy_s(tensor->GetData(), totalSize - sizeof(Tensor), data, size) != EOK) { + MKI_LOG(ERROR) << "Failed to add const input"; + return; + } + tensor->MutableOriginShape() = shape; + tensor->MutableStorageShape() = shape; + tensor->SetDataType(GeDataType(dtype)); + tensor->SetStorageFormat(GeFormat(format)); + tensor->SetOriginFormat(GeFormat(format)); + contextComponent_.indexToTensors.emplace_back(inputs_.size(), std::move(tensorHolder)); + AddTensor(dtype, format, dims, inputs_); + } + + void AddOutput(Mki::TensorDType dtype, Mki::TensorFormat format, const Mki::SVector &dims) + { + AddTensor(dtype, format, dims, outputs_); + } + + void AddAttrBool(bool value) + { + auto data = ge::ComGraphMakeUnique(sizeof(uint8_t)); + auto ret = memcpy_s(data.get(), sizeof(uint8_t), &value, sizeof(uint8_t)); + MKI_CHECK(ret == EOK, "failed to copy attr bool", return); + attrs_.emplace_back(std::make_pair(std::move(data), sizeof(uint8_t))); + } + + void AddAttrInt64(int64_t attr) + { + auto data = ge::ComGraphMakeUnique(sizeof(int64_t)); + auto ret = memcpy_s(data.get(), sizeof(int64_t), &attr, sizeof(int64_t)); + MKI_CHECK(ret == EOK, "failed to copy attr int64", return); + attrs_.emplace_back(std::make_pair(std::move(data), sizeof(int64_t))); + } + + void AddAttrFloat(float attr) + { + auto data = ge::ComGraphMakeUnique(sizeof(float)); + auto ret = memcpy_s(data.get(), sizeof(float), &attr, sizeof(float)); + MKI_CHECK(ret == EOK, "failed to copy attr float", return); + attrs_.emplace_back(std::make_pair(std::move(data), sizeof(float))); + } + + void AddAttrStr(const char *attr) + { + size_t dataLen = strlen(attr) + 1; + auto data = ge::ComGraphMakeUnique(dataLen); + auto ret = memcpy_s(data.get(), dataLen, attr, dataLen); + MKI_CHECK(ret == EOK, "failed to copy attr float", return); + attrs_.emplace_back(std::make_pair(std::move(data), dataLen)); + } + + void AddAttrIntList(const int64_t *attr, const size_t num) + { + size_t totalSize = 0; + auto data = ContinuousVector::Create(num, totalSize); + MKI_CHECK(data != nullptr, "failed to create ContinuousVector", return); + auto dataVec = reinterpret_cast(data.get()); + dataVec->SetSize(num); + auto ret = memcpy_s(dataVec->MutableData(), sizeof(int64_t) * num, attr, sizeof(int64_t) * num); + MKI_CHECK(ret == EOK, "failed to copy attr list int", return); + attrs_.emplace_back(std::make_pair(std::move(data), totalSize)); + } + + Mki::Status GetTilingParseContextHolder(KernelContextHolder &tilingParseContextHolder, + std::unique_ptr &computeNode) + { + { + std::lock_guard lck(g_mutexTilingParseCache); + auto it = g_tilingParseCache.find(kernelName_); + if (it == g_tilingParseCache.end()) { + MKI_LOG(DEBUG) << "tactic " << kernelName_ << " is first tiling parse"; + g_tilingParseCache[kernelName_] = BuildTilingParseContextHolder(computeNode); + MKI_CHECK(g_tilingParseCache[kernelName_].context_ != nullptr, + "failed to build tiling parse context", return Mki::Status::FailStatus(1)); + MKI_CHECK((opImpl_->tiling_parse)(g_tilingParseCache[kernelName_].context_) == ge::GRAPH_SUCCESS, + "failed to run tiling parse", return Mki::Status::FailStatus(1)); + } + } + tilingParseContextHolder.context_ = g_tilingParseCache[kernelName_].context_; + MKI_CHECK(((tilingParseContextHolder.context_)->GetOutputPointer(0)) != nullptr, + "OutputPointer is nullptr", return Mki::Status::FailStatus(1)); + + return Mki::Status::OkStatus(); + } + + Mki::Status GetTilingData(uint8_t *tilingData, uint64_t tilingDataLen, const BinHandle &binHandle) + { + MKI_CHECK(tilingData != nullptr, "tilingData invalid", return Mki::Status::FailStatus(1)); + MKI_CHECK(tilingDataLen != 0, "tilingData invalid", return Mki::Status::FailStatus(1)); + MKI_CHECK(InitKernelAttrs(binHandle), "failed to init tactic attrs", return Mki::Status::FailStatus(1)); + MKI_CHECK(InitPlatformInfo(), "failed to init platform info", return Mki::Status::FailStatus(1)); + opImpl_ = gert::OpImplRegistry::GetInstance().GetOpImpl(opType_); + MKI_CHECK(opImpl_ != nullptr, "failed to find tiling entry", return Mki::Status::FailStatus(1)); + auto computeNodePtr = CreateComputeNode(); + MKI_CHECK(computeNodePtr != nullptr, "compute node is nullptr", return Mki::Status::FailStatus(1)); + + KernelContextHolder tilingParseContextHolder; + auto status = GetTilingParseContextHolder(tilingParseContextHolder, computeNodePtr); + MKI_CHECK(status.Ok(), "failed to get tiling parse context", return Mki::Status::FailStatus(1)); + + KernelContextHolder tilingContextHolder = BuildTilingContextHolder(computeNodePtr, + *((tilingParseContextHolder.context_)->GetOutputPointer(0)), tilingDataLen); + MKI_CHECK(tilingContextHolder.context_ != nullptr, + "failed to build tiling context", return Mki::Status::FailStatus(1)); + + auto tilingContext = reinterpret_cast(tilingContextHolder.context_); + MKI_CHECK(opImpl_->tiling(tilingContext) == ge::GRAPH_SUCCESS, + "failed to run tiling", return Mki::Status::FailStatus(1)); + + auto rawTilingData = tilingContext->GetRawTilingData(); + MKI_CHECK(rawTilingData != nullptr, "failed to get rawtilingdata", return Mki::Status::FailStatus(1)); + auto ret = memcpy_s(tilingData, tilingDataLen, rawTilingData->GetData(), rawTilingData->GetDataSize()); + MKI_CHECK(ret == EOK, "failed to copy tilingdata", return Mki::Status::FailStatus(1)); + + contextComponent_.blockDim = tilingContext->GetBlockDim(); + contextComponent_.tilingId = tilingContext->GetTilingKey(); + contextComponent_.tilingSize = rawTilingData->GetDataSize(); + + return Mki::Status::OkStatus(); + } + + uint32_t GetBlockDim() + { + MKI_LOG(INFO) << kernelName_ << " BlockDim " << contextComponent_.blockDim; + return contextComponent_.blockDim; + } + + uint32_t GetIntercoreSync() + { + MKI_LOG(INFO) << kernelName_ << " IntercoreSync " << intercoreSync_; + return intercoreSync_; + } + + uint64_t GetTilingId() + { + MKI_LOG(INFO) << kernelName_ << " TilingId " << contextComponent_.tilingId; + return contextComponent_.tilingId; + } + + uint64_t GetTilingSize() + { + MKI_LOG(INFO) << kernelName_ << " TilingSize " << contextComponent_.tilingSize; + return contextComponent_.tilingSize; + } + + void GetWorkSpace(Mki::SVector &workspace) // 8 小容量SVECTOR + { + uint8_t *wksp = contextComponent_.workspaceSize.get(); + auto workspaceInfo = reinterpret_cast *>(wksp); + MKI_CHECK(workspaceInfo && workspaceInfo->GetData(), "failed to get workspace info", return); + + size_t workspaceNum = workspaceInfo->GetSize(); + const size_t *workspaceSize = workspaceInfo->GetData(); + MKI_LOG(INFO) << kernelName_ << " workspace num " << workspaceNum; + for (size_t i = 0; i < workspaceNum; i++) { + size_t bufferSize = workspaceSize[i]; + MKI_LOG(DEBUG) << "size[" << i << "] " << bufferSize; + workspace.push_back(bufferSize); + } + } + +private: + bool InitKernelAttrs(const BinHandle &binHandle) + { + // compileInfo + compileInfo_ = binHandle.GetKernelCompileInfo(); + MKI_CHECK(compileInfo_ != nullptr, "compile info is nullptr", return false); + // coreType + int32_t coreTypeIdx = binHandle.GetKernelCoreType(); + MKI_CHECK(coreTypeIdx > -1, "core type is empty", return false); + coreType_ = CORE_TYPE_VEC[coreTypeIdx]; + // intercoreSync + intercoreSync_ = binHandle.GetIntercoreSync(); + // taskRatio + cubeRatio_ = binHandle.GetCubeRatio(); + vectorRatio_ = binHandle.GetCubeRatio(); + + return true; + } + + bool InitPlatformInfo() + { + auto [isSuccess, platformInfo] = AsdOpsFePlatformInfosManager::GetPlatFormInfos(); + MKI_CHECK(isSuccess, "failed to get PlatFormInfos", return false); + platformInfo_ = platformInfo; + platformInfo_.SetCoreNumByCoreType(coreType_); + + if (coreType_ == "MIX" && (cubeRatio_ != 0 || vectorRatio_ != 0)) { + uint32_t cubeCoreNum = platformInfo_.GetCoreNumByType("AiCore"); + uint32_t vectorCoreNum = platformInfo_.GetCoreNumByType("VectorCore"); + cubeCoreNum = (cubeRatio_ == 0) ? std::numeric_limits::max() : (cubeCoreNum / cubeRatio_); + vectorCoreNum = (vectorRatio_ == 0) ? std::numeric_limits::max() : (vectorCoreNum / vectorRatio_); + uint32_t coreNum = (cubeCoreNum < vectorCoreNum) ? cubeCoreNum : vectorCoreNum; + if (coreNum == 0) { + MKI_LOG(WARN) << "invalid coreNum for MIX with ratio: " << cubeRatio_ << ":" << vectorRatio_ + << ", use 1 instead"; + coreNum = 1; + } + platformInfo_.SetCoreNum(coreNum); + } + + return true; + } + + void AddTensor(Mki::TensorDType dtype, Mki::TensorFormat format, const Mki::SVector &dims, + std::vector, std::unique_ptr>> &tensors) const + { + auto desc = ge::ComGraphMakeUnique(); + MKI_CHECK(desc != nullptr, "desc is nullptr", return); + desc->SetDataType(GeDataType(dtype)); + desc->SetFormat(GeFormat(format)); + desc->SetOriginFormat(GeFormat(format)); + auto shape = ge::ComGraphMakeUnique(); + MKI_CHECK(shape != nullptr, "shape is nullptr", return); + for (const auto dim : dims) { + (void)shape->AppendDim(dim); + } + tensors.emplace_back(std::make_pair(std::move(desc), std::move(shape))); + } + + ge::DataType GeDataType(Mki::TensorDType dtype) const + { + switch (dtype) { + case Mki::TENSOR_DTYPE_FLOAT: return ge::DT_FLOAT; + case Mki::TENSOR_DTYPE_FLOAT16: return ge::DT_FLOAT16; + case Mki::TENSOR_DTYPE_INT8: return ge::DT_INT8; + case Mki::TENSOR_DTYPE_INT32: return ge::DT_INT32; + case Mki::TENSOR_DTYPE_UINT8: return ge::DT_UINT8; + case Mki::TENSOR_DTYPE_INT16: return ge::DT_INT16; + case Mki::TENSOR_DTYPE_UINT16: return ge::DT_UINT16; + case Mki::TENSOR_DTYPE_UINT32: return ge::DT_UINT32; + case Mki::TENSOR_DTYPE_INT64: return ge::DT_INT64; + case Mki::TENSOR_DTYPE_UINT64: return ge::DT_UINT64; + case Mki::TENSOR_DTYPE_DOUBLE: return ge::DT_DOUBLE; + case Mki::TENSOR_DTYPE_BOOL: return ge::DT_BOOL; + case Mki::TENSOR_DTYPE_STRING: return ge::DT_STRING; + case Mki::TENSOR_DTYPE_COMPLEX64: return ge::DT_COMPLEX64; + case Mki::TENSOR_DTYPE_COMPLEX128: return ge::DT_COMPLEX128; + case Mki::TENSOR_DTYPE_BF16: return ge::DT_BF16; + default: + // ERROR LOG + break; + } + return ge::DT_MAX; + } + + ge::Format GeFormat(Mki::TensorFormat format) const + { + switch (format) { + case Mki::TENSOR_FORMAT_NCHW: return ge::FORMAT_NCHW; + case Mki::TENSOR_FORMAT_NHWC: return ge::FORMAT_NHWC; + case Mki::TENSOR_FORMAT_ND: return ge::FORMAT_ND; + case Mki::TENSOR_FORMAT_NC1HWC0: return ge::FORMAT_NC1HWC0; + case Mki::TENSOR_FORMAT_FRACTAL_Z: return ge::FORMAT_FRACTAL_Z; + case Mki::TENSOR_FORMAT_NC1HWC0_C04: return ge::FORMAT_NC1HWC0_C04; + case Mki::TENSOR_FORMAT_HWCN: return ge::FORMAT_HWCN; + case Mki::TENSOR_FORMAT_NDHWC: return ge::FORMAT_NDHWC; + case Mki::TENSOR_FORMAT_FRACTAL_NZ: return ge::FORMAT_FRACTAL_NZ; + case Mki::TENSOR_FORMAT_NCDHW: return ge::FORMAT_NCDHW; + case Mki::TENSOR_FORMAT_NDC1HWC0: return ge::FORMAT_NDC1HWC0; + case Mki::TENSOR_FORMAT_FRACTAL_Z_3D: return ge::FORMAT_FRACTAL_Z_3D; + default: + // ERROR LOG + break; + } + return ge::FORMAT_MAX; + } + + std::unique_ptr CreateComputeNode() + { + size_t attrSize = sizeof(RuntimeAttrsDef); + size_t attrNum = attrs_.size(); + attrSize += sizeof(size_t) * attrNum; + for (size_t i = 0; i < attrNum; ++i) { + attrSize += attrs_[i].second; + } + auto attrPtr = ge::ComGraphMakeUnique(attrSize); + MKI_CHECK(attrPtr != nullptr, "attrPtr is nullptr", return nullptr); + auto attrsDef = reinterpret_cast(attrPtr.get()); + attrsDef->attr_num = attrNum; + auto memret = memset_s(attrsDef->reserved_, sizeof(attrsDef->reserved_), 0, sizeof(attrsDef->reserved_)); + if (memret != EOK) { + MKI_LOG(ERROR) << "Memset failed, result:" << memret; + return nullptr; + } + size_t currentOffset = sizeof(RuntimeAttrsDef) + sizeof(size_t) * attrsDef->attr_num; + auto attrPos = attrPtr.get(); + for (size_t i = 0; i < attrs_.size(); ++i) { + attrsDef->offset[i] = currentOffset; + auto ret = memcpy_s(attrPos + currentOffset, attrSize - currentOffset, + attrs_[i].first.get(), attrs_[i].second); + if (ret != EOK) { + MKI_LOG(ERROR) << "Failed to copy attr to AttrDef"; + return nullptr; + } + currentOffset += attrs_[i].second; + } + + size_t inputNum = inputs_.size(); + size_t outputNum = outputs_.size(); + size_t computeNodeSize = sizeof(ComputeNodeInfo) + (inputNum + outputNum) * sizeof(CompileTimeTensorDesc); + size_t totalSize = computeNodeSize + attrSize; + auto computeNodePtr = ge::ComGraphMakeUnique(totalSize); + MKI_CHECK(computeNodePtr != nullptr, "computeNodePtr is nullptr", return nullptr); + auto computeNodeDef = reinterpret_cast(computeNodePtr.get()); + computeNodeDef->Init(0, inputNum, outputNum, opType_, opType_); + for (size_t i = 0; i < inputNum; i++) { + auto td = computeNodeDef->MutableInputTdInfo(i); + MKI_CHECK(td != nullptr, "td is nullptr", return nullptr); + td->SetDataType(inputs_[i].first->GetDataType()); + td->SetOriginFormat(inputs_[i].first->GetFormat()); + td->SetStorageFormat(inputs_[i].first->GetFormat()); + } + for (size_t i = 0; i < outputNum; i++) { + auto td = computeNodeDef->MutableOutputTdInfo(i); + MKI_CHECK(td != nullptr, "td is nullptr", return nullptr); + td->SetDataType(outputs_[i].first->GetDataType()); + td->SetOriginFormat(outputs_[i].first->GetFormat()); + td->SetStorageFormat(outputs_[i].first->GetFormat()); + } + auto attr = computeNodeDef->MutableAttrs(); + const auto offset = ge::PtrToPtr(attr) - computeNodePtr.get(); + auto ret = memcpy_s(ge::PtrToPtr(attr), (totalSize - offset), attrPtr.get(), attrSize); + if (ret != EOK) { + MKI_LOG(ERROR) << "Failed to copy AttrDef to ComputeNode"; + return nullptr; + } + return computeNodePtr; + } + + KernelContextHolder BuildTilingParseContextHolder(std::unique_ptr &computeNode) + { + const size_t inputSize = 3; // TilingParse has 3 inputs + const size_t outputSize = 1; // TilingParse has 1 output + KernelContextHolder holder; + size_t size = sizeof(KernelRunContext) + sizeof(Chain *) * (inputSize + outputSize); + holder.context_holder_ = ge::ComGraphMakeUnique(size); + MKI_CHECK(holder.context_holder_ != nullptr, "context holder is nullptr", return holder); + holder.context_ = ge::PtrToPtr(holder.context_holder_.get()); + auto kernelRunContext = holder.context_->GetContext(); + kernelRunContext->input_size = inputSize; // TilingParse has 3 inputs + kernelRunContext->output_size = outputSize; // TilingParse has 1 output + kernelRunContext->compute_node_info = ge::PtrToPtr(computeNode.get()); + kernelRunContext->output_start = &(kernelRunContext->values[kernelRunContext->input_size]); + holder.value_holder_.resize(kernelRunContext->input_size + kernelRunContext->output_size); + for (size_t i = 0UL; i < holder.value_holder_.size(); ++i) { + kernelRunContext->values[i] = ge::PtrToPtr(&holder.value_holder_[i]); + } + + size_t i = 0; + holder.value_holder_[i++].Set(const_cast(compileInfo_), nullptr); + holder.value_holder_[i++].Set(reinterpret_cast(&platformInfo_), nullptr); + holder.value_holder_[i++].Set(const_cast(opType_), nullptr); + + holder.value_holder_[i++].Set(opImpl_->compile_info_creator(), opImpl_->compile_info_deleter); + + return holder; + } + + KernelContextHolder BuildTilingContextHolder(std::unique_ptr &computeNode, void *compileInfo, + uint32_t tilingSize) + { + // prepare contextComponent_ + KernelContextHolder holder; + size_t inputNum = inputs_.size(); + size_t outputNum = outputs_.size(); + for (size_t i = 0; i < inputNum; i++) { + StorageShape storageShape; + storageShape.MutableStorageShape() = *(inputs_[i].second); + storageShape.MutableOriginShape() = *(inputs_[i].second); + contextComponent_.storageShapes.emplace_back(storageShape); + } + for (size_t i = 0; i < outputNum; i++) { + StorageShape storageShape; + storageShape.MutableStorageShape() = *(outputs_[i].second); + storageShape.MutableOriginShape() = *(outputs_[i].second); + contextComponent_.storageShapes.emplace_back(storageShape); + } + + contextComponent_.tilingData = TilingData::CreateCap(tilingSize); + MKI_CHECK(contextComponent_.tilingData != nullptr, "tilingData is nullptr", return holder); + contextComponent_.workspaceSize = ContinuousVector::Create(kWorkspaceHolerSize_); + MKI_CHECK(contextComponent_.workspaceSize != nullptr, "workspaceSize is nullptr", return holder); + std::vector tilingContextInputs(contextComponent_.storageShapes.size() + kSize_, nullptr); + for (size_t i = 0UL; i < contextComponent_.indexToTensors.size(); ++i) { + tilingContextInputs[contextComponent_.indexToTensors[i].first] = + reinterpret_cast(contextComponent_.indexToTensors[i].second.get()); + } + for (size_t i = 0UL; i < contextComponent_.storageShapes.size(); ++i) { + if (tilingContextInputs[i] == nullptr) { + tilingContextInputs[i] = &contextComponent_.storageShapes[i]; + } + } + tilingContextInputs[contextComponent_.storageShapes.size()] = compileInfo; + tilingContextInputs[contextComponent_.storageShapes.size() + 1] = reinterpret_cast(&platformInfo_); + + // prepare kernelruncontext + size_t contextSize = sizeof(KernelRunContext) + sizeof(Chain *) * (tilingContextInputs.size() + 5); // output 5 + holder.context_holder_ = ge::ComGraphMakeUnique(contextSize); + MKI_CHECK(holder.context_holder_ != nullptr, "context holder is nullptr", return holder); + holder.context_ = ge::PtrToPtr(holder.context_holder_.get()); + auto kernelRunContext = holder.context_->GetContext(); + kernelRunContext->input_size = tilingContextInputs.size(); + kernelRunContext->output_size = 5; // TilingContext has 5 outputs + kernelRunContext->compute_node_info = ge::PtrToPtr(computeNode.get()); + kernelRunContext->output_start = &(kernelRunContext->values[kernelRunContext->input_size]); + holder.value_holder_.resize(kernelRunContext->input_size + kernelRunContext->output_size); + for (size_t i = 0UL; i < holder.value_holder_.size(); ++i) { + kernelRunContext->values[i] = ge::PtrToPtr(&holder.value_holder_[i]); + } + for (size_t i = 0UL; i < tilingContextInputs.size(); ++i) { + holder.value_holder_[i].Set(tilingContextInputs[i], nullptr); + } + + size_t i = tilingContextInputs.size(); + holder.value_holder_[i++].Set(nullptr, nullptr); + holder.value_holder_[i++].Set(nullptr, nullptr); + holder.value_holder_[i++].Set(&contextComponent_.atomicFlag, nullptr); + holder.value_holder_[i++].Set(contextComponent_.tilingData.get(), nullptr); + holder.value_holder_[i++].Set(contextComponent_.workspaceSize.get(), nullptr); + + return holder; + } + +private: + struct ContextComponent { + std::vector storageShapes; + std::vector>> indexToTensors; + std::unique_ptr tilingData; + std::unique_ptr workspaceSize; + bool atomicFlag = true; + // tiling extrainfo + uint32_t blockDim = 0; + uint64_t tilingId = 0; + uint64_t tilingSize = 0; + }; + + const size_t kSize_ = 3UL; + const size_t kWorkspaceHolerSize_ = 8UL; + const char *opType_ = "DefaultImpl"; + const char *compileInfo_{nullptr}; + std::string kernelName_; + std::string coreType_ = ""; + uint32_t intercoreSync_ = 0; + uint32_t cubeRatio_ = 0; + uint32_t vectorRatio_ = 0; + fe::PlatFormInfos platformInfo_; + const OpImplRegistry::OpImplFunctions *opImpl_ = nullptr; + ContextComponent contextComponent_; + std::vector, std::unique_ptr>> inputs_; + std::vector, std::unique_ptr>> outputs_; + std::vector, size_t>> attrs_; +}; + +TbeTilingRunner::TbeTilingRunner() : impl_(std::make_shared()) {} + +TbeTilingRunner &TbeTilingRunner::SetName(const char *opType) +{ + impl_->SetName(opType); + return *this; +} + +TbeTilingRunner &TbeTilingRunner::SetKernelName(const std::string kernelName) +{ + impl_->SetKernelName(kernelName); + return *this; +} + +TbeTilingRunner &TbeTilingRunner::AddInput(Mki::TensorDType dtype, Mki::TensorFormat format, + const Mki::SVector &dims) +{ + impl_->AddInput(dtype, format, dims); + return *this; +} + +TbeTilingRunner &TbeTilingRunner::AddConstInput(Mki::TensorDType dtype, Mki::TensorFormat format, + std::initializer_list dims, const void *data, size_t size) +{ + Mki::SVector dims1(dims); + impl_->AddConstInput(dtype, format, dims1, data, size); + return *this; +} + +TbeTilingRunner &TbeTilingRunner::AddConstInput(Mki::TensorDType dtype, Mki::TensorFormat format, + const Mki::SVector &dims, const void *data, size_t size) +{ + impl_->AddConstInput(dtype, format, dims, data, size); + return *this; +} + +TbeTilingRunner &TbeTilingRunner::AddOutput(Mki::TensorDType dtype, Mki::TensorFormat format, + const Mki::SVector &dims) +{ + impl_->AddOutput(dtype, format, dims); + return *this; +} + +TbeTilingRunner &TbeTilingRunner::AddAttrBool(bool value) +{ + impl_->AddAttrBool(value); + return *this; +} + +TbeTilingRunner &TbeTilingRunner::AddAttrInt(int32_t attr) +{ + impl_->AddAttrInt64(attr); + return *this; +} + +TbeTilingRunner &TbeTilingRunner::AddAttrStr(const char *attr) +{ + impl_->AddAttrStr(attr); + return *this; +} + +TbeTilingRunner &TbeTilingRunner::AddAttrInt64(int64_t attr) +{ + impl_->AddAttrInt64(attr); + return *this; +} + +TbeTilingRunner &TbeTilingRunner::AddAttrFloat(float attr) +{ + impl_->AddAttrFloat(attr); + return *this; +} + +TbeTilingRunner &TbeTilingRunner::AddAttrIntList(const int64_t *attr, const size_t num) +{ + impl_->AddAttrIntList(attr, num); + return *this; +} + +Mki::Status TbeTilingRunner::GetTilingData(uint8_t *tilingData, uint64_t tilingDataLen, const BinHandle &binHandle) +{ + return impl_->GetTilingData(tilingData, tilingDataLen, binHandle); +} + +uint32_t TbeTilingRunner::GetBlockDim() +{ + return impl_->GetBlockDim(); +} + +uint32_t TbeTilingRunner::GetIntercoreSync() +{ + return impl_->GetIntercoreSync(); +} + +uint64_t TbeTilingRunner::GetTilingId() +{ + return impl_->GetTilingId(); +} + +uint64_t TbeTilingRunner::GetTilingSize() +{ + return impl_->GetTilingSize(); +} + +void TbeTilingRunner::GetWorkSpace(Mki::SVector &workspace) // 8 小容量SVECTOR +{ + impl_->GetWorkSpace(workspace); +} +} // namespace AsdOpsGeRt + +namespace AsdOps { +Mki::Status GetTilingFromRunner(KernelInfo &kernelInfo, AsdOpsGeRt::TbeTilingRunner &runner, const BinHandle &binHandle) +{ + auto status = runner.GetTilingData(kernelInfo.GetTilingHostAddr(), kernelInfo.GetTilingSize(), binHandle); + MKI_CHECK(status.Ok(), "failed to run tiling runner", return status); + + kernelInfo.SetBlockDim(runner.GetBlockDim()); + kernelInfo.SetTilingId(runner.GetTilingId()); + kernelInfo.SetTilingUsedSize(runner.GetTilingSize()); + if (runner.GetIntercoreSync() == 1) { + kernelInfo.SetHwsyncIdx(0); + } + runner.GetWorkSpace(kernelInfo.GetScratchSizes()); + + return Mki::Status::OkStatus(); +} +} // namespace AsdOps diff --git a/src/torch_atb/CMakeLists.txt b/src/torch_atb/CMakeLists.txt index 3e525a8b27e71738ae815ae9c587d1b02a387eca..50111f2ed452d9ca986547d0817c475c15de93e8 100644 --- a/src/torch_atb/CMakeLists.txt +++ b/src/torch_atb/CMakeLists.txt @@ -13,4 +13,5 @@ pybind11_add_module(_C ${pybind11_source_files}) set_target_properties(_C PROPERTIES OUTPUT_NAME "_C" SUFFIX ".so") target_link_options(_C PRIVATE -rdynamic -ldl -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -Wl,--build-id=none -fexceptions) target_link_libraries(_C PRIVATE torch_npu) +target_include_directories(_C PRIVATE ${ATB_INCLUDE_DIR}) install(TARGETS _C DESTINATION ${CMAKE_SOURCE_DIR}/output/torch_atb) \ No newline at end of file diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index d15dc57ac21188aa3137c882e17ea4c4459b95e2..1b6efccf67205c8fffdaa107827e1b33c4cd7005 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -8,6 +8,10 @@ # See LICENSE in the root of the software repository for the full text of the License. # +include_directories( + $ENV{ASCEND_HOME_PATH}/include +) + add_subdirectory(framework) if(USE_UNIT_TEST OR USE_ALL_TEST) add_subdirectory(unittest)