From efd7bc9314ce6c034935df3b27503c095d45bfe6 Mon Sep 17 00:00:00 2001 From: may Date: Mon, 22 Jul 2024 15:31:58 +0800 Subject: [PATCH] Add transformer_asr --- .../transformer_asr/ixrt/README.md | 56 ++ .../transformer_asr/ixrt/aishell_prepare.py | 141 ++++ .../transformer_asr/ixrt/beam_search.py | 381 +++++++++++ .../transformer_asr/ixrt/build.sh | 23 + .../transformer_asr/ixrt/builder.py | 466 ++++++++++++++ .../transformer_asr/ixrt/convert.py | 95 +++ .../transformer_asr/ixrt/ctc.py | 394 ++++++++++++ .../ixrt/faster_cat/__init__.py | 13 + .../transformer_asr/ixrt/faster_cat/build.sh | 22 + .../transformer_asr/ixrt/faster_cat/kernel.cu | 79 +++ .../transformer_asr/ixrt/faster_cat/setup.py | 48 ++ .../transformer_asr/ixrt/faster_cat/test.cpp | 21 + .../transformer_asr/ixrt/faster_cat/test.py | 37 ++ .../ixrt/faster_layer_norm/__init__.py | 16 + .../ixrt/faster_layer_norm/build.sh | 22 + .../ixrt/faster_layer_norm/kernel.cu | 168 +++++ .../ixrt/faster_layer_norm/setup.py | 48 ++ .../ixrt/faster_layer_norm/test.cpp | 22 + .../faster_layer_norm/transformer_helper.cuh | 295 +++++++++ .../ixrt/faster_logsumexp/__init__.py | 38 ++ .../ixrt/faster_logsumexp/build.sh | 22 + .../ixrt/faster_logsumexp/kernel.cu | 155 +++++ .../ixrt/faster_logsumexp/setup.py | 48 ++ .../ixrt/faster_logsumexp/test.cpp | 27 + .../ixrt/faster_logsumexp/test.py | 50 ++ .../ixrt/faster_stack/__init__.py | 33 + .../ixrt/faster_stack/build.sh | 22 + .../ixrt/faster_stack/kernel.cu | 146 +++++ .../ixrt/faster_stack/setup.py | 48 ++ .../ixrt/faster_stack/test.cpp | 29 + .../transformer_asr/ixrt/faster_stack/test.py | 74 +++ .../ixrt/hparams/train_ASR_transformer.yaml | 253 ++++++++ .../transformer_asr/ixrt/inference.py | 606 ++++++++++++++++++ .../transformer_asr/ixrt/load_ixrt_plugin.py | 26 + 34 files changed, 3924 insertions(+) create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/README.md create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/aishell_prepare.py create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/beam_search.py create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/build.sh create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/builder.py create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/convert.py create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/ctc.py create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/faster_cat/__init__.py create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/faster_cat/build.sh create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/faster_cat/kernel.cu create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/faster_cat/setup.py create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/faster_cat/test.cpp create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/faster_cat/test.py create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/faster_layer_norm/__init__.py create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/faster_layer_norm/build.sh create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/faster_layer_norm/kernel.cu create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/faster_layer_norm/setup.py create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/faster_layer_norm/test.cpp create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/faster_layer_norm/transformer_helper.cuh create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/faster_logsumexp/__init__.py create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/faster_logsumexp/build.sh create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/faster_logsumexp/kernel.cu create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/faster_logsumexp/setup.py create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/faster_logsumexp/test.cpp create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/faster_logsumexp/test.py create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/faster_stack/__init__.py create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/faster_stack/build.sh create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/faster_stack/kernel.cu create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/faster_stack/setup.py create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/faster_stack/test.cpp create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/faster_stack/test.py create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/hparams/train_ASR_transformer.yaml create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/inference.py create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/load_ixrt_plugin.py diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/README.md b/models/speech/speech_recognition/transformer_asr/ixrt/README.md new file mode 100644 index 00000000..7560b5eb --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/README.md @@ -0,0 +1,56 @@ +# Asr transformer fp16 inference (BeamSearch) + +## Description + +Beam search allows us to exert control over the output of text generation. This is useful because we sometimes know exactly what we want inside the output. For example, in a Neural Machine Translation task, we might know which words must be included in the final translation with a dictionary lookup. + + +## Setup + +### Install + +``` +pip3 install speechbrain==0.5.13 +``` + +* ixrt 4.0.1_MR release + +### Download + +Pretrained model: + +Dataset: to download the Aishell dataset. + +``` +# Make sure the checkpoint path is results/transformer/8886/save +mkdir -p results/transformer/8886/save +# Make sure the dataset path is results/transformer/8886/save +mkdir -p /home/data/speechbrain +``` + +## Inference + +### Build faster kernels + +```bash +bash build.sh +``` + +### Build engine + +max_batch_size and max_seq_len depend on the situation. + +``` +python3 builder.py \ +--ckpt_path results/transformer/8886/save \ +--head_num 4 \ +--max_batch_size 64 \ +--max_seq_len 1024 \ +--engine_path transformer.engine +``` + +### Run engine + +``` +python3 inference.py hparams/train_ASR_transformer.yaml --data_folder=/home/data/speechbrain/aishell --engine_path transformer.engine +``` \ No newline at end of file diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/aishell_prepare.py b/models/speech/speech_recognition/transformer_asr/ixrt/aishell_prepare.py new file mode 100644 index 00000000..ba319394 --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/aishell_prepare.py @@ -0,0 +1,141 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +import os +import shutil +import logging +from speechbrain.dataio.dataio import read_audio +from speechbrain.utils.data_utils import download_file +import glob +import csv +import argparse + +logger = logging.getLogger(__name__) + + +def prepare_aishell(data_folder, save_folder, skip_prep=False): + """ + This function prepares the AISHELL-1 dataset. + If the folder does not exist, the zip file will be extracted. If the zip file does not exist, it will be downloaded. + + data_folder : path to AISHELL-1 dataset. + save_folder: path where to store the manifest csv files. + skip_prep: If True, skip data preparation. + + """ + if skip_prep: + return + + # If the data folders do not exist, we need to extract the data + if not os.path.isdir(os.path.join(data_folder, "data_aishell/wav")): + # # Check for zip file and download if it doesn't exist + # zip_location = os.path.join(data_folder, "data_aishell.tgz") + # if not os.path.exists(zip_location): + # url = "https://www.openslr.org/resources/33/data_aishell.tgz" + # download_file(url, zip_location, unpack=True) + # logger.info("Extracting data_aishell.tgz...") + # shutil.unpack_archive(zip_location, data_folder) + + wav_dir = os.path.join(data_folder, "data_aishell/wav") + tgz_list = glob.glob(wav_dir + "/*.tar.gz") + for tgz in tgz_list: + shutil.unpack_archive(tgz, wav_dir) + os.remove(tgz) + + # Create filename-to-transcript dictionary + filename2transcript = {} + with open( + os.path.join( + data_folder, "data_aishell/transcript/aishell_transcript_v0.8.txt" + ), + "r", + ) as f: + lines = f.readlines() + for line in lines: + key = line.split()[0] + value = " ".join(line.split()[1:]) + filename2transcript[key] = value + + splits = [ + # "train", + "dev", + "test", + ] + ID_start = 0 # needed to have a unique ID for each audio + for split in splits: + new_filename = os.path.join(save_folder, split) + ".csv" + if os.path.exists(new_filename): + continue + logger.info("Preparing %s..." % new_filename) + + csv_output = [["ID", "duration", "wav", "transcript"]] + entry = [] + + all_wavs = glob.glob( + os.path.join(data_folder, "data_aishell/wav") + "/" + split + "/*/*.wav" + ) + for i in range(len(all_wavs)): + filename = all_wavs[i].split("/")[-1].split(".wav")[0] + if filename not in filename2transcript: + continue + signal = read_audio(all_wavs[i]) + duration = signal.shape[0] / 16000 + transcript_ = filename2transcript[filename] + csv_line = [ + ID_start + i, + str(duration), + all_wavs[i], + transcript_, + ] + entry.append(csv_line) + + csv_output = csv_output + entry + + with open(new_filename, mode="w") as csv_f: + csv_writer = csv.writer( + csv_f, delimiter=",", quotechar='"', quoting=csv.QUOTE_MINIMAL + ) + for line in csv_output: + csv_writer.writerow(line) + + msg = "\t%s successfully created!" % (new_filename) + logger.info(msg) + + ID_start += len(all_wavs) + + +def parse_config(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--data_folder", + type=str, + default="/home/data/speechbrain/aishell", + help="data folder", + ) + parser.add_argument( + "--save_folder", + type=str, + default="/home/data/speechbrain/aishell/csv_data", + help="csv save folder", + ) + + config = parser.parse_args() + print("Config:", config) + return config + + +if __name__ == "__main__": + + config = parse_config() + prepare_aishell(config.data_folder, config.save_folder, skip_prep=False) diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/beam_search.py b/models/speech/speech_recognition/transformer_asr/ixrt/beam_search.py new file mode 100644 index 00000000..61e5c794 --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/beam_search.py @@ -0,0 +1,381 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +import torch +from ctc import CTCPrefixScorer +import time + +def forward(self, enc_states, wav_len): # noqa: C901 + """Applies beamsearch and returns the predicted tokens.""" + enc_lens = torch.round(enc_states.shape[1] * wav_len).int() + device = enc_states.device + batch_size = enc_states.shape[0] + + memory = self.reset_mem(batch_size * self.beam_size, device=device) + + if self.lm_weight > 0: + lm_memory = self.reset_lm_mem(batch_size * self.beam_size, device) + + if self.ctc_weight > 0: + # (batch_size * beam_size, L, vocab_size) + ctc_outputs = self.ctc_forward_step(enc_states) + ctc_scorer = CTCPrefixScorer( + ctc_outputs, + enc_lens, + batch_size, + self.beam_size, + self.blank_index, + self.eos_index, + self.ctc_window_size, + ) + ctc_memory = None + + # Inflate the enc_states and enc_len by beam_size times + enc_states = inflate_tensor(enc_states, times=self.beam_size, dim=0) + enc_lens = inflate_tensor(enc_lens, times=self.beam_size, dim=0) + + # Using bos as the first input + inp_tokens = ( + torch.zeros(batch_size * self.beam_size, device=device) + .fill_(self.bos_index) + .long() + ) + + # The first index of each sentence. + self.beam_offset = ( + torch.arange(batch_size, device=device) * self.beam_size + ) + + # initialize sequence scores variables. + sequence_scores = torch.empty( + batch_size * self.beam_size, device=device + ) + sequence_scores.fill_(float("-inf")) + + # keep only the first to make sure no redundancy. + sequence_scores.index_fill_(0, self.beam_offset, 0.0) + + # keep the hypothesis that reaches eos and their corresponding score and log_probs. + hyps_and_scores = [[] for _ in range(batch_size)] + + # keep the sequences that still not reaches eos. + alived_seq = torch.empty( + batch_size * self.beam_size, 0, device=device + ).long() + + # Keep the log-probabilities of alived sequences. + alived_log_probs = torch.empty( + batch_size * self.beam_size, 0, device=device + ) + + min_decode_steps = int(enc_states.shape[1] * self.min_decode_ratio) + max_decode_steps = int(enc_states.shape[1] * self.max_decode_ratio) + + # Initialize the previous attention peak to zero + # This variable will be used when using_max_attn_shift=True + prev_attn_peak = torch.zeros(batch_size * self.beam_size, device=device) + + for t in range(max_decode_steps): + # terminate condition + if self._check_full_beams(hyps_and_scores, self.beam_size): + break + + log_probs, memory, attn = self.forward_step( + inp_tokens, memory, enc_states, enc_lens + ) + log_probs = self.att_weight * log_probs + + # Keep the original value + log_probs_clone = log_probs.clone().reshape(batch_size, -1) + vocab_size = log_probs.shape[-1] + + if self.using_max_attn_shift: + # Block the candidates that exceed the max shift + cond, attn_peak = self._check_attn_shift(attn, prev_attn_peak) + log_probs = mask_by_condition( + log_probs, cond, fill_value=self.minus_inf + ) + prev_attn_peak = attn_peak + + # Set eos to minus_inf when less than minimum steps. + if t < min_decode_steps: + log_probs[:, self.eos_index] = self.minus_inf + + # Set the eos prob to minus_inf when it doesn't exceed threshold. + if self.using_eos_threshold: + cond = self._check_eos_threshold(log_probs) + log_probs[:, self.eos_index] = mask_by_condition( + log_probs[:, self.eos_index], + cond, + fill_value=self.minus_inf, + ) + + # adding LM scores to log_prob if lm_weight > 0 + if self.lm_weight > 0: + lm_log_probs, lm_memory = self.lm_forward_step( + inp_tokens, lm_memory + ) + log_probs = log_probs + self.lm_weight * lm_log_probs + + # adding CTC scores to log_prob if ctc_weight > 0 + if self.ctc_weight > 0: + g = alived_seq + # block blank token + log_probs[:, self.blank_index] = self.minus_inf + if self.ctc_weight != 1.0 and self.ctc_score_mode == "partial": + # pruning vocab for ctc_scorer + _, ctc_candidates = log_probs.topk( + self.beam_size * 2, dim=-1 + ) + else: + ctc_candidates = None + + ctc_log_probs, ctc_memory = ctc_scorer.forward_step( + g, ctc_memory, ctc_candidates, attn + ) + log_probs = log_probs + self.ctc_weight * ctc_log_probs + + scores = sequence_scores.unsqueeze(1).expand(-1, vocab_size) + scores = scores + log_probs + + # length normalization + if self.length_normalization: + scores = scores / (t + 1) + + # keep topk beams + scores, candidates = scores.view(batch_size, -1).topk( + self.beam_size, dim=-1 + ) + + # The input for the next step, also the output of current step. + inp_tokens = (candidates % vocab_size).view( + batch_size * self.beam_size + ) + + scores = scores.view(batch_size * self.beam_size) + sequence_scores = scores + + # recover the length normalization + if self.length_normalization: + sequence_scores = sequence_scores * (t + 1) + + # The index of which beam the current top-K output came from in (t-1) timesteps. + predecessors = ( + torch.div(candidates, vocab_size, rounding_mode="floor") + + self.beam_offset.unsqueeze(1).expand_as(candidates) + ).view(batch_size * self.beam_size) + + # Permute the memory to synchoronize with the output. + memory = self.permute_mem(memory, index=predecessors) + if self.lm_weight > 0: + lm_memory = self.permute_lm_mem(lm_memory, index=predecessors) + + if self.ctc_weight > 0: + ctc_memory = ctc_scorer.permute_mem(ctc_memory, candidates) + + # If using_max_attn_shift, then the previous attn peak has to be permuted too. + if self.using_max_attn_shift: + prev_attn_peak = torch.index_select( + prev_attn_peak, dim=0, index=predecessors + ) + + # Add coverage penalty + if self.coverage_penalty > 0: + cur_attn = torch.index_select(attn, dim=0, index=predecessors) + + # coverage: cumulative attention probability vector + if t == 0: + # Init coverage + self.coverage = cur_attn + + # the attn of transformer is [batch_size*beam_size, current_step, source_len] + if len(cur_attn.size()) > 2: + self.converage = torch.sum(cur_attn, dim=1) + else: + # Update coverage + self.coverage = torch.index_select( + self.coverage, dim=0, index=predecessors + ) + self.coverage = self.coverage + cur_attn + + # Compute coverage penalty and add it to scores + penalty = torch.max( + self.coverage, self.coverage.clone().fill_(0.5) + ).sum(-1) + penalty = penalty - self.coverage.size(-1) * 0.5 + penalty = penalty.view(batch_size * self.beam_size) + penalty = ( + penalty / (t + 1) if self.length_normalization else penalty + ) + scores = scores - penalty * self.coverage_penalty + + # Update alived_seq + alived_seq = torch.cat( + [ + torch.index_select(alived_seq, dim=0, index=predecessors), + inp_tokens.unsqueeze(1), + ], + dim=-1, + ) + + # Takes the log-probabilities + beam_log_probs = log_probs_clone[ + torch.arange(batch_size).unsqueeze(1), candidates + ].reshape(batch_size * self.beam_size) + alived_log_probs = torch.cat( + [ + torch.index_select( + alived_log_probs, dim=0, index=predecessors + ), + beam_log_probs.unsqueeze(1), + ], + dim=-1, + ) + + is_eos = self._update_hyp_and_scores( + inp_tokens, + alived_seq, + alived_log_probs, + hyps_and_scores, + scores, + timesteps=t, + ) + + # Block the paths that have reached eos. + sequence_scores.masked_fill_(is_eos, float("-inf")) + + if not self._check_full_beams(hyps_and_scores, self.beam_size): + # Using all eos to fill-up the hyps. + eos = ( + torch.zeros(batch_size * self.beam_size, device=device) + .fill_(self.eos_index) + .long() + ) + _ = self._update_hyp_and_scores( + eos, + alived_seq, + alived_log_probs, + hyps_and_scores, + scores, + timesteps=max_decode_steps, + ) + + ( + topk_hyps, + topk_scores, + topk_lengths, + log_probs, + ) = self._get_top_score_prediction(hyps_and_scores, topk=self.topk,) + # pick the best hyp + predictions = topk_hyps[:, 0, :] + predictions = batch_filter_seq2seq_output( + predictions, eos_id=self.eos_index + ) + + if self.return_log_probs: + return predictions, topk_scores, log_probs + else: + return predictions, topk_scores + + +def inflate_tensor(tensor, times, dim): + """This function inflates the tensor for times along dim. + + Arguments + --------- + tensor : torch.Tensor + The tensor to be inflated. + times : int + The tensor will inflate for this number of times. + dim : int + The dim to be inflated. + + Returns + ------- + torch.Tensor + The inflated tensor. + + Example + ------- + >>> tensor = torch.Tensor([[1,2,3], [4,5,6]]) + >>> new_tensor = inflate_tensor(tensor, 2, dim=0) + >>> new_tensor + tensor([[1., 2., 3.], + [1., 2., 3.], + [4., 5., 6.], + [4., 5., 6.]]) + """ + return torch.repeat_interleave(tensor, times, dim=dim) + +def batch_filter_seq2seq_output(prediction, eos_id=-1): + """Calling batch_size times of filter_seq2seq_output. + + Arguments + --------- + prediction : list of torch.Tensor + A list containing the output ints predicted by the seq2seq system. + eos_id : int, string + The id of the eos. + + Returns + ------ + list + The output predicted by seq2seq model. + + Example + ------- + >>> predictions = [torch.IntTensor([1,2,3,4]), torch.IntTensor([2,3,4,5,6])] + >>> predictions = batch_filter_seq2seq_output(predictions, eos_id=4) + >>> predictions + [[1, 2, 3], [2, 3]] + """ + outputs = [] + for p in prediction: + res = filter_seq2seq_output(p.tolist(), eos_id=eos_id) + outputs.append(res) + return outputs + +def filter_seq2seq_output(string_pred, eos_id=-1): + """Filter the output until the first eos occurs (exclusive). + + Arguments + --------- + string_pred : list + A list containing the output strings/ints predicted by the seq2seq system. + eos_id : int, string + The id of the eos. + + Returns + ------ + list + The output predicted by seq2seq model. + + Example + ------- + >>> string_pred = ['a','b','c','d','eos','e'] + >>> string_out = filter_seq2seq_output(string_pred, eos_id='eos') + >>> string_out + ['a', 'b', 'c', 'd'] + """ + if isinstance(string_pred, list): + try: + eos_index = next( + i for i, v in enumerate(string_pred) if v == eos_id + ) + except StopIteration: + eos_index = len(string_pred) + string_out = string_pred[:eos_index] + else: + raise ValueError("The input must be a list.") + return string_out \ No newline at end of file diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/build.sh b/models/speech/speech_recognition/transformer_asr/ixrt/build.sh new file mode 100644 index 00000000..a8991234 --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/build.sh @@ -0,0 +1,23 @@ +#!/usr/bin/env bash + +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +for i in fast* +do + cd $i + bash build.sh + cd .. +done diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/builder.py b/models/speech/speech_recognition/transformer_asr/ixrt/builder.py new file mode 100644 index 00000000..5c19a9f4 --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/builder.py @@ -0,0 +1,466 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +import argparse +import torch +from tensorrt.deploy.api import GraphTransform, create_source, create_target +from tensorrt.deploy.ir.data_type import DataType +from tensorrt.deploy.ir.variable import Variable, VariableOptions +from tensorrt.deploy.ir.graph import Graph +from collections import OrderedDict +import math +import re +import glob +import os +from onnx import numpy_helper +import subprocess + + +def parse_args(): + parser = argparse.ArgumentParser( + description="build ixrt engine", usage="" + ) + parser.add_argument( + "--ckpt_path", + type=str, + required=True, + help="", + ) + parser.add_argument( + "--head_num", + type=int, + required=True, + help="", + ) + parser.add_argument( + "--max_batch_size", + type=int, + required=True, + help="", + ) + parser.add_argument( + "--max_seq_len", + type=int, + required=True, + help="", + ) + parser.add_argument( + "--onnx_path", + type=str, + default=".tmp.onnx", + help="", + ) + parser.add_argument( + "--engine_path", + type=str, + required=True, + help="", + ) + args = parser.parse_args() + return args + + +def add_make_mask_op(graph, state_dict, args): + attributes = {} + + t = graph + inputs = [ + graph.make_variable('length_radio', dtype=DataType.FLOAT16), + graph.make_variable('input', dtype=DataType.FLOAT16), + ] + + outputs = [t.make_variable("attention_mask", dtype=DataType.INT32)] + + t.make_operator( + "MakeMaskByRadio_IxRT", inputs=inputs, outputs=outputs, **attributes + ) + + +def add_custom_linear_op(graph, state_dict, args): + linear_keys = [ + "1.custom_src_module.layers.0.w.weight", + "1.custom_src_module.layers.0.w.bias" + ] + W = numpy_helper.from_array(state_dict[linear_keys[0]].cpu().numpy(), name="W") + B = numpy_helper.from_array(state_dict[linear_keys[1]].cpu().numpy(), name="B") + attributes = { + "out_dims": state_dict["1.custom_src_module.layers.0.w.weight"].size(0), + "type_id": 1, + "W": W, + "B": B, + } + assert state_dict['1.custom_src_module.layers.0.w.weight'].size( + 0) == state_dict["1.custom_src_module.layers.0.w.bias"].size(0) + + t = graph + inputs = [ + graph.get_variable('input'), + ] + + outputs = [t.make_variable("custom_src_output")] + t.make_operator( + "CustomFCPluginDynamic_IxRT", inputs=inputs, outputs=outputs, **attributes + ) + + +# def add_custom_linear_op(graph, state_dict, args): +# linear_keys = [ +# "1.custom_src_module.layers.0.w.weight", +# "1.custom_src_module.layers.0.w.bias" +# ] +# attributes = { +# "linear_dim": state_dict["1.custom_src_module.layers.0.w.weight"].size(0), +# "hidden_size": state_dict["1.custom_src_module.layers.0.w.weight"].size(1), +# "has_bias": 1, +# "act_type": "none", +# } +# assert state_dict['1.custom_src_module.layers.0.w.weight'].size( +# 0) == state_dict["1.custom_src_module.layers.0.w.bias"].size(0) +# +# t = graph +# inputs = [ +# graph.get_variable('input'), +# ] +# +# outputs = [t.make_variable("custom_src_output",dtype=DataType.FLOAT16)] +# for key in linear_keys: +# inputs.append(t.make_variable(name=key, value=state_dict[key].half())) +# t.make_operator( +# "LinearFP16", inputs=inputs, outputs=outputs, **attributes +# ) + + +def add_pos_encode_op(graph, state_dict, args): + attributes = {} + t = graph + inputs = [ + graph.get_variable('custom_src_output'), + ] + outputs = [t.make_variable("hidden_state", dtype=DataType.FLOAT16)] + t.make_operator( + "PosEncodeSinCos_IxRT", inputs=inputs, outputs=outputs, **attributes + ) + + +def add_transformer_op(graph, state_dict, args): + enc_tensor_layer_fp16_keys = OrderedDict([ + ["1.encoder.layers.{}.norm1.norm.weight", [args.hidden_size]], + ["1.encoder.layers.{}.norm1.norm.bias", [args.hidden_size]], + ["1.encoder.layers.{}.self_att.att.in_proj_weight", + [args.hidden_size * 3, args.hidden_size]], + ["1.encoder.layers.{}.self_att.att.in_proj_bias", [args.hidden_size * 3]], + ["1.encoder.layers.{}.self_att.att.out_proj.weight", + [args.hidden_size, args.hidden_size]], + ["1.encoder.layers.{}.self_att.att.out_proj.bias", [args.hidden_size]], + ["1.encoder.layers.{}.pos_ffn.ffn.0.weight", + [args.inner_size, args.hidden_size]], + ["1.encoder.layers.{}.pos_ffn.ffn.0.bias", [args.inner_size]], + ["1.encoder.layers.{}.pos_ffn.ffn.3.weight", + [args.hidden_size, args.inner_size]], + ["1.encoder.layers.{}.pos_ffn.ffn.3.bias", [args.hidden_size]], + ["1.encoder.layers.{}.norm2.norm.weight", [args.hidden_size]], + ["1.encoder.layers.{}.norm2.norm.bias", [args.hidden_size]], + ]) + attributes_legcy = { + "hidden_size": args.hidden_size, + "num_layers": args.num_layers, + "head_num": args.head_num, + "head_dim": args.head_dim, + "inner_size": args.inner_size, + "act_type": "gelu", + "normalize_before": 1, + "is_fmha": 1, + "atten_scaler": 1 / math.sqrt(args.head_dim) + } + + + attributes = { + "hidden_size": int(args.hidden_size), + "num_layers": int(args.num_layers), + "head_num": int(args.head_num), + "head_dim": int(args.head_dim), + "inner_size": int(args.inner_size), + "act_type": 12, #gelu + "normalize_before": 1, + "is_fmha": 1, + "atten_scaler": 1.0 / math.sqrt(args.head_dim), + "max_seq_len": int(args.max_seq_len), + "max_batch_size": int(args.max_batch_size), + + } + + t = graph + inputs = [ + graph.get_variable('hidden_state'), + graph.get_variable('attention_mask'), + ] + outputs = [t.make_variable("encoder_out", dtype=DataType.FLOAT16)] + for layer_id in range(args.num_layers): + for key, shape in enc_tensor_layer_fp16_keys.items(): + # we need cat qkv gemm's weight and bias + new_key = key.format(layer_id) + w = state_dict[new_key] + if list(w.shape) != shape: + print("weights shape error!") + print("key: ", key) + print("need shape: ", shape) + print("weight shape: ", w.shape) + exit(1) + inputs.append(t.make_variable(name=new_key, value=w.half())) + t.make_operator( + "TransformerEncoderFp16_IxRT", inputs=inputs, outputs=outputs, **attributes + ) + + +def add_layer_norm_op(graph, state_dict, args): + enc_ln_tensor_fp16_keys = OrderedDict([ + ["1.encoder.norm.norm.weight", [args.hidden_size]], + ["1.encoder.norm.norm.bias", [args.hidden_size]], + ]) + attributes = { + "epsilon": 1e-5, + "axis": -1, + "stash_type": 1 + } + t = graph + inputs = [ + graph.get_variable('encoder_out'), + ] + outputs = [t.make_variable("encoder_ln_out")] + for key, shape in enc_ln_tensor_fp16_keys.items(): + new_key = key + w = state_dict[new_key] + if list(w.shape) != shape: + print("weights shape error!") + print("key: ", key) + print("need shape: ", shape) + print("weight shape: ", w.shape) + exit(1) + inputs.append(t.make_variable(name=new_key, value=w.half())) + t.make_operator( + "LayerNormalization", inputs=inputs, outputs=outputs, **attributes + ) + + +# def add_layer_norm_op(graph, state_dict, args): +# enc_ln_tensor_fp16_keys = OrderedDict([ +# ["1.encoder.norm.norm.weight", [args.hidden_size]], +# ["1.encoder.norm.norm.bias", [args.hidden_size]], +# ]) +# attributes = { +# "hidden_size": args.hidden_size, +# } +# t = graph +# inputs = [ +# graph.get_variable('encoder_out'), +# ] +# outputs = [t.make_variable("encoder_ln_out",dtype=DataType.FLOAT16)] +# for key, shape in enc_ln_tensor_fp16_keys.items(): +# new_key = key +# w = state_dict[new_key] +# if list(w.shape) != shape: +# print("weights shape error!") +# print("key: ", key) +# print("need shape: ", shape) +# print("weight shape: ", w.shape) +# exit(1) +# inputs.append(t.make_variable(name=new_key, value=w.half())) +# t.make_operator( +# "LayerNormFp16", inputs=inputs, outputs=outputs, **attributes +# ) + +def add_linear_op(graph, state_dict, args): + linear_keys = [ + "3.w.weight", + "3.w.bias" + ] + W = numpy_helper.from_array(state_dict[linear_keys[0]].cpu().numpy(), name="W") + B = numpy_helper.from_array(state_dict[linear_keys[1]].cpu().numpy(), name="B") + attributes = { + "out_dims": state_dict["3.w.weight"].size(0), + "type_id": 1, + "W": W, + "B": B, + } + assert state_dict['3.w.weight'].size(0) == state_dict["3.w.bias"].size(0) + + t = graph + inputs = [ + graph.get_variable('encoder_ln_out'), + ] + + outputs = [t.make_variable("lin_output")] + t.make_operator( + "CustomFCPluginDynamic_IxRT", inputs=inputs, outputs=outputs, **attributes + ) + + +# +# def add_linear_op(graph, state_dict, args): +# lin_keys = [ +# "3.w.weight", +# "3.w.bias" +# ] +# attributes = { +# "linear_dim": state_dict["3.w.weight"].size(0), +# "hidden_size": state_dict["3.w.weight"].size(1), +# "has_bias": 1, +# "act_type": "none", +# } +# assert state_dict['3.w.weight'].size(0) == state_dict["3.w.bias"].size(0) +# +# t = graph +# inputs = [ +# graph.get_variable('encoder_ln_out'), +# ] +# +# outputs = [t.make_variable("lin_output",dtype=DataType.FLOAT16)] +# for key in lin_keys: +# inputs.append(t.make_variable(name=key, value=state_dict[key].half())) +# t.make_operator( +# "LinearFP16", inputs=inputs, outputs=outputs, **attributes +# ) + + +def add_log_softmax_op(graph, state_dict, args): + attributes = { + "axis": "-1", + } + + t = graph + inputs = [ + graph.get_variable('lin_output'), + ] + + outputs = [t.make_variable("log_softmax_output", dtype=DataType.FLOAT16)] + + t.make_operator( + "LogSoftmax", inputs=inputs, outputs=outputs, **attributes + ) + + +def add_search_node(graph, state_dict, args): + attributes = { + "vocab_size": args.vocab_size, + "eos_id": args.vocab_size, + "pad_id": -10000, + "beam_size": 1, + "attr1": 1.0, + "min_decode_ratio": 0.0, + "max_decode_ratio": 1.0, + "ctc_weight": 0.40, + "using_eos_threshold": 0, + "length_normalization": 1, + } + t = graph + inputs = [ + graph.get_variable('lin_output'), + ] + + outputs = [t.make_variable("output_tokens", dtype=DataType.INT32)] + list_value_half = [] + list_key_half = [] + for key in state_dict.keys(): + if "decoder" in key or "custom_tgt_module" in key or "2.w.weight" in key or "2.w.bias" in key: + list_key_half.append(key) + list_value_half.append(state_dict[key].half()) + for i, item in enumerate(list_key_half): + inputs.append(t.make_variable(name=list_key_half[i], value=list_value_half[i])) + t.make_operator( + "Search", inputs=inputs, outputs=outputs, **attributes + ) + + +def get_num_layers(state_dict): + num_layers = -1 + for key in state_dict: + layer_id = re.search( + "1.encoder.layers.([0-9]+).pos_ffn.ffn.0.bias", key) + if layer_id: + layer_id = layer_id.group(1) + num_layers = max(num_layers, int(layer_id) + 1) + assert num_layers > 0 + return num_layers + + +def build_engine(onnx_file, engine_file, max_batch_size,max_seq_len): + cmd = f"ixrtexec --onnx {onnx_file} --min_shape input:1x32x5120,length_radio:1 --opt_shape input:8x64x5120,length_radio:8 --max_shape input:{max_batch_size}x{max_seq_len}x5120,length_radio:64 --plugins ixrt_plugin --save_engine {engine_file}" + subprocess.run(cmd.split(), check=True) + + +def main(args): + graph = Graph() + transform = GraphTransform(graph) + ckpt_path = glob.glob(os.path.join(args.ckpt_path, "*/model.ckpt"))[0] + print("load ckpt from: ", ckpt_path) + state_dict = torch.load(ckpt_path) + + # print([i for i in state_dict ]) + # print(state_dict['3.w.bias']) + args.hidden_size = state_dict['1.encoder.layers.0.norm1.norm.weight'].size( + 0) + args.head_dim = args.hidden_size / args.head_num + args.inner_size = state_dict['1.encoder.layers.0.pos_ffn.ffn.0.bias'].size( + 0) + args.vocab_size = state_dict['3.w.weight'].size(0) + + args.num_layers = get_num_layers(state_dict) + + args.src_len = state_dict["1.custom_src_module.layers.0.w.weight"].size(1) + + # args.num_layers = 1 + add_make_mask_op(transform, state_dict, args) + add_custom_linear_op(transform, state_dict, args) + add_pos_encode_op(transform, state_dict, args) + add_transformer_op(transform, state_dict, args) + add_layer_norm_op(transform, state_dict, args) + # add_linear_op(transform, state_dict, args) + # add_log_softmax_op(transform, state_dict, args) + # add_search_node(transform, state_dict, args) + + # IO attributes + length_radio = graph.get_variable('length_radio') + length_radio.set_shape(["batch_size"]) + length_radio.dtype = "float16" + graph.add_input(length_radio) + + input = graph.get_variable('input') + input.set_shape(["batch_size", "seq_len", "src_len"]) + input.dtype = "float16" + graph.add_input(input) + + output = graph.get_variable('encoder_ln_out') + output.dtype = "float16" + graph.add_output(output) + + create_target(saved_path=args.onnx_path).export(graph) + + build_engine(args.onnx_path, args.engine_path, args.max_batch_size, args.max_seq_len) + print("save engine: ", args.engine_path) + + +if __name__ == "__main__": + args = parse_args() + ckpt_path = args.ckpt_path + + main(args) + +""" +python3 builder.py \ +--ckpt_path results/transformer/8886/save \ +--head_num 4 \ +--max_batch_size 64 \ +--max_seq_len 1024 \ +--engine_path transformer.engine +""" diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/convert.py b/models/speech/speech_recognition/transformer_asr/ixrt/convert.py new file mode 100644 index 00000000..11d71a56 --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/convert.py @@ -0,0 +1,95 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +import torch +from faster_layer_norm import FasterLayerNorm + +def replace_layer_norm(model): + module_output = model + + if isinstance(model, torch.nn.modules.normalization.LayerNorm): + return FasterLayerNorm(model.weight, model.bias) + + for name, child in model.named_children(): + module_output.add_module( + name, replace_layer_norm(child) + ) + return module_output + + +def convert_decoder_model(model): + model = replace_layer_norm(model) + # for layer in model.layers: + # norm = layer.norm1.norm + # print(type(norm)) + # exit() + # new_norm = FasterLayerNorm(norm.weight, norm.bias) + # layer.norm1.norm = new_norm + + # norm = layer.norm2.norm + # new_norm = FasterLayerNorm(norm.weight, norm.bias) + # layer.norm2.norm = new_norm + + # norm = layer.norm3.norm + # new_norm = FasterLayerNorm(norm.weight, norm.bias) + # layer.norm3.norm = new_norm + return model + +# def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''): +# if type(module) in layers: +# return {name: module} +# res = {} +# for name1, child in module.named_children(): +# res.update(find_layers( +# child, layers=layers, name=name + '.' + name1 if name != '' else name1 +# )) +# return res + +def find_node(module): + if type(module) in [torch.nn.LayerNorm]: + print(module) + return + res = {} + for name1, child in module.named_children(): + find_node(child) + return + + +def patch_get_lookahead_mask(padded_input): + """Creates a binary mask for each sequence which maskes future frames. + + Arguments + --------- + padded_input: torch.Tensor + Padded input tensor. + + Example + ------- + >>> a = torch.LongTensor([[1,1,0], [2,3,0], [4,5,0]]) + >>> get_lookahead_mask(a) + tensor([[0., -inf, -inf], + [0., 0., -inf], + [0., 0., 0.]]) + """ + seq_len = padded_input.shape[1] + mask = ( + torch.triu(torch.ones((seq_len, seq_len), device=padded_input.device)) + == 1 + ).transpose(0, 1) + mask = ( + mask.float() + .masked_fill(mask == 0, float("-inf")) + .masked_fill(mask == 1, float(0.0)) + ) + return mask.detach().to(padded_input.device).to(torch.float16) \ No newline at end of file diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/ctc.py b/models/speech/speech_recognition/transformer_asr/ixrt/ctc.py new file mode 100644 index 00000000..9db6ab7e --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/ctc.py @@ -0,0 +1,394 @@ +"""Decoders and output normalization for CTC. + +Authors + * Mirco Ravanelli 2020 + * Aku Rouhe 2020 + * Sung-Lin Yeh 2020 +""" +import torch +from itertools import groupby +from speechbrain.dataio.dataio import length_to_mask +from faster_logsumexp import FasterLogSumExp +from faster_stack import FasterStack +from faster_cat import FastCat + + +class CTCPrefixScorer: + """This class implements the CTC prefix scorer of Algorithm 2 in + reference: https://www.merl.com/publications/docs/TR2017-190.pdf. + Official implementation: https://github.com/espnet/espnet/blob/master/espnet/nets/ctc_prefix_score.py + + Arguments + --------- + x : torch.Tensor + The encoder states. + enc_lens : torch.Tensor + The actual length of each enc_states sequence. + batch_size : int + The size of the batch. + beam_size : int + The width of beam. + blank_index : int + The index of the blank token. + eos_index : int + The index of the end-of-sequence (eos) token. + ctc_window_size: int + Compute the ctc scores over the time frames using windowing based on attention peaks. + If 0, no windowing applied. + """ + + def __init__( + self, + x, + enc_lens, + batch_size, + beam_size, + blank_index, + eos_index, + ctc_window_size=0, + ): + self.blank_index = blank_index + self.eos_index = eos_index + self.max_enc_len = x.size(1) + self.batch_size = batch_size + self.beam_size = beam_size + self.vocab_size = x.size(-1) + self.device = x.device + self.minus_inf = -1e4 + self.last_frame_index = enc_lens - 1 + self.ctc_window_size = ctc_window_size + + # mask frames > enc_lens + mask = 1 - length_to_mask(enc_lens) + mask = mask.unsqueeze(-1).expand(-1, -1, x.size(-1)).eq(1) + x.masked_fill_(mask, self.minus_inf) + x[:, :, 0] = x[:, :, 0].masked_fill_(mask[:, :, 0], 0) + + # dim=0: xnb, nonblank posteriors, dim=1: xb, blank posteriors + xnb = x.transpose(0, 1) + xb = ( + xnb[:, :, self.blank_index] + .unsqueeze(2) + .expand(-1, -1, self.vocab_size) + ) + + # (2, L, batch_size * beam_size, vocab_size) + # self.x = torch.stack([xnb, xb]) + self.x = FasterStack([xnb.contiguous(), xb.contiguous()]) + + # The first index of each sentence. + self.beam_offset = ( + torch.arange(batch_size, device=self.device) * self.beam_size + ) + # The first index of each candidates. + self.cand_offset = ( + torch.arange(batch_size, device=self.device) * self.vocab_size + ) + + def forward_step(self, g, state, candidates=None, attn=None): + """This method if one step of forwarding operation + for the prefix ctc scorer. + + Arguments + --------- + g : torch.Tensor + The tensor of prefix label sequences, h = g + c. + state : tuple + Previous ctc states. + candidates : torch.Tensor + (batch_size * beam_size, ctc_beam_size), The topk candidates for rescoring. + The ctc_beam_size is set as 2 * beam_size. If given, performing partial ctc scoring. + """ + + prefix_length = g.size(1) + last_char = [gi[-1] for gi in g] if prefix_length > 0 else [0] * len(g) + self.num_candidates = ( + self.vocab_size if candidates is None else candidates.size(-1) + ) + if state is None: + # r_prev: (L, 2, batch_size * beam_size) + r_prev = torch.full( + (self.max_enc_len, 2, self.batch_size, self.beam_size), + self.minus_inf, + device=self.device, + dtype=torch.float16 + ) + + # Accumulate blank posteriors at each step + r_prev[:, 1] = torch.cumsum( + self.x[0, :, :, self.blank_index], 0 + ).unsqueeze(2) + r_prev = r_prev.view(-1, 2, self.batch_size * self.beam_size) + psi_prev = 0.0 + else: + r_prev, psi_prev = state + r_prev = r_prev.half() + + # for partial search + if candidates is not None: + scoring_table = torch.full( + (self.batch_size * self.beam_size, self.vocab_size), + -1, + dtype=torch.long, + device=self.device, + ) + # Assign indices of candidates to their positions in the table + col_index = torch.arange( + self.batch_size * self.beam_size, device=self.device + ).unsqueeze(1) + scoring_table[col_index, candidates] = torch.arange( + self.num_candidates, device=self.device + ) + # Select candidates indices for scoring + scoring_index = ( + candidates + + self.cand_offset.unsqueeze(1) + .repeat(1, self.beam_size) + .view(-1, 1) + ).view(-1) + x_inflate = torch.index_select( + self.x.view(2, -1, self.batch_size * self.vocab_size), + 2, + scoring_index, + ).view(2, -1, self.batch_size * self.beam_size, self.num_candidates) + # for full search + else: + scoring_table = None + x_inflate = ( + self.x.unsqueeze(3) + .repeat(1, 1, 1, self.beam_size, 1) + .view( + 2, -1, self.batch_size * self.beam_size, self.num_candidates + ) + ) + + # Prepare forward probs + r = torch.full( + ( + self.max_enc_len, + 2, + self.batch_size * self.beam_size, + self.num_candidates, + ), + self.minus_inf, + device=self.device, + dtype=torch.float16 + ) + r.fill_(self.minus_inf) + + # (Alg.2-6) + if prefix_length == 0: + r[0, 0] = x_inflate[0, 0] + # (Alg.2-10): phi = prev_nonblank + prev_blank = r_t-1^nb(g) + r_t-1^b(g) + r_sum = FasterLogSumExp(r_prev, 1) + phi = r_sum.unsqueeze(2).repeat(1, 1, self.num_candidates) + + # (Alg.2-10): if last token of prefix g in candidates, phi = prev_b + 0 + if candidates is not None: + for i in range(self.batch_size * self.beam_size): + pos = scoring_table[i, last_char[i]] + if pos != -1: + phi[:, i, pos] = r_prev[:, 1, i] + else: + for i in range(self.batch_size * self.beam_size): + phi[:, i, last_char[i]] = r_prev[:, 1, i] + + # Start, end frames for scoring (|g| < |h|). + # Scoring based on attn peak if ctc_window_size > 0 + if self.ctc_window_size == 0 or attn is None: + start = max(1, prefix_length) + end = self.max_enc_len + else: + _, attn_peak = torch.max(attn, dim=1) + max_frame = torch.max(attn_peak).item() + self.ctc_window_size + min_frame = torch.min(attn_peak).item() - self.ctc_window_size + start = max(max(1, prefix_length), int(min_frame)) + end = min(self.max_enc_len, int(max_frame)) + + # Compute forward prob log(r_t^nb(h)) and log(r_t^b(h)): + for t in range(start, end): + # (Alg.2-11): dim=0, p(h|cur step is nonblank) = [p(prev step=y) + phi] * p(c) + rnb_prev = r[t - 1, 0] + # (Alg.2-12): dim=1, p(h|cur step is blank) = [p(prev step is blank) + p(prev step is nonblank)] * p(blank) + rb_prev = r[t - 1, 1] + # r_ = torch.stack([rnb_prev, phi[t - 1], rnb_prev, rb_prev]).view( + # 2, 2, self.batch_size * self.beam_size, self.num_candidates + # ) + r_ = FasterStack([rnb_prev, phi[t - 1], rnb_prev, rb_prev]).view( + 2, 2, self.batch_size * self.beam_size, self.num_candidates + ) + r[t] = FasterLogSumExp(r_, 1) + x_inflate[:, t] + + # Compute the predix prob, psi + psi_init = r[start - 1, 0].unsqueeze(0) + # phi is prob at t-1 step, shift one frame and add it to the current prob p(c) + phix = FastCat((phi[0].unsqueeze(0), phi[:-1]), dim=0) + x_inflate[0] + + # (Alg.2-13): psi = psi + phi * p(c) + if candidates is not None: + psi = torch.full( + (self.batch_size * self.beam_size, self.vocab_size), + self.minus_inf, + device=self.device, + dtype=torch.float16 + ) + psi_ = FasterLogSumExp( + FastCat((phix[start:end], psi_init), dim=0), dim=0 + ) + # only assign prob to candidates + for i in range(self.batch_size * self.beam_size): + psi[i, candidates[i]] = psi_[i] + else: + psi = FastCat((phix[start:end], psi_init), dim=0) + psi = FasterLogSumExp(psi, dim=0) + + # (Alg.2-3): if c = , psi = log(r_T^n(g) + r_T^b(g)), where T is the length of max frames + for i in range(self.batch_size * self.beam_size): + psi[i, self.eos_index] = r_sum[ + self.last_frame_index[i // self.beam_size], i + ] + + # Exclude blank probs for joint scoring + psi[:, self.blank_index] = self.minus_inf + + return psi - psi_prev, (r, psi, scoring_table) + + def permute_mem(self, memory, index): + """This method permutes the CTC model memory + to synchronize the memory index with the current output. + + Arguments + --------- + memory : No limit + The memory variable to be permuted. + index : torch.Tensor + The index of the previous path. + + Return + ------ + The variable of the memory being permuted. + + """ + r, psi, scoring_table = memory + # The index of top-K vocab came from in (t-1) timesteps. + best_index = ( + index + + (self.beam_offset.unsqueeze(1).expand_as(index) * self.vocab_size) + ).view(-1) + # synchronize forward prob + psi = torch.index_select(psi.view(-1), dim=0, index=best_index) + psi = ( + psi.view(-1, 1) + .repeat(1, self.vocab_size) + .view(self.batch_size * self.beam_size, self.vocab_size) + ) + + # synchronize ctc states + if scoring_table is not None: + effective_index = ( + index // self.vocab_size + self.beam_offset.view(-1, 1) + ).view(-1) + selected_vocab = (index % self.vocab_size).view(-1) + score_index = scoring_table[effective_index, selected_vocab] + score_index[score_index == -1] = 0 + best_index = score_index + effective_index * self.num_candidates + + r = torch.index_select( + r.view( + -1, 2, self.batch_size * self.beam_size * self.num_candidates + ), + dim=-1, + index=best_index, + ) + r = r.view(-1, 2, self.batch_size * self.beam_size) + + return r, psi + + +def filter_ctc_output(string_pred, blank_id=-1): + """Apply CTC output merge and filter rules. + + Removes the blank symbol and output repetitions. + + Arguments + --------- + string_pred : list + A list containing the output strings/ints predicted by the CTC system. + blank_id : int, string + The id of the blank. + + Returns + ------- + list + The output predicted by CTC without the blank symbol and + the repetitions. + + Example + ------- + >>> string_pred = ['a','a','blank','b','b','blank','c'] + >>> string_out = filter_ctc_output(string_pred, blank_id='blank') + >>> print(string_out) + ['a', 'b', 'c'] + """ + + if isinstance(string_pred, list): + # Filter the repetitions + string_out = [ + v + for i, v in enumerate(string_pred) + if i == 0 or v != string_pred[i - 1] + ] + + # Remove duplicates + string_out = [i[0] for i in groupby(string_out)] + + # Filter the blank symbol + string_out = list(filter(lambda elem: elem != blank_id, string_out)) + else: + raise ValueError("filter_ctc_out can only filter python lists") + return string_out + + +def ctc_greedy_decode(probabilities, seq_lens, blank_id=-1): + """Greedy decode a batch of probabilities and apply CTC rules. + + Arguments + --------- + probabilities : torch.tensor + Output probabilities (or log-probabilities) from the network with shape + [batch, probabilities, time] + seq_lens : torch.tensor + Relative true sequence lengths (to deal with padded inputs), + the longest sequence has length 1.0, others a value between zero and one + shape [batch, lengths]. + blank_id : int, string + The blank symbol/index. Default: -1. If a negative number is given, + it is assumed to mean counting down from the maximum possible index, + so that -1 refers to the maximum possible index. + + Returns + ------- + list + Outputs as Python list of lists, with "ragged" dimensions; padding + has been removed. + + Example + ------- + >>> import torch + >>> probs = torch.tensor([[[0.3, 0.7], [0.0, 0.0]], + ... [[0.2, 0.8], [0.9, 0.1]]]) + >>> lens = torch.tensor([0.51, 1.0]) + >>> blank_id = 0 + >>> ctc_greedy_decode(probs, lens, blank_id) + [[1], [1]] + """ + if isinstance(blank_id, int) and blank_id < 0: + blank_id = probabilities.shape[-1] + blank_id + batch_max_len = probabilities.shape[1] + batch_outputs = [] + for seq, seq_len in zip(probabilities, seq_lens): + actual_size = int(torch.round(seq_len * batch_max_len)) + scores, predictions = torch.max(seq.narrow(0, 0, actual_size), dim=1) + out = filter_ctc_output(predictions.tolist(), blank_id=blank_id) + batch_outputs.append(out) + return batch_outputs diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/faster_cat/__init__.py b/models/speech/speech_recognition/transformer_asr/ixrt/faster_cat/__init__.py new file mode 100644 index 00000000..537d35c5 --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/faster_cat/__init__.py @@ -0,0 +1,13 @@ +import torch +from faster_cat import sp_opt + +def FastCat(inputs,dim=0): + if len(inputs) == 2 and dim==0: + a,b = inputs + in_shape = a.shape + if len(in_shape)>1: + res, = sp_opt.test_opt_2(a.view(a.shape[0],-1),b.view(b.shape[0],-1)) + new_shape = (a.shape[0]+b.shape[0],) + in_shape[1:] + res = res.view(*new_shape) + return res + return torch.cat(inputs,dim=dim) \ No newline at end of file diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/faster_cat/build.sh b/models/speech/speech_recognition/transformer_asr/ixrt/faster_cat/build.sh new file mode 100644 index 00000000..f679258d --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/faster_cat/build.sh @@ -0,0 +1,22 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +set -euox pipefail + +rm -rf build +rm -rf *.so + +python3 setup.py build + +cp build/lib*/*.so . \ No newline at end of file diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/faster_cat/kernel.cu b/models/speech/speech_recognition/transformer_asr/ixrt/faster_cat/kernel.cu new file mode 100644 index 00000000..022fac39 --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/faster_cat/kernel.cu @@ -0,0 +1,79 @@ +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace iluvatar::inferrt::transformer { + +__global__ void Cat(half* a, half* b, half* output, int m1, int m2, int k) { + int i = blockIdx.y * blockDim.x + threadIdx.x; + // a + if (blockIdx.x < m1) { + half2* h2_a = reinterpret_cast(a + blockIdx.x * k); + half2* h2_out_a = reinterpret_cast(output + blockIdx.x * k); + if (i < k / 2) { + h2_out_a[i] = h2_a[i]; + } + } + // b + if (blockIdx.x < m2) { + half2* h2_b = reinterpret_cast(b + blockIdx.x * k); + half2* h2_out_b = + reinterpret_cast(output + blockIdx.x * k + m1 * k); + if (i < k / 2) { + h2_out_b[i] = h2_b[i]; + } + } +} + +void IxinferCatLauncher(half* a, half* b, half* output, int m1, int m2, int k, + cudaStream_t stream) { + if (k % 2 != 0) { + throw std::runtime_error("IxinferStackLauncher: size error!"); + } + int m = std::max(m1, m2); + int num_threads = 1024; + int half_k = k / 2; + int num_roll = (half_k - 1 + num_threads) / num_threads; + dim3 grid(m, num_roll); + dim3 block(num_threads); + Cat<<>>(a, b, output, m1, m2, k); +} + +} // namespace iluvatar::inferrt::transformer + +std::vector one_test_opt_2(at::Tensor a, at::Tensor b) { + TORCH_CHECK(a.scalar_type() == at::ScalarType::Half); + TORCH_CHECK(a.is_cuda()); + TORCH_CHECK(a.is_contiguous()); + + TORCH_CHECK(b.scalar_type() == at::ScalarType::Half); + TORCH_CHECK(b.is_cuda()); + TORCH_CHECK(b.is_contiguous()); + + TORCH_CHECK(a.dim() == 2); + TORCH_CHECK(b.dim() == 2); + + int m1 = a.size(0); + int m2 = b.size(0); + + int k = a.size(1); + + TORCH_CHECK(b.size(1) == k); + + at::Tensor output = a.new_empty({(m1 + m2), k}); + + half* p_a = (half*)a.data_ptr(); + half* p_b = (half*)b.data_ptr(); + half* p_out = (half*)output.data_ptr(); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + iluvatar::inferrt::transformer::IxinferCatLauncher(p_a, p_b, p_out, m1, m2, k, + stream); + return {output}; +} diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/faster_cat/setup.py b/models/speech/speech_recognition/transformer_asr/ixrt/faster_cat/setup.py new file mode 100644 index 00000000..a031577c --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/faster_cat/setup.py @@ -0,0 +1,48 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +import os +import glob +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +CUR_DIR = os.path.dirname(os.path.abspath(__file__)) + +# cpp_files = glob.glob(os.path.join(CUR_DIR,"*.cpp")) +# cu_files = glob.glob(os.path.join(CUR_DIR,'*.cu')) +# source_files = cpp_files + cu_files +# print("source files:") +# for i in source_files: +# print(i) +source_files = [ + os.path.join(CUR_DIR,'test.cpp'), + os.path.join(CUR_DIR,'kernel.cu'), +] + +for i in source_files: + assert os.path.isfile(i) + print(i) + +setup( + name="test", + ext_modules=[ + CUDAExtension( + name="sp_opt", + libraries=["cuinfer"], + sources=source_files) + ], + cmdclass={ + "build_ext": BuildExtension + } +) diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/faster_cat/test.cpp b/models/speech/speech_recognition/transformer_asr/ixrt/faster_cat/test.cpp new file mode 100644 index 00000000..11720811 --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/faster_cat/test.cpp @@ -0,0 +1,21 @@ +#include +#include +#include +#include +#include +#include + +#include +#include +#include + + +std::vector one_test_opt_2(at::Tensor a, at::Tensor b); + +std::vector test_opt_2(at::Tensor a, at::Tensor b) { + return one_test_opt_2(a, b); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("test_opt_2", &test_opt_2, ""); +} diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/faster_cat/test.py b/models/speech/speech_recognition/transformer_asr/ixrt/faster_cat/test.py new file mode 100644 index 00000000..2713dae2 --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/faster_cat/test.py @@ -0,0 +1,37 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +import torch +import sp_opt + +if __name__ == "__main__": + m1 = 320 + m2 = 321 + hidden_size = 5000 + + a = torch.randn([m1,hidden_size]).cuda().half() + b = torch.randn([m2,hidden_size]).cuda().half() + + + res_pt = torch.cat([a,b],dim=0) + + res_cu, = sp_opt.test_opt_2(a,b) + + + diff = torch.abs(res_pt-res_cu) + print(diff) + print(diff.max()) + + for i in range(20): + res_cu, = sp_opt.test_opt_2(a,b) \ No newline at end of file diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/faster_layer_norm/__init__.py b/models/speech/speech_recognition/transformer_asr/ixrt/faster_layer_norm/__init__.py new file mode 100644 index 00000000..20603650 --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/faster_layer_norm/__init__.py @@ -0,0 +1,16 @@ +import torch +from faster_layer_norm import sp_opt + +class FasterLayerNorm(torch.nn.Module): + def __init__(self, weight, bias): + super(FasterLayerNorm, self).__init__() + self.weight = weight + self.bias = bias + + def forward(self, inputs, *args, **kwargs): + hidden_size = self.weight.size(0) + in_shape = inputs.shape + inputs = inputs.view(-1,hidden_size) + output, = sp_opt.test_opt(inputs,self.weight,self.bias) + output = output.view(*in_shape) + return output diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/faster_layer_norm/build.sh b/models/speech/speech_recognition/transformer_asr/ixrt/faster_layer_norm/build.sh new file mode 100644 index 00000000..f679258d --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/faster_layer_norm/build.sh @@ -0,0 +1,22 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +set -euox pipefail + +rm -rf build +rm -rf *.so + +python3 setup.py build + +cp build/lib*/*.so . \ No newline at end of file diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/faster_layer_norm/kernel.cu b/models/speech/speech_recognition/transformer_asr/ixrt/faster_layer_norm/kernel.cu new file mode 100644 index 00000000..852db917 --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/faster_layer_norm/kernel.cu @@ -0,0 +1,168 @@ +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "transformer_helper.cuh" + +namespace iluvatar::inferrt::transformer { + +template +__global__ void LnOpt2Kernel(half* input, half* ln_weight, half* ln_bias, + half* output, int hidden_size, + float layernorm_eps) { + input += blockIdx.x * hidden_size; + output += blockIdx.x * hidden_size; + + half2* p_in = reinterpret_cast(input); + half2* p_out = reinterpret_cast(output); + half2* p_wei = reinterpret_cast(ln_weight); + half2* p_bias = reinterpret_cast(ln_bias); + int half_hidden_size = hidden_size / 2; + + extern __shared__ half2 shmem[]; + + float s_mean; + float s_variance; + float x_sum = 0.0f; + float x2_sum = 0.0f; +#pragma unroll UNROLL_FACTOR + for (int i = 0; i < UNROLL_FACTOR; ++i) { + int index = i * blockDim.x + threadIdx.x; + if (index < half_hidden_size) { + half2 value = p_in[index]; + shmem[index] = value; + float val_1 = __half2float(value.x); + float val_2 = __half2float(value.y); + x_sum += val_1 + val_2; + x2_sum += val_1 * val_1 + val_2 * val_2; + } + } + float sums[2]; // 和,平方和 + sums[0] = x_sum; + sums[1] = x2_sum; + blockReduceSumV2(sums); + + s_mean = sums[0] / hidden_size; + s_variance = rsqrtf(sums[1] / hidden_size - s_mean * s_mean + layernorm_eps); + +#pragma unroll UNROLL_FACTOR + for (int i = 0; i < UNROLL_FACTOR; ++i) { + int index = i * blockDim.x + threadIdx.x; + if (index < half_hidden_size) { + half2 wei_value = p_wei[index]; + half2 bias_value = p_bias[index]; + half2 vals_value = shmem[index]; + + float2 norm_value; + norm_value.x = (__half2float(vals_value.x) - s_mean) * s_variance * + __half2float(wei_value.x) + + __half2float(bias_value.x); + norm_value.y = (__half2float(vals_value.y) - s_mean) * s_variance * + __half2float(wei_value.y) + + __half2float(bias_value.y); + + __half2 res; + res.x = __float2half(norm_value.x); + res.y = __float2half(norm_value.y); + + p_out[index] = res; + } + } +} + +// FasterTransformer/src/fastertransformer/kernels/layernorm_kernels.cu +void IxinferLnLauncherOpt2(__half* input, __half* ln_weight, __half* ln_bias, + __half* output, int batch_tokens, int hidden_size, + cudaStream_t stream) { + const float layernorm_eps = 1e-5; + if (hidden_size % 2 != 0) { + throw std::runtime_error("layer norm error: hidden_size % 2 != 0"); + } + dim3 grid(batch_tokens); + int half_n = hidden_size / 2; + int half_n_warp = (half_n + warpSize - 1) / warpSize * warpSize; + dim3 block(std::min(half_n_warp, 1024)); + int rolls_per_thread = (half_n + block.x - 1) / block.x; + switch (rolls_per_thread) { + case 1: + LnOpt2Kernel<1><<>>( + input, ln_weight, ln_bias, output, hidden_size, layernorm_eps); + break; + case 2: + LnOpt2Kernel<2><<>>( + input, ln_weight, ln_bias, output, hidden_size, layernorm_eps); + break; + case 3: + LnOpt2Kernel<3><<>>( + input, ln_weight, ln_bias, output, hidden_size, layernorm_eps); + break; + case 4: + LnOpt2Kernel<4><<>>( + input, ln_weight, ln_bias, output, hidden_size, layernorm_eps); + break; + case 5: + LnOpt2Kernel<5><<>>( + input, ln_weight, ln_bias, output, hidden_size, layernorm_eps); + break; + case 6: + LnOpt2Kernel<6><<>>( + input, ln_weight, ln_bias, output, hidden_size, layernorm_eps); + break; + case 7: + LnOpt2Kernel<7><<>>( + input, ln_weight, ln_bias, output, hidden_size, layernorm_eps); + break; + case 8: + LnOpt2Kernel<8><<>>( + input, ln_weight, ln_bias, output, hidden_size, layernorm_eps); + break; + default: + std::cout << "hidden_size: " << hidden_size << std::endl; + throw std::runtime_error("layer norm error, unsupport hidden size! "); + break; + } +} +} // namespace iluvatar::inferrt::transformer + +std::vector one_test_opt(at::Tensor input, at::Tensor ln_weight, + at::Tensor ln_bias) { + TORCH_CHECK(input.scalar_type() == at::ScalarType::Half); + TORCH_CHECK(input.is_cuda()); + TORCH_CHECK(input.is_contiguous()); + + TORCH_CHECK(ln_weight.scalar_type() == at::ScalarType::Half); + TORCH_CHECK(ln_weight.is_cuda()); + TORCH_CHECK(ln_weight.is_contiguous()); + + TORCH_CHECK(ln_bias.scalar_type() == at::ScalarType::Half); + TORCH_CHECK(ln_bias.is_cuda()); + TORCH_CHECK(ln_bias.is_contiguous()); + + TORCH_CHECK(input.dim() == 2); + TORCH_CHECK(ln_weight.dim() == 1); + TORCH_CHECK(ln_bias.dim() == 1); + + int batch_tokens = input.size(0); + int hidden_size = input.size(1); + + TORCH_CHECK(ln_weight.size(0) == hidden_size); + TORCH_CHECK(ln_bias.size(0) == hidden_size); + + at::Tensor output = at::empty_like(input); + + half* p_in = (half*)input.data_ptr(); + half* p_wei = (half*)ln_weight.data_ptr(); + half* p_bias = (half*)ln_bias.data_ptr(); + half* p_out = (half*)output.data_ptr(); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + iluvatar::inferrt::transformer::IxinferLnLauncherOpt2( + p_in, p_wei, p_bias, p_out, batch_tokens, hidden_size, stream); + return {output}; +} diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/faster_layer_norm/setup.py b/models/speech/speech_recognition/transformer_asr/ixrt/faster_layer_norm/setup.py new file mode 100644 index 00000000..a031577c --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/faster_layer_norm/setup.py @@ -0,0 +1,48 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +import os +import glob +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +CUR_DIR = os.path.dirname(os.path.abspath(__file__)) + +# cpp_files = glob.glob(os.path.join(CUR_DIR,"*.cpp")) +# cu_files = glob.glob(os.path.join(CUR_DIR,'*.cu')) +# source_files = cpp_files + cu_files +# print("source files:") +# for i in source_files: +# print(i) +source_files = [ + os.path.join(CUR_DIR,'test.cpp'), + os.path.join(CUR_DIR,'kernel.cu'), +] + +for i in source_files: + assert os.path.isfile(i) + print(i) + +setup( + name="test", + ext_modules=[ + CUDAExtension( + name="sp_opt", + libraries=["cuinfer"], + sources=source_files) + ], + cmdclass={ + "build_ext": BuildExtension + } +) diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/faster_layer_norm/test.cpp b/models/speech/speech_recognition/transformer_asr/ixrt/faster_layer_norm/test.cpp new file mode 100644 index 00000000..f925c1b4 --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/faster_layer_norm/test.cpp @@ -0,0 +1,22 @@ +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +std::vector one_test_opt(at::Tensor input, at::Tensor ln_weight, + at::Tensor ln_bias); + +std::vector test_opt(at::Tensor input, at::Tensor ln_weight, + at::Tensor ln_bias) { + return one_test_opt(input, ln_weight, ln_bias); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("test_opt", &test_opt, "fast depthwise conv1d forward"); +} diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/faster_layer_norm/transformer_helper.cuh b/models/speech/speech_recognition/transformer_asr/ixrt/faster_layer_norm/transformer_helper.cuh new file mode 100644 index 00000000..f8a57622 --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/faster_layer_norm/transformer_helper.cuh @@ -0,0 +1,295 @@ +#pragma once +#include +#include + +namespace iluvatar { +namespace inferrt { +namespace transformer { + +__forceinline__ int nearest_4(int x) { + if (x % 4 == 0) { + return x; + } else { + int padding = 4 - x % 4; + return x + padding; + } +} + +__forceinline__ int nearest_2(int x) { + if (x % 2 == 0) { + return x; + } else { + int padding = 2 - x % 2; + return x + padding; + } +} + +__forceinline__ int nearest_num(int x, int value) { + if (x % value == 0) { + return x; + } else { + int padding = value - x % value; + return x + padding; + } +} + +__device__ int8_t float2int8(float x, float quant_scale) { + float i8_f = x * quant_scale; + int32_t i8 = floorf(i8_f + 0.5); + i8 = i8 < -127 ? -127 : (i8 > 127 ? 127 : i8); + return int8_t(i8); +} + +__device__ void WelfordCombine(float val, float *mean, float *m2, + float *count) { + // Use Welford Online algorithem to compute mean and variance + // For more details you can refer to: + // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm + *count += 1; + float delta1 = val - *mean; + *mean += delta1 / *count; + float delta2 = val - *mean; + *m2 += delta1 * delta2; +} + +__device__ void WelfordCombine(float b_mean, float b_m2, float b_count, + float *mean, float *m2, float *count) { + if (b_count == 0) { + return; + } + float new_count = *count + b_count; + float nb_over_n = b_count / new_count; + float delta = b_mean - *mean; + *mean += delta * nb_over_n; + *m2 += b_m2 + delta * delta * (*count) * nb_over_n; + *count = new_count; +} + +__device__ void WelfordWarpReduce(float thread_mean, float thread_m2, + float thread_count, float *mean, float *m2, + float *count) { + *mean = thread_mean; + *m2 = thread_m2; + *count = thread_count; + for (int mask = warpSize / 2; mask > 0; mask /= 2) { + float b_mean = __shfl_down_sync(0xffffffff, *mean, mask); + float b_m2 = __shfl_down_sync(0xffffffff, *m2, mask); + float b_count = __shfl_down_sync(0xffffffff, *count, mask); + WelfordCombine(b_mean, b_m2, b_count, mean, m2, count); + } +} + +// load 两个 half2, 保存到 float4 +__device__ void load_float4_from_half(float4 &vals, __half2 *input, int index) { + __half2 i1 = input[index * 2]; + __half2 i2 = input[index * 2 + 1]; + + vals.x = __half2float(i1.x); + vals.y = __half2float(i1.y); + vals.z = __half2float(i2.x); + vals.w = __half2float(i2.y); +} + +__device__ char4 float42char4(float4 vals, float quant_scale) { + char4 res; + res.x = float2int8(vals.x, quant_scale); + res.y = float2int8(vals.y, quant_scale); + res.z = float2int8(vals.z, quant_scale); + res.w = float2int8(vals.w, quant_scale); + return res; +} + +__device__ float4 char4addhalf2_dequant(char4 input_4, half2 residual_1, + half2 residual_2, float dequant_scale) { + float4 res; + res.x = + __int2float_rn(input_4.x) * dequant_scale + __half2float(residual_1.x); + res.y = + __int2float_rn(input_4.y) * dequant_scale + __half2float(residual_1.y); + res.z = + __int2float_rn(input_4.z) * dequant_scale + __half2float(residual_2.x); + res.w = + __int2float_rn(input_4.w) * dequant_scale + __half2float(residual_2.y); + return res; +} + +__device__ float4 compute_float4_norm_value(float4 vals, float mean, float m2, + int hidden_size, float epsilon, + half2 scale_1, half2 scale_2, + half2 bias_1, half2 bias_2) { + float4 norm_value; + norm_value.x = (vals.x - mean) * rsqrtf(m2 / hidden_size + epsilon) * + __half2float(scale_1.x) + + __half2float(bias_1.x); + norm_value.y = (vals.y - mean) * rsqrtf(m2 / hidden_size + epsilon) * + __half2float(scale_1.y) + + __half2float(bias_1.y); + norm_value.z = (vals.z - mean) * rsqrtf(m2 / hidden_size + epsilon) * + __half2float(scale_2.x) + + __half2float(bias_2.x); + norm_value.w = (vals.w - mean) * rsqrtf(m2 / hidden_size + epsilon) * + __half2float(scale_2.y) + + __half2float(bias_2.y); + return norm_value; +} + +// softmax +__forceinline__ __host__ __device__ int log2_ceil(int value) { + int log2_value = 0; + while ((1 << log2_value) < value) ++log2_value; + return log2_value; +} +template +__device__ T WARP_SHFL_XOR(T value, int laneMask, int width) { + unsigned int mask = 0xffffffff; +#if !(defined(__HIP_PLATFORM_HCC__) || defined(__ILUVATAR__)) + return __shfl_xor_sync(mask, value, laneMask, width); +#else + return __shfl_xor(value, laneMask, width); +#endif +} + +template +struct Add { + __device__ T operator()(T a, T b) const { return a + b; } +}; + +template +struct Max { + __device__ T operator()(T a, T b) const { return a < b ? b : a; } +}; +template class ReduceOp> +__device__ void warp_reduce(acc_t *sum) { + ReduceOp r; +#pragma unroll + for (int offset = REDUCE_WARP_SIZE / 2; offset > 0; offset /= 2) { + acc_t b = WARP_SHFL_XOR(*sum, offset, REDUCE_WARP_SIZE); + *sum = r(*sum, b); + } +} + +__device__ void warp_argmax(float &value, int32_t &idx) { + for (int offset = warpSize / 2; offset > 0; offset /= 2) { + float next_value = WARP_SHFL_XOR(value, offset, warpSize); + float next_idx = WARP_SHFL_XOR(idx, offset, warpSize); + if (next_value > value) { + value = next_value; + idx = next_idx; + } + } +} + +// gelu +// IxinferBiasGeluI8II8OKernel +template +__device__ T tanhf_exp(T x) { + // float e1 = __expf(x); + // float e2 = 1.0f / e1; + // return (e1 - e2) / (e1 + e2); + + return (2.f / (1.f + __expf(-2.f * x)) - 1.f); +} + +template +__device__ T gelu(T x) { + float cdf = + 0.5f * + (1.0f + tanhf_exp((0.7978845608028654f * (x + 0.044715f * x * x * x)))); + return x * cdf; +} + +/* fp16 gelu */ +template <> +__forceinline__ __device__ __half2 gelu<__half2>(__half2 val) { + __half2 val_pow3 = __hmul2(val, __hmul2(val, val)); + float2 tmp_pow = __half22float2(val_pow3); + float2 tmp = __half22float2(val); + + tmp.x = + 0.5f * + (1.0f + tanhf((0.7978845608028654f * (tmp.x + 0.044715f * tmp_pow.x)))); + tmp.y = + 0.5f * + (1.0f + tanhf((0.7978845608028654f * (tmp.y + 0.044715f * tmp_pow.y)))); + return __hmul2(val, __float22half2_rn(tmp)); +} + +/* Convert vector index to 3-dim tensor index */ +__forceinline__ __host__ __device__ void decompose_3dim(int src, int dim1, + int dim2, int *id0, + int *id1, int *id2) { + *id2 = src % dim2; + src /= dim2; + + *id1 = src % dim1; + *id0 = src / dim1; +} + +template +__inline__ __device__ T warpReduceSumV2(T *val) { +#pragma unroll + for (int i = 0; i < NUM; i++) { +#pragma unroll + for (int mask = warpSize / 2; mask > 0; mask >>= 1) + val[i] += __shfl_xor_sync(0xffffffff, val[i], mask, warpSize); + } + return (T)(0.0f); +} + +template +__inline__ __device__ T blockReduceSumV2(T *val) { + static __shared__ T shared[NUM][warpSize + 1]; + int lane = threadIdx.x % warpSize; + int wid = threadIdx.x / warpSize; + + warpReduceSumV2(val); + + if (lane == 0) { +#pragma unroll + for (int i = 0; i < NUM; i++) { + shared[i][wid] = val[i]; + } + } + + __syncthreads(); + + bool is_mask = lane < (blockDim.x / warpSize); +#pragma unroll + for (int i = 0; i < NUM; i++) { + val[i] = is_mask ? shared[i][lane] : (T)(0.0f); + } + warpReduceSumV2(val); + return (T)0.0f; +} + +__inline__ __device__ void warpReduceSum2Number(float *x, float *y) { +#pragma unroll + for (int mask = warpSize / 2; mask > 0; mask >>= 1) { + *x += __shfl_xor_sync(0xffffffff, *x, mask, warpSize); + *y += __shfl_xor_sync(0xffffffff, *y, mask, warpSize); + } +} + +__inline__ __device__ void blockReduceSum2Number(float *x, float *y) { + static __shared__ float shared[2][warpSize + 1]; + int lane = threadIdx.x % warpSize; + int wid = threadIdx.x / warpSize; + + warpReduceSum2Number(x, y); + if (lane == 0) { + shared[0][wid] = *x; + shared[1][wid] = *y; + } + __syncthreads(); + bool is_mask = lane < (blockDim.x / warpSize); + *x = is_mask ? shared[0][lane] : 0.0f; + *y = is_mask ? shared[0][lane] : 0.0f; + + warpReduceSum2Number(x, y); +} + +} // namespace transformer + +} // namespace inferrt +} // namespace iluvatar diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/faster_logsumexp/__init__.py b/models/speech/speech_recognition/transformer_asr/ixrt/faster_logsumexp/__init__.py new file mode 100644 index 00000000..d50b3758 --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/faster_logsumexp/__init__.py @@ -0,0 +1,38 @@ +import torch +from faster_logsumexp import sp_opt + +# class FasterLogSumExp(torch.nn.Module): +# def __init__(self, weight, bias): +# super(FasterLogSumExp, self).__init__() +# self.weight = weight +# self.bias = bias + +# def forward(self, inputs, *args, **kwargs): +# hidden_size = self.weight.size(0) +# in_shape = inputs.shape +# inputs = inputs.view(-1,hidden_size) +# output, = sp_opt.test_opt(inputs,self.weight,self.bias) +# output = output.view(*in_shape) +# return output + +def FasterLogSumExp(inputs,dim): + # print(inputs.shape, dim) + if dim == 1 and len(inputs.shape)>2 and inputs.size(1)==2: + in_shape = inputs.shape + inputs = inputs.view(in_shape[0],in_shape[1],-1) + res, = sp_opt.test_opt(inputs) + new_shape = (in_shape[0],) + in_shape[2:] + res = res.view(*new_shape) + return res + # dim==0 现在的实现会有bug? + # if dim == 0 and len(inputs.shape)>=2: + # in_shape = inputs.shape + # inputs = inputs.view(in_shape[0],-1) + # res, = sp_opt.test_opt_dim0(inputs) + # new_shape = in_shape[1:] + # res = res.view(*new_shape) + # return res + # print(f"not support shape: {inputs.shape} dim: {dim}") + res = torch.logsumexp(inputs, dim=dim) + return res + diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/faster_logsumexp/build.sh b/models/speech/speech_recognition/transformer_asr/ixrt/faster_logsumexp/build.sh new file mode 100644 index 00000000..f679258d --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/faster_logsumexp/build.sh @@ -0,0 +1,22 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +set -euox pipefail + +rm -rf build +rm -rf *.so + +python3 setup.py build + +cp build/lib*/*.so . \ No newline at end of file diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/faster_logsumexp/kernel.cu b/models/speech/speech_recognition/transformer_asr/ixrt/faster_logsumexp/kernel.cu new file mode 100644 index 00000000..56eb0810 --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/faster_logsumexp/kernel.cu @@ -0,0 +1,155 @@ +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace iluvatar::inferrt::transformer { + +__global__ void LogSumExpWith2(half* input, half* output, int H) { + half2* h2_in1 = reinterpret_cast(input + blockIdx.x * 2 * H); + half2* h2_in2 = reinterpret_cast(input + blockIdx.x * 2 * H + H); + half2* h2_out = reinterpret_cast(output + blockIdx.x * H); + + int i = blockIdx.y * blockDim.x + threadIdx.x; + if (i < H / 2) { + float2 res; + half2 value1 = h2_in1[i]; + half2 value2 = h2_in2[i]; + + res.x = std::log(__expf(__half2float(value1.x)) + + __expf(__half2float(value2.x))); + res.y = std::log(__expf(__half2float(value1.y)) + + __expf(__half2float(value2.y))); + + half2 res_h2; + res_h2.x = __float2half(res.x); + res_h2.y = __float2half(res.y); + h2_out[i] = res_h2; + } +} + +void IxinferLogSumExpLauncher(half* input, half* output, int N, int C, int H, + cudaStream_t stream) { + const float layernorm_eps = 1e-5; + if (H % 2 != 0) { + throw std::runtime_error("IxinferLogSumExpLauncher: size error!"); + } + int num_threads = 1024; + int half_h = H / 2; + int num_roll = (half_h - 1 + num_threads) / num_threads; + dim3 grid(N, num_roll); + dim3 block(num_threads); + switch (C) { + case 2: + LogSumExpWith2<<>>(input, output, H); + break; + default: + throw std::runtime_error( + "IxinferLogSumExpLauncher error, unsupport size! "); + break; + } +} + +// https://zhuanlan.zhihu.com/p/153535799 +__global__ void LogSumExpDim0(half* input, half* output, int N, int H) { + half2* h2_out = reinterpret_cast(output); + + int i = blockIdx.x * blockDim.x + threadIdx.x; + + float2 res; + res.x = 0.f; + res.y = 0.f; + + float2 max_values; + max_values.x = -1000.f; + max_values.y = -1000.f; + + for (int batch_idx = 0; batch_idx < N; batch_idx++) { + half2* h2_in = reinterpret_cast(input + batch_idx * H); + half2 value = h2_in[i]; + + if (max_values.x < __half2float(value.x)) { + max_values.x = __half2float(value.x); + } + if (max_values.y < __half2float(value.y)) { + max_values.y = __half2float(value.y); + } + } + + for (int batch_idx = 0; batch_idx < N; batch_idx++) { + half2* h2_in = reinterpret_cast(input + batch_idx * H); + half2 value = h2_in[i]; + + res.x += __expf(__half2float(value.x) - max_values.x); + res.y += __expf(__half2float(value.y) - max_values.y); + } + + half2 res_h2; + res_h2.x = __float2half(std::log(res.x) + max_values.x); + res_h2.y = __float2half(std::log(res.y) + max_values.y); + + h2_out[i] = res_h2; +} + +void IxinferLogSumExpLauncher(half* input, half* output, int N, int H, + cudaStream_t stream) { + if (H % 2 != 0) { + throw std::runtime_error("IxinferLogSumExpLauncher: size error!"); + } + int num_threads = 1024; + int half_h = H / 2; + int num_roll = (half_h - 1 + num_threads) / num_threads; + dim3 grid(num_roll); + dim3 block(num_threads); + LogSumExpDim0<<>>(input, output, N, H); +} + +} // namespace iluvatar::inferrt::transformer + +std::vector one_test_opt(at::Tensor input) { + TORCH_CHECK(input.scalar_type() == at::ScalarType::Half); + TORCH_CHECK(input.is_cuda()); + TORCH_CHECK(input.is_contiguous()); + + TORCH_CHECK(input.dim() == 3); + + int N = input.size(0); + int C = input.size(1); + int H = input.size(2); + + at::Tensor output = input.new_empty({N, H}); + + half* p_in = (half*)input.data_ptr(); + half* p_out = (half*)output.data_ptr(); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + iluvatar::inferrt::transformer::IxinferLogSumExpLauncher(p_in, p_out, N, C, H, + stream); + return {output}; +} + +std::vector one_test_dim0(at::Tensor input) { + TORCH_CHECK(input.scalar_type() == at::ScalarType::Half); + TORCH_CHECK(input.is_cuda()); + TORCH_CHECK(input.is_contiguous()); + + TORCH_CHECK(input.dim() == 2); + + int N = input.size(0); + int H = input.size(1); + + at::Tensor output = input.new_empty({H}); + + half* p_in = (half*)input.data_ptr(); + half* p_out = (half*)output.data_ptr(); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + iluvatar::inferrt::transformer::IxinferLogSumExpLauncher(p_in, p_out, N, H, + stream); + return {output}; +} \ No newline at end of file diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/faster_logsumexp/setup.py b/models/speech/speech_recognition/transformer_asr/ixrt/faster_logsumexp/setup.py new file mode 100644 index 00000000..a031577c --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/faster_logsumexp/setup.py @@ -0,0 +1,48 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +import os +import glob +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +CUR_DIR = os.path.dirname(os.path.abspath(__file__)) + +# cpp_files = glob.glob(os.path.join(CUR_DIR,"*.cpp")) +# cu_files = glob.glob(os.path.join(CUR_DIR,'*.cu')) +# source_files = cpp_files + cu_files +# print("source files:") +# for i in source_files: +# print(i) +source_files = [ + os.path.join(CUR_DIR,'test.cpp'), + os.path.join(CUR_DIR,'kernel.cu'), +] + +for i in source_files: + assert os.path.isfile(i) + print(i) + +setup( + name="test", + ext_modules=[ + CUDAExtension( + name="sp_opt", + libraries=["cuinfer"], + sources=source_files) + ], + cmdclass={ + "build_ext": BuildExtension + } +) diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/faster_logsumexp/test.cpp b/models/speech/speech_recognition/transformer_asr/ixrt/faster_logsumexp/test.cpp new file mode 100644 index 00000000..5eaf6fe1 --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/faster_logsumexp/test.cpp @@ -0,0 +1,27 @@ +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +std::vector one_test_opt(at::Tensor input); + +std::vector test_opt(at::Tensor input) { + return one_test_opt(input); +} + +std::vector one_test_dim0(at::Tensor input); + +std::vector test_opt_dim0(at::Tensor input) { + return one_test_dim0(input); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("test_opt", &test_opt, ""); + m.def("test_opt_dim0", &test_opt_dim0, ""); +} diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/faster_logsumexp/test.py b/models/speech/speech_recognition/transformer_asr/ixrt/faster_logsumexp/test.py new file mode 100644 index 00000000..7b22dbdd --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/faster_logsumexp/test.py @@ -0,0 +1,50 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +import torch +import sp_opt + +if __name__ == "__main__": + batch_tokens = 2 + c = 2 + hidden_size = 320*5000 + + inputs = torch.randn([batch_tokens,c, hidden_size]).cuda().half() + + # res1 = torch.log(torch.sum(torch.exp(inputs),dim=-1)) + # res2 = torch.logsumexp(inputs,dim=-1) + # diff = torch.abs(res1-res2) + # print(diff.max()) + + res_pt = torch.logsumexp(inputs,dim=1) + + res_cu, = sp_opt.test_opt(inputs) + + diff = torch.abs(res_pt - res_cu) + print(diff.max()) + + for i in range(20): + res_cu, = sp_opt.test_opt(inputs) + + batch_tokens = 55 + hidden_size = 320*5000 + inputs = torch.randn([batch_tokens,hidden_size]).cuda().half() + res_pt = torch.logsumexp(inputs,dim=0) + res_cu, = sp_opt.test_opt_dim0(inputs) + + diff = torch.abs(res_pt - res_cu) + print(diff.max()) + for i in range(20): + res_cu, = sp_opt.test_opt_dim0(inputs) + diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/faster_stack/__init__.py b/models/speech/speech_recognition/transformer_asr/ixrt/faster_stack/__init__.py new file mode 100644 index 00000000..48d0cf5b --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/faster_stack/__init__.py @@ -0,0 +1,33 @@ +import torch +from faster_stack import sp_opt + +# class FasterLogSumExp(torch.nn.Module): +# def __init__(self, weight, bias): +# super(FasterLogSumExp, self).__init__() +# self.weight = weight +# self.bias = bias + +# def forward(self, inputs, *args, **kwargs): +# hidden_size = self.weight.size(0) +# in_shape = inputs.shape +# inputs = inputs.view(-1,hidden_size) +# output, = sp_opt.test_opt(inputs,self.weight,self.bias) +# output = output.view(*in_shape) +# return output + +def FasterStack(inputs): + if len(inputs) == 4: + a,b,c,d = inputs + in_shape = a.shape + res, = sp_opt.test_opt(a.view(-1),b.view(-1),c.view(-1),d.view(-1)) + new_shape = (4,) + in_shape + res = res.view(*new_shape) + return res + if len(inputs) == 2: + a,b = inputs + in_shape = a.shape + res, = sp_opt.test_opt_2(a.view(-1),b.view(-1)) + new_shape = (2,) + in_shape + res = res.view(*new_shape) + return res + return torch.stack(inputs) \ No newline at end of file diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/faster_stack/build.sh b/models/speech/speech_recognition/transformer_asr/ixrt/faster_stack/build.sh new file mode 100644 index 00000000..f679258d --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/faster_stack/build.sh @@ -0,0 +1,22 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +set -euox pipefail + +rm -rf build +rm -rf *.so + +python3 setup.py build + +cp build/lib*/*.so . \ No newline at end of file diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/faster_stack/kernel.cu b/models/speech/speech_recognition/transformer_asr/ixrt/faster_stack/kernel.cu new file mode 100644 index 00000000..0fdff649 --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/faster_stack/kernel.cu @@ -0,0 +1,146 @@ +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace iluvatar::inferrt::transformer { + +__global__ void Stack(half* a, half* b, half* c, half* d, half* output, int H) { + half2* h2_a = reinterpret_cast(a); + half2* h2_b = reinterpret_cast(b); + half2* h2_c = reinterpret_cast(c); + half2* h2_d = reinterpret_cast(d); + + half2* h2_out_a = reinterpret_cast(output); + half2* h2_out_b = reinterpret_cast(output + H); + half2* h2_out_c = reinterpret_cast(output + H * 2); + half2* h2_out_d = reinterpret_cast(output + H * 3); + + int i = blockIdx.x * blockDim.x + threadIdx.x; + + if (i < H / 2) { + h2_out_a[i] = h2_a[i]; + h2_out_b[i] = h2_b[i]; + h2_out_c[i] = h2_c[i]; + h2_out_d[i] = h2_d[i]; + } +} + +void IxinferStackLauncher(half* a, half* b, half* c, half* d, half* output, + int H, cudaStream_t stream) { + if (H % 2 != 0) { + throw std::runtime_error("IxinferStackLauncher: size error!"); + } + int num_threads = 1024; + int half_h = H / 2; + int num_roll = (half_h - 1 + num_threads) / num_threads; + dim3 grid(num_roll); + dim3 block(num_threads); + Stack<<>>(a, b, c, d, output, H); +} + +__global__ void Stack(half* a, half* b, half* output, int H) { + half2* h2_a = reinterpret_cast(a); + half2* h2_b = reinterpret_cast(b); + + half2* h2_out_a = reinterpret_cast(output); + half2* h2_out_b = reinterpret_cast(output + H); + + int i = blockIdx.x * blockDim.x + threadIdx.x; + + if (i < H / 2) { + h2_out_a[i] = h2_a[i]; + h2_out_b[i] = h2_b[i]; + } +} + +void IxinferStackLauncher(half* a, half* b, half* output, int H, + cudaStream_t stream) { + if (H % 2 != 0) { + throw std::runtime_error("IxinferStackLauncher: size error!"); + } + int num_threads = 1024; + int half_h = H / 2; + int num_roll = (half_h - 1 + num_threads) / num_threads; + dim3 grid(num_roll); + dim3 block(num_threads); + Stack<<>>(a, b, output, H); +} + +} // namespace iluvatar::inferrt::transformer + +std::vector one_test_opt(at::Tensor a, at::Tensor b, at::Tensor c, + at::Tensor d) { + TORCH_CHECK(a.scalar_type() == at::ScalarType::Half); + TORCH_CHECK(a.is_cuda()); + TORCH_CHECK(a.is_contiguous()); + + TORCH_CHECK(b.scalar_type() == at::ScalarType::Half); + TORCH_CHECK(b.is_cuda()); + TORCH_CHECK(b.is_contiguous()); + + TORCH_CHECK(c.scalar_type() == at::ScalarType::Half); + TORCH_CHECK(c.is_cuda()); + TORCH_CHECK(c.is_contiguous()); + + TORCH_CHECK(d.scalar_type() == at::ScalarType::Half); + TORCH_CHECK(d.is_cuda()); + TORCH_CHECK(d.is_contiguous()); + + TORCH_CHECK(a.dim() == 1); + TORCH_CHECK(b.dim() == 1); + TORCH_CHECK(c.dim() == 1); + TORCH_CHECK(d.dim() == 1); + + int N = a.size(0); + + TORCH_CHECK(b.size(0) == N); + TORCH_CHECK(c.size(0) == N); + TORCH_CHECK(d.size(0) == N); + + at::Tensor output = a.new_empty({N * 4}); + + half* p_a = (half*)a.data_ptr(); + half* p_b = (half*)b.data_ptr(); + half* p_c = (half*)c.data_ptr(); + half* p_d = (half*)d.data_ptr(); + half* p_out = (half*)output.data_ptr(); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + iluvatar::inferrt::transformer::IxinferStackLauncher(p_a, p_b, p_c, p_d, + p_out, N, stream); + return {output}; +} + +std::vector one_test_opt_2(at::Tensor a, at::Tensor b) { + TORCH_CHECK(a.scalar_type() == at::ScalarType::Half); + TORCH_CHECK(a.is_cuda()); + TORCH_CHECK(a.is_contiguous()); + + TORCH_CHECK(b.scalar_type() == at::ScalarType::Half); + TORCH_CHECK(b.is_cuda()); + TORCH_CHECK(b.is_contiguous()); + + TORCH_CHECK(a.dim() == 1); + TORCH_CHECK(b.dim() == 1); + + int N = a.size(0); + + TORCH_CHECK(b.size(0) == N); + + at::Tensor output = a.new_empty({N * 2}); + + half* p_a = (half*)a.data_ptr(); + half* p_b = (half*)b.data_ptr(); + half* p_out = (half*)output.data_ptr(); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + iluvatar::inferrt::transformer::IxinferStackLauncher(p_a, p_b, p_out, N, + stream); + return {output}; +} diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/faster_stack/setup.py b/models/speech/speech_recognition/transformer_asr/ixrt/faster_stack/setup.py new file mode 100644 index 00000000..a031577c --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/faster_stack/setup.py @@ -0,0 +1,48 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +import os +import glob +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +CUR_DIR = os.path.dirname(os.path.abspath(__file__)) + +# cpp_files = glob.glob(os.path.join(CUR_DIR,"*.cpp")) +# cu_files = glob.glob(os.path.join(CUR_DIR,'*.cu')) +# source_files = cpp_files + cu_files +# print("source files:") +# for i in source_files: +# print(i) +source_files = [ + os.path.join(CUR_DIR,'test.cpp'), + os.path.join(CUR_DIR,'kernel.cu'), +] + +for i in source_files: + assert os.path.isfile(i) + print(i) + +setup( + name="test", + ext_modules=[ + CUDAExtension( + name="sp_opt", + libraries=["cuinfer"], + sources=source_files) + ], + cmdclass={ + "build_ext": BuildExtension + } +) diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/faster_stack/test.cpp b/models/speech/speech_recognition/transformer_asr/ixrt/faster_stack/test.cpp new file mode 100644 index 00000000..08703064 --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/faster_stack/test.cpp @@ -0,0 +1,29 @@ +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +std::vector one_test_opt(at::Tensor a, at::Tensor b, at::Tensor c, + at::Tensor d); + +std::vector test_opt(at::Tensor a, at::Tensor b, at::Tensor c, + at::Tensor d) { + return one_test_opt(a, b, c, d); +} + +std::vector one_test_opt_2(at::Tensor a, at::Tensor b); + +std::vector test_opt_2(at::Tensor a, at::Tensor b) { + return one_test_opt_2(a, b); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("test_opt", &test_opt, ""); + m.def("test_opt_2", &test_opt_2, ""); +} diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/faster_stack/test.py b/models/speech/speech_recognition/transformer_asr/ixrt/faster_stack/test.py new file mode 100644 index 00000000..185b829b --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/faster_stack/test.py @@ -0,0 +1,74 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +import torch +import sp_opt + +if __name__ == "__main__": + batch_tokens = 320 + hidden_size = 5000 + + a = torch.randn([batch_tokens,hidden_size]).cuda().half() + b = torch.randn([batch_tokens,hidden_size]).cuda().half() + c = torch.randn([batch_tokens,hidden_size]).cuda().half() + d = torch.randn([batch_tokens,hidden_size]).cuda().half() + + res_pt = torch.stack([a,b,c,d]) + + res_cu, = sp_opt.test_opt(a.view(-1),b.view(-1),c.view(-1),d.view(-1)) + res_cu = res_cu.view(4,batch_tokens,hidden_size) + + diff = torch.abs(res_pt-res_cu) + print(diff) + print(diff.max()) + + for i in range(20): + res_cu, = sp_opt.test_opt(a.view(-1),b.view(-1),c.view(-1),d.view(-1)) + + res_pt = torch.stack([a,b]) + + res_cu, = sp_opt.test_opt_2(a.view(-1),b.view(-1)) + res_cu = res_cu.view(2,batch_tokens,hidden_size) + + diff = torch.abs(res_pt-res_cu) + print(diff) + print(diff.max()) + for i in range(20): + res_cu, = sp_opt.test_opt_2(a.view(-1),b.view(-1)) + # # res1 = torch.log(torch.sum(torch.exp(inputs),dim=-1)) + # # res2 = torch.logsumexp(inputs,dim=-1) + # # diff = torch.abs(res1-res2) + # # print(diff.max()) + + # res_pt = torch.logsumexp(inputs,dim=1) + + # res_cu, = sp_opt.test_opt(inputs) + + # diff = torch.abs(res_pt - res_cu) + # print(diff.max()) + + # for i in range(20): + # res_cu, = sp_opt.test_opt(inputs) + + # batch_tokens = 55 + # hidden_size = 320*5000 + # inputs = torch.randn([batch_tokens,hidden_size]).cuda().half() + # res_pt = torch.logsumexp(inputs,dim=0) + # res_cu, = sp_opt.test_opt_dim0(inputs) + + # diff = torch.abs(res_pt - res_cu) + # print(diff.max()) + # for i in range(20): + # res_cu, = sp_opt.test_opt_dim0(inputs) + diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/hparams/train_ASR_transformer.yaml b/models/speech/speech_recognition/transformer_asr/ixrt/hparams/train_ASR_transformer.yaml new file mode 100644 index 00000000..859d09f3 --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/hparams/train_ASR_transformer.yaml @@ -0,0 +1,253 @@ +# ############################################################################ +# Model: E2E ASR with Transformer +# Encoder: Transformer Encoder +# Decoder: Transformer Decoder + (CTC/ATT joint) beamsearch +# Tokens: BPE with unigram +# losses: CTC + KLdiv (Label Smoothing loss) +# Training: AISHELL-1 +# Authors: Jianyuan Zhong, Titouan Parcollet +# ############################################################################ +# Seed needs to be set at top of yaml, before objects with parameters are made +seed: 8886 +__set_seed: !apply:torch.manual_seed [!ref ] +output_folder: !ref results/transformer/ +cer_file: !ref /cer.txt +save_folder: !ref /save +train_log: !ref /train_log.txt + +# Data files +data_folder: !PLACEHOLDER # e,g./path/to/aishell +# noise/ris dataset will automatically be downloaded +data_folder_rirs: !ref # Change this is needed +skip_prep: False +ckpt_interval_minutes: 15 # save checkpoint every N min +train_data: !ref /csv_data/train.csv +valid_data: !ref /csv_data/dev.csv +test_data: !ref /csv_data/test.csv +tokenizer_file: speechbrain/asr-transformer-aishell/tokenizer.ckpt + +# Training parameters +number_of_epochs: 50 +batch_size: 64 +ctc_weight: 0.3 +gradient_accumulation: 4 +loss_reduction: 'batchmean' +sorting: ascending + +dynamic_batching: False +dynamic_batch_sampler: + feats_hop_size: 0.01 + max_batch_len: 15 # in terms of "duration" in annotations by default, second here + left_bucket_len: 200 # old implementation attributs + multiplier: 1.1 # old implementation attributs + shuffle_ex: False # if true re-creates batches at each epoch shuffling examples. + num_buckets: 10 # floor(log(max_batch_len/left_bucket_len, multiplier)) + 1 + batch_ordering: ascending + +num_workers: 6 + +# stages related parameters +stage_one_epochs: 40 +lr_adam: 1.0 +lr_sgd: 0.000025 + +# Feature parameters +sample_rate: 16000 +n_fft: 400 +n_mels: 80 + +# Dataloader options +train_dataloader_opts: + batch_size: !ref + shuffle: True + +valid_dataloader_opts: + batch_size: !ref + +test_dataloader_opts: + batch_size: !ref + +####################### Model parameters ########################### +# Transformer +d_model: 256 +nhead: 4 +num_encoder_layers: 12 +num_decoder_layers: 6 +d_ffn: 2048 +transformer_dropout: 0.1 +activation: !name:torch.nn.GELU +output_neurons: 5000 + +# Outputs +blank_index: 0 +label_smoothing: 0.1 +pad_index: 0 +bos_index: 1 +eos_index: 2 + +# Decoding parameters +min_decode_ratio: 0.0 +max_decode_ratio: 1.0 # 1.0 +valid_search_interval: 10 +valid_beam_size: 10 +test_beam_size: 1 +ctc_weight_decode: 0.40 + +############################## models ################################ + +CNN: !new:speechbrain.lobes.models.convolution.ConvolutionFrontEnd + input_shape: (8, 10, 80) + num_blocks: 2 + num_layers_per_block: 1 + out_channels: (256, 256) + kernel_sizes: (3, 3) + strides: (2, 2) + residuals: (False, False) + +Transformer: !new:speechbrain.lobes.models.transformer.TransformerASR.TransformerASR # yamllint disable-line rule:line-length + input_size: 5120 + tgt_vocab: !ref + d_model: !ref + nhead: !ref + num_encoder_layers: !ref + num_decoder_layers: !ref + d_ffn: !ref + dropout: !ref + activation: !ref + normalize_before: True + +tokenizer: !new:sentencepiece.SentencePieceProcessor + +ctc_lin: !new:speechbrain.nnet.linear.Linear + input_size: !ref + n_neurons: !ref + +seq_lin: !new:speechbrain.nnet.linear.Linear + input_size: !ref + n_neurons: !ref + +env_corrupt: !new:speechbrain.lobes.augment.EnvCorrupt + openrir_folder: !ref + babble_prob: 0.0 + reverb_prob: 0.0 + noise_prob: 1.0 + noise_snr_low: 0 + noise_snr_high: 15 + +modules: + CNN: !ref + Transformer: !ref + seq_lin: !ref + ctc_lin: !ref + env_corrupt: !ref + +model: !new:torch.nn.ModuleList + - [!ref , !ref , !ref , !ref ] + +# define two optimizers here for two-stage training +Adam: !name:torch.optim.Adam + lr: 0 + betas: (0.9, 0.98) + eps: 0.000000001 + +SGD: !name:torch.optim.SGD + lr: !ref + momentum: 0.99 + nesterov: True + + +valid_search: !new:speechbrain.decoders.S2STransformerBeamSearch + modules: [!ref , !ref , !ref ] + bos_index: !ref + eos_index: !ref + blank_index: !ref + min_decode_ratio: !ref + max_decode_ratio: !ref + beam_size: !ref + ctc_weight: !ref + using_eos_threshold: False + length_normalization: True + +test_search: !new:speechbrain.decoders.S2STransformerBeamSearch + modules: [!ref , !ref , !ref ] + bos_index: !ref + eos_index: !ref + blank_index: !ref + min_decode_ratio: !ref + max_decode_ratio: !ref + beam_size: !ref + ctc_weight: !ref + using_eos_threshold: False + length_normalization: True + +log_softmax: !new:torch.nn.LogSoftmax + dim: -1 + +ctc_cost: !name:speechbrain.nnet.losses.ctc_loss + blank_index: !ref + reduction: !ref + +seq_cost: !name:speechbrain.nnet.losses.kldiv_loss + label_smoothing: !ref + reduction: !ref + +noam_annealing: !new:speechbrain.nnet.schedulers.NoamScheduler + lr_initial: !ref + n_warmup_steps: 25000 + model_size: !ref + +checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer + checkpoints_dir: !ref + recoverables: + model: !ref + noam_scheduler: !ref + normalizer: !ref + counter: !ref + +epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter + limit: !ref + +normalize: !new:speechbrain.processing.features.InputNormalization + norm_type: global + update_until_epoch: 4 + +augmentation: !new:speechbrain.lobes.augment.SpecAugment + time_warp: True + time_warp_window: 5 + time_warp_mode: bicubic + freq_mask: True + n_freq_mask: 2 + time_mask: True + n_time_mask: 2 + replace_with_zero: False + freq_mask_width: 30 + time_mask_width: 40 + +compute_features: !new:speechbrain.lobes.features.Fbank + sample_rate: !ref + n_fft: !ref + n_mels: !ref + +train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger + save_file: !ref + +# AISHELL-1 has spaces between words in the transcripts, +# which Chinese writing normally does not do. +# If remove_spaces, spaces are removed +# from the transcript before computing CER. +# (e.g., 祝 可爱 的 你 —> 祝可爱的你) +remove_spaces: True +split_tokens: !apply:operator.not_ [!ref ] + +cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats + split_tokens: !ref +acc_computer: !name:speechbrain.utils.Accuracy.AccuracyStats + +pretrainer: !new:speechbrain.utils.parameter_transfer.Pretrainer + collect_in: !ref + loadables: + tokenizer: !ref + paths: + tokenizer: !ref +engine_path: transformer.engine +ckpt_path: /home/data/speechbrain/results \ No newline at end of file diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/inference.py b/models/speech/speech_recognition/transformer_asr/ixrt/inference.py new file mode 100644 index 00000000..68ef0e40 --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/inference.py @@ -0,0 +1,606 @@ +#!/usr/bin/env/python3 +""" + +AISHELL-1 transformer model recipe. (Adapted from the LibriSpeech recipe.) + +""" +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +import sys +import time +import torch +import logging +import speechbrain as sb +from speechbrain import Stage +from speechbrain.dataio.dataloader import LoopedLoader +from speechbrain.utils.distributed import run_on_main +from hyperpyyaml import load_hyperpyyaml +from speechbrain.utils.checkpoints import Checkpointer +import numpy as np +from speechbrain.utils import data_utils +import tensorrt +from torch.utils.data import DataLoader +from tqdm import tqdm +import convert +import beam_search +from load_ixrt_plugin import load_ixrt_plugin +from tensorrt import Dims +from speechbrain.lobes.models.transformer import Transformer +Transformer.get_lookahead_mask = convert.patch_get_lookahead_mask +load_ixrt_plugin() +logger = logging.getLogger(__name__) + + +def volume(shape): + result = 1 + for i in shape: + result *= i + return result + + +class ASR(sb.core.Brain): + def __init__(self, engine_path, *args, **kwargs): + super().__init__(*args, **kwargs) + # + self.forward_time = 0 + # ixrt + self.logger = tensorrt.Logger(tensorrt.Logger.ERROR) + with open(engine_path, "rb") as f, tensorrt.Runtime(self.logger) as self.runtime: + self.engine = self.runtime.deserialize_cuda_engine(f.read()) + assert self.engine + self.context = self.engine.create_execution_context() + assert self.context + self.encoder_ln_out = torch.zeros((64,2048,256), dtype=torch.float16).cuda() + self.infer_time = 0 + self.hparams.valid_search.return_log_probs = True + self.modules.CNN = self.modules.CNN.half() + self.hparams.valid_search = self.hparams.valid_search.half() + self.hparams.valid_search.model = self.hparams.valid_search.model.half() + self.hparams.valid_search.fc = self.hparams.valid_search.fc.half() + self.hparams.valid_search.ctc_fc = self.hparams.valid_search.ctc_fc.half() + self.hparams.valid_search.minus_inf = -10000 + self.hparams.valid_search.softmax = self.hparams.valid_search.softmax.half() + self.hparams.valid_search.model.decoder = convert.convert_decoder_model(self.hparams.valid_search.model.decoder) + # Given all input/output bindings, run in a dynamic shape way + def ixrt_infer(self, engine, context, bindings): + assert engine.num_bindings == len(bindings) + io_buffers = [0] * engine.num_bindings + for name, arr in bindings.items(): + idx = engine.get_binding_index(name) + io_buffers[idx] = arr.data_ptr() + # dynamic input + if engine.binding_is_input(idx): + context.set_binding_shape(idx, Dims(arr.shape)) + + forward_start_time = time.time() + assert context.execute_v2(io_buffers) + + torch.cuda.synchronize() + self.forward_time += time.time() - forward_start_time + outputs = {} + for name, arr in bindings.items(): + idx = engine.get_binding_index(name) + if not engine.binding_is_input(idx): + # dynamic output + shape = context.get_binding_shape(idx) + outputs[name] = arr.view(-1)[:volume(shape)].view(*shape) + return outputs + + def compute_forward(self, batch, stage): + """Forward computations from the waveform batches to the output probabilities.""" + + batch = batch.to(self.device) + wavs, wav_lens = batch.sig + tokens_bos, _ = batch.tokens_bos + + # Add augmentation if specified + if stage == sb.Stage.TRAIN: + if hasattr(self.modules, "env_corrupt"): + wavs_noise = self.modules.env_corrupt(wavs, wav_lens) + wavs = torch.cat([wavs, wavs_noise], dim=0) + wav_lens = torch.cat([wav_lens, wav_lens]) + tokens_bos = torch.cat([tokens_bos, tokens_bos], dim=0) + + torch.cuda.synchronize() + start_time = time.time() + + # compute features + feats = self.hparams.compute_features(wavs) + current_epoch = self.hparams.epoch_counter.current + feats = self.hparams.normalize(feats, wav_lens, epoch=current_epoch) + + if stage == sb.Stage.TRAIN: + if hasattr(self.hparams, "augmentation"): + feats = self.hparams.augmentation(feats) + + # forward modules + src = self.modules.CNN(feats.half()) + + # Orignal PyTorch implementation, comment this to compare + # enc_out, _ = self.modules.Transformer( + # src, tokens_bos, wav_lens, pad_idx=self.hparams.pad_index + # ) + # logits = self.modules.ctc_lin(enc_out) + # p_ctc = self.hparams.log_softmax(logits) + # hyps, _ = self.hparams.test_search( + # enc_out.detach(), wav_lens + # ) + # return p_ctc, wav_lens, hyps + + # transformer + if src.ndim == 4: + bz, t, ch1, ch2 = src.shape + src = src.reshape(bz, t, ch1 * ch2) + + # ixrt inference + t1 = time.time() + bindings = {"input": src.half(), "length_radio": wav_lens.half(), + "encoder_ln_out": self.encoder_ln_out} + + infer_result = self.ixrt_infer(self.engine, self.context, bindings) + encoder_ln_out = infer_result["encoder_ln_out"] + t2 = time.time() + + hyps, _, p_ctc = beam_search.forward(self.hparams.valid_search, encoder_ln_out.half(), wav_lens.half()) + torch.cuda.synchronize() + infer_time = time.time() - start_time + + self.infer_time += infer_time + + return p_ctc, wav_lens, hyps + + def compute_objectives(self, predictions, batch, stage): + """Computes the loss (CTC+NLL) given predictions and targets.""" + + # ( + # p_ctc, + # p_seq, + # wav_lens, + # hyps, + # ) = predictions + + # 去除 seq2seq log-probabilities + ( + p_ctc, + wav_lens, + hyps, + ) = predictions + + ids = batch.id + tokens_eos, tokens_eos_lens = batch.tokens_eos + tokens, tokens_lens = batch.tokens + + if hasattr(self.modules, "env_corrupt") and stage == sb.Stage.TRAIN: + tokens_eos = torch.cat([tokens_eos, tokens_eos], dim=0) + tokens_eos_lens = torch.cat( + [tokens_eos_lens, tokens_eos_lens], dim=0) + tokens = torch.cat([tokens, tokens], dim=0) + tokens_lens = torch.cat([tokens_lens, tokens_lens], dim=0) + + # 去除 seq2seq 部分 loss + # loss = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens) + + if stage != sb.Stage.TRAIN: + current_epoch = self.hparams.epoch_counter.current + valid_search_interval = self.hparams.valid_search_interval + + if current_epoch % valid_search_interval == 0 or (stage == sb.Stage.TEST): + # Decode token terms to words + predicted_words = [ + tokenizer.decode_ids(utt_seq).split(" ") for utt_seq in hyps + ] + target_words = [wrd.split(" ") for wrd in batch.wrd] + if self.hparams.remove_spaces: + predicted_words = ["".join(p) for p in predicted_words] + target_words = ["".join(t) for t in target_words] + self.cer_metric.append(ids, predicted_words, target_words) + + # 不计算 acc 部分 + # # compute the accuracy of the one-step-forward prediction + # self.acc_metric.append(p_seq, tokens_eos, tokens_eos_lens) + return -torch.ones([1]) + + def fit_batch(self, batch): + """Train the parameters given a single batch in input""" + # check if we need to switch optimizer + # if so change the optimizer from Adam to SGD + self.check_and_reset_optimizer() + + predictions = self.compute_forward(batch, sb.Stage.TRAIN) + loss = self.compute_objectives(predictions, batch, sb.Stage.TRAIN) + + # normalize the loss by gradient_accumulation step + (loss / self.hparams.gradient_accumulation).backward() + + if self.step % self.hparams.gradient_accumulation == 0: + # gradient clipping & early stop if loss is not fini + self.check_gradients(loss) + + self.optimizer.step() + self.optimizer.zero_grad() + + # anneal lr every update + self.hparams.noam_annealing(self.optimizer) + + return loss.detach() + + def evaluate_batch(self, batch, stage): + """Computations needed for validation/test batches""" + with torch.no_grad(): + predictions = self.compute_forward(batch, stage=stage) + loss = self.compute_objectives(predictions, batch, stage=stage) + return loss + + def on_stage_start(self, stage, epoch): + """Gets called at the beginning of each epoch""" + if stage != sb.Stage.TRAIN: + # self.acc_metric = self.hparams.acc_computer() + self.cer_metric = self.hparams.cer_computer() + + def on_stage_end(self, stage, stage_loss, epoch): + """Gets called at the end of a epoch.""" + # Compute/store important stats + stage_stats = {"forward time": self.forward_time} + if stage == sb.Stage.TRAIN: + self.train_stats = stage_stats + else: + # stage_stats["ACC"] = self.acc_metric.summarize() + current_epoch = self.hparams.epoch_counter.current + valid_search_interval = self.hparams.valid_search_interval + if current_epoch % valid_search_interval == 0 or stage == sb.Stage.TEST: + stage_stats["CER"] = self.cer_metric.summarize("error_rate") + + # log stats and save checkpoint at end-of-epoch + if stage == sb.Stage.VALID and sb.utils.distributed.if_main_process(): + + # report different epoch stages according current stage + current_epoch = self.hparams.epoch_counter.current + if current_epoch <= self.hparams.stage_one_epochs: + lr = self.hparams.noam_annealing.current_lr + steps = self.hparams.noam_annealing.n_steps + optimizer = self.optimizer.__class__.__name__ + else: + lr = self.hparams.lr_sgd + steps = -1 + optimizer = self.optimizer.__class__.__name__ + + epoch_stats = { + "epoch": epoch, + "lr": lr, + "steps": steps, + "optimizer": optimizer, + } + self.hparams.train_logger.log_stats( + stats_meta=epoch_stats, + train_stats=self.train_stats, + valid_stats=stage_stats, + ) + self.checkpointer.save_and_keep_only( + meta={"ACC": stage_stats["ACC"], "epoch": epoch}, + max_keys=["ACC"], + num_to_keep=10, + ) + + elif stage == sb.Stage.TEST: + self.hparams.train_logger.log_stats( + stats_meta={ + "Epoch loaded": self.hparams.epoch_counter.current}, + test_stats=stage_stats, + ) + with open(self.hparams.cer_file, "w") as w: + self.cer_metric.write_stats(w) + + def check_and_reset_optimizer(self): + """reset the optimizer if training enters stage 2""" + current_epoch = self.hparams.epoch_counter.current + if not hasattr(self, "switched"): + self.switched = False + if isinstance(self.optimizer, torch.optim.SGD): + self.switched = True + + if self.switched is True: + return + + if current_epoch > self.hparams.stage_one_epochs: + self.optimizer = self.hparams.SGD(self.modules.parameters()) + + if self.checkpointer is not None: + self.checkpointer.add_recoverable("optimizer", self.optimizer) + + self.switched = True + + def on_fit_start(self): + """Initialize the right optimizer on the training start""" + super().on_fit_start() + + # if the model is resumed from stage two, reinitialize the optimizer + current_epoch = self.hparams.epoch_counter.current + current_optimizer = self.optimizer + if current_epoch > self.hparams.stage_one_epochs: + del self.optimizer + self.optimizer = self.hparams.SGD(self.modules.parameters()) + + # Load latest checkpoint to resume training if interrupted + if self.checkpointer is not None: + + # do not reload the weights if training is interrupted right before stage 2 + group = current_optimizer.param_groups[0] + if "momentum" not in group: + return + + self.checkpointer.recover_if_possible( + device=torch.device(self.device)) + + def on_evaluate_start(self, max_key=None, min_key=None): + """perform checkpoint averge if needed""" + super().on_evaluate_start() + + ckpts = self.checkpointer.find_checkpoints( + max_key=max_key, min_key=min_key) + ckpt = sb.utils.checkpoints.average_checkpoints( + ckpts, recoverable_name="model", device=self.device + ) + + self.hparams.model.load_state_dict(ckpt, strict=True) + self.hparams.model.eval() + + def evaluate( + self, + test_set, + max_key=None, + min_key=None, + progressbar=None, + test_loader_kwargs={}, + ): + self.debug = False + self.debug_batches = 1 + if progressbar is None: + progressbar = not self.noprogressbar + + if not ( + isinstance(test_set, DataLoader) + or isinstance(test_set, LoopedLoader) + ): + test_loader_kwargs["ckpt_prefix"] = None + test_set = self.make_dataloader( + test_set, Stage.TEST, **test_loader_kwargs + ) + self.on_evaluate_start(max_key=max_key, min_key=min_key) + self.on_stage_start(Stage.TEST, epoch=None) + self.modules.eval() + avg_test_loss = 0.0 + self.step = 0 + with torch.no_grad(): + for batch in tqdm( + test_set, dynamic_ncols=True, disable=not progressbar + ): + self.step += 1 + loss = self.evaluate_batch(batch, stage=Stage.TEST) + avg_test_loss = self.update_average(loss, avg_test_loss) + + # Profile only if desired (steps allow the profiler to know when all is warmed up) + if self.profiler is not None: + if self.profiler.record_steps: + self.profiler.step() + + # Debug mode only runs a few batches + if self.debug and self.step == self.debug_batches: + break + + # Only run evaluation "on_stage_end" on main process + run_on_main( + self.on_stage_end, args=[Stage.TEST, avg_test_loss, None] + ) + self.step = 0 + return avg_test_loss + + +def dataio_prepare(hparams): + """This function prepares the datasets to be used in the brain class. + It also defines the data processing pipeline through user-defined functions.""" + data_folder = hparams["data_folder"] + + train_data = sb.dataio.dataset.DynamicItemDataset.from_csv( + csv_path=hparams["train_data"], + replacements={"data_root": data_folder}, + ) + + if hparams["sorting"] == "ascending": + # we sort training data to speed up training and get better results. + train_data = train_data.filtered_sorted(sort_key="duration") + # when sorting do not shuffle in dataloader ! otherwise is pointless + hparams["train_dataloader_opts"]["shuffle"] = False + + elif hparams["sorting"] == "descending": + train_data = train_data.filtered_sorted( + sort_key="duration", reverse=True) + # when sorting do not shuffle in dataloader ! otherwise is pointless + hparams["train_dataloader_opts"]["shuffle"] = False + + elif hparams["sorting"] == "random": + pass + + else: + raise NotImplementedError( + "sorting must be random, ascending or descending") + + valid_data = sb.dataio.dataset.DynamicItemDataset.from_csv( + csv_path=hparams["valid_data"], + replacements={"data_root": data_folder}, + ) + valid_data = valid_data.filtered_sorted(sort_key="duration") + + test_data = sb.dataio.dataset.DynamicItemDataset.from_csv( + csv_path=hparams["test_data"], + replacements={"data_root": data_folder}, + ) + test_data = test_data.filtered_sorted(sort_key="duration") + + datasets = [train_data, valid_data, test_data] + + # Defining tokenizer and loading it + tokenizer = hparams["tokenizer"] + + # 2. Define audio pipeline: + @sb.utils.data_pipeline.takes("wav") + @sb.utils.data_pipeline.provides("sig") + def audio_pipeline(wav): + sig = sb.dataio.dataio.read_audio(wav) + return sig + + sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline) + + # 3. Define text pipeline: + @sb.utils.data_pipeline.takes("transcript") + @sb.utils.data_pipeline.provides( + "wrd", "tokens_list", "tokens_bos", "tokens_eos", "tokens" + ) + def text_pipeline(wrd): + yield wrd + tokens_list = tokenizer.encode_as_ids(wrd) + yield tokens_list + tokens_bos = torch.LongTensor([hparams["bos_index"]] + (tokens_list)) + yield tokens_bos + tokens_eos = torch.LongTensor(tokens_list + [hparams["eos_index"]]) + yield tokens_eos + tokens = torch.LongTensor(tokens_list) + yield tokens + + sb.dataio.dataset.add_dynamic_item(datasets, text_pipeline) + + # 4. Set output: + sb.dataio.dataset.set_output_keys( + datasets, + ["id", "sig", "wrd", "tokens_bos", "tokens_eos", "tokens"], + ) + + # 5. If Dynamic Batching is used, we instantiate the needed samplers. + train_batch_sampler = None + valid_batch_sampler = None + if hparams["dynamic_batching"]: + from speechbrain.dataio.sampler import DynamicBatchSampler # noqa + + dynamic_hparams = hparams["dynamic_batch_sampler"] + num_buckets = dynamic_hparams["num_buckets"] + + train_batch_sampler = DynamicBatchSampler( + train_data, + dynamic_hparams["max_batch_len"], + num_buckets=num_buckets, + length_func=lambda x: x["duration"], + shuffle=dynamic_hparams["shuffle_ex"], + batch_ordering=dynamic_hparams["batch_ordering"], + ) + + valid_batch_sampler = DynamicBatchSampler( + valid_data, + dynamic_hparams["max_batch_len"], + num_buckets=num_buckets, + length_func=lambda x: x["duration"], + shuffle=dynamic_hparams["shuffle_ex"], + batch_ordering=dynamic_hparams["batch_ordering"], + ) + + return ( + train_data, + valid_data, + test_data, + tokenizer, + train_batch_sampler, + valid_batch_sampler, + ) + + +if __name__ == "__main__": + # CLI: + hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:]) + with open(hparams_file) as fin: + hparams = load_hyperpyyaml(fin, overrides) + + # If --distributed_launch then + # create ddp_group with the right communication protocol + sb.utils.distributed.ddp_init_group(run_opts) + + # Create experiment directory + sb.create_experiment_directory( + experiment_directory=hparams["output_folder"], + hyperparams_to_save=hparams_file, + overrides=overrides, + ) + + # 1. # Dataset prep (parsing Librispeech) + from aishell_prepare import prepare_aishell # noqa + + # multi-gpu (ddp) save data preparation + run_on_main( + prepare_aishell, + kwargs={ + "data_folder": hparams["data_folder"], + "save_folder": hparams["output_folder"], + "skip_prep": hparams["skip_prep"], + }, + ) + + # here we create the datasets objects as well as tokenization and encoding + ( + train_data, + valid_data, + test_data, + tokenizer, + train_bsampler, + valid_bsampler, + ) = dataio_prepare(hparams) + + hparams["pretrainer"].collect_files(default_source=hparams['ckpt_path']) + hparams["pretrainer"].load_collected(device=run_opts["device"]) + + # Trainer initialization + asr_brain = ASR( + modules=hparams["modules"], + opt_class=hparams["Adam"], + hparams=hparams, + run_opts=run_opts, + checkpointer=hparams["checkpointer"], + engine_path=hparams['engine_path'] + ) + + asr_brain.tokenizer = tokenizer + + # Changing the samplers if dynamic batching is activated + train_dataloader_opts = hparams["train_dataloader_opts"] + valid_dataloader_opts = hparams["valid_dataloader_opts"] + + if train_bsampler is not None: + train_dataloader_opts = { + "batch_sampler": train_bsampler, + "num_workers": hparams["num_workers"], + } + if valid_bsampler is not None: + valid_dataloader_opts = {"batch_sampler": valid_bsampler} + + # evaluation + print("*** start evaluation ***") + start_time = time.time() + asr_brain.evaluate( + test_data, test_loader_kwargs=hparams["test_dataloader_opts"]) + eval_time = asr_brain.infer_time + + ## 统计数据总音频时长 + duration = 0.0 + for value in test_data.data.values(): + duration = duration + value['duration'] + num_samples = len(test_data) + print(f"samples: {num_samples}, QPS: {num_samples / eval_time} ") + print(f"infer time :{eval_time},RTF: {eval_time / duration} ") diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/load_ixrt_plugin.py b/models/speech/speech_recognition/transformer_asr/ixrt/load_ixrt_plugin.py new file mode 100644 index 00000000..2bb0abc2 --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/load_ixrt_plugin.py @@ -0,0 +1,26 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +import ctypes +import tensorrt +from os.path import join, dirname, exists +def load_ixrt_plugin(logger=tensorrt.Logger(tensorrt.Logger.INFO), namespace="", dynamic_path=""): + if not dynamic_path: + dynamic_path = join(dirname(tensorrt.__file__), "lib", "libixrt_plugin.so") + if not exists(dynamic_path): + raise FileNotFoundError( + f"The ixrt_plugin lib {dynamic_path} is not existed, please provided effective plugin path!") + ctypes.CDLL(dynamic_path) + tensorrt.init_libnvinfer_plugins(logger, namespace) + print(f"Loaded plugin from {dynamic_path}") -- Gitee