diff --git a/.gitignore b/.gitignore index 68b790b687ef52b8fe7e9397f5988832fbf89987..410ef0b679b8b66ab288b04cb946e0dd6c2a52c9 100644 --- a/.gitignore +++ b/.gitignore @@ -60,3 +60,4 @@ cover/ checkpoints/ imagenet_val/ *.json +data/ diff --git a/models/nlp/language_model/bert_base_squad/ixrt/CMakeLists.txt b/models/nlp/language_model/bert_base_squad/ixrt/CMakeLists.txt index 4b1d5075a2db96165f54940a58a6c2f36b20dd16..d20ef427ed6e475398ca3449ca9ea106ae6b1de2 100644 --- a/models/nlp/language_model/bert_base_squad/ixrt/CMakeLists.txt +++ b/models/nlp/language_model/bert_base_squad/ixrt/CMakeLists.txt @@ -25,6 +25,19 @@ if(DEFINED USE_TENSORRT) message(STATUS "cuda_libs = ${CUDA_LIBRARIES}") message(STATUS "cudadevrt_libs = ${CUDA_cudadevrt_LIBRARY}") + + include(FindPluginFiles) + + ################################## Compile Options ###################################### + cuda_add_library(${SHARED_TARGET} SHARED + ${PLUGIN_FILES} + ) + + target_link_libraries(${SHARED_TARGET} ${CUDA_LIBRARIES} ${CUDA_cudadevrt_LIBRARY} ${TRT_LIBRARY}) + target_link_directories(${SHARED_TARGET} PUBLIC ${CUDA_PATH}/lib64 ${TRT_LIB_PATH} ${IXRT_LIB_DIR}) + target_include_directories(${SHARED_TARGET} PUBLIC ${CUDA_PATH}/include ${TRT_INC_PATH} src PUBLIC src/common) + + else() include(FindIxrt) include(FindCompiler) @@ -35,15 +48,19 @@ else() add_definitions(-D__ILUVATAR__) string(APPEND CMAKE_CXX_FLAGS " -std=c++17") -endif() -include(FindPluginFiles) + file(GLOB_RECURSE PLUGIN_FILES ${CMAKE_CURRENT_SOURCE_DIR}/src_ixrt/*.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/src_ixrt/*.cc + ${CMAKE_CURRENT_SOURCE_DIR}/src_ixrt/*.cu) + + ################################## Compile Options ###################################### + cuda_add_library(${SHARED_TARGET} SHARED + ${PLUGIN_FILES} + ) -################################## Compile Options ###################################### -cuda_add_library(${SHARED_TARGET} SHARED - ${PLUGIN_FILES} -) + target_link_libraries(${SHARED_TARGET} ${CUDA_LIBRARIES} ${CUDA_cudadevrt_LIBRARY} ${TRT_LIBRARY}) + target_link_directories(${SHARED_TARGET} PUBLIC ${CUDA_PATH}/lib64 ${TRT_LIB_PATH} ${IXRT_LIB_DIR}) + target_include_directories(${SHARED_TARGET} PUBLIC ${CUDA_PATH}/include ${TRT_INC_PATH} src_ixrt PUBLIC src_ixrt/common) + +endif() -target_link_libraries(${SHARED_TARGET} ${CUDA_LIBRARIES} ${CUDA_cudadevrt_LIBRARY} ${TRT_LIBRARY}) -target_link_directories(${SHARED_TARGET} PUBLIC ${CUDA_PATH}/lib64 ${TRT_LIB_PATH} ${IXRT_LIB_DIR}) -target_include_directories(${SHARED_TARGET} PUBLIC ${CUDA_PATH}/include ${TRT_INC_PATH} src PUBLIC src/common) diff --git a/models/nlp/language_model/bert_base_squad/ixrt/README.md b/models/nlp/language_model/bert_base_squad/ixrt/README.md index 6d0858ac6f92edab227711c5526ce47d2b5e250c..983d4b7143865c46440983798f3f3300f14be764 100644 --- a/models/nlp/language_model/bert_base_squad/ixrt/README.md +++ b/models/nlp/language_model/bert_base_squad/ixrt/README.md @@ -29,18 +29,43 @@ bash script/prepare.sh v1_1 ## Inference -### FP16 +### On T4 + ```bash +# FP16 cd python pip install onnx pycuda # use --bs to set max_batch_size (dynamic) -bash script/build_engine --bs 32 +bash script/build_engine.sh --bs 32 bash script/inference_squad.sh --bs {batch_size} ``` +```bash +# INT8 +cd python +pip install onnx pycuda +bash script/build_engine.sh --bs 32 --int8 +bash script/inference_squad.sh --bs {batch_size} --int8 +``` +#### On iluvatar + +```bash +# FP16 +cd python/script +bash infer_bert_base_squad_fp16_ixrt.sh +``` + +```bash +# INT8 +cd python/script +bash infer_bert_base_squad_int8_ixrt.sh +``` + ## Results Model | BatchSize | Precision | FPS | ACC ------|-----------|-----------|-----|---- BERT-Base-SQuAD | 32 | fp16 | Latency QPS: 1543.40 sentences/s | "exact_match": 80.92, "f1": 88.20 +## Referenece +- [bert-base-uncased.zip 外网链接](https://drive.google.com/file/d/1_DJDdKBanqJ6h3VGhH78F9EPgE2wK_Tw/view?usp=drive_link) \ No newline at end of file diff --git a/models/nlp/language_model/bert_base_squad/ixrt/cmake/FindCuda.cmake b/models/nlp/language_model/bert_base_squad/ixrt/cmake/FindCuda.cmake index 58e39e6003cb6a0545a76f9a6fab88e44fe39caa..e8aa67dc2dc3a2a03af152038dcd54f80c0497e8 100644 --- a/models/nlp/language_model/bert_base_squad/ixrt/cmake/FindCuda.cmake +++ b/models/nlp/language_model/bert_base_squad/ixrt/cmake/FindCuda.cmake @@ -11,7 +11,7 @@ if(DEFINED ENV{CUDA_PATH}) set(CUDA_PATH "$ENV{CUDA_PATH}") else() set(CUDA_PATH - "/opt/sw_home/local/cuda" + "/usr/local/corex" CACHE PATH "cuda installation root path") endif() message(STATUS "Use CUDA_PATH=${CUDA_PATH} ") diff --git a/models/nlp/language_model/bert_base_squad/ixrt/cmake/FindIxrt.cmake b/models/nlp/language_model/bert_base_squad/ixrt/cmake/FindIxrt.cmake index 5b0f27293edaebf80cd5bfd622c363f49b36966b..3635406a9986bb5b8bedfe0145b7a5be5701df65 100644 --- a/models/nlp/language_model/bert_base_squad/ixrt/cmake/FindIxrt.cmake +++ b/models/nlp/language_model/bert_base_squad/ixrt/cmake/FindIxrt.cmake @@ -5,8 +5,10 @@ if(NOT "${IXRT_HOME}" STREQUAL "") set(IXRT_LIB_DIR ${IXRT_HOME}/lib) # From default paths else() - set(IXRT_INCLUDE_DIR /usr/local/corex/include) - set(IXRT_LIB_DIR /usr/local/corex/lib) +# set(IXRT_INCLUDE_DIR /usr/local/corex/include) +# set(IXRT_LIB_DIR /usr/local/corex/lib) + set(IXRT_INCLUDE_DIR /usr/local/lib64/python3.10/site-packages/tensorrt/include) + set(IXRT_LIB_DIR /usr/local/lib64/python3.10/site-packages/tensorrt/lib) endif() message(STATUS "IXRT_INCLUDE_DIR: ${IXRT_INCLUDE_DIR}") diff --git a/models/nlp/language_model/bert_base_squad/ixrt/python/ixrt/builder.py b/models/nlp/language_model/bert_base_squad/ixrt/python/ixrt/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..8632d95dec10d22834cf928ef8f8c940c1c12962 --- /dev/null +++ b/models/nlp/language_model/bert_base_squad/ixrt/python/ixrt/builder.py @@ -0,0 +1,394 @@ +#!/usr/bin/env python3 +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import argparse +import json +import tensorrt as trt +import time +import sys +import ctypes +import os +import numpy as np +from builder_utils import load_onnx_weights_and_quant, load_pytorch_weights_and_quant +from builder_utils import WQKV, BQKV # Attention Keys +from builder_utils import W_AOUT, B_AOUT, W_MID, B_MID, W_LOUT, B_LOUT # Transformer Keys +from builder_utils import SQD_W, SQD_B # SQuAD Output Keys + +trt_version = [int(n) for n in trt.__version__.split('.')] +plugin_lib_name = "libnvinfer_plugin.so" if os.getenv('USE_TRT') == 'True' else "libixrt_plugin.so" +print(plugin_lib_name) + +TRT_LOGGER = trt.Logger(trt.Logger.WARNING) +from load_ixrt_plugin import load_ixrt_plugin +load_ixrt_plugin(TRT_LOGGER) + +plg_registry = trt.get_plugin_registry() +registry_list = plg_registry.plugin_creator_list +print("registry_list: ", [registry.name + '/' + registry.plugin_version for registry in registry_list]) +emln_plg_creator = plg_registry.get_plugin_creator("CustomEmbLayerNormPluginDynamic_IxRT", "1", "") +qkv2_plg_creator = plg_registry.get_plugin_creator("CustomQKVToContextPluginDynamic_IxRT", "1", "") +skln_plg_creator = plg_registry.get_plugin_creator("CustomSkipLayerNormPluginDynamic_IxRT", "1", "") +ffn_plg_creator = plg_registry.get_plugin_creator("CustomFFNPluginDynamic_IxRT", "1", "") +gelu_plg_creator = plg_registry.get_plugin_creator("CustomGeluPluginDynamic_IxRT", "1", "") +fc_plg_creator = plg_registry.get_plugin_creator("CustomFCPluginDynamic_IxRT", "1", "") + +class BertConfig: + def __init__(self, bert_config_path, use_fp16, use_trt): + with open(bert_config_path, "r") as f: + data = json.load(f) + self.num_attention_heads = data["num_attention_heads"] + self.hidden_size = data["hidden_size"] + self.intermediate_size = data["intermediate_size"] + self.num_hidden_layers = data["num_hidden_layers"] + self.head_size = self.hidden_size // self.num_attention_heads + self.use_fp16 = use_fp16 + self.use_trt = use_trt + +def set_tensor_name(tensor, prefix, name): + tensor.name = prefix + name + +def set_output_name(layer, prefix, name, out_idx = 0): + set_tensor_name(layer.get_output(out_idx), prefix, name) + +def set_output_range(layer, maxval, out_idx = 0): + layer.get_output(out_idx).set_dynamic_range(-maxval, maxval) + +def get_mha_dtype(config): + dtype = trt.float32 + if config.use_fp16: + dtype = trt.float16 + return int(dtype) + +def custom_fc(network, input_tensor, out_dims, W, B): + pf_out_dims = trt.PluginField("out_dims", np.array(out_dims, dtype=np.int32), trt.PluginFieldType.INT32) + pf_type = trt.PluginField("type_id", np.array(int(trt.float16), dtype=np.int32), trt.PluginFieldType.INT32) + pf_W = trt.PluginField("W", W, trt.PluginFieldType.FLOAT32) + fields = [pf_out_dims, pf_type, pf_W] + if B is not None: + pf_B = trt.PluginField("B", B, trt.PluginFieldType.FLOAT32) + fields.append(pf_B) + + pfc = trt.PluginFieldCollection(fields) + fc_plugin = fc_plg_creator.create_plugin("fcplugin", pfc) + plug_inputs = [input_tensor] + out_dense = network.add_plugin_v2(plug_inputs, fc_plugin) + return out_dense + +def attention_layer_opt(prefix, config, init_dict, network, input_tensor, imask): + """ + Add the attention layer + """ + B, S, hidden_size = input_tensor.shape + num_heads = config.num_attention_heads + head_size = int(hidden_size / num_heads) + + Wall = init_dict[prefix + WQKV] + Ball = init_dict[prefix + BQKV] + + # FC_attention + mult_all = custom_fc(network, input_tensor, 3 * hidden_size, Wall, Ball) + + has_mask = imask is not None + # QKV2CTX + pf_type = trt.PluginField("type_id", np.array([get_mha_dtype(config)], np.int32), trt.PluginFieldType.INT32) + pf_hidden_size = trt.PluginField("hidden_size", np.array([hidden_size], np.int32), trt.PluginFieldType.INT32) + pf_num_heads = trt.PluginField("num_heads", np.array([num_heads], np.int32), trt.PluginFieldType.INT32) + pf_has_mask = trt.PluginField("has_mask", np.array([has_mask], np.int32), trt.PluginFieldType.INT32) + pfc = trt.PluginFieldCollection([pf_hidden_size, pf_num_heads, pf_has_mask, pf_type]) + qkv2ctx_plug = qkv2_plg_creator.create_plugin("qkv2ctx", pfc) + + qkv_in = [mult_all.get_output(0)] + if has_mask: + qkv_in.append(imask) + qkv2ctx = network.add_plugin_v2(qkv_in, qkv2ctx_plug) + return qkv2ctx + + +def skipln(prefix, config, init_dict, network, input_tensor, skip, bias=None): + """ + Add the skip layer + """ + idims = input_tensor.shape + hidden_size = idims[2] + + dtype = trt.float32 + if config.use_fp16: + dtype = trt.float16 + + pf_ld = trt.PluginField("ld", np.array([hidden_size], np.int32), trt.PluginFieldType.INT32) + wbeta = init_dict[prefix + "beta"] + pf_beta = trt.PluginField("beta", wbeta, trt.PluginFieldType.FLOAT32) + wgamma = init_dict[prefix + "gamma"] + pf_gamma = trt.PluginField("gamma", wgamma, trt.PluginFieldType.FLOAT32) + pf_type = trt.PluginField("type_id", np.array([int(dtype)], np.int32), trt.PluginFieldType.INT32) + + fields = [pf_ld, pf_beta, pf_gamma, pf_type ] + + if bias is not None: + pf_bias = trt.PluginField("bias", bias, trt.PluginFieldType.FLOAT32) + fields.append(pf_bias) + + pfc = trt.PluginFieldCollection(fields) + skipln_plug = skln_plg_creator.create_plugin("skipln", pfc) + + skipln_inputs = [input_tensor, skip] + layer = network.add_plugin_v2(skipln_inputs, skipln_plug) + return layer + +def ffn_trt(prefix, config, init_dict, network, input_tensor): + # FC1 + GELU + B_mid = init_dict[prefix + B_MID] + W_mid = init_dict[prefix + W_MID] + mid_dense = network.add_fully_connected(input_tensor, config.intermediate_size, W_mid, B_mid) + + dtype = trt.float32 + if config.use_fp16: + dtype = trt.float16 + pf_type = trt.PluginField("type_id", np.array([int(dtype)], np.int32), trt.PluginFieldType.INT32) + pf_ld = trt.PluginField("ld", np.array([config.hidden_size], np.int32), trt.PluginFieldType.INT32) + + pfc = trt.PluginFieldCollection([pf_type, pf_ld]) + gelu_plug = gelu_plg_creator.create_plugin("gelu", pfc) + + gelu_inputs = [mid_dense.get_output(0)] + gelu_layer = network.add_plugin_v2(gelu_inputs, gelu_plug) + + intermediate_act = gelu_layer.get_output(0) + + # FC2 + # Dense to hidden size + B_lout = init_dict[prefix + B_LOUT] + W_lout = init_dict[prefix + W_LOUT] + out_dense = network.add_fully_connected(intermediate_act, config.hidden_size, W_lout, B_lout) + B_lout = None + + out_layer = skipln(prefix + "output_layernorm_", config, init_dict, network, out_dense.get_output(0), input_tensor, B_lout) + return out_layer + +def ffn(prefix, config, init_dict, network, input_tensor): + # FC1 + GELU + B_mid = init_dict[prefix + B_MID] + W_mid = init_dict[prefix + W_MID] + B_lout = init_dict[prefix + B_LOUT] + W_lout = init_dict[prefix + W_LOUT] + pf_out_dim = trt.PluginField("out_dims", np.array(config.hidden_size, np.int32), trt.PluginFieldType.INT32) + pf_type = trt.PluginField("type_id", np.array(int(trt.float16), np.int32), trt.PluginFieldType.INT32) + pf_W1 = trt.PluginField("W1", W_mid, trt.PluginFieldType.FLOAT32) + pf_W2 = trt.PluginField("W2", W_lout, trt.PluginFieldType.FLOAT32) + pf_B1 = trt.PluginField("B1", B_mid, trt.PluginFieldType.FLOAT32) + pf_act_type = trt.PluginField("act_type", np.array(int(3), np.int32), trt.PluginFieldType.INT32) + pfc = trt.PluginFieldCollection([pf_out_dim, pf_type, pf_W1, pf_W2, pf_B1, pf_act_type]) + ffn_plug = ffn_plg_creator.create_plugin("ffn", pfc) + + ffn_inputs = [input_tensor] + ffn_layer = network.add_plugin_v2(ffn_inputs, ffn_plug) + + out_layer = skipln(prefix + "output_layernorm_", config, init_dict, network, ffn_layer.get_output(0), input_tensor, B_lout) + return out_layer + +def transformer_layer_opt(prefix, config, init_dict, network, input_tensor, imask): + """ + Add the transformer layer + """ + idims = input_tensor.shape + hidden_size = idims[2] + + context_transposed = attention_layer_opt(prefix + "attention_", config, init_dict, network, input_tensor, imask) + attention_heads = context_transposed.get_output(0) + + # FC0 + B_aout = init_dict[prefix + B_AOUT] + W_aout = init_dict[prefix + W_AOUT] + attention_out_fc = custom_fc(network, attention_heads, hidden_size, W_aout, B_aout) + B_aout = None + + skiplayer = skipln(prefix + "attention_output_layernorm_",config, init_dict, network, attention_out_fc.get_output(0), input_tensor, B_aout) + attention_ln = skiplayer.get_output(0) + + if config.use_trt: + ffn_layer = ffn_trt(prefix, config, init_dict, network, attention_ln) + else: + ffn_layer = ffn(prefix, config, init_dict, network, attention_ln) + return ffn_layer + +def bert_model(config, init_dict, network, input_tensor, input_mask): + """ + Create the bert model + """ + prev_input = input_tensor + for layer in range(0, config.num_hidden_layers): + ss = "l{}_".format(layer) + out_layer = transformer_layer_opt(ss, config, init_dict, network, prev_input, input_mask) + prev_input = out_layer.get_output(0) + return prev_input + +def squad_output(prefix, config, init_dict, network, input_tensor): + """ + Create the squad output + """ + + idims = input_tensor.shape + B, S, hidden_size = idims + + W_out = init_dict[prefix + SQD_W] + B_out = init_dict[prefix + SQD_B] + + dense = custom_fc(network, input_tensor, 2, W_out, B_out) + + if config.use_trt: + OUT = network.add_shuffle(dense.get_output(0)) + OUT.second_transpose = (1, 0, 2) + return OUT + return dense + +def emb_layernorm(builder, network, config, weights_dict, builder_config, sequence_lengths, batch_sizes): + input_ids = network.add_input(name="input_ids", dtype=trt.int32, shape=(-1 if len(batch_sizes) > 1 else batch_sizes[0], -1 if len(sequence_lengths) > 1 else sequence_lengths[0])) + segment_ids = network.add_input(name="segment_ids", dtype=trt.int32, shape=(-1 if len(batch_sizes) > 1 else batch_sizes[0], -1 if len(sequence_lengths) > 1 else sequence_lengths[0])) + input_mask = network.add_input(name="input_mask", dtype=trt.int32, shape=(-1 if len(batch_sizes) > 1 else batch_sizes[0], -1 if len(sequence_lengths) > 1 else sequence_lengths[0])) + + if len(sequence_lengths) > 1: + profile = builder.create_optimization_profile() + min_shape = (batch_sizes[0], sequence_lengths[0]) + opt_shape = (batch_sizes[1], sequence_lengths[1]) + max_shape = (batch_sizes[2], sequence_lengths[2]) + assert(sequence_lengths[0] <= sequence_lengths[1] and sequence_lengths[1] <= sequence_lengths[2]) + + print('set dynamic shape -> ', min_shape, opt_shape, max_shape) + profile.set_shape("input_ids", min_shape, opt_shape, max_shape) + profile.set_shape("segment_ids", min_shape, opt_shape, max_shape) + profile.set_shape("input_mask", min_shape, opt_shape, max_shape) + builder_config.add_optimization_profile(profile) + + wbeta = trt.PluginField("bert_embeddings_layernorm_beta", weights_dict["bert_embeddings_layernorm_beta"], trt.PluginFieldType.FLOAT32) + wgamma = trt.PluginField("bert_embeddings_layernorm_gamma", weights_dict["bert_embeddings_layernorm_gamma"], trt.PluginFieldType.FLOAT32) + wwordemb = trt.PluginField("bert_embeddings_word_embeddings", weights_dict["bert_embeddings_word_embeddings"], trt.PluginFieldType.FLOAT32) + wtokemb = trt.PluginField("bert_embeddings_token_type_embeddings", weights_dict["bert_embeddings_token_type_embeddings"], trt.PluginFieldType.FLOAT32) + wposemb = trt.PluginField("bert_embeddings_position_embeddings", weights_dict["bert_embeddings_position_embeddings"], trt.PluginFieldType.FLOAT32) + + output_fp16 = trt.PluginField("output_fp16", np.array([1 if config.use_fp16 else 0]).astype(np.int32), trt.PluginFieldType.INT32) + mha_type = trt.PluginField("mha_type_id", np.array([get_mha_dtype(config)], np.int32), trt.PluginFieldType.INT32) + + pfc = trt.PluginFieldCollection([wbeta, wgamma, wwordemb, wtokemb, wposemb, output_fp16, mha_type]) + fn = emln_plg_creator.create_plugin("embeddings", pfc) + + if config.use_trt: + input_ids = network.add_shuffle(input_ids) + input_ids.second_transpose = (1, 0) + segment_ids = network.add_shuffle(segment_ids) + segment_ids.second_transpose = (1, 0) + input_mask = network.add_shuffle(input_mask) + input_mask.second_transpose = (1, 0) + inputs = [input_ids.get_output(0), segment_ids.get_output(0), input_mask.get_output(0)] + else: + inputs = [input_ids, segment_ids, input_mask] + emb_layer = network.add_plugin_v2(inputs, fn) + return emb_layer + +def build_engine(batch_sizes, sequence_lengths, config, weights_dict): + explicit_batch_flag = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) + + builder = trt.Builder(TRT_LOGGER) + with builder.create_network(explicit_batch_flag) as network, builder.create_builder_config() as builder_config: + if config.use_fp16: + builder_config.set_flag(trt.BuilderFlag.FP16) + + # Create the network + emb_layer = emb_layernorm(builder, network, config, weights_dict, builder_config, sequence_lengths, batch_sizes) + embeddings = emb_layer.get_output(0) + mask_idx = emb_layer.get_output(1) + + bert_out = bert_model(config, weights_dict, network, embeddings, mask_idx) + + squad_logits = squad_output("cls_", config, weights_dict, network, bert_out) + squad_logits_out = squad_logits.get_output(0) + + network.mark_output(squad_logits_out) + + build_start_time = time.time() + plan = builder.build_serialized_network(network, builder_config) + build_time_elapsed = (time.time() - build_start_time) + TRT_LOGGER.log(TRT_LOGGER.INFO, "build engine in {:.3f} Sec".format(build_time_elapsed)) + return plan + +def str2bool(v): + return v.lower() in ('yes', 'true') + +def main(): + parser = argparse.ArgumentParser(description="TensorRT BERT Sample", formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument("-z", "--use_trt", type=str2bool, default=False, help = "Whether to use tensorRT or IxRT") + parser.add_argument("-x", "--onnx", required=False, help="The ONNX model file path.") + parser.add_argument("-pt", "--pytorch", required=False, help="The PyTorch checkpoint file path.") + parser.add_argument("-o", "--output", required=True, default="bert_base_384.engine", help="The bert engine file, ex bert.engine") + parser.add_argument("-b", "--batch-size", nargs='+', help="Batch size(s) to optimize for. The engine will be usable with any batch size below this, but may not be optimal for smaller sizes. Can be specified multiple times to optimize for more than one batch size.", type=int) + parser.add_argument("-s", "--sequence-length", nargs='+', help="Sequence length of the BERT model", type=int) + parser.add_argument("-c", "--config-dir", required=True, + help="The folder containing the bert_config.json, which can be downloaded e.g. from https://github.com/google-research/bert#pre-trained-models or by running download_models.py in dle/TensorFlow/LanguageModeling/BERT/data/pretrained_models_google") + parser.add_argument("-f", "--fp16", action="store_true", help="Indicates that inference should be run in FP16 precision", required=False) + parser.add_argument("-j", "--squad-json", default="squad/dev-v1.1.json", help="squad json dataset used for int8 calibration", required=False) + parser.add_argument("-v", "--vocab-file", default="./pre-trained_model/uncased_L-24_H-1024_A-16/vocab.txt", help="Path to file containing entire understandable vocab", required=False) + parser.add_argument("--verbose", action="store_true", help="Turn on verbose logger and set profiling verbosity to DETAILED", required=False) + + args, _ = parser.parse_known_args() + args.batch_size = args.batch_size or [1] + args.sequence_length = args.sequence_length or [128] + + if len(args.sequence_length) not in [1, 3]: + print("Error: You must provide either one or three integers.") + sys.exit(1) + + if len(args.batch_size) not in [1, 3]: + print("Error: You must provide either one or three integers.") + sys.exit(1) + + if args.verbose: + TRT_LOGGER.min_severity = TRT_LOGGER.VERBOSE + + bert_config_path = args.config_dir + TRT_LOGGER.log(TRT_LOGGER.INFO, "Using configuration file: {:}".format(bert_config_path)) + + config = BertConfig(bert_config_path, args.fp16, args.use_trt) + + if args.onnx != None: + weights_dict = load_onnx_weights_and_quant(args.onnx, config) + elif args.pytorch != None: + weights_dict = load_pytorch_weights_and_quant(args.pytorch, config) + else: + raise RuntimeError("You need either specify TF checkpoint using option --ckpt or ONNX using option --onnx to build TRT BERT model.") + + with build_engine(args.batch_size, args.sequence_length, config, weights_dict) as serialized_engine: + TRT_LOGGER.log(TRT_LOGGER.INFO, "Saving Engine to {:}".format(args.output)) + with open(args.output, "wb") as fout: + fout.write(serialized_engine) + TRT_LOGGER.log(TRT_LOGGER.INFO, "Done.") + +if __name__ == "__main__": + main() diff --git a/models/nlp/language_model/bert_base_squad/ixrt/python/ixrt/builder_int8.py b/models/nlp/language_model/bert_base_squad/ixrt/python/ixrt/builder_int8.py new file mode 100644 index 0000000000000000000000000000000000000000..7167882bff938a2020dfd896cacfd43572e6d5be --- /dev/null +++ b/models/nlp/language_model/bert_base_squad/ixrt/python/ixrt/builder_int8.py @@ -0,0 +1,408 @@ +#!/usr/bin/env python3 +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import argparse +import json +import tensorrt as trt +import time +import sys +import ctypes +import os +import numpy as np +from builder_utils_int8 import load_pytorch_weights_and_quant +from builder_utils_int8 import WQKV, BQKV # Attention Keys +from builder_utils_int8 import W_AOUT, B_AOUT, W_MID, B_MID, W_LOUT, B_LOUT # Transformer Keys +from builder_utils_int8 import SQD_W, SQD_B # SQuAD Output Keys +from builder import custom_fc as custom_fc_fp16 + +trt_version = [int(n) for n in trt.__version__.split('.')] + +TRT_LOGGER = trt.Logger(trt.Logger.ERROR) +from load_ixrt_plugin import load_ixrt_plugin +load_ixrt_plugin(TRT_LOGGER) + +plg_registry = trt.get_plugin_registry() +registry_list = plg_registry.plugin_creator_list +print("registry_list: ", [registry.name + '/' + registry.plugin_version for registry in registry_list]) +emln_plg_creator = plg_registry.get_plugin_creator("CustomEmbLayerNormPluginDynamic_IxRT", "2", "") +qkv2_plg_creator = plg_registry.get_plugin_creator("CustomQKVToContextPluginDynamic_IxRT", "3", "") +skln_plg_creator = plg_registry.get_plugin_creator("CustomSkipLayerNormPluginDynamic_IxRT", "3", "") +gelu_plg_creator = plg_registry.get_plugin_creator("CustomGeluPluginDynamic_IxRT", "1", "") +fc_plg_creator = plg_registry.get_plugin_creator("CustomFCPluginDynamic_IxRT", "2", "") + +# +class BertConfig: + def __init__(self, bert_config_path, use_int8): + with open(bert_config_path, "r") as f: + data = json.load(f) + self.num_attention_heads = data["num_attention_heads"] + self.hidden_size = data["hidden_size"] + self.intermediate_size = data["intermediate_size"] + self.num_hidden_layers = data["num_hidden_layers"] + self.head_size = self.hidden_size // self.num_attention_heads + self.use_int8 = use_int8 + +def set_tensor_name(tensor, prefix, name): + tensor.name = prefix + name + +def set_output_name(layer, prefix, name, out_idx = 0): + set_tensor_name(layer.get_output(out_idx), prefix, name) + +def set_output_range(layer, maxval, out_idx = 0): + layer.get_output(out_idx).set_dynamic_range(-maxval, maxval) + +def get_mha_dtype(config): + dtype = trt.float32 + if config.use_int8: + dtype = trt.int8 + return int(dtype) + +def custom_fc(prefix, config, init_dict, network, input_tensor, out_dims, W, B): + pf_out_dims = trt.PluginField("out_dims", np.array([out_dims], dtype=np.int32), trt.PluginFieldType.INT32) + pf_W = trt.PluginField("W", W, trt.PluginFieldType.FLOAT32) + + fields = [pf_out_dims, pf_W] + + if config.use_int8: + amax_vec = [init_dict[prefix + "wei_amax"]] + if B is not None: + pf_B = trt.PluginField("Bias", B, trt.PluginFieldType.FLOAT32) + amax_vec.append(init_dict[prefix + "out_amax"]) + pf_amax = trt.PluginField("fc_amax", np.array(amax_vec, np.float32), trt.PluginFieldType.FLOAT32) + fields.append(pf_B) + fields.append(pf_amax) + else: + pf_amax = trt.PluginField("fc_amax", np.array(amax_vec, np.float32), trt.PluginFieldType.FLOAT32) + fields.append(pf_amax) + + pfc = trt.PluginFieldCollection(fields) + fc_plugin = fc_plg_creator.create_plugin("fcplugin", pfc) + plug_inputs = [input_tensor] + out_dense = network.add_plugin_v2(plug_inputs, fc_plugin) + return out_dense + +def attention_layer_opt(prefix, config, init_dict, network, input_tensor, imask): + """ + Add the attention layer + """ + B, S, hidden_size = input_tensor.shape + num_heads = config.num_attention_heads + head_size = int(hidden_size / num_heads) + + Wall = init_dict[prefix + WQKV] + Ball = init_dict[prefix + BQKV] + + # FC_attention + mult_all = custom_fc(prefix + "self_qkv_", config, init_dict, network, input_tensor, 3*hidden_size, Wall, Ball) + set_output_range(mult_all, init_dict[prefix + "self_qkv_out_amax"]) + + has_mask = imask is not None + + # QKV2CTX + pf_hidden_size = trt.PluginField("hidden_size", np.array([hidden_size], np.int32), trt.PluginFieldType.INT32) + pf_num_heads = trt.PluginField("num_heads", np.array([num_heads], np.int32), trt.PluginFieldType.INT32) + fields = [pf_hidden_size, pf_num_heads] + dq_probs = [ + init_dict[prefix + "arrange_qkv_amax"], + init_dict[prefix + "softmax_in_amax"], + init_dict[prefix + "softmax_out_amax"] + ] + pf_dq = trt.PluginField("dq_probs", np.array(dq_probs, np.float32), trt.PluginFieldType.FLOAT32) + fields.append(pf_dq) + + pfc = trt.PluginFieldCollection(fields) + qkv2ctx_plug = qkv2_plg_creator.create_plugin("qkv2ctx", pfc) + + qkv_in = [mult_all.get_output(0)] + if has_mask: + qkv_in.append(imask) + qkv2ctx = network.add_plugin_v2(qkv_in, qkv2ctx_plug) + if config.use_int8: + set_output_range(qkv2ctx, init_dict[prefix + "output_dense_in_amax"]) + return qkv2ctx + + +def skipln(prefix, config, init_dict, network, input_tensor, skip, residual, is_last_layer, bias=None): + """ + Add the skip layer + """ + idims = input_tensor.shape + hidden_size = idims[2] + + dtype = trt.float32 + if config.use_int8: + dtype = trt.int8 + + wbeta = init_dict[prefix + "beta"] + wgamma = init_dict[prefix + "gamma"] + + pf_ld = trt.PluginField("ld", np.array([hidden_size], np.int32), trt.PluginFieldType.INT32) + pf_beta = trt.PluginField("beta", wbeta, trt.PluginFieldType.FLOAT32) + pf_gamma = trt.PluginField("gamma", wgamma, trt.PluginFieldType.FLOAT32) + pf_type = trt.PluginField("type_id", np.array([int(dtype)], np.int32), trt.PluginFieldType.INT32) + + fields = [pf_ld, pf_beta, pf_gamma, pf_type ] + if bias is not None: + pf_bias = trt.PluginField("bias", bias, trt.PluginFieldType.FLOAT32) + fields.append(pf_bias) + if is_last_layer: + pf_fp32 = trt.PluginField("output_fp32", np.array([1], np.int32), trt.PluginFieldType.INT32) + fields.append(pf_fp32) + + pfc = trt.PluginFieldCollection(fields) + skipln_plug = skln_plg_creator.create_plugin("skipln", pfc) + + skipln_inputs = [input_tensor, skip] + if config.use_int8: + skipln_inputs.append(residual) + layer = network.add_plugin_v2(skipln_inputs, skipln_plug) + return layer + +def ffn(prefix, config, init_dict, network, input_tensor, residual, is_last_layer): + # FC1 + GELU + B_mid = init_dict[prefix + B_MID] + W_mid = init_dict[prefix + W_MID] + + mid_dense = custom_fc(prefix + "intermediate_dense_", config, init_dict, network, input_tensor, config.intermediate_size, W_mid, None) + set_output_range(mid_dense, init_dict[prefix + "intermediate_dense_out_amax"]) + + dtype = trt.float32 + + if config.use_int8: + dtype = trt.int8 + + pf_type = trt.PluginField("type_id", np.array([int(dtype)], np.int32), trt.PluginFieldType.INT32) + pf_ld = trt.PluginField("ld", np.array([int(config.intermediate_size)], np.int32), trt.PluginFieldType.INT32) + fields = [pf_type, pf_ld] + if config.use_int8: + pf_bias = trt.PluginField("bias", B_mid, trt.PluginFieldType.FLOAT32) + fields.append(pf_bias) + + pfc = trt.PluginFieldCollection(fields) + gelu_plug = gelu_plg_creator.create_plugin("gelu", pfc) + + gelu_inputs = [mid_dense.get_output(0)] + gelu_layer = network.add_plugin_v2(gelu_inputs, gelu_plug) + + if config.use_int8: + set_output_range(gelu_layer, init_dict[prefix + "output_dense_in_amax"]) + + intermediate_act = gelu_layer.get_output(0) + # set_tensor_name(intermediate_act, prefix, "gelu") + + # FC2 + # Dense to hidden size + B_lout = init_dict[prefix + B_LOUT] + W_lout = init_dict[prefix + W_LOUT] + out_dense = custom_fc(prefix + "output_dense_", config, init_dict, network, intermediate_act, config.hidden_size, W_lout, None) + set_output_range(out_dense, init_dict[prefix + "output_dense_out_amax"]) + + out_layer = skipln(prefix + "output_layernorm_", config, init_dict, network, out_dense.get_output(0), input_tensor, residual, is_last_layer, B_lout) + return out_layer + +def transformer_layer_opt(prefix, config, init_dict, network, input_tensor, imask, residual, is_last_layer): + """ + Add the transformer layer + """ + idims = input_tensor.shape + hidden_size = idims[2] + + context_transposed = attention_layer_opt(prefix + "attention_", config, init_dict, network, input_tensor, imask) + attention_heads = context_transposed.get_output(0) + + # FC0 + B_aout = init_dict[prefix + B_AOUT] + W_aout = init_dict[prefix + W_AOUT] + attention_out_fc = custom_fc(prefix + "attention_output_dense_", config, init_dict, network, attention_heads, hidden_size, W_aout, None) + set_output_range(attention_out_fc, init_dict[prefix + "attention_output_dense_out_amax"]) + + skiplayer = skipln(prefix + "attention_output_layernorm_", config, init_dict, network, attention_out_fc.get_output(0), input_tensor, residual, False, B_aout) + if config.use_int8: + set_output_range(skiplayer, init_dict[prefix + "intermediate_dense_in_amax"]) + + ffn_layer = ffn(prefix, config, init_dict, network, skiplayer.get_output(0), skiplayer.get_output(1), is_last_layer) + return ffn_layer + +def bert_model(config, init_dict, network, input_tensor, input_mask, residual): + """ + Create the bert model + """ + prev_input = input_tensor + for layer in range(0, config.num_hidden_layers): + ss = "l{}_".format(layer) + out_layer = transformer_layer_opt(ss, config, init_dict, network, prev_input, input_mask, residual, + True if config.use_int8 and layer == config.num_hidden_layers - 1 else False) + prev_input = out_layer.get_output(0) + residual = None + if config.use_int8: + residual = out_layer.get_output(1) + if layer < config.num_hidden_layers - 1: + set_output_range(out_layer, init_dict["l{}_".format(layer+1) + "attention_self_qkv_in_amax"]) + else: + set_output_range(out_layer, 1) + + return prev_input + +def squad_output(prefix, config, init_dict, network, input_tensor): + """ + Create the squad output + """ + + idims = input_tensor.shape + B, S, hidden_size = idims + + W_out = init_dict[prefix + SQD_W] + B_out = init_dict[prefix + SQD_B] + + dense = custom_fc_fp16(network, input_tensor, 2, W_out, B_out) + return dense + +def emb_layernorm(builder, network, config, weights_dict, builder_config, sequence_lengths, batch_sizes): + input_ids = network.add_input(name="input_ids", dtype=trt.int32, shape=(-1 if len(batch_sizes) > 1 else batch_sizes[0], -1 if len(sequence_lengths) > 1 else sequence_lengths[0])) + segment_ids = network.add_input(name="segment_ids", dtype=trt.int32, shape=(-1 if len(batch_sizes) > 1 else batch_sizes[0], -1 if len(sequence_lengths) > 1 else sequence_lengths[0])) + input_mask = network.add_input(name="input_mask", dtype=trt.int32, shape=(-1 if len(batch_sizes) > 1 else batch_sizes[0], -1 if len(sequence_lengths) > 1 else sequence_lengths[0])) + + if len(sequence_lengths) > 1: + profile = builder.create_optimization_profile() + min_shape = (batch_sizes[0], sequence_lengths[0]) + opt_shape = (batch_sizes[1], sequence_lengths[1]) + max_shape = (batch_sizes[2], sequence_lengths[2]) + assert(sequence_lengths[0] <= sequence_lengths[1] and sequence_lengths[1] <= sequence_lengths[2]) + + print('set dynamic shape -> ', min_shape, opt_shape, max_shape) + profile.set_shape("input_ids", min_shape, opt_shape, max_shape) + profile.set_shape("segment_ids", min_shape, opt_shape, max_shape) + profile.set_shape("input_mask", min_shape, opt_shape, max_shape) + builder_config.add_optimization_profile(profile) + + wbeta = trt.PluginField("bert_embeddings_layernorm_beta", weights_dict["bert_embeddings_layernorm_beta"], trt.PluginFieldType.FLOAT32) + wgamma = trt.PluginField("bert_embeddings_layernorm_gamma", weights_dict["bert_embeddings_layernorm_gamma"], trt.PluginFieldType.FLOAT32) + wwordemb = trt.PluginField("bert_embeddings_word_embeddings", weights_dict["bert_embeddings_word_embeddings"], trt.PluginFieldType.FLOAT32) + wtokemb = trt.PluginField("bert_embeddings_token_type_embeddings", weights_dict["bert_embeddings_token_type_embeddings"], trt.PluginFieldType.FLOAT32) + wposemb = trt.PluginField("bert_embeddings_position_embeddings", weights_dict["bert_embeddings_position_embeddings"], trt.PluginFieldType.FLOAT32) + + output_fp16 = trt.PluginField("output_fp16", np.array([1]).astype(np.int32), trt.PluginFieldType.INT32) + mha_type = trt.PluginField("mha_type_id", np.array([get_mha_dtype(config)], np.int32), trt.PluginFieldType.INT32) + + pfc = trt.PluginFieldCollection([wbeta, wgamma, wwordemb, wtokemb, wposemb, output_fp16, mha_type]) + fn = emln_plg_creator.create_plugin("embeddings", pfc) + + inputs = [input_ids, segment_ids, input_mask] + emb_layer = network.add_plugin_v2(inputs, fn) + + if config.use_int8: + set_output_range(emb_layer, weights_dict["l0_attention_self_qkv_in_amax"]) + set_output_range(emb_layer, 1.0, 1) + return emb_layer + +def build_engine(batch_sizes, sequence_lengths, config, weights_dict): + explicit_batch_flag = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) + + builder = trt.Builder(TRT_LOGGER) + with builder.create_network(explicit_batch_flag) as network, builder.create_builder_config() as builder_config: + network = builder.create_network(explicit_batch_flag) + builder_config = builder.create_builder_config() + builder_config.set_flag(trt.BuilderFlag.INT8) + + # Create the network + emb_layer = emb_layernorm(builder, network, config, weights_dict, builder_config, sequence_lengths, batch_sizes) + embeddings = emb_layer.get_output(0) + mask_idx = emb_layer.get_output(1) + + residual_buffer = None + if config.use_int8: + residual_buffer = emb_layer.get_output(2) + + bert_out = bert_model(config, weights_dict, network, embeddings, mask_idx, residual_buffer) + + squad_logits = squad_output("cls_", config, weights_dict, network, bert_out) + squad_logits_out = squad_logits.get_output(0) + + network.mark_output(squad_logits_out) + + build_start_time = time.time() + plan = builder.build_serialized_network(network, builder_config) + build_time_elapsed = (time.time() - build_start_time) + TRT_LOGGER.log(TRT_LOGGER.INFO, "build engine in {:.3f} Sec".format(build_time_elapsed)) + return plan + +def main(): + parser = argparse.ArgumentParser(description="TensorRT BERT Sample", formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument("-x", "--onnx", required=False, help="The ONNX model file path.") + parser.add_argument("-pt", "--pytorch", required=False, help="The PyTorch checkpoint file path.") + parser.add_argument("-o", "--output", required=True, default="bert_base_384.engine", help="The bert engine file, ex bert.engine") + parser.add_argument("-b", "--batch-size", nargs='+', help="Batch size(s) to optimize for. The engine will be usable with any batch size below this, but may not be optimal for smaller sizes. Can be specified multiple times to optimize for more than one batch size.", type=int) + parser.add_argument("-s", "--sequence-length", nargs='+', help="Sequence length of the BERT model", type=int) + parser.add_argument("-c", "--config-dir", required=True, + help="The folder containing the bert_config.json, which can be downloaded e.g. from https://github.com/google-research/bert#pre-trained-models or by running download_models.py in dle/TensorFlow/LanguageModeling/BERT/data/pretrained_models_google") + parser.add_argument("-f", "--fp16", action="store_true", help="Indicates that inference should be run in FP16 precision", required=False) + parser.add_argument("-i", "--int8", action="store_true", help="Indicates that inference should be run in INT8 precision", required=False) + parser.add_argument("-j", "--squad-json", default="squad/dev-v1.1.json", help="squad json dataset used for int8 calibration", required=False) + parser.add_argument("-v", "--vocab-file", default="./pre-trained_model/uncased_L-24_H-1024_A-16/vocab.txt", help="Path to file containing entire understandable vocab", required=False) + parser.add_argument("--verbose", action="store_true", help="Turn on verbose logger and set profiling verbosity to DETAILED", required=False) + + args, _ = parser.parse_known_args() + args.batch_size = args.batch_size or [1] + args.sequence_length = args.sequence_length or [128] + + if len(args.sequence_length) not in [1, 3]: + print("Error: You must provide either one or three integers.") + sys.exit(1) + + if len(args.batch_size) not in [1, 3]: + print("Error: You must provide either one or three integers.") + sys.exit(1) + + if args.verbose: + TRT_LOGGER.min_severity = TRT_LOGGER.VERBOSE + + bert_config_path = os.path.join(args.config_dir, "config.json") + TRT_LOGGER.log(TRT_LOGGER.INFO, "Using configuration file: {:}".format(bert_config_path)) + + config = BertConfig(bert_config_path, args.int8) + + if args.onnx != None: + if args.int8: + raise RuntimeError("int8 onnx not supported now!!!") + elif args.pytorch != None: + weights_dict = load_pytorch_weights_and_quant(args.pytorch, config) + else: + raise RuntimeError("You need either specify TF checkpoint using option --ckpt or ONNX using option --onnx to build TRT BERT model.") + + # engine = build_engine(args.batch_size, args.workspace_size, args.sequence_length, config, weights_dict, args.squad_json, args.vocab_file, None, args.calib_num, args.verbose) + with build_engine(args.batch_size, args.sequence_length, config, weights_dict) as serialized_engine: + TRT_LOGGER.log(TRT_LOGGER.INFO, "Saving Engine to {:}".format(args.output)) + with open(args.output, "wb") as fout: + fout.write(serialized_engine) + TRT_LOGGER.log(TRT_LOGGER.INFO, "Done.") + +if __name__ == "__main__": + main() diff --git a/models/nlp/language_model/bert_base_squad/ixrt/python/ixrt/builder_utils.py b/models/nlp/language_model/bert_base_squad/ixrt/python/ixrt/builder_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..25018bd1c9f2da211a650f16b335613abb04a4eb --- /dev/null +++ b/models/nlp/language_model/bert_base_squad/ixrt/python/ixrt/builder_utils.py @@ -0,0 +1,269 @@ +#!/usr/bin/env python3 +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import onnx +import numpy as np +import tensorrt as trt +import json +import struct +import torch + +TRT_LOGGER = trt.Logger(trt.Logger.INFO) + +""" +Attentions Keys +""" +WQ = "self_query_kernel" +BQ = "self_query_bias" +WK = "self_key_kernel" +BK = "self_key_bias" +WV = "self_value_kernel" +BV = "self_value_bias" +WQKV = "self_qkv_kernel" +BQKV = "self_qkv_bias" + +""" +Transformer Keys +""" +W_AOUT = "attention_output_dense_kernel" +B_AOUT = "attention_output_dense_bias" +AOUT_LN_BETA = "attention_output_layernorm_beta" +AOUT_LN_GAMMA = "attention_output_layernorm_gamma" +W_MID = "intermediate_dense_kernel" +B_MID = "intermediate_dense_bias" +W_LOUT = "output_dense_kernel" +B_LOUT = "output_dense_bias" +LOUT_LN_BETA = "output_layernorm_beta" +LOUT_LN_GAMMA = "output_layernorm_gamma" + +""" +Squad Output Keys +""" +SQD_W = "squad_output_weights" +SQD_B = "squad_output_bias" + + +def get_onnx_weight_dict(tensor_dict, config): + N = config.num_attention_heads + H = config.head_size + hidden_size = config.hidden_size + + weights_dict = dict() + for outname, tensor in tensor_dict.items(): + if outname.find("_amax") != -1: + weights_dict[outname] = tensor.flatten() + elif outname.find(BQ) != -1: + prefix = outname[:outname.find(BQ)] + + Wqkv = np.zeros((3, hidden_size, hidden_size), np.float32) + Bqkv = np.zeros((3, hidden_size), np.float32) + + Wqkv[0,:,:] = tensor_dict[prefix + WQ] + Wqkv[1,:,:] = tensor_dict[prefix + WK] + Wqkv[2,:,:] = tensor_dict[prefix + WV] + Bqkv[0,:] = tensor + Bqkv[1,:] = tensor_dict[prefix + BK] + Bqkv[2,:] = tensor_dict[prefix + BV] + + if config.use_trt: + Wqkv = np.ascontiguousarray(Wqkv.reshape((3, N, H, N, H)).transpose((1,0,2,3,4))) + Bqkv = np.ascontiguousarray(Bqkv.reshape((3, N, H)).transpose((1,0,2))) + + weights_dict[prefix + WQKV] = Wqkv.flatten() + weights_dict[prefix + BQKV] = Bqkv.flatten() + weights_dict[prefix + WQKV + "_notrans"] = np.ascontiguousarray(Wqkv.T).flatten() + + elif outname.find(BK) != -1 or outname.find(BV) != -1 or outname.find(WQ) != -1 or outname.find(WK) != -1 or outname.find(WV) != -1: + pass + else: + flat_tensor = np.ascontiguousarray(tensor).flatten() + weights_dict[outname] = flat_tensor + + if outname.find("kernel") != -1 and config.use_trt: + tensor = np.transpose(tensor) + weights_dict[outname + "_notrans"] = np.ascontiguousarray(tensor).flatten() + + return weights_dict + +def onnx_to_trt_name(onnx_name): + """ + Converting variables in the onnx checkpoint to names corresponding to the naming convention used in the TF version, expected by the builder + """ + qkv_strings = {'key', 'value', 'query', 'query_key_value'} + onnx_name = onnx_name.lower() + toks = [t.strip('_') for t in onnx_name.split('.')] + if toks[0] == 'bert': #embeddings or encoder + if toks[1] == 'encoder': #transformer + # Token conversions for sparse checkpoints + if toks[-2] == 'dense_act': + toks[-2] = 'dense' + elif toks[-3] == 'dense_act': + if toks[-2] == 'input_quantizer': + toks[-2] = 'input' + elif toks[-2] == 'weight_quantizer': + toks[-2] = 'kernel' + toks[-3] = 'dense' + elif toks[-2].startswith('matmul'): + toks[-2] = { + 'matmul_q_quantizer': 'qv_a_input_quantizer', + 'matmul_k_quantizer': 'qv_b_input_quantizer', + 'matmul_v_quantizer': 'av_b_input_quantizer', + 'matmul_a_quantizer': 'av_a_input_quantizer', + }[toks[-2].replace('input_', '')] + + # Token conversions for all checkpoints + if toks[-2] == 'layernorm': #bias->beta, weight->gamma + toks[-1] = 'beta' if toks[-1] == 'bias' else 'gamma' + elif (toks[-2] == 'dense' or toks[-2] in qkv_strings) and toks[-1] == 'weight': + toks[-1] = 'kernel' + elif (toks[-3] == 'dense' or toks[-3] in qkv_strings) and toks[-1] == 'amax': + if toks[-2] == 'weight_quantizer': + toks[-2] = 'kernel' + elif toks[-2] == 'input_quantizer': + toks[-2] = 'input' + + if 'final_input_quantizer' not in toks[2]: + ind = toks.index('layers')+1 if 'layers' in toks else 3 + toks = toks[ind:] + toks[0] = 'l{}'.format(int(toks[0])) + else: + if toks[-2] == 'layernorm': #bias->beta, weight->gamma + toks[-1] = 'beta' if toks[-1] == 'bias' else 'gamma' + else: #embeddings: drop "_weight" suffix + if toks[-1] == 'amax': + toks[-2] = 'amax' + toks = toks[:-1] + elif 'qa' in onnx_name: + name = 'cls_squad_output_bias' if toks[-1] == 'bias' else 'cls_squad_output_weights' + return name + else: + print("Encountered unknown case:", onnx_name) + assert(False) + parsed = '_'.join(toks) + return parsed + +def pt_to_trt_name(pt_name): + """ + Converting variables in the onnx checkpoint to names corresponding to the naming convention used in the TF version, expected by the builder + """ + qkv_strings = {'key', 'value', 'query', 'query_key_value'} + pt_name = pt_name.lower() + toks = [t.strip('_') for t in pt_name.split('.')] + if toks[0] == 'bert': #embeddings or encoder + if toks[1] == 'encoder': #transformer + if toks[-2] == 'layernorm': #bias->beta, weight->gamma + toks[-1] = 'beta' if toks[-1] == 'bias' else 'gamma' + elif (toks[-2] == 'dense' or toks[-2] in qkv_strings) and toks[-1] == 'weight': + toks[-1] = 'kernel' + + if 'final_input_quantizer' not in toks[2]: + ind = toks.index('layers')+1 if 'layers' in toks else 3 + toks = toks[ind:] + toks[0] = 'l{}'.format(int(toks[0])) + + else: + if toks[-2] == 'layernorm': #bias->beta, weight->gamma + toks[-1] = 'beta' if toks[-1] == 'bias' else 'gamma' + else: #embeddings: drop "_weight" suffix + toks = toks[:-1] + + elif 'qa_outputs' in pt_name: ## + name = 'cls_squad_output_bias' if toks[-1] == 'bias' else 'cls_squad_output_weights' + return name + else: + print("Encountered unknown case:", pt_name) + assert(False) + parsed = '_'.join(toks) + return parsed + +def load_onnx_weights_and_quant(path, config): + """ + Load the weights from the onnx checkpoint + """ + model = onnx.load(path) + weights = model.graph.initializer + # for w in weights: + # print(w.name, w.dims,flush=True) + tensor_dict = dict((onnx_to_trt_name(w.name), np.frombuffer(w.raw_data, np.int8).reshape(w.dims)) + if w.name.split('_')[-1] == 'mask' else + (onnx_to_trt_name(w.name), np.frombuffer(w.raw_data, np.float32).reshape(w.dims)) + for w in weights) + # for key in tensor_dict: + # print(key, tensor_dict[key].shape,flush=True) + + return get_onnx_weight_dict(tensor_dict, config) + +def load_pytorch_weights_and_quant(path, config): + """ + Load the weights from the pytorch checkpoint + """ + state_dict = torch.load(path, map_location='cpu') + # for name in state_dict: + # print(name, state_dict[name].size(),flush=True) + tensor_dict = {pt_to_trt_name(name):val.numpy() for name, val in state_dict.items()} + # for key in tensor_dict: + # print(key, tensor_dict[key].shape,flush=True) + return get_onnx_weight_dict(tensor_dict, config) + +class BertConfig: + def __init__(self, bert_config_path, use_fp16, use_int8=False): + with open(bert_config_path, "r") as f: + data = json.load(f) + self.num_attention_heads = data["num_attention_heads"] + self.hidden_size = data["hidden_size"] + self.intermediate_size = data["intermediate_size"] + self.num_hidden_layers = data["num_hidden_layers"] + self.head_size = self.hidden_size // self.num_attention_heads + self.use_fp16 = use_fp16 + self.use_int8 = use_int8 + +if __name__ == '__main__': + bert_config_path = '../bert-large-uncased/bert_config.json' + onnx_model_path = '../bert-large-uncased/bert_large_v1_1_fake_quant.onnx' + weight_save_path = "../bert-large-uncased/bert_large_v1_1.wts" + config = config = BertConfig(bert_config_path, True) + weights_dict = load_onnx_weights_and_quant(onnx_model_path, config) + f = open(weight_save_path, "w") + num = 0 + for key, value in weights_dict.items(): + if key.find('_amax') == -1: + num += 1 + + f.write('{}\n'.format(num)) + for key, value in weights_dict.items(): + print('key: ', key) + if key.find('_amax') != -1: + continue + f.write("{} {}".format(key, len(value))) + print(len(value)) + for v in value: + f.write(" ") + f.write(struct.pack('>f', float(v)).hex()) + f.write("\n") diff --git a/models/nlp/language_model/bert_base_squad/ixrt/python/ixrt/builder_utils_int8.py b/models/nlp/language_model/bert_base_squad/ixrt/python/ixrt/builder_utils_int8.py new file mode 100644 index 0000000000000000000000000000000000000000..67a53f05b4fbaba98420924abe3a4d7afdbd01bd --- /dev/null +++ b/models/nlp/language_model/bert_base_squad/ixrt/python/ixrt/builder_utils_int8.py @@ -0,0 +1,208 @@ +#!/usr/bin/env python3 +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import tensorrt as trt +import json +import struct +import torch + +TRT_LOGGER = trt.Logger(trt.Logger.INFO) + +""" +Attentions Keys +""" +WQ = "self_query_kernel" +BQ = "self_query_bias" +WK = "self_key_kernel" +BK = "self_key_bias" +WV = "self_value_kernel" +BV = "self_value_bias" +WQKV = "self_qkv_kernel" +BQKV = "self_qkv_bias" + +""" +Transformer Keys +""" +W_AOUT = "attention_output_dense_kernel" +B_AOUT = "attention_output_dense_bias" +AOUT_LN_BETA = "attention_output_layernorm_beta" +AOUT_LN_GAMMA = "attention_output_layernorm_gamma" +W_MID = "intermediate_dense_kernel" +B_MID = "intermediate_dense_bias" +W_LOUT = "output_dense_kernel" +B_LOUT = "output_dense_bias" +LOUT_LN_BETA = "output_layernorm_beta" +LOUT_LN_GAMMA = "output_layernorm_gamma" + +""" +Squad Output Keys +""" +SQD_W = "squad_output_weights" +SQD_B = "squad_output_bias" + +ixrt_name_map = { + "bert.embeddings.LayerNorm.bias": "bert_embeddings_layernorm_beta", + "bert.embeddings.LayerNorm.weight" : "bert_embeddings_layernorm_gamma", + "bert.embeddings.word_embeddings.weight" : "bert_embeddings_word_embeddings", + "bert.embeddings.token_type_embeddings.weight" : "bert_embeddings_token_type_embeddings", + "bert.embeddings.position_embeddings.weight" : "bert_embeddings_position_embeddings", + "qa_outputs.weight" : "cls_squad_output_weights", + "qa_outputs.bias" : "cls_squad_output_bias" +} + +ixrt_atten_name_map = { + "bert.encoder.layer.{}.self_attn.qkv_proj.weight" : "l{}_attention_self_qkv_kernel", + "bert.encoder.layer.{}.self_attn.qkv_proj.bias" : "l{}_attention_self_qkv_bias", + "bert.encoder.layer.{}.self_attn.out_proj.bias" : "l{}_attention_output_dense_bias", + "bert.encoder.layer.{}.self_attn.out_proj.weight" : "l{}_attention_output_dense_kernel", + "bert.encoder.layer.{}.fc1.weight" : "l{}_intermediate_dense_kernel", + "bert.encoder.layer.{}.fc1.bias" : "l{}_intermediate_dense_bias", + "bert.encoder.layer.{}.fc2.weight" : "l{}_output_dense_kernel", + "bert.encoder.layer.{}.fc2.bias" : "l{}_output_dense_bias", + "bert.encoder.layer.{}.self_attn_layer_norm.weight" : "l{}_attention_output_layernorm_gamma", + "bert.encoder.layer.{}.self_attn_layer_norm.bias" : "l{}_attention_output_layernorm_beta", + "bert.encoder.layer.{}.final_layer_norm.weight" : "l{}_output_layernorm_gamma", + "bert.encoder.layer.{}.final_layer_norm.bias" : "l{}_output_layernorm_beta", + "bert.encoder.layer.{}.self_attn.qkv_proj.weight_quant.clip.clip_value_max" : "l{}_attention_self_qkv_wei_amax", + "bert.encoder.layer.{}.self_attn.qkv_proj.input_quant.clip.clip_value_max" : "l{}_attention_self_qkv_in_amax", + "bert.encoder.layer.{}.self_attn.qkv_proj.output_quant.clip.clip_value_max" : "l{}_attention_self_qkv_out_amax", + "bert.encoder.layer.{}.self_attn.attention_quant.clip.clip_value_max" : "l{}_attention_arrange_qkv_amax", + "bert.encoder.layer.{}.self_attn.softmax_in_quant.clip.clip_value_max" : "l{}_attention_softmax_in_amax", + "bert.encoder.layer.{}.self_attn.atten_score_out_quant.clip.clip_value_max" : "l{}_attention_softmax_out_amax", + "bert.encoder.layer.{}.self_attn.out_proj.input_quant.clip.clip_value_max" : "l{}_attention_output_dense_in_amax", + "bert.encoder.layer.{}.self_attn.out_proj.output_quant.clip.clip_value_max" : "l{}_attention_output_dense_out_amax", + "bert.encoder.layer.{}.self_attn.out_proj.weight_quant.clip.clip_value_max" : "l{}_attention_output_dense_wei_amax", + "bert.encoder.layer.{}.fc1.input_quant.clip.clip_value_max" : "l{}_intermediate_dense_in_amax", + "bert.encoder.layer.{}.fc1.output_quant.clip.clip_value_max" : "l{}_intermediate_dense_out_amax", + "bert.encoder.layer.{}.fc1.weight_quant.clip.clip_value_max" : "l{}_intermediate_dense_wei_amax", + "bert.encoder.layer.{}.fc2.input_quant.clip.clip_value_max" : "l{}_output_dense_in_amax", + "bert.encoder.layer.{}.fc2_out_quant.clip.clip_value_max" : "l{}_output_dense_out_amax", + "bert.encoder.layer.{}.fc2.weight_quant.clip.clip_value_max" : "l{}_output_dense_wei_amax" +} + +def get_weight_dict(tensor_dict, config): + N = config.num_attention_heads + H = config.head_size + hidden_size = config.hidden_size + + weights_dict = dict() + for outname, tensor in tensor_dict.items(): + if outname.find("_amax") != -1: + weights_dict[outname] = tensor.item() + elif outname.find(BQ) != -1: + prefix = outname[:outname.find(BQ)] + + Wqkv = np.zeros((3, hidden_size, hidden_size), np.float32) + Bqkv = np.zeros((3, hidden_size), np.float32) + + Wqkv[0,:,:] = tensor_dict[prefix + WQ] + Wqkv[1,:,:] = tensor_dict[prefix + WK] + Wqkv[2,:,:] = tensor_dict[prefix + WV] + Bqkv[0,:] = tensor + Bqkv[1,:] = tensor_dict[prefix + BK] + Bqkv[2,:] = tensor_dict[prefix + BV] + + weights_dict[prefix + WQKV] = Wqkv.flatten() + weights_dict[prefix + BQKV] = Bqkv.flatten() + elif outname.find(BK) != -1 or outname.find(BV) != -1 or outname.find(WQ) != -1 or outname.find(WK) != -1 or outname.find(WV) != -1: + pass + else: + flat_tensor = np.ascontiguousarray(tensor).flatten() + weights_dict[outname] = flat_tensor + + return weights_dict + +def pytorch_to_trt_name(state_dict, num_layer): + tensor_dict = {} + for name in ixrt_name_map.keys(): + tensor_dict[ixrt_name_map[name]] = state_dict[name] + + for name in ixrt_atten_name_map.keys(): + for layer_id in range(num_layer): + key_name = name.format(layer_id) + value_name = ixrt_atten_name_map[name].format(layer_id) + tensor_dict[value_name] = state_dict[key_name] + return tensor_dict + +def load_pytorch_weights_and_quant(path, config): + """ + Load the weights from the pytorch checkpoint + """ + state_dict = torch.load(path, map_location='cpu') + tensor_dict = pytorch_to_trt_name(state_dict, config.num_hidden_layers) + return get_weight_dict(tensor_dict, config) + +class BertConfig: + def __init__(self, bert_config_path, use_fp16, use_int8=False, use_trt=False): + with open(bert_config_path, "r") as f: + data = json.load(f) + self.num_attention_heads = data["num_attention_heads"] + self.hidden_size = data["hidden_size"] + self.intermediate_size = data["intermediate_size"] + self.num_hidden_layers = data["num_hidden_layers"] + self.head_size = self.hidden_size // self.num_attention_heads + self.use_fp16 = use_fp16 + self.use_int8 = use_int8 + self.use_trt = use_trt + +if __name__ == '__main__': + bert_config_path = './data/bert-large-uncased/bert_config.json' + pytorch_model_path = './data/bert-large-uncased/bert_large_int8_qat.bin' + weight_save_path = "./data/bert-large-uncased/bert_large_v1_1_int8.wts" + config = BertConfig(bert_config_path, True) + weights_dict = load_pytorch_weights_and_quant(pytorch_model_path, config) + f = open(weight_save_path, "w") + num = 0 + for key, value in weights_dict.items(): + if key.find('_amax') == -1: + num += 1 + + f.write('{}\n'.format(num)) + for key, value in weights_dict.items(): + if key.find('_amax') != -1: + continue + print('key: ', key) + f.write("{} {}".format(key, len(value))) + print(len(value)) + for v in value: + f.write(" ") + f.write(struct.pack('>f', float(v)).hex()) + f.write("\n") + + f.write('{}\n'.format(len(weights_dict) - num)) + for key, value in weights_dict.items(): + if key.find('_amax') == -1: + continue + print('key: ', key) + print('value: ', value) + f.write('{} '.format(key)) + f.write(struct.pack('>f', float(weights_dict[key])).hex()) + f.write('\n') diff --git a/models/nlp/language_model/bert_base_squad/ixrt/python/ixrt/evaluate-v1.1.py b/models/nlp/language_model/bert_base_squad/ixrt/python/ixrt/evaluate-v1.1.py new file mode 100644 index 0000000000000000000000000000000000000000..92c4e83bf7f150156108b7ccd99f0a9373222c2a --- /dev/null +++ b/models/nlp/language_model/bert_base_squad/ixrt/python/ixrt/evaluate-v1.1.py @@ -0,0 +1,126 @@ +#!/usr/bin/env python3 +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Obtained from https://rajpurkar.github.io/SQuAD-explorer/ + +""" Official evaluation script for v1.1 of the SQuAD dataset. """ +from __future__ import print_function +from collections import Counter +import string +import re +import argparse +import json +import sys + +def normalize_answer(s): + """Lower text and remove punctuation, articles and extra whitespace.""" + def remove_articles(text): + return re.sub(r'\b(a|an|the)\b', ' ', text) + + def white_space_fix(text): + return ' '.join(text.split()) + + def remove_punc(text): + exclude = set(string.punctuation) + return ''.join(ch for ch in text if ch not in exclude) + + def lower(text): + return text.lower() + + return white_space_fix(remove_articles(remove_punc(lower(s)))) + + +def f1_score(prediction, ground_truth): + prediction_tokens = normalize_answer(prediction).split() + ground_truth_tokens = normalize_answer(ground_truth).split() + common = Counter(prediction_tokens) & Counter(ground_truth_tokens) + num_same = sum(common.values()) + if num_same == 0: + return 0 + precision = 1.0 * num_same / len(prediction_tokens) + recall = 1.0 * num_same / len(ground_truth_tokens) + f1 = (2 * precision * recall) / (precision + recall) + return f1 + + +def exact_match_score(prediction, ground_truth): + return (normalize_answer(prediction) == normalize_answer(ground_truth)) + + +def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): + scores_for_ground_truths = [] + for ground_truth in ground_truths: + score = metric_fn(prediction, ground_truth) + scores_for_ground_truths.append(score) + return max(scores_for_ground_truths) + +def evaluate(dataset, predictions, f1_acc): + f1 = exact_match = total = 0 + for article in dataset: + for paragraph in article['paragraphs']: + for qa in paragraph['qas']: + total += 1 + if qa['id'] not in predictions: + message = 'Unanswered question ' + qa['id'] + \ + ' will receive score 0.' + print(message, file=sys.stderr) + continue + ground_truths = list(map(lambda x: x['text'], qa['answers'])) + prediction = predictions[qa['id']] + exact_match += metric_max_over_ground_truths( + exact_match_score, prediction, ground_truths) + f1 += metric_max_over_ground_truths( + f1_score, prediction, ground_truths) + + exact_match = 100.0 * exact_match / total + f1 = 100.0 * f1 / total + status = 1 + if (f1 < f1_acc - 0.5): + print("&&&& FAILED TensorRT BERT Squad Accuracy matches reference.") + status = 0 + else: + print("&&&& PASSED TensorRT BERT Squad Accuracy matches reference.") + + return {'exact_match': exact_match, 'f1': f1, "status": status} + +if __name__ == '__main__': + expected_version = '1.1' + parser = argparse.ArgumentParser( + description='Evaluation for SQuAD ' + expected_version) + parser.add_argument('dataset_file', help='Dataset file') + parser.add_argument('prediction_file', help='Prediction File') + parser.add_argument('f1_acc', help='Reference Accuracy') + args = parser.parse_args() + with open(args.dataset_file) as dataset_file: + dataset_json = json.load(dataset_file) + if (dataset_json['version'] != expected_version): + print('Evaluation expects v-' + expected_version + + ', but got dataset with v-' + dataset_json['version'], + file=sys.stderr) + dataset = dataset_json['data'] + with open(args.prediction_file) as prediction_file: + predictions = json.load(prediction_file) + f1_acc = float(args.f1_acc) + res = evaluate(dataset, predictions, f1_acc) + print(res) + if res["status"] == 1: + print("pass!") + exit() + else: + print("failed!") + exit(1) diff --git a/models/nlp/language_model/bert_base_squad/ixrt/python/ixrt/evaluate.py b/models/nlp/language_model/bert_base_squad/ixrt/python/ixrt/evaluate.py new file mode 100644 index 0000000000000000000000000000000000000000..49b0dedec85518e852bd3d18e106945273094e27 --- /dev/null +++ b/models/nlp/language_model/bert_base_squad/ixrt/python/ixrt/evaluate.py @@ -0,0 +1,131 @@ +#!/usr/bin/env python3 +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Official evaluation script for v1.1 of the SQuAD dataset. """ + +import argparse +import json +import re +import string +import sys +from collections import Counter + + +def normalize_answer(s): + """Lower text and remove punctuation, articles and extra whitespace.""" + + def remove_articles(text): + return re.sub(r"\b(a|an|the)\b", " ", text) + + def white_space_fix(text): + return " ".join(text.split()) + + def remove_punc(text): + exclude = set(string.punctuation) + return "".join(ch for ch in text if ch not in exclude) + + def lower(text): + return text.lower() + + return white_space_fix(remove_articles(remove_punc(lower(s)))) + + +def f1_score(prediction, ground_truth): + prediction_tokens = normalize_answer(prediction).split() + ground_truth_tokens = normalize_answer(ground_truth).split() + common = Counter(prediction_tokens) & Counter(ground_truth_tokens) + num_same = sum(common.values()) + if num_same == 0: + return 0 + precision = 1.0 * num_same / len(prediction_tokens) + recall = 1.0 * num_same / len(ground_truth_tokens) + f1 = (2 * precision * recall) / (precision + recall) + return f1 + + +def exact_match_score(prediction, ground_truth): + return normalize_answer(prediction) == normalize_answer(ground_truth) + + +def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): + scores_for_ground_truths = [] + for ground_truth in ground_truths: + score = metric_fn(prediction, ground_truth) + scores_for_ground_truths.append(score) + return max(scores_for_ground_truths) + + +def evaluate(dataset, predictions): + f1 = exact_match = total = 0 + for article in dataset: + for paragraph in article["paragraphs"]: + for qa in paragraph["qas"]: + total += 1 + if qa["id"] not in predictions: + message = ( + "Unanswered question " + qa["id"] + " will receive score 0." + ) + print(message, file=sys.stderr) + continue + ground_truths = list(map(lambda x: x["text"], qa["answers"])) + prediction = predictions[qa["id"]] + exact_match += metric_max_over_ground_truths( + exact_match_score, prediction, ground_truths + ) + f1 += metric_max_over_ground_truths(f1_score, prediction, ground_truths) + + exact_match = 100.0 * exact_match / total + f1 = 100.0 * f1 / total + + return {"exact_match": exact_match, "f1": f1} + + +if __name__ == "__main__": + expected_version = "1.1" + parser = argparse.ArgumentParser( + description="Evaluation for SQuAD " + expected_version + ) + parser.add_argument("dataset_file", help="Dataset file") + parser.add_argument("prediction_file", help="Prediction File") + args = parser.parse_args() + with open(args.dataset_file) as dataset_file: + dataset_json = json.load(dataset_file) + if dataset_json["version"] != expected_version: + print( + "Evaluation expects v-" + + expected_version + + ", but got dataset with v-" + + dataset_json["version"], + file=sys.stderr, + ) + dataset = dataset_json["data"] + with open(args.prediction_file) as prediction_file: + predictions = json.load(prediction_file) + print(json.dumps(evaluate(dataset, predictions))) diff --git a/models/nlp/language_model/bert_base_squad/ixrt/python/ixrt/helpers/__init__.py b/models/nlp/language_model/bert_base_squad/ixrt/python/ixrt/helpers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/nlp/language_model/bert_base_squad/ixrt/python/ixrt/helpers/calibrator.py b/models/nlp/language_model/bert_base_squad/ixrt/python/ixrt/helpers/calibrator.py new file mode 100644 index 0000000000000000000000000000000000000000..beacc625fae0f73bda3480054e4ecceca85fb240 --- /dev/null +++ b/models/nlp/language_model/bert_base_squad/ixrt/python/ixrt/helpers/calibrator.py @@ -0,0 +1,112 @@ +#!/usr/bin/env python3 +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import tensorrt as trt +import os + +import pycuda.driver as cuda +import pycuda.autoinit +import numpy as np +import helpers.tokenization as tokenization +import helpers.data_processing as dp + +class BertCalibrator(trt.IInt8LegacyCalibrator): + def __init__(self, squad_json, vocab_file, cache_file, batch_size, max_seq_length, num_inputs): + # Whenever you specify a custom constructor for a TensorRT class, + # you MUST call the constructor of the parent explicitly. + trt.IInt8LegacyCalibrator.__init__(self) + + self.cache_file = cache_file + + # Every time get_batch is called, the next batch of size batch_size will be copied to the device and returned. + self.data = dp.read_squad_json(squad_json) + self.max_seq_length = max_seq_length + self.batch_size = batch_size + self.current_index = 0 + self.num_inputs = num_inputs + self.tokenizer = tokenization.BertTokenizer(vocab_file=vocab_file, do_lower_case=True) + self.doc_stride = 128 + self.max_query_length = 64 + + # Allocate enough memory for a whole batch. + self.device_inputs = [cuda.mem_alloc(self.max_seq_length * trt.int32.itemsize * self.batch_size) for binding in range(3)] + + def free(self): + for dinput in self.device_inputs: + dinput.free() + + def get_batch_size(self): + return self.batch_size + + # TensorRT passes along the names of the engine bindings to the get_batch function. + # You don't necessarily have to use them, but they can be useful to understand the order of + # the inputs. The bindings list is expected to have the same ordering as 'names'. + def get_batch(self, names): + if self.current_index + self.batch_size > self.num_inputs: + print("Calibrating index {:} batch size {:} exceed max input limit {:} sentences".format(self.current_index, self.batch_size, self.num_inputs)) + return None + + current_batch = int(self.current_index / self.batch_size) + if current_batch % 10 == 0: + print("Calibrating batch {:}, containing {:} sentences".format(current_batch, self.batch_size)) + + input_ids = [] + segment_ids = [] + input_mask = [] + for i in range(self.batch_size): + example = self.data[self.current_index + i] + features = dp.convert_example_to_features(example.doc_tokens, example.question_text, self.tokenizer, self.max_seq_length, self.doc_stride, self.max_query_length) + if len(input_ids) and len(segment_ids) and len(input_mask): + input_ids = np.concatenate((input_ids, features[0].input_ids)) + segment_ids = np.concatenate((segment_ids, features[0].segment_ids)) + input_mask = np.concatenate((input_mask, features[0].input_mask)) + else: + input_ids = features[0].input_ids + segment_ids = features[0].segment_ids + input_mask = features[0].input_mask + + cuda.memcpy_htod(self.device_inputs[0], input_ids.ravel()) + cuda.memcpy_htod(self.device_inputs[1], segment_ids.ravel()) + cuda.memcpy_htod(self.device_inputs[2], input_mask.ravel()) + + self.current_index += self.batch_size + return self.device_inputs + + def read_calibration_cache(self): + # If there is a cache, use it instead of calibrating again. Otherwise, implicitly return None. + if os.path.exists(self.cache_file): + with open(self.cache_file, "rb") as f: + return f.read() + + def write_calibration_cache(self, cache): + with open(self.cache_file, "wb") as f: + f.write(cache) + f.flush() + os.fsync(f) + + def get_quantile(self): + return 0.9999 + + def get_regression_cutoff(self): + return 1.0 + + def read_histogram_cache(self, length): + return None + + def write_histogram_cache(self, ptr, length): + return None diff --git a/models/nlp/language_model/bert_base_squad/ixrt/python/ixrt/helpers/data_processing.py b/models/nlp/language_model/bert_base_squad/ixrt/python/ixrt/helpers/data_processing.py new file mode 100644 index 0000000000000000000000000000000000000000..712e1a61d29a198eb276f41a9249b0c66e3786ba --- /dev/null +++ b/models/nlp/language_model/bert_base_squad/ixrt/python/ixrt/helpers/data_processing.py @@ -0,0 +1,497 @@ +#!/usr/bin/env python3 +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import helpers.tokenization as tokenization +import collections +import numpy as np +import six +import math +import json + + +def convert_doc_tokens(paragraph_text): + + """ Return the list of tokens from the doc text """ + def is_whitespace(c): + if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F: + return True + return False + + doc_tokens = [] + prev_is_whitespace = True + for c in paragraph_text: + if is_whitespace(c): + prev_is_whitespace = True + else: + if prev_is_whitespace: + doc_tokens.append(c) + else: + doc_tokens[-1] += c + prev_is_whitespace = False + + return doc_tokens + + +def _check_is_max_context(doc_spans, cur_span_index, position): + """Check if this is the 'max context' doc span for the token.""" + + # Because of the sliding window approach taken to scoring documents, a single + # token can appear in multiple documents. E.g. + # Doc: the man went to the store and bought a gallon of milk + # Span A: the man went to the + # Span B: to the store and bought + # Span C: and bought a gallon of + # ... + # + # Now the word 'bought' will have two scores from spans B and C. We only + # want to consider the score with "maximum context", which we define as + # the *minimum* of its left and right context (the *sum* of left and + # right context will always be the same, of course). + # + # In the example the maximum context for 'bought' would be span C since + # it has 1 left context and 3 right context, while span B has 4 left context + # and 0 right context. + best_score = None + best_span_index = None + for (span_index, doc_span) in enumerate(doc_spans): + end = doc_span.start + doc_span.length - 1 + if position < doc_span.start: + continue + if position > end: + continue + num_left_context = position - doc_span.start + num_right_context = end - position + score = min(num_left_context, num_right_context) + 0.01 * doc_span.length + if best_score is None or score > best_score: + best_score = score + best_span_index = span_index + + return cur_span_index == best_span_index + + +def convert_example_to_features(doc_tokens, question_text, tokenizer, max_seq_length, + doc_stride, max_query_length): + """Loads a data file into a list of `InputBatch`s.""" + + query_tokens = tokenizer.tokenize(question_text) + + if len(query_tokens) > max_query_length: + query_tokens = query_tokens[0:max_query_length] + + tok_to_orig_index = [] + orig_to_tok_index = [] + all_doc_tokens = [] + for (i, token) in enumerate(doc_tokens): + orig_to_tok_index.append(len(all_doc_tokens)) + sub_tokens = tokenizer.tokenize(token) + for sub_token in sub_tokens: + tok_to_orig_index.append(i) + all_doc_tokens.append(sub_token) + + # The -3 accounts for [CLS], [SEP] and [SEP] + max_tokens_for_doc = max_seq_length - len(query_tokens) - 3 + + # We can have documents that are longer than the maximum sequence length. + # To deal with this we do a sliding window approach, where we take chunks + # of the up to our max length with a stride of `doc_stride`. + _DocSpan = collections.namedtuple( # pylint: disable=invalid-name + "DocSpan", ["start", "length"]) + doc_spans = [] + start_offset = 0 + while start_offset < len(all_doc_tokens): + length = len(all_doc_tokens) - start_offset + if length > max_tokens_for_doc: + length = max_tokens_for_doc + doc_spans.append(_DocSpan(start=start_offset, length=length)) + if start_offset + length == len(all_doc_tokens): + break + start_offset += min(length, doc_stride) + + _Feature = collections.namedtuple( # pylint: disable=invalid-name + "Feature", + ["input_ids", "input_mask", "segment_ids", "tokens", "token_to_orig_map", "token_is_max_context"]) + + + features = [] + for (doc_span_index, doc_span) in enumerate(doc_spans): + tokens = [] + token_to_orig_map = {} + token_is_max_context = {} + segment_ids = [] + tokens.append("[CLS]") + segment_ids.append(0) + for token in query_tokens: + tokens.append(token) + segment_ids.append(0) + tokens.append("[SEP]") + segment_ids.append(0) + + for i in range(doc_span.length): + split_token_index = doc_span.start + i + token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index] + + is_max_context = _check_is_max_context(doc_spans, doc_span_index, split_token_index) + token_is_max_context[len(tokens)] = is_max_context + tokens.append(all_doc_tokens[split_token_index]) + segment_ids.append(1) + tokens.append("[SEP]") + segment_ids.append(1) + + input_ids = tokenizer.convert_tokens_to_ids(tokens) + + # The mask has 1 for real tokens and 0 for padding tokens. Only real + # tokens are attended to. + input_mask = [1] * len(input_ids) + + # Zero-pad up to the sequence length. + # while len(input_ids) < max_seq_length: + # input_ids.append(0) + # input_mask.append(0) + # segment_ids.append(0) + + # assert len(input_ids) == max_seq_length + # assert len(input_mask) == max_seq_length + # assert len(segment_ids) == max_seq_length + + def create_int_feature(values): + feature = np.asarray(values, dtype=np.int32, order=None) + return feature + + + features.append(_Feature( + input_ids = create_int_feature(input_ids), + input_mask = create_int_feature(input_mask), + segment_ids = create_int_feature(segment_ids), + tokens = tokens, + token_to_orig_map = token_to_orig_map, + token_is_max_context = token_is_max_context + )) + return features + + +def read_squad_json(input_file): + """read from squad json into a list of examples""" + with open(input_file, "r", encoding='utf-8') as reader: + input_data = json.load(reader)["data"] + + _Example = collections.namedtuple( # pylint: disable=invalid-name + "Example", + ["id", "question_text", "doc_tokens"]) + + examples = [] + for entry in input_data: + for paragraph in entry["paragraphs"]: + paragraph_text = paragraph["context"] + doc_tokens = convert_doc_tokens(paragraph_text) + + for qa in paragraph["qas"]: + examples.append(_Example( + id = qa["id"], + question_text = qa["question"], + doc_tokens = doc_tokens + )) + + return examples + + +def _get_best_indexes(logits, n_best_size): + """Get the n-best logits from a list.""" + + index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True) + + best_indexes = [] + for i in range(len(index_and_score)): + if i >= n_best_size: + break + best_indexes.append(index_and_score[i][0]) + return best_indexes + + +def get_final_text(pred_text, orig_text, do_lower_case): + """Project the tokenized prediction back to the original text.""" + + # When we created the data, we kept track of the alignment between original + # (whitespace tokenized) tokens and our WordPiece tokenized tokens. So + # now `orig_text` contains the span of our original text corresponding to the + # span that we predicted. + # + # However, `orig_text` may contain extra characters that we don't want in + # our prediction. + # + # For example, let's say: + # pred_text = steve smith + # orig_text = Steve Smith's + # + # We don't want to return `orig_text` because it contains the extra "'s". + # + # We don't want to return `pred_text` because it's already been normalized + # (the SQuAD eval script also does punctuation stripping/lower casing but + # our tokenizer does additional normalization like stripping accent + # characters). + # + # What we really want to return is "Steve Smith". + # + # Therefore, we have to apply a semi-complicated alignment heruistic between + # `pred_text` and `orig_text` to get a character-to-charcter alignment. This + # can fail in certain cases in which case we just return `orig_text`. + + def _strip_spaces(text): + ns_chars = [] + ns_to_s_map = collections.OrderedDict() + for (i, c) in enumerate(text): + if c == " ": + continue + ns_to_s_map[len(ns_chars)] = i + ns_chars.append(c) + ns_text = "".join(ns_chars) + return (ns_text, ns_to_s_map) + + # We first tokenize `orig_text`, strip whitespace from the result + # and `pred_text`, and check if they are the same length. If they are + # NOT the same length, the heuristic has failed. If they are the same + # length, we assume the characters are one-to-one aligned. + tokenizer = tokenization.BasicTokenizer(do_lower_case=do_lower_case) + + tok_text = " ".join(tokenizer.tokenize(orig_text)) + + start_position = tok_text.find(pred_text) + if start_position == -1: + return orig_text + end_position = start_position + len(pred_text) - 1 + + (orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text) + (tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text) + + if len(orig_ns_text) != len(tok_ns_text): + return orig_text + + # We then project the characters in `pred_text` back to `orig_text` using + # the character-to-character alignment. + tok_s_to_ns_map = {} + for (i, tok_index) in six.iteritems(tok_ns_to_s_map): + tok_s_to_ns_map[tok_index] = i + + orig_start_position = None + if start_position in tok_s_to_ns_map: + ns_start_position = tok_s_to_ns_map[start_position] + if ns_start_position in orig_ns_to_s_map: + orig_start_position = orig_ns_to_s_map[ns_start_position] + + if orig_start_position is None: + return orig_text + + orig_end_position = None + if end_position in tok_s_to_ns_map: + ns_end_position = tok_s_to_ns_map[end_position] + if ns_end_position in orig_ns_to_s_map: + orig_end_position = orig_ns_to_s_map[ns_end_position] + + if orig_end_position is None: + return orig_text + + output_text = orig_text[orig_start_position:(orig_end_position + 1)] + return output_text + + +def _compute_softmax(scores): + """Compute softmax probability over raw logits.""" + if not scores: + return [] + + max_score = None + for score in scores: + if max_score is None or score > max_score: + max_score = score + + exp_scores = [] + total_sum = 0.0 + for score in scores: + x = math.exp(score - max_score) + exp_scores.append(x) + total_sum += x + + probs = [] + for score in exp_scores: + probs.append(score / total_sum) + return probs + + +def get_predictions(doc_tokens, features, results, n_best_size, max_answer_length): + _PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name + "PrelimPrediction", + ["feature_index", "start_index", "end_index", "start_logit", "end_logit"]) + + prediction = "" + scores_diff_json = 0.0 + + prelim_predictions = [] + # keep track of the minimum score of null start+end of position 0 + score_null = 1000000 # large and positive + min_null_feature_index = 0 # the paragraph slice with min mull score + null_start_logit = 0 # the start logit at the slice with min null score + null_end_logit = 0 # the end logit at the slice with min null score + version_2_with_negative = False + + for result in results: + start_indexes = _get_best_indexes(result.start_logits, n_best_size) + end_indexes = _get_best_indexes(result.end_logits, n_best_size) + feature = features[result.feature_index] + + # if we could have irrelevant answers, get the min score of irrelevant + if version_2_with_negative: + feature_null_score = result.start_logits[0] + result.end_logits[0] + if feature_null_score < score_null: + score_null = feature_null_score + min_null_feature_index = 0 + null_start_logit = result.start_logits[0] + null_end_logit = result.end_logits[0] + + for start_index in start_indexes: + for end_index in end_indexes: + # We could hypothetically create invalid predictions, e.g., predict + # that the start of the span is in the question. We throw out all + # invalid predictions. + if start_index >= len(feature.tokens): + continue + if end_index >= len(feature.tokens): + continue + if start_index not in feature.token_to_orig_map: + continue + if end_index not in feature.token_to_orig_map: + continue + if not feature.token_is_max_context.get(start_index, False): + continue + if end_index < start_index: + continue + length = end_index - start_index + 1 + if length > max_answer_length: + continue + prelim_predictions.append( + _PrelimPrediction( + feature_index=result.feature_index, + start_index=start_index, + end_index=end_index, + start_logit=result.start_logits[start_index], + end_logit=result.end_logits[end_index])) + + if version_2_with_negative: + prelim_predictions.append( + _PrelimPrediction( + feature_index=result.feature_index, + start_index=0, + end_index=0, + start_logit=null_start_logit, + end_logit=null_end_logit)) + + prelim_predictions = sorted( + prelim_predictions, + key=lambda x: (x.start_logit + x.end_logit), + reverse=True) + + _NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name + "NbestPrediction", ["text", "start_logit", "end_logit"]) + + seen_predictions = {} + nbest = [] + for pred in prelim_predictions: + if len(nbest) >= n_best_size: + break + + if pred.start_index > 0: # this is a non-null prediction + feature = features[pred.feature_index] + tok_tokens = feature.tokens[pred.start_index:(pred.end_index + 1)] + orig_doc_start = feature.token_to_orig_map[pred.start_index] + orig_doc_end = feature.token_to_orig_map[pred.end_index] + orig_tokens = doc_tokens[orig_doc_start:(orig_doc_end + 1)] + tok_text = " ".join(tok_tokens) + + # De-tokenize WordPieces that have been split off. + tok_text = tok_text.replace(" ##", "") + tok_text = tok_text.replace("##", "") + + # Clean whitespace + tok_text = tok_text.strip() + tok_text = " ".join(tok_text.split()) + orig_text = " ".join(orig_tokens) + + final_text = get_final_text(tok_text, orig_text, True) + if final_text in seen_predictions: + continue + + seen_predictions[final_text] = True + else: + final_text = "" + seen_predictions[final_text] = True + + if len(final_text): + nbest.append( + _NbestPrediction( + text=final_text, + start_logit=pred.start_logit, + end_logit=pred.end_logit)) + + # if we didn't inlude the empty option in the n-best, inlcude it + if version_2_with_negative: + if "" not in seen_predictions: + nbest.append( + _NbestPrediction( + text="", start_logit=null_start_logit, + end_logit=null_end_logit)) + # In very rare edge cases we could have no valid predictions. So we + # just create a nonce prediction in this case to avoid failure. + if not nbest: + nbest.append( + _NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0)) + + assert len(nbest) >= 1 + + total_scores = [] + best_non_null_entry = None + for entry in nbest: + total_scores.append(entry.start_logit + entry.end_logit) + if not best_non_null_entry: + if entry.text: + best_non_null_entry = entry + + probs = _compute_softmax(total_scores) + + nbest_json = [] + for (i, entry) in enumerate(nbest): + output = collections.OrderedDict() + output["text"] = entry.text + output["probability"] = probs[i] + output["start_logit"] = entry.start_logit + output["end_logit"] = entry.end_logit + nbest_json.append(output) + + assert len(nbest_json) >= 1 + + null_score_diff_threshold = 0.0 + if not version_2_with_negative: + prediction = nbest_json[0]["text"] + else: + # predict "" iff the null score - the score of best non-null > threshold + score_diff = score_null - best_non_null_entry.start_logit - ( + best_non_null_entry.end_logit) + scores_diff_json = score_diff + if score_diff > null_score_diff_threshold: + prediction = "" + else: + prediction = best_non_null_entry.text + + return prediction, nbest_json, scores_diff_json diff --git a/models/nlp/language_model/bert_base_squad/ixrt/python/ixrt/helpers/tokenization.py b/models/nlp/language_model/bert_base_squad/ixrt/python/ixrt/helpers/tokenization.py new file mode 100644 index 0000000000000000000000000000000000000000..434f411df061376e565c13b5a96466175b39383c --- /dev/null +++ b/models/nlp/language_model/bert_base_squad/ixrt/python/ixrt/helpers/tokenization.py @@ -0,0 +1,446 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import re +import unicodedata +import six + + +def validate_case_matches_checkpoint(do_lower_case, init_checkpoint): + """Checks whether the casing config is consistent with the checkpoint name.""" + + # The casing has to be passed in by the user and there is no explicit check + # as to whether it matches the checkpoint. The casing information probably + # should have been stored in the bert_config.json file, but it's not, so + # we have to heuristically detect it to validate. + + if not init_checkpoint: + return + + m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint) + if m is None: + return + + model_name = m.group(1) + + lower_models = [ + "uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12", + "multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12" + ] + + cased_models = [ + "cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16", + "multi_cased_L-12_H-768_A-12" + ] + + is_bad_config = False + if model_name in lower_models and not do_lower_case: + is_bad_config = True + actual_flag = "False" + case_name = "lowercased" + opposite_flag = "True" + + if model_name in cased_models and do_lower_case: + is_bad_config = True + actual_flag = "True" + case_name = "cased" + opposite_flag = "False" + + if is_bad_config: + raise ValueError( + "You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. " + "However, `%s` seems to be a %s model, so you " + "should pass in `--do_lower_case=%s` so that the fine-tuning matches " + "how the model was pre-training. If this error is wrong, please " + "just comment out this check." % (actual_flag, init_checkpoint, + model_name, case_name, opposite_flag)) + + +def convert_to_unicode(text): + """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" + if six.PY3: + if isinstance(text, str): + return text + elif isinstance(text, bytes): + return text.decode("utf-8", "ignore") + else: + raise ValueError("Unsupported string type: %s" % (type(text))) + elif six.PY2: + if isinstance(text, str): + return text.decode("utf-8", "ignore") + elif isinstance(text, unicode): + return text + else: + raise ValueError("Unsupported string type: %s" % (type(text))) + else: + raise ValueError("Not running on Python2 or Python 3?") + + +def printable_text(text): + """Returns text encoded in a way suitable for print or `tf.logging`.""" + + # These functions want `str` for both Python2 and Python3, but in one case + # it's a Unicode string and in the other it's a byte string. + if six.PY3: + if isinstance(text, str): + return text + elif isinstance(text, bytes): + return text.decode("utf-8", "ignore") + else: + raise ValueError("Unsupported string type: %s" % (type(text))) + elif six.PY2: + if isinstance(text, str): + return text + elif isinstance(text, unicode): + return text.encode("utf-8") + else: + raise ValueError("Unsupported string type: %s" % (type(text))) + else: + raise ValueError("Not running on Python2 or Python 3?") + + +def load_vocab(vocab_file): + """Loads a vocabulary file into a dictionary.""" + vocab = collections.OrderedDict() + index = 0 + with open(vocab_file, "r", encoding='utf-8') as reader: + while True: + token = convert_to_unicode(reader.readline()) + if not token: + break + token = token.strip() + vocab[token] = index + index += 1 + return vocab + + +def convert_by_vocab(vocab, items): + """Converts a sequence of [tokens|ids] using the vocab.""" + output = [] + for item in items: + output.append(vocab[item]) + return output + + +def convert_tokens_to_ids(vocab, tokens): + return convert_by_vocab(vocab, tokens) + + +def convert_ids_to_tokens(inv_vocab, ids): + return convert_by_vocab(inv_vocab, ids) + + +def whitespace_tokenize(text): + """Runs basic whitespace cleaning and splitting on a piece of text.""" + text = text.strip() + if not text: + return [] + tokens = text.split() + return tokens + + +class FullTokenizer(object): + """Runs end-to-end tokenziation.""" + + def __init__(self, vocab_file, do_lower_case=True): + self.vocab = load_vocab(vocab_file) + self.inv_vocab = {v: k for k, v in self.vocab.items()} + self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) + self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) + + def tokenize(self, text): + split_tokens = [] + for token in self.basic_tokenizer.tokenize(text): + for sub_token in self.wordpiece_tokenizer.tokenize(token): + split_tokens.append(sub_token) + + return split_tokens + + def convert_tokens_to_ids(self, tokens): + return convert_by_vocab(self.vocab, tokens) + + def convert_ids_to_tokens(self, ids): + return convert_by_vocab(self.inv_vocab, ids) + + +class BertTokenizer(object): + """Runs end-to-end tokenization: punctuation splitting + wordpiece""" + + def __init__(self, vocab_file, do_lower_case=True): + self.vocab = load_vocab(vocab_file) + self.ids_to_tokens = collections.OrderedDict( + [(ids, tok) for tok, ids in self.vocab.items()]) + self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) + self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) + + def tokenize(self, text): + split_tokens = [] + for token in self.basic_tokenizer.tokenize(text): + for sub_token in self.wordpiece_tokenizer.tokenize(token): + split_tokens.append(sub_token) + return split_tokens + + def convert_tokens_to_ids(self, tokens): + """Converts a sequence of tokens into ids using the vocab.""" + ids = [] + for token in tokens: + ids.append(self.vocab[token]) + return ids + + def convert_ids_to_tokens(self, ids): + """Converts a sequence of ids in wordpiece tokens using the vocab.""" + tokens = [] + for i in ids: + tokens.append(self.ids_to_tokens[i]) + return tokens + +class BasicTokenizer(object): + """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" + + def __init__(self, do_lower_case=True): + """Constructs a BasicTokenizer. + + Args: + do_lower_case: Whether to lower case the input. + """ + self.do_lower_case = do_lower_case + + def tokenize(self, text): + """Tokenizes a piece of text.""" + text = convert_to_unicode(text) + text = self._clean_text(text) + + # This was added on November 1st, 2018 for the multilingual and Chinese + # models. This is also applied to the English models now, but it doesn't + # matter since the English models were not trained on any Chinese data + # and generally don't have any Chinese data in them (there are Chinese + # characters in the vocabulary because Wikipedia does have some Chinese + # words in the English Wikipedia.). + text = self._tokenize_chinese_chars(text) + + orig_tokens = whitespace_tokenize(text) + split_tokens = [] + for token in orig_tokens: + if self.do_lower_case: + token = token.lower() + token = self._run_strip_accents(token) + split_tokens.extend(self._run_split_on_punc(token)) + + output_tokens = whitespace_tokenize(" ".join(split_tokens)) + return output_tokens + + def _run_strip_accents(self, text): + """Strips accents from a piece of text.""" + text = unicodedata.normalize("NFD", text) + output = [] + for char in text: + cat = unicodedata.category(char) + if cat == "Mn": + continue + output.append(char) + return "".join(output) + + def _run_split_on_punc(self, text): + """Splits punctuation on a piece of text.""" + chars = list(text) + i = 0 + start_new_word = True + output = [] + while i < len(chars): + char = chars[i] + if _is_punctuation(char): + output.append([char]) + start_new_word = True + else: + if start_new_word: + output.append([]) + start_new_word = False + output[-1].append(char) + i += 1 + + return ["".join(x) for x in output] + + def _tokenize_chinese_chars(self, text): + """Adds whitespace around any CJK character.""" + output = [] + for char in text: + cp = ord(char) + if self._is_chinese_char(cp): + output.append(" ") + output.append(char) + output.append(" ") + else: + output.append(char) + return "".join(output) + + def _is_chinese_char(self, cp): + """Checks whether CP is the codepoint of a CJK character.""" + # This defines a "chinese character" as anything in the CJK Unicode block: + # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) + # + # Note that the CJK Unicode block is NOT all Japanese and Korean characters, + # despite its name. The modern Korean Hangul alphabet is a different block, + # as is Japanese Hiragana and Katakana. Those alphabets are used to write + # space-separated words, so they are not treated specially and handled + # like the all of the other languages. + if ((cp >= 0x4E00 and cp <= 0x9FFF) or # + (cp >= 0x3400 and cp <= 0x4DBF) or # + (cp >= 0x20000 and cp <= 0x2A6DF) or # + (cp >= 0x2A700 and cp <= 0x2B73F) or # + (cp >= 0x2B740 and cp <= 0x2B81F) or # + (cp >= 0x2B820 and cp <= 0x2CEAF) or + (cp >= 0xF900 and cp <= 0xFAFF) or # + (cp >= 0x2F800 and cp <= 0x2FA1F)): # + return True + + return False + + def _clean_text(self, text): + """Performs invalid character removal and whitespace cleanup on text.""" + output = [] + for char in text: + cp = ord(char) + if cp == 0 or cp == 0xfffd or _is_control(char): + continue + if _is_whitespace(char): + output.append(" ") + else: + output.append(char) + return "".join(output) + + +class WordpieceTokenizer(object): + """Runs WordPiece tokenziation.""" + + def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200): + self.vocab = vocab + self.unk_token = unk_token + self.max_input_chars_per_word = max_input_chars_per_word + + def tokenize(self, text): + """Tokenizes a piece of text into its word pieces. + + This uses a greedy longest-match-first algorithm to perform tokenization + using the given vocabulary. + + For example: + input = "unaffable" + output = ["un", "##aff", "##able"] + + Args: + text: A single token or whitespace separated tokens. This should have + already been passed through `BasicTokenizer. + + Returns: + A list of wordpiece tokens. + """ + + text = convert_to_unicode(text) + + output_tokens = [] + for token in whitespace_tokenize(text): + chars = list(token) + if len(chars) > self.max_input_chars_per_word: + output_tokens.append(self.unk_token) + continue + + is_bad = False + start = 0 + sub_tokens = [] + while start < len(chars): + end = len(chars) + cur_substr = None + while start < end: + substr = "".join(chars[start:end]) + if start > 0: + substr = "##" + substr + if substr in self.vocab: + cur_substr = substr + break + end -= 1 + if cur_substr is None: + is_bad = True + break + sub_tokens.append(cur_substr) + start = end + + if is_bad: + output_tokens.append(self.unk_token) + else: + output_tokens.extend(sub_tokens) + return output_tokens + + +def _is_whitespace(char): + """Checks whether `chars` is a whitespace character.""" + # \t, \n, and \r are technically contorl characters but we treat them + # as whitespace since they are generally considered as such. + if char == " " or char == "\t" or char == "\n" or char == "\r": + return True + cat = unicodedata.category(char) + if cat == "Zs": + return True + return False + + +def _is_control(char): + """Checks whether `chars` is a control character.""" + # These are technically control characters but we count them as whitespace + # characters. + if char == "\t" or char == "\n" or char == "\r": + return False + cat = unicodedata.category(char) + if cat.startswith("C"): + return True + return False + + +def _is_punctuation(char): + """Checks whether `chars` is a punctuation character.""" + cp = ord(char) + # We treat all non-letter/number ASCII as punctuation. + # Characters such as "^", "$", and "`" are not in the Unicode + # Punctuation class but we treat them as punctuation anyways, for + # consistency. + if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or + (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): + return True + cat = unicodedata.category(char) + if cat.startswith("P"): + return True + return False diff --git a/models/nlp/language_model/bert_base_squad/ixrt/python/ixrt/inference.py b/models/nlp/language_model/bert_base_squad/ixrt/python/ixrt/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..a85e765c91152562d6180307c2bb1317dc385356 --- /dev/null +++ b/models/nlp/language_model/bert_base_squad/ixrt/python/ixrt/inference.py @@ -0,0 +1,420 @@ +#!/usr/bin/env python3 +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys +import time +import json +import ctypes +import argparse +import collections +import numpy as np +import tensorrt as trt +import pycuda.driver as cuda +import pycuda.autoinit + +import helpers.tokenization as tokenization +import helpers.data_processing as dp +from tqdm import tqdm +import math + +from load_ixrt_plugin import load_ixrt_plugin +TRT_LOGGER = trt.Logger(trt.Logger.ERROR) + +def parse_args(): + """ + Parse command line arguments + """ + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument('-e', '--engine', + help='Path to BERT TensorRT engine') + parser.add_argument("-b", "--batch-size", default=1, help="Batch size for inference.", type=int) + parser.add_argument('-p', '--passage', nargs='*', + help='Text for paragraph/passage for BERT QA', + default='') + parser.add_argument('-pf', '--passage-file', + help='File containing input passage', + default='') + parser.add_argument('-q', '--question', nargs='*', + help='Text for query/question for BERT QA', + default='') + parser.add_argument('-qf', '--question-file', + help='File containing input question', + default='') + parser.add_argument('-sq', '--squad-json', + help='SQuAD json file', + default='') + parser.add_argument('-o', '--output-prediction-file', + help='Output prediction file for SQuAD evaluation', + default='./predictions.json') + parser.add_argument('-v', '--vocab-file', + help='Path to file containing entire understandable vocab') + parser.add_argument('-s', '--sequence-length', + help='The sequence length to use. Defaults to 128', + default=128, type=int) + parser.add_argument('--max-query-length', + help='The maximum length of a query in number of tokens. Queries longer than this will be truncated', + default=64, type=int) + parser.add_argument('--max-answer-length', + help='The maximum length of an answer that can be generated', + default=30, type=int) + parser.add_argument('--n-best-size', + help='Total number of n-best predictions to generate in the nbest_predictions.json output file', + default=20, type=int) + parser.add_argument('--doc-stride', + help='When splitting up a long document into chunks, what stride to take between chunks', + default=128, type=int) + parser.add_argument('--target_qps', + help="target qps metric", required=False, type=int) + parser.add_argument("-i", "--int8", action="store_true", help="Indicates that inference should be run in INT8 precision", required=False) + args, _ = parser.parse_known_args() + return args + +if __name__ == '__main__': + args = parse_args() + + paragraph_text = None + squad_examples = None + output_prediction_file = None + + if not args.passage == '': + paragraph_text = ' '.join(args.passage) + elif not args.passage_file == '': + f = open(args.passage_file, 'r') + paragraph_text = f.read() + elif not args.squad_json == '': + squad_examples = dp.read_squad_json(args.squad_json) + output_prediction_file = args.output_prediction_file + else: + paragraph_text = input("Paragraph: ") + + question_text = None + if not args.question == '': + question_text = ' '.join(args.question) + elif not args.question_file == '': + f = open(args.question_file, 'r') + question_text = f.read() + + tokenizer = tokenization.FullTokenizer(vocab_file=args.vocab_file, do_lower_case=True) + # When splitting up a long document into chunks, how much stride to take between chunks. + doc_stride = args.doc_stride + # The maximum total input sequence length after WordPiece tokenization. + # Sequences longer than this will be truncated, and sequences shorter + max_seq_length = args.sequence_length + + def question_features(tokens, question): + # Extract features from the paragraph and question + return dp.convert_example_to_features(tokens, question, tokenizer, max_seq_length, doc_stride, args.max_query_length) + + load_ixrt_plugin(TRT_LOGGER) + + # The first context created will use the 0th profile. A new context must be created + # for each additional profile needed. Here, we only use batch size 1, thus we only need the first profile. + with open(args.engine, 'rb') as f: + runtime = trt.Runtime(TRT_LOGGER) + engine = runtime.deserialize_cuda_engine(f.read()) + context = engine.create_execution_context() + + # select engine profile + selected_profile = -1 + num_binding_per_profile = engine.num_bindings // engine.num_optimization_profiles + for idx in range(engine.num_optimization_profiles): + profile_shape = engine.get_profile_shape(profile_index = idx, binding = idx * num_binding_per_profile) + if profile_shape[0][0] <= args.batch_size and profile_shape[2][0] >= args.batch_size and profile_shape[0][1] <= max_seq_length and profile_shape[2][1] >= max_seq_length: + selected_profile = idx + break + if selected_profile == -1: + raise RuntimeError("Could not find any profile that can run batch size {}.".format(args.batch_size)) + + # Create a stream in which to copy inputs/outputs and run inference. + stream = cuda.Stream() + + # if args.use_trt: + # context.active_optimization_profile = selected_profile + # else: + context.set_optimization_profile_async(selected_profile, stream.handle) + binding_idx_offset = selected_profile * num_binding_per_profile + + input_shape = (args.batch_size, max_seq_length) + input_nbytes = trt.volume(input_shape) * 4 + for binding in range(3): + context.set_binding_shape(binding, input_shape) + assert context.all_binding_shapes_specified + + # Allocate device memory for inputs. + d_inputs = [cuda.mem_alloc(input_nbytes) for binding in range(3)] + + # Allocate output buffer by querying the size from the context. This may be different for different input shapes. + h_output = cuda.pagelocked_empty(tuple(context.get_binding_shape(binding_idx_offset + 3)), dtype=np.float32) + d_output = cuda.mem_alloc(h_output.nbytes) + + def inference(features, tokens): + global h_output + + _NetworkOutput = collections.namedtuple( # pylint: disable=invalid-name + "NetworkOutput", + ["start_logits", "end_logits", "feature_index"]) + networkOutputs = [] + + eval_time_elapsed = 0 + for feature_index, feature in enumerate(features): + # Copy inputs + input_ids_batch = np.repeat(np.expand_dims(feature.input_ids, 0), args.batch_size, axis=0) + segment_ids_batch = np.repeat(np.expand_dims(feature.segment_ids, 0), args.batch_size, axis=0) + input_mask_batch = np.repeat(np.expand_dims(feature.input_mask, 0), args.batch_size, axis=0) + + input_ids = cuda.register_host_memory(np.ascontiguousarray(input_ids_batch.ravel())) + segment_ids = cuda.register_host_memory(np.ascontiguousarray(segment_ids_batch.ravel())) + input_mask = cuda.register_host_memory(np.ascontiguousarray(input_mask_batch.ravel())) + + eval_start_time = time.time() + cuda.memcpy_htod_async(d_inputs[0], input_ids, stream) + cuda.memcpy_htod_async(d_inputs[1], segment_ids, stream) + cuda.memcpy_htod_async(d_inputs[2], input_mask, stream) + + # Run inference + context.execute_async_v2(bindings=[0 for i in range(binding_idx_offset)] +[int(d_inp) for d_inp in d_inputs] + [int(d_output)], stream_handle=stream.handle) + # Synchronize the stream + stream.synchronize() + eval_time_elapsed += (time.time() - eval_start_time) + + # Transfer predictions back from GPU + cuda.memcpy_dtoh_async(h_output, d_output, stream) + stream.synchronize() + # for x in h_output[0].reshape(-1,2): + # print(x) + # Only retrieve and post-process the first batch + batch = h_output[0] + + networkOutputs.append(_NetworkOutput( + start_logits = np.array(batch.squeeze()[:, 0]), + end_logits = np.array(batch.squeeze()[:, 1]), + feature_index = feature_index + )) + + eval_time_elapsed /= len(features) + + # Total number of n-best predictions to generate in the nbest_predictions.json output file + n_best_size = 20 + + # The maximum length of an answer that can be generated. This is needed + # because the start and end predictions are not conditioned on one another + max_answer_length = 30 + + prediction, nbest_json, scores_diff_json = dp.get_predictions(tokens, features, + networkOutputs, args.n_best_size, args.max_answer_length) + + return eval_time_elapsed, prediction, nbest_json + + def print_single_query(eval_time_elapsed, prediction, nbest_json): + print("------------------------") + print("Running inference in {:.3f} Sentences/Sec".format(args.batch_size/eval_time_elapsed)) + print("------------------------") + + print("Answer: '{}'".format(prediction)) + print("With probability: {:.3f}".format(nbest_json[0]['probability'] * 100.0)) + + def inference_all_dynamic(features_list, squad_examples, sort_index, all_precision): + # h_output = torch.tensor((args.batch_size, max_seq_length, 2)) + global h_output + _NetworkOutput = collections.namedtuple( # pylint: disable=invalid-name + "NetworkOutput", + ["start_logits", "end_logits", "feature_index"]) + networkOutputs = [] + + batch_input_ids = [] + batch_segment_ids = [] + all_token_ids = [] + batch_example_list = [] + batch_feature_list = [] + batch_feature = [] + batch_example = [] + max_batch_length = 0 + seq_length_list = [] + for index in sort_index: + batch_feature.append(features_list[index]) + batch_example.append(squad_examples[index]) + max_batch_length = max(max_batch_length, len(features_list[index].input_ids)) + if args.int8: + max_batch_length = math.ceil(max_batch_length / 2) * 2 + else: + # workround to solve bs=1 10% slow + if args.batch_size == 1: + max_batch_length = math.ceil(max_batch_length / 64) * 64 + seq_length_list.append(len(features_list[index].input_ids)) + if len(batch_feature) == args.batch_size: + batch_input_ids = [ + np.pad(bf.input_ids, (0, max_batch_length - bf.input_ids.shape[0]), 'constant',constant_values = (0)).reshape(1, -1) + for bf in batch_feature + ] + batch_input_ids = np.concatenate(batch_input_ids, axis=0) + batch_segment_ids = [ + np.pad(bf.segment_ids, (0, max_batch_length - bf.segment_ids.shape[0]), 'constant',constant_values = (0)).reshape(1, -1) + for bf in batch_feature + ] + batch_segment_ids = np.concatenate(batch_segment_ids, axis=0) + all_token_ids.append( + [ + batch_input_ids.astype(np.int32), + batch_segment_ids.astype(np.int32) + ] + ) + batch_example_list.append(batch_example) + batch_feature_list.append(batch_feature) + batch_input_ids = [] + batch_segment_ids = [] + batch_feature = [] + batch_example = [] + max_batch_length = 0 + + if len(batch_feature): + batch_input_ids = [ + np.pad(bf.input_ids, (0, max_batch_length - bf.input_ids.shape[0]), 'constant',constant_values = (0)).reshape(1, -1) + for bf in batch_feature + ] + batch_input_ids = np.concatenate(batch_input_ids, axis=0) + batch_segment_ids = [ + np.pad(bf.segment_ids, (0, max_batch_length - bf.segment_ids.shape[0]), 'constant',constant_values = (0)).reshape(1, -1) + for bf in batch_feature + ] + batch_segment_ids = np.concatenate(batch_segment_ids, axis=0) + all_token_ids.append( + [ + batch_input_ids.astype(np.int32), + batch_segment_ids.astype(np.int32) + ] + ) + batch_input_ids = [] + batch_segment_ids = [] + batch_example_list.append(batch_example) + batch_feature_list.append(batch_feature) + + # warm up + for i in range(20): + for binding in range(3): + context.set_binding_shape(binding, (args.batch_size, max_seq_length)) + assert context.all_binding_shapes_specified + cuda.memcpy_htod_async(d_inputs[0], np.zeros((args.batch_size, max_seq_length), dtype=np.int32).ravel(), stream) + cuda.memcpy_htod_async(d_inputs[1], np.zeros((args.batch_size, max_seq_length), dtype=np.int32).ravel(), stream) + context.execute_async_v2(bindings=[0 for i in range(binding_idx_offset)] +[int(d_inp) for d_inp in d_inputs] + [int(d_output)], stream_handle=stream.handle) + stream.synchronize() + + infer_toal_time = 0 + output_index = 0 + for input_ids, segment_ids in tqdm(all_token_ids): + for binding in range(3): + context.set_binding_shape(binding, input_ids.shape) + assert context.all_binding_shapes_specified + + cuda.memcpy_htod_async(d_inputs[0], input_ids.ravel(), stream) + cuda.memcpy_htod_async(d_inputs[1], segment_ids.ravel(), stream) + stream.synchronize() + + infer_start_time = time.time() + context.execute_async_v2(bindings=[0 for i in range(binding_idx_offset)] +[int(d_inp) for d_inp in d_inputs] + [int(d_output)], stream_handle=stream.handle) + stream.synchronize() + infer_end_time = time.time() + infer_time = infer_end_time - infer_start_time + infer_toal_time += infer_time + + cuda.memcpy_dtoh_async(h_output, d_output, stream) + stream.synchronize() + + new_h_output = np.array(h_output.reshape(-1)[:input_ids.shape[0]*input_ids.shape[1]*2]).reshape(input_ids.shape[0], input_ids.shape[1], 2) + for index in range(input_ids.shape[0]): + networkOutputs.append(_NetworkOutput( + start_logits = new_h_output[index, :seq_length_list[output_index], 0], + end_logits = new_h_output[index, :seq_length_list[output_index], 1], + feature_index = index + )) + output_index += 1 + + output_index = 0 + for (be, bf) in zip(batch_example_list, batch_feature_list): + for index in range(len(bf)): + prediction, nbest_json, scores_diff_json = dp.get_predictions(be[index].doc_tokens, bf, + [networkOutputs[output_index]], args.n_best_size, args.max_answer_length) + output_index += 1 + all_precision[be[index].id] = prediction + return infer_toal_time, all_precision + + status = 0 + if squad_examples: + all_predictions = collections.OrderedDict() + + features_list = [] + lengths = [] + + for example_index, example in enumerate(squad_examples): + features = question_features(example.doc_tokens, example.question_text) + features_list.append(features[0]) + lengths.append(len(features[0].input_ids)) + + sort_index = np.argsort(lengths) + infer_time, all_predictions = inference_all_dynamic(features_list, squad_examples, sort_index, all_predictions) + + qps = math.ceil(len(squad_examples)/args.batch_size)*args.batch_size/infer_time + print(f"Latency QPS: {qps} sentences/s") + + with open(output_prediction_file, "w") as f: + f.write(json.dumps(all_predictions, indent=4)) + print("\nOutput dump to {}".format(output_prediction_file)) + + if args.target_qps: + if qps >= args.target_qps: + print(f"target qps: {args.target_qps}, qps: {qps}, pass.") + else: + print(f"target qps: {args.target_qps}, qps: {qps}, failed.") + status = 1 + else: + # Extract tokecs from the paragraph + doc_tokens = dp.convert_doc_tokens(paragraph_text) + + if question_text: + print("\nPassage: {}".format(paragraph_text)) + print("\nQuestion: {}".format(question_text)) + + features = question_features(doc_tokens, question_text) + eval_time_elapsed, prediction, nbest_json = inference(features, doc_tokens) + print_single_query(eval_time_elapsed, prediction, nbest_json) + else: + # If no question text is provided, loop until the question is 'exit' + EXIT_CMDS = ["exit", "quit"] + question_text = input("Question (to exit, type one of {:}): ".format(EXIT_CMDS)) + + while question_text.strip() not in EXIT_CMDS: + features = question_features(doc_tokens, question_text) + eval_time_elapsed, prediction, nbest_json = inference(features, doc_tokens) + # print_single_query(eval_time_elapsed, prediction, nbest_json) + # question_text = input("Question (to exit, type one of {:}): ".format(EXIT_CMDS)) + del context + del engine + sys.exit(status) \ No newline at end of file diff --git a/models/nlp/language_model/bert_base_squad/ixrt/python/ixrt/load_ixrt_plugin.py b/models/nlp/language_model/bert_base_squad/ixrt/python/ixrt/load_ixrt_plugin.py new file mode 100644 index 0000000000000000000000000000000000000000..444296d57aa370c7f068b57e8e88eb9276470f6f --- /dev/null +++ b/models/nlp/language_model/bert_base_squad/ixrt/python/ixrt/load_ixrt_plugin.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from os.path import join, dirname, exists +import tensorrt as trt +import ctypes + +current_directory = os.getcwd() + +def load_ixrt_plugin(logger=trt.Logger(trt.Logger.WARNING), namespace="", dynamic_path=""): + if not dynamic_path: + dynamic_path = join(current_directory, "..", "..", "build", "libixrt_plugin.so") + if not exists(dynamic_path): + raise FileNotFoundError( + f"The ixrt_plugin lib {dynamic_path} is not existed, please provided effective plugin path!") + ctypes.CDLL(dynamic_path, mode=ctypes.RTLD_GLOBAL) + trt.init_libnvinfer_plugins(logger, namespace) + print(f"Loaded plugin from {dynamic_path}") \ No newline at end of file diff --git a/models/nlp/language_model/bert_base_squad/ixrt/python/ixrt/perf.py b/models/nlp/language_model/bert_base_squad/ixrt/python/ixrt/perf.py new file mode 100644 index 0000000000000000000000000000000000000000..369f28f9938597059dce0595d500b3734da3d78d --- /dev/null +++ b/models/nlp/language_model/bert_base_squad/ixrt/python/ixrt/perf.py @@ -0,0 +1,180 @@ +#!/usr/bin/env python3 +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import ctypes +import time +import numpy as np +import tensorrt as trt +import pycuda.driver as cuda +import pycuda.autoinit + +import numpy as np + +TRT_LOGGER = trt.Logger(trt.Logger.ERROR) +from load_ixrt_plugin import load_ixrt_plugin + +class DeviceBuffer(object): + def __init__(self, shape, dtype=trt.int32): + self.buf = cuda.mem_alloc(trt.volume(shape) * 4) + + def binding(self): + return int(self.buf) + + def free(self): + self.buf.free() + + +def main(): + parser = argparse.ArgumentParser(description='BERT Inference Benchmark') + parser.add_argument("-z", "--use_trt", action="store_false", help="Whether to use tensorRT or IxRT") + parser.add_argument("-e", "--engine", help='Path to BERT TensorRT engine') + parser.add_argument('-b', '--batch-size', default=[], action="append", help='Batch size(s) to benchmark. Can be specified multiple times for more than one batch size. This script assumes that the engine has been built with one optimization profile for each batch size, and that these profiles are in order of increasing batch size.', type=int) + parser.add_argument('-s', '--sequence-length', default=128, help='Sequence length of the BERT model', type=int) + parser.add_argument('-i', '--iterations', default=200, help='Number of iterations to run when benchmarking each batch size.', type=int) + parser.add_argument('-w', '--warm-up-runs', default=10, help='Number of iterations to run prior to benchmarking.', type=int) + parser.add_argument('-d', '--duration', default=0.0, help='Minimal number of seconds to run when benchmarking each batch size.', type=float) + parser.add_argument('-r', '--random-seed', required=False, default=12345, help='Random seed.', type=int) + parser.add_argument('-t', '--target-qps', default=0, help='Target QPS', type=int) + args, _ = parser.parse_known_args() + args.batch_size = args.batch_size or [1] + + # Import necessary plugins for BERT TensorRT + load_ixrt_plugin(TRT_LOGGER) + + with open(args.engine, 'rb') as f: + runtime = trt.Runtime(TRT_LOGGER) + engine = runtime.deserialize_cuda_engine(f.read()) + context = engine.create_execution_context() + + # Allocate buffers large enough to store the largest batch size + max_input_shape = (max(args.batch_size), args.sequence_length) + max_output_shape = (max(args.batch_size), args.sequence_length, 2, 1, 1) + buffers = [ + DeviceBuffer(max_input_shape), + DeviceBuffer(max_input_shape), + DeviceBuffer(max_input_shape), + DeviceBuffer(max_output_shape) + ] + + # Prepare random input + pseudo_vocab_size = 30522 + pseudo_type_vocab_size = 2 + np.random.seed(args.random_seed) + test_word_ids = np.random.randint(0, pseudo_vocab_size, (max(args.batch_size), args.sequence_length), dtype=np.int32) + test_segment_ids = np.random.randint(0, pseudo_type_vocab_size, (max(args.batch_size), args.sequence_length), dtype=np.int32) + test_input_mask = np.ones((max(args.batch_size), args.sequence_length), dtype=np.int32) + + # Copy input h2d + cuda.memcpy_htod(buffers[0].buf, test_word_ids.ravel()) + cuda.memcpy_htod(buffers[1].buf, test_segment_ids.ravel()) + cuda.memcpy_htod(buffers[2].buf, test_input_mask.ravel()) + + num_binding_per_profile = engine.num_bindings // engine.num_optimization_profiles + + bench_times = {} + + stream = cuda.Stream() + for batch_size in sorted(args.batch_size): + # # Select engine profile + selected_profile = -1 + for idx in range(engine.num_optimization_profiles): + profile_shape = engine.get_profile_shape(idx, idx * num_binding_per_profile) + if profile_shape[0][0] <= batch_size and profile_shape[2][0] >= batch_size and profile_shape[0][1] <= args.sequence_length and profile_shape[2][1] >= args.sequence_length: + selected_profile = idx + break + if selected_profile == -1: + raise RuntimeError("None of the dynamic shape profiles meets the requirement batch = {} and sequence = {}.".format(batch_size, args.sequence_length)) + context.set_optimization_profile_async(selected_profile, stream.handle) + + # Each profile has unique bindings + binding_idx_offset = selected_profile * num_binding_per_profile + bindings = [0] * binding_idx_offset + [buf.binding() for buf in buffers] + + shapes = { + 0 : (batch_size, args.sequence_length), + 1 : (batch_size, args.sequence_length), + 2 : (batch_size, args.sequence_length), + } + + for binding, shape in shapes.items(): + context.set_binding_shape(binding, shape) + assert context.all_binding_shapes_specified + + # Inference + total_time = 0 + start = cuda.Event() + end = cuda.Event() + + # Warmup + for _ in range(args.warm_up_runs): + context.execute_async_v2(bindings=bindings, stream_handle=stream.handle) + stream.synchronize() + + # Timing loop + times = [] + actual_iterations = 0 + start_time = time.time() + while actual_iterations < args.iterations or (time.time() - start_time) < args.duration: + start.record(stream) + context.execute_async_v2(bindings=bindings, stream_handle=stream.handle) + end.record(stream) + stream.synchronize() + times.append(end.time_since(start)) + actual_iterations += 1 + + # Compute average time, 95th percentile time and 99th percentile time. + bench_times[batch_size] = times + + [b.free() for b in buffers] + del context + del engine + + for batch_size, times in bench_times.items(): + total_time = sum(times) + avg_time = total_time / float(actual_iterations) + # times.sort() + # percentile95 = times[int(actual_iterations * 0.95)] + # percentile99 = times[int(actual_iterations * 0.99)] + # print("Running {:} iterations with Batch Size: {:}\n\tTotal Time: {:} ms \tAverage Time: {:} ms\t95th Percentile Time: {:} ms\t99th Percentile Time: {:}".format(actual_iterations, batch_size, total_time, avg_time, percentile95, percentile99)) + QPS = 1000.0 / (avg_time / batch_size) + print("BatchSize = {:d}, QPS = {:.3f}".format(batch_size, QPS)) + if QPS >= args.target_qps: + print("Performance Check : Test {:.3f} >= target {:.3f}".format(QPS, args.target_qps)) + print("pass!") + exit() + else: + print("failed!") + exit(1) + + + +if __name__ == '__main__': + main() diff --git a/models/nlp/language_model/bert_base_squad/ixrt/python/script/infer_bert_base_squad_fp16_ixrt.sh b/models/nlp/language_model/bert_base_squad/ixrt/python/script/infer_bert_base_squad_fp16_ixrt.sh new file mode 100644 index 0000000000000000000000000000000000000000..5c7767fd1cd801ac8a2d3a7b26a519d545304e0d --- /dev/null +++ b/models/nlp/language_model/bert_base_squad/ixrt/python/script/infer_bert_base_squad_fp16_ixrt.sh @@ -0,0 +1,50 @@ +set -eo pipefail + +BSZ=32 +TGT=87 +USE_TRT=False + +# Update arguments +index=0 +options=$@ +arguments=($options) +for argument in $options +do + index=`expr $index + 1` + case $argument in + --bs) BSZ=${arguments[index]};; + --tgt) TGT=${arguments[index]};; + --use_trt) USE_TRT=${arguments[index]};; + esac +done + +current_path=$(cd $(dirname "${BASH_SOURCE[0]}") && pwd) +project_path=$(realpath ${current_path}/..) +checkpoints_path=${project_path}/data/bert_base_uncased_squad +datasets_path=${project_path}/data/ + +echo 'USE_TRT='${USE_TRT} +export USE_TRT=$USE_TRT + +echo "Step1 Build Engine FP16(bert base squad)!" +cd ${project_path}/ixrt +python3 builder.py -x ${checkpoints_path}/bert_base_squad.onnx \ + -w 4096 \ + -o ${checkpoints_path}/bert_base_b${BSZ}.engine \ + -s 1 384 384 \ + -b 1 ${BSZ} ${BSZ}\ + --fp16 \ + -c ${checkpoints_path}/config.json \ + -z ${USE_TRT} + +echo "Step2 Run dev.json and generate json" +python3 inference.py -e ${checkpoints_path}/bert_base_b${BSZ}.engine \ + -s 384 \ + -b ${BSZ} \ + -sq ${datasets_path}/squad/dev-v1.1.json \ + -v ${checkpoints_path}/vocab.txt \ + -o ${checkpoints_path}/predictions-bert_base_b${BSZ}.json \ + -z ${USE_TRT} + +echo "Step3 Inference(test F1-score)" +python3 evaluate-v1.1.py ${datasets_path}/squad/dev-v1.1.json ${checkpoints_path}/predictions-bert_base_b${BSZ}.json ${TGT} \ No newline at end of file diff --git a/models/nlp/language_model/bert_base_squad/ixrt/python/script/infer_bert_base_squad_int8_ixrt.sh b/models/nlp/language_model/bert_base_squad/ixrt/python/script/infer_bert_base_squad_int8_ixrt.sh new file mode 100644 index 0000000000000000000000000000000000000000..93195450ec16935abbc0739a7c8fc1e842cecf1b --- /dev/null +++ b/models/nlp/language_model/bert_base_squad/ixrt/python/script/infer_bert_base_squad_int8_ixrt.sh @@ -0,0 +1,50 @@ +set -eo pipefail + +BSZ=32 +TGT=86 +USE_TRT=False + +# Update arguments +index=0 +options=$@ +arguments=($options) +for argument in $options +do + index=`expr $index + 1` + case $argument in + --bs) BSZ=${arguments[index]};; + --tgt) TGT=${arguments[index]};; + --use_trt) USE_TRT=${arguments[index]};; + esac +done + +current_path=$(cd $(dirname "${BASH_SOURCE[0]}") && pwd) +project_path=$(realpath ${current_path}/..) +echo ${project_path} +checkpoints_path=${project_path}/data/bert_base_uncased_squad/ +datasets_path=${project_path}/data/ + +echo 'USE_TRT='${USE_TRT} +export USE_TRT=$USE_TRT + +echo "Step1 Build Engine Int8(bert base squad)!" +cd ${project_path}/ixrt +python3 builder_int8.py -pt ${checkpoints_path}/bert_base_int8_qat.bin \ + -o ${checkpoints_path}/bert_base_int8_b${BSZ}.engine \ + -b 1 ${BSZ} ${BSZ} \ + -s 1 384 384 \ + -i \ + -c ${checkpoints_path} + +echo "Step2 Run dev.json and generate json" +python3 inference.py -e ${checkpoints_path}/bert_base_int8_b${BSZ}.engine \ + -b ${BSZ} \ + -s 384 \ + -sq ${datasets_path}/squad/dev-v1.1.json \ + -v ${checkpoints_path}/vocab.txt \ + -o ${checkpoints_path}/predictions-bert_base_int8_b${BSZ}.json \ + -z ${USE_TRT} \ + -i + +echo "Step3 Inference(test F1-score)" +python3 evaluate-v1.1.py ${datasets_path}/squad/dev-v1.1.json ${checkpoints_path}/predictions-bert_base_int8_b${BSZ}.json ${TGT} \ No newline at end of file diff --git a/models/nlp/language_model/bert_base_squad/ixrt/python/script/prepare.sh b/models/nlp/language_model/bert_base_squad/ixrt/python/script/prepare.sh index 843166dec9d30224e818649f61868e9b968a2f37..18bc8ca1d2496ec382b98e6639568d8bdea27f33 100644 --- a/models/nlp/language_model/bert_base_squad/ixrt/python/script/prepare.sh +++ b/models/nlp/language_model/bert_base_squad/ixrt/python/script/prepare.sh @@ -54,7 +54,7 @@ fi echo "Step 2: Downloading model file and config to ./data/bert-large-uncased" if [ ! -d "./bert_base_uncased_squad" ]; then - wget https://drive.google.com/file/d/1_q7SaiZjwysJ3jWAIQT2Ne-duFdgWivR/view?usp=drive_link + wget https://drive.google.com/file/d/1_DJDdKBanqJ6h3VGhH78F9EPgE2wK_Tw/view?usp=drive_link unzip bert_base_uncased_squad.zip -d ./ rm -f bert_base_uncased_squad.zip else diff --git a/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/api/plugin_loader.cc b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/api/plugin_loader.cc new file mode 100644 index 0000000000000000000000000000000000000000..af5a1c61bbd581452997834ef75575198136d3e1 --- /dev/null +++ b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/api/plugin_loader.cc @@ -0,0 +1,143 @@ +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +* All Rights Reserved. +* +* Licensed under the Apache License, Version 2.0 (the "License"); you may +* not use this file except in compliance with the License. You may obtain +* a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +* License for the specific language governing permissions and limitations +* under the License. +*/ +// Created on 2023/11/22. +#include +#include +#include +#include + +#include "NvInfer.h" +#include "NvInferPlugin.h" +#include "custom_fc/fcPlugin.h" +#include "emb_layernorm/embLayerNormInt8Plugin.h" +#include "emb_layernorm/embLayerNormPlugin.h" +#include "ffn/ffnPlugin.h" +#include "gelu/geluPlugin.h" +#include "qkv_to_context/qkvToContextInt8Plugin.h" +#include "qkv_to_context/qkvToContextPlugin.h" +#include "skip_layernorm/skipLayerNormInt8Plugin.h" +#include "skip_layernorm/skipLayerNormPlugin.h" +using namespace nvinfer1; +using namespace nvinfer1::plugin; + +namespace nvinfer1 { +namespace plugin { + +extern ILogger* gLogger; + +} // namespace plugin +} // namespace nvinfer1 + +namespace { +// This singleton ensures that each plugin is only registered once for a given +// namespace and type, and attempts of duplicate registration are ignored. +class PluginCreatorRegistry { + public: + static PluginCreatorRegistry& getInstance() { + static PluginCreatorRegistry instance; + return instance; + } + + string GetPluginUniqKey(const AsciiChar* const plugin_namespace, const AsciiChar* const plugin_name, + const AsciiChar* const plugin_version) { + stringstream os; + os << plugin_namespace << "::" << plugin_name << " version " << plugin_version; + return os.str(); + } + + template + void addPluginCreator(void* logger, char const* libNamespace) { + // Make accesses to the plugin creator registry thread safe + std::lock_guard lock(mRegistryLock); + + std::string errorMsg; + std::string verboseMsg; + + std::unique_ptr pluginCreator{new CreatorType{}}; + pluginCreator->setPluginNamespace(libNamespace); + + nvinfer1::plugin::gLogger = static_cast(logger); + std::string pluginType = GetPluginUniqKey(pluginCreator->getPluginNamespace(), pluginCreator->getPluginName(), + pluginCreator->getPluginVersion()); + + if (mRegistryList.find(pluginType) == mRegistryList.end()) { + bool status = getPluginRegistry()->registerCreator(*pluginCreator, libNamespace); + if (status) { + mRegistry.push(std::move(pluginCreator)); + mRegistryList.insert(pluginType); + verboseMsg = "Registered plugin creator - " + pluginType; + } else { + errorMsg = "Could not register plugin creator - " + pluginType; + } + } else { + verboseMsg = "Plugin creator already registered - " + pluginType; + } + + if (logger) { + if (!errorMsg.empty()) { + nvinfer1::plugin::gLogger->log(ILogger::Severity::kERROR, errorMsg.c_str()); + } + if (!verboseMsg.empty()) { + nvinfer1::plugin::gLogger->log(ILogger::Severity::kVERBOSE, verboseMsg.c_str()); + } + } + } + + ~PluginCreatorRegistry() { + std::lock_guard lock(mRegistryLock); + + // Release pluginCreators in LIFO order of registration. + while (!mRegistry.empty()) { + mRegistry.pop(); + } + mRegistryList.clear(); + } + + private: + PluginCreatorRegistry() {} + + std::mutex mRegistryLock; + std::stack> mRegistry; + std::unordered_set mRegistryList; + + public: + PluginCreatorRegistry(PluginCreatorRegistry const&) = delete; + void operator=(PluginCreatorRegistry const&) = delete; +}; + +template +void initializePlugin(void* logger, char const* libNamespace) { + PluginCreatorRegistry::getInstance().addPluginCreator(logger, libNamespace); +} + +} // namespace + +extern "C" { +bool initLibNvInferPlugins(void* logger, const char* libNamespace) { + initializePlugin(logger, libNamespace); + initializePlugin(logger, libNamespace); + initializePlugin(logger, libNamespace); + initializePlugin(logger, libNamespace); + initializePlugin(logger, libNamespace); + initializePlugin(logger, libNamespace); + initializePlugin(logger, libNamespace); + initializePlugin(logger, libNamespace); + initializePlugin(logger, libNamespace); + initializePlugin(logger, libNamespace); + + return true; +} +} diff --git a/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/backend/bert/bert_embed_kernel.cu b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/backend/bert/bert_embed_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..a1f90e6b9e3eeb7375d1c0c8b3d363ece1629b7a --- /dev/null +++ b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/backend/bert/bert_embed_kernel.cu @@ -0,0 +1,150 @@ +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +* All Rights Reserved. +* +* Licensed under the Apache License, Version 2.0 (the "License"); you may +* not use this file except in compliance with the License. You may obtain +* a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +* License for the specific language governing permissions and limitations +* under the License. +*/ +#include "bert_embed_kernel.h" + +namespace nvinfer1::plugin { +namespace backend { + +template +__global__ void IxinferBertEmbedKernel(const int8_t *token_emb, const T *pos_emb, const int *tokens, T *output, + int8_t *pad_mask, int pad_id, int batch_size, int seq_len, int hidden_dim, + float dequant_scale, bool scaled) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= batch_size * seq_len * hidden_dim) { + return; + } + int batch_idx, seq_idx, dim_idx; + decompose_3dim(idx, seq_len, hidden_dim, &batch_idx, &seq_idx, &dim_idx); + int tokens_idx = batch_idx * seq_len + seq_idx; + int token = tokens[tokens_idx]; + float4 value; + + if (token == pad_id) { + if (dim_idx == 0) { + pad_mask[tokens_idx] = 1; + } + value.x = 0.f; + value.y = 0.f; + value.z = 0.f; + value.w = 0.f; + } else { + if (dim_idx == 0) { + pad_mask[tokens_idx] = 0; + } + char4 value_i4 = ((char4 *)token_emb)[token * hidden_dim + dim_idx]; + float4 pemb = ((float4 *)pos_emb)[seq_idx * hidden_dim + dim_idx]; + float scale = dequant_scale; + if (scaled) { + scale *= sqrtf(hidden_dim << 2); + } + value.x = float(value_i4.x) * scale + pemb.x; + value.y = float(value_i4.y) * scale + pemb.y; + value.z = float(value_i4.z) * scale + pemb.z; + value.w = float(value_i4.w) * scale + pemb.w; + } + ((float4 *)output)[idx] = value; +} + +template <> +__global__ void IxinferBertEmbedKernel<__half>(const int8_t *token_emb, const __half *pos_emb, const int *tokens, + __half *output, int8_t *pad_mask, int pad_id, int batch_size, + int seq_len, int hidden_dim, float dequant_scale, bool scaled) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= batch_size * seq_len * hidden_dim) { + return; + } + int batch_idx, seq_idx, dim_idx; + decompose_3dim(idx, seq_len, hidden_dim, &batch_idx, &seq_idx, &dim_idx); + int tokens_idx = batch_idx * seq_len + seq_idx; + int token = tokens[tokens_idx]; + float4 value; + + if (token == pad_id) { + if (dim_idx == 0) { + pad_mask[tokens_idx] = 1; + } + value.x = 0.f; + value.y = 0.f; + value.z = 0.f; + value.w = 0.f; + } else { + if (dim_idx == 0) { + pad_mask[tokens_idx] = 0; + } + int2 value_i8 = ((int2 *)token_emb)[token * hidden_dim + dim_idx]; + float4 pemb = ((float4 *)pos_emb)[seq_idx * hidden_dim + dim_idx]; + __half2 *value_h2 = (__half2 *)(&value); + char2 *value_i2 = (char2 *)(&value_i8); + __half2 *pemb_h2 = (__half2 *)(&pemb); + float scale = dequant_scale; + if (scaled) { + scale *= sqrtf(hidden_dim << 3); + } +#pragma unroll + for (int i = 0; i < 4; i++) { + float2 value_f2; + float2 pemb_f2 = __half22float2(pemb_h2[i]); + value_f2.x = float(value_i2[i].x) * scale + pemb_f2.x; + value_f2.y = float(value_i2[i].y) * scale + pemb_f2.y; + value_h2[i] = __float22half2_rn(value_f2); + } + } + ((float4 *)output)[idx] = value; +} + +template +void cuinferEncEmbI8I(const int8_t *token_emb, const T *pos_emb, const int *tokens, T *output, int8_t *pad_mask, + int pad_id, int batch_size, int seq_len, int hidden_dim, cudaStream_t stream, const T *lang_emb, + const int *lang_id, int multilg_type, float dequant_scale, bool scaled) { + if (hidden_dim % 4 != 0) { + throw std::runtime_error("BertEmbed: hidden_size % 4 ! = 0"); + } + if (multilg_type != 0) { + throw std::runtime_error("multilingle not supported"); + } + + const int max_threads = 1024; + + hidden_dim >>= 2; + int nele = batch_size * seq_len * hidden_dim; + int nblock = (nele + max_threads - 1) / max_threads; + + IxinferBertEmbedKernel<<>>( + token_emb, pos_emb, tokens, output, pad_mask, pad_id, batch_size, seq_len, hidden_dim, dequant_scale, scaled); +} + +template <> +void cuinferEncEmbI8I<__half>(const int8_t *token_emb, const __half *pos_emb, const int *tokens, __half *output, + int8_t *pad_mask, int pad_id, int batch_size, int seq_len, int hidden_dim, + cudaStream_t stream, const __half *lang_emb, const int *lang_id, int multilg_type, + float dequant_scale, bool scaled) { + if (hidden_dim % 8 != 0) { + throw std::runtime_error("BertEmbed: hidden_size % 8 ! = 0"); + } + if (multilg_type != 0) { + throw std::runtime_error("multilingle not supported"); + } + hidden_dim >>= 3; + const int max_threads = 1024; + int nele = batch_size * seq_len * hidden_dim; + int nblock = (nele + max_threads - 1) / max_threads; + + IxinferBertEmbedKernel<__half><<>>( + token_emb, pos_emb, tokens, output, pad_mask, pad_id, batch_size, seq_len, hidden_dim, dequant_scale, scaled); +} + +} // namespace backend +} // namespace nvinfer1::plugin diff --git a/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/backend/bert/bert_embed_kernel.h b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/backend/bert/bert_embed_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..b045e1d0628775be413563002ab28cba6688fd5e --- /dev/null +++ b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/backend/bert/bert_embed_kernel.h @@ -0,0 +1,28 @@ +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +* All Rights Reserved. +* +* Licensed under the Apache License, Version 2.0 (the "License"); you may +* not use this file except in compliance with the License. You may obtain +* a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +* License for the specific language governing permissions and limitations +* under the License. +*/ +#pragma once + +#include "bert_helper.h" + +namespace nvinfer1::plugin { +namespace backend { + +template +void cuinferEncEmbI8I(const int8_t *token_emb, const T *pos_emb, const int *tokens, T *output, int8_t *pad_mask, + int pad_id, int batch_size, int seq_len, int hidden_dim, cudaStream_t stream, const T *lang_emb, + const int *lang_id, int multilg_type, float dequant_scale, bool scaled); +} // namespace backend +} // namespace nvinfer1::plugin diff --git a/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/backend/bert/bert_helper.h b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/backend/bert/bert_helper.h new file mode 100644 index 0000000000000000000000000000000000000000..760823c4534b37f6d1a015e128a1ccf93b8889e0 --- /dev/null +++ b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/backend/bert/bert_helper.h @@ -0,0 +1,296 @@ +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +* All Rights Reserved. +* +* Licensed under the Apache License, Version 2.0 (the "License"); you may +* not use this file except in compliance with the License. You may obtain +* a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +* License for the specific language governing permissions and limitations +* under the License. +*/ +#pragma once +#include +#include + +#include + +#ifndef C10_WARP_SIZE + +#ifdef __ILUVATAR__ +#define C10_WARP_SIZE 64 +#else +#define C10_WARP_SIZE 32 +#endif + +#endif + +namespace nvinfer1::plugin { +namespace backend { + +const float epsilon = 0.000000000001; +const unsigned int WARP_REDUCE_MASK = 0xffffffff; +const float CUDA_FLOAT_INF_NEG = -100000000.f; // FIXME later +const float CUDA_FLOAT_INF_POS = 100000000.f; // FIXME later +const int CUDA_INT_INF = 2147483647; +const int MAX_THREADS = 1024; + +__forceinline__ __device__ int8_t float2int8(float x, float quant_scale) { + float i8_f = x * quant_scale; + int32_t i8 = floorf(i8_f + 0.5); + i8 = i8 < -127 ? -127 : (i8 > 127 ? 127 : i8); + return int8_t(i8); +} + +inline __device__ void WelfordCombine(float val, float *mean, float *m2, float *count) { + // Use Welford Online algorithem to compute mean and variance + // For more details you can refer to: + // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm + *count += 1; + float delta1 = val - *mean; + *mean += delta1 / *count; + float delta2 = val - *mean; + *m2 += delta1 * delta2; +} + +inline __device__ void WelfordCombine(float b_mean, float b_m2, float b_count, float *mean, float *m2, float *count) { + if (b_count == 0) { + return; + } + float new_count = *count + b_count; + float nb_over_n = b_count / new_count; + float delta = b_mean - *mean; + *mean += delta * nb_over_n; + *m2 += b_m2 + delta * delta * (*count) * nb_over_n; + *count = new_count; +} + +__inline__ __device__ void WelfordWarpReduce(float thread_mean, float thread_m2, float thread_count, float *mean, + float *m2, float *count) { + *mean = thread_mean; + *m2 = thread_m2; + *count = thread_count; + for (int mask = C10_WARP_SIZE / 2; mask > 0; mask /= 2) { + float b_mean = __shfl_down_sync(0xffffffff, *mean, mask); + float b_m2 = __shfl_down_sync(0xffffffff, *m2, mask); + float b_count = __shfl_down_sync(0xffffffff, *count, mask); + WelfordCombine(b_mean, b_m2, b_count, mean, m2, count); + } +} +// addd by pxl +// block内所有数据完成reduce +// template +__inline__ __device__ void WelfordBlockAllReduce(float thread_mean, float thread_m2, float thread_count, + float *result_mean, float *result_m2, float *result_count) { + __shared__ float mean_shared[warpSize]; + __shared__ float m2_shared[warpSize]; + __shared__ float count_shared[warpSize]; + __shared__ float mean_result_broadcast; + __shared__ float m2_result_broadcast; + __shared__ float count_result_broadcast; + + const int lid = threadIdx.x % warpSize; + const int wid = threadIdx.x / warpSize; + float warp_mean = 0; + float warp_m2 = 0; + float warp_count = 0; + WelfordWarpReduce(thread_mean, thread_m2, thread_count, &warp_mean, &warp_m2, &warp_count); + __syncthreads(); + + if (lid == 0) { + mean_shared[wid] = warp_mean; + m2_shared[wid] = warp_m2; + count_shared[wid] = warp_count; + } + __syncthreads(); + + if (wid == 0) { + if (threadIdx.x < blockDim.x / warpSize) { + warp_mean = mean_shared[lid]; + warp_m2 = m2_shared[lid]; + warp_count = count_shared[lid]; + + } else { + warp_mean = 0.f; + warp_m2 = 0.f; + warp_count = 0.f; + } + __syncwarp(); + + float block_mean = 0; + float block_m2 = 0; + float block_count = 0; + + WelfordWarpReduce(warp_mean, warp_m2, warp_count, &block_mean, &block_m2, &block_count); + + if (lid == 0) { + mean_result_broadcast = block_mean; + m2_result_broadcast = block_m2; + count_result_broadcast = block_count; + } + } + __syncthreads(); + *result_mean = mean_result_broadcast; + *result_m2 = m2_result_broadcast; + *result_count = count_result_broadcast; +} +__forceinline__ __device__ char4 float42char4(float4 vals, float quant_scale) { + char4 res; + res.x = float2int8(vals.x, quant_scale); + res.y = float2int8(vals.y, quant_scale); + res.z = float2int8(vals.z, quant_scale); + res.w = float2int8(vals.w, quant_scale); + return res; +} + +// load 两个 half2, 保存到 float4 +__forceinline__ __device__ void load_float4_from_half(float4 &vals, __half2 *input, int index) { + __half2 i1 = input[index * 2]; + __half2 i2 = input[index * 2 + 1]; + + vals.x = __half2float(i1.x); + vals.y = __half2float(i1.y); + vals.z = __half2float(i2.x); + vals.w = __half2float(i2.y); +} + +/* Convert vector index to 3-dim tensor index */ +__forceinline__ __host__ __device__ void decompose_3dim(int src, int dim1, int dim2, int *id0, int *id1, int *id2) { + *id2 = src % dim2; + src /= dim2; + + *id1 = src % dim1; + *id0 = src / dim1; +} + +__forceinline__ __device__ float4 compute_float4_norm_value(float4 vals, float mean, float m2, int hidden_size, + float epsilon, float4 scale, float4 bias) { + float4 norm_value; + norm_value.x = + (vals.x - mean) * rsqrtf(m2 / hidden_size + epsilon) * scale.x + bias.x; + norm_value.y = + (vals.y - mean) * rsqrtf(m2 / hidden_size + epsilon) * scale.y + bias.y; + norm_value.z = + (vals.z - mean) * rsqrtf(m2 / hidden_size + epsilon) * scale.z + bias.z; + norm_value.w = + (vals.w - mean) * rsqrtf(m2 / hidden_size + epsilon) * scale.w + bias.w; + return norm_value; +} + +// for layer norm +__forceinline__ __device__ float4 compute_float4_norm_value(float4 vals, float mean, float m2, int hidden_size, + float epsilon, half2 scale_1, half2 scale_2, half2 bias_1, + half2 bias_2) { + float4 norm_value; + norm_value.x = + (vals.x - mean) * rsqrtf(m2 / hidden_size + epsilon) * __half2float(scale_1.x) + __half2float(bias_1.x); + norm_value.y = + (vals.y - mean) * rsqrtf(m2 / hidden_size + epsilon) * __half2float(scale_1.y) + __half2float(bias_1.y); + norm_value.z = + (vals.z - mean) * rsqrtf(m2 / hidden_size + epsilon) * __half2float(scale_2.x) + __half2float(bias_2.x); + norm_value.w = + (vals.w - mean) * rsqrtf(m2 / hidden_size + epsilon) * __half2float(scale_2.y) + __half2float(bias_2.y); + return norm_value; +} +/* Convert half2 into float2, mask inf and -inf */ +__forceinline__ __host__ __device__ float safe_half_to_float(half hval) { + return fmax(fmin(100000.f, __half2float(hval)), -100000.f); +} +__forceinline__ __device__ float4 char4addfloat4_dequant(char4 input_4, float4 residual, + float dequant_scale) { + float4 res; + res.x = __int2float_rn(input_4.x) * dequant_scale + residual.x; + res.y = __int2float_rn(input_4.y) * dequant_scale + residual.y; + res.z = __int2float_rn(input_4.z) * dequant_scale + residual.z; + res.w = __int2float_rn(input_4.w) * dequant_scale + residual.w; + return res; +} +__forceinline__ __device__ float4 char4addhalf2_dequant(char4 input_4, half2 residual_1, half2 residual_2, + float dequant_scale) { + float4 res; + res.x = __int2float_rn(input_4.x) * dequant_scale + safe_half_to_float(residual_1.x); + res.y = __int2float_rn(input_4.y) * dequant_scale + safe_half_to_float(residual_1.y); + res.z = __int2float_rn(input_4.z) * dequant_scale + safe_half_to_float(residual_2.x); + res.w = __int2float_rn(input_4.w) * dequant_scale + safe_half_to_float(residual_2.y); + return res; +} + +// gelu +// IxinferBiasGeluI8II8OKernel +template +__forceinline__ __device__ T tanhf_exp(T x) { + // float e1 = __expf(x); + // float e2 = 1.0f / e1; + // return (e1 - e2) / (e1 + e2); + + return (2.f / (1.f + __expf(-2.f * x)) - 1.f); +} + +template +__forceinline__ __device__ T gelu(T x) { + float cdf = 0.5f * (1.0f + tanhf_exp((0.7978845608028654f * (x + 0.044715f * x * x * x)))); + return x * cdf; +} + +// softmax +__forceinline__ __host__ __device__ int log2_ceil(int value) { + int log2_value = 0; + while ((1 << log2_value) < value) ++log2_value; + return log2_value; +} +template +__device__ __forceinline__ T WARP_SHFL_XOR(T value, int laneMask, int width, unsigned int mask = 0xffffffff) { +#if !(defined(__HIP_PLATFORM_HCC__) || defined(__ILUVATAR__)) + return __shfl_xor_sync(mask, value, laneMask, width); +#else + return __shfl_xor(value, laneMask, width); +#endif +} + +template +struct Add { + __device__ __forceinline__ T operator()(T a, T b) const { return a + b; } +}; + +template +struct Max { + __device__ __forceinline__ T operator()(T a, T b) const { return a < b ? b : a; } +}; +template class ReduceOp> +__device__ __forceinline__ void warp_reduce(acc_t *sum) { + ReduceOp r; +#pragma unroll + for (int offset = REDUCE_WARP_SIZE / 2; offset > 0; offset /= 2) { + acc_t b = WARP_SHFL_XOR(*sum, offset, REDUCE_WARP_SIZE); + *sum = r(*sum, b); + } +} +/* Convert 3-dim tensor index into vector index */ +__forceinline__ __host__ __device__ int targetid_3dim(int id1, int id2, int id3, int dim2, int dim3) { + return id1 * dim2 * dim3 + id2 * dim3 + id3; +} + +/* Convert 4-dim tensor index into vector index */ +__forceinline__ __host__ __device__ int targetid_4dim(int id1, int id2, int id3, int id4, int dim2, int dim3, + int dim4) { + // return id1*(dim2*dim3*dim4) + id2*(dim3*dim4) + id3*dim4 + id4; + int res = id4; + + int ld = dim4; + res += id3 * ld; + + ld *= dim3; + res += id2 * ld; + + ld *= dim2; + res += id1 * ld; + + return res; +} + +} // namespace backend +} // namespace nvinfer1::plugin diff --git a/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/backend/bert/bert_layer_kernel.cu b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/backend/bert/bert_layer_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..4e62a800e248bbeb9e98ce695557fcb1553c9826 --- /dev/null +++ b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/backend/bert/bert_layer_kernel.cu @@ -0,0 +1,1002 @@ +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +* All Rights Reserved. +* +* Licensed under the Apache License, Version 2.0 (the "License"); you may +* not use this file except in compliance with the License. You may obtain +* a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +* License for the specific language governing permissions and limitations +* under the License. +*/ +#include "bert_helper.h" +#include "bert_layer_kernel.h" + +/** +@file +Implemented the cuda kernel function and its +that required by transformer model. +Currently, fp16 and fp32 versions are provided +*/ +namespace nvinfer1::plugin { +namespace backend { +template +__global__ void IxinferResidualBiasLnI8II8OKernel(const int8_t *input, const T *scale, const T *bias, + const T *residual_bias, int8_t *output, T *residual, int hidden_size, + float dequant_scale, float quant_scale, bool is_post_ln, + const T *colsum) { + // register + float vals[THREAD_DATA_LEN]; + int block_start = blockIdx.x * hidden_size; + // one line start + input += block_start; + output += block_start; + residual += block_start; + + float thread_m2 = 0; + float thread_mean = 0; + float thread_count = 0; + + // load data from global memory +#pragma unroll + for (int it = 0; it < THREAD_DATA_LEN; ++it) { + int element_index = threadIdx.x + it * C10_WARP_SIZE; + // vals = dequant(input) + residual + vals[it] = __int2float_rn(input[element_index]) * dequant_scale + (float)residual[element_index]; + WelfordCombine(vals[it], &thread_mean, &thread_m2, &thread_count); + } + + // mean var + float mean = 0; + float m2 = 0; + float count = 0; + WelfordWarpReduce(thread_mean, thread_m2, thread_count, &mean, &m2, &count); + mean = __shfl_sync(0xffffffff, mean, 0, C10_WARP_SIZE); + m2 = __shfl_sync(0xffffffff, m2, 0, C10_WARP_SIZE); + count = __shfl_sync(0xffffffff, count, 0, C10_WARP_SIZE); + +#pragma unroll + for (int it = 0; it < THREAD_DATA_LEN; ++it) { + int element_index = threadIdx.x + it * C10_WARP_SIZE; + float norm_value = (vals[it] - mean) * rsqrtf(m2 / hidden_size + epsilon) * (float)scale[element_index] + + (float)bias[element_index]; + int8_t res = float2int8(norm_value, quant_scale); + output[element_index] = res; + float res_bias_val = (residual_bias == nullptr) ? 0.0f : (float)residual_bias[element_index]; + + if (is_post_ln) { + residual[element_index] = norm_value + res_bias_val; + } else { + residual[element_index] = vals[it] + res_bias_val; + } + } +} +template +__global__ void IxinferResidualBiasLnI8II8OKernel(const int8_t *input, const __half *scale, const __half *bias, + const __half *residual_bias, int8_t *output, __half *residual, + int hidden_size, float dequant_scale, float quant_scale, + bool is_post_ln, const __half *colsum) { + // register + // process 2 data + float4 vals[THREAD_DATA_LEN / 4]; + int block_start = blockIdx.x * hidden_size / 4; + char4 *p_input = (char4 *)input; + char4 *p_output = (char4 *)output; + half2 *p_residual = (half2 *)residual; + half2 *p_scale = (half2 *)scale; + half2 *p_bias = (half2 *)bias; + half2 *p_residual_bias = (half2 *)residual_bias; + // one line start + p_input += block_start; + p_output += block_start; + p_residual += block_start * 2; + + float thread_m2 = 0; + float thread_mean = 0; + float thread_count = 0; + + // load data from global memory +#pragma unroll + for (int it = 0; it < THREAD_DATA_LEN / 4; ++it) { + int element_index = threadIdx.x + it * C10_WARP_SIZE; + // vals = dequant(input) + residual + vals[it] = char4addhalf2_dequant(p_input[element_index], p_residual[element_index * 2], + p_residual[element_index * 2 + 1], dequant_scale); + WelfordCombine(vals[it].x, &thread_mean, &thread_m2, &thread_count); + WelfordCombine(vals[it].y, &thread_mean, &thread_m2, &thread_count); + WelfordCombine(vals[it].z, &thread_mean, &thread_m2, &thread_count); + WelfordCombine(vals[it].w, &thread_mean, &thread_m2, &thread_count); + } + + // mean var + float mean = 0; + float m2 = 0; + float count = 0; + WelfordWarpReduce(thread_mean, thread_m2, thread_count, &mean, &m2, &count); + mean = __shfl_sync(0xffffffff, mean, 0, C10_WARP_SIZE); + m2 = __shfl_sync(0xffffffff, m2, 0, C10_WARP_SIZE); + count = __shfl_sync(0xffffffff, count, 0, C10_WARP_SIZE); + +#pragma unroll + for (int it = 0; it < THREAD_DATA_LEN / 4; ++it) { + int element_index = threadIdx.x + it * C10_WARP_SIZE; + float4 norm_value = compute_float4_norm_value(vals[it], mean, m2, hidden_size, epsilon, + p_scale[element_index * 2], p_scale[element_index * 2 + 1], + p_bias[element_index * 2], p_bias[element_index * 2 + 1]); + + char4 res = float42char4(norm_value, quant_scale); + p_output[element_index] = res; + half2 res_bias_val_1; + half2 res_bias_val_2; + if (residual_bias == nullptr) { + res_bias_val_1.x = __float2half(0.0f); + res_bias_val_1.y = __float2half(0.0f); + res_bias_val_2.x = __float2half(0.0f); + res_bias_val_2.y = __float2half(0.0f); + } else { + res_bias_val_1 = p_residual_bias[element_index * 2]; + res_bias_val_2 = p_residual_bias[element_index * 2 + 1]; + } + half2 r1; + half2 r2; + if (is_post_ln) { + r1.x = __hadd(__float2half(norm_value.x), res_bias_val_1.x); + r1.y = __hadd(__float2half(norm_value.y), res_bias_val_1.y); + r2.x = __hadd(__float2half(norm_value.z), res_bias_val_2.x); + r2.y = __hadd(__float2half(norm_value.w), res_bias_val_2.y); + // res.x = __hadd(__float2half(a.x), b.x); + // res.y = __hadd(__float2half(a.y), b.y); + // p_residual[element_index] = float2addhalf2(norm_value, res_bias_val); + } else { + // p_residual[element_index] = float2addhalf2(vals[it], res_bias_val); + r1.x = __hadd(__float2half(vals[it].x), res_bias_val_1.x); + r1.y = __hadd(__float2half(vals[it].y), res_bias_val_1.y); + r2.x = __hadd(__float2half(vals[it].z), res_bias_val_2.x); + r2.y = __hadd(__float2half(vals[it].w), res_bias_val_2.y); + } + p_residual[element_index * 2] = r1; + p_residual[element_index * 2 + 1] = r2; + } +} + +template +void IxinferResidualBiasLnI8II8O(const int8_t *input, const T *scale, const T *bias, const T *residual_bias, + int8_t *output, T *residual, int batch_tokens, int hidden_size, float dequant_scale, + float quant_scale, int max_thread_per_block, cudaStream_t stream, bool is_post_ln, + const T *colsum) { + if (colsum) { + throw std::runtime_error( + "IxinferResidualBiasLnI8II8O: colsum has not been " + "implemented yet!"); + } + + if (hidden_size > 1024) { + throw std::runtime_error("hidden_size should <= 1024"); + } + if (hidden_size % C10_WARP_SIZE != 0) { + throw std::runtime_error("hidden_size // C10_WARP_SIZE != 0"); + } + dim3 gridSize(batch_tokens); + dim3 blockSize(C10_WARP_SIZE); + + int num_warp = hidden_size / C10_WARP_SIZE; + + switch (num_warp) { + case 1: + IxinferResidualBiasLnI8II8OKernel + <<>>(input, scale, bias, residual_bias, output, residual, hidden_size, + dequant_scale, quant_scale, is_post_ln, colsum); + break; + case 2: + IxinferResidualBiasLnI8II8OKernel + <<>>(input, scale, bias, residual_bias, output, residual, hidden_size, + dequant_scale, quant_scale, is_post_ln, colsum); + break; + case 3: + IxinferResidualBiasLnI8II8OKernel + <<>>(input, scale, bias, residual_bias, output, residual, hidden_size, + dequant_scale, quant_scale, is_post_ln, colsum); + break; + case 4: + IxinferResidualBiasLnI8II8OKernel + <<>>(input, scale, bias, residual_bias, output, residual, hidden_size, + dequant_scale, quant_scale, is_post_ln, colsum); + break; + case 5: + IxinferResidualBiasLnI8II8OKernel + <<>>(input, scale, bias, residual_bias, output, residual, hidden_size, + dequant_scale, quant_scale, is_post_ln, colsum); + break; + case 6: + IxinferResidualBiasLnI8II8OKernel + <<>>(input, scale, bias, residual_bias, output, residual, hidden_size, + dequant_scale, quant_scale, is_post_ln, colsum); + break; + case 7: + IxinferResidualBiasLnI8II8OKernel + <<>>(input, scale, bias, residual_bias, output, residual, hidden_size, + dequant_scale, quant_scale, is_post_ln, colsum); + break; + case 8: + IxinferResidualBiasLnI8II8OKernel + <<>>(input, scale, bias, residual_bias, output, residual, hidden_size, + dequant_scale, quant_scale, is_post_ln, colsum); + break; + case 9: + IxinferResidualBiasLnI8II8OKernel + <<>>(input, scale, bias, residual_bias, output, residual, hidden_size, + dequant_scale, quant_scale, is_post_ln, colsum); + break; + case 10: + IxinferResidualBiasLnI8II8OKernel + <<>>(input, scale, bias, residual_bias, output, residual, hidden_size, + dequant_scale, quant_scale, is_post_ln, colsum); + break; + case 11: + IxinferResidualBiasLnI8II8OKernel + <<>>(input, scale, bias, residual_bias, output, residual, hidden_size, + dequant_scale, quant_scale, is_post_ln, colsum); + break; + case 12: + IxinferResidualBiasLnI8II8OKernel + <<>>(input, scale, bias, residual_bias, output, residual, hidden_size, + dequant_scale, quant_scale, is_post_ln, colsum); + break; + case 13: + IxinferResidualBiasLnI8II8OKernel + <<>>(input, scale, bias, residual_bias, output, residual, hidden_size, + dequant_scale, quant_scale, is_post_ln, colsum); + break; + case 14: + IxinferResidualBiasLnI8II8OKernel + <<>>(input, scale, bias, residual_bias, output, residual, hidden_size, + dequant_scale, quant_scale, is_post_ln, colsum); + break; + case 15: + IxinferResidualBiasLnI8II8OKernel + <<>>(input, scale, bias, residual_bias, output, residual, hidden_size, + dequant_scale, quant_scale, is_post_ln, colsum); + break; + case 16: + IxinferResidualBiasLnI8II8OKernel + <<>>(input, scale, bias, residual_bias, output, residual, hidden_size, + dequant_scale, quant_scale, is_post_ln, colsum); + break; + default: + throw std::runtime_error("IxinferResidualBiasLnI8II8O"); + break; + } +} + +template <> +void IxinferResidualBiasLnI8II8O<__half>(const int8_t *input, const __half *scale, const __half *bias, + const __half *residual_bias, int8_t *output, __half *residual, + int batch_tokens, int hidden_size, float dequant_scale, float quant_scale, + int max_thread_per_block, cudaStream_t stream, bool is_post_ln, + const __half *colsum) { + if (colsum) { + throw std::runtime_error( + "IxinferResidualBiasLnI8II8O: colsum has not been " + "implemented yet!"); + } + + if (hidden_size > 1024) { + throw std::runtime_error("hidden_size should <= 1024"); + } + if (hidden_size % C10_WARP_SIZE != 0) { + throw std::runtime_error("hidden_size // C10_WARP_SIZE != 0"); + } + if (hidden_size % 256 != 0) { + throw std::runtime_error("hidden_size // 256 != 0"); + } + dim3 gridSize(batch_tokens); + dim3 blockSize(C10_WARP_SIZE); + + int num_warp = hidden_size / C10_WARP_SIZE; + + switch (num_warp) { + case 1: + IxinferResidualBiasLnI8II8OKernel<1> + <<>>(input, scale, bias, residual_bias, output, residual, hidden_size, + dequant_scale, quant_scale, is_post_ln, colsum); + break; + case 2: + IxinferResidualBiasLnI8II8OKernel<2> + <<>>(input, scale, bias, residual_bias, output, residual, hidden_size, + dequant_scale, quant_scale, is_post_ln, colsum); + break; + case 3: + IxinferResidualBiasLnI8II8OKernel<3> + <<>>(input, scale, bias, residual_bias, output, residual, hidden_size, + dequant_scale, quant_scale, is_post_ln, colsum); + break; + case 4: + IxinferResidualBiasLnI8II8OKernel<4> + <<>>(input, scale, bias, residual_bias, output, residual, hidden_size, + dequant_scale, quant_scale, is_post_ln, colsum); + break; + case 5: + IxinferResidualBiasLnI8II8OKernel<5> + <<>>(input, scale, bias, residual_bias, output, residual, hidden_size, + dequant_scale, quant_scale, is_post_ln, colsum); + break; + case 6: + IxinferResidualBiasLnI8II8OKernel<6> + <<>>(input, scale, bias, residual_bias, output, residual, hidden_size, + dequant_scale, quant_scale, is_post_ln, colsum); + break; + case 7: + IxinferResidualBiasLnI8II8OKernel<7> + <<>>(input, scale, bias, residual_bias, output, residual, hidden_size, + dequant_scale, quant_scale, is_post_ln, colsum); + break; + case 8: + IxinferResidualBiasLnI8II8OKernel<8> + <<>>(input, scale, bias, residual_bias, output, residual, hidden_size, + dequant_scale, quant_scale, is_post_ln, colsum); + break; + case 9: + IxinferResidualBiasLnI8II8OKernel<9> + <<>>(input, scale, bias, residual_bias, output, residual, hidden_size, + dequant_scale, quant_scale, is_post_ln, colsum); + break; + case 10: + IxinferResidualBiasLnI8II8OKernel<10> + <<>>(input, scale, bias, residual_bias, output, residual, hidden_size, + dequant_scale, quant_scale, is_post_ln, colsum); + break; + case 11: + IxinferResidualBiasLnI8II8OKernel<11> + <<>>(input, scale, bias, residual_bias, output, residual, hidden_size, + dequant_scale, quant_scale, is_post_ln, colsum); + break; + case 12: + IxinferResidualBiasLnI8II8OKernel<12> + <<>>(input, scale, bias, residual_bias, output, residual, hidden_size, + dequant_scale, quant_scale, is_post_ln, colsum); + break; + case 13: + IxinferResidualBiasLnI8II8OKernel<13> + <<>>(input, scale, bias, residual_bias, output, residual, hidden_size, + dequant_scale, quant_scale, is_post_ln, colsum); + break; + case 14: + IxinferResidualBiasLnI8II8OKernel<14> + <<>>(input, scale, bias, residual_bias, output, residual, hidden_size, + dequant_scale, quant_scale, is_post_ln, colsum); + break; + case 15: + IxinferResidualBiasLnI8II8OKernel<15> + <<>>(input, scale, bias, residual_bias, output, residual, hidden_size, + dequant_scale, quant_scale, is_post_ln, colsum); + break; + case 16: + IxinferResidualBiasLnI8II8OKernel<16> + <<>>(input, scale, bias, residual_bias, output, residual, hidden_size, + dequant_scale, quant_scale, is_post_ln, colsum); + break; + default: + throw std::runtime_error("IxinferResidualBiasLnI8II8O"); + break; + } +} + +/* +input = dequant(input) + residual +output = layer_norm(input,scale,bias) +*/ +template +__global__ void IxinferLnResidualI8IKernel(int8_t *input, float *scale, float *bias, float *residual, float *output, + int hidden_size, float dequant_scale) { + // register + float4 vals[THREAD_DATA_LEN]; + int block_start = blockIdx.x * hidden_size; + // one line start + input += block_start; + output += block_start; + residual += block_start; + + char4 *p_input = (char4 *)input; + float4 *p_output = (float4 *)output; + float4 *p_residual = (float4 *)residual; + float4 *p_scale = (float4 *)scale; + float4 *p_bias = (float4 *)bias; + + float thread_m2 = 0; + float thread_mean = 0; + float thread_count = 0; + + // load data from global memory +#pragma unroll + for (int it = 0; it < THREAD_DATA_LEN; ++it) { + int element_index = threadIdx.x + it * C10_WARP_SIZE; + char4 i_value = p_input[element_index]; + float4 r_value = p_residual[element_index]; + + vals[it].x = (float)i_value.x * dequant_scale + r_value.x; + vals[it].y = (float)i_value.y * dequant_scale + r_value.y; + vals[it].z = (float)i_value.z * dequant_scale + r_value.z; + vals[it].w = (float)i_value.w * dequant_scale + r_value.w; + + WelfordCombine(vals[it].x, &thread_mean, &thread_m2, &thread_count); + WelfordCombine(vals[it].y, &thread_mean, &thread_m2, &thread_count); + WelfordCombine(vals[it].z, &thread_mean, &thread_m2, &thread_count); + WelfordCombine(vals[it].w, &thread_mean, &thread_m2, &thread_count); + } + + // mean var + float mean = 0; + float m2 = 0; + float count = 0; + WelfordWarpReduce(thread_mean, thread_m2, thread_count, &mean, &m2, &count); + mean = __shfl_sync(0xffffffff, mean, 0, C10_WARP_SIZE); + m2 = __shfl_sync(0xffffffff, m2, 0, C10_WARP_SIZE); + count = __shfl_sync(0xffffffff, count, 0, C10_WARP_SIZE); + m2 = rsqrtf(m2 / hidden_size + epsilon); + +#pragma unroll + for (int it = 0; it < THREAD_DATA_LEN; ++it) { + int element_index = threadIdx.x + it * C10_WARP_SIZE; + + float4 scale_value = p_scale[element_index]; + float4 bias_value = p_bias[element_index]; + + float4 norm_value; + norm_value.x = (vals[it].x - mean) * m2 * scale_value.x + bias_value.x; + norm_value.y = (vals[it].y - mean) * m2 * scale_value.y + bias_value.y; + norm_value.z = (vals[it].z - mean) * m2 * scale_value.z + bias_value.z; + norm_value.w = (vals[it].w - mean) * m2 * scale_value.w + bias_value.w; + + p_output[element_index] = norm_value; + } +} + +template +__global__ void IxinferLnResidualI8IKernel(int8_t *input, __half *scale, __half *bias, __half *residual, float *output, + int hidden_size, float dequant_scale) { + // register + float4 vals[THREAD_DATA_LEN]; + int block_start = blockIdx.x * hidden_size; + // one line start + input += block_start; + output += block_start; + residual += block_start; + + char4 *p_input = (char4 *)input; + float4 *p_output = (float4 *)output; + __half2 *p_residual = (__half2 *)residual; + __half2 *p_scale = (__half2 *)scale; + __half2 *p_bias = (__half2 *)bias; + + float thread_m2 = 0; + float thread_mean = 0; + float thread_count = 0; + + // load data from global memory +#pragma unroll + for (int it = 0; it < THREAD_DATA_LEN; ++it) { + int element_index = threadIdx.x + it * C10_WARP_SIZE; + char4 i_value = p_input[element_index]; + float4 r_value; + load_float4_from_half(r_value, p_residual, element_index); + + vals[it].x = (float)i_value.x * dequant_scale + r_value.x; + vals[it].y = (float)i_value.y * dequant_scale + r_value.y; + vals[it].z = (float)i_value.z * dequant_scale + r_value.z; + vals[it].w = (float)i_value.w * dequant_scale + r_value.w; + + WelfordCombine(vals[it].x, &thread_mean, &thread_m2, &thread_count); + WelfordCombine(vals[it].y, &thread_mean, &thread_m2, &thread_count); + WelfordCombine(vals[it].z, &thread_mean, &thread_m2, &thread_count); + WelfordCombine(vals[it].w, &thread_mean, &thread_m2, &thread_count); + } + + // mean var + float mean = 0; + float m2 = 0; + float count = 0; + WelfordWarpReduce(thread_mean, thread_m2, thread_count, &mean, &m2, &count); + mean = __shfl_sync(0xffffffff, mean, 0, C10_WARP_SIZE); + m2 = __shfl_sync(0xffffffff, m2, 0, C10_WARP_SIZE); + count = __shfl_sync(0xffffffff, count, 0, C10_WARP_SIZE); + m2 = rsqrtf(m2 / hidden_size + epsilon); + +#pragma unroll + for (int it = 0; it < THREAD_DATA_LEN; ++it) { + int element_index = threadIdx.x + it * C10_WARP_SIZE; + + float4 scale_value; + load_float4_from_half(scale_value, p_scale, element_index); + float4 bias_value; + load_float4_from_half(bias_value, p_bias, element_index); + + float4 norm_value; + norm_value.x = (vals[it].x - mean) * m2 * scale_value.x + bias_value.x; + norm_value.y = (vals[it].y - mean) * m2 * scale_value.y + bias_value.y; + norm_value.z = (vals[it].z - mean) * m2 * scale_value.z + bias_value.z; + norm_value.w = (vals[it].w - mean) * m2 * scale_value.w + bias_value.w; + p_output[element_index] = norm_value; + // __half2 p1; + // p1.x = __float2half(norm_value.x); + // p1.y = __float2half(norm_value.y); + // __half2 p2; + // p2.x = __float2half(norm_value.z); + // p2.y = __float2half(norm_value.w); + + // p_output[element_index * 2] = p1; + // p_output[element_index * 2 + 1] = p2; + } +} + +template +void IxinferLnResidualI8I(int8_t *input, T *scale, T *bias, T *residual, float *output, int batch_tokens, + int hidden_size, float dequant_scale, cudaStream_t stream) { + if (hidden_size > 4096) { + throw std::runtime_error("hidden_size should <= 4096"); + } + if (hidden_size / 4 % C10_WARP_SIZE != 0) { + throw std::runtime_error("hidden_size // C10_WARP_SIZE != 0"); + } + dim3 gridSize(batch_tokens); + dim3 blockSize(C10_WARP_SIZE); + + int num_warp = hidden_size / C10_WARP_SIZE / 4; + + switch (num_warp) { + case 1: + IxinferLnResidualI8IKernel<1> + <<>>(input, scale, bias, residual, output, hidden_size, dequant_scale); + break; + case 2: + IxinferLnResidualI8IKernel<1> + <<>>(input, scale, bias, residual, output, hidden_size, dequant_scale); + break; + case 3: + IxinferLnResidualI8IKernel<3> + <<>>(input, scale, bias, residual, output, hidden_size, dequant_scale); + break; + case 4: + IxinferLnResidualI8IKernel<4> + <<>>(input, scale, bias, residual, output, hidden_size, dequant_scale); + break; + case 5: + IxinferLnResidualI8IKernel<5> + <<>>(input, scale, bias, residual, output, hidden_size, dequant_scale); + break; + case 6: + IxinferLnResidualI8IKernel<6> + <<>>(input, scale, bias, residual, output, hidden_size, dequant_scale); + break; + case 7: + IxinferLnResidualI8IKernel<7> + <<>>(input, scale, bias, residual, output, hidden_size, dequant_scale); + break; + case 8: + IxinferLnResidualI8IKernel<8> + <<>>(input, scale, bias, residual, output, hidden_size, dequant_scale); + break; + case 9: + IxinferLnResidualI8IKernel<9> + <<>>(input, scale, bias, residual, output, hidden_size, dequant_scale); + break; + case 10: + IxinferLnResidualI8IKernel<10> + <<>>(input, scale, bias, residual, output, hidden_size, dequant_scale); + break; + case 11: + IxinferLnResidualI8IKernel<11> + <<>>(input, scale, bias, residual, output, hidden_size, dequant_scale); + break; + case 12: + IxinferLnResidualI8IKernel<12> + <<>>(input, scale, bias, residual, output, hidden_size, dequant_scale); + break; + case 13: + IxinferLnResidualI8IKernel<13> + <<>>(input, scale, bias, residual, output, hidden_size, dequant_scale); + break; + case 14: + IxinferLnResidualI8IKernel<14> + <<>>(input, scale, bias, residual, output, hidden_size, dequant_scale); + break; + case 15: + IxinferLnResidualI8IKernel<15> + <<>>(input, scale, bias, residual, output, hidden_size, dequant_scale); + break; + case 16: + IxinferLnResidualI8IKernel<16> + <<>>(input, scale, bias, residual, output, hidden_size, dequant_scale); + break; + default: + throw std::runtime_error("IxinferLnResidualI8I"); + break; + } +} + +template void IxinferLnResidualI8I(int8_t *input, float *scale, float *bias, float *residual, float *output, + int batch_tokens, int hidden_size, float dequant_scale, cudaStream_t stream); +template void IxinferLnResidualI8I<__half>(int8_t *input, __half *scale, __half *bias, __half *residual, float *output, + int batch_tokens, int hidden_size, float dequant_scale, cudaStream_t stream); + +template +__global__ void IxinferBiasGeluI8II8OKernel(int8_t *input, int8_t *output, const T *bias, int feature_dim, + float dequant_scale, float quant_scale) { + int block_start = blockIdx.x * feature_dim; + int start = block_start + threadIdx.x; + int end = block_start + feature_dim; + for (int i = start; i < end; i += blockDim.x) { + int input_index = i; + + float fout = gelu(float(input[input_index]) * dequant_scale + __ldg(&bias[i - block_start])); + + int output_index = i; + output[output_index] = float2int8(fout, quant_scale); + } +} + +/* fp16 version method 1 by pxl*/ +template <> +__global__ void IxinferBiasGeluI8II8OKernel<__half>(int8_t *input, int8_t *output, const __half *bias, int feature_dim, + float dequant_scale, float quant_scale) { + // #pragma unroll + for (int block_index = 0; block_index < 2; block_index++) { + int block_start = (blockIdx.x * 2 + block_index) * feature_dim; + int start = block_start + threadIdx.x * 4; + int end = block_start + feature_dim; + // #pragma unroll + // for (int i = start; i < end; i += blockDim.x*4) + // { + int input_index = start; + char4 *p_input = (char4 *)(input + input_index); + half2 *p_bias = (half2 *)(bias + input_index - block_start); + float fout1 = gelu(float(p_input[0].x) * dequant_scale + __half2float(p_bias[0].x)); + float fout2 = gelu(float(p_input[0].y) * dequant_scale + __half2float(p_bias[0].y)); + float fout3 = gelu(float(p_input[0].z) * dequant_scale + __half2float(p_bias[1].x)); + float fout4 = gelu(float(p_input[0].w) * dequant_scale + __half2float(p_bias[1].y)); + + int output_index = start; + char4 out; + out.x = float2int8(fout1, quant_scale); + out.y = float2int8(fout2, quant_scale); + out.z = float2int8(fout3, quant_scale); + out.w = float2int8(fout4, quant_scale); + char4 *p_output = (char4 *)(output + output_index); + + p_output[0] = out; + } + // } +} + +template +void IxinferBiasGeluI8II8O(int batch_token_num, cudaStream_t stream, int8_t *input, int8_t *output, const T *bias, + int feature_dim, float dequant_scale, float quant_scale) { + IxinferBiasGeluI8II8OKernel + <<>>(input, output, bias, feature_dim, dequant_scale, quant_scale); +} + +template <> +void IxinferBiasGeluI8II8O<__half>(int batch_token_num, cudaStream_t stream, int8_t *input, int8_t *output, + const __half *bias, int feature_dim, float dequant_scale, float quant_scale) { + if (feature_dim / 4 > 4096) { + throw std::runtime_error( + "IxinferBiasGeluI8II8O: feature_dim / 4 > 4096 has not " + "been " + "implemented yet!"); + } + if (feature_dim % 4 != 0) { + throw std::runtime_error( + "IxinferBiasGeluI8II8O: feature_dim % 4 != 0 has not been " + "implemented yet!"); + } + if (batch_token_num % 2 != 0) { + throw std::runtime_error( + "IxinferBiasGeluI8II8O: batch_token_num % 2 != 0 has not " + "been " + "implemented yet!"); + } + IxinferBiasGeluI8II8OKernel<__half><<>>( + input, output, bias, feature_dim, dequant_scale, quant_scale); +} + +/*******************add by pxl ***********/ +template +__global__ void IxinferArrangeEncselfQkvI8II8OKernel(const int8_t *ori_qkv, const T *qkv_bias, int8_t *new_qkv, + int max_batch_dim, int batch_seq_len, int dim_per_head, + int head_num, float quant_scale, float dequant_scale) { + int hidden_size = dim_per_head * head_num; + int batch_id = blockIdx.x / batch_seq_len; + int token_id = blockIdx.x % batch_seq_len; + + int i = threadIdx.x; // 1个线程处理4个数据 + + int head_id = (i * 4) / dim_per_head; + int dim_id = (i * 4) % dim_per_head; + int target_id = targetid_4dim(batch_id, head_id, token_id, dim_id, head_num, batch_seq_len, dim_per_head); + +#pragma unroll + for (int qkv_idx = 0; qkv_idx < 3; qkv_idx++) { + char4 *p_ori_qkv = (char4 *)(ori_qkv + (blockIdx.x * 3 + qkv_idx) * hidden_size); + int qkv_offset = max_batch_dim * qkv_idx; + char4 value; + + value.x = + float2int8(p_ori_qkv[i].x * dequant_scale + (float)qkv_bias[qkv_idx * hidden_size + i * 4], quant_scale); + value.y = float2int8(p_ori_qkv[i].y * dequant_scale + (float)qkv_bias[qkv_idx * hidden_size + i * 4 + 1], + quant_scale); + value.z = float2int8(p_ori_qkv[i].z * dequant_scale + (float)qkv_bias[qkv_idx * hidden_size + i * 4 + 2], + quant_scale); + value.w = float2int8(p_ori_qkv[i].w * dequant_scale + (float)qkv_bias[qkv_idx * hidden_size + i * 4 + 3], + quant_scale); + + char4 *p_new_qkv = (char4 *)(new_qkv + qkv_offset + target_id); + p_new_qkv[0] = value; + } +} + +/*method 1 by pxl + */ +template <> +__global__ void IxinferArrangeEncselfQkvI8II8OKernel<__half>(const int8_t *ori_qkv, const __half *qkv_bias, + int8_t *new_qkv, int max_batch_dim, int batch_seq_len, + int dim_per_head, int head_num, float quant_scale, + float dequant_scale) { + int hidden_size = dim_per_head * head_num; + int batch_id = blockIdx.x / batch_seq_len; + int token_id = blockIdx.x % batch_seq_len; + + int i = threadIdx.x; // 1个线程处理4个数据 + + int head_id = (i * 4) / dim_per_head; + int dim_id = (i * 4) % dim_per_head; + int target_id = targetid_4dim(batch_id, head_id, token_id, dim_id, head_num, batch_seq_len, dim_per_head); + +#pragma unroll + for (int qkv_idx = 0; qkv_idx < 3; qkv_idx++) { + char4 *p_ori_qkv = (char4 *)(ori_qkv + (blockIdx.x * 3 + qkv_idx) * hidden_size); + int qkv_offset = max_batch_dim * qkv_idx; + char4 value; + + half2 *p_qkv_bias = (half2 *)(qkv_bias + qkv_idx * hidden_size + i * 4); + value.x = float2int8(float(p_ori_qkv[i].x) * dequant_scale + __half2float(p_qkv_bias[0].x), quant_scale); + value.y = float2int8(float(p_ori_qkv[i].y) * dequant_scale + __half2float(p_qkv_bias[0].y), quant_scale); + value.z = float2int8(float(p_ori_qkv[i].z) * dequant_scale + __half2float(p_qkv_bias[1].x), quant_scale); + value.w = float2int8(float(p_ori_qkv[i].w) * dequant_scale + __half2float(p_qkv_bias[1].y), quant_scale); + + char4 *p_new_qkv = (char4 *)(new_qkv + qkv_offset + target_id); + p_new_qkv[0] = value; + } +} + +template +void IxinferArrangeEncselfQkvI8II8O(int batch_token_num, int hidden_size, cudaStream_t stream, const int8_t *ori_qkv, + const T *qkv_bias, int8_t *new_qkv, int max_batch_dim, int batch_seq_len, + int dim_per_head, int head_num, int max_thread_per_block, float quant_scale, + float dequant_scale) { + IxinferArrangeEncselfQkvI8II8OKernel<<>>( + ori_qkv, qkv_bias, new_qkv, max_batch_dim, batch_seq_len, dim_per_head, head_num, quant_scale, dequant_scale); +} +template <> +void IxinferArrangeEncselfQkvI8II8O<__half>(int batch_token_num, int hidden_size, cudaStream_t stream, + const int8_t *ori_qkv, const __half *qkv_bias, int8_t *new_qkv, + int max_batch_dim, int batch_seq_len, int dim_per_head, int head_num, + int max_thread_per_block, float quant_scale, float dequant_scale) { + IxinferArrangeEncselfQkvI8II8OKernel<__half><<>>( + ori_qkv, qkv_bias, new_qkv, max_batch_dim, batch_seq_len, dim_per_head, head_num, quant_scale, dequant_scale); +} + +template +__global__ void IxinferCorrelationSoftmaxEncselfI8II8OKernel(int8_t *correlation, const int8_t *src_padding_mask, + int batch_seq_len, float quant_scale, + float dequant_scale) { + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int SOFT_WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + constexpr int WARP_ITERATIONS = next_power_of_two / SOFT_WARP_SIZE; + int local_idx = threadIdx.x; + + for (int warp_idx = 0; warp_idx < WARP_BATCH; ++warp_idx) { + int start_idx = (blockIdx.x * gridDim.y * WARP_BATCH * gridDim.z * batch_seq_len + + (blockIdx.y + gridDim.y * warp_idx) * gridDim.z * batch_seq_len + blockIdx.z * batch_seq_len); + + char4 *p_correlation = (char4 *)(correlation + start_idx); + char4 *p_src_padding_mask = (char4 *)(src_padding_mask + blockIdx.x * batch_seq_len); + + // load data from global memory + // float + float4 elements[WARP_ITERATIONS]; +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; ++it) { + int element_index = local_idx + it * SOFT_WARP_SIZE; + if (element_index < batch_seq_len / 4) { + char4 mask = p_src_padding_mask[element_index]; + char4 correlation_value = p_correlation[element_index]; + + elements[it].x = + mask.x ? -std::numeric_limits::infinity() : (float)correlation_value.x * dequant_scale; + elements[it].y = + mask.y ? -std::numeric_limits::infinity() : (float)correlation_value.y * dequant_scale; + elements[it].z = + mask.z ? -std::numeric_limits::infinity() : (float)correlation_value.z * dequant_scale; + elements[it].w = + mask.w ? -std::numeric_limits::infinity() : (float)correlation_value.w * dequant_scale; + + } else { + elements[it].x = -std::numeric_limits::infinity(); + elements[it].y = -std::numeric_limits::infinity(); + elements[it].z = -std::numeric_limits::infinity(); + elements[it].w = -std::numeric_limits::infinity(); + } + } + + // compute max_value + float max_value = elements[0].x; + max_value = (max_value > elements[0].y) ? max_value : elements[0].y; + max_value = (max_value > elements[0].z) ? max_value : elements[0].z; + max_value = (max_value > elements[0].w) ? max_value : elements[0].w; + +#pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + max_value = (max_value > elements[it].x) ? max_value : elements[it].x; + max_value = (max_value > elements[it].y) ? max_value : elements[it].y; + max_value = (max_value > elements[it].z) ? max_value : elements[it].z; + max_value = (max_value > elements[it].w) ? max_value : elements[it].w; + } + + warp_reduce(&max_value); + + // exp sum + float sum = 0.0f; +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; ++it) { + elements[it].x = __expf(elements[it].x - max_value); + elements[it].y = __expf(elements[it].y - max_value); + elements[it].z = __expf(elements[it].z - max_value); + elements[it].w = __expf(elements[it].w - max_value); + + sum += (elements[it].x + elements[it].y + elements[it].z + elements[it].w); + } + + warp_reduce(&sum); + sum = 1.0f / sum; + // store result +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; ++it) { + int element_index = local_idx + it * SOFT_WARP_SIZE; + char4 correlation_value; + if (element_index < batch_seq_len / 4) { + correlation_value.x = float2int8(elements[it].x * sum, quant_scale); + correlation_value.y = float2int8(elements[it].y * sum, quant_scale); + correlation_value.z = float2int8(elements[it].z * sum, quant_scale); + correlation_value.w = float2int8(elements[it].w * sum, quant_scale); + + p_correlation[element_index] = correlation_value; + + } else { + break; + } + } + } +} + +void IxinferCorrelationSoftmaxEncselfI8II8O(int batch_size, int batch_seq_len, int head_num, cudaStream_t stream, + int8_t *correlation, const int8_t *src_padding_mask, float quant_scale, + float dequant_scale) { + const int NUM_INT8_SOFTMAX_BATCH_WARP = 4; + if (batch_seq_len > 512) { + throw std::runtime_error("batch_seq_len should <= 512"); + } + if (head_num % NUM_INT8_SOFTMAX_BATCH_WARP != 0) { + throw std::runtime_error("head_num % NUM_INT8_SOFTMAX_BATCH_WARP !0"); + } + if (batch_seq_len % 4 != 0) { + throw std::runtime_error("batch_seq_len % 4 != 0"); + } + + int log2_elements = log2_ceil(batch_seq_len / 4); + int next_power_of_two = 1 << log2_elements; + int SOFT_WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + // dim3 blockSize(batch_size, head_num / NUM_INT8_SOFTMAX_BATCH_WARP, + // batch_seq_len); + // + dim3 grid(batch_size, head_num / NUM_INT8_SOFTMAX_BATCH_WARP, batch_seq_len); + + dim3 block(SOFT_WARP_SIZE); + + switch (log2_elements) { + case 0: + IxinferCorrelationSoftmaxEncselfI8II8OKernel<0, NUM_INT8_SOFTMAX_BATCH_WARP> + <<>>(correlation, src_padding_mask, batch_seq_len, quant_scale, dequant_scale); + + break; + + case 1: + IxinferCorrelationSoftmaxEncselfI8II8OKernel<1, NUM_INT8_SOFTMAX_BATCH_WARP> + <<>>(correlation, src_padding_mask, batch_seq_len, quant_scale, dequant_scale); + break; + + case 2: + IxinferCorrelationSoftmaxEncselfI8II8OKernel<2, NUM_INT8_SOFTMAX_BATCH_WARP> + <<>>(correlation, src_padding_mask, batch_seq_len, quant_scale, dequant_scale); + break; + + case 3: + IxinferCorrelationSoftmaxEncselfI8II8OKernel<3, NUM_INT8_SOFTMAX_BATCH_WARP> + <<>>(correlation, src_padding_mask, batch_seq_len, quant_scale, dequant_scale); + break; + + case 4: + IxinferCorrelationSoftmaxEncselfI8II8OKernel<4, NUM_INT8_SOFTMAX_BATCH_WARP> + <<>>(correlation, src_padding_mask, batch_seq_len, quant_scale, dequant_scale); + break; + + case 5: + IxinferCorrelationSoftmaxEncselfI8II8OKernel<5, NUM_INT8_SOFTMAX_BATCH_WARP> + <<>>(correlation, src_padding_mask, batch_seq_len, quant_scale, dequant_scale); + break; + + case 6: + IxinferCorrelationSoftmaxEncselfI8II8OKernel<6, NUM_INT8_SOFTMAX_BATCH_WARP> + <<>>(correlation, src_padding_mask, batch_seq_len, quant_scale, dequant_scale); + break; + case 7: + IxinferCorrelationSoftmaxEncselfI8II8OKernel<7, NUM_INT8_SOFTMAX_BATCH_WARP> + <<>>(correlation, src_padding_mask, batch_seq_len, quant_scale, dequant_scale); + break; + case 8: + IxinferCorrelationSoftmaxEncselfI8II8OKernel<8, NUM_INT8_SOFTMAX_BATCH_WARP> + <<>>(correlation, src_padding_mask, batch_seq_len, quant_scale, dequant_scale); + break; + case 9: + IxinferCorrelationSoftmaxEncselfI8II8OKernel<9, NUM_INT8_SOFTMAX_BATCH_WARP> + <<>>(correlation, src_padding_mask, batch_seq_len, quant_scale, dequant_scale); + break; + default: + throw std::runtime_error( + "ker_correlation_softmax_encself_i8I_i8O_ix_ " + "NotImplementedError"); + break; + } +} + +__global__ void IxinferArrangeAttenOutputI8II8OKernel(const int8_t *ori_q, int8_t *new_q, int beam_size, + int dim_per_head, int head_num, float quant_scale, + float dequant_scale) { + int hidden_size = dim_per_head * head_num; + +#pragma unroll + for (int blockin = 0; blockin < 4; blockin++) { + int batch_id = (blockIdx.x * 4 + blockin) / beam_size; + // note, for encoder, beam_id is token_id; for decoder, beam_id is beam_id + int beam_id = (blockIdx.x * 4 + blockin) % beam_size; + int i = threadIdx.x; + int out_index = (blockIdx.x * 4 + blockin) * hidden_size + i; + int head_id = i / dim_per_head; + int dim_id = i % dim_per_head; + + char4 *p_ori_q = (char4 *)ori_q; + char4 *p_new_q = (char4 *)new_q; + char4 value; + + value = p_ori_q[targetid_4dim(batch_id, head_id, beam_id, dim_id, head_num, beam_size, dim_per_head)]; + value.x = float2int8(value.x * dequant_scale, quant_scale); + value.y = float2int8(value.y * dequant_scale, quant_scale); + value.z = float2int8(value.z * dequant_scale, quant_scale); + value.w = float2int8(value.w * dequant_scale, quant_scale); + p_new_q[out_index] = value; + } +} + +void IxinferArrangeAttenOutputI8II8O(int batch_token_num, int hidden_size, cudaStream_t stream, const int8_t *ori_q, + int8_t *new_q, int beam_size, int dim_per_head, int head_num, + int max_thread_per_block, float quant_scale, float dequant_scale) { + int qual_hidden_size = hidden_size >> 2; + int qual_dim_per_head = dim_per_head >> 2; + IxinferArrangeAttenOutputI8II8OKernel<<>>( + ori_q, new_q, beam_size, qual_dim_per_head, head_num, quant_scale, dequant_scale); +} + +} // namespace backend +} // namespace nvinfer1::plugin diff --git a/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/backend/bert/bert_layer_kernel.h b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/backend/bert/bert_layer_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..74ebb17c387e5b94fb7be1b9e7b96d423a5a9987 --- /dev/null +++ b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/backend/bert/bert_layer_kernel.h @@ -0,0 +1,65 @@ +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +* All Rights Reserved. +* +* Licensed under the Apache License, Version 2.0 (the "License"); you may +* not use this file except in compliance with the License. You may obtain +* a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +* License for the specific language governing permissions and limitations +* under the License. +*/ +#pragma once + +#include +#include +#include + +#include + +namespace nvinfer1::plugin { +namespace backend { + +// TODO: support col32 switch for input and output like relu +template +void IxinferBiasGeluI8II8O(int batch_token_num, cudaStream_t stream, int8_t *input, int8_t *output, const T *bias, + int feature_dim, float dequant_scale, float quant_scale); + +template +void IxinferResidualBiasLnI8II8O(const int8_t *input, const T *scale, const T *bias, const T *residual_bias, + int8_t *output, T *residual, int batch_tokens, int hidden_size, float dequant_scale, + float quant_scale, int max_thread_per_block, cudaStream_t stream, + bool is_post_ln = false, const T *colsum = nullptr); + +/***************add by pxl***************************/ + +template +void IxinferArrangeEncselfQkvI8II8O(int batch_token_num, int hidden_size, cudaStream_t stream, const int8_t *ori_qkv, + const T *qkv_bias, int8_t *new_qkv, int max_batch_dim, int batch_seq_len, + int dim_per_head, int head_num, int max_thread_per_block, float quant_scale, + float dequant_scale); +template <> +void IxinferArrangeEncselfQkvI8II8O<__half>(int batch_token_num, int hidden_size, cudaStream_t stream, + const int8_t *ori_qkv, const __half *qkv_bias, int8_t *new_qkv, + int max_batch_dim, int batch_seq_len, int dim_per_head, int head_num, + int max_thread_per_block, float quant_scale, float dequant_scale); + +void IxinferArrangeAttenOutputI8II8O(int batch_token_num, int hidden_size, cudaStream_t stream, const int8_t *ori_q, + int8_t *new_q, int beam_size, int dim_per_head, int head_num, + int max_thread_per_block, float quant_scale, float dequant_scale); +/***************add by pxl end***************************/ + +template +void IxinferLnResidualI8I(int8_t *input, T *scale, T *bias, T *residual, float *output, int batch_tokens, + int hidden_size, float dequant_scale, cudaStream_t stream); + +void IxinferCorrelationSoftmaxEncselfI8II8O(int batch_size, int batch_seq_len, int head_num, cudaStream_t stream, + int8_t *correlation, const int8_t *src_padding_mask, float quant_scale, + float dequant_scale); + +} // namespace backend +} // namespace nvinfer1::plugin diff --git a/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/backend/bert/bert_mha_kernel.cu b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/backend/bert/bert_mha_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..acb0423870fea31af7fd4f090caa7aaea286d802 --- /dev/null +++ b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/backend/bert/bert_mha_kernel.cu @@ -0,0 +1,465 @@ +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +* All Rights Reserved. +* +* Licensed under the Apache License, Version 2.0 (the "License"); you may +* not use this file except in compliance with the License. You may obtain +* a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +* License for the specific language governing permissions and limitations +* under the License. +*/ +#include + +#include "bert_helper.h" +#include "bert_mha_kernel.h" +namespace nvinfer1::plugin { +namespace backend { + +template +__forceinline__ __device__ T warpReduceMax(T val) { + for (int mask = (reduceWidth >> 1); mask > 0; mask >>= 1) + val = std::max(val, __shfl_xor_sync(0xffffffff, val, mask, reduceWidth)); + return val; +} + +template +__forceinline__ __device__ T warpReduceSum(T val) { + for (int mask = (reduceWidth >> 1); mask > 0; mask >>= 1) + val += __shfl_xor_sync(0xffffffff, val, mask, reduceWidth); + return val; +} + +__forceinline__ __device__ int8_t float2int8_amax(float x, float amax) { + float quant_scale = 127.f / amax; + float i8_f = x * quant_scale; + int32_t i8 = floorf(i8_f + 0.5); + i8 = i8 < -127 ? -127 : (i8 > 127 ? 127 : i8); + return int8_t(i8); +} + +template +__forceinline__ __device__ void _warp_softmax_16_n(float (*SOFTMAX_IN)[4], float *SOFT_MAX_SUM_VALUES) { + // 16*N的softmax + // softmax, 每16个线程计算一行,一共16行,需要4次 + for (int i = 0; i < 4; ++i) { + // softmax -> max value, + // 每个线程的最大值已经计算出来,现在需要每16个线程计算最大值 + float max_value = SOFT_MAX_SUM_VALUES[i]; + max_value = warpReduceMax(max_value); + SOFT_MAX_SUM_VALUES[i] = 0; + // exp(x-x_max) + for (int Ni = 0; Ni < BLOCK_SIZE_N / 16; ++Ni) { + SOFTMAX_IN[Ni][i] = __expf(SOFTMAX_IN[Ni][i] - max_value); + SOFT_MAX_SUM_VALUES[i] += SOFTMAX_IN[Ni][i]; + } + } + // exp sum, softmax + for (int i = 0; i < 4; ++i) { + float sum_value = SOFT_MAX_SUM_VALUES[i]; + sum_value = 1.f / (warpReduceSum(sum_value)); + for (int Ni = 0; Ni < BLOCK_SIZE_N / 16; ++Ni) { + SOFTMAX_IN[Ni][i] = SOFTMAX_IN[Ni][i] * sum_value; + } + } +} + +template +__forceinline__ __device__ void _warp_softmax_16_n(float (*SOFTMAX_IN)[4], float *SOFT_MAX_VALUES, + float *SOFT_SUM_VALUES) { + // 计算一个sub-block: BLOCK_SIZE_M*BLOCK_SIZE_M 的softmax + // 16*N的softmax + // softmax, 每16个线程计算一行,一共16行,需要4次 + for (int i = 0; i < 4; ++i) { + // softmax -> max value, + // 每个线程的最大值已经计算出来,现在需要每16个线程计算最大值 + SOFT_MAX_VALUES[i] = warpReduceMax(SOFT_MAX_VALUES[i]); + SOFT_SUM_VALUES[i] = 0; + // exp(x-x_max) + for (int Ni = 0; Ni < BLOCK_SIZE_M / 16; ++Ni) { + SOFTMAX_IN[Ni][i] = __expf(SOFTMAX_IN[Ni][i] - SOFT_MAX_VALUES[i]); + SOFT_SUM_VALUES[i] += SOFTMAX_IN[Ni][i]; + } + } + // exp sum, softmax + for (int i = 0; i < 4; ++i) { + SOFT_SUM_VALUES[i] = warpReduceSum(SOFT_SUM_VALUES[i]); + for (int Ni = 0; Ni < BLOCK_SIZE_M / 16; ++Ni) { + SOFTMAX_IN[Ni][i] = SOFTMAX_IN[Ni][i] / SOFT_SUM_VALUES[i]; + } + } +} +template +__forceinline__ __device__ void _mr_fmha_i8_tcu_impl_block(const int8_t *Q, const int8_t *K, const int8_t *V, + const int8_t *mask, int8_t *C, float q_amax, float k_amax, + float v_amax, float s_max, float qk_amax, float r_amax) { + // BLOCK_SIZE_N: K的行数,V的行数,需要是BLOCK_SIZE_M的整数倍 + static_assert(BLOCK_SIZE_N % BLOCK_SIZE_M == 0, "BLOCK_SIZE_N % BLOCK_SIZE_M != 0"); + static_assert(BLOCK_SIZE_N == N, "BLOCK_SIZE_N != N"); + + const float atten_scaler = sqrt(1.f / 64.f); + const float scale_qk_out = q_amax * k_amax / (127.f * 127.f) * atten_scaler; + const float scale_v_out = s_max * v_amax / (127.f * r_amax); + const float softmax_out_scale = 127.f / s_max; + + // 只考虑 M, 切分为 BLOCK_SIZE_M + unsigned warpBase = __ivcorex_readlane(threadIdx.x, 0); + unsigned warpId = warpBase >> 6; // warp的ID + unsigned laneId = __ivcorex_lane_id(); // 每个warp中线程的id, 0-63 + unsigned laneCol = laneId % 16; // lane对应的列, 一个warp的线程被划分为4*16 + unsigned laneRow = laneId / 16; // lane对应的行 + + // A load 到 shared memory + const unsigned SMSizePerStage = BLOCK_SIZE_M * 64 / 4; + const unsigned SNSizePerStage = BLOCK_SIZE_M * 64 / 4; + const unsigned SVSizePerStage = BLOCK_SIZE_M * 64 / 4; + + __shared__ int SLB_T[SMSizePerStage + SNSizePerStage + SVSizePerStage]; + // __shared__ int SLB_T[SMSizePerStage + SNSizePerStage * 2]; + int *SM = SLB_T; + int *SN = SLB_T + SMSizePerStage; + int *SV = SN + SNSizePerStage; + int8_t *SMI8 = reinterpret_cast(SM); + + v4u32 ABase; + ABase.x = (unsigned)(unsigned long long)Q; + ABase.y = (unsigned)((unsigned long long)Q >> 32); + ABase.z = -1u; + ABase.w = 64 * sizeof(char); // K 对应于 StrideA + + v4u32 BBase; + BBase.x = (unsigned)(unsigned long long)K; + BBase.y = (unsigned)((unsigned long long)K >> 32); + BBase.z = -1u; + BBase.w = 64 * sizeof(char); // K 对应于 StrideB + + v4u32 VBase; + VBase.x = (unsigned)(unsigned long long)V; + VBase.y = (unsigned)((unsigned long long)V >> 32); + VBase.z = -1u; + VBase.w = 64 * sizeof(char); // K 对应于 StrideB + + char *BC = (char *)(C); + v4u32 CBase; + CBase.x = (unsigned)(unsigned long long)BC; + CBase.y = (unsigned)((unsigned long long)BC >> 32); + CBase.zw = -1u; + + unsigned EmLaneId[4]; + for (unsigned i = 0; i < 4; i++) { + EmLaneId[i] = (laneRow + i) % 4 * 16 + laneCol; + } + + unsigned EmLaneIdB[4]; + for (unsigned i = 0; i < 4; i++) { + int kk = laneId / 4 % 4; + int tt = laneId / 16; + + int index = kk * 64 + tt * 16 + (kk ^ i) * 4 + (laneId % 4); + EmLaneIdB[i] = index; + } + + // load for Q to SM, 每个warp load 16*64 + unsigned gOffsetA = warpId * 1024; + unsigned sOffsetA = (unsigned)(unsigned long long)&SM[warpId * 256]; + __ivcorex_sme_load_16x1b64_rowxfb8(sOffsetA, ABase, gOffsetA, 1); + // load K to SM, 每个warp load 16*64 + unsigned gOffsetB = warpId * 1024; + unsigned sOffsetB = (unsigned)(unsigned long long)&SN[warpId * 256]; + __ivcorex_sme_load_16x1b64_colxfb8(sOffsetB, BBase, gOffsetB, 1); + + // load V to SM, 每个warp load 16*64 + unsigned gOffsetV = warpId * 1024; + unsigned sOffsetV = (unsigned)(unsigned long long)&SV[warpId * 256]; + __ivcorex_sme_load_16x1b64_rowxfb8(sOffsetV, VBase, gOffsetV, 1); + + __syncthreads(); + v2i32 MMA[2]; // 8个int8 + v2i32 MMB; + v2i32 MMS; // 8个int8 + v4i32 MMC[BLOCK_SIZE_M / 16] = {}; // 4个i32 + v4i32 MMV[4] = {}; // 64/16 + + // 第一个block的SOFTMAX + float SOFTMAX_IN_1[BLOCK_SIZE_M / 16][4]; + float SOFT_MAX_VALUES_1[4] = {-std::numeric_limits::infinity()}; + float SOFT_SUM_VALUES_1[4]; + + // 第一个block BLOCK_SIZE_M*BLOCK_SIZE_M + for (int Ni = 0; Ni < BLOCK_SIZE_M / 16; ++Ni) { + // 计算16*16 + for (unsigned Ki = 0; Ki < 2; Ki++) { + // A + unsigned SAI = Ki * 2; // 0,2 + int *SMI = &SM[warpId * 256 + SAI * warpSize]; + MMA[Ki][0] = SMI[EmLaneId[SAI % 4]]; + MMA[Ki][1] = SMI[warpSize + EmLaneId[(SAI + 1) % 4]]; + // B + unsigned SBI = Ki * 2; + int *SNI = &SN[Ni * 256 + (SBI / 4) * 256]; + MMB[0] = SNI[EmLaneIdB[SBI % 4]]; + MMB[1] = SNI[EmLaneIdB[(SBI + 1) % 4]]; + MMC[Ni] = __ivcorex_matrix_mad_i32x4_i8x8(MMA[Ki], MMB, MMC[Ni]); + } + // 反量化为float,为softmax做准备 + int seq_id = Ni * 16 + laneId % 16; + int mask_id = (int)mask[seq_id]; + for (int i = 0; i < 4; ++i) { + SOFTMAX_IN_1[Ni][i] = + mask_id ? (float)(MMC[Ni][i]) * scale_qk_out - 10000.f : (float)(MMC[Ni][i]) * scale_qk_out; + SOFT_MAX_VALUES_1[i] = max(SOFTMAX_IN_1[Ni][i], SOFT_MAX_VALUES_1[i]); + } + } + _warp_softmax_16_n(SOFTMAX_IN_1, SOFT_MAX_VALUES_1, SOFT_SUM_VALUES_1); + + // 与V的矩阵乘 -> 16*64 + // 16 * N 的softmax已经计算完成,接下来要计算 S*V: 16*N N*64 的 矩阵乘法 + // 每16*64写入SM + for (int N_step = 0; N_step < BLOCK_SIZE_M / 64; ++N_step) { + // 16*64写入 sm + for (int j = 0; j < 4; j++) { + int Ni = N_step * 4 + j; + for (int i = 0; i < 4; ++i) { + SMI8[warpId * 1024 + j * 16 * 16 + i * 64 + laneCol * 4 + laneRow] = + __float2int_rn(max(min(softmax_out_scale * SOFTMAX_IN_1[Ni][i], 127.f), -128.f)); + } + } + + // N方向上的计算次数,每次计算16*16,N方向需要重复 64/16 + for (int n_index = 0; n_index < 64 / 16; ++n_index) { + // 每次计算 16*32 32*16 + for (unsigned Ki = 0; Ki < 2; Ki++) { + MMS[0] = SM[warpId * 256 + Ki * 16 * 32 / 4 + laneId]; + MMS[1] = SM[warpId * 256 + Ki * 16 * 32 / 4 + 64 + laneId]; + + // V + unsigned SCI = Ki * 2 * (64 / 16) + n_index; + int *SVI = &SV[N_step * 64 * 64 / 4 + SCI * warpSize]; + MMB[0] = SVI[EmLaneId[SCI % 4]]; + MMB[1] = SVI[warpSize * (64 / 16) + EmLaneId[(SCI + (64 / 16)) % 4]]; + MMV[n_index] = __ivcorex_matrix_mad_i32x4_i8x8(MMS, MMB, MMV[n_index]); + } + } + } + + // 在N方向,每次只计算BLOCK_SIZE_M,需要循环 + for (int KNi = 1; KNi < BLOCK_SIZE_N / BLOCK_SIZE_M; ++KNi) { + __syncthreads(); + // load K to SM, 每个warp load 16*64, 第2个block + gOffsetB = warpId * 1024 + KNi * BLOCK_SIZE_M * 64; + sOffsetB = (unsigned)(unsigned long long)&SN[warpId * 256]; + __ivcorex_sme_load_16x1b64_colxfb8(sOffsetB, BBase, gOffsetB, 1); + gOffsetV = warpId * 1024 + KNi * BLOCK_SIZE_M * 64; + sOffsetV = (unsigned)(unsigned long long)&SV[warpId * 256]; + __ivcorex_sme_load_16x1b64_rowxfb8(sOffsetV, VBase, gOffsetV, 1); + __syncthreads(); + + v4i32 MMC[BLOCK_SIZE_M / 16] = {}; + v4i32 MMV2[4] = {}; // 64/16 + float SOFT_MAX_VALUES_2[4] = {-std::numeric_limits::infinity()}; + float SOFT_SUM_VALUES_2[4]; + + for (int Ni = 0; Ni < BLOCK_SIZE_M / 16; ++Ni) { + for (unsigned Ki = 0; Ki < 2; Ki++) { + // B + unsigned SBI = Ki * 2; + int *SNI = &SN[Ni * 256 + (SBI / 4) * 256]; + MMB[0] = SNI[EmLaneIdB[SBI % 4]]; + MMB[1] = SNI[EmLaneIdB[(SBI + 1) % 4]]; + MMC[Ni] = __ivcorex_matrix_mad_i32x4_i8x8(MMA[Ki], MMB, MMC[Ni]); + } + // 反量化为float,为softmax做准备 + int seq_id = BLOCK_SIZE_M * KNi + Ni * 16 + laneId % 16; + int mask_id = (int)mask[seq_id]; + for (int i = 0; i < 4; ++i) { + SOFTMAX_IN_1[Ni][i] = + mask_id ? (float)(MMC[Ni][i]) * scale_qk_out - 10000.f : (float)(MMC[Ni][i]) * scale_qk_out; + SOFT_MAX_VALUES_2[i] = max(SOFTMAX_IN_1[Ni][i], SOFT_MAX_VALUES_2[i]); + } + } + _warp_softmax_16_n(SOFTMAX_IN_1, SOFT_MAX_VALUES_2, SOFT_SUM_VALUES_2); + + // 与V的矩阵乘 -> 16*64 + // 16 * N 的softmax已经计算完成,接下来要计算 S*V: 16*N N*64 的 矩阵乘法 + // 每16*64写入SM + for (int N_step = 0; N_step < BLOCK_SIZE_M / 64; ++N_step) { + // 16*64写入 sm + for (int j = 0; j < 4; j++) { + int Ni = N_step * 4 + j; + for (int i = 0; i < 4; ++i) { + SMI8[warpId * 1024 + j * 16 * 16 + i * 64 + laneCol * 4 + laneRow] = + __float2int_rn(max(min(softmax_out_scale * SOFTMAX_IN_1[Ni][i], 127.f), -128.f)); + } + } + + // N方向上的计算次数,每次计算16*16,N方向需要重复 64/16 + for (int n_index = 0; n_index < 64 / 16; ++n_index) { + // 每次计算 16*32 32*16 + for (unsigned Ki = 0; Ki < 2; Ki++) { + MMS[0] = SM[warpId * 256 + Ki * 16 * 32 / 4 + laneId]; + MMS[1] = SM[warpId * 256 + Ki * 16 * 32 / 4 + 64 + laneId]; + + // V + unsigned SCI = Ki * 2 * (64 / 16) + n_index; + int *SVI = &SV[N_step * 64 * 64 / 4 + SCI * warpSize]; + MMB[0] = SVI[EmLaneId[SCI % 4]]; + MMB[1] = SVI[warpSize * (64 / 16) + EmLaneId[(SCI + (64 / 16)) % 4]]; + MMV2[n_index] = __ivcorex_matrix_mad_i32x4_i8x8(MMS, MMB, MMV2[n_index]); + } + } + } + // // flash softmax + for (int i = 0; i < 4; ++i) { + // update max + float max_value = SOFT_MAX_VALUES_1[i] > SOFT_MAX_VALUES_2[i] ? SOFT_MAX_VALUES_1[i] : SOFT_MAX_VALUES_2[i]; + float sf1_scale = __expf(SOFT_MAX_VALUES_1[i] - max_value); + float sf2_scale = __expf(SOFT_MAX_VALUES_2[i] - max_value); + // update exp sum + float sf_sum_value = SOFT_SUM_VALUES_1[i] * sf1_scale + SOFT_SUM_VALUES_2[i] * sf2_scale; + // update s*v + for (int n_index = 0; n_index < 64 / 16; ++n_index) { + float updated_value = (float)MMV[n_index][i] * SOFT_SUM_VALUES_1[i] * sf1_scale / sf_sum_value + + (float)MMV2[n_index][i] * SOFT_SUM_VALUES_2[i] * sf2_scale / sf_sum_value; + + MMV[n_index][i] = (int32_t)updated_value; + } + SOFT_MAX_VALUES_1[i] = max_value; + SOFT_SUM_VALUES_1[i] = sf_sum_value; + } + } + + v4i8 vr[4]; + for (int n_index = 0; n_index < 64 / 16; ++n_index) { + for (int i = 0; i < 4; ++i) { + vr[n_index][i] = __float2int_rn(max(min(scale_v_out * __int2float_rn(MMV[n_index][i]), 127.f), -128.f)); + } + } + for (int i = 0; i < 4; ++i) { + int m = laneId % 4; + int n = (laneId % 16) / 4 + i * 4 + laneRow * 16; + SM[warpId * 256 + m * 64 + n] = reinterpret_cast(vr[i]); + } + for (int i = 0; i < 4; ++i) { + vr[i] = reinterpret_cast(SM[warpId * 256 + i * 64 + laneId]); + } + + v4i8 val[4]; + for (int i = 0; i < 4; ++i) { + for (int j = 0; j < 4; ++j) { + val[i][j] = vr[j][i]; + } + } + for (int i = 0; i < 4; ++i) { + __ivcorex_ml_mem_store_i32(reinterpret_cast(val[i]), CBase, laneId * 4, warpId * 1024 + i * 64 * 4, 0); + } + + // for (int n_index = 0; n_index < 64 / 16; ++n_index) { + // for (int i = 0; i < 4; ++i) { + // int m = i * 4 + laneId / 16 + warpId * 16; + // int n = laneId % 16 + n_index * 16; + // C[m * 64 + n] = + // float2int8((float)(MMV[n_index][i]) * scale_v_out, r_amax); + // } + // } +} + +template +__global__ void _mr_fmha_i8_tcu_impl(const int8_t *Q, const int8_t *K, const int8_t *V, const int8_t *mask, int8_t *C, + int head_num, float q_amax, float k_amax, float v_amax, float s_max, float qk_amax, + float r_amax) { + Q += blockIdx.x * M * 64 + blockIdx.y * BLOCK_SIZE_M * 64; + K += blockIdx.x * N * 64; + V += blockIdx.x * N * 64; + C += blockIdx.x * M * 64 + blockIdx.y * BLOCK_SIZE_M * 64; + // C += blockIdx.x * M * N + blockIdx.y * BLOCK_SIZE_M * N; + mask += blockIdx.x / head_num * N; + _mr_fmha_i8_tcu_impl_block(Q, K, V, mask, C, q_amax, k_amax, v_amax, s_max, + qk_amax, r_amax); +} + +void IxinferMhaI8Launcher(cudaStream_t &stream, int8_t *q, int8_t *k, int8_t *v, int8_t *mask, int8_t *c, + int batch_size, int head_num, int seq_len, int head_dim, float q_amax, float k_amax, + float v_amax, float s_max, float qk_amax, float r_amax) { + if (head_dim != 64) { + throw std::runtime_error("mha kernel only support head_dim=64"); + } + + // dim3 gridSize_128(batch_size * head_num, seq_len / 128); + // dim3 gridSize_256(batch_size * head_num, seq_len / 256); + // switch (seq_len) { + // case 128: + // _mr_fmha_i8_tcu_impl<128, 128, 128, 128, 16> + // <<>>( + // q, k, v, mask, c, head_num, q_amax, k_amax, v_amax, s_max, + // qk_amax, r_amax); + // break; + // case 256: + // _mr_fmha_i8_tcu_impl<256, 256, 256, 256, 16> + // <<>>( + // q, k, v, mask, c, head_num, q_amax, k_amax, v_amax, s_max, + // qk_amax, r_amax); + // break; + // case 384: + // _mr_fmha_i8_tcu_impl<384, 384, 128, 384, 16> + // <<>>( + // q, k, v, mask, c, head_num, q_amax, k_amax, v_amax, s_max, + // qk_amax, r_amax); + // break; + // case 512: + // _mr_fmha_i8_tcu_impl<512, 512, 128, 512, 16> + // <<>>( + // q, k, v, mask, c, head_num, q_amax, k_amax, v_amax, s_max, + // qk_amax, r_amax); + // break; + // default: + // throw std::runtime_error("IxinferMhaI8Launcher parameter error!"); + // break; + // } + dim3 gridSize_64(batch_size * head_num, seq_len / 64); + dim3 gridSize_128(batch_size * head_num, seq_len / 128); + dim3 gridSize_256(batch_size * head_num, seq_len / 256); + switch (seq_len) { + case 64: + _mr_fmha_i8_tcu_impl<64, 64, 64, 64, 16><<>>( + q, k, v, mask, c, head_num, q_amax, k_amax, v_amax, s_max, qk_amax, r_amax); + break; + case 128: + _mr_fmha_i8_tcu_impl<128, 128, 128, 128, 16><<>>( + q, k, v, mask, c, head_num, q_amax, k_amax, v_amax, s_max, qk_amax, r_amax); + break; + case 192: + _mr_fmha_i8_tcu_impl<192, 192, 64, 192, 16><<>>( + q, k, v, mask, c, head_num, q_amax, k_amax, v_amax, s_max, qk_amax, r_amax); + break; + case 256: + _mr_fmha_i8_tcu_impl<256, 256, 256, 256, 16><<>>( + q, k, v, mask, c, head_num, q_amax, k_amax, v_amax, s_max, qk_amax, r_amax); + break; + case 320: + _mr_fmha_i8_tcu_impl<320, 320, 64, 320, 16><<>>( + q, k, v, mask, c, head_num, q_amax, k_amax, v_amax, s_max, qk_amax, r_amax); + break; + case 384: + _mr_fmha_i8_tcu_impl<384, 384, 128, 384, 16><<>>( + q, k, v, mask, c, head_num, q_amax, k_amax, v_amax, s_max, qk_amax, r_amax); + break; + case 448: + _mr_fmha_i8_tcu_impl<448, 448, 64, 448, 16><<>>( + q, k, v, mask, c, head_num, q_amax, k_amax, v_amax, s_max, qk_amax, r_amax); + break; + case 512: + _mr_fmha_i8_tcu_impl<512, 512, 128, 512, 16><<>>( + q, k, v, mask, c, head_num, q_amax, k_amax, v_amax, s_max, qk_amax, r_amax); + break; + default: + throw std::runtime_error("IxinferMhaI8Launcher parameter error!"); + break; + } +} + +} // namespace backend +} // namespace nvinfer1::plugin diff --git a/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/backend/bert/bert_mha_kernel.h b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/backend/bert/bert_mha_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..af4453a552a9c88ab7dcd54f46a6286042062174 --- /dev/null +++ b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/backend/bert/bert_mha_kernel.h @@ -0,0 +1,26 @@ +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +* All Rights Reserved. +* +* Licensed under the Apache License, Version 2.0 (the "License"); you may +* not use this file except in compliance with the License. You may obtain +* a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +* License for the specific language governing permissions and limitations +* under the License. +*/ +#pragma once +#include + +namespace nvinfer1::plugin { +namespace backend { + +void IxinferMhaI8Launcher(cudaStream_t& stream, int8_t* q, int8_t* k, int8_t* v, int8_t* mask, int8_t* c, + int batch_size, int head_num, int seq_len, int head_dim, float q_amax, float k_amax, + float v_amax, float s_max, float qk_amax, float r_amax); +} // namespace backend +} // namespace nvinfer1::plugin diff --git a/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/backend/cublas/cublas_helper.h b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/backend/cublas/cublas_helper.h new file mode 100644 index 0000000000000000000000000000000000000000..ef7db0e9ec752d2f72f52bddf77641c83cf5dda2 --- /dev/null +++ b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/backend/cublas/cublas_helper.h @@ -0,0 +1,310 @@ +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +* All Rights Reserved. +* +* Licensed under the Apache License, Version 2.0 (the "License"); you may +* not use this file except in compliance with the License. You may obtain +* a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +* License for the specific language governing permissions and limitations +* under the License. +*/ +#pragma once +#include +#include +#include +#include + +#include + +#include "checkMacrosPlugin.h" + +namespace nvinfer1::plugin { +namespace backend { + +/* GPU function guard */ + +/** + * @brief cublasLt gemm without imma + * + * @tparam OutType output dtype + * @tparam ScaleType scale dtype + * @param input_a + * @param input_b + * @param output_c + * @param batch_count + * @param m + * @param n + * @param k + * @param stridea + * @param strideb + * @param stridec + * @param alpha + * @param cublasLt_handle + * @param stream + */ +template +void cublaslt_gemm(const int8_t* input_a, const int8_t* input_b, OutType* output_c, int batch_count, int m, int n, + int k, int64_t stridea, int64_t strideb, int64_t stridec, const ScaleType alpha, + cublasLtHandle_t cublasLt_handle, cudaStream_t stream) { + cublasOperation_t transpose = CUBLAS_OP_T; +#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 + cublasComputeType_t compute_type = CUBLAS_COMPUTE_32I; +#else + cudaDataType_t compute_type = CUDA_R_32I; +#endif + cublasLtMatmulDesc_t matmul_desc; + cublasLtMatrixLayout_t desc_a = NULL; + cublasLtMatrixLayout_t desc_b = NULL; + cublasLtMatrixLayout_t desc_c = NULL; + + cudaDataType_t out_dtype; + cudaDataType_t scale_dtype; + if (std::is_same::value) { + out_dtype = CUDA_R_32I; + scale_dtype = CUDA_R_32I; + } else if (std::is_same::value) { + out_dtype = CUDA_R_8I; + scale_dtype = CUDA_R_32F; + } else { + throw std::runtime_error("Unsupported output type"); + } + +#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 + CHECK_GPU_ERROR(cublasLtMatmulDescCreate(&matmul_desc, compute_type, scale_dtype)); +#else + CHECK_GPU_ERROR(cublasLtMatmulDescCreate(&matmul_desc, compute_type)); + CHECK_GPU_ERROR(cublasLtMatmulDescSetAttribute(matmul_desc, CUBLASLT_MATMUL_DESC_SCALE_TYPE, &scale_dtype, + sizeof(scale_dtype))); +#endif + CHECK_GPU_ERROR( + cublasLtMatmulDescSetAttribute(matmul_desc, CUBLASLT_MATMUL_DESC_TRANSA, &transpose, sizeof(transpose))); + + CHECK_GPU_ERROR(cublasLtMatrixLayoutCreate(&desc_a, CUDA_R_8I, k, m, k)); + CHECK_GPU_ERROR(cublasLtMatrixLayoutCreate(&desc_b, CUDA_R_8I, k, n, k)); + CHECK_GPU_ERROR(cublasLtMatrixLayoutCreate(&desc_c, out_dtype, m, n, m)); + + if (batch_count > 1) { + CHECK_GPU_ERROR(cublasLtMatrixLayoutSetAttribute(desc_a, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, + sizeof(batch_count))); + CHECK_GPU_ERROR(cublasLtMatrixLayoutSetAttribute(desc_a, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stridea, + sizeof(stridea))); + CHECK_GPU_ERROR(cublasLtMatrixLayoutSetAttribute(desc_b, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, + sizeof(batch_count))); + CHECK_GPU_ERROR(cublasLtMatrixLayoutSetAttribute(desc_b, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &strideb, + sizeof(strideb))); + CHECK_GPU_ERROR(cublasLtMatrixLayoutSetAttribute(desc_c, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, + sizeof(batch_count))); + CHECK_GPU_ERROR(cublasLtMatrixLayoutSetAttribute(desc_c, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stridec, + sizeof(stridec))); + } + + ScaleType beta = ScaleType(0); + CHECK_GPU_ERROR(cublasLtMatmul(cublasLt_handle, matmul_desc, &alpha, input_a, desc_a, input_b, desc_b, &beta, + output_c, desc_c, output_c, desc_c, NULL, NULL, 0, stream)); + + CHECK_GPU_ERROR(cublasLtMatmulDescDestroy(matmul_desc)); + CHECK_GPU_ERROR(cublasLtMatrixLayoutDestroy(desc_a)); + CHECK_GPU_ERROR(cublasLtMatrixLayoutDestroy(desc_b)); + CHECK_GPU_ERROR(cublasLtMatrixLayoutDestroy(desc_c)); +} + +inline void cublaslt_gemm(const half* input_a, const half* input_b, half* output_c, int batch_count, int m, int n, + int k, int64_t stridea, int64_t strideb, int64_t stridec, const float alpha, + cublasLtHandle_t cublasLt_handle, cudaStream_t stream) { + cublasOperation_t transpose = CUBLAS_OP_T; +#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 + cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F; +#else + cudaDataType_t compute_type = CUDA_R_32F; +#endif + cublasLtMatmulDesc_t matmul_desc; + cublasLtMatrixLayout_t desc_a = NULL; + cublasLtMatrixLayout_t desc_b = NULL; + cublasLtMatrixLayout_t desc_c = NULL; + + cudaDataType_t out_dtype = CUDA_R_16F; + cudaDataType_t scale_dtype = CUDA_R_32F; + +#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 + CHECK_GPU_ERROR(cublasLtMatmulDescCreate(&matmul_desc, compute_type, scale_dtype)); +#else + CHECK_GPU_ERROR(cublasLtMatmulDescCreate(&matmul_desc, compute_type)); + CHECK_GPU_ERROR(cublasLtMatmulDescSetAttribute(matmul_desc, CUBLASLT_MATMUL_DESC_SCALE_TYPE, &scale_dtype, + sizeof(scale_dtype))); +#endif + CHECK_GPU_ERROR( + cublasLtMatmulDescSetAttribute(matmul_desc, CUBLASLT_MATMUL_DESC_TRANSA, &transpose, sizeof(transpose))); + + CHECK_GPU_ERROR(cublasLtMatrixLayoutCreate(&desc_a, CUDA_R_16F, k, m, k)); + CHECK_GPU_ERROR(cublasLtMatrixLayoutCreate(&desc_b, CUDA_R_16F, k, n, k)); + CHECK_GPU_ERROR(cublasLtMatrixLayoutCreate(&desc_c, out_dtype, m, n, m)); + + if (batch_count > 1) { + CHECK_GPU_ERROR(cublasLtMatrixLayoutSetAttribute(desc_a, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, + sizeof(batch_count))); + CHECK_GPU_ERROR(cublasLtMatrixLayoutSetAttribute(desc_a, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stridea, + sizeof(stridea))); + CHECK_GPU_ERROR(cublasLtMatrixLayoutSetAttribute(desc_b, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, + sizeof(batch_count))); + CHECK_GPU_ERROR(cublasLtMatrixLayoutSetAttribute(desc_b, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &strideb, + sizeof(strideb))); + CHECK_GPU_ERROR(cublasLtMatrixLayoutSetAttribute(desc_c, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, + sizeof(batch_count))); + CHECK_GPU_ERROR(cublasLtMatrixLayoutSetAttribute(desc_c, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stridec, + sizeof(stridec))); + } + + float beta = 0.0; + CHECK_GPU_ERROR(cublasLtMatmul(cublasLt_handle, matmul_desc, &alpha, input_a, desc_a, input_b, desc_b, &beta, + output_c, desc_c, output_c, desc_c, NULL, NULL, 0, stream)); + + CHECK_GPU_ERROR(cublasLtMatmulDescDestroy(matmul_desc)); + CHECK_GPU_ERROR(cublasLtMatrixLayoutDestroy(desc_a)); + CHECK_GPU_ERROR(cublasLtMatrixLayoutDestroy(desc_b)); + CHECK_GPU_ERROR(cublasLtMatrixLayoutDestroy(desc_c)); +} + +template void cublaslt_gemm(const int8_t* input_a, const int8_t* input_b, int32_t* output_c, + int batchCount, int m, int n, int k, int64_t stridea, int64_t strideb, + int64_t stridec, const int32_t alpha, cublasLtHandle_t cublasLt_handle, + cudaStream_t stream); + +template void cublaslt_gemm(const int8_t* input_a, const int8_t* input_b, int8_t* output_c, + int batchCount, int m, int n, int k, int64_t stridea, int64_t strideb, + int64_t stridec, const float alpha, cublasLtHandle_t cublasLt_handle, + cudaStream_t stream); + +/************add by pxl *************/ +template +void cublaslt_gemm_nn(const int8_t* input_a, const int8_t* input_b, OutType* output_c, int batch_count, int m, int n, + int k, int64_t stridea, int64_t strideb, int64_t stridec, const ScaleType alpha, + cublasLtHandle_t cublasLt_handle, cudaStream_t stream) { +#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 + cublasComputeType_t compute_type = CUBLAS_COMPUTE_32I; +#else + cudaDataType_t compute_type = CUDA_R_32I; +#endif + cublasLtMatmulDesc_t matmul_desc; + cublasLtMatrixLayout_t desc_a = NULL; + cublasLtMatrixLayout_t desc_b = NULL; + cublasLtMatrixLayout_t desc_c = NULL; + + cudaDataType_t out_dtype; + cudaDataType_t scale_dtype; + if (std::is_same::value) { + out_dtype = CUDA_R_32I; + scale_dtype = CUDA_R_32I; + } else if (std::is_same::value) { + out_dtype = CUDA_R_8I; + scale_dtype = CUDA_R_32F; + } else { + throw std::runtime_error("Unsupported output type"); + } + +#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 + CHECK_GPU_ERROR(cublasLtMatmulDescCreate(&matmul_desc, compute_type, scale_dtype)); +#else + CHECK_GPU_ERROR(cublasLtMatmulDescCreate(&matmul_desc, compute_type)); + CHECK_GPU_ERROR(cublasLtMatmulDescSetAttribute(matmul_desc, CUBLASLT_MATMUL_DESC_SCALE_TYPE, &scale_dtype, + sizeof(scale_dtype))); +#endif + + CHECK_GPU_ERROR(cublasLtMatrixLayoutCreate(&desc_a, CUDA_R_8I, m, k, m)); + CHECK_GPU_ERROR(cublasLtMatrixLayoutCreate(&desc_b, CUDA_R_8I, k, n, k)); + CHECK_GPU_ERROR(cublasLtMatrixLayoutCreate(&desc_c, out_dtype, m, n, m)); + + if (batch_count > 1) { + CHECK_GPU_ERROR(cublasLtMatrixLayoutSetAttribute(desc_a, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, + sizeof(batch_count))); + CHECK_GPU_ERROR(cublasLtMatrixLayoutSetAttribute(desc_a, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stridea, + sizeof(stridea))); + CHECK_GPU_ERROR(cublasLtMatrixLayoutSetAttribute(desc_b, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, + sizeof(batch_count))); + CHECK_GPU_ERROR(cublasLtMatrixLayoutSetAttribute(desc_b, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &strideb, + sizeof(strideb))); + CHECK_GPU_ERROR(cublasLtMatrixLayoutSetAttribute(desc_c, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, + sizeof(batch_count))); + CHECK_GPU_ERROR(cublasLtMatrixLayoutSetAttribute(desc_c, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stridec, + sizeof(stridec))); + } + + ScaleType beta = ScaleType(0); + CHECK_GPU_ERROR(cublasLtMatmul(cublasLt_handle, matmul_desc, &alpha, input_a, desc_a, input_b, desc_b, &beta, + output_c, desc_c, output_c, desc_c, NULL, NULL, 0, stream)); + + CHECK_GPU_ERROR(cublasLtMatmulDescDestroy(matmul_desc)); + CHECK_GPU_ERROR(cublasLtMatrixLayoutDestroy(desc_a)); + CHECK_GPU_ERROR(cublasLtMatrixLayoutDestroy(desc_b)); + CHECK_GPU_ERROR(cublasLtMatrixLayoutDestroy(desc_c)); +} + +inline void cublaslt_gemm_nn(const half* input_a, const half* input_b, half* output_c, int batch_count, int m, int n, + int k, int64_t stridea, int64_t strideb, int64_t stridec, const float alpha, + cublasLtHandle_t cublasLt_handle, cudaStream_t stream) { +#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 + cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F; +#else + cudaDataType_t compute_type = CUDA_R_32F; +#endif + cublasLtMatmulDesc_t matmul_desc; + cublasLtMatrixLayout_t desc_a = NULL; + cublasLtMatrixLayout_t desc_b = NULL; + cublasLtMatrixLayout_t desc_c = NULL; + + cudaDataType_t out_dtype = CUDA_R_16F; + cudaDataType_t scale_dtype = CUDA_R_32F; + +#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 + CHECK_GPU_ERROR(cublasLtMatmulDescCreate(&matmul_desc, compute_type, scale_dtype)); +#else + CHECK_GPU_ERROR(cublasLtMatmulDescCreate(&matmul_desc, compute_type)); + CHECK_GPU_ERROR(cublasLtMatmulDescSetAttribute(matmul_desc, CUBLASLT_MATMUL_DESC_SCALE_TYPE, &scale_dtype, + sizeof(scale_dtype))); +#endif + + CHECK_GPU_ERROR(cublasLtMatrixLayoutCreate(&desc_a, CUDA_R_16F, m, k, m)); + CHECK_GPU_ERROR(cublasLtMatrixLayoutCreate(&desc_b, CUDA_R_16F, k, n, k)); + CHECK_GPU_ERROR(cublasLtMatrixLayoutCreate(&desc_c, out_dtype, m, n, m)); + + if (batch_count > 1) { + CHECK_GPU_ERROR(cublasLtMatrixLayoutSetAttribute(desc_a, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, + sizeof(batch_count))); + CHECK_GPU_ERROR(cublasLtMatrixLayoutSetAttribute(desc_a, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stridea, + sizeof(stridea))); + CHECK_GPU_ERROR(cublasLtMatrixLayoutSetAttribute(desc_b, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, + sizeof(batch_count))); + CHECK_GPU_ERROR(cublasLtMatrixLayoutSetAttribute(desc_b, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &strideb, + sizeof(strideb))); + CHECK_GPU_ERROR(cublasLtMatrixLayoutSetAttribute(desc_c, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, + sizeof(batch_count))); + CHECK_GPU_ERROR(cublasLtMatrixLayoutSetAttribute(desc_c, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stridec, + sizeof(stridec))); + } + + float beta = float(0); + CHECK_GPU_ERROR(cublasLtMatmul(cublasLt_handle, matmul_desc, &alpha, input_a, desc_a, input_b, desc_b, &beta, + output_c, desc_c, output_c, desc_c, NULL, NULL, 0, stream)); + + CHECK_GPU_ERROR(cublasLtMatmulDescDestroy(matmul_desc)); + CHECK_GPU_ERROR(cublasLtMatrixLayoutDestroy(desc_a)); + CHECK_GPU_ERROR(cublasLtMatrixLayoutDestroy(desc_b)); + CHECK_GPU_ERROR(cublasLtMatrixLayoutDestroy(desc_c)); +} + +template void cublaslt_gemm_nn(const int8_t* input_a, const int8_t* input_b, int32_t* output_c, + int batchCount, int m, int n, int k, int64_t stridea, int64_t strideb, + int64_t stridec, const int32_t alpha, cublasLtHandle_t cublasLt_handle, + cudaStream_t stream); + +template void cublaslt_gemm_nn(const int8_t* input_a, const int8_t* input_b, int8_t* output_c, + int batchCount, int m, int n, int k, int64_t stridea, int64_t strideb, + int64_t stridec, const float alpha, cublasLtHandle_t cublasLt_handle, + cudaStream_t stream); + +} // namespace backend +} // end of namespace nvinfer1::plugin diff --git a/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/backend/ixinfer/ixinfer_gemm_helper.cu b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/backend/ixinfer/ixinfer_gemm_helper.cu new file mode 100644 index 0000000000000000000000000000000000000000..1916cca7ae9bf6ebed579ea17553d348459f55d8 --- /dev/null +++ b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/backend/ixinfer/ixinfer_gemm_helper.cu @@ -0,0 +1,428 @@ +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +* All Rights Reserved. +* +* Licensed under the Apache License, Version 2.0 (the "License"); you may +* not use this file except in compliance with the License. You may obtain +* a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +* License for the specific language governing permissions and limitations +* under the License. +*/ +#include + +#include "ixinfer_gemm_helper.h" + +namespace nvinfer1::plugin { +namespace backend { + +void cuinfer_i8_gemm(const int8_t *input_a, const int8_t *input_b, int8_t *output_c, int batch_count, int m, int n, + int k, int64_t stridea, int64_t strideb, int64_t stridec, const float alpha, + cuinferHandle_t cuinfer_handle, cudaStream_t stream) { + /* TN: input_a: m,k input_b: n,k output_c: n,m */ + cuinferPointerMode_t cuinfer_ptr_mode = CUINFER_POINTER_MODE_HOST; + cuinferOperation_t transa = CUINFER_OP_T; + cuinferOperation_t transb = CUINFER_OP_N; + + cudaDataType_t Atype = CUDA_R_8I; + cudaDataType_t Btype = CUDA_R_8I; + cudaDataType_t Ctype = CUDA_R_8I; + cudaDataType_t computeType = CUDA_R_32I; + cudaDataType_t scaleType = CUDA_R_32F; + cuinferGEMMCustomOption_t customOption = CUINFER_BLAS_GEMM_CUSTOM_NONE; + + int lda = k; + int ldb = k; + int ldc = m; + + float beta = 0.f; + + cuinferStatus_t status = + cuinferCustomGemm(cuinfer_handle, stream, cuinfer_ptr_mode, transa, transb, m, n, k, &alpha, input_a, Atype, + lda, stridea, input_b, Btype, ldb, strideb, &beta, output_c, Ctype, ldc, stridec, batch_count, + computeType, scaleType, nullptr, nullptr, customOption); + + if (status != CUINFER_STATUS_SUCCESS) { + throw std::runtime_error("cuinferCustomGemm error!, error type: " + std::to_string((int)status) + " !"); + } +} + +void cuinfer_i8_gemm(const int8_t *input_a, const int8_t *input_b, const float *bias, int8_t *output_c, int batch_count, + int m, int n, int k, int64_t stridea, int64_t strideb, int64_t stridec, const float alpha, + const float beta, const int act_type, cuinferHandle_t &cuinfer_handle, cudaStream_t &stream) { + cuinferPointerMode_t cuinfer_ptr_mode = CUINFER_POINTER_MODE_HOST; + cuinferOperation_t transa = CUINFER_OP_T; + cuinferOperation_t transb = CUINFER_OP_N; + cudaDataType_t Atype = CUDA_R_8I; + cudaDataType_t Btype = CUDA_R_8I; + cudaDataType_t Ctype = CUDA_R_8I; + cudaDataType_t computeType = CUDA_R_32I; + cudaDataType_t scaleType = CUDA_R_32F; + cuinferGEMMCustomOption_t customOption; + if (bias != nullptr) { + if (act_type == 3) { + customOption = CUINFER_BLAS_GEMM_CUSTOM_HALFBIAS_GELU; + } else if (act_type == 4) { + customOption = CUINFER_BLAS_GEMM_CUSTOM_HALFBIAS_RELU; + } else { + customOption = CUINFER_BLAS_GEMM_CUSTOM_HALFBIAS; + } + } else { + customOption = CUINFER_BLAS_GEMM_CUSTOM_NONE; + } + + int lda = k; + int ldb = k; + int ldc = m; + + cuinferStatus_t status = + cuinferCustomGemm(cuinfer_handle, stream, cuinfer_ptr_mode, transa, transb, m, n, k, &alpha, input_a, Atype, + lda, stridea, input_b, Btype, ldb, strideb, &beta, output_c, Ctype, ldc, stridec, batch_count, + computeType, scaleType, nullptr, (void *)bias, customOption); + if (status != CUINFER_STATUS_SUCCESS) { + throw std::runtime_error("cuinferCustomGemm error, error type: " + std::to_string((int)status) + " !"); + } +} + +void cuinfer_nn_i8_gemm(const int8_t *input_a, const int8_t *input_b, int8_t *output_c, int batch_count, int m, int n, + int k, int64_t stridea, int64_t strideb, int64_t stridec, const float alpha, + cuinferHandle_t cuinfer_handle, cudaStream_t stream) { + /* TN: input_a: k,m input_b: n,k output_c: n,m */ + cuinferPointerMode_t cuinfer_ptr_mode = CUINFER_POINTER_MODE_HOST; + cuinferOperation_t transa = CUINFER_OP_N; + cuinferOperation_t transb = CUINFER_OP_N; + + cudaDataType_t Atype = CUDA_R_8I; + cudaDataType_t Btype = CUDA_R_8I; + cudaDataType_t Ctype = CUDA_R_8I; + cudaDataType_t computeType = CUDA_R_32I; + cudaDataType_t scaleType = CUDA_R_32F; + cuinferGEMMCustomOption_t customOption = CUINFER_BLAS_GEMM_CUSTOM_NONE; + + int lda = m; + int ldb = k; + int ldc = m; + + float beta = 0.f; + + cuinferStatus_t status = + cuinferCustomGemm(cuinfer_handle, stream, cuinfer_ptr_mode, transa, transb, m, n, k, &alpha, input_a, Atype, + lda, stridea, input_b, Btype, ldb, strideb, &beta, output_c, Ctype, ldc, stridec, batch_count, + computeType, scaleType, nullptr, nullptr, customOption); + + if (status != CUINFER_STATUS_SUCCESS) { + throw std::runtime_error("cuinferCustomGemm error!"); + } +} + +void cuinfer_nt_i8_gemm(const int8_t *input_a, const int8_t *input_b, int8_t *output_c, int batch_count, int m, int n, + int k, int64_t stridea, int64_t strideb, int64_t stridec, const float alpha, + cuinferHandle_t cuinfer_handle, cudaStream_t stream) { + /* TN: input_a: k,m input_b: k,n output_c: n,m */ + cuinferPointerMode_t cuinfer_ptr_mode = CUINFER_POINTER_MODE_HOST; + cuinferOperation_t transa = CUINFER_OP_N; + cuinferOperation_t transb = CUINFER_OP_T; + + cudaDataType_t Atype = CUDA_R_8I; + cudaDataType_t Btype = CUDA_R_8I; + cudaDataType_t Ctype = CUDA_R_8I; + cudaDataType_t computeType = CUDA_R_32I; + cudaDataType_t scaleType = CUDA_R_32F; + cuinferGEMMCustomOption_t customOption = CUINFER_BLAS_GEMM_CUSTOM_NONE; + + int lda = m; + int ldb = n; + int ldc = m; + + float beta = 0.f; + + cuinferStatus_t status = + cuinferCustomGemm(cuinfer_handle, stream, cuinfer_ptr_mode, transa, transb, m, n, k, &alpha, input_a, Atype, + lda, stridea, input_b, Btype, ldb, strideb, &beta, output_c, Ctype, ldc, stridec, batch_count, + computeType, scaleType, nullptr, nullptr, customOption); + + if (status != CUINFER_STATUS_SUCCESS) { + throw std::runtime_error("cuinferCustomGemm error!"); + } +} + +void cuinfer_tt_i8_gemm(const int8_t *input_a, const int8_t *input_b, int8_t *output_c, int batch_count, int m, int n, + int k, int64_t stridea, int64_t strideb, int64_t stridec, const float alpha, + cuinferHandle_t cuinfer_handle, cudaStream_t stream) { + /* TN: input_a: k,m input_b: k,n output_c: n,m */ + cuinferPointerMode_t cuinfer_ptr_mode = CUINFER_POINTER_MODE_HOST; + cuinferOperation_t transa = CUINFER_OP_T; + cuinferOperation_t transb = CUINFER_OP_T; + + cudaDataType_t Atype = CUDA_R_8I; + cudaDataType_t Btype = CUDA_R_8I; + cudaDataType_t Ctype = CUDA_R_8I; + cudaDataType_t computeType = CUDA_R_32I; + cudaDataType_t scaleType = CUDA_R_32F; + cuinferGEMMCustomOption_t customOption = CUINFER_BLAS_GEMM_CUSTOM_NONE; + + int lda = k; + int ldb = n; + int ldc = m; + + float beta = 0.f; + + cuinferStatus_t status = + cuinferCustomGemm(cuinfer_handle, stream, cuinfer_ptr_mode, transa, transb, m, n, k, &alpha, input_a, Atype, + lda, stridea, input_b, Btype, ldb, strideb, &beta, output_c, Ctype, ldc, stridec, batch_count, + computeType, scaleType, nullptr, nullptr, customOption); + + if (status != CUINFER_STATUS_SUCCESS) { + throw std::runtime_error("cuinferCustomGemm error!"); + } +} + +void cuinfer_gemm(const half *input_a, const half *input_b, half *output_c, int batch_count, int m, int n, int k, + int64_t stridea, int64_t strideb, int64_t stridec, const float alpha, cublasHandle_t handle, + cudaStream_t stream) { + /* Performs operation using cublas */ + float beta = 0.0f; + cublasSetStream(handle, stream); + cublasStatus_t status; + if (batch_count <= 1) { + status = cublasGemmEx(handle, CUBLAS_OP_T, CUBLAS_OP_N, m, n, k, &alpha, input_a, CUDA_R_16F, k, input_b, + CUDA_R_16F, k, &beta, output_c, CUDA_R_16F, m, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP); + } else { + status = cublasGemmStridedBatchedEx(handle, CUBLAS_OP_T, CUBLAS_OP_N, m, n, k, &alpha, input_a, CUDA_R_16F, k, + stridea, input_b, CUDA_R_16F, k, strideb, &beta, output_c, CUDA_R_16F, m, + stridec, batch_count, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP); + } + if (status != CUBLAS_STATUS_SUCCESS) { + throw std::runtime_error("cuinfer_gemm error!"); + } +} + +void cuinfer_nn_gemm(const half *input_a, const half *input_b, half *output_c, int batch_count, int m, int n, int k, + int64_t stridea, int64_t strideb, int64_t stridec, const float alpha, cublasHandle_t handle, + cudaStream_t stream) { + /* Performs operation using cublas */ + float beta = 0.0f; + cublasSetStream(handle, stream); + cublasStatus_t status; + if (batch_count <= 1) { + // k,m n,k + status = cublasGemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, m, n, k, &alpha, input_a, CUDA_R_16F, m, input_b, + CUDA_R_16F, k, &beta, output_c, CUDA_R_16F, m, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP); + } else { + status = cublasGemmStridedBatchedEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, m, n, k, &alpha, input_a, CUDA_R_16F, m, + stridea, input_b, CUDA_R_16F, k, strideb, &beta, output_c, CUDA_R_16F, m, + stridec, batch_count, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP); + } + if (status != CUBLAS_STATUS_SUCCESS) { + throw std::runtime_error("cuinfer_gemm error!"); + } +} + +void cuinfer_gemm(const half *input_a, const half *input_b, const half *bias, half *output_c, int batch_count, int m, + int n, int k, int64_t stridea, int64_t strideb, int64_t stridec, const float alpha, + const int act_type, cudaStream_t &stream, cuinferHandle_t &cuinfer_handle, const float swish_alpha) { + cuinferPointerMode_t cuinfer_ptr_mode = CUINFER_POINTER_MODE_HOST; + cuinferOperation_t transa = CUINFER_OP_T; + cuinferOperation_t transb = CUINFER_OP_N; + cudaDataType_t Atype = CUDA_R_16F; + cudaDataType_t Btype = CUDA_R_16F; + cudaDataType_t Ctype = CUDA_R_16F; + cudaDataType_t computeType = CUDA_R_32F; + cudaDataType_t scaleType = CUDA_R_32F; + cuinferGEMMCustomOption_t customOption; + float tmp; + float *customHostPtr = nullptr; + + if (bias != nullptr) { + if (act_type == 3) { + customOption = CUINFER_BLAS_GEMM_CUSTOM_HALFBIAS_GELU; + } else if (act_type == 4) { + customOption = CUINFER_BLAS_GEMM_CUSTOM_HALFBIAS_RELU; + } else if (act_type == 20) { + customOption = CUINFER_BLAS_GEMM_CUSTOM_HALFBIAS_SWISH; + //tmp = 1.7020000219345093; + tmp = swish_alpha; + customHostPtr = &tmp; + } else if (act_type == 21) { + customOption = CUINFER_BLAS_GEMM_CUSTOM_HALFBIAS_ERF_GELU; + } else { + customOption = CUINFER_BLAS_GEMM_CUSTOM_HALFBIAS; + } + } else { + customOption = CUINFER_BLAS_GEMM_CUSTOM_NONE; + // std::cout << "CUINFER_BLAS_GEMM_CUSTOM_NONE" << std::endl; + } + + int lda = k; + int ldb = k; + int ldc = m; + float beta = 0.f; + + cuinferStatus_t status = + cuinferCustomGemm(cuinfer_handle, stream, cuinfer_ptr_mode, transa, transb, m, n, k, &alpha, input_a, Atype, + lda, stridea, input_b, Btype, ldb, strideb, &beta, output_c, Ctype, ldc, stridec, batch_count, + computeType, scaleType, customHostPtr, (void *)bias, customOption); + if (status != CUINFER_STATUS_SUCCESS) { + throw std::runtime_error("cuinferCustomGemm error, error type: " + std::to_string((int)status) + " !"); + } +} +void cuinfer_gemm(const half *input_a, const half *input_b, const half *bias, half *output_c, int batch_count, int m, + int n, int k, int64_t stridea, int64_t strideb, int64_t stridec, const float alpha, const float beta, + const int act_type, cudaStream_t &stream, cuinferHandle_t &cuinfer_handle) { + cuinferPointerMode_t cuinfer_ptr_mode = CUINFER_POINTER_MODE_HOST; + cuinferOperation_t transa = CUINFER_OP_T; + cuinferOperation_t transb = CUINFER_OP_N; + cudaDataType_t Atype = CUDA_R_16F; + cudaDataType_t Btype = CUDA_R_16F; + cudaDataType_t Ctype = CUDA_R_16F; + cudaDataType_t computeType = CUDA_R_32F; + cudaDataType_t scaleType = CUDA_R_32F; + cuinferGEMMCustomOption_t customOption; + if (bias != nullptr) { + if (act_type == 3) { + customOption = CUINFER_BLAS_GEMM_CUSTOM_HALFBIAS_GELU; + } else if (act_type == 4) { + customOption = CUINFER_BLAS_GEMM_CUSTOM_HALFBIAS_RELU; + } else { + customOption = CUINFER_BLAS_GEMM_CUSTOM_HALFBIAS; + } + } else { + customOption = CUINFER_BLAS_GEMM_CUSTOM_NONE; + // std::cout << "CUINFER_BLAS_GEMM_CUSTOM_NONE" << std::endl; + } + + int lda = k; + int ldb = k; + int ldc = m; + // float beta = 0.f; + + cuinferStatus_t status = + cuinferCustomGemm(cuinfer_handle, stream, cuinfer_ptr_mode, transa, transb, m, n, k, &alpha, input_a, Atype, + lda, stridea, input_b, Btype, ldb, strideb, &beta, output_c, Ctype, ldc, stridec, batch_count, + computeType, scaleType, nullptr, (void *)bias, customOption); + if (status != CUINFER_STATUS_SUCCESS) { + throw std::runtime_error("cuinferCustomGemm error, error type: " + std::to_string((int)status) + " !"); + } +} +void cuinfer_nn_gemm(const half *input_a, const half *input_b, const half *bias, half *output_c, int batch_count, int m, + int n, int k, int64_t stridea, int64_t strideb, int64_t stridec, const float alpha, + const int act_type, cudaStream_t &stream, cuinferHandle_t &cuinfer_handle) { + cuinferPointerMode_t cuinfer_ptr_mode = CUINFER_POINTER_MODE_HOST; + cuinferOperation_t transa = CUINFER_OP_N; + cuinferOperation_t transb = CUINFER_OP_N; + cudaDataType_t Atype = CUDA_R_16F; + cudaDataType_t Btype = CUDA_R_16F; + cudaDataType_t Ctype = CUDA_R_16F; + cudaDataType_t computeType = CUDA_R_32F; + cudaDataType_t scaleType = CUDA_R_32F; + cuinferGEMMCustomOption_t customOption; + if (bias != nullptr) { + if (act_type == 3) { + customOption = CUINFER_BLAS_GEMM_CUSTOM_HALFBIAS_GELU; + + } else if (act_type == 4) { + customOption = CUINFER_BLAS_GEMM_CUSTOM_HALFBIAS_RELU; + } else { + customOption = CUINFER_BLAS_GEMM_CUSTOM_HALFBIAS; + } + } else { + customOption = CUINFER_BLAS_GEMM_CUSTOM_NONE; + } + + int lda = m; + int ldb = k; + int ldc = m; + float beta = 0.f; + + cuinferStatus_t status = + cuinferCustomGemm(cuinfer_handle, stream, cuinfer_ptr_mode, transa, transb, m, n, k, &alpha, input_a, Atype, + lda, stridea, input_b, Btype, ldb, strideb, &beta, output_c, Ctype, ldc, stridec, batch_count, + computeType, scaleType, nullptr, (void *)bias, customOption); + if (status != CUINFER_STATUS_SUCCESS) { + throw std::runtime_error("cuinferCustomGemm error, error type: " + std::to_string((int)status) + " !"); + } +} +void cuinfer_nt_gemm(const half *input_a, const half *input_b, const half *bias, half *output_c, int batch_count, int m, + int n, int k, int64_t stridea, int64_t strideb, int64_t stridec, const float alpha, + const int act_type, cudaStream_t &stream, cuinferHandle_t &cuinfer_handle) { + cuinferPointerMode_t cuinfer_ptr_mode = CUINFER_POINTER_MODE_HOST; + cuinferOperation_t transa = CUINFER_OP_N; + cuinferOperation_t transb = CUINFER_OP_T; + cudaDataType_t Atype = CUDA_R_16F; + cudaDataType_t Btype = CUDA_R_16F; + cudaDataType_t Ctype = CUDA_R_16F; + cudaDataType_t computeType = CUDA_R_32F; + cudaDataType_t scaleType = CUDA_R_32F; + cuinferGEMMCustomOption_t customOption; + if (bias != nullptr) { + if (act_type == 3) { + customOption = CUINFER_BLAS_GEMM_CUSTOM_HALFBIAS_GELU; + + } else if (act_type == 4) { + customOption = CUINFER_BLAS_GEMM_CUSTOM_HALFBIAS_RELU; + } else { + customOption = CUINFER_BLAS_GEMM_CUSTOM_HALFBIAS; + } + } else { + customOption = CUINFER_BLAS_GEMM_CUSTOM_NONE; + } + + int lda = m; + int ldb = n; + int ldc = m; + float beta = 0.f; + + cuinferStatus_t status = + cuinferCustomGemm(cuinfer_handle, stream, cuinfer_ptr_mode, transa, transb, m, n, k, &alpha, input_a, Atype, + lda, stridea, input_b, Btype, ldb, strideb, &beta, output_c, Ctype, ldc, stridec, batch_count, + computeType, scaleType, nullptr, (void *)bias, customOption); + if (status != CUINFER_STATUS_SUCCESS) { + throw std::runtime_error("cuinferCustomGemm error, error type: " + std::to_string((int)status) + " !"); + } +} + +void cuinfer_tt_gemm(const half *input_a, const half *input_b, const half *bias, half *output_c, int batch_count, int m, + int n, int k, int64_t stridea, int64_t strideb, int64_t stridec, const float alpha, + const int act_type, cudaStream_t &stream, cuinferHandle_t &cuinfer_handle) { + cuinferPointerMode_t cuinfer_ptr_mode = CUINFER_POINTER_MODE_HOST; + cuinferOperation_t transa = CUINFER_OP_T; + cuinferOperation_t transb = CUINFER_OP_T; + cudaDataType_t Atype = CUDA_R_16F; + cudaDataType_t Btype = CUDA_R_16F; + cudaDataType_t Ctype = CUDA_R_16F; + cudaDataType_t computeType = CUDA_R_32F; + cudaDataType_t scaleType = CUDA_R_32F; + cuinferGEMMCustomOption_t customOption; + if (bias != nullptr) { + if (act_type == 3) { + customOption = CUINFER_BLAS_GEMM_CUSTOM_HALFBIAS_GELU; + + } else if (act_type == 4) { + customOption = CUINFER_BLAS_GEMM_CUSTOM_HALFBIAS_RELU; + } else { + customOption = CUINFER_BLAS_GEMM_CUSTOM_HALFBIAS; + } + } else { + customOption = CUINFER_BLAS_GEMM_CUSTOM_NONE; + } + + int lda = k; + int ldb = n; + int ldc = m; + float beta = 0.f; + + cuinferStatus_t status = + cuinferCustomGemm(cuinfer_handle, stream, cuinfer_ptr_mode, transa, transb, m, n, k, &alpha, input_a, Atype, + lda, stridea, input_b, Btype, ldb, strideb, &beta, output_c, Ctype, ldc, stridec, batch_count, + computeType, scaleType, nullptr, (void *)bias, customOption); + if (status != CUINFER_STATUS_SUCCESS) { + throw std::runtime_error("cuinferCustomGemm error, error type: " + std::to_string((int)status) + " !"); + } +} + +} // namespace backend +} // namespace nvinfer1::plugin diff --git a/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/backend/ixinfer/ixinfer_gemm_helper.h b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/backend/ixinfer/ixinfer_gemm_helper.h new file mode 100644 index 0000000000000000000000000000000000000000..c5757d67eacea380a2ca90d9eb55448551377fd2 --- /dev/null +++ b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/backend/ixinfer/ixinfer_gemm_helper.h @@ -0,0 +1,72 @@ +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +* All Rights Reserved. +* +* Licensed under the Apache License, Version 2.0 (the "License"); you may +* not use this file except in compliance with the License. You may obtain +* a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +* License for the specific language governing permissions and limitations +* under the License. +*/ +#pragma once +#include +#include +#include +#include + +#include + +namespace nvinfer1::plugin { +namespace backend { + +void cuinfer_i8_gemm(const int8_t *input_a, const int8_t *input_b, int8_t *output_c, int batch_count, int m, int n, + int k, int64_t stridea, int64_t strideb, int64_t stridec, const float alpha, + cuinferHandle_t cuinfer_handle, cudaStream_t stream); + +void cuinfer_i8_gemm(const int8_t *input_a, const int8_t *input_b, const float *bias, int8_t *output_c, int batch_count, + int m, int n, int k, int64_t stridea, int64_t strideb, int64_t stridec, const float alpha, + const float beta, const int act_type, cuinferHandle_t &cuinfer_handle, cudaStream_t &stream); + +void cuinfer_nn_i8_gemm(const int8_t *input_a, const int8_t *input_b, int8_t *output_c, int batch_count, int m, int n, + int k, int64_t stridea, int64_t strideb, int64_t stridec, const float alpha, + cuinferHandle_t cuinfer_handle, cudaStream_t stream); + +void cuinfer_nt_i8_gemm(const int8_t *input_a, const int8_t *input_b, int8_t *output_c, int batch_count, int m, int n, + int k, int64_t stridea, int64_t strideb, int64_t stridec, const float alpha, + cuinferHandle_t cuinfer_handle, cudaStream_t stream); + +void cuinfer_tt_i8_gemm(const int8_t *input_a, const int8_t *input_b, int8_t *output_c, int batch_count, int m, int n, + int k, int64_t stridea, int64_t strideb, int64_t stridec, const float alpha, + cuinferHandle_t cuinfer_handle, cudaStream_t stream); + +void cuinfer_gemm(const half *input_a, const half *input_b, half *output_c, int batch_count, int m, int n, int k, + int64_t stridea, int64_t strideb, int64_t stridec, const float alpha, cublasHandle_t cublas_handle, + cudaStream_t stream); + +void cuinfer_nn_gemm(const half *input_a, const half *input_b, half *output_c, int batch_count, int m, int n, int k, + int64_t stridea, int64_t strideb, int64_t stridec, const float alpha, cublasHandle_t cublas_handle, + cudaStream_t stream); + +void cuinfer_gemm(const half *input_a, const half *input_b, const half *bias, half *output_c, int batch_count, int m, + int n, int k, int64_t stridea, int64_t strideb, int64_t stridec, const float alpha, + const int act_type, cudaStream_t &stream, cuinferHandle_t &cuinfer_handle, + const float swish_alpha = 1.0); +void cuinfer_gemm(const half *input_a, const half *input_b, const half *bias, half *output_c, int batch_count, int m, + int n, int k, int64_t stridea, int64_t strideb, int64_t stridec, const float alpha, const float beta, + const int act_type, cudaStream_t &stream, cuinferHandle_t &cuinfer_handle); +void cuinfer_nn_gemm(const half *input_a, const half *input_b, const half *bias, half *output_c, int batch_count, int m, + int n, int k, int64_t stridea, int64_t strideb, int64_t stridec, const float alpha, + const int act_type, cudaStream_t &stream, cuinferHandle_t &cuinfer_handle); +void cuinfer_nt_gemm(const half *input_a, const half *input_b, const half *bias, half *output_c, int batch_count, int m, + int n, int k, int64_t stridea, int64_t strideb, int64_t stridec, const float alpha, + const int act_type, cudaStream_t &stream, cuinferHandle_t &cuinfer_handle); +void cuinfer_tt_gemm(const half *input_a, const half *input_b, const half *bias, half *output_c, int batch_count, int m, + int n, int k, int64_t stridea, int64_t strideb, int64_t stridec, const float alpha, + const int act_type, cudaStream_t &stream, cuinferHandle_t &cuinfer_handle); +} // namespace backend +} // namespace nvinfer1::plugin diff --git a/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/backend/transformer/transformer_add_norm.cu b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/backend/transformer/transformer_add_norm.cu new file mode 100644 index 0000000000000000000000000000000000000000..a064610d4b9953af43f7d2bc36e3d379eee2b80b --- /dev/null +++ b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/backend/transformer/transformer_add_norm.cu @@ -0,0 +1,918 @@ +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +* All Rights Reserved. +* +* Licensed under the Apache License, Version 2.0 (the "License"); you may +* not use this file except in compliance with the License. You may obtain +* a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +* License for the specific language governing permissions and limitations +* under the License. +*/ +#include + +#include "transformer_add_norm.h" +#include "transformer_helper.cuh" +namespace nvinfer1 { +namespace plugin { +namespace backend { + +const float epsilon = 0.000000000001; + +template +__global__ void IxinferResidualBiasLnI8II8OKernel(const int8_t *input, const __half *scale, const __half *bias, + const __half *residual_bias, int8_t *output, __half *residual, + int hidden_size, float dequant_scale, float quant_scale, + bool is_post_ln) { + // register + // process 2 data + float4 vals[THREAD_DATA_LEN / 4]; + int block_start = blockIdx.x * hidden_size / 4; + char4 *p_input = (char4 *)input; + char4 *p_output = (char4 *)output; + half2 *p_residual = (half2 *)residual; + half2 *p_scale = (half2 *)scale; + half2 *p_bias = (half2 *)bias; + half2 *p_residual_bias = (half2 *)residual_bias; + // one line start + p_input += block_start; + p_output += block_start; + p_residual += block_start * 2; + + float thread_m2 = 0; + float thread_mean = 0; + float thread_count = 0; + + // load data from global memory +#pragma unroll + for (int it = 0; it < THREAD_DATA_LEN / 4; ++it) { + int element_index = threadIdx.x + it * warpSize; + // vals = dequant(input) + residual + vals[it] = char4addhalf2_dequant(p_input[element_index], p_residual[element_index * 2], + p_residual[element_index * 2 + 1], dequant_scale); + half2 res_bias_val_1; + half2 res_bias_val_2; + if (residual_bias == nullptr) { + res_bias_val_1.x = __float2half(0.0f); + res_bias_val_1.y = __float2half(0.0f); + res_bias_val_2.x = __float2half(0.0f); + res_bias_val_2.y = __float2half(0.0f); + } else { + res_bias_val_1 = p_residual_bias[element_index * 2]; + res_bias_val_2 = p_residual_bias[element_index * 2 + 1]; + } + vals[it].x = vals[it].x + __half2float(res_bias_val_1.x); + vals[it].y = vals[it].y + __half2float(res_bias_val_1.y); + vals[it].z = vals[it].z + __half2float(res_bias_val_2.x); + vals[it].w = vals[it].w + __half2float(res_bias_val_2.y); + + WelfordCombine(vals[it].x, &thread_mean, &thread_m2, &thread_count); + WelfordCombine(vals[it].y, &thread_mean, &thread_m2, &thread_count); + WelfordCombine(vals[it].z, &thread_mean, &thread_m2, &thread_count); + WelfordCombine(vals[it].w, &thread_mean, &thread_m2, &thread_count); + } + + // mean var + float mean = 0; + float m2 = 0; + float count = 0; + WelfordWarpReduce(thread_mean, thread_m2, thread_count, &mean, &m2, &count); + mean = __shfl_sync(0xffffffff, mean, 0, warpSize); + m2 = __shfl_sync(0xffffffff, m2, 0, warpSize); + count = __shfl_sync(0xffffffff, count, 0, warpSize); + +#pragma unroll + for (int it = 0; it < THREAD_DATA_LEN / 4; ++it) { + int element_index = threadIdx.x + it * warpSize; + float4 norm_value = compute_float4_norm_value(vals[it], mean, m2, hidden_size, epsilon, + p_scale[element_index * 2], p_scale[element_index * 2 + 1], + p_bias[element_index * 2], p_bias[element_index * 2 + 1]); + + char4 res = float42char4(norm_value, quant_scale); + p_output[element_index] = res; + + half2 r1; + half2 r2; + if (is_post_ln) { + r1.x = __float2half(norm_value.x); + r1.y = __float2half(norm_value.y); + r2.x = __float2half(norm_value.z); + r2.y = __float2half(norm_value.w); + // res.x = __hadd(__float2half(a.x), b.x); + // res.y = __hadd(__float2half(a.y), b.y); + // p_residual[element_index] = float2addhalf2(norm_value, res_bias_val); + } else { + // p_residual[element_index] = float2addhalf2(vals[it], res_bias_val); + r1.x = __float2half(vals[it].x); + r1.y = __float2half(vals[it].y); + r2.x = __float2half(vals[it].z); + r2.y = __float2half(vals[it].w); + } + p_residual[element_index * 2] = r1; + p_residual[element_index * 2 + 1] = r2; + } +} + +void IxinferResidualBiasLnI8II8O_v2(const int8_t *input, const __half *scale, const __half *bias, + const __half *residual_bias, int8_t *output, __half *residual, int batch_tokens, + int hidden_size, float dequant_scale, float quant_scale, cudaStream_t stream, + bool is_post_ln) { + if (hidden_size > 1024) { + throw std::runtime_error("hidden_size should <= 1024"); + } + if (hidden_size % warpSize != 0) { + throw std::runtime_error("hidden_size // warpSize != 0"); + } + if (hidden_size % 256 != 0) { + throw std::runtime_error("hidden_size // 256 != 0"); + } + dim3 gridSize(batch_tokens); + dim3 blockSize(warpSize); + + int num_warp = hidden_size / warpSize; + + switch (num_warp) { + case 1: + IxinferResidualBiasLnI8II8OKernel<1> + <<>>(input, scale, bias, residual_bias, output, residual, hidden_size, + dequant_scale, quant_scale, is_post_ln); + break; + case 2: + IxinferResidualBiasLnI8II8OKernel<2> + <<>>(input, scale, bias, residual_bias, output, residual, hidden_size, + dequant_scale, quant_scale, is_post_ln); + break; + case 3: + IxinferResidualBiasLnI8II8OKernel<3> + <<>>(input, scale, bias, residual_bias, output, residual, hidden_size, + dequant_scale, quant_scale, is_post_ln); + break; + case 4: + IxinferResidualBiasLnI8II8OKernel<4> + <<>>(input, scale, bias, residual_bias, output, residual, hidden_size, + dequant_scale, quant_scale, is_post_ln); + break; + case 5: + IxinferResidualBiasLnI8II8OKernel<5> + <<>>(input, scale, bias, residual_bias, output, residual, hidden_size, + dequant_scale, quant_scale, is_post_ln); + break; + case 6: + IxinferResidualBiasLnI8II8OKernel<6> + <<>>(input, scale, bias, residual_bias, output, residual, hidden_size, + dequant_scale, quant_scale, is_post_ln); + break; + case 7: + IxinferResidualBiasLnI8II8OKernel<7> + <<>>(input, scale, bias, residual_bias, output, residual, hidden_size, + dequant_scale, quant_scale, is_post_ln); + break; + case 8: + IxinferResidualBiasLnI8II8OKernel<8> + <<>>(input, scale, bias, residual_bias, output, residual, hidden_size, + dequant_scale, quant_scale, is_post_ln); + break; + case 9: + IxinferResidualBiasLnI8II8OKernel<9> + <<>>(input, scale, bias, residual_bias, output, residual, hidden_size, + dequant_scale, quant_scale, is_post_ln); + break; + case 10: + IxinferResidualBiasLnI8II8OKernel<10> + <<>>(input, scale, bias, residual_bias, output, residual, hidden_size, + dequant_scale, quant_scale, is_post_ln); + break; + case 11: + IxinferResidualBiasLnI8II8OKernel<11> + <<>>(input, scale, bias, residual_bias, output, residual, hidden_size, + dequant_scale, quant_scale, is_post_ln); + break; + case 12: + IxinferResidualBiasLnI8II8OKernel<12> + <<>>(input, scale, bias, residual_bias, output, residual, hidden_size, + dequant_scale, quant_scale, is_post_ln); + break; + case 13: + IxinferResidualBiasLnI8II8OKernel<13> + <<>>(input, scale, bias, residual_bias, output, residual, hidden_size, + dequant_scale, quant_scale, is_post_ln); + break; + case 14: + IxinferResidualBiasLnI8II8OKernel<14> + <<>>(input, scale, bias, residual_bias, output, residual, hidden_size, + dequant_scale, quant_scale, is_post_ln); + break; + case 15: + IxinferResidualBiasLnI8II8OKernel<15> + <<>>(input, scale, bias, residual_bias, output, residual, hidden_size, + dequant_scale, quant_scale, is_post_ln); + break; + case 16: + IxinferResidualBiasLnI8II8OKernel<16> + <<>>(input, scale, bias, residual_bias, output, residual, hidden_size, + dequant_scale, quant_scale, is_post_ln); + break; + default: + throw std::runtime_error("IxinferResidualBiasLnI8II8O_v2"); + break; + } +} + +template +__global__ void IxinferResidualI8IKernel(int8_t *input, __half *residual_bias, __half *residual, __half *output, + int hidden_size, float dequant_scale) { + // register + float4 vals[THREAD_DATA_LEN]; + int block_start = blockIdx.x * hidden_size; + // one line start + input += block_start; + output += block_start; + residual += block_start; + + char4 *p_input = (char4 *)input; + __half2 *p_output = (__half2 *)output; + __half2 *p_residual = (__half2 *)residual; + + __half2 *p_residual_bias = (__half2 *)residual_bias; + + // load data from global memory +#pragma unroll + for (int it = 0; it < THREAD_DATA_LEN; ++it) { + int element_index = threadIdx.x + it * warpSize; + char4 i_value = p_input[element_index]; + float4 r_value; + float4 r_bias_value; + + load_float4_from_half(r_value, p_residual, element_index); + load_float4_from_half(r_bias_value, p_residual_bias, element_index); + + vals[it].x = (float)i_value.x * dequant_scale + r_value.x + (r_bias_value.x); + vals[it].y = (float)i_value.y * dequant_scale + r_value.y + (r_bias_value.y); + vals[it].z = (float)i_value.z * dequant_scale + r_value.z + (r_bias_value.z); + vals[it].w = (float)i_value.w * dequant_scale + r_value.w + (r_bias_value.w); + + half2 r1; + half2 r2; + // if( blockIdx.x==0 && element_index==0 ) + // { + // printf("i_value.x %d dequant_scale %f r_value.x %f r_bias_value.x %f + // vals[it].x %f ",i_value.x , dequant_scale , r_value.x , r_bias_value.x + // , vals[it].x); + // } + + r1.x = __float2half(vals[it].x); + r1.y = __float2half(vals[it].y); + r2.x = __float2half(vals[it].z); + r2.y = __float2half(vals[it].w); + + p_output[element_index * 2] = r1; + p_output[element_index * 2 + 1] = r2; + } +} + +void IxinferResidualBiasI8I(int8_t *input, __half *residual_bias, __half *residual, __half *output, int batch_tokens, + int hidden_size, float dequant_scale, cudaStream_t stream) { + if (hidden_size > 4096) { + throw std::runtime_error("hidden_size should <= 4096"); + } + if (hidden_size / 4 % warpSize != 0) { + throw std::runtime_error("hidden_size // warpSize != 0"); + } + dim3 gridSize(batch_tokens); + dim3 blockSize(warpSize); + + int num_warp = hidden_size / warpSize / 4; + + switch (num_warp) { + case 1: + IxinferResidualI8IKernel<1><<>>(input, residual_bias, residual, output, + hidden_size, dequant_scale); + break; + case 2: + IxinferResidualI8IKernel<2><<>>(input, residual_bias, residual, output, + hidden_size, dequant_scale); + break; + case 3: + IxinferResidualI8IKernel<3><<>>(input, residual_bias, residual, output, + hidden_size, dequant_scale); + break; + case 4: + IxinferResidualI8IKernel<4><<>>(input, residual_bias, residual, output, + hidden_size, dequant_scale); + break; + case 5: + IxinferResidualI8IKernel<5><<>>(input, residual_bias, residual, output, + hidden_size, dequant_scale); + break; + case 6: + IxinferResidualI8IKernel<6><<>>(input, residual_bias, residual, output, + hidden_size, dequant_scale); + break; + case 7: + IxinferResidualI8IKernel<7><<>>(input, residual_bias, residual, output, + hidden_size, dequant_scale); + break; + case 8: + IxinferResidualI8IKernel<8><<>>(input, residual_bias, residual, output, + hidden_size, dequant_scale); + break; + case 9: + IxinferResidualI8IKernel<9><<>>(input, residual_bias, residual, output, + hidden_size, dequant_scale); + break; + case 10: + IxinferResidualI8IKernel<10><<>>(input, residual_bias, residual, output, + hidden_size, dequant_scale); + break; + case 11: + IxinferResidualI8IKernel<11><<>>(input, residual_bias, residual, output, + hidden_size, dequant_scale); + break; + case 12: + IxinferResidualI8IKernel<12><<>>(input, residual_bias, residual, output, + hidden_size, dequant_scale); + break; + case 13: + IxinferResidualI8IKernel<13><<>>(input, residual_bias, residual, output, + hidden_size, dequant_scale); + break; + case 14: + IxinferResidualI8IKernel<14><<>>(input, residual_bias, residual, output, + hidden_size, dequant_scale); + break; + case 15: + IxinferResidualI8IKernel<15><<>>(input, residual_bias, residual, output, + hidden_size, dequant_scale); + break; + case 16: + IxinferResidualI8IKernel<16><<>>(input, residual_bias, residual, output, + hidden_size, dequant_scale); + break; + default: + throw std::runtime_error("IxinferResidualBiasI8I"); + break; + } +} + +template +__global__ void IxinferResidualBiasLn(const half *input, const half *scale, const half *bias, const half *residual_bias, + half *output, half *residual, int hidden_size, bool is_post_ln) { + // register + // process 2 data + float2 vals[THREAD_DATA_LEN]; + int block_start = blockIdx.x * hidden_size / 2; + half2 *p_input = (half2 *)input; + half2 *p_output = (half2 *)output; + half2 *p_residual = (half2 *)residual; + half2 *p_scale = (half2 *)scale; + half2 *p_bias = (half2 *)bias; + half2 *p_residual_bias = (half2 *)residual_bias; + // one line start + p_input += block_start; + p_output += block_start; + p_residual += block_start; + + float thread_m2 = 0; + float thread_mean = 0; + float thread_count = 0; + + // load data from global memory +#pragma unroll + for (int it = 0; it < THREAD_DATA_LEN; ++it) { + int element_index = threadIdx.x + it * warpSize; + // vals = dequant(input) + residual + half2 value1 = p_input[element_index]; + half2 value2 = p_residual[element_index]; + + vals[it].x = __half2float(value1.x) + __half2float(value2.x); + vals[it].y = __half2float(value1.y) + __half2float(value2.y); + + half2 res_bias_val_1; + if (residual_bias == nullptr) { + res_bias_val_1.x = __float2half(0.0f); + res_bias_val_1.y = __float2half(0.0f); + } else { + res_bias_val_1 = p_residual_bias[element_index]; + } + vals[it].x = vals[it].x + __half2float(res_bias_val_1.x); + vals[it].y = vals[it].y + __half2float(res_bias_val_1.y); + + WelfordCombine(vals[it].x, &thread_mean, &thread_m2, &thread_count); + WelfordCombine(vals[it].y, &thread_mean, &thread_m2, &thread_count); + } + + // mean var + float mean = 0; + float m2 = 0; + float count = 0; + WelfordWarpReduce(thread_mean, thread_m2, thread_count, &mean, &m2, &count); + mean = __shfl_sync(0xffffffff, mean, 0, warpSize); + m2 = __shfl_sync(0xffffffff, m2, 0, warpSize); + count = __shfl_sync(0xffffffff, count, 0, warpSize); + +#pragma unroll + for (int it = 0; it < THREAD_DATA_LEN; ++it) { + int element_index = threadIdx.x + it * warpSize; + float2 norm_value; + half2 scale_1 = p_scale[element_index]; + half2 bias_1 = p_bias[element_index]; + norm_value.x = + (vals[it].x - mean) * rsqrtf(m2 / hidden_size + epsilon) * __half2float(scale_1.x) + __half2float(bias_1.x); + norm_value.y = + (vals[it].y - mean) * rsqrtf(m2 / hidden_size + epsilon) * __half2float(scale_1.y) + __half2float(bias_1.y); + + half2 res; + res.x = __float2half(norm_value.x); + res.y = __float2half(norm_value.y); + + p_output[element_index] = res; + + half2 r1; + if (is_post_ln) { + r1 = res; + } else { + r1.x = __float2half(vals[it].x); + r1.y = __float2half(vals[it].y); + } + p_residual[element_index] = r1; + } +} + +void IxinferResidualBiasLn(const half *input, const half *scale, const half *bias, const half *residual_bias, + half *output, half *residual, int batch_tokens, int hidden_size, cudaStream_t stream, + bool is_post_ln) { + if (hidden_size > 2048) { + throw std::runtime_error("hidden_size should <= 1024"); + } + if ((hidden_size % 2 == 0) && (hidden_size % (warpSize * 2) != 0)) { + IxinferResidualBiasLnPad(input, scale, bias, residual_bias, output, residual, batch_tokens, hidden_size, stream, + is_post_ln); + } else { + if (hidden_size % (warpSize * 2) != 0) { + throw std::runtime_error("hidden_size // (warpSize*2) != 0"); + } + dim3 gridSize(batch_tokens); + dim3 blockSize(warpSize); + + int num_warp = hidden_size / warpSize / 2; + + switch (num_warp) { + case 1: + IxinferResidualBiasLn<1><<>>(input, scale, bias, residual_bias, output, + residual, hidden_size, is_post_ln); + break; + case 2: + IxinferResidualBiasLn<2><<>>(input, scale, bias, residual_bias, output, + residual, hidden_size, is_post_ln); + break; + case 3: + IxinferResidualBiasLn<3><<>>(input, scale, bias, residual_bias, output, + residual, hidden_size, is_post_ln); + break; + case 4: + IxinferResidualBiasLn<4><<>>(input, scale, bias, residual_bias, output, + residual, hidden_size, is_post_ln); + break; + case 5: + IxinferResidualBiasLn<5><<>>(input, scale, bias, residual_bias, output, + residual, hidden_size, is_post_ln); + break; + case 6: + IxinferResidualBiasLn<6><<>>(input, scale, bias, residual_bias, output, + residual, hidden_size, is_post_ln); + break; + case 7: + IxinferResidualBiasLn<7><<>>(input, scale, bias, residual_bias, output, + residual, hidden_size, is_post_ln); + break; + case 8: + IxinferResidualBiasLn<8><<>>(input, scale, bias, residual_bias, output, + residual, hidden_size, is_post_ln); + break; + case 9: + IxinferResidualBiasLn<9><<>>(input, scale, bias, residual_bias, output, + residual, hidden_size, is_post_ln); + break; + case 10: + IxinferResidualBiasLn<10><<>>(input, scale, bias, residual_bias, output, + residual, hidden_size, is_post_ln); + break; + case 11: + IxinferResidualBiasLn<11><<>>(input, scale, bias, residual_bias, output, + residual, hidden_size, is_post_ln); + break; + case 12: + IxinferResidualBiasLn<12><<>>(input, scale, bias, residual_bias, output, + residual, hidden_size, is_post_ln); + break; + case 13: + IxinferResidualBiasLn<13><<>>(input, scale, bias, residual_bias, output, + residual, hidden_size, is_post_ln); + break; + case 14: + IxinferResidualBiasLn<14><<>>(input, scale, bias, residual_bias, output, + residual, hidden_size, is_post_ln); + break; + case 15: + IxinferResidualBiasLn<15><<>>(input, scale, bias, residual_bias, output, + residual, hidden_size, is_post_ln); + break; + case 16: + IxinferResidualBiasLn<16><<>>(input, scale, bias, residual_bias, output, + residual, hidden_size, is_post_ln); + break; + default: + throw std::runtime_error("IxinferResidualBiasLn"); + break; + } + } +} + +__global__ void IxinferResidualBiasKernel(half *input, half *residual_bias, half *residual, half *output, + int hidden_size) { + int block_start = blockIdx.x * hidden_size; + // one line start + input += block_start; + output += block_start; + residual += block_start; + + half2 *p_input = (half2 *)input; + half2 *p_output = (half2 *)output; + half2 *p_residual = (half2 *)residual; + + half2 *p_residual_bias = (half2 *)residual_bias; + + half2 value1 = p_input[threadIdx.x]; + half2 value2 = p_residual[threadIdx.x]; + half2 value3 = p_residual_bias[threadIdx.x]; + + float2 value_out; + value_out.x = __half2float(value1.x) + __half2float(value2.x) + __half2float(value3.x); + value_out.y = __half2float(value1.y) + __half2float(value2.y) + __half2float(value3.y); + + half2 res; + res.x = __float2half(value_out.x); + res.y = __float2half(value_out.y); + + p_output[threadIdx.x] = res; +} + +void IxinferResidualBias(half *input, half *residual_bias, half *residual, half *output, int batch_tokens, + int hidden_size, cudaStream_t stream) { + if (hidden_size / 2 > 4096) { + throw std::runtime_error("hidden_size/2 should <= 4096"); + } + if (hidden_size % 2 != 0) { + throw std::runtime_error("hidden_size % 2!= 0"); + } + dim3 gridSize(batch_tokens); + dim3 blockSize(hidden_size / 2); + IxinferResidualBiasKernel<<>>(input, residual_bias, residual, output, hidden_size); +} + +template +__global__ void IxinferResidualBiasLnPad(const half *input, const half *scale, const half *bias, + const half *residual_bias, half *output, half *residual, int hidden_size, + bool is_post_ln) { + // register + // process 2 data + float2 vals[THREAD_DATA_LEN]; + int block_start = blockIdx.x * hidden_size / 2; + half2 *p_input = (half2 *)input; + half2 *p_output = (half2 *)output; + half2 *p_residual = (half2 *)residual; + half2 *p_scale = (half2 *)scale; + half2 *p_bias = (half2 *)bias; + half2 *p_residual_bias = (half2 *)residual_bias; + // one line start + p_input += block_start; + p_output += block_start; + p_residual += block_start; + + float thread_m2 = 0; + float thread_mean = 0; + float thread_count = 0; + + // load data from global memory +#pragma unroll + for (int it = 0; it < THREAD_DATA_LEN; ++it) { + int element_index = threadIdx.x + it * warpSize; + if (element_index < hidden_size / 2) { + // vals = dequant(input) + residual + half2 value1 = p_input[element_index]; + half2 value2 = p_residual[element_index]; + + vals[it].x = __half2float(value1.x) + __half2float(value2.x); + vals[it].y = __half2float(value1.y) + __half2float(value2.y); + + half2 res_bias_val_1; + if (residual_bias == nullptr) { + res_bias_val_1.x = __float2half(0.0f); + res_bias_val_1.y = __float2half(0.0f); + } else { + res_bias_val_1 = p_residual_bias[element_index]; + } + vals[it].x = vals[it].x + __half2float(res_bias_val_1.x); + vals[it].y = vals[it].y + __half2float(res_bias_val_1.y); + + WelfordCombine(vals[it].x, &thread_mean, &thread_m2, &thread_count); + WelfordCombine(vals[it].y, &thread_mean, &thread_m2, &thread_count); + } + } + + // mean var + float mean = 0; + float m2 = 0; + float count = 0; + WelfordWarpReduce(thread_mean, thread_m2, thread_count, &mean, &m2, &count); + mean = __shfl_sync(0xffffffff, mean, 0, warpSize); + m2 = __shfl_sync(0xffffffff, m2, 0, warpSize); + count = __shfl_sync(0xffffffff, count, 0, warpSize); + +#pragma unroll + for (int it = 0; it < THREAD_DATA_LEN; ++it) { + int element_index = threadIdx.x + it * warpSize; + if (element_index < hidden_size / 2) { + float2 norm_value; + half2 scale_1 = p_scale[element_index]; + half2 bias_1 = p_bias[element_index]; + norm_value.x = (vals[it].x - mean) * rsqrtf(m2 / hidden_size + epsilon) * __half2float(scale_1.x) + + __half2float(bias_1.x); + norm_value.y = (vals[it].y - mean) * rsqrtf(m2 / hidden_size + epsilon) * __half2float(scale_1.y) + + __half2float(bias_1.y); + + half2 res; + res.x = __float2half(norm_value.x); + res.y = __float2half(norm_value.y); + + p_output[element_index] = res; + + half2 r1; + if (is_post_ln) { + r1 = res; + } else { + r1.x = __float2half(vals[it].x); + r1.y = __float2half(vals[it].y); + } + p_residual[element_index] = r1; + } + } +} + +void IxinferResidualBiasLnPad(const half *input, const half *scale, const half *bias, const half *residual_bias, + half *output, half *residual, int batch_tokens, int hidden_size, cudaStream_t stream, + bool is_post_ln) { + if (hidden_size > 2048) { + throw std::runtime_error("hidden_size should <= 1024"); + } + if (hidden_size % 2 != 0) { + throw std::runtime_error("hidden_size % 2 != 0"); + } + + dim3 gridSize(batch_tokens); + dim3 blockSize(warpSize); + + int neareast_hidden_size = hidden_size; + if (neareast_hidden_size % (warpSize * 2) != 0) { + neareast_hidden_size = neareast_hidden_size + warpSize * 2 - neareast_hidden_size % (warpSize * 2); + } + + int num_warp = neareast_hidden_size / warpSize / 2; + + switch (num_warp) { + case 1: + IxinferResidualBiasLnPad<1><<>>(input, scale, bias, residual_bias, output, + residual, hidden_size, is_post_ln); + break; + case 2: + IxinferResidualBiasLnPad<2><<>>(input, scale, bias, residual_bias, output, + residual, hidden_size, is_post_ln); + break; + case 3: + IxinferResidualBiasLnPad<3><<>>(input, scale, bias, residual_bias, output, + residual, hidden_size, is_post_ln); + break; + case 4: + IxinferResidualBiasLnPad<4><<>>(input, scale, bias, residual_bias, output, + residual, hidden_size, is_post_ln); + break; + case 5: + IxinferResidualBiasLnPad<5><<>>(input, scale, bias, residual_bias, output, + residual, hidden_size, is_post_ln); + break; + case 6: + IxinferResidualBiasLnPad<6><<>>(input, scale, bias, residual_bias, output, + residual, hidden_size, is_post_ln); + break; + case 7: + IxinferResidualBiasLnPad<7><<>>(input, scale, bias, residual_bias, output, + residual, hidden_size, is_post_ln); + break; + case 8: + IxinferResidualBiasLnPad<8><<>>(input, scale, bias, residual_bias, output, + residual, hidden_size, is_post_ln); + break; + case 9: + IxinferResidualBiasLnPad<9><<>>(input, scale, bias, residual_bias, output, + residual, hidden_size, is_post_ln); + break; + case 10: + IxinferResidualBiasLnPad<10><<>>(input, scale, bias, residual_bias, output, + residual, hidden_size, is_post_ln); + break; + case 11: + IxinferResidualBiasLnPad<11><<>>(input, scale, bias, residual_bias, output, + residual, hidden_size, is_post_ln); + break; + case 12: + IxinferResidualBiasLnPad<12><<>>(input, scale, bias, residual_bias, output, + residual, hidden_size, is_post_ln); + break; + case 13: + IxinferResidualBiasLnPad<13><<>>(input, scale, bias, residual_bias, output, + residual, hidden_size, is_post_ln); + break; + case 14: + IxinferResidualBiasLnPad<14><<>>(input, scale, bias, residual_bias, output, + residual, hidden_size, is_post_ln); + break; + case 15: + IxinferResidualBiasLnPad<15><<>>(input, scale, bias, residual_bias, output, + residual, hidden_size, is_post_ln); + break; + case 16: + IxinferResidualBiasLnPad<16><<>>(input, scale, bias, residual_bias, output, + residual, hidden_size, is_post_ln); + break; + default: + std::cout << "hidden size: " << hidden_size << std::endl; + throw std::runtime_error("IxinferResidualBiasLnPad not supported!"); + break; + } +} + +template +__global__ void IxinferResidualBiasLn(const float *input, const half *scale, const half *bias, + const half *residual_bias, float *output, float *residual, int hidden_size, + bool is_post_ln) { + // register + // process 1 data + float vals[THREAD_DATA_LEN]; + int block_start = blockIdx.x * hidden_size; + float *p_input = (float *)input; + float *p_output = (float *)output; + float *p_residual = (float *)residual; + half *p_scale = (half *)scale; + half *p_bias = (half *)bias; + half *p_residual_bias = (half *)residual_bias; + // one line start + p_input += block_start; + p_output += block_start; + p_residual += block_start; + + float thread_m2 = 0; + float thread_mean = 0; + float thread_count = 0; + + // load data from global memory +#pragma unroll + for (int it = 0; it < THREAD_DATA_LEN; ++it) { + int element_index = threadIdx.x + it * warpSize; + // vals = dequant(input) + residual + float value1 = p_input[element_index]; + float value2 = p_residual[element_index]; + + vals[it] = value1 + value2; + + float res_bias_val_1; + if (residual_bias == nullptr) { + res_bias_val_1 = 0.0f; + } else { + res_bias_val_1 = __half2float(p_residual_bias[element_index]); + } + vals[it] = vals[it] + res_bias_val_1; + + WelfordCombine(vals[it], &thread_mean, &thread_m2, &thread_count); + } + + // mean var + float mean = 0; + float m2 = 0; + float count = 0; + WelfordWarpReduce(thread_mean, thread_m2, thread_count, &mean, &m2, &count); + mean = __shfl_sync(0xffffffff, mean, 0, warpSize); + m2 = __shfl_sync(0xffffffff, m2, 0, warpSize); + count = __shfl_sync(0xffffffff, count, 0, warpSize); + +#pragma unroll + for (int it = 0; it < THREAD_DATA_LEN; ++it) { + int element_index = threadIdx.x + it * warpSize; + float norm_value; + float scale_1 = __half2float(p_scale[element_index]); + float bias_1 = __half2float(p_bias[element_index]); + norm_value = (vals[it] - mean) * rsqrtf(m2 / hidden_size + epsilon) * scale_1 + bias_1; + + p_output[element_index] = norm_value; + + float r1; + if (is_post_ln) { + r1 = norm_value; + } else { + r1 = vals[it]; + } + p_residual[element_index] = r1; + } +} + +void IxinferResidualBiasLn(const float *input, const half *scale, const half *bias, const half *residual_bias, + float *output, float *residual, int batch_tokens, int hidden_size, cudaStream_t stream, + bool is_post_ln) { + if (hidden_size > 1024) { + throw std::runtime_error("hidden_size should <= 1024"); + } + if (hidden_size % warpSize != 0) { + throw std::runtime_error("hidden_size // warpSize != 0"); + } + dim3 gridSize(batch_tokens); + dim3 blockSize(warpSize); + + int num_warp = hidden_size / warpSize; + + switch (num_warp) { + case 1: + IxinferResidualBiasLn<1><<>>(input, scale, bias, residual_bias, output, + residual, hidden_size, is_post_ln); + break; + case 2: + IxinferResidualBiasLn<2><<>>(input, scale, bias, residual_bias, output, + residual, hidden_size, is_post_ln); + break; + case 3: + IxinferResidualBiasLn<3><<>>(input, scale, bias, residual_bias, output, + residual, hidden_size, is_post_ln); + break; + case 4: + IxinferResidualBiasLn<4><<>>(input, scale, bias, residual_bias, output, + residual, hidden_size, is_post_ln); + break; + case 5: + IxinferResidualBiasLn<5><<>>(input, scale, bias, residual_bias, output, + residual, hidden_size, is_post_ln); + break; + case 6: + IxinferResidualBiasLn<6><<>>(input, scale, bias, residual_bias, output, + residual, hidden_size, is_post_ln); + break; + case 7: + IxinferResidualBiasLn<7><<>>(input, scale, bias, residual_bias, output, + residual, hidden_size, is_post_ln); + break; + case 8: + IxinferResidualBiasLn<8><<>>(input, scale, bias, residual_bias, output, + residual, hidden_size, is_post_ln); + break; + case 9: + IxinferResidualBiasLn<9><<>>(input, scale, bias, residual_bias, output, + residual, hidden_size, is_post_ln); + break; + case 10: + IxinferResidualBiasLn<10><<>>(input, scale, bias, residual_bias, output, + residual, hidden_size, is_post_ln); + break; + case 11: + IxinferResidualBiasLn<11><<>>(input, scale, bias, residual_bias, output, + residual, hidden_size, is_post_ln); + break; + case 12: + IxinferResidualBiasLn<12><<>>(input, scale, bias, residual_bias, output, + residual, hidden_size, is_post_ln); + break; + case 13: + IxinferResidualBiasLn<13><<>>(input, scale, bias, residual_bias, output, + residual, hidden_size, is_post_ln); + break; + case 14: + IxinferResidualBiasLn<14><<>>(input, scale, bias, residual_bias, output, + residual, hidden_size, is_post_ln); + break; + case 15: + IxinferResidualBiasLn<15><<>>(input, scale, bias, residual_bias, output, + residual, hidden_size, is_post_ln); + break; + case 16: + IxinferResidualBiasLn<16><<>>(input, scale, bias, residual_bias, output, + residual, hidden_size, is_post_ln); + break; + default: + throw std::runtime_error("IxinferResidualBiasLn"); + break; + } +} + +} // namespace backend +} // namespace plugin +} // namespace nvinfer1 diff --git a/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/backend/transformer/transformer_add_norm.h b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/backend/transformer/transformer_add_norm.h new file mode 100644 index 0000000000000000000000000000000000000000..975e8bc46bb2f1b877b4d9d9642dd99c9ba82116 --- /dev/null +++ b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/backend/transformer/transformer_add_norm.h @@ -0,0 +1,96 @@ +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +* All Rights Reserved. +* +* Licensed under the Apache License, Version 2.0 (the "License"); you may +* not use this file except in compliance with the License. You may obtain +* a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +* License for the specific language governing permissions and limitations +* under the License. +*/ +#pragma once +#include +#include + +#include + +namespace nvinfer1 { +namespace plugin { +namespace backend { + +/* +transformer add && norm 模块 +input: self_attn 的结果 [bsz*seq_len,hsz] +scale: [hsz] layer_norm_weight +bias: [hsz] layer_norm_bias +residual_bias: [hsz] self_attn_out_proj_bias +output: [bsz*seq_len,hsz] +residual: [bsz*seq_len,hsz] + +is_post_ln: ln 是否在后面 + +x = dequant(input) + residual_bias + residual +y = layer_norm(x) +output = quant(y) + +1. is_post_ln is True +residual = y +2. is_post_ln is False +residual = x +*/ +// bert v2,丢弃掉lightseq 中的写法,residual不再提前加下一个linear 的bias +void IxinferResidualBiasLnI8II8O_v2(const int8_t *input, const __half *scale, const __half *bias, + const __half *residual_bias, int8_t *output, __half *residual, int batch_tokens, + int hidden_size, float dequant_scale, float quant_scale, cudaStream_t stream, + bool is_post_ln); + +/* +output = dequant(input) + residual_bias + residual +*/ +void IxinferResidualBiasI8I(int8_t *input, __half *residual_bias, __half *residual, __half *output, int batch_tokens, + int hidden_size, float dequant_scale, cudaStream_t stream); + +/* +@jian.wang +transformer add && norm 模块 +input: self_attn 的结果 [batch_tokens,hidden_size] +scale: [hidden_size] layer_norm_weight +bias: [hidden_size] layer_norm_bias +residual_bias: [hidden_size] self_attn_out_proj_bias +output: [batch_tokens,hidden_size] +residual: [batch_tokens,hidden_size] + +is_post_ln: ln 是否在后面 + +x = input + residual_bias + residual +y = layer_norm(x) +output = y + +1. is_post_ln is True +residual = y +2. is_post_ln is False +residual = x +*/ +void IxinferResidualBiasLn(const float *input, const half *scale, const half *bias, const half *residual_bias, + float *output, float *residual, int batch_tokens, int hidden_size, cudaStream_t stream, + bool is_post_ln); +void IxinferResidualBiasLn(const half *input, const half *scale, const half *bias, const half *residual_bias, + half *output, half *residual, int batch_tokens, int hidden_size, cudaStream_t stream, + bool is_post_ln); + +/* +output = input + residual_bias + residual +*/ +void IxinferResidualBias(half *input, half *residual_bias, half *residual, half *output, int batch_tokens, + int hidden_size, cudaStream_t stream); +void IxinferResidualBiasLnPad(const half *input, const half *scale, const half *bias, const half *residual_bias, + half *output, half *residual, int batch_tokens, int hidden_size, cudaStream_t stream, + bool is_post_ln); +} // namespace backend +} // namespace plugin +} // namespace nvinfer1 diff --git a/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/backend/transformer/transformer_arrange.cu b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/backend/transformer/transformer_arrange.cu new file mode 100644 index 0000000000000000000000000000000000000000..f2e04a2b0b7c5c0de0a6a64b93c92eabe63a7e82 --- /dev/null +++ b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/backend/transformer/transformer_arrange.cu @@ -0,0 +1,1001 @@ +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +* All Rights Reserved. +* +* Licensed under the Apache License, Version 2.0 (the "License"); you may +* not use this file except in compliance with the License. You may obtain +* a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +* License for the specific language governing permissions and limitations +* under the License. +*/ +#include "transformer_arrange.h" +#include "transformer_helper.cuh" + +namespace nvinfer1 { +namespace plugin { +namespace backend { + +/*method 1 by pxl + */ +__global__ void IxinferArrangeDecSelfQkvI8II8OKernel(const int8_t *ori_qkv, const __half *qkv_bias, int8_t *new_q, + int8_t *new_k, int8_t *new_v, int8_t *prev_k, int8_t *prev_v, + int batch_size, int dim_per_head, int head_num, int seq_idx, + int seq_len_pad, int max_seq_len, float quant_scale, + float dequant_scale) { + int hidden_size = dim_per_head * head_num; + int batch_id = blockIdx.x % batch_size; + int token_id = blockIdx.x / batch_size; + + int i = threadIdx.x; // 1个线程处理4个数据 + + int head_id = (i * 4) / dim_per_head; + int dim_id = (i * 4) % dim_per_head; + + char4 value; + char4 *p_ori_qkv; + half2 *p_qkv_bias; + char4 *p_new_qkv; + int target_id; + // q + p_ori_qkv = (char4 *)(ori_qkv + blockIdx.x * 3 * hidden_size); + p_qkv_bias = (half2 *)(qkv_bias + i * 4); + value.x = float2int8(float(p_ori_qkv[i].x) * dequant_scale + __half2float(p_qkv_bias[0].x), quant_scale); + value.y = float2int8(float(p_ori_qkv[i].y) * dequant_scale + __half2float(p_qkv_bias[0].y), quant_scale); + value.z = float2int8(float(p_ori_qkv[i].z) * dequant_scale + __half2float(p_qkv_bias[1].x), quant_scale); + value.w = float2int8(float(p_ori_qkv[i].w) * dequant_scale + __half2float(p_qkv_bias[1].y), quant_scale); + target_id = batch_id * head_num * 4 * dim_per_head + head_id * 4 * dim_per_head + dim_id; + p_new_qkv = (char4 *)(new_q + target_id); + p_new_qkv[0] = value; + + // k + p_ori_qkv = (char4 *)(ori_qkv + blockIdx.x * 3 * hidden_size + hidden_size); + p_qkv_bias = (half2 *)(qkv_bias + hidden_size + i * 4); + value.x = float2int8(float(p_ori_qkv[i].x) * dequant_scale + __half2float(p_qkv_bias[0].x), quant_scale); + value.y = float2int8(float(p_ori_qkv[i].y) * dequant_scale + __half2float(p_qkv_bias[0].y), quant_scale); + value.z = float2int8(float(p_ori_qkv[i].z) * dequant_scale + __half2float(p_qkv_bias[1].x), quant_scale); + value.w = float2int8(float(p_ori_qkv[i].w) * dequant_scale + __half2float(p_qkv_bias[1].y), quant_scale); + target_id = batch_id * head_num * seq_len_pad * dim_per_head + head_id * seq_len_pad * dim_per_head + + seq_idx * dim_per_head + dim_id; + p_new_qkv = (char4 *)(new_k + target_id); + p_new_qkv[0] = value; + // update to prev + target_id = batch_id * head_num * max_seq_len * dim_per_head + head_id * max_seq_len * dim_per_head + + seq_idx * dim_per_head + dim_id; + p_new_qkv = (char4 *)(prev_k + target_id); + p_new_qkv[0] = value; + + // v + p_ori_qkv = (char4 *)(ori_qkv + blockIdx.x * 3 * hidden_size + 2 * hidden_size); + p_qkv_bias = (half2 *)(qkv_bias + hidden_size * 2 + i * 4); + value.x = float2int8(float(p_ori_qkv[i].x) * dequant_scale + __half2float(p_qkv_bias[0].x), quant_scale); + value.y = float2int8(float(p_ori_qkv[i].y) * dequant_scale + __half2float(p_qkv_bias[0].y), quant_scale); + value.z = float2int8(float(p_ori_qkv[i].z) * dequant_scale + __half2float(p_qkv_bias[1].x), quant_scale); + value.w = float2int8(float(p_ori_qkv[i].w) * dequant_scale + __half2float(p_qkv_bias[1].y), quant_scale); + target_id = batch_id * head_num * seq_len_pad * dim_per_head + head_id * seq_len_pad * dim_per_head + + seq_idx * dim_per_head + dim_id; + p_new_qkv = (char4 *)(new_v + target_id); + p_new_qkv[0] = value; + // update to prev + target_id = batch_id * head_num * max_seq_len * dim_per_head + head_id * max_seq_len * dim_per_head + + seq_idx * dim_per_head + dim_id; + p_new_qkv = (char4 *)(prev_v + target_id); + p_new_qkv[0] = value; +} + +/* @jian.wang +for transformer decoder self-attn + +qkv = dequant(ori_qkv) + qkv_bias +qkv = quant(qkv) + +q,k,v split +qkv都为: 1,batch_size,head_num,head_dim +qkv 0123->0213 + +q写到 batch_size*head_num,4,head_dim seq_idx = 0 +1. kv写到: batch_size*head_num,nearesr_4(seq_len),head_dim seq_idx = seq_idx +2. kv写到prev_k,prev_v: batch_size*head_num,max_seq_len,head_dim seq_idx = +seq_idx +*/ +void IxinferArrangeDecSelfQkvI8II8O(int batch_token_num, int hidden_size, const int8_t *ori_qkv, const __half *qkv_bias, + int8_t *new_q, int8_t *new_k, int8_t *new_v, int8_t *prev_k, int8_t *prev_v, + int seq_idx, int batch_seq_len, int dim_per_head, int head_num, int max_seq_len, + float quant_scale, float dequant_scale, cudaStream_t stream) { + if (hidden_size / 4 > 4096) { + throw std::runtime_error("hidden_size / 4 > 4096"); + } + if (head_num * dim_per_head != hidden_size) { + throw std::runtime_error("head_num * dim_per_head!=hidden_size"); + } + if (dim_per_head != 64) { + throw std::runtime_error("dim_per_head!=64"); + } + if (batch_seq_len != 1) { + throw std::runtime_error("batch_seq_len!=1"); + } + int seq_len_pad = nearest_4(seq_idx + 1); + int batch_size = batch_token_num / batch_seq_len; + IxinferArrangeDecSelfQkvI8II8OKernel<<>>( + ori_qkv, qkv_bias, new_q, new_k, new_v, prev_k, prev_v, batch_size, dim_per_head, head_num, seq_idx, + seq_len_pad, max_seq_len, quant_scale, dequant_scale); +} + +__global__ void IxinferDecSelfKvCatI8II8O(int8_t *k, int8_t *v, const int8_t *prev_k, const int8_t *prev_v, + int max_seq_len, int seq_len_pad, int seq_idx, int head_dim) { + k += blockIdx.x * seq_len_pad * head_dim; + v += blockIdx.x * seq_len_pad * head_dim; + prev_k += blockIdx.x * max_seq_len * head_dim; + prev_v += blockIdx.x * max_seq_len * head_dim; + char4 *p_k = (char4 *)k; + char4 *p_v = (char4 *)v; + char4 *p_prev_k = (char4 *)prev_k; + char4 *p_prev_v = (char4 *)prev_v; + int idx = threadIdx.x; + while (idx < seq_idx * head_dim / 4) { + p_k[idx] = p_prev_k[idx]; + p_v[idx] = p_prev_v[idx]; + idx += blockDim.x; + } +} + +/* @ jian.wang +k: bsz*head_num, seq_len_pad, head_dim +prev_k: bsz*head_num, seq_idx, head_dim + +将 prev_k 写到 k[bsz*head_num, :seq_idx, head_dim] +*/ +void IxinferDecSelfKvCatI8II8O(int8_t *k, int8_t *v, int8_t *prev_k, int8_t *prev_v, int batches, int head_num, + int seq_idx, int head_dim, int max_seq_len, cudaStream_t stream) { + if (seq_idx <= 0) { + throw std::runtime_error("seq_idx<=0"); + } + if (head_dim % 4 != 0) { + throw std::runtime_error("head_dim%4!=0"); + } + + int seq_len_pad = nearest_4(seq_idx + 1); + + IxinferDecSelfKvCatI8II8O<<>>(k, v, prev_k, prev_v, max_seq_len, seq_len_pad, + seq_idx, head_dim); +} + +__global__ void IxinferDecAttnOutArrangeI8II8OKernel(const int8_t *ori_q, int8_t *new_q, int bsz, int tgt_len, + int head_dim) { + char4 *p_ori_q = (char4 *)ori_q; + char4 *p_new_q; + int elem_idx = threadIdx.x + blockIdx.x * blockDim.x; + while (elem_idx < bsz * tgt_len * head_dim / 4) { + int i8_elem_idx = elem_idx * 4; + int tgt_len_head_dim = tgt_len * head_dim; + + int bsz_idx = i8_elem_idx / tgt_len_head_dim; + int tgt_len_idx = i8_elem_idx % tgt_len_head_dim / head_dim; + int dim_idx = i8_elem_idx % tgt_len_head_dim % head_dim; + + int tgt_index = tgt_len_idx * bsz * head_dim + bsz_idx * head_dim + dim_idx; + + p_new_q = (char4 *)(new_q + tgt_index); + p_new_q[0] = p_ori_q[elem_idx]; + + elem_idx += gridDim.x * blockDim.x; + } +} + +/* @ jian.wang +ori_q: bsz*head_num, tgt_len, head_dim +new_q: tgt_len, bsz*head_num, head_dim + +将 ori_q 的 0,1维度转置 +*/ +void IxinferDecAttnOutArrangeI8II8O(int8_t *ori_q, int8_t *new_q, int bsz, int tgt_len, int head_dim, + cudaStream_t stream) { + if (bsz * tgt_len * head_dim % 4 != 0) { + throw std::runtime_error("bsz*tgt_len*head_dim%4!=0"); + } + int num_threads = 512; + int num_blocks = ((bsz * tgt_len * head_dim / 4 - 1 + num_threads) / num_threads); + num_blocks = std::max(num_blocks, 128); + IxinferDecAttnOutArrangeI8II8OKernel<<>>(ori_q, new_q, bsz, tgt_len, head_dim); +} + +__global__ void IxinferArrangeDecEncQkvI8II8OKernel(const int8_t *ori_q, const int8_t *ori_kv, const __half *q_bias, + const __half *kv_bias, int8_t *new_q, int8_t *new_k, int8_t *new_v, + int bsz, int head_dim, int head_num, int tgt_len, int src_len, + float quant_scale, float dequant_scale) { + int hidden_size = head_dim * head_num; + int batch_id = blockIdx.x % bsz; + int token_id = blockIdx.x / bsz; + + int i = threadIdx.x; // 1个线程处理4个数据 + + int head_id = (i * 4) / head_dim; + int dim_id = (i * 4) % head_dim; + + char4 value; + char4 *p_ori_qkv; + half2 *p_qkv_bias; + char4 *p_new_qkv; + int target_id; + + if (token_id == 0) { + // q + // tgt_len,bsz,hsz + p_ori_qkv = (char4 *)(ori_q + batch_id * hidden_size); + p_qkv_bias = (half2 *)(q_bias + i * 4); + value.x = float2int8(float(p_ori_qkv[i].x) * dequant_scale + __half2float(p_qkv_bias[0].x), quant_scale); + value.y = float2int8(float(p_ori_qkv[i].y) * dequant_scale + __half2float(p_qkv_bias[0].y), quant_scale); + value.z = float2int8(float(p_ori_qkv[i].z) * dequant_scale + __half2float(p_qkv_bias[1].x), quant_scale); + value.w = float2int8(float(p_ori_qkv[i].w) * dequant_scale + __half2float(p_qkv_bias[1].y), quant_scale); + // bsz,head_num,tgt_len,head_dim + target_id = batch_id * head_num * tgt_len * head_dim + head_id * tgt_len * head_dim + dim_id; + p_new_qkv = (char4 *)(new_q + target_id); + p_new_qkv[0] = value; + } + + // k + p_ori_qkv = (char4 *)(ori_kv + blockIdx.x * 2 * hidden_size); + p_qkv_bias = (half2 *)(kv_bias + i * 4); + value.x = float2int8(float(p_ori_qkv[i].x) * dequant_scale + __half2float(p_qkv_bias[0].x), quant_scale); + value.y = float2int8(float(p_ori_qkv[i].y) * dequant_scale + __half2float(p_qkv_bias[0].y), quant_scale); + value.z = float2int8(float(p_ori_qkv[i].z) * dequant_scale + __half2float(p_qkv_bias[1].x), quant_scale); + value.w = float2int8(float(p_ori_qkv[i].w) * dequant_scale + __half2float(p_qkv_bias[1].y), quant_scale); + target_id = batch_id * head_num * src_len * head_dim + head_id * src_len * head_dim + token_id * head_dim + dim_id; + p_new_qkv = (char4 *)(new_k + target_id); + p_new_qkv[0] = value; + + // v + p_ori_qkv = (char4 *)(ori_kv + blockIdx.x * 2 * hidden_size + hidden_size); + p_qkv_bias = (half2 *)(kv_bias + hidden_size + i * 4); + value.x = float2int8(float(p_ori_qkv[i].x) * dequant_scale + __half2float(p_qkv_bias[0].x), quant_scale); + value.y = float2int8(float(p_ori_qkv[i].y) * dequant_scale + __half2float(p_qkv_bias[0].y), quant_scale); + value.z = float2int8(float(p_ori_qkv[i].z) * dequant_scale + __half2float(p_qkv_bias[1].x), quant_scale); + value.w = float2int8(float(p_ori_qkv[i].w) * dequant_scale + __half2float(p_qkv_bias[1].y), quant_scale); + target_id = batch_id * head_num * src_len * head_dim + head_id * src_len * head_dim + token_id * head_dim + dim_id; + p_new_qkv = (char4 *)(new_v + target_id); + p_new_qkv[0] = value; +} + +/* @jian.wang +for transformer decoder enc-attn + +q = dequant(ori_q) + q_bias +q: tgt_len,bsz,head_num,head_dim -> bsz,head_num,tgt_len,head_dim +只有tgt_len_idx = 0才需要写 + + +kv = dequant(ori_qkv) + qkv_bias +kv = quant(qkv) + +kv都为: src_len,batch_size,head_num,head_dim +kv src_len,bsz,head_num,head_dim -> bsz,head_num,src_len,head_dim +*/ +void IxinferArrangeDecEncQkvI8II8O(int8_t *ori_q, int8_t *ori_kv, int8_t *new_q, int8_t *new_k, int8_t *new_v, + __half *q_bias, __half *kv_bias, int bsz, int head_num, int head_dim, int tgt_len, + int src_len, float quant_scale, float dequant_scale, cudaStream_t stream) { + int hidden_size = head_dim * head_num; + if (hidden_size / 4 > 4096) { + throw std::runtime_error("hidden_size / 4 > 4096"); + } + if (head_dim != 64) { + throw std::runtime_error("head_dim!=64"); + } + + IxinferArrangeDecEncQkvI8II8OKernel<<>>( + ori_q, ori_kv, q_bias, kv_bias, new_q, new_k, new_v, bsz, head_dim, head_num, tgt_len, src_len, quant_scale, + dequant_scale); +} + +__global__ void IxinferArrangeDecEncQI8II8OKernel(const int8_t *ori_q, const __half *q_bias, int8_t *new_q, int bsz, + int head_dim, int head_num, int tgt_len, float quant_scale, + float dequant_scale) { + int hidden_size = head_dim * head_num; + int batch_id = blockIdx.x % bsz; + int token_id = blockIdx.x / bsz; + + int i = threadIdx.x; // 1个线程处理4个数据 + + int head_id = (i * 4) / head_dim; + int dim_id = (i * 4) % head_dim; + + char4 value; + char4 *p_ori_qkv; + half2 *p_qkv_bias; + char4 *p_new_qkv; + int target_id; + + // q + // tgt_len,bsz,hsz + p_ori_qkv = (char4 *)(ori_q + batch_id * hidden_size); + p_qkv_bias = (half2 *)(q_bias + i * 4); + value.x = float2int8(float(p_ori_qkv[i].x) * dequant_scale + __half2float(p_qkv_bias[0].x), quant_scale); + value.y = float2int8(float(p_ori_qkv[i].y) * dequant_scale + __half2float(p_qkv_bias[0].y), quant_scale); + value.z = float2int8(float(p_ori_qkv[i].z) * dequant_scale + __half2float(p_qkv_bias[1].x), quant_scale); + value.w = float2int8(float(p_ori_qkv[i].w) * dequant_scale + __half2float(p_qkv_bias[1].y), quant_scale); + // bsz,head_num,tgt_len,head_dim + target_id = batch_id * head_num * tgt_len * head_dim + head_id * tgt_len * head_dim + dim_id; + p_new_qkv = (char4 *)(new_q + target_id); + p_new_qkv[0] = value; +} + +/* @jian.wang +for transformer decoder enc-attn + +q = dequant(ori_q) + q_bias +q: tgt_len,bsz,head_num,head_dim -> bsz,head_num,tgt_len,head_dim +只有tgt_len_idx = 0才需要写 +*/ +void IxinferArrangeDecEncQI8II8O(int8_t *ori_q, int8_t *new_q, __half *q_bias, int bsz, int head_num, int head_dim, + int tgt_len, float quant_scale, float dequant_scale, cudaStream_t stream) { + int hidden_size = head_dim * head_num; + if (hidden_size / 4 > 4096) { + throw std::runtime_error("hidden_size / 4 > 4096"); + } + if (head_dim != 64) { + throw std::runtime_error("head_dim!=64"); + } + + IxinferArrangeDecEncQI8II8OKernel<<>>( + ori_q, q_bias, new_q, bsz, head_dim, head_num, tgt_len, quant_scale, dequant_scale); +} + +__global__ void IxinferArrangeDecSelfQkvKernel(const half *ori_qkv, const half *qkv_bias, half *new_q, half *new_k, + half *new_v, half *prev_k, half *prev_v, int seq_idx, int head_dim, + int head_num, int tgt_len, int src_len, int max_dec_len) { + int hsz = head_dim * head_num; + int batch_id = blockIdx.x; + int token_id = 0; + + int i = threadIdx.x; // 1个线程处理2个数据 + + int head_id = (i * 2) / head_dim; + int dim_id = (i * 2) % head_dim; + + half2 *p_ori_qkv; + half2 *p_qkv_bias; + half2 value; + half2 *p_new_qkv; + int target_id; + + // q + // 1,bsz,hsz -> bsz,head_num,0,head_dim + p_ori_qkv = (half2 *)(ori_qkv + batch_id * hsz * 3); + p_qkv_bias = (half2 *)(qkv_bias + i * 2); + value.x = __float2half(__half2float(p_ori_qkv[i].x) + __half2float(p_qkv_bias[0].x)); + value.y = __float2half(__half2float(p_ori_qkv[i].y) + __half2float(p_qkv_bias[0].y)); + target_id = batch_id * head_num * tgt_len * head_dim + head_id * tgt_len * head_dim + dim_id; + p_new_qkv = (half2 *)(new_q + target_id); + p_new_qkv[0] = value; + + // k + // 1,bsz,hsz -> bsz,head_num,seq_idx,head_dim + p_ori_qkv = (half2 *)(ori_qkv + batch_id * hsz * 3 + hsz); + p_qkv_bias = (half2 *)(qkv_bias + hsz + i * 2); + value.x = __float2half(__half2float(p_ori_qkv[i].x) + __half2float(p_qkv_bias[0].x)); + value.y = __float2half(__half2float(p_ori_qkv[i].y) + __half2float(p_qkv_bias[0].y)); + target_id = batch_id * head_num * src_len * head_dim + head_id * src_len * head_dim + seq_idx * head_dim + dim_id; + p_new_qkv = (half2 *)(new_k + target_id); + p_new_qkv[0] = value; + // update to prev + target_id = + batch_id * head_num * max_dec_len * head_dim + head_id * max_dec_len * head_dim + seq_idx * head_dim + dim_id; + p_new_qkv = (half2 *)(prev_k + target_id); + p_new_qkv[0] = value; + + // v + // 1,bsz,hsz -> bsz,head_num,seq_idx,head_dim + p_ori_qkv = (half2 *)(ori_qkv + batch_id * hsz * 3 + hsz * 2); + p_qkv_bias = (half2 *)(qkv_bias + hsz * 2 + i * 2); + value.x = __float2half(__half2float(p_ori_qkv[i].x) + __half2float(p_qkv_bias[0].x)); + value.y = __float2half(__half2float(p_ori_qkv[i].y) + __half2float(p_qkv_bias[0].y)); + target_id = batch_id * head_num * src_len * head_dim + head_id * src_len * head_dim + seq_idx * head_dim + dim_id; + p_new_qkv = (half2 *)(new_v + target_id); + p_new_qkv[0] = value; + // update to prev + target_id = + batch_id * head_num * max_dec_len * head_dim + head_id * max_dec_len * head_dim + seq_idx * head_dim + dim_id; + p_new_qkv = (half2 *)(prev_v + target_id); + p_new_qkv[0] = value; +} + +void IxinferArrangeDecSelfQkv(const half *ori_qkv, const half *qkv_bias, half *new_q, half *new_k, half *new_v, + half *prev_k, half *prev_v, int bsz, int seq_idx, int head_dim, int head_num, int tgt_len, + int src_len, int max_dec_len, cudaStream_t stream) { + int hsz = head_num * head_dim; + if (hsz / 2 > 4096) { + throw std::runtime_error("hidden_size / 2 > 4096"); + } + if (hsz % 2 != 0) { + throw std::runtime_error("hsz % 2 != 0"); + } + if (head_dim != 64) { + throw std::runtime_error("head_dim!=64"); + } + + IxinferArrangeDecSelfQkvKernel<<>>(ori_qkv, qkv_bias, new_q, new_k, new_v, prev_k, prev_v, + seq_idx, head_dim, head_num, tgt_len, src_len, + max_dec_len); +}; + +__global__ void IxinferDecSelfKvCatKernel(half *k, half *v, half *prev_k, half *prev_v, int seq_idx, int head_dim, + int src_len, int max_dec_len) { + k += blockIdx.x * src_len * head_dim; + v += blockIdx.x * src_len * head_dim; + prev_k += blockIdx.x * max_dec_len * head_dim; + prev_v += blockIdx.x * max_dec_len * head_dim; + + half2 *p_k = (half2 *)k; + half2 *p_v = (half2 *)v; + half2 *p_prev_k = (half2 *)prev_k; + half2 *p_prev_v = (half2 *)prev_v; + int idx = threadIdx.x; + while (idx < seq_idx * head_dim / 2) { + p_k[idx] = p_prev_k[idx]; + p_v[idx] = p_prev_v[idx]; + idx += blockDim.x; + } +} + +void IxinferDecSelfKvCat(half *k, half *v, half *prev_k, half *prev_v, int bsz, int head_num, int seq_idx, int head_dim, + int src_len, int max_dec_len, cudaStream_t stream) { + if (seq_idx <= 0) { + throw std::runtime_error("seq_idx<=0"); + } + if (head_dim % 2 != 0) { + throw std::runtime_error("head_dim%2!=0"); + } + if (src_len > max_dec_len) { + throw std::runtime_error("src_len>max_dec_len"); + } + if (seq_idx + 1 > src_len) { + throw std::runtime_error("seq_idx+1>src_len"); + } + IxinferDecSelfKvCatKernel<<>>(k, v, prev_k, prev_v, seq_idx, head_dim, src_len, + max_dec_len); +}; + +__global__ void IxinferDecAttnOutArrangeKernel(const half *ori_q, half *new_q, int bsz, int tgt_len, int head_dim) { + half2 *p_ori_q = (half2 *)ori_q; + half2 *p_new_q; + int elem_idx = threadIdx.x + blockIdx.x * blockDim.x; + while (elem_idx < bsz * tgt_len * head_dim / 2) { + int half_elem_idx = elem_idx * 2; + int tgt_len_head_dim = tgt_len * head_dim; + + int bsz_idx = half_elem_idx / tgt_len_head_dim; + int tgt_len_idx = half_elem_idx % tgt_len_head_dim / head_dim; + int dim_idx = half_elem_idx % tgt_len_head_dim % head_dim; + + int tgt_index = tgt_len_idx * bsz * head_dim + bsz_idx * head_dim + dim_idx; + + p_new_q = (half2 *)(new_q + tgt_index); + p_new_q[0] = p_ori_q[elem_idx]; + + elem_idx += gridDim.x * blockDim.x; + } +} + +void IxinferDecAttnOutArrange(half *ori_q, half *new_q, int bsz, int tgt_len, int head_dim, cudaStream_t stream) { + if (bsz * tgt_len * head_dim % 2 != 0) { + throw std::runtime_error("bsz*tgt_len*head_dim%2!=0"); + } + int num_threads = 512; + int num_blocks = ((bsz * tgt_len * head_dim / 2 - 1 + num_threads) / num_threads); + num_blocks = std::min(num_blocks, 128); + IxinferDecAttnOutArrangeKernel<<>>(ori_q, new_q, bsz, tgt_len, head_dim); +} + +__global__ void IxinferArrangeDecEncQKernel(const half *ori_q, const half *q_bias, half *new_q, int bsz, int head_dim, + int head_num, int tgt_len) { + int hidden_size = head_dim * head_num; + int batch_id = blockIdx.x % bsz; + int token_id = blockIdx.x / bsz; + + int i = threadIdx.x; // 1个线程处理2个数据 + + int head_id = (i * 2) / head_dim; + int dim_id = (i * 2) % head_dim; + + half2 value; + half2 *p_ori_qkv; + half2 *p_qkv_bias; + half2 *p_new_qkv; + int target_id; + + // q + // tgt_len,bsz,hsz + p_ori_qkv = (half2 *)(ori_q + batch_id * hidden_size); + p_qkv_bias = (half2 *)(q_bias + i * 2); + value.x = __float2half(__half2float(p_ori_qkv[i].x) + __half2float(p_qkv_bias[0].x)); + value.y = __float2half(__half2float(p_ori_qkv[i].y) + __half2float(p_qkv_bias[0].y)); + // bsz,head_num,tgt_len,head_dim + target_id = batch_id * head_num * tgt_len * head_dim + head_id * tgt_len * head_dim + dim_id; + p_new_qkv = (half2 *)(new_q + target_id); + p_new_qkv[0] = value; +} + +void IxinferArrangeDecEncQ(half *ori_q, half *new_q, half *q_bias, int bsz, int head_num, int head_dim, int tgt_len, + cudaStream_t stream) { + int hidden_size = head_dim * head_num; + if (hidden_size / 2 > 4096) { + throw std::runtime_error("hidden_size / 2 > 4096"); + } + if (head_dim != 64) { + throw std::runtime_error("head_dim!=64"); + } + + IxinferArrangeDecEncQKernel<<>>(ori_q, q_bias, new_q, bsz, head_dim, head_num, + tgt_len); +} + +__global__ void IxinferArrangeDecEncQkvKernel(const half *ori_q, const half *ori_kv, const half *q_bias, + const half *kv_bias, half *new_q, half *new_k, half *new_v, int bsz, + int head_dim, int head_num, int tgt_len, int src_len) { + int hidden_size = head_dim * head_num; + int batch_id = blockIdx.x % bsz; + int token_id = blockIdx.x / bsz; + + int i = threadIdx.x; // 1个线程处理2个数据 + + int head_id = (i * 2) / head_dim; + int dim_id = (i * 2) % head_dim; + + half2 value; + half2 *p_ori_qkv; + half2 *p_qkv_bias; + half2 *p_new_qkv; + int target_id; + + if (token_id == 0) { + // q + // tgt_len,bsz,hsz + p_ori_qkv = (half2 *)(ori_q + batch_id * hidden_size); + p_qkv_bias = (half2 *)(q_bias + i * 2); + value.x = __float2half(__half2float(p_ori_qkv[i].x) + __half2float(p_qkv_bias[0].x)); + value.y = __float2half(__half2float(p_ori_qkv[i].y) + __half2float(p_qkv_bias[0].y)); + // bsz,head_num,tgt_len,head_dim + target_id = batch_id * head_num * tgt_len * head_dim + head_id * tgt_len * head_dim + dim_id; + p_new_qkv = (half2 *)(new_q + target_id); + p_new_qkv[0] = value; + } + + // k + p_ori_qkv = (half2 *)(ori_kv + blockIdx.x * 2 * hidden_size); + p_qkv_bias = (half2 *)(kv_bias + i * 2); + value.x = __float2half(__half2float(p_ori_qkv[i].x) + __half2float(p_qkv_bias[0].x)); + value.y = __float2half(__half2float(p_ori_qkv[i].y) + __half2float(p_qkv_bias[0].y)); + target_id = batch_id * head_num * src_len * head_dim + head_id * src_len * head_dim + token_id * head_dim + dim_id; + p_new_qkv = (half2 *)(new_k + target_id); + p_new_qkv[0] = value; + + // v + p_ori_qkv = (half2 *)(ori_kv + blockIdx.x * 2 * hidden_size + hidden_size); + p_qkv_bias = (half2 *)(kv_bias + hidden_size + i * 2); + value.x = __float2half(__half2float(p_ori_qkv[i].x) + __half2float(p_qkv_bias[0].x)); + value.y = __float2half(__half2float(p_ori_qkv[i].y) + __half2float(p_qkv_bias[0].y)); + target_id = batch_id * head_num * src_len * head_dim + head_id * src_len * head_dim + token_id * head_dim + dim_id; + p_new_qkv = (half2 *)(new_v + target_id); + p_new_qkv[0] = value; +} + +/* @jian.wang +for transformer decoder enc-attn + +ori_q: [tgt_len,bsz,head_num*head_dim] +ori_kv: [src_len,bsz,head_num*head_dim*2] + +1. new_q 的计算, 同 IxinferArrangeDecEncQ +ori_q = ori_q + bias +new_q[:,:,0,:] = ori_q.permute(1,2,0,3)[:,:,0,:] +2. new_k, new_v的计算 +ori_kv = ori_kv + kv_bias +k,v = ori_kv.chunk(2,dim=-1) + +new_k = k.permute(1,2,0,3) +new_v = v.permute(1,2,0,3) +*/ +void IxinferArrangeDecEncQkv(half *ori_q, half *ori_kv, half *new_q, half *new_k, half *new_v, half *q_bias, + half *kv_bias, int bsz, int head_num, int head_dim, int tgt_len, int src_len, + cudaStream_t stream) { + int hidden_size = head_dim * head_num; + if (hidden_size / 2 > 4096) { + throw std::runtime_error("hidden_size / 2 > 4096"); + } + if (head_dim != 64) { + throw std::runtime_error("head_dim!=64"); + } + + IxinferArrangeDecEncQkvKernel<<>>( + ori_q, ori_kv, q_bias, kv_bias, new_q, new_k, new_v, bsz, head_dim, head_num, tgt_len, src_len); +} + +void __global__ IxinferArrangeEncQkvKernel(half *ori_qkv, half *qkv_bias, half *new_q, half *new_k, half *new_v, + int head_dim, int head_num, int batch_seq_len, int fmha_seq_len) { + int hidden_size = head_dim * head_num; + int batch_id = blockIdx.x; + int token_id = blockIdx.y; + + int i = threadIdx.x; // 1个线程处理2个数据 + int head_id = (i * 2) / head_dim; + int dim_id = (i * 2) % head_dim; + + half2 *p_ori_qkv = (half2 *)(ori_qkv + batch_id * batch_seq_len * hidden_size * 3 + token_id * hidden_size * 3); + half2 *p_qkv_bias = (half2 *)(qkv_bias); + half2 value; + half2 *p_new_qkv; + + int target_id = batch_id * head_num * fmha_seq_len * head_dim + head_id * fmha_seq_len * head_dim + + token_id * head_dim + dim_id; + /* q */ + value = __hadd2(p_ori_qkv[i], p_qkv_bias[i]); + p_new_qkv = (half2 *)(new_q + target_id); + p_new_qkv[0] = value; + /* k */ + p_ori_qkv += hidden_size / 2; + p_qkv_bias += hidden_size / 2; + value = __hadd2(p_ori_qkv[i], p_qkv_bias[i]); + p_new_qkv = (half2 *)(new_k + target_id); + p_new_qkv[0] = value; + /* v */ + p_ori_qkv += hidden_size / 2; + p_qkv_bias += hidden_size / 2; + value = __hadd2(p_ori_qkv[i], p_qkv_bias[i]); + p_new_qkv = (half2 *)(new_v + target_id); + p_new_qkv[0] = value; +} + +void IxinferArrangeEncQkv(half *ori_qkv, half *ori_qkv_bias, half *new_q, half *new_k, half *new_v, int bsz, + int head_num, int head_dim, int ori_seq_len, int fmha_seq_len, cudaStream_t stream) { + int hsz = head_num * head_dim; + if (hsz / 2 > 4096) { + throw std::runtime_error("hidden_size / 2 > 4096"); + } + if (hsz % 2 != 0) { + throw std::runtime_error("hsz % 2 != 0"); + } + if (head_dim % 2 != 0) { + throw std::runtime_error("head_dim %2 != 0"); + } + // std::cout << "ori_seq_len: " << ori_seq_len << std::endl; + // std::cout << "fmha_seq_len: " << fmha_seq_len << std::endl; + dim3 blockSize(bsz, ori_seq_len); + IxinferArrangeEncQkvKernel<<>>(ori_qkv, ori_qkv_bias, new_q, new_k, new_v, head_dim, + head_num, ori_seq_len, fmha_seq_len); +} +// add by pxl for stable diffusion cross attentio arrange qkv and arrange out +__global__ void IxinferCrossAttnOutArrangeKernel(const half *ori_q, half *new_q, int bsz, int tgt_len, int num_head, + int head_dim) { + half2 *p_ori_q = (half2 *)ori_q; + half2 *p_new_q; + int hidden_size = num_head * head_dim; + int elem_idx = threadIdx.x + blockIdx.x * blockDim.x; + while (elem_idx < bsz * tgt_len * hidden_size / 2) { + int half_elem_idx = elem_idx * 2; + int tgt_len_hidden_size = tgt_len * hidden_size; + + int bsz_idx = half_elem_idx / tgt_len_hidden_size; + int mod = half_elem_idx % tgt_len_hidden_size; + + int head_id = mod / (tgt_len * head_dim); + mod = mod % (tgt_len * head_dim); + int tgt_len_idx = mod / head_dim; + int dim_idx = mod % head_dim; + + // int tgt_index = tgt_len_idx * bsz * head_dim + bsz_idx * head_dim + dim_idx; + int tgt_index = (bsz_idx * tgt_len + tgt_len_idx) * hidden_size + head_id * head_dim + dim_idx; + + p_new_q = (half2 *)(new_q + tgt_index); + p_new_q[0] = p_ori_q[elem_idx]; + + elem_idx += gridDim.x * blockDim.x; + } +} +// cross attention: arrangeout +void IxinferCrossAttnOutArrange(half *ori_q, half *new_q, int bsz, int tgt_len, int num_head, int head_dim, + cudaStream_t stream) { + if (bsz * num_head * tgt_len * head_dim % 2 != 0) { + throw std::runtime_error("bsz*tgt_len*head_dim%2!=0"); + } + int num_threads = 512; + int num_blocks = ((bsz * tgt_len * num_head * head_dim / 2 - 1 + num_threads) / num_threads); + num_blocks = std::min(num_blocks, 128); + IxinferCrossAttnOutArrangeKernel<<>>(ori_q, new_q, bsz, tgt_len, num_head, + head_dim); +} +__global__ void IxinferArrangeCrossQkvKernel(const half *ori_q, const half *ori_kv, const half *q_bias, + const half *kv_bias, half *new_q, half *new_k, half *new_v, int bsz, + int head_num, int head_dim, int tgt_len, int src_len, int fmha_src_len) { + int hidden_size = head_dim * head_num; + int batch_id = blockIdx.x % bsz; + int token_id = blockIdx.x / bsz; + + int i = threadIdx.x; // 1个线程处理2个数据 + // if(blockIdx.x==0 && threadIdx.x==0) + // { + // printf("cuda %f %f %f \n",__half2float(ori_q[0]), __half2float(q_bias[0]),__half2float(ori_q[0]) + + // __half2float(q_bias[0])); + // } + int head_id = (i * 2) / head_dim; + int dim_id = (i * 2) % head_dim; + + half2 value; + half2 *p_ori_qkv; + half2 *p_qkv_bias; + half2 *p_new_qkv; + int target_id; + + if (token_id < tgt_len) { + // q + // tgt_len,bsz,hsz + p_ori_qkv = (half2 *)(ori_q + (batch_id * tgt_len + token_id) * hidden_size); + p_qkv_bias = (half2 *)(q_bias + i * 2); + value.x = __float2half(__half2float(p_ori_qkv[i].x) + __half2float(p_qkv_bias[0].x)); + value.y = __float2half(__half2float(p_ori_qkv[i].y) + __half2float(p_qkv_bias[0].y)); + // bsz,head_num,tgt_len,head_dim + target_id = + batch_id * head_num * tgt_len * head_dim + head_id * tgt_len * head_dim + token_id * head_dim + dim_id; + p_new_qkv = (half2 *)(new_q + target_id); + p_new_qkv[0] = value; + // if(blockIdx.x==0 && threadIdx.x==0) + // { + // printf("cuda %f %f %f \n",__half2float(p_ori_qkv[i].x), + // __half2float(p_qkv_bias[0].x),__half2float(p_ori_qkv[i].x) + __half2float(p_qkv_bias[0].x)); printf("cuda + // %f %f %f \n",__half2float(ori_q[0]), __half2float(q_bias[0]),__half2float(ori_q[0]) + + // __half2float(q_bias[0])); + // } + } + if (token_id < src_len) { + // k + p_ori_qkv = (half2 *)(ori_kv + (batch_id * src_len + token_id) * 2 * hidden_size); + p_qkv_bias = (half2 *)(kv_bias + i * 2); + value.x = __float2half(__half2float(p_ori_qkv[i].x) + __half2float(p_qkv_bias[0].x)); + value.y = __float2half(__half2float(p_ori_qkv[i].y) + __half2float(p_qkv_bias[0].y)); + target_id = batch_id * head_num * fmha_src_len * head_dim + head_id * fmha_src_len * head_dim + + token_id * head_dim + dim_id; + p_new_qkv = (half2 *)(new_k + target_id); + p_new_qkv[0] = value; + + // v + p_ori_qkv = (half2 *)(ori_kv + (batch_id * src_len + token_id) * 2 * hidden_size + hidden_size); + p_qkv_bias = (half2 *)(kv_bias + hidden_size + i * 2); + value.x = __float2half(__half2float(p_ori_qkv[i].x) + __half2float(p_qkv_bias[0].x)); + value.y = __float2half(__half2float(p_ori_qkv[i].y) + __half2float(p_qkv_bias[0].y)); + target_id = batch_id * head_num * fmha_src_len * head_dim + head_id * fmha_src_len * head_dim + + token_id * head_dim + dim_id; + p_new_qkv = (half2 *)(new_v + target_id); + p_new_qkv[0] = value; + } +} + +/* @xuelu.peng +for transformer/stable_diffusion cross atten tgt_len <=4096 +ori_q: [bsz,tgt_len,head_num*head_dim] +ori_kv: [bsz,src_len,head_num*head_dim*2] + +1. new_q 的计算, +ori_q = ori_q + bias +new_q = ori_q.permute(1,2,0,3) +2. new_k, new_v的计算 +ori_kv = ori_kv + kv_bias +k,v = ori_kv.chunk(2,dim=-1) +new_k = k.permute(1,2,0,3).padded fmha_src_len +new_v = v.permute(1,2,0,3).padded fmha_src_len +*/ +void IxinferArrangeCrossQkv(half *ori_q, half *ori_kv, half *new_q, half *new_k, half *new_v, half *q_bias, + half *kv_bias, int bsz, int head_num, int head_dim, int tgt_len, int src_len, + int fmha_src_len, cudaStream_t stream) { + int hidden_size = head_dim * head_num; + if (hidden_size / 2 > 4096) { + throw std::runtime_error("hidden_size / 2 > 4096"); + } + // if (head_dim != 64) { + // throw std::runtime_error("head_dim!=64"); + // } + if (head_dim % 2 != 0) { + throw std::runtime_error("head_dim%2 != 0"); + } + int max_len = src_len > tgt_len ? src_len : tgt_len; + + IxinferArrangeCrossQkvKernel<<>>( + ori_q, ori_kv, q_bias, kv_bias, new_q, new_k, new_v, bsz, head_num, head_dim, tgt_len, src_len, fmha_src_len); +} +__global__ void IxinferEncAttnOutArrangeKernel(const half *ori_q, half *new_q, const int bsz, const int ori_seq_len, + const int fmha_seq_len, const int head_num, const int head_dim) { + half2 *p_ori_q = (half2 *)ori_q; + half2 *p_new_q = (half2 *)new_q; + + int batch_token_num = ori_seq_len * head_dim * head_num; + int hidden_size = head_dim * head_num; + int date_length = bsz * ori_seq_len * head_num * head_dim; + + int elem_idx = threadIdx.x + blockIdx.x * blockDim.x; + while (elem_idx < date_length / 2) { + int half_elem_idx = elem_idx * 2; + + int bsz_idx = half_elem_idx / batch_token_num; + int seq_idx = half_elem_idx % batch_token_num / hidden_size; + int head_idx = half_elem_idx % batch_token_num % hidden_size / head_dim; + int dim_idx = half_elem_idx % batch_token_num % hidden_size % head_dim; + + int src_index = bsz_idx * head_num * fmha_seq_len * head_dim + head_idx * fmha_seq_len * head_dim + + seq_idx * head_dim + dim_idx; + + p_new_q[elem_idx] = p_ori_q[src_index / 2]; + + elem_idx += gridDim.x * blockDim.x; + } +} + +void IxinferEncAttnOutArrange(half *ori_q, half *new_q, int bsz, int ori_seq_len, int fmha_seq_len, int head_num, + int head_dim, cudaStream_t stream) { + if (bsz * ori_seq_len * head_num * head_dim % 2 != 0) { + throw std::runtime_error("bsz * ori_seq_len * head_num * head_dim % 2 != 0"); + } + int data_length = bsz * ori_seq_len * head_num * head_dim / 2; + int num_threads = 512; + int num_blocks = ((data_length - 1 + num_threads) / num_threads); + num_blocks = std::min(num_blocks, 128); + IxinferEncAttnOutArrangeKernel<<>>(ori_q, new_q, bsz, ori_seq_len, fmha_seq_len, + head_num, head_dim); +} + +void __global__ IxinferArrangeGPT2ContextQkvKernel(half *ori_qkv, half *new_q, half *new_k, half *new_v, half *prev_key, + half *prev_value, int head_dim, int head_num, int ori_seq_len, + int fmha_seq_len, int prev_len) { + int hidden_size = head_dim * head_num; + int batch_id = blockIdx.x; + int token_id = blockIdx.y; + + half2 *p_ori_qkv = (half2 *)(ori_qkv + batch_id * ori_seq_len * hidden_size * 3 + token_id * hidden_size * 3); + half2 *p_new_q = (half2 *)new_q; + half2 *p_new_k = (half2 *)new_k; + half2 *p_new_v = (half2 *)new_v; + half2 *p_prev_key = (half2 *)prev_key; + half2 *p_prev_value = (half2 *)prev_value; + + int cur_idx = threadIdx.x; // 1个线程处理2个数据 + + while (cur_idx < hidden_size / 2) { + int head_id = (cur_idx * 2) / head_dim; + int dim_id = (cur_idx * 2) % head_dim; + + int target_id = batch_id * head_num * fmha_seq_len * head_dim + head_id * fmha_seq_len * head_dim + + token_id * head_dim + dim_id; + // q + p_new_q[target_id / 2] = p_ori_qkv[cur_idx]; + // k + p_new_k[target_id / 2] = p_ori_qkv[cur_idx + hidden_size / 2]; + // v + p_new_v[target_id / 2] = p_ori_qkv[cur_idx + hidden_size]; + + // prev_key, prev_value + int target_id_prev = + batch_id * head_num * prev_len * head_dim + head_id * prev_len * head_dim + token_id * head_dim + dim_id; + p_prev_key[target_id_prev / 2] = p_ori_qkv[cur_idx + hidden_size / 2]; + p_prev_value[target_id_prev / 2] = p_ori_qkv[cur_idx + hidden_size]; + + cur_idx += blockDim.x; + } +} + +void IxinferArrangeGPT2ContextQkv(half *ori_qkv, half *new_q, half *new_k, half *new_v, half *prev_key, + half *prev_value, int bsz, int head_num, int head_dim, int ori_seq_len, + int fmha_seq_len, int prev_len, cudaStream_t stream) { + if (head_dim % 2 != 0) { + throw std::runtime_error("head_dim % 2 != 0"); + } + dim3 blockSize(bsz, ori_seq_len); + IxinferArrangeGPT2ContextQkvKernel<<>>( + ori_qkv, new_q, new_k, new_v, prev_key, prev_value, head_dim, head_num, ori_seq_len, fmha_seq_len, prev_len); +} + +__global__ void IxinferArrangeGPT2SelfQkvKernel(int32_t *length, const half *ori_qkv, half *new_q, half *new_k, + half *new_v, half *prev_k, half *prev_v, int head_dim, int head_num, + int tgt_len, int src_len, int max_dec_len) { + int hsz = head_dim * head_num; + int batch_id = blockIdx.x; + int token_id = 0; + + int i = threadIdx.x; // 1个线程处理2个数据 + int seq_idx = length[batch_id] - 1; + + int head_id = (i * 2) / head_dim; + int dim_id = (i * 2) % head_dim; + + half2 *p_ori_qkv; + half2 *p_qkv_bias; + half2 value; + half2 *p_new_qkv; + int target_id; + + // q + // 1,bsz,hsz -> bsz,head_num,0,head_dim + p_ori_qkv = (half2 *)(ori_qkv + batch_id * hsz * 3); + value = p_ori_qkv[i]; + target_id = batch_id * head_num * tgt_len * head_dim + head_id * tgt_len * head_dim + dim_id; + p_new_qkv = (half2 *)(new_q + target_id); + p_new_qkv[0] = value; + + // k + // 1,bsz,hsz -> bsz,head_num,seq_idx,head_dim + p_ori_qkv = (half2 *)(ori_qkv + batch_id * hsz * 3 + hsz); + value = p_ori_qkv[i]; + target_id = batch_id * head_num * src_len * head_dim + head_id * src_len * head_dim + seq_idx * head_dim + dim_id; + p_new_qkv = (half2 *)(new_k + target_id); + p_new_qkv[0] = value; + // update to prev + target_id = + batch_id * head_num * max_dec_len * head_dim + head_id * max_dec_len * head_dim + seq_idx * head_dim + dim_id; + p_new_qkv = (half2 *)(prev_k + target_id); + p_new_qkv[0] = value; + + // v + // 1,bsz,hsz -> bsz,head_num,seq_idx,head_dim + p_ori_qkv = (half2 *)(ori_qkv + batch_id * hsz * 3 + hsz * 2); + value = p_ori_qkv[i]; + target_id = batch_id * head_num * src_len * head_dim + head_id * src_len * head_dim + seq_idx * head_dim + dim_id; + p_new_qkv = (half2 *)(new_v + target_id); + p_new_qkv[0] = value; + // update to prev + target_id = + batch_id * head_num * max_dec_len * head_dim + head_id * max_dec_len * head_dim + seq_idx * head_dim + dim_id; + p_new_qkv = (half2 *)(prev_v + target_id); + p_new_qkv[0] = value; +} + +void IxinferArrangeGPT2SelfQkv(int32_t *length, const half *ori_qkv, half *new_q, half *new_k, half *new_v, + half *prev_k, half *prev_v, int bsz, int head_dim, int head_num, int tgt_len, + int src_len, int max_dec_len, cudaStream_t stream) { + int hsz = head_num * head_dim; + if (hsz / 2 > 4096) { + throw std::runtime_error("hidden_size / 2 > 4096"); + } + if (hsz % 2 != 0) { + throw std::runtime_error("hsz % 2 != 0"); + } + if (head_dim != 64) { + throw std::runtime_error("head_dim!=64"); + } + + IxinferArrangeGPT2SelfQkvKernel<<>>(length, ori_qkv, new_q, new_k, new_v, prev_k, prev_v, + head_dim, head_num, tgt_len, src_len, max_dec_len); +}; + +__global__ void IxinferGPT2SelfKvCatKernel(half *k, half *v, half *prev_k, half *prev_v, int head_dim, int src_len, + int max_dec_len) { + k += blockIdx.x * src_len * head_dim; + v += blockIdx.x * src_len * head_dim; + prev_k += blockIdx.x * max_dec_len * head_dim; + prev_v += blockIdx.x * max_dec_len * head_dim; + + half2 *p_k = (half2 *)k; + half2 *p_v = (half2 *)v; + half2 *p_prev_k = (half2 *)prev_k; + half2 *p_prev_v = (half2 *)prev_v; + int idx = threadIdx.x; + while (idx < src_len * head_dim / 2) { + p_k[idx] = p_prev_k[idx]; + p_v[idx] = p_prev_v[idx]; + idx += blockDim.x; + } +} + +void IxinferGPT2SelfKvCat(half *k, half *v, half *prev_k, half *prev_v, int bsz, int head_num, int head_dim, + int src_len, int max_dec_len, cudaStream_t stream) { + if (head_dim % 2 != 0) { + throw std::runtime_error("head_dim%2!=0"); + } + if (src_len > max_dec_len) { + throw std::runtime_error("src_len>max_dec_len"); + } + IxinferGPT2SelfKvCatKernel<<>>(k, v, prev_k, prev_v, head_dim, src_len, + max_dec_len); +} + +} // namespace backend +} // namespace plugin +} // namespace nvinfer1 diff --git a/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/backend/transformer/transformer_arrange.h b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/backend/transformer/transformer_arrange.h new file mode 100644 index 0000000000000000000000000000000000000000..153269eb2a3f4974108b818bb0a91211f7f8d30a --- /dev/null +++ b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/backend/transformer/transformer_arrange.h @@ -0,0 +1,277 @@ +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +* All Rights Reserved. +* +* Licensed under the Apache License, Version 2.0 (the "License"); you may +* not use this file except in compliance with the License. You may obtain +* a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +* License for the specific language governing permissions and limitations +* under the License. +*/ +#pragma once +#include +#include + +#include + +namespace nvinfer1 { +namespace plugin { +namespace backend { + +/* @jian.wang +for transformer decoder self-attn + +qkv = dequant(ori_qkv) + qkv_bias +qkv = quant(qkv) + +q,k,v split +qkv都为: 1,batch_size,head_num,head_dim +qkv 0123->0213 + +q写到 batch_size*head_num,4,head_dim seq_idx = 0 +1. kv写到: batch_size*head_num,nearesr_4(seq_len),head_dim seq_idx = seq_idx +2. kv写到prev_k,prev_v: batch_size*head_num,max_seq_len,head_dim seq_idx = +seq_idx +*/ +void IxinferArrangeDecSelfQkvI8II8O(int batch_token_num, int hidden_size, const int8_t *ori_qkv, const __half *qkv_bias, + int8_t *new_q, int8_t *new_k, int8_t *new_v, int8_t *prev_k, int8_t *prev_v, + int seq_idx, int batch_seq_len, int dim_per_head, int head_num, int max_seq_len, + float quant_scale, float dequant_scale, cudaStream_t stream); + +/* @ jian.wang +k: bsz*head_num, seq_len_pad, head_dim +prev_k: bsz*head_num, seq_idx, head_dim + +将 prev_k 写到 k[bsz*head_num, :seq_idx, head_dim] +*/ +void IxinferDecSelfKvCatI8II8O(int8_t *k, int8_t *v, int8_t *prev_k, int8_t *prev_v, int batches, int head_num, + int seq_idx, int head_dim, int max_seq_len, cudaStream_t stream); + +/* @ jian.wang +ori_q: bsz*head_num, tgt_len, head_dim +new_q: tgt_len, bsz*head_num, head_dim + +将 ori_q 的 0,1维度转置 +*/ +void IxinferDecAttnOutArrangeI8II8O(int8_t *ori_q, int8_t *new_q, int bsz, int tgt_len, int head_dim, + cudaStream_t stream); + +/* @jian.wang +for transformer decoder enc-attn + +q = dequant(ori_q) + q_bias +q: tgt_len,bsz,head_num,head_dim -> bsz,head_num,tgt_len,head_dim +只有tgt_len_idx = 0才需要写 + + +kv = dequant(ori_qkv) + qkv_bias +kv = quant(qkv) + +kv都为: src_len,batch_size,head_num,head_dim +kv src_len,bsz,head_num,head_dim -> bsz,head_num,src_len,head_dim +*/ +void IxinferArrangeDecEncQkvI8II8O(int8_t *ori_q, int8_t *ori_kv, int8_t *new_q, int8_t *new_k, int8_t *new_v, + __half *q_bias, __half *kv_bias, int bsz, int head_num, int head_dim, int tgt_len, + int src_len, float quant_scale, float dequant_scale, cudaStream_t stream); + +/* @jian.wang +for transformer decoder enc-attn + +q = dequant(ori_q) + q_bias +q: tgt_len,bsz,head_num,head_dim -> bsz,head_num,tgt_len,head_dim +只有tgt_len_idx = 0才需要写 +*/ +void IxinferArrangeDecEncQI8II8O(int8_t *ori_q, int8_t *new_q, __half *q_bias, int bsz, int head_num, int head_dim, + int tgt_len, float quant_scale, float dequant_scale, cudaStream_t stream); + +/* @jian.wang +for transformer decoder self-attn + +ori_qkv:[1,bsz,hsz*3] +qkv_bias: [hsz*3] + +new_q: [bsz,head_num,tgt_len,head_dim] +new_k: [bsz,head_num,src_len,head_dim] +new_v: [bsz,head_num,src_len,head_dim] + +prev_k,prev_v: [bsz,head_num,max_dec_len,head_dim] + +step 1: +qkv = ori_qkv + qkv_bias +step 2: +q,k,v split qkv都为: [1,bsz,head_num,head_dim] +qkv 维度调整 0123->1203 [bsz,head_num,1,head_dim] +step 3: +new_q[bsz,head_num,0,head_dim] = q +new_k[bsz,head_num,seq_idx,head_dim] = k +new_v[bsz,head_num,seq_idx,head_dim] = v + +prev_k[bsz,head_num,seq_idx,head_dim] = k +prev_v[bsz,head_num,seq_idx,head_dim] = v + +*/ +void IxinferArrangeDecSelfQkv(const half *ori_qkv, const half *qkv_bias, half *new_q, half *new_k, half *new_v, + half *prev_k, half *prev_v, int bsz, int seq_idx, int head_dim, int head_num, int tgt_len, + int src_len, int max_dec_len, cudaStream_t stream); + +/* @ jian.wang +k,v: [bsz,head_num,src_len,head_dim] +prev_k,prev_v: [bsz,head_num,max_dec_len,head_dim] + +k[:,:,:seq_idx,:] = prev_k[:,:,:seq_idx,:] +v[:,:,:seq_idx,:] = prev_v[:,:,:seq_idx,:] +*/ +void IxinferDecSelfKvCat(half *k, half *v, half *prev_k, half *prev_v, int bsz, int head_num, int seq_idx, int head_dim, + int src_len, int max_dec_len, cudaStream_t stream); + +/* @ jian.wang +ori_q: bsz, tgt_len, head_dim +new_q: tgt_len, bsz, head_dim + +将 ori_q 的 0,1维度转置 +*/ +void IxinferDecAttnOutArrange(half *ori_q, half *new_q, int bsz, int tgt_len, int head_dim, cudaStream_t stream); + +/* @jian.wang +for transformer decoder enc-attn + +ori_q: [tgt_len,bsz,head_num,head_dim] +new_q: [bsz,head_num,tgt_len,head_dim] + +ori_q = ori_q + q_bias + + +new_q[:,:,0,:] = ori_q.permute(1,2,0,3)[:,:,0,:] + +只有tgt_len_idx = 0才需要写 +*/ +void IxinferArrangeDecEncQ(half *ori_q, half *new_q, half *q_bias, int bsz, int head_num, int head_dim, int tgt_len, + cudaStream_t stream); + +/* @jian.wang +for transformer decoder enc-attn + +ori_q: [tgt_len,bsz,head_num*head_dim] +ori_kv: [src_len,bsz,head_num*head_dim*2] + +1. new_q 的计算, 同 IxinferArrangeDecEncQ +ori_q = ori_q + bias +new_q[:,:,0,:] = ori_q.permute(1,2,0,3)[:,:,0,:] +2. new_k, new_v的计算 +ori_kv = ori_kv + kv_bias +k,v = ori_kv.chunk(2,dim=-1) + +new_k = k.permute(1,2,0,3) +new_v = v.permute(1,2,0,3) +*/ +void IxinferArrangeDecEncQkv(half *ori_q, half *ori_kv, half *new_q, half *new_k, half *new_v, half *q_bias, + half *kv_bias, int bsz, int head_num, int head_dim, int tgt_len, int src_len, + cudaStream_t stream); + +/* +ori_qkv: [batch_tokens,hsz*3] batch_tokens>=bsz*ori_seq_len +ori_qkv_bias: [hsz*3] + +1. split qkv +ori_qkv = ori_qkv + ori_qkv_bias +q,k,v = torch.chunk(ori_qkv[:bsz*ori_seq_len,:],3,dim=-1) +qkv arrange +q: [bsz,head_num,ori_seq_len,head_dim] +k: [bsz,head_num,ori_seq_len,head_dim] +v: [bsz,head_num,ori_seq_len,head_dim] + +new_q[:,:,:ori_seq_len] = q[:,:,:ori_seq_len] +new_k[:,:,:ori_seq_len] = k[:,:,:ori_seq_len] +new_v[:,:,:ori_seq_len] = v[:,:,:ori_seq_len] +*/ +void IxinferArrangeEncQkv(half *ori_qkv, half *ori_qkv_bias, half *new_q, half *new_k, half *new_v, int bsz, + int head_num, int head_dim, int ori_seq_len, int fmha_seq_len, cudaStream_t stream); + +/* @ jian.wang +ori_q: bsz, head_num, fmha_seq_len, head_dim +new_q: bsz, ori_seq_len, head_num, head_dim + +ori_q = ori_q.permute(0,2,1,3) +new_q = ori_q[:,:ori_seq_len,head_num,head_dim] +*/ +void IxinferEncAttnOutArrange(half *ori_q, half *new_q, int bsz, int ori_seq_len, int fmha_seq_len, int head_num, + int head_dim, cudaStream_t stream); + +/* +ori_qkv: [bsz,ori_seq_len,hsz*3] +new_q: [bsz,head_num,fmha_seq_len,head_dim] +new_k: [bsz,head_num,fmha_seq_len,head_dim] +new_v: [bsz,head_num,fmha_seq_len,head_dim] +prev_key: [bsz,head_num,prev_len,head_dim] +prev_value: [bsz,head_num,prev_len,head_dim] + +1. split qkv +q,k,v = torch.chunk(ori_qkv[:bsz*ori_seq_len,:],3,dim=-1) +qkv arrange +q: [bsz,head_num,ori_seq_len,head_dim] +k: [bsz,head_num,ori_seq_len,head_dim] +v: [bsz,head_num,ori_seq_len,head_dim] + +new_q[:,:,:ori_seq_len] = q[:,:,:ori_seq_len] +new_k[:,:,:ori_seq_len] = k[:,:,:ori_seq_len] +new_v[:,:,:ori_seq_len] = v[:,:,:ori_seq_len] +prev_key[:,:,:ori_seq_len] = q[:,:,:ori_seq_len] +prev_value[:,:,:ori_seq_len] = k[:,:,:ori_seq_len] +*/ +void IxinferArrangeGPT2ContextQkv(half *ori_qkv, half *new_q, half *new_k, half *new_v, half *prev_key, + half *prev_value, int bsz, int head_num, int head_dim, int ori_seq_len, + int fmha_seq_len, int prev_len, cudaStream_t stream); + +/* @jian.wang +for GPT2 decoder self-attn + +ori_qkv: [bsz,1,hsz*3] +length: [bsz] +seq_idx = length[batch_idx] + +new_q: [bsz,head_num,tgt_len,head_dim] +new_k: [bsz,head_num,src_len,head_dim] +new_v: [bsz,head_num,src_len,head_dim] + +prev_k,prev_v: [bsz,head_num,max_dec_len,head_dim] + +step 1: +q,k,v split qkv都为: [bsz,1,head_num,head_dim] +qkv 维度调整 0123->1203 [bsz,head_num,1,head_dim] +step 3: +new_q[bsz,head_num,0,head_dim] = q +new_k[bsz,head_num,seq_idx,head_dim] = k +new_v[bsz,head_num,seq_idx,head_dim] = v + +prev_k[bsz,head_num,seq_idx,head_dim] = k +prev_v[bsz,head_num,seq_idx,head_dim] = v + +*/ +void IxinferArrangeGPT2SelfQkv(int32_t *length, const half *ori_qkv, half *new_q, half *new_k, half *new_v, + half *prev_k, half *prev_v, int bsz, int head_dim, int head_num, int tgt_len, + int src_len, int max_dec_len, cudaStream_t stream); + +/* @ jian.wang +k,v: [bsz,head_num,src_len,head_dim] +prev_k,prev_v: [bsz,head_num,max_dec_len,head_dim] + +k[:,:,:src_len,:] = prev_k[:,:,:src_len,:] +v[:,:,:src_len,:] = prev_v[:,:,:src_len,:] +*/ +void IxinferGPT2SelfKvCat(half *k, half *v, half *prev_k, half *prev_v, int bsz, int head_num, int head_dim, + int src_len, int max_dec_len, cudaStream_t stream); +// cross attention +void IxinferCrossAttnOutArrange(half *ori_q, half *new_q, int bsz, int tgt_len, int num_head, int head_dim, + cudaStream_t stream); +void IxinferArrangeCrossQkv(half *ori_q, half *ori_kv, half *new_q, half *new_k, half *new_v, half *q_bias, + half *kv_bias, int bsz, int head_num, int head_dim, int tgt_len, int src_len, + int fmha_src_len, cudaStream_t stream); + +} // namespace backend +} // namespace plugin +} // namespace nvinfer1 diff --git a/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/backend/transformer/transformer_embed.cu b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/backend/transformer/transformer_embed.cu new file mode 100644 index 0000000000000000000000000000000000000000..5753c8ccc9f791ab83a98f00b96b33784a1f32a5 --- /dev/null +++ b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/backend/transformer/transformer_embed.cu @@ -0,0 +1,608 @@ +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +* All Rights Reserved. +* +* Licensed under the Apache License, Version 2.0 (the "License"); you may +* not use this file except in compliance with the License. You may obtain +* a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +* License for the specific language governing permissions and limitations +* under the License. +*/ +#include "transformer_embed.h" +#include "transformer_helper.cuh" + +namespace nvinfer1 { +namespace plugin { +namespace backend { + +const float epsilon = 0.000000000001; + +__global__ void IxinferDecEmbedKernel(const float *pos_emb_weight, const float *token_emb_weight, float *output, + const int32_t *token_ids, const int pos_idx, const float embed_scale) { + int token_id = token_ids[blockIdx.x]; + float value = token_emb_weight[token_id * blockDim.x + threadIdx.x] * embed_scale + + pos_emb_weight[pos_idx * blockDim.x + threadIdx.x]; + output[blockIdx.x * blockDim.x + threadIdx.x] = value; +} + +__global__ void IxinferDecEmbedKernel(const __half *pos_emb_weight, const __half *token_emb_weight, __half *output, + const int32_t *token_ids, const int pos_idx, const float embed_scale) { + half2 *p_pos_emb_weight = (half2 *)pos_emb_weight; + half2 *p_token_emb_weight = (half2 *)token_emb_weight; + half2 *p_out = (half2 *)output; + int token_id = token_ids[blockIdx.x]; + + half2 p_value = p_pos_emb_weight[pos_idx * blockDim.x + threadIdx.x]; + half2 t_value = p_token_emb_weight[token_id * blockDim.x + threadIdx.x]; + + half2 value; + value.x = __float2half(__half2float(p_value.x) + __half2float(t_value.x) * embed_scale); + value.y = __float2half(__half2float(p_value.y) + __half2float(t_value.y) * embed_scale); + + p_out[blockIdx.x * blockDim.x + threadIdx.x] = value; +} + +void IxinferDecEmbed(float *pos_emb_weight, float *token_emb_weight, float *output, int32_t *token_ids, int bsz, + int seq_len, int padding_idx, int embed_dim, float embed_scale, cudaStream_t stream) { + if (embed_dim > 4096) { + throw std::runtime_error("embed_dim>4096"); + } + IxinferDecEmbedKernel<<>>(pos_emb_weight, token_emb_weight, output, token_ids, + seq_len + padding_idx, embed_scale); +} + +void IxinferDecEmbed(__half *pos_emb_weight, __half *token_emb_weight, __half *output, int32_t *token_ids, int bsz, + int seq_len, int padding_idx, int embed_dim, float embed_scale, cudaStream_t stream) { + if (embed_dim / 2 > 4096) { + throw std::runtime_error("embed_dim/2>4096"); + } + if (embed_dim % 2 != 0) { + throw std::runtime_error("embed_dim % 2 !=0"); + } + IxinferDecEmbedKernel<<>>(pos_emb_weight, token_emb_weight, output, token_ids, + seq_len + padding_idx, embed_scale); +} + +void __global__ IxinferEncTokenEmbedKernel(int32_t *tokens, half *weight, int32_t *mask, half *output, int pad_idx, + float embed_scale) { + int hidden_size = blockDim.x * 2; + int batch_idx = blockIdx.x; + int seq_idx = blockIdx.y; + + int token_id = tokens[batch_idx * gridDim.y + seq_idx]; + // update mask + if (threadIdx.x == 0) { + mask[batch_idx * gridDim.y + seq_idx] = (token_id == pad_idx) ? 1 : 0; + } + half2 *p_weight = (half2 *)(weight + token_id * hidden_size); + half2 *p_out = (half2 *)(output + batch_idx * gridDim.y * hidden_size + seq_idx * hidden_size); + + half2 w = p_weight[threadIdx.x]; + half2 res; + res.x = __float2half(__half2float(w.x) * embed_scale); + res.y = __float2half(__half2float(w.y) * embed_scale); + p_out[threadIdx.x] = res; +} + +void IxinferEncTokenEmbed(int32_t *tokens, half *weight, int32_t *mask, half *output, int bsz, int seq_len, int hsz, + int pad_idx, float embed_scale, cudaStream_t stream) { + if (hsz / 2 > 4096) { + throw std::runtime_error("hsz/2>4096"); + } + if (hsz % 2 != 0) { + throw std::runtime_error("hsz % 2 !=0"); + } + dim3 blockSize(bsz, seq_len); + IxinferEncTokenEmbedKernel<<>>(tokens, weight, mask, output, pad_idx, embed_scale); +} + +void __global__ IxinferEncPositionKernel(const int32_t *mask, int32_t *pos, const int seq_len, const int pad_id) { + mask += threadIdx.x * seq_len; + pos += threadIdx.x * seq_len; + int p = pad_id; + for (int i = 0; i < seq_len; ++i) { + if (mask[i]) { + // padding + pos[i] = pad_id; + } else { + p += 1; + pos[i] = p; + } + } +} + +void IxinferEncPosition(int32_t *mask, int32_t *pos, int bsz, int seq_len, int pad_idx, cudaStream_t stream) { + if (bsz > 4096) { + throw std::runtime_error("bsz>4096"); + } + IxinferEncPositionKernel<<<1, bsz, 0, stream>>>(mask, pos, seq_len, pad_idx); +} + +void __global__ IxinferEncPosEmbedKernel(int32_t *pos, half *weight, half *token_embed, half *output) { + int hidden_size = blockDim.x * 2; + int batch_idx = blockIdx.x; + int seq_idx = blockIdx.y; + + int token_id = pos[batch_idx * gridDim.y + seq_idx]; + + half2 *p_weight = (half2 *)(weight + token_id * hidden_size); + half2 *p_token_embed = (half2 *)(token_embed + batch_idx * gridDim.y * hidden_size + seq_idx * hidden_size); + half2 *p_out = (half2 *)(output + batch_idx * gridDim.y * hidden_size + seq_idx * hidden_size); + p_out[threadIdx.x] = __hadd2(p_weight[threadIdx.x], p_token_embed[threadIdx.x]); +} + +void IxinferEncPosEmbed(int32_t *pos, half *weight, half *token_embed, half *output, int bsz, int seq_len, int hsz, + cudaStream_t stream) { + if (hsz / 2 > 4096) { + throw std::runtime_error("hsz/2>4096"); + } + if (hsz % 2 != 0) { + throw std::runtime_error("hsz % 2 !=0"); + } + dim3 blockSize(bsz, seq_len); + IxinferEncPosEmbedKernel<<>>(pos, weight, token_embed, output); +} + +void IxinferEncEmbed(int32_t *tokens, half *token_weight, half *pos_weight, int32_t *pos_buffer, int32_t *mask, + half *output, int bsz, int seq_len, int hsz, int pad_idx, float embed_scale, cudaStream_t stream) { + IxinferEncTokenEmbed(tokens, token_weight, mask, output, bsz, seq_len, hsz, pad_idx, embed_scale, stream); + IxinferEncPosition(mask, pos_buffer, bsz, seq_len, pad_idx, stream); + IxinferEncPosEmbed(pos_buffer, pos_weight, output, output, bsz, seq_len, hsz, stream); +} + +__global__ void IxinferBertEmbedKernel(const __half *token_emb, const __half *pos_emb, const __half *type_emb, + const int *tokens, __half *output, int *pad_mask, int *type_ids, int pad_id, + int batch_size, int seq_len, int hidden_dim) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= batch_size * seq_len * hidden_dim) { + return; + } + int batch_idx, seq_idx, dim_idx; + decompose_3dim(idx, seq_len, hidden_dim, &batch_idx, &seq_idx, &dim_idx); + int tokens_idx = batch_idx * seq_len + seq_idx; + int token = tokens[tokens_idx]; + int token_type = type_ids[tokens_idx]; + + float4 value; + + if (token == pad_id) { + if (dim_idx == 0) { + pad_mask[tokens_idx] = 1; + } + value.x = 0.f; + value.y = 0.f; + value.z = 0.f; + value.w = 0.f; + } else { + if (dim_idx == 0) { + pad_mask[tokens_idx] = 0; + } + value = ((float4 *)token_emb)[token * hidden_dim + dim_idx]; + float4 pemb = ((float4 *)pos_emb)[seq_idx * hidden_dim + dim_idx]; + float4 temb = ((float4 *)type_emb)[token_type * hidden_dim + dim_idx]; + __half2 *value_h2 = (__half2 *)(&value); + __half2 *pemb_h2 = (__half2 *)(&pemb); + __half2 *temb_h2 = (__half2 *)(&temb); +#pragma unroll + for (int i = 0; i < 4; i++) { + float2 value_f2 = __half22float2(value_h2[i]); + float2 pemb_f2 = __half22float2(pemb_h2[i]); + float2 temb_f2 = __half22float2(temb_h2[i]); + value_f2.x += (pemb_f2.x + temb_f2.x); + value_f2.y += (pemb_f2.y + temb_f2.y); + value_h2[i] = __float22half2_rn(value_f2); + } + } + ((float4 *)output)[idx] = value; +} + +void IxinferBertEmbed(const __half *token_emb, const __half *pos_emb, const __half *type_emb, const int *tokens, + __half *output, int *pad_mask, int *type_ids, int pad_id, int batch_size, int seq_len, + int hidden_dim, cudaStream_t stream) { + if (hidden_dim % 8 != 0) { + throw std::runtime_error("violate hidden_dim % 8 = 0"); + } + const int MAX_THREADS = 512; + hidden_dim >>= 3; + int nele = batch_size * seq_len * hidden_dim; + int nblock = (nele + MAX_THREADS - 1) / MAX_THREADS; + + IxinferBertEmbedKernel<<>>(token_emb, pos_emb, type_emb, tokens, output, pad_mask, + type_ids, pad_id, batch_size, seq_len, hidden_dim); +} + +template +__global__ void IxinferBertEmbedKernel(const T *token_emb, const T *pos_emb, const T *type_emb, const int *tokens, + T *output, int *pad_mask, int *type_ids, int pad_id, int batch_size, int seq_len, + int hidden_dim) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= batch_size * seq_len * hidden_dim) { + return; + } + int batch_idx, seq_idx, dim_idx; + decompose_3dim(idx, seq_len, hidden_dim, &batch_idx, &seq_idx, &dim_idx); + int tokens_idx = batch_idx * seq_len + seq_idx; + int token = tokens[tokens_idx]; + int token_type = type_ids[tokens_idx]; + float4 value; + + if (token == pad_id) { + if (dim_idx == 0) { + pad_mask[tokens_idx] = 1; + } + value.x = 0.f; + value.y = 0.f; + value.z = 0.f; + value.w = 0.f; + } else { + if (dim_idx == 0) { + pad_mask[tokens_idx] = 0; + } + value = ((float4 *)token_emb)[token * hidden_dim + dim_idx]; + float4 pemb = ((float4 *)pos_emb)[seq_idx * hidden_dim + dim_idx]; + float4 temb = ((float4 *)type_emb)[token_type * hidden_dim + dim_idx]; + value.x += (pemb.x + temb.x); + value.y += (pemb.y + temb.y); + value.z += (pemb.z + temb.z); + value.w += (pemb.w + temb.w); + // value.x += (pemb.x); + // value.y += (pemb.y); + // value.z += (pemb.z); + // value.w += (pemb.w); + } + ((float4 *)output)[idx] = value; +} + +template <> +__global__ void IxinferBertEmbedKernel<__half>(const __half *token_emb, const __half *pos_emb, const __half *type_emb, + const int *tokens, __half *output, int *pad_mask, int *type_ids, + int pad_id, int batch_size, int seq_len, int hidden_dim) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= batch_size * seq_len * hidden_dim) { + return; + } + int batch_idx, seq_idx, dim_idx; + decompose_3dim(idx, seq_len, hidden_dim, &batch_idx, &seq_idx, &dim_idx); + int tokens_idx = batch_idx * seq_len + seq_idx; + int token = tokens[tokens_idx]; + int token_type = type_ids[tokens_idx]; + + float4 value; + + if (token == pad_id) { + if (dim_idx == 0) { + pad_mask[tokens_idx] = 1; + } + value.x = 0.f; + value.y = 0.f; + value.z = 0.f; + value.w = 0.f; + } else { + if (dim_idx == 0) { + pad_mask[tokens_idx] = 0; + } + value = ((float4 *)token_emb)[token * hidden_dim + dim_idx]; + float4 pemb = ((float4 *)pos_emb)[seq_idx * hidden_dim + dim_idx]; + float4 temb = ((float4 *)type_emb)[token_type * hidden_dim + dim_idx]; + __half2 *value_h2 = (__half2 *)(&value); + __half2 *pemb_h2 = (__half2 *)(&pemb); + __half2 *temb_h2 = (__half2 *)(&temb); +#pragma unroll + for (int i = 0; i < 4; i++) { + float2 value_f2 = __half22float2(value_h2[i]); + float2 pemb_f2 = __half22float2(pemb_h2[i]); + float2 temb_f2 = __half22float2(temb_h2[i]); + value_f2.x += (pemb_f2.x + temb_f2.x); + value_f2.y += (pemb_f2.y + temb_f2.y); + value_h2[i] = __float22half2_rn(value_f2); + } + } + ((float4 *)output)[idx] = value; +} +template +void IxinferBertEmbed(const T *token_emb, const T *pos_emb, const T *type_emb, const int *tokens, T *output, + int *pad_mask, int *type_ids, int pad_id, int batch_size, int seq_len, int hidden_dim, + cudaStream_t stream) { + const int MAX_THREADS = 512; + if (hidden_dim % 4 != 0) { + throw std::runtime_error("violate hidden_dim % 4 = 0"); + } + hidden_dim >>= 2; + int nele = batch_size * seq_len * hidden_dim; + int nblock = (nele + MAX_THREADS - 1) / MAX_THREADS; + + IxinferBertEmbedKernel<<>>( + token_emb, pos_emb, type_emb, tokens, output, pad_mask, type_ids, pad_id, batch_size, seq_len, hidden_dim); +} + +template <> +void IxinferBertEmbed<__half>(const __half *token_emb, const __half *pos_emb, const __half *type_emb, const int *tokens, + __half *output, int *pad_mask, int *type_ids, int pad_id, int batch_size, int seq_len, + int hidden_dim, cudaStream_t stream) { + const int MAX_THREADS = 512; + if (hidden_dim % 8 != 0) { + throw std::runtime_error("violate hidden_dim % 8 = 0"); + } + hidden_dim >>= 3; + int nele = batch_size * seq_len * hidden_dim; + int nblock = (nele + MAX_THREADS - 1) / MAX_THREADS; + + IxinferBertEmbedKernel<__half><<>>( + token_emb, pos_emb, type_emb, tokens, output, pad_mask, type_ids, pad_id, batch_size, seq_len, hidden_dim); +} + +template void IxinferBertEmbed(const float *token_emb, const float *pos_emb, const float *type_emb, + const int *tokens, float *output, int *pad_mask, int *type_ids, int pad_id, + int batch_size, int seq_len, int hidden_dim, cudaStream_t stream); + +template +__global__ void IxinferBertEmbedLnKernel(const __half *token_emb, const __half *pos_emb, const __half *type_emb, + const int *tokens, __half *output, int *pad_mask, int *type_ids, int pad_id, + int batch_size, int seq_len, int hidden_dim, const __half *scale, + const __half *bias) { + // register + float2 vals[THREAD_DATA_LEN]; + int block_start = blockIdx.x * hidden_dim; + // one line start + // input += block_start; + output += block_start; + + // __half2 *p_input = (__half2 *)input; + __half2 *p_output = (__half2 *)output; + __half2 *p_scale = (__half2 *)scale; + __half2 *p_bias = (__half2 *)bias; + + float thread_m2 = 0; + float thread_mean = 0; + float thread_count = 0; + + // load data from global memory +#pragma unroll + for (int it = 0; it < THREAD_DATA_LEN; ++it) { + int element_index = threadIdx.x + it * warpSize; + // __half2 value = p_input[element_index]; + // int idx = block_start+element_index; + + int batch_idx, seq_idx, dim_idx; + batch_idx = blockIdx.x / seq_len; + seq_idx = blockIdx.x % seq_len; + dim_idx = element_index; + // decompose_3dim(idx, seq_len, hidden_dim, &batch_idx, &seq_idx, &dim_idx); + int tokens_idx = blockIdx.x; + int token = tokens[tokens_idx]; + int token_type = type_ids[tokens_idx]; + + half2 value; + + if (token == pad_id) { + if (dim_idx == 0) { + pad_mask[tokens_idx] = 1; + } + value.x = __float2half(0.f); + value.y = __float2half(0.f); + + } else { + if (dim_idx == 0) { + pad_mask[tokens_idx] = 0; + } + value = ((half2 *)(token_emb + token * hidden_dim + dim_idx * 2))[0]; + half2 pemb = ((half2 *)(pos_emb + seq_idx * hidden_dim + dim_idx * 2))[0]; + half2 temb = ((half2 *)(type_emb + token_type * hidden_dim + dim_idx * 2))[0]; + + vals[it].x = __half2float(value.x) + __half2float(pemb.x) + __half2float(temb.x); + vals[it].y = __half2float(value.y) + __half2float(pemb.y) + __half2float(temb.y); + + WelfordCombine(vals[it].x, &thread_mean, &thread_m2, &thread_count); + WelfordCombine(vals[it].y, &thread_mean, &thread_m2, &thread_count); + } + + // mean var + float mean = 0; + float m2 = 0; + float count = 0; + WelfordWarpReduce(thread_mean, thread_m2, thread_count, &mean, &m2, &count); + mean = __shfl_sync(0xffffffff, mean, 0, warpSize); + m2 = __shfl_sync(0xffffffff, m2, 0, warpSize); + count = __shfl_sync(0xffffffff, count, 0, warpSize); + m2 = rsqrtf(m2 / hidden_dim + epsilon); + +#pragma unroll + for (int it = 0; it < THREAD_DATA_LEN; ++it) { + int element_index = threadIdx.x + it * warpSize; + + __half2 scale_value = p_scale[element_index]; + __half2 bias_value = p_bias[element_index]; + + float2 norm_value; + norm_value.x = (vals[it].x - mean) * m2 * __half2float(scale_value.x) + __half2float(bias_value.x); + norm_value.y = (vals[it].y - mean) * m2 * __half2float(scale_value.y) + __half2float(bias_value.y); + + __half2 res; + res.x = __float2half(norm_value.x); + res.y = __float2half(norm_value.y); + + int token = tokens[tokens_idx]; + if (token == pad_id) { + res.x = __float2half(0.f); + res.y = __float2half(0.f); + p_output[element_index] = res; + } else { + p_output[element_index] = res; + } + } + } +} + +void IxinferBertEmbedLn(const half *token_emb, const half *pos_emb, const half *type_emb, + const int *tokens, half *output, int *pad_mask, int *type_ids, int pad_id, + int batch_size, int seq_len, int hidden_size, const half *scale, const half *bias, + cudaStream_t stream) { + if (hidden_size > 2048) { + throw std::runtime_error("hidden_size should <= 2048"); + } + if (hidden_size / 2 % warpSize != 0) { + throw std::runtime_error("hidden_size / 2 // warpSize != 0"); + } + int batch_tokens = batch_size * seq_len; + dim3 gridSize(batch_tokens); + dim3 blockSize(warpSize); + + int num_warp = hidden_size / warpSize / 2; + + switch (num_warp) { + case 1: + IxinferBertEmbedLnKernel<1> + <<>>(token_emb, pos_emb, type_emb, tokens, output, pad_mask, type_ids, + pad_id, batch_size, seq_len, hidden_size, scale, bias); + break; + case 2: + IxinferBertEmbedLnKernel<2> + <<>>(token_emb, pos_emb, type_emb, tokens, output, pad_mask, type_ids, + pad_id, batch_size, seq_len, hidden_size, scale, bias); + break; + case 3: + IxinferBertEmbedLnKernel<3> + <<>>(token_emb, pos_emb, type_emb, tokens, output, pad_mask, type_ids, + pad_id, batch_size, seq_len, hidden_size, scale, bias); + break; + case 4: + IxinferBertEmbedLnKernel<4> + <<>>(token_emb, pos_emb, type_emb, tokens, output, pad_mask, type_ids, + pad_id, batch_size, seq_len, hidden_size, scale, bias); + break; + case 5: + IxinferBertEmbedLnKernel<5> + <<>>(token_emb, pos_emb, type_emb, tokens, output, pad_mask, type_ids, + pad_id, batch_size, seq_len, hidden_size, scale, bias); + break; + case 6: + IxinferBertEmbedLnKernel<6> + <<>>(token_emb, pos_emb, type_emb, tokens, output, pad_mask, type_ids, + pad_id, batch_size, seq_len, hidden_size, scale, bias); + break; + case 7: + IxinferBertEmbedLnKernel<7> + <<>>(token_emb, pos_emb, type_emb, tokens, output, pad_mask, type_ids, + pad_id, batch_size, seq_len, hidden_size, scale, bias); + break; + case 8: + IxinferBertEmbedLnKernel<8> + <<>>(token_emb, pos_emb, type_emb, tokens, output, pad_mask, type_ids, + pad_id, batch_size, seq_len, hidden_size, scale, bias); + break; + case 9: + IxinferBertEmbedLnKernel<9> + <<>>(token_emb, pos_emb, type_emb, tokens, output, pad_mask, type_ids, + pad_id, batch_size, seq_len, hidden_size, scale, bias); + break; + case 10: + IxinferBertEmbedLnKernel<10> + <<>>(token_emb, pos_emb, type_emb, tokens, output, pad_mask, type_ids, + pad_id, batch_size, seq_len, hidden_size, scale, bias); + break; + case 11: + IxinferBertEmbedLnKernel<11> + <<>>(token_emb, pos_emb, type_emb, tokens, output, pad_mask, type_ids, + pad_id, batch_size, seq_len, hidden_size, scale, bias); + break; + case 12: + IxinferBertEmbedLnKernel<12> + <<>>(token_emb, pos_emb, type_emb, tokens, output, pad_mask, type_ids, + pad_id, batch_size, seq_len, hidden_size, scale, bias); + break; + case 13: + IxinferBertEmbedLnKernel<13> + <<>>(token_emb, pos_emb, type_emb, tokens, output, pad_mask, type_ids, + pad_id, batch_size, seq_len, hidden_size, scale, bias); + break; + case 14: + IxinferBertEmbedLnKernel<14> + <<>>(token_emb, pos_emb, type_emb, tokens, output, pad_mask, type_ids, + pad_id, batch_size, seq_len, hidden_size, scale, bias); + break; + case 15: + IxinferBertEmbedLnKernel<15> + <<>>(token_emb, pos_emb, type_emb, tokens, output, pad_mask, type_ids, + pad_id, batch_size, seq_len, hidden_size, scale, bias); + break; + case 16: + IxinferBertEmbedLnKernel<16> + <<>>(token_emb, pos_emb, type_emb, tokens, output, pad_mask, type_ids, + pad_id, batch_size, seq_len, hidden_size, scale, bias); + break; + default: + throw std::runtime_error("IxinferBertEmbedLn"); + break; + } +} + +__global__ void IxinferGptEmbedKernel(const int32_t *token_ids, const half *token_weight, const half *pos_weight, + half *output, const int hidden_size) { + int batch_idx = blockIdx.x; + int seq_idx = blockIdx.y; + + int token_id = token_ids[seq_idx * gridDim.x + batch_idx]; + int pos_id = seq_idx; + + output += batch_idx * gridDim.y * hidden_size + seq_idx * hidden_size; + token_weight += token_id * hidden_size; + pos_weight += pos_id * hidden_size; + + half2 *p_out = (half2 *)(output); + half2 *p_token_wei = (half2 *)(token_weight); + half2 *p_pos_wei = (half2 *)(pos_weight); + + int idx = threadIdx.x; + while (idx < hidden_size / 2) { + p_out[idx] = __hadd2(p_token_wei[idx], p_pos_wei[idx]); + idx += blockDim.x; + } +} +void IxinferGptEmbed(int32_t *token_ids, half *token_weight, half *pos_weight, half *output, int batch_size, + int seq_len, int hidden_size, cudaStream_t stream) { + if (hidden_size % 2 != 0) { + throw std::runtime_error("hidden_size % 2 != 0"); + } + dim3 gridSize(batch_size, seq_len); + IxinferGptEmbedKernel<<>>(token_ids, token_weight, pos_weight, output, hidden_size); +}; + +__global__ void IxinferGptEmbedKernel(const int32_t *token_ids, const int32_t *length, const half *token_weight, + const half *pos_weight, half *output, const int hidden_size) { + int batch_idx = blockIdx.x; + + int token_id = token_ids[batch_idx]; + int pos_id = length[batch_idx] - 1; + + output += batch_idx * hidden_size; + token_weight += token_id * hidden_size; + pos_weight += pos_id * hidden_size; + + half2 *p_out = (half2 *)(output); + half2 *p_token_wei = (half2 *)(token_weight); + half2 *p_pos_wei = (half2 *)(pos_weight); + + int idx = threadIdx.x; + while (idx < hidden_size / 2) { + p_out[idx] = __hadd2(p_token_wei[idx], p_pos_wei[idx]); + idx += blockDim.x; + } +} +void IxinferGptEmbed(int32_t *token_ids, int32_t *length, half *token_weight, half *pos_weight, half *output, + int batch_size, int hidden_size, cudaStream_t stream) { + if (hidden_size % 2 != 0) { + throw std::runtime_error("hidden_size % 2 != 0"); + } + dim3 gridSize(batch_size); + IxinferGptEmbedKernel<<>>(token_ids, length, token_weight, pos_weight, output, + hidden_size); +}; + +} // namespace backend +} // namespace plugin +} // namespace nvinfer1 diff --git a/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/backend/transformer/transformer_embed.h b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/backend/transformer/transformer_embed.h new file mode 100644 index 0000000000000000000000000000000000000000..47fa8b1862db8b52cc9271167d92eb5e2a0518da --- /dev/null +++ b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/backend/transformer/transformer_embed.h @@ -0,0 +1,121 @@ +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +* All Rights Reserved. +* +* Licensed under the Apache License, Version 2.0 (the "License"); you may +* not use this file except in compliance with the License. You may obtain +* a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +* License for the specific language governing permissions and limitations +* under the License. +*/ +#pragma once +#include +#include + +#include + +namespace nvinfer1 { +namespace plugin { +namespace backend { + +/* +by jian.wang +transformer decoder对应的embedding + +pos_emb_weights: [max_pos,embed_dim] +token_emb_weight: [vocab_size,embed_dim] +output: [1,bsz,embed_dim] + +token_ids: [bsz,1] + +x1 = torch.nn.functional.embedding(token_ids, token_emb_weights) +x2 = torch.nn.functional.embedding([seq_len+padding_idx], pos_emb_weights) + +output = x1 * embed_scale + x2 +*/ +void IxinferDecEmbed(float *pos_emb_weight, float *token_emb_weight, float *output, int32_t *token_ids, int bsz, + int seq_len, int padding_idx, int embed_dim, float embed_scale, cudaStream_t stream); +void IxinferDecEmbed(__half *pos_emb_weight, __half *token_emb_weight, __half *output, int32_t *token_ids, int bsz, + int seq_len, int padding_idx, int embed_dim, float embed_scale, cudaStream_t stream); + +/* +tokens: [bsz,seq_len] +weight: [vocab_size,hsz] +mask: [bsz,seq_len] +output: [bsz,seq_len,hsz] + +output = torch.nn.functional.embedding(tokens, weight,padding_idx=pad_idx) * +embed_scale mask = tokens.eq(pad_idx) +*/ +void IxinferEncTokenEmbed(int32_t *tokens, half *weight, int32_t *mask, half *output, int bsz, int seq_len, int hsz, + int pad_idx, float embed_scale, cudaStream_t stream); + +/* +mask: [bsz,seq_len] +pos: [bsz,seq_len] + +根据mask,计算position值 +例如: pad_idx = 1 +mask: 1,1,0,0 +pos: 1,1,2,3 +*/ +void IxinferEncPosition(int32_t *mask, int32_t *pos, int bsz, int seq_len, int pad_idx, cudaStream_t stream); + +/* +pos: [bsz,seq_len] +weight: [max_pos,hsz] +token_embed: [bsz,seq_len,hsz] +output: [bsz,seq_len,hsz] + +output = token_embed + torch.nn.functional.embedding(pos, weight) +*/ +void IxinferEncPosEmbed(int32_t *pos, half *weight, half *token_embed, half *output, int bsz, int seq_len, int hsz, + cudaStream_t stream); +/* +依次执行 +IxinferEncTokenEmbed +IxinferEncPosition +IxinferEncPosEmbed +*/ +void IxinferEncEmbed(int32_t *tokens, half *token_weight, half *pos_weight, int32_t *pos_buffer, int32_t *mask, + half *output, int bsz, int seq_len, int hsz, int pad_idx, float embed_scale, cudaStream_t stream); + +void IxinferBertEmbed(const __half *token_emb, const __half *pos_emb, const __half *type_emb, const int *tokens, + __half *output, int *pad_mask, int *type_ids, int pad_id, int batch_size, int seq_len, + int hidden_dim, cudaStream_t stream); + +void IxinferBertEmbedLn(const half *token_emb, const half *pos_emb, const half *type_emb, const int *tokens, half *output, + int *pad_mask, int *type_ids, int pad_id, int batch_size, int seq_len, int hidden_size, + const half *scale, const half *bias, cudaStream_t stream); + +/* +gpt embedding +1. seq_len > 1 +token_ids: [seq_len,batch_size] +token_weight: [vocab_size,hidden_size] +pos_weight: [max_pos,hidden_size] + +pos_ids = range(seq_len) +output = embedding(token_ids) + embedding(pos_ids) +2. seq_len == 1 +token_ids: [1,batch_size] +token_weight: [vocab_size,hidden_size] +pos_weight: [max_pos,hidden_size] + +output = embedding(token_ids) + embedding(pos_ids) + +output = output.transpose(0,1) # batch_size,seq_len,hidden_size +*/ +void IxinferGptEmbed(int32_t *token_ids, half *token_weight, half *pos_weight, half *output, int batch_size, + int seq_len, int hidden_size, cudaStream_t stream); +void IxinferGptEmbed(int32_t *token_ids, int32_t *length, half *token_weight, half *pos_weight, half *output, + int batch_size, int hidden_size, cudaStream_t stream); + +} // namespace backend +} // namespace plugin +} // namespace nvinfer1 diff --git a/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/backend/transformer/transformer_helper.cuh b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/backend/transformer/transformer_helper.cuh new file mode 100644 index 0000000000000000000000000000000000000000..36fa151c3a66897647e58a7ea75d7ece7c38c751 --- /dev/null +++ b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/backend/transformer/transformer_helper.cuh @@ -0,0 +1,279 @@ +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +* All Rights Reserved. +* +* Licensed under the Apache License, Version 2.0 (the "License"); you may +* not use this file except in compliance with the License. You may obtain +* a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +* License for the specific language governing permissions and limitations +* under the License. +*/ +#pragma once +#include +#include + +namespace nvinfer1 { +namespace plugin { +namespace backend { + +__forceinline__ int nearest_4(int x) { + if (x % 4 == 0) { + return x; + } else { + int padding = 4 - x % 4; + return x + padding; + } +} + +__forceinline__ int nearest_2(int x) { + if (x % 2 == 0) { + return x; + } else { + int padding = 2 - x % 2; + return x + padding; + } +} + +__forceinline__ int nearest_num(int x, int value) { + if (x % value == 0) { + return x; + } else { + int padding = value - x % value; + return x + padding; + } +} + +__device__ int8_t float2int8(float x, float quant_scale) { + float i8_f = x * quant_scale; + int32_t i8 = floorf(i8_f + 0.5); + i8 = i8 < -127 ? -127 : (i8 > 127 ? 127 : i8); + return int8_t(i8); +} + +__device__ void WelfordCombine(float val, float *mean, float *m2, float *count) { + // Use Welford Online algorithem to compute mean and variance + // For more details you can refer to: + // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm + *count += 1; + float delta1 = val - *mean; + *mean += delta1 / *count; + float delta2 = val - *mean; + *m2 += delta1 * delta2; +} + +__device__ void WelfordCombine(float b_mean, float b_m2, float b_count, float *mean, float *m2, float *count) { + if (b_count == 0) { + return; + } + float new_count = *count + b_count; + float nb_over_n = b_count / new_count; + float delta = b_mean - *mean; + *mean += delta * nb_over_n; + *m2 += b_m2 + delta * delta * (*count) * nb_over_n; + *count = new_count; +} + +__device__ void WelfordWarpReduce(float thread_mean, float thread_m2, float thread_count, float *mean, float *m2, + float *count) { + *mean = thread_mean; + *m2 = thread_m2; + *count = thread_count; + for (int mask = warpSize / 2; mask > 0; mask /= 2) { + float b_mean = __shfl_down_sync(0xffffffff, *mean, mask); + float b_m2 = __shfl_down_sync(0xffffffff, *m2, mask); + float b_count = __shfl_down_sync(0xffffffff, *count, mask); + WelfordCombine(b_mean, b_m2, b_count, mean, m2, count); + } +} +// addd by pxl +// block内所有数据完成reduce +// template +__inline__ __device__ void WelfordBlockAllReduce(float thread_mean, float thread_m2, float thread_count, + float *result_mean, float *result_m2, float *result_count) { + __shared__ float mean_shared[warpSize]; + __shared__ float m2_shared[warpSize]; + __shared__ float count_shared[warpSize]; + __shared__ float mean_result_broadcast; + __shared__ float m2_result_broadcast; + __shared__ float count_result_broadcast; + + const int lid = threadIdx.x % warpSize; + const int wid = threadIdx.x / warpSize; + float warp_mean = 0; + float warp_m2 = 0; + float warp_count = 0; + WelfordWarpReduce(thread_mean, thread_m2, thread_count, &warp_mean, &warp_m2, &warp_count); + __syncthreads(); + + if (lid == 0) { + mean_shared[wid] = warp_mean; + m2_shared[wid] = warp_m2; + count_shared[wid] = warp_count; + } + __syncthreads(); + + if (wid == 0) { + if (threadIdx.x < blockDim.x / warpSize) { + warp_mean = mean_shared[lid]; + warp_m2 = m2_shared[lid]; + warp_count = count_shared[lid]; + + } else { + warp_mean = 0.f; + warp_m2 = 0.f; + warp_count = 0.f; + } + __syncwarp(); + + float block_mean = 0; + float block_m2 = 0; + float block_count = 0; + + WelfordWarpReduce(warp_mean, warp_m2, warp_count, &block_mean, &block_m2, &block_count); + + if (lid == 0) { + mean_result_broadcast = block_mean; + m2_result_broadcast = block_m2; + count_result_broadcast = block_count; + } + } + __syncthreads(); + *result_mean = mean_result_broadcast; + *result_m2 = m2_result_broadcast; + *result_count = count_result_broadcast; +} +// load 两个 half2, 保存到 float4 +__device__ void load_float4_from_half(float4 &vals, __half2 *input, int index) { + __half2 i1 = input[index * 2]; + __half2 i2 = input[index * 2 + 1]; + + vals.x = __half2float(i1.x); + vals.y = __half2float(i1.y); + vals.z = __half2float(i2.x); + vals.w = __half2float(i2.y); +} + +__device__ char4 float42char4(float4 vals, float quant_scale) { + char4 res; + res.x = float2int8(vals.x, quant_scale); + res.y = float2int8(vals.y, quant_scale); + res.z = float2int8(vals.z, quant_scale); + res.w = float2int8(vals.w, quant_scale); + return res; +} + +__device__ float4 char4addhalf2_dequant(char4 input_4, half2 residual_1, half2 residual_2, float dequant_scale) { + float4 res; + res.x = __int2float_rn(input_4.x) * dequant_scale + __half2float(residual_1.x); + res.y = __int2float_rn(input_4.y) * dequant_scale + __half2float(residual_1.y); + res.z = __int2float_rn(input_4.z) * dequant_scale + __half2float(residual_2.x); + res.w = __int2float_rn(input_4.w) * dequant_scale + __half2float(residual_2.y); + return res; +} + +__device__ float4 compute_float4_norm_value(float4 vals, float mean, float m2, int hidden_size, float epsilon, + half2 scale_1, half2 scale_2, half2 bias_1, half2 bias_2) { + float4 norm_value; + norm_value.x = + (vals.x - mean) * rsqrtf(m2 / hidden_size + epsilon) * __half2float(scale_1.x) + __half2float(bias_1.x); + norm_value.y = + (vals.y - mean) * rsqrtf(m2 / hidden_size + epsilon) * __half2float(scale_1.y) + __half2float(bias_1.y); + norm_value.z = + (vals.z - mean) * rsqrtf(m2 / hidden_size + epsilon) * __half2float(scale_2.x) + __half2float(bias_2.x); + norm_value.w = + (vals.w - mean) * rsqrtf(m2 / hidden_size + epsilon) * __half2float(scale_2.y) + __half2float(bias_2.y); + return norm_value; +} + +// softmax +__forceinline__ __host__ __device__ int log2_ceil(int value) { + int log2_value = 0; + while ((1 << log2_value) < value) ++log2_value; + return log2_value; +} +template +__device__ T WARP_SHFL_XOR(T value, int laneMask, int width) { + unsigned int mask = 0xffffffff; +#if !(defined(__HIP_PLATFORM_HCC__) || defined(__ILUVATAR__)) + return __shfl_xor_sync(mask, value, laneMask, width); +#else + return __shfl_xor(value, laneMask, width); +#endif +} + +template +struct Add { + __device__ T operator()(T a, T b) const { return a + b; } +}; + +template +struct Max { + __device__ T operator()(T a, T b) const { return a < b ? b : a; } +}; +template class ReduceOp> +__device__ void warp_reduce(acc_t *sum) { + ReduceOp r; +#pragma unroll + for (int offset = REDUCE_WARP_SIZE / 2; offset > 0; offset /= 2) { + acc_t b = WARP_SHFL_XOR(*sum, offset, REDUCE_WARP_SIZE); + *sum = r(*sum, b); + } +} + +__device__ void warp_argmax(float &value, int32_t &idx) { + for (int offset = warpSize / 2; offset > 0; offset /= 2) { + float next_value = WARP_SHFL_XOR(value, offset, warpSize); + float next_idx = WARP_SHFL_XOR(idx, offset, warpSize); + if (next_value > value) { + value = next_value; + idx = next_idx; + } + } +} + +// gelu +// IxinferBiasGeluI8II8OKernel +template +__device__ T tanhf_exp(T x) { + // float e1 = __expf(x); + // float e2 = 1.0f / e1; + // return (e1 - e2) / (e1 + e2); + + return (2.f / (1.f + __expf(-2.f * x)) - 1.f); +} + +template +__device__ T gelu(T x) { + float cdf = 0.5f * (1.0f + tanhf_exp((0.7978845608028654f * (x + 0.044715f * x * x * x)))); + return x * cdf; +} + +/* fp16 gelu */ +template <> +__forceinline__ __device__ __half2 gelu<__half2>(__half2 val) { + __half2 val_pow3 = __hmul2(val, __hmul2(val, val)); + float2 tmp_pow = __half22float2(val_pow3); + float2 tmp = __half22float2(val); + + tmp.x = 0.5f * (1.0f + tanhf((0.7978845608028654f * (tmp.x + 0.044715f * tmp_pow.x)))); + tmp.y = 0.5f * (1.0f + tanhf((0.7978845608028654f * (tmp.y + 0.044715f * tmp_pow.y)))); + return __hmul2(val, __float22half2_rn(tmp)); +} + +/* Convert vector index to 3-dim tensor index */ +__forceinline__ __host__ __device__ void decompose_3dim(int src, int dim1, int dim2, int *id0, int *id1, int *id2) { + *id2 = src % dim2; + src /= dim2; + + *id1 = src % dim1; + *id0 = src / dim1; +} + +} // namespace backend +} // namespace plugin +} // namespace nvinfer1 diff --git a/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/backend/transformer/transformer_softmax.cu b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/backend/transformer/transformer_softmax.cu new file mode 100644 index 0000000000000000000000000000000000000000..efe30993a586fecd9282e3a2eacbd82d8cf5a726 --- /dev/null +++ b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/backend/transformer/transformer_softmax.cu @@ -0,0 +1,494 @@ +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +* All Rights Reserved. +* +* Licensed under the Apache License, Version 2.0 (the "License"); you may +* not use this file except in compliance with the License. You may obtain +* a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +* License for the specific language governing permissions and limitations +* under the License. +*/ +#include "transformer_helper.cuh" +#include "transformer_softmax.h" +namespace nvinfer1 { +namespace plugin { +namespace backend { + +template +__global__ void IxinferCorrelationSoftmaxEncselfKernel(__half *correlation, const int *src_padding_mask, + const int batch_seq_len) { + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int SOFT_WARP_SIZE = (next_power_of_two < warpSize) ? next_power_of_two : warpSize; + constexpr int WARP_ITERATIONS = next_power_of_two / SOFT_WARP_SIZE; + + int head_num = blockDim.y; + int seq_len = gridDim.y; + int start_idx = (blockIdx.x * head_num * seq_len * batch_seq_len + threadIdx.y * seq_len * batch_seq_len + + blockIdx.y * batch_seq_len); + + half2 *p_correlation = (half2 *)(correlation + start_idx); + int32_t *p_mask = (int32_t *)(src_padding_mask + blockIdx.x * batch_seq_len); + + int local_idx = threadIdx.x; + + float2 elements[WARP_ITERATIONS]; +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; ++it) { + int element_index = local_idx + it * SOFT_WARP_SIZE; + if (element_index < batch_seq_len / 2) { + half2 correlation_value = p_correlation[element_index]; + + elements[it].x = + p_mask[element_index * 2] ? -std::numeric_limits::infinity() : __half2float(correlation_value.x); + elements[it].y = p_mask[element_index * 2 + 1] ? -std::numeric_limits::infinity() + : __half2float(correlation_value.y); + + } else { + elements[it].x = -std::numeric_limits::infinity(); + elements[it].y = -std::numeric_limits::infinity(); + } + } + + // compute max_value + float max_value = elements[0].x; + max_value = (max_value > elements[0].y) ? max_value : elements[0].y; + +#pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + max_value = (max_value > elements[it].x) ? max_value : elements[it].x; + max_value = (max_value > elements[it].y) ? max_value : elements[it].y; + } + + warp_reduce(&max_value); + + // exp sum + float sum = 0.0f; +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; ++it) { + elements[it].x = __expf(elements[it].x - max_value); + elements[it].y = __expf(elements[it].y - max_value); + + sum += (elements[it].x + elements[it].y); + } + + warp_reduce(&sum); + sum = 1.0f / sum; + // store result +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; ++it) { + int element_index = local_idx + it * SOFT_WARP_SIZE; + half2 correlation_value; + if (element_index < batch_seq_len / 2) { + correlation_value.x = __float2half(elements[it].x * sum); + correlation_value.y = __float2half(elements[it].y * sum); + + p_correlation[element_index] = correlation_value; + + } else { + break; + } + } +} + +void IxinferCorrelationSoftmaxEncself(int batch_size, int batch_seq_len, int head_num, cudaStream_t stream, + __half *correlation, const int *src_padding_mask) { + if (batch_seq_len > 4096) { + throw std::runtime_error("batch_seq_len should <= 4096"); + } + if (batch_seq_len % 2 != 0) { + throw std::runtime_error("batch_seq_len % 2 != 0"); + } + + int log2_elements = log2_ceil(batch_seq_len / 2); + int next_power_of_two = 1 << log2_elements; + int WARP_SIZE = (next_power_of_two < warpSize) ? next_power_of_two : warpSize; + + dim3 grid(batch_size, batch_seq_len); + + dim3 block(WARP_SIZE, head_num); + + switch (log2_elements) { + case 0: + IxinferCorrelationSoftmaxEncselfKernel<0> + <<>>(correlation, src_padding_mask, batch_seq_len); + break; + + case 1: + IxinferCorrelationSoftmaxEncselfKernel<1> + <<>>(correlation, src_padding_mask, batch_seq_len); + break; + + case 2: + IxinferCorrelationSoftmaxEncselfKernel<2> + <<>>(correlation, src_padding_mask, batch_seq_len); + break; + + case 3: + IxinferCorrelationSoftmaxEncselfKernel<3> + <<>>(correlation, src_padding_mask, batch_seq_len); + break; + + case 4: + IxinferCorrelationSoftmaxEncselfKernel<4> + <<>>(correlation, src_padding_mask, batch_seq_len); + break; + + case 5: + IxinferCorrelationSoftmaxEncselfKernel<5> + <<>>(correlation, src_padding_mask, batch_seq_len); + break; + + case 6: + IxinferCorrelationSoftmaxEncselfKernel<6> + <<>>(correlation, src_padding_mask, batch_seq_len); + break; + case 7: + IxinferCorrelationSoftmaxEncselfKernel<7> + <<>>(correlation, src_padding_mask, batch_seq_len); + break; + case 8: + IxinferCorrelationSoftmaxEncselfKernel<8> + <<>>(correlation, src_padding_mask, batch_seq_len); + break; + case 9: + IxinferCorrelationSoftmaxEncselfKernel<9> + <<>>(correlation, src_padding_mask, batch_seq_len); + break; + case 10: + IxinferCorrelationSoftmaxEncselfKernel<10> + <<>>(correlation, src_padding_mask, batch_seq_len); + break; + case 11: + IxinferCorrelationSoftmaxEncselfKernel<11> + <<>>(correlation, src_padding_mask, batch_seq_len); + break; + case 12: + IxinferCorrelationSoftmaxEncselfKernel<12> + <<>>(correlation, src_padding_mask, batch_seq_len); + break; + default: + throw std::runtime_error("IxinferCorrelationSoftmaxEncself NotImplementedError"); + break; + } +} + +template +__global__ void IxinferCorrelationSoftmaxEncselfKernel(__half *correlation, const int *src_padding_mask, + const int batch_seq_len, bool is_causal_mask) { + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int SOFT_WARP_SIZE = (next_power_of_two < warpSize) ? next_power_of_two : warpSize; + constexpr int WARP_ITERATIONS = next_power_of_two / SOFT_WARP_SIZE; + + int head_num = blockDim.y; + int seq_len = gridDim.y; + int start_idx = (blockIdx.x * head_num * seq_len * batch_seq_len + threadIdx.y * seq_len * batch_seq_len + + blockIdx.y * batch_seq_len); + + half2 *p_correlation = (half2 *)(correlation + start_idx); + int32_t *p_mask = (int32_t *)(src_padding_mask + blockIdx.x * batch_seq_len); + + int local_idx = threadIdx.x; + + float2 elements[WARP_ITERATIONS]; +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; ++it) { + int element_index = local_idx + it * SOFT_WARP_SIZE; + if (element_index < batch_seq_len / 2) { + half2 correlation_value = p_correlation[element_index]; + if (is_causal_mask) { //因果mask,下三角是0,上三角是1 + if (!p_mask[element_index * 2] && blockIdx.y >= element_index * 2) { + elements[it].x = __half2float(correlation_value.x); + } else { + elements[it].x = -std::numeric_limits::infinity(); + } + if (!p_mask[element_index * 2 + 1] && blockIdx.y >= element_index * 2 + 1) { + elements[it].y = __half2float(correlation_value.y); + } else { + elements[it].y = -std::numeric_limits::infinity(); + } + } else { + elements[it].x = p_mask[element_index * 2] ? -std::numeric_limits::infinity() + : __half2float(correlation_value.x); + elements[it].y = p_mask[element_index * 2 + 1] ? -std::numeric_limits::infinity() + : __half2float(correlation_value.y); + } + + } else { + elements[it].x = -std::numeric_limits::infinity(); + elements[it].y = -std::numeric_limits::infinity(); + } + } + + // compute max_value + float max_value = elements[0].x; + max_value = (max_value > elements[0].y) ? max_value : elements[0].y; + +#pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + max_value = (max_value > elements[it].x) ? max_value : elements[it].x; + max_value = (max_value > elements[it].y) ? max_value : elements[it].y; + } + + warp_reduce(&max_value); + + // exp sum + float sum = 0.0f; +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; ++it) { + elements[it].x = __expf(elements[it].x - max_value); + elements[it].y = __expf(elements[it].y - max_value); + + sum += (elements[it].x + elements[it].y); + } + + warp_reduce(&sum); + sum = 1.0f / sum; + // store result +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; ++it) { + int element_index = local_idx + it * SOFT_WARP_SIZE; + half2 correlation_value; + if (element_index < batch_seq_len / 2) { + correlation_value.x = __float2half(elements[it].x * sum); + correlation_value.y = __float2half(elements[it].y * sum); + + p_correlation[element_index] = correlation_value; + + } else { + break; + } + } +} + +void IxinferCorrelationSoftmaxEncself(int batch_size, int batch_seq_len, int head_num, cudaStream_t stream, + __half *correlation, const int *src_padding_mask, bool is_causal_mask) { + if (batch_seq_len > 1024) { + throw std::runtime_error("batch_seq_len should <= 1024"); + } + if (batch_seq_len % 2 != 0) { + throw std::runtime_error("batch_seq_len % 2 != 0"); + } + + int log2_elements = log2_ceil(batch_seq_len / 2); + int next_power_of_two = 1 << log2_elements; + int WARP_SIZE = (next_power_of_two < warpSize) ? next_power_of_two : warpSize; + + dim3 grid(batch_size, batch_seq_len); + + dim3 block(WARP_SIZE, head_num); + + switch (log2_elements) { + case 0: + IxinferCorrelationSoftmaxEncselfKernel<0> + <<>>(correlation, src_padding_mask, batch_seq_len, is_causal_mask); + break; + + case 1: + IxinferCorrelationSoftmaxEncselfKernel<1> + <<>>(correlation, src_padding_mask, batch_seq_len, is_causal_mask); + break; + + case 2: + IxinferCorrelationSoftmaxEncselfKernel<2> + <<>>(correlation, src_padding_mask, batch_seq_len, is_causal_mask); + break; + + case 3: + IxinferCorrelationSoftmaxEncselfKernel<3> + <<>>(correlation, src_padding_mask, batch_seq_len, is_causal_mask); + break; + + case 4: + IxinferCorrelationSoftmaxEncselfKernel<4> + <<>>(correlation, src_padding_mask, batch_seq_len, is_causal_mask); + break; + + case 5: + IxinferCorrelationSoftmaxEncselfKernel<5> + <<>>(correlation, src_padding_mask, batch_seq_len, is_causal_mask); + break; + + case 6: + IxinferCorrelationSoftmaxEncselfKernel<6> + <<>>(correlation, src_padding_mask, batch_seq_len, is_causal_mask); + break; + case 7: + IxinferCorrelationSoftmaxEncselfKernel<7> + <<>>(correlation, src_padding_mask, batch_seq_len, is_causal_mask); + break; + case 8: + IxinferCorrelationSoftmaxEncselfKernel<8> + <<>>(correlation, src_padding_mask, batch_seq_len, is_causal_mask); + break; + case 9: + IxinferCorrelationSoftmaxEncselfKernel<9> + <<>>(correlation, src_padding_mask, batch_seq_len, is_causal_mask); + break; + case 10: + IxinferCorrelationSoftmaxEncselfKernel<10> + <<>>(correlation, src_padding_mask, batch_seq_len, is_causal_mask); + break; + default: + throw std::runtime_error("IxinferCorrelationSoftmaxEncself NotImplementedError"); + break; + } +} + +template +__global__ void xSoftmaxKernel(__half *correlation, __half *out, const bool *mask, const int batch_seq_len, bool broadcast_two_dim) { + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int SOFT_WARP_SIZE = (next_power_of_two < warpSize) ? next_power_of_two : warpSize; + constexpr int WARP_ITERATIONS = next_power_of_two / SOFT_WARP_SIZE; + + int head_num = blockDim.y; + int seq_len = gridDim.y; + int start_idx = (blockIdx.x * head_num * seq_len * batch_seq_len + threadIdx.y * seq_len * batch_seq_len + + blockIdx.y * batch_seq_len); + + half2 *p_correlation = (half2 *)(correlation + start_idx); + half2 *p_out = (half2 *)(out + start_idx); + + /*[bsz,1,1,sequence*/ /*[bsz,1,sequence, sequence*/ + int32_t mask_index =broadcast_two_dim ? (blockIdx.x * batch_seq_len): (blockIdx.x * 1 * seq_len * batch_seq_len + blockIdx.y * batch_seq_len); + + + bool *p_mask = + (bool *)(mask + mask_index); // 1 mean broadcast + + int local_idx = threadIdx.x; + + float2 elements[WARP_ITERATIONS]; +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; ++it) { + int element_index = local_idx + it * SOFT_WARP_SIZE; + if (element_index < batch_seq_len / 2) { + half2 correlation_value = p_correlation[element_index]; + + elements[it].x = + p_mask[element_index * 2] ? -std::numeric_limits::infinity() : __half2float(correlation_value.x); + + elements[it].y = p_mask[element_index * 2 + 1] ? -std::numeric_limits::infinity() : __half2float(correlation_value.y); + + } else { + elements[it].x = -std::numeric_limits::infinity(); + elements[it].y = -std::numeric_limits::infinity(); + } + } + + // compute max_value + float max_value = elements[0].x; + max_value = (max_value > elements[0].y) ? max_value : elements[0].y; + +#pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + max_value = (max_value > elements[it].x) ? max_value : elements[it].x; + max_value = (max_value > elements[it].y) ? max_value : elements[it].y; + } + + warp_reduce(&max_value); + + // exp sum + float sum = 0.0f; +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; ++it) { + elements[it].x = __expf(elements[it].x - max_value); + elements[it].y = __expf(elements[it].y - max_value); + + sum += (elements[it].x + elements[it].y); + } + + warp_reduce(&sum); + sum = 1.0f / sum; + // store result +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; ++it) { + int element_index = local_idx + it * SOFT_WARP_SIZE; + half2 correlation_value; + if (element_index < batch_seq_len / 2) { + if (max_value == -std::numeric_limits::infinity()) { + p_out[element_index].x = 0.f; + p_out[element_index].y = 0.f; // all data masked + } else { + correlation_value.x = + p_mask[element_index * 2] ? __float2half(0.f) : __float2half(elements[it].x * sum); + correlation_value.y = + p_mask[element_index * 2 + 1] ? __float2half(0.f) : __float2half(elements[it].y * sum); + p_out[element_index] = correlation_value; + } + } else { + break; + } + } +} + +void xSoftmax(int batch_size, int batch_seq_len, int head_num, cudaStream_t stream, __half *correlation, __half *out, + const bool *mask, bool broadcast_two_dim) { + if (batch_seq_len > 1024) { + throw std::runtime_error("batch_seq_len should <= 1024"); + } + if (batch_seq_len % 2 != 0) { + throw std::runtime_error("batch_seq_len % 2 != 0"); + } + + int log2_elements = log2_ceil(batch_seq_len / 2); + int next_power_of_two = 1 << log2_elements; + int WARP_SIZE = (next_power_of_two < warpSize) ? next_power_of_two : warpSize; + + dim3 grid(batch_size, batch_seq_len); + + dim3 block(WARP_SIZE, head_num); + + switch (log2_elements) { + case 0: + xSoftmaxKernel<0><<>>(correlation, out, mask, batch_seq_len, broadcast_two_dim); + break; + + case 1: + xSoftmaxKernel<1><<>>(correlation, out, mask, batch_seq_len,broadcast_two_dim); + break; + + case 2: + xSoftmaxKernel<2><<>>(correlation, out, mask, batch_seq_len,broadcast_two_dim); + break; + + case 3: + xSoftmaxKernel<3><<>>(correlation, out, mask, batch_seq_len,broadcast_two_dim); + break; + + case 4: + xSoftmaxKernel<4><<>>(correlation, out, mask, batch_seq_len,broadcast_two_dim); + break; + + case 5: + xSoftmaxKernel<5><<>>(correlation, out, mask, batch_seq_len,broadcast_two_dim); + break; + + case 6: + xSoftmaxKernel<6><<>>(correlation, out, mask, batch_seq_len,broadcast_two_dim); + break; + case 7: + xSoftmaxKernel<7><<>>(correlation, out, mask, batch_seq_len,broadcast_two_dim); + break; + case 8: + xSoftmaxKernel<8><<>>(correlation, out, mask, batch_seq_len,broadcast_two_dim); + break; + case 9: + xSoftmaxKernel<9><<>>(correlation, out, mask, batch_seq_len,broadcast_two_dim); + break; + case 10: + xSoftmaxKernel<10><<>>(correlation, out, mask, batch_seq_len,broadcast_two_dim); + break; + default: + throw std::runtime_error("xSoftmaxKernel NotImplementedError"); + break; + } +} + +} // namespace backend +} // namespace plugin +} // namespace nvinfer1 diff --git a/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/backend/transformer/transformer_softmax.h b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/backend/transformer/transformer_softmax.h new file mode 100644 index 0000000000000000000000000000000000000000..dd5a8a1ed2850d631d3198ecc4a377445b394d63 --- /dev/null +++ b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/backend/transformer/transformer_softmax.h @@ -0,0 +1,47 @@ +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +* All Rights Reserved. +* +* Licensed under the Apache License, Version 2.0 (the "License"); you may +* not use this file except in compliance with the License. You may obtain +* a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +* License for the specific language governing permissions and limitations +* under the License. +*/ +#pragma once +#include +#include + +#include + +namespace nvinfer1 { +namespace plugin { +namespace backend { +/* +correlation: [batch_size,head_num,batch_seq_len,batch_seq_len] +src_padding_mask: [batch_size,batch_seq_len] 1表示paddding + +*/ +void IxinferCorrelationSoftmaxEncself(int batch_size, int batch_seq_len, int head_num, cudaStream_t stream, + half *correlation, const int *src_padding_mask); + +void IxinferCorrelationSoftmaxEncself(int batch_size, int batch_seq_len, int head_num, cudaStream_t stream, + __half *correlation, const int *src_padding_mask, bool is_causal_mask); + +/* +correlation: [batch_size,head_num,batch_seq_len,batch_seq_len] +mask: [batch_size,1, batch_seq_len,batch_seq_len] 数据 false 表示mask掉。dim1 support broadcast + +*/ + +void xSoftmax(int batch_size, int batch_seq_len, int head_num, cudaStream_t stream, __half *correlation, __half *out, + const bool *src_padding_mask,bool broadcast_two_dim); + +} // namespace backend +} // namespace plugin +} // namespace nvinfer1 diff --git a/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/common/bertCommon.h b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/common/bertCommon.h new file mode 100644 index 0000000000000000000000000000000000000000..d9ea59cbef2f50217ce004c811c4e2c8aeeae1b9 --- /dev/null +++ b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/common/bertCommon.h @@ -0,0 +1,222 @@ +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +* All Rights Reserved. +* +* Licensed under the Apache License, Version 2.0 (the "License"); you may +* not use this file except in compliance with the License. You may obtain +* a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +* License for the specific language governing permissions and limitations +* under the License. +*/ +#pragma once +#include + +#include +#include +#include +#include + +#include "NvInfer.h" +#include "NvInferRuntime.h" +#include "NvInferRuntimeCommon.h" +#include "checkMacrosPlugin.h" + +namespace nvinfer1::plugin { +namespace bert { + +constexpr uint32_t BDIM = 0; // batch dimension +constexpr uint32_t SDIM = 1; // seq len dimension +constexpr uint32_t HDIM = 2; // hidden dimension + +#define TRT_UNUSED (void) + +template +struct CudaDeleter { + void operator()(T* buf) { IXRT_PLUGIN_CUASSERT(cudaFree(buf)); } +}; + +template +using cuda_unique_ptr = std::unique_ptr>; + +inline uint32_t getElementSize(nvinfer1::DataType t) noexcept { + switch (t) { + case nvinfer1::DataType::kINT32: + return 4; + case nvinfer1::DataType::kFLOAT: + return 4; + case nvinfer1::DataType::kHALF: + return 2; + case nvinfer1::DataType::kBOOL: + case nvinfer1::DataType::kUINT8: + case nvinfer1::DataType::kINT8: + return 1; + case DataType::kUNKNOWN: + case DataType::kINT64: + case DataType::kFLOAT64: + break; + } + return 0; +} + +inline int64_t getWeightsSize(nvinfer1::Weights const& w, nvinfer1::DataType type) { + return w.count * getElementSize(type); +} + +template +using cuda_shared_ptr = std::shared_ptr; + +template +void make_cuda_shared(cuda_shared_ptr& ptr, void* cudaMem) { + ptr.reset(static_cast(cudaMem), bert::CudaDeleter()); +} + +struct WeightsWithOwnership : public nvinfer1::Weights { + ILogger* logger_; + WeightsWithOwnership() { + values = nullptr; + count = 0; + } + ~WeightsWithOwnership() { operator delete[](const_cast(values)); } + + WeightsWithOwnership(WeightsWithOwnership const&) = delete; + WeightsWithOwnership operator=(WeightsWithOwnership const&) = delete; + WeightsWithOwnership(WeightsWithOwnership const&&) = delete; + WeightsWithOwnership operator=(WeightsWithOwnership const&&) = delete; + + void convertAndCopy(nvinfer1::Weights const& src, nvinfer1::DataType type, float scale = 1) { + this->type = type; + this->count = src.count; + + if (type == nvinfer1::DataType::kFLOAT) { + auto destBuf = new float[src.count]; + this->values = destBuf; + + if (src.type == nvinfer1::DataType::kFLOAT) { + gLogInfo << "Float Weights(Host) => Float Array(Host)" << endl; + std::copy_n(static_cast(src.values), src.count, destBuf); + } else { + IXRT_PLUGIN_ASSERT(src.type == nvinfer1::DataType::kHALF); + + gLogInfo << "Half Weights(Host) => Float Array(Host)" << endl; + auto const s = static_cast(src.values); + auto d = static_cast(const_cast(this->values)); + + for (auto it = 0; it < src.count; it++) { + d[it] = __half2float(s[it]); + } + } + } else if (type == nvinfer1::DataType::kHALF) { + auto destBuf = new half[src.count]; + this->values = destBuf; + + if (src.type == nvinfer1::DataType::kHALF) { + gLogInfo << "Half Weights(Host) => Half Array(Host)" << endl; + std::copy_n(static_cast(src.values), src.count, destBuf); + } else { + IXRT_PLUGIN_ASSERT(src.type == nvinfer1::DataType::kFLOAT); + + gLogInfo << "Float Weights(Host) => Half Array(Host)" << endl; + auto const s = static_cast(src.values); + auto d = static_cast(const_cast(this->values)); + + for (auto it = 0; it < src.count; it++) { + d[it] = __float2half(s[it]); + } + } + } else if (type == nvinfer1::DataType::kINT8) { + auto destBuf = new int8_t[src.count]; + this->values = destBuf; + + if (src.type == nvinfer1::DataType::kFLOAT) { + gLogInfo << "Float Weights(Host) => Int8 Array(Host)" << endl; + auto const s = static_cast(src.values); + auto d = static_cast(const_cast(this->values)); + + for (auto it = 0; it < src.count; it++) { + int32_t v = static_cast(std::roundf(s[it] / scale)); + d[it] = v <= -127 ? -127 : (v >= 127 ? 127 : v); + } + } else if (src.type == nvinfer1::DataType::kINT8) { + gLogInfo << "Int8 Weights(Host) => Int8 Array(Host)" << endl; + std::copy_n(static_cast(src.values), src.count, destBuf); + } else { + throw std::runtime_error("Unsupported DataType specified for plugin."); + } + } else { + throw std::runtime_error("Unsupported DataType specified for plugin."); + } + } + + void convertAndCopy(char const*& srcBuf, size_t count, nvinfer1::DataType type) noexcept { + this->type = type; + this->count = count; + auto const nbBytes = getWeightsSize(*this, type); + auto destBuf = new char[nbBytes]; + this->values = destBuf; + + std::copy_n(srcBuf, nbBytes, destBuf); + srcBuf += nbBytes; + } +}; + +template +inline void copyToDevice(WeightsWithOwnership& hostWeights, size_t nbBytes, cuda_unique_ptr& cudaWeights) { + if (hostWeights.values) { + void* cudaMem{nullptr}; + IXRT_PLUGIN_CUASSERT(cudaMalloc(&cudaMem, nbBytes)); + IXRT_PLUGIN_CUASSERT(cudaMemcpy(cudaMem, hostWeights.values, nbBytes, cudaMemcpyHostToDevice)); + cudaWeights.reset(static_cast(cudaMem)); + } +} + +template +inline void serFromDev(char*& buffer, T const* data, size_t nbElem) { + const size_t len = sizeof(T) * nbElem; + IXRT_PLUGIN_CUASSERT(cudaMemcpy(buffer, static_cast(data), len, cudaMemcpyDeviceToHost)); + buffer += len; +} + +template +inline T* deserToDev(char const*& buffer, size_t nbElem) { + void* dev{nullptr}; + const size_t len = sizeof(T) * nbElem; + IXRT_PLUGIN_CUASSERT(cudaMalloc(&dev, len)); + IXRT_PLUGIN_CUASSERT(cudaMemcpy(dev, buffer, len, cudaMemcpyHostToDevice)); + + buffer += len; + return static_cast(dev); +} + +inline nvinfer1::DataType fieldTypeToDataType(const nvinfer1::PluginFieldType ftype) { + switch (ftype) { + case nvinfer1::PluginFieldType::kFLOAT32: { + gLogInfo << "PluginFieldType is Float32" << endl; + return nvinfer1::DataType::kFLOAT; + } + case nvinfer1::PluginFieldType::kFLOAT16: { + gLogInfo << "PluginFieldType is Float16" << endl; + return nvinfer1::DataType::kHALF; + } + case nvinfer1::PluginFieldType::kINT32: { + gLogInfo << "PluginFieldType is Int32" << endl; + return nvinfer1::DataType::kINT32; + } + case nvinfer1::PluginFieldType::kINT8: { + gLogInfo << "PluginFieldType is Int8" << endl; + return nvinfer1::DataType::kINT8; + } + default: + throw std::invalid_argument("No corresponding datatype for plugin field type"); + } +} + +inline int64_t volume(nvinfer1::Dims const& d) { + return std::accumulate(d.d, d.d + d.nbDims, int64_t{1}, std::multiplies{}); +} +} // namespace bert +} // namespace nvinfer1::plugin diff --git a/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/common/checkMacrosPlugin.cpp b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/common/checkMacrosPlugin.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7bad76bc923984a0dd5b35cefa57c31b6fa66e68 --- /dev/null +++ b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/common/checkMacrosPlugin.cpp @@ -0,0 +1,46 @@ +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +* All Rights Reserved. +* +* Licensed under the Apache License, Version 2.0 (the "License"); you may +* not use this file except in compliance with the License. You may obtain +* a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +* License for the specific language governing permissions and limitations +* under the License. +*/ +#include "checkMacrosPlugin.h" + +#include "NvInferRuntimeCommon.h" + +namespace nvinfer1 { +namespace plugin { + +ILogger* gLogger{nullptr}; + +template +int32_t LogStream::Buf::sync() { + std::string s = str(); + while (!s.empty() && s.back() == '\n') { + s.pop_back(); + } + if (gLogger != nullptr) { + gLogger->log(kSeverity, s.c_str()); + } + str(""); + return 0; +} + +// These use gLogger, and therefore require initLibNvInferPlugins() to be called with a logger +// (otherwise, it will not log) +LogStream gLogError; +LogStream gLogWarning; +LogStream gLogInfo; +LogStream gLogVerbose; + +} // namespace plugin +} // namespace nvinfer1 diff --git a/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/common/checkMacrosPlugin.h b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/common/checkMacrosPlugin.h new file mode 100644 index 0000000000000000000000000000000000000000..478e53ea2c46a4eca47a87c1d470e4c21c45671d --- /dev/null +++ b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/common/checkMacrosPlugin.h @@ -0,0 +1,205 @@ +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +* All Rights Reserved. +* +* Licensed under the Apache License, Version 2.0 (the "License"); you may +* not use this file except in compliance with the License. You may obtain +* a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +* License for the specific language governing permissions and limitations +* under the License. +*/ +#pragma once +#include + +#include +#include +#include +#include + +#include "NvInfer.h" +#include "NvInferRuntime.h" + +// Logs failed assertion and aborts. +// Aborting is undesirable and will be phased-out from the plugin module, at which point +// PLUGIN_ASSERT will perform the same function as PLUGIN_VALIDATE. +using namespace std; + +namespace nvinfer1 { +namespace plugin { + +#ifdef _MSC_VER +#define FN_NAME __FUNCTION__ +#else +#define FN_NAME __func__ +#endif + +#define IXRT_PLUGIN_CHECK_VALUE(value, msg) \ + { \ + if (not(value)) { \ + std::cerr << __FILE__ << " (" << __LINE__ << ")" \ + << "-" << __FUNCTION__ << " : " \ + << " Plugin assert error: " << msg << std::endl; \ + std::exit(EXIT_FAILURE); \ + } \ + } + +#define IXRT_PLUGIN_ASSERT(value) \ + { \ + if (not(value)) { \ + std::cerr << __FILE__ << " (" << __LINE__ << ")" \ + << "-" << __FUNCTION__ << " : " \ + << " Plugin assert false" << std::endl; \ + std::exit(EXIT_FAILURE); \ + } \ + } + +#define IXRT_PLUGIN_CHECK_CUDA(call) \ + do { \ + const cudaError_t error_code = call; \ + if (error_code != cudaSuccess) { \ + printf("CUDA Error:\n"); \ + printf(" File: %s\n", __FILE__); \ + printf(" Line: %d\n", __LINE__); \ + printf(" Error code: %d\n", error_code); \ + printf(" Error text: %s\n", cudaGetErrorString(error_code)); \ + exit(1); \ + } \ + } while (0) + +inline void caughtError(const std::exception& e) { std::cerr << e.what() << std::endl; } + +#define IXRT_PLUGIN_FAIL(msg) \ + do { \ + std::ostringstream stream; \ + stream << "Assertion failed: " << msg << "\n" \ + << __FILE__ << ':' << __LINE__ << "\n" \ + << "Aborting..." \ + << "\n"; \ + IXRT_PLUGIN_CHECK_CUDA(cudaDeviceReset()); \ + abort; \ + } while (0) + +inline void throwCudaError(char const* file, char const* function, int32_t line, int32_t status, char const* msg) { + std::cerr << file << " (" << line << ")" + << "-" << function << " : " << msg << std::endl; + std::exit(EXIT_FAILURE); +} + +#define IXRT_PLUGIN_CUASSERT(status_) \ + { \ + auto s_ = status_; \ + if (s_ != cudaSuccess) { \ + const char* msg = cudaGetErrorString(s_); \ + throwCudaError(__FILE__, FN_NAME, __LINE__, s_, msg); \ + } \ + } + +#undef CUINFER_CHECK +#define CUINFER_CHECK(func) \ + do { \ + cuinferStatus_t status = (func); \ + if (status != CUINFER_STATUS_SUCCESS) { \ + std::cerr << "Error in file " << __FILE__ << " on line " << __LINE__ << ": " \ + << cuinferGetErrorString(status) << std::endl; \ + std::exit(EXIT_FAILURE); \ + } \ + } while (0) + +static std::string _cudaGetErrorString(cublasStatus_t error) { + switch (error) { + case CUBLAS_STATUS_SUCCESS: + return "CUBLAS_STATUS_SUCCESS"; + + case CUBLAS_STATUS_NOT_INITIALIZED: + return "CUBLAS_STATUS_NOT_INITIALIZED"; + + case CUBLAS_STATUS_ALLOC_FAILED: + return "CUBLAS_STATUS_ALLOC_FAILED"; + + case CUBLAS_STATUS_INVALID_VALUE: + return "CUBLAS_STATUS_INVALID_VALUE"; + + case CUBLAS_STATUS_ARCH_MISMATCH: + return "CUBLAS_STATUS_ARCH_MISMATCH"; + + case CUBLAS_STATUS_MAPPING_ERROR: + return "CUBLAS_STATUS_MAPPING_ERROR"; + + case CUBLAS_STATUS_EXECUTION_FAILED: + return "CUBLAS_STATUS_EXECUTION_FAILED"; + + case CUBLAS_STATUS_INTERNAL_ERROR: + return "CUBLAS_STATUS_INTERNAL_ERROR"; + + case CUBLAS_STATUS_NOT_SUPPORTED: + return "CUBLAS_STATUS_NOT_SUPPORTED"; + + case CUBLAS_STATUS_LICENSE_ERROR: + return "CUBLAS_STATUS_LICENSE_ERROR"; + } + return "CUBLAS_UNKNOW"; +} + +template +void check_gpu_error(T result, char const* const func, const char* const file, int const line) { + if (result) { + throw std::runtime_error(std::string("[CUDA][ERROR] ") + +file + "(" + std::to_string(line) + + "): " + (_cudaGetErrorString(result)) + "\n"); + } +} + +#define CHECK_GPU_ERROR(val) check_gpu_error((val), #val, __FILE__, __LINE__) + +template +class LogStream : public std::ostream { + class Buf : public std::stringbuf { + public: + int32_t sync() override; + }; + + Buf buffer; + std::mutex mLogStreamMutex; + + public: + std::mutex& getMutex() { return mLogStreamMutex; } + LogStream() : std::ostream(&buffer){}; +}; + +// Use mutex to protect multi-stream write to buffer +template +LogStream& operator<<(LogStream& stream, T const& msg) { + std::lock_guard guard(stream.getMutex()); + auto& os = static_cast(stream); + os << msg; + return stream; +} + +// Special handling static numbers +template +inline LogStream& operator<<(LogStream& stream, int32_t num) { + std::lock_guard guard(stream.getMutex()); + auto& os = static_cast(stream); + os << num; + return stream; +} + +// Special handling std::endl +template +inline LogStream& operator<<(LogStream& stream, std::ostream& (*f)(std::ostream&)) { + std::lock_guard guard(stream.getMutex()); + auto& os = static_cast(stream); + // os << f; + return stream; +} + +extern LogStream gLogError; +extern LogStream gLogWarning; +extern LogStream gLogInfo; +extern LogStream gLogVerbose; +} // namespace plugin +} // namespace nvinfer1 diff --git a/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/common/common_def.cuh b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/common/common_def.cuh new file mode 100644 index 0000000000000000000000000000000000000000..21c0469424556076177453c24af96de5187d7718 --- /dev/null +++ b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/common/common_def.cuh @@ -0,0 +1,64 @@ +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +* All Rights Reserved. +* +* Licensed under the Apache License, Version 2.0 (the "License"); you may +* not use this file except in compliance with the License. You may obtain +* a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +* License for the specific language governing permissions and limitations +* under the License. +*/ +#pragma once + +#include + +#include +namespace nvinfer1::plugin { +#ifdef __ILUVATAR__ +static const int kMaxThreadNbPerBlock = 1024; +static const int kMaxBlockNbPerSM = 8; +static const int kWarpSize = 64; +static const dim3 kMaxBlockDimension = {4096, 4096, 64}; +static const dim3 kMaxGridDimension = {4294967295, 65536, 65536}; +static const int kNbThreadsPerBlockGainBestPerformance = 1024; +static const int kMaxSharedMemSizePerBlock = (128 * 1024 * 4); +static const int kNbSmemLane = 64; +static const int kNbBytesPerSmemLane = 4; +#else +static const int kMaxThreadNbPerBlock = 1024; +static const int kMaxBlockNbPerSM = 8; +static const int kWarpSize = 32; +static const dim3 kMaxBlockDimension = {1024, 1024, 64}; +static const dim3 kMaxGridDimension = {2147483647, 65535, 65535}; +static const int kNbThreadsPerBlockGainBestPerformance = 256; +static const int kMaxSharedMemSizePerBlock = 48 * 1024 * 4; +static const int kNbSmemLane = 32; +static const int kNbBytesPerSmemLane = 4; +#endif + +static const int kNbCe = 4; +static const int kNbCuPerCe = 4; +static const int kNbSppPerCu = 4; + +static const float kLog2e = 1.442695040888963387; + +#define DivUp(x, y) (((x) + (y)-1) / (y)) + +__device__ __forceinline__ float floatExp(float x) { return __builtin_exp2f(kLog2e * x); } + +__device__ __forceinline__ float floatLog(float x) { return __logf(x); } + +__forceinline__ int nearest_num(int x, int value) { + if (x % value == 0) { + return x; + } else { + int padding = value - x % value; + return x + padding; + } +} +} // namespace nvinfer1::plugin diff --git a/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/common/kernels/cuda_helper.cuh b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/common/kernels/cuda_helper.cuh new file mode 100644 index 0000000000000000000000000000000000000000..c23258bb4dfa4b6327f8dd6cec6a26c326ed81b5 --- /dev/null +++ b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/common/kernels/cuda_helper.cuh @@ -0,0 +1,23 @@ +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +* All Rights Reserved. +* +* Licensed under the Apache License, Version 2.0 (the "License"); you may +* not use this file except in compliance with the License. You may obtain +* a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +* License for the specific language governing permissions and limitations +* under the License. +*/ +#define DEVICE_FUNC __device__ __forceinline__ +namespace nvinfer1::plugin { +constexpr float LOG2E = 1.442695040888963387; + +DEVICE_FUNC float _exp(float x) { return __builtin_exp2f(LOG2E * x); } +DEVICE_FUNC float dequantize(int8_t x, float scale) { return scale * static_cast(x); } +DEVICE_FUNC float sigmoid(float x) { return __ivcorex_rcpf((1.f + _exp(0.f - x))); } +} // namespace nvinfer1::plugin \ No newline at end of file diff --git a/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/common/plugin.cpp b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/common/plugin.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0085b94b6dd0f7e0204ce4060cfd5c2a334af8e1 --- /dev/null +++ b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/common/plugin.cpp @@ -0,0 +1,47 @@ +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +* All Rights Reserved. +* +* Licensed under the Apache License, Version 2.0 (the "License"); you may +* not use this file except in compliance with the License. You may obtain +* a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +* License for the specific language governing permissions and limitations +* under the License. +*/ +#include "plugin.h" +#include "checkMacrosPlugin.h" + +namespace nvinfer1 +{ +namespace plugin +{ + +void validateRequiredAttributesExist(std::set requiredFieldNames, PluginFieldCollection const* fc) +{ + for (int32_t i = 0; i < fc->nbFields; i++) + { + requiredFieldNames.erase(fc->fields[i].name); + } + if (!requiredFieldNames.empty()) + { + std::stringstream msg{}; + msg << "PluginFieldCollection missing required fields: {"; + char const* separator = ""; + for (auto const& field : requiredFieldNames) + { + msg << separator << field; + separator = ", "; + } + msg << "}"; + std::string msg_str = msg.str(); + IXRT_PLUGIN_CHECK_VALUE(false, msg_str.c_str()); + } +} + +} // namespace plugin +} // namespace nvinfer1 \ No newline at end of file diff --git a/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/common/plugin.h b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/common/plugin.h new file mode 100644 index 0000000000000000000000000000000000000000..110da3522e155c2209b1feb1e9b3659a1c2a7088 --- /dev/null +++ b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/common/plugin.h @@ -0,0 +1,60 @@ +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +* All Rights Reserved. +* +* Licensed under the Apache License, Version 2.0 (the "License"); you may +* not use this file except in compliance with the License. You may obtain +* a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +* License for the specific language governing permissions and limitations +* under the License. +*/ +#pragma once + +#include +#include +#include +#include +#include +#include +#include "NvInferRuntimeCommon.h" + +typedef enum +{ + STATUS_SUCCESS = 0, + STATUS_FAILURE = 1, + STATUS_BAD_PARAM = 2, + STATUS_NOT_SUPPORTED = 3, + STATUS_NOT_INITIALIZED = 4 +} pluginStatus_t; + +namespace nvinfer1 { + +namespace plugin { + + +// Write values into buffer +template +void write(char*& buffer, const T& val) { + std::memcpy(buffer, &val, sizeof(T)); + buffer += sizeof(T); +} + +// Read values from buffer +template +T read(const char*& buffer) { + T val{}; + std::memcpy(&val, buffer, sizeof(T)); + buffer += sizeof(T); + return val; +} + +void validateRequiredAttributesExist(std::set requiredFieldNames, PluginFieldCollection const* fc); + +} // namespace plugin + +} // namespace nvinfer1 diff --git a/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/common/serialize.h b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/common/serialize.h new file mode 100644 index 0000000000000000000000000000000000000000..a2ac72d7f8b35e05626e20659b730b707b3e157a --- /dev/null +++ b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/common/serialize.h @@ -0,0 +1,132 @@ +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +* All Rights Reserved. +* +* Licensed under the Apache License, Version 2.0 (the "License"); you may +* not use this file except in compliance with the License. You may obtain +* a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +* License for the specific language governing permissions and limitations +* under the License. +*/ +#pragma once + +#include +#include +#include +#include + +#include +using std::cerr; +using std::cout; +using std::endl; + +template +inline void serialize_value(void** buffer, T const& value); + +template +inline void deserialize_value(void const** buffer, size_t* buffer_size, T* value); + +namespace +{ + +template +struct Serializer +{ +}; + +template +struct Serializer::value || std::is_enum::value || std::is_pod::value>::type> +{ + static size_t serialized_size(T const&) + { + return sizeof(T); + } + static void serialize(void** buffer, T const& value) + { + ::memcpy(*buffer, &value, sizeof(T)); + reinterpret_cast(*buffer) += sizeof(T); + } + static void deserialize(void const** buffer, size_t* buffer_size, T* value) + { + assert(*buffer_size >= sizeof(T)); + ::memcpy(value, *buffer, sizeof(T)); + reinterpret_cast(*buffer) += sizeof(T); + *buffer_size -= sizeof(T); + } +}; + +template <> +struct Serializer +{ + static size_t serialized_size(const char* value) + { + return strlen(value) + 1; + } + static void serialize(void** buffer, const char* value) + { + ::strcpy(static_cast(*buffer), value); + reinterpret_cast(*buffer) += strlen(value) + 1; + } + static void deserialize(void const** buffer, size_t* buffer_size, const char** value) + { + *value = static_cast(*buffer); + size_t data_size = strnlen(*value, *buffer_size) + 1; + assert(*buffer_size >= data_size); + reinterpret_cast(*buffer) += data_size; + *buffer_size -= data_size; + } +}; + +template +struct Serializer, + typename std::enable_if::value || std::is_enum::value || std::is_pod::value>::type> +{ + static size_t serialized_size(std::vector const& value) + { + return sizeof(value.size()) + value.size() * sizeof(T); + } + static void serialize(void** buffer, std::vector const& value) + { + serialize_value(buffer, value.size()); + size_t nbyte = value.size() * sizeof(T); + ::memcpy(*buffer, value.data(), nbyte); + reinterpret_cast(*buffer) += nbyte; + } + static void deserialize(void const** buffer, size_t* buffer_size, std::vector* value) + { + size_t size; + deserialize_value(buffer, buffer_size, &size); + value->resize(size); + size_t nbyte = value->size() * sizeof(T); + assert(*buffer_size >= nbyte); + ::memcpy(value->data(), *buffer, nbyte); + reinterpret_cast(*buffer) += nbyte; + *buffer_size -= nbyte; + } +}; + +} // namespace + +template +inline size_t serialized_size(T const& value) +{ + return Serializer::serialized_size(value); +} + +template +inline void serialize_value(void** buffer, T const& value) +{ + return Serializer::serialize(buffer, value); +} + +template +inline void deserialize_value(void const** buffer, size_t* buffer_size, T* value) +{ + return Serializer::deserialize(buffer, buffer_size, value); +} \ No newline at end of file diff --git a/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/custom_fc/fcInt8Plugin.cpp b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/custom_fc/fcInt8Plugin.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b9358f7d2ee1cc44a039f581489f9f3c6862bf9b --- /dev/null +++ b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/custom_fc/fcInt8Plugin.cpp @@ -0,0 +1,427 @@ +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +* All Rights Reserved. +* +* Licensed under the Apache License, Version 2.0 (the "License"); you may +* not use this file except in compliance with the License. You may obtain +* a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +* License for the specific language governing permissions and limitations +* under the License. +*/ +#include "NvInferRuntimeCommon.h" +#include "bertCommon.h" +#include "checkMacrosPlugin.h" +#include "cuda_runtime_api.h" +#include "driver_types.h" +#include "fcPlugin.h" +#include "plugin.h" +#include "serialize.h" +#ifdef __ILUVATAR__ +#include "backend/ixinfer/ixinfer_gemm_helper.h" +#else +#include "backend/cublas/cublas_helper.h" +#endif + +using namespace nvinfer1; +using namespace nvinfer1::plugin; +using namespace nvinfer1::plugin::bert; +using namespace nvinfer1::plugin::backend; + +namespace { +char const* const kFC_VERSION{"2"}; +char const* const kFC_NAME{"CustomFCPluginDynamic_IxRT"}; +} // namespace + +// Static class fields initialization +PluginFieldCollection FCInt8PluginDynamicCreator::mFC{}; +std::vector FCInt8PluginDynamicCreator::mPluginAttributes; + +FCInt8PluginDynamicCreator::FCInt8PluginDynamicCreator() { + mPluginAttributes.clear(); + mPluginAttributes.emplace_back(PluginField("out_dims", nullptr, PluginFieldType::kINT32, 1)); + mPluginAttributes.emplace_back(PluginField("W", nullptr, PluginFieldType::kINT8, 1)); + mPluginAttributes.emplace_back(PluginField("fc_amax", nullptr, PluginFieldType::kFLOAT32, 2)); + + mFC.nbFields = mPluginAttributes.size(); + mFC.fields = mPluginAttributes.data(); +} + +char const* FCInt8PluginDynamicCreator::getPluginName() const noexcept { return kFC_NAME; } + +char const* FCInt8PluginDynamicCreator::getPluginVersion() const noexcept { return kFC_VERSION; } + +PluginFieldCollection const* FCInt8PluginDynamicCreator::getFieldNames() noexcept { return &mFC; } + +IPluginV2* FCInt8PluginDynamicCreator::createPlugin(char const* name, PluginFieldCollection const* fc) noexcept { + try { + gLogInfo << "Creating FCInt8PluginDynamicCreator..." << endl; + IXRT_PLUGIN_ASSERT(name != nullptr); + IXRT_PLUGIN_ASSERT(fc != nullptr); + + int32_t outDims = 0; + Weights W{DataType::kINT8, nullptr, 0LL}; + Weights Bias{DataType::kFLOAT, nullptr, 0LL}; + plugin::validateRequiredAttributesExist({"out_dims", "W", "fc_amax"}, fc); + vector weight_scale; + + for (int32_t i = 0; i < fc->nbFields; i++) { + std::string fieldName(fc->fields[i].name); + if (fieldName.compare("out_dims") == 0) { + outDims = static_cast(fc->fields[i].data)[0]; + gLogInfo << "Building outDims: " << outDims << endl; + } + + if (fieldName.compare("W") == 0) { + gLogInfo << "Building W..." << endl; + W.values = fc->fields[i].data; + W.count = fc->fields[i].length; + W.type = fieldTypeToDataType(fc->fields[i].type); + gLogInfo << "Is W int8: " << (W.type == DataType::kINT8) << endl; + } + + if (fieldName.compare("Bias") == 0) { + gLogInfo << "Building Bias..." << endl; + Bias.values = fc->fields[i].data; + Bias.count = fc->fields[i].length; + Bias.type = fieldTypeToDataType(fc->fields[i].type); + gLogInfo << "Is Bias float32: " << (Bias.type == DataType::kFLOAT) << endl; + } + + if (fieldName.compare("fc_amax") == 0) { + gLogInfo << "Building fc_amax..." << endl; + for (auto j = 0; j < fc->fields[i].length; j++) { + auto value = static_cast(fc->fields[i].data)[j]; + weight_scale.emplace_back(value / 127); + } + } + } + + if (outDims <= 0) { + gLogInfo << "Invalid output dimension" << endl; + } + if (W.count == 0 || W.values == nullptr || W.count < outDims) { + gLogInfo << "Invalid weights" << endl; + } + + DataType type = DataType::kINT8; + return new FCInt8PluginDynamic(name, type, outDims, W, Bias, weight_scale); + } catch (std::exception const& e) { + caughtError(e); + } + return nullptr; +} + +IPluginV2* FCInt8PluginDynamicCreator::deserializePlugin(char const* name, void const* serialData, + size_t serialLength) noexcept { + // This object will be deleted when the network is destroyed, which will + // call FCInt8PluginDynamic::destroy() + try { + return new FCInt8PluginDynamic(name, serialData, serialLength); + } catch (std::exception const& e) { + caughtError(e); + } + return nullptr; +} + +void FCInt8PluginDynamicCreator::setPluginNamespace(char const* libNamespace) noexcept { + try { + IXRT_PLUGIN_ASSERT(libNamespace != nullptr); + mNamespace = libNamespace; + } catch (std::exception const& e) { + caughtError(e); + } +} + +char const* FCInt8PluginDynamicCreator::getPluginNamespace() const noexcept { return mNamespace.c_str(); } + +// REGISTER_TENSORRT_PLUGIN(FCInt8PluginDynamicCreator); +//#########################################################################// +FCInt8PluginDynamic::FCInt8PluginDynamic(std::string const name, DataType const type, int32_t const outDim, + Weights const& W, Weights const& Bias, vector const& scale) + : mLayerName(name), + mType(type), + mOutDim(outDim), + mNumParams(W.count), + mNmax(0), + mK(0), + mWdev(nullptr), + mNumBias(Bias.count), + mScale(scale), + mBiasdev(nullptr) { + if (W.type == nvinfer1::DataType::kFLOAT) { + float weight_max = std::numeric_limits::min(); + for (int64_t wb = 0, we = W.count; wb < we; ++wb) { + float val = static_cast(W.values)[wb]; + weight_max = std::max(weight_max, std::abs(val)); + } + // mWeightScale = 127 / weight_max; + } + + mW.convertAndCopy(W, DataType::kINT8, scale[0]); + copyToDevice(mW, getWeightsSize(mW, DataType::kINT8), mWdev); + if (Bias.values != nullptr) { + mBias.convertAndCopy(Bias, DataType::kFLOAT); + copyToDevice(mBias, getWeightsSize(mBias, DataType::kFLOAT), mBiasdev); + } +} + +FCInt8PluginDynamic::FCInt8PluginDynamic(std::string const name, void const* data, size_t length) + : mLayerName(name), mWdev(nullptr), mBiasdev(nullptr) { + gLogInfo << "FCInt8PluginDynamic deserialize" << endl; + + // Deserialize in the same order as serialization + deserialize_value(&data, &length, &mType); + deserialize_value(&data, &length, &mOutDim); + deserialize_value(&data, &length, &mNumParams); + deserialize_value(&data, &length, &mNmax); + deserialize_value(&data, &length, &mK); + deserialize_value(&data, &length, &mNumBias); + deserialize_value(&data, &length, &mScale); + + char const* d = static_cast(data); + + mW.convertAndCopy(d, mNumParams, DataType::kINT8); + copyToDevice(mW, getWeightsSize(mW, DataType::kINT8), mWdev); + if (mNumBias > 0) { + mBias.convertAndCopy(d, mNumBias, DataType::kFLOAT); + copyToDevice(mBias, getWeightsSize(mBias, DataType::kFLOAT), mBiasdev); + } +} + +// IPluginV2 Methods +char const* FCInt8PluginDynamic::getPluginType() const noexcept { return kFC_NAME; } + +char const* FCInt8PluginDynamic::getPluginVersion() const noexcept { return kFC_VERSION; } + +int32_t FCInt8PluginDynamic::getNbOutputs() const noexcept { return 1; } + +int32_t FCInt8PluginDynamic::initialize() noexcept { + gLogInfo << "FCInt8PluginDynamic initialize" << endl; + return 0; +} + +void FCInt8PluginDynamic::terminate() noexcept { gLogInfo << "FCInt8PluginDynamic terminate" << endl; } + +size_t FCInt8PluginDynamic::getSerializationSize() const noexcept { + return sizeof(mType) + sizeof(mOutDim) + sizeof(mNumParams) + sizeof(mNmax) + sizeof(mK) + sizeof(mNumBias) + + mScale.size() * sizeof(float) + sizeof(mScale.size()) + getElementSize(DataType::kINT8) * mNumParams + + getElementSize(DataType::kFLOAT) * mNumBias; +} + +void FCInt8PluginDynamic::serialize(void* buffer) const noexcept { + serialize_value(&buffer, mType); + serialize_value(&buffer, mOutDim); + serialize_value(&buffer, mNumParams); + serialize_value(&buffer, mNmax); + serialize_value(&buffer, mK); + serialize_value(&buffer, mNumBias); + serialize_value(&buffer, mScale); + + char* d = static_cast(buffer); + serFromDev(d, static_cast(mWdev.get()), mNumParams * getElementSize(DataType::kINT8)); + + if (mNumBias > 0) { + serFromDev(d, static_cast(mBiasdev.get()), mNumBias * getElementSize(DataType::kFLOAT)); + } +} + +void FCInt8PluginDynamic::destroy() noexcept { + gLogInfo << "FCInt8PluginDynamic destroy" << endl; + mWdev.reset(nullptr); + if (mNumBias > 0) { + mBiasdev.reset(nullptr); + } + delete this; +} + +void FCInt8PluginDynamic::setPluginNamespace(char const* libNamespace) noexcept { + try { + IXRT_PLUGIN_ASSERT(libNamespace != nullptr); + mNamespace = libNamespace; + } catch (std::exception const& e) { + caughtError(e); + } +} + +char const* FCInt8PluginDynamic::getPluginNamespace() const noexcept { return mNamespace.c_str(); } + +// IPluginV2Ext Methods +DataType FCInt8PluginDynamic::getOutputDataType(int32_t index, DataType const* inputTypes, + int32_t nbInputs) const noexcept { + IXRT_PLUGIN_ASSERT(index == 0); + IXRT_PLUGIN_ASSERT(nbInputs == 1); + IXRT_PLUGIN_ASSERT(inputTypes != nullptr); + IXRT_PLUGIN_ASSERT(inputTypes[0] == DataType::kINT8); + return inputTypes[0]; +} + +// IPluginV2DynamicExt Methods +IPluginV2DynamicExt* FCInt8PluginDynamic::clone() const noexcept { + try { + gLogInfo << "FCInt8PluginDynamic clone" << endl; + + auto* p = new FCInt8PluginDynamic(mLayerName, mType, mOutDim, mW, mBias, mScale); + p->setPluginNamespace(mNamespace.c_str()); + + return p; + } catch (std::exception const& e) { + caughtError(e); + } + return nullptr; +} + +DimsExprs FCInt8PluginDynamic::getOutputDimensions(int32_t outputIndex, DimsExprs const* inputs, int32_t nbInputs, + IExprBuilder& exprBuilder) noexcept { + try { + IXRT_PLUGIN_ASSERT(nbInputs == 1); + IXRT_PLUGIN_ASSERT(outputIndex == 0); + IXRT_PLUGIN_ASSERT(inputs != nullptr); + DimsExprs ret; + if (inputs[0].nbDims == 5){ + ret.nbDims = 5; // plugin support + ret.d[0] = inputs[0].d[0]; + ret.d[1] = inputs[0].d[1]; + ret.d[2] = exprBuilder.constant(mOutDim); + ret.d[3] = exprBuilder.constant(1); + ret.d[4] = exprBuilder.constant(1); + }else if(inputs[0].nbDims == 3){ + ret.nbDims = 3;//onnx support + ret.d[0] = inputs[0].d[0]; + ret.d[1] = inputs[0].d[1]; + ret.d[2] = exprBuilder.constant(mOutDim); + } + return ret; + + } catch (std::exception const& e) { + caughtError(e); + } + return DimsExprs{}; +} + +bool FCInt8PluginDynamic::supportsFormatCombination(int32_t pos, PluginTensorDesc const* inOut, int32_t nbInputs, + int32_t nbOutputs) noexcept { + IXRT_PLUGIN_ASSERT(nbInputs == 1); + IXRT_PLUGIN_ASSERT(nbOutputs == 1); + IXRT_PLUGIN_ASSERT(inOut != nullptr); + + PluginTensorDesc const& in = inOut[pos]; + if (pos == 0) { + return (in.type == mType) && (in.format == TensorFormat::kLINEAR); + } + PluginTensorDesc const& prev = inOut[pos - 1]; + + // output + return in.type == prev.type && in.format == prev.format; +} + +void FCInt8PluginDynamic::configurePlugin(DynamicPluginTensorDesc const* inputs, int32_t nbInputs, + DynamicPluginTensorDesc const* outputs, int32_t nbOutputs) noexcept { + try { + // Validate input arguments + IXRT_PLUGIN_ASSERT(nbOutputs == 1); + IXRT_PLUGIN_ASSERT(nbInputs == 1); + IXRT_PLUGIN_ASSERT(inputs != nullptr); + IXRT_PLUGIN_ASSERT(outputs != nullptr); + IXRT_PLUGIN_ASSERT(mType == inputs[0].desc.type); + auto const& inDims0 = inputs[0].desc.dims; + + IXRT_PLUGIN_ASSERT(inDims0.nbDims == 3 || inDims0.nbDims == 5); + mK = inDims0.d[HDIM]; // hiddensize + // IXRT_PLUGIN_ASSERT(hiddenSize * mOutDim == mNumParams); + // IXRT_PLUGIN_ASSERT(inDims0.d[3] == 1); + // IXRT_PLUGIN_ASSERT(inDims0.d[4] == 1); +#ifdef __ILUVATAR__ + CUINFER_CHECK(cuinferCreate(&cuinfer_handle)); +#else + CHECK_GPU_ERROR(cublasLtCreate(&blaslt_handle)); +#endif + } catch (std::exception const& e) { + caughtError(e); + } +} + +size_t FCInt8PluginDynamic::getWorkspaceSize(PluginTensorDesc const* inputs, int32_t nbInputs, + PluginTensorDesc const* outputs, int32_t nbOutputs) const noexcept { + int32_t const B = inputs[0].dims.d[BDIM]; + int32_t const S = inputs[0].dims.d[SDIM]; + int32_t const oE = outputs[0].dims.d[HDIM]; + if (mNumBias > 0) { + return B * S * oE * sizeof(int8_t); + } else { + return 0; + } +} + +int32_t FCInt8PluginDynamic::enqueue(PluginTensorDesc const* inputDesc, PluginTensorDesc const* outputDesc, + void const* const* inputs, void* const* outputs, void* workSpace, + cudaStream_t stream) noexcept { + gLogInfo << "in FCInt8PluginDynamic.." << endl; + try { +#ifdef __ILUVATAR__ + CUINFER_CHECK(cuinferSetStream(cuinfer_handle, stream)); +#endif + int32_t const S = inputDesc->dims.d[SDIM]; + int32_t const B = inputDesc->dims.d[BDIM]; + int32_t const E = inputDesc->dims.d[HDIM]; + int32_t const oE = outputDesc->dims.d[HDIM]; + int32_t const n = S * B; + IXRT_PLUGIN_ASSERT(n >= 0); + + float qkv_in_scale = inputDesc[0].scale; + float qkv_wei_scale = mScale[0]; + float output_scale = outputDesc[0].scale; + float qkv_out_scale; + if (mScale.size() == 2) { + qkv_out_scale = mScale[1]; + } else { + qkv_out_scale = output_scale; + } +#ifdef __ILUVATAR__ + int8_t* buffer = static_cast(workSpace); +#else + int32_t* buffer = static_cast(workSpace); +#endif + if (mType == DataType::kINT8) { + auto const* const input = static_cast(inputs[0]); + auto* output = static_cast(outputs[0]); + auto weight = static_cast(mWdev.get()); + + float dequant_scale = (qkv_in_scale * qkv_wei_scale) / qkv_out_scale; + + if (mBiasdev.get() != nullptr) { +#ifdef __ILUVATAR__ + cuinfer_i8_gemm(weight, input, nullptr, buffer, 1, oE, n, E, 0, 0, 0, dequant_scale, 0.0, 0, + cuinfer_handle, stream); + dequantGemmWithBias(buffer, static_cast(mBiasdev.get()), output, B * S, oE, qkv_out_scale, + 1.0 / output_scale, stream); +#else + cublaslt_gemm(weight, input, buffer, 1, oE, n, E, 0, 0, 0, 1, blaslt_handle, stream); + dequantGemmWithBias(buffer, static_cast(mBiasdev.get()), output, B * S, oE, dequant_scale, + qkv_out_scale, 1.0 / output_scale, stream); +#endif + } else { +#ifdef __ILUVATAR__ + cuinfer_i8_gemm(weight, input, nullptr, output, 1, oE, n, E, 0, 0, 0, dequant_scale, 0.0, 0, + cuinfer_handle, stream); +#else + cublaslt_gemm(weight, input, buffer, 1, oE, n, E, 0, 0, 0, 1, blaslt_handle, stream); + + quantGemm(buffer, output, B * S, oE, dequant_scale, stream); +#endif + } + } else { + gLogError << "Unsupported type error, expected [kINT8], but received " << static_cast(mType) + << endl; + return STATUS_FAILURE; + } + return STATUS_SUCCESS; + } catch (std::exception const& e) { + caughtError(e); + } + return STATUS_FAILURE; +} diff --git a/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/custom_fc/fcInt8Plugin.cu b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/custom_fc/fcInt8Plugin.cu new file mode 100644 index 0000000000000000000000000000000000000000..d2c1b0cd56b8d30007e2cc360ce5ce4a9f3f0a76 --- /dev/null +++ b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/custom_fc/fcInt8Plugin.cu @@ -0,0 +1,467 @@ +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +* All Rights Reserved. +* +* Licensed under the Apache License, Version 2.0 (the "License"); you may +* not use this file except in compliance with the License. You may obtain +* a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +* License for the specific language governing permissions and limitations +* under the License. +*/ +#include "backend/bert/bert_helper.h" +#include "fcPlugin.h" +using namespace nvinfer1::plugin::backend; +namespace nvinfer1::plugin { +namespace bert { +template +__global__ void dequant_gemm_without_bias(const int8_t* input, int8_t* output, int hidden_size, float dequant_scale, + float quant_scale, int num_per_tca) { + float4 val[THREAD_DATA_LEN]; + + int block_start = blockIdx.x * hidden_size; + input += block_start; + output += block_start; + + char4* p_input = (char4*)input; + char4* p_output = (char4*)output; + +#pragma unroll + for (int it = 0; it < THREAD_DATA_LEN; ++it) { + int element_index = threadIdx.x + it * num_per_tca; + + val[it].x = __int2float_rn(p_input[element_index].x) * dequant_scale; + val[it].y = __int2float_rn(p_input[element_index].y) * dequant_scale; + val[it].z = __int2float_rn(p_input[element_index].z) * dequant_scale; + val[it].w = __int2float_rn(p_input[element_index].w) * dequant_scale; + + char4 res = float42char4(val[it], quant_scale); + p_output[element_index] = res; + } +} + +template +__global__ void dequant_gemm_with_bias(const int8_t* input, const float* bias, int8_t* output, int hidden_size, + float dequant_scale, float quant_scale, int num_per_tca) { + float4 val[THREAD_DATA_LEN]; + + int block_start = blockIdx.x * hidden_size; + input += block_start; + output += block_start; + + char4* p_input = (char4*)input; + float4* p_bias = (float4*)bias; + char4* p_output = (char4*)output; + + float4 bias_val; +#pragma unroll + for (int it = 0; it < THREAD_DATA_LEN; ++it) { + int element_index = threadIdx.x + it * num_per_tca; + bias_val.x = p_bias[element_index].x; + bias_val.y = p_bias[element_index].y; + bias_val.z = p_bias[element_index].z; + bias_val.w = p_bias[element_index].w; + + val[it].x = __int2float_rn(p_input[element_index].x) * dequant_scale + bias_val.x; + val[it].y = __int2float_rn(p_input[element_index].y) * dequant_scale + bias_val.y; + val[it].z = __int2float_rn(p_input[element_index].z) * dequant_scale + bias_val.z; + val[it].w = __int2float_rn(p_input[element_index].w) * dequant_scale + bias_val.w; + + char4 res = float42char4(val[it], quant_scale); + p_output[element_index] = res; + } +} + +template +__global__ void dequant_gemm_with_bias(const int32_t* input, const float* bias, int8_t* output, int hidden_size, + float quant_scale1, float dequant_scale, float quant_scale2, int num_per_tca) { + float4 val[THREAD_DATA_LEN]; + + int block_start = blockIdx.x * hidden_size; + input += block_start; + output += block_start; + + int4* p_input = (int4*)input; + float4* p_bias = (float4*)bias; + char4* p_output = (char4*)output; + + float4 bias_val; +#pragma unroll + for (int it = 0; it < THREAD_DATA_LEN; ++it) { + int element_index = threadIdx.x + it * num_per_tca; + bias_val.x = p_bias[element_index].x; + bias_val.y = p_bias[element_index].y; + bias_val.z = p_bias[element_index].z; + bias_val.w = p_bias[element_index].w; + + char4 q_input; + q_input.x = float2int8(p_input[element_index].x * 1.0, quant_scale1); + q_input.y = float2int8(p_input[element_index].y * 1.0, quant_scale1); + q_input.z = float2int8(p_input[element_index].z * 1.0, quant_scale1); + q_input.w = float2int8(p_input[element_index].w * 1.0, quant_scale1); + + val[it].x = __int2float_rn(q_input.x) * dequant_scale + bias_val.x; + val[it].y = __int2float_rn(q_input.y) * dequant_scale + bias_val.y; + val[it].z = __int2float_rn(q_input.z) * dequant_scale + bias_val.z; + val[it].w = __int2float_rn(q_input.w) * dequant_scale + bias_val.w; + + char4 res = float42char4(val[it], quant_scale2); + p_output[element_index] = res; + } +} + +void dequantGemmWithoutBias(int8_t* input, int8_t* output, int batch_seq_len, int hidden_size, float dequant_scale, + float quant_scale, cudaStream_t stream) { + if (hidden_size > 4096) { + throw std::runtime_error("hidden_size should <= 4096"); + } + if (hidden_size / 4 % C10_WARP_SIZE != 0) { + throw std::runtime_error("hidden_size // C10_WARP_SIZE != 0"); + } + int num_per_tca = 64; + dim3 gridSize(batch_seq_len); + dim3 blockSize(num_per_tca); + + int num_warp = hidden_size / num_per_tca / 4; + + switch (num_warp) { + case 1: + dequant_gemm_without_bias<1><<>>(input, output, hidden_size, dequant_scale, + quant_scale, num_per_tca); + break; + case 2: + dequant_gemm_without_bias<2><<>>(input, output, hidden_size, dequant_scale, + quant_scale, num_per_tca); + break; + case 3: + dequant_gemm_without_bias<3><<>>(input, output, hidden_size, dequant_scale, + quant_scale, num_per_tca); + break; + case 4: + dequant_gemm_without_bias<4><<>>(input, output, hidden_size, dequant_scale, + quant_scale, num_per_tca); + break; + case 5: + dequant_gemm_without_bias<5><<>>(input, output, hidden_size, dequant_scale, + quant_scale, num_per_tca); + break; + case 6: + dequant_gemm_without_bias<6><<>>(input, output, hidden_size, dequant_scale, + quant_scale, num_per_tca); + break; + case 7: + dequant_gemm_without_bias<7><<>>(input, output, hidden_size, dequant_scale, + quant_scale, num_per_tca); + break; + case 8: + dequant_gemm_without_bias<8><<>>(input, output, hidden_size, dequant_scale, + quant_scale, num_per_tca); + break; + case 9: + dequant_gemm_without_bias<9><<>>(input, output, hidden_size, dequant_scale, + quant_scale, num_per_tca); + break; + case 10: + dequant_gemm_without_bias<10><<>>(input, output, hidden_size, dequant_scale, + quant_scale, num_per_tca); + break; + case 11: + dequant_gemm_without_bias<11><<>>(input, output, hidden_size, dequant_scale, + quant_scale, num_per_tca); + break; + case 12: + dequant_gemm_without_bias<12><<>>(input, output, hidden_size, dequant_scale, + quant_scale, num_per_tca); + break; + case 13: + dequant_gemm_without_bias<13><<>>(input, output, hidden_size, dequant_scale, + quant_scale, num_per_tca); + break; + case 14: + dequant_gemm_without_bias<14><<>>(input, output, hidden_size, dequant_scale, + quant_scale, num_per_tca); + break; + case 15: + dequant_gemm_without_bias<15><<>>(input, output, hidden_size, dequant_scale, + quant_scale, num_per_tca); + break; + case 16: + dequant_gemm_without_bias<16><<>>(input, output, hidden_size, dequant_scale, + quant_scale, num_per_tca); + break; + default: + throw std::runtime_error("dequantGemmWithoutBias"); + break; + } +} + +void dequantGemmWithBias(int8_t* input, float* bias, int8_t* output, int batch_seq_len, int hidden_size, + float dequant_scale, float quant_scale, cudaStream_t stream) { + if (hidden_size > 4096) { + throw std::runtime_error("hidden_size should <= 4096"); + } + if (hidden_size / 4 % C10_WARP_SIZE != 0) { + throw std::runtime_error("hidden_size // C10_WARP_SIZE != 0"); + } + int num_per_tca = 64; + dim3 gridSize(batch_seq_len); + dim3 blockSize(num_per_tca); + + int num_warp = hidden_size / num_per_tca / 4; + + switch (num_warp) { + case 1: + dequant_gemm_with_bias<1><<>>(input, bias, output, hidden_size, + dequant_scale, quant_scale, num_per_tca); + break; + case 2: + dequant_gemm_with_bias<2><<>>(input, bias, output, hidden_size, + dequant_scale, quant_scale, num_per_tca); + break; + case 3: + dequant_gemm_with_bias<3><<>>(input, bias, output, hidden_size, + dequant_scale, quant_scale, num_per_tca); + break; + case 4: + dequant_gemm_with_bias<4><<>>(input, bias, output, hidden_size, + dequant_scale, quant_scale, num_per_tca); + break; + case 5: + dequant_gemm_with_bias<5><<>>(input, bias, output, hidden_size, + dequant_scale, quant_scale, num_per_tca); + break; + case 6: + dequant_gemm_with_bias<6><<>>(input, bias, output, hidden_size, + dequant_scale, quant_scale, num_per_tca); + break; + case 7: + dequant_gemm_with_bias<7><<>>(input, bias, output, hidden_size, + dequant_scale, quant_scale, num_per_tca); + break; + case 8: + dequant_gemm_with_bias<8><<>>(input, bias, output, hidden_size, + dequant_scale, quant_scale, num_per_tca); + break; + case 9: + dequant_gemm_with_bias<9><<>>(input, bias, output, hidden_size, + dequant_scale, quant_scale, num_per_tca); + break; + case 10: + dequant_gemm_with_bias<10><<>>(input, bias, output, hidden_size, + dequant_scale, quant_scale, num_per_tca); + break; + case 11: + dequant_gemm_with_bias<11><<>>(input, bias, output, hidden_size, + dequant_scale, quant_scale, num_per_tca); + break; + case 12: + dequant_gemm_with_bias<12><<>>(input, bias, output, hidden_size, + dequant_scale, quant_scale, num_per_tca); + break; + case 13: + dequant_gemm_with_bias<13><<>>(input, bias, output, hidden_size, + dequant_scale, quant_scale, num_per_tca); + break; + case 14: + dequant_gemm_with_bias<14><<>>(input, bias, output, hidden_size, + dequant_scale, quant_scale, num_per_tca); + break; + case 15: + dequant_gemm_with_bias<15><<>>(input, bias, output, hidden_size, + dequant_scale, quant_scale, num_per_tca); + break; + case 16: + dequant_gemm_with_bias<16><<>>(input, bias, output, hidden_size, + dequant_scale, quant_scale, num_per_tca); + break; + default: + throw std::runtime_error("dequantGemmWithBias with int8_t input"); + break; + } +} + +void dequantGemmWithBias(int32_t* input, float* bias, int8_t* output, int batch_seq_len, int hidden_size, + float quant_scale1, float dequant_scale, float quant_scale2, cudaStream_t stream) { + if (hidden_size > 4096) { + throw std::runtime_error("hidden_size should <= 4096"); + } + if (hidden_size / 4 % C10_WARP_SIZE != 0) { + throw std::runtime_error("hidden_size // C10_WARP_SIZE != 0"); + } + int num_per_tca = 64; + dim3 gridSize(batch_seq_len); + dim3 blockSize(num_per_tca); + + int num_warp = hidden_size / num_per_tca / 4; + + switch (num_warp) { + case 1: + dequant_gemm_with_bias<1><<>>( + input, bias, output, hidden_size, quant_scale1, dequant_scale, quant_scale2, num_per_tca); + break; + case 2: + dequant_gemm_with_bias<2><<>>( + input, bias, output, hidden_size, quant_scale1, dequant_scale, quant_scale2, num_per_tca); + break; + case 3: + dequant_gemm_with_bias<3><<>>( + input, bias, output, hidden_size, quant_scale1, dequant_scale, quant_scale2, num_per_tca); + break; + case 4: + dequant_gemm_with_bias<4><<>>( + input, bias, output, hidden_size, quant_scale1, dequant_scale, quant_scale2, num_per_tca); + break; + case 5: + dequant_gemm_with_bias<5><<>>( + input, bias, output, hidden_size, quant_scale1, dequant_scale, quant_scale2, num_per_tca); + break; + case 6: + dequant_gemm_with_bias<6><<>>( + input, bias, output, hidden_size, quant_scale1, dequant_scale, quant_scale2, num_per_tca); + break; + case 7: + dequant_gemm_with_bias<7><<>>( + input, bias, output, hidden_size, quant_scale1, dequant_scale, quant_scale2, num_per_tca); + break; + case 8: + dequant_gemm_with_bias<8><<>>( + input, bias, output, hidden_size, quant_scale1, dequant_scale, quant_scale2, num_per_tca); + break; + case 9: + dequant_gemm_with_bias<9><<>>( + input, bias, output, hidden_size, quant_scale1, dequant_scale, quant_scale2, num_per_tca); + break; + case 10: + dequant_gemm_with_bias<10><<>>( + input, bias, output, hidden_size, quant_scale1, dequant_scale, quant_scale2, num_per_tca); + break; + case 11: + dequant_gemm_with_bias<11><<>>( + input, bias, output, hidden_size, quant_scale1, dequant_scale, quant_scale2, num_per_tca); + break; + case 12: + dequant_gemm_with_bias<12><<>>( + input, bias, output, hidden_size, quant_scale1, dequant_scale, quant_scale2, num_per_tca); + break; + case 13: + dequant_gemm_with_bias<13><<>>( + input, bias, output, hidden_size, quant_scale1, dequant_scale, quant_scale2, num_per_tca); + break; + case 14: + dequant_gemm_with_bias<14><<>>( + input, bias, output, hidden_size, quant_scale1, dequant_scale, quant_scale2, num_per_tca); + break; + case 15: + dequant_gemm_with_bias<15><<>>( + input, bias, output, hidden_size, quant_scale1, dequant_scale, quant_scale2, num_per_tca); + break; + case 16: + dequant_gemm_with_bias<16><<>>( + input, bias, output, hidden_size, quant_scale1, dequant_scale, quant_scale2, num_per_tca); + break; + default: + throw std::runtime_error("dequantGemmWithBias with int32_t input"); + break; + } +} + +template +__global__ void quant_gemm(const int32_t* input, int8_t* output, int hidden_size, float quant_scale, int num_per_tca) { + float4 val[THREAD_DATA_LEN]; + + int block_start = blockIdx.x * hidden_size; + input += block_start; + output += block_start; + + int4* p_input = (int4*)input; + char4* p_output = (char4*)output; + + float4 bias_val; +#pragma unroll + for (int it = 0; it < THREAD_DATA_LEN; ++it) { + int element_index = threadIdx.x + it * num_per_tca; + char4 q_input; + q_input.x = float2int8(p_input[element_index].x * 1.0, quant_scale); + q_input.y = float2int8(p_input[element_index].y * 1.0, quant_scale); + q_input.z = float2int8(p_input[element_index].z * 1.0, quant_scale); + q_input.w = float2int8(p_input[element_index].w * 1.0, quant_scale); + + p_output[element_index] = q_input; + } +} + +void quantGemm(int32_t* input, int8_t* output, int batch_seq_len, int hidden_size, float dequant_scale, + cudaStream_t stream) { + if (hidden_size > 4096) { + throw std::runtime_error("hidden_size should <= 4096"); + } + if (hidden_size / 4 % C10_WARP_SIZE != 0) { + throw std::runtime_error("hidden_size // C10_WARP_SIZE != 0"); + } + int num_per_tca = 64; + dim3 gridSize(batch_seq_len); + dim3 blockSize(num_per_tca); + + int num_warp = hidden_size / num_per_tca / 4; + + switch (num_warp) { + case 1: + quant_gemm<1><<>>(input, output, hidden_size, dequant_scale, num_per_tca); + break; + case 2: + quant_gemm<2><<>>(input, output, hidden_size, dequant_scale, num_per_tca); + break; + case 3: + quant_gemm<3><<>>(input, output, hidden_size, dequant_scale, num_per_tca); + break; + case 4: + quant_gemm<4><<>>(input, output, hidden_size, dequant_scale, num_per_tca); + break; + case 5: + quant_gemm<5><<>>(input, output, hidden_size, dequant_scale, num_per_tca); + break; + case 6: + quant_gemm<6><<>>(input, output, hidden_size, dequant_scale, num_per_tca); + break; + case 7: + quant_gemm<7><<>>(input, output, hidden_size, dequant_scale, num_per_tca); + break; + case 8: + quant_gemm<8><<>>(input, output, hidden_size, dequant_scale, num_per_tca); + break; + case 9: + quant_gemm<9><<>>(input, output, hidden_size, dequant_scale, num_per_tca); + break; + case 10: + quant_gemm<10><<>>(input, output, hidden_size, dequant_scale, num_per_tca); + break; + case 11: + quant_gemm<11><<>>(input, output, hidden_size, dequant_scale, num_per_tca); + break; + case 12: + quant_gemm<12><<>>(input, output, hidden_size, dequant_scale, num_per_tca); + break; + case 13: + quant_gemm<13><<>>(input, output, hidden_size, dequant_scale, num_per_tca); + break; + case 14: + quant_gemm<14><<>>(input, output, hidden_size, dequant_scale, num_per_tca); + break; + case 15: + quant_gemm<15><<>>(input, output, hidden_size, dequant_scale, num_per_tca); + break; + case 16: + quant_gemm<16><<>>(input, output, hidden_size, dequant_scale, num_per_tca); + break; + default: + throw std::runtime_error("quantGemm"); + break; + } +} + +} // namespace bert +} // namespace nvinfer1::plugin diff --git a/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/custom_fc/fcPlugin.cpp b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/custom_fc/fcPlugin.cpp new file mode 100644 index 0000000000000000000000000000000000000000..26115bb3ae9921b32f7e71eda999e22850377961 --- /dev/null +++ b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/custom_fc/fcPlugin.cpp @@ -0,0 +1,406 @@ +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +* All Rights Reserved. +* +* Licensed under the Apache License, Version 2.0 (the "License"); you may +* not use this file except in compliance with the License. You may obtain +* a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +* License for the specific language governing permissions and limitations +* under the License. +*/ +#include "fcPlugin.h" + +#include "NvInferRuntimeCommon.h" +#include "backend/ixinfer/ixinfer_gemm_helper.h" +#include "bertCommon.h" +#include "checkMacrosPlugin.h" +#include "plugin.h" +#include "serialize.h" + +using namespace nvinfer1; +using namespace nvinfer1::plugin; +using namespace nvinfer1::plugin::bert; +using namespace nvinfer1::plugin::backend; + +namespace { +char const* const kFC_VERSION{"1"}; +char const* const kFC_NAME{"CustomFCPluginDynamic_IxRT"}; +} // namespace + +// Static class fields initialization +PluginFieldCollection FCPluginDynamicCreator::mFC{}; +std::vector FCPluginDynamicCreator::mPluginAttributes; + +FCPluginDynamicCreator::FCPluginDynamicCreator() { + mPluginAttributes.clear(); + mPluginAttributes.emplace_back(PluginField("out_dims", nullptr, PluginFieldType::kINT32, 1)); + mPluginAttributes.emplace_back(PluginField("type_id", nullptr, PluginFieldType::kINT32, 1)); + mPluginAttributes.emplace_back(PluginField("W", nullptr, PluginFieldType::kFLOAT32, 1)); + mPluginAttributes.emplace_back(PluginField("B", nullptr, PluginFieldType::kFLOAT32, 1)); + mPluginAttributes.emplace_back(PluginField("act_type", nullptr, PluginFieldType::kINT32, 1)); + mPluginAttributes.emplace_back(PluginField("swish_alpha", nullptr, PluginFieldType::kFLOAT32, 1)); + + mFC.nbFields = mPluginAttributes.size(); + mFC.fields = mPluginAttributes.data(); +} + +char const* FCPluginDynamicCreator::getPluginName() const noexcept { return kFC_NAME; } + +char const* FCPluginDynamicCreator::getPluginVersion() const noexcept { return kFC_VERSION; } + +PluginFieldCollection const* FCPluginDynamicCreator::getFieldNames() noexcept { return &mFC; } + +IPluginV2* FCPluginDynamicCreator::createPlugin(char const* name, PluginFieldCollection const* fc) noexcept { + try { + gLogInfo << "Creating FCPluginDynamicCreator..." << endl; + IXRT_PLUGIN_ASSERT(name != nullptr); + IXRT_PLUGIN_ASSERT(fc != nullptr); + + int32_t outDims = 0; + int32_t typeId = -1; + int32_t act_type = -1; + Weights W{DataType::kFLOAT, nullptr, 0LL}; + Weights B{DataType::kFLOAT, nullptr, 0LL}; + plugin::validateRequiredAttributesExist({"out_dims", "type_id", "W"}, fc); + float swish_alpha = 1.7020000219345093; + for (int32_t i = 0; i < fc->nbFields; i++) { + std::string fieldName(fc->fields[i].name); + if (fieldName.compare("out_dims") == 0) { + outDims = static_cast(fc->fields[i].data)[0]; + gLogInfo << "Building outDims: " << outDims << endl; + } + + if (fieldName.compare("type_id") == 0) { + typeId = static_cast(fc->fields[i].data)[0]; + gLogInfo << "Building typeId: " << typeId << endl; + } + + if (fieldName.compare("W") == 0) { + gLogInfo << "Building W..." << endl; + W.values = fc->fields[i].data; + W.count = fc->fields[i].length; + W.type = fieldTypeToDataType(fc->fields[i].type); + gLogInfo << "Is W float32: " << (W.type == DataType::kFLOAT) << endl; + } + + if (fieldName.compare("B") == 0) { + gLogInfo << "Building B..." << endl; + B.values = fc->fields[i].data; + B.count = fc->fields[i].length; + B.type = fieldTypeToDataType(fc->fields[i].type); + gLogInfo << "Is B float32: " << (B.type == DataType::kFLOAT) << endl; + } + + if (fieldName.compare("act_type") == 0) { + if (fc->fields[i].data != nullptr) { + act_type = static_cast(fc->fields[i].data)[0]; + } + gLogInfo << "Building act_type: " << act_type << endl; + } + if (fieldName.compare("swish_alpha") == 0) { + if (fc->fields[i].data != nullptr) { + swish_alpha = static_cast(fc->fields[i].data)[0]; + } + } + } + + if (outDims <= 0) { + gLogInfo << "Invalid output dimension" << endl; + } + if (typeId < 0 || typeId > 1) { + gLogInfo << "Invalid type id" << typeId << endl; + } + if (W.count == 0 || W.values == nullptr || W.count < outDims) { + gLogInfo << "Invalid weights" << endl; + } + + DataType type = typeId == 0 ? DataType::kFLOAT : DataType::kHALF; + return new FCPluginDynamic(name, type, outDims, act_type, W, B, swish_alpha); + } catch (std::exception const& e) { + caughtError(e); + } + return nullptr; +} + +IPluginV2* FCPluginDynamicCreator::deserializePlugin(char const* name, void const* serialData, + size_t serialLength) noexcept { + // This object will be deleted when the network is destroyed, which will + // call FCPluginDynamic::destroy() + try { + return new FCPluginDynamic(name, serialData, serialLength); + } catch (std::exception const& e) { + caughtError(e); + } + return nullptr; +} + +void FCPluginDynamicCreator::setPluginNamespace(char const* libNamespace) noexcept { + try { + IXRT_PLUGIN_ASSERT(libNamespace != nullptr); + mNamespace = libNamespace; + } catch (std::exception const& e) { + caughtError(e); + } +} + +char const* FCPluginDynamicCreator::getPluginNamespace() const noexcept { return mNamespace.c_str(); } + +// REGISTER_TENSORRT_PLUGIN(FCPluginDynamicCreator); +//#########################################################################// +FCPluginDynamic::FCPluginDynamic(std::string const name, DataType const type, int32_t const outDim, + int32_t const act_type, Weights const& W, Weights const& B, float alpha) + : mLayerName(name), + mType(type), + mOutDim(outDim), + mActType(act_type), + mNumParams(W.count), + mNumBias(B.count), + mWdev(nullptr), + mBdev(nullptr), + mSwishAlpha(alpha) { + mW.convertAndCopy(W, mType); + copyToDevice(mW, getWeightsSize(mW, mType), mWdev); + if (mNumBias) { + mB.convertAndCopy(B, mType); + copyToDevice(mB, getWeightsSize(mB, mType), mBdev); + } +} + +FCPluginDynamic::FCPluginDynamic(std::string const name, void const* data, size_t length) + : mLayerName(name), mWdev(nullptr) { + gLogInfo << "FCPluginDynamic deserialize" << endl; + + // Deserialize in the same order as serialization + deserialize_value(&data, &length, &mType); + deserialize_value(&data, &length, &mOutDim); + deserialize_value(&data, &length, &mActType); + deserialize_value(&data, &length, &mNumParams); + deserialize_value(&data, &length, &mNumBias); + deserialize_value(&data, &length, &mSwishAlpha); + + char const* d = static_cast(data); + + mW.convertAndCopy(d, mNumParams, mType); + copyToDevice(mW, getWeightsSize(mW, mType), mWdev); + if (mNumBias) { + mB.convertAndCopy(d, mNumBias, mType); + copyToDevice(mB, getWeightsSize(mB, mType), mBdev); + } +} + +// IPluginV2 Methods +char const* FCPluginDynamic::getPluginType() const noexcept { return kFC_NAME; } + +char const* FCPluginDynamic::getPluginVersion() const noexcept { return kFC_VERSION; } + +int32_t FCPluginDynamic::getNbOutputs() const noexcept { return 1; } + +int32_t FCPluginDynamic::initialize() noexcept { + gLogInfo << "FCPluginDynamic initialize" << endl; + return 0; +} + +void FCPluginDynamic::terminate() noexcept { gLogInfo << "FCPluginDynamic terminate" << endl; } + +size_t FCPluginDynamic::getSerializationSize() const noexcept { + size_t wordSize = getElementSize(mType); + return wordSize * (mNumParams + mNumBias) + sizeof(mType) + sizeof(mOutDim) + sizeof(mActType) + + sizeof(mNumParams) + sizeof(mNumBias) + sizeof(mSwishAlpha); +} + +void FCPluginDynamic::serialize(void* buffer) const noexcept { + serialize_value(&buffer, mType); + serialize_value(&buffer, mOutDim); + serialize_value(&buffer, mActType); + serialize_value(&buffer, mNumParams); + serialize_value(&buffer, mNumBias); + serialize_value(&buffer, mSwishAlpha); + + size_t wordSize = getElementSize(mType); + char* d = static_cast(buffer); + serFromDev(d, static_cast(mWdev.get()), mNumParams * wordSize); + if (mNumBias) { + serFromDev(d, static_cast(mBdev.get()), mNumBias * wordSize); + } +} + +void FCPluginDynamic::destroy() noexcept { + gLogInfo << "FCPluginDynamic destroy" << endl; + mWdev.reset(nullptr); + if (mNumBias) { + mBdev.reset(nullptr); + } + delete this; +} + +void FCPluginDynamic::setPluginNamespace(char const* libNamespace) noexcept { + try { + IXRT_PLUGIN_ASSERT(libNamespace != nullptr); + mNamespace = libNamespace; + } catch (std::exception const& e) { + caughtError(e); + } +} + +char const* FCPluginDynamic::getPluginNamespace() const noexcept { return mNamespace.c_str(); } + +// IPluginV2Ext Methods +DataType FCPluginDynamic::getOutputDataType(int32_t index, DataType const* inputTypes, + int32_t nbInputs) const noexcept { + IXRT_PLUGIN_ASSERT(index == 0); + IXRT_PLUGIN_ASSERT(nbInputs == 1); + IXRT_PLUGIN_ASSERT(inputTypes != nullptr); + IXRT_PLUGIN_ASSERT(inputTypes[0] == DataType::kFLOAT || inputTypes[0] == DataType::kHALF); + return inputTypes[0]; +} + +// IPluginV2DynamicExt Methods +IPluginV2DynamicExt* FCPluginDynamic::clone() const noexcept { + try { + gLogInfo << "FCPluginDynamic clone" << endl; + + auto* p = new FCPluginDynamic(mLayerName, mType, mOutDim, mActType, mW, mB, mSwishAlpha); + p->setPluginNamespace(mNamespace.c_str()); + + return p; + } catch (std::exception const& e) { + caughtError(e); + } + return nullptr; +} + +DimsExprs FCPluginDynamic::getOutputDimensions(int32_t outputIndex, DimsExprs const* inputs, int32_t nbInputs, + IExprBuilder& exprBuilder) noexcept { + try { + IXRT_PLUGIN_ASSERT(nbInputs == 1); + IXRT_PLUGIN_ASSERT(outputIndex == 0); + IXRT_PLUGIN_ASSERT(inputs != nullptr); + DimsExprs ret; + if (inputs[0].nbDims == 5) { + ret.nbDims = 5; // plugin support + ret.d[0] = inputs[0].d[0]; + ret.d[1] = inputs[0].d[1]; + ret.d[2] = exprBuilder.constant(mOutDim); + ret.d[3] = exprBuilder.constant(1); + ret.d[4] = exprBuilder.constant(1); + } else if (inputs[0].nbDims == 3) { + ret.nbDims = 3; // onnx support + ret.d[0] = inputs[0].d[0]; + ret.d[1] = inputs[0].d[1]; + ret.d[2] = exprBuilder.constant(mOutDim); + } else if (inputs[0].nbDims == 2) { + ret.nbDims = 2; + ret.d[0] = inputs[0].d[0]; + ret.d[1] = exprBuilder.constant(mOutDim); + } else if (inputs[0].nbDims == 4) { + ret.nbDims = 4; + ret.d[0] = inputs[0].d[0]; + ret.d[1] = inputs[0].d[1]; + ret.d[2] = inputs[0].d[2]; + ret.d[3] = exprBuilder.constant(mOutDim); + } else { + std::cerr << "CustomFC doesn't support input nb_dim=" << inputs[0].nbDims << std::endl; + IXRT_PLUGIN_ASSERT(false); + } + + return ret; + } catch (std::exception const& e) { + caughtError(e); + } + return DimsExprs{}; +} + +bool FCPluginDynamic::supportsFormatCombination(int32_t pos, PluginTensorDesc const* inOut, int32_t nbInputs, + int32_t nbOutputs) noexcept { + IXRT_PLUGIN_ASSERT(nbInputs == 1); + IXRT_PLUGIN_ASSERT(nbOutputs == 1); + IXRT_PLUGIN_ASSERT(inOut != nullptr); + + PluginTensorDesc const& in = inOut[pos]; + if (pos == 0) { + return (in.type == mType) && (in.format == TensorFormat::kLINEAR); + } + PluginTensorDesc const& prev = inOut[pos - 1]; + + // output + return in.type == prev.type && in.format == prev.format; +} + +void FCPluginDynamic::configurePlugin(DynamicPluginTensorDesc const* inputs, int32_t nbInputs, + DynamicPluginTensorDesc const* outputs, int32_t nbOutputs) noexcept { + try { + // Validate input arguments + IXRT_PLUGIN_ASSERT(nbOutputs == 1); + IXRT_PLUGIN_ASSERT(nbInputs == 1); + IXRT_PLUGIN_ASSERT(inputs != nullptr); + IXRT_PLUGIN_ASSERT(outputs != nullptr); + IXRT_PLUGIN_ASSERT(mType == inputs[0].desc.type); + auto const& inDims0 = inputs[0].desc.dims; + + IXRT_PLUGIN_ASSERT(inDims0.nbDims == 2 || inDims0.nbDims == 3 || inDims0.nbDims == 5); + // IXRT_PLUGIN_ASSERT(inDims0.d[3] == 1); + // IXRT_PLUGIN_ASSERT(inDims0.d[4] == 1); +#ifdef __ILUVATAR__ + CUINFER_CHECK(cuinferCreate(&cuinfer_handle)); +#else + CHECK_GPU_ERROR(cublasLtCreate(&blaslt_handle)); +#endif + } catch (std::exception const& e) { + caughtError(e); + } +} + +size_t FCPluginDynamic::getWorkspaceSize(PluginTensorDesc const* inputs, int32_t nbInputs, + PluginTensorDesc const* outputs, int32_t nbOutputs) const noexcept { + return 0; +} + +int32_t FCPluginDynamic::enqueue(PluginTensorDesc const* inputDesc, PluginTensorDesc const* outputDesc, + void const* const* inputs, void* const* outputs, void* workSpace, + cudaStream_t stream) noexcept { + gLogInfo << "in FCPluginDynamic.." << endl; + try { +#ifdef __ILUVATAR__ + CUINFER_CHECK(cuinferSetStream(cuinfer_handle, stream)); +#endif + int32_t n = 1; + int32_t k = 1; + int32_t m = mOutDim; + int32_t floop = inputDesc->dims.nbDims == 5 ? 3 : inputDesc->dims.nbDims; + for (int i = 0; i < floop - 1; ++i) { + n *= inputDesc->dims.d[i]; + } + k = inputDesc->dims.d[floop - 1]; + IXRT_PLUGIN_ASSERT(n >= 0); + + if (mType == DataType::kHALF) { + auto const* const input = static_cast(inputs[0]); + auto* output = static_cast(outputs[0]); + auto weight = static_cast(mWdev.get()); + half* bias = nullptr; + if (mNumBias) { + bias = static_cast(mBdev.get()); + } + +#ifdef __ILUVATAR__ + cuinfer_gemm(weight, input, bias, output, 1, m, n, k, 0, 0, 0, 1.0f, mActType, stream, cuinfer_handle, + mSwishAlpha); +#else + cublaslt_gemm(weight, input, output, 1, m, n, k, 0, 0, 0, 1.0f, blaslt_handle, stream); +#endif + } else { + gLogError << "Unsupported type error, expected [kHALF,kFLOAT], but received " << static_cast(mType) + << endl; + return STATUS_FAILURE; + } + return STATUS_SUCCESS; + } catch (std::exception const& e) { + caughtError(e); + } + return STATUS_FAILURE; +} diff --git a/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/custom_fc/fcPlugin.h b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/custom_fc/fcPlugin.h new file mode 100644 index 0000000000000000000000000000000000000000..b1bc35202493008de4c1ee534d3e314102939025 --- /dev/null +++ b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/custom_fc/fcPlugin.h @@ -0,0 +1,225 @@ +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +* All Rights Reserved. +* +* Licensed under the Apache License, Version 2.0 (the "License"); you may +* not use this file except in compliance with the License. You may obtain +* a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +* License for the specific language governing permissions and limitations +* under the License. +*/ +#pragma once +#include + +#include + +#include "NvInferRuntime.h" +#include "NvInferRuntimeCommon.h" +#include "backend/cublas/cublas_helper.h" +#include "bertCommon.h" + +namespace nvinfer1::plugin { +namespace bert { + +void quantGemm(int32_t* input, int8_t* output, int batch_seq_len, int hidden_size, float dequant_scale, + cudaStream_t stream); + +void dequantGemmWithBias(int32_t* input, float* bias, int8_t* output, int batch_seq_len, int hidden_size, + float dequant_scale1, float dequant_scale2, float quant_scale, cudaStream_t stream); + +void dequantGemmWithBias(int8_t* input, float* bias, int8_t* output, int batch_seq_len, int hidden_size, + float dequant_scale, float quant_scale, cudaStream_t stream); + +void dequantGemmWithoutBias(int8_t* input, int8_t* output, int batch_seq_len, int hidden_size, float dequant_scale, + float quant_scale, cudaStream_t stream); + +class FCPluginDynamic : public nvinfer1::IPluginV2DynamicExt { + public: + FCPluginDynamic(std::string const name, nvinfer1::DataType const type, int32_t const outDim, int32_t const act_type, + nvinfer1::Weights const& W, nvinfer1::Weights const& B, float alpha); + + FCPluginDynamic(std::string const name, void const* data, size_t length); + + // It doesn't make sense to make FCPluginDynamic without arguments, so we + // delete default constructor. + FCPluginDynamic() = delete; + + // IPluginV2 Methods + char const* getPluginType() const noexcept override; + char const* getPluginVersion() const noexcept override; + int32_t getNbOutputs() const noexcept override; + int32_t initialize() noexcept override; + void terminate() noexcept override; + size_t getSerializationSize() const noexcept override; + void serialize(void* buffer) const noexcept override; + void destroy() noexcept override; + void setPluginNamespace(char const* pluginNamespace) noexcept override; + char const* getPluginNamespace() const noexcept override; + + // IPluginV2Ext Methods + nvinfer1::DataType getOutputDataType(int32_t index, nvinfer1::DataType const* inputTypes, + int32_t nbInputs) const noexcept override; + + // IPluginV2DynamicExt Methods + nvinfer1::IPluginV2DynamicExt* clone() const noexcept override; + nvinfer1::DimsExprs getOutputDimensions(int32_t outputIndex, nvinfer1::DimsExprs const* inputs, int32_t nbInputs, + nvinfer1::IExprBuilder& exprBuilder) noexcept override; + bool supportsFormatCombination(int32_t pos, nvinfer1::PluginTensorDesc const* inOut, int32_t nbInputs, + int32_t nbOutputs) noexcept override; + void configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int32_t nbInputs, + nvinfer1::DynamicPluginTensorDesc const* out, int32_t nbOutputs) noexcept override; + size_t getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int32_t nbInputs, + nvinfer1::PluginTensorDesc const* outputs, int32_t nbOutputs) const noexcept override; + int32_t enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc, + void const* const* inputs, void* const* outputs, void* workspace, + cudaStream_t stream) noexcept override; + + private: + std::string const mLayerName; + std::string mNamespace; + + nvinfer1::DataType mType; + size_t mActType; + size_t mOutDim; // leading dim + size_t mNumParams; + size_t mNumBias; + + bert::WeightsWithOwnership mW; + bert::cuda_unique_ptr mWdev; + bert::WeightsWithOwnership mB; + bert::cuda_unique_ptr mBdev; + float mSwishAlpha; + +#ifdef __ILUVATAR__ + cuinferHandle_t cuinfer_handle; +#else + cublasLtHandle_t blaslt_handle; +#endif + cudaStream_t stream; +}; + +class FCPluginDynamicCreator : public nvinfer1::IPluginCreator { + public: + FCPluginDynamicCreator(); + + char const* getPluginName() const noexcept override; + + char const* getPluginVersion() const noexcept override; + + nvinfer1::PluginFieldCollection const* getFieldNames() noexcept override; + + nvinfer1::IPluginV2* createPlugin(char const* name, nvinfer1::PluginFieldCollection const* fc) noexcept override; + + nvinfer1::IPluginV2* deserializePlugin(char const* name, void const* serialData, + size_t serialLength) noexcept override; + + void setPluginNamespace(char const* pluginNamespace) noexcept override; + + char const* getPluginNamespace() const noexcept override; + + private: + static nvinfer1::PluginFieldCollection mFC; + static std::vector mPluginAttributes; + std::string mNamespace; +}; + +class FCInt8PluginDynamic : public nvinfer1::IPluginV2DynamicExt { + public: + FCInt8PluginDynamic(std::string const name, nvinfer1::DataType const type, int32_t const outDim, + nvinfer1::Weights const& W, nvinfer1::Weights const& Bias, vector const& scale); + + FCInt8PluginDynamic(std::string const name, void const* data, size_t length); + + // It doesn't make sense to make FCInt8PluginDynamic without arguments, so we + // delete default constructor. + FCInt8PluginDynamic() = delete; + + // IPluginV2 Methods + char const* getPluginType() const noexcept override; + char const* getPluginVersion() const noexcept override; + int32_t getNbOutputs() const noexcept override; + int32_t initialize() noexcept override; + void terminate() noexcept override; + size_t getSerializationSize() const noexcept override; + void serialize(void* buffer) const noexcept override; + void destroy() noexcept override; + void setPluginNamespace(char const* pluginNamespace) noexcept override; + char const* getPluginNamespace() const noexcept override; + + // IPluginV2Ext Methods + nvinfer1::DataType getOutputDataType(int32_t index, nvinfer1::DataType const* inputTypes, + int32_t nbInputs) const noexcept override; + + // IPluginV2DynamicExt Methods + nvinfer1::IPluginV2DynamicExt* clone() const noexcept override; + nvinfer1::DimsExprs getOutputDimensions(int32_t outputIndex, nvinfer1::DimsExprs const* inputs, int32_t nbInputs, + nvinfer1::IExprBuilder& exprBuilder) noexcept override; + bool supportsFormatCombination(int32_t pos, nvinfer1::PluginTensorDesc const* inOut, int32_t nbInputs, + int32_t nbOutputs) noexcept override; + void configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int32_t nbInputs, + nvinfer1::DynamicPluginTensorDesc const* out, int32_t nbOutputs) noexcept override; + size_t getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int32_t nbInputs, + nvinfer1::PluginTensorDesc const* outputs, int32_t nbOutputs) const noexcept override; + int32_t enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc, + void const* const* inputs, void* const* outputs, void* workspace, + cudaStream_t stream) noexcept override; + + private: + std::string const mLayerName; + std::string mNamespace; + + nvinfer1::DataType mType; + size_t mOutDim; // leading dim + size_t mNumParams; + int32_t mNmax; + int32_t mK; + int32_t mNumBias; + + vector mScale; + + bert::WeightsWithOwnership mW; + bert::cuda_unique_ptr mWdev; + + bert::WeightsWithOwnership mBias; + bert::cuda_unique_ptr mBiasdev; + +#ifdef __ILUVATAR__ + cuinferHandle_t cuinfer_handle; +#else + cublasLtHandle_t blaslt_handle; +#endif + cudaStream_t stream; +}; + +class FCInt8PluginDynamicCreator : public nvinfer1::IPluginCreator { + public: + FCInt8PluginDynamicCreator(); + + char const* getPluginName() const noexcept override; + + char const* getPluginVersion() const noexcept override; + + nvinfer1::PluginFieldCollection const* getFieldNames() noexcept override; + + nvinfer1::IPluginV2* createPlugin(char const* name, nvinfer1::PluginFieldCollection const* fc) noexcept override; + + nvinfer1::IPluginV2* deserializePlugin(char const* name, void const* serialData, + size_t serialLength) noexcept override; + + void setPluginNamespace(char const* pluginNamespace) noexcept override; + + char const* getPluginNamespace() const noexcept override; + + private: + static nvinfer1::PluginFieldCollection mFC; + static std::vector mPluginAttributes; + std::string mNamespace; +}; + +} // namespace bert +} // namespace nvinfer1::plugin diff --git a/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/emb_layernorm/embLayerNormInt8Plugin.cpp b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/emb_layernorm/embLayerNormInt8Plugin.cpp new file mode 100644 index 0000000000000000000000000000000000000000..21c1348771e0cabbc3fc2b4ee75c6b521a1ebeba --- /dev/null +++ b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/emb_layernorm/embLayerNormInt8Plugin.cpp @@ -0,0 +1,504 @@ +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +* All Rights Reserved. +* +* Licensed under the Apache License, Version 2.0 (the "License"); you may +* not use this file except in compliance with the License. You may obtain +* a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +* License for the specific language governing permissions and limitations +* under the License. +*/ +#include "embLayerNormInt8Plugin.h" + +#include "NvInferImpl.h" +#include "NvInferRuntimeCommon.h" +#include "checkMacrosPlugin.h" +#include "common_def.cuh" +#include "driver_types.h" +#include "plugin.h" +#include "serialize.h" + +using namespace nvinfer1; +using namespace nvinfer1::plugin; +using namespace nvinfer1::plugin::bert; + +namespace { +char const* EMB_LAYER_NORM_INT8_VERSION{"2"}; +char const* EMB_LAYER_NORM_INT8_NAME{"CustomEmbLayerNormPluginDynamic_IxRT"}; +} // namespace + +// Static class fields initialization +PluginFieldCollection EmbLayerNormInt8PluginDynamicCreator::mFC{}; +std::vector EmbLayerNormInt8PluginDynamicCreator::mPluginAttributes; + +EmbLayerNormInt8PluginDynamicCreator::EmbLayerNormInt8PluginDynamicCreator() { + mPluginAttributes.clear(); + mPluginAttributes.emplace_back(PluginField("bert_embeddings_layernorm_beta")); + mPluginAttributes.emplace_back(PluginField("bert_embeddings_layernorm_gamma")); + mPluginAttributes.emplace_back(PluginField("bert_embeddings_word_embeddings")); + mPluginAttributes.emplace_back(PluginField("bert_embeddings_token_type_embeddings")); + mPluginAttributes.emplace_back(PluginField("bert_embeddings_position_embeddings")); + mPluginAttributes.emplace_back(PluginField("output_fp16")); + mPluginAttributes.emplace_back(PluginField("full_mask")); + mPluginAttributes.emplace_back(PluginField("mha_type_id")); + mPluginAttributes.emplace_back(PluginField("pad_id")); + mFC.nbFields = mPluginAttributes.size(); + mFC.fields = mPluginAttributes.data(); +} + +char const* EmbLayerNormInt8PluginDynamicCreator::getPluginName() const noexcept { return EMB_LAYER_NORM_INT8_NAME; } + +char const* EmbLayerNormInt8PluginDynamicCreator::getPluginVersion() const noexcept { + return EMB_LAYER_NORM_INT8_VERSION; +} + +PluginFieldCollection const* EmbLayerNormInt8PluginDynamicCreator::getFieldNames() noexcept { return &mFC; } + +IPluginV2DynamicExt* EmbLayerNormInt8PluginDynamicCreator::createPlugin(char const* name, + PluginFieldCollection const* fc) noexcept { + try { + IXRT_PLUGIN_ASSERT(fc != nullptr); + gLogInfo << "EmbLayerNormInt8PluginDynamic createPlugin." << endl; + std::set const requiredAttributes{ + "bert_embeddings_layernorm_beta", "bert_embeddings_layernorm_gamma", + "bert_embeddings_word_embeddings", "bert_embeddings_token_type_embeddings", + "bert_embeddings_position_embeddings", + }; + + bool output_fp16 = false; + bool useFullMask = false; + Weights beta{}; + Weights gamma{}; + Weights word_emb{}; + Weights pos_emb{}; + Weights tok_emb{}; + int32_t mhaTypeId = 0; + int32_t pad_id = 0; + + for (auto i = 0; i < fc->nbFields; i++) { + std::string field_name(fc->fields[i].name); + if (field_name.compare("bert_embeddings_layernorm_beta") == 0) { + gLogInfo << "Building bert_embeddings_layernorm_beta..." << endl; + beta.values = fc->fields[i].data; + beta.count = fc->fields[i].length; + beta.type = fieldTypeToDataType(fc->fields[i].type); + } + + if (field_name.compare("bert_embeddings_layernorm_gamma") == 0) { + gLogInfo << "Building bert_embeddings_layernorm_gamma..." << endl; + gamma.values = fc->fields[i].data; + gamma.count = fc->fields[i].length; + gamma.type = fieldTypeToDataType(fc->fields[i].type); + } + + if (field_name.compare("bert_embeddings_word_embeddings") == 0) { + gLogInfo << "Building bert_embeddings_word_embeddings..." << endl; + word_emb.values = fc->fields[i].data; + word_emb.count = fc->fields[i].length; + word_emb.type = fieldTypeToDataType(fc->fields[i].type); + } + + if (field_name.compare("bert_embeddings_token_type_embeddings") == 0) { + gLogInfo << "Building bert_embeddings_token_type_embeddings..." << endl; + tok_emb.values = fc->fields[i].data; + tok_emb.count = fc->fields[i].length; + tok_emb.type = fieldTypeToDataType(fc->fields[i].type); + } + + if (field_name.compare("bert_embeddings_position_embeddings") == 0) { + gLogInfo << "Building bert_embeddings_position_embeddings..." << endl; + pos_emb.values = fc->fields[i].data; + pos_emb.count = fc->fields[i].length; + pos_emb.type = fieldTypeToDataType(fc->fields[i].type); + } + + if (field_name.compare("output_fp16") == 0) { + IXRT_PLUGIN_ASSERT(fc->fields[i].type == PluginFieldType::kINT32); + output_fp16 = static_cast(fc->fields[i].data)[0] != 0; + gLogInfo << "Building output_fp16: " << output_fp16 << endl; + } + + if (field_name.compare("full_mask") == 0) { + IXRT_PLUGIN_ASSERT(fc->fields[i].type == PluginFieldType::kINT32); + useFullMask = static_cast(fc->fields[i].data)[0] != 0; + gLogInfo << "Building full_mask: " << useFullMask << endl; + } + + if (field_name.compare("mha_type_id") == 0) { + mhaTypeId = *static_cast(fc->fields[i].data); + IXRT_PLUGIN_ASSERT(mhaTypeId >= 0 && mhaTypeId < 3); + gLogInfo << "Building mha typeId: " << mhaTypeId << endl; + } + + if (field_name.compare("pad_id") == 0) { + IXRT_PLUGIN_ASSERT(fc->fields[i].type == PluginFieldType::kINT32) + pad_id = *static_cast(fc->fields[i].data); + } + } + gLogInfo << "Building EmbLayerNormInt8PluginDynamic Plugin..." << endl; + DataType mhaType = static_cast(mhaTypeId); + EmbLayerNormInt8PluginDynamic* p = + new EmbLayerNormInt8PluginDynamic(name, output_fp16 ? DataType::kHALF : DataType::kFLOAT, mhaType, beta, + gamma, word_emb, pos_emb, tok_emb, useFullMask, pad_id); + + return p; + } catch (std::exception const& e) { + caughtError(e); + } + return nullptr; +} + +IPluginV2DynamicExt* EmbLayerNormInt8PluginDynamicCreator::deserializePlugin(char const* name, void const* serialData, + size_t serialLength) noexcept { + try { + IXRT_PLUGIN_ASSERT(serialData != nullptr); + return new EmbLayerNormInt8PluginDynamic(name, serialData, serialLength); + } catch (std::exception const& e) { + caughtError(e); + } + return nullptr; +} + +void EmbLayerNormInt8PluginDynamicCreator::setPluginNamespace(char const* pluginNamespace) noexcept { + try { + IXRT_PLUGIN_ASSERT(pluginNamespace != nullptr); + mNamespace = pluginNamespace; + } catch (std::exception const& e) { + caughtError(e); + } +} + +char const* EmbLayerNormInt8PluginDynamicCreator::getPluginNamespace() const noexcept { return mNamespace.c_str(); } + +// REGISTER_TENSORRT_PLUGIN(EmbLayerNormInt8PluginDynamicCreator); + +//#########################################################################// +EmbLayerNormInt8PluginDynamic::EmbLayerNormInt8PluginDynamic(std::string const& name, DataType const type, + DataType const mhaType, Weights const& beta, + Weights const& gamma, Weights const& wordEmb, + Weights const& posEmb, Weights const& tokEmb, + bool const useFullMask, int32_t padId) + : mLayerName(name), + mHiddenSize(beta.count), + mEmbType(type), + mUseFullMask(useFullMask), + mMhaType(mhaType), + mPadId(padId) { + IXRT_PLUGIN_ASSERT(beta.count == gamma.count); + IXRT_PLUGIN_ASSERT(mHiddenSize > 0U); + IXRT_PLUGIN_ASSERT(wordEmb.count % mHiddenSize == 0); + IXRT_PLUGIN_ASSERT(posEmb.count % mHiddenSize == 0); + IXRT_PLUGIN_ASSERT(tokEmb.count % mHiddenSize == 0); + mWordVocabSize = wordEmb.count / mHiddenSize; + mPosVocabSize = posEmb.count / mHiddenSize; + mTokVocabSize = tokEmb.count / mHiddenSize; + + mBeta.convertAndCopy(beta, nvinfer1::DataType::kHALF); + mGamma.convertAndCopy(gamma, nvinfer1::DataType::kHALF); + mWordEmb.convertAndCopy(wordEmb, mEmbType); + mTokEmb.convertAndCopy(tokEmb, mEmbType); + mPosEmb.convertAndCopy(posEmb, mEmbType); + + copyToDevice(mGamma, sizeof(half) * mGamma.count, mGammaDev); + copyToDevice(mBeta, sizeof(half) * mBeta.count, mBetaDev); + copyToDevice(mWordEmb, getWeightsSize(mWordEmb, mEmbType), mWordEmbDev); + copyToDevice(mPosEmb, getWeightsSize(mPosEmb, mEmbType), mPosEmbDev); + copyToDevice(mTokEmb, getWeightsSize(mTokEmb, mEmbType), mTokEmbDev); +} + +EmbLayerNormInt8PluginDynamic::EmbLayerNormInt8PluginDynamic(std::string const& name, void const* data, size_t length) + : mLayerName(name), + mGammaDev(nullptr), + mBetaDev(nullptr), + mWordEmbDev(nullptr), + mTokEmbDev(nullptr), + mPosEmbDev(nullptr) { + gLogInfo << "EmbLayerNormInt8PluginDynamic deserialize." << endl; + + // Deserialize in the same order as serialization + deserialize_value(&data, &length, &mEmbType); + deserialize_value(&data, &length, &mMhaType); + deserialize_value(&data, &length, &mHiddenSize); + deserialize_value(&data, &length, &mSeqLen); + deserialize_value(&data, &length, &mPadId); + deserialize_value(&data, &length, &mWordVocabSize); + deserialize_value(&data, &length, &mPosVocabSize); + deserialize_value(&data, &length, &mTokVocabSize); + deserialize_value(&data, &length, &mUseFullMask); + + char const* d = static_cast(data); + mBeta.convertAndCopy(d, mHiddenSize, nvinfer1::DataType::kHALF); + mGamma.convertAndCopy(d, mHiddenSize, nvinfer1::DataType::kHALF); + mWordEmb.convertAndCopy(d, mHiddenSize * mWordVocabSize, mEmbType); + mPosEmb.convertAndCopy(d, mHiddenSize * mPosVocabSize, mEmbType); + mTokEmb.convertAndCopy(d, mHiddenSize * mTokVocabSize, mEmbType); + + copyToDevice(mGamma, sizeof(half) * mGamma.count, mGammaDev); + copyToDevice(mBeta, sizeof(half) * mBeta.count, mBetaDev); + copyToDevice(mWordEmb, getWeightsSize(mWordEmb, mEmbType), mWordEmbDev); + copyToDevice(mPosEmb, getWeightsSize(mPosEmb, mEmbType), mPosEmbDev); + copyToDevice(mTokEmb, getWeightsSize(mTokEmb, mEmbType), mTokEmbDev); +} + +// IPluginV2 Methods +char const* EmbLayerNormInt8PluginDynamic::getPluginType() const noexcept { return EMB_LAYER_NORM_INT8_NAME; } + +char const* EmbLayerNormInt8PluginDynamic::getPluginVersion() const noexcept { return EMB_LAYER_NORM_INT8_VERSION; } + +int32_t EmbLayerNormInt8PluginDynamic::getNbOutputs() const noexcept { return 3; } + +int32_t EmbLayerNormInt8PluginDynamic::initialize() noexcept { return 0; } + +void EmbLayerNormInt8PluginDynamic::terminate() noexcept { + gLogInfo << "EmbLayerNormInt8PluginDynamic terminate." << endl; +} + +size_t EmbLayerNormInt8PluginDynamic::getSerializationSize() const noexcept { + size_t const wordSize = getElementSize(mEmbType); + return sizeof(mEmbType) * 2 // mEmbType, mMhaType + + sizeof(mHiddenSize) * 6 // mHiddenSize, mSeqLen, 3*VocabSize, mPadId + + sizeof(mUseFullMask) // mask type + + 2 * sizeof(half) * mHiddenSize // beta + gamma + + wordSize * mHiddenSize * mWordVocabSize // word emb + + wordSize * mHiddenSize * mPosVocabSize // pos emb + + wordSize * mHiddenSize * mTokVocabSize // tok emb + ; +} + +void EmbLayerNormInt8PluginDynamic::serialize(void* buffer) const noexcept { + serialize_value(&buffer, mEmbType); + serialize_value(&buffer, mMhaType); + serialize_value(&buffer, mHiddenSize); + serialize_value(&buffer, mSeqLen); + serialize_value(&buffer, mPadId); + serialize_value(&buffer, mWordVocabSize); + serialize_value(&buffer, mPosVocabSize); + serialize_value(&buffer, mTokVocabSize); + serialize_value(&buffer, mUseFullMask); + + char* d = static_cast(buffer); + serFromDev(d, mBetaDev.get(), mHiddenSize); + serFromDev(d, mGammaDev.get(), mHiddenSize); + size_t const wordSize = getElementSize(mEmbType); + serFromDev(d, static_cast(mWordEmbDev.get()), mHiddenSize * mWordVocabSize * wordSize); + serFromDev(d, static_cast(mPosEmbDev.get()), mHiddenSize * mPosVocabSize * wordSize); + serFromDev(d, static_cast(mTokEmbDev.get()), mHiddenSize * mTokVocabSize * wordSize); +} + +void EmbLayerNormInt8PluginDynamic::destroy() noexcept { + gLogInfo << "EmbLayerNormInt8PluginDynamic destroy." << endl; + // This gets called when the network containing plugin is destroyed + mGammaDev.reset(nullptr); + mBetaDev.reset(nullptr); + mWordEmbDev.reset(nullptr); + mPosEmbDev.reset(nullptr); + mTokEmbDev.reset(nullptr); + delete this; +} + +void EmbLayerNormInt8PluginDynamic::setPluginNamespace(char const* libNamespace) noexcept { + try { + mNamespace = libNamespace; + } catch (std::exception const& e) { + caughtError(e); + } +} + +char const* EmbLayerNormInt8PluginDynamic::getPluginNamespace() const noexcept { return mNamespace.c_str(); } + +// IPluginV2Ext Methods +DataType EmbLayerNormInt8PluginDynamic::getOutputDataType(int32_t index, DataType const* inputTypes, + int32_t nbInputs) const noexcept { + IXRT_PLUGIN_ASSERT(index >= 0 && index <= 2); + if (index == 0) { + return mMhaType; + } + if (index == 1) { + return DataType::kINT8; + } + return DataType::kHALF; +} + +// IPluginV2DynamicExt Methods +IPluginV2DynamicExt* EmbLayerNormInt8PluginDynamic::clone() const noexcept { + try { + gLogInfo << "EmbLayerNormInt8PluginDynamic clone." << endl; + + auto p = new EmbLayerNormInt8PluginDynamic(mLayerName, mEmbType, mMhaType, mBeta, mGamma, mWordEmb, mPosEmb, + mTokEmb, mUseFullMask); + p->mSeqLen = mSeqLen; + p->setPluginNamespace(mNamespace.c_str()); + + return p; + } catch (std::exception const& e) { + caughtError(e); + } + return nullptr; +} + +DimsExprs EmbLayerNormInt8PluginDynamic::getOutputDimensions(int32_t outputIndex, DimsExprs const* inputs, + int32_t nbInputs, IExprBuilder& exprBuilder) noexcept { + try { + // Input should be input ids and token ids and the input mask + // Output should be the embeddings tensor and mask indices + IXRT_PLUGIN_ASSERT(nbInputs == 3); + + IXRT_PLUGIN_ASSERT(inputs[0].nbDims == 2); // BxS + IXRT_PLUGIN_ASSERT(inputs[0].nbDims == inputs[1].nbDims); + IXRT_PLUGIN_ASSERT(inputs[0].nbDims == inputs[2].nbDims); + + IXRT_PLUGIN_ASSERT(outputIndex >= 0 || outputIndex <= 2); + + if (outputIndex == 0) { + DimsExprs ret; + ret.nbDims = 3; + ret.d[0] = inputs[0].d[BDIM]; + ret.d[1] = inputs[0].d[SDIM]; + ret.d[2] = exprBuilder.constant(mHiddenSize); + // ret.d[3] = exprBuilder.constant(1); + // ret.d[4] = exprBuilder.constant(1); + return ret; + } + if (outputIndex == 1) { + DimsExprs ret; + ret.nbDims = 2; + ret.d[0] = inputs[0].d[BDIM]; + ret.d[1] = inputs[0].d[SDIM]; + return ret; + } + + DimsExprs ret; + ret.nbDims = 3; + ret.d[0] = inputs[0].d[BDIM]; + ret.d[1] = inputs[0].d[SDIM]; + ret.d[2] = exprBuilder.constant(mHiddenSize); + // ret.d[3] = exprBuilder.constant(1); + // ret.d[4] = exprBuilder.constant(1); + return ret; + } catch (std::exception const& e) { + caughtError(e); + } + return DimsExprs{}; +} + +bool EmbLayerNormInt8PluginDynamic::supportsFormatCombination(int32_t pos, PluginTensorDesc const* inOut, + int32_t nbInputs, int32_t nbOutputs) noexcept { + // 3 inputs of size BxS + IXRT_PLUGIN_ASSERT(nbInputs == 3); + IXRT_PLUGIN_ASSERT(nbOutputs == 3); + + PluginTensorDesc const& desc = inOut[pos]; + if (desc.format != TensorFormat::kLINEAR) { + return false; + } + if (pos == 0) { + return desc.type == DataType::kINT32; + } + + PluginTensorDesc const& prev = inOut[pos - 1]; + if (pos == 1 || pos == 2) { + return desc.type == DataType::kINT32 && desc.format == prev.format; + } + + // emb_out + if (pos == 3 || pos == 4) { + return desc.type == mMhaType && desc.format == prev.format; + } + // residual + return desc.type == DataType::kHALF; +} + +void EmbLayerNormInt8PluginDynamic::configurePlugin(DynamicPluginTensorDesc const* inputs, int32_t nbInputs, + DynamicPluginTensorDesc const* outputs, + int32_t nbOutputs) noexcept { + gLogInfo << "EmbLayerNormInt8PluginDynamic configurePlugin." << endl; + + // Validate input arguments + IXRT_PLUGIN_ASSERT(nbOutputs == 3); + IXRT_PLUGIN_ASSERT(nbInputs == 3); + + IXRT_PLUGIN_ASSERT(inputs[0].desc.dims.nbDims == 2); + int32_t const S = inputs[0].desc.dims.d[SDIM]; + mSeqLen = S; + int32_t const B = inputs[0].desc.dims.d[BDIM]; + TRT_UNUSED B; + IXRT_PLUGIN_ASSERT(mSeqLen == static_cast(inputs[1].desc.dims.d[SDIM])); + IXRT_PLUGIN_ASSERT(B == inputs[1].desc.dims.d[BDIM]); + IXRT_PLUGIN_ASSERT(mSeqLen == static_cast(inputs[2].desc.dims.d[SDIM])); + IXRT_PLUGIN_ASSERT(B == inputs[2].desc.dims.d[BDIM]); + + IXRT_PLUGIN_ASSERT(outputs[0].desc.dims.nbDims == 3); + IXRT_PLUGIN_ASSERT(static_cast(outputs[0].desc.dims.d[SDIM]) == mSeqLen); + IXRT_PLUGIN_ASSERT(outputs[0].desc.dims.d[BDIM] == B); + IXRT_PLUGIN_ASSERT(static_cast(outputs[0].desc.dims.d[2]) == mHiddenSize); + // IXRT_PLUGIN_ASSERT(outputs[0].desc.dims.d[3] == 1); + // IXRT_PLUGIN_ASSERT(outputs[0].desc.dims.d[4] == 1); + + IXRT_PLUGIN_ASSERT(outputs[1].desc.dims.nbDims == 2); + IXRT_PLUGIN_ASSERT(outputs[1].desc.dims.d[0] == B); + IXRT_PLUGIN_ASSERT(outputs[1].desc.dims.d[1] == S); + + IXRT_PLUGIN_ASSERT(outputs[2].desc.dims.nbDims == 3); + IXRT_PLUGIN_ASSERT(outputs[2].desc.dims.d[SDIM] == outputs[0].desc.dims.d[SDIM]); + IXRT_PLUGIN_ASSERT(outputs[2].desc.dims.d[BDIM] == outputs[0].desc.dims.d[BDIM]); + IXRT_PLUGIN_ASSERT(outputs[2].desc.dims.d[2] == outputs[0].desc.dims.d[2]); + // IXRT_PLUGIN_ASSERT(outputs[2].desc.dims.d[3] == 1); + // IXRT_PLUGIN_ASSERT(outputs[2].desc.dims.d[4] == 1); +} + +size_t EmbLayerNormInt8PluginDynamic::getWorkspaceSize(PluginTensorDesc const* inputs, int32_t nbInputs, + PluginTensorDesc const* outputs, + int32_t nbOutputs) const noexcept { + int32_t const B = inputs[0].dims.d[BDIM]; + int32_t const S = inputs[0].dims.d[SDIM]; + return B * S * sizeof(int32_t); +} + +int32_t EmbLayerNormInt8PluginDynamic::enqueue(PluginTensorDesc const* inputDesc, PluginTensorDesc const* outputDesc, + void const* const* inputs, void* const* outputs, void* workspace, + cudaStream_t stream) noexcept { + gLogInfo << "in EmbLayerNormInt8PluginDynamic.." << endl; + try { + int32_t const B = inputDesc->dims.d[BDIM]; + int32_t const S = inputDesc->dims.d[SDIM]; + int32_t status = STATUS_SUCCESS; + int32_t fmha_S = S; + int32_t batch_tokens = B * fmha_S; + + // Our plugin outputs only one tensor + auto const inputIds = static_cast(inputs[0]); + auto const segmentIds = static_cast(inputs[1]); + + half const* beta = mBetaDev.get(); + half const* gamma = mGammaDev.get(); + auto output = static_cast(outputs[0]); + auto mNewMask = static_cast(outputs[1]); + auto residual = static_cast(outputs[2]); + auto const wordEmb = static_cast(mWordEmbDev.get()); + auto const tokEmb = static_cast(mTokEmbDev.get()); + auto const posEmb = static_cast(mPosEmbDev.get()); + + float l0_qkv_in_amax = outputDesc[0].scale * 127; + + auto mask_idx = static_cast(workspace); + status = embLayerNorm(stream, static_cast(mHiddenSize), B, S, inputIds, segmentIds, beta, gamma, + wordEmb, posEmb, tokEmb, mWordVocabSize, mTokVocabSize, residual, output, mask_idx, + mPadId, l0_qkv_in_amax); + + IxinferMaskPad(mask_idx, mNewMask, B, S, mHiddenSize, fmha_S, batch_tokens, stream); + + if (status != cudaSuccess) { + return status; + } + + return status; + } catch (std::exception const& e) { + caughtError(e); + } + return STATUS_FAILURE; +} diff --git a/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/emb_layernorm/embLayerNormInt8Plugin.cu b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/emb_layernorm/embLayerNormInt8Plugin.cu new file mode 100644 index 0000000000000000000000000000000000000000..3cca4d181bf40d158654ec8c314433abb91f2130 --- /dev/null +++ b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/emb_layernorm/embLayerNormInt8Plugin.cu @@ -0,0 +1,374 @@ +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +* All Rights Reserved. +* +* Licensed under the Apache License, Version 2.0 (the "License"); you may +* not use this file except in compliance with the License. You may obtain +* a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +* License for the specific language governing permissions and limitations +* under the License. +*/ +#include "backend/bert/bert_helper.h" +#include "backend/transformer/transformer_embed.h" +#include "embLayerNormInt8Plugin.h" + +namespace nvinfer1::plugin { +using namespace backend; +namespace bert { + +template +__global__ void IxinferResidualI8O(const float *input, int8_t *output, int hidden_size, float quant_scale) { + float4 vals[THREAD_DATA_LEN]; + int block_start = blockIdx.x * hidden_size; + + input += block_start; + output += block_start; + + float4 *p_input = (float4 *)input; + char4 *p_output = (char4 *)output; + +#pragma unroll + for (int it = 0; it < THREAD_DATA_LEN; ++it) { + int element_index = threadIdx.x + it * C10_WARP_SIZE; + vals[it].x = p_input[element_index].x; + vals[it].y = p_input[element_index].y; + vals[it].z = p_input[element_index].z; + vals[it].w = p_input[element_index].w; + + char4 res = float42char4(vals[it], quant_scale); + p_output[element_index] = res; + } +} + +template +__global__ void IxinferResidualI8O(const half *input, int8_t *output, int hidden_size, float quant_scale) { + // register + float4 vals[THREAD_DATA_LEN]; + int block_start = blockIdx.x * hidden_size; + // one line start + input += block_start; + output += block_start; + + __half2 *p_input = (__half2 *)input; + char4 *p_output = (char4 *)output; + + // load data from global memory +#pragma unroll + for (int it = 0; it < THREAD_DATA_LEN; ++it) { + int element_index = threadIdx.x + it * C10_WARP_SIZE; + load_float4_from_half(vals[it], p_input, element_index); + + char4 res = float42char4(vals[it], quant_scale); + p_output[element_index] = res; + } +} + +template +void IxinferResidualI8OLauncher(const T *input, int8_t *output, int batch_tokens, int hidden_size, float quant_scale, + cudaStream_t stream) { + if (hidden_size > 4096) { + throw std::runtime_error("hidden_size should <= 4096"); + } + if (hidden_size / 4 % C10_WARP_SIZE != 0) { + throw std::runtime_error("hidden_size // C10_WARP_SIZE != 0"); + } + dim3 gridSize(batch_tokens); + dim3 blockSize(C10_WARP_SIZE); + + int num_warp = hidden_size / C10_WARP_SIZE / 4; + + switch (num_warp) { + case 1: + IxinferResidualI8O<1><<>>(input, output, hidden_size, quant_scale); + break; + case 2: + IxinferResidualI8O<2><<>>(input, output, hidden_size, quant_scale); + break; + case 3: + IxinferResidualI8O<3><<>>(input, output, hidden_size, quant_scale); + break; + case 4: + IxinferResidualI8O<4><<>>(input, output, hidden_size, quant_scale); + break; + case 5: + IxinferResidualI8O<5><<>>(input, output, hidden_size, quant_scale); + break; + case 6: + IxinferResidualI8O<6><<>>(input, output, hidden_size, quant_scale); + break; + case 7: + IxinferResidualI8O<7><<>>(input, output, hidden_size, quant_scale); + break; + case 8: + IxinferResidualI8O<8><<>>(input, output, hidden_size, quant_scale); + break; + case 9: + IxinferResidualI8O<9><<>>(input, output, hidden_size, quant_scale); + break; + case 10: + IxinferResidualI8O<10><<>>(input, output, hidden_size, quant_scale); + break; + case 11: + IxinferResidualI8O<11><<>>(input, output, hidden_size, quant_scale); + break; + case 12: + IxinferResidualI8O<12><<>>(input, output, hidden_size, quant_scale); + break; + case 13: + IxinferResidualI8O<13><<>>(input, output, hidden_size, quant_scale); + break; + case 14: + IxinferResidualI8O<14><<>>(input, output, hidden_size, quant_scale); + break; + case 15: + IxinferResidualI8O<15><<>>(input, output, hidden_size, quant_scale); + break; + case 16: + IxinferResidualI8O<16><<>>(input, output, hidden_size, quant_scale); + break; + default: + throw std::runtime_error("IxinferResidualI8OLauncher"); + break; + } +} + +cudaError_t embLayerNorm(cudaStream_t stream, int E, int B, int S, int32_t const *inputIds, int32_t const *segmentIds, + half const *beta, half const *gamma, half const *wordEmb, half const *posEmb, + half const *tokEmb, int32_t const wordSize, int32_t const tokSize, half *buffer, + int8_t *output, int32_t *maskIdx, int32_t padId, float l0_qkv_in_amax) { + backend::IxinferBertEmbedLn(wordEmb, posEmb, tokEmb, inputIds, buffer, maskIdx, (int *)segmentIds, padId, B, S, E, + gamma, beta, stream); + IxinferResidualI8OLauncher(buffer, output, B * S, E, 127.0 / l0_qkv_in_amax, stream); + return cudaSuccess; +} + +template +__global__ void IxinferBertEmbedLnKernel(const float *token_emb, const float *pos_emb, const float *type_emb, + const int *tokens, float *output, int *pad_mask, int *type_ids, int pad_id, + int batch_size, int seq_len, int hidden_dim, const float *scale, + const float *bias) { + float4 vals[THREAD_DATA_LEN]; + int block_start = blockIdx.x * hidden_dim; + int batch_idx, seq_idx; + batch_idx = blockIdx.x / seq_len; + seq_idx = blockIdx.x % seq_len; + + int tokens_idx = blockIdx.x; + int token = tokens[tokens_idx]; + int token_type = type_ids[tokens_idx]; + + output += block_start; + + float4 *p_output = (float4 *)output; + + float4 *p_scale = (float4 *)scale; + float4 *p_bias = (float4 *)bias; + float4 *p_value = (float4 *)(token_emb + token * hidden_dim); + float4 *p_pemb = (float4 *)(pos_emb + seq_idx * hidden_dim); + float4 *p_temb = (float4 *)(type_emb + token_type * hidden_dim); + + float thread_m2 = 0; + float thread_mean = 0; + float thread_count = 0; + +#pragma unroll + for (int it = 0; it < THREAD_DATA_LEN; ++it) { + int element_index = threadIdx.x + it * C10_WARP_SIZE; + + if (token == pad_id) { + if (element_index == 0) { + pad_mask[tokens_idx] = 1; + } + vals[it] = make_float4(0.f, 0.f, 0.f, 0.f); + + } else { + if (element_index == 0) { + pad_mask[tokens_idx] = 0; + } + + vals[it].x = p_value[element_index].x + p_pemb[element_index].x + p_temb[element_index].x; + vals[it].y = p_value[element_index].y + p_pemb[element_index].y + p_temb[element_index].y; + vals[it].z = p_value[element_index].z + p_pemb[element_index].z + p_temb[element_index].z; + vals[it].w = p_value[element_index].w + p_pemb[element_index].w + p_temb[element_index].w; + WelfordCombine(vals[it].x, &thread_mean, &thread_m2, &thread_count); + WelfordCombine(vals[it].y, &thread_mean, &thread_m2, &thread_count); + WelfordCombine(vals[it].z, &thread_mean, &thread_m2, &thread_count); + WelfordCombine(vals[it].w, &thread_mean, &thread_m2, &thread_count); + } + } + float mean = 0; + float m2 = 0; + float count = 0; + WelfordWarpReduce(thread_mean, thread_m2, thread_count, &mean, &m2, &count); + mean = __shfl_sync(0xffffffff, mean, 0, C10_WARP_SIZE); + m2 = __shfl_sync(0xffffffff, m2, 0, C10_WARP_SIZE); + count = __shfl_sync(0xffffffff, count, 0, C10_WARP_SIZE); + +#pragma unroll + for (int it = 0; it < THREAD_DATA_LEN; ++it) { + int element_index = threadIdx.x + it * C10_WARP_SIZE; + float4 scale_value = p_scale[element_index]; + float4 bias_value = p_bias[element_index]; + float4 norm_value = compute_float4_norm_value(vals[it], mean, m2, hidden_dim, epsilon, scale_value, bias_value); + int tokens_idx = blockIdx.x; + + int token = tokens[tokens_idx]; + if (token == pad_id) { + p_output[element_index] = make_float4(0.f, 0.f, 0.f, 0.f); + } else { + p_output[element_index] = norm_value; + } + } +} + +void IxinferBertEmbedLn(const float *token_emb, const float *pos_emb, const float *type_emb, const int *tokens, + float *output, int *pad_mask, int *type_ids, int pad_id, int batch_size, int seq_len, + int hidden_size, const float *scale, const float *bias, cudaStream_t stream) { + if (hidden_size > 4096) { + throw std::runtime_error("hidden_size should <= 4096"); + } + if (hidden_size % C10_WARP_SIZE != 0) { + throw std::runtime_error("hidden_size // C10_WARP_SIZE != 0"); + } + int batch_tokens = batch_size * seq_len; + dim3 gridSize(batch_tokens); + dim3 blockSize(C10_WARP_SIZE); + int num_warp = hidden_size / C10_WARP_SIZE / 4; + + switch (num_warp) { + case 1: + IxinferBertEmbedLnKernel<1> + <<>>(token_emb, pos_emb, type_emb, tokens, output, pad_mask, type_ids, + pad_id, batch_size, seq_len, hidden_size, scale, bias); + break; + case 2: + IxinferBertEmbedLnKernel<2> + <<>>(token_emb, pos_emb, type_emb, tokens, output, pad_mask, type_ids, + pad_id, batch_size, seq_len, hidden_size, scale, bias); + break; + case 3: + IxinferBertEmbedLnKernel<3> + <<>>(token_emb, pos_emb, type_emb, tokens, output, pad_mask, type_ids, + pad_id, batch_size, seq_len, hidden_size, scale, bias); + break; + case 4: + IxinferBertEmbedLnKernel<4> + <<>>(token_emb, pos_emb, type_emb, tokens, output, pad_mask, type_ids, + pad_id, batch_size, seq_len, hidden_size, scale, bias); + break; + case 5: + IxinferBertEmbedLnKernel<5> + <<>>(token_emb, pos_emb, type_emb, tokens, output, pad_mask, type_ids, + pad_id, batch_size, seq_len, hidden_size, scale, bias); + break; + case 6: + IxinferBertEmbedLnKernel<6> + <<>>(token_emb, pos_emb, type_emb, tokens, output, pad_mask, type_ids, + pad_id, batch_size, seq_len, hidden_size, scale, bias); + break; + case 7: + IxinferBertEmbedLnKernel<7> + <<>>(token_emb, pos_emb, type_emb, tokens, output, pad_mask, type_ids, + pad_id, batch_size, seq_len, hidden_size, scale, bias); + break; + case 8: + IxinferBertEmbedLnKernel<8> + <<>>(token_emb, pos_emb, type_emb, tokens, output, pad_mask, type_ids, + pad_id, batch_size, seq_len, hidden_size, scale, bias); + break; + case 9: + IxinferBertEmbedLnKernel<9> + <<>>(token_emb, pos_emb, type_emb, tokens, output, pad_mask, type_ids, + pad_id, batch_size, seq_len, hidden_size, scale, bias); + break; + case 10: + IxinferBertEmbedLnKernel<10> + <<>>(token_emb, pos_emb, type_emb, tokens, output, pad_mask, type_ids, + pad_id, batch_size, seq_len, hidden_size, scale, bias); + break; + case 11: + IxinferBertEmbedLnKernel<11> + <<>>(token_emb, pos_emb, type_emb, tokens, output, pad_mask, type_ids, + pad_id, batch_size, seq_len, hidden_size, scale, bias); + break; + case 12: + IxinferBertEmbedLnKernel<12> + <<>>(token_emb, pos_emb, type_emb, tokens, output, pad_mask, type_ids, + pad_id, batch_size, seq_len, hidden_size, scale, bias); + break; + case 13: + IxinferBertEmbedLnKernel<13> + <<>>(token_emb, pos_emb, type_emb, tokens, output, pad_mask, type_ids, + pad_id, batch_size, seq_len, hidden_size, scale, bias); + break; + case 14: + IxinferBertEmbedLnKernel<14> + <<>>(token_emb, pos_emb, type_emb, tokens, output, pad_mask, type_ids, + pad_id, batch_size, seq_len, hidden_size, scale, bias); + break; + case 15: + IxinferBertEmbedLnKernel<15> + <<>>(token_emb, pos_emb, type_emb, tokens, output, pad_mask, type_ids, + pad_id, batch_size, seq_len, hidden_size, scale, bias); + break; + case 16: + IxinferBertEmbedLnKernel<16> + <<>>(token_emb, pos_emb, type_emb, tokens, output, pad_mask, type_ids, + pad_id, batch_size, seq_len, hidden_size, scale, bias); + break; + default: + throw std::runtime_error("IxinferBertEmbedLn"); + break; + } +} + +cudaError_t embLayerNorm(cudaStream_t stream, int E, int B, int S, int32_t const *inputIds, int32_t const *segmentIds, + float const *beta, float const *gamma, float const *wordEmb, float const *posEmb, + float const *tokEmb, int32_t const wordSize, int32_t const tokSize, float *buffer, + int8_t *output, int32_t *maskIdx, int32_t padId, float l0_qkv_in_amax) { + IxinferBertEmbedLn(wordEmb, posEmb, tokEmb, inputIds, buffer, maskIdx, (int *)segmentIds, padId, B, S, E, gamma, + beta, stream); + + IxinferResidualI8OLauncher(buffer, output, B * S, E, 127.0 / l0_qkv_in_amax, stream); + return cudaSuccess; +} + +void __global__ IxinferMaskPadKernel(const int32_t *mask, int8_t *new_mask, int bsz, int ori_seq_len, int hsz, + int fmha_seq_len) { + int batch_idx = blockIdx.x; + int seq_idx = blockIdx.y; + + if (seq_idx < ori_seq_len) { + if (threadIdx.x == 0) { + new_mask[batch_idx * fmha_seq_len + seq_idx] = mask[batch_idx * ori_seq_len + seq_idx]; + } + } else { + new_mask[batch_idx * fmha_seq_len + seq_idx] = 1; + } +} + +void IxinferMaskPad(int32_t *mask, int8_t *new_mask, int bsz, int ori_seq_len, int hsz, int fmha_seq_len, + int batch_tokens, cudaStream_t stream) { + if (hsz / 2 > 4096) { + throw std::runtime_error("hsz/2>4096"); + } + if (hsz % 2 != 0) { + throw std::runtime_error("hsz % 2 !=0"); + } + if (ori_seq_len > fmha_seq_len) { + throw std::runtime_error("ori_seq_len > fmha_seq_len"); + } + if (bsz * ori_seq_len > batch_tokens) { + throw std::runtime_error("bsz*ori_seq_len > batch_tokens"); + } + dim3 blockSize(bsz, fmha_seq_len); + IxinferMaskPadKernel<<>>(mask, new_mask, bsz, ori_seq_len, hsz, fmha_seq_len); +} + +} // namespace bert +} // namespace nvinfer1::plugin diff --git a/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/emb_layernorm/embLayerNormInt8Plugin.h b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/emb_layernorm/embLayerNormInt8Plugin.h new file mode 100644 index 0000000000000000000000000000000000000000..9aa560c751853f2236b0cccd76ebdf1915fb3178 --- /dev/null +++ b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/emb_layernorm/embLayerNormInt8Plugin.h @@ -0,0 +1,133 @@ +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +* All Rights Reserved. +* +* Licensed under the Apache License, Version 2.0 (the "License"); you may +* not use this file except in compliance with the License. You may obtain +* a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +* License for the specific language governing permissions and limitations +* under the License. +*/ +#pragma once +#include +#include + +#include +#include +#include +#include +#include + +#include "NvInferRuntimeCommon.h" +#include "bertCommon.h" + +namespace nvinfer1::plugin { +namespace bert { + +void IxinferBertEmbedLn(const float* token_emb, const float* pos_emb, const float* type_emb, const int* tokens, + float* output, int* pad_mask, int* type_ids, int pad_id, int batch_size, int seq_len, + int hidden_size, const float* scale, const float* bias, cudaStream_t stream); + +cudaError_t embLayerNorm(cudaStream_t stream, int E, int B, int S, int32_t const* inputIds, int32_t const* segmentIds, + half const* beta, half const* gamma, half const* wordEmb, half const* posEmb, + half const* tokEmb, int32_t const wordSize, int32_t const tokSize, half* buffer, + int8_t* output, int32_t* maskIdx, int32_t padId, float token_embed_amax_); + +void IxinferMaskPad(int32_t* mask, int8_t* new_mask, int bsz, int ori_seq_len, int hsz, int fmha_seq_len, + int batch_tokens, cudaStream_t stream); + +class EmbLayerNormInt8PluginDynamic : public IPluginV2DynamicExt { + public: + EmbLayerNormInt8PluginDynamic(std::string const& name, nvinfer1::DataType const type, + nvinfer1::DataType const mhaType, nvinfer1::Weights const& beta, + nvinfer1::Weights const& gamma, nvinfer1::Weights const& word_emb, + nvinfer1::Weights const& pos_emb, nvinfer1::Weights const& tok_emb, + bool const useFullMask, int32_t padId = 0); + EmbLayerNormInt8PluginDynamic(std::string const& name, void const* data, size_t length); + EmbLayerNormInt8PluginDynamic() noexcept = delete; + ~EmbLayerNormInt8PluginDynamic() override = default; + + // IPluginV2 methods + char const* getPluginType() const noexcept override; + char const* getPluginVersion() const noexcept override; + int32_t getNbOutputs() const noexcept override; + int32_t initialize() noexcept override; + void terminate() noexcept override; + size_t getSerializationSize() const noexcept override; + void serialize(void* buffer) const noexcept override; + void destroy() noexcept override; + void setPluginNamespace(char const* libNamespace) noexcept override; + char const* getPluginNamespace() const noexcept override; + + // IPluginV2Ext methods + DataType getOutputDataType(int32_t index, DataType const* inputType, int32_t nbInputs) const noexcept override; + + // IPluginV2DynamicExt methods + IPluginV2DynamicExt* clone() const noexcept override; + DimsExprs getOutputDimensions(int32_t outputIndex, DimsExprs const* inputs, int32_t nbInputs, + IExprBuilder& exprBuilder) noexcept override; + bool supportsFormatCombination(int32_t pos, PluginTensorDesc const* inOut, int32_t nbInputs, + int32_t nbOutputs) noexcept override; + void configurePlugin(DynamicPluginTensorDesc const* in, int32_t nbInputs, DynamicPluginTensorDesc const* out, + int32_t nbOutputs) noexcept override; + size_t getWorkspaceSize(PluginTensorDesc const* inputs, int32_t nbInputs, PluginTensorDesc const* outputs, + int32_t nbOutputs) const noexcept override; + int32_t enqueue(PluginTensorDesc const* inputDesc, PluginTensorDesc const* outputDesc, void const* const* inputs, + void* const* outputs, void* workspace, cudaStream_t stream) noexcept override; + + private: + const std::string mLayerName; + std::string mNamespace; + size_t mHiddenSize; + size_t mSeqLen; + size_t mPadId; + DataType mEmbType; + bool mUseFullMask; + DataType mMhaType; + size_t mWordVocabSize, mPosVocabSize, mTokVocabSize; + cuda_unique_ptr mGammaDev; + cuda_unique_ptr mBetaDev; + cuda_unique_ptr mWordEmbDev; + cuda_unique_ptr mTokEmbDev; + cuda_unique_ptr mPosEmbDev; + // cuda_unique_ptr mNewMask; + WeightsWithOwnership mBeta; + WeightsWithOwnership mGamma; + WeightsWithOwnership mWordEmb; + WeightsWithOwnership mTokEmb; + WeightsWithOwnership mPosEmb; +}; + +class EmbLayerNormInt8PluginDynamicCreator : public IPluginCreator { + public: + EmbLayerNormInt8PluginDynamicCreator(); + + ~EmbLayerNormInt8PluginDynamicCreator() override = default; + + char const* getPluginName() const noexcept override; + + char const* getPluginVersion() const noexcept override; + + PluginFieldCollection const* getFieldNames() noexcept override; + + IPluginV2DynamicExt* createPlugin(char const* name, PluginFieldCollection const* fc) noexcept override; + + IPluginV2DynamicExt* deserializePlugin(char const* name, void const* serialData, + size_t serialLength) noexcept override; + + void setPluginNamespace(char const* pluginNamespace) noexcept override; + char const* getPluginNamespace() const noexcept override; + + private: + static PluginFieldCollection mFC; + static std::vector mPluginAttributes; + std::string mNamespace; +}; + +} // namespace bert +} // namespace nvinfer1::plugin diff --git a/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/emb_layernorm/embLayerNormPlugin.cpp b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/emb_layernorm/embLayerNormPlugin.cpp new file mode 100644 index 0000000000000000000000000000000000000000..77d1544c125baae43b6e811539ae1997022a5893 --- /dev/null +++ b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/emb_layernorm/embLayerNormPlugin.cpp @@ -0,0 +1,481 @@ +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +* All Rights Reserved. +* +* Licensed under the Apache License, Version 2.0 (the "License"); you may +* not use this file except in compliance with the License. You may obtain +* a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +* License for the specific language governing permissions and limitations +* under the License. +*/ +#include "embLayerNormPlugin.h" + +#include "NvInferImpl.h" +#include "NvInferRuntimeCommon.h" +#include "bertCommon.h" +#include "checkMacrosPlugin.h" +#include "common_def.cuh" +#include "driver_types.h" +#include "plugin.h" +#include "serialize.h" + +using namespace nvinfer1; +using namespace nvinfer1::plugin; +using namespace nvinfer1::plugin::bert; + +namespace { +char const* EMB_LAYER_NORM_VERSION{"1"}; +char const* EMB_LAYER_NORM_NAME{"CustomEmbLayerNormPluginDynamic_IxRT"}; +} // namespace + +// Static class fields initialization +PluginFieldCollection EmbLayerNormPluginDynamicCreator::mFC{}; +std::vector EmbLayerNormPluginDynamicCreator::mPluginAttributes; + +EmbLayerNormPluginDynamicCreator::EmbLayerNormPluginDynamicCreator() { + mPluginAttributes.clear(); + mPluginAttributes.emplace_back(PluginField("bert_embeddings_layernorm_beta")); + mPluginAttributes.emplace_back(PluginField("bert_embeddings_layernorm_gamma")); + mPluginAttributes.emplace_back(PluginField("bert_embeddings_word_embeddings")); + mPluginAttributes.emplace_back(PluginField("bert_embeddings_token_type_embeddings")); + mPluginAttributes.emplace_back(PluginField("bert_embeddings_position_embeddings")); + mPluginAttributes.emplace_back(PluginField("output_fp16")); + mPluginAttributes.emplace_back(PluginField("full_mask")); + mPluginAttributes.emplace_back(PluginField("mha_type_id")); + mPluginAttributes.emplace_back(PluginField("pad_id")); + mFC.nbFields = mPluginAttributes.size(); + mFC.fields = mPluginAttributes.data(); +} + +char const* EmbLayerNormPluginDynamicCreator::getPluginName() const noexcept { return EMB_LAYER_NORM_NAME; } + +char const* EmbLayerNormPluginDynamicCreator::getPluginVersion() const noexcept { return EMB_LAYER_NORM_VERSION; } + +PluginFieldCollection const* EmbLayerNormPluginDynamicCreator::getFieldNames() noexcept { return &mFC; } + +IPluginV2DynamicExt* EmbLayerNormPluginDynamicCreator::createPlugin(char const* name, + PluginFieldCollection const* fc) noexcept { + try { + IXRT_PLUGIN_ASSERT(fc != nullptr); + gLogInfo << "EmbLayerNormPluginDynamic createPlugin." << endl; + std::set const requiredAttributes{ + "bert_embeddings_layernorm_beta", "bert_embeddings_layernorm_gamma", + "bert_embeddings_word_embeddings", "bert_embeddings_token_type_embeddings", + "bert_embeddings_position_embeddings", + }; + + bool output_fp16 = false; + bool useFullMask = false; + Weights beta{}; + Weights gamma{}; + Weights word_emb{}; + Weights pos_emb{}; + Weights tok_emb{}; + int32_t mhaTypeId = 0; + int32_t pad_id = 0; + + for (auto i = 0; i < fc->nbFields; i++) { + std::string field_name(fc->fields[i].name); + if (field_name.compare("bert_embeddings_layernorm_beta") == 0) { + gLogInfo << "Building bert_embeddings_layernorm_beta..." << endl; + beta.values = fc->fields[i].data; + beta.count = fc->fields[i].length; + beta.type = fieldTypeToDataType(fc->fields[i].type); + } + + if (field_name.compare("bert_embeddings_layernorm_gamma") == 0) { + gLogInfo << "Building bert_embeddings_layernorm_gamma..." << endl; + gamma.values = fc->fields[i].data; + gamma.count = fc->fields[i].length; + gamma.type = fieldTypeToDataType(fc->fields[i].type); + } + + if (field_name.compare("bert_embeddings_word_embeddings") == 0) { + gLogInfo << "Building bert_embeddings_word_embeddings..." << endl; + word_emb.values = fc->fields[i].data; + word_emb.count = fc->fields[i].length; + word_emb.type = fieldTypeToDataType(fc->fields[i].type); + } + + if (field_name.compare("bert_embeddings_token_type_embeddings") == 0) { + gLogInfo << "Building bert_embeddings_token_type_embeddings..." << endl; + tok_emb.values = fc->fields[i].data; + tok_emb.count = fc->fields[i].length; + tok_emb.type = fieldTypeToDataType(fc->fields[i].type); + } + + if (field_name.compare("bert_embeddings_position_embeddings") == 0) { + gLogInfo << "Building bert_embeddings_position_embeddings..." << endl; + pos_emb.values = fc->fields[i].data; + pos_emb.count = fc->fields[i].length; + pos_emb.type = fieldTypeToDataType(fc->fields[i].type); + } + + if (field_name.compare("output_fp16") == 0) { + IXRT_PLUGIN_ASSERT(fc->fields[i].type == PluginFieldType::kINT32); + output_fp16 = static_cast(fc->fields[i].data)[0] != 0; + gLogInfo << "Building output_fp16" << output_fp16 << endl; + } + + if (field_name.compare("full_mask") == 0) { + IXRT_PLUGIN_ASSERT(fc->fields[i].type == PluginFieldType::kINT32); + useFullMask = static_cast(fc->fields[i].data)[0] != 0; + gLogInfo << "Building full_mask:" << useFullMask << endl; + } + + if (field_name.compare("mha_type_id") == 0) { + mhaTypeId = *static_cast(fc->fields[i].data); + IXRT_PLUGIN_ASSERT(mhaTypeId >= 0 && mhaTypeId < 3); + gLogInfo << "Building mha typeId: " << mhaTypeId << endl; + } + + if (field_name.compare("pad_id") == 0) { + IXRT_PLUGIN_ASSERT(fc->fields[i].type == PluginFieldType::kINT32) + pad_id = *static_cast(fc->fields[i].data); + } + } + gLogInfo << "Building EmbLayerNormPluginDynamic Plugin..." << endl; + DataType mhaType = static_cast(mhaTypeId); + EmbLayerNormPluginDynamic* p = + new EmbLayerNormPluginDynamic(name, output_fp16 ? DataType::kHALF : DataType::kFLOAT, mhaType, beta, gamma, + word_emb, pos_emb, tok_emb, useFullMask, pad_id); + + return p; + } catch (std::exception const& e) { + caughtError(e); + } + return nullptr; +} + +IPluginV2DynamicExt* EmbLayerNormPluginDynamicCreator::deserializePlugin(char const* name, void const* serialData, + size_t serialLength) noexcept { + try { + IXRT_PLUGIN_ASSERT(serialData != nullptr); + return new EmbLayerNormPluginDynamic(name, serialData, serialLength); + } catch (std::exception const& e) { + caughtError(e); + } + return nullptr; +} + +void EmbLayerNormPluginDynamicCreator::setPluginNamespace(char const* pluginNamespace) noexcept { + try { + IXRT_PLUGIN_ASSERT(pluginNamespace != nullptr); + mNamespace = pluginNamespace; + } catch (std::exception const& e) { + caughtError(e); + } +} + +char const* EmbLayerNormPluginDynamicCreator::getPluginNamespace() const noexcept { return mNamespace.c_str(); } + +// REGISTER_TENSORRT_PLUGIN(EmbLayerNormPluginDynamicCreator); + +//#########################################################################// +EmbLayerNormPluginDynamic::EmbLayerNormPluginDynamic(std::string const& name, DataType const type, + DataType const mhaType, Weights const& beta, Weights const& gamma, + Weights const& wordEmb, Weights const& posEmb, + Weights const& tokEmb, bool const useFullMask, int32_t padId) + : mLayerName(name), + mHiddenSize(beta.count), + mEmbType(type), + mUseFullMask(useFullMask), + mMhaType(mhaType), + mPadId(padId) { + IXRT_PLUGIN_ASSERT(beta.count == gamma.count); + IXRT_PLUGIN_ASSERT(mHiddenSize > 0U); + IXRT_PLUGIN_ASSERT(wordEmb.count % mHiddenSize == 0); + IXRT_PLUGIN_ASSERT(posEmb.count % mHiddenSize == 0); + IXRT_PLUGIN_ASSERT(tokEmb.count % mHiddenSize == 0); + mWordVocabSize = wordEmb.count / mHiddenSize; + mPosVocabSize = posEmb.count / mHiddenSize; + mTokVocabSize = tokEmb.count / mHiddenSize; + + mBeta.convertAndCopy(beta, nvinfer1::DataType::kHALF); + mGamma.convertAndCopy(gamma, nvinfer1::DataType::kHALF); + mWordEmb.convertAndCopy(wordEmb, mEmbType); + mTokEmb.convertAndCopy(tokEmb, mEmbType); + mPosEmb.convertAndCopy(posEmb, mEmbType); + + copyToDevice(mGamma, sizeof(half) * mGamma.count, mGammaDev); + copyToDevice(mBeta, sizeof(half) * mBeta.count, mBetaDev); + copyToDevice(mWordEmb, getWeightsSize(mWordEmb, mEmbType), mWordEmbDev); + copyToDevice(mPosEmb, getWeightsSize(mPosEmb, mEmbType), mPosEmbDev); + copyToDevice(mTokEmb, getWeightsSize(mTokEmb, mEmbType), mTokEmbDev); +} + +EmbLayerNormPluginDynamic::EmbLayerNormPluginDynamic(std::string const& name, void const* data, size_t length) + : mLayerName(name), + mGammaDev(nullptr), + mBetaDev(nullptr), + mWordEmbDev(nullptr), + mTokEmbDev(nullptr), + mPosEmbDev(nullptr) { + gLogInfo << "EmbLayerNormPluginDynamic deserialize." << endl; + + // Deserialize in the same order as serialization + deserialize_value(&data, &length, &mEmbType); + deserialize_value(&data, &length, &mMhaType); + deserialize_value(&data, &length, &mHiddenSize); + deserialize_value(&data, &length, &mSeqLen); + deserialize_value(&data, &length, &mPadId); + deserialize_value(&data, &length, &mWordVocabSize); + deserialize_value(&data, &length, &mPosVocabSize); + deserialize_value(&data, &length, &mTokVocabSize); + deserialize_value(&data, &length, &mUseFullMask); + + char const* d = static_cast(data); + mBeta.convertAndCopy(d, mHiddenSize, nvinfer1::DataType::kHALF); + mGamma.convertAndCopy(d, mHiddenSize, nvinfer1::DataType::kHALF); + mWordEmb.convertAndCopy(d, mHiddenSize * mWordVocabSize, mEmbType); + mPosEmb.convertAndCopy(d, mHiddenSize * mPosVocabSize, mEmbType); + mTokEmb.convertAndCopy(d, mHiddenSize * mTokVocabSize, mEmbType); + + copyToDevice(mGamma, sizeof(half) * mGamma.count, mGammaDev); + copyToDevice(mBeta, sizeof(half) * mBeta.count, mBetaDev); + copyToDevice(mWordEmb, getWeightsSize(mWordEmb, mEmbType), mWordEmbDev); + copyToDevice(mPosEmb, getWeightsSize(mPosEmb, mEmbType), mPosEmbDev); + copyToDevice(mTokEmb, getWeightsSize(mTokEmb, mEmbType), mTokEmbDev); +} + +// IPluginV2 Methods +char const* EmbLayerNormPluginDynamic::getPluginType() const noexcept { return EMB_LAYER_NORM_NAME; } + +char const* EmbLayerNormPluginDynamic::getPluginVersion() const noexcept { return EMB_LAYER_NORM_VERSION; } + +int32_t EmbLayerNormPluginDynamic::getNbOutputs() const noexcept { return 2; } + +int32_t EmbLayerNormPluginDynamic::initialize() noexcept { return 0; } + +void EmbLayerNormPluginDynamic::terminate() noexcept { gLogInfo << "EmbLayerNormPluginDynamic terminate." << endl; } + +size_t EmbLayerNormPluginDynamic::getSerializationSize() const noexcept { + size_t const wordSize = getElementSize(mEmbType); + return sizeof(mEmbType) * 2 // mEmbType, mMhaType + + sizeof(mHiddenSize) * 6 // mHiddenSize, mSeqLen, 3*VocabSize, mPadId + + sizeof(mUseFullMask) // mask type + + 2 * sizeof(half) * mHiddenSize // beta + gamma + + wordSize * mHiddenSize * mWordVocabSize // word emb + + wordSize * mHiddenSize * mPosVocabSize // pos emb + + wordSize * mHiddenSize * mTokVocabSize // tok emb + ; +} + +void EmbLayerNormPluginDynamic::serialize(void* buffer) const noexcept { + serialize_value(&buffer, mEmbType); + serialize_value(&buffer, mMhaType); + serialize_value(&buffer, mHiddenSize); + serialize_value(&buffer, mSeqLen); + serialize_value(&buffer, mPadId); + serialize_value(&buffer, mWordVocabSize); + serialize_value(&buffer, mPosVocabSize); + serialize_value(&buffer, mTokVocabSize); + serialize_value(&buffer, mUseFullMask); + + char* d = static_cast(buffer); + serFromDev(d, mBetaDev.get(), mHiddenSize); + serFromDev(d, mGammaDev.get(), mHiddenSize); + size_t const wordSize = getElementSize(mEmbType); + serFromDev(d, static_cast(mWordEmbDev.get()), mHiddenSize * mWordVocabSize * wordSize); + serFromDev(d, static_cast(mPosEmbDev.get()), mHiddenSize * mPosVocabSize * wordSize); + serFromDev(d, static_cast(mTokEmbDev.get()), mHiddenSize * mTokVocabSize * wordSize); +} + +void EmbLayerNormPluginDynamic::destroy() noexcept { + gLogInfo << "EmbLayerNormPluginDynamic destroy." << endl; + // This gets called when the network containing plugin is destroyed + mGammaDev.reset(nullptr); + mBetaDev.reset(nullptr); + mWordEmbDev.reset(nullptr); + mPosEmbDev.reset(nullptr); + mTokEmbDev.reset(nullptr); + delete this; +} + +void EmbLayerNormPluginDynamic::setPluginNamespace(char const* libNamespace) noexcept { + try { + mNamespace = libNamespace; + } catch (std::exception const& e) { + caughtError(e); + } +} + +char const* EmbLayerNormPluginDynamic::getPluginNamespace() const noexcept { return mNamespace.c_str(); } + +// IPluginV2Ext Methods +DataType EmbLayerNormPluginDynamic::getOutputDataType(int32_t index, DataType const* inputTypes, + int32_t nbInputs) const noexcept { + IXRT_PLUGIN_ASSERT(index == 0 || index == 1); + if (index == 0) { + IXRT_PLUGIN_ASSERT(mMhaType == DataType::kHALF || mMhaType == DataType::kFLOAT); + return mMhaType; + } + return DataType::kINT32; +} + +// IPluginV2DynamicExt Methods +IPluginV2DynamicExt* EmbLayerNormPluginDynamic::clone() const noexcept { + try { + gLogInfo << "EmbLayerNormPluginDynamic clone." << endl; + + auto p = new EmbLayerNormPluginDynamic(mLayerName, mEmbType, mMhaType, mBeta, mGamma, mWordEmb, mPosEmb, + mTokEmb, mUseFullMask); + p->mSeqLen = mSeqLen; + p->setPluginNamespace(mNamespace.c_str()); + + return p; + } catch (std::exception const& e) { + caughtError(e); + } + return nullptr; +} + +DimsExprs EmbLayerNormPluginDynamic::getOutputDimensions(int32_t outputIndex, DimsExprs const* inputs, int32_t nbInputs, + IExprBuilder& exprBuilder) noexcept { + try { + // Input should be input ids and token ids and the input mask + // Output should be the embeddings tensor and mask indices + IXRT_PLUGIN_ASSERT(nbInputs == 3); + + IXRT_PLUGIN_ASSERT(inputs[0].nbDims == 2); // BxS + IXRT_PLUGIN_ASSERT(inputs[0].nbDims == inputs[1].nbDims); + IXRT_PLUGIN_ASSERT(inputs[0].nbDims == inputs[2].nbDims); + + IXRT_PLUGIN_ASSERT(outputIndex == 0 || outputIndex == 1); + + if (outputIndex == 0) { + DimsExprs ret; + ret.nbDims = 3; + ret.d[0] = inputs[0].d[0]; + ret.d[1] = inputs[0].d[1]; + ret.d[2] = exprBuilder.constant(mHiddenSize); + // ret.d[3] = exprBuilder.constant(1); + // ret.d[4] = exprBuilder.constant(1); + return ret; + } + + DimsExprs ret; + ret.nbDims = 2; + ret.d[0] = inputs[0].d[BDIM]; + ret.d[1] = inputs[0].d[SDIM]; + return ret; + } catch (std::exception const& e) { + caughtError(e); + } + return DimsExprs{}; +} + +bool EmbLayerNormPluginDynamic::supportsFormatCombination(int32_t pos, PluginTensorDesc const* inOut, int32_t nbInputs, + int32_t nbOutputs) noexcept { + // 3 inputs of size BxS + IXRT_PLUGIN_ASSERT(nbInputs == 3); + IXRT_PLUGIN_ASSERT(nbOutputs == 2); + + PluginTensorDesc const& desc = inOut[pos]; + if (desc.format != TensorFormat::kLINEAR) { + return false; + } + if (pos == 0) { + return desc.type == DataType::kINT32; + } + + PluginTensorDesc const& prev = inOut[pos - 1]; + if (pos == 1 || pos == 2) { + return desc.type == DataType::kINT32 && desc.format == prev.format; + } + + // embedded sequence + if (pos == 3) { + return desc.type == mMhaType && desc.format == prev.format; + } + // mask + return desc.type == ((mMhaType == DataType::kHALF) ? DataType::kINT32 : mMhaType); +} + +void EmbLayerNormPluginDynamic::configurePlugin(DynamicPluginTensorDesc const* inputs, int32_t nbInputs, + DynamicPluginTensorDesc const* outputs, int32_t nbOutputs) noexcept { + gLogInfo << "EmbLayerNormPluginDynamic configurePlugin." << endl; + + // Validate input arguments + IXRT_PLUGIN_ASSERT(nbOutputs == 2); + IXRT_PLUGIN_ASSERT(nbInputs == 3); + + IXRT_PLUGIN_ASSERT(inputs[0].desc.dims.nbDims == 2); + int32_t const S = inputs[0].desc.dims.d[SDIM]; + mSeqLen = S; + int32_t const B = inputs[0].desc.dims.d[BDIM]; + TRT_UNUSED B; + IXRT_PLUGIN_ASSERT(mSeqLen == static_cast(inputs[1].desc.dims.d[SDIM])); + IXRT_PLUGIN_ASSERT(B == inputs[1].desc.dims.d[BDIM]); + IXRT_PLUGIN_ASSERT(mSeqLen == static_cast(inputs[2].desc.dims.d[SDIM])); + IXRT_PLUGIN_ASSERT(B == inputs[2].desc.dims.d[BDIM]); + + int32_t fmha_S = nearest_num(S, 64); + IXRT_PLUGIN_ASSERT(outputs[0].desc.dims.nbDims == 3); + IXRT_PLUGIN_ASSERT(static_cast(outputs[0].desc.dims.d[SDIM]) == mSeqLen); + IXRT_PLUGIN_ASSERT(outputs[0].desc.dims.d[BDIM] == B); + IXRT_PLUGIN_ASSERT(static_cast(outputs[0].desc.dims.d[2]) == mHiddenSize); + // IXRT_PLUGIN_ASSERT(outputs[0].desc.dims.d[3] == 1); + // IXRT_PLUGIN_ASSERT(outputs[0].desc.dims.d[4] == 1); + + IXRT_PLUGIN_ASSERT(outputs[1].desc.dims.nbDims == 2); + IXRT_PLUGIN_ASSERT(outputs[1].desc.dims.d[0] == B); + IXRT_PLUGIN_ASSERT(outputs[1].desc.dims.d[1] == S); +} + +size_t EmbLayerNormPluginDynamic::getWorkspaceSize(PluginTensorDesc const* inputs, int32_t nbInputs, + PluginTensorDesc const* outputs, int32_t nbOutputs) const noexcept { + return 0; +} + +int32_t EmbLayerNormPluginDynamic::enqueue(PluginTensorDesc const* inputDesc, PluginTensorDesc const* outputDesc, + void const* const* inputs, void* const* outputs, void* workspace, + cudaStream_t stream) noexcept { + gLogInfo << "in EmbLayerNormPluginDynamic.." << endl; + try { + int32_t const B = inputDesc->dims.d[BDIM]; + int32_t const S = inputDesc->dims.d[SDIM]; + int32_t status = STATUS_SUCCESS; + int32_t fmha_S = S; + int32_t batch_tokens = B * fmha_S; + + // Our plugin outputs only one tensor + auto const inputIds = static_cast(inputs[0]); + auto const segmentIds = static_cast(inputs[1]); + + half const* beta = mBetaDev.get(); + half const* gamma = mGammaDev.get(); + if (mMhaType == DataType::kFLOAT) { + gLogError << "embLayerNormPlugin float type not supported!" << endl; + return STATUS_NOT_SUPPORTED; + } else if (mMhaType == DataType::kHALF) { + auto output = static_cast(outputs[0]); + auto mNewMask = static_cast(outputs[1]); + auto const wordEmb = static_cast(mWordEmbDev.get()); + auto const tokEmb = static_cast(mTokEmbDev.get()); + auto const posEmb = static_cast(mPosEmbDev.get()); + + status = embLayerNorm(stream, static_cast(mHiddenSize), B, S, inputIds, segmentIds, beta, gamma, + wordEmb, posEmb, tokEmb, mWordVocabSize, mTokVocabSize, output, mNewMask, mPadId); + + if (status != cudaSuccess) { + return status; + } + } else { + gLogError << "Unsupported type error, expected [kHALF,kFLOAT], but received " + << static_cast(mMhaType) << endl; + + return STATUS_NOT_SUPPORTED; + } + + return status; + } catch (std::exception const& e) { + caughtError(e); + } + return STATUS_FAILURE; +} diff --git a/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/emb_layernorm/embLayerNormPlugin.cu b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/emb_layernorm/embLayerNormPlugin.cu new file mode 100644 index 0000000000000000000000000000000000000000..df621789571400d430fae56a61e0594ec1eefa55 --- /dev/null +++ b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/emb_layernorm/embLayerNormPlugin.cu @@ -0,0 +1,65 @@ +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +* All Rights Reserved. +* +* Licensed under the Apache License, Version 2.0 (the "License"); you may +* not use this file except in compliance with the License. You may obtain +* a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +* License for the specific language governing permissions and limitations +* under the License. +*/ +#include "backend/bert/bert_helper.h" +#include "backend/transformer/transformer_embed.h" +#include "embLayerNormPlugin.h" + +using namespace nvinfer1::plugin::backend; + +namespace nvinfer1::plugin { +namespace bert { + +cudaError_t embLayerNorm(cudaStream_t stream, int E, int B, int S, int32_t const* inputIds, int32_t const* segmentIds, + half const* beta, half const* gamma, half const* wordEmb, half const* posEmb, + half const* tokEmb, int32_t const wordSize, int32_t const tokSize, half* output, + int32_t* maskIdx, int32_t padId) { + IxinferBertEmbedLn(wordEmb, posEmb, tokEmb, inputIds, output, maskIdx, (int*)segmentIds, padId, B, S, E, gamma, + beta, stream); + return cudaSuccess; +} + +void __global__ IxinferMaskPadKernel(const int32_t* mask, int32_t* new_mask, int bsz, int ori_seq_len, int hsz, + int fmha_seq_len) { + int batch_idx = blockIdx.x; + int seq_idx = blockIdx.y; + + if (seq_idx < ori_seq_len) { + new_mask[batch_idx * fmha_seq_len + seq_idx] = mask[batch_idx * ori_seq_len + seq_idx]; + } else { + new_mask[batch_idx * fmha_seq_len + seq_idx] = 1; + } +} + +void IxinferMaskPad(int32_t* mask, int32_t* new_mask, int bsz, int ori_seq_len, int hsz, int fmha_seq_len, + int batch_tokens, cudaStream_t stream) { + if (hsz / 2 > 4096) { + throw std::runtime_error("hsz/2>4096"); + } + if (hsz % 2 != 0) { + throw std::runtime_error("hsz % 2 !=0"); + } + if (ori_seq_len > fmha_seq_len) { + throw std::runtime_error("ori_seq_len > fmha_seq_len"); + } + if (bsz * ori_seq_len > batch_tokens) { + throw std::runtime_error("bsz*ori_seq_len > batch_tokens"); + } + dim3 blockSize(bsz, fmha_seq_len); + IxinferMaskPadKernel<<>>(mask, new_mask, bsz, ori_seq_len, hsz, fmha_seq_len); +} + +} // namespace bert +} // namespace nvinfer1::plugin diff --git a/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/emb_layernorm/embLayerNormPlugin.h b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/emb_layernorm/embLayerNormPlugin.h new file mode 100644 index 0000000000000000000000000000000000000000..f9ee0c6dd230ccd61bfbf6df3c3532e4b17918b3 --- /dev/null +++ b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/emb_layernorm/embLayerNormPlugin.h @@ -0,0 +1,130 @@ +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +* All Rights Reserved. +* +* Licensed under the Apache License, Version 2.0 (the "License"); you may +* not use this file except in compliance with the License. You may obtain +* a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +* License for the specific language governing permissions and limitations +* under the License. +*/ +#pragma once +#include +#include + +#include +#include +#include +#include +#include + +#include "NvInferRuntimeCommon.h" +#include "bertCommon.h" + +namespace nvinfer1::plugin { +namespace bert { + +cudaError embLayerNorm(cudaStream_t stream, int E, int B, int S, int32_t const* inputIds, int32_t const* segmentIds, + half const* beta, half const* gamma, half const* wordEmb, half const* posEmb, half const* tokEmb, + int32_t const wordSize, int32_t const tokSize, half* output, int32_t* maskIdx, int32_t padId); + +void IxinferMaskPad(int32_t* mask, int32_t* new_mask, int bsz, int ori_seq_len, int hsz, int fmha_seq_len, + int batch_tokens, cudaStream_t stream); +void IxinferEncPad(half* query, half* new_query, int32_t* mask, int32_t* new_mask, int bsz, int ori_seq_len, int hsz, + int fmha_seq_len, int batch_tokens, cudaStream_t stream); + +class EmbLayerNormPluginDynamic : public IPluginV2DynamicExt { + public: + EmbLayerNormPluginDynamic(std::string const& name, nvinfer1::DataType const type, nvinfer1::DataType const mhaType, + nvinfer1::Weights const& beta, nvinfer1::Weights const& gamma, + nvinfer1::Weights const& word_emb, nvinfer1::Weights const& pos_emb, + nvinfer1::Weights const& tok_emb, bool const useFullMask, int32_t padId = 0); + EmbLayerNormPluginDynamic(std::string const& name, void const* data, size_t length); + EmbLayerNormPluginDynamic() noexcept = delete; + ~EmbLayerNormPluginDynamic() override = default; + + // IPluginV2 methods + char const* getPluginType() const noexcept override; + char const* getPluginVersion() const noexcept override; + int32_t getNbOutputs() const noexcept override; + int32_t initialize() noexcept override; + void terminate() noexcept override; + size_t getSerializationSize() const noexcept override; + void serialize(void* buffer) const noexcept override; + void destroy() noexcept override; + void setPluginNamespace(char const* libNamespace) noexcept override; + char const* getPluginNamespace() const noexcept override; + + // IPluginV2Ext methods + DataType getOutputDataType(int32_t index, DataType const* inputType, int32_t nbInputs) const noexcept override; + + // IPluginV2DynamicExt methods + IPluginV2DynamicExt* clone() const noexcept override; + DimsExprs getOutputDimensions(int32_t outputIndex, DimsExprs const* inputs, int32_t nbInputs, + IExprBuilder& exprBuilder) noexcept override; + bool supportsFormatCombination(int32_t pos, PluginTensorDesc const* inOut, int32_t nbInputs, + int32_t nbOutputs) noexcept override; + void configurePlugin(DynamicPluginTensorDesc const* in, int32_t nbInputs, DynamicPluginTensorDesc const* out, + int32_t nbOutputs) noexcept override; + size_t getWorkspaceSize(PluginTensorDesc const* inputs, int32_t nbInputs, PluginTensorDesc const* outputs, + int32_t nbOutputs) const noexcept override; + int32_t enqueue(PluginTensorDesc const* inputDesc, PluginTensorDesc const* outputDesc, void const* const* inputs, + void* const* outputs, void* workspace, cudaStream_t stream) noexcept override; + + private: + const std::string mLayerName; + std::string mNamespace; + size_t mHiddenSize; + size_t mSeqLen; + size_t mBatchSize; + size_t mPadId; + DataType mEmbType; + bool mUseFullMask; + DataType mMhaType; + size_t mWordVocabSize, mPosVocabSize, mTokVocabSize; + cuda_unique_ptr mGammaDev; + cuda_unique_ptr mBetaDev; + cuda_unique_ptr mWordEmbDev; + cuda_unique_ptr mTokEmbDev; + cuda_unique_ptr mPosEmbDev; + // cuda_unique_ptr mNewMask; + WeightsWithOwnership mBeta; + WeightsWithOwnership mGamma; + WeightsWithOwnership mWordEmb; + WeightsWithOwnership mTokEmb; + WeightsWithOwnership mPosEmb; +}; + +class EmbLayerNormPluginDynamicCreator : public IPluginCreator { + public: + EmbLayerNormPluginDynamicCreator(); + + ~EmbLayerNormPluginDynamicCreator() override = default; + + char const* getPluginName() const noexcept override; + + char const* getPluginVersion() const noexcept override; + + PluginFieldCollection const* getFieldNames() noexcept override; + + IPluginV2DynamicExt* createPlugin(char const* name, PluginFieldCollection const* fc) noexcept override; + + IPluginV2DynamicExt* deserializePlugin(char const* name, void const* serialData, + size_t serialLength) noexcept override; + + void setPluginNamespace(char const* pluginNamespace) noexcept override; + char const* getPluginNamespace() const noexcept override; + + private: + static PluginFieldCollection mFC; + static std::vector mPluginAttributes; + std::string mNamespace; +}; + +} // namespace bert +} // namespace nvinfer1::plugin diff --git a/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/ffn/ffnPlugin.cpp b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/ffn/ffnPlugin.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f1945a91d23e454f89fbd0fa70b16c2b836db7d5 --- /dev/null +++ b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/ffn/ffnPlugin.cpp @@ -0,0 +1,397 @@ +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +* All Rights Reserved. +* +* Licensed under the Apache License, Version 2.0 (the "License"); you may +* not use this file except in compliance with the License. You may obtain +* a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +* License for the specific language governing permissions and limitations +* under the License. +*/ +#include "ffnPlugin.h" + +#include "NvInferRuntime.h" +#include "NvInferRuntimeCommon.h" +#include "backend/ixinfer/ixinfer_gemm_helper.h" +#include "bertCommon.h" +#include "checkMacrosPlugin.h" +#include "gelu/geluPlugin.h" +#include "plugin.h" +#include "serialize.h" + +using namespace nvinfer1; +using namespace nvinfer1::plugin; +using namespace nvinfer1::plugin::bert; +using namespace nvinfer1::plugin::backend; + +namespace { +char const* const kFFN_VERSION{"1"}; +char const* const kFFN_NAME{"CustomFFNPluginDynamic_IxRT"}; +} // namespace + +// Static class fields initialization +PluginFieldCollection FFNPluginDynamicCreator::mFFN{}; +std::vector FFNPluginDynamicCreator::mPluginAttributes; + +FFNPluginDynamicCreator::FFNPluginDynamicCreator() { + mPluginAttributes.clear(); + mPluginAttributes.emplace_back(PluginField("out_dims", nullptr, PluginFieldType::kINT32, 1)); + mPluginAttributes.emplace_back(PluginField("type_id", nullptr, PluginFieldType::kINT32, 1)); + mPluginAttributes.emplace_back(PluginField("act_type", nullptr, PluginFieldType::kINT32, 1)); + + mFFN.nbFields = mPluginAttributes.size(); + mFFN.fields = mPluginAttributes.data(); +} + +char const* FFNPluginDynamicCreator::getPluginName() const noexcept { return kFFN_NAME; } + +char const* FFNPluginDynamicCreator::getPluginVersion() const noexcept { return kFFN_VERSION; } + +PluginFieldCollection const* FFNPluginDynamicCreator::getFieldNames() noexcept { return &mFFN; } + +IPluginV2* FFNPluginDynamicCreator::createPlugin(char const* name, PluginFieldCollection const* fc) noexcept { + try { + gLogInfo << "Creating FFNPluginDynamicCreator..." << endl; + IXRT_PLUGIN_ASSERT(name != nullptr); + IXRT_PLUGIN_ASSERT(fc != nullptr); + + int32_t outDims = 0; + int32_t typeId = -1; + int32_t act_type = -1; + Weights W1{DataType::kFLOAT, nullptr, 0LL}; + Weights W2{DataType::kFLOAT, nullptr, 0LL}; + Weights B1{DataType::kFLOAT, nullptr, 0LL}; + plugin::validateRequiredAttributesExist({"out_dims", "type_id", "W1", "W2", "B1"}, fc); + + for (int32_t i = 0; i < fc->nbFields; i++) { + std::string fieldName(fc->fields[i].name); + if (fieldName.compare("out_dims") == 0) { + outDims = static_cast(fc->fields[i].data)[0]; + gLogInfo << "Building outDims: " << outDims << endl; + } + + if (fieldName.compare("type_id") == 0) { + typeId = static_cast(fc->fields[i].data)[0]; + gLogInfo << "Building typeId: " << typeId << endl; + } + + if (fieldName.compare("W1") == 0) { + gLogInfo << "Building W1..." << endl; + W1.values = fc->fields[i].data; + W1.count = fc->fields[i].length; + W1.type = fieldTypeToDataType(fc->fields[i].type); + gLogInfo << "Is W1 float32: " << (W1.type == DataType::kFLOAT) << endl; + } + + if (fieldName.compare("W2") == 0) { + gLogInfo << "Building W2..." << endl; + W2.values = fc->fields[i].data; + W2.count = fc->fields[i].length; + W2.type = fieldTypeToDataType(fc->fields[i].type); + gLogInfo << "Is W2 float32: " << (W2.type == DataType::kFLOAT) << endl; + } + + if (fieldName.compare("B1") == 0) { + gLogInfo << "Building B1..." << endl; + B1.values = fc->fields[i].data; + B1.count = fc->fields[i].length; + B1.type = fieldTypeToDataType(fc->fields[i].type); + gLogInfo << "Is B1 float32: " << (B1.type == DataType::kFLOAT) << endl; + } + + if (fieldName.compare("act_type") == 0) { + gLogInfo << "Building act_type..." << endl; + act_type = static_cast(fc->fields[i].data)[0]; + gLogInfo << "Building act_type: " << act_type << endl; + } + } + + if (outDims <= 0) { + gLogInfo << "Invalid output dimension" << endl; + } + if (typeId < 0 || typeId > 1) { + gLogInfo << "Invalid type id" << typeId << endl; + } + if (W1.count == 0 || W1.values == nullptr) { + gLogInfo << "Invalid weights W1" << endl; + } + if (W2.count == 0 || W2.values == nullptr) { + gLogInfo << "Invalid weights W2" << endl; + } + if (B1.count == 0 || B1.values == nullptr) { + gLogInfo << "Invalid weights B1" << endl; + } + + DataType type = typeId == 0 ? DataType::kFLOAT : DataType::kHALF; + return new FFNPluginDynamic(name, type, outDims, act_type, W1, W2, B1); + } catch (std::exception const& e) { + caughtError(e); + } + return nullptr; +} + +IPluginV2* FFNPluginDynamicCreator::deserializePlugin(char const* name, void const* serialData, + size_t serialLength) noexcept { + // This object will be deleted when the network is destroyed, which will + // call FFNPluginDynamic::destroy() + try { + return new FFNPluginDynamic(name, serialData, serialLength); + } catch (std::exception const& e) { + caughtError(e); + } + return nullptr; +} + +void FFNPluginDynamicCreator::setPluginNamespace(char const* libNamespace) noexcept { + try { + IXRT_PLUGIN_ASSERT(libNamespace != nullptr); + mNamespace = libNamespace; + } catch (std::exception const& e) { + caughtError(e); + } +} + +char const* FFNPluginDynamicCreator::getPluginNamespace() const noexcept { return mNamespace.c_str(); } + +// REGISTER_TENSORRT_PLUGIN(FFNPluginDynamicCreator); +//#########################################################################// +FFNPluginDynamic::FFNPluginDynamic(std::string const name, DataType const type, int32_t const outDim, + int32_t const act_type, Weights const& W1, Weights const& W2, Weights const& B1) + : mLayerName(name), + mType(type), + mHiddenSize(outDim), + mActType(act_type), + mWdev1(nullptr), + mWdev2(nullptr), + mBdev1(nullptr) { + mW1.convertAndCopy(W1, mType); + mW2.convertAndCopy(W2, mType); + mB1.convertAndCopy(B1, mType); + copyToDevice(mW1, getWeightsSize(mW1, mType), mWdev1); + copyToDevice(mW2, getWeightsSize(mW2, mType), mWdev2); + copyToDevice(mB1, getWeightsSize(mB1, mType), mBdev1); +} + +FFNPluginDynamic::FFNPluginDynamic(std::string const name, void const* data, size_t length) + : mLayerName(name), mWdev1(nullptr), mWdev2(nullptr), mBdev1(nullptr) { + gLogInfo << "FFNPluginDynamic deserialize" << endl; + + // Deserialize in the same order as serialization + deserialize_value(&data, &length, &mType); + deserialize_value(&data, &length, &mHiddenSize); + deserialize_value(&data, &length, &mActType); + + char const* d = static_cast(data); + + mW1.convertAndCopy(d, mHiddenSize * mHiddenSize * 4, mType); + copyToDevice(mW1, getWeightsSize(mW1, mType), mWdev1); + + mW2.convertAndCopy(d, mHiddenSize * mHiddenSize * 4, mType); + copyToDevice(mW2, getWeightsSize(mW2, mType), mWdev2); + + mB1.convertAndCopy(d, mHiddenSize * 4, mType); + copyToDevice(mB1, getWeightsSize(mB1, mType), mBdev1); +} + +// IPluginV2 Methods +char const* FFNPluginDynamic::getPluginType() const noexcept { return kFFN_NAME; } + +char const* FFNPluginDynamic::getPluginVersion() const noexcept { return kFFN_VERSION; } + +int32_t FFNPluginDynamic::getNbOutputs() const noexcept { return 1; } + +int32_t FFNPluginDynamic::initialize() noexcept { + gLogInfo << "FFNPluginDynamic initialize" << endl; + return 0; +} + +void FFNPluginDynamic::terminate() noexcept { gLogInfo << "FFNPluginDynamic terminate" << endl; } + +size_t FFNPluginDynamic::getSerializationSize() const noexcept { + size_t wordSize = getElementSize(mType); + return wordSize * (mHiddenSize * mHiddenSize * 8 + mHiddenSize * 4) + sizeof(mType) + sizeof(mHiddenSize) + + sizeof(mActType); +} + +void FFNPluginDynamic::serialize(void* buffer) const noexcept { + serialize_value(&buffer, mType); + serialize_value(&buffer, mHiddenSize); + serialize_value(&buffer, mActType); + + size_t wordSize = getElementSize(mType); + char* d = static_cast(buffer); + serFromDev(d, static_cast(mWdev1.get()), 4 * mHiddenSize * mHiddenSize * wordSize); + serFromDev(d, static_cast(mWdev2.get()), 4 * mHiddenSize * mHiddenSize * wordSize); + serFromDev(d, static_cast(mBdev1.get()), 4 * mHiddenSize * wordSize); +} + +void FFNPluginDynamic::destroy() noexcept { + gLogInfo << "FFNPluginDynamic destroy" << endl; + mWdev1.reset(nullptr); + mWdev2.reset(nullptr); + mBdev1.reset(nullptr); + delete this; +} + +void FFNPluginDynamic::setPluginNamespace(char const* libNamespace) noexcept { + try { + IXRT_PLUGIN_ASSERT(libNamespace != nullptr); + mNamespace = libNamespace; + } catch (std::exception const& e) { + caughtError(e); + } +} + +char const* FFNPluginDynamic::getPluginNamespace() const noexcept { return mNamespace.c_str(); } + +// IPluginV2Ext Methods +DataType FFNPluginDynamic::getOutputDataType(int32_t index, DataType const* inputTypes, + int32_t nbInputs) const noexcept { + IXRT_PLUGIN_ASSERT(index == 0); + IXRT_PLUGIN_ASSERT(nbInputs == 1); + IXRT_PLUGIN_ASSERT(inputTypes != nullptr); + IXRT_PLUGIN_ASSERT(inputTypes[0] == DataType::kFLOAT || inputTypes[0] == DataType::kHALF); + return inputTypes[0]; +} + +// IPluginV2DynamicExt Methods +IPluginV2DynamicExt* FFNPluginDynamic::clone() const noexcept { + try { + gLogInfo << "FFNPluginDynamic clone" << endl; + + auto* p = new FFNPluginDynamic(mLayerName, mType, mHiddenSize, mActType, mW1, mW2, mB1); + p->setPluginNamespace(mNamespace.c_str()); + + return p; + } catch (std::exception const& e) { + caughtError(e); + } + return nullptr; +} + +DimsExprs FFNPluginDynamic::getOutputDimensions(int32_t outputIndex, DimsExprs const* inputs, int32_t nbInputs, + IExprBuilder& exprBuilder) noexcept { + try { + IXRT_PLUGIN_ASSERT(nbInputs == 1); + IXRT_PLUGIN_ASSERT(outputIndex == 0); + IXRT_PLUGIN_ASSERT(inputs != nullptr); + DimsExprs ret; + + if(inputs[0].nbDims == 5){ + ret.nbDims = 5; + ret.d[0] = inputs[0].d[0]; + ret.d[1] = inputs[0].d[1]; + ret.d[2] = exprBuilder.constant(mHiddenSize); + ret.d[3] = exprBuilder.constant(1); + ret.d[4] = exprBuilder.constant(1); + + }else if(inputs[0].nbDims == 3){ + ret.nbDims = 3; + ret.d[0] = inputs[0].d[0]; + ret.d[1] = inputs[0].d[1]; + ret.d[2] = exprBuilder.constant(mHiddenSize); + } + + return ret; + } catch (std::exception const& e) { + caughtError(e); + } + return DimsExprs{}; +} + +bool FFNPluginDynamic::supportsFormatCombination(int32_t pos, PluginTensorDesc const* inOut, int32_t nbInputs, + int32_t nbOutputs) noexcept { + IXRT_PLUGIN_ASSERT(nbInputs == 1); + IXRT_PLUGIN_ASSERT(nbOutputs == 1); + IXRT_PLUGIN_ASSERT(inOut != nullptr); + + PluginTensorDesc const& in = inOut[pos]; + if (pos == 0) { + return (in.type == mType) && (in.format == TensorFormat::kLINEAR); + } + PluginTensorDesc const& prev = inOut[pos - 1]; + + // output + return in.type == prev.type && in.format == prev.format; +} + +void FFNPluginDynamic::configurePlugin(DynamicPluginTensorDesc const* inputs, int32_t nbInputs, + DynamicPluginTensorDesc const* outputs, int32_t nbOutputs) noexcept { + try { + // Validate input arguments + IXRT_PLUGIN_ASSERT(nbOutputs == 1); + IXRT_PLUGIN_ASSERT(nbInputs == 1); + IXRT_PLUGIN_ASSERT(inputs != nullptr); + IXRT_PLUGIN_ASSERT(outputs != nullptr); + IXRT_PLUGIN_ASSERT(mType == inputs[0].desc.type); + auto const& inDims0 = inputs[0].desc.dims; + + IXRT_PLUGIN_ASSERT(inDims0.nbDims == 3 || inDims0.nbDims == 5); + // IXRT_PLUGIN_ASSERT(inDims0.d[3] == 1); + // IXRT_PLUGIN_ASSERT(inDims0.d[4] == 1); +#ifdef __ILUVATAR__ + CUINFER_CHECK(cuinferCreate(&cuinfer_handle)); +#else + CHECK_GPU_ERROR(cublasLtCreate(&blaslt_handle)); +#endif + } catch (std::exception const& e) { + caughtError(e); + } +} + +size_t FFNPluginDynamic::getWorkspaceSize(PluginTensorDesc const* inputs, int32_t nbInputs, + PluginTensorDesc const* outputs, int32_t nbOutputs) const noexcept { + int32_t const S = inputs[0].dims.d[SDIM]; + int32_t const B = inputs[0].dims.d[BDIM]; + return B * S * 4 * mHiddenSize * sizeof(half); +} + +int32_t FFNPluginDynamic::enqueue(PluginTensorDesc const* inputDesc, PluginTensorDesc const* outputDesc, + void const* const* inputs, void* const* outputs, void* workSpace, + cudaStream_t stream) noexcept { + gLogInfo << "in FFNPluginDynamic.." << endl; + try { +#ifdef __ILUVATAR__ + CUINFER_CHECK(cuinferSetStream(cuinfer_handle, stream)); +#endif + int32_t const S = inputDesc->dims.d[SDIM]; + int32_t const B = inputDesc->dims.d[BDIM]; + int32_t const n = S * B; + IXRT_PLUGIN_ASSERT(n >= 0); + + if (mType == DataType::kHALF) { + auto const* const input = static_cast(inputs[0]); + auto* output = static_cast(outputs[0]); + auto weight1 = static_cast(mWdev1.get()); + auto weight2 = static_cast(mWdev2.get()); + auto bias1 = static_cast(mBdev1.get()); + auto buffer = static_cast(workSpace); + +#ifdef __ILUVATAR__ + cuinfer_gemm(weight1, input, bias1, buffer, 1, mHiddenSize * 4, n, mHiddenSize, 0, 0, 0, 1.0f, mActType, + stream, cuinfer_handle); + cuinfer_gemm(weight2, buffer, nullptr, output, 1, mHiddenSize, n, 4 * mHiddenSize, 0, 0, 0, 1.0f, -1, + stream, cuinfer_handle); +#else + cublaslt_gemm(weight1, input, buffer, 1, mHiddenSize * 4, n, mHiddenSize, 0, 0, 0, 1.0f, blaslt_handle, + stream); + computeGeluBias(buffer, buffer, bias1, 4 * mHiddenSize, n, stream); + cublaslt_gemm(weight2, buffer, output, 1, mHiddenSize, n, mHiddenSize * 4, 0, 0, 0, 1.0f, blaslt_handle, + stream); +#endif + } else { + gLogError << "Unsupported type error, expected [kHALF], but received " << static_cast(mType) + << endl; + return STATUS_FAILURE; + } + return STATUS_SUCCESS; + } catch (std::exception const& e) { + caughtError(e); + } + return STATUS_FAILURE; +} diff --git a/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/ffn/ffnPlugin.h b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/ffn/ffnPlugin.h new file mode 100644 index 0000000000000000000000000000000000000000..e80be54dce928294ecb5ca1cd86cf4bf15cb21e7 --- /dev/null +++ b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/ffn/ffnPlugin.h @@ -0,0 +1,213 @@ +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +* All Rights Reserved. +* +* Licensed under the Apache License, Version 2.0 (the "License"); you may +* not use this file except in compliance with the License. You may obtain +* a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +* License for the specific language governing permissions and limitations +* under the License. +*/ +#pragma once +#include + +#include + +#include "NvInferRuntime.h" +#include "NvInferRuntimeCommon.h" +#include "backend/cublas/cublas_helper.h" +#include "bertCommon.h" + +namespace nvinfer1::plugin { +namespace bert { + +class FFNPluginDynamic : public nvinfer1::IPluginV2DynamicExt { + public: + FFNPluginDynamic(std::string const name, nvinfer1::DataType const type, int32_t const outDim, + int32_t const out_type, nvinfer1::Weights const& W1, nvinfer1::Weights const& W2, + nvinfer1::Weights const& B1); + + FFNPluginDynamic(std::string const name, void const* data, size_t length); + + // It doesn't make sense to make FFNPluginDynamic without arguments, so we + // delete default constructor. + FFNPluginDynamic() = delete; + + // IPluginV2 Methods + char const* getPluginType() const noexcept override; + char const* getPluginVersion() const noexcept override; + int32_t getNbOutputs() const noexcept override; + int32_t initialize() noexcept override; + void terminate() noexcept override; + size_t getSerializationSize() const noexcept override; + void serialize(void* buffer) const noexcept override; + void destroy() noexcept override; + void setPluginNamespace(char const* pluginNamespace) noexcept override; + char const* getPluginNamespace() const noexcept override; + + // IPluginV2Ext Methods + nvinfer1::DataType getOutputDataType(int32_t index, nvinfer1::DataType const* inputTypes, + int32_t nbInputs) const noexcept override; + + // IPluginV2DynamicExt Methods + nvinfer1::IPluginV2DynamicExt* clone() const noexcept override; + nvinfer1::DimsExprs getOutputDimensions(int32_t outputIndex, nvinfer1::DimsExprs const* inputs, int32_t nbInputs, + nvinfer1::IExprBuilder& exprBuilder) noexcept override; + bool supportsFormatCombination(int32_t pos, nvinfer1::PluginTensorDesc const* inOut, int32_t nbInputs, + int32_t nbOutputs) noexcept override; + void configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int32_t nbInputs, + nvinfer1::DynamicPluginTensorDesc const* out, int32_t nbOutputs) noexcept override; + size_t getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int32_t nbInputs, + nvinfer1::PluginTensorDesc const* outputs, int32_t nbOutputs) const noexcept override; + int32_t enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc, + void const* const* inputs, void* const* outputs, void* workspace, + cudaStream_t stream) noexcept override; + + private: + std::string const mLayerName; + std::string mNamespace; + + nvinfer1::DataType mType; + size_t mHiddenSize; + size_t mActType; + + bert::WeightsWithOwnership mW1; + bert::WeightsWithOwnership mB1; + bert::WeightsWithOwnership mW2; + bert::cuda_unique_ptr mWdev1; + bert::cuda_unique_ptr mWdev2; + bert::cuda_unique_ptr mBdev1; + +#ifdef __ILUVATAR__ + cuinferHandle_t cuinfer_handle; +#else + cublasLtHandle_t blaslt_handle; +#endif + cudaStream_t stream; +}; + +class FFNPluginDynamicCreator : public nvinfer1::IPluginCreator { + public: + FFNPluginDynamicCreator(); + + char const* getPluginName() const noexcept override; + + char const* getPluginVersion() const noexcept override; + + nvinfer1::PluginFieldCollection const* getFieldNames() noexcept override; + + nvinfer1::IPluginV2* createPlugin(char const* name, nvinfer1::PluginFieldCollection const* fc) noexcept override; + + nvinfer1::IPluginV2* deserializePlugin(char const* name, void const* serialData, + size_t serialLength) noexcept override; + + void setPluginNamespace(char const* pluginNamespace) noexcept override; + + char const* getPluginNamespace() const noexcept override; + + private: + static nvinfer1::PluginFieldCollection mFFN; + static std::vector mPluginAttributes; + std::string mNamespace; +}; + +class FFNInt8PluginDynamic : public nvinfer1::IPluginV2DynamicExt { + public: + FFNInt8PluginDynamic(std::string const name, nvinfer1::DataType const type, int32_t const outDim, + nvinfer1::Weights const& W, nvinfer1::Weights const& Bias, vector const& scale); + + FFNInt8PluginDynamic(std::string const name, void const* data, size_t length); + + // It doesn't make sense to make FFNInt8PluginDynamic without arguments, so we + // delete default constructor. + FFNInt8PluginDynamic() = delete; + + // IPluginV2 Methods + char const* getPluginType() const noexcept override; + char const* getPluginVersion() const noexcept override; + int32_t getNbOutputs() const noexcept override; + int32_t initialize() noexcept override; + void terminate() noexcept override; + size_t getSerializationSize() const noexcept override; + void serialize(void* buffer) const noexcept override; + void destroy() noexcept override; + void setPluginNamespace(char const* pluginNamespace) noexcept override; + char const* getPluginNamespace() const noexcept override; + + // IPluginV2Ext Methods + nvinfer1::DataType getOutputDataType(int32_t index, nvinfer1::DataType const* inputTypes, + int32_t nbInputs) const noexcept override; + + // IPluginV2DynamicExt Methods + nvinfer1::IPluginV2DynamicExt* clone() const noexcept override; + nvinfer1::DimsExprs getOutputDimensions(int32_t outputIndex, nvinfer1::DimsExprs const* inputs, int32_t nbInputs, + nvinfer1::IExprBuilder& exprBuilder) noexcept override; + bool supportsFormatCombination(int32_t pos, nvinfer1::PluginTensorDesc const* inOut, int32_t nbInputs, + int32_t nbOutputs) noexcept override; + void configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int32_t nbInputs, + nvinfer1::DynamicPluginTensorDesc const* out, int32_t nbOutputs) noexcept override; + size_t getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int32_t nbInputs, + nvinfer1::PluginTensorDesc const* outputs, int32_t nbOutputs) const noexcept override; + int32_t enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc, + void const* const* inputs, void* const* outputs, void* workspace, + cudaStream_t stream) noexcept override; + + private: + std::string const mLayerName; + std::string mNamespace; + + nvinfer1::DataType mType; + size_t mOutDim; // leading dim + size_t mNumParams; + int32_t mNmax; + int32_t mK; + int32_t mNumBias; + + vector mScale; + + bert::WeightsWithOwnership mW; + bert::cuda_unique_ptr mWdev; + + bert::WeightsWithOwnership mBias; + bert::cuda_unique_ptr mBiasdev; + +#ifdef __ILUVATAR__ + cuinferHandle_t cuinfer_handle; +#else + cublasLtHandle_t blaslt_handle; +#endif + cudaStream_t stream; +}; + +class FFNInt8PluginDynamicCreator : public nvinfer1::IPluginCreator { + public: + FFNInt8PluginDynamicCreator(); + + char const* getPluginName() const noexcept override; + + char const* getPluginVersion() const noexcept override; + + nvinfer1::PluginFieldCollection const* getFieldNames() noexcept override; + + nvinfer1::IPluginV2* createPlugin(char const* name, nvinfer1::PluginFieldCollection const* fc) noexcept override; + + nvinfer1::IPluginV2* deserializePlugin(char const* name, void const* serialData, + size_t serialLength) noexcept override; + + void setPluginNamespace(char const* pluginNamespace) noexcept override; + + char const* getPluginNamespace() const noexcept override; + + private: + static nvinfer1::PluginFieldCollection mFC; + static std::vector mPluginAttributes; + std::string mNamespace; +}; + +} // namespace bert +} // namespace nvinfer1::plugin diff --git a/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/gelu/geluPlugin.cpp b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/gelu/geluPlugin.cpp new file mode 100644 index 0000000000000000000000000000000000000000..530d983eb1205158904f8d3fbe57c06ebf10cfaa --- /dev/null +++ b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/gelu/geluPlugin.cpp @@ -0,0 +1,347 @@ +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +* All Rights Reserved. +* +* Licensed under the Apache License, Version 2.0 (the "License"); you may +* not use this file except in compliance with the License. You may obtain +* a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +* License for the specific language governing permissions and limitations +* under the License. +*/ +#include "geluPlugin.h" + +#include + +#include "NvInferRuntimeCommon.h" +#include "backend/bert/bert_layer_kernel.h" +#include "bertCommon.h" +#include "checkMacrosPlugin.h" +#include "plugin.h" +#include "serialize.h" + +using namespace nvinfer1; +using namespace nvinfer1::plugin; +using namespace nvinfer1::plugin::bert; + +namespace { +char const* const kGELU_IXRT_PLUGIN_VERSION{"1"}; +char const* const kGELU_IXRT_PLUGIN_NAME{"CustomGeluPluginDynamic_IxRT"}; +} // namespace + +// Static class fields initialization +PluginFieldCollection GeluPluginDynamicCreator::mFC{}; +std::vector GeluPluginDynamicCreator::mPluginAttributes; + +GeluPluginDynamicCreator::GeluPluginDynamicCreator() { + mPluginAttributes.emplace_back(PluginField("type_id", nullptr, PluginFieldType::kINT32, 1)); + mPluginAttributes.emplace_back(PluginField("bias", nullptr, PluginFieldType::kFLOAT32, 1)); + + // Fill PluginFieldCollection with PluginField arguments metadata + mFC.nbFields = mPluginAttributes.size(); + mFC.fields = mPluginAttributes.data(); +} + +char const* GeluPluginDynamicCreator::getPluginName() const noexcept { return kGELU_IXRT_PLUGIN_NAME; } + +char const* GeluPluginDynamicCreator::getPluginVersion() const noexcept { return kGELU_IXRT_PLUGIN_VERSION; } + +PluginFieldCollection const* GeluPluginDynamicCreator::getFieldNames() noexcept { return &mFC; } + +IPluginV2* GeluPluginDynamicCreator::createPlugin(char const* name, PluginFieldCollection const* fc) noexcept { + try { + gLogVerbose << "GeluPluginDynamicCreator createPlugin\n"; + IXRT_PLUGIN_ASSERT(fc != nullptr); + + Weights bias{DataType::kFLOAT, nullptr, 0}; + int32_t typeId = -1; + plugin::validateRequiredAttributesExist({"type_id"}, fc); + int32_t ld = 0; + + for (int32_t i = 0; i < fc->nbFields; i++) { + IXRT_PLUGIN_ASSERT(fc->fields[i].name != nullptr); + std::string fieldName(fc->fields[i].name); + + if (fieldName.compare("type_id") == 0) { + typeId = *static_cast(fc->fields[i].data); + } + if (fieldName.compare("bias") == 0) { + bias.values = fc->fields[i].data; + bias.count = fc->fields[i].length; + bias.type = fieldTypeToDataType(fc->fields[i].type); + if (ld == 0) { + ld = bias.count; + } + } + if (fieldName.compare("ld") == 0) { + ld = *static_cast(fc->fields[i].data); + } + } + + if (typeId < 0 || typeId > 3) { + gLogError << "GeluPluginDynamicCreator: invalid typeId " << typeId << std::endl; + return nullptr; + } + + return new GeluPluginDynamic(name, static_cast(typeId), bias, ld); + } catch (std::exception const& e) { + caughtError(e); + } + return nullptr; +} + +IPluginV2* GeluPluginDynamicCreator::deserializePlugin(char const* name, void const* serialData, + size_t serialLength) noexcept { + // This object will be deleted when the network is destroyed, which will + // call GeluPluginDynamic::destroy() + try { + return new GeluPluginDynamic(name, serialData, serialLength); + } catch (std::exception const& e) { + caughtError(e); + } + return nullptr; +} + +void GeluPluginDynamicCreator::setPluginNamespace(char const* libNamespace) noexcept { + try { + IXRT_PLUGIN_ASSERT(libNamespace != nullptr); + mNamespace = libNamespace; + } catch (std::exception const& e) { + caughtError(e); + } +} + +char const* GeluPluginDynamicCreator::getPluginNamespace() const noexcept { return mNamespace.c_str(); } + +// REGISTER_TENSORRT_PLUGIN(GeluPluginDynamicCreator); +//#########################################################################// +GeluPluginDynamic::GeluPluginDynamic(const std::string name, const DataType type, Weights const& bias, const int ld) + : mLayerName(name), mType(type), mLd(ld), mNumBias(bias.count) { + if (mNumBias > 0) { + mBias.convertAndCopy(bias, DataType::kHALF); + copyToDevice(mBias, getWeightsSize(mBias, DataType::kHALF), mBiasDev); + } +} + +GeluPluginDynamic::GeluPluginDynamic(const std::string name, void const* data, size_t length) : mLayerName(name) { + gLogVerbose << "GeluPluginDynamic deserialize\n"; + deserialize_value(&data, &length, &mType); + deserialize_value(&data, &length, &mLd); + deserialize_value(&data, &length, &mNumBias); + + if (mNumBias > 0) { + IXRT_PLUGIN_ASSERT(mLd > 0); + char const* d = static_cast(data); + mBias.convertAndCopy(d, mNumBias, DataType::kHALF); + copyToDevice(mBias, getWeightsSize(mBias, DataType::kHALF), mBiasDev); + } +} + +// IPluginV2 Methods + +char const* GeluPluginDynamic::getPluginType() const noexcept { return kGELU_IXRT_PLUGIN_NAME; } + +char const* GeluPluginDynamic::getPluginVersion() const noexcept { return kGELU_IXRT_PLUGIN_VERSION; } + +int32_t GeluPluginDynamic::getNbOutputs() const noexcept { return 1; } + +int32_t GeluPluginDynamic::initialize() noexcept { + gLogVerbose << "GeluPluginDynamic initalize\n"; + return 0; +} + +void GeluPluginDynamic::terminate() noexcept { gLogVerbose << "GeluPluginDynamic terminate\n"; } + +size_t GeluPluginDynamic::getSerializationSize() const noexcept { + const size_t wordSize = getElementSize(mType); + return sizeof(mType) + sizeof(mLd) + sizeof(mNumBias) + mNumBias * sizeof(half); +} + +void GeluPluginDynamic::serialize(void* buffer) const noexcept { + serialize_value(&buffer, mType); + serialize_value(&buffer, mLd); + serialize_value(&buffer, mNumBias); + if (mNumBias > 0) { + IXRT_PLUGIN_ASSERT(mLd > 0); + char* d = static_cast(buffer); + + serFromDev(d, static_cast(mBiasDev.get()), mLd * getElementSize(DataType::kHALF)); + } +} + +void GeluPluginDynamic::destroy() noexcept { + gLogVerbose << "GeluPluginDynamic destroy\n"; + // This gets called when the network containing plugin is destroyed + if (mNumBias > 0) { + mBiasDev.reset(); + } + delete this; +} + +void GeluPluginDynamic::setPluginNamespace(char const* libNamespace) noexcept { + try { + IXRT_PLUGIN_ASSERT(libNamespace != nullptr); + mNamespace = libNamespace; + } catch (std::exception const& e) { + caughtError(e); + } +} + +char const* GeluPluginDynamic::getPluginNamespace() const noexcept { return mNamespace.c_str(); } + +// IPluginV2Ext Methods +nvinfer1::DataType GeluPluginDynamic::getOutputDataType(int32_t index, nvinfer1::DataType const* inputTypes, + int32_t nbInputs) const noexcept { + try { + IXRT_PLUGIN_ASSERT(index == 0); + IXRT_PLUGIN_ASSERT(inputTypes != nullptr); + IXRT_PLUGIN_ASSERT(inputTypes[0] == DataType::kFLOAT || inputTypes[0] == DataType::kHALF || + inputTypes[0] == DataType::kINT8); + return inputTypes[0]; + } catch (std::exception const& e) { + caughtError(e); + } + return DataType{}; +} + +// IPluginV2DynamicExt Methods +nvinfer1::IPluginV2DynamicExt* GeluPluginDynamic::clone() const noexcept { + try { + gLogVerbose << "GeluPluginDynamic clone\n"; + auto* plugin = new GeluPluginDynamic(mLayerName, mType, mBias, mLd); + plugin->setPluginNamespace(mNamespace.c_str()); + return plugin; + } catch (std::exception const& e) { + caughtError(e); + } + return nullptr; +} + +nvinfer1::DimsExprs GeluPluginDynamic::getOutputDimensions(int32_t outputIndex, nvinfer1::DimsExprs const* inputs, + int32_t nbInputs, + nvinfer1::IExprBuilder& exprBuilder) noexcept { + try { + IXRT_PLUGIN_ASSERT(inputs != nullptr); + IXRT_PLUGIN_ASSERT(nbInputs == 1); + IXRT_PLUGIN_ASSERT(outputIndex == 0); + return inputs[0]; + } catch (std::exception const& e) { + caughtError(e); + } + return DimsExprs{}; +} + +bool GeluPluginDynamic::supportsFormatCombination(int32_t pos, nvinfer1::PluginTensorDesc const* inOut, + int32_t nbInputs, int32_t nbOutputs) noexcept { + try { + IXRT_PLUGIN_ASSERT(inOut != nullptr); + IXRT_PLUGIN_ASSERT(nbInputs == 1); + IXRT_PLUGIN_ASSERT(nbOutputs == 1); + IXRT_PLUGIN_ASSERT(pos >= 0); + IXRT_PLUGIN_ASSERT(pos < nbInputs + nbOutputs); + } catch (std::exception const& e) { + caughtError(e); + return false; + } + + PluginTensorDesc const& input = inOut[0]; + if (pos == 0) { + return (input.type == mType) && (input.format == TensorFormat::kLINEAR); + } + if (pos == 1) { + PluginTensorDesc const& output = inOut[1]; + return (input.type == output.type) && (output.format == TensorFormat::kLINEAR) && (output.type == mType); + } + return false; +} + +void GeluPluginDynamic::configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int32_t nbInputs, + nvinfer1::DynamicPluginTensorDesc const* out, int32_t nbOutputs) noexcept { + gLogVerbose << "GeluPluginDynamic configurePlugin\n"; + + try { + IXRT_PLUGIN_ASSERT(in != nullptr); + IXRT_PLUGIN_ASSERT(nbInputs == 1); + IXRT_PLUGIN_ASSERT(mType == in[0].desc.type); + IXRT_PLUGIN_ASSERT(mType == DataType::kHALF || mType == DataType::kINT8); + } catch (std::exception const& e) { + caughtError(e); + } +} + +size_t GeluPluginDynamic::getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int32_t nbInputs, + nvinfer1::PluginTensorDesc const* outputs, + int32_t nbOutputs) const noexcept { + return 0; +} + +template +int32_t GeluPluginDynamic::enqueueTyped(void const* input_, void* output_, int32_t const inputVolume, + cudaStream_t stream) noexcept { + TDataType const* input = static_cast(input_); + TDataType* output = static_cast(output_); + int32_t const cols = inputVolume / mLd; + int32_t const rows = mLd; + + if (mNumBias > 0) { + TDataType const* bias = static_cast(mBiasDev.get()); + return computeGeluBias(output, input, bias, rows, cols, stream); + } else { + return computeGelu(stream, inputVolume, input, output); + } +} + +int32_t GeluPluginDynamic::enqueueInt8(void const* input_, void* output_, float dequant_scale, float quant_scale, + int32_t const inputVolume, cudaStream_t stream) noexcept { + int8_t const* input = static_cast(input_); + int8_t* output = static_cast(output_); + int32_t const cols = inputVolume / mLd; + int32_t const rows = mLd; + + if (mNumBias > 0) { + half const* bias = static_cast(mBiasDev.get()); + return computeGeluI8O8Bias(output, input, bias, rows, cols, dequant_scale, quant_scale, stream); + } else { + return computeGeluI8O8(stream, inputVolume, input, output, dequant_scale, quant_scale); + } +} + +int32_t GeluPluginDynamic::enqueue(nvinfer1::PluginTensorDesc const* inputDesc, + nvinfer1::PluginTensorDesc const* outputDesc, void const* const* inputs, + void* const* outputs, void* workspace, cudaStream_t stream) noexcept { + gLogInfo << "in GeluPluginDynamic.." << endl; + try { + IXRT_PLUGIN_ASSERT(inputDesc != nullptr); + IXRT_PLUGIN_ASSERT(inputs != nullptr); + IXRT_PLUGIN_ASSERT(outputs != nullptr); + } catch (std::exception const& e) { + caughtError(e); + return STATUS_FAILURE; + } + + int32_t const inputVolume = volume(inputDesc[0].dims); + int32_t batch_token_num = inputDesc[0].dims.d[BDIM] * inputDesc[0].dims.d[SDIM]; + + // Our plugin outputs only one tensor. + // Launch CUDA kernel wrapper and save its return value. + switch (mType) { + case DataType::kFLOAT: + return enqueueTyped(inputs[0], outputs[0], inputVolume, stream); + case DataType::kHALF: + return enqueueTyped(inputs[0], outputs[0], inputVolume, stream); + case DataType::kINT8: { + int8_t* input = (int8_t*)(inputs[0]); + int8_t* output = (int8_t*)(outputs[0]); + backend::IxinferBiasGeluI8II8O(batch_token_num, stream, (int8_t*)input, (int8_t*)output, + static_cast(mBiasDev.get()), mLd, inputDesc[0].scale, + 1.0 / outputDesc[0].scale); + return STATUS_SUCCESS; + } + default: + return STATUS_FAILURE; + } +} diff --git a/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/gelu/geluPlugin.cu b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/gelu/geluPlugin.cu new file mode 100644 index 0000000000000000000000000000000000000000..8bfd74f6e8469cf2977329ac341f83e6d447403f --- /dev/null +++ b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/gelu/geluPlugin.cu @@ -0,0 +1,149 @@ +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +* All Rights Reserved. +* +* Licensed under the Apache License, Version 2.0 (the "License"); you may +* not use this file except in compliance with the License. You may obtain +* a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +* License for the specific language governing permissions and limitations +* under the License. +*/ +#include "checkMacrosPlugin.h" +#include "geluPlugin.h" + +namespace nvinfer1::plugin { +namespace bert { +// constants for approximating the normal cdf +constexpr float A = 0.5f; +constexpr float B = 0.7978845608028654f; // sqrt(2.0/M_PI) +constexpr float C = 0.035677408136300125f; // 0.044715 * sqrt(2.0/M_PI) + +template +__global__ void geluKernel(const half a, const half b, const half c, int n, const half* input, half* output) { + const int idx = blockIdx.x * TPB + threadIdx.x; + + if (idx < n) { + const half in = input[idx]; + const half cdf = a + a * __float2half(tanh(__half2float(in * (c * in * in + b)))); + output[idx] = in * cdf; + } +} + +template +__global__ void geluKernel(const float a, const float b, const float c, int n, const float* input, float* output) { + const int idx = blockIdx.x * TPB + threadIdx.x; + + if (idx < n) { + const float in = input[idx]; + const float cdf = a + a * tanh(in * (c * in * in + b)); + output[idx] = in * cdf; + } +} + +template +__global__ void geluKernel(const float a, const float b, const float c, int n, const int8_t* input, int8_t* output, + float dequant_scale, float quant_scale) { + const int idx = blockIdx.x * TPB + threadIdx.x; + + if (idx < n) { + const float in = float(input[idx]) * dequant_scale; + const float cdf = a + a * tanh(in * (c * in * in + b)); + float i8_f = in * cdf * quant_scale; + int32_t i8 = floorf(i8_f + 0.5); + i8 = i8 < -127 ? -127 : (i8 > 127 ? 127 : i8); + output[idx] = int8_t(i8); + } +} + +int computeGelu(cudaStream_t stream, int n, const float* input, float* output) { + constexpr int blockSize = 256; + const int gridSize = (n + blockSize - 1) / blockSize; + geluKernel<<>>(A, B, C, n, input, output); + + return 0; +} + +int computeGelu(cudaStream_t stream, int n, const half* input, half* output) { + constexpr int blockSize = 256; + const int gridSize = (n + blockSize - 1) / blockSize; + geluKernel<<>>(A, B, C, n, input, output); + + return 0; +} + +int32_t computeGeluI8O8(cudaStream_t stream, int n, const int8_t* input, int8_t* output, float dequant_scale, + float quant_scale) { + constexpr int blockSize = 256; + const int gridSize = (n + blockSize - 1) / blockSize; + geluKernel<<>>(A, B, C, n, input, output, dequant_scale, quant_scale); + + return 0; +} + +template +__global__ void geluBiasKernel(const half a, const half b, const half c, half* output, const half* input, + const half* bias, const int ld) { + const int offset = blockIdx.x * ld; + + for (int it = threadIdx.x; it < ld; it += TPB) { + const int idx = it + offset; + const half in = input[idx] + bias[it]; + const half cdf = a + a * __float2half(tanh(__half2float(in * (c * in * in + b)))); + output[idx] = in * cdf; + } +} + +template +__global__ void geluBiasKernel(const float a, const float b, const float c, float* output, const float* input, + const float* bias, const int ld) { + const int offset = blockIdx.x * ld; + + for (int it = threadIdx.x; it < ld; it += TPB) { + const int idx = it + offset; + const float in = input[idx] + bias[it]; + const float cdf = a + a * tanh(in * (c * in * in + b)); + output[idx] = in * cdf; + } +} + +template +__global__ void geluBiasKernel(const float a, const float b, const float c, int8_t* output, const int8_t* input, + const half* bias, float dequant_scale, float quant_scale, const int ld) { + const int offset = blockIdx.x * ld; + + for (int it = threadIdx.x; it < ld; it += TPB) { + const int idx = it + offset; + const float in = float(input[idx]) * dequant_scale + __half2float(bias[it]); + const float cdf = a + a * tanh(in * (c * in * in + b)); + float i8_f = in * cdf * quant_scale; + int32_t i8 = floorf(i8_f + 0.5); + i8 = i8 < -127 ? -127 : (i8 > 127 ? 127 : i8); + output[idx] = int8_t(i8); + } +} + +int computeGeluBias(float* output, const float* input, const float* bias, const int ld, const int cols, + cudaStream_t stream) { + geluBiasKernel<256><<>>(A, B, C, output, input, bias, ld); + return cudaPeekAtLastError(); +} + +int computeGeluBias(half* output, const half* input, const half* bias, const int ld, const int cols, + cudaStream_t stream) { + geluBiasKernel<256><<>>(A, B, C, output, input, bias, ld); + return cudaPeekAtLastError(); +} + +int32_t computeGeluI8O8Bias(int8_t* output, const int8_t* input, const half* bias, const int ld, const int cols, + float dequant_scale, float quant_scale, cudaStream_t stream) { + geluBiasKernel<256><<>>(A, B, C, output, input, bias, dequant_scale, quant_scale, ld); + return cudaPeekAtLastError(); +} + +} // namespace bert +} // namespace nvinfer1::plugin diff --git a/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/gelu/geluPlugin.h b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/gelu/geluPlugin.h new file mode 100644 index 0000000000000000000000000000000000000000..b1d1e4651e19e582c1c9c79f4de4443bc21a02eb --- /dev/null +++ b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/gelu/geluPlugin.h @@ -0,0 +1,125 @@ +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +* All Rights Reserved. +* +* Licensed under the Apache License, Version 2.0 (the "License"); you may +* not use this file except in compliance with the License. You may obtain +* a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +* License for the specific language governing permissions and limitations +* under the License. +*/ +#pragma once +#include + +#include "NvInferRuntime.h" +#include "NvInferRuntimeCommon.h" +#include "bertCommon.h" + +namespace nvinfer1::plugin { +namespace bert { +int32_t computeGelu(cudaStream_t stream, int32_t n, float const* input, float* output); + +int32_t computeGelu(cudaStream_t stream, int32_t n, half const* input, half* output); + +int32_t computeGeluI8O8(cudaStream_t stream, int n, const int8_t* input, int8_t* output, float dequant_scale, + float quant_scale); + +int32_t computeGeluBias(float* output, float const* input, float const* bias, int32_t const ld, int32_t const cols, + cudaStream_t stream); + +int32_t computeGeluBias(half* output, half const* input, half const* bias, int32_t const ld, int32_t const cols, + cudaStream_t stream); + +int32_t computeGeluI8O8Bias(int8_t* output, const int8_t* input, const half* bias, const int ld, const int cols, + float dequant_scale, float quant_scale, cudaStream_t stream); + +class GeluPluginDynamic : public nvinfer1::IPluginV2DynamicExt { + public: + GeluPluginDynamic(const std::string name, const nvinfer1::DataType type, nvinfer1::Weights const& bias, + const int ld); + + GeluPluginDynamic(const std::string name, void const* data, size_t length); + + // It doesn't make sense to make GeluPluginDynamic without arguments, so we delete + // default constructor. + GeluPluginDynamic() = delete; + + // IPluginV2 Methods + char const* getPluginType() const noexcept override; + char const* getPluginVersion() const noexcept override; + int32_t getNbOutputs() const noexcept override; + int32_t initialize() noexcept override; + void terminate() noexcept override; + size_t getSerializationSize() const noexcept override; + void serialize(void* buffer) const noexcept override; + void destroy() noexcept override; + void setPluginNamespace(char const* pluginNamespace) noexcept override; + char const* getPluginNamespace() const noexcept override; + + // IPluginV2Ext Methods + nvinfer1::DataType getOutputDataType(int32_t index, nvinfer1::DataType const* inputTypes, + int32_t nbInputs) const noexcept override; + + // IPluginV2DynamicExt Methods + nvinfer1::IPluginV2DynamicExt* clone() const noexcept override; + nvinfer1::DimsExprs getOutputDimensions(int32_t outputIndex, nvinfer1::DimsExprs const* inputs, int32_t nbInputs, + nvinfer1::IExprBuilder& exprBuilder) noexcept override; + bool supportsFormatCombination(int32_t pos, nvinfer1::PluginTensorDesc const* inOut, int32_t nbInputs, + int32_t nbOutputs) noexcept override; + void configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int32_t nbInputs, + nvinfer1::DynamicPluginTensorDesc const* out, int32_t nbOutputs) noexcept override; + size_t getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int32_t nbInputs, + nvinfer1::PluginTensorDesc const* outputs, int32_t nbOutputs) const noexcept override; + int32_t enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc, + void const* const* inputs, void* const* outputs, void* workspace, + cudaStream_t stream) noexcept override; + + private: + // Helper method for enqueue() + template + int32_t enqueueTyped(void const* input, void* output, int32_t const inputVolume, cudaStream_t stream) noexcept; + int32_t enqueueInt8(void const* input_, void* output_, float dequant_scale, float quant_scale, + int32_t const inputVolume, cudaStream_t stream) noexcept; + + const std::string mLayerName; + std::string mNamespace; + + nvinfer1::DataType mType; + bert::WeightsWithOwnership mBias; + bert::cuda_unique_ptr mBiasDev; + size_t mLd; + size_t mNumBias; +}; + +class GeluPluginDynamicCreator : public nvinfer1::IPluginCreator { + public: + GeluPluginDynamicCreator(); + + char const* getPluginName() const noexcept override; + + char const* getPluginVersion() const noexcept override; + + nvinfer1::PluginFieldCollection const* getFieldNames() noexcept override; + + nvinfer1::IPluginV2* createPlugin(char const* name, nvinfer1::PluginFieldCollection const* fc) noexcept override; + + nvinfer1::IPluginV2* deserializePlugin(char const* name, void const* serialData, + size_t serialLength) noexcept override; + + void setPluginNamespace(char const* pluginNamespace) noexcept override; + + char const* getPluginNamespace() const noexcept override; + + private: + static nvinfer1::PluginFieldCollection mFC; + static std::vector mPluginAttributes; + std::string mNamespace; +}; + +} // namespace bert +} // namespace nvinfer1::plugin diff --git a/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/qkv_to_context/qkvToContextInt8Plugin.cpp b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/qkv_to_context/qkvToContextInt8Plugin.cpp new file mode 100644 index 0000000000000000000000000000000000000000..072ba6880fe02199c1a8244a6e018b4b2d97f46d --- /dev/null +++ b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/qkv_to_context/qkvToContextInt8Plugin.cpp @@ -0,0 +1,313 @@ +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +* All Rights Reserved. +* +* Licensed under the Apache License, Version 2.0 (the "License"); you may +* not use this file except in compliance with the License. You may obtain +* a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +* License for the specific language governing permissions and limitations +* under the License. +*/ +#include "qkvToContextInt8Plugin.h" + +#include "NvInferRuntime.h" +#include "NvInferRuntimeCommon.h" +#include "bertCommon.h" +#include "checkMacrosPlugin.h" +#include "common_def.cuh" +#include "plugin.h" +#include "serialize.h" + +using namespace nvinfer1; +using namespace nvinfer1::plugin; +using namespace nvinfer1::plugin::bert; + +namespace { +char const* const kQKV_TO_CONTEXT_INT8_IXRT_PLUGIN_VERSION{"3"}; +char const* const kQKV_TO_CONTEXT_INT8_IXRT_PLUGIN_NAME{"CustomQKVToContextPluginDynamic_IxRT"}; +} // namespace + +PluginFieldCollection QKVToContextInt8PluginDynamicCreator::mFC{}; +std::vector QKVToContextInt8PluginDynamicCreator::mPluginAttributes; + +constexpr uint32_t IIDX = 0; // index of the input tensor +constexpr uint32_t MIDX = 1; // index of the mask +/* +dq_probs: +_arrange_qkv_amax +_softmax_in_amax +_softmax_out_amax +*/ +QKVToContextInt8PluginDynamicCreator::QKVToContextInt8PluginDynamicCreator() { + mPluginAttributes.clear(); + mPluginAttributes.emplace_back(PluginField("hidden_size", nullptr, PluginFieldType::kINT32, 1)); + mPluginAttributes.emplace_back(PluginField("num_heads", nullptr, PluginFieldType::kINT32, 1)); + mPluginAttributes.emplace_back(PluginField("dq_probs", nullptr, PluginFieldType::kFLOAT32, 3)); + + mFC.nbFields = mPluginAttributes.size(); + mFC.fields = mPluginAttributes.data(); +} + +char const* QKVToContextInt8PluginDynamicCreator::getPluginName() const noexcept { + return kQKV_TO_CONTEXT_INT8_IXRT_PLUGIN_NAME; +} + +char const* QKVToContextInt8PluginDynamicCreator::getPluginVersion() const noexcept { + return kQKV_TO_CONTEXT_INT8_IXRT_PLUGIN_VERSION; +} + +PluginFieldCollection const* QKVToContextInt8PluginDynamicCreator::getFieldNames() noexcept { return &mFC; } + +IPluginV2* QKVToContextInt8PluginDynamicCreator::createPlugin(char const* name, + PluginFieldCollection const* fc) noexcept { + try { + int32_t hiddenSize = 0; + // Since numHeads must always exist or validateRequiredAttributes will fail, + // we can set numHeads to -1 so that static analysis tools don't warn about + // a division by zero in QKVToContextInt8PluginDynamic constructor. + int32_t numHeads{-1}; + + vector dqProbs; + + plugin::validateRequiredAttributesExist({"hidden_size", "num_heads"}, fc); + + for (int32_t i = 0; i < fc->nbFields; i++) { + std::string field_name(fc->fields[i].name); + + if (field_name.compare("hidden_size") == 0) { + hiddenSize = *static_cast(fc->fields[i].data); + IXRT_PLUGIN_CHECK_VALUE(hiddenSize > 0, + ("QKV: Invalid hiddenSize " + std::to_string(hiddenSize)).c_str()); + gLogInfo << "Building hiddenSize: " << hiddenSize << endl; + } + if (field_name.compare("num_heads") == 0) { + numHeads = *static_cast(fc->fields[i].data); + IXRT_PLUGIN_CHECK_VALUE(numHeads > 0, ("QKV: Invalid numHeads " + std::to_string(numHeads)).c_str()); + gLogInfo << "Building numHeads: " << numHeads << endl; + } + if (field_name.compare("dq_probs") == 0) { + IXRT_PLUGIN_CHECK_VALUE(fc->fields[i].length > 0, + ("QKV: dpProbs can not be empty, error: [dpProbs.length == 0]!")); + gLogInfo << "Building dqProbs: ["; + for (auto j = 0; j < fc->fields[i].length; j++) { + dqProbs.emplace_back(static_cast((fc->fields[i].data))[j]); + gLogInfo << std::setprecision(5) << dqProbs[j]; + } + gLogInfo << "]" << endl; + } + } + + QKVToContextInt8PluginDynamic* p = new QKVToContextInt8PluginDynamic(name, hiddenSize, numHeads, dqProbs); + return p; + } catch (std::exception const& e) { + caughtError(e); + } + return nullptr; +} + +IPluginV2* QKVToContextInt8PluginDynamicCreator::deserializePlugin(char const* name, void const* serialData, + size_t serialLength) noexcept { + try { + // This object will be deleted when the network is destroyed, which will + // call QKVToContextInt8PluginDynamic::destroy() noexcept + return new QKVToContextInt8PluginDynamic(name, serialData, serialLength); + } catch (std::exception const& e) { + caughtError(e); + } + return nullptr; +} + +void QKVToContextInt8PluginDynamicCreator::setPluginNamespace(char const* libNamespace) noexcept { + mNamespace = libNamespace; +} + +char const* QKVToContextInt8PluginDynamicCreator::getPluginNamespace() const noexcept { return mNamespace.c_str(); } + +// REGISTER_TENSORRT_PLUGIN(QKVToContextInt8PluginDynamicCreator); +//#########################################################################// +QKVToContextInt8PluginDynamic::QKVToContextInt8PluginDynamic(std::string const& name, int32_t const hiddenSize, + int32_t const numHeads, vector const dqProbs) + : mLayerName(name), + mS(0), + mB(0), + mHeadSize(hiddenSize / numHeads), + mHiddenSize(hiddenSize), + mNumHeads(numHeads), + mDqProbs(dqProbs) {} + +QKVToContextInt8PluginDynamic::QKVToContextInt8PluginDynamic(std::string const& name, void const* data, size_t length) + : mLayerName(name) { + gLogInfo << "deserialize QKVToContextInt8PluginDynamic" << endl; + deserialize_value(&data, &length, &mNumHeads); + deserialize_value(&data, &length, &mHeadSize); + deserialize_value(&data, &length, &mHiddenSize); + deserialize_value(&data, &length, &mDqProbs); +} + +// IPluginV2 Methods +char const* QKVToContextInt8PluginDynamic::getPluginType() const noexcept { + return kQKV_TO_CONTEXT_INT8_IXRT_PLUGIN_NAME; +} + +char const* QKVToContextInt8PluginDynamic::getPluginVersion() const noexcept { + return kQKV_TO_CONTEXT_INT8_IXRT_PLUGIN_VERSION; +} + +int32_t QKVToContextInt8PluginDynamic::getNbOutputs() const noexcept { return 1; } + +int32_t QKVToContextInt8PluginDynamic::initialize() noexcept { return 0; } + +void QKVToContextInt8PluginDynamic::terminate() noexcept {} + +size_t QKVToContextInt8PluginDynamic::getSerializationSize() const noexcept { + return sizeof(mNumHeads) + sizeof(mHeadSize) + sizeof(mHiddenSize) + mDqProbs.size() * sizeof(float) + + sizeof(mDqProbs.size()); +} + +void QKVToContextInt8PluginDynamic::serialize(void* buffer) const noexcept { + serialize_value(&buffer, mNumHeads); + serialize_value(&buffer, mHeadSize); + serialize_value(&buffer, mHiddenSize); + serialize_value(&buffer, mDqProbs); +} + +void QKVToContextInt8PluginDynamic::destroy() noexcept { delete this; } + +void QKVToContextInt8PluginDynamic::setPluginNamespace(char const* libNamespace) noexcept { mNamespace = libNamespace; } + +char const* QKVToContextInt8PluginDynamic::getPluginNamespace() const noexcept { return mNamespace.c_str(); } + +// IPluginV2Ext Methods +DataType QKVToContextInt8PluginDynamic::getOutputDataType(int32_t index, nvinfer1::DataType const* inputTypes, + int32_t nbInputs) const noexcept { + IXRT_PLUGIN_ASSERT(index == 0) + return DataType::kINT8; +} + +// IPluginV2DynamicExt Methods +nvinfer1::IPluginV2DynamicExt* QKVToContextInt8PluginDynamic::clone() const noexcept { + try { + QKVToContextInt8PluginDynamic* ret = + new QKVToContextInt8PluginDynamic(mLayerName, mHiddenSize, mNumHeads, mDqProbs); + + ret->setPluginNamespace(mNamespace.c_str()); + return ret; + } catch (std::exception const& e) { + caughtError(e); + } + return nullptr; +} + +DimsExprs QKVToContextInt8PluginDynamic::getOutputDimensions(int32_t outputIndex, DimsExprs const* inputs, + int32_t nbInputs, IExprBuilder& exprBuilder) noexcept { + // input [B, S, 3*E] int8 + // pad_mask [B, S] int8 + + // output [B, S, E] int8 + IXRT_PLUGIN_ASSERT(outputIndex == 0); + // Copy over everything + DimsExprs output(inputs[IIDX]); + // Divide last dim by three + auto const* three = exprBuilder.constant(3); + output.d[HDIM] = exprBuilder.constant(mHiddenSize); + // output.d[HDIM] = exprBuilder.operation(DimensionOperation::kFLOOR_DIV, *inputs[IIDX].d[HDIM], *three); + return output; +} +bool QKVToContextInt8PluginDynamic::supportsFormatCombination(int32_t pos, PluginTensorDesc const* inOut, + int32_t nbInputs, int32_t nbOutputs) noexcept { + IXRT_PLUGIN_ASSERT(nbInputs == 2); + IXRT_PLUGIN_ASSERT(nbOutputs == 1); + return (inOut[pos].type == DataType::kINT8) && (inOut[pos].format == TensorFormat::kLINEAR); +} + +void QKVToContextInt8PluginDynamic::configurePlugin(DynamicPluginTensorDesc const* in, int32_t nbInputs, + DynamicPluginTensorDesc const* out, int32_t nbOutputs) noexcept { + IXRT_PLUGIN_ASSERT(nbInputs == 2); + IXRT_PLUGIN_ASSERT(nbOutputs == 1); + PluginTensorDesc const& inDesc = in[IIDX].desc; + PluginTensorDesc const& outDesc = out[0].desc; + IXRT_PLUGIN_ASSERT(inDesc.dims.nbDims == 3 || inDesc.dims.nbDims == 5) + IXRT_PLUGIN_ASSERT(inDesc.dims.d[HDIM] == 3 * mHiddenSize); + // IXRT_PLUGIN_ASSERT(inDesc.dims.d[3] == 1); + // IXRT_PLUGIN_ASSERT(inDesc.dims.d[4] == 1); + + PluginTensorDesc const& maskDesc = in[MIDX].desc; + IXRT_PLUGIN_ASSERT(maskDesc.dims.nbDims == 2); + IXRT_PLUGIN_ASSERT(maskDesc.dims.d[0] == inDesc.dims.d[0]); + IXRT_PLUGIN_ASSERT(maskDesc.dims.d[1] == inDesc.dims.d[1]); + + const int32_t S = inDesc.dims.d[SDIM]; + + IXRT_PLUGIN_ASSERT(outDesc.dims.nbDims == 3 || outDesc.dims.nbDims == 5); + IXRT_PLUGIN_ASSERT(outDesc.dims.d[BDIM] == inDesc.dims.d[BDIM]); + IXRT_PLUGIN_ASSERT(outDesc.dims.d[SDIM] == S); + IXRT_PLUGIN_ASSERT(outDesc.dims.d[HDIM] == mHiddenSize); + // IXRT_PLUGIN_ASSERT(outDesc.dims.d[3] == 1); + // IXRT_PLUGIN_ASSERT(outDesc.dims.d[4] == 1); +#ifdef __ILUVATAR__ + CUINFER_CHECK(cuinferCreate(&cuinfer_handle)); +#else + CHECK_GPU_ERROR(cublasLtCreate(&blaslt_handle)); +#endif +} + +size_t QKVToContextInt8PluginDynamic::getWorkspaceSize(PluginTensorDesc const* inputs, int32_t nbInputs, + PluginTensorDesc const* outputs, + int32_t nbOutputs) const noexcept { + const int32_t B = inputs[0].dims.d[BDIM]; + const int32_t S = inputs->dims.d[SDIM]; + const int32_t E = inputs->dims.d[HDIM]; + IXRT_PLUGIN_ASSERT(E == 3 * mHiddenSize); + int64_t buffer_size = B * S * E; + return (B * S * E + buffer_size) * sizeof(int8_t) + buffer_size * sizeof(int32_t); +} + +int32_t QKVToContextInt8PluginDynamic::enqueue(PluginTensorDesc const* inputDesc, PluginTensorDesc const* outputDesc, + void const* const* inputs, void* const* outputs, void* workspace, + cudaStream_t stream) noexcept { + try { +#ifdef __ILUVATAR__ + CUINFER_CHECK(cuinferSetStream(cuinfer_handle, 0)); +#endif + int32_t const B = inputDesc[0].dims.d[BDIM]; + int32_t const S = inputDesc[0].dims.d[SDIM]; + + float qkv_out_amax_ = inputDesc[0].scale * 127; + float linear_in_amax_ = outputDesc[0].scale * 127; + float arrange_qkv_amax_ = mDqProbs[0]; + float softmax_in_amax_ = mDqProbs[1]; + float softmax_out_amax_ = mDqProbs[2]; + + int8_t* qkv_buffer_ = (int8_t*)inputs[0]; + int8_t* qkv_out_ = (int8_t*)outputs[0]; + int8_t* mask_ = (int8_t*)inputs[1]; + + int64_t buffer_size = B * S * mHiddenSize; + int64_t buffer_size2 = B * S * S * mNumHeads; + int8_t* q_buffer_ = static_cast(workspace); + int8_t* k_buffer_ = q_buffer_ + buffer_size; + int8_t* v_buffer_ = k_buffer_ + buffer_size; + int8_t* qk_buffer_ = v_buffer_ + buffer_size; + int32_t* qk_out_ = reinterpret_cast(qk_buffer_ + buffer_size2); +#ifdef __ILUVATAR__ + auto status = + fused_multihead_attetion_int8(qkv_buffer_, mask_, q_buffer_, k_buffer_, v_buffer_, qk_out_, qkv_out_, + qk_buffer_, B, S, mHeadSize, mNumHeads, mHiddenSize, arrange_qkv_amax_, + softmax_in_amax_, softmax_out_amax_, linear_in_amax_, cuinfer_handle, stream); +#else + auto status = + fused_multihead_attetion_int8(qkv_buffer_, mask_, q_buffer_, k_buffer_, v_buffer_, qk_out_, qkv_out_, + qk_buffer_, B, S, mHeadSize, mNumHeads, mHiddenSize, arrange_qkv_amax_, + softmax_in_amax_, softmax_out_amax_, linear_in_amax_, blaslt_handle, stream); +#endif + return cudaPeekAtLastError(); + } catch (std::exception const& e) { + caughtError(e); + return -1; + } +} diff --git a/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/qkv_to_context/qkvToContextInt8Plugin.cu b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/qkv_to_context/qkvToContextInt8Plugin.cu new file mode 100644 index 0000000000000000000000000000000000000000..d352b0225e4bb9e923b0ed136496ca3da4d69c24 --- /dev/null +++ b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/qkv_to_context/qkvToContextInt8Plugin.cu @@ -0,0 +1,277 @@ +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +* All Rights Reserved. +* +* Licensed under the Apache License, Version 2.0 (the "License"); you may +* not use this file except in compliance with the License. You may obtain +* a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +* License for the specific language governing permissions and limitations +* under the License. +*/ +#include "backend/bert/bert_helper.h" +#include "backend/bert/bert_layer_kernel.h" +#include "backend/cublas/cublas_helper.h" +#include "backend/ixinfer/ixinfer_gemm_helper.h" +#include "qkvToContextInt8Plugin.h" + +using namespace nvinfer1::plugin::backend; + +namespace nvinfer1::plugin { +namespace bert { +const int _max_thread_per_block = 1024; +const float _quant_range = 127.0; + +__global__ void IxinferArrangeEncselfQkvI8II8ONoBias(const int8_t* ori_qkv, int8_t* new_qkv, int max_batch_dim, + int batch_seq_len, int dim_per_head, int head_num) { + int hidden_size = dim_per_head * head_num; + int batch_id = blockIdx.x / batch_seq_len; + int token_id = blockIdx.x % batch_seq_len; + + int i = threadIdx.x; // 1个线程处理4个数据 + + int head_id = (i * 4) / dim_per_head; + int dim_id = (i * 4) % dim_per_head; + int target_id = targetid_4dim(batch_id, head_id, token_id, dim_id, head_num, batch_seq_len, dim_per_head); + +#pragma unroll + for (int qkv_idx = 0; qkv_idx < 3; qkv_idx++) { + char4* p_ori_qkv = (char4*)(ori_qkv + (blockIdx.x * 3 + qkv_idx) * hidden_size); + int qkv_offset = max_batch_dim * qkv_idx; + char4* p_new_qkv = (char4*)(new_qkv + qkv_offset + target_id); + p_new_qkv[0] = p_ori_qkv[i]; + } +} + +#ifdef __ILUVATAR__ +cudaError_t fused_multihead_attetion_int8(int8_t* qkv_buffer, int8_t* mask, int8_t* q_buffer, int8_t* k_buffer, + int8_t* v_buffer, int32_t* qk_out, int8_t* qkv_out, int8_t* qk_buffer, + int batch_size, int batch_seq_len, int head_dim, int head_num, + int hidden_size, float arrange_qkv_amax, float softmax_in_amax, + float softmax_out_amax, float linear_in_amax, cuinferHandle_t& cuinfer_handle, + cudaStream_t& stream) { + int batch_token_num = batch_size * batch_seq_len; + int max_batch_dim = batch_token_num * hidden_size; + + float scaleCtx = linear_in_amax / _quant_range; + float scaleArrange = arrange_qkv_amax / _quant_range; + float scaleSoftin = softmax_in_amax / _quant_range; + float scaleSoftout = softmax_out_amax / _quant_range; + + float scaleBmm1 = scaleArrange * scaleArrange / scaleSoftin * sqrt(1.f / head_dim); + float scaleBmm2 = scaleSoftout * scaleArrange / scaleCtx; + + IxinferArrangeEncselfQkvI8II8ONoBias<<>>( + qkv_buffer, q_buffer, max_batch_dim, batch_seq_len, head_dim, head_num); + + switch (head_dim) { + case 64: + case 128: + case 192: + case 256: + cuinferFlashAttnConfigInfo flashAttnInfo; + flashAttnInfo.scaling = sqrt(1.f / (head_dim * 1.0)); + flashAttnInfo.quantParam.q_amax = arrange_qkv_amax; + flashAttnInfo.quantParam.k_amax = arrange_qkv_amax; + flashAttnInfo.quantParam.v_amax = arrange_qkv_amax; + flashAttnInfo.quantParam.p_amax = softmax_out_amax; + flashAttnInfo.quantParam.o_amax = linear_in_amax; + + cuinferTensorDescriptor_t qDesc, kDesc, vDesc, maskDesc, oDesc; + CUINFER_CHECK(cuinferCreateTensorDescriptor(&qDesc)); + CUINFER_CHECK(cuinferCreateTensorDescriptor(&kDesc)); + CUINFER_CHECK(cuinferCreateTensorDescriptor(&vDesc)); + CUINFER_CHECK(cuinferCreateTensorDescriptor(&maskDesc)); + CUINFER_CHECK(cuinferCreateTensorDescriptor(&oDesc)); + + CUINFER_CHECK(cuinferSetTensor4dDescriptor(qDesc, cuinferTensorFormat_t::CUINFER_TENSOR_NCHW, + CUINFER_DATA_INT8, batch_size, head_num, batch_seq_len, + head_dim)); + CUINFER_CHECK(cuinferSetTensor4dDescriptor(kDesc, cuinferTensorFormat_t::CUINFER_TENSOR_NCHW, + CUINFER_DATA_INT8, batch_size, head_num, batch_seq_len, + head_dim)); + CUINFER_CHECK(cuinferSetTensor4dDescriptor(vDesc, cuinferTensorFormat_t::CUINFER_TENSOR_NCHW, + CUINFER_DATA_INT8, batch_size, head_num, batch_seq_len, + head_dim)); + CUINFER_CHECK(cuinferSetTensor4dDescriptor(maskDesc, cuinferTensorFormat_t::CUINFER_TENSOR_NCHW, + CUINFER_DATA_INT8, batch_size, 1, 1, batch_seq_len)); + CUINFER_CHECK(cuinferSetTensor4dDescriptor(oDesc, cuinferTensorFormat_t::CUINFER_TENSOR_NCHW, + CUINFER_DATA_INT8, batch_size, head_num, batch_seq_len, + head_dim)); + + CUINFER_CHECK(cuinferFMHAForwardEx(cuinfer_handle, flashAttnInfo, qDesc, q_buffer, kDesc, k_buffer, vDesc, + v_buffer, maskDesc, mask, oDesc, qk_buffer)); + break; + default: + cuinfer_i8_gemm(k_buffer, q_buffer, nullptr, qkv_buffer, batch_size * head_num, batch_seq_len, + batch_seq_len, head_dim, batch_seq_len * head_dim, batch_seq_len * head_dim, + batch_seq_len * batch_seq_len, scaleBmm1, 0.0, 0, cuinfer_handle, stream); + + IxinferCorrelationSoftmaxEncselfI8II8O(batch_size, batch_seq_len, head_num, stream, qkv_buffer, mask, + 1.0 / scaleSoftout, scaleSoftin); + + cuinfer_nn_i8_gemm(v_buffer, qkv_buffer, q_buffer, batch_size * head_num, head_dim, batch_seq_len, + batch_seq_len, batch_seq_len * head_dim, batch_seq_len * batch_seq_len, + batch_seq_len * head_dim, scaleBmm2, cuinfer_handle, stream); + break; + } + + IxinferArrangeAttenOutputI8II8O(batch_token_num, hidden_size, stream, qk_buffer, qkv_out, batch_seq_len, head_dim, + head_num, _max_thread_per_block, 1.f, 1.f); + return cudaSuccess; +} +#else +template +__global__ void quant_qkv_gemm(const int32_t *input, int8_t *output, int hidden_size, float quant_scale, + int num_per_tca) { + float4 val[THREAD_DATA_LEN]; + + int block_id = blockIdx.x * gridDim.y * gridDim.z + blockIdx.y * gridDim.z + blockIdx.z; + int block_start = block_id * hidden_size; + input += block_start; + output += block_start; + + int4 *p_input = (int4 *)input; + char4 *p_output = (char4 *)output; + + float4 bias_val; +#pragma unroll + for (int it = 0; it < THREAD_DATA_LEN; ++it) { + int element_index = threadIdx.x + it * num_per_tca; + char4 q_input; + q_input.x = float2int8(p_input[element_index].x * 1.0, quant_scale); + q_input.y = float2int8(p_input[element_index].y * 1.0, quant_scale); + q_input.z = float2int8(p_input[element_index].z * 1.0, quant_scale); + q_input.w = float2int8(p_input[element_index].w * 1.0, quant_scale); + + p_output[element_index] = q_input; + } +} + +void quantQKVGemm(int32_t *input, int8_t *output, int batch_size, int head_num, int batch_seq_len, int hidden_size, + float dequant_scale, cudaStream_t stream) { + if (hidden_size > 4096) { + throw std::runtime_error("hidden_size should <= 4096"); + } + int num_per_tca = min(hidden_size / 4, C10_WARP_SIZE); + dim3 gridSize(batch_size, head_num, batch_seq_len); + dim3 blockSize(num_per_tca); + + int num_warp = hidden_size / num_per_tca / 4; + switch (num_warp) { + case 1: + quant_qkv_gemm<1> + <<>>(input, output, hidden_size, dequant_scale, num_per_tca); + break; + case 2: + quant_qkv_gemm<2> + <<>>(input, output, hidden_size, dequant_scale, num_per_tca); + break; + case 3: + quant_qkv_gemm<3> + <<>>(input, output, hidden_size, dequant_scale, num_per_tca); + break; + case 4: + quant_qkv_gemm<4> + <<>>(input, output, hidden_size, dequant_scale, num_per_tca); + break; + case 5: + quant_qkv_gemm<5> + <<>>(input, output, hidden_size, dequant_scale, num_per_tca); + break; + case 6: + quant_qkv_gemm<6> + <<>>(input, output, hidden_size, dequant_scale, num_per_tca); + break; + case 7: + quant_qkv_gemm<7> + <<>>(input, output, hidden_size, dequant_scale, num_per_tca); + break; + case 8: + quant_qkv_gemm<8> + <<>>(input, output, hidden_size, dequant_scale, num_per_tca); + break; + case 9: + quant_qkv_gemm<9> + <<>>(input, output, hidden_size, dequant_scale, num_per_tca); + break; + case 10: + quant_qkv_gemm<10> + <<>>(input, output, hidden_size, dequant_scale, num_per_tca); + break; + case 11: + quant_qkv_gemm<11> + <<>>(input, output, hidden_size, dequant_scale, num_per_tca); + break; + case 12: + quant_qkv_gemm<12> + <<>>(input, output, hidden_size, dequant_scale, num_per_tca); + break; + case 13: + quant_qkv_gemm<13> + <<>>(input, output, hidden_size, dequant_scale, num_per_tca); + break; + case 14: + quant_qkv_gemm<14> + <<>>(input, output, hidden_size, dequant_scale, num_per_tca); + break; + case 15: + quant_qkv_gemm<15> + <<>>(input, output, hidden_size, dequant_scale, num_per_tca); + break; + case 16: + quant_qkv_gemm<16> + <<>>(input, output, hidden_size, dequant_scale, num_per_tca); + break; + default: + throw std::runtime_error("quantQKVGemm"); + break; + } +} + +cudaError_t fused_multihead_attetion_int8(int8_t *qkv_buffer, int8_t *mask, int8_t *q_buffer, int8_t *k_buffer, + int8_t *v_buffer, int32_t *qk_out, int8_t *qkv_out, int8_t *qk_buffer, + int batch_size, int batch_seq_len, int head_dim, int head_num, + int hidden_size, float arrange_qkv_amax, float softmax_in_amax, + float softmax_out_amax, float linear_in_amax, + cublasLtHandle_t &cublas_lt_handle, cudaStream_t &stream) { + int batch_token_num = batch_size * batch_seq_len; + int max_batch_dim = batch_token_num * hidden_size; + + float scaleCtx = linear_in_amax / _quant_range; + float scaleArrange = arrange_qkv_amax / _quant_range; + float scaleSoftin = softmax_in_amax / _quant_range; + float scaleSoftout = softmax_out_amax / _quant_range; + + float scaleBmm1 = scaleArrange * scaleArrange / scaleSoftin * sqrt(1.f / head_dim); + float scaleBmm2 = scaleSoftout * scaleArrange / scaleCtx; + + IxinferArrangeEncselfQkvI8II8ONoBias<<>>( + qkv_buffer, q_buffer, max_batch_dim, batch_seq_len, head_dim, head_num); + + cublaslt_gemm(k_buffer, q_buffer, qk_out, batch_size * head_num, batch_seq_len, batch_seq_len, head_dim, + batch_seq_len * head_dim, batch_seq_len * head_dim, batch_seq_len * batch_seq_len, 1, + cublas_lt_handle, stream); + + quantQKVGemm(qk_out, qk_buffer, batch_size, head_num, batch_seq_len, batch_seq_len, scaleBmm1, stream); + + IxinferCorrelationSoftmaxEncselfI8II8O(batch_size, batch_seq_len, head_num, stream, qk_buffer, mask, + 1.0 / scaleSoftout, scaleSoftin); + + cublaslt_gemm_nn(v_buffer, qk_buffer, qk_out, batch_size * head_num, head_dim, batch_seq_len, batch_seq_len, + batch_seq_len * head_dim, batch_seq_len * batch_seq_len, batch_seq_len * head_dim, 1, + cublas_lt_handle, stream); + + quantQKVGemm(qk_out, q_buffer, batch_size, head_num, batch_seq_len, head_dim, scaleBmm2, stream); + + IxinferArrangeAttenOutputI8II8O(batch_token_num, hidden_size, stream, q_buffer, qkv_out, batch_seq_len, head_dim, + head_num, _max_thread_per_block, 1.f, 1.f); + return cudaSuccess; +} +#endif +} // namespace bert +} // namespace nvinfer1::plugin diff --git a/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/qkv_to_context/qkvToContextInt8Plugin.h b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/qkv_to_context/qkvToContextInt8Plugin.h new file mode 100644 index 0000000000000000000000000000000000000000..8d92b00a3c6b943ef7701db6278a7043af18be7d --- /dev/null +++ b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/qkv_to_context/qkvToContextInt8Plugin.h @@ -0,0 +1,141 @@ +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +* All Rights Reserved. +* +* Licensed under the Apache License, Version 2.0 (the "License"); you may +* not use this file except in compliance with the License. You may obtain +* a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +* License for the specific language governing permissions and limitations +* under the License. +*/ +#pragma once +#include + +#include +#include +#include + +#include "NvInferRuntime.h" +#include "bertCommon.h" + +#ifdef __ILUVATAR__ +#include "ixinfer.h" +#endif + +namespace nvinfer1::plugin { +namespace bert { +#ifdef __ILUVATAR__ +cudaError_t fused_multihead_attetion_int8(int8_t* qkv_buffer, int8_t* mask, int8_t* q_buffer, int8_t* k_buffer, + int8_t* v_buffer, int32_t* qk_out, int8_t* qkv_out, int8_t* qk_buffer, + int batch_size, int batch_seq_len, int head_dim, int head_num, + int hidden_size, float arrange_qkv_amax, float softmax_in_amax, + float softmax_out_amax, float linear_in_amax, cuinferHandle_t& cuinfer_handle, + cudaStream_t& stream); +#else +cudaError_t fused_multihead_attetion_int8(int8_t* qkv_buffer, int8_t* mask, int8_t* q_buffer, int8_t* k_buffer, + int8_t* v_buffer, int32_t* qk_out, int8_t* qkv_out, int8_t* qk_buffer, + int batch_size, int batch_seq_len, int head_dim, int head_num, + int hidden_size, float arrange_qkv_amax, float softmax_in_amax, + float softmax_out_amax, float linear_in_amax, + cublasLtHandle_t& cublas_lt_handle, cudaStream_t& stream); +#endif + +class QKVToContextInt8PluginDynamic : public nvinfer1::IPluginV2DynamicExt { + public: + QKVToContextInt8PluginDynamic(std::string const& name, int32_t const hiddenSize, int32_t const numHeads, + vector const dqProbs); + + QKVToContextInt8PluginDynamic(std::string const& name, void const* data, size_t length); + + // It doesn't make sense to make QKVToContextInt8PluginDynamic without arguments, so we + // delete default constructor. + QKVToContextInt8PluginDynamic() = delete; + + // IPluginV2 Methods + char const* getPluginType() const noexcept override; + char const* getPluginVersion() const noexcept override; + int32_t getNbOutputs() const noexcept override; + int32_t initialize() noexcept override; + void terminate() noexcept override; + size_t getSerializationSize() const noexcept override; + void serialize(void* buffer) const noexcept override; + void destroy() noexcept override; + void setPluginNamespace(char const* pluginNamespace) noexcept override; + char const* getPluginNamespace() const noexcept override; + + // IPluginV2Ext Methods + nvinfer1::DataType getOutputDataType(int32_t index, nvinfer1::DataType const* inputTypes, + int32_t nbInputs) const noexcept override; + + // IPluginV2DynamicExt Methods + nvinfer1::IPluginV2DynamicExt* clone() const noexcept override; + nvinfer1::DimsExprs getOutputDimensions(int32_t outputIndex, nvinfer1::DimsExprs const* inputs, int32_t nbInputs, + nvinfer1::IExprBuilder& exprBuilder) noexcept override; + bool supportsFormatCombination(int32_t pos, nvinfer1::PluginTensorDesc const* inOut, int32_t nbInputs, + int32_t nbOutputs) noexcept override; + void configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int32_t nbInputs, + nvinfer1::DynamicPluginTensorDesc const* out, int32_t nbOutputs) noexcept override; + size_t getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int32_t nbInputs, + nvinfer1::PluginTensorDesc const* outputs, int32_t nbOutputs) const noexcept override; + int32_t enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc, + void const* const* inputs, void* const* outputs, void* workspace, + cudaStream_t stream) noexcept override; + + protected: + void createMHARunner() noexcept; + int32_t getSMVersion() const noexcept; + + private: + std::string const& mLayerName; + std::string mNamespace; + + int32_t mS; + int32_t mB; + int32_t mSM; + int32_t mHeadSize; + int32_t mHiddenSize; + int32_t mNumHeads; + + cuda_unique_ptr mQkvBias; + + vector mDqProbs; + bool mUseInt8ScaleMax{true}; +#ifdef __ILUVATAR__ + cuinferHandle_t cuinfer_handle; +#else + cublasLtHandle_t blaslt_handle; +#endif +}; + +class QKVToContextInt8PluginDynamicCreator : public nvinfer1::IPluginCreator { + public: + QKVToContextInt8PluginDynamicCreator(); + + char const* getPluginName() const noexcept override; + + char const* getPluginVersion() const noexcept override; + + nvinfer1::PluginFieldCollection const* getFieldNames() noexcept override; + + nvinfer1::IPluginV2* createPlugin(char const* name, nvinfer1::PluginFieldCollection const* fc) noexcept override; + + nvinfer1::IPluginV2* deserializePlugin(char const* name, void const* serialData, + size_t serialLength) noexcept override; + + void setPluginNamespace(char const* pluginNamespace) noexcept override; + + char const* getPluginNamespace() const noexcept override; + + private: + static nvinfer1::PluginFieldCollection mFC; + static std::vector mPluginAttributes; + std::string mNamespace; +}; + +} // namespace bert +} // namespace nvinfer1::plugin diff --git a/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/qkv_to_context/qkvToContextPlugin.cpp b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/qkv_to_context/qkvToContextPlugin.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b9483ddb53219511ba6aa7ed56b1b737a0a046bf --- /dev/null +++ b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/qkv_to_context/qkvToContextPlugin.cpp @@ -0,0 +1,425 @@ +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +* All Rights Reserved. +* +* Licensed under the Apache License, Version 2.0 (the "License"); you may +* not use this file except in compliance with the License. You may obtain +* a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +* License for the specific language governing permissions and limitations +* under the License. +*/ +#include "qkvToContextPlugin.h" + +#include "bertCommon.h" +#include "checkMacrosPlugin.h" +#include "common_def.cuh" +#include "cuda_runtime_api.h" +#include "plugin.h" +#include "serialize.h" +// #include "backend/transformer/transformer_utils.h" +// #include "backend/transformer/transformer_attention.h" +#include +#include +#include + +using namespace nvinfer1; +using namespace nvinfer1::plugin; +using namespace nvinfer1::plugin::bert; +// using namespace nvinfer1::inferrt::transformer; + +namespace { +char const* const kQKV_TO_CONTEXT_IXRT_PLUGIN_VERSION{"1"}; +char const* const kQKV_TO_CONTEXT_VAR_SEQLEN_IXRT_PLUGIN_VERSION{"2"}; +char const* const kQKV_TO_CONTEXT_IXRT_PLUGIN_NAME{"CustomQKVToContextPluginDynamic_IxRT"}; +} // namespace + +// Static class fields initialization +PluginFieldCollection QKVToContextPluginDynamicCreator::mFC{}; +std::vector QKVToContextPluginDynamicCreator::mPluginAttributes; + +constexpr uint32_t IIDX = 0; // index of the input tensor +constexpr uint32_t MIDX = 1; // index of the mask + +QKVToContextPluginDynamicCreator::QKVToContextPluginDynamicCreator() { + mPluginAttributes.emplace_back(PluginField("type_id", nullptr, PluginFieldType::kINT32, 1)); + mPluginAttributes.emplace_back(PluginField("hidden_size", nullptr, PluginFieldType::kINT32, 1)); + mPluginAttributes.emplace_back(PluginField("num_heads", nullptr, PluginFieldType::kINT32, 1)); + mPluginAttributes.emplace_back(PluginField("has_mask", nullptr, PluginFieldType::kINT32, 1)); + mPluginAttributes.emplace_back(PluginField("has_qk_bias", nullptr, PluginFieldType::kINT32, 1)); + mPluginAttributes.emplace_back(PluginField("is_t5_mode", nullptr, PluginFieldType::kINT32, 1)); + + mFC.nbFields = mPluginAttributes.size(); + mFC.fields = mPluginAttributes.data(); +} + +char const* QKVToContextPluginDynamicCreator::getPluginName() const noexcept { + return kQKV_TO_CONTEXT_IXRT_PLUGIN_NAME; +} + +char const* QKVToContextPluginDynamicCreator::getPluginVersion() const noexcept { + return kQKV_TO_CONTEXT_IXRT_PLUGIN_VERSION; +} + +PluginFieldCollection const* QKVToContextPluginDynamicCreator::getFieldNames() noexcept { return &mFC; } + +IPluginV2* QKVToContextPluginDynamicCreator::createPlugin(char const* name, PluginFieldCollection const* fc) noexcept { + try { + gLogInfo << "Creating QKV2ContextPlugin..." << endl; + IXRT_PLUGIN_ASSERT(fc != nullptr); + int32_t hiddenSize = 0; + // Since numHeads must always exist or validateRequiredAttributes will fail, + // we can set numHeads to -1 so that static analysis tools don't warn about + // a division by zero in QKVToContextPluginDynamic constructor. + int32_t numHeads{-1}; + bool hasMask = false; + bool hasQkBias = false; + int32_t typeId = -1; + int32_t isT5Mode = 0; + + float dqProbs = -1; + + IXRT_PLUGIN_ASSERT(fc->fields != nullptr); + plugin::validateRequiredAttributesExist({"type_id", "hidden_size", "num_heads", "has_mask"}, fc); + + for (int32_t i = 0; i < fc->nbFields; i++) { + IXRT_PLUGIN_ASSERT(fc->fields[i].name != nullptr || fc->fields[i].data != nullptr); + std::string field_name(fc->fields[i].name); + + if (field_name.compare("type_id") == 0) { + typeId = *static_cast(fc->fields[i].data); + IXRT_PLUGIN_CHECK_VALUE(typeId >= 0 && typeId <= 2, + ("QKV: Invalid TypeId " + std::to_string(typeId)).c_str()); + gLogInfo << "Building typeId: " << typeId << endl; + } + if (field_name.compare("hidden_size") == 0) { + hiddenSize = *static_cast(fc->fields[i].data); + IXRT_PLUGIN_CHECK_VALUE(hiddenSize > 0, + ("QKV: Invalid hiddenSize " + std::to_string(hiddenSize)).c_str()); + gLogInfo << "Building hiddenSize: " << hiddenSize << endl; + } + if (field_name.compare("num_heads") == 0) { + numHeads = *static_cast(fc->fields[i].data); + IXRT_PLUGIN_CHECK_VALUE(numHeads > 0, ("QKV: Invalid numHeads " + std::to_string(numHeads)).c_str()); + gLogInfo << "Building numHeads: " << numHeads << endl; + } + if (field_name.compare("has_mask") == 0) { + auto hasMaskValue = *static_cast(fc->fields[i].data); + IXRT_PLUGIN_CHECK_VALUE(hasMaskValue == 0 || hasMaskValue == 1, + ("QKV: Invalid hasMask " + std::to_string(hasMaskValue)).c_str()); + hasMask = static_cast(hasMaskValue); + gLogInfo << "Building hasMask: " << hasMask << endl; + } + if (field_name.compare("has_qk_bias") == 0) { + auto hasQKBiasValue = *static_cast(fc->fields[i].data); + IXRT_PLUGIN_CHECK_VALUE(hasQKBiasValue == 0 || hasQKBiasValue == 1, + ("QKV: Invalid has_qk_bias " + std::to_string(hasQKBiasValue)).c_str()); + hasQkBias = static_cast(hasQKBiasValue); + gLogInfo << "Building hasQkBias: " << hasQkBias << endl; + } + if (field_name.compare("is_t5_mode") == 0) { + if (fc->fields[i].data != nullptr) { + isT5Mode = *static_cast(fc->fields[i].data); + } + IXRT_PLUGIN_CHECK_VALUE(isT5Mode == 0 || isT5Mode == 1, + ("QKV: Invalid isT5Mode " + std::to_string(isT5Mode)).c_str()); + gLogInfo << "Building isT5Mode: " << isT5Mode << endl; + } + } + + gLogInfo << "Building the Plugin..." << endl; + auto type = static_cast(typeId); + auto* p = + new QKVToContextPluginDynamic(name, type, hiddenSize, numHeads, dqProbs, hasMask, hasQkBias, isT5Mode); + return p; + } catch (std::exception const& e) { + caughtError(e); + } + return nullptr; +} + +IPluginV2* QKVToContextPluginDynamicCreator::deserializePlugin(char const* name, void const* serialData, + size_t serialLength) noexcept { + // This object will be deleted when the network is destroyed, which will + // call QKVToContextPluginDynamic::destroy() + return new QKVToContextPluginDynamic(name, serialData, serialLength); +} + +void QKVToContextPluginDynamicCreator::setPluginNamespace(char const* libNamespace) noexcept { + mNamespace = libNamespace; +} + +char const* QKVToContextPluginDynamicCreator::getPluginNamespace() const noexcept { return mNamespace.c_str(); } + +// REGISTER_TENSORRT_PLUGIN(QKVToContextPluginDynamicCreator); +//#########################################################################// +QKVToContextPluginDynamic::QKVToContextPluginDynamic(const std::string name, const DataType type, + const int32_t hiddenSize, const int32_t numHeads, + float const dqProbs, bool hasImask, bool hasQkBias, + int32_t isT5Mode) + : mLayerName(name), + mS(0), + mB(0), + mHeadSize(hiddenSize / numHeads), + mHiddenSize(hiddenSize), + mNumHeads(numHeads), + mHasImask(hasImask), + mHasQKBias(hasQkBias), + mType(type), + mT5Mode(isT5Mode) + +{ + // +} + +QKVToContextPluginDynamic::QKVToContextPluginDynamic(const std::string name, void const* data, size_t length) + : mLayerName(name) { + gLogInfo << "QKV Deser Start" << endl; + deserialize_value(&data, &length, &mType); + deserialize_value(&data, &length, &mNumHeads); + deserialize_value(&data, &length, &mHeadSize); + deserialize_value(&data, &length, &mHasImask); + deserialize_value(&data, &length, &mHasQKBias); + deserialize_value(&data, &length, &mHiddenSize); + deserialize_value(&data, &length, &mS); + deserialize_value(&data, &length, &mB); + deserialize_value(&data, &length, &mT5Mode); + + gLogInfo << "QKV Deser done" << endl; +} + +// IPluginV2 Methods +char const* QKVToContextPluginDynamic::getPluginType() const noexcept { return kQKV_TO_CONTEXT_IXRT_PLUGIN_NAME; } + +char const* QKVToContextPluginDynamic::getPluginVersion() const noexcept { return kQKV_TO_CONTEXT_IXRT_PLUGIN_VERSION; } + +int32_t QKVToContextPluginDynamic::getNbOutputs() const noexcept { return 1; } + +int32_t QKVToContextPluginDynamic::initialize() noexcept { return 0; } + +void QKVToContextPluginDynamic::terminate() noexcept {} + +size_t QKVToContextPluginDynamic::getSerializationSize() const noexcept { + return sizeof(mNumHeads) + sizeof(mHeadSize) + sizeof(DataType) + sizeof(mHasImask) + sizeof(mHasQKBias) + + sizeof(mHiddenSize) + sizeof(mS) + sizeof(mB) + sizeof(mT5Mode); +} + +void QKVToContextPluginDynamic::serialize(void* buffer) const noexcept { + serialize_value(&buffer, mType); + serialize_value(&buffer, mNumHeads); + serialize_value(&buffer, mHeadSize); + serialize_value(&buffer, mHasImask); + serialize_value(&buffer, mHasQKBias); + serialize_value(&buffer, mHiddenSize); + serialize_value(&buffer, mS); + serialize_value(&buffer, mB); + serialize_value(&buffer, mT5Mode); +} + +void QKVToContextPluginDynamic::destroy() noexcept { delete this; } + +void QKVToContextPluginDynamic::setPluginNamespace(char const* libNamespace) noexcept { mNamespace = libNamespace; } + +char const* QKVToContextPluginDynamic::getPluginNamespace() const noexcept { return mNamespace.c_str(); } + +// IPluginV2Ext Methods +DataType QKVToContextPluginDynamic::getOutputDataType(int32_t index, nvinfer1::DataType const* inputTypes, + int32_t /*nbInputs*/) const noexcept { + IXRT_PLUGIN_ASSERT(index == 0); + IXRT_PLUGIN_ASSERT(inputTypes[0] == DataType::kFLOAT || inputTypes[0] == DataType::kHALF || + inputTypes[0] == DataType::kINT8); + return inputTypes[0]; +} + +// IPluginV2DynamicExt Methods +nvinfer1::IPluginV2DynamicExt* QKVToContextPluginDynamic::clone() const noexcept { + gLogInfo << "QKV Clone" << endl; + + QKVToContextPluginDynamic* ret = nullptr; + ret = new QKVToContextPluginDynamic(mLayerName, mType, mHiddenSize, mNumHeads, mDqProbs, mHasImask, mHasQKBias, + mT5Mode); + + ret->setPluginNamespace(mNamespace.c_str()); + gLogInfo << "QKV Clone done" << endl; + return ret; +} + +DimsExprs QKVToContextPluginDynamic::getOutputDimensions(int32_t outputIndex, DimsExprs const* inputs, + int32_t /*nbInputs*/, IExprBuilder& exprBuilder) noexcept { + gLogInfo << "QKV getOutputDimensions" << endl; + // Input is BxSx3*N*H, output should be BxSxN*H + IXRT_PLUGIN_ASSERT(outputIndex == 0); + // Copy over everything + DimsExprs output(inputs[IIDX]); + // Divide last dim by three + auto const* three = exprBuilder.constant(3); + output.d[HDIM] = exprBuilder.constant(mHiddenSize); + // output.d[HDIM] = exprBuilder.operation(DimensionOperation::kFLOOR_DIV, *inputs[IIDX].d[HDIM], *three); + return output; +} +bool QKVToContextPluginDynamic::supportsFormatCombination(int32_t pos, PluginTensorDesc const* inOut, int32_t nbInputs, + int32_t /*nbOutputs*/) noexcept { + IXRT_PLUGIN_ASSERT(pos >= 0); + IXRT_PLUGIN_ASSERT(pos < 2 + mHasImask); + IXRT_PLUGIN_ASSERT(nbInputs == 1 + mHasImask); + auto const* in = inOut; + auto const* out = inOut + nbInputs; + + if (pos == 0) { + return (in->type == mType) && (in->format == TensorFormat::kLINEAR); + } + + // pos==1 + if ((mHasImask && pos == 1)) // pos 1 is the mask + { + auto const* inMask = &inOut[1]; + // detect full mask and check that it was produced + if (mHasQKBias) { + return (inMask->type == DataType::kFLOAT) && // precision + (inMask->format == TensorFormat::kLINEAR); // format + } else { + return (inMask->type == DataType::kINT32) && // precision + (inMask->format == TensorFormat::kLINEAR); // format + } + } + + if (!mHasImask || pos == 2) // output pos + { + return (in->type == out->type) && (out->format == TensorFormat::kLINEAR); + } + + return false; +} +void QKVToContextPluginDynamic::configurePlugin(DynamicPluginTensorDesc const* in, int32_t nbInputs, + DynamicPluginTensorDesc const* out, int32_t nbOutputs) noexcept { + IXRT_PLUGIN_ASSERT(nbInputs == 1 + mHasImask); + IXRT_PLUGIN_ASSERT(nbOutputs == 1); + PluginTensorDesc const& inDesc = in[IIDX].desc; + TRT_UNUSED inDesc; + PluginTensorDesc const& outDesc = out->desc; + TRT_UNUSED outDesc; + IXRT_PLUGIN_ASSERT(mType == inDesc.type); + IXRT_PLUGIN_ASSERT(mType == outDesc.type); + IXRT_PLUGIN_ASSERT(inDesc.dims.nbDims == 3 || inDesc.dims.nbDims == 5) + IXRT_PLUGIN_ASSERT(inDesc.dims.d[HDIM] == 3 * mHiddenSize); + // IXRT_PLUGIN_ASSERT(inDesc.dims.d[3] == 1); + // IXRT_PLUGIN_ASSERT(inDesc.dims.d[4] == 1); + + if (mHasImask) { + PluginTensorDesc const& maskDesc = in[MIDX].desc; + TRT_UNUSED maskDesc; + if (!mHasQKBias) { + IXRT_PLUGIN_ASSERT(maskDesc.dims.nbDims == 2); + IXRT_PLUGIN_ASSERT(maskDesc.dims.d[0] == inDesc.dims.d[0]); + IXRT_PLUGIN_ASSERT(maskDesc.dims.d[1] == inDesc.dims.d[1]); + } else { + IXRT_PLUGIN_ASSERT(maskDesc.dims.nbDims <= 4 && maskDesc.dims.nbDims >= 2); + } + } + + const int32_t S = inDesc.dims.d[SDIM] <= 0 ? in->max.d[SDIM] : inDesc.dims.d[SDIM]; + const int32_t B = inDesc.dims.d[BDIM] <= 0 ? in->max.d[BDIM] : inDesc.dims.d[BDIM]; + mS = S; + mB = B; + + IXRT_PLUGIN_ASSERT(outDesc.dims.nbDims == 3 || outDesc.dims.nbDims == 5); + IXRT_PLUGIN_ASSERT(outDesc.dims.d[BDIM] == inDesc.dims.d[BDIM]); + IXRT_PLUGIN_ASSERT(outDesc.dims.d[SDIM] == mS); + IXRT_PLUGIN_ASSERT(outDesc.dims.d[HDIM] == mHiddenSize); + // IXRT_PLUGIN_ASSERT(outDesc.dims.d[3] == 1); + // IXRT_PLUGIN_ASSERT(outDesc.dims.d[4] == 1); + +#ifdef __ILUVATAR__ + CUINFER_CHECK(cuinferCreate(&cuinfer_handle)); +#else + CHECK_GPU_ERROR(cublasLtCreate(&blaslt_handle)); +#endif +} + +size_t QKVToContextPluginDynamic::getWorkspaceSize(PluginTensorDesc const* inputs, int32_t nbInputs, + PluginTensorDesc const* outputs, int32_t nbOutputs) const noexcept { + const int32_t B = inputs->dims.d[0]; + const int32_t S = inputs->dims.d[1]; + const int32_t E = inputs->dims.d[2]; + int32_t fmha_S = S; + int64_t buffer_size = B * fmha_S * E; +#ifndef __ILUVATAR__ + buffer_size += B * S * S * mNumHeads; +#endif + return buffer_size * sizeof(mType); +} +/* +mS(0) +mB(0) +mHeadSize(hiddenSize / numHeads) +mHiddenSize(hiddenSize) +mNumHeads(numHeads) +mHasImask(hasImask) +mType(type) +*/ + +inline void print_element(half* x, int num, string name) { + printf("%s: \n", name.c_str()); + half* out = (half*)malloc(num * sizeof(half)); + cudaMemcpy(out, x, num * sizeof(half), cudaMemcpyDeviceToHost); + for (auto i = 0; i < num; i++) { + printf("%f\n", __half2float(out[i])); + } + printf("\n"); +} + +int32_t QKVToContextPluginDynamic::enqueue(PluginTensorDesc const* inputDesc, PluginTensorDesc const* outputDesc, + void const* const* inputs, void* const* outputs, void* workspace, + cudaStream_t stream) noexcept { + gLogInfo << "in QKVToContextPluginDynamic.." << endl; + int32_t S = inputDesc->dims.d[SDIM]; + int32_t B = inputDesc->dims.d[BDIM]; + int32_t status = STATUS_SUCCESS; +#ifdef __ILUVATAR__ + CUINFER_CHECK(cuinferSetStream(cuinfer_handle, stream)); +#endif + try { + if (mType != DataType::kHALF) { + gLogError << "embLayerNormPlugin infer type{" << int(mType) << "} not supported!" << endl; + return STATUS_NOT_SUPPORTED; + } + half* qkv_buffer_ = (half*)inputs[0]; + half* qkv_out_ = (half*)outputs[0]; + // [B, fmha_S] + int32_t* mask_ = (mHasImask && !mHasQKBias) ? (int32_t*)inputs[1] : nullptr; + int fmha_seq_len = S; + + int64_t buffer_size = B * fmha_seq_len * mHiddenSize; + half* q_buffer_ = reinterpret_cast(workspace); + half* k_buffer_ = q_buffer_ + buffer_size; + half* v_buffer_ = k_buffer_ + buffer_size; + half* qk_out_ = v_buffer_ + buffer_size; + + // [B, S, 3*E, 1, 1] [B, fmha_S] +#ifdef __ILUVATAR__ + auto status = cudaSuccess; + if (mHasQKBias) { + status = fused_multihead_attetion_with_posbias( + qkv_buffer_, (float*)inputs[1], q_buffer_, k_buffer_, v_buffer_, qk_out_, qkv_out_, B, mHeadSize, + mNumHeads, mHiddenSize, S, fmha_seq_len, cuinfer_handle, stream, inputDesc[1].dims, mT5Mode); + } else { + status = + fused_multihead_attetion(qkv_buffer_, mask_, q_buffer_, k_buffer_, v_buffer_, qk_out_, qkv_out_, B, + mHeadSize, mNumHeads, mHiddenSize, S, fmha_seq_len, cuinfer_handle, stream); + } +#else + auto status = + fused_multihead_attetion(qkv_buffer_, mask_, q_buffer_, k_buffer_, v_buffer_, qk_out_, qkv_out_, B, + mHeadSize, mNumHeads, mHiddenSize, S, fmha_seq_len, blaslt_handle, stream); +#endif + return status; + + } catch (std::exception const& e) { + caughtError(e); + return STATUS_FAILURE; + } +} diff --git a/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/qkv_to_context/qkvToContextPlugin.cu b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/qkv_to_context/qkvToContextPlugin.cu new file mode 100644 index 0000000000000000000000000000000000000000..76df7213ef77ed9d8827026adfe098bd19c999fc --- /dev/null +++ b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/qkv_to_context/qkvToContextPlugin.cu @@ -0,0 +1,256 @@ +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +* All Rights Reserved. +* +* Licensed under the Apache License, Version 2.0 (the "License"); you may +* not use this file except in compliance with the License. You may obtain +* a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +* License for the specific language governing permissions and limitations +* under the License. +*/ +#include // 1.包含头文件 +#include // 二进制写文件 + +#include "backend/cublas/cublas_helper.h" +#include "backend/ixinfer/ixinfer_gemm_helper.h" +#include "backend/transformer/transformer_arrange.h" +#include "backend/transformer/transformer_helper.cuh" +#include "backend/transformer/transformer_softmax.h" +#include "qkvToContextPlugin.h" +using namespace std; + +using namespace nvinfer1::plugin::backend; + +namespace nvinfer1::plugin { +namespace bert { + +void __global__ IxinferArrangeEncQkvKernel(half *ori_qkv, half *new_q, half *new_k, half *new_v, int head_dim, + int head_num, int batch_seq_len, int fmha_seq_len) { + int hidden_size = head_dim * head_num; + int batch_id = blockIdx.x; + int token_id = blockIdx.y; + + int i = threadIdx.x; // 1个线程处理2个数据 + int head_id = (i * 2) / head_dim; + int dim_id = (i * 2) % head_dim; + + half2 *p_ori_qkv = (half2 *)(ori_qkv + batch_id * batch_seq_len * hidden_size * 3 + token_id * hidden_size * 3); + half2 *p_new_qkv; + + int target_id = batch_id * head_num * fmha_seq_len * head_dim + head_id * fmha_seq_len * head_dim + + token_id * head_dim + dim_id; + /* q */ + p_new_qkv = (half2 *)(new_q + target_id); + p_new_qkv[0] = p_ori_qkv[i]; + /* k */ + p_ori_qkv += hidden_size / 2; + p_new_qkv = (half2 *)(new_k + target_id); + p_new_qkv[0] = p_ori_qkv[i]; + /* v */ + p_ori_qkv += hidden_size / 2; + p_new_qkv = (half2 *)(new_v + target_id); + p_new_qkv[0] = p_ori_qkv[i]; +} + +void IxinferArrangeEncQkv(half *ori_qkv, half *new_q, half *new_k, half *new_v, int bsz, int head_num, int head_dim, + int ori_seq_len, int fmha_seq_len, cudaStream_t stream) { + int hsz = head_num * head_dim; + if (hsz / 2 > 4096) { + throw std::runtime_error("hidden_size / 2 > 4096"); + } + if (hsz % 2 != 0) { + throw std::runtime_error("hsz % 2 != 0"); + } + if (head_dim % 2 != 0) { + throw std::runtime_error("head_dim %2 != 0"); + } + // std::cout << "ori_seq_len: " << ori_seq_len << std::endl; + // std::cout << "fmha_seq_len: " << fmha_seq_len << std::endl; + dim3 blockSize(bsz, ori_seq_len); + IxinferArrangeEncQkvKernel<<>>(ori_qkv, new_q, new_k, new_v, head_dim, head_num, + ori_seq_len, fmha_seq_len); +} + +#ifdef __ILUVATAR__ +cudaError_t fused_multihead_attetion_with_posbias(half *qkv_buffer, float *mask, half *q_buffer, half *k_buffer, + half *v_buffer, half *qk_out, half *qkv_out, int bsz, int head_dim, + int head_num, int hsz, int ori_seq_len, int fmha_seq_len, + cuinferHandle_t &cuinfer_handle, cudaStream_t &stream, + const Dims &mask_dim, int32_t is_t5_mode) { + /* qkv arrange*/ + // bsz,ori_seq_len,3*hsz -> 3*(bsz,head_num,fmha_seq_len,head_dim) + IxinferArrangeEncQkv(qkv_buffer, q_buffer, k_buffer, v_buffer, bsz, head_num, head_dim, ori_seq_len, fmha_seq_len, + stream); + + if (is_t5_mode > 0) { + cuinferTensorDescriptor_t qDesc, kDesc, vDesc, maskDesc, oDesc; + cuinferDataType_t _cuinferCompType = cuinferDataType_t::CUINFER_DATA_FLOAT; + cuinferDataType_t _cuinferDataType = cuinferDataType_t::CUINFER_DATA_HALF; + cuinferDataType_t _cuinferMaskType = cuinferDataType_t::CUINFER_DATA_FLOAT; + cuinferCreateTensorDescriptor(&qDesc); + cuinferCreateTensorDescriptor(&kDesc); + cuinferCreateTensorDescriptor(&vDesc); + cuinferCreateTensorDescriptor(&maskDesc); + cuinferCreateTensorDescriptor(&oDesc); + + cuinferFlashAttnConfigInfo flashAttnInfo; + flashAttnInfo.layout = cuinferFlashAttnLayout_t::CUINFER_FATTN_BHSD; + flashAttnInfo.kvSeqStart = 0; + flashAttnInfo.kvSeqEnd = fmha_seq_len; + flashAttnInfo.kvHeadNum = head_num; + flashAttnInfo.scaling = 1.0f; + + cuinferSetTensor4dDescriptor(qDesc, cuinferTensorFormat_t::CUINFER_TENSOR_NCHW, _cuinferDataType, bsz, head_num, + fmha_seq_len, head_dim); + cuinferSetTensor4dDescriptor(kDesc, cuinferTensorFormat_t::CUINFER_TENSOR_NCHW, _cuinferDataType, bsz, head_num, + fmha_seq_len, head_dim); + cuinferSetTensor4dDescriptor(vDesc, cuinferTensorFormat_t::CUINFER_TENSOR_NCHW, _cuinferDataType, bsz, head_num, + fmha_seq_len, head_dim); + + switch (mask_dim.nbDims) { + case 4: + cuinferSetTensor4dDescriptor(maskDesc, cuinferTensorFormat_t::CUINFER_TENSOR_NCHW, _cuinferMaskType, + mask_dim.d[0], mask_dim.d[1], mask_dim.d[2], mask_dim.d[3]); + break; + case 3: + cuinferSetTensor4dDescriptor(maskDesc, cuinferTensorFormat_t::CUINFER_TENSOR_NCHW, _cuinferMaskType, 1, + mask_dim.d[0], mask_dim.d[1], mask_dim.d[2]); + break; + case 2: + cuinferSetTensor4dDescriptor(maskDesc, cuinferTensorFormat_t::CUINFER_TENSOR_NCHW, _cuinferMaskType, 1, + 1, mask_dim.d[0], mask_dim.d[1]); + break; + default: + break; + } + cuinferSetTensor4dDescriptor(oDesc, cuinferTensorFormat_t::CUINFER_TENSOR_NCHW, _cuinferDataType, bsz, head_num, + fmha_seq_len, head_dim); + + cuinferFMHAForwardEx(cuinfer_handle, flashAttnInfo, qDesc, q_buffer, kDesc, k_buffer, vDesc, v_buffer, maskDesc, + mask, oDesc, q_buffer); + } else { + cuinferTensorDescriptor_t qDesc, kDesc, vDesc, maskDesc, oDesc; + cuinferDataType_t _cuinferCompType = cuinferDataType_t::CUINFER_DATA_FLOAT; + cuinferDataType_t _cuinferDataType = cuinferDataType_t::CUINFER_DATA_HALF; + cuinferDataType_t _cuinferMaskType = cuinferDataType_t::CUINFER_DATA_FLOAT; + cuinferCreateTensorDescriptor(&qDesc); + cuinferCreateTensorDescriptor(&kDesc); + cuinferCreateTensorDescriptor(&vDesc); + cuinferCreateTensorDescriptor(&maskDesc); + cuinferCreateTensorDescriptor(&oDesc); + + cuinferSetTensor4dDescriptor(qDesc, cuinferTensorFormat_t::CUINFER_TENSOR_NCHW, _cuinferDataType, bsz, head_num, + fmha_seq_len, head_dim); + cuinferSetTensor4dDescriptor(kDesc, cuinferTensorFormat_t::CUINFER_TENSOR_NCHW, _cuinferDataType, bsz, head_num, + fmha_seq_len, head_dim); + cuinferSetTensor4dDescriptor(vDesc, cuinferTensorFormat_t::CUINFER_TENSOR_NCHW, _cuinferDataType, bsz, head_num, + fmha_seq_len, head_dim); + switch (mask_dim.nbDims) { + case 4: + cuinferSetTensor4dDescriptor(maskDesc, cuinferTensorFormat_t::CUINFER_TENSOR_NCHW, _cuinferMaskType, + mask_dim.d[0], mask_dim.d[1], mask_dim.d[2], mask_dim.d[3]); + break; + case 3: + cuinferSetTensor4dDescriptor(maskDesc, cuinferTensorFormat_t::CUINFER_TENSOR_NCHW, _cuinferMaskType, 1, + mask_dim.d[0], mask_dim.d[1], mask_dim.d[2]); + break; + case 2: + cuinferSetTensor4dDescriptor(maskDesc, cuinferTensorFormat_t::CUINFER_TENSOR_NCHW, _cuinferMaskType, 1, + 1, mask_dim.d[0], mask_dim.d[1]); + break; + default: + break; + } + cuinferSetTensor4dDescriptor(oDesc, cuinferTensorFormat_t::CUINFER_TENSOR_NCHW, _cuinferDataType, bsz, head_num, + fmha_seq_len, head_dim); + + cuinferFMHAParam fmha_param; + cuinferFMHAForward(cuinfer_handle, fmha_param, _cuinferCompType, _cuinferDataType, _cuinferMaskType, qDesc, + q_buffer, kDesc, k_buffer, vDesc, v_buffer, maskDesc, mask, oDesc, q_buffer, true); + } + IxinferEncAttnOutArrange(q_buffer, qkv_out, bsz, ori_seq_len, fmha_seq_len, head_num, head_dim, stream); + return cudaSuccess; +} + +cudaError_t fused_multihead_attetion(half *qkv_buffer, int32_t *mask, half *q_buffer, half *k_buffer, half *v_buffer, + half *qk_out, half *qkv_out, int bsz, int head_dim, int head_num, int hsz, + int ori_seq_len, int fmha_seq_len, cuinferHandle_t &cuinfer_handle, + cudaStream_t &stream, bool use_fmha, bool is_casual_mask) { + /* qkv arrange*/ + // bsz,ori_seq_len,3*hsz -> 3*(bsz,head_num,fmha_seq_len,head_dim) + IxinferArrangeEncQkv(qkv_buffer, q_buffer, k_buffer, v_buffer, bsz, head_num, head_dim, ori_seq_len, fmha_seq_len, + stream); + if (use_fmha) { + cuinferTensorDescriptor_t qDesc, kDesc, vDesc, maskDesc, oDesc; + cuinferDataType_t _cuinferCompType = cuinferDataType_t::CUINFER_DATA_FLOAT; + cuinferDataType_t _cuinferDataType = cuinferDataType_t::CUINFER_DATA_HALF; + cuinferDataType_t _cuinferMaskType = cuinferDataType_t::CUINFER_DATA_INT32; + cuinferCreateTensorDescriptor(&qDesc); + cuinferCreateTensorDescriptor(&kDesc); + cuinferCreateTensorDescriptor(&vDesc); + cuinferCreateTensorDescriptor(&maskDesc); + cuinferCreateTensorDescriptor(&oDesc); + + cuinferSetTensor4dDescriptor(qDesc, cuinferTensorFormat_t::CUINFER_TENSOR_NCHW, _cuinferDataType, bsz, head_num, + fmha_seq_len, head_dim); + cuinferSetTensor4dDescriptor(kDesc, cuinferTensorFormat_t::CUINFER_TENSOR_NCHW, _cuinferDataType, bsz, head_num, + fmha_seq_len, head_dim); + cuinferSetTensor4dDescriptor(vDesc, cuinferTensorFormat_t::CUINFER_TENSOR_NCHW, _cuinferDataType, bsz, head_num, + fmha_seq_len, head_dim); + cuinferSetTensor4dDescriptor(maskDesc, cuinferTensorFormat_t::CUINFER_TENSOR_NCHW, _cuinferMaskType, bsz, 1, 1, + fmha_seq_len); + cuinferSetTensor4dDescriptor(oDesc, cuinferTensorFormat_t::CUINFER_TENSOR_NCHW, _cuinferDataType, bsz, head_num, + fmha_seq_len, head_dim); + + cuinferFMHAParam fmha_param; + cuinferFMHAForward(cuinfer_handle, fmha_param, _cuinferCompType, _cuinferDataType, _cuinferMaskType, qDesc, + q_buffer, kDesc, k_buffer, vDesc, v_buffer, maskDesc, mask, oDesc, q_buffer, true); + } else { + cuinfer_gemm(k_buffer, q_buffer, nullptr, qk_out, bsz * head_num, fmha_seq_len, fmha_seq_len, head_dim, + fmha_seq_len * head_dim, fmha_seq_len * head_dim, fmha_seq_len * fmha_seq_len, + 1.0 / sqrt(head_dim * 1.0), -1, stream, cuinfer_handle); + if (is_casual_mask) { + IxinferCorrelationSoftmaxEncself(bsz, fmha_seq_len, head_num, stream, qk_out, mask, true); + } else { + IxinferCorrelationSoftmaxEncself(bsz, fmha_seq_len, head_num, stream, qk_out, mask); + } + + cuinfer_nn_gemm(v_buffer, qk_out, nullptr, q_buffer, bsz * head_num, head_dim, fmha_seq_len, fmha_seq_len, + fmha_seq_len * head_dim, fmha_seq_len * fmha_seq_len, fmha_seq_len * head_dim, 1.0f, -1, stream, + cuinfer_handle); + } + IxinferEncAttnOutArrange(q_buffer, qkv_out, bsz, ori_seq_len, fmha_seq_len, head_num, head_dim, stream); + return cudaSuccess; +} +#else +cudaError_t fused_multihead_attetion(half *qkv_buffer, int32_t *mask, half *q_buffer, half *k_buffer, half *v_buffer, + half *qk_out, half *qkv_out, int bsz, int head_dim, int head_num, int hsz, + int ori_seq_len, int fmha_seq_len, cublasLtHandle_t &blaslt_handle, + cudaStream_t &stream) { + /* qkv arrange*/ + // bsz,ori_seq_len,3*hsz -> 3*(bsz,head_num,fmha_seq_len,head_dim) + IxinferArrangeEncQkv(qkv_buffer, q_buffer, k_buffer, v_buffer, bsz, head_num, head_dim, ori_seq_len, fmha_seq_len, + stream); + + cublaslt_gemm(k_buffer, q_buffer, qk_out, bsz * head_num, fmha_seq_len, fmha_seq_len, head_dim, + fmha_seq_len * head_dim, fmha_seq_len * head_dim, fmha_seq_len * fmha_seq_len, + 1.0 / sqrt(head_dim * 1.0), blaslt_handle, stream); + + IxinferCorrelationSoftmaxEncself(bsz, fmha_seq_len, head_num, stream, qk_out, mask); + + cublaslt_gemm_nn(v_buffer, qk_out, q_buffer, bsz * head_num, head_dim, fmha_seq_len, fmha_seq_len, + fmha_seq_len * head_dim, fmha_seq_len * fmha_seq_len, fmha_seq_len * head_dim, 1.0f, blaslt_handle, + stream); + + IxinferEncAttnOutArrange(q_buffer, qkv_out, bsz, ori_seq_len, fmha_seq_len, head_num, head_dim, stream); + return cudaSuccess; +} +#endif + +} // namespace bert +} // namespace nvinfer1::plugin diff --git a/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/qkv_to_context/qkvToContextPlugin.h b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/qkv_to_context/qkvToContextPlugin.h new file mode 100644 index 0000000000000000000000000000000000000000..eb8dbb6c118bc092c48280d74a44270e7fd0827d --- /dev/null +++ b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/qkv_to_context/qkvToContextPlugin.h @@ -0,0 +1,142 @@ +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +* All Rights Reserved. +* +* Licensed under the Apache License, Version 2.0 (the "License"); you may +* not use this file except in compliance with the License. You may obtain +* a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +* License for the specific language governing permissions and limitations +* under the License. +*/ +#pragma once +#include +#include + +#include + +#include "NvInferRuntime.h" +#include "NvInferRuntimeCommon.h" +#include "bertCommon.h" + +namespace nvinfer1::plugin { +namespace bert { +#ifdef __ILUVATAR__ +cudaError_t fused_multihead_attetion(half* qkv_buffer, int32_t* mask, half* q_buffer, half* k_buffer, half* v_buffer, + half* qk_out, half* qkv_out, int bsz, int head_dim, int head_num, int hsz, + int ori_seq_len, int fmha_seq_len, cuinferHandle_t& cuinfer_handle, + cudaStream_t& stream, bool use_fmha = true, bool is_casual_mask = false); +cudaError_t fused_multihead_attetion_with_posbias(half* qkv_buffer, float* mask, half* q_buffer, half* k_buffer, + half* v_buffer, half* qk_out, half* qkv_out, int bsz, int head_dim, + int head_num, int hsz, int ori_seq_len, int fmha_seq_len, + cuinferHandle_t& cuinfer_handle, cudaStream_t& stream, + const Dims& mask_dim, int32_t is_t5_mode); +#else +cudaError_t fused_multihead_attetion(half* qkv_buffer, int32_t* mask, half* q_buffer, half* k_buffer, half* v_buffer, + half* qk_out, half* qkv_out, int bsz, int head_dim, int head_num, int hsz, + int ori_seq_len, int fmha_seq_len, cublasLtHandle_t& blaslt_handle, + cudaStream_t& stream); +#endif + +void IxinferArrangeEncQkv(half* ori_qkv, half* new_q, half* new_k, half* new_v, int bsz, int head_num, int head_dim, + int ori_seq_len, int fmha_seq_len, cudaStream_t stream); + +class QKVToContextPluginDynamic : public nvinfer1::IPluginV2DynamicExt { + public: + QKVToContextPluginDynamic(const std::string name, const nvinfer1::DataType type, const int32_t hiddenSize, + const int32_t numHeads, float const dqProbs, bool hasImask = false, + bool hasQkBias = false, int32_t isT5Mode = -1); + + QKVToContextPluginDynamic(const std::string name, void const* data, size_t length); + + // It doesn't make sense to make QKVToContextPluginDynamic without arguments, so we + // delete default constructor. + QKVToContextPluginDynamic() = delete; + + // IPluginV2 Methods + char const* getPluginType() const noexcept override; + char const* getPluginVersion() const noexcept override; + int32_t getNbOutputs() const noexcept override; + int32_t initialize() noexcept override; + void terminate() noexcept override; + size_t getSerializationSize() const noexcept override; + void serialize(void* buffer) const noexcept override; + void destroy() noexcept override; + void setPluginNamespace(char const* pluginNamespace) noexcept override; + char const* getPluginNamespace() const noexcept override; + + // IPluginV2Ext Methods + nvinfer1::DataType getOutputDataType(int32_t index, nvinfer1::DataType const* inputTypes, + int32_t nbInputs) const noexcept override; + + // IPluginV2DynamicExt Methods + nvinfer1::IPluginV2DynamicExt* clone() const noexcept override; + nvinfer1::DimsExprs getOutputDimensions(int32_t outputIndex, nvinfer1::DimsExprs const* inputs, int32_t nbInputs, + nvinfer1::IExprBuilder& exprBuilder) noexcept override; + bool supportsFormatCombination(int32_t pos, nvinfer1::PluginTensorDesc const* inOut, int32_t nbInputs, + int32_t nbOutputs) noexcept override; + void configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int32_t nbInputs, + nvinfer1::DynamicPluginTensorDesc const* out, int32_t nbOutputs) noexcept override; + size_t getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int32_t nbInputs, + nvinfer1::PluginTensorDesc const* outputs, int32_t nbOutputs) const noexcept override; + int32_t enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc, + void const* const* inputs, void* const* outputs, void* workspace, + cudaStream_t stream) noexcept override; + + private: + const std::string mLayerName; + std::string mNamespace; + + int32_t mS; + int32_t mB; + int32_t mSM; + int32_t mHeadSize; + int32_t mHiddenSize; + int32_t mNumHeads; + bool mHasImask; + bool mHasQKBias; + nvinfer1::DataType mType; + float mDqProbs; + int32_t mT5Mode; + +#ifdef __ILUVATAR__ + cuinferHandle_t cuinfer_handle; +#else + cublasLtHandle_t blaslt_handle; +#endif + cudaStream_t stream; + + half* query_; +}; + +class QKVToContextPluginDynamicCreator : public nvinfer1::IPluginCreator { + public: + QKVToContextPluginDynamicCreator(); + + char const* getPluginName() const noexcept override; + + char const* getPluginVersion() const noexcept override; + + nvinfer1::PluginFieldCollection const* getFieldNames() noexcept override; + + nvinfer1::IPluginV2* createPlugin(char const* name, nvinfer1::PluginFieldCollection const* fc) noexcept override; + + nvinfer1::IPluginV2* deserializePlugin(char const* name, void const* serialData, + size_t serialLength) noexcept override; + + void setPluginNamespace(char const* pluginNamespace) noexcept override; + + char const* getPluginNamespace() const noexcept override; + + private: + static nvinfer1::PluginFieldCollection mFC; + static vector mPluginAttributes; + std::string mNamespace; +}; + +} // namespace bert +} // namespace nvinfer1::plugin diff --git a/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/skip_layernorm/skipLayerNormInt8Plugin.cpp b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/skip_layernorm/skipLayerNormInt8Plugin.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5fd59035b1da3ef8f7476193d51c8a77b8c91189 --- /dev/null +++ b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/skip_layernorm/skipLayerNormInt8Plugin.cpp @@ -0,0 +1,421 @@ +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +* All Rights Reserved. +* +* Licensed under the Apache License, Version 2.0 (the "License"); you may +* not use this file except in compliance with the License. You may obtain +* a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +* License for the specific language governing permissions and limitations +* under the License. +*/ +#include "skipLayerNormInt8Plugin.h" + +#include +#include + +#include "NvInferRuntime.h" +#include "NvInferRuntimeCommon.h" +#include "backend/bert/bert_layer_kernel.h" +#include "checkMacrosPlugin.h" +#include "cuda_fp16.hpp" +#include "cuda_runtime_api.h" +#include "driver_types.h" +#include "plugin.h" +#include "serialize.h" + +using namespace nvinfer1; +using namespace nvinfer1::plugin; +using namespace nvinfer1::plugin::bert; + +// Clip plugin specific constants +namespace { +char const* kSKIP_LAYER_NORM_INT8_VERSION_HFACE{"3"}; +char const* kSKIP_LAYER_NORM_INT8_VERSION_MTRON{"4"}; +char const* kSKIP_LAYER_NORM_INT8_NAME{"CustomSkipLayerNormPluginDynamic_IxRT"}; +} // namespace + +// Static class fields initialization +PluginFieldCollection SkipLayerNormInt8PluginBaseCreator::mFC{}; +std::vector SkipLayerNormInt8PluginBaseCreator::mPluginAttributes; + +constexpr auto param_type = DataType::kHALF; + +SkipLayerNormInt8PluginBaseCreator::SkipLayerNormInt8PluginBaseCreator() { + mPluginAttributes.clear(); + mPluginAttributes.emplace_back(PluginField("beta")); + mPluginAttributes.emplace_back(PluginField("gamma")); + mPluginAttributes.emplace_back(PluginField("bias")); + mPluginAttributes.emplace_back(PluginField("output_fp32")); + mFC.nbFields = mPluginAttributes.size(); + mFC.fields = mPluginAttributes.data(); +} + +SkipLayerNormInt8PluginHFaceCreator::SkipLayerNormInt8PluginHFaceCreator() : SkipLayerNormInt8PluginBaseCreator() {} + +char const* SkipLayerNormInt8PluginBaseCreator::getPluginName() const noexcept { return kSKIP_LAYER_NORM_INT8_NAME; } + +PluginFieldCollection const* SkipLayerNormInt8PluginBaseCreator::getFieldNames() noexcept { return &mFC; } + +void SkipLayerNormInt8PluginBaseCreator::setPluginNamespace(char const* libNamespace) noexcept { + mNamespace = libNamespace; +} + +char const* SkipLayerNormInt8PluginBaseCreator::getPluginNamespace() const noexcept { return mNamespace.c_str(); } + +char const* SkipLayerNormInt8PluginHFaceCreator::getPluginVersion() const noexcept { + return kSKIP_LAYER_NORM_INT8_VERSION_HFACE; +} + +bool buildBetaAndGamma(PluginFieldCollection const* fc, Weights& beta, Weights& gamma, Weights& bias) { + plugin::validateRequiredAttributesExist({"beta", "gamma"}, fc); + + bool output_fp32 = false; + + for (int32_t i = 0; i < fc->nbFields; i++) { + std::string field_name(fc->fields[i].name); + + if (field_name.compare("beta") == 0) { + gLogInfo << "Building beta..." << endl; + beta.values = fc->fields[i].data; + beta.count = fc->fields[i].length; + beta.type = fieldTypeToDataType(fc->fields[i].type); + } + + if (field_name.compare("gamma") == 0) { + gLogInfo << "Building gamma..." << endl; + gamma.values = fc->fields[i].data; + gamma.count = fc->fields[i].length; + gamma.type = fieldTypeToDataType(fc->fields[i].type); + } + + if (field_name.compare("bias") == 0) { + gLogInfo << "Building bias..." << endl; + bias.values = fc->fields[i].data; + bias.count = fc->fields[i].length; + bias.type = fieldTypeToDataType(fc->fields[i].type); + } + + if (field_name.compare("output_fp32") == 0) { + IXRT_PLUGIN_ASSERT(fc->fields[i].type == PluginFieldType::kINT32); + output_fp32 = (static_cast(fc->fields[i].data)[0] == 1); + gLogInfo << "Building output_fp32" << output_fp32 << endl; + } + } + + IXRT_PLUGIN_CHECK_VALUE(beta.values != nullptr, "SkipLayerNorm: invalid beta"); + IXRT_PLUGIN_CHECK_VALUE(beta.count > 0, "SkipLayerNorm: invalid beta"); + + IXRT_PLUGIN_CHECK_VALUE(gamma.values != nullptr, "SkipLayerNorm: invalid gamma"); + IXRT_PLUGIN_CHECK_VALUE(gamma.count > 0, "SkipLayerNorm: invalid gamma"); + return output_fp32; +} + +IPluginV2* SkipLayerNormInt8PluginHFaceCreator::createPlugin(char const* name, + PluginFieldCollection const* fc) noexcept { + try { + gLogInfo << "SkipLayerNormInt8PluginHFaceCreator createPlugin" << endl; + + Weights beta{DataType::kFLOAT, nullptr, 0}; + Weights gamma{DataType::kFLOAT, nullptr, 0}; + Weights bias{DataType::kFLOAT, nullptr, 0}; + bool output_fp32 = buildBetaAndGamma(fc, beta, gamma, bias); + return new SkipLayerNormInt8PluginHFace(name, beta, gamma, bias, output_fp32); + } catch (std::exception const& e) { + caughtError(e); + } + return nullptr; +} + +IPluginV2* SkipLayerNormInt8PluginHFaceCreator::deserializePlugin(char const* name, void const* serialData, + size_t serialLength) noexcept { + // This object will be deleted when the network is destroyed, which will + // call SkipLayerNormInterleavedPlugin::destroy() + try { + gLogInfo << "SkipLayerNormInterleavedPluginHFaceCreator deserializePlugin" << endl; + return new SkipLayerNormInt8PluginHFace(name, serialData, serialLength); + } catch (std::exception const& e) { + caughtError(e); + } + return nullptr; +} + +// REGISTER_TENSORRT_PLUGIN(SkipLayerNormInt8PluginHFaceCreator); +//#########################################################################// +SkipLayerNormInt8PluginBase::SkipLayerNormInt8PluginBase(std::string const& name, Weights const& beta, + Weights const& gamma, Weights const& bias, bool output_fp32) + : mLayerName(name), + mGammaDev(nullptr), + mBetaDev(nullptr), + mBiasDev(nullptr), + mLd(beta.count), + mParamsOnDevice(false), + output_fp32(output_fp32) { + IXRT_PLUGIN_ASSERT(mLd > 0); + IXRT_PLUGIN_ASSERT(beta.count == gamma.count); + // dataType for beta, gamma weights is always fp16 + mParamWordsize = getElementSize(param_type); + + mBeta.convertAndCopy(beta, param_type); + mGamma.convertAndCopy(gamma, param_type); + + mHasBias = (bias.values != nullptr); + if (mHasBias) { + mBias.convertAndCopy(bias, param_type); + } + + copyToDevice(mGamma, getWeightsSize(mGamma, param_type), mGammaDev); + copyToDevice(mBeta, getWeightsSize(mBeta, param_type), mBetaDev); + if (mHasBias) { + copyToDevice(mBias, getWeightsSize(mBias, param_type), mBiasDev); + } +} + +SkipLayerNormInt8PluginBase::SkipLayerNormInt8PluginBase(std::string const& name, void const* data, size_t length) + : mLayerName(name), mGammaDev(nullptr), mBetaDev(nullptr), mParamsOnDevice(false) { + // Deserialize in the same order as serialization + deserialize_value(&data, &length, &mLd); + deserialize_value(&data, &length, &mHasBias); + deserialize_value(&data, &length, &output_fp32); + + mParamWordsize = getElementSize(param_type); + + char const* d = static_cast(data); + mBeta.convertAndCopy(d, mLd, param_type); + mGamma.convertAndCopy(d, mLd, param_type); + + if (mHasBias) { + mBias.convertAndCopy(d, mLd, param_type); + } + + copyToDevice(mGamma, getWeightsSize(mGamma, param_type), mGammaDev); + copyToDevice(mBeta, getWeightsSize(mBeta, param_type), mBetaDev); + if (mHasBias) { + copyToDevice(mBias, getWeightsSize(mBias, param_type), mBiasDev); + } +} + +SkipLayerNormInt8PluginHFace::SkipLayerNormInt8PluginHFace(std::string const& name, Weights const& beta, + Weights const& gamma, Weights const& bias, bool output_fp32) + : SkipLayerNormInt8PluginBase(name, beta, gamma, bias, output_fp32) {} + +SkipLayerNormInt8PluginHFace::SkipLayerNormInt8PluginHFace(std::string const& name, void const* data, size_t length) + : SkipLayerNormInt8PluginBase(name, data, length) { + gLogInfo << "SkipLayerNormInt8PluginHFace deserialize" << endl; +} + +// IPluginV2 Methods +char const* SkipLayerNormInt8PluginBase::getPluginType() const noexcept { return kSKIP_LAYER_NORM_INT8_NAME; } + +size_t SkipLayerNormInt8PluginBase::getSerializationSize() const noexcept { + const size_t biasSize = mHasBias ? (mLd * mParamWordsize) : 0; + return 2 * mParamWordsize * mLd + sizeof(mLd) + sizeof(mHasBias) + sizeof(output_fp32) + biasSize; +} + +void SkipLayerNormInt8PluginBase::serialize(void* buffer) const noexcept { + try { + serialize_value(&buffer, mLd); + serialize_value(&buffer, mHasBias); + serialize_value(&buffer, output_fp32); + + char* d = static_cast(buffer); + serFromDev(d, static_cast(mBetaDev.get()), mLd * mParamWordsize); + serFromDev(d, static_cast(mGammaDev.get()), mLd * mParamWordsize); + if (mHasBias) { + serFromDev(d, static_cast(mBiasDev.get()), mLd * mParamWordsize); + } + } catch (std::exception const& e) { + caughtError(e); + } +} + +void SkipLayerNormInt8PluginBase::destroy() noexcept { + try { + // This gets called when the network containing plugin is destroyed + mGammaDev.reset(nullptr); + mBetaDev.reset(nullptr); + if (mHasBias) { + mBiasDev.reset(nullptr); + } + delete this; + } catch (std::exception const& e) { + caughtError(e); + } +} + +void SkipLayerNormInt8PluginBase::setPluginNamespace(char const* libNamespace) noexcept { mNamespace = libNamespace; } + +char const* SkipLayerNormInt8PluginBase::getPluginNamespace() const noexcept { return mNamespace.c_str(); } + +// HFace +int32_t SkipLayerNormInt8PluginHFace::initialize() noexcept { + gLogInfo << "SkipLayerNormInterleavedPluginHFace initialize" << endl; + return 0; +} + +void SkipLayerNormInt8PluginHFace::terminate() noexcept { + gLogInfo << "SkipLayerNormInterleavedPluginHFace terminate" << endl; +} + +void SkipLayerNormInt8PluginHFace::destroy() noexcept { + gLogInfo << "SkipLayerNormInterleavedPluginHFace destroy" << endl; + SkipLayerNormInt8PluginBase::destroy(); +} + +char const* SkipLayerNormInt8PluginHFace::getPluginVersion() const noexcept { + return kSKIP_LAYER_NORM_INT8_VERSION_HFACE; +} + +int32_t SkipLayerNormInt8PluginHFace::getNbOutputs() const noexcept { return 2; } + +// IPluginV2Ext Methods +DataType SkipLayerNormInt8PluginBase::getOutputDataType(int32_t index, DataType const* inputTypes, + int32_t nbInputs) const noexcept { + try { + IXRT_PLUGIN_ASSERT(inputTypes != nullptr); + IXRT_PLUGIN_ASSERT(index >= 0 && index < getNbOutputs()); + IXRT_PLUGIN_ASSERT(nbInputs == 3); + if (index == 0) { + return output_fp32 ? DataType::kHALF : DataType::kINT8; + } + return DataType::kHALF; + } catch (std::exception const& e) { + caughtError(e); + } + return DataType{}; +} + +// IPluginV2DynamicExt Methods +DimsExprs SkipLayerNormInt8PluginBase::getOutputDimensions(int32_t outputIndex, DimsExprs const* inputs, + int32_t nbInputs, IExprBuilder& exprBuilder) noexcept { + try { + IXRT_PLUGIN_ASSERT(inputs != nullptr); + IXRT_PLUGIN_ASSERT(nbInputs == 3); + IXRT_PLUGIN_ASSERT(outputIndex >= 0 && outputIndex < getNbOutputs()); + IXRT_PLUGIN_ASSERT(inputs[0].nbDims == inputs[1].nbDims); + IXRT_PLUGIN_ASSERT(inputs[0].nbDims == inputs[1].nbDims); + return inputs[0]; + } catch (std::exception const& e) { + caughtError(e); + } + return DimsExprs{}; +} + +bool SkipLayerNormInt8PluginBase::supportsFormatCombination(int32_t pos, PluginTensorDesc const* inOut, + int32_t nbInputs, int32_t nbOutputs) noexcept { + try { + IXRT_PLUGIN_ASSERT(inOut != nullptr); + IXRT_PLUGIN_ASSERT(nbInputs == 3); + IXRT_PLUGIN_ASSERT(nbOutputs == getNbOutputs()); + IXRT_PLUGIN_ASSERT(pos >= 0 && pos < (nbInputs + nbOutputs)); + + PluginTensorDesc const& desc = inOut[pos]; + if (pos == 2 || pos == 4 || (output_fp32 && pos == 3)) { + return desc.type == DataType::kHALF && desc.format == TensorFormat::kLINEAR; + } + return desc.type == DataType::kINT8 && desc.format == TensorFormat::kLINEAR; + } catch (std::exception const& e) { + caughtError(e); + } + return false; +} + +void SkipLayerNormInt8PluginBase::configurePlugin(DynamicPluginTensorDesc const* inputs, int32_t nbInputs, + DynamicPluginTensorDesc const* outputs, int32_t nbOutputs) noexcept { + try { + // Validate input arguments + IXRT_PLUGIN_ASSERT(inputs != nullptr); + IXRT_PLUGIN_ASSERT(outputs != nullptr); + IXRT_PLUGIN_ASSERT(nbOutputs == getNbOutputs()); + IXRT_PLUGIN_ASSERT(nbInputs == 3); + IXRT_PLUGIN_ASSERT(DataType::kINT8 == inputs[0].desc.type); + IXRT_PLUGIN_ASSERT(DataType::kINT8 == inputs[1].desc.type); + IXRT_PLUGIN_ASSERT(DataType::kHALF == inputs[2].desc.type); + + auto const& inDims0 = inputs[0].desc.dims; + auto const& inDims1 = inputs[1].desc.dims; + auto const& inDims2 = inputs[2].desc.dims; + TRT_UNUSED inDims1; + TRT_UNUSED inDims2; + + IXRT_PLUGIN_ASSERT(inDims0.nbDims == inDims1.nbDims); + IXRT_PLUGIN_ASSERT(std::equal(inDims0.d, inDims0.d + inDims0.nbDims, inDims1.d)); + IXRT_PLUGIN_ASSERT(inDims0.nbDims == inDims2.nbDims); + IXRT_PLUGIN_ASSERT(std::equal(inDims0.d, inDims0.d + inDims0.nbDims, inDims2.d)); + + mParamWordsize = getElementSize(param_type); + + // if (!mParamsOnDevice) { + // copyToDevice(mGamma, getWeightsSize(mGamma, param_type), mGammaDev); + // copyToDevice(mBeta, getWeightsSize(mBeta, param_type), mBetaDev); + // copyToDevice(mBias, getWeightsSize(mBias, param_type), mBiasDev); + // mParamsOnDevice = true; + // } + } catch (std::exception const& e) { + caughtError(e); + } +} + +size_t SkipLayerNormInt8PluginBase::getWorkspaceSize(PluginTensorDesc const* inputs, int32_t nbInputs, + PluginTensorDesc const* outputs, + int32_t nbOutputs) const noexcept { + return 0; +} + +// HFace IPluginV2DynamicExt Methods +IPluginV2DynamicExt* SkipLayerNormInt8PluginHFace::clone() const noexcept { + try { + gLogInfo << "SkipLayerNormInterleavedPluginHFace clone" << endl; + auto* p = new SkipLayerNormInt8PluginHFace(mLayerName, mBeta, mGamma, mBias, output_fp32); + p->initialize(); + p->setPluginNamespace(mNamespace.c_str()); + return p; + } catch (std::exception const& e) { + caughtError(e); + } + return nullptr; +} + +int32_t SkipLayerNormInt8PluginHFace::enqueue(PluginTensorDesc const* inputDesc, PluginTensorDesc const* outputDesc, + void const* const* inputs, void* const* outputs, void* workspace, + cudaStream_t stream) noexcept { + try { + IXRT_PLUGIN_ASSERT(inputs != nullptr); + IXRT_PLUGIN_ASSERT(outputs != nullptr); + auto const iDesc = inputDesc[0]; + auto const oDesc = outputDesc[0]; + + const int32_t B = iDesc.dims.d[0]; + const int32_t S = iDesc.dims.d[1]; + const int32_t E = iDesc.dims.d[2]; + int batch_token_num = B * S; + float const dqScaleIn = iDesc.scale; + IXRT_PLUGIN_ASSERT(dqScaleIn > 1e-9); + float const qScale = oDesc.scale; + int8_t const* input = static_cast(inputs[0]); + int8_t const* skip = static_cast(inputs[1]); + half* residual = (half*)inputs[2]; + half const* gamma = static_cast(mGammaDev.get()); + half const* beta = static_cast(mBetaDev.get()); + half const* bias = static_cast(mBiasDev.get()); + half* residual_out = static_cast(outputs[1]); + + if (!output_fp32) { + int8_t* output = static_cast(outputs[0]); + skipLayerNormI8II8O(input, gamma, beta, bias, output, residual, residual_out, batch_token_num, E, dqScaleIn, + 1.0 / qScale, 1024, stream, true); + } else { + half* output = static_cast(outputs[0]); + skipLayerNormI8IF16O(input, gamma, beta, bias, output, residual, residual_out, batch_token_num, E, + 1.0 / dqScaleIn, 1.0 / qScale, 1024, stream, true); + } + return cudaSuccess; + } catch (std::exception const& e) { + caughtError(e); + } + return -1; +} diff --git a/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/skip_layernorm/skipLayerNormInt8Plugin.cu b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/skip_layernorm/skipLayerNormInt8Plugin.cu new file mode 100644 index 0000000000000000000000000000000000000000..adda3b02ccd6b322156add03963d0a7e36c4f4a6 --- /dev/null +++ b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/skip_layernorm/skipLayerNormInt8Plugin.cu @@ -0,0 +1,529 @@ +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +* All Rights Reserved. +* +* Licensed under the Apache License, Version 2.0 (the "License"); you may +* not use this file except in compliance with the License. You may obtain +* a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +* License for the specific language governing permissions and limitations +* under the License. +*/ +#include "backend/bert/bert_helper.h" +#include "skipLayerNormInt8Plugin.h" +using namespace nvinfer1::plugin::backend; + +namespace nvinfer1::plugin { +namespace bert { + +template +__global__ void skipLayernormI8II8OKernel(const int8_t *input, const half *scale, const half *bias, + const half *residual_bias, int8_t *output, half *residual, half *residual_out, + int hidden_size, float dequant_scale, float quant_scale, bool is_post_ln) { + // register + // process 2 data + float4 vals[THREAD_DATA_LEN]; + int block_start = blockIdx.x * hidden_size / 4; + char4 *p_input = (char4 *)input; + char4 *p_output = (char4 *)output; + half2 *p_residual = (half2 *)residual; + half2 *p_residual_out = (half2 *)residual_out; + half2 *p_scale = (half2 *)scale; + half2 *p_bias = (half2 *)bias; + half2 *p_residual_bias = (half2 *)residual_bias; + // one line start + p_input += block_start; + p_output += block_start; + p_residual += block_start * 2; + p_residual_out += block_start * 2; + + float thread_m2 = 0; + float thread_mean = 0; + float thread_count = 0; + + // load data from global memory +#pragma unroll + for (int it = 0; it < THREAD_DATA_LEN; ++it) { + int element_index = threadIdx.x + it * C10_WARP_SIZE; + // vals = dequant(input) + residual + bias + half2 p_residual_value1; + half2 p_residual_value2; + p_residual_value1.x = __hadd(p_residual[element_index << 1].x, p_residual_bias[element_index << 1].x); + p_residual_value1.y = __hadd(p_residual[element_index << 1].y, p_residual_bias[element_index << 1].y); + p_residual_value2.x = __hadd(p_residual[element_index << 1 | 1].x, p_residual_bias[element_index << 1 | 1].x); + p_residual_value2.y = __hadd(p_residual[element_index << 1 | 1].y, p_residual_bias[element_index << 1 | 1].y); + vals[it] = char4addhalf2_dequant(p_input[element_index], p_residual_value1, p_residual_value2, dequant_scale); + WelfordCombine(vals[it].x, &thread_mean, &thread_m2, &thread_count); + WelfordCombine(vals[it].y, &thread_mean, &thread_m2, &thread_count); + WelfordCombine(vals[it].z, &thread_mean, &thread_m2, &thread_count); + WelfordCombine(vals[it].w, &thread_mean, &thread_m2, &thread_count); + } + + // mean var + float mean = 0; + float m2 = 0; + float count = 0; + WelfordWarpReduce(thread_mean, thread_m2, thread_count, &mean, &m2, &count); + mean = __shfl_sync(0xffffffff, mean, 0, C10_WARP_SIZE); + m2 = __shfl_sync(0xffffffff, m2, 0, C10_WARP_SIZE); + count = __shfl_sync(0xffffffff, count, 0, C10_WARP_SIZE); + +#pragma unroll + for (int it = 0; it < THREAD_DATA_LEN; ++it) { + int element_index = threadIdx.x + it * C10_WARP_SIZE; + float4 norm_value = compute_float4_norm_value(vals[it], mean, m2, hidden_size, epsilon, + p_scale[element_index << 1], p_scale[element_index << 1 | 1], + p_bias[element_index << 1], p_bias[element_index << 1 | 1]); + + p_residual_out[element_index << 1].x = __float2half(norm_value.x); + p_residual_out[element_index << 1].y = __float2half(norm_value.y); + p_residual_out[element_index << 1 | 1].x = __float2half(norm_value.z); + p_residual_out[element_index << 1 | 1].y = __float2half(norm_value.w); + + char4 res = float42char4(norm_value, quant_scale); + p_output[element_index] = res; + } +} + +template +__global__ void skipLayernormI8IF16OKernel(const int8_t *input, const __half *scale, const __half *bias, + const __half *residual_bias, half *output, __half *residual, + half *residual_out, int hidden_size, float dequant_scale, float quant_scale, + bool is_post_ln) { + // register + // process 2 data + float4 vals[THREAD_DATA_LEN]; + int block_start = blockIdx.x * hidden_size / 4; + char4 *p_input = (char4 *)input; + half2 *p_output = (half2 *)output; + half2 *p_residual = (half2 *)residual; + half2 *p_residual_out = (half2 *)residual_out; + half2 *p_scale = (half2 *)scale; + half2 *p_bias = (half2 *)bias; + half2 *p_residual_bias = (half2 *)residual_bias; + // one line start + p_input += block_start; + p_output += block_start * 2; + p_residual += block_start * 2; + p_residual_out += block_start * 2; + + float thread_m2 = 0; + float thread_mean = 0; + float thread_count = 0; + + // load data from global memory +#pragma unroll + for (int it = 0; it < THREAD_DATA_LEN; ++it) { + int element_index = threadIdx.x + it * C10_WARP_SIZE; + // vals = dequant(input) + residual + bias + half2 p_residual_value1, p_residual_value2; + p_residual_value1 = p_residual[element_index * 2] + p_residual_bias[element_index * 2]; + p_residual_value2 = p_residual[element_index * 2 + 1] + p_residual_bias[element_index * 2 + 1]; + vals[it] = char4addhalf2_dequant(p_input[element_index], p_residual_value1, p_residual_value2, dequant_scale); + WelfordCombine(vals[it].x, &thread_mean, &thread_m2, &thread_count); + WelfordCombine(vals[it].y, &thread_mean, &thread_m2, &thread_count); + WelfordCombine(vals[it].z, &thread_mean, &thread_m2, &thread_count); + WelfordCombine(vals[it].w, &thread_mean, &thread_m2, &thread_count); + } + + // mean var + float mean = 0; + float m2 = 0; + float count = 0; + WelfordWarpReduce(thread_mean, thread_m2, thread_count, &mean, &m2, &count); + mean = __shfl_sync(0xffffffff, mean, 0, C10_WARP_SIZE); + m2 = __shfl_sync(0xffffffff, m2, 0, C10_WARP_SIZE); + count = __shfl_sync(0xffffffff, count, 0, C10_WARP_SIZE); + +#pragma unroll + for (int it = 0; it < THREAD_DATA_LEN; ++it) { + int element_index = threadIdx.x + it * C10_WARP_SIZE; + float4 norm_value = compute_float4_norm_value(vals[it], mean, m2, hidden_size, epsilon, + p_scale[element_index * 2], p_scale[element_index * 2 + 1], + p_bias[element_index * 2], p_bias[element_index * 2 + 1]); + + p_output[element_index << 1].x = __float2half(norm_value.x); + p_output[element_index << 1].y = __float2half(norm_value.y); + p_output[element_index << 1 | 1].x = __float2half(norm_value.z); + p_output[element_index << 1 | 1].y = __float2half(norm_value.w); + } +} + +template +__global__ void skipLayernormI8IF32OKernel(const int8_t *input, const float *scale, const float *bias, + const float *residual_bias, float *output, float *residual, + float *residual_out, int hidden_size, float dequant_scale, float quant_scale, + bool is_post_ln) { + // register + // process 2 data + float4 vals[THREAD_DATA_LEN]; + int block_start = blockIdx.x * hidden_size / 4; + char4 *p_input = (char4 *)input; + float4 *p_output = (float4 *)output; + float4 *p_residual = (float4 *)residual; + float4 *p_residual_out = (float4 *)residual_out; + float4 *p_scale = (float4 *)scale; + float4 *p_bias = (float4 *)bias; + float4 *p_residual_bias = (float4 *)residual_bias; + // one line start + p_input += block_start; + p_output += block_start; + p_residual += block_start; + p_residual_out += block_start; + + float thread_m2 = 0; + float thread_mean = 0; + float thread_count = 0; + + // load data from global memory +#pragma unroll + for (int it = 0; it < THREAD_DATA_LEN; ++it) { + int element_index = threadIdx.x + it * C10_WARP_SIZE; + // vals = dequant(input) + residual + bias + + float4 p_residual_value; + p_residual_value.x = p_residual[element_index].x + p_residual_bias[element_index].x; + p_residual_value.y = p_residual[element_index].y + p_residual_bias[element_index].y; + p_residual_value.z = p_residual[element_index].z + p_residual_bias[element_index].z; + p_residual_value.w = p_residual[element_index].w + p_residual_bias[element_index].w; + vals[it] = char4addfloat4_dequant(p_input[element_index], p_residual_value, dequant_scale); + WelfordCombine(vals[it].x, &thread_mean, &thread_m2, &thread_count); + WelfordCombine(vals[it].y, &thread_mean, &thread_m2, &thread_count); + WelfordCombine(vals[it].z, &thread_mean, &thread_m2, &thread_count); + WelfordCombine(vals[it].w, &thread_mean, &thread_m2, &thread_count); + } + + // mean var + float mean = 0; + float m2 = 0; + float count = 0; + WelfordWarpReduce(thread_mean, thread_m2, thread_count, &mean, &m2, &count); + mean = __shfl_sync(0xffffffff, mean, 0, C10_WARP_SIZE); + m2 = __shfl_sync(0xffffffff, m2, 0, C10_WARP_SIZE); + count = __shfl_sync(0xffffffff, count, 0, C10_WARP_SIZE); + +#pragma unroll + for (int it = 0; it < THREAD_DATA_LEN; ++it) { + int element_index = threadIdx.x + it * C10_WARP_SIZE; + float4 norm_value = compute_float4_norm_value(vals[it], mean, m2, hidden_size, epsilon, p_scale[element_index], + p_bias[element_index]); + + p_output[element_index].x = norm_value.x; + p_output[element_index].y = norm_value.y; + p_output[element_index].z = norm_value.z; + p_output[element_index].w = norm_value.w; + } +} + +void skipLayerNormI8II8O(const int8_t *input, const half *scale, const half *bias, const half *residual_bias, + int8_t *output, half *residual, half *residual_out, int batch_tokens, int hidden_size, + float dequant_scale, float quant_scale, int max_thread_per_block, cudaStream_t stream, + bool is_post_ln) { + if (hidden_size > 1024) { + throw std::runtime_error("hidden_size should <= 1024"); + } + if (hidden_size % C10_WARP_SIZE != 0) { + throw std::runtime_error("hidden_size // C10_WARP_SIZE != 0"); + } + dim3 gridSize(batch_tokens); + dim3 blockSize(C10_WARP_SIZE); + + int num_warp = hidden_size / C10_WARP_SIZE / 4; + + switch (num_warp) { + case 1: + skipLayernormI8II8OKernel<1><<>>(input, scale, bias, residual_bias, output, + residual, residual_out, hidden_size, + dequant_scale, quant_scale, is_post_ln); + break; + case 2: + skipLayernormI8II8OKernel<2><<>>(input, scale, bias, residual_bias, output, + residual, residual_out, hidden_size, + dequant_scale, quant_scale, is_post_ln); + break; + case 3: + skipLayernormI8II8OKernel<3><<>>(input, scale, bias, residual_bias, output, + residual, residual_out, hidden_size, + dequant_scale, quant_scale, is_post_ln); + break; + case 4: + skipLayernormI8II8OKernel<4><<>>(input, scale, bias, residual_bias, output, + residual, residual_out, hidden_size, + dequant_scale, quant_scale, is_post_ln); + break; + case 5: + skipLayernormI8II8OKernel<5><<>>(input, scale, bias, residual_bias, output, + residual, residual_out, hidden_size, + dequant_scale, quant_scale, is_post_ln); + break; + case 6: + skipLayernormI8II8OKernel<6><<>>(input, scale, bias, residual_bias, output, + residual, residual_out, hidden_size, + dequant_scale, quant_scale, is_post_ln); + break; + case 7: + skipLayernormI8II8OKernel<7><<>>(input, scale, bias, residual_bias, output, + residual, residual_out, hidden_size, + dequant_scale, quant_scale, is_post_ln); + break; + case 8: + skipLayernormI8II8OKernel<8><<>>(input, scale, bias, residual_bias, output, + residual, residual_out, hidden_size, + dequant_scale, quant_scale, is_post_ln); + break; + case 9: + skipLayernormI8II8OKernel<9><<>>(input, scale, bias, residual_bias, output, + residual, residual_out, hidden_size, + dequant_scale, quant_scale, is_post_ln); + break; + case 10: + skipLayernormI8II8OKernel<10><<>>(input, scale, bias, residual_bias, output, + residual, residual_out, hidden_size, + dequant_scale, quant_scale, is_post_ln); + break; + case 11: + skipLayernormI8II8OKernel<11><<>>(input, scale, bias, residual_bias, output, + residual, residual_out, hidden_size, + dequant_scale, quant_scale, is_post_ln); + break; + case 12: + skipLayernormI8II8OKernel<12><<>>(input, scale, bias, residual_bias, output, + residual, residual_out, hidden_size, + dequant_scale, quant_scale, is_post_ln); + break; + case 13: + skipLayernormI8II8OKernel<13><<>>(input, scale, bias, residual_bias, output, + residual, residual_out, hidden_size, + dequant_scale, quant_scale, is_post_ln); + break; + case 14: + skipLayernormI8II8OKernel<14><<>>(input, scale, bias, residual_bias, output, + residual, residual_out, hidden_size, + dequant_scale, quant_scale, is_post_ln); + break; + case 15: + skipLayernormI8II8OKernel<15><<>>(input, scale, bias, residual_bias, output, + residual, residual_out, hidden_size, + dequant_scale, quant_scale, is_post_ln); + break; + case 16: + skipLayernormI8II8OKernel<16><<>>(input, scale, bias, residual_bias, output, + residual, residual_out, hidden_size, + dequant_scale, quant_scale, is_post_ln); + break; + default: + throw std::runtime_error("skipLayernormI8II8OKernel"); + break; + } +} + +void skipLayerNormI8IF16O(const int8_t *input, const half *scale, const half *bias, const half *residual_bias, + half *output, half *residual, half *residual_out, int batch_tokens, int hidden_size, + float dequant_scale, float quant_scale, int max_thread_per_block, cudaStream_t stream, + bool is_post_ln) { + if (hidden_size > 1024) { + throw std::runtime_error("hidden_size should <= 1024"); + } + if (hidden_size % C10_WARP_SIZE != 0) { + throw std::runtime_error("hidden_size // C10_WARP_SIZE != 0"); + } + dim3 gridSize(batch_tokens); + dim3 blockSize(C10_WARP_SIZE); + + int num_warp = hidden_size / C10_WARP_SIZE / 4; + + switch (num_warp) { + case 1: + skipLayernormI8IF16OKernel<1><<>>(input, scale, bias, residual_bias, output, + residual, residual_out, hidden_size, + dequant_scale, quant_scale, is_post_ln); + break; + case 2: + skipLayernormI8IF16OKernel<2><<>>(input, scale, bias, residual_bias, output, + residual, residual_out, hidden_size, + dequant_scale, quant_scale, is_post_ln); + break; + case 3: + skipLayernormI8IF16OKernel<3><<>>(input, scale, bias, residual_bias, output, + residual, residual_out, hidden_size, + dequant_scale, quant_scale, is_post_ln); + break; + case 4: + skipLayernormI8IF16OKernel<4><<>>(input, scale, bias, residual_bias, output, + residual, residual_out, hidden_size, + dequant_scale, quant_scale, is_post_ln); + break; + case 5: + skipLayernormI8IF16OKernel<5><<>>(input, scale, bias, residual_bias, output, + residual, residual_out, hidden_size, + dequant_scale, quant_scale, is_post_ln); + break; + case 6: + skipLayernormI8IF16OKernel<6><<>>(input, scale, bias, residual_bias, output, + residual, residual_out, hidden_size, + dequant_scale, quant_scale, is_post_ln); + break; + case 7: + skipLayernormI8IF16OKernel<7><<>>(input, scale, bias, residual_bias, output, + residual, residual_out, hidden_size, + dequant_scale, quant_scale, is_post_ln); + break; + case 8: + skipLayernormI8IF16OKernel<8><<>>(input, scale, bias, residual_bias, output, + residual, residual_out, hidden_size, + dequant_scale, quant_scale, is_post_ln); + break; + case 9: + skipLayernormI8IF16OKernel<9><<>>(input, scale, bias, residual_bias, output, + residual, residual_out, hidden_size, + dequant_scale, quant_scale, is_post_ln); + break; + case 10: + skipLayernormI8IF16OKernel<10> + <<>>(input, scale, bias, residual_bias, output, residual, residual_out, + hidden_size, dequant_scale, quant_scale, is_post_ln); + break; + case 11: + skipLayernormI8IF16OKernel<11> + <<>>(input, scale, bias, residual_bias, output, residual, residual_out, + hidden_size, dequant_scale, quant_scale, is_post_ln); + break; + case 12: + skipLayernormI8IF16OKernel<12> + <<>>(input, scale, bias, residual_bias, output, residual, residual_out, + hidden_size, dequant_scale, quant_scale, is_post_ln); + break; + case 13: + skipLayernormI8IF16OKernel<13> + <<>>(input, scale, bias, residual_bias, output, residual, residual_out, + hidden_size, dequant_scale, quant_scale, is_post_ln); + break; + case 14: + skipLayernormI8IF16OKernel<14> + <<>>(input, scale, bias, residual_bias, output, residual, residual_out, + hidden_size, dequant_scale, quant_scale, is_post_ln); + break; + case 15: + skipLayernormI8IF16OKernel<15> + <<>>(input, scale, bias, residual_bias, output, residual, residual_out, + hidden_size, dequant_scale, quant_scale, is_post_ln); + break; + case 16: + skipLayernormI8IF16OKernel<16> + <<>>(input, scale, bias, residual_bias, output, residual, residual_out, + hidden_size, dequant_scale, quant_scale, is_post_ln); + break; + default: + throw std::runtime_error("skipLayernormI8II8OKernel"); + break; + } +} + +void skipLayerNormI8IF32O(const int8_t *input, const float *scale, const float *bias, const float *residual_bias, + float *output, float *residual, float *residual_out, int batch_tokens, int hidden_size, + float dequant_scale, float quant_scale, int max_thread_per_block, cudaStream_t stream, + bool is_post_ln) { + if (hidden_size > 1024) { + throw std::runtime_error("hidden_size should <= 1024"); + } + if (hidden_size % C10_WARP_SIZE != 0) { + throw std::runtime_error("hidden_size // C10_WARP_SIZE != 0"); + } + dim3 gridSize(batch_tokens); + dim3 blockSize(C10_WARP_SIZE); + + int num_warp = hidden_size / C10_WARP_SIZE / 4; + + switch (num_warp) { + case 1: + skipLayernormI8IF32OKernel<1><<>>(input, scale, bias, residual_bias, output, + residual, residual_out, hidden_size, + dequant_scale, quant_scale, is_post_ln); + break; + case 2: + skipLayernormI8IF32OKernel<2><<>>(input, scale, bias, residual_bias, output, + residual, residual_out, hidden_size, + dequant_scale, quant_scale, is_post_ln); + break; + case 3: + skipLayernormI8IF32OKernel<3><<>>(input, scale, bias, residual_bias, output, + residual, residual_out, hidden_size, + dequant_scale, quant_scale, is_post_ln); + break; + case 4: + skipLayernormI8IF32OKernel<4><<>>(input, scale, bias, residual_bias, output, + residual, residual_out, hidden_size, + dequant_scale, quant_scale, is_post_ln); + break; + case 5: + skipLayernormI8IF32OKernel<5><<>>(input, scale, bias, residual_bias, output, + residual, residual_out, hidden_size, + dequant_scale, quant_scale, is_post_ln); + break; + case 6: + skipLayernormI8IF32OKernel<6><<>>(input, scale, bias, residual_bias, output, + residual, residual_out, hidden_size, + dequant_scale, quant_scale, is_post_ln); + break; + case 7: + skipLayernormI8IF32OKernel<7><<>>(input, scale, bias, residual_bias, output, + residual, residual_out, hidden_size, + dequant_scale, quant_scale, is_post_ln); + break; + case 8: + skipLayernormI8IF32OKernel<8><<>>(input, scale, bias, residual_bias, output, + residual, residual_out, hidden_size, + dequant_scale, quant_scale, is_post_ln); + break; + case 9: + skipLayernormI8IF32OKernel<9><<>>(input, scale, bias, residual_bias, output, + residual, residual_out, hidden_size, + dequant_scale, quant_scale, is_post_ln); + break; + case 10: + skipLayernormI8IF32OKernel<10> + <<>>(input, scale, bias, residual_bias, output, residual, residual_out, + hidden_size, dequant_scale, quant_scale, is_post_ln); + break; + case 11: + skipLayernormI8IF32OKernel<11> + <<>>(input, scale, bias, residual_bias, output, residual, residual_out, + hidden_size, dequant_scale, quant_scale, is_post_ln); + break; + case 12: + skipLayernormI8IF32OKernel<12> + <<>>(input, scale, bias, residual_bias, output, residual, residual_out, + hidden_size, dequant_scale, quant_scale, is_post_ln); + break; + case 13: + skipLayernormI8IF32OKernel<13> + <<>>(input, scale, bias, residual_bias, output, residual, residual_out, + hidden_size, dequant_scale, quant_scale, is_post_ln); + break; + case 14: + skipLayernormI8IF32OKernel<14> + <<>>(input, scale, bias, residual_bias, output, residual, residual_out, + hidden_size, dequant_scale, quant_scale, is_post_ln); + break; + case 15: + skipLayernormI8IF32OKernel<15> + <<>>(input, scale, bias, residual_bias, output, residual, residual_out, + hidden_size, dequant_scale, quant_scale, is_post_ln); + break; + case 16: + skipLayernormI8IF32OKernel<16> + <<>>(input, scale, bias, residual_bias, output, residual, residual_out, + hidden_size, dequant_scale, quant_scale, is_post_ln); + break; + default: + throw std::runtime_error("skipLayernormI8II8OKernel"); + break; + } +} + +} // namespace bert +} // namespace nvinfer1::plugin diff --git a/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/skip_layernorm/skipLayerNormInt8Plugin.h b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/skip_layernorm/skipLayerNormInt8Plugin.h new file mode 100644 index 0000000000000000000000000000000000000000..82a873cdc0e1bce0a67c633870fd9d91c26db723 --- /dev/null +++ b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/skip_layernorm/skipLayerNormInt8Plugin.h @@ -0,0 +1,148 @@ +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +* All Rights Reserved. +* +* Licensed under the Apache License, Version 2.0 (the "License"); you may +* not use this file except in compliance with the License. You may obtain +* a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +* License for the specific language governing permissions and limitations +* under the License. +*/ +#pragma once + +#include +#include + +#include "NvInferRuntime.h" +#include "bertCommon.h" + +namespace nvinfer1::plugin { +namespace bert { + +void skipLayerNormI8II8O(const int8_t* input, const half* scale, const half* bias, const half* residual_bias, + int8_t* output, half* residual, half* residual_out, int batch_tokens, int hidden_size, + float dequant_scale, float quant_scale, int max_thread_per_block, cudaStream_t stream, + bool is_post_ln); + +void skipLayerNormI8IF16O(const int8_t* input, const half* scale, const half* bias, const half* residual_bias, + half* output, half* residual, half* residual_out, int batch_tokens, int hidden_size, + float dequant_scale, float quant_scale, int max_thread_per_block, cudaStream_t stream, + bool is_post_ln); + +void skipLayerNormI8IF32O(const int8_t* input, const float* scale, const float* bias, const float* residual_bias, + float* output, float* residual, float* residual_out, int batch_tokens, int hidden_size, + float dequant_scale, float quant_scale, int max_thread_per_block, cudaStream_t stream, + bool is_post_ln); + +class SkipLayerNormInt8PluginBase : public nvinfer1::IPluginV2DynamicExt { + public: + SkipLayerNormInt8PluginBase(std::string const& name, nvinfer1::Weights const& beta, nvinfer1::Weights const& gamma, + nvinfer1::Weights const& bias, bool output_fp32); + + SkipLayerNormInt8PluginBase(std::string const& name, void const* data, size_t length); + + // It doesn't make sense to make SkipLayerNormInterleavedPlugin without + // arguments, so we delete default constructor. + SkipLayerNormInt8PluginBase() = delete; + + // IPluginV2 Methods + char const* getPluginType() const noexcept override; + size_t getSerializationSize() const noexcept override; + void serialize(void* buffer) const noexcept override; + void destroy() noexcept override; + void setPluginNamespace(char const* pluginNamespace) noexcept override; + char const* getPluginNamespace() const noexcept override; + + // IPluginV2Ext Methods + nvinfer1::DataType getOutputDataType(int32_t index, nvinfer1::DataType const* inputTypes, + int32_t nbInputs) const noexcept override; + + // IPluginV2DynamicExt Methods + nvinfer1::DimsExprs getOutputDimensions(int32_t outputIndex, nvinfer1::DimsExprs const* inputs, int32_t nbInputs, + nvinfer1::IExprBuilder& exprBuilder) noexcept override; + bool supportsFormatCombination(int32_t pos, nvinfer1::PluginTensorDesc const* inOut, int32_t nbInputs, + int32_t nbOutputs) noexcept override; + void configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int32_t nbInputs, + nvinfer1::DynamicPluginTensorDesc const* out, int32_t nbOutputs) noexcept override; + size_t getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int32_t nbInputs, + nvinfer1::PluginTensorDesc const* outputs, int32_t nbOutputs) const noexcept override; + + protected: + std::string const& mLayerName; + std::string mNamespace; + + bert::cuda_unique_ptr mGammaDev; + bert::cuda_unique_ptr mBetaDev; + size_t mLd{}; // leading dim + bert::WeightsWithOwnership mGamma; + bert::WeightsWithOwnership mBeta; + + size_t mParamWordsize{}; + bool mParamsOnDevice{}; + bool mHasBias{}; + cuda_unique_ptr mBiasDev; + WeightsWithOwnership mBias; + bool output_fp32{}; +}; + +class SkipLayerNormInt8PluginHFace : public SkipLayerNormInt8PluginBase { + public: + SkipLayerNormInt8PluginHFace(std::string const& name, nvinfer1::Weights const& beta, nvinfer1::Weights const& gamma, + nvinfer1::Weights const& bias, bool output_fp32); + + SkipLayerNormInt8PluginHFace(std::string const& name, void const* data, size_t length); + + // It doesn't make sense to make SkipLayerNormInterleavedPlugin without + // arguments, so we delete default constructor. + SkipLayerNormInt8PluginHFace() = delete; + + // IPluginV2DynamicExt Methods + nvinfer1::IPluginV2DynamicExt* clone() const noexcept override; + int32_t enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc, + void const* const* inputs, void* const* outputs, void* workspace, + cudaStream_t stream) noexcept override; + + // IPluginV2 Methods + int32_t initialize() noexcept override; + void terminate() noexcept override; + void destroy() noexcept override; + int32_t getNbOutputs() const noexcept override; + char const* getPluginVersion() const noexcept override; +}; + +class SkipLayerNormInt8PluginBaseCreator : public nvinfer1::IPluginCreator { + public: + SkipLayerNormInt8PluginBaseCreator(); + + char const* getPluginName() const noexcept override; + + nvinfer1::PluginFieldCollection const* getFieldNames() noexcept override; + + void setPluginNamespace(char const* pluginNamespace) noexcept override; + + char const* getPluginNamespace() const noexcept override; + + private: + static nvinfer1::PluginFieldCollection mFC; + static std::vector mPluginAttributes; + std::string mNamespace; +}; + +class SkipLayerNormInt8PluginHFaceCreator : public SkipLayerNormInt8PluginBaseCreator { + public: + SkipLayerNormInt8PluginHFaceCreator(); + + char const* getPluginVersion() const noexcept override; + + nvinfer1::IPluginV2* createPlugin(char const* name, nvinfer1::PluginFieldCollection const* fc) noexcept override; + nvinfer1::IPluginV2* deserializePlugin(char const* name, void const* serialData, + size_t serialLength) noexcept override; +}; + +} // namespace bert +} // namespace nvinfer1::plugin diff --git a/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/skip_layernorm/skipLayerNormPlugin.cpp b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/skip_layernorm/skipLayerNormPlugin.cpp new file mode 100644 index 0000000000000000000000000000000000000000..12d9ef284d0836f3706756db6740ed20aed070e8 --- /dev/null +++ b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/skip_layernorm/skipLayerNormPlugin.cpp @@ -0,0 +1,416 @@ +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +* All Rights Reserved. +* +* Licensed under the Apache License, Version 2.0 (the "License"); you may +* not use this file except in compliance with the License. You may obtain +* a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +* License for the specific language governing permissions and limitations +* under the License. +*/ +#include "skipLayerNormPlugin.h" + +#include "NvInferRuntimeCommon.h" +#include "bertCommon.h" +#include "checkMacrosPlugin.h" +#include "plugin.h" +#include "serialize.h" + +using namespace nvinfer1; +using namespace nvinfer1::plugin; +using namespace nvinfer1::plugin::bert; + +namespace { +char const* kSKIP_LAYER_NORM_VERSION{"1"}; +char const* kSKIP_LAYER_NORM_NAME{"CustomSkipLayerNormPluginDynamic_IxRT"}; +char const* kSKIP_LAYER_NORM_VAR_SEQLEN_VERSION{"2"}; +} // namespace + +// Static class fields initialization +PluginFieldCollection SkipLayerNormPluginDynamicCreator::mFC{}; +std::vector SkipLayerNormPluginDynamicCreator::mPluginAttributes; + +// REGISTER_TENSORRT_PLUGIN(SkipLayerNormPluginDynamicCreator); + +static inline DataType getParamWordType(DataType cfgType) noexcept { + if (cfgType == DataType::kINT8) { + return DataType::kHALF; + } + + return cfgType; +} + +SkipLayerNormPluginDynamicCreator::SkipLayerNormPluginDynamicCreator() { + mPluginAttributes.clear(); + mPluginAttributes.emplace_back(PluginField("ld", nullptr, PluginFieldType::kINT32, 1)); + mPluginAttributes.emplace_back(PluginField("type_id", nullptr, PluginFieldType::kINT32, 1)); + mPluginAttributes.emplace_back(PluginField("beta", nullptr, PluginFieldType::kFLOAT32)); + mPluginAttributes.emplace_back(PluginField("gamma", nullptr, PluginFieldType::kFLOAT32)); + mPluginAttributes.emplace_back(PluginField("bias", nullptr, PluginFieldType::kFLOAT32)); + mFC.nbFields = mPluginAttributes.size(); + mFC.fields = mPluginAttributes.data(); +} + +char const* SkipLayerNormPluginDynamicCreator::getPluginName() const noexcept { return kSKIP_LAYER_NORM_NAME; } + +char const* SkipLayerNormPluginDynamicCreator::getPluginVersion() const noexcept { return kSKIP_LAYER_NORM_VERSION; } + +PluginFieldCollection const* SkipLayerNormPluginDynamicCreator::getFieldNames() noexcept { return &mFC; } + +IPluginV2* SkipLayerNormPluginDynamicCreator::createPlugin(char const* name, PluginFieldCollection const* fc) noexcept { + try { + gLogInfo << "SkipLayerNormPluginDynamicCreator createPlugin" << endl; + + int32_t ld = 0; + Weights beta{DataType::kFLOAT, nullptr, 0}; + Weights gamma{DataType::kFLOAT, nullptr, 0}; + Weights bias{DataType::kFLOAT, nullptr, 0}; + int32_t typeId = -1; + + IXRT_PLUGIN_ASSERT(fc != nullptr); + + plugin::validateRequiredAttributesExist({"type_id", "beta", "ld", "gamma"}, fc); + + for (int32_t i = 0; i < fc->nbFields; i++) { + std::string field_name(fc->fields[i].name); + if (field_name.compare("ld") == 0) { + ld = *static_cast(fc->fields[i].data); + gLogInfo << "Building ld: " << ld << endl; + } + + if (field_name.compare("type_id") == 0) { + typeId = *static_cast(fc->fields[i].data); + gLogInfo << "Building typeId: " << typeId << endl; + } + + if (field_name.compare("beta") == 0) { + gLogInfo << "Building beta..." << endl; + beta.values = fc->fields[i].data; + beta.count = fc->fields[i].length; + beta.type = fieldTypeToDataType(fc->fields[i].type); + } + + if (field_name.compare("gamma") == 0) { + gLogInfo << "Building gamma..." << endl; + gamma.values = fc->fields[i].data; + gamma.count = fc->fields[i].length; + gamma.type = fieldTypeToDataType(fc->fields[i].type); + } + + if (field_name.compare("bias") == 0) { + gLogInfo << "Building bias..." << endl; + bias.values = fc->fields[i].data; + bias.count = fc->fields[i].length; + bias.type = fieldTypeToDataType(fc->fields[i].type); + } + } + gLogInfo << "Type " << typeId << endl; + + IXRT_PLUGIN_CHECK_VALUE(typeId >= 0 && typeId <= 3, + ("SkipLayerNorm: Invalid type ID: " + std::to_string(typeId)).c_str()); + + IXRT_PLUGIN_CHECK_VALUE(beta.values != nullptr, "SkipLayerNorm: invalid beta"); + IXRT_PLUGIN_CHECK_VALUE(beta.count > 0, "SkipLayerNorm: invalid beta"); + + IXRT_PLUGIN_CHECK_VALUE(gamma.values != nullptr, "SkipLayerNorm: invalid gamma"); + IXRT_PLUGIN_CHECK_VALUE(gamma.count > 0, "SkipLayerNorm: invalid gamma"); + + IXRT_PLUGIN_CHECK_VALUE(typeId == (int)DataType::kHALF, "typeId != DataType::kHALF error"); + + return new SkipLayerNormPluginDynamic(name, static_cast(typeId), ld, beta, gamma, bias); + } catch (std::exception const& e) { + caughtError(e); + } + return nullptr; +} + +nvinfer1::IPluginV2* SkipLayerNormPluginDynamicCreator::deserializePlugin(char const* name, void const* serialData, + size_t serialLength) noexcept { + try { + return new SkipLayerNormPluginDynamic(name, serialData, serialLength); + } catch (std::exception const& e) { + caughtError(e); + } + return nullptr; +} + +void SkipLayerNormPluginDynamicCreator::setPluginNamespace(char const* pluginNamespace) noexcept { + try { + mNamespace = pluginNamespace; + } catch (std::exception const& e) { + caughtError(e); + } +} + +char const* SkipLayerNormPluginDynamicCreator::getPluginNamespace() const noexcept { return mNamespace.c_str(); } + +//#########################################################################// +SkipLayerNormPluginDynamic::SkipLayerNormPluginDynamic(const std::string name, const DataType type, int32_t const ld, + Weights const& beta, Weights const& gamma, Weights const& bias) + : mLayerName(name), mGammaDev(nullptr), mBetaDev(nullptr), mHiddenSize(ld), mType(type), mBiasDev(nullptr) { + IXRT_PLUGIN_ASSERT(mType == nvinfer1::DataType::kFLOAT || mType == nvinfer1::DataType::kHALF); + + mCfgType = DataType::kHALF; + mParamWordsize = getElementSize(mCfgType); + + mBeta.convertAndCopy(beta, mCfgType); + mGamma.convertAndCopy(gamma, mCfgType); + + mHasBias = (bias.values != nullptr); + if (mHasBias) { + mBias.convertAndCopy(bias, mCfgType); + } + + copyToDevice(mGamma, getWeightsSize(mGamma, mCfgType), mGammaDev); + copyToDevice(mBeta, getWeightsSize(mBeta, mCfgType), mBetaDev); + if (mHasBias) { + copyToDevice(mBias, getWeightsSize(mBias, mCfgType), mBiasDev); + } +} + +SkipLayerNormPluginDynamic::SkipLayerNormPluginDynamic(const std::string& name, void const* data, size_t length) + : mLayerName(name), mGammaDev(nullptr), mBetaDev(nullptr), mBiasDev(nullptr) { + gLogInfo << "SkipLayerNormPluginDynamic deserialize" << endl; + + // Deserialize in the same order as serialization + deserialize_value(&data, &length, &mType); + deserialize_value(&data, &length, &mCfgType); + deserialize_value(&data, &length, &mHiddenSize); + deserialize_value(&data, &length, &mHasBias); + + IXRT_PLUGIN_ASSERT(mCfgType == nvinfer1::DataType::kFLOAT || mCfgType == nvinfer1::DataType::kHALF); + mParamWordsize = getElementSize(mCfgType); + + char const* d = static_cast(data); + mBeta.convertAndCopy(d, mHiddenSize, mCfgType); + mGamma.convertAndCopy(d, mHiddenSize, mCfgType); + if (mHasBias) { + mBias.convertAndCopy(d, mHiddenSize, mCfgType); + } + + copyToDevice(mGamma, getWeightsSize(mGamma, mCfgType), mGammaDev); + copyToDevice(mBeta, getWeightsSize(mBeta, mCfgType), mBetaDev); + if (mHasBias) { + copyToDevice(mBias, getWeightsSize(mBias, mCfgType), mBiasDev); + } +} + +// IPluginV2Ext Methods +DataType SkipLayerNormPluginDynamic::getOutputDataType(int32_t index, DataType const* inputTypes, + int32_t nbInputs) const noexcept { + try { + IXRT_PLUGIN_ASSERT(inputTypes != nullptr); + IXRT_PLUGIN_ASSERT(index == 0); + IXRT_PLUGIN_ASSERT(nbInputs == 2); + return inputTypes[0]; + } catch (std::exception const& e) { + caughtError(e); + } + return DataType{}; +} + +// IPluginV2 Methods +char const* SkipLayerNormPluginDynamic::getPluginType() const noexcept { return kSKIP_LAYER_NORM_NAME; } + +char const* SkipLayerNormPluginDynamic::getPluginVersion() const noexcept { return kSKIP_LAYER_NORM_VERSION; } + +int32_t SkipLayerNormPluginDynamic::getNbOutputs() const noexcept { return 1; } +int32_t SkipLayerNormPluginDynamic::initialize() noexcept { + gLogInfo << "SkipLayerNormPluginDynamic initialize" << endl; + return 0; +} + +void SkipLayerNormPluginDynamic::terminate() noexcept { gLogInfo << "SkipLayerNormPluginDynamic terminate" << endl; } + +size_t SkipLayerNormPluginDynamic::getSerializationSize() const noexcept { + const size_t biasSize = mHasBias ? (mHiddenSize * mParamWordsize) : 0; + return 2 * mParamWordsize * mHiddenSize + 2 * sizeof(DataType) + sizeof(mHiddenSize) + biasSize + sizeof(mHasBias); +} + +void SkipLayerNormPluginDynamic::serialize(void* buffer) const noexcept { + try { + serialize_value(&buffer, mType); + serialize_value(&buffer, mCfgType); + serialize_value(&buffer, mHiddenSize); + serialize_value(&buffer, mHasBias); + + char* d = static_cast(buffer); + serFromDev(d, static_cast(mBetaDev.get()), mHiddenSize * mParamWordsize); + serFromDev(d, static_cast(mGammaDev.get()), mHiddenSize * mParamWordsize); + if (mHasBias) { + serFromDev(d, static_cast(mBiasDev.get()), mHiddenSize * mParamWordsize); + } + } catch (std::exception const& e) { + caughtError(e); + } +} + +void SkipLayerNormPluginDynamic::destroy() noexcept { + try { + gLogInfo << "SkipLayerNormPluginDynamic destroy" << endl; + // This gets called when the network containing plugin is destroyed + mGammaDev.reset(nullptr); + mBetaDev.reset(nullptr); + if (mHasBias) { + mBiasDev.reset(nullptr); + } + delete this; + } catch (std::exception const& e) { + caughtError(e); + } +} + +void SkipLayerNormPluginDynamic::setPluginNamespace(char const* libNamespace) noexcept { + try { + mNamespace = libNamespace; + } catch (std::exception const& e) { + caughtError(e); + } +} + +char const* SkipLayerNormPluginDynamic::getPluginNamespace() const noexcept { return mNamespace.c_str(); } + +// IPluginV2DynamicExt Methods +IPluginV2DynamicExt* SkipLayerNormPluginDynamic::clone() const noexcept { + try { + gLogInfo << "SkipLayerNormPluginDynamic clone" << endl; + + auto* p = new SkipLayerNormPluginDynamic(mLayerName, mType, mHiddenSize, mBeta, mGamma, mBias); + p->initialize(); + p->setPluginNamespace(mNamespace.c_str()); + return p; + } catch (std::exception const& e) { + caughtError(e); + } + return nullptr; +} + +DimsExprs SkipLayerNormPluginDynamic::getOutputDimensions(int32_t outputIndex, DimsExprs const* inputs, + int32_t nbInputs, IExprBuilder& exprBuilder) noexcept { + try { + IXRT_PLUGIN_ASSERT(inputs != nullptr); + IXRT_PLUGIN_ASSERT(nbInputs == 2); + IXRT_PLUGIN_ASSERT(outputIndex == 0); + IXRT_PLUGIN_ASSERT(inputs[0].nbDims == inputs[1].nbDims); + return inputs[0]; + } catch (std::exception const& e) { + caughtError(e); + } + return DimsExprs{}; +} + +bool SkipLayerNormPluginDynamic::supportsFormatCombination(int32_t pos, PluginTensorDesc const* inOut, int32_t nbInputs, + int32_t nbOutputs) noexcept { + try { + IXRT_PLUGIN_ASSERT(inOut != nullptr); + IXRT_PLUGIN_ASSERT(nbInputs == 2); + IXRT_PLUGIN_ASSERT(nbOutputs == 1); + IXRT_PLUGIN_ASSERT(pos >= 0 && pos < (nbInputs + nbOutputs)); + + PluginTensorDesc const& in = inOut[pos]; + if (pos == 0) { + return (in.type == DataType::kHALF or in.type == DataType::kFLOAT) && (in.format == TensorFormat::kLINEAR); + } + PluginTensorDesc const& prev = inOut[pos - 1]; + + return in.type == prev.type && in.format == prev.format; + } catch (std::exception const& e) { + caughtError(e); + } + return false; +} + +void SkipLayerNormPluginDynamic::configurePlugin(DynamicPluginTensorDesc const* inputs, int32_t nbInputs, + DynamicPluginTensorDesc const* outputs, int32_t nbOutputs) noexcept { + try { + gLogInfo << "SkipLayerNormPluginDynamic configurePlugin" << endl; + + // Validate input arguments + IXRT_PLUGIN_ASSERT(inputs != nullptr); + IXRT_PLUGIN_ASSERT(outputs != nullptr); + IXRT_PLUGIN_ASSERT(nbOutputs == 1); + IXRT_PLUGIN_ASSERT(nbInputs == 2); + + auto const& inDims0 = inputs[0].desc.dims; + auto const& inDims1 = inputs[1].desc.dims; + IXRT_PLUGIN_ASSERT(inDims0.nbDims == inDims1.nbDims); + + IXRT_PLUGIN_ASSERT(std::equal(inDims0.d, inDims0.d + inDims0.nbDims, inDims1.d)); + + IXRT_PLUGIN_ASSERT(inDims0.nbDims == 3 || inDims0.nbDims == 5); + mHiddenSize = inDims0.d[HDIM]; // hiddensize + IXRT_PLUGIN_ASSERT(mHiddenSize != 0U); + // IXRT_PLUGIN_ASSERT(inDims0.d[3] == 1); + // IXRT_PLUGIN_ASSERT(inDims0.d[4] == 1); + // IXRT_PLUGIN_ASSERT(outputs[0].desc.type == DataType::kHALF); + + mCfgType = inputs[0].desc.type == DataType::kINT8 ? DataType::kHALF : DataType::kHALF; + + auto const paramType = getParamWordType(mCfgType); + mParamWordsize = getElementSize(paramType); + } catch (std::exception const& e) { + caughtError(e); + } +} + +size_t SkipLayerNormPluginDynamic::getWorkspaceSize(PluginTensorDesc const* inputs, int32_t nbInputs, + PluginTensorDesc const* outputs, int32_t nbOutputs) const noexcept { + return 0; +} + +int32_t SkipLayerNormPluginDynamic::enqueue(PluginTensorDesc const* inputDesc, PluginTensorDesc const* outputDesc, + void const* const* inputs, void* const* outputs, void* workspace, + cudaStream_t stream) noexcept { + gLogInfo << "in SkipLayerNormPluginDynamic.." << endl; + int32_t status = -1; + try { + IXRT_PLUGIN_ASSERT(inputs != nullptr); + IXRT_PLUGIN_ASSERT(outputs != nullptr); + int32_t const inputVolume = volume(inputDesc[0].dims); + DataType iType = inputDesc->type; + + // Our plugin outputs only one tensor + // Launch CUDA kernel wrapper and save its return value + if (iType == DataType::kFLOAT) { + auto const* input = static_cast(inputs[0]); + auto skip = (float*)(inputs[1]); + auto* output = static_cast(outputs[0]); + auto const* const bias = static_cast(mBiasDev.get()); + auto const* const beta = static_cast(mBetaDev.get()); + auto const* const gamma = static_cast(mGammaDev.get()); + if (mHasBias) { + status = computeSkipLayerNorm(stream, static_cast(mHiddenSize), inputVolume, + input, gamma, beta, bias, skip, output); + } else { + status = computeSkipLayerNorm(stream, static_cast(mHiddenSize), inputVolume, + input, gamma, beta, bias, skip, output); + } + } else if (iType == DataType::kHALF) { + auto const* input = static_cast(inputs[0]); + auto skip = (half*)(inputs[1]); + auto* output = static_cast(outputs[0]); + auto const* const bias = static_cast(mBiasDev.get()); + auto const* const beta = static_cast(mBetaDev.get()); + auto const* const gamma = static_cast(mGammaDev.get()); + if (mHasBias) { + status = computeSkipLayerNorm(stream, static_cast(mHiddenSize), inputVolume, input, + gamma, beta, bias, skip, output); + } else { + status = computeSkipLayerNorm(stream, static_cast(mHiddenSize), inputVolume, + input, gamma, beta, bias, skip, output); + } + } else { + IXRT_PLUGIN_CHECK_VALUE(false, "Unsupported type error, expected [kHALF,kFLOAT], but received " + + std::to_string(static_cast(iType))); + } + return status; + } catch (std::exception const& e) { + caughtError(e); + } + return STATUS_FAILURE; +} diff --git a/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/skip_layernorm/skipLayerNormPlugin.cu b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/skip_layernorm/skipLayerNormPlugin.cu new file mode 100644 index 0000000000000000000000000000000000000000..e1d3ffcb3a18508c6e1748c4e0f2ae851ab0b065 --- /dev/null +++ b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/skip_layernorm/skipLayerNormPlugin.cu @@ -0,0 +1,50 @@ +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +* All Rights Reserved. +* +* Licensed under the Apache License, Version 2.0 (the "License"); you may +* not use this file except in compliance with the License. You may obtain +* a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +* License for the specific language governing permissions and limitations +* under the License. +*/ +#include +#include + +#include +#include + +#include "backend/transformer/transformer_add_norm.h" +#include "skipLayerNormPlugin.h" + +using namespace nvinfer1::plugin::backend; + +namespace nvinfer1::plugin { +namespace bert { + +template +int32_t computeSkipLayerNorm(cudaStream_t stream, int32_t E, int32_t volume, const T* input, const half* gamma, + const half* beta, const half* bias, T* skip, T* output) { + assert(volume % E == 0); + int32_t batch_tokens = volume / E; + IxinferResidualBiasLn(input, gamma, beta, bias, output, skip, batch_tokens, E, stream, true); + return 0; +} + +template int32_t computeSkipLayerNorm(cudaStream_t, int32_t, int32_t, const half*, const half*, const half*, + const half*, half*, half*); +template int32_t computeSkipLayerNorm(cudaStream_t, int32_t, int32_t, const half*, const half*, + const half*, const half*, half*, half*); + +template int32_t computeSkipLayerNorm(cudaStream_t, int32_t, int32_t, const float*, const half*, + const half*, const half*, float*, float*); +template int32_t computeSkipLayerNorm(cudaStream_t, int32_t, int32_t, const float*, const half*, + const half*, const half*, float*, float*); + +} // namespace bert +} // namespace nvinfer1::plugin diff --git a/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/skip_layernorm/skipLayerNormPlugin.h b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/skip_layernorm/skipLayerNormPlugin.h new file mode 100644 index 0000000000000000000000000000000000000000..e369947acf87e4e406d76434e4ed5969786a0ccf --- /dev/null +++ b/models/nlp/language_model/bert_base_squad/ixrt/src_ixrt/skip_layernorm/skipLayerNormPlugin.h @@ -0,0 +1,117 @@ +/* Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +* All Rights Reserved. +* +* Licensed under the Apache License, Version 2.0 (the "License"); you may +* not use this file except in compliance with the License. You may obtain +* a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +* License for the specific language governing permissions and limitations +* under the License. +*/ +#pragma once +#include +#include + +#include +#include +#include + +#include "NvInferRuntime.h" +#include "NvInferRuntimeCommon.h" +#include "bertCommon.h" + +namespace nvinfer1::plugin { +namespace bert { + +template +int32_t computeSkipLayerNorm(cudaStream_t stream, int32_t E, int32_t volume, const T* input, const half* gamma, + const half* beta, const half* bias, T* skip, T* output); + +class SkipLayerNormPluginDynamic : public IPluginV2DynamicExt { + public: + SkipLayerNormPluginDynamic(const std::string name, const nvinfer1::DataType type, int32_t const ld, + nvinfer1::Weights const& beta, nvinfer1::Weights const& gamma, + nvinfer1::Weights const& bias); + SkipLayerNormPluginDynamic(const std::string& name, void const* data, size_t length); + SkipLayerNormPluginDynamic() noexcept = delete; + ~SkipLayerNormPluginDynamic() override = default; + + // IPluginV2 methods + char const* getPluginType() const noexcept override; + char const* getPluginVersion() const noexcept override; + int32_t getNbOutputs() const noexcept override; + int32_t initialize() noexcept override; + void terminate() noexcept override; + size_t getSerializationSize() const noexcept override; + void serialize(void* buffer) const noexcept override; + void destroy() noexcept override; + void setPluginNamespace(char const* libNamespace) noexcept override; + char const* getPluginNamespace() const noexcept override; + + // IPluginV2Ext methods + DataType getOutputDataType(int32_t index, DataType const* inputType, int32_t nbInputs) const noexcept override; + + // IPluginV2DynamicExt methods + IPluginV2DynamicExt* clone() const noexcept override; + DimsExprs getOutputDimensions(int32_t outputIndex, DimsExprs const* inputs, int32_t nbInputs, + IExprBuilder& exprBuilder) noexcept override; + bool supportsFormatCombination(int32_t pos, PluginTensorDesc const* inOut, int32_t nbInputs, + int32_t nbOutputs) noexcept override; + void configurePlugin(DynamicPluginTensorDesc const* in, int32_t nbInputs, DynamicPluginTensorDesc const* out, + int32_t nbOutputs) noexcept override; + size_t getWorkspaceSize(PluginTensorDesc const* inputs, int32_t nbInputs, PluginTensorDesc const* outputs, + int32_t nbOutputs) const noexcept override; + int32_t enqueue(PluginTensorDesc const* inputDesc, PluginTensorDesc const* outputDesc, void const* const* inputs, + void* const* outputs, void* workspace, cudaStream_t stream) noexcept override; + + private: + const std::string mLayerName; + std::string mNamespace; + cuda_unique_ptr mGammaDev; + cuda_unique_ptr mBetaDev; + WeightsWithOwnership mGamma; + WeightsWithOwnership mBeta; + size_t mHiddenSize{}; + size_t mParamWordsize{}; + DataType mType; + DataType mCfgType; + // mCfgType is the dataType for beta, gamma bias weights, always fp16 or fp32 + // mType is the plugin IO datatype, can be int8 + + bool mHasBias{}; + cuda_unique_ptr mBiasDev; + WeightsWithOwnership mBias; +}; + +class SkipLayerNormPluginDynamicCreator : public nvinfer1::IPluginCreator { + public: + SkipLayerNormPluginDynamicCreator(); + + char const* getPluginName() const noexcept override; + + char const* getPluginVersion() const noexcept override; + + nvinfer1::PluginFieldCollection const* getFieldNames() noexcept override; + + nvinfer1::IPluginV2* createPlugin(char const* name, nvinfer1::PluginFieldCollection const* fc) noexcept override; + + nvinfer1::IPluginV2* deserializePlugin(char const* name, void const* serialData, + size_t serialLength) noexcept override; + + void setPluginNamespace(char const* pluginNamespace) noexcept override; + + char const* getPluginNamespace() const noexcept override; + + private: + static nvinfer1::PluginFieldCollection mFC; + static std::vector mPluginAttributes; + std::string mNamespace; +}; + +} // namespace bert +} // namespace nvinfer1::plugin