diff --git a/example/graph_example.py b/example/graph_example.py index 79630c8b1818a76f4a5884038a587cabf1f3d7f0..be2d27c99af63c70953115dd7705c9f85dae5e87 100644 --- a/example/graph_example.py +++ b/example/graph_example.py @@ -13,29 +13,29 @@ import acl import torch import torch_atb -s = 128 # Sequence Length -h = 16 # Number of Heads -d_k = 64 # Head Dimension -d_v = 64 # Value Dimension (vHiddenSize) -output_dim = 64 -output_dim_1 = 128 +SEQLEN = 128 # Sequence Length +HEAD = 16 # Number of Heads +D_K = 64 # Head Dimension +D_V = 64 # Value Dimension (vHiddenSize) +OUTPUT_DIM = 64 +OUTPUT_DIM_1 = 128 def get_inputs(): torch.manual_seed(233) # 单batch场景,batch不为1时s应为seq len * batch - query = (torch.randn((s, 16, d_k), dtype=torch.float16)).npu() - key = (torch.randn((s, 16, d_k), dtype=torch.float16)).npu() - value = (torch.randn((s, 16, d_v), dtype=torch.float16)).npu() - seqLen = (torch.tensor([s], dtype=torch.int32)) - input_0 = (torch.randn((16, d_k), dtype=torch.float16)).npu() - gamma = (torch.randn((s, 16, d_k), dtype=torch.float16)).npu() - beta = (torch.zeros((s, 16, d_k), dtype=torch.float16)).npu() - weight_0 = (torch.randn((output_dim_1, output_dim), dtype=torch.float16)).npu() - bias_0 = (torch.randn((output_dim_1,), dtype=torch.float16)).npu() - weight_1 = (torch.randn((output_dim_1, output_dim_1), dtype=torch.float16)).npu() - bias_1 = (torch.randn((output_dim_1,), dtype=torch.float16)).npu() - inputs = [query, key, value, seqLen, input_0, gamma, beta, weight_0, bias_0, weight_1, bias_1] + query = (torch.randn((SEQLEN, HEAD, D_K), dtype=torch.float16)).npu() + key = (torch.randn((SEQLEN, HEAD, D_K), dtype=torch.float16)).npu() + value = (torch.randn((SEQLEN, HEAD, D_V), dtype=torch.float16)).npu() + seqlen = (torch.tensor([SEQLEN], dtype=torch.int32)) + input_0 = (torch.randn((HEAD, D_K), dtype=torch.float16)).npu() + gamma = (torch.randn((SEQLEN, HEAD, D_K), dtype=torch.float16)).npu() + beta = (torch.zeros((SEQLEN, HEAD, D_K), dtype=torch.float16)).npu() + weight_0 = (torch.randn((OUTPUT_DIM_1, OUTPUT_DIM), dtype=torch.float16)).npu() + bias_0 = (torch.randn((OUTPUT_DIM_1,), dtype=torch.float16)).npu() + weight_1 = (torch.randn((OUTPUT_DIM_1, OUTPUT_DIM_1), dtype=torch.float16)).npu() + bias_1 = (torch.randn((OUTPUT_DIM_1,), dtype=torch.float16)).npu() + inputs = [query, key, value, seqlen, input_0, gamma, beta, weight_0, bias_0, weight_1, bias_1] return inputs @@ -44,12 +44,12 @@ def graph_build(): query = graph.add_input("query") key = graph.add_input("key") value = graph.add_input("value") - seqLen = graph.add_input("seqLen") + seqlen = graph.add_input("seqLen") self_attention_param = torch_atb.SelfAttentionParam() - self_attention_param.head_num = 16 - self_attention_param.kv_head_num = 16 + self_attention_param.head_num = HEAD + self_attention_param.kv_head_num = HEAD self_attention_param.calc_type = torch_atb.SelfAttentionParam.CalcType.PA_ENCODER - self_attention = graph.add_node([query, key, value, seqLen], self_attention_param) + self_attention = graph.add_node([query, key, value, seqlen], self_attention_param) self_attention_out = self_attention.get_output(0) input_0 = graph.add_input("input_0") @@ -81,14 +81,14 @@ def graph_build(): linear_1_out = linear_1.get_output(0) graph.mark_output(linear_1_out) - Graph = graph.build() - return Graph + graph_out = graph.build() + return graph_out def run(): - Graph = graph_build() + graph_op = graph_build() inputs = get_inputs() - results = Graph.forward(inputs) + results = graph_op.forward(inputs) logging.info(results) diff --git a/example/multiStream/multiStream_multiGraph_demo.cpp b/example/multiStream/multiStream_multiGraph_demo.cpp index ff6491a004788a06b9e2bb0b2a13f16f34a35f2c..cc732b7654ab0df0a773c6a3f77a6091eceaead2 100644 --- a/example/multiStream/multiStream_multiGraph_demo.cpp +++ b/example/multiStream/multiStream_multiGraph_demo.cpp @@ -12,20 +12,37 @@ #include #include "atb/atb_infer.h" +namespace { +const int GRAPH_IN_TENSOR_NUM = 2; +const int GRAPH_OUT_TENSOR_NUM = 1; +const int GRAPH_INTERNAL_TENSOR_NUM = 2; + +const int DIM_NUM = 2; +const int DIM2 = 2; + +const int TENSOR_ID0 = 0; +const int TENSOR_ID1 = 1; +const int TENSOR_ID2 = 2; +const int TENSOR_ID3 = 3; +const int TENSOR_ID4 = 4; +const int TENSOR_ID5 = 5; +} // namespace + static void CreateInTensorDescs(atb::SVector &intensorDescs) { for (size_t i = 0; i < intensorDescs.size(); i++) { intensorDescs.at(i).dtype = ACL_FLOAT16; intensorDescs.at(i).format = ACL_FORMAT_ND; - intensorDescs.at(i).shape.dimNum = 2; - intensorDescs.at(i).shape.dims[0] = 2; - intensorDescs.at(i).shape.dims[1] = 2; + intensorDescs.at(i).shape.dimNum = DIM_NUM; + intensorDescs.at(i).shape.dims[0] = DIM2; + intensorDescs.at(i).shape.dims[1] = DIM2; } } static aclError CreateInTensors(atb::SVector &inTensors, atb::SVector &intensorDescs) { - std::vector zeroData(8, 0); // 一段全0的hostBuffer + const int HOST_BUFFER_SIZE = 8; + std::vector zeroData(HOST_BUFFER_SIZE, 0); // 一段全0的hostBuffer int ret; for (size_t i = 0; i < inTensors.size(); i++) { inTensors.at(i).desc = intensorDescs.at(i); @@ -64,10 +81,11 @@ static aclError CreateOutTensors(atb::SVector &outTensors, atb::SVe static void CreateMiniGraphOperation(atb::GraphParam &opGraph, atb::Operation **operation) { // 构子图流程 - opGraph.inTensorNum = 2; - opGraph.outTensorNum = 1; - opGraph.internalTensorNum = 2; - opGraph.nodes.resize(3); + opGraph.inTensorNum = GRAPH_IN_TENSOR_NUM; + opGraph.outTensorNum = GRAPH_OUT_TENSOR_NUM; + opGraph.internalTensorNum = GRAPH_INTERNAL_TENSOR_NUM; + const int GRAPH_NODE_NUM = 3; + opGraph.nodes.resize(GRAPH_NODE_NUM); size_t nodeId = 0; atb::Node &addNode = opGraph.nodes.at(nodeId++); @@ -77,30 +95,31 @@ static void CreateMiniGraphOperation(atb::GraphParam &opGraph, atb::Operation ** atb::infer::ElewiseParam addParam; addParam.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_ADD; atb::CreateOperation(addParam, &addNode.operation); - addNode.inTensorIds = {0, 1}; - addNode.outTensorIds = {3}; + addNode.inTensorIds = {TENSOR_ID0, TENSOR_ID1}; + addNode.outTensorIds = {TENSOR_ID3}; atb::infer::ElewiseParam addParam2; addParam2.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_ADD; atb::CreateOperation(addParam2, &addNode2.operation); - addNode2.inTensorIds = {3, 1}; - addNode2.outTensorIds = {4}; + addNode2.inTensorIds = {TENSOR_ID3, TENSOR_ID1}; + addNode2.outTensorIds = {TENSOR_ID4}; atb::infer::ElewiseParam addParam3; addParam3.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_ADD; CreateOperation(addParam3, &addNode3.operation); - addNode3.inTensorIds = {4, 1}; - addNode3.outTensorIds = {2}; + addNode3.inTensorIds = {TENSOR_ID4, TENSOR_ID1}; + addNode3.outTensorIds = {TENSOR_ID2}; atb::CreateOperation(opGraph, operation); } static void CreateGraphOperationWithWREvent(atb::GraphParam &opGraph, atb::Operation **operation, aclrtEvent event) { - opGraph.inTensorNum = 2; - opGraph.outTensorNum = 1; - opGraph.internalTensorNum = 2; - opGraph.nodes.resize(5); + opGraph.inTensorNum = GRAPH_IN_TENSOR_NUM; + opGraph.outTensorNum = GRAPH_OUT_TENSOR_NUM; + opGraph.internalTensorNum = GRAPH_INTERNAL_TENSOR_NUM; + const int GRAPH_NODE_NUM = 5; + opGraph.nodes.resize(GRAPH_NODE_NUM); size_t nodeId = 0; atb::Node &mulNode = opGraph.nodes.at(nodeId++); @@ -112,8 +131,8 @@ static void CreateGraphOperationWithWREvent(atb::GraphParam &opGraph, atb::Opera atb::infer::ElewiseParam mulParam; mulParam.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_MUL; atb::CreateOperation(mulParam, &mulNode.operation); - mulNode.inTensorIds = {0, 1}; - mulNode.outTensorIds = {3}; + mulNode.inTensorIds = {TENSOR_ID0, TENSOR_ID1}; + mulNode.outTensorIds = {TENSOR_ID3}; atb::common::EventParam waitParam; waitParam.event = event; @@ -122,14 +141,14 @@ static void CreateGraphOperationWithWREvent(atb::GraphParam &opGraph, atb::Opera atb::GraphParam graphParam; CreateMiniGraphOperation(graphParam, &graphNode.operation); - graphNode.inTensorIds = {3, 4}; - graphNode.outTensorIds = {2}; + graphNode.inTensorIds = {TENSOR_ID3, TENSOR_ID4}; + graphNode.outTensorIds = {TENSOR_ID2}; atb::infer::ElewiseParam addParam; addParam.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_ADD; atb::CreateOperation(addParam, &addNode.operation); - addNode.inTensorIds = {0, 1}; - addNode.outTensorIds = {4}; + addNode.inTensorIds = {TENSOR_ID0, TENSOR_ID1}; + addNode.outTensorIds = {TENSOR_ID4}; atb::common::EventParam recordParam; recordParam.event = event; @@ -141,10 +160,11 @@ static void CreateGraphOperationWithWREvent(atb::GraphParam &opGraph, atb::Opera static void CreateGraphOperationWithRWEvent(atb::GraphParam &opGraph, atb::Operation **operation, aclrtEvent event) { - opGraph.inTensorNum = 2; - opGraph.outTensorNum = 1; - opGraph.internalTensorNum = 2; - opGraph.nodes.resize(5); + opGraph.inTensorNum = GRAPH_IN_TENSOR_NUM; + opGraph.outTensorNum = GRAPH_OUT_TENSOR_NUM; + opGraph.internalTensorNum = GRAPH_INTERNAL_TENSOR_NUM; + const int GRAPH_NODE_NUM = 5; + opGraph.nodes.resize(GRAPH_NODE_NUM); size_t nodeId = 0; atb::Node &mulNode = opGraph.nodes.at(nodeId++); @@ -156,8 +176,8 @@ static void CreateGraphOperationWithRWEvent(atb::GraphParam &opGraph, atb::Opera atb::infer::ElewiseParam mulParam; mulParam.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_MUL; atb::CreateOperation(mulParam, &mulNode.operation); - mulNode.inTensorIds = {0, 1}; - mulNode.outTensorIds = {3}; + mulNode.inTensorIds = {TENSOR_ID0, TENSOR_ID1}; + mulNode.outTensorIds = {TENSOR_ID3}; atb::common::EventParam recordParam; recordParam.event = event; @@ -166,14 +186,14 @@ static void CreateGraphOperationWithRWEvent(atb::GraphParam &opGraph, atb::Opera atb::GraphParam graphParam; CreateMiniGraphOperation(graphParam, &graphNode.operation); - graphNode.inTensorIds = {3, 4}; - graphNode.outTensorIds = {2}; + graphNode.inTensorIds = {TENSOR_ID3, TENSOR_ID4}; + graphNode.outTensorIds = {TENSOR_ID2}; atb::infer::ElewiseParam addParam; addParam.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_ADD; atb::CreateOperation(addParam, &addNode.operation); - addNode.inTensorIds = {0, 1}; - addNode.outTensorIds = {4}; + addNode.inTensorIds = {TENSOR_ID0, TENSOR_ID1}; + addNode.outTensorIds = {TENSOR_ID4}; atb::common::EventParam waitParam; waitParam.event = event; @@ -237,8 +257,7 @@ int main() packRW.outTensors.resize(outTensorNum); operationWR->InferShape(intensorDescs, outtensorDescs); - aclError ret; - ret = CreateInTensors(packWR.inTensors, intensorDescs); + aclError ret = CreateInTensors(packWR.inTensors, intensorDescs); if (ret != 0) { exit(ret); } diff --git a/example/multiStream/multiStream_singleGraph_demo.cpp b/example/multiStream/multiStream_singleGraph_demo.cpp index d95873fdc85b3870f0929f2675c6cdbc2244cedc..fae0cbd19f87322184a0b37ba4bfb45e856d5470 100644 --- a/example/multiStream/multiStream_singleGraph_demo.cpp +++ b/example/multiStream/multiStream_singleGraph_demo.cpp @@ -12,31 +12,50 @@ #include #include "atb/atb_infer.h" +namespace { +const int GRAPH_IN_TENSOR_NUM = 2; +const int GRAPH_OUT_TENSOR_NUM = 1; +const int GRAPH_INTERNAL_TENSOR_NUM = 2; + +const int DIM_NUM = 2; +const int DIM2 = 2; + +const int TENSOR_ID0 = 0; +const int TENSOR_ID1 = 1; +const int TENSOR_ID2 = 2; +const int TENSOR_ID3 = 3; +const int TENSOR_ID4 = 4; +const int TENSOR_ID5 = 5; +} // namespace + static void CreateInTensorDescs(atb::SVector &intensorDescs) { for (size_t i = 0; i < intensorDescs.size(); i++) { intensorDescs.at(i).dtype = ACL_FLOAT16; intensorDescs.at(i).format = ACL_FORMAT_ND; - intensorDescs.at(i).shape.dimNum = 2; - intensorDescs.at(i).shape.dims[0] = 2; - intensorDescs.at(i).shape.dims[1] = 2; + intensorDescs.at(i).shape.dimNum = DIM_NUM; + intensorDescs.at(i).shape.dims[0] = DIM2; + intensorDescs.at(i).shape.dims[1] = DIM2; } } static aclError CreateInTensors(atb::SVector &inTensors, atb::SVector &intensorDescs) { - std::vector zeroData(8, 0); // 一段全0的hostBuffer + const int HOST_BUFFER_SIZE = 8; + std::vector zeroData(HOST_BUFFER_SIZE, 0); // 一段全0的hostBuffer int ret; for (size_t i = 0; i < inTensors.size(); i++) { inTensors.at(i).desc = intensorDescs.at(i); inTensors.at(i).dataSize = atb::Utils::GetTensorSize(inTensors.at(i)); - ret = aclrtMalloc(&inTensors.at(i).deviceData, inTensors.at(i).dataSize, ACL_MEM_MALLOC_HUGE_FIRST); // 分配NPU内存 + ret = aclrtMalloc(&inTensors.at(i).deviceData, inTensors.at(i).dataSize, + ACL_MEM_MALLOC_HUGE_FIRST); // 分配NPU内存 if (ret != 0) { std::cout << "alloc error!"; return ret; } // 拷贝CPU内存到NPU侧 - ret = aclrtMemcpy(inTensors.at(i).deviceData, inTensors.at(i).dataSize, zeroData.data(), zeroData.size(), ACL_MEMCPY_HOST_TO_DEVICE); + ret = aclrtMemcpy(inTensors.at(i).deviceData, inTensors.at(i).dataSize, zeroData.data(), zeroData.size(), + ACL_MEMCPY_HOST_TO_DEVICE); if (ret != 0) { std::cout << "memcpy error!"; } @@ -61,11 +80,12 @@ static aclError CreateOutTensors(atb::SVector &outTensors, atb::SVe static void CreateMiniGraphOperation(atb::GraphParam &opGraph, atb::Operation **operation) { - // 构子图流程 - opGraph.inTensorNum = 2; - opGraph.outTensorNum = 1; - opGraph.internalTensorNum = 2; - opGraph.nodes.resize(3); + // 构子图流程 + opGraph.inTensorNum = GRAPH_IN_TENSOR_NUM; + opGraph.outTensorNum = GRAPH_OUT_TENSOR_NUM; + opGraph.internalTensorNum = GRAPH_INTERNAL_TENSOR_NUM; + const int NODE_SIZE = 3; + opGraph.nodes.resize(NODE_SIZE); size_t nodeId = 0; atb::Node &addNode = opGraph.nodes.at(nodeId++); @@ -75,20 +95,20 @@ static void CreateMiniGraphOperation(atb::GraphParam &opGraph, atb::Operation ** atb::infer::ElewiseParam addParam; addParam.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_ADD; atb::CreateOperation(addParam, &addNode.operation); - addNode.inTensorIds = {0, 1}; - addNode.outTensorIds = {3}; + addNode.inTensorIds = {TENSOR_ID0, TENSOR_ID1}; + addNode.outTensorIds = {TENSOR_ID3}; atb::infer::ElewiseParam addParam2; addParam2.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_ADD; atb::CreateOperation(addParam2, &addNode2.operation); - addNode2.inTensorIds = {3, 1}; - addNode2.outTensorIds = {4}; + addNode2.inTensorIds = {TENSOR_ID3, TENSOR_ID1}; + addNode2.outTensorIds = {TENSOR_ID4}; atb::infer::ElewiseParam addParam3; addParam3.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_ADD; CreateOperation(addParam3, &addNode3.operation); - addNode3.inTensorIds = {4, 1}; - addNode3.outTensorIds = {2}; + addNode3.inTensorIds = {TENSOR_ID4, TENSOR_ID1}; + addNode3.outTensorIds = {TENSOR_ID2}; atb::CreateOperation(opGraph, operation); } @@ -96,10 +116,11 @@ static void CreateMiniGraphOperation(atb::GraphParam &opGraph, atb::Operation ** static void CreateGraphOperationForMultiStream(atb::GraphParam &opGraph, atb::Operation **operation) { // 构单图多流大图 - opGraph.inTensorNum = 2; - opGraph.outTensorNum = 2; - opGraph.internalTensorNum = 2; - opGraph.nodes.resize(4); + opGraph.inTensorNum = GRAPH_IN_TENSOR_NUM; + opGraph.outTensorNum = GRAPH_OUT_TENSOR_NUM; + opGraph.internalTensorNum = GRAPH_INTERNAL_TENSOR_NUM; + const int NODE_SIZE = 4; + opGraph.nodes.resize(NODE_SIZE); size_t nodeId = 0; atb::Node &mulNode = opGraph.nodes.at(nodeId++); @@ -110,24 +131,24 @@ static void CreateGraphOperationForMultiStream(atb::GraphParam &opGraph, atb::Op atb::infer::ElewiseParam mulParam; mulParam.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_MUL; atb::CreateOperation(mulParam, &mulNode.operation); - mulNode.inTensorIds = {0, 1}; - mulNode.outTensorIds = {3}; + mulNode.inTensorIds = {TENSOR_ID0, TENSOR_ID1}; + mulNode.outTensorIds = {TENSOR_ID3}; atb::infer::ElewiseParam addParam2; addParam2.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_ADD; atb::CreateOperation(addParam2, &addNode2.operation); - addNode2.inTensorIds = {0, 1}; - addNode2.outTensorIds = {4}; + addNode2.inTensorIds = {TENSOR_ID0, TENSOR_ID1}; + addNode2.outTensorIds = {TENSOR_ID4}; atb::GraphParam graphParam; CreateMiniGraphOperation(graphParam, &graphNode.operation); - graphNode.inTensorIds = {4, 1}; - graphNode.outTensorIds = {5}; + graphNode.inTensorIds = {TENSOR_ID4, TENSOR_ID1}; + graphNode.outTensorIds = {TENSOR_ID5}; SetExecuteStreamId(graphNode.operation, 1); atb::CreateOperation(mulParam, &mulNode1.operation); - mulNode1.inTensorIds = {5, 1}; - mulNode1.outTensorIds = {2}; + mulNode1.inTensorIds = {TENSOR_ID5, TENSOR_ID1}; + mulNode1.outTensorIds = {TENSOR_ID2}; SetExecuteStreamId(mulNode1.operation, 1); atb::CreateOperation(opGraph, operation); @@ -135,10 +156,11 @@ static void CreateGraphOperationForMultiStream(atb::GraphParam &opGraph, atb::Op static void CreateGraphOperationWithWREvent(atb::GraphParam &opGraph, atb::Operation **operation, aclrtEvent event) { - opGraph.inTensorNum = 2; - opGraph.outTensorNum = 1; - opGraph.internalTensorNum = 2; - opGraph.nodes.resize(5); + opGraph.inTensorNum = GRAPH_IN_TENSOR_NUM; + opGraph.outTensorNum = GRAPH_OUT_TENSOR_NUM; + opGraph.internalTensorNum = GRAPH_INTERNAL_TENSOR_NUM; + const int NODE_SIZE = 5; + opGraph.nodes.resize(NODE_SIZE); size_t nodeId = 0; atb::Node &mulNode = opGraph.nodes.at(nodeId++); @@ -150,8 +172,8 @@ static void CreateGraphOperationWithWREvent(atb::GraphParam &opGraph, atb::Opera atb::infer::ElewiseParam mulParam; mulParam.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_MUL; atb::CreateOperation(mulParam, &mulNode.operation); - mulNode.inTensorIds = {0, 1}; - mulNode.outTensorIds = {3}; + mulNode.inTensorIds = {TENSOR_ID0, TENSOR_ID1}; + mulNode.outTensorIds = {TENSOR_ID3}; atb::common::EventParam waitParam; waitParam.event = event; @@ -160,14 +182,14 @@ static void CreateGraphOperationWithWREvent(atb::GraphParam &opGraph, atb::Opera atb::GraphParam graphParam; CreateMiniGraphOperation(graphParam, &graphNode.operation); - graphNode.inTensorIds = {3, 4}; - graphNode.outTensorIds = {2}; + graphNode.inTensorIds = {TENSOR_ID3, TENSOR_ID4}; + graphNode.outTensorIds = {TENSOR_ID2}; atb::infer::ElewiseParam addParam; addParam.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_ADD; atb::CreateOperation(addParam, &addNode.operation); - addNode.inTensorIds = {0, 1}; - addNode.outTensorIds = {4}; + addNode.inTensorIds = {TENSOR_ID0, TENSOR_ID1}; + addNode.outTensorIds = {TENSOR_ID4}; SetExecuteStreamId(addNode.operation, 1); atb::common::EventParam recordParam; @@ -181,10 +203,11 @@ static void CreateGraphOperationWithWREvent(atb::GraphParam &opGraph, atb::Opera static void CreateGraphOperationWithRWEvent(atb::GraphParam &opGraph, atb::Operation **operation, aclrtEvent event) { - opGraph.inTensorNum = 2; - opGraph.outTensorNum = 1; - opGraph.internalTensorNum = 2; - opGraph.nodes.resize(5); + opGraph.inTensorNum = GRAPH_IN_TENSOR_NUM; + opGraph.outTensorNum = GRAPH_OUT_TENSOR_NUM; + opGraph.internalTensorNum = GRAPH_INTERNAL_TENSOR_NUM; + const int NODE_SIZE = 5; + opGraph.nodes.resize(NODE_SIZE); size_t nodeId = 0; atb::Node &mulNode = opGraph.nodes.at(nodeId++); @@ -196,8 +219,8 @@ static void CreateGraphOperationWithRWEvent(atb::GraphParam &opGraph, atb::Opera atb::infer::ElewiseParam mulParam; mulParam.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_MUL; atb::CreateOperation(mulParam, &mulNode.operation); - mulNode.inTensorIds = {0, 1}; - mulNode.outTensorIds = {3}; + mulNode.inTensorIds = {TENSOR_ID0, TENSOR_ID1}; + mulNode.outTensorIds = {TENSOR_ID3}; atb::common::EventParam waitParam; waitParam.event = event; @@ -206,14 +229,14 @@ static void CreateGraphOperationWithRWEvent(atb::GraphParam &opGraph, atb::Opera atb::GraphParam graphParam; CreateMiniGraphOperation(graphParam, &graphNode.operation); - graphNode.inTensorIds = {3, 4}; - graphNode.outTensorIds = {2}; + graphNode.inTensorIds = {TENSOR_ID3, TENSOR_ID4}; + graphNode.outTensorIds = {TENSOR_ID2}; atb::infer::ElewiseParam addParam; addParam.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_ADD; atb::CreateOperation(addParam, &addNode.operation); - addNode.inTensorIds = {0, 1}; - addNode.outTensorIds = {4}; + addNode.inTensorIds = {TENSOR_ID0, TENSOR_ID1}; + addNode.outTensorIds = {TENSOR_ID4}; SetExecuteStreamId(addNode.operation, 1); atb::common::EventParam recordParam; @@ -228,7 +251,7 @@ static void CreateGraphOperationWithRWEvent(atb::GraphParam &opGraph, atb::Opera int main() { aclInit(nullptr); - // 设置卡号、创建stream、创建context、设置stream + // 设置卡号、创建stream、创建context、设置stream uint32_t deviceId = 1; aclrtSetDevice(deviceId); // 创建多个stream @@ -246,7 +269,7 @@ int main() atb::GraphParam opGraph; CreateGraphOperationForMultiStream(opGraph, &operation); - // 输入输出tensor准备 + // 输入输出tensor准备 atb::VariantPack pack; atb::SVector intensorDescs; atb::SVector outtensorDescs; @@ -258,14 +281,13 @@ int main() pack.inTensors.resize(inTensorNum); intensorDescs.resize(inTensorNum); - + CreateInTensorDescs(intensorDescs); - + outtensorDescs.resize(outTensorNum); pack.outTensors.resize(outTensorNum); operation->InferShape(intensorDescs, outtensorDescs); - aclError ret; - ret = CreateOutTensors(pack.outTensors, outtensorDescs); + aclError ret = CreateOutTensors(pack.outTensors, outtensorDescs); if (ret != 0) { exit(ret); } @@ -287,7 +309,7 @@ int main() exit(1); } } - operation->Execute(pack, (uint8_t*)workSpace, workspaceSize, context); + operation->Execute(pack, (uint8_t *)workSpace, workspaceSize, context); // 流同步 ret = aclrtSynchronizeStream(stream1); if (ret != 0) { @@ -301,7 +323,7 @@ int main() exit(1); } - // 资源释放 + // 资源释放 atb::DestroyOperation(operation); atb::DestroyContext(context); for (size_t i = 0; i < pack.inTensors.size(); i++) { diff --git a/example/op_demo/activation/activation_demo.cpp b/example/op_demo/activation/activation_demo.cpp index e4d9ed0e75efe6e4f4e1e94e60b5817e4fbe08ad..958f01c2b37e7054095323879d9d5f3bde458b46 100644 --- a/example/op_demo/activation/activation_demo.cpp +++ b/example/op_demo/activation/activation_demo.cpp @@ -11,9 +11,11 @@ #include #include "../demo_util.h" +namespace { const uint32_t BATCH_SIZE = 16; // 批处理大小 const uint32_t SEQ_LEN = 1024; // 序列长度 const uint32_t HIDDEN_SIZE = 4096; // 隐藏层维度 +} /** * @brief 准备atb::VariantPack中的输入tensor diff --git a/example/op_demo/all_gather/all_gather_demo.cpp b/example/op_demo/all_gather/all_gather_demo.cpp index 6cd5826cabb8232931c5f047da6bdc2818179fdd..0dc0865d89802c9345312fb004c2c49e446d725f 100644 --- a/example/op_demo/all_gather/all_gather_demo.cpp +++ b/example/op_demo/all_gather/all_gather_demo.cpp @@ -12,6 +12,17 @@ #include #include "../demo_util.h" +namespace { +const int64_t INTPUT_DIM_NUM = 2; +const int64_t OUTPUT_DIM_NUM = 3; +const int64_t DIM2 = 2; +const int64_t DIM3 = 3; +const int64_t DIM5 = 5; +const int64_t IDX0 = 0; +const int64_t IDX1 = 1; +const int64_t IDX2 = 2; +} + atb::Status ExcuteImpl(atb::Operation *op, atb::VariantPack variantPack, atb::Context *context) { uint64_t workspaceSize = 0; @@ -41,18 +52,18 @@ atb::Status AllGatherSample(int rank, int rankSize) atb::Tensor input; input.desc.dtype = ACL_FLOAT16; input.desc.format = ACL_FORMAT_ND; - input.desc.shape.dimNum = 2; - input.desc.shape.dims[0] = 3; - input.desc.shape.dims[1] = 5; + input.desc.shape.dimNum = INTPUT_DIM_NUM; + input.desc.shape.dims[IDX0] = DIM3; + input.desc.shape.dims[IDX1] = DIM5; input.dataSize = atb::Utils::GetTensorSize(input); CHECK_STATUS(aclrtMalloc(&input.deviceData, input.dataSize, ACL_MEM_MALLOC_HUGE_FIRST)); atb::Tensor output; output.desc.dtype = ACL_FLOAT16; output.desc.format = ACL_FORMAT_ND; - output.desc.shape.dimNum = 3; - output.desc.shape.dims[0] = 2; - output.desc.shape.dims[1] = 3; - output.desc.shape.dims[2] = 5; + output.desc.shape.dimNum = OUTPUT_DIM_NUM; + output.desc.shape.dims[IDX0] = DIM2; + output.desc.shape.dims[IDX1] = DIM3; + output.desc.shape.dims[IDX2] = DIM5; output.dataSize = atb::Utils::GetTensorSize(output); CHECK_STATUS(aclrtMalloc(&output.deviceData, output.dataSize, ACL_MEM_MALLOC_HUGE_FIRST)); atb::infer::AllGatherParam param; diff --git a/example/op_demo/all_reduce/all_reduce_demo.cpp b/example/op_demo/all_reduce/all_reduce_demo.cpp index f502d28feb3cb8640c10d536c8955f30fcfb442a..e0f8094a64318eced3fc61e2063ac2a3a830990e 100644 --- a/example/op_demo/all_reduce/all_reduce_demo.cpp +++ b/example/op_demo/all_reduce/all_reduce_demo.cpp @@ -29,7 +29,8 @@ atb::VariantPack PrepareVariantPack(Args &args) atb::VariantPack variantPack; std::vector shape = {2, 1024}; // 创建Host侧数据 - std::vector xHostData(shape[0] * shape[1], 2.0); + const float value = 2.0; + std::vector xHostData(shape[0] * shape[1], value); std::vector outputHostData(shape[0] * shape[1], 0); // 生成ATB tensor atb::Tensor tensorX; diff --git a/example/op_demo/elewise/elewise_demo.cpp b/example/op_demo/elewise/elewise_demo.cpp index 98672d0c3bea2b21988870864b97e36b8939c1f7..c4f953c938a1c60390719a41f9b2a6dd3af9fad7 100644 --- a/example/op_demo/elewise/elewise_demo.cpp +++ b/example/op_demo/elewise/elewise_demo.cpp @@ -13,8 +13,10 @@ using namespace atb; using namespace std; +namespace { const int VECTOR_SIZE = 4; // 向量的大小 const float INIT_VALUE = 2.0f; // 向量的初始值 +} /** * @brief 准备atb::VariantPack中的所有输入tensor diff --git a/example/op_demo/layer_norm/layer_norm_demo.cpp b/example/op_demo/layer_norm/layer_norm_demo.cpp index e6839e6008e55a4bbb8ddd503f4c3f8e16f10f06..d4f8f1d643596db7b1ee54a43e4f1e6f631c4495 100644 --- a/example/op_demo/layer_norm/layer_norm_demo.cpp +++ b/example/op_demo/layer_norm/layer_norm_demo.cpp @@ -10,10 +10,12 @@ #include "../demo_util.h" +namespace { const int32_t DEVICE_ID = 0; const uint32_t DIM_0 = 128; const uint32_t DIM_1 = 256; const int32_t BEGIN_NORM_AXIS = 1; +} /** * @brief 准备atb::VariantPack中的所有输入tensor @@ -24,14 +26,15 @@ const int32_t BEGIN_NORM_AXIS = 1; */ atb::Status PrepareInTensor(atb::Context *contextPtr, aclrtStream stream, atb::SVector &inTensors) { + const float value = 2.0; atb::Tensor x; - CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector(DIM_0 * DIM_1, 2.0), ACL_FLOAT16, + CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector(DIM_0 * DIM_1, value), ACL_FLOAT16, aclFormat::ACL_FORMAT_ND, {DIM_0, DIM_1}, x)); atb::Tensor gamma; - CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector(DIM_1, 2.0), ACL_FLOAT16, + CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector(DIM_1, value), ACL_FLOAT16, aclFormat::ACL_FORMAT_ND, {DIM_1}, gamma)); atb::Tensor beta; - CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector(DIM_1, 1.0), ACL_FLOAT16, + CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector(DIM_1, value), ACL_FLOAT16, aclFormat::ACL_FORMAT_ND, {DIM_1}, beta)); inTensors = {x, gamma, beta}; return atb::ErrorType::NO_ERROR; @@ -67,7 +70,7 @@ int main(int argc, char **argv) CHECK_STATUS(CreateLayerNormOperation(&layerNormOp)); // 准备输入tensor atb::VariantPack variantPack; - CHECK_STATUS(PrepareInTensor(context, stream, variantPack.inTensors)); // 放入输入tensor + CHECK_STATUS(PrepareInTensor(context, stream, variantPack.inTensors)); // 放入输入tensor atb::Tensor tensorOut; CHECK_STATUS(CreateTensor(ACL_FLOAT16, aclFormat::ACL_FORMAT_ND, {DIM_0, DIM_1}, tensorOut)); variantPack.outTensors = {tensorOut}; // 放入输出tensor diff --git a/example/op_demo/linear/linear_demo.cpp b/example/op_demo/linear/linear_demo.cpp index 29e03c3813d57cdca92c9e09ed65b7d86787c362..d4e8498ad266d6e784e17e237c3580c3d5863659 100644 --- a/example/op_demo/linear/linear_demo.cpp +++ b/example/op_demo/linear/linear_demo.cpp @@ -10,12 +10,14 @@ #include "../demo_util.h" +namespace { const int32_t DEVICE_ID = 0; const uint32_t X_DIM_0 = 2; const uint32_t X_DIM_1 = 3; const uint32_t WEIGHT_DIM_0 = 3; const uint32_t WEIGHT_DIM_1 = 2; const uint32_t BIAS_DIM_0 = 2; +} /** * @brief 准备atb::VariantPack diff --git a/example/op_demo/linear/linear_dequant_demo.cpp b/example/op_demo/linear/linear_dequant_demo.cpp index 48486fd2509a95648af21174f933e80e3ce674af..30f0b237d249abb251f199253d7bce95e67461d6 100644 --- a/example/op_demo/linear/linear_dequant_demo.cpp +++ b/example/op_demo/linear/linear_dequant_demo.cpp @@ -10,11 +10,14 @@ #include "../demo_util.h" +namespace { const int32_t DEVICE_ID = 0; +const int32_t SIZE_2 = 2; const uint32_t XDIM_0 = 2; const uint32_t XDIM_1 = 3; const uint32_t WEIGHTDIM_0 = 3; const uint32_t WEIGHTDIM_1 = 2; +} /** * @brief 准备atb::VariantPack @@ -36,11 +39,11 @@ atb::Status PrepareInTensor(atb::Context *contextPtr, aclrtStream stream, atb::S weight)); // 创建shape为[2]bias tensor atb::Tensor bias; - CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector(2, 1), aclDataType::ACL_INT32, + CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector(SIZE_2, 1), aclDataType::ACL_INT32, aclFormat::ACL_FORMAT_ND, {1, 2}, bias)); // 创建shape为[2]的输入deqScale tensor atb::Tensor deqScale; - CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector(2, 1), aclDataType::ACL_FLOAT, + CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector(SIZE_2, 1), aclDataType::ACL_FLOAT, aclFormat::ACL_FORMAT_ND, {1, 2}, deqScale)); inTensors = {x, weight, bias, deqScale}; return atb::ErrorType::NO_ERROR; @@ -60,7 +63,7 @@ atb::Status CreateLinearOperation(atb::Operation **linearOp) param.outDataType = aclDataType::ACL_BF16; param.enAccum = false; param.matmulType = atb::infer::LinearParam::MatmulType::MATMUL_UNDEFINED; - param.quantMode = atb::infer::LinearParam::QuantMode::PER_CHANNEL; + param.quantMode = atb::infer::LinearParam::QuantMode::PER_CHANNEL; return atb::CreateOperation(param, linearOp); } diff --git a/example/op_demo/linear/linear_einsum_demo.cpp b/example/op_demo/linear/linear_einsum_demo.cpp index 22a86b151229d809b853d746375ef88218b35079..a04f900e50ee95de1ba62dec2de0c04baada21be 100644 --- a/example/op_demo/linear/linear_einsum_demo.cpp +++ b/example/op_demo/linear/linear_einsum_demo.cpp @@ -10,6 +10,7 @@ #include "../demo_util.h" +namespace { const int32_t DEVICE_ID = 0; const uint32_t XDIM_0 = 2; const uint32_t XDIM_1 = 3; @@ -17,6 +18,7 @@ const uint32_t WEIGHTDIM_0 = 3; const uint32_t WEIGHTDIM_1 = 2; const uint32_t BIASDIM_0 = 2; const uint32_t BATCHSIZE = 1; +} /** * @brief 准备atb::VariantPack diff --git a/example/op_demo/linear_parallel/linear_parallel_demo.cpp b/example/op_demo/linear_parallel/linear_parallel_demo.cpp index 8a853fb30273d434b8bc1305d53deaccc61fe0c2..e5571bf6f10098fc34525a7a20a7dfafe85be35d 100644 --- a/example/op_demo/linear_parallel/linear_parallel_demo.cpp +++ b/example/op_demo/linear_parallel/linear_parallel_demo.cpp @@ -13,6 +13,13 @@ #include "../demo_util.h" +namespace { +const int OUTPUT_DIM_NUM = 2; +const int DIM_M = 2; +const int DIM_N = 2; +const int DIM_K = 32; +} + atb::Status ExcuteImpl(atb::Operation *op, atb::VariantPack variantPack, atb::Context *context, aclrtStream &stream) { uint64_t workspaceSize = 0; @@ -48,18 +55,20 @@ atb::Status LinearParallelSample(int rank, int rankSize) context->SetExecuteStream(stream); atb::Tensor input; - CHECK_STATUS(CreateTensorFromVector(context, stream, std::vector(64, 2.0), aclDataType::ACL_FLOAT16, - aclFormat::ACL_FORMAT_ND, {2, 32}, input)); + const int TENSOR_SIZE = 64; + const float tensorValue = 2.0; + CHECK_STATUS(CreateTensorFromVector(context, stream, std::vector(TENSOR_SIZE, tensorValue), + aclDataType::ACL_FLOAT16, aclFormat::ACL_FORMAT_ND, {DIM_M, DIM_K}, input)); atb::Tensor weight; - CHECK_STATUS(CreateTensorFromVector(context, stream, std::vector(64, 2.0), aclDataType::ACL_FLOAT16, - aclFormat::ACL_FORMAT_ND, {32, 2}, weight)); + CHECK_STATUS(CreateTensorFromVector(context, stream, std::vector(TENSOR_SIZE, tensorValue), + aclDataType::ACL_FLOAT16, aclFormat::ACL_FORMAT_ND, {DIM_K, DIM_N}, weight)); atb::Tensor output; output.desc.dtype = ACL_FLOAT16; output.desc.format = ACL_FORMAT_ND; - output.desc.shape.dimNum = 2; - output.desc.shape.dims[0] = 2; - output.desc.shape.dims[1] = 2; + output.desc.shape.dimNum = OUTPUT_DIM_NUM; + output.desc.shape.dims[0] = DIM_M; + output.desc.shape.dims[1] = DIM_N; output.dataSize = atb::Utils::GetTensorSize(output); CHECK_STATUS(aclrtMalloc(&output.deviceData, output.dataSize, ACL_MEM_MALLOC_HUGE_FIRST)); diff --git a/example/op_demo/linear_parallel/linear_parallel_quant_demo.cpp b/example/op_demo/linear_parallel/linear_parallel_quant_demo.cpp index 23cdf8e35930a41de33f9bdd4fb955acef9730a4..cd8350debcff04cd0b2317b436672fc29e89178e 100644 --- a/example/op_demo/linear_parallel/linear_parallel_quant_demo.cpp +++ b/example/op_demo/linear_parallel/linear_parallel_quant_demo.cpp @@ -13,6 +13,12 @@ #include "../demo_util.h" +namespace { +const int DIM_M = 2; +const int DIM_N = 2; +const int DIM_K = 32; +} + atb::Status ExcuteImpl(atb::Operation *op, atb::VariantPack variantPack, atb::Context *context, aclrtStream &stream) { uint64_t workspaceSize = 0; @@ -47,23 +53,24 @@ atb::Status LinearParallelSample(int rank, int rankSize) context->SetExecuteStream(stream); atb::Tensor input; - CreateTensorFromVector(context, stream, std::vector(64, 2.0), aclDataType::ACL_FLOAT16, - aclFormat::ACL_FORMAT_ND, {2, 32}, input); + const float value = 2.0; + CreateTensorFromVector(context, stream, std::vector(DIM_M * DIM_K, value), aclDataType::ACL_FLOAT16, + aclFormat::ACL_FORMAT_ND, {DIM_M, DIM_K}, input); atb::Tensor weight; - CreateTensorFromVector(context, stream, std::vector(64, 2), aclDataType::ACL_INT8, aclFormat::ACL_FORMAT_ND, - {32, 2}, weight); + CreateTensorFromVector(context, stream, std::vector(DIM_K * DIM_N, value), aclDataType::ACL_INT8, + aclFormat::ACL_FORMAT_ND, {DIM_K, DIM_N}, weight); atb::Tensor bias; - CreateTensorFromVector(context, stream, std::vector(1, 1.0), aclDataType::ACL_FLOAT16, + CreateTensorFromVector(context, stream, std::vector(1, value), aclDataType::ACL_FLOAT16, aclFormat::ACL_FORMAT_ND, {1}, bias); atb::Tensor deqScale; - CreateTensorFromVector(context, stream, std::vector(1, 1.0), aclDataType::ACL_FLOAT16, + CreateTensorFromVector(context, stream, std::vector(1, value), aclDataType::ACL_FLOAT16, aclFormat::ACL_FORMAT_ND, {1}, deqScale); atb::Tensor output; - CreateTensor(aclDataType::ACL_FLOAT16, aclFormat::ACL_FORMAT_ND, {2, 2}, output); + CreateTensor(aclDataType::ACL_FLOAT16, aclFormat::ACL_FORMAT_ND, {DIM_M, DIM_N}, output); atb::infer::LinearParallelParam param; param.transWeight = false; diff --git a/example/op_demo/mla_preprocess/mlapo_demo.cpp b/example/op_demo/mla_preprocess/mlapo_demo.cpp index 7f93d7126e3d4a1b40624affeba8f982af9b8c4a..7cd6640aa53ff75c5846067bd956552969523839 100644 --- a/example/op_demo/mla_preprocess/mlapo_demo.cpp +++ b/example/op_demo/mla_preprocess/mlapo_demo.cpp @@ -10,9 +10,31 @@ #include "../demo_util.h" +namespace { const int32_t DEVICE_ID = 2; -const uint32_t blockSize = 128; -const uint32_t blockNum = 64; +const uint32_t BLOCK_SIZE = 128; +const uint32_t BLOCK_NUM = 64; + +const int32_t INPUT_NUM = 4; +const int32_t DTYPE_IDX = 1; +const int32_t TOKEN_NUM_IDX = 2; +const int32_t HEAD_NUM_IDX = 3; +const int32_t KV_CACHE_NOPE_IDX = 19; +const int32_t KV_CACHE_ROPE_IDX = 20; + +const int32_t RUNS = 10; + +const int32_t ROPE_DIM64 = 64; +const int32_t NZ_DIM16 = 16; +const int32_t NOPE_DIM512 = 512; +const int32_t RMSNORM_QUANT_DIM7168 = 7168; +const int32_t RMSNORM_QUANT_DIM1536 = 1536; +const int32_t MATMUL_DIM224 = 224; +const int32_t MATMUL_DIM2112 = 2112; +const int32_t MATMUL_DIM32 = 32; +const int32_t MATMUL_DIM192 = 192; +const int32_t MATMUL_DIM48 = 48; +} // namespace /** * @brief 准备atb::VariantPack中的输入tensor @@ -26,16 +48,17 @@ atb::Status PrepareInTensor1(atb::Context *contextPtr, aclrtStream stream, aclDa { // 创建shape为[tokenNum, 7168]的输入input tensor atb::Tensor input; - CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector<__fp16>(tokenNum * 7168, 0), dtype, - aclFormat::ACL_FORMAT_ND, {tokenNum, 7168}, input, dtype)); + CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector<__fp16>(tokenNum * RMSNORM_QUANT_DIM7168, 0), + dtype, aclFormat::ACL_FORMAT_ND, {tokenNum, RMSNORM_QUANT_DIM7168}, input, + dtype)); // 创建shape为[7168]的输入gamma0 tensor atb::Tensor gamma0; - CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector<__fp16>(7168, 0), dtype, - aclFormat::ACL_FORMAT_ND, {7168}, gamma0, dtype)); + CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector<__fp16>(RMSNORM_QUANT_DIM7168, 0), dtype, + aclFormat::ACL_FORMAT_ND, {RMSNORM_QUANT_DIM7168}, gamma0, dtype)); // 创建shape为[7168]的输入beta0 tensor atb::Tensor beta0; - CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector<__fp16>(7168, 0), dtype, - aclFormat::ACL_FORMAT_ND, {7168}, beta0, dtype)); + CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector<__fp16>(RMSNORM_QUANT_DIM7168, 0), dtype, + aclFormat::ACL_FORMAT_ND, {RMSNORM_QUANT_DIM7168}, beta0, dtype)); // 创建shape为[1]的输入quantScale0 tensor atb::Tensor quantScale0; CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector<__fp16>(1, 0), dtype, aclFormat::ACL_FORMAT_ND, @@ -46,29 +69,31 @@ atb::Status PrepareInTensor1(atb::Context *contextPtr, aclrtStream stream, aclDa aclFormat::ACL_FORMAT_ND, {1}, quantOffset0)); // 创建shape为[1,224,2112,32]的输入wdqkv tensor atb::Tensor wdqkv; - CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector(224 * 2112 * 32, 1), ACL_INT8, - aclFormat::ACL_FORMAT_FRACTAL_NZ, {1, 224, 2112, 32}, wdqkv)); + CHECK_STATUS(CreateTensorFromVector( + contextPtr, stream, std::vector(MATMUL_DIM224 * MATMUL_DIM2112 * MATMUL_DIM32, 1), ACL_INT8, + aclFormat::ACL_FORMAT_FRACTAL_NZ, {1, MATMUL_DIM224, MATMUL_DIM2112, MATMUL_DIM32}, wdqkv)); // 创建shape为[2112]的输入deScale0 tensor atb::Tensor deScale0; if (dtype == ACL_BF16) { - CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector(2112, 1), ACL_FLOAT, - aclFormat::ACL_FORMAT_ND, {2112}, deScale0)); + CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector(MATMUL_DIM2112, 1), ACL_FLOAT, + aclFormat::ACL_FORMAT_ND, {MATMUL_DIM2112}, deScale0)); } else { - CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector(2112, 10), ACL_INT64, - aclFormat::ACL_FORMAT_ND, {2112}, deScale0)); + int64_t deScale0Value = 10; + CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector(MATMUL_DIM2112, deScale0Value), + ACL_INT64, aclFormat::ACL_FORMAT_ND, {MATMUL_DIM2112}, deScale0)); } // 创建shape为[2112]的输入bias0 tensor atb::Tensor bias0; - CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector(2112, 1), ACL_INT32, - aclFormat::ACL_FORMAT_ND, {2112}, bias0)); + CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector(MATMUL_DIM2112, 1), ACL_INT32, + aclFormat::ACL_FORMAT_ND, {MATMUL_DIM2112}, bias0)); // 创建shape为[1536]的输入gamma1 tensor atb::Tensor gamma1; - CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector<__fp16>(1536, 0), dtype, - aclFormat::ACL_FORMAT_ND, {1536}, gamma1, dtype)); + CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector<__fp16>(RMSNORM_QUANT_DIM1536, 0), dtype, + aclFormat::ACL_FORMAT_ND, {RMSNORM_QUANT_DIM1536}, gamma1, dtype)); // 创建shape为[1536]的输入beta1 tensor atb::Tensor beta1; - CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector<__fp16>(1536, 0), dtype, - aclFormat::ACL_FORMAT_ND, {1536}, beta1, dtype)); + CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector<__fp16>(RMSNORM_QUANT_DIM1536, 0), dtype, + aclFormat::ACL_FORMAT_ND, {RMSNORM_QUANT_DIM1536}, beta1, dtype)); // 创建shape为[1]的输入quantScale1 tensor atb::Tensor quantScale1; CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector<__fp16>(1, 0), dtype, aclFormat::ACL_FORMAT_ND, @@ -94,47 +119,52 @@ atb::Status PrepareInTensor2(atb::Context *contextPtr, aclrtStream stream, aclDa { // 创建shape为[1,48,headNum*192,32]的输入wuq tensor atb::Tensor wuq; - CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector(48 * headNum * 192 * 32, 1), ACL_INT8, - aclFormat::ACL_FORMAT_FRACTAL_NZ, {1, 48, headNum * 192, 32}, wuq)); + CHECK_STATUS(CreateTensorFromVector( + contextPtr, stream, std::vector(MATMUL_DIM48 * headNum * MATMUL_DIM192 * MATMUL_DIM32, 1), ACL_INT8, + aclFormat::ACL_FORMAT_FRACTAL_NZ, {1, MATMUL_DIM48, headNum * MATMUL_DIM192, MATMUL_DIM32}, wuq)); // 创建shape为[headNum*192]的输入deScale1 tensor atb::Tensor deScale1; if (dtype == ACL_BF16) { - CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector(headNum * 192, 1), ACL_FLOAT, - aclFormat::ACL_FORMAT_ND, {headNum * 192}, deScale1)); + CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector(headNum * MATMUL_DIM192, 1), + ACL_FLOAT, aclFormat::ACL_FORMAT_ND, {headNum * MATMUL_DIM192}, deScale1)); } else { - CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector(headNum * 192, 10), ACL_INT64, - aclFormat::ACL_FORMAT_ND, {headNum * 192}, deScale1)); + CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector(headNum * MATMUL_DIM192, 1), + ACL_INT64, aclFormat::ACL_FORMAT_ND, {headNum * MATMUL_DIM192}, deScale1)); } // 创建shape为[headNum*192]的输入bias1 tensor atb::Tensor bias1; - CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector(headNum * 192, 1), ACL_INT32, - aclFormat::ACL_FORMAT_ND, {headNum * 192}, bias1)); + CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector(headNum * MATMUL_DIM192, 1), ACL_INT32, + aclFormat::ACL_FORMAT_ND, {headNum * MATMUL_DIM192}, bias1)); // 创建shape为[512]的输入gamma2 tensor atb::Tensor gamma2; - CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector<__fp16>(512, 0), dtype, - aclFormat::ACL_FORMAT_ND, {512}, gamma2, dtype)); + CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector<__fp16>(NOPE_DIM512, 0), dtype, + aclFormat::ACL_FORMAT_ND, {NOPE_DIM512}, gamma2, dtype)); // 创建shape为[tokenNum,64]的输入cos tensor atb::Tensor cos; - CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector<__fp16>(tokenNum * 64, 0), dtype, - aclFormat::ACL_FORMAT_ND, {tokenNum, 64}, cos, dtype)); + CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector<__fp16>(tokenNum * ROPE_DIM64, 0), dtype, + aclFormat::ACL_FORMAT_ND, {tokenNum, ROPE_DIM64}, cos, dtype)); // 创建shape为[tokenNum,64]的输入sin tensor atb::Tensor sin; - CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector<__fp16>(tokenNum * 64, 0.5), dtype, - aclFormat::ACL_FORMAT_ND, {tokenNum, 64}, sin, dtype)); + __fp16 sinValue = 0.5; + CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector<__fp16>(tokenNum * ROPE_DIM64, sinValue), dtype, + aclFormat::ACL_FORMAT_ND, {tokenNum, ROPE_DIM64}, sin, dtype)); // 创建shape为[headNum,32,128,16]的输入wuk tensor atb::Tensor wuk; - CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector<__fp16>(headNum * 32 * 128 * 16, 0), dtype, - aclFormat::ACL_FORMAT_FRACTAL_NZ, {headNum, 32, 128, 16}, wuk, dtype)); - // 创建shape为[blockNum, headNum*512/32,block_size, 32]的输入kvCache tensor + CHECK_STATUS(CreateTensorFromVector( + contextPtr, stream, std::vector<__fp16>(headNum * MATMUL_DIM32 * BLOCK_SIZE * NZ_DIM16, 0), dtype, + aclFormat::ACL_FORMAT_FRACTAL_NZ, {headNum, MATMUL_DIM32, BLOCK_SIZE, NZ_DIM16}, wuk, dtype)); + // 创建shape为[BLOCK_NUM, headNum*512/32,block_size, 32]的输入kvCache tensor atb::Tensor kvCache; CHECK_STATUS(CreateTensorFromVector( - contextPtr, stream, std::vector(blockNum * headNum * 512 * blockSize, 1), ACL_INT8, - aclFormat::ACL_FORMAT_FRACTAL_NZ, {blockNum, headNum * 512 / 32, blockSize, 32}, kvCache)); - // 创建shape为[blockNum, headNum*64/16 ,block_size, 16]的输入kvCacheRope tensor + contextPtr, stream, std::vector(BLOCK_NUM * headNum * NOPE_DIM512 * BLOCK_SIZE, 1), ACL_INT8, + aclFormat::ACL_FORMAT_FRACTAL_NZ, {BLOCK_NUM, headNum * NOPE_DIM512 / MATMUL_DIM32, BLOCK_SIZE, MATMUL_DIM32}, + kvCache)); + // 创建shape为[BLOCK_NUM, headNum*64/16 ,block_size, 16]的输入kvCacheRope tensor atb::Tensor kvCacheRope; CHECK_STATUS(CreateTensorFromVector( - contextPtr, stream, std::vector<__fp16>(blockNum * headNum * 64 / 16 * blockSize * 16, 0), dtype, - aclFormat::ACL_FORMAT_FRACTAL_NZ, {blockNum, headNum * 64 / 16, blockSize, 16}, kvCacheRope, dtype)); + contextPtr, stream, std::vector<__fp16>(BLOCK_NUM * headNum * ROPE_DIM64 / NZ_DIM16 * BLOCK_SIZE * NZ_DIM16, 0), + dtype, aclFormat::ACL_FORMAT_FRACTAL_NZ, {BLOCK_NUM, headNum * ROPE_DIM64 / NZ_DIM16, BLOCK_SIZE, NZ_DIM16}, + kvCacheRope, dtype)); auto slotmappingHost = std::vector(1, tokenNum); for (size_t i = 0; i < slotmappingHost.size(); i++) slotmappingHost[i] = static_cast(i); @@ -191,11 +221,11 @@ atb::Status RunDemo(atb::Context *context, void *stream, aclDataType dtype, int CHECK_STATUS(PrepareInTensor2(context, stream, dtype, tokenNum, headNum, &variantPack.inTensors)); // 准备输出tensor atb::Tensor qOut0; - CreateTensor(ACL_INT8, aclFormat::ACL_FORMAT_ND, {tokenNum, headNum, 512}, qOut0); - atb::Tensor &kvCacheOut0 = variantPack.inTensors.at(19); + CreateTensor(ACL_INT8, aclFormat::ACL_FORMAT_ND, {tokenNum, headNum, NOPE_DIM512}, qOut0); + atb::Tensor &kvCacheOut0 = variantPack.inTensors.at(KV_CACHE_NOPE_IDX); atb::Tensor qOut1; - CreateTensor(dtype, aclFormat::ACL_FORMAT_ND, {tokenNum, headNum, 64}, qOut1); - atb::Tensor &kvCacheOut1 = variantPack.inTensors.at(20); + CreateTensor(dtype, aclFormat::ACL_FORMAT_ND, {tokenNum, headNum, ROPE_DIM64}, qOut1); + atb::Tensor &kvCacheOut1 = variantPack.inTensors.at(KV_CACHE_ROPE_IDX); variantPack.outTensors = {qOut0, kvCacheOut0, qOut1, kvCacheOut1}; // 放入输出tensor uint64_t workspaceSize = 0; @@ -205,7 +235,7 @@ atb::Status RunDemo(atb::Context *context, void *stream, aclDataType dtype, int if (workspaceSize > 0) { CHECK_STATUS(aclrtMalloc((void **)(&workspacePtr), workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST)); } - for (size_t i = 0; i < 10; i++) { + for (size_t i = 0; i < RUNS; i++) { std::cout << "tokenNum: " << tokenNum << " headNum: " << headNum << " loop: " << i << std::endl; // mlaPreprocess执行 CHECK_STATUS(mlaPreprocessOp->Execute(variantPack, workspacePtr, workspaceSize, context)); @@ -238,10 +268,10 @@ int main(int argc, char **argv) int tokenNum = 4; int headNum = 128; aclDataType dtype = ACL_FLOAT16; - if (argc == 4) { - dtypeStr = argv[1]; - tokenNum = std::stoi(argv[2]); - headNum = std::stoi(argv[3]); + if (argc == INPUT_NUM) { + dtypeStr = argv[DTYPE_IDX]; + tokenNum = std::stoi(argv[TOKEN_NUM_IDX]); + headNum = std::stoi(argv[HEAD_NUM_IDX]); } if (dtypeStr == "bf16") { dtype = ACL_BF16; diff --git a/example/op_demo/multi_latent_attention/mlapa_demo.cpp b/example/op_demo/multi_latent_attention/mlapa_demo.cpp index 501d63e9f4388310cd9765930065b1e3f972c4f8..6db4369a7455b2983f0ee5a8956c24040ac276b2 100644 --- a/example/op_demo/multi_latent_attention/mlapa_demo.cpp +++ b/example/op_demo/multi_latent_attention/mlapa_demo.cpp @@ -8,13 +8,31 @@ * See LICENSE in the root of the software repository for the full text of the License. */ +#include #include "../demo_util.h" +namespace { const int32_t DEVICE_ID = 1; -const uint32_t blockSize = 128; -int32_t blockNum = 64; +const uint32_t BLOCK_SIZE = 128; +const int32_t DIM512 = 512; +const int32_t ROPE_HEAD_SIZE = 64; +const int32_t KV_HEAD_NUM = 1; +const int32_t CTKV_HEAD_SIZE_CACHE2 = 32; +const int32_t ALIGN16 = 16; +const int32_t NUM4 = 4; + std::vector contextLensHost; +const int32_t INPUT_NUM = 5; +const int32_t DTYPE_IDX = 1; +const int32_t TOKEN_NUM_IDX = 2; +const int32_t HEAD_NUM_IDX = 3; +const int32_t K_SEQLEN_IDX = 4; + +const int32_t RUNS = 2; +} // namespace + + /** * @brief 准备atb::VariantPack * @param contextPtr context指针 @@ -27,23 +45,27 @@ atb::Status PrepareInTensor(atb::Context *contextPtr, aclrtStream stream, aclDat { // 创建shape为[tokenNum, headNum, 512]的输入qNope tensor atb::Tensor qNope; - CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector(tokenNum * headNum * 512, 1), ACL_INT8, - aclFormat::ACL_FORMAT_ND, {tokenNum, headNum, 512}, qNope)); + CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector(tokenNum * headNum * DIM512, 1), + ACL_INT8, aclFormat::ACL_FORMAT_ND, {tokenNum, headNum, DIM512}, qNope)); // 创建shape为[tokenNum, headNum, 64]的输入qRope tensor atb::Tensor qRope; - CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector<__fp16>(tokenNum * headNum * 64, 0), dtype, - aclFormat::ACL_FORMAT_ND, {tokenNum, headNum, 64}, qRope, dtype)); - int maxBlockNumPerSeq = (kSeqLen + blockSize - 1) / blockSize; + CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector<__fp16>(tokenNum * headNum * ROPE_HEAD_SIZE, 0), + dtype, aclFormat::ACL_FORMAT_ND, {tokenNum, headNum, ROPE_HEAD_SIZE}, qRope, + dtype)); + int maxBlockNumPerSeq = (kSeqLen + BLOCK_SIZE - 1) / BLOCK_SIZE; + int32_t blockNum = 64; blockNum = tokenNum * maxBlockNumPerSeq; - // 创建shape为[blockNum, 16, blockSize, 32]的输入ctKV tensor + // 创建shape为[blockNum, 16, BLOCK_SIZE, 32]的输入ctKV tensor atb::Tensor ctKV; - CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector(blockNum * blockSize * 512, 1), - ACL_INT8, aclFormat::ACL_FORMAT_FRACTAL_NZ, {blockNum, 16, blockSize, 32}, + CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector(blockNum * BLOCK_SIZE * DIM512, 1), + ACL_INT8, aclFormat::ACL_FORMAT_FRACTAL_NZ, + {blockNum, ALIGN16, BLOCK_SIZE, CTKV_HEAD_SIZE_CACHE2}, ctKV)); - // 创建shape为[blockNum, 4, blockSize, 16]的输入kRope tensor + // 创建shape为[blockNum, 4, BLOCK_SIZE, 16]的输入kRope tensor atb::Tensor kRope; - CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector<__fp16>(blockNum * blockSize * 64, 0), dtype, - aclFormat::ACL_FORMAT_FRACTAL_NZ, {blockNum, 4, blockSize, 16}, kRope, dtype)); + CHECK_STATUS(CreateTensorFromVector( + contextPtr, stream, std::vector<__fp16>(blockNum * BLOCK_SIZE * ROPE_HEAD_SIZE, 0), dtype, + aclFormat::ACL_FORMAT_FRACTAL_NZ, {blockNum, NUM4, BLOCK_SIZE, ALIGN16}, kRope, dtype)); // 创建shape为[tokenNum, maxBlockNumPerSeq]的输入blockTables tensor auto blockTablesHost = std::vector(tokenNum * maxBlockNumPerSeq); for (size_t i = 0; i < tokenNum; i++) { @@ -80,7 +102,7 @@ atb::Status CreateMultiLatentAttentionOperation(int headNum, atb::Operation **ml { atb::infer::MultiLatentAttentionParam param; param.headNum = headNum; - param.qkScale = 0.0416666679084301; + param.qkScale = 1 / sqrt(DIM512); param.kvHeadNum = 1; param.cacheMode = atb::infer::MultiLatentAttentionParam::CacheMode::INT8_NZCACHE; return atb::CreateOperation(param, mlaOp); @@ -108,7 +130,7 @@ atb::Status RunDemo(atb::Context *context, void *stream, aclDataType dtype, int PrepareInTensor(context, stream, dtype, tokenNum, headNum, kSeqLen, variantPack.inTensors)); // 放入输入tensor // 准备输出tensor atb::Tensor attenOut; - CHECK_STATUS(CreateTensor(dtype, aclFormat::ACL_FORMAT_ND, {tokenNum, headNum, 512}, attenOut)); + CHECK_STATUS(CreateTensor(dtype, aclFormat::ACL_FORMAT_ND, {tokenNum, headNum, DIM512}, attenOut)); variantPack.outTensors = {attenOut}; // 放入输出tensor uint64_t workspaceSize = 0; @@ -118,7 +140,7 @@ atb::Status RunDemo(atb::Context *context, void *stream, aclDataType dtype, int if (workspaceSize > 0) { CHECK_STATUS(aclrtMalloc((void **)(&workspacePtr), workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST)); } - for (size_t i = 0; i < 2; i++) { + for (size_t i = 0; i < RUNS; i++) { std::cout << "tokenNum: " << tokenNum << " headNum: " << headNum << " loop: " << i << std::endl; // mlaPreprocess执行 mlaOp->Execute(variantPack, workspacePtr, workspaceSize, context); @@ -153,11 +175,11 @@ int main(int argc, char **argv) int headNum = 128; int kSeqLen = 1500; aclDataType dtype = ACL_FLOAT16; - if (argc == 5) { - dtypeStr = argv[1]; - tokenNum = std::stoi(argv[2]); - headNum = std::stoi(argv[3]); - kSeqLen = std::stoi(argv[4]); + if (argc == INPUT_NUM) { + dtypeStr = argv[DTYPE_IDX]; + tokenNum = std::stoi(argv[TOKEN_NUM_IDX]); + headNum = std::stoi(argv[HEAD_NUM_IDX]); + kSeqLen = std::stoi(argv[K_SEQLEN_IDX]); } if (dtypeStr == "bf16") { dtype = ACL_BF16; diff --git a/example/op_demo/paged_attention/paged_attention_demo.cpp b/example/op_demo/paged_attention/paged_attention_demo.cpp index 7d4a2b19ae7881a243f0968e7d8ab5a895614cc4..b3bda477570b1cefd4cdbc7ac581486dd78fbd38 100644 --- a/example/op_demo/paged_attention/paged_attention_demo.cpp +++ b/example/op_demo/paged_attention/paged_attention_demo.cpp @@ -9,18 +9,22 @@ */ #include +#include #include "../demo_util.h" -const uint32_t NTOKENS = 2; // token数量 -const uint32_t BATCH_SIZE = NTOKENS; // batch数量 -const uint32_t MAX_SEQ_LEN = 1024; // 最大序列长度 -const uint32_t HEAD_NUM = 32; // 头数 -const uint32_t KV_HEAD_NUM = 32; // kv头数 -const uint32_t HEAD_SIZE = 128; // 头大小 -const uint32_t BLOCK_NUM = 16; // 块数量 -const uint32_t BLOCK_SIZE = 128; // 块大小 -const uint32_t MAX_CONTEXT_LEN = 1024; // 上下文最大长度 -std::vector contextLensData(BATCH_SIZE, 256); // contextLens的host侧数据 +namespace { +const uint32_t NTOKENS = 2; // token数量 +const uint32_t BATCH_SIZE = NTOKENS; // batch数量 +const uint32_t MAX_SEQ_LEN = 1024; // 最大序列长度 +const uint32_t HEAD_NUM = 32; // 头数 +const uint32_t KV_HEAD_NUM = 32; // kv头数 +const uint32_t HEAD_SIZE = 128; // 头大小 +const uint32_t BLOCK_NUM = 16; // 块数量 +const uint32_t BLOCK_SIZE = 128; // 块大小 +const uint32_t MAX_CONTEXT_LEN = 1024; // 上下文最大长度 +const int32_t SEQLEN_VALUE = 256; // 每batch对应seqlen长度 +std::vector contextLensData(BATCH_SIZE, SEQLEN_VALUE); // contextLens的host侧数据 +} /** * @brief 准备atb::VariantPack中的所有输入tensor @@ -49,7 +53,8 @@ atb::Status PrepareInTensor(atb::Context *contextPtr, aclrtStream stream, atb::S std::vector blockTablesData(NTOKENS * maxNumBlocksPerQuery, 0); std::random_device rd; std::mt19937 gen(rd()); - std::uniform_int_distribution dist(0, BLOCK_NUM - 2); + const int diff = 2; + std::uniform_int_distribution dist(0, BLOCK_NUM - diff); for (size_t i = 0; i < blockTablesData.size(); i++) { blockTablesData[i] = dist(gen); } @@ -87,7 +92,7 @@ atb::Status PrepareOperation(atb::Operation **paOp) paOpParam.maskType = atb::infer::PagedAttentionParam::MaskType::MASK_TYPE_NORM; paOpParam.headNum = HEAD_NUM; paOpParam.kvHeadNum = KV_HEAD_NUM; - paOpParam.qkScale = 0.08838834764831843; + paOpParam.qkScale = 1 / sqrt(HEAD_SIZE); return atb::CreateOperation(paOpParam, paOp); } diff --git a/example/op_demo/paged_attention/paged_attention_inference_demo.cpp b/example/op_demo/paged_attention/paged_attention_inference_demo.cpp index 0513315100a89ae2610d944d11409a7de65b035d..4957853a5aeed137760d56dc1a70da8b870f97bb 100644 --- a/example/op_demo/paged_attention/paged_attention_inference_demo.cpp +++ b/example/op_demo/paged_attention/paged_attention_inference_demo.cpp @@ -9,18 +9,22 @@ */ #include +#include #include "../demo_util.h" -const uint32_t NTOKENS = 2; // token数量 -const uint32_t BATCH_SIZE = NTOKENS; // batch数量 -const uint32_t MAX_SEQ_LEN = 1024; // 最大序列长度 -const uint32_t HEAD_NUM = 32; // 头数 -const uint32_t KV_HEAD_NUM = 32; // kv头数 -const uint32_t HEAD_SIZE = 128; // 头大小 -const uint32_t BLOCK_NUM = 16; // 块数量 -const uint32_t BLOCK_SIZE = 128; // 块大小 -const uint32_t MAX_CONTEXT_LEN = 1024; // 上下文最大长度 -std::vector contextLensData(BATCH_SIZE, 256); // contextLens的host侧数据 +namespace { +const uint32_t NTOKENS = 2; // token数量 +const uint32_t BATCH_SIZE = NTOKENS; // batch数量 +const uint32_t MAX_SEQ_LEN = 1024; // 最大序列长度 +const uint32_t HEAD_NUM = 32; // 头数 +const uint32_t KV_HEAD_NUM = 32; // kv头数 +const uint32_t HEAD_SIZE = 128; // 头大小 +const uint32_t BLOCK_NUM = 16; // 块数量 +const uint32_t BLOCK_SIZE = 128; // 块大小 +const uint32_t MAX_CONTEXT_LEN = 1024; // 上下文最大长度 +const int32_t SEQLEN_VALUE = 256; // 每batch对应seqlen长度 +std::vector contextLensData(BATCH_SIZE, SEQLEN_VALUE); // contextLens的host侧数据 +} /** * @brief 准备atb::VariantPack中的所有输入tensor @@ -51,7 +55,8 @@ atb::Status PrepareInTensor(atb::Context *contextPtr, aclrtStream stream, atb::S std::vector blockTablesData(NTOKENS * maxNumBlocksPerQuery, 0); std::random_device rd; std::mt19937 gen(rd()); - std::uniform_int_distribution dist(0, BLOCK_NUM - 2); + const int diff = 2; + std::uniform_int_distribution dist(0, BLOCK_NUM - diff); for (size_t i = 0; i < blockTablesData.size(); i++) { blockTablesData[i] = dist(gen); } @@ -78,7 +83,7 @@ atb::Status PrepareOperation(atb::Operation **paOp) atb::infer::PagedAttentionParam paOpParam; paOpParam.headNum = HEAD_NUM; paOpParam.kvHeadNum = KV_HEAD_NUM; - paOpParam.qkScale = 0.08838834764831843; + paOpParam.qkScale = 1 / sqrt(HEAD_SIZE); return atb::CreateOperation(paOpParam, paOp); } diff --git a/example/op_demo/reshape_and_cache/reshape_and_cache_demo.cpp b/example/op_demo/reshape_and_cache/reshape_and_cache_demo.cpp index 0ca4922d47b570448246ecf8e76157399eb69f9c..8eab59a3b2d879a84141ad07abb3f4d49dd0de34 100644 --- a/example/op_demo/reshape_and_cache/reshape_and_cache_demo.cpp +++ b/example/op_demo/reshape_and_cache/reshape_and_cache_demo.cpp @@ -15,12 +15,14 @@ #include "../demo_util.h" +namespace { uint32_t NUM_TOKENS = 2; uint32_t NUM_HEAD = 32; uint32_t K_HEAD_SIZE = 128; uint32_t V_HEAD_SIZE = K_HEAD_SIZE; uint32_t NUM_BLOCKS = 512; uint32_t BLOCK_SIZE = 128; +} /** * @brief 准备随机输入tensorK或输入tensorV的内容 @@ -29,11 +31,12 @@ uint32_t BLOCK_SIZE = 128; */ std::vector KvGeneration(bool kvflag) { + const float range = 100; // 创建随机数生成器 std::random_device rd; std::mt19937 gen(rd()); // 定义随机数分布范围 - std::uniform_real_distribution<> dis(-100.0, 100.0); + std::uniform_real_distribution<> dis(-range, range); // 定义要生成的随机数的个数 size_t num_elements = kvflag ? NUM_TOKENS * NUM_HEAD * V_HEAD_SIZE : NUM_TOKENS * NUM_HEAD * K_HEAD_SIZE; // 创建一个 vector 并填充随机数 diff --git a/example/op_demo/reshape_and_cache/reshape_and_cache_inference_demo.cpp b/example/op_demo/reshape_and_cache/reshape_and_cache_inference_demo.cpp index b43c1a1f9d7cbedc150511886796a7d95f5a8b54..421b015a8079368cfcd553ddc36d9d6d4389f1f9 100644 --- a/example/op_demo/reshape_and_cache/reshape_and_cache_inference_demo.cpp +++ b/example/op_demo/reshape_and_cache/reshape_and_cache_inference_demo.cpp @@ -15,12 +15,14 @@ #include #include "../demo_util.h" +namespace { uint32_t NUM_TOKENS = 3; uint32_t NUM_HEAD = 4; uint32_t K_HEAD_SIZE = 128; uint32_t V_HEAD_SIZE = K_HEAD_SIZE; uint32_t NUM_BLOCKS = 512; uint32_t BLOCK_SIZE = 128; +} /** * @brief 准备随机输入tensorK或输入tensorV的内容 @@ -29,11 +31,12 @@ uint32_t BLOCK_SIZE = 128; */ std::vector KvGeneration(bool kvflag) { + const float range = 100; // 创建随机数生成器 std::random_device rd; std::mt19937 gen(rd()); // 定义随机数分布范围 - std::uniform_real_distribution<> dis(-100.0, 100.0); + std::uniform_real_distribution<> dis(-range, range); // 定义要生成的随机数的个数 size_t num_elements = kvflag ? NUM_TOKENS * NUM_HEAD * V_HEAD_SIZE : NUM_TOKENS * NUM_HEAD * K_HEAD_SIZE; // 创建一个 vector 并填充随机数 diff --git a/example/op_demo/rms_norm/rms_norm_demo.cpp b/example/op_demo/rms_norm/rms_norm_demo.cpp index 5ca64ddff5cb49ba4756c4db5308d5f9dc1b32af..a8581d5ea68794df9fa507c148a2947f18217384 100644 --- a/example/op_demo/rms_norm/rms_norm_demo.cpp +++ b/example/op_demo/rms_norm/rms_norm_demo.cpp @@ -10,10 +10,12 @@ #include "../demo_util.h" +namespace { const int32_t DEVICE_ID = 0; const uint32_t DIM_0 = 4; const uint32_t DIM_1 = 1024; const uint32_t DIM_2 = 5120; +} /** * @brief 准备atb::VariantPack中的所有输入tensor @@ -25,11 +27,12 @@ const uint32_t DIM_2 = 5120; atb::Status PrepareInTensor(atb::Context *contextPtr, aclrtStream stream, atb::SVector &inTensors) { // 创建shape为[4, 1024, 5120]的tensor + const float value = 2.0; atb::Tensor x; - CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector(DIM_0 * DIM_1 * DIM_2, 2.0), ACL_FLOAT16, - aclFormat::ACL_FORMAT_ND, {DIM_0, DIM_1, DIM_2}, x)); + CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector(DIM_0 * DIM_1 * DIM_2, value), + ACL_FLOAT16, aclFormat::ACL_FORMAT_ND, {DIM_0, DIM_1, DIM_2}, x)); atb::Tensor gamma; - CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector(DIM_2, 2.0), ACL_FLOAT16, + CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector(DIM_2, value), ACL_FLOAT16, aclFormat::ACL_FORMAT_ND, {DIM_2}, gamma)); inTensors = {x, gamma}; return atb::ErrorType::NO_ERROR; diff --git a/example/op_demo/rms_norm_backward/rms_norm_backward_demo.cpp b/example/op_demo/rms_norm_backward/rms_norm_backward_demo.cpp index 71c22d563218da207b2a033da426ded3d5186332..3da798725c40f0f00990f3649bba5530bbb33f05 100644 --- a/example/op_demo/rms_norm_backward/rms_norm_backward_demo.cpp +++ b/example/op_demo/rms_norm_backward/rms_norm_backward_demo.cpp @@ -10,10 +10,12 @@ #include "../demo_util.h" +namespace { const int32_t DEVICE_ID = 0; const uint32_t DIM_0 = 32; const uint32_t DIM_1 = 64; const uint32_t DIM_2 = 128; +} /** * @brief 准备atb::VariantPack中的所有输入tensor @@ -24,18 +26,19 @@ const uint32_t DIM_2 = 128; */ atb::Status PrepareInTensor(atb::Context *contextPtr, aclrtStream stream, atb::SVector &inTensors) { + const float tensorValue = 2.0; // 创建shape为[32, 64, 128]的tensor atb::Tensor dy; - CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector(DIM_0 * DIM_1 * DIM_2, 2.0), ACL_FLOAT, - aclFormat::ACL_FORMAT_ND, {DIM_0, DIM_1, DIM_2}, dy)); + CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector(DIM_0 * DIM_1 * DIM_2, tensorValue), + ACL_FLOAT, aclFormat::ACL_FORMAT_ND, {DIM_0, DIM_1, DIM_2}, dy)); atb::Tensor x; - CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector(DIM_0 * DIM_1 * DIM_2, 2.0), ACL_FLOAT, - aclFormat::ACL_FORMAT_ND, {DIM_0, DIM_1, DIM_2}, x)); + CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector(DIM_0 * DIM_1 * DIM_2, tensorValue), + ACL_FLOAT, aclFormat::ACL_FORMAT_ND, {DIM_0, DIM_1, DIM_2}, x)); atb::Tensor rstd; - CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector(DIM_0 * DIM_1, 2.0), ACL_FLOAT, + CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector(DIM_0 * DIM_1, tensorValue), ACL_FLOAT, aclFormat::ACL_FORMAT_ND, {DIM_0, DIM_1, 1}, rstd)); atb::Tensor gamma; - CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector(DIM_2, 2.0), ACL_FLOAT, + CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector(DIM_2, tensorValue), ACL_FLOAT, aclFormat::ACL_FORMAT_ND, {DIM_2}, gamma)); inTensors = {dy, x, rstd, gamma}; return atb::ErrorType::NO_ERROR; diff --git a/example/op_demo/rope/rope_demo.cpp b/example/op_demo/rope/rope_demo.cpp index e562410f73bfccca8d7578bf51d469cbce266ade..16e83ebad832e8fcb461ad6404f0a07fafac7f80 100644 --- a/example/op_demo/rope/rope_demo.cpp +++ b/example/op_demo/rope/rope_demo.cpp @@ -10,11 +10,13 @@ #include "../demo_util.h" +namespace { const uint32_t BATCH_SIZE = 1; // 批处理大小 const uint32_t NTOKENS = 4; // TOKEN大小 const uint32_t HIDDENSIZEQ = 16; // Q 隐藏层大小 const uint32_t HIDDENSIZEK = 16; // K 隐藏层大小 const uint32_t HEAD_SIZE = 8; // 头大小 +} /** * @brief 创建一个rope operation @@ -52,7 +54,8 @@ atb::Status PrepareInTensor(atb::Context *contextPtr, aclrtStream stream, atb::S atb::Tensor tensorSin; CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, sin, ACL_FLOAT16, aclFormat::ACL_FORMAT_ND, {NTOKENS, HEAD_SIZE}, tensorSin)); - std::vector seqLenHost(BATCH_SIZE, 4); + const int32_t seqlenValue = 4; + std::vector seqLenHost(BATCH_SIZE, seqlenValue); atb::Tensor tensorSeqLen; CHECK_STATUS(CreateTensor(ACL_INT32, aclFormat::ACL_FORMAT_ND, {BATCH_SIZE}, tensorSeqLen)); tensorSeqLen.hostData = seqLenHost.data(); diff --git a/example/op_demo/self_attention/self_attention_encoder_demo.cpp b/example/op_demo/self_attention/self_attention_encoder_demo.cpp index f6f7e44d2e805e966356867729600460386727ce..003c7ecea36512d5099f89b5c3538d5aea963187 100644 --- a/example/op_demo/self_attention/self_attention_encoder_demo.cpp +++ b/example/op_demo/self_attention/self_attention_encoder_demo.cpp @@ -10,9 +10,11 @@ #include "../demo_util.h" -const uint32_t BATCH_SIZE = 100; // 批处理大小 -std::vector seqLenHost(BATCH_SIZE, 16); // host侧tensor值,用于存储每个批处理中的序列长度 -std::vector tokenOffsetHost(BATCH_SIZE, 16); // host侧tensor值,token偏移 +namespace { +const uint32_t BATCH_SIZE = 100; // 批处理大小 +const int32_t SEQLEN_VALUE = 16; // 每个batch对应seqlen长度 +std::vector seqLenHost(BATCH_SIZE, SEQLEN_VALUE); // host侧tensor值,用于存储每个批处理中的序列长度 +std::vector tokenOffsetHost(BATCH_SIZE, SEQLEN_VALUE); // host侧tensor值,token偏移 std::vector layerId(1, 0); // device侧,kvCache中取哪个计算 const uint32_t NTOKENS = accumulate(seqLenHost.begin(), seqLenHost.end(), 0); // sum(seqLenHost) const uint32_t MAX_SEQ_LEN = 1024; // 最大序列长度 @@ -25,6 +27,7 @@ const int32_t IN_V_INDEX = 2; const int32_t IN_CACHE_K_INDEX = 3; const int32_t IN_CACHE_V_INDEX = 4; const int32_t IN_CACHE_MASK_INDEX = 5; +} /** * @brief 准备atb::VariantPack中的所有输入tensor @@ -55,16 +58,17 @@ atb::Status PrepareInTensor(atb::Context *contextPtr, aclrtStream stream, std::v std::vector kvCacheData(BATCH_SIZE * MAX_SEQ_LEN * kvHiddenSize, 1.0); // 创建norm mask,值为-inf的上三角mask std::vector maskData(BATCH_SIZE * MAX_SEQ_LEN * MAX_SEQ_LEN, 0); + const float negtiveInf = -32768; for (int i = 0; i < BATCH_SIZE; ++i) { for (int j = 0; j < MAX_SEQ_LEN; ++j) { for (int k = j + 1; k < MAX_SEQ_LEN; ++k) { - maskData[i * MAX_SEQ_LEN * MAX_SEQ_LEN + j * MAX_SEQ_LEN + k] = -32768; + maskData[i * MAX_SEQ_LEN * MAX_SEQ_LEN + j * MAX_SEQ_LEN + k] = negtiveInf; } } } // 创建tokenOffset,host侧tensor atb::Tensor tensorTokenOffset; - atb::Tensor tensorSeqLen + atb::Tensor tensorSeqLen; atb::Tensor tensorLayerId; CHECK_STATUS(CreateTensor(ACL_INT32, aclFormat::ACL_FORMAT_ND, {BATCH_SIZE}, tensorTokenOffset)); tensorTokenOffset.hostData = tokenOffsetHost.data(); // host侧tensor,拷贝值 diff --git a/example/op_demo/self_attention/self_attention_encoder_inference_demo.cpp b/example/op_demo/self_attention/self_attention_encoder_inference_demo.cpp index 38a2546d750383a180aac29c6ae55f5b075f6726..ad654d0f27ac4ae824030a824b3e0c4e345aac4e 100644 --- a/example/op_demo/self_attention/self_attention_encoder_inference_demo.cpp +++ b/example/op_demo/self_attention/self_attention_encoder_inference_demo.cpp @@ -10,9 +10,11 @@ #include "../demo_util.h" -const uint32_t BATCH_SIZE = 1; // 批处理大小 -std::vector seqLenHost(BATCH_SIZE, 16); // host侧tensor值,用于存储每个批处理中的序列长度 -std::vector tokenOffsetHost(BATCH_SIZE, 16); // host侧tensor值,token偏移 +namespace { +const uint32_t BATCH_SIZE = 1; // 批处理大小 +const int32_t SEQLEN_VALUE = 16; // 每个batch对应seqlen长度 +std::vector seqLenHost(BATCH_SIZE, SEQLEN_VALUE); // host侧tensor值,用于存储每个批处理中的序列长度 +std::vector tokenOffsetHost(BATCH_SIZE, SEQLEN_VALUE); // host侧tensor值,token偏移 std::vector layerId(1, 0); // device侧,kvCache中取哪个计算 const uint32_t NTOKENS = accumulate(seqLenHost.begin(), seqLenHost.end(), 0); // sum(seqLenHost) const uint32_t MAX_SEQ_LEN = 256; // 最大序列长度 @@ -20,6 +22,7 @@ const uint32_t HEAD_NUM = 16; // const uint32_t KV_HEAD_NUM = 16; // kv头数 const uint32_t HEAD_SIZE = 16; // 头大小 const uint32_t LAYER_NUM = 1; // 层大小 +} /** * @brief 准备atb::VariantPack中的所有输入tensor diff --git a/example/op_demo/self_attention/self_attention_pa_encoder_demo.cpp b/example/op_demo/self_attention/self_attention_pa_encoder_demo.cpp index 0e1decd2863a69420e6bd14eac03a3bafc491fd0..933cec9228935530d8a50f175ad2ad79e2824ccd 100644 --- a/example/op_demo/self_attention/self_attention_pa_encoder_demo.cpp +++ b/example/op_demo/self_attention/self_attention_pa_encoder_demo.cpp @@ -10,13 +10,16 @@ #include "../demo_util.h" -const uint32_t BATCH_SIZE = 1; // 批处理大小 -std::vector seqLenHost(BATCH_SIZE, 16); // host侧tensor值,用于存储每个批处理中的序列长度 +namespace { +const uint32_t BATCH_SIZE = 1; // 批处理大小 +const int32_t SEQLEN_VALUE = 16; // 每个batch对应seqlen长度 +std::vector seqLenHost(BATCH_SIZE, SEQLEN_VALUE); // host侧tensor值,用于存储每个批处理中的序列长度 const uint32_t NTOKENS = accumulate(seqLenHost.begin(), seqLenHost.end(), 0); // sum(seqLenHost) const uint32_t MAX_SEQ_LEN = 1024; // 最大序列长度 const uint32_t HEAD_NUM = 32; // 头数 const uint32_t KV_HEAD_NUM = 32; // kv头数 const uint32_t HEAD_SIZE = 64; // 头大小 +} /** * @brief 准备atb::VariantPack中的所有输入tensor @@ -45,10 +48,11 @@ atb::Status PrepareInTensor(atb::Context *contextPtr, aclrtStream stream, std::v {NTOKENS, KV_HEAD_NUM, HEAD_SIZE}, tensorV)); std::vector maskData(BATCH_SIZE * MAX_SEQ_LEN * MAX_SEQ_LEN, 0); // 创建norm mask,值为-inf的上三角mask + const float negtiveInf = -32768; for (int i = 0; i < BATCH_SIZE; ++i) { for (int j = 0; j < MAX_SEQ_LEN; ++j) { for (int k = j + 1; k < MAX_SEQ_LEN; ++k) { - maskData[i * MAX_SEQ_LEN * MAX_SEQ_LEN + j * MAX_SEQ_LEN + k] = -32768; + maskData[i * MAX_SEQ_LEN * MAX_SEQ_LEN + j * MAX_SEQ_LEN + k] = negtiveInf; } } } diff --git a/example/op_demo/self_attention/self_attention_prefix_encoder_demo.cpp b/example/op_demo/self_attention/self_attention_prefix_encoder_demo.cpp index 56f6e95524b41467683704ce8a2f3644d74072e7..bd2ee2086884ecfc3b3e5f70753ea3ae4daf0201 100644 --- a/example/op_demo/self_attention/self_attention_prefix_encoder_demo.cpp +++ b/example/op_demo/self_attention/self_attention_prefix_encoder_demo.cpp @@ -10,6 +10,7 @@ #include "../demo_util.h" +namespace { const uint32_t BATCH_SIZE = 4; // 批处理大小 std::vector seqLenHost = {16, 16, 32, 32}; // host侧tensor值,用于存储Query每个批处理中的序列长度 const uint32_t NTOKENS = accumulate(seqLenHost.begin(), seqLenHost.end(), 0); // sum(seqLenHost) @@ -19,6 +20,8 @@ const uint32_t HEAD_NUM = 32; const uint32_t KV_HEAD_NUM = 32; // kv头数 const uint32_t HEAD_SIZE = 64; // 头大小 const uint32_t BLOCK_SIZE = 128; // 以block存放的kv块大小 +const uint32_t BLOCK_TABLES_SIZE = 16; // blockTables大小 +} /** * @brief 准备atb::VariantPack中的所有输入tensor @@ -50,7 +53,7 @@ atb::Status PrepareInTensor(atb::Context *contextPtr, aclrtStream stream, std::v // 创建blockTables atb::Tensor tensorBlockTables; CHECK_STATUS(CreateTensor(ACL_INT32, aclFormat::ACL_FORMAT_ND, {BATCH_SIZE, 4}, tensorBlockTables)); - std::vector blockTablesData(16); + std::vector blockTablesData(BLOCK_TABLES_SIZE); std::iota(blockTablesData.begin(), blockTablesData.end(), 0); CHECK_STATUS(aclrtMemcpy(tensorBlockTables.deviceData, tensorBlockTables.dataSize, blockTablesData.data(), sizeof(int32_t) * blockTablesData.size(), ACL_MEMCPY_HOST_TO_DEVICE)); @@ -58,14 +61,14 @@ atb::Status PrepareInTensor(atb::Context *contextPtr, aclrtStream stream, std::v std::vector maskData = std::vector(HEAD_NUM * NTOKENS * 128, 0); // alibi128 mask for (int i = 0; i < HEAD_NUM; ++i) { for (int j = 0; j < NTOKENS; ++j) { - for (int k = j + 1; k < 128; ++k) { - maskData[i * NTOKENS * 128 + j * 128 + k] = 1; + for (int k = j + 1; k < BLOCK_SIZE; ++k) { + maskData[i * NTOKENS * BLOCK_SIZE + j * BLOCK_SIZE + k] = 1; } } } atb::Tensor tensorMask; CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, maskData, ACL_FLOAT16, aclFormat::ACL_FORMAT_ND, - {HEAD_NUM, NTOKENS, 128}, tensorMask)); + {HEAD_NUM, NTOKENS, BLOCK_SIZE}, tensorMask)); // 创建seqLen,host侧tensor atb::Tensor tensorSeqLen; CHECK_STATUS(CreateTensor(ACL_INT32, aclFormat::ACL_FORMAT_ND, {BATCH_SIZE}, tensorSeqLen)); diff --git a/example/op_demo/split/split_demo.cpp b/example/op_demo/split/split_demo.cpp index b4a92ce2974d3075382a5daba55f2cbd556ea4c4..e55c7b8fbbee272a0b35097327deb8bc61298d66 100644 --- a/example/op_demo/split/split_demo.cpp +++ b/example/op_demo/split/split_demo.cpp @@ -18,10 +18,11 @@ */ atb::Status RunSplitDemo(atb::Context *context, void *stream) { + const int32_t splitNum = 2; // 配置Op参数 atb::infer::SplitParam opParam; opParam.splitDim = 1; // 设定切分轴为1 - opParam.splitNum = 2; // 设置切分后得到的块数 + opParam.splitNum = splitNum; // 设置切分后得到的块数 opParam.splitSizes = {2, 3}; // 设置不均匀切分时每块大小 // 准备VariantPack diff --git a/example/op_demo/transdata/transdata_demo.cpp b/example/op_demo/transdata/transdata_demo.cpp index a1ea8fe6d382a09b3dbcd67cd42a39ab70b78da8..c7d45465dd843d29ef4c1fcbea0ff0638892120b 100644 --- a/example/op_demo/transdata/transdata_demo.cpp +++ b/example/op_demo/transdata/transdata_demo.cpp @@ -10,9 +10,11 @@ #include "../demo_util.h" +namespace { const uint32_t BATCH_SIZE = 8; // 批处理大小 const uint32_t SEQ_LEN = 100; // 序列长度 const uint32_t HIDDEN_SIZE = 30; // 隐藏层维度 +} /** * @brief 准备atb::VariantPack中的所有输入tensor diff --git a/example/op_demo/transpose/transpose_demo.cpp b/example/op_demo/transpose/transpose_demo.cpp index b73a16784f7f3cd8c4f21bceb5c4e091d33c236a..98f78a9a7e10fc2c2189b35c5ab1918519d029d3 100644 --- a/example/op_demo/transpose/transpose_demo.cpp +++ b/example/op_demo/transpose/transpose_demo.cpp @@ -10,8 +10,10 @@ #include "../demo_util.h" +namespace { const uint32_t DIM1 = 2; const uint32_t DIM2 = 3; +} /** * @brief 准备atb::VariantPack中的所有输入tensor diff --git a/include/atb/atb_acl.h b/include/atb/atb_acl.h index dfd8f0d4b54df0ba8df434c56437b180461d1894..919ad560f41f5e652e6b8f86135daa4b5b051d92 100644 --- a/include/atb/atb_acl.h +++ b/include/atb/atb_acl.h @@ -113,7 +113,7 @@ atb::Status AtbMLAGetWorkspaceSize(const aclTensor *qNope, const aclTensor *qRop //! \param context MLA算子的上下文参数 //! //! \return 表示函数是否执行成功的状态码 -atb::Status AtbMLA(void* workspace, uint64_t workspaceSize, atb::Operation *op, atb::Context *context); +atb::Status AtbMLA(void *workspace, uint64_t workspaceSize, atb::Operation *op, atb::Context *context); //! //! \brief MLA prefill 前处理接口 @@ -139,10 +139,11 @@ atb::Status AtbMLA(void* workspace, uint64_t workspaceSize, atb::Operation *op, //! //! \return 表示函数是否执行成功的状态码 atb::Status AtbMLAPreFillGetWorkspaceSize(const aclTensor *q, const aclTensor *qRope, const aclTensor *k, - const aclTensor *kRope, const aclTensor *v, const aclTensor *qSeqLen, const aclTensor *kvSeqLen, - const aclTensor *mask, int32_t headNum, float qkScale, int32_t kvHeadNum, - int maskType, uint8_t cacheMode, aclTensor *attenOut, - uint64_t *workspaceSize, atb::Operation **op, atb::Context *context); + const aclTensor *kRope, const aclTensor *v, const aclTensor *qSeqLen, + const aclTensor *kvSeqLen, const aclTensor *mask, int32_t headNum, + float qkScale, int32_t kvHeadNum, int maskType, uint8_t cacheMode, + aclTensor *attenOut, uint64_t *workspaceSize, atb::Operation **op, + atb::Context *context); //! //! \brief MLA prefill 处理接口 @@ -153,7 +154,7 @@ atb::Status AtbMLAPreFillGetWorkspaceSize(const aclTensor *q, const aclTensor *q //! \param context MLA算子的上下文参数 //! //! \return 表示函数是否执行成功的状态码 -atb::Status AtbMLAPreFill(void* workspace, uint64_t workspaceSize, atb::Operation *op, atb::Context *context); +atb::Status AtbMLAPreFill(void *workspace, uint64_t workspaceSize, atb::Operation *op, atb::Context *context); //! //! \brief 关于MlaPreprocess算子使用aclnn风格调用的2段式接口的第1段, @@ -326,7 +327,8 @@ atb::Status AtbRingMLA(void *workspace, uint64_t workspaceSize, atb::Operation * //! \param mask SelfAttentionPrefixEncoder算子的输入tensor(maskType为MASK_TYPE_CASUAL_MASK时,需要置为nullptr) //! \param seqLen SelfAttentionPrefixEncoder算子的输入tensor //! \param kvSeqLen SelfAttentionPrefixEncoder算子的输入tensor -//! \param slopes SelfAttentionPrefixEncoder算子的输入tensor(maskType不为MASK_TYPE_ALIBI_COMPRESS或MASK_TYPE_ALIBI_COMPRESS_SQRT时,需要置为nullptr) +//! \param slopes SelfAttentionPrefixEncoder算子的输入tensor +//! maskType不为MASK_TYPE_ALIBI_COMPRESS或MASK_TYPE_ALIBI_COMPRESS_SQRT时,需要置为nullptr) //! \param maskType SelfAttentionPrefixEncoder mask类型 //! \param headNum SelfAttentionPrefixEncoder算子头大小 diff --git a/include/atb/comm.h b/include/atb/comm.h index 67a2a5f896f90cee4ddec2998e3e14ae0ff5f631..c86fc5975a2da66075c1a26943e065687c4ccd8a 100644 --- a/include/atb/comm.h +++ b/include/atb/comm.h @@ -24,7 +24,7 @@ namespace atb { //! \brief 通信域指针 //! -using HcclComm = void*; +using HcclComm = void *; //! //! \namespace Comm //! @@ -52,8 +52,7 @@ HcclComm CreateHcclComm(int32_t rank, int32_t rankRoot, int32_t rankSize, char * //! //! \return 返回通信域指针 //! -HcclComm CreateHcclCommByRankTableFile(int32_t rank, int32_t rankSize, const char *rankTableFile, - char *commName); +HcclComm CreateHcclCommByRankTableFile(int32_t rank, int32_t rankSize, const char *rankTableFile, char *commName); //! //! \brief 创建HCCL多机通信域 diff --git a/include/atb/context.h b/include/atb/context.h index 9f9d8150b7b4a98b5042dc6d4fe3f19bddacfa8f..b22ddde3683d2371aaccc86fa783ff585ab2f08a 100644 --- a/include/atb/context.h +++ b/include/atb/context.h @@ -158,14 +158,15 @@ Status CreateContext(Context **context); //! 在当前进程或线程中显式创建一个由用户管理Tiling内存的Context. //! //! \param context 传入的context -//! +// ! //! \param alloc 传入的Tiling内存分配方法 //! //! \param dealloc 传入的Tiling内存释放方法 //! //! \return 状态值.如果设置成功,返回NO_ERROR. //! -Status CreateContext(Context **context, const std::function& alloc, const std::function& dealloc); +Status CreateContext(Context **context, const std::function &alloc, + const std::function &dealloc); //! //! \brief 销毁上下文. diff --git a/scripts/build.sh b/scripts/build.sh index 2542c458c0b3126a20fbcd837f145072c1a7c87c..9b5e4079e9fa1ddb34282b01dcfd2be70b69866c 100644 --- a/scripts/build.sh +++ b/scripts/build.sh @@ -470,7 +470,7 @@ function fn_build() fi fn_build_3rdparty_for_compile cd $CACHE_DIR - if [ "$CMAKE_CXX_COMPILER_LAUNCHER" == "" -a command -v ccache &> /dev/null ]; then + if [ "$CMAKE_CXX_COMPILER_LAUNCHER" == "" ] && command -v ccache &> /dev/null; then COMPILE_OPTIONS="${COMPILE_OPTIONS} -DCMAKE_CXX_COMPILER_LAUNCHER=ccache" fi echo "COMPILE_OPTIONS:$COMPILE_OPTIONS" diff --git a/scripts/set_env.sh b/scripts/set_env.sh index 61753aec630408c79e778b0fab331a22f528b01f..e53db017aa1232bd588695d4df15914529973719 100644 --- a/scripts/set_env.sh +++ b/scripts/set_env.sh @@ -44,7 +44,7 @@ if [[ -n "$ZSH_VERSION" ]]; then set_env_path="$0" fi -if [[ -f "$set_env_path" ]] && [[ "$set_env_path" =~ 'set_env.sh' ]];then +if [[ -f "$set_env_path" ]] && [[ "$set_env_path" =~ set_env.sh ]];then atb_path=$(cd $(dirname $set_env_path); pwd) get_cxx_abi_option "$@" export ATB_HOME_PATH="${atb_path}/cxx_abi_${cxx_abi}" diff --git a/src/atb/core/allocator/default_device_allocator.cpp b/src/atb/core/allocator/default_device_allocator.cpp index 846ad58874ef84ec93d0a02250b6ecad40aa0b05..48a470d3b48104ff4df8effb78b5c734fb739a27 100644 --- a/src/atb/core/allocator/default_device_allocator.cpp +++ b/src/atb/core/allocator/default_device_allocator.cpp @@ -24,11 +24,11 @@ DefaultDeviceAllocator::~DefaultDeviceAllocator() currentAllocateSize_ -= it->second; #ifdef _DEBUG ATB_LOG(INFO) << "DefaultDeviceAllocator::~DefaultDeviceAllocator aclrtFree free device buffer: " << it->first - << ", which the device bufferSize is " << it->second << ", currentAllocateSize_: " - << currentAllocateSize_; + << ", which the device bufferSize is " << it->second + << ", currentAllocateSize_: " << currentAllocateSize_; #else - ATB_LOG(INFO) << "DefaultDeviceAllocator::~DefaultDeviceAllocator aclrtFree free device bufferSize: " << it->second - << ", and currentAllocateSize_: " << currentAllocateSize_; + ATB_LOG(INFO) << "DefaultDeviceAllocator::~DefaultDeviceAllocator aclrtFree free device bufferSize: " + << it->second << ", and currentAllocateSize_: " << currentAllocateSize_; #endif } } @@ -52,10 +52,11 @@ void *DefaultDeviceAllocator::Allocate(size_t bufferSize) memMap.insert(std::make_pair(addr, bufferSize)); #ifdef _DEBUG ATB_LOG(INFO) << "DefaultDeviceAllocator::Allocate device buffer success, deivce buffer is " << addr - << ", which bufferSize is " << bufferSize << " and the currentAllocateSize_: " << currentAllocateSize_; + << ", which bufferSize is " << bufferSize + << " and the currentAllocateSize_: " << currentAllocateSize_; #else - ATB_LOG(INFO) << "DefaultDeviceAllocator::Allocate device buffer success, bufferSize is " << bufferSize << " currentAllocateSize_: " - << currentAllocateSize_; + ATB_LOG(INFO) << "DefaultDeviceAllocator::Allocate device buffer success, bufferSize is " << bufferSize + << " currentAllocateSize_: " << currentAllocateSize_; #endif return addr; } @@ -80,9 +81,9 @@ Status DefaultDeviceAllocator::Deallocate(void *addr) currentAllocateSize_ -= it->second; #ifdef _DEBUG ATB_LOG(INFO) << "DefaultDeviceAllocator::Deallocate device buffer success, free device buffer: " << addr - << ", which bufferSize is "<< it->second << ", currentAllocateSize_: " << currentAllocateSize_; + << ", which bufferSize is " << it->second << ", currentAllocateSize_: " << currentAllocateSize_; #else - ATB_LOG(INFO) << "DefaultDeviceAllocator::Deallocate device buffer success, free bufferSize: "<< it->second + ATB_LOG(INFO) << "DefaultDeviceAllocator::Deallocate device buffer success, free bufferSize: " << it->second << ", currentAllocateSize_: " << currentAllocateSize_; #endif memMap.erase(addr); diff --git a/src/atb/core/allocator/default_host_allocator.cpp b/src/atb/core/allocator/default_host_allocator.cpp index 76c82c35943ba51f741686e8a594ce2af4fbbd48..830df2bb12c684e38665d6fdd92091528a81ac9b 100644 --- a/src/atb/core/allocator/default_host_allocator.cpp +++ b/src/atb/core/allocator/default_host_allocator.cpp @@ -24,11 +24,11 @@ DefaultHostAllocator::~DefaultHostAllocator() currentAllocateSize_ -= it->second; #ifdef _DEBUG ATB_LOG(INFO) << "DefaultHostAllocator::~DefaultHostAllocator aclrtFreeHost free host buffer: " << it->first - << ", which the host bufferSize is " << it->second << ", currentAllocateSize_: " - << currentAllocateSize_; + << ", which the host bufferSize is " << it->second + << ", currentAllocateSize_: " << currentAllocateSize_; #else - ATB_LOG(INFO) << "DefaultHostAllocator::~DefaultHostAllocator aclrtFreeHost free host bufferSize: " << it->second - << ", and currentAllocateSize_: " << currentAllocateSize_; + ATB_LOG(INFO) << "DefaultHostAllocator::~DefaultHostAllocator aclrtFreeHost free host bufferSize: " + << it->second << ", and currentAllocateSize_: " << currentAllocateSize_; #endif } } @@ -52,10 +52,11 @@ void *DefaultHostAllocator::Allocate(size_t bufferSize) memMap.insert(std::make_pair(addr, bufferSize)); #ifdef _DEBUG ATB_LOG(INFO) << "DefaultHostAllocator::Allocate host buffer success, host buffer is " << addr - << ", which bufferSize is " << bufferSize << " and the currentAllocateSize_: " << currentAllocateSize_; + << ", which bufferSize is " << bufferSize + << " and the currentAllocateSize_: " << currentAllocateSize_; #else - ATB_LOG(INFO) << "DefaultHostAllocator::Allocate host buffer success, bufferSize is " << bufferSize << " currentAllocateSize_: " - << currentAllocateSize_; + ATB_LOG(INFO) << "DefaultHostAllocator::Allocate host buffer success, bufferSize is " << bufferSize + << " currentAllocateSize_: " << currentAllocateSize_; #endif return addr; } @@ -80,10 +81,10 @@ Status DefaultHostAllocator::Deallocate(void *addr) currentAllocateSize_ -= it->second; #ifdef _DEBUG ATB_LOG(INFO) << "DefaultHostAllocator::Deallocate host buffer success, free host buffer: " << addr - << ", which bufferSize is "<< it->second << ", currentAllocateSize_: " << currentAllocateSize_; + << ", which bufferSize is " << it->second << ", currentAllocateSize_: " << currentAllocateSize_; #else - ATB_LOG(INFO) << "DefaultHostAllocator::Deallocate host buffer success, free bufferSize: "<< it->second - << ", currentAllocateSize_: " << currentAllocateSize_; + ATB_LOG(INFO) << "DefaultHostAllocator::Deallocate host buffer success, free bufferSize: " << it->second + << ", currentAllocateSize_: " << currentAllocateSize_; #endif memMap.erase(addr); return NO_ERROR; diff --git a/src/atb/core/context.cpp b/src/atb/core/context.cpp index 79f8659d4a8bdf7d917ad7720ffd4d482bad3a34..7b341b4dbc2b95d7a023c9af0e038bbe00bfdcf5 100644 --- a/src/atb/core/context.cpp +++ b/src/atb/core/context.cpp @@ -38,7 +38,8 @@ Status CreateContext(Context **context) return NO_ERROR; } -Status CreateContext(Context **context, const std::function& alloc, const std::function& dealloc) +Status CreateContext(Context **context, const std::function &alloc, + const std::function &dealloc) { if (!context) { ATB_LOG(ERROR) << "param context is null, CreateContext fail"; diff --git a/src/atb/core/context_base.cpp b/src/atb/core/context_base.cpp index ae0812d856bd51a31f0d33bff2d79f9e6850d79d..411e1500981c31d8931714cb14bb2af09893c0c1 100644 --- a/src/atb/core/context_base.cpp +++ b/src/atb/core/context_base.cpp @@ -48,7 +48,7 @@ ContextBase::~ContextBase() noexcept } } -Status ContextBase::Init(const std::function& alloc, const std::function& dealloc) +Status ContextBase::Init(const std::function &alloc, const std::function &dealloc) { executeStreams_.resize(DEFAULT_EXECUTE_STREAM_NUMBER); @@ -63,17 +63,19 @@ Status ContextBase::Init(const std::function& alloc, const std::f return st; } if (alloc && dealloc) { - ATB_LOG(INFO) << "Using the Custom Allocate Function and Deallocate Funciton to allocate and deallocate device tiling buffer"; + ATB_LOG(INFO) << "Using the Custom Allocate Function and Deallocate Funciton to allocate and " + "deallocate device tiling buffer"; allocateFunc_ = alloc; deallocateFunc_ = dealloc; } else if (!alloc && !dealloc) { - ATB_LOG(INFO) << "Using the Default Allocate Function and Default Deallocate Function to allocate and deallocate device tiling buffer"; + ATB_LOG(INFO) << "Using the Default Allocate Function and Default Deallocate Function to allocate and " + "deallocate device tiling buffer"; } else { ATB_LOG(ERROR) << "Can not support to pass in only Allocate Function or Deallocate Function"; return ERROR_INVALID_PARAM; } - deviceTilingBufferPool_ = std::make_unique(GetSingleton().GetDeviceTilingBlockNum(), - TILING_BUFFER_BLOCK_SIZE, allocateFunc_, deallocateFunc_); + deviceTilingBufferPool_ = std::make_unique( + GetSingleton().GetDeviceTilingBlockNum(), TILING_BUFFER_BLOCK_SIZE, allocateFunc_, deallocateFunc_); if (!deviceTilingBufferPool_) { return ERROR_OUT_OF_HOST_MEMORY; } @@ -175,7 +177,7 @@ uint8_t *ContextBase::GetHostTilingBuffer() // 如果走图模式的话直接使用hostAllocator申请内存 if (mode_ == GRAPH_LAUNCH_MODE) { ATB_LOG(INFO) << "At GRAPH_LAUNCH_MODE, contextBase start allocate host tiling buffer using Allocator"; - return reinterpret_cast(hostAllocator_->Allocate(TILING_BUFFER_BLOCK_SIZE)); + return reinterpret_cast(hostAllocator_->Allocate(TILING_BUFFER_BLOCK_SIZE)); } return hostTilingBufferPool_ ? hostTilingBufferPool_->GetBuffer() : nullptr; } @@ -185,7 +187,7 @@ uint8_t *ContextBase::GetDeviceTilingBuffer() // 如果走图模式的话直接使用deviceAllocator申请内存 if (mode_ == GRAPH_LAUNCH_MODE) { ATB_LOG(INFO) << "At GRAPH_LAUNCH_MODE, contextBase start allocate device tiling buffer using Allocator"; - return reinterpret_cast(deviceAllocator_->Allocate(TILING_BUFFER_BLOCK_SIZE)); + return reinterpret_cast(deviceAllocator_->Allocate(TILING_BUFFER_BLOCK_SIZE)); } return deviceTilingBufferPool_ ? deviceTilingBufferPool_->GetBuffer() : nullptr; } @@ -349,7 +351,7 @@ Status ContextBase::FreeArgsHostBuffer(void *addr) { return hostAllocator_->Deallocate(addr); } -bool ContextBase::GetLaunchWithTilingStatus() +bool ContextBase::GetLaunchWithTilingStatus() const { return mode_ != GRAPH_LAUNCH_MODE; } diff --git a/src/atb/core/node_impl/mki_node_implement.cpp b/src/atb/core/node_impl/mki_node_implement.cpp index dcdc54622cf30fb92e7c9de1dc83b103804a5cea..eacb23b7ed3531494ec323487992c557b0d7352a 100644 --- a/src/atb/core/node_impl/mki_node_implement.cpp +++ b/src/atb/core/node_impl/mki_node_implement.cpp @@ -15,7 +15,6 @@ #include "atb/utils/tensor_util.h" #include "atb/utils/statistic.h" #include "atb/utils/store_util.h" -#include "atb/utils/singleton.h" #include "atb/utils/probe.h" namespace atb { @@ -385,4 +384,4 @@ Status MkiNodeImplement::BuildLaunchParam(const SVector &inTensor } return NO_ERROR; } -} // namespace atb \ No newline at end of file +} // namespace atb diff --git a/src/atb/core/runner_pool.cpp b/src/atb/core/runner_pool.cpp index 5853eda030f8dc471d927370cbd521a3b4238ff4..d8487f0315e45e53862561160002cd7e638f86d0 100644 --- a/src/atb/core/runner_pool.cpp +++ b/src/atb/core/runner_pool.cpp @@ -11,7 +11,6 @@ #include "atb/core/runner_pool.h" #include "atb/utils/config.h" #include "atb/runner/runner.h" -#include "atb/utils/singleton.h" static const uint32_t DEFAULT_RUNNER_POOL_SIZE = 64; diff --git a/src/atb/core/tiling_buffer_pool/device_tiling_buffer_pool.cpp b/src/atb/core/tiling_buffer_pool/device_tiling_buffer_pool.cpp index c9df0409a9f7efe21dd5b8f0791ea9645ed51663..5f086f38374b309b1e2bbb0cecaae9b808ebb761 100644 --- a/src/atb/core/tiling_buffer_pool/device_tiling_buffer_pool.cpp +++ b/src/atb/core/tiling_buffer_pool/device_tiling_buffer_pool.cpp @@ -11,7 +11,9 @@ #include "atb/utils/log.h" namespace atb { -DeviceTilingBufferPool::DeviceTilingBufferPool(uint64_t blockNum, uint64_t blockSize, const std::function& alloc, const std::function& dealloc) +DeviceTilingBufferPool::DeviceTilingBufferPool(uint64_t blockNum, uint64_t blockSize, + const std::function &alloc, + const std::function &dealloc) : TilingBufferPool(blockNum, blockSize), allocateFunc_(alloc), deallocateFunc_(dealloc) { } diff --git a/src/atb/operation/operation_base.cpp b/src/atb/operation/operation_base.cpp index 1a6106ce630cdbf4275ec108a1ffc808bd2dad3a..c1fa98add0d9682465209d3b0ea697515b3c5f44 100644 --- a/src/atb/operation/operation_base.cpp +++ b/src/atb/operation/operation_base.cpp @@ -157,7 +157,7 @@ void OperationBase::InitEmptyOutTensorPerms() const } ATB_LOG(INFO) << GetLogPrefix() << "InitEmptyOutTensorPerms finished:" << emptyOutTensorPerms_; } - + SVector OperationBase::GetEmptyOutTensorPermissions() const { if (emptyOutTensorPerms_.size() == 0) { @@ -728,18 +728,18 @@ Status OperationBase::CopyTilingToDevice() template Status OperationBase::ExecuteVariantPackInTensorCheck(const SVector &inTensors) const { - std::string Prefix = GetLogPrefix(); + std::string prefix = GetLogPrefix(); if (inTensors.size() != runnerVariantPack_.inTensors.size()) { - ATB_LOG(ERROR) << GetLogPrefix() << "execute inTensors.size:" << inTensors.size() + ATB_LOG(ERROR) << prefix << "execute inTensors.size:" << inTensors.size() << " != setup inTensors.size:" << runnerVariantPack_.inTensors.size(); return ERROR_INVALID_PARAM; } SVector emptyInTensorPerms = GetEmptyInTensorPermissions(); for (size_t i = 0; i < inTensors.size(); i++) { const Tensor &variantPackInTensor = inTensors.at(i); - if (Prefix.find("WithStride") == std::string::npos && // "WithStride" indicates non continuous tensors + if (prefix.find("WithStride") == std::string::npos && // "WithStride" indicates non continuous tensors variantPackInTensor.dataSize != Utils::GetTensorSize(runnerVariantPack_.inTensors.at(i).desc)) { - ATB_LOG(ERROR) << GetLogPrefix() << "execute variantPack.inTensors(" << i + ATB_LOG(ERROR) << prefix << "execute variantPack.inTensors(" << i << ").dataSize is Not equal to the setup dataSize"; return ERROR_INVALID_PARAM; } @@ -757,9 +757,9 @@ Status OperationBase::ExecuteVariantPackInTensorCheck(const SVector template Status OperationBase::ExecuteVariantPackOutTensorCheck(const SVector &outTensors) const { - std::string Prefix = GetLogPrefix(); + std::string prefix = GetLogPrefix(); if (outTensors.size() != runnerVariantPack_.outTensors.size()) { - ATB_LOG(ERROR) << GetLogPrefix() << "execute outTensors.size:" << outTensors.size() + ATB_LOG(ERROR) << prefix << "execute outTensors.size:" << outTensors.size() << " != setup outTensors.size:" << runnerVariantPack_.outTensors.size(); return ERROR_INVALID_PARAM; } @@ -767,7 +767,7 @@ Status OperationBase::ExecuteVariantPackOutTensorCheck(const SVector for (size_t i = 0; i < outTensors.size(); i++) { const Tensor &variantPackOutTensor = outTensors.at(i); if (variantPackOutTensor.dataSize != Utils::GetTensorSize(runnerVariantPack_.outTensors.at(i).desc)) { - ATB_LOG(ERROR) << GetLogPrefix() << "execute variantPack.outTensors(" << i + ATB_LOG(ERROR) << prefix << "execute variantPack.outTensors(" << i << ").dataSize is Not equal to the setup dataSize"; return ERROR_INVALID_PARAM; } @@ -776,15 +776,14 @@ Status OperationBase::ExecuteVariantPackOutTensorCheck(const SVector continue; } if (!variantPackOutTensor.deviceData && !variantPackOutTensor.hostData) { - ATB_LOG(ERROR) << GetLogPrefix() << "execute variantPack.outTensors(" << i - << ") deviceData&hostData is null"; + ATB_LOG(ERROR) << prefix << "execute variantPack.outTensors(" << i << ") deviceData&hostData is null"; return ERROR_INVALID_PARAM; } } return NO_ERROR; } -Status OperationBase::ExecuteVariantPackCheck(const VariantPack &variantPack) +Status OperationBase::ExecuteVariantPackCheck(const VariantPack &variantPack) const { Status st = NO_ERROR; st = ExecuteVariantPackInTensorCheck(variantPack.inTensors); @@ -946,10 +945,13 @@ Status OperationBase::GraphModePreLaunch(const VariantPack &variantPack, uint8_t } else if (workspace > runnerVariantPack_.workspaceBuffer) { // 如果workspace发生了变化,计算workspace变化带来的偏移量时需要再加上workspaceBufferSize才是中间tensor对应内存的起始地址 runnerVariantPack_.intermediateBuffer = workspace - - reinterpret_cast(runnerVariantPack_.workspaceBuffer) + runnerVariantPack_.workspaceBufferSize; + reinterpret_cast(runnerVariantPack_.workspaceBuffer) + + runnerVariantPack_.workspaceBufferSize; #ifdef _DEBUG - ATB_LOG(INFO) << GetLogPrefix() << "changing the old workspace: " << static_cast(runnerVariantPack_.workspaceBuffer) - << " to new workspace: " << static_cast(workspace) << ", and the runnerVariantPack_.intermediateBuffer: " + ATB_LOG(INFO) << GetLogPrefix() + << "changing the old workspace: " << static_cast(runnerVariantPack_.workspaceBuffer) + << " to new workspace: " << static_cast(workspace) + << ", and the runnerVariantPack_.intermediateBuffer: " << static_cast(runnerVariantPack_.intermediateBuffer); #endif runnerVariantPack_.workspaceBuffer = workspace; @@ -957,10 +959,13 @@ Status OperationBase::GraphModePreLaunch(const VariantPack &variantPack, uint8_t st = runner_->UpdateWorkspaceBuffer(runnerVariantPack_); } else { runnerVariantPack_.intermediateBuffer = runnerVariantPack_.workspaceBuffer - - reinterpret_cast(workspace) + runnerVariantPack_.workspaceBufferSize; + reinterpret_cast(workspace) + + runnerVariantPack_.workspaceBufferSize; #ifdef _DEBUG - ATB_LOG(INFO) << GetLogPrefix() << "changing the old workspace: " << static_cast(runnerVariantPack_.workspaceBuffer) - << " to new workspace: " << static_cast(workspace) << ", and the runnerVariantPack_.intermediateBuffer: " + ATB_LOG(INFO) << GetLogPrefix() + << "changing the old workspace: " << static_cast(runnerVariantPack_.workspaceBuffer) + << " to new workspace: " << static_cast(workspace) + << ", and the runnerVariantPack_.intermediateBuffer: " << static_cast(runnerVariantPack_.intermediateBuffer); #endif runnerVariantPack_.workspaceBuffer = workspace; @@ -1000,7 +1005,7 @@ Status OperationBase::Launch() Status OperationBase::EagerModeLaunch() { - Mki::Timer ExecuteTime; + Mki::Timer executeTime; void *executeStream = GetExecuteStream(runnerVariantPack_.context); #ifdef _DEBUG ATB_LOG(INFO) << GetLogPrefix() << "execute " << runner_->GetName() << "_" << runner_.get() << " start"; @@ -1031,7 +1036,7 @@ Status OperationBase::EagerModeLaunch() int ret = aclrtSynchronizeStream(executeStream); ATB_LOG_IF(ret != 0, ERROR) << GetLogPrefix() << "stream sync fail, ret:" << ret; } - GetOpExecuteStatistic().launchTime += ExecuteTime.ElapsedMicroSecond(); + GetOpExecuteStatistic().launchTime += executeTime.ElapsedMicroSecond(); GetOpExecuteStatistic().totalTime += GetOpExecuteStatistic().preLaunchTime + GetOpExecuteStatistic().launchTime; ATB_LOG(INFO) << GetLogPrefix() << "execute statistic:" << GetOpExecuteStatistic().ToString(); return st; @@ -1074,7 +1079,7 @@ Status OperationBase::Execute(const VariantPack &variantPack, uint8_t *workspace OPERATION_EXECUTE : (executeType == EXECUTE_PRELAUNCH ? OPERATION_PRELAUNCH : OPERATION_LAUNCH); std::shared_ptr mstxMemRegister{nullptr}; - if (workspaceSize && MstxMemRegister::IsMstxEnable()) { + if (workspaceSize != 0 && MstxMemRegister::IsMstxEnable()) { mstxMemRegister = std::make_shared(); if (mstxMemRegister->MstxHeapRegister(workspace, workspaceSize) == NO_ERROR) { runnerVariantPack_.mstxMemRegister = mstxMemRegister.get(); @@ -1228,7 +1233,8 @@ void OperationBase::FillHostTilingBuffer() } Mki::Timer runnerFillHostTilingTimer; - Status st = runner_->FillHostTilingBuffer(hostTilingBuffer_, runnerVariantPack_.tilingBufferSize, runnerVariantPack_.context); + Status st = runner_->FillHostTilingBuffer(hostTilingBuffer_, runnerVariantPack_.tilingBufferSize, + runnerVariantPack_.context); if (st != NO_ERROR) { ATB_LOG(ERROR) << GetLogPrefix() << "fill host tiling buffer fail"; return; @@ -1347,12 +1353,12 @@ aclrtStream OperationBase::GetExecuteStream(Context *context) const return streams.at(streamId_); } -Status OperationBase::CopyArgsToDevice(Context *context) +Status OperationBase::CopyArgsToDevice(Context *context) const { Status st = NO_ERROR; #ifdef _DEBUG ATB_LOG(DEBUG) << GetLogPrefix() << "args in graphMode is:"; - const size_t counter = argsBufferSize_ / sizeof(void *); + const size_t counter = argsBufferSize_ / sizeof(void *); for (size_t i = 0; i < counter; i++) { ATB_LOG(DEBUG) << ((void **)(hostArgsBuffer_))[i]; } diff --git a/src/atb/runner/graph_runner.cpp b/src/atb/runner/graph_runner.cpp index 62df27e603112a57c15333a805b831f236b47cde..d203b999868bfca0e3a3f5e4df7f43bfc76a0336 100644 --- a/src/atb/runner/graph_runner.cpp +++ b/src/atb/runner/graph_runner.cpp @@ -68,8 +68,9 @@ void GraphRunner::Graph::SetNonReuseTensors() for (size_t nodeId = 0; nodeId < nodes.size(); ++nodeId) { auto &node = nodes.at(nodeId); uint32_t streamId = GetExecuteStreamId(node.op.get()); - if (streamId == 0) + if (streamId == 0) { continue; + } for (auto inTensorIt : node.inTensors) { auto it = isInTensorCanFree.find(inTensorIt); if (it != isInTensorCanFree.end()) { @@ -353,14 +354,13 @@ uint64_t GraphRunner::GetTilingBufferSizeImpl() return totalTilingBufferSize_; } -Status GraphRunner::FillHostTilingBufferImpl(uint8_t *hostTilingBuffer, uint64_t tilingBufferSize, - ContextBase *context) +Status GraphRunner::FillHostTilingBufferImpl(uint8_t *hostTilingBuffer, uint64_t tilingBufferSize, ContextBase *context) { uint64_t tilingOffset = 0; for (size_t nodeId = 0; nodeId < runnerGraph_.nodes.size(); ++nodeId) { auto &node = runnerGraph_.nodes.at(nodeId); - Status ret = node.runner->FillHostTilingBuffer(hostTilingBuffer + tilingOffset, - tilingBufferSizes_.at(nodeId), context); + Status ret = + node.runner->FillHostTilingBuffer(hostTilingBuffer + tilingOffset, tilingBufferSizes_.at(nodeId), context); if (ret != NO_ERROR) { ATB_LOG(ERROR) << GetLogPrefix() << "GraphRunner::FillHostTilingBufferImpl failed! error code:" << ret; return ret; @@ -872,7 +872,7 @@ void GraphRunner::UpdateVariantPackBuffer(RunnerVariantPack &runnerVariantPack) node.runnerVariantPack.argsDeviceBuffer = runnerVariantPack.argsDeviceBuffer + offset; offset += node.runner->GetArgsSize(); ATB_LOG(DEBUG) << GetLogPrefix() << "Graph node " << nodeId << " argsDeviceAddr is " - << reinterpret_cast(node.runnerVariantPack.argsDeviceBuffer); + << static_cast(node.runnerVariantPack.argsDeviceBuffer); } } @@ -883,7 +883,7 @@ void GraphRunner::UpdateVariantPackBuffer(RunnerVariantPack &runnerVariantPack) node.runnerVariantPack.argsHostBuffer = runnerVariantPack.argsHostBuffer + offset; offset += node.runner->GetArgsSize(); ATB_LOG(DEBUG) << GetLogPrefix() << "Graph node " << nodeId << " argsHostAddr is " - << reinterpret_cast(node.runnerVariantPack.argsHostBuffer); + << static_cast(node.runnerVariantPack.argsHostBuffer); } } ATB_LOG(INFO) << GetLogPrefix() << " update runner variant pack's buffer end"; @@ -943,19 +943,23 @@ Status GraphRunner::ExecuteAllRunner(RunnerVariantPack &runnerVariantPack) { for (size_t nodeId = 0; nodeId < runnerGraph_.nodes.size(); ++nodeId) { auto &node = runnerGraph_.nodes.at(nodeId); - ATB_LOG(INFO) << GetLogPrefix() << " mstx registe tensor.data node[" << nodeId << "]" << "graphrunner start"; - if (runnerVariantPack.mstxMemRegister != nullptr && !(dynamic_cast(node.runner.get()))) { + ATB_LOG(INFO) << GetLogPrefix() << " mstx registe tensor.data node[" << nodeId << "]" + << "graphrunner start"; + if (runnerVariantPack.mstxMemRegister != nullptr && + !(dynamic_cast(node.runner.get()) || dynamic_cast(node.runner.get()))) { runnerVariantPack.mstxMemRegister->ClearMstxMemRegions(); for (size_t i = 0; i < node.runnerVariantPack.inTensors.size(); ++i) { auto &tensor = node.runnerVariantPack.inTensors.at(i); if (node.inTensorTypes.at(i) == GraphRunner::INTERMEDIATE_TENSOR) { - runnerVariantPack.mstxMemRegister->AddTensorMemRegions(tensor.deviceData, static_cast(TensorUtil::AlignInt(tensor.dataSize, ALIGN_INT))); + runnerVariantPack.mstxMemRegister->AddTensorMemRegions( + tensor.deviceData, static_cast(TensorUtil::AlignInt(tensor.dataSize, ALIGN_INT))); } } for (size_t i = 0; i < node.runnerVariantPack.outTensors.size(); ++i) { auto &tensor = node.runnerVariantPack.outTensors.at(i); if (node.outTensorTypes.at(i) == GraphRunner::INTERMEDIATE_TENSOR) { - runnerVariantPack.mstxMemRegister->AddTensorMemRegions(tensor.deviceData, static_cast(TensorUtil::AlignInt(tensor.dataSize, ALIGN_INT))); + runnerVariantPack.mstxMemRegister->AddTensorMemRegions( + tensor.deviceData, static_cast(TensorUtil::AlignInt(tensor.dataSize, ALIGN_INT))); } } if (runnerVariantPack.mstxMemRegister->CheckTensorRange()) { @@ -963,14 +967,14 @@ Status GraphRunner::ExecuteAllRunner(RunnerVariantPack &runnerVariantPack) } } ATB_LOG(INFO) << GetLogPrefix() << " node[" << nodeId << "] execute start, runner:" << node.runner->GetName() - << ", variantPack:\n" - << node.runnerVariantPack.ToString(); + << ", variantPack:\n" + << node.runnerVariantPack.ToString(); node.runnerVariantPack.context = runnerVariantPack.context; node.runnerVariantPack.mstxMemRegister = runnerVariantPack.mstxMemRegister; Status st = node.runner->Execute(node.runnerVariantPack); if (st != 0) { ATB_LOG(ERROR) << GetLogPrefix() << " node[" << nodeId - << "] execute fail, runner name:" << node.runner->GetName(); + << "] execute fail, runner name:" << node.runner->GetName(); return st; } if (runnerVariantPack.mstxMemRegister != nullptr && runnerVariantPack.mstxMemRegister->CheckTensorRange()) { diff --git a/src/atb/runner/ops_runner.cpp b/src/atb/runner/ops_runner.cpp index eb754df9284102773211057097da50d1a4172e89..d5f022d4497d52b69f7833431689d029b50666b4 100644 --- a/src/atb/runner/ops_runner.cpp +++ b/src/atb/runner/ops_runner.cpp @@ -617,17 +617,21 @@ Status OpsRunner::RunAllKernel(RunnerVariantPack &runnerVariantPack) KernelGraphNode &node = kernelGraph_.nodes.at(nodeId); if (runnerVariantPack.mstxMemRegister != nullptr) { runnerVariantPack.mstxMemRegister->ClearMstxMemRegions(); - if (runnerVariantPack.workspaceBufferSize) { - runnerVariantPack.workspaceBufferSize = static_cast(TensorUtil::AlignInt(runnerVariantPack.workspaceBufferSize, ALIGN_INT)); - runnerVariantPack.mstxMemRegister->AddTensorMemRegions(runnerVariantPack.workspaceBuffer, runnerVariantPack.workspaceBufferSize); + if (runnerVariantPack.workspaceBufferSize != 0) { + runnerVariantPack.workspaceBufferSize = + static_cast(TensorUtil::AlignInt(runnerVariantPack.workspaceBufferSize, ALIGN_INT)); + runnerVariantPack.mstxMemRegister->AddTensorMemRegions(runnerVariantPack.workspaceBuffer, + runnerVariantPack.workspaceBufferSize); } auto &inTensors = node.impl->GetInTensors(); const uint64_t inTensorsSize = inTensors.size(); for (uint64_t tensorId = 0; tensorId < inTensorsSize; tensorId++) { Mki::Tensor &tensor = inTensors.at(tensorId); if (node.inTensorsType.at(tensorId) == TensorType::INTERMEDIATE_TENSOR) { - tensor.data = runnerVariantPack.intermediateBuffer + reinterpret_cast(node.inTensors.at(tensorId)->data); - runnerVariantPack.mstxMemRegister->AddTensorMemRegions(tensor.data, static_cast(TensorUtil::AlignInt(tensor.dataSize, ALIGN_INT))); + tensor.data = runnerVariantPack.intermediateBuffer + + reinterpret_cast(node.inTensors.at(tensorId)->data); + runnerVariantPack.mstxMemRegister->AddTensorMemRegions( + tensor.data, static_cast(TensorUtil::AlignInt(tensor.dataSize, ALIGN_INT))); } } auto &outTensors = node.impl->GetOutTensors(); @@ -635,8 +639,10 @@ Status OpsRunner::RunAllKernel(RunnerVariantPack &runnerVariantPack) for (uint64_t tensorId = 0; tensorId < outTensorsSize; tensorId++) { Mki::Tensor &tensor = outTensors.at(tensorId); if (node.outTensorsType.at(tensorId) == TensorType::INTERMEDIATE_TENSOR) { - tensor.data = runnerVariantPack.intermediateBuffer + reinterpret_cast(node.outTensors.at(tensorId)->data); - runnerVariantPack.mstxMemRegister->AddTensorMemRegions(tensor.data, static_cast(TensorUtil::AlignInt(tensor.dataSize, ALIGN_INT))); + tensor.data = runnerVariantPack.intermediateBuffer + + reinterpret_cast(node.outTensors.at(tensorId)->data); + runnerVariantPack.mstxMemRegister->AddTensorMemRegions( + tensor.data, static_cast(TensorUtil::AlignInt(tensor.dataSize, ALIGN_INT))); } } if (runnerVariantPack.mstxMemRegister->CheckTensorRange()) { @@ -735,8 +741,8 @@ bool OpsRunner::UpdateNodeBestKernelAndTiling(KernelGraphNode &node, size_t node uint64_t maxTilingSize, bool launchWithTiling) { uint64_t tilingSize = 0; - bool getTilingSuccess = GetCachedTiling(node, nodeId, kernelHostTilingBuffer, - maxTilingSize, tilingSize, launchWithTiling); + bool getTilingSuccess = + GetCachedTiling(node, nodeId, kernelHostTilingBuffer, maxTilingSize, tilingSize, launchWithTiling); if (!node.impl->UpdateBestKernel()) { ATB_LOG(ERROR) << GetLogPrefix() << " node[" << nodeId << "] " << node.GetName() << " update best kernel failed"; @@ -1083,8 +1089,8 @@ bool OpsRunner::GetCachedTiling(KernelGraphNode &node, size_t nodeId, uint8_t *k for (size_t i = 0; i < kernelCachesSize; ++i) { KernelCache *kernelCache = kernelCaches_.at(i).first; bool isLocalCache = kernelCaches_.at(i).second; - bool getTilingSuccess = node.impl->GetCachedTiling(*kernelCache, nodeId, kernelHostTilingBuffer, - maxTilingSize, tilingSizeFetched, launchWithTiling); + bool getTilingSuccess = node.impl->GetCachedTiling(*kernelCache, nodeId, kernelHostTilingBuffer, maxTilingSize, + tilingSizeFetched, launchWithTiling); if (getTilingSuccess) { ATB_LOG(INFO) << GetLogPrefix() << " node[" << nodeId << "] kernel cache get last tiling"; IncreaseStatisticCacheHitCount(isLocalCache); @@ -1408,10 +1414,11 @@ Status OpsRunner::UpdateWorkspaceBuffer(RunnerVariantPack &runnerVariantPack) KernelGraphNode &node = kernelGraph_.nodes.at(nodeId); if (needSetworkspace) { #ifdef _DEBUG - ATB_LOG(INFO) << GetLogPrefix() << "node[" << nodeId << "] update kernel runinfo workspaceBuffer, and new workspaceBuffer is " - << static_cast(runnerVariantPack.workspaceBuffer); + ATB_LOG(INFO) << GetLogPrefix() << "node[" << nodeId + << "] update kernel runinfo workspaceBuffer, and new workspaceBuffer is " + << static_cast(runnerVariantPack.workspaceBuffer); #else - ATB_LOG(INFO) << GetLogPrefix() << "node[" << nodeId << "] update kernel runinfo workspaceBuffer"; + ATB_LOG(INFO) << GetLogPrefix() << "node[" << nodeId << "] update kernel runinfo workspaceBuffer"; #endif node.impl->SetWorkspaceDeviceAddr(runnerVariantPack.workspaceBuffer); } diff --git a/src/atb/utils/comm.cpp b/src/atb/utils/comm.cpp index 9534aca04df700c39b5a32a74526fb95bf7bb27f..44a8e946b68d0240a6f6fedc5d490b1f2ea87d4d 100644 --- a/src/atb/utils/comm.cpp +++ b/src/atb/utils/comm.cpp @@ -210,7 +210,7 @@ std::shared_ptr CreateHcclCommByClusterInfo(uint32_t subCommRankId, const << ret; return std::shared_ptr(); } - return std::shared_ptr(static_cast(newHcclComm), [=](void *hcclComm) { + return std::shared_ptr(static_cast(newHcclComm), [=](const void *hcclComm) { (void)hcclComm; ATB_LOG(INFO) << "destroy hcclComm, but not call HcclCommDestroy hcclComm:" << hcclComm; }); diff --git a/src/atb/utils/common_utils.cpp b/src/atb/utils/common_utils.cpp index 3e5b5a42bbdf3569ad1dcb49179c18ce2ca8a7a9..0f109baa99319cbe74211b94988747d31395660e 100644 --- a/src/atb/utils/common_utils.cpp +++ b/src/atb/utils/common_utils.cpp @@ -48,7 +48,7 @@ HcclDataType GetHcclDtype(const aclDataType dtype) return HCCL_DATA_TYPE_BFP16; default: ATB_LOG(ERROR) << "not support dtype:" << dtype; - return static_cast(255); // RESERVED TYPE + return static_cast(255); // 255: RESERVED TYPE } } diff --git a/src/atb/utils/mstx_mem_register.cpp b/src/atb/utils/mstx_mem_register.cpp index 490a5cf885978896951660aaac268ebfbc59898a..a37001a5ec5ce3ef8a2a132cef876e723321538b 100644 --- a/src/atb/utils/mstx_mem_register.cpp +++ b/src/atb/utils/mstx_mem_register.cpp @@ -8,8 +8,6 @@ * See LICENSE in the root of the software repository for the full text of the License. */ #include "atb/utils/mstx_mem_register.h" -#include -#include #include #include #include "atb/utils/log.h" @@ -114,13 +112,9 @@ void MstxMemRegister::AddTensorMemRegions(void *ptr, uint64_t size) } } -Status MstxMemRegister::CheckTensorRange() +bool MstxMemRegister::CheckTensorRange() { - if (rangesDesc_.empty()) { - return false; - } else { - return true; - } + return !rangesDesc_.empty(); } } // namespace atb diff --git a/src/atb/utils/operation_util.cpp b/src/atb/utils/operation_util.cpp index 6598d340194cf2d588cd8d385a54ab98e663b753..389f6bb08da87b885819cbbf972684e64aa2d712 100644 --- a/src/atb/utils/operation_util.cpp +++ b/src/atb/utils/operation_util.cpp @@ -475,8 +475,9 @@ bool OperationUtil::MatmulInputWeightShapeCheck(const SVector &inTen // check batch if (!param.isMoe) { bool isPass; - if (CheckBatch(inTensorDescs, logPrefix, param, isPass)) + if (CheckBatch(inTensorDescs, logPrefix, param, isPass)) { return isPass; + } } // check k diff --git a/src/cinterface/atb_acl_fused_add_topk_div.cpp b/src/cinterface/atb_acl_fused_add_topk_div.cpp index 587eddbb0d374bf14d327883aac26251e1b332b4..2249ea550b6d02bcf3a1a96ec7b7d2e62d1033fd 100644 --- a/src/cinterface/atb_acl_fused_add_topk_div.cpp +++ b/src/cinterface/atb_acl_fused_add_topk_div.cpp @@ -14,8 +14,8 @@ #ifdef __cplusplus extern "C" { #endif -const size_t g_FUSED_ADD_TOPK_INTENSOR_NUM = 2; -const size_t g_FUSED_ADD_TOPK_OUTTENSOR_NUM = 2; +const size_t FUSED_ADD_TOPK_INTENSOR_NUM = 2; +const size_t FUSED_ADD_TOPK_OUTTENSOR_NUM = 2; atb::Status AtbFusedAddTopkDivGetWorkspaceSize(const aclTensor *x, const aclTensor *addNum, const aclTensor *mappingNum, const aclTensor *mappingTable, uint32_t groupNum, uint32_t groupTopk, @@ -42,7 +42,7 @@ atb::Status AtbFusedAddTopkDivGetWorkspaceSize(const aclTensor *x, const aclTens } atb::VariantPack pack; - size_t intensorNum = g_FUSED_ADD_TOPK_INTENSOR_NUM; + size_t intensorNum = FUSED_ADD_TOPK_INTENSOR_NUM; if (enableExpertMapping) { intensorNum += 2; // 2: mappingNum, mappingTable } @@ -60,7 +60,7 @@ atb::Status AtbFusedAddTopkDivGetWorkspaceSize(const aclTensor *x, const aclTens } index = 0; - pack.outTensors.resize(g_FUSED_ADD_TOPK_OUTTENSOR_NUM); + pack.outTensors.resize(FUSED_ADD_TOPK_OUTTENSOR_NUM); status = aclTensorToAtbTensor(y, &(pack.outTensors[index++])); ATB_CHECK(status == atb::NO_ERROR, "y create failed!", return status); status = aclTensorToAtbTensor(indices, &(pack.outTensors[index++])); diff --git a/src/cinterface/atb_acl_mla.cpp b/src/cinterface/atb_acl_mla.cpp index 8e6bc83497d2db8fa844a7e24c41c8fba20cdcf6..9cff31510f291bc3e4ab052413f546da20ad3560 100644 --- a/src/cinterface/atb_acl_mla.cpp +++ b/src/cinterface/atb_acl_mla.cpp @@ -15,12 +15,12 @@ extern "C" { #endif -const size_t g_MLAINTENSORNUMINT8NOMASK = 9; -const size_t g_MLAINTENSORNUMINT8MASK = 10; -const size_t g_MLAINTENSORNUMNOMASK = 7; -const size_t g_MLAINTENSORNUMMASK = 8; -const size_t g_MLAOUTTENSORNUMCALCRING = 2; -const size_t g_MLAOUTTENSORNUMNOCALCRING = 1; +const size_t MLA_INTENSOR_NUM_INT8_NO_MASK = 9; +const size_t MLA_INTENSOR_NUM_INT8_MASK = 10; +const size_t MLA_INTENSOR_NUM_NO_MASK = 7; +const size_t MLA_INTENSOR_NUM_MASK = 8; +const size_t MLA_OUTTENSOR_NUM_CALCRING = 2; +const size_t MLA_OUTTENSOR_NUM_NO_CALCRING = 1; atb::Status AtbMLAGetWorkspaceSize(const aclTensor *qNope, const aclTensor *qRope, const aclTensor *ctKV, const aclTensor *kRope, const aclTensor *blockTables, const aclTensor *contextLens, @@ -48,19 +48,19 @@ atb::Status AtbMLAGetWorkspaceSize(const aclTensor *qNope, const aclTensor *qRop size_t counter = 0; if (param.cacheMode == atb::infer::MultiLatentAttentionParam::CacheMode::INT8_NZCACHE) { if (param.maskType == atb::infer::MultiLatentAttentionParam::MaskType::UNDEFINED) { - pack.inTensors.resize(g_MLAINTENSORNUMINT8NOMASK); - counter = g_MLAINTENSORNUMINT8NOMASK; + pack.inTensors.resize(MLA_INTENSOR_NUM_INT8_NO_MASK); + counter = MLA_INTENSOR_NUM_INT8_NO_MASK; } else { - pack.inTensors.resize(g_MLAINTENSORNUMINT8MASK); - counter = g_MLAINTENSORNUMINT8MASK; + pack.inTensors.resize(MLA_INTENSOR_NUM_INT8_MASK); + counter = MLA_INTENSOR_NUM_INT8_MASK; } } else { if (param.maskType == atb::infer::MultiLatentAttentionParam::MaskType::UNDEFINED) { - pack.inTensors.resize(g_MLAINTENSORNUMNOMASK); - counter = g_MLAINTENSORNUMNOMASK; + pack.inTensors.resize(MLA_INTENSOR_NUM_NO_MASK); + counter = MLA_INTENSOR_NUM_NO_MASK; } else { - pack.inTensors.resize(g_MLAINTENSORNUMMASK); - counter = g_MLAINTENSORNUMMASK; + pack.inTensors.resize(MLA_INTENSOR_NUM_MASK); + counter = MLA_INTENSOR_NUM_MASK; } } if (param.calcType != atb::infer::MultiLatentAttentionParam::CalcType::CALC_TYPE_SPEC) { @@ -95,11 +95,11 @@ atb::Status AtbMLAGetWorkspaceSize(const aclTensor *qNope, const aclTensor *qRop } i = 0; if (param.calcType != atb::infer::MultiLatentAttentionParam::CalcType::CALC_TYPE_RING) { - pack.outTensors.resize(g_MLAOUTTENSORNUMNOCALCRING); + pack.outTensors.resize(MLA_OUTTENSOR_NUM_NO_CALCRING); status = aclTensorToAtbTensor(attenOut, &(pack.outTensors[i++])); ATB_CHECK(status == atb::NO_ERROR, "attenOut create failed!", return status); } else { - pack.outTensors.resize(g_MLAOUTTENSORNUMCALCRING); + pack.outTensors.resize(MLA_OUTTENSOR_NUM_CALCRING); status = aclTensorToAtbTensor(attenOut, &(pack.outTensors[i++])); ATB_CHECK(status == atb::NO_ERROR, "calc_type_ring attenOut create failed!", return status); status = aclTensorToAtbTensor(lse, &(pack.outTensors[i++])); @@ -109,15 +109,19 @@ atb::Status AtbMLAGetWorkspaceSize(const aclTensor *qNope, const aclTensor *qRop ATB_LOG(ERROR) << "AtbMLAGetWorkspaceSize opeartion pointer is nullptr!"; return atb::ERROR_INVALID_OPERATION_ADDR; } + if (op == nullptr || *op == nullptr) { + ATB_LOG(ERROR) << "AtbMLAGetWorkspaceSize opeartion pointer is nullptr!"; + return atb::ERROR_INVALID_OPERATION_ADDR; + } atb::Status st = (*op)->Setup(pack, *workspaceSize, context); ATB_CHECK(st == atb::NO_ERROR, "AtbMLA Setup failed!", return st); return atb::NO_ERROR; } -atb::Status AtbMLA(void *workSpcace, uint64_t workspaceSize, atb::Operation *op, atb::Context *context) +atb::Status AtbMLA(void *workspcace, uint64_t workspaceSize, atb::Operation *op, atb::Context *context) { atb::VariantPack pack; - atb::Status st = op->Execute(pack, (uint8_t *)(workSpcace), workspaceSize, context); + atb::Status st = op->Execute(pack, (uint8_t *)(workspcace), workspaceSize, context); ATB_CHECK(st == atb::NO_ERROR, "AtbMLA Execute failed!", return st); return st; } @@ -147,9 +151,9 @@ atb::Status AtbMLAPreFillGetWorkspaceSize(const aclTensor *q, const aclTensor *q size_t i = 0; if (param.maskType == atb::infer::MultiLatentAttentionParam::MaskType::UNDEFINED) { - pack.inTensors.resize(g_MLAINTENSORNUMNOMASK); + pack.inTensors.resize(MLA_INTENSOR_NUM_NO_MASK); } else { - pack.inTensors.resize(g_MLAINTENSORNUMMASK); + pack.inTensors.resize(MLA_INTENSOR_NUM_MASK); } auto status = aclTensorToAtbTensor(q, &(pack.inTensors[i++])); @@ -172,7 +176,7 @@ atb::Status AtbMLAPreFillGetWorkspaceSize(const aclTensor *q, const aclTensor *q ATB_CHECK(status == atb::NO_ERROR, "mask create failed!", return status); } - pack.outTensors.resize(g_MLAOUTTENSORNUMNOCALCRING); + pack.outTensors.resize(MLA_OUTTENSOR_NUM_NO_CALCRING); status = aclTensorToAtbTensor(attenOut, &(pack.outTensors[0])); ATB_CHECK(status == atb::NO_ERROR, "attenOut create failed!", return status); diff --git a/src/cinterface/atb_acl_mla_preprocess.cpp b/src/cinterface/atb_acl_mla_preprocess.cpp index 12915080ae70e15c9b47b5afa49652022e0915de..b4385a8793d4262a340e11519b07eda8547f9885 100644 --- a/src/cinterface/atb_acl_mla_preprocess.cpp +++ b/src/cinterface/atb_acl_mla_preprocess.cpp @@ -15,9 +15,9 @@ extern "C" { #endif -const size_t g_MLAPPINTENSORNUM = 24; -const size_t g_MLAPPOUTTENSORNUMCACHEMODE = 4; -const size_t g_MLAPPOUTTENSORNUM = 2; +const size_t MLAPPINTENSORNUM = 24; +const size_t MLAPPOUTTENSORNUMCACHEMODE = 4; +const size_t MLAPPOUTTENSORNUM = 2; atb::Status AtbMLAPreprocessGetWorkspaceSize( const aclTensor *input, const aclTensor *gamma0, const aclTensor *beta0, const aclTensor *quantScale0, @@ -53,7 +53,7 @@ atb::Status AtbMLAPreprocessGetWorkspaceSize( } atb::VariantPack pack; size_t i = 0; - pack.inTensors.resize(g_MLAPPINTENSORNUM); + pack.inTensors.resize(MLAPPINTENSORNUM); auto status = aclTensorToAtbTensor(input, &(pack.inTensors[i++])); ATB_CHECK(status == ACL_ERROR_NONE, "input create failed!", return status); status = aclTensorToAtbTensor(gamma0, &(pack.inTensors[i++])); @@ -139,7 +139,7 @@ atb::Status AtbMLAPreprocessGetWorkspaceSize( i = 0; if (param.cacheMode != atb::infer::MlaPreprocessParam::CacheMode::KVCACHE) { - pack.outTensors.resize(g_MLAPPOUTTENSORNUMCACHEMODE); + pack.outTensors.resize(MLAPPOUTTENSORNUMCACHEMODE); status = aclTensorToAtbTensor(qOut0, &(pack.outTensors[i++])); ATB_CHECK(status == ACL_ERROR_NONE, "qOut0 create failed!", return status); status = aclTensorToAtbTensor(kvCacheOut0, &(pack.outTensors[i++])); @@ -149,7 +149,7 @@ atb::Status AtbMLAPreprocessGetWorkspaceSize( status = aclTensorToAtbTensor(kvCacheOut1, &(pack.outTensors[i++])); ATB_CHECK(status == ACL_ERROR_NONE, "kvCacheOut1 create failed!", return status); } else { - pack.outTensors.resize(g_MLAPPOUTTENSORNUM); + pack.outTensors.resize(MLAPPOUTTENSORNUM); status = aclTensorToAtbTensor(qOut0, &(pack.outTensors[i++])); ATB_CHECK(status == ACL_ERROR_NONE, "qOut0 create failed!", return status); status = aclTensorToAtbTensor(kvCacheOut0, &(pack.outTensors[i++])); diff --git a/src/cinterface/atb_acl_paged_cache_load.cpp b/src/cinterface/atb_acl_paged_cache_load.cpp index b84b9f8bb2b0e67bb0b04a28a00d3fa899129acd..670d7bd826f4ae1bb52677869175386a8a1e839f 100644 --- a/src/cinterface/atb_acl_paged_cache_load.cpp +++ b/src/cinterface/atb_acl_paged_cache_load.cpp @@ -15,8 +15,8 @@ extern "C" { #endif -const size_t g_PAGED_CACHE_LOAD_INTENSOR_NUM = 6; -const size_t g_PAGED_CACHE_LOAD_OUTTENSOR_NUM = 2; +const size_t PAGED_CACHE_LOAD_INTENSOR_NUM = 6; +const size_t PAGED_CACHE_LOAD_OUTTENSOR_NUM = 2; atb::Status AtbPagedCacheLoadGetWorkspaceSize(const aclTensor *keyCache, const aclTensor *valueCache, const aclTensor *blockTables, const aclTensor *contextLens, @@ -39,9 +39,9 @@ atb::Status AtbPagedCacheLoadGetWorkspaceSize(const aclTensor *keyCache, const a atb::VariantPack pack; size_t i = 0; if (param.hasSeqStarts) { - pack.inTensors.resize(g_PAGED_CACHE_LOAD_INTENSOR_NUM + 1); + pack.inTensors.resize(PAGED_CACHE_LOAD_INTENSOR_NUM + 1); } else { - pack.inTensors.resize(g_PAGED_CACHE_LOAD_INTENSOR_NUM); + pack.inTensors.resize(PAGED_CACHE_LOAD_INTENSOR_NUM); } auto status = aclTensorToAtbTensor(keyCache, &(pack.inTensors[i++])); @@ -62,7 +62,7 @@ atb::Status AtbPagedCacheLoadGetWorkspaceSize(const aclTensor *keyCache, const a } i = 0; - pack.outTensors.resize(g_PAGED_CACHE_LOAD_OUTTENSOR_NUM); + pack.outTensors.resize(PAGED_CACHE_LOAD_OUTTENSOR_NUM); status = aclTensorToAtbTensor(key, &(pack.outTensors[i++])); ATB_CHECK(status == atb::NO_ERROR, "key create failed!", return status); status = aclTensorToAtbTensor(value, &(pack.outTensors[i++])); diff --git a/src/cinterface/atb_acl_ring_mla.cpp b/src/cinterface/atb_acl_ring_mla.cpp index 0f80b5eb087e9a69e0bfbeac1ca9b8554e1e74ae..65c25f234e4c1ee751cac94330f79136685a3202 100644 --- a/src/cinterface/atb_acl_ring_mla.cpp +++ b/src/cinterface/atb_acl_ring_mla.cpp @@ -15,8 +15,8 @@ extern "C" { #endif -const size_t g_RING_MLA_INTENSOR_NUM = 7; -const size_t g_RING_MLA_OUTTENSOR_NUM = 2; +const size_t RING_MLA_INTENSOR_NUM = 7; +const size_t RING_MLA_OUTTENSOR_NUM = 2; atb::Status AtbRingMLAGetWorkspaceSize(const aclTensor *querySplit1, const aclTensor *querySplit2, const aclTensor *keySplit1, const aclTensor *keySplit2, const aclTensor *value, @@ -44,9 +44,9 @@ atb::Status AtbRingMLAGetWorkspaceSize(const aclTensor *querySplit1, const aclTe atb::VariantPack pack; size_t index = 0; if (param.calcType == atb::infer::RingMLAParam::CalcType::CALC_TYPE_DEFAULT) { - pack.inTensors.resize(g_RING_MLA_INTENSOR_NUM + 2); // 2: prevOut, prevLse + pack.inTensors.resize(RING_MLA_INTENSOR_NUM + 2); // 2: prevOut, prevLse } else { - pack.inTensors.resize(g_RING_MLA_INTENSOR_NUM); + pack.inTensors.resize(RING_MLA_INTENSOR_NUM); } auto status = aclTensorToAtbTensor(querySplit1, &(pack.inTensors[index++])); @@ -71,7 +71,7 @@ atb::Status AtbRingMLAGetWorkspaceSize(const aclTensor *querySplit1, const aclTe } index = 0; - pack.outTensors.resize(g_RING_MLA_OUTTENSOR_NUM); + pack.outTensors.resize(RING_MLA_OUTTENSOR_NUM); status = aclTensorToAtbTensor(output, &(pack.outTensors[index++])); ATB_CHECK(status == atb::NO_ERROR, "output create failed!", return status); status = aclTensorToAtbTensor(softmaxLse, &(pack.outTensors[index++])); diff --git a/src/cinterface/atb_acl_self_attention_prefix_encoder.cpp b/src/cinterface/atb_acl_self_attention_prefix_encoder.cpp index ff5440a211a6ad6abb3c960e70494106f7f1eb45..fca57823fc0e15c9e5dd0d74c11de40011fbb588 100644 --- a/src/cinterface/atb_acl_self_attention_prefix_encoder.cpp +++ b/src/cinterface/atb_acl_self_attention_prefix_encoder.cpp @@ -15,8 +15,8 @@ See LICENSE in the root of the software repository for the full text of the Lice extern "C" { #endif -const size_t g_SELF_ATTENTION_PREFIX_ENCODER_INTENSOR_NUM = 6; -const size_t g_SELF_ATTENTION_PREFIX_ENCODER_OUTTENSOR_NUM = 1; +const size_t SELF_ATTENTION_PREFIX_ENCODER_INTENSOR_NUM = 6; +const size_t SELF_ATTENTION_PREFIX_ENCODER_OUTTENSOR_NUM = 1; atb::Status AtbSelfAttentionPrefixEncoderGetWorkspaceSize(const aclTensor *query, const aclTensor *key, const aclTensor *value, const aclTensor *blockTables, @@ -57,11 +57,11 @@ atb::Status AtbSelfAttentionPrefixEncoderGetWorkspaceSize(const aclTensor *query bool isAlibiMask = param.maskType == atb::infer::SelfAttentionParam::MaskType::MASK_TYPE_ALIBI_COMPRESS || param.maskType == atb::infer::SelfAttentionParam::MaskType::MASK_TYPE_ALIBI_COMPRESS_SQRT; if (isAlibiMask) { - pack.inTensors.resize(g_SELF_ATTENTION_PREFIX_ENCODER_INTENSOR_NUM + 2); // 2: mask, slopes + pack.inTensors.resize(SELF_ATTENTION_PREFIX_ENCODER_INTENSOR_NUM + 2); // 2: mask, slopes } else if (param.maskType == atb::infer::SelfAttentionParam::MaskType::MASK_TYPE_CAUSAL_MASK) { - pack.inTensors.resize(g_SELF_ATTENTION_PREFIX_ENCODER_INTENSOR_NUM); // mask auto-generated + pack.inTensors.resize(SELF_ATTENTION_PREFIX_ENCODER_INTENSOR_NUM); // mask auto-generated } else { - pack.inTensors.resize(g_SELF_ATTENTION_PREFIX_ENCODER_INTENSOR_NUM + 1); // 1: mask + pack.inTensors.resize(SELF_ATTENTION_PREFIX_ENCODER_INTENSOR_NUM + 1); // 1: mask } auto status = aclTensorToAtbTensor(query, &(pack.inTensors[index++])); @@ -86,7 +86,7 @@ atb::Status AtbSelfAttentionPrefixEncoderGetWorkspaceSize(const aclTensor *query } index = 0; - pack.outTensors.resize(g_SELF_ATTENTION_PREFIX_ENCODER_OUTTENSOR_NUM); + pack.outTensors.resize(SELF_ATTENTION_PREFIX_ENCODER_OUTTENSOR_NUM); status = aclTensorToAtbTensor(attnOut, &(pack.outTensors[index])); ATB_CHECK(status == atb::NO_ERROR, "attnOut create failed!", return status); diff --git a/src/cinterface/atb_acl_util.cpp b/src/cinterface/atb_acl_util.cpp index 9d1eb39ad4fe5f3d53ca28d2ab6765e5c46b66a4..cf2696720930af06a3fb22b2d352b8b774173fb0 100644 --- a/src/cinterface/atb_acl_util.cpp +++ b/src/cinterface/atb_acl_util.cpp @@ -51,7 +51,8 @@ atb::Status aclTensorToAtbTensor(const aclTensor *aclTensorSrc, atb::Tensor *atb atbTensorDst->desc = desc; atbTensorDst->deviceData = aclTensorSrc->GetData(); atbTensorDst->hostData = nullptr; - atbTensorDst->dataSize = GetTensorSize(aclTensorSrc) * aclDataTypeSize(dataType); + atbTensorDst->dataSize = + static_cast(GetTensorSize(aclTensorSrc) * static_cast(aclDataTypeSize(dataType))); return atb::NO_ERROR; } diff --git a/src/cinterface/atb_acl_util.h b/src/cinterface/atb_acl_util.h index 96ce1b1d5dd3d478f8f5d9743cbcbe743926f976..101288638d6884085995e968381866599f114246 100644 --- a/src/cinterface/atb_acl_util.h +++ b/src/cinterface/atb_acl_util.h @@ -11,6 +11,8 @@ See LICENSE in the root of the software repository for the full text of the Lice #ifndef ATB_ACL_UTIL_H #define ATB_ACL_UTIL_H +#include "atb/types.h" + #ifdef __cplusplus extern "C" { #endif diff --git a/src/include/atb/core/context_base.h b/src/include/atb/core/context_base.h index baf07acef91de9b26c63024035c6c9a36a224c1e..af4321e18d3b6bba72175339574ba14551c9b94f 100644 --- a/src/include/atb/core/context_base.h +++ b/src/include/atb/core/context_base.h @@ -24,7 +24,8 @@ public: ~ContextBase() override; ContextBase(const ContextBase &other) = delete; ContextBase &operator=(const ContextBase &other) = delete; - Status Init(const std::function& alloc = nullptr, const std::function& dealloc = nullptr); + Status Init(const std::function &alloc = nullptr, + const std::function &dealloc = nullptr); void Destroy(); Status SetExecuteStream(aclrtStream stream) override; aclrtStream GetExecuteStream() const override; @@ -49,7 +50,7 @@ public: void *GetArgsHostBuffer(size_t bufferSize); Status FreeArgsDeviceBuffer(void *addr); Status FreeArgsHostBuffer(void *addr); - bool GetLaunchWithTilingStatus(); + bool GetLaunchWithTilingStatus() const; private: Status CreateCopyStreamAndEvents(); @@ -70,10 +71,10 @@ private: Tensor overflowOutTensor_; static thread_local ExecuteType executeType_; LaunchMode mode_ = KERNEL_LAUNCH_MODE; - std::unique_ptr deviceAllocator_; // 一开始就赋值为defaultDeviceAllocator - std::unique_ptr hostAllocator_; // 一开始就赋值为defaultHostAllocator - std::function allocateFunc_; // 默认使用defaultDeviceAllocator中的Allocate方法 - std::function deallocateFunc_; // 默认使用defaultDeviceAllocator中的Deallocate方法 + std::unique_ptr deviceAllocator_; // 一开始就赋值为defaultDeviceAllocator + std::unique_ptr hostAllocator_; // 一开始就赋值为defaultHostAllocator + std::function allocateFunc_; // 默认使用defaultDeviceAllocator中的Allocate方法 + std::function deallocateFunc_; // 默认使用defaultDeviceAllocator中的Deallocate方法 }; } // namespace atb #endif \ No newline at end of file diff --git a/src/include/atb/core/node_impl/mki_node_implement.h b/src/include/atb/core/node_impl/mki_node_implement.h index 589cd3c1eb42917fb5a588f8e54b8d944082fd0f..c44772d3922b36de2dd47ee002cd7f6ad224b735 100644 --- a/src/include/atb/core/node_impl/mki_node_implement.h +++ b/src/include/atb/core/node_impl/mki_node_implement.h @@ -32,12 +32,12 @@ public: size_t GetTilingSize() const override; bool UpdateBestKernel() override; int64_t GetWorkspaceSize() const override; - Status InitKernelInfo(uint8_t *hostTilingBuffer, uint64_t tilingSize, bool launchWithTiling) override; + Status InitKernelInfo(uint8_t *hostTilingBuffer, uint64_t tilingSize, bool isLaunchWithTiling) override; void SetWorkspaceDeviceAddr(uint8_t *deviceWorkspaceBuffer) override; void SetTilingDeviceAddr(uint8_t *deviceTilingBuffer) override; Status Run(aclrtStream stream) override; bool GetCachedTiling(KernelCache &kernelCache, size_t kernelIndex, uint8_t *kernelHostTilingBuffer, - uint64_t maxTilingSize, uint64_t &tilingSizeFetched, bool launchWithTiling) override; + uint64_t maxTilingSize, uint64_t &tilingSizeFetched, bool isLaunchWithTiling) override; void AddTiling(KernelCache &kernelCache, size_t kernelIndex, uint8_t *hostTilingBuffer, size_t tilingSize) const override; void SetArgsDeviceBuffer(void *deviceBuffer) override; @@ -76,7 +76,7 @@ private: void *argsHostBuffer_ = nullptr; }; -static const std::unordered_map InitAtbMkiErrorHash() noexcept +static inline const std::unordered_map InitAtbMkiErrorHash() noexcept { return {{Mki::ErrorType::NO_ERROR, atb::ErrorType::NO_ERROR}, {Mki::ErrorType::ERROR_INVALID_VALUE, atb::ErrorType::ERROR_INVALID_PARAM}, diff --git a/src/include/atb/core/tiling_buffer_pool/device_tiling_buffer_pool.h b/src/include/atb/core/tiling_buffer_pool/device_tiling_buffer_pool.h index 2fd14bb98b2517f798272bd2894e2edf5d7792ef..9f654daa724fd88da1c70a9d313541f42375ab26 100644 --- a/src/include/atb/core/tiling_buffer_pool/device_tiling_buffer_pool.h +++ b/src/include/atb/core/tiling_buffer_pool/device_tiling_buffer_pool.h @@ -15,7 +15,8 @@ namespace atb { class DeviceTilingBufferPool : public TilingBufferPool { public: - DeviceTilingBufferPool(uint64_t blockNum, uint64_t blockSize, const std::function& alloc, const std::function& dealloc); + DeviceTilingBufferPool(uint64_t blockNum, uint64_t blockSize, const std::function &alloc, + const std::function &dealloc); ~DeviceTilingBufferPool() override; protected: @@ -24,8 +25,8 @@ protected: bool IsDeviceBufferPool() override; private: - std::function allocateFunc_; - std::function deallocateFunc_; + std::function allocateFunc_; + std::function deallocateFunc_; }; } // namespace atb #endif \ No newline at end of file diff --git a/src/include/atb/operation/operation_base.h b/src/include/atb/operation/operation_base.h index 037ae2af993c4114edc6891472af906a4884dbfa..10ea7eea377d42f54094962f529b15610c842c7d 100644 --- a/src/include/atb/operation/operation_base.h +++ b/src/include/atb/operation/operation_base.h @@ -87,7 +87,7 @@ private: Context *context); template Status ExecuteVariantPackInTensorCheck(const SVector &inTensors) const; template Status ExecuteVariantPackOutTensorCheck(const SVector &outTensors) const; - Status ExecuteVariantPackCheck(const VariantPack &variantPack); + Status ExecuteVariantPackCheck(const VariantPack &variantPack) const; void InitRunnerVariantPack(const VariantPack &variantPack); Status CopyHostTilingToDevice(aclrtStream stream); Status CopyTilingToDevice(); @@ -116,7 +116,7 @@ private: Status EagerModeLaunch(); Status GraphModeLaunch(); void ProfilingPrepare(); - Status CopyArgsToDevice(Context *context); + Status CopyArgsToDevice(Context *context) const; private: std::string logPrefix_; diff --git a/src/include/atb/utils/mstx_mem_register.h b/src/include/atb/utils/mstx_mem_register.h index eac540651d30d9d1ecedc0b7a53d63faf086d4a4..12e90afacdd3802d938d4ac01c876f141595122f 100644 --- a/src/include/atb/utils/mstx_mem_register.h +++ b/src/include/atb/utils/mstx_mem_register.h @@ -9,8 +9,6 @@ */ #ifndef ATB_UTILS_MSTX_REGISTER_H #define ATB_UTILS_MSTX_REGISTER_H -#include -#include #include #include #include @@ -23,7 +21,7 @@ public: ~MstxMemRegister(); static mstxDomainHandle_t &GetRegisterDomain(); static bool IsMstxEnable(); - Status CheckTensorRange(); + bool CheckTensorRange(); Status MstxHeapRegister(void *workspace, uint64_t workspaceSize); void MstxMemRegionsRegister(); void MstxMemRegionsUnregister(); diff --git a/src/ops_infer/block_copy/block_copy_operation.cpp b/src/ops_infer/block_copy/block_copy_operation.cpp index a9f70eb717464d7a651b93aa04312be89fb4965d..0ac8cb881c00ec1990f147c0419a2a876a569a96 100644 --- a/src/ops_infer/block_copy/block_copy_operation.cpp +++ b/src/ops_infer/block_copy/block_copy_operation.cpp @@ -21,6 +21,9 @@ namespace { constexpr static uint32_t CACHE_DIM = 4; constexpr static uint32_t INDICES_DIM = 1; + +constexpr static size_t INPUT_K_CACHE = 0; +constexpr static size_t INPUT_V_CACHE = 1; constexpr static size_t INPUT_SRC_BLOCK = 2; constexpr static size_t INPUT_DST_BLOCK = 3; constexpr static size_t INPUT_CUMSUM = 4; @@ -63,12 +66,13 @@ uint32_t BlockCopyOperation::GetOutputNum() const Status BlockCopyOperation::InferShapeCheckImpl(const SVector &inTensorDescs) const { - int64_t blockCount = inTensorDescs.at(0).shape.dims[0]; - if (!TensorUtil::TensorShapeEqual(inTensorDescs.at(0).shape, inTensorDescs.at(1).shape)) { + int64_t blockCount = inTensorDescs.at(INPUT_K_CACHE).shape.dims[0]; + if (!TensorUtil::TensorShapeEqual(inTensorDescs.at(INPUT_K_CACHE).shape, inTensorDescs.at(INPUT_V_CACHE).shape)) { ATB_LOG(ERROR) << GetLogPrefix() << "kCache shape is not equal vCache shape"; return ERROR_INVALID_TENSOR_DIM; } - if (inTensorDescs.at(0).shape.dimNum != CACHE_DIM || inTensorDescs.at(1).shape.dimNum != CACHE_DIM) { + if (inTensorDescs.at(INPUT_K_CACHE).shape.dimNum != CACHE_DIM || + inTensorDescs.at(INPUT_V_CACHE).shape.dimNum != CACHE_DIM) { ATB_LOG(ERROR) << GetLogPrefix() << "cache shape is not " << CACHE_DIM; return ERROR_INVALID_TENSOR_DIM_NUM; } @@ -100,12 +104,13 @@ Status BlockCopyOperation::InferShapeImpl(const SVector &inTensorDes Status BlockCopyOperation::SetupCheckImpl(const SVector &inTensors, const SVector &outTensors) const { (void)outTensors; - int64_t blockCount = inTensors.at(0).desc.shape.dims[0]; - if (!TensorUtil::TensorDescEqual(inTensors.at(0).desc, inTensors.at(1).desc)) { + int64_t blockCount = inTensors.at(INPUT_K_CACHE).desc.shape.dims[0]; + if (!TensorUtil::TensorDescEqual(inTensors.at(INPUT_K_CACHE).desc, inTensors.at(INPUT_V_CACHE).desc)) { ATB_LOG(ERROR) << GetLogPrefix() << "kCache desc is not equal vCache desc"; return ERROR_INVALID_TENSOR_DIM; } - if (inTensors.at(0).desc.shape.dimNum != CACHE_DIM || inTensors.at(1).desc.shape.dimNum != CACHE_DIM) { + if (inTensors.at(INPUT_K_CACHE).desc.shape.dimNum != CACHE_DIM || + inTensors.at(INPUT_V_CACHE).desc.shape.dimNum != CACHE_DIM) { ATB_LOG(ERROR) << GetLogPrefix() << "cache shape is not 4"; return ERROR_INVALID_TENSOR_DIM_NUM; } @@ -125,8 +130,8 @@ Status BlockCopyOperation::SetupCheckImpl(const SVector &inTensors, cons return ERROR_INVALID_TENSOR_DIM; } if (GetSingleton().Is310P()) { - if ((inTensors.at(0).desc.dtype != ACL_FLOAT16) || - (inTensors.at(1).desc.dtype != ACL_FLOAT16)) { + if ((inTensors.at(INPUT_K_CACHE).desc.dtype != ACL_FLOAT16) || + (inTensors.at(INPUT_V_CACHE).desc.dtype != ACL_FLOAT16)) { ATB_LOG(ERROR) << "Atlas 300I Duo inference products only support fp16"; return ERROR_INVALID_TENSOR_DTYPE; } @@ -134,45 +139,51 @@ Status BlockCopyOperation::SetupCheckImpl(const SVector &inTensors, cons if (status != NO_ERROR) { return status; } - } else if (inTensors.at(0).desc.format == aclFormat::ACL_FORMAT_FRACTAL_NZ || - inTensors.at(1).desc.format == aclFormat::ACL_FORMAT_FRACTAL_NZ) { - return ERROR_INVALID_TENSOR_FORMAT; + } else if (inTensors.at(INPUT_K_CACHE).desc.format == aclFormat::ACL_FORMAT_FRACTAL_NZ || + inTensors.at(INPUT_V_CACHE).desc.format == aclFormat::ACL_FORMAT_FRACTAL_NZ) { + return ERROR_INVALID_TENSOR_FORMAT; } return NO_ERROR; } Status BlockCopyOperation::SetupDimCheck310P(const SVector &inTensors) const { - if ((inTensors.at(0).desc.shape.dimNum != 4) || // kCache: 4 dim - (inTensors.at(1).desc.shape.dimNum != 4) || // vCache: 4 dim - (inTensors.at(2).desc.shape.dimNum != 1) || // 2: src - (inTensors.at(3).desc.shape.dimNum != 1) || // 3: dst - (inTensors.at(4).desc.shape.dimNum != 1) // 4: cumsum + if ((inTensors.at(INPUT_K_CACHE).desc.shape.dimNum != 4) || // kCache: 4 dim + (inTensors.at(INPUT_V_CACHE).desc.shape.dimNum != 4) || // vCache: 4 dim + (inTensors.at(INPUT_SRC_BLOCK).desc.shape.dimNum != 1) || // 2: src, 1: dimNum + (inTensors.at(INPUT_DST_BLOCK).desc.shape.dimNum != 1) || // 3: dst, 1: dimNum + (inTensors.at(INPUT_CUMSUM).desc.shape.dimNum != 1) // 4: cumsum, 1: dimNum ) { ATB_LOG(ERROR) << "invalid intensor dimNums"; return ERROR_INVALID_TENSOR_DIM; } - if (inTensors.at(0).desc.format == aclFormat::ACL_FORMAT_FRACTAL_NZ || - inTensors.at(1).desc.format == aclFormat::ACL_FORMAT_FRACTAL_NZ) { - if ((inTensors.at(0).desc.shape.dims[3] != NZBLOCKSIZE) || - (inTensors.at(1).desc.shape.dims[3] != NZBLOCKSIZE) || - (inTensors.at(0).desc.shape.dims[2] % NZBLOCKSIZE != 0) || - (inTensors.at(1).desc.shape.dims[2] % NZBLOCKSIZE != 0)) { // 2: dim - ATB_LOG(ERROR) << GetLogPrefix() << "NZ format tensor dim should be aligned to 16"; - return ERROR_INVALID_TENSOR_DIM; - } + if (inTensors.at(INPUT_K_CACHE).desc.format == aclFormat::ACL_FORMAT_FRACTAL_NZ || + inTensors.at(INPUT_V_CACHE).desc.format == aclFormat::ACL_FORMAT_FRACTAL_NZ) { + if ((inTensors.at(INPUT_K_CACHE).desc.shape.dims[INPUT_DST_BLOCK] != NZBLOCKSIZE) || + (inTensors.at(INPUT_V_CACHE).desc.shape.dims[INPUT_DST_BLOCK] != NZBLOCKSIZE) || + (inTensors.at(INPUT_K_CACHE).desc.shape.dims[2] % NZBLOCKSIZE != 0) || // 2: blockSize dim + (inTensors.at(INPUT_V_CACHE).desc.shape.dims[2] % NZBLOCKSIZE != 0)) { // 2: blockSize dim + ATB_LOG(ERROR) << GetLogPrefix() << "NZ format tensor dim should be aligned to 16"; + return ERROR_INVALID_TENSOR_DIM; + } } else { - if ((inTensors.at(0).desc.shape.dims[3] * inTensors.at(0).desc.shape.dims[2] * inTensors.at(0).desc.shape.dims[1]) % NZBLOCKSIZE != 0 || - (inTensors.at(1).desc.shape.dims[3] * inTensors.at(1).desc.shape.dims[2] * inTensors.at(0).desc.shape.dims[1]) % NZBLOCKSIZE != 0) { - ATB_LOG(ERROR) << GetLogPrefix() << "ND format product of the first three tensor dim should be aligned to 16"; - return ERROR_INVALID_TENSOR_DIM; - } + const int64_t ndPruduct1 = inTensors.at(INPUT_K_CACHE).desc.shape.dims[3] * // 3: dim + inTensors.at(INPUT_K_CACHE).desc.shape.dims[2] * // 2: dim + inTensors.at(INPUT_K_CACHE).desc.shape.dims[1]; // 1: dim + const int64_t ndPruduct2 = inTensors.at(INPUT_V_CACHE).desc.shape.dims[3] * // 3: dim + inTensors.at(INPUT_V_CACHE).desc.shape.dims[2] * // 2: dim + inTensors.at(INPUT_K_CACHE).desc.shape.dims[1]; // 1: dim + if ((ndPruduct1 % NZBLOCKSIZE != 0) || (ndPruduct2 % NZBLOCKSIZE != 0)) { + ATB_LOG(ERROR) << GetLogPrefix() + << "ND format product of the first three tensor dim should be aligned to 16"; + return ERROR_INVALID_TENSOR_DIM; + } } - if (inTensors.at(2).desc.shape.dims[0] != inTensors.at(4).desc.shape.dims[0]) { + if (inTensors.at(INPUT_SRC_BLOCK).desc.shape.dims[0] != inTensors.at(INPUT_CUMSUM).desc.shape.dims[0]) { ATB_LOG(ERROR) << "src dim should be same as cumsum"; return ERROR_INVALID_TENSOR_DIM; } - return NO_ERROR; + return NO_ERROR; } std::shared_ptr BlockCopyOperation::CreateRunner(Context &context) const diff --git a/src/ops_infer/fill/fill_ops_runner.h b/src/ops_infer/fill/fill_ops_runner.h index 8ba3c4caec70021ed48fb90de3cd47773f4c514a..fb2fbb4544b893dfef83711cab810c6ae2087eb0 100644 --- a/src/ops_infer/fill/fill_ops_runner.h +++ b/src/ops_infer/fill/fill_ops_runner.h @@ -27,22 +27,24 @@ private: }; namespace infer { +inline bool IsFloatSVectorEqual(const SVector &v1, const SVector &v2) +{ + if (v1.size() != v2.size()) { + return false; + } + for (size_t i = 0; i < v1.size(); ++i) { + if (!UtilsInternal::IsFloatEqual(v1[i], v2[i])) { + return false; + } + } + return true; +} + inline bool operator==(const FillParam &left, const FillParam &right) { - return left.withMask == right.withMask && - [](const SVector &v1, const SVector &v2) { - if (v1.size() != v2.size()) { - return false; - } - for (size_t i = 0; i < v1.size(); ++i) { - if (!UtilsInternal::IsFloatEqual(v1[i], v2[i])) { - return false; - } - } - return true; - }(left.value, right.value) && + return left.withMask == right.withMask && IsFloatSVectorEqual(left.value, right.value) && left.outDim == right.outDim; } } // namespace infer } // namespace atb -#endif \ No newline at end of file +#endif diff --git a/src/ops_infer/gmm_deq_swiglu_quant_gmm_deq/gmm_deq_swiglu_quant_gmm_deq_operation.cpp b/src/ops_infer/gmm_deq_swiglu_quant_gmm_deq/gmm_deq_swiglu_quant_gmm_deq_operation.cpp index 5e5956e2d0c7b201d066464cabb7e550cd0f7638..59cfe8288a4f55f168f164cd22b9b1436ddacaca 100644 --- a/src/ops_infer/gmm_deq_swiglu_quant_gmm_deq/gmm_deq_swiglu_quant_gmm_deq_operation.cpp +++ b/src/ops_infer/gmm_deq_swiglu_quant_gmm_deq/gmm_deq_swiglu_quant_gmm_deq_operation.cpp @@ -100,12 +100,12 @@ bool ParamCheck(const atb::infer::GmmDeqSwigluQuantGmmDeqParam &opParam) return false; } - if (opParam.transposeWeightUp != false) { + if (opParam.transposeWeightUp) { ATB_LOG(ERROR) << "Param transposeWeightUp only support false."; return false; } - if (opParam.transposeWeightDown != true) { + if (!opParam.transposeWeightDown) { ATB_LOG(ERROR) << "Param transposeWeightDown only support true."; return false; } diff --git a/src/ops_infer/linear/linear_operation.cpp b/src/ops_infer/linear/linear_operation.cpp index 8a9af41c2eb27a98183ca3cab1b142479568dc07..041c53bb16e8a273270db6798247ddcb13e4d1a0 100644 --- a/src/ops_infer/linear/linear_operation.cpp +++ b/src/ops_infer/linear/linear_operation.cpp @@ -33,7 +33,8 @@ bool MatmulParamCheck(const infer::LinearParam &opParam, ExternalError &error) { if (opParam.quantMode != atb::infer::LinearParam::QUANT_UNDEFINED) { error.errorData = OperationUtil::ConcatInfo(error.errorData, "quantMode = ", opParam.quantMode); - error.errorDesc = "QuantMode is not the default value. When outDataType is undefind, quantMode needs to be undefind."; + error.errorDesc = + "QuantMode is not the default value. When outDataType is undefind, quantMode needs to be undefind."; error.solutionDesc = "Please check the quantMode value of input params."; ATB_LOG(ERROR) << error; return false; @@ -376,7 +377,8 @@ Status LinearOperation::InTensorDescsCheck(const SVector &inTensorDe const TensorDesc &deqScaleTensorDesc = inTensorDescs.at(inTensorId++); if (param_.quantMode == infer::LinearParam::PER_TOKEN) { const TensorDesc &perTokenDeqScaleTensorDesc = inTensorDescs.at(inTensorId++); - status = PerTokenDeqScaleCheck(deqScaleTensorDesc, weightTensorDesc, xTensorDesc, perTokenDeqScaleTensorDesc); + status = + PerTokenDeqScaleCheck(deqScaleTensorDesc, weightTensorDesc, xTensorDesc, perTokenDeqScaleTensorDesc); } else { status = DeqScaleCheck(deqScaleTensorDesc, weightTensorDesc); } @@ -519,7 +521,8 @@ Status LinearOperation::DeqScaleCheck(const TensorDesc &deqScaleTensorDesc, cons } Status LinearOperation::PerTokenDeqScaleCheck(const TensorDesc &deqScaleTensorDesc, const TensorDesc &weightTensorDesc, - const TensorDesc &xTensorDesc, const TensorDesc &perTokendeqScaleTensorDesc) const + const TensorDesc &xTensorDesc, + const TensorDesc &perTokendeqScaleTensorDesc) const { ExternalError error; error.solutionDesc = "Please check the shape of inTensors."; @@ -654,14 +657,17 @@ bool LinearOperation::XWeightDimNumCheck(const TensorDesc &xTensorDesc, const Te return true; } -bool LinearOperation::PerTokenXWeightDimNumCheck(const TensorDesc &xTensorDesc, const TensorDesc &weightTensorDesc) const +bool LinearOperation::PerTokenXWeightDimNumCheck(const TensorDesc &xTensorDesc, + const TensorDesc &weightTensorDesc) const { ExternalError error; error.errorType = ERROR_INVALID_TENSOR_DIM_NUM; error.solutionDesc = "Please check format and shape of inTensors."; - if (param_.quantMode == infer::LinearParam::PER_TOKEN && xTensorDesc.shape.dimNum == DIM_NUM_3 && weightTensorDesc.shape.dimNum == DIM_NUM_2) { + if (param_.quantMode == infer::LinearParam::PER_TOKEN && xTensorDesc.shape.dimNum == DIM_NUM_3 && + weightTensorDesc.shape.dimNum == DIM_NUM_2) { error.errorDesc = "When quantMode is PER_TOKEN and inTensor0 dim num is 3, inTensor1 dim num cannot be 2,"; - error.errorData = OperationUtil::ConcatInfo("quantMode = ", param_.quantMode, ", inTensor0 dimNum = ", xTensorDesc.shape.dimNum, + error.errorData = OperationUtil::ConcatInfo("quantMode = ", param_.quantMode, + ", inTensor0 dimNum = ", xTensorDesc.shape.dimNum, ", inTensor1 dimNum = ", weightTensorDesc.shape.dimNum); ATB_LOG(ERROR) << GetLogPrefix() << error; return false; diff --git a/src/ops_infer/linear/linear_ops_runner.cpp b/src/ops_infer/linear/linear_ops_runner.cpp index 023f967876dc427dc6d76d881327a8d39accbefe..220b2a5a5f857563d8660e03ced3998e91c865fe 100644 --- a/src/ops_infer/linear/linear_ops_runner.cpp +++ b/src/ops_infer/linear/linear_ops_runner.cpp @@ -14,6 +14,8 @@ #include "atb/utils/singleton.h" #include "atb/utils/utils_internal.h" +static constexpr size_t SIZE_0 = 0; +static constexpr size_t SIZE_1 = 1; static constexpr size_t SIZE_2 = 2; static constexpr size_t SIZE_3 = 3; static constexpr size_t SIZE_4 = 4; @@ -99,7 +101,7 @@ Status LinearOpsRunner::SetupKernelGraph(const OpsTensorPack &opsTensorPack) if (matmulParam_.enDequant) { if (GetSingleton().Is910B() && param_.quantMode == infer::LinearParam::PER_TOKEN) { - return SetupKernelGraphMatmulDequantPerToken910B(); + return SetupKernelGraphMatmulDequantPerToken910B(); } else if (GetSingleton().Is910B() || GetSingleton().Is310B()) { return SetupKernelGraphMatmulDequant910B(); } @@ -139,7 +141,7 @@ Status LinearOpsRunner::SetupKernelGraphMatmul910B() { ATB_LOG(INFO) << GetLogPrefix() << "LinearOpsRunner::SetupKernelGraphMatmulA2"; - InitKernelGraph(SIZE_2, 1, 0, 1); + InitKernelGraph(SIZE_2, SIZE_1, SIZE_0, SIZE_1); Mki::Tensor &xTensor = kernelGraph_.inTensors.at(0); Mki::Tensor &weightTensor = kernelGraph_.inTensors.at(1); @@ -168,7 +170,7 @@ Status LinearOpsRunner::SetupKernelGraphMatmulWeightNdNot910B() { ATB_LOG(INFO) << GetLogPrefix() << "LinearOpsRunner::SetupKernelGraphMatmulWeightNdNotA2"; - InitKernelGraph(SIZE_2, 1, SIZE_3, SIZE_4); + InitKernelGraph(SIZE_2, SIZE_1, SIZE_3, SIZE_4); Mki::Tensor &xTensor = kernelGraph_.inTensors.at(0); Mki::Tensor &weightTensor = kernelGraph_.inTensors.at(1); @@ -217,7 +219,7 @@ Status LinearOpsRunner::SetupKernelGraphMatmulWeightNzNot910B() { ATB_LOG(INFO) << GetLogPrefix() << "LinearOpsRunner::SetupKernelGraphMatmulWeightNzNotA2"; - InitKernelGraph(SIZE_2, 1, SIZE_2, SIZE_3); + InitKernelGraph(SIZE_2, SIZE_1, SIZE_2, SIZE_3); Mki::Tensor &xTensor = kernelGraph_.inTensors.at(0); Mki::Tensor &weightTensor = kernelGraph_.inTensors.at(1); @@ -259,7 +261,7 @@ Status LinearOpsRunner::SetupKernelGraphMatmulElewiseAdd910B() { ATB_LOG(INFO) << GetLogPrefix() << "LinearOpsRunner::SetupKernelGraphMatmulElewiseAddA2"; - InitKernelGraph(SIZE_3, 1, 1, SIZE_2); + InitKernelGraph(SIZE_3, SIZE_1, SIZE_1, SIZE_2); size_t inTensorId = 0; Mki::Tensor &xTensor = kernelGraph_.inTensors.at(inTensorId++); @@ -299,7 +301,7 @@ Status LinearOpsRunner::SetupKernelGraphMatmulElewiseAddWeightNdNot910B() { ATB_LOG(INFO) << GetLogPrefix() << "LinearOpsRunner::SetupKernelGraphMatmulElewiseAddWeightNdNotA2"; - InitKernelGraph(SIZE_3, 1, SIZE_4, SIZE_5); + InitKernelGraph(SIZE_3, SIZE_1, SIZE_4, SIZE_5); size_t inTensorId = 0; Mki::Tensor &xTensor = kernelGraph_.inTensors.at(inTensorId++); @@ -358,7 +360,7 @@ Status LinearOpsRunner::SetupKernelGraphMatmulElewiseAddWeightNzNot910B() { ATB_LOG(INFO) << GetLogPrefix() << "LinearOpsRunner::SetupKernelGraphMatmulElewiseAddWeightNzNotA2"; - InitKernelGraph(SIZE_3, 1, SIZE_3, SIZE_4); + InitKernelGraph(SIZE_3, SIZE_1, SIZE_3, SIZE_4); size_t inTensorId = 0; Mki::Tensor &xTensor = kernelGraph_.inTensors.at(inTensorId++); @@ -411,7 +413,7 @@ Status LinearOpsRunner::SetupKernelGraphMatmulWithBias() { ATB_LOG(INFO) << GetLogPrefix() << "LinearOpsRunner::SetupKernelGraphMatmulWithBias"; - InitKernelGraph(SIZE_3, 1, 0, 1); + InitKernelGraph(SIZE_3, SIZE_1, SIZE_0, SIZE_1); size_t inTensorId = 0; Mki::Tensor &xTensor = kernelGraph_.inTensors.at(inTensorId++); @@ -453,10 +455,10 @@ Status LinearOpsRunner::SetupKernelGraphMatmulAccum() Mki::Tensor &outTensor = kernelGraph_.outTensors.at(0); if (xTensor.desc.dims.size() == SIZE_2 && xTensor.desc.dims.at(1) > MATMUL_TRANSPOSE_THRESHOLD) { - if (kernelGraph_.nodes.size() != 2) { + if (kernelGraph_.nodes.size() != SIZE_2) { isParamUpdated_ = true; } - InitKernelGraph(SIZE_3, 1, 1, 2); + InitKernelGraph(SIZE_3, SIZE_1, SIZE_1, SIZE_2); Mki::Tensor &transposedXtensor = kernelGraph_.internalTensors.at(0); size_t nodeId = 0; @@ -468,7 +470,7 @@ Status LinearOpsRunner::SetupKernelGraphMatmulAccum() transposeNode.inTensors = {&xTensor}; transposeNode.outTensors = {&transposedXtensor}; - bool matmulTransposeA = !param_.transposeA; + bool matmulTransposeA = !param_.transposeA; matmulParam_.transposeA = matmulTransposeA; SetupMatmulOriShape(transposedXtensor, weightTensor); @@ -484,7 +486,7 @@ Status LinearOpsRunner::SetupKernelGraphMatmulAccum() if (kernelGraph_.nodes.size() != 1) { isParamUpdated_ = true; } - InitKernelGraph(SIZE_3, 1, 0, 1); + InitKernelGraph(SIZE_3, SIZE_1, SIZE_0, SIZE_1); KernelGraphNode &matmulNode = kernelGraph_.nodes.at(0); matmulParam_.matmulType = AsdOps::OpParam::MatMul::MatMulType::MATMUL_ACCUM_ATOMIC; @@ -502,51 +504,51 @@ Status LinearOpsRunner::SetupKernelGraphMatmulAccum() Status LinearOpsRunner::SetupKernelGraphMatmulEin() { ATB_LOG(INFO) << GetLogPrefix() << "LinearOpsRunner::SetupKernelGraphMatmulEin"; - - InitKernelGraph(SIZE_2, 1, 0, 1); - + + InitKernelGraph(SIZE_2, SIZE_1, SIZE_0, SIZE_1); + size_t inTensorId = 0; Mki::Tensor &xTensor = kernelGraph_.inTensors.at(inTensorId++); Mki::Tensor &weightTensor = kernelGraph_.inTensors.at(inTensorId++); - + Mki::Tensor &outTensor = kernelGraph_.outTensors.at(0); - + KernelGraphNode &matmulNode = kernelGraph_.nodes.at(0); - + matmulParam_.matmulType = AsdOps::OpParam::MatMul::MatMulType::MATMUL_EIN_SUM; - matmulNode.opDesc = { 0, "MatMulOperation", matmulParam_ }; - matmulNode.inTensors = { &xTensor, &weightTensor }; - matmulNode.outTensors = { &outTensor }; + matmulNode.opDesc = {0, "MatMulOperation", matmulParam_}; + matmulNode.inTensors = {&xTensor, &weightTensor}; + matmulNode.outTensors = {&outTensor}; matmulNode.inTensorViewFuncs.resize(matmulNode.inTensors.size()); if (isWeightNz_) { matmulNode.inTensorViewFuncs.at(1) = matmulNzReshape_; } - + return NO_ERROR; } Status LinearOpsRunner::SetupKernelGraphMatmulEinElewiseAdd() { ATB_LOG(INFO) << GetLogPrefix() << "LinearOpsRunner::SetupKernelGraphMatmulEinElewiseAdd"; - - InitKernelGraph(SIZE_3, 1, 1, SIZE_2); - + + InitKernelGraph(SIZE_3, SIZE_1, SIZE_1, SIZE_2); + size_t inTensorId = 0; Mki::Tensor &xTensor = kernelGraph_.inTensors.at(inTensorId++); Mki::Tensor &weightTensor = kernelGraph_.inTensors.at(inTensorId++); Mki::Tensor &biasTensor = kernelGraph_.inTensors.at(inTensorId++); - + Mki::Tensor &outTensor = kernelGraph_.outTensors.at(0); Mki::Tensor &matmuloutTensor = kernelGraph_.internalTensors.at(0); - + KernelGraphNode &matmulNode = kernelGraph_.nodes.at(0); KernelGraphNode &addNode = kernelGraph_.nodes.at(1); - + matmulParam_.matmulType = AsdOps::OpParam::MatMul::MatMulType::MATMUL_EIN_SUM; - matmulNode.opDesc = { 0, "MatMulOperation", matmulParam_ }; - matmulNode.inTensors = { &xTensor, &weightTensor }; - matmulNode.outTensors = { &matmuloutTensor }; + matmulNode.opDesc = {0, "MatMulOperation", matmulParam_}; + matmulNode.inTensors = {&xTensor, &weightTensor}; + matmulNode.outTensors = {&matmuloutTensor}; matmulNode.inTensorViewFuncs.resize(matmulNode.inTensors.size()); if (isWeightNz_) { matmulNode.inTensorViewFuncs.at(1) = matmulNzReshape_; @@ -557,7 +559,7 @@ Status LinearOpsRunner::SetupKernelGraphMatmulEinElewiseAdd() addNode.outTensors = {&outTensor}; addNode.inTensorViewFuncs.resize(addNode.inTensors.size()); addNode.inTensorViewFuncs.at(1) = elewiseAddUnsqueeze_; - + return NO_ERROR; } @@ -572,7 +574,7 @@ Status LinearOpsRunner::SetupKernelGraphMatmulDequant910B() matmulParam_.quantMode = AsdOps::OpParam::MatMul::QuantMode::PER_CHANNEL_SYMM; } - InitKernelGraph(inTensorNum, 1, 0, 1); + InitKernelGraph(inTensorNum, SIZE_1, SIZE_0, SIZE_1); size_t inTensorId = 0; Mki::Tensor &xTensor = kernelGraph_.inTensors.at(inTensorId++); @@ -607,8 +609,8 @@ Status LinearOpsRunner::SetupKernelGraphMatmulDequantPerToken910B() ATB_LOG(INFO) << GetLogPrefix() << "SetupKernelGraphMatmulDequantPerTokenA2"; size_t inTensorNum = param_.hasBias ? SIZE_5 : SIZE_4; - InitKernelGraph(inTensorNum, 1, 0, 1); - + InitKernelGraph(inTensorNum, SIZE_1, SIZE_0, SIZE_1); + matmulParam_.quantMode = AsdOps::OpParam::MatMul::QuantMode::PER_TOKEN_SYMM; size_t inTensorId = 0; @@ -649,7 +651,7 @@ Status LinearOpsRunner::SetupKernelGraphMatmulDequantWeightNdNot910B() } else { matmulParam_.quantMode = AsdOps::OpParam::MatMul::QuantMode::PER_CHANNEL_SYMM; } - InitKernelGraph(inTensorNum, 1, SIZE_3, SIZE_4); + InitKernelGraph(inTensorNum, SIZE_1, SIZE_3, SIZE_4); size_t inTensorId = 0; Mki::Tensor &xTensor = kernelGraph_.inTensors.at(inTensorId++); @@ -689,7 +691,8 @@ Status LinearOpsRunner::SetupKernelGraphMatmulDequantWeightNdNot910B() matmulParam_.withBias = param_.hasBias; matmulNode.opDesc = {0, "MatMulOperation", matmulParam_}; - matmulNode.inTensors = {&transdataAOutTensor, &transdataBOutTensor, &biasTensor, &deqScaleTensor, &perTokenScaleTensor}; + matmulNode.inTensors = {&transdataAOutTensor, &transdataBOutTensor, &biasTensor, &deqScaleTensor, + &perTokenScaleTensor}; matmulNode.outTensors = {&matmulOutTensor}; transdataOutNode.opDesc = {0, "TransdataOperation", transdataNzToNdParam_}; @@ -709,7 +712,7 @@ Status LinearOpsRunner::SetupKernelGraphMatmulDequantWeightNzNot910B() } else { matmulParam_.quantMode = AsdOps::OpParam::MatMul::QuantMode::PER_CHANNEL_SYMM; } - InitKernelGraph(inTensorNum, 1, SIZE_2, SIZE_3); + InitKernelGraph(inTensorNum, SIZE_1, SIZE_2, SIZE_3); size_t inTensorId = 0; Mki::Tensor &xTensor = kernelGraph_.inTensors.at(inTensorId++); @@ -756,7 +759,7 @@ Status LinearOpsRunner::SetupKernelGraphMoeGateCorr() { ATB_LOG(INFO) << GetLogPrefix() << "LinearOpsRunner::SetupKernelGraphMoeGateCorr"; - InitKernelGraph(SIZE_2, 1, 0, 1); + InitKernelGraph(SIZE_2, SIZE_1, SIZE_0, SIZE_1); Mki::Tensor &xTensor = kernelGraph_.inTensors.at(0); Mki::Tensor &weightTensor = kernelGraph_.inTensors.at(1); diff --git a/src/ops_infer/linear_parallel/linear_parallel_operation.cpp b/src/ops_infer/linear_parallel/linear_parallel_operation.cpp index 411e3281f1bbb028f5ac05249be15db9afd3ca25..a6d19830d454e2a36657c7f5fee3182d62e37869 100644 --- a/src/ops_infer/linear_parallel/linear_parallel_operation.cpp +++ b/src/ops_infer/linear_parallel/linear_parallel_operation.cpp @@ -110,7 +110,7 @@ template <> Status CreateOperation(const infer::LinearParallelParam &opParam, Op ATB_LOG(ERROR) << "LinearParallelOperation DistributedInitCheck failed."; return ERROR_INVALID_PARAM; } - int rankSize = opParam.rankSize; + uint32_t rankSize = static_cast(opParam.rankSize); if (opParam.rankSize <= 0 || (rankSize & (rankSize - 1)) != 0) { ATB_LOG(ERROR) << "LinearParallel rankSize support power of 2 but got [" << opParam.rankSize << "]"; return ERROR_INVALID_PARAM; @@ -121,8 +121,9 @@ template <> Status CreateOperation(const infer::LinearParallelParam &opParam, Op } if (opParam.backend == "lcoc") { Status isOk; - if (CheckType(opParam, isOk)) + if (CheckType(opParam, isOk)) { return isOk; + } } *operation = new (std::nothrow) LinearParallelOperation(opParam); if (*operation == nullptr) { @@ -336,7 +337,7 @@ Status LinearParallelOperation::InferShapeCheckLinearAllReduce(const SVector infer::LinearParallelParam::QuantType::QUANT_TYPE_UNQUANT && param_.quantType < infer::LinearParallelParam::QuantType::QUANT_TYPE_MAX; - if (isQuant && inTensorDescs.at(3).dtype == ACL_FLOAT && param_.outDataType == ACL_FLOAT16) { + if (isQuant && inTensorDescs.at(3).dtype == ACL_FLOAT && param_.outDataType == ACL_FLOAT16) { // 3: deqScale ATB_LOG(ERROR) << GetLogPrefix() << "when perChannelScale's type is float, " << "outputDataType do not support float16_t"; return ERROR_INVALID_TENSOR_INI_MATCH; diff --git a/src/ops_infer/mla_preprocess/mla_preprocess_ops_runner.h b/src/ops_infer/mla_preprocess/mla_preprocess_ops_runner.h index 38820854fd1a27dfd76dd8681ee334bc5196c24d..63983c3cd300c831f96f9b8dd96c968fdc5fef1f 100644 --- a/src/ops_infer/mla_preprocess/mla_preprocess_ops_runner.h +++ b/src/ops_infer/mla_preprocess/mla_preprocess_ops_runner.h @@ -9,7 +9,6 @@ */ #ifndef ATB_MLAPREPROCESS_OPS_RUNNER_H #define ATB_MLAPREPROCESS_OPS_RUNNER_H -#include #include "atb/runner/ops_runner.h" #include "atb/infer_op_params.h" #include "atb/utils/utils_internal.h" @@ -43,4 +42,4 @@ inline bool operator==(const MlaPreprocessParam &left, const MlaPreprocessParam } // namespace infer } // namespace atb -#endif \ No newline at end of file +#endif diff --git a/src/ops_infer/mm_deq_swiglu_quant_mm_deq/mm_deq_swiglu_quant_mm_deq_operation.cpp b/src/ops_infer/mm_deq_swiglu_quant_mm_deq/mm_deq_swiglu_quant_mm_deq_operation.cpp index 3621b519e3674d4b22872e9932336aa6b36245cc..0bcb19d3a56d1a1e0df63700b39ae74c708513c7 100644 --- a/src/ops_infer/mm_deq_swiglu_quant_mm_deq/mm_deq_swiglu_quant_mm_deq_operation.cpp +++ b/src/ops_infer/mm_deq_swiglu_quant_mm_deq/mm_deq_swiglu_quant_mm_deq_operation.cpp @@ -46,8 +46,7 @@ struct XDimIndex { }; constexpr size_t WEIGHT_DIMS = 2; -template -struct WeightDimIndex { +template struct WeightDimIndex { static constexpr size_t K = IsTrans ? 1 : 0; static constexpr size_t N = IsTrans ? 0 : 1; }; @@ -86,12 +85,12 @@ bool ParamCheck(const atb::infer::MmDeqSwigluQuantMmDeqParam &opParam) return false; } - if (opParam.transposeWeightUp != false) { + if (opParam.transposeWeightUp) { ATB_LOG(ERROR) << "Param transposeWeightUp only support false."; return false; } - if (opParam.transposeWeightDown != true) { + if (!opParam.transposeWeightDown) { ATB_LOG(ERROR) << "Param transposeWeightDown only support true."; return false; } @@ -113,8 +112,8 @@ atb::Status CheckInTensorsDescDimNum(const atb::SVector &inTens ATB_LOG(ERROR) << "scale1 dim num is not support for MmDeqSwigluQuantMmDeqOperation."; return atb::ERROR_INVALID_TENSOR_DIM_NUM; } - if (!atb::TensorCheck::IsTensorDescDimNumValid( - inTensorDescs.at(InTensorIndex::PER_TOKEN_SCALE1), PER_TOKEN_SCALE_DIMS)) { + if (!atb::TensorCheck::IsTensorDescDimNumValid(inTensorDescs.at(InTensorIndex::PER_TOKEN_SCALE1), + PER_TOKEN_SCALE_DIMS)) { ATB_LOG(ERROR) << "pertokenScale1 dim num is not support for MmDeqSwigluQuantMmDeqOperation."; return atb::ERROR_INVALID_TENSOR_DIM_NUM; } @@ -131,14 +130,13 @@ atb::Status CheckInTensorsDescDimNum(const atb::SVector &inTens bool CheckX1Shape(const atb::TensorDesc &x1Desc, int64_t m) { - return x1Desc.shape.dims[XDimIndex::M] == m && - x1Desc.shape.dims[XDimIndex::K] == SUPPORTED_K1; + return x1Desc.shape.dims[XDimIndex::M] == m && x1Desc.shape.dims[XDimIndex::K] == SUPPORTED_K1; } bool CheckWeight1Shape(const atb::TensorDesc &weight1Desc) { return weight1Desc.shape.dims[WeightDimIndex::K] == SUPPORTED_K1 && - weight1Desc.shape.dims[WeightDimIndex::N] == SUPPORTED_N1; + weight1Desc.shape.dims[WeightDimIndex::N] == SUPPORTED_N1; } bool CheckScale1Shape(const atb::TensorDesc &scale1Desc) @@ -154,7 +152,7 @@ bool CheckPerTokenScale1Shape(const atb::TensorDesc &perTokenScale1Desc, int64_t bool CheckWeight2Shape(const atb::TensorDesc &weight2Desc) { return weight2Desc.shape.dims[WeightDimIndex::N] == SUPPORTED_N2 && - weight2Desc.shape.dims[WeightDimIndex::K] == SUPPORTED_K2; + weight2Desc.shape.dims[WeightDimIndex::K] == SUPPORTED_K2; } bool CheckScale2Shape(const atb::TensorDesc &scale2Desc) @@ -264,7 +262,7 @@ void MmDeqSwigluQuantMmDeqOperation::SetParam(const infer::MmDeqSwigluQuantMmDeq } Status MmDeqSwigluQuantMmDeqOperation::InferShapeImpl(const SVector &inTensorDescs, - SVector &outTensorDescs) const + SVector &outTensorDescs) const { int64_t m = OperationUtil::GetXTensorM(inTensorDescs.at(InTensorIndex::X1), false); auto &outDesc = outTensorDescs.at(0); @@ -282,7 +280,7 @@ Status MmDeqSwigluQuantMmDeqOperation::InferShapeCheckImpl(const SVector &inTensors, - const SVector &outTensors) const + const SVector &outTensors) const { Status status = CheckInTensors(inTensors); if (status != NO_ERROR) { diff --git a/src/ops_infer/multi_latent_attention/multi_latent_attention_operation.cpp b/src/ops_infer/multi_latent_attention/multi_latent_attention_operation.cpp index 02949a59047a49d436b1c3ae8cfe4162d5dc14e2..8bfc4a0b6424763d8443fae970f07ec4e3ccb6b2 100644 --- a/src/ops_infer/multi_latent_attention/multi_latent_attention_operation.cpp +++ b/src/ops_infer/multi_latent_attention/multi_latent_attention_operation.cpp @@ -54,8 +54,7 @@ static bool ParamRangeCheck(const infer::MultiLatentAttentionParam &opParam); static bool ParamCheck(const infer::MultiLatentAttentionParam &opParam); static bool ParamPrefillCheck(const infer::MultiLatentAttentionParam &opParam); -template <> -Status CreateOperation(const infer::MultiLatentAttentionParam &opParam, Operation **operation) +template <> Status CreateOperation(const infer::MultiLatentAttentionParam &opParam, Operation **operation) { if (operation == nullptr) { return ERROR_INVALID_PARAM; @@ -87,8 +86,8 @@ Status CreateOperation(const infer::MultiLatentAttentionParam &opParam, Operatio static bool ParamCheck(const infer::MultiLatentAttentionParam &opParam) { - if (opParam.headNum != 8 && opParam.headNum != 16 && opParam.headNum != 32 && // 8, 16, 32: headNum - opParam.headNum != 64 && opParam.headNum != 128) { // 64, 128: headNum + if (opParam.headNum != 8 && opParam.headNum != 16 && opParam.headNum != 32 && // 8, 16, 32: headNum + opParam.headNum != 64 && opParam.headNum != 128) { // 64, 128: headNum ATB_LOG(ERROR) << "headNum should be {8,16,32,64,128}"; return false; } @@ -187,12 +186,12 @@ MultiLatentAttentionOperation::MultiLatentAttentionOperation(const infer::MultiL param_.maskType != infer::MultiLatentAttentionParam::MaskType::MASK_TYPE_CAUSAL_MASK) { opIrKeyStr += "Mask"; } - if (param_.calcType == infer::MultiLatentAttentionParam::CalcType::CALC_TYPE_SPEC || param_. - calcType == infer::MultiLatentAttentionParam::CalcType::CALC_TYPE_SPEC_AND_RING) { + if (param_.calcType == infer::MultiLatentAttentionParam::CalcType::CALC_TYPE_SPEC || + param_.calcType == infer::MultiLatentAttentionParam::CalcType::CALC_TYPE_SPEC_AND_RING) { opIrKeyStr += "Qlens"; } - if (param_.calcType == infer::MultiLatentAttentionParam::CalcType::CALC_TYPE_RING || param_. - calcType == infer::MultiLatentAttentionParam::CalcType::CALC_TYPE_SPEC_AND_RING) { + if (param_.calcType == infer::MultiLatentAttentionParam::CalcType::CALC_TYPE_RING || + param_.calcType == infer::MultiLatentAttentionParam::CalcType::CALC_TYPE_SPEC_AND_RING) { opIrKeyStr += "IsRing"; } if (param_.cacheMode == infer::MultiLatentAttentionParam::CacheMode::INT8_NZCACHE) { @@ -219,8 +218,8 @@ uint32_t MultiLatentAttentionOperation::GetInputNum() const if (param_.calcType == infer::MultiLatentAttentionParam::CalcType::CALC_TYPE_PREFILL) { intensorNumBase++; } - if (param_.calcType == infer::MultiLatentAttentionParam::CalcType::CALC_TYPE_SPEC || param_. - calcType == infer::MultiLatentAttentionParam::CalcType::CALC_TYPE_SPEC_AND_RING) { + if (param_.calcType == infer::MultiLatentAttentionParam::CalcType::CALC_TYPE_SPEC || + param_.calcType == infer::MultiLatentAttentionParam::CalcType::CALC_TYPE_SPEC_AND_RING) { intensorNumBase++; } if (param_.cacheMode == infer::MultiLatentAttentionParam::CacheMode::INT8_NZCACHE) { @@ -231,8 +230,8 @@ uint32_t MultiLatentAttentionOperation::GetInputNum() const uint32_t MultiLatentAttentionOperation::GetOutputNum() const { - bool isRing = param_.calcType == infer::MultiLatentAttentionParam::CalcType::CALC_TYPE_RING || param_. - calcType == infer::MultiLatentAttentionParam::CalcType::CALC_TYPE_SPEC_AND_RING; + bool isRing = param_.calcType == infer::MultiLatentAttentionParam::CalcType::CALC_TYPE_RING || + param_.calcType == infer::MultiLatentAttentionParam::CalcType::CALC_TYPE_SPEC_AND_RING; return isRing ? OUT_TENSOR_NUM_2 : OUT_TENSOR_NUM_1; } @@ -241,8 +240,8 @@ Status MultiLatentAttentionOperation::InferShapeImpl(const SVector & { outTensorDescs.at(0) = inTensorDescs.at(0); outTensorDescs.at(0).dtype = inTensorDescs.at(1).dtype; - if ((param_.calcType == infer::MultiLatentAttentionParam::CalcType::CALC_TYPE_RING || param_. - calcType == infer::MultiLatentAttentionParam::CalcType::CALC_TYPE_SPEC_AND_RING)) { + if ((param_.calcType == infer::MultiLatentAttentionParam::CalcType::CALC_TYPE_RING || + param_.calcType == infer::MultiLatentAttentionParam::CalcType::CALC_TYPE_SPEC_AND_RING)) { outTensorDescs.at(1) = outTensorDescs.at(0); if (param_.cacheMode == infer::MultiLatentAttentionParam::CacheMode::INT8_NZCACHE) { outTensorDescs.at(1).dtype = ACL_FLOAT; diff --git a/src/ops_infer/multi_latent_attention/multi_latent_attention_ops_runner_prefill.h b/src/ops_infer/multi_latent_attention/multi_latent_attention_ops_runner_prefill.h index 7837fefdf8304ca6afaced0832dc155ab80cf788..1018e4cbfbc7bfff31bbf3e3327987b1a66d4867 100644 --- a/src/ops_infer/multi_latent_attention/multi_latent_attention_ops_runner_prefill.h +++ b/src/ops_infer/multi_latent_attention/multi_latent_attention_ops_runner_prefill.h @@ -11,7 +11,6 @@ #define ATB_MULTI_LATENT_ATTENTION_OPS_RUNNER_PREFILL_H #include "atb/runner/ops_runner.h" #include "atb/infer_op_params.h" -#include "atb/utils/utils_internal.h" #include "param.h" namespace atb { diff --git a/src/ops_infer/paged_attention/paged_attention_operation.cpp b/src/ops_infer/paged_attention/paged_attention_operation.cpp index 702711fa8ba7925dd986a21a92bc00ef0a1210d0..37f78af80878b184985b4cd969e14c36134383ef 100644 --- a/src/ops_infer/paged_attention/paged_attention_operation.cpp +++ b/src/ops_infer/paged_attention/paged_attention_operation.cpp @@ -34,6 +34,13 @@ static const int RAZOROFFSET_BIT = 0x00020; static const int LOGN_BIT = 0x00040; static const int QKVQUANTOFFLINE_BIT = 0x00040; static const int QKVQUANTONLINE_BIT = 0x00080; +static const int BLOCK_SIZE_DIM128 = 128; +static const int DIM0 = 0; +static const int DIM1 = 1; +static const int DIM2 = 2; +static const int DIM3 = 3; +static const int IN_MASK_IDX = 5; +static const int MAX_BLOCK_SIZE = 256; } // namespace namespace atb { @@ -466,7 +473,8 @@ Status PagedAttentionOperation::KVCacheDimCheck310P(const SVector &i ATB_LOG(ERROR) << "head_size should align 16 when format of keycache is NZ"; return ERROR_INVALID_TENSOR_DIM; } - if (headSize > 256 || headSize * blockSize > 128 * 128) { // 256: 310p headSize大小限制 // 128: 大小限制 + if (headSize > MAX_BLOCK_SIZE || + headSize * blockSize > BLOCK_SIZE_DIM128 * BLOCK_SIZE_DIM128) { // 256: 310p headSize大小限制 // 128: 大小限制 ATB_LOG(ERROR) << "head_size of keyCache should be no greater than 256 and " << "block_size * head_size should be no greater than 128 * 128"; return ERROR_INVALID_TENSOR_DIM; @@ -515,7 +523,7 @@ Status PagedAttentionOperation::KVCacheDimCheck910B(const SVector &i ATB_LOG(ERROR) << "headSize of keyCache and valueCache should be same"; return ERROR_INVALID_TENSOR_DIM; } - if (headSize > 256 || headSize * blockSize > 128 * 128) { // 256: 310p headSize大小限制 // 128: 大小限制 + if (headSize > MAX_BLOCK_SIZE || headSize * blockSize > BLOCK_SIZE_DIM128 * BLOCK_SIZE_DIM128) { ATB_LOG(ERROR) << "head_size of keyCache should be no greater than 256 and " << "block_size * head_size should be no greater than 128 * 128"; return ERROR_INVALID_TENSOR_DIM; @@ -536,14 +544,15 @@ Status PagedAttentionOperation::KVCacheDimCheck910B(const SVector &i } // 特殊场景支持blocksize 256 bool blockSize256Check = - param_.mlaVHeadSize > 0 && blockSize == 256 && param_.kvHeadNum == 1 && // 256: blockSize - (param_.headNum == 16 || param_.headNum == 32) && headSize == 576 && // 16 32: headNum 576: headSize - headSizeV == 512 && // 512: headSizeV + param_.mlaVHeadSize > 0 && blockSize == MAX_BLOCK_SIZE && param_.kvHeadNum == 1 && // 256: blockSize + (param_.headNum == 16 || param_.headNum == 32) && headSize == 576 && // 16 32: headNum 576: headSize + headSizeV == 512 && // 512: headSizeV param_.calcType != infer::PagedAttentionParam::CalcType::CALC_TYPE_SPEC; if (blockSize256Check) { return NO_ERROR; } - if (((headSize > 256 || headSizeV > 256) && blockSize > 128)) { // 128: mla blockSize大小限制 256:headsize阈值 + if (((headSize > MAX_BLOCK_SIZE || headSizeV > MAX_BLOCK_SIZE) && + blockSize > BLOCK_SIZE_DIM128)) { // 128: mla blockSize大小限制 256:headsize阈值 ATB_LOG(ERROR) << "blockSize should be no greater than 128 if headSize > 256"; return ERROR_INVALID_TENSOR_DIM; } @@ -647,17 +656,22 @@ Status PagedAttentionOperation::MaskFreeInferShapeCheck310P(const SVector().Is310P()) { - if (inTensorDescs.at(5).shape.dimNum != 4) { - ATB_LOG(ERROR) << "When maskType is mask free on Altas 300I Duo inference products, mask dim num should be 4"; + if (inTensorDescs.at(IN_MASK_IDX).shape.dimNum != 4) { // 4: PA MASK_TYPE_MASK_FREE dimNum + ATB_LOG(ERROR) + << "When maskType is mask free on Altas 300I Duo inference products, mask dim num should be 4"; return ERROR_INVALID_TENSOR_DIM; } - if (inTensorDescs.at(5).shape.dims[0] != 1 || inTensorDescs.at(5).shape.dims[1] != 8 || inTensorDescs.at(5).shape.dims[2] != 128 || inTensorDescs.at(5).shape.dims[3] != 16) { - ATB_LOG(ERROR) << "When maskType is mask free on Altas 300I Duo inference products, mask dims should be [1,8,128,16]"; + if (inTensorDescs.at(IN_MASK_IDX).shape.dims[DIM0] != 1 || // 1: mask dims [1,8,128,16] + inTensorDescs.at(IN_MASK_IDX).shape.dims[DIM1] != 8 || // 8: mask dims [1,8,128,16] + inTensorDescs.at(IN_MASK_IDX).shape.dims[DIM2] != BLOCK_SIZE_DIM128 || + inTensorDescs.at(IN_MASK_IDX).shape.dims[DIM3] != DIM_ALIGN_16_NZ) { + ATB_LOG(ERROR) << "When maskType is mask free on Altas 300I Duo inference products, mask dims should " + "be [1,8,128,16]"; return ERROR_INVALID_TENSOR_DIM; } - size_t kBlockSize = inTensorDescs.at(1).shape.dims[2]; - size_t vBlockSize = inTensorDescs.at(2).shape.dims[2]; - if (kBlockSize != 128 || vBlockSize != 128) { + size_t kBlockSize = inTensorDescs.at(DIM1).shape.dims[2]; // 1: k, 2: blockSize + size_t vBlockSize = inTensorDescs.at(DIM2).shape.dims[2]; // 2: v, 2: blockSize + if (kBlockSize != BLOCK_SIZE_DIM128 || vBlockSize != BLOCK_SIZE_DIM128) { ATB_LOG(ERROR) << "PagedAttentionOperation intensor1 and intensor2 dim2 should be 128."; return ERROR_INVALID_PARAM; } @@ -671,48 +685,54 @@ Status PagedAttentionOperation::MaskFreeInferShapeCheck310P(const SVector &inTensor) const { - if (param_.maskType == atb::infer::PagedAttentionParam::MASK_TYPE_MASK_FREE) { - if (GetSingleton().Is310P()) { - if (GetSingleton().Is310P() && param_.maskType == atb::infer::PagedAttentionParam::MASK_TYPE_MASK_FREE) { - if (inTensor.at(5).desc.shape.dimNum != 4) { - ATB_LOG(ERROR) << "When maskType is mask free on Altas 300I Duo inference products, mask dim num should be 4"; - return ERROR_INVALID_TENSOR_DIM; - } - if (inTensor.at(5).desc.shape.dims[0] != 1 || inTensor.at(5).desc.shape.dims[1] != 8 || inTensor.at(5).desc.shape.dims[2] != 128 || inTensor.at(5).desc.shape.dims[3] != 16) { - ATB_LOG(ERROR) << "When maskType is mask free on Altas 300I Duo inference products, mask dims should be [1,8,128,16]"; - return ERROR_INVALID_TENSOR_DIM; - } - } - if (inTensor.at(4).desc.shape.dimNum == 1) { - size_t batch = inTensor.at(4).desc.shape.dims[0]; - int *k_seqlen_list = static_cast(inTensor[4].hostData); - int *q_seqlen_list = static_cast(inTensor[6].hostData); - - for (size_t i = 0; i < batch; i++) { - if (k_seqlen_list[i] < q_seqlen_list[i]) { - ATB_LOG(ERROR) << "PagedAttentionOperation intensor4[i] should bigger than intensor6[i]."; - return ERROR_INVALID_PARAM; - } - if ((k_seqlen_list[i] - q_seqlen_list[i]) % 128 != 0) { - ATB_LOG(ERROR) << "PagedAttentionOperation (intensor4[i] - item in intensor6[i]) % 128 should be 0. "; - return ERROR_INVALID_PARAM; - } - } - } else { - ATB_LOG(ERROR) << "PagedAttentionOperation k_seqlen_list dims should be 1."; + if (param_.maskType != atb::infer::PagedAttentionParam::MASK_TYPE_MASK_FREE) { + return NO_ERROR; + } + if (!GetSingleton().Is310P()) { + ATB_LOG(ERROR) << "Only Altas 300I Duo inference products support mask free"; + return ERROR_INVALID_TENSOR_DIM; + } + if (GetSingleton().Is310P() && param_.maskType == atb::infer::PagedAttentionParam::MASK_TYPE_MASK_FREE) { + if (inTensor.at(IN_MASK_IDX).desc.shape.dimNum != 4) { // 4: PA MASK_TYPE_MASK_FREE dimNum + ATB_LOG(ERROR) + << "When maskType is mask free on Altas 300I Duo inference products, mask dim num should be 4"; + return ERROR_INVALID_TENSOR_DIM; + } + if (inTensor.at(IN_MASK_IDX).desc.shape.dims[DIM0] != 1 || // 1: mask dims [1,8,128,16] + inTensor.at(IN_MASK_IDX).desc.shape.dims[DIM1] != 8 || // 8: mask dims [1,8,128,16] + inTensor.at(IN_MASK_IDX).desc.shape.dims[DIM2] != BLOCK_SIZE_DIM128 || + inTensor.at(IN_MASK_IDX).desc.shape.dims[DIM3] != DIM_ALIGN_16_NZ) { + ATB_LOG(ERROR) << "When maskType is mask free on Altas 300I Duo inference products, mask dims " + "should be [1,8,128,16]"; + return ERROR_INVALID_TENSOR_DIM; + } + } + static const int KSEQLEN_INDEX4 = 4; + if (inTensor.at(KSEQLEN_INDEX4).desc.shape.dimNum == 1) { + size_t batch = inTensor.at(KSEQLEN_INDEX4).desc.shape.dims[0]; + int *kSeqlenList = static_cast(inTensor[KSEQLEN_INDEX4].hostData); + int *qSeqlenList = static_cast(inTensor[6].hostData); // 6: qSeqlen + + for (size_t i = 0; i < batch; i++) { + if (kSeqlenList[i] < qSeqlenList[i]) { + ATB_LOG(ERROR) << "PagedAttentionOperation intensor4[i] should bigger than intensor6[i]."; return ERROR_INVALID_PARAM; } - - size_t kBlockSize = inTensor.at(1).desc.shape.dims[2]; - size_t vBlockSize = inTensor.at(2).desc.shape.dims[2]; - if (kBlockSize != 128 || vBlockSize != 128) { - ATB_LOG(ERROR) << "PagedAttentionOperation intensor1 and intensor2 dim2 should be 128."; + if ((kSeqlenList[i] - qSeqlenList[i]) % BLOCK_SIZE_DIM128 != 0) { + ATB_LOG(ERROR) << "PagedAttentionOperation (intensor4[i] - item in intensor6[i]) % 128 should be 0. "; return ERROR_INVALID_PARAM; } - } else { - ATB_LOG(ERROR) << "Only Altas 300I Duo inference products support mask free"; - return ERROR_INVALID_TENSOR_DIM; } + } else { + ATB_LOG(ERROR) << "PagedAttentionOperation kSeqlenList dims should be 1."; + return ERROR_INVALID_PARAM; + } + + size_t kBlockSize = inTensor.at(1).desc.shape.dims[2]; // 1: k, 2: blockSize dim + size_t vBlockSize = inTensor.at(2).desc.shape.dims[2]; // 2: v, 2: blockSize dim + if (kBlockSize != BLOCK_SIZE_DIM128 || vBlockSize != BLOCK_SIZE_DIM128) { + ATB_LOG(ERROR) << "PagedAttentionOperation intensor1 and intensor2 dim2 should be 128."; + return ERROR_INVALID_PARAM; } return NO_ERROR; } diff --git a/src/ops_infer/paged_cache_load/paged_cache_load_operation.cpp b/src/ops_infer/paged_cache_load/paged_cache_load_operation.cpp index a18d5ab5afb0e65ca92d6a867cf24faeb1706950..63014e4df98dc673c60cc11c7b0ee1dba5523bc2 100644 --- a/src/ops_infer/paged_cache_load/paged_cache_load_operation.cpp +++ b/src/ops_infer/paged_cache_load/paged_cache_load_operation.cpp @@ -1,12 +1,12 @@ /* -* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -* This file is a part of the CANN Open Software. -* Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). -* Please refer to the License for details. You may not use this file except in compliance with the License. -* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -* See LICENSE in the root of the software repository for the full text of the License. -*/ + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ #include "paged_cache_load_operation.h" #include "paged_cache_load_ops_runner.h" @@ -37,8 +37,8 @@ static const uint32_t OUT_DIM = 3; static const uint32_t SIXTEEN = 16; static const uint32_t THIRTYTWO = 32; static const uint32_t MAX_SEQ_LEN = 2048; -static const uint32_t MAX_k = 147456; -static const uint32_t MAX_v = 147456; +static const uint32_t MAX_K = 147456; +static const uint32_t MAX_V = 147456; static const uint32_t BLOCKSIZEINDEX = 2; static const uint32_t BLOCKSIZEINDEX_ND = 1; static const uint32_t NUM_HEADS_INDEX = 2; @@ -47,6 +47,10 @@ static const uint32_t ALIGN_INT8 = 32; static const uint32_t ALIGN_FP16_BF16 = 16; static const uint32_t ZERO = 0; static const uint32_t ONE = 1; +static const uint32_t DIM0 = 0; +static const uint32_t DIM1 = 1; +static const uint32_t DIM2 = 2; +static const uint32_t DIM3 = 3; template <> Status CreateOperation(const infer::PagedCacheLoadParam &opParam, Operation **operation) { @@ -105,10 +109,10 @@ uint32_t PagedCacheLoadOperation::GetOutputNum() const } Status PagedCacheLoadOperation::InferShapeImpl(const SVector &inTensorDescs, - SVector &outTensorDescs) const + SVector &outTensorDescs) const { - outTensorDescs.at(0) = inTensorDescs.at(IN_TENSOR_4_KEY); - outTensorDescs.at(1) = inTensorDescs.at(IN_TENSOR_5_VALUE); + outTensorDescs.at(IN_TENSOR_0_KEYCACHE) = inTensorDescs.at(IN_TENSOR_4_KEY); + outTensorDescs.at(IN_TENSOR_1_VALUECACHE) = inTensorDescs.at(IN_TENSOR_5_VALUE); return NO_ERROR; } @@ -118,55 +122,63 @@ Status PagedCacheLoadOperation::InferShapeCheckImpl(const SVector &i } Status PagedCacheLoadOperation::SetupCheckImpl(const SVector &inTensors, - const SVector &outTensors) const + const SVector &outTensors) const { SVector inTensorDescs; for (size_t i = 0; i < inTensors.size(); i++) { inTensorDescs.push_back(inTensors.at(i).desc); } if (param_.kvCacheCfg == infer::PagedCacheLoadParam::KvCacheCfg::K_CACHE_V_CACHE_NZ) { // NZ - if (outTensors.at(0).desc.shape.dims[1] != - inTensorDescs.at(0).shape.dims[1] * inTensorDescs.at(0).shape.dims[OUT_DIM] || - outTensors.at(1).desc.shape.dims[1] != - inTensorDescs.at(1).shape.dims[1] * inTensorDescs.at(1).shape.dims[OUT_DIM]) { + int64_t alignCacheK = inTensorDescs.at(IN_TENSOR_0_KEYCACHE).shape.dims[DIM1] * + inTensorDescs.at(IN_TENSOR_0_KEYCACHE).shape.dims[DIM3]; + int64_t alignCacheV = inTensorDescs.at(IN_TENSOR_1_VALUECACHE).shape.dims[DIM1] * + inTensorDescs.at(IN_TENSOR_1_VALUECACHE).shape.dims[DIM3]; + if (outTensors.at(IN_TENSOR_0_KEYCACHE).desc.shape.dims[DIM1] != alignCacheK || + outTensors.at(IN_TENSOR_1_VALUECACHE).desc.shape.dims[DIM1] != alignCacheV) { ATB_LOG(ERROR) << GetLogPrefix() << "The last dimension of outTensors needs to remain aligned"; return ERROR_INVALID_TENSOR_DIM; } - } else if (outTensors.at(0).desc.shape.dims[1] != inTensorDescs.at(0).shape.dims[2] || - outTensors.at(0).desc.shape.dims[2] != inTensorDescs.at(0).shape.dims[3] || - outTensors.at(1).desc.shape.dims[1] != inTensorDescs.at(1).shape.dims[2] || - outTensors.at(1).desc.shape.dims[2] != inTensorDescs.at(1).shape.dims[3]) { - ATB_LOG(ERROR) << GetLogPrefix() << "The last dimension of outTensors needs to remain aligned"; - return ERROR_INVALID_TENSOR_DIM; - } - Status st = DimCheck(inTensorDescs); - if (st != NO_ERROR) { - return st; + } else if (outTensors.at(IN_TENSOR_0_KEYCACHE).desc.shape.dims[DIM1] != + inTensorDescs.at(IN_TENSOR_0_KEYCACHE).shape.dims[DIM2]) { + ATB_LOG(ERROR) << GetLogPrefix() << "The last dimension of outTensors needs to remain aligned"; + return ERROR_INVALID_TENSOR_DIM; + } else if (outTensors.at(IN_TENSOR_0_KEYCACHE).desc.shape.dims[DIM2] != + inTensorDescs.at(IN_TENSOR_0_KEYCACHE).shape.dims[DIM3]) { + ATB_LOG(ERROR) << GetLogPrefix() << "The last dimension of outTensors needs to remain aligned"; + return ERROR_INVALID_TENSOR_DIM; + } else if (outTensors.at(IN_TENSOR_1_VALUECACHE).desc.shape.dims[DIM1] != + inTensorDescs.at(IN_TENSOR_1_VALUECACHE).shape.dims[DIM2]) { + ATB_LOG(ERROR) << GetLogPrefix() << "The last dimension of outTensors needs to remain aligned"; + return ERROR_INVALID_TENSOR_DIM; + } else if (outTensors.at(IN_TENSOR_1_VALUECACHE).desc.shape.dims[DIM2] != + inTensorDescs.at(IN_TENSOR_1_VALUECACHE).shape.dims[DIM3]) { + ATB_LOG(ERROR) << GetLogPrefix() << "The last dimension of outTensors needs to remain aligned"; + return ERROR_INVALID_TENSOR_DIM; } - return NO_ERROR; + return DimCheck(inTensorDescs); } Status PagedCacheLoadOperation::DimCheck(const SVector &inTensorDescs) const { - int64_t numBlocks = inTensorDescs.at(IN_TENSOR_0_KEYCACHE).shape.dims[0]; // 0: keyCache - int64_t lencontext = inTensorDescs.at(IN_TENSOR_2_BLOCKTABLE).shape.dims[0]; // 2: blocktable - int64_t sumcontext = inTensorDescs.at(IN_TENSOR_4_KEY).shape.dims[0]; // 4: key - if (numBlocks != inTensorDescs.at(IN_TENSOR_1_VALUECACHE).shape.dims[0]) { + int64_t numBlocks = inTensorDescs.at(IN_TENSOR_0_KEYCACHE).shape.dims[DIM0]; // 0: keyCache + int64_t lenContext = inTensorDescs.at(IN_TENSOR_2_BLOCKTABLE).shape.dims[DIM0]; // 2: blocktable + int64_t sumcontext = inTensorDescs.at(IN_TENSOR_4_KEY).shape.dims[DIM0]; // 4: key + if (numBlocks != inTensorDescs.at(IN_TENSOR_1_VALUECACHE).shape.dims[DIM0]) { ATB_LOG(ERROR) << GetLogPrefix() << "numBlocks should be same"; return ERROR_INVALID_TENSOR_DIM; } - if (sumcontext != inTensorDescs.at(IN_TENSOR_5_VALUE).shape.dims[0]) { + if (sumcontext != inTensorDescs.at(IN_TENSOR_5_VALUE).shape.dims[DIM0]) { ATB_LOG(ERROR) << GetLogPrefix() << "sumcontextlens should be same"; return ERROR_INVALID_TENSOR_DIM; } if (param_.kvCacheCfg == infer::PagedCacheLoadParam::KvCacheCfg::K_CACHE_V_CACHE_ND) { int64_t blockSize = inTensorDescs.at(IN_TENSOR_0_KEYCACHE).shape.dims[BLOCKSIZEINDEX_ND]; // 1: keyCache - int64_t num_heads = inTensorDescs.at(IN_TENSOR_0_KEYCACHE).shape.dims[BLOCKSIZEINDEX]; // 1: keyCache + int64_t numHeads = inTensorDescs.at(IN_TENSOR_0_KEYCACHE).shape.dims[BLOCKSIZEINDEX]; // 1: keyCache if (blockSize != inTensorDescs.at(IN_TENSOR_1_VALUECACHE).shape.dims[BLOCKSIZEINDEX_ND]) { ATB_LOG(ERROR) << GetLogPrefix() << "blockSizes should be same"; return ERROR_INVALID_TENSOR_DIM; } - if (num_heads != inTensorDescs.at(IN_TENSOR_1_VALUECACHE).shape.dims[BLOCKSIZEINDEX]) { + if (numHeads != inTensorDescs.at(IN_TENSOR_1_VALUECACHE).shape.dims[BLOCKSIZEINDEX]) { ATB_LOG(ERROR) << GetLogPrefix() << "numHeads should be same"; return ERROR_INVALID_TENSOR_DIM; } @@ -175,17 +187,19 @@ Status PagedCacheLoadOperation::DimCheck(const SVector &inTensorDesc return ERROR_INVALID_TENSOR_DIM; } if (param_.isSeqLensCumsumMode) { - if (lencontext != inTensorDescs.at(IN_TENSOR_3_CONTEXTLENS).shape.dims[0] - ONE) { // 1:-1 - ATB_LOG(ERROR) << GetLogPrefix() << - "the lencontext of blocktable should match the lencontext of SeqLens when isSeqLensCumsumMode is true."; + if (lenContext != inTensorDescs.at(IN_TENSOR_3_CONTEXTLENS).shape.dims[DIM0] - ONE) { // 1:-1 + ATB_LOG(ERROR) << GetLogPrefix() + << "the lenContext of blocktable should match the lenContext of SeqLens when " + "isSeqLensCumsumMode is true."; return ERROR_INVALID_TENSOR_DIM; } - } else if (lencontext != inTensorDescs.at(IN_TENSOR_3_CONTEXTLENS).shape.dims[0]) { - ATB_LOG(ERROR) << GetLogPrefix() << - "the lencontext of blocktable should match the lencontext of SeqLens when isSeqLensCumsumMode is true."; - return ERROR_INVALID_TENSOR_DIM; + } else if (lenContext != inTensorDescs.at(IN_TENSOR_3_CONTEXTLENS).shape.dims[DIM0]) { + ATB_LOG(ERROR) << GetLogPrefix() + << "the lenContext of blocktable should match the lenContext of SeqLens when " + "isSeqLensCumsumMode is true."; + return ERROR_INVALID_TENSOR_DIM; } - } else { // NZ + } else { // NZ int64_t blockSize = inTensorDescs.at(IN_TENSOR_0_KEYCACHE).shape.dims[BLOCKSIZEINDEX]; // 1: keyCache if (blockSize != inTensorDescs.at(IN_TENSOR_1_VALUECACHE).shape.dims[BLOCKSIZEINDEX]) { ATB_LOG(ERROR) << GetLogPrefix() << "blockSizes should be same"; @@ -195,92 +209,95 @@ Status PagedCacheLoadOperation::DimCheck(const SVector &inTensorDesc ATB_LOG(ERROR) << GetLogPrefix() << "blockSize cannot be zero"; return ERROR_INVALID_TENSOR_DIM; } - if (blockSize % 16 != ZERO) { + if (blockSize % ALIGN_FP16_BF16 != ZERO) { ATB_LOG(ERROR) << GetLogPrefix() << "blockSize must be aligned to 16"; return ERROR_INVALID_TENSOR_DIM; } - if (lencontext != inTensorDescs.at(IN_TENSOR_3_CONTEXTLENS).shape.dims[0]) { + if (lenContext != inTensorDescs.at(IN_TENSOR_3_CONTEXTLENS).shape.dims[DIM0]) { ATB_LOG(ERROR) << GetLogPrefix() << "lenscontextlens should be same"; return ERROR_INVALID_TENSOR_DIM; } } return (param_.kvCacheCfg == infer::PagedCacheLoadParam::KvCacheCfg::K_CACHE_V_CACHE_NZ) ? - KVCacheDimCheck910BNZ(inTensorDescs) : KVCacheDimCheck910BND(inTensorDescs); + KVCacheDimCheck910BNZ(inTensorDescs) : + KVCacheDimCheck910BND(inTensorDescs); } Status PagedCacheLoadOperation::KVCacheDimCheck910BNZ(const SVector &inTensorDescs) const { - if (inTensorDescs.at(IN_TENSOR_0_KEYCACHE).shape.dimNum != INPUTKEY_DIM || // 0: keyCache - inTensorDescs.at(IN_TENSOR_1_VALUECACHE).shape.dimNum != INPUTVALUE_DIM || // 1: value Cache - inTensorDescs.at(IN_TENSOR_2_BLOCKTABLE).shape.dimNum != INPUTBLOCK_DIM || // 2: dim=2 + if (inTensorDescs.at(IN_TENSOR_0_KEYCACHE).shape.dimNum != INPUTKEY_DIM || // 0: keyCache + inTensorDescs.at(IN_TENSOR_1_VALUECACHE).shape.dimNum != INPUTVALUE_DIM || // 1: value Cache + inTensorDescs.at(IN_TENSOR_2_BLOCKTABLE).shape.dimNum != INPUTBLOCK_DIM || // 2: dim=2 inTensorDescs.at(IN_TENSOR_3_CONTEXTLENS).shape.dimNum != INPUTCONTEXTLENS_DIM || // 1: dim=1 - inTensorDescs.at(IN_TENSOR_4_KEY).shape.dimNum != INPUTBLOCK_DIM || // 2: dim=2 - inTensorDescs.at(IN_TENSOR_5_VALUE).shape.dimNum != INPUTBLOCK_DIM) { // 2: dim=2 + inTensorDescs.at(IN_TENSOR_4_KEY).shape.dimNum != INPUTBLOCK_DIM || // 2: dim=2 + inTensorDescs.at(IN_TENSOR_5_VALUE).shape.dimNum != INPUTBLOCK_DIM) { // 2: dim=2 ATB_LOG(ERROR) << GetLogPrefix() << "invalid intensor dimNum"; return ERROR_INVALID_TENSOR_DIM_NUM; } if (inTensorDescs.at(IN_TENSOR_0_KEYCACHE).dtype == ACL_INT8) { - if (THIRTYTWO != inTensorDescs.at(IN_TENSOR_0_KEYCACHE).shape.dims[OUT_DIM] || - THIRTYTWO!= inTensorDescs.at(IN_TENSOR_1_VALUECACHE).shape.dims[OUT_DIM]) { // 1: valueCache + if (inTensorDescs.at(IN_TENSOR_0_KEYCACHE).shape.dims[OUT_DIM] != THIRTYTWO || + inTensorDescs.at(IN_TENSOR_1_VALUECACHE).shape.dims[OUT_DIM] != THIRTYTWO) { // 1: valueCache ATB_LOG(ERROR) << GetLogPrefix() << "The last dimension of keycache and valuecache must be 32"; return ERROR_INVALID_TENSOR_DIM; } - if (MAX_k < inTensorDescs.at(IN_TENSOR_0_KEYCACHE).shape.dims[1] * THIRTYTWO || - MAX_v < inTensorDescs.at(IN_TENSOR_1_VALUECACHE).shape.dims[1] * THIRTYTWO) { + if (inTensorDescs.at(IN_TENSOR_0_KEYCACHE).shape.dims[DIM1] * THIRTYTWO > MAX_K || + inTensorDescs.at(IN_TENSOR_1_VALUECACHE).shape.dims[DIM1] * THIRTYTWO > MAX_V) { ATB_LOG(ERROR) << GetLogPrefix() << "The scend dimension of blocktables must be less than 147456"; return ERROR_INVALID_TENSOR_DIM; } - } else if (SIXTEEN != inTensorDescs.at(IN_TENSOR_0_KEYCACHE).shape.dims[OUT_DIM] || - SIXTEEN!= inTensorDescs.at(IN_TENSOR_1_VALUECACHE).shape.dims[OUT_DIM]) { // 1: valueCache - ATB_LOG(ERROR) << GetLogPrefix() << "The last dimension of keycache and valuecache must be 16"; - return ERROR_INVALID_TENSOR_DIM; - } else if (MAX_k < inTensorDescs.at(IN_TENSOR_0_KEYCACHE).shape.dims[1] * SIXTEEN || - MAX_v < inTensorDescs.at(IN_TENSOR_1_VALUECACHE).shape.dims[1] * SIXTEEN) { - ATB_LOG(ERROR) << GetLogPrefix() << "The scend dimension of blocktables must be less than 147456"; - return ERROR_INVALID_TENSOR_DIM; + } else if (inTensorDescs.at(IN_TENSOR_0_KEYCACHE).shape.dims[OUT_DIM] != SIXTEEN || + inTensorDescs.at(IN_TENSOR_1_VALUECACHE).shape.dims[OUT_DIM] != SIXTEEN) { // 1: valueCache + ATB_LOG(ERROR) << GetLogPrefix() << "The last dimension of keycache and valuecache must be 16"; + return ERROR_INVALID_TENSOR_DIM; + } else if (inTensorDescs.at(IN_TENSOR_0_KEYCACHE).shape.dims[DIM1] * SIXTEEN > MAX_K || + inTensorDescs.at(IN_TENSOR_1_VALUECACHE).shape.dims[DIM1] * SIXTEEN > MAX_V) { + ATB_LOG(ERROR) << GetLogPrefix() << "The scend dimension of blocktables must be less than 147456"; + return ERROR_INVALID_TENSOR_DIM; } return NO_ERROR; } Status PagedCacheLoadOperation::KVCacheDimCheck910BND(const SVector &inTensorDescs) const { - if (inTensorDescs.at(IN_TENSOR_0_KEYCACHE).shape.dimNum != INPUTKEY_DIM || // 0: keyCache - inTensorDescs.at(IN_TENSOR_1_VALUECACHE).shape.dimNum != INPUTVALUE_DIM || // 1: valueCache - inTensorDescs.at(IN_TENSOR_2_BLOCKTABLE).shape.dimNum != INPUTBLOCK_DIM || // 2: blockTable + if (inTensorDescs.at(IN_TENSOR_0_KEYCACHE).shape.dimNum != INPUTKEY_DIM || // 0: keyCache + inTensorDescs.at(IN_TENSOR_1_VALUECACHE).shape.dimNum != INPUTVALUE_DIM || // 1: valueCache + inTensorDescs.at(IN_TENSOR_2_BLOCKTABLE).shape.dimNum != INPUTBLOCK_DIM || // 2: blockTable inTensorDescs.at(IN_TENSOR_3_CONTEXTLENS).shape.dimNum != INPUTCONTEXTLENS_DIM || // 3: SeqLens - inTensorDescs.at(IN_TENSOR_4_KEY).shape.dimNum != OUT_DIM || // 4:key - inTensorDescs.at(IN_TENSOR_5_VALUE).shape.dimNum != OUT_DIM || // 5: value - inTensorDescs.at(IN_TENSOR_6_SEQ_STARTS).shape.dimNum != INPUTCONTEXTLENS_DIM) { // 6: seq start + inTensorDescs.at(IN_TENSOR_4_KEY).shape.dimNum != OUT_DIM || // 4:key + inTensorDescs.at(IN_TENSOR_5_VALUE).shape.dimNum != OUT_DIM || // 5: value + inTensorDescs.at(IN_TENSOR_6_SEQ_STARTS).shape.dimNum != INPUTCONTEXTLENS_DIM) { // 6: seq start ATB_LOG(ERROR) << GetLogPrefix() << "invalid intensor dimNum"; return ERROR_INVALID_TENSOR_DIM_NUM; } - int64_t num_heads = inTensorDescs.at(IN_TENSOR_0_KEYCACHE).shape.dims[NUM_HEADS_INDEX]; // 2: num heads - int64_t head_size_k = inTensorDescs.at(IN_TENSOR_0_KEYCACHE).shape.dims[HEAD_SIZE_INDEX]; // 3: head_size_k - int64_t head_size_v = inTensorDescs.at(IN_TENSOR_1_VALUECACHE).shape.dims[HEAD_SIZE_INDEX]; // 3: head_size_v - if (inTensorDescs.at(IN_TENSOR_0_KEYCACHE).dtype == ACL_INT8) { // keyCache: int8 - if (num_heads * head_size_k % ALIGN_INT8 != 0) { - ATB_LOG(ERROR) << GetLogPrefix() << "int8 ND format num_heads*head_size_k should be aligned to 32!"; + int64_t numHeads = inTensorDescs.at(IN_TENSOR_0_KEYCACHE).shape.dims[NUM_HEADS_INDEX]; // 2: num heads + int64_t headSizeK = inTensorDescs.at(IN_TENSOR_0_KEYCACHE).shape.dims[HEAD_SIZE_INDEX]; // 3: headSizeK + int64_t headSizeV = inTensorDescs.at(IN_TENSOR_1_VALUECACHE).shape.dims[HEAD_SIZE_INDEX]; // 3: headSizeV + if (inTensorDescs.at(IN_TENSOR_0_KEYCACHE).dtype == ACL_INT8) { // keyCache: int8 + if (numHeads * headSizeK % ALIGN_INT8 != 0) { + ATB_LOG(ERROR) << GetLogPrefix() << "int8 ND format numHeads*headSizeK should be aligned to 32!"; return ERROR_INVALID_TENSOR_DIM; } - } else if (num_heads * head_size_k % ALIGN_FP16_BF16 != 0) { // keyCache: fp16/bf16 - ATB_LOG(ERROR) << GetLogPrefix() << "fp16/bf16 ND format num_heads*head_size_k should be aligned to 16!"; - return ERROR_INVALID_TENSOR_DIM; + } else if (numHeads * headSizeK % ALIGN_FP16_BF16 != 0) { // keyCache: fp16/bf16 + ATB_LOG(ERROR) << GetLogPrefix() << "fp16/bf16 ND format numHeads*headSizeK should be aligned to 16!"; + return ERROR_INVALID_TENSOR_DIM; } if (inTensorDescs.at(IN_TENSOR_1_VALUECACHE).dtype == ACL_INT8) { // valueCache: int8 - if (num_heads * head_size_v % ALIGN_INT8 != 0) { - ATB_LOG(ERROR) << GetLogPrefix() << "int8 ND format num_heads*head_size_v should be aligned to 32!"; + if (numHeads * headSizeV % ALIGN_INT8 != 0) { + ATB_LOG(ERROR) << GetLogPrefix() << "int8 ND format numHeads*headSizeV should be aligned to 32!"; return ERROR_INVALID_TENSOR_DIM; } - } else if (num_heads * head_size_v % ALIGN_FP16_BF16 != 0) { // valueCache: fp16/bf16 - ATB_LOG(ERROR) << GetLogPrefix() << "fp16/bf16 ND format num_heads*head_size_v should be aligned to 16!"; - return ERROR_INVALID_TENSOR_DIM; + } else if (numHeads * headSizeV % ALIGN_FP16_BF16 != 0) { // valueCache: fp16/bf16 + ATB_LOG(ERROR) << GetLogPrefix() << "fp16/bf16 ND format numHeads*headSizeV should be aligned to 16!"; + return ERROR_INVALID_TENSOR_DIM; } - int64_t lencontext = inTensorDescs.at(IN_TENSOR_2_BLOCKTABLE).shape.dims[0]; // 2: blockTable-batch(len(contextLens)) - int64_t seqstart = inTensorDescs.at(IN_TENSOR_6_SEQ_STARTS).shape.dims[0]; // 6: seq_start-batch(len(contextLens)) + int64_t lenContext = + inTensorDescs.at(IN_TENSOR_2_BLOCKTABLE).shape.dims[DIM0]; // 2: blockTable-batch(len(contextLens)) + int64_t seqstart = + inTensorDescs.at(IN_TENSOR_6_SEQ_STARTS).shape.dims[DIM0]; // 6: seq_start-batch(len(contextLens)) if (param_.hasSeqStarts) { - if (seqstart != lencontext) { - ATB_LOG(ERROR) << GetLogPrefix() << - "the length of seq_startus should match lencontext when hasSeqStarts is true."; + if (seqstart != lenContext) { + ATB_LOG(ERROR) << GetLogPrefix() + << "the length of seq_startus should match lenContext when hasSeqStarts is true."; return ERROR_INVALID_TENSOR_DIM; } } @@ -292,4 +309,4 @@ std::shared_ptr PagedCacheLoadOperation::CreateRunner(Context &context) (void)context; return std::make_shared(param_); } -} \ No newline at end of file +} // namespace atb diff --git a/src/ops_infer/paged_cache_load/paged_cache_load_ops_runner.cpp b/src/ops_infer/paged_cache_load/paged_cache_load_ops_runner.cpp index cf13603a99c1c79426e11c63d4b104ea9952b6f2..2d807c0a74a17f365be81f3efcfc026de83cf83f 100644 --- a/src/ops_infer/paged_cache_load/paged_cache_load_ops_runner.cpp +++ b/src/ops_infer/paged_cache_load/paged_cache_load_ops_runner.cpp @@ -1,12 +1,12 @@ /* -* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -* This file is a part of the CANN Open Software. -* Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). -* Please refer to the License for details. You may not use this file except in compliance with the License. -* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -* See LICENSE in the root of the software repository for the full text of the License. -*/ + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ #include "paged_cache_load_ops_runner.h" #include "atb/utils/log.h" @@ -29,7 +29,7 @@ PagedCacheLoadOpsRunner::PagedCacheLoadOpsRunner(const infer::PagedCacheLoadPara Mki::Tensor &value = kernelGraph_.inTensors.at(inTensorStart++); if (param_.kvCacheCfg == infer::PagedCacheLoadParam::KvCacheCfg::K_CACHE_V_CACHE_NZ) { - Mki::Tensor &seq_starts = kernelGraph_.inTensors.at(3); + Mki::Tensor &seqStarts = kernelGraph_.inTensors.at(3); size_t outTensorStart = 0; Mki::Tensor &outKeyTensor = kernelGraph_.outTensors.at(outTensorStart++); @@ -44,7 +44,8 @@ PagedCacheLoadOpsRunner::PagedCacheLoadOpsRunner(const infer::PagedCacheLoadPara pagedCacheLoadParam.hasSeqStarts = param_.hasSeqStarts; pagedCacheLoadNode.opDesc = {0, "PagedCacheLoadOperation", pagedCacheLoadParam}; - pagedCacheLoadNode.inTensors = {&keyCacheTensor, &valueCacheTensor, &blockTablesTensor, &contextLens, &key, &value, &seq_starts}; + pagedCacheLoadNode.inTensors = {&keyCacheTensor, &valueCacheTensor, &blockTablesTensor, &contextLens, &key, + &value, &seqStarts}; pagedCacheLoadNode.outTensors = {&outKeyTensor, &outValueTensor}; pagedCacheLoadNode.inferShapePreFunc = [](Mki::LaunchParam &launchParam) { for (size_t i = 0; i < launchParam.GetInTensorCount(); i++) { @@ -56,7 +57,7 @@ PagedCacheLoadOpsRunner::PagedCacheLoadOpsRunner(const infer::PagedCacheLoadPara } }; } else { - Mki::Tensor &seq_starts = kernelGraph_.inTensors.at(inTensorStart++); + Mki::Tensor &seqStarts = kernelGraph_.inTensors.at(inTensorStart++); size_t outTensorStart = 0; Mki::Tensor &outKeyTensor = kernelGraph_.outTensors.at(outTensorStart++); @@ -70,7 +71,8 @@ PagedCacheLoadOpsRunner::PagedCacheLoadOpsRunner(const infer::PagedCacheLoadPara pagedCacheLoadParam.hasSeqStarts = param_.hasSeqStarts; pagedCacheLoadNode.opDesc = {0, "PagedCacheLoadOperation", pagedCacheLoadParam}; - pagedCacheLoadNode.inTensors = {&keyCacheTensor, &valueCacheTensor, &blockTablesTensor, &contextLens, &key, &value, &seq_starts}; + pagedCacheLoadNode.inTensors = {&keyCacheTensor, &valueCacheTensor, &blockTablesTensor, &contextLens, &key, + &value, &seqStarts}; pagedCacheLoadNode.outTensors = {&outKeyTensor, &outValueTensor}; pagedCacheLoadNode.inferShapePreFunc = [](Mki::LaunchParam &launchParam) { for (size_t i = 0; i < launchParam.GetInTensorCount(); i++) { diff --git a/src/ops_infer/razor_fusion_attention/razor_fusion_attention_operation.cpp b/src/ops_infer/razor_fusion_attention/razor_fusion_attention_operation.cpp index f60321e84cfeaf9c155ae4ed75d796d961cba420..8a9f68fb4a5c390a0b9eaf6c8825f5f85fad1447 100644 --- a/src/ops_infer/razor_fusion_attention/razor_fusion_attention_operation.cpp +++ b/src/ops_infer/razor_fusion_attention/razor_fusion_attention_operation.cpp @@ -1,12 +1,12 @@ /* -* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -* This file is a part of the CANN Open Software. -* Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). -* Please refer to the License for details. You may not use this file except in compliance with the License. -* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -* See LICENSE in the root of the software repository for the full text of the License. -*/ + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ #include "razor_fusion_attention_operation.h" #include "atb/utils/param_to_json.h" #include "razor_fusion_attention_ops_runner.h" @@ -158,7 +158,8 @@ Status RazorFusionAttentionOperation::DimCheck(const SVector &inTens int64_t valueDim1 = inTensorDescs.at(IN_TENSOR_2_VALUE).shape.dims[1]; if (queryDimNum == IN_TENSOR_DIM_NUM_2) { if (queryDim1 != DEFAULT_HEAD_SIZE || keyDim1 != DEFAULT_HEAD_SIZE || valueDim1 != DEFAULT_HEAD_SIZE) { - ATB_LOG(ERROR) << GetLogPrefix() << "when dimNum is 2, intensor0Dim1, intensor1Dim1 and intensor2Dim1 must equal to 128."; + ATB_LOG(ERROR) << GetLogPrefix() + << "when dimNum is 2, intensor0Dim1, intensor1Dim1 and intensor2Dim1 must equal to 128."; return ERROR_INVALID_TENSOR_DIM; } } diff --git a/src/ops_infer/relay_attention/param.cpp b/src/ops_infer/relay_attention/param.cpp index 5ca38676319c05053fdd56a54ef5502c8609f072..34f3a5282f29f7b4cd847b40e2caa0bbe97a675c 100644 --- a/src/ops_infer/relay_attention/param.cpp +++ b/src/ops_infer/relay_attention/param.cpp @@ -39,7 +39,8 @@ bool RelayAttentionVariantPackParam::BuildFromTensor(const SVector return true; } -void RelayAttentionVariantPackParam::ReintCastShapeFix(const Mki::Tensor tensor, std::vector &tensorList) +void RelayAttentionVariantPackParam::ReintCastShapeFix(const Mki::Tensor tensor, + std::vector &tensorList) const { if (tensor.desc.dims.size() - 1 != tensorList[0].desc.shape.dimNum) { size_t diffDimNum = static_cast(tensorList[0].desc.shape.dimNum); diff --git a/src/ops_infer/relay_attention/param.h b/src/ops_infer/relay_attention/param.h index b8b6c7b9dea54d7a5a7d8eb1c7be01a5e83b8a9b..68bc70392ba4db60fafb6c1fbcbb43391838c842 100644 --- a/src/ops_infer/relay_attention/param.h +++ b/src/ops_infer/relay_attention/param.h @@ -30,7 +30,7 @@ struct RelayAttentionVariantPackParam { std::vector valueShare; bool BuildFromTensor(const SVector &inTensors); bool HostDataCheck(const SVector &inTensors); - void ReintCastShapeFix(const Mki::Tensor tensor, std::vector &tensorList); + void ReintCastShapeFix(const Mki::Tensor tensor, std::vector &tensorList) const; }; } // namespace atb #endif \ No newline at end of file diff --git a/src/ops_infer/reshape_and_cache/reshape_and_cache_operation.cpp b/src/ops_infer/reshape_and_cache/reshape_and_cache_operation.cpp index 583049b7e19ea54e1705fb79d5eb604fe4d9361c..bd226080fe22a13143e55d4bf84bdc1cf817ef49 100644 --- a/src/ops_infer/reshape_and_cache/reshape_and_cache_operation.cpp +++ b/src/ops_infer/reshape_and_cache/reshape_and_cache_operation.cpp @@ -318,7 +318,7 @@ Status ReshapeAndCacheOperation::KVCacheDimCheck910BNZ(const SVector } int64_t kNumHead = inTensorDescs.at(IN_TENSOR_0_KEY).shape.dims[1]; // 1: value int64_t kHeadSize = inTensorDescs.at(IN_TENSOR_0_KEY).shape.dims[2]; // 2: kheadSize dim - int64_t VHeadSize = inTensorDescs.at(IN_TENSOR_1_VALUE).shape.dims[2]; // 2: vheadSize dim + int64_t vHeadSize = inTensorDescs.at(IN_TENSOR_1_VALUE).shape.dims[2]; // 2: vheadSize dim int64_t vNumHead = inTensorDescs.at(IN_TENSOR_1_VALUE).shape.dims[1]; // 1: value int64_t blockSize = inTensorDescs.at(IN_TENSOR_2_KEYCACHE).shape.dims[2]; // 2: keyCache; 2: blocksize if (blockSize != inTensorDescs.at(IN_TENSOR_3_VALUECACHE).shape.dims[2]) { // 3: valueCache @@ -345,7 +345,7 @@ Status ReshapeAndCacheOperation::KVCacheDimCheck910BNZ(const SVector return ERROR_INVALID_TENSOR_DIM; } } // value valueCache - if (vNumHead * VHeadSize != inTensorDescs.at(IN_TENSOR_3_VALUECACHE).shape.dims[1] * BLOCK_SIZE_16_NZ || + if (vNumHead * vHeadSize != inTensorDescs.at(IN_TENSOR_3_VALUECACHE).shape.dims[1] * BLOCK_SIZE_16_NZ || inTensorDescs.at(IN_TENSOR_3_VALUECACHE).shape.dims[3] != BLOCK_SIZE_16_NZ) { // 3: last dim ATB_LOG(ERROR) << GetLogPrefix() << "NZ format tensor dim should be aligned to 16!"; return ERROR_INVALID_TENSOR_DIM; diff --git a/src/ops_infer/ring_mla/ring_mla_operation.cpp b/src/ops_infer/ring_mla/ring_mla_operation.cpp index ffbc3233daaf70c542b4a621f4fd6b0bfaaa5ebb..31bec89e7e09286dd5ab09c9fb806f3ffbb6f9bd 100644 --- a/src/ops_infer/ring_mla/ring_mla_operation.cpp +++ b/src/ops_infer/ring_mla/ring_mla_operation.cpp @@ -390,8 +390,7 @@ bool RingMLAOperation::InputLseDimCheck(const SVector &inTensorDescs Status RingMLAOperation::InferShapeCheckImpl(const SVector &inTensorDescs) const { - Status st; - st = DimCheck(inTensorDescs); + Status st = DimCheck(inTensorDescs); if (st != NO_ERROR) { return st; } diff --git a/src/ops_infer/ring_mla/ring_mla_ops_runner.cpp b/src/ops_infer/ring_mla/ring_mla_ops_runner.cpp index 2d54ff5ae715c47056c04ca246c21b3331ceaccf..94a8ae40f364b2c114230f83f4d7f64f43accbeb 100644 --- a/src/ops_infer/ring_mla/ring_mla_ops_runner.cpp +++ b/src/ops_infer/ring_mla/ring_mla_ops_runner.cpp @@ -15,7 +15,7 @@ #include "param.h" namespace { -// query_split1, query_split2, key_split1, key_split2, value, mask, seqLen, prevOut (optional), prevLse(optional) +// querySplit1, querySplit2, keySplit1, keySplit2, value, mask, seqLen, prevOut (optional), prevLse(optional) static constexpr uint32_t VALUE_TENSOR_POS = 2; static constexpr uint32_t SEQLEN_TENSOR_POS = 6; static constexpr uint32_t IN_TENSOR_NUM = 15; @@ -36,10 +36,10 @@ RingMLAOpsRunner::RingMLAOpsRunner(const infer::RingMLAParam ¶m) kernelGraph_.outTensors.resize(OUT_TENSOR_NUM); int inTensorStart = 0; - Mki::Tensor *query_split1 = &kernelGraph_.inTensors.at(inTensorStart++); - Mki::Tensor *query_split2 = &kernelGraph_.inTensors.at(inTensorStart++); - Mki::Tensor *key_split1 = &kernelGraph_.inTensors.at(inTensorStart++); - Mki::Tensor *key_split2 = &kernelGraph_.inTensors.at(inTensorStart++); + Mki::Tensor *querySplit1 = &kernelGraph_.inTensors.at(inTensorStart++); + Mki::Tensor *querySplit2 = &kernelGraph_.inTensors.at(inTensorStart++); + Mki::Tensor *keySplit1 = &kernelGraph_.inTensors.at(inTensorStart++); + Mki::Tensor *keySplit2 = &kernelGraph_.inTensors.at(inTensorStart++); Mki::Tensor *value = &kernelGraph_.inTensors.at(inTensorStart++); Mki::Tensor *mask = &kernelGraph_.inTensors.at(inTensorStart++); Mki::Tensor *seqLen = &kernelGraph_.inTensors.at(inTensorStart++); @@ -63,9 +63,7 @@ RingMLAOpsRunner::RingMLAOpsRunner(const infer::RingMLAParam ¶m) ringMLANode.opDesc = {0, "RINGMLAOperation", ringMLAParam}; - // flashAttentionEncoderNode.inTensors = {&query_split1, query_split2, &key_split1, key_split2, value, - // mask, slopes, qkDescale, qkOffset, vpvDescale, vpvOffset, pScale, logN, prevOut, prevLse}; - ringMLANode.inTensors = {query_split1, query_split2, key_split1, key_split2, value, + ringMLANode.inTensors = {querySplit1, querySplit2, keySplit1, keySplit2, value, mask, &nullTensor_, &nullTensor_, &nullTensor_, &nullTensor_, &nullTensor_, &nullTensor_, &nullTensor_, prevOut, prevLse}; diff --git a/src/ops_infer/ring_mla/ring_mla_ops_runner.h b/src/ops_infer/ring_mla/ring_mla_ops_runner.h index dda96cce42611f63546f265846c0cf3232991424..3ebdc45ebcc7761c091c627d35e726426db4520a 100644 --- a/src/ops_infer/ring_mla/ring_mla_ops_runner.h +++ b/src/ops_infer/ring_mla/ring_mla_ops_runner.h @@ -23,7 +23,7 @@ public: private: Status ModifyKernelGraph(const OpsTensorPack &opsTensorPack) override; - void SetRingMLAParam(AtbOps::OpParam::RINGMLA &RingMLAParam); + void SetRingMLAParam(AtbOps::OpParam::RINGMLA &ringMLAParam); infer::RingMLAParam param_; bool isInputSoftmaxLse_ = false; diff --git a/src/ops_infer/scatter_elements_v2/scatter_elements_v2_ops_runner.cpp b/src/ops_infer/scatter_elements_v2/scatter_elements_v2_ops_runner.cpp index d5663f56be20e0873a64b741772e976bc2fa0786..42a5958f7e911d525c5564506ce5868128817f28 100644 --- a/src/ops_infer/scatter_elements_v2/scatter_elements_v2_ops_runner.cpp +++ b/src/ops_infer/scatter_elements_v2/scatter_elements_v2_ops_runner.cpp @@ -12,7 +12,7 @@ #include "atb/utils/log.h" #include "atb/utils/tensor_util.h" - + namespace atb { ScatterElementsV2OpsRunner::ScatterElementsV2OpsRunner(const infer::ScatterElementsV2Param ¶m) : OpsRunner("ScatterElementsV2OpsRunner", RUNNER_TYPE_GATHER), param_(param) @@ -22,14 +22,15 @@ ScatterElementsV2OpsRunner::ScatterElementsV2OpsRunner(const infer::ScatterEleme Mki::Tensor &inputTensor = kernelGraph_.inTensors.at(0); Mki::Tensor &indiceTensor = kernelGraph_.inTensors.at(1); Mki::Tensor &updateTensor = kernelGraph_.inTensors.at(2); - + // 原地写算子,无须创建outTensors kernelGraph_.outTensors.resize(0); - + kernelGraph_.nodes.resize(1); auto &scatterElementsV2Node = kernelGraph_.nodes[0]; - AsdOps::OpParam::ScatterElementsV2::ReductionType reduction = AsdOps::OpParam::ScatterElementsV2::ReductionType::NONE; + AsdOps::OpParam::ScatterElementsV2::ReductionType reduction = + AsdOps::OpParam::ScatterElementsV2::ReductionType::NONE; if (param_.reduction == atb::infer::ScatterElementsV2Param::ReductionType::NONE) { reduction = AsdOps::OpParam::ScatterElementsV2::ReductionType::NONE; } else if (param_.reduction == atb::infer::ScatterElementsV2Param::ReductionType::ADD) { @@ -37,16 +38,16 @@ ScatterElementsV2OpsRunner::ScatterElementsV2OpsRunner(const infer::ScatterEleme } else { MKI_LOG(ERROR) << "reduction only support none or add"; } - + AsdOps::OpParam::ScatterElementsV2 scatterElementsV2NodeParam = {reduction, param_.axis}; - + scatterElementsV2Node.opDesc = {0, "ScatterElementsV2Operation", scatterElementsV2NodeParam}; scatterElementsV2Node.inTensors = {&inputTensor, &indiceTensor, &updateTensor}; - + // 原地写算子,无须创建outTensors指向输入tensor scatterElementsV2Node.outTensors = {&inputTensor}; } - + ScatterElementsV2OpsRunner::~ScatterElementsV2OpsRunner() {} - + } // namespace atb \ No newline at end of file diff --git a/src/ops_infer/self_attention/self_attention_prefix_encoder_ops_runner.cpp b/src/ops_infer/self_attention/self_attention_prefix_encoder_ops_runner.cpp index f391bab79737d6610bea85201cf56fe6ca01b23f..c3096f9c73974a906033c56ddb8cb8626a7794ea 100644 --- a/src/ops_infer/self_attention/self_attention_prefix_encoder_ops_runner.cpp +++ b/src/ops_infer/self_attention/self_attention_prefix_encoder_ops_runner.cpp @@ -86,13 +86,13 @@ Status SelfAttentionPrefixEncoderOpsRunner::ModifyKernelGraph(const OpsTensorPac { // query, key, value, blockTables, mask, seqlen, kvSeqLen, slopes SelfAttentionFusionVariantPackParam newParam; - uint32_t seqlen_pos = SEQLEN_TENSOR_POS; - uint32_t kvSeqlen_pos = KVSEQLEN_TENSOR_POS; + uint32_t seqlenPos = SEQLEN_TENSOR_POS; + uint32_t kvSeqlenPos = KVSEQLEN_TENSOR_POS; if (!needMask_) { - seqlen_pos = SEQLEN_TENSOR_POS - 1; - kvSeqlen_pos = KVSEQLEN_TENSOR_POS - 1; + seqlenPos = SEQLEN_TENSOR_POS - 1; + kvSeqlenPos = KVSEQLEN_TENSOR_POS - 1; } - bool ret = newParam.BuildFromTensorPrefixEncoder(opsTensorPack.inTensors, seqlen_pos, kvSeqlen_pos); + bool ret = newParam.BuildFromTensorPrefixEncoder(opsTensorPack.inTensors, seqlenPos, kvSeqlenPos); if (!ret) { ATB_LOG(ERROR) << GetLogPrefix() << " build param from host tensor fail"; return ERROR_INVALID_PARAM; diff --git a/src/ops_infer/topk_topp_sampling/topk_topp_sampling_operation.cpp b/src/ops_infer/topk_topp_sampling/topk_topp_sampling_operation.cpp index 9a9eb15090402ef27aafe9d178eea030adc49b9e..5903ac492dbec7cd51920904df54c93a21476715 100644 --- a/src/ops_infer/topk_topp_sampling/topk_topp_sampling_operation.cpp +++ b/src/ops_infer/topk_topp_sampling/topk_topp_sampling_operation.cpp @@ -41,7 +41,7 @@ static const uint32_t LOG_PROBS_OUT_TENSOR_INDEX = 2; static const uint32_t LOG_PROBS_OUT_TENSOR_DIM = 2; static const uint32_t LAST_DIM = 1; -using atbInferTopkToppSamplingType = atb::infer::TopkToppSamplingParam::TopkToppSamplingType; +using AtbInferTopkToppSamplingType = atb::infer::TopkToppSamplingParam::TopkToppSamplingType; bool ParamCheck(const atb::infer::TopkToppSamplingParam &opParam) { @@ -60,15 +60,15 @@ OPERATION_PARAM_FUNCS(TopkToppSamplingOperation, infer::TopkToppSamplingParam) static Mki::OperationIr *GetOperationIrForTopkToppSampling(const infer::TopkToppSamplingParam ¶m) { switch (param.topkToppSamplingType) { - case atbInferTopkToppSamplingType::BATCH_TOPK_EXPONENTIAL_SAMPLING: + case AtbInferTopkToppSamplingType::BATCH_TOPK_EXPONENTIAL_SAMPLING: return GetSingleton().GetOperationIr("TopkToppSamplingBatchTopKExpOperation"); - case atbInferTopkToppSamplingType::BATCH_TOPK_EXPONENTIAL_LOGPROBS_SAMPLING: + case AtbInferTopkToppSamplingType::BATCH_TOPK_EXPONENTIAL_LOGPROBS_SAMPLING: return GetSingleton().GetOperationIr("TopkToppSamplingBatchTopKLogProbsExpOperation"); - case atbInferTopkToppSamplingType::BATCH_TOPK_MULTINOMIAL_SAMPLING: + case AtbInferTopkToppSamplingType::BATCH_TOPK_MULTINOMIAL_SAMPLING: return GetSingleton().GetOperationIr("TopkToppSamplingBatchTopKMultiOperation"); - case atbInferTopkToppSamplingType::BATCH_TOPK_MULTINOMIAL_LOGPROBS_SAMPLING: + case AtbInferTopkToppSamplingType::BATCH_TOPK_MULTINOMIAL_LOGPROBS_SAMPLING: return GetSingleton().GetOperationIr("TopkToppSamplingBatchTopKLogProbsMultiOperation"); - case atbInferTopkToppSamplingType::SINGLE_TOPK_SAMPLING: + case AtbInferTopkToppSamplingType::SINGLE_TOPK_SAMPLING: return GetSingleton().GetOperationIr("TopkToppSamplingSingleTopKOperation"); default: ATB_LOG(ERROR) << "UnSupported TopkToppSamplingType: " << param.topkToppSamplingType; @@ -89,15 +89,15 @@ TopkToppSamplingOperation::~TopkToppSamplingOperation() {} uint32_t TopkToppSamplingOperation::GetInputNum() const { switch (param_.topkToppSamplingType) { - case atbInferTopkToppSamplingType::BATCH_TOPK_EXPONENTIAL_SAMPLING: + case AtbInferTopkToppSamplingType::BATCH_TOPK_EXPONENTIAL_SAMPLING: return BATCH_TOPK_EXP_IN_TENSOR_NUM; - case atbInferTopkToppSamplingType::BATCH_TOPK_MULTINOMIAL_SAMPLING: + case AtbInferTopkToppSamplingType::BATCH_TOPK_MULTINOMIAL_SAMPLING: return BATCH_TOPK_MULTI_IN_TENSOR_NUM; - case atbInferTopkToppSamplingType::BATCH_TOPK_EXPONENTIAL_LOGPROBS_SAMPLING: + case AtbInferTopkToppSamplingType::BATCH_TOPK_EXPONENTIAL_LOGPROBS_SAMPLING: return BATCH_TOPK_EXP_IN_TENSOR_NUM; - case atbInferTopkToppSamplingType::BATCH_TOPK_MULTINOMIAL_LOGPROBS_SAMPLING: + case AtbInferTopkToppSamplingType::BATCH_TOPK_MULTINOMIAL_LOGPROBS_SAMPLING: return BATCH_TOPK_MULTI_LOGPROBS_IN_TENSOR_NUM; - case atbInferTopkToppSamplingType::SINGLE_TOPK_SAMPLING: + case AtbInferTopkToppSamplingType::SINGLE_TOPK_SAMPLING: return SINGLE_TOPK_IN_TENSOR_NUM; default: ATB_LOG(ERROR) << GetLogPrefix() << "UnSupported TopkToppSamplingType: " << param_.topkToppSamplingType; @@ -108,9 +108,9 @@ uint32_t TopkToppSamplingOperation::GetInputNum() const uint32_t TopkToppSamplingOperation::GetOutputNum() const { switch (param_.topkToppSamplingType) { - case atbInferTopkToppSamplingType::BATCH_TOPK_EXPONENTIAL_LOGPROBS_SAMPLING: + case AtbInferTopkToppSamplingType::BATCH_TOPK_EXPONENTIAL_LOGPROBS_SAMPLING: return LOG_PROBS_OUT_TENSOR_NUM; - case atbInferTopkToppSamplingType::BATCH_TOPK_MULTINOMIAL_LOGPROBS_SAMPLING: + case AtbInferTopkToppSamplingType::BATCH_TOPK_MULTINOMIAL_LOGPROBS_SAMPLING: return LOG_PROBS_OUT_TENSOR_NUM; default: return OUT_TENSOR_NUM; @@ -128,8 +128,8 @@ Status TopkToppSamplingOperation::InferShapeImpl(const SVector &inTe outTensorDescs.at(1) = inTensorDescs.at(0); outTensorDescs.at(1).shape.dims[dimNum - 1] = 1; - if (param_.topkToppSamplingType == atbInferTopkToppSamplingType::BATCH_TOPK_EXPONENTIAL_LOGPROBS_SAMPLING || - param_.topkToppSamplingType == atbInferTopkToppSamplingType::BATCH_TOPK_MULTINOMIAL_LOGPROBS_SAMPLING) { + if (param_.topkToppSamplingType == AtbInferTopkToppSamplingType::BATCH_TOPK_EXPONENTIAL_LOGPROBS_SAMPLING || + param_.topkToppSamplingType == AtbInferTopkToppSamplingType::BATCH_TOPK_MULTINOMIAL_LOGPROBS_SAMPLING) { outTensorDescs.at(OUT_TENSOR_LOGPROBS) = inTensorDescs.at(0); outTensorDescs.at(OUT_TENSOR_LOGPROBS).dtype = ACL_FLOAT; outTensorDescs.at(OUT_TENSOR_LOGPROBS).shape.dims[dimNum - 1] = param_.logProbsSize; @@ -250,15 +250,15 @@ Status TopkToppSamplingOperation::CheckIntensorAndParam(const SVector &inTensor return ERROR_INVALID_TENSOR_DIM; } } - if (param_.topkToppSamplingType == atbInferTopkToppSamplingType::BATCH_TOPK_EXPONENTIAL_LOGPROBS_SAMPLING || - param_.topkToppSamplingType == atbInferTopkToppSamplingType::BATCH_TOPK_MULTINOMIAL_LOGPROBS_SAMPLING) { - Status LogProbsOutTensorCheckRes = TopkToppLogProbsOutTensorCheck(outTensorDescs); - if (LogProbsOutTensorCheckRes != NO_ERROR) { - return LogProbsOutTensorCheckRes; + if (param_.topkToppSamplingType == AtbInferTopkToppSamplingType::BATCH_TOPK_EXPONENTIAL_LOGPROBS_SAMPLING || + param_.topkToppSamplingType == AtbInferTopkToppSamplingType::BATCH_TOPK_MULTINOMIAL_LOGPROBS_SAMPLING) { + Status logProbsOutTensorCheckRes = TopkToppLogProbsOutTensorCheck(outTensorDescs); + if (logProbsOutTensorCheckRes != NO_ERROR) { + return logProbsOutTensorCheckRes; } } return CheckIntensorAndParam(inTensorDescs); diff --git a/src/torch_atb/enger_graph_builder.cpp b/src/torch_atb/enger_graph_builder.cpp index cf3a9374bfac2eb2f36c1aa4b78137d7e69c0936..d1d1bb4095125bdba4108595932114061d66fbd0 100644 --- a/src/torch_atb/enger_graph_builder.cpp +++ b/src/torch_atb/enger_graph_builder.cpp @@ -284,7 +284,7 @@ OperationWrapper GraphBuilder::Build() } } for (const std::string &outTensorName : graphNode.outTensorIds) { - if (outTensorIds_.count(outTensorName)) { + if (outTensorIds_.count(outTensorName) != 0) { node.outTensorIds.push_back(GetTensorId(outTensorName)); } else { uint32_t id = graphParam_.inTensorNum + graphParam_.outTensorNum + internalId++; diff --git a/src/torch_atb/graph_node.cpp b/src/torch_atb/graph_node.cpp index fa207305d2e34e24215a63d37e54f2667a2e1424..d356587301a4a5c72f4b7610f02c00487b85fbd7 100644 --- a/src/torch_atb/graph_node.cpp +++ b/src/torch_atb/graph_node.cpp @@ -32,7 +32,7 @@ bool GraphNode::FindOutput(const std::string &id) const return it != outTensorIds.end(); } -void GraphNode::SetStreamId(uint32_t streamId) +void GraphNode::SetStreamId(uint32_t streamId) const { if (!operation) { throw std::runtime_error("Set execute stream id fail, operation is nullptr"); @@ -40,7 +40,7 @@ void GraphNode::SetStreamId(uint32_t streamId) SetExecuteStreamId(operation, streamId); } -uint32_t GraphNode::GetStreamId() +uint32_t GraphNode::GetStreamId() const { if (!operation) { throw std::runtime_error("Get execute stream id fail, operation is nullptr"); diff --git a/src/torch_atb/graph_node.h b/src/torch_atb/graph_node.h index bbf26ab4c68152e426aeeba915fe935856f35fab..08c4a97539165c7fd7c1a6c8d42313500de92fd6 100644 --- a/src/torch_atb/graph_node.h +++ b/src/torch_atb/graph_node.h @@ -22,8 +22,8 @@ public: void SetOperation(atb::Operation *op); std::string GetOutput(size_t index) const; bool FindOutput(const std::string &id) const; - void SetStreamId(uint32_t streamId); - uint32_t GetStreamId(); + void SetStreamId(uint32_t streamId) const; + uint32_t GetStreamId() const; std::vector inTensorIds; std::vector outTensorIds; diff --git a/src/torch_atb/resource/utils.h b/src/torch_atb/resource/utils.h index 3be3a691a1812a2273b26e280dac41c230191c60..ef23ef8344b8343c613add912a7cffc19ee703fe 100644 --- a/src/torch_atb/resource/utils.h +++ b/src/torch_atb/resource/utils.h @@ -26,4 +26,4 @@ aclrtStream GetCurrentStream(); } // namespace Utils } // namespace TorchAtb -#endif // TORCH_ATB_UTILS_H \ No newline at end of file +#endif // TORCH_ATB_UTILS_H