From efd7bc9314ce6c034935df3b27503c095d45bfe6 Mon Sep 17 00:00:00 2001 From: may Date: Mon, 22 Jul 2024 15:31:58 +0800 Subject: [PATCH 1/2] Add transformer_asr --- .../transformer_asr/ixrt/README.md | 56 ++ .../transformer_asr/ixrt/aishell_prepare.py | 141 ++++ .../transformer_asr/ixrt/beam_search.py | 381 +++++++++++ .../transformer_asr/ixrt/build.sh | 23 + .../transformer_asr/ixrt/builder.py | 466 ++++++++++++++ .../transformer_asr/ixrt/convert.py | 95 +++ .../transformer_asr/ixrt/ctc.py | 394 ++++++++++++ .../ixrt/faster_cat/__init__.py | 13 + .../transformer_asr/ixrt/faster_cat/build.sh | 22 + .../transformer_asr/ixrt/faster_cat/kernel.cu | 79 +++ .../transformer_asr/ixrt/faster_cat/setup.py | 48 ++ .../transformer_asr/ixrt/faster_cat/test.cpp | 21 + .../transformer_asr/ixrt/faster_cat/test.py | 37 ++ .../ixrt/faster_layer_norm/__init__.py | 16 + .../ixrt/faster_layer_norm/build.sh | 22 + .../ixrt/faster_layer_norm/kernel.cu | 168 +++++ .../ixrt/faster_layer_norm/setup.py | 48 ++ .../ixrt/faster_layer_norm/test.cpp | 22 + .../faster_layer_norm/transformer_helper.cuh | 295 +++++++++ .../ixrt/faster_logsumexp/__init__.py | 38 ++ .../ixrt/faster_logsumexp/build.sh | 22 + .../ixrt/faster_logsumexp/kernel.cu | 155 +++++ .../ixrt/faster_logsumexp/setup.py | 48 ++ .../ixrt/faster_logsumexp/test.cpp | 27 + .../ixrt/faster_logsumexp/test.py | 50 ++ .../ixrt/faster_stack/__init__.py | 33 + .../ixrt/faster_stack/build.sh | 22 + .../ixrt/faster_stack/kernel.cu | 146 +++++ .../ixrt/faster_stack/setup.py | 48 ++ .../ixrt/faster_stack/test.cpp | 29 + .../transformer_asr/ixrt/faster_stack/test.py | 74 +++ .../ixrt/hparams/train_ASR_transformer.yaml | 253 ++++++++ .../transformer_asr/ixrt/inference.py | 606 ++++++++++++++++++ .../transformer_asr/ixrt/load_ixrt_plugin.py | 26 + 34 files changed, 3924 insertions(+) create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/README.md create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/aishell_prepare.py create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/beam_search.py create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/build.sh create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/builder.py create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/convert.py create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/ctc.py create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/faster_cat/__init__.py create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/faster_cat/build.sh create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/faster_cat/kernel.cu create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/faster_cat/setup.py create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/faster_cat/test.cpp create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/faster_cat/test.py create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/faster_layer_norm/__init__.py create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/faster_layer_norm/build.sh create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/faster_layer_norm/kernel.cu create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/faster_layer_norm/setup.py create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/faster_layer_norm/test.cpp create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/faster_layer_norm/transformer_helper.cuh create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/faster_logsumexp/__init__.py create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/faster_logsumexp/build.sh create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/faster_logsumexp/kernel.cu create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/faster_logsumexp/setup.py create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/faster_logsumexp/test.cpp create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/faster_logsumexp/test.py create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/faster_stack/__init__.py create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/faster_stack/build.sh create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/faster_stack/kernel.cu create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/faster_stack/setup.py create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/faster_stack/test.cpp create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/faster_stack/test.py create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/hparams/train_ASR_transformer.yaml create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/inference.py create mode 100644 models/speech/speech_recognition/transformer_asr/ixrt/load_ixrt_plugin.py diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/README.md b/models/speech/speech_recognition/transformer_asr/ixrt/README.md new file mode 100644 index 00000000..7560b5eb --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/README.md @@ -0,0 +1,56 @@ +# Asr transformer fp16 inference (BeamSearch) + +## Description + +Beam search allows us to exert control over the output of text generation. This is useful because we sometimes know exactly what we want inside the output. For example, in a Neural Machine Translation task, we might know which words must be included in the final translation with a dictionary lookup. + + +## Setup + +### Install + +``` +pip3 install speechbrain==0.5.13 +``` + +* ixrt 4.0.1_MR release + +### Download + +Pretrained model: + +Dataset: to download the Aishell dataset. + +``` +# Make sure the checkpoint path is results/transformer/8886/save +mkdir -p results/transformer/8886/save +# Make sure the dataset path is results/transformer/8886/save +mkdir -p /home/data/speechbrain +``` + +## Inference + +### Build faster kernels + +```bash +bash build.sh +``` + +### Build engine + +max_batch_size and max_seq_len depend on the situation. + +``` +python3 builder.py \ +--ckpt_path results/transformer/8886/save \ +--head_num 4 \ +--max_batch_size 64 \ +--max_seq_len 1024 \ +--engine_path transformer.engine +``` + +### Run engine + +``` +python3 inference.py hparams/train_ASR_transformer.yaml --data_folder=/home/data/speechbrain/aishell --engine_path transformer.engine +``` \ No newline at end of file diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/aishell_prepare.py b/models/speech/speech_recognition/transformer_asr/ixrt/aishell_prepare.py new file mode 100644 index 00000000..ba319394 --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/aishell_prepare.py @@ -0,0 +1,141 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +import os +import shutil +import logging +from speechbrain.dataio.dataio import read_audio +from speechbrain.utils.data_utils import download_file +import glob +import csv +import argparse + +logger = logging.getLogger(__name__) + + +def prepare_aishell(data_folder, save_folder, skip_prep=False): + """ + This function prepares the AISHELL-1 dataset. + If the folder does not exist, the zip file will be extracted. If the zip file does not exist, it will be downloaded. + + data_folder : path to AISHELL-1 dataset. + save_folder: path where to store the manifest csv files. + skip_prep: If True, skip data preparation. + + """ + if skip_prep: + return + + # If the data folders do not exist, we need to extract the data + if not os.path.isdir(os.path.join(data_folder, "data_aishell/wav")): + # # Check for zip file and download if it doesn't exist + # zip_location = os.path.join(data_folder, "data_aishell.tgz") + # if not os.path.exists(zip_location): + # url = "https://www.openslr.org/resources/33/data_aishell.tgz" + # download_file(url, zip_location, unpack=True) + # logger.info("Extracting data_aishell.tgz...") + # shutil.unpack_archive(zip_location, data_folder) + + wav_dir = os.path.join(data_folder, "data_aishell/wav") + tgz_list = glob.glob(wav_dir + "/*.tar.gz") + for tgz in tgz_list: + shutil.unpack_archive(tgz, wav_dir) + os.remove(tgz) + + # Create filename-to-transcript dictionary + filename2transcript = {} + with open( + os.path.join( + data_folder, "data_aishell/transcript/aishell_transcript_v0.8.txt" + ), + "r", + ) as f: + lines = f.readlines() + for line in lines: + key = line.split()[0] + value = " ".join(line.split()[1:]) + filename2transcript[key] = value + + splits = [ + # "train", + "dev", + "test", + ] + ID_start = 0 # needed to have a unique ID for each audio + for split in splits: + new_filename = os.path.join(save_folder, split) + ".csv" + if os.path.exists(new_filename): + continue + logger.info("Preparing %s..." % new_filename) + + csv_output = [["ID", "duration", "wav", "transcript"]] + entry = [] + + all_wavs = glob.glob( + os.path.join(data_folder, "data_aishell/wav") + "/" + split + "/*/*.wav" + ) + for i in range(len(all_wavs)): + filename = all_wavs[i].split("/")[-1].split(".wav")[0] + if filename not in filename2transcript: + continue + signal = read_audio(all_wavs[i]) + duration = signal.shape[0] / 16000 + transcript_ = filename2transcript[filename] + csv_line = [ + ID_start + i, + str(duration), + all_wavs[i], + transcript_, + ] + entry.append(csv_line) + + csv_output = csv_output + entry + + with open(new_filename, mode="w") as csv_f: + csv_writer = csv.writer( + csv_f, delimiter=",", quotechar='"', quoting=csv.QUOTE_MINIMAL + ) + for line in csv_output: + csv_writer.writerow(line) + + msg = "\t%s successfully created!" % (new_filename) + logger.info(msg) + + ID_start += len(all_wavs) + + +def parse_config(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--data_folder", + type=str, + default="/home/data/speechbrain/aishell", + help="data folder", + ) + parser.add_argument( + "--save_folder", + type=str, + default="/home/data/speechbrain/aishell/csv_data", + help="csv save folder", + ) + + config = parser.parse_args() + print("Config:", config) + return config + + +if __name__ == "__main__": + + config = parse_config() + prepare_aishell(config.data_folder, config.save_folder, skip_prep=False) diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/beam_search.py b/models/speech/speech_recognition/transformer_asr/ixrt/beam_search.py new file mode 100644 index 00000000..61e5c794 --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/beam_search.py @@ -0,0 +1,381 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +import torch +from ctc import CTCPrefixScorer +import time + +def forward(self, enc_states, wav_len): # noqa: C901 + """Applies beamsearch and returns the predicted tokens.""" + enc_lens = torch.round(enc_states.shape[1] * wav_len).int() + device = enc_states.device + batch_size = enc_states.shape[0] + + memory = self.reset_mem(batch_size * self.beam_size, device=device) + + if self.lm_weight > 0: + lm_memory = self.reset_lm_mem(batch_size * self.beam_size, device) + + if self.ctc_weight > 0: + # (batch_size * beam_size, L, vocab_size) + ctc_outputs = self.ctc_forward_step(enc_states) + ctc_scorer = CTCPrefixScorer( + ctc_outputs, + enc_lens, + batch_size, + self.beam_size, + self.blank_index, + self.eos_index, + self.ctc_window_size, + ) + ctc_memory = None + + # Inflate the enc_states and enc_len by beam_size times + enc_states = inflate_tensor(enc_states, times=self.beam_size, dim=0) + enc_lens = inflate_tensor(enc_lens, times=self.beam_size, dim=0) + + # Using bos as the first input + inp_tokens = ( + torch.zeros(batch_size * self.beam_size, device=device) + .fill_(self.bos_index) + .long() + ) + + # The first index of each sentence. + self.beam_offset = ( + torch.arange(batch_size, device=device) * self.beam_size + ) + + # initialize sequence scores variables. + sequence_scores = torch.empty( + batch_size * self.beam_size, device=device + ) + sequence_scores.fill_(float("-inf")) + + # keep only the first to make sure no redundancy. + sequence_scores.index_fill_(0, self.beam_offset, 0.0) + + # keep the hypothesis that reaches eos and their corresponding score and log_probs. + hyps_and_scores = [[] for _ in range(batch_size)] + + # keep the sequences that still not reaches eos. + alived_seq = torch.empty( + batch_size * self.beam_size, 0, device=device + ).long() + + # Keep the log-probabilities of alived sequences. + alived_log_probs = torch.empty( + batch_size * self.beam_size, 0, device=device + ) + + min_decode_steps = int(enc_states.shape[1] * self.min_decode_ratio) + max_decode_steps = int(enc_states.shape[1] * self.max_decode_ratio) + + # Initialize the previous attention peak to zero + # This variable will be used when using_max_attn_shift=True + prev_attn_peak = torch.zeros(batch_size * self.beam_size, device=device) + + for t in range(max_decode_steps): + # terminate condition + if self._check_full_beams(hyps_and_scores, self.beam_size): + break + + log_probs, memory, attn = self.forward_step( + inp_tokens, memory, enc_states, enc_lens + ) + log_probs = self.att_weight * log_probs + + # Keep the original value + log_probs_clone = log_probs.clone().reshape(batch_size, -1) + vocab_size = log_probs.shape[-1] + + if self.using_max_attn_shift: + # Block the candidates that exceed the max shift + cond, attn_peak = self._check_attn_shift(attn, prev_attn_peak) + log_probs = mask_by_condition( + log_probs, cond, fill_value=self.minus_inf + ) + prev_attn_peak = attn_peak + + # Set eos to minus_inf when less than minimum steps. + if t < min_decode_steps: + log_probs[:, self.eos_index] = self.minus_inf + + # Set the eos prob to minus_inf when it doesn't exceed threshold. + if self.using_eos_threshold: + cond = self._check_eos_threshold(log_probs) + log_probs[:, self.eos_index] = mask_by_condition( + log_probs[:, self.eos_index], + cond, + fill_value=self.minus_inf, + ) + + # adding LM scores to log_prob if lm_weight > 0 + if self.lm_weight > 0: + lm_log_probs, lm_memory = self.lm_forward_step( + inp_tokens, lm_memory + ) + log_probs = log_probs + self.lm_weight * lm_log_probs + + # adding CTC scores to log_prob if ctc_weight > 0 + if self.ctc_weight > 0: + g = alived_seq + # block blank token + log_probs[:, self.blank_index] = self.minus_inf + if self.ctc_weight != 1.0 and self.ctc_score_mode == "partial": + # pruning vocab for ctc_scorer + _, ctc_candidates = log_probs.topk( + self.beam_size * 2, dim=-1 + ) + else: + ctc_candidates = None + + ctc_log_probs, ctc_memory = ctc_scorer.forward_step( + g, ctc_memory, ctc_candidates, attn + ) + log_probs = log_probs + self.ctc_weight * ctc_log_probs + + scores = sequence_scores.unsqueeze(1).expand(-1, vocab_size) + scores = scores + log_probs + + # length normalization + if self.length_normalization: + scores = scores / (t + 1) + + # keep topk beams + scores, candidates = scores.view(batch_size, -1).topk( + self.beam_size, dim=-1 + ) + + # The input for the next step, also the output of current step. + inp_tokens = (candidates % vocab_size).view( + batch_size * self.beam_size + ) + + scores = scores.view(batch_size * self.beam_size) + sequence_scores = scores + + # recover the length normalization + if self.length_normalization: + sequence_scores = sequence_scores * (t + 1) + + # The index of which beam the current top-K output came from in (t-1) timesteps. + predecessors = ( + torch.div(candidates, vocab_size, rounding_mode="floor") + + self.beam_offset.unsqueeze(1).expand_as(candidates) + ).view(batch_size * self.beam_size) + + # Permute the memory to synchoronize with the output. + memory = self.permute_mem(memory, index=predecessors) + if self.lm_weight > 0: + lm_memory = self.permute_lm_mem(lm_memory, index=predecessors) + + if self.ctc_weight > 0: + ctc_memory = ctc_scorer.permute_mem(ctc_memory, candidates) + + # If using_max_attn_shift, then the previous attn peak has to be permuted too. + if self.using_max_attn_shift: + prev_attn_peak = torch.index_select( + prev_attn_peak, dim=0, index=predecessors + ) + + # Add coverage penalty + if self.coverage_penalty > 0: + cur_attn = torch.index_select(attn, dim=0, index=predecessors) + + # coverage: cumulative attention probability vector + if t == 0: + # Init coverage + self.coverage = cur_attn + + # the attn of transformer is [batch_size*beam_size, current_step, source_len] + if len(cur_attn.size()) > 2: + self.converage = torch.sum(cur_attn, dim=1) + else: + # Update coverage + self.coverage = torch.index_select( + self.coverage, dim=0, index=predecessors + ) + self.coverage = self.coverage + cur_attn + + # Compute coverage penalty and add it to scores + penalty = torch.max( + self.coverage, self.coverage.clone().fill_(0.5) + ).sum(-1) + penalty = penalty - self.coverage.size(-1) * 0.5 + penalty = penalty.view(batch_size * self.beam_size) + penalty = ( + penalty / (t + 1) if self.length_normalization else penalty + ) + scores = scores - penalty * self.coverage_penalty + + # Update alived_seq + alived_seq = torch.cat( + [ + torch.index_select(alived_seq, dim=0, index=predecessors), + inp_tokens.unsqueeze(1), + ], + dim=-1, + ) + + # Takes the log-probabilities + beam_log_probs = log_probs_clone[ + torch.arange(batch_size).unsqueeze(1), candidates + ].reshape(batch_size * self.beam_size) + alived_log_probs = torch.cat( + [ + torch.index_select( + alived_log_probs, dim=0, index=predecessors + ), + beam_log_probs.unsqueeze(1), + ], + dim=-1, + ) + + is_eos = self._update_hyp_and_scores( + inp_tokens, + alived_seq, + alived_log_probs, + hyps_and_scores, + scores, + timesteps=t, + ) + + # Block the paths that have reached eos. + sequence_scores.masked_fill_(is_eos, float("-inf")) + + if not self._check_full_beams(hyps_and_scores, self.beam_size): + # Using all eos to fill-up the hyps. + eos = ( + torch.zeros(batch_size * self.beam_size, device=device) + .fill_(self.eos_index) + .long() + ) + _ = self._update_hyp_and_scores( + eos, + alived_seq, + alived_log_probs, + hyps_and_scores, + scores, + timesteps=max_decode_steps, + ) + + ( + topk_hyps, + topk_scores, + topk_lengths, + log_probs, + ) = self._get_top_score_prediction(hyps_and_scores, topk=self.topk,) + # pick the best hyp + predictions = topk_hyps[:, 0, :] + predictions = batch_filter_seq2seq_output( + predictions, eos_id=self.eos_index + ) + + if self.return_log_probs: + return predictions, topk_scores, log_probs + else: + return predictions, topk_scores + + +def inflate_tensor(tensor, times, dim): + """This function inflates the tensor for times along dim. + + Arguments + --------- + tensor : torch.Tensor + The tensor to be inflated. + times : int + The tensor will inflate for this number of times. + dim : int + The dim to be inflated. + + Returns + ------- + torch.Tensor + The inflated tensor. + + Example + ------- + >>> tensor = torch.Tensor([[1,2,3], [4,5,6]]) + >>> new_tensor = inflate_tensor(tensor, 2, dim=0) + >>> new_tensor + tensor([[1., 2., 3.], + [1., 2., 3.], + [4., 5., 6.], + [4., 5., 6.]]) + """ + return torch.repeat_interleave(tensor, times, dim=dim) + +def batch_filter_seq2seq_output(prediction, eos_id=-1): + """Calling batch_size times of filter_seq2seq_output. + + Arguments + --------- + prediction : list of torch.Tensor + A list containing the output ints predicted by the seq2seq system. + eos_id : int, string + The id of the eos. + + Returns + ------ + list + The output predicted by seq2seq model. + + Example + ------- + >>> predictions = [torch.IntTensor([1,2,3,4]), torch.IntTensor([2,3,4,5,6])] + >>> predictions = batch_filter_seq2seq_output(predictions, eos_id=4) + >>> predictions + [[1, 2, 3], [2, 3]] + """ + outputs = [] + for p in prediction: + res = filter_seq2seq_output(p.tolist(), eos_id=eos_id) + outputs.append(res) + return outputs + +def filter_seq2seq_output(string_pred, eos_id=-1): + """Filter the output until the first eos occurs (exclusive). + + Arguments + --------- + string_pred : list + A list containing the output strings/ints predicted by the seq2seq system. + eos_id : int, string + The id of the eos. + + Returns + ------ + list + The output predicted by seq2seq model. + + Example + ------- + >>> string_pred = ['a','b','c','d','eos','e'] + >>> string_out = filter_seq2seq_output(string_pred, eos_id='eos') + >>> string_out + ['a', 'b', 'c', 'd'] + """ + if isinstance(string_pred, list): + try: + eos_index = next( + i for i, v in enumerate(string_pred) if v == eos_id + ) + except StopIteration: + eos_index = len(string_pred) + string_out = string_pred[:eos_index] + else: + raise ValueError("The input must be a list.") + return string_out \ No newline at end of file diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/build.sh b/models/speech/speech_recognition/transformer_asr/ixrt/build.sh new file mode 100644 index 00000000..a8991234 --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/build.sh @@ -0,0 +1,23 @@ +#!/usr/bin/env bash + +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +for i in fast* +do + cd $i + bash build.sh + cd .. +done diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/builder.py b/models/speech/speech_recognition/transformer_asr/ixrt/builder.py new file mode 100644 index 00000000..5c19a9f4 --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/builder.py @@ -0,0 +1,466 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +import argparse +import torch +from tensorrt.deploy.api import GraphTransform, create_source, create_target +from tensorrt.deploy.ir.data_type import DataType +from tensorrt.deploy.ir.variable import Variable, VariableOptions +from tensorrt.deploy.ir.graph import Graph +from collections import OrderedDict +import math +import re +import glob +import os +from onnx import numpy_helper +import subprocess + + +def parse_args(): + parser = argparse.ArgumentParser( + description="build ixrt engine", usage="" + ) + parser.add_argument( + "--ckpt_path", + type=str, + required=True, + help="", + ) + parser.add_argument( + "--head_num", + type=int, + required=True, + help="", + ) + parser.add_argument( + "--max_batch_size", + type=int, + required=True, + help="", + ) + parser.add_argument( + "--max_seq_len", + type=int, + required=True, + help="", + ) + parser.add_argument( + "--onnx_path", + type=str, + default=".tmp.onnx", + help="", + ) + parser.add_argument( + "--engine_path", + type=str, + required=True, + help="", + ) + args = parser.parse_args() + return args + + +def add_make_mask_op(graph, state_dict, args): + attributes = {} + + t = graph + inputs = [ + graph.make_variable('length_radio', dtype=DataType.FLOAT16), + graph.make_variable('input', dtype=DataType.FLOAT16), + ] + + outputs = [t.make_variable("attention_mask", dtype=DataType.INT32)] + + t.make_operator( + "MakeMaskByRadio_IxRT", inputs=inputs, outputs=outputs, **attributes + ) + + +def add_custom_linear_op(graph, state_dict, args): + linear_keys = [ + "1.custom_src_module.layers.0.w.weight", + "1.custom_src_module.layers.0.w.bias" + ] + W = numpy_helper.from_array(state_dict[linear_keys[0]].cpu().numpy(), name="W") + B = numpy_helper.from_array(state_dict[linear_keys[1]].cpu().numpy(), name="B") + attributes = { + "out_dims": state_dict["1.custom_src_module.layers.0.w.weight"].size(0), + "type_id": 1, + "W": W, + "B": B, + } + assert state_dict['1.custom_src_module.layers.0.w.weight'].size( + 0) == state_dict["1.custom_src_module.layers.0.w.bias"].size(0) + + t = graph + inputs = [ + graph.get_variable('input'), + ] + + outputs = [t.make_variable("custom_src_output")] + t.make_operator( + "CustomFCPluginDynamic_IxRT", inputs=inputs, outputs=outputs, **attributes + ) + + +# def add_custom_linear_op(graph, state_dict, args): +# linear_keys = [ +# "1.custom_src_module.layers.0.w.weight", +# "1.custom_src_module.layers.0.w.bias" +# ] +# attributes = { +# "linear_dim": state_dict["1.custom_src_module.layers.0.w.weight"].size(0), +# "hidden_size": state_dict["1.custom_src_module.layers.0.w.weight"].size(1), +# "has_bias": 1, +# "act_type": "none", +# } +# assert state_dict['1.custom_src_module.layers.0.w.weight'].size( +# 0) == state_dict["1.custom_src_module.layers.0.w.bias"].size(0) +# +# t = graph +# inputs = [ +# graph.get_variable('input'), +# ] +# +# outputs = [t.make_variable("custom_src_output",dtype=DataType.FLOAT16)] +# for key in linear_keys: +# inputs.append(t.make_variable(name=key, value=state_dict[key].half())) +# t.make_operator( +# "LinearFP16", inputs=inputs, outputs=outputs, **attributes +# ) + + +def add_pos_encode_op(graph, state_dict, args): + attributes = {} + t = graph + inputs = [ + graph.get_variable('custom_src_output'), + ] + outputs = [t.make_variable("hidden_state", dtype=DataType.FLOAT16)] + t.make_operator( + "PosEncodeSinCos_IxRT", inputs=inputs, outputs=outputs, **attributes + ) + + +def add_transformer_op(graph, state_dict, args): + enc_tensor_layer_fp16_keys = OrderedDict([ + ["1.encoder.layers.{}.norm1.norm.weight", [args.hidden_size]], + ["1.encoder.layers.{}.norm1.norm.bias", [args.hidden_size]], + ["1.encoder.layers.{}.self_att.att.in_proj_weight", + [args.hidden_size * 3, args.hidden_size]], + ["1.encoder.layers.{}.self_att.att.in_proj_bias", [args.hidden_size * 3]], + ["1.encoder.layers.{}.self_att.att.out_proj.weight", + [args.hidden_size, args.hidden_size]], + ["1.encoder.layers.{}.self_att.att.out_proj.bias", [args.hidden_size]], + ["1.encoder.layers.{}.pos_ffn.ffn.0.weight", + [args.inner_size, args.hidden_size]], + ["1.encoder.layers.{}.pos_ffn.ffn.0.bias", [args.inner_size]], + ["1.encoder.layers.{}.pos_ffn.ffn.3.weight", + [args.hidden_size, args.inner_size]], + ["1.encoder.layers.{}.pos_ffn.ffn.3.bias", [args.hidden_size]], + ["1.encoder.layers.{}.norm2.norm.weight", [args.hidden_size]], + ["1.encoder.layers.{}.norm2.norm.bias", [args.hidden_size]], + ]) + attributes_legcy = { + "hidden_size": args.hidden_size, + "num_layers": args.num_layers, + "head_num": args.head_num, + "head_dim": args.head_dim, + "inner_size": args.inner_size, + "act_type": "gelu", + "normalize_before": 1, + "is_fmha": 1, + "atten_scaler": 1 / math.sqrt(args.head_dim) + } + + + attributes = { + "hidden_size": int(args.hidden_size), + "num_layers": int(args.num_layers), + "head_num": int(args.head_num), + "head_dim": int(args.head_dim), + "inner_size": int(args.inner_size), + "act_type": 12, #gelu + "normalize_before": 1, + "is_fmha": 1, + "atten_scaler": 1.0 / math.sqrt(args.head_dim), + "max_seq_len": int(args.max_seq_len), + "max_batch_size": int(args.max_batch_size), + + } + + t = graph + inputs = [ + graph.get_variable('hidden_state'), + graph.get_variable('attention_mask'), + ] + outputs = [t.make_variable("encoder_out", dtype=DataType.FLOAT16)] + for layer_id in range(args.num_layers): + for key, shape in enc_tensor_layer_fp16_keys.items(): + # we need cat qkv gemm's weight and bias + new_key = key.format(layer_id) + w = state_dict[new_key] + if list(w.shape) != shape: + print("weights shape error!") + print("key: ", key) + print("need shape: ", shape) + print("weight shape: ", w.shape) + exit(1) + inputs.append(t.make_variable(name=new_key, value=w.half())) + t.make_operator( + "TransformerEncoderFp16_IxRT", inputs=inputs, outputs=outputs, **attributes + ) + + +def add_layer_norm_op(graph, state_dict, args): + enc_ln_tensor_fp16_keys = OrderedDict([ + ["1.encoder.norm.norm.weight", [args.hidden_size]], + ["1.encoder.norm.norm.bias", [args.hidden_size]], + ]) + attributes = { + "epsilon": 1e-5, + "axis": -1, + "stash_type": 1 + } + t = graph + inputs = [ + graph.get_variable('encoder_out'), + ] + outputs = [t.make_variable("encoder_ln_out")] + for key, shape in enc_ln_tensor_fp16_keys.items(): + new_key = key + w = state_dict[new_key] + if list(w.shape) != shape: + print("weights shape error!") + print("key: ", key) + print("need shape: ", shape) + print("weight shape: ", w.shape) + exit(1) + inputs.append(t.make_variable(name=new_key, value=w.half())) + t.make_operator( + "LayerNormalization", inputs=inputs, outputs=outputs, **attributes + ) + + +# def add_layer_norm_op(graph, state_dict, args): +# enc_ln_tensor_fp16_keys = OrderedDict([ +# ["1.encoder.norm.norm.weight", [args.hidden_size]], +# ["1.encoder.norm.norm.bias", [args.hidden_size]], +# ]) +# attributes = { +# "hidden_size": args.hidden_size, +# } +# t = graph +# inputs = [ +# graph.get_variable('encoder_out'), +# ] +# outputs = [t.make_variable("encoder_ln_out",dtype=DataType.FLOAT16)] +# for key, shape in enc_ln_tensor_fp16_keys.items(): +# new_key = key +# w = state_dict[new_key] +# if list(w.shape) != shape: +# print("weights shape error!") +# print("key: ", key) +# print("need shape: ", shape) +# print("weight shape: ", w.shape) +# exit(1) +# inputs.append(t.make_variable(name=new_key, value=w.half())) +# t.make_operator( +# "LayerNormFp16", inputs=inputs, outputs=outputs, **attributes +# ) + +def add_linear_op(graph, state_dict, args): + linear_keys = [ + "3.w.weight", + "3.w.bias" + ] + W = numpy_helper.from_array(state_dict[linear_keys[0]].cpu().numpy(), name="W") + B = numpy_helper.from_array(state_dict[linear_keys[1]].cpu().numpy(), name="B") + attributes = { + "out_dims": state_dict["3.w.weight"].size(0), + "type_id": 1, + "W": W, + "B": B, + } + assert state_dict['3.w.weight'].size(0) == state_dict["3.w.bias"].size(0) + + t = graph + inputs = [ + graph.get_variable('encoder_ln_out'), + ] + + outputs = [t.make_variable("lin_output")] + t.make_operator( + "CustomFCPluginDynamic_IxRT", inputs=inputs, outputs=outputs, **attributes + ) + + +# +# def add_linear_op(graph, state_dict, args): +# lin_keys = [ +# "3.w.weight", +# "3.w.bias" +# ] +# attributes = { +# "linear_dim": state_dict["3.w.weight"].size(0), +# "hidden_size": state_dict["3.w.weight"].size(1), +# "has_bias": 1, +# "act_type": "none", +# } +# assert state_dict['3.w.weight'].size(0) == state_dict["3.w.bias"].size(0) +# +# t = graph +# inputs = [ +# graph.get_variable('encoder_ln_out'), +# ] +# +# outputs = [t.make_variable("lin_output",dtype=DataType.FLOAT16)] +# for key in lin_keys: +# inputs.append(t.make_variable(name=key, value=state_dict[key].half())) +# t.make_operator( +# "LinearFP16", inputs=inputs, outputs=outputs, **attributes +# ) + + +def add_log_softmax_op(graph, state_dict, args): + attributes = { + "axis": "-1", + } + + t = graph + inputs = [ + graph.get_variable('lin_output'), + ] + + outputs = [t.make_variable("log_softmax_output", dtype=DataType.FLOAT16)] + + t.make_operator( + "LogSoftmax", inputs=inputs, outputs=outputs, **attributes + ) + + +def add_search_node(graph, state_dict, args): + attributes = { + "vocab_size": args.vocab_size, + "eos_id": args.vocab_size, + "pad_id": -10000, + "beam_size": 1, + "attr1": 1.0, + "min_decode_ratio": 0.0, + "max_decode_ratio": 1.0, + "ctc_weight": 0.40, + "using_eos_threshold": 0, + "length_normalization": 1, + } + t = graph + inputs = [ + graph.get_variable('lin_output'), + ] + + outputs = [t.make_variable("output_tokens", dtype=DataType.INT32)] + list_value_half = [] + list_key_half = [] + for key in state_dict.keys(): + if "decoder" in key or "custom_tgt_module" in key or "2.w.weight" in key or "2.w.bias" in key: + list_key_half.append(key) + list_value_half.append(state_dict[key].half()) + for i, item in enumerate(list_key_half): + inputs.append(t.make_variable(name=list_key_half[i], value=list_value_half[i])) + t.make_operator( + "Search", inputs=inputs, outputs=outputs, **attributes + ) + + +def get_num_layers(state_dict): + num_layers = -1 + for key in state_dict: + layer_id = re.search( + "1.encoder.layers.([0-9]+).pos_ffn.ffn.0.bias", key) + if layer_id: + layer_id = layer_id.group(1) + num_layers = max(num_layers, int(layer_id) + 1) + assert num_layers > 0 + return num_layers + + +def build_engine(onnx_file, engine_file, max_batch_size,max_seq_len): + cmd = f"ixrtexec --onnx {onnx_file} --min_shape input:1x32x5120,length_radio:1 --opt_shape input:8x64x5120,length_radio:8 --max_shape input:{max_batch_size}x{max_seq_len}x5120,length_radio:64 --plugins ixrt_plugin --save_engine {engine_file}" + subprocess.run(cmd.split(), check=True) + + +def main(args): + graph = Graph() + transform = GraphTransform(graph) + ckpt_path = glob.glob(os.path.join(args.ckpt_path, "*/model.ckpt"))[0] + print("load ckpt from: ", ckpt_path) + state_dict = torch.load(ckpt_path) + + # print([i for i in state_dict ]) + # print(state_dict['3.w.bias']) + args.hidden_size = state_dict['1.encoder.layers.0.norm1.norm.weight'].size( + 0) + args.head_dim = args.hidden_size / args.head_num + args.inner_size = state_dict['1.encoder.layers.0.pos_ffn.ffn.0.bias'].size( + 0) + args.vocab_size = state_dict['3.w.weight'].size(0) + + args.num_layers = get_num_layers(state_dict) + + args.src_len = state_dict["1.custom_src_module.layers.0.w.weight"].size(1) + + # args.num_layers = 1 + add_make_mask_op(transform, state_dict, args) + add_custom_linear_op(transform, state_dict, args) + add_pos_encode_op(transform, state_dict, args) + add_transformer_op(transform, state_dict, args) + add_layer_norm_op(transform, state_dict, args) + # add_linear_op(transform, state_dict, args) + # add_log_softmax_op(transform, state_dict, args) + # add_search_node(transform, state_dict, args) + + # IO attributes + length_radio = graph.get_variable('length_radio') + length_radio.set_shape(["batch_size"]) + length_radio.dtype = "float16" + graph.add_input(length_radio) + + input = graph.get_variable('input') + input.set_shape(["batch_size", "seq_len", "src_len"]) + input.dtype = "float16" + graph.add_input(input) + + output = graph.get_variable('encoder_ln_out') + output.dtype = "float16" + graph.add_output(output) + + create_target(saved_path=args.onnx_path).export(graph) + + build_engine(args.onnx_path, args.engine_path, args.max_batch_size, args.max_seq_len) + print("save engine: ", args.engine_path) + + +if __name__ == "__main__": + args = parse_args() + ckpt_path = args.ckpt_path + + main(args) + +""" +python3 builder.py \ +--ckpt_path results/transformer/8886/save \ +--head_num 4 \ +--max_batch_size 64 \ +--max_seq_len 1024 \ +--engine_path transformer.engine +""" diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/convert.py b/models/speech/speech_recognition/transformer_asr/ixrt/convert.py new file mode 100644 index 00000000..11d71a56 --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/convert.py @@ -0,0 +1,95 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +import torch +from faster_layer_norm import FasterLayerNorm + +def replace_layer_norm(model): + module_output = model + + if isinstance(model, torch.nn.modules.normalization.LayerNorm): + return FasterLayerNorm(model.weight, model.bias) + + for name, child in model.named_children(): + module_output.add_module( + name, replace_layer_norm(child) + ) + return module_output + + +def convert_decoder_model(model): + model = replace_layer_norm(model) + # for layer in model.layers: + # norm = layer.norm1.norm + # print(type(norm)) + # exit() + # new_norm = FasterLayerNorm(norm.weight, norm.bias) + # layer.norm1.norm = new_norm + + # norm = layer.norm2.norm + # new_norm = FasterLayerNorm(norm.weight, norm.bias) + # layer.norm2.norm = new_norm + + # norm = layer.norm3.norm + # new_norm = FasterLayerNorm(norm.weight, norm.bias) + # layer.norm3.norm = new_norm + return model + +# def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''): +# if type(module) in layers: +# return {name: module} +# res = {} +# for name1, child in module.named_children(): +# res.update(find_layers( +# child, layers=layers, name=name + '.' + name1 if name != '' else name1 +# )) +# return res + +def find_node(module): + if type(module) in [torch.nn.LayerNorm]: + print(module) + return + res = {} + for name1, child in module.named_children(): + find_node(child) + return + + +def patch_get_lookahead_mask(padded_input): + """Creates a binary mask for each sequence which maskes future frames. + + Arguments + --------- + padded_input: torch.Tensor + Padded input tensor. + + Example + ------- + >>> a = torch.LongTensor([[1,1,0], [2,3,0], [4,5,0]]) + >>> get_lookahead_mask(a) + tensor([[0., -inf, -inf], + [0., 0., -inf], + [0., 0., 0.]]) + """ + seq_len = padded_input.shape[1] + mask = ( + torch.triu(torch.ones((seq_len, seq_len), device=padded_input.device)) + == 1 + ).transpose(0, 1) + mask = ( + mask.float() + .masked_fill(mask == 0, float("-inf")) + .masked_fill(mask == 1, float(0.0)) + ) + return mask.detach().to(padded_input.device).to(torch.float16) \ No newline at end of file diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/ctc.py b/models/speech/speech_recognition/transformer_asr/ixrt/ctc.py new file mode 100644 index 00000000..9db6ab7e --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/ctc.py @@ -0,0 +1,394 @@ +"""Decoders and output normalization for CTC. + +Authors + * Mirco Ravanelli 2020 + * Aku Rouhe 2020 + * Sung-Lin Yeh 2020 +""" +import torch +from itertools import groupby +from speechbrain.dataio.dataio import length_to_mask +from faster_logsumexp import FasterLogSumExp +from faster_stack import FasterStack +from faster_cat import FastCat + + +class CTCPrefixScorer: + """This class implements the CTC prefix scorer of Algorithm 2 in + reference: https://www.merl.com/publications/docs/TR2017-190.pdf. + Official implementation: https://github.com/espnet/espnet/blob/master/espnet/nets/ctc_prefix_score.py + + Arguments + --------- + x : torch.Tensor + The encoder states. + enc_lens : torch.Tensor + The actual length of each enc_states sequence. + batch_size : int + The size of the batch. + beam_size : int + The width of beam. + blank_index : int + The index of the blank token. + eos_index : int + The index of the end-of-sequence (eos) token. + ctc_window_size: int + Compute the ctc scores over the time frames using windowing based on attention peaks. + If 0, no windowing applied. + """ + + def __init__( + self, + x, + enc_lens, + batch_size, + beam_size, + blank_index, + eos_index, + ctc_window_size=0, + ): + self.blank_index = blank_index + self.eos_index = eos_index + self.max_enc_len = x.size(1) + self.batch_size = batch_size + self.beam_size = beam_size + self.vocab_size = x.size(-1) + self.device = x.device + self.minus_inf = -1e4 + self.last_frame_index = enc_lens - 1 + self.ctc_window_size = ctc_window_size + + # mask frames > enc_lens + mask = 1 - length_to_mask(enc_lens) + mask = mask.unsqueeze(-1).expand(-1, -1, x.size(-1)).eq(1) + x.masked_fill_(mask, self.minus_inf) + x[:, :, 0] = x[:, :, 0].masked_fill_(mask[:, :, 0], 0) + + # dim=0: xnb, nonblank posteriors, dim=1: xb, blank posteriors + xnb = x.transpose(0, 1) + xb = ( + xnb[:, :, self.blank_index] + .unsqueeze(2) + .expand(-1, -1, self.vocab_size) + ) + + # (2, L, batch_size * beam_size, vocab_size) + # self.x = torch.stack([xnb, xb]) + self.x = FasterStack([xnb.contiguous(), xb.contiguous()]) + + # The first index of each sentence. + self.beam_offset = ( + torch.arange(batch_size, device=self.device) * self.beam_size + ) + # The first index of each candidates. + self.cand_offset = ( + torch.arange(batch_size, device=self.device) * self.vocab_size + ) + + def forward_step(self, g, state, candidates=None, attn=None): + """This method if one step of forwarding operation + for the prefix ctc scorer. + + Arguments + --------- + g : torch.Tensor + The tensor of prefix label sequences, h = g + c. + state : tuple + Previous ctc states. + candidates : torch.Tensor + (batch_size * beam_size, ctc_beam_size), The topk candidates for rescoring. + The ctc_beam_size is set as 2 * beam_size. If given, performing partial ctc scoring. + """ + + prefix_length = g.size(1) + last_char = [gi[-1] for gi in g] if prefix_length > 0 else [0] * len(g) + self.num_candidates = ( + self.vocab_size if candidates is None else candidates.size(-1) + ) + if state is None: + # r_prev: (L, 2, batch_size * beam_size) + r_prev = torch.full( + (self.max_enc_len, 2, self.batch_size, self.beam_size), + self.minus_inf, + device=self.device, + dtype=torch.float16 + ) + + # Accumulate blank posteriors at each step + r_prev[:, 1] = torch.cumsum( + self.x[0, :, :, self.blank_index], 0 + ).unsqueeze(2) + r_prev = r_prev.view(-1, 2, self.batch_size * self.beam_size) + psi_prev = 0.0 + else: + r_prev, psi_prev = state + r_prev = r_prev.half() + + # for partial search + if candidates is not None: + scoring_table = torch.full( + (self.batch_size * self.beam_size, self.vocab_size), + -1, + dtype=torch.long, + device=self.device, + ) + # Assign indices of candidates to their positions in the table + col_index = torch.arange( + self.batch_size * self.beam_size, device=self.device + ).unsqueeze(1) + scoring_table[col_index, candidates] = torch.arange( + self.num_candidates, device=self.device + ) + # Select candidates indices for scoring + scoring_index = ( + candidates + + self.cand_offset.unsqueeze(1) + .repeat(1, self.beam_size) + .view(-1, 1) + ).view(-1) + x_inflate = torch.index_select( + self.x.view(2, -1, self.batch_size * self.vocab_size), + 2, + scoring_index, + ).view(2, -1, self.batch_size * self.beam_size, self.num_candidates) + # for full search + else: + scoring_table = None + x_inflate = ( + self.x.unsqueeze(3) + .repeat(1, 1, 1, self.beam_size, 1) + .view( + 2, -1, self.batch_size * self.beam_size, self.num_candidates + ) + ) + + # Prepare forward probs + r = torch.full( + ( + self.max_enc_len, + 2, + self.batch_size * self.beam_size, + self.num_candidates, + ), + self.minus_inf, + device=self.device, + dtype=torch.float16 + ) + r.fill_(self.minus_inf) + + # (Alg.2-6) + if prefix_length == 0: + r[0, 0] = x_inflate[0, 0] + # (Alg.2-10): phi = prev_nonblank + prev_blank = r_t-1^nb(g) + r_t-1^b(g) + r_sum = FasterLogSumExp(r_prev, 1) + phi = r_sum.unsqueeze(2).repeat(1, 1, self.num_candidates) + + # (Alg.2-10): if last token of prefix g in candidates, phi = prev_b + 0 + if candidates is not None: + for i in range(self.batch_size * self.beam_size): + pos = scoring_table[i, last_char[i]] + if pos != -1: + phi[:, i, pos] = r_prev[:, 1, i] + else: + for i in range(self.batch_size * self.beam_size): + phi[:, i, last_char[i]] = r_prev[:, 1, i] + + # Start, end frames for scoring (|g| < |h|). + # Scoring based on attn peak if ctc_window_size > 0 + if self.ctc_window_size == 0 or attn is None: + start = max(1, prefix_length) + end = self.max_enc_len + else: + _, attn_peak = torch.max(attn, dim=1) + max_frame = torch.max(attn_peak).item() + self.ctc_window_size + min_frame = torch.min(attn_peak).item() - self.ctc_window_size + start = max(max(1, prefix_length), int(min_frame)) + end = min(self.max_enc_len, int(max_frame)) + + # Compute forward prob log(r_t^nb(h)) and log(r_t^b(h)): + for t in range(start, end): + # (Alg.2-11): dim=0, p(h|cur step is nonblank) = [p(prev step=y) + phi] * p(c) + rnb_prev = r[t - 1, 0] + # (Alg.2-12): dim=1, p(h|cur step is blank) = [p(prev step is blank) + p(prev step is nonblank)] * p(blank) + rb_prev = r[t - 1, 1] + # r_ = torch.stack([rnb_prev, phi[t - 1], rnb_prev, rb_prev]).view( + # 2, 2, self.batch_size * self.beam_size, self.num_candidates + # ) + r_ = FasterStack([rnb_prev, phi[t - 1], rnb_prev, rb_prev]).view( + 2, 2, self.batch_size * self.beam_size, self.num_candidates + ) + r[t] = FasterLogSumExp(r_, 1) + x_inflate[:, t] + + # Compute the predix prob, psi + psi_init = r[start - 1, 0].unsqueeze(0) + # phi is prob at t-1 step, shift one frame and add it to the current prob p(c) + phix = FastCat((phi[0].unsqueeze(0), phi[:-1]), dim=0) + x_inflate[0] + + # (Alg.2-13): psi = psi + phi * p(c) + if candidates is not None: + psi = torch.full( + (self.batch_size * self.beam_size, self.vocab_size), + self.minus_inf, + device=self.device, + dtype=torch.float16 + ) + psi_ = FasterLogSumExp( + FastCat((phix[start:end], psi_init), dim=0), dim=0 + ) + # only assign prob to candidates + for i in range(self.batch_size * self.beam_size): + psi[i, candidates[i]] = psi_[i] + else: + psi = FastCat((phix[start:end], psi_init), dim=0) + psi = FasterLogSumExp(psi, dim=0) + + # (Alg.2-3): if c = , psi = log(r_T^n(g) + r_T^b(g)), where T is the length of max frames + for i in range(self.batch_size * self.beam_size): + psi[i, self.eos_index] = r_sum[ + self.last_frame_index[i // self.beam_size], i + ] + + # Exclude blank probs for joint scoring + psi[:, self.blank_index] = self.minus_inf + + return psi - psi_prev, (r, psi, scoring_table) + + def permute_mem(self, memory, index): + """This method permutes the CTC model memory + to synchronize the memory index with the current output. + + Arguments + --------- + memory : No limit + The memory variable to be permuted. + index : torch.Tensor + The index of the previous path. + + Return + ------ + The variable of the memory being permuted. + + """ + r, psi, scoring_table = memory + # The index of top-K vocab came from in (t-1) timesteps. + best_index = ( + index + + (self.beam_offset.unsqueeze(1).expand_as(index) * self.vocab_size) + ).view(-1) + # synchronize forward prob + psi = torch.index_select(psi.view(-1), dim=0, index=best_index) + psi = ( + psi.view(-1, 1) + .repeat(1, self.vocab_size) + .view(self.batch_size * self.beam_size, self.vocab_size) + ) + + # synchronize ctc states + if scoring_table is not None: + effective_index = ( + index // self.vocab_size + self.beam_offset.view(-1, 1) + ).view(-1) + selected_vocab = (index % self.vocab_size).view(-1) + score_index = scoring_table[effective_index, selected_vocab] + score_index[score_index == -1] = 0 + best_index = score_index + effective_index * self.num_candidates + + r = torch.index_select( + r.view( + -1, 2, self.batch_size * self.beam_size * self.num_candidates + ), + dim=-1, + index=best_index, + ) + r = r.view(-1, 2, self.batch_size * self.beam_size) + + return r, psi + + +def filter_ctc_output(string_pred, blank_id=-1): + """Apply CTC output merge and filter rules. + + Removes the blank symbol and output repetitions. + + Arguments + --------- + string_pred : list + A list containing the output strings/ints predicted by the CTC system. + blank_id : int, string + The id of the blank. + + Returns + ------- + list + The output predicted by CTC without the blank symbol and + the repetitions. + + Example + ------- + >>> string_pred = ['a','a','blank','b','b','blank','c'] + >>> string_out = filter_ctc_output(string_pred, blank_id='blank') + >>> print(string_out) + ['a', 'b', 'c'] + """ + + if isinstance(string_pred, list): + # Filter the repetitions + string_out = [ + v + for i, v in enumerate(string_pred) + if i == 0 or v != string_pred[i - 1] + ] + + # Remove duplicates + string_out = [i[0] for i in groupby(string_out)] + + # Filter the blank symbol + string_out = list(filter(lambda elem: elem != blank_id, string_out)) + else: + raise ValueError("filter_ctc_out can only filter python lists") + return string_out + + +def ctc_greedy_decode(probabilities, seq_lens, blank_id=-1): + """Greedy decode a batch of probabilities and apply CTC rules. + + Arguments + --------- + probabilities : torch.tensor + Output probabilities (or log-probabilities) from the network with shape + [batch, probabilities, time] + seq_lens : torch.tensor + Relative true sequence lengths (to deal with padded inputs), + the longest sequence has length 1.0, others a value between zero and one + shape [batch, lengths]. + blank_id : int, string + The blank symbol/index. Default: -1. If a negative number is given, + it is assumed to mean counting down from the maximum possible index, + so that -1 refers to the maximum possible index. + + Returns + ------- + list + Outputs as Python list of lists, with "ragged" dimensions; padding + has been removed. + + Example + ------- + >>> import torch + >>> probs = torch.tensor([[[0.3, 0.7], [0.0, 0.0]], + ... [[0.2, 0.8], [0.9, 0.1]]]) + >>> lens = torch.tensor([0.51, 1.0]) + >>> blank_id = 0 + >>> ctc_greedy_decode(probs, lens, blank_id) + [[1], [1]] + """ + if isinstance(blank_id, int) and blank_id < 0: + blank_id = probabilities.shape[-1] + blank_id + batch_max_len = probabilities.shape[1] + batch_outputs = [] + for seq, seq_len in zip(probabilities, seq_lens): + actual_size = int(torch.round(seq_len * batch_max_len)) + scores, predictions = torch.max(seq.narrow(0, 0, actual_size), dim=1) + out = filter_ctc_output(predictions.tolist(), blank_id=blank_id) + batch_outputs.append(out) + return batch_outputs diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/faster_cat/__init__.py b/models/speech/speech_recognition/transformer_asr/ixrt/faster_cat/__init__.py new file mode 100644 index 00000000..537d35c5 --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/faster_cat/__init__.py @@ -0,0 +1,13 @@ +import torch +from faster_cat import sp_opt + +def FastCat(inputs,dim=0): + if len(inputs) == 2 and dim==0: + a,b = inputs + in_shape = a.shape + if len(in_shape)>1: + res, = sp_opt.test_opt_2(a.view(a.shape[0],-1),b.view(b.shape[0],-1)) + new_shape = (a.shape[0]+b.shape[0],) + in_shape[1:] + res = res.view(*new_shape) + return res + return torch.cat(inputs,dim=dim) \ No newline at end of file diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/faster_cat/build.sh b/models/speech/speech_recognition/transformer_asr/ixrt/faster_cat/build.sh new file mode 100644 index 00000000..f679258d --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/faster_cat/build.sh @@ -0,0 +1,22 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +set -euox pipefail + +rm -rf build +rm -rf *.so + +python3 setup.py build + +cp build/lib*/*.so . \ No newline at end of file diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/faster_cat/kernel.cu b/models/speech/speech_recognition/transformer_asr/ixrt/faster_cat/kernel.cu new file mode 100644 index 00000000..022fac39 --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/faster_cat/kernel.cu @@ -0,0 +1,79 @@ +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace iluvatar::inferrt::transformer { + +__global__ void Cat(half* a, half* b, half* output, int m1, int m2, int k) { + int i = blockIdx.y * blockDim.x + threadIdx.x; + // a + if (blockIdx.x < m1) { + half2* h2_a = reinterpret_cast(a + blockIdx.x * k); + half2* h2_out_a = reinterpret_cast(output + blockIdx.x * k); + if (i < k / 2) { + h2_out_a[i] = h2_a[i]; + } + } + // b + if (blockIdx.x < m2) { + half2* h2_b = reinterpret_cast(b + blockIdx.x * k); + half2* h2_out_b = + reinterpret_cast(output + blockIdx.x * k + m1 * k); + if (i < k / 2) { + h2_out_b[i] = h2_b[i]; + } + } +} + +void IxinferCatLauncher(half* a, half* b, half* output, int m1, int m2, int k, + cudaStream_t stream) { + if (k % 2 != 0) { + throw std::runtime_error("IxinferStackLauncher: size error!"); + } + int m = std::max(m1, m2); + int num_threads = 1024; + int half_k = k / 2; + int num_roll = (half_k - 1 + num_threads) / num_threads; + dim3 grid(m, num_roll); + dim3 block(num_threads); + Cat<<>>(a, b, output, m1, m2, k); +} + +} // namespace iluvatar::inferrt::transformer + +std::vector one_test_opt_2(at::Tensor a, at::Tensor b) { + TORCH_CHECK(a.scalar_type() == at::ScalarType::Half); + TORCH_CHECK(a.is_cuda()); + TORCH_CHECK(a.is_contiguous()); + + TORCH_CHECK(b.scalar_type() == at::ScalarType::Half); + TORCH_CHECK(b.is_cuda()); + TORCH_CHECK(b.is_contiguous()); + + TORCH_CHECK(a.dim() == 2); + TORCH_CHECK(b.dim() == 2); + + int m1 = a.size(0); + int m2 = b.size(0); + + int k = a.size(1); + + TORCH_CHECK(b.size(1) == k); + + at::Tensor output = a.new_empty({(m1 + m2), k}); + + half* p_a = (half*)a.data_ptr(); + half* p_b = (half*)b.data_ptr(); + half* p_out = (half*)output.data_ptr(); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + iluvatar::inferrt::transformer::IxinferCatLauncher(p_a, p_b, p_out, m1, m2, k, + stream); + return {output}; +} diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/faster_cat/setup.py b/models/speech/speech_recognition/transformer_asr/ixrt/faster_cat/setup.py new file mode 100644 index 00000000..a031577c --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/faster_cat/setup.py @@ -0,0 +1,48 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +import os +import glob +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +CUR_DIR = os.path.dirname(os.path.abspath(__file__)) + +# cpp_files = glob.glob(os.path.join(CUR_DIR,"*.cpp")) +# cu_files = glob.glob(os.path.join(CUR_DIR,'*.cu')) +# source_files = cpp_files + cu_files +# print("source files:") +# for i in source_files: +# print(i) +source_files = [ + os.path.join(CUR_DIR,'test.cpp'), + os.path.join(CUR_DIR,'kernel.cu'), +] + +for i in source_files: + assert os.path.isfile(i) + print(i) + +setup( + name="test", + ext_modules=[ + CUDAExtension( + name="sp_opt", + libraries=["cuinfer"], + sources=source_files) + ], + cmdclass={ + "build_ext": BuildExtension + } +) diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/faster_cat/test.cpp b/models/speech/speech_recognition/transformer_asr/ixrt/faster_cat/test.cpp new file mode 100644 index 00000000..11720811 --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/faster_cat/test.cpp @@ -0,0 +1,21 @@ +#include +#include +#include +#include +#include +#include + +#include +#include +#include + + +std::vector one_test_opt_2(at::Tensor a, at::Tensor b); + +std::vector test_opt_2(at::Tensor a, at::Tensor b) { + return one_test_opt_2(a, b); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("test_opt_2", &test_opt_2, ""); +} diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/faster_cat/test.py b/models/speech/speech_recognition/transformer_asr/ixrt/faster_cat/test.py new file mode 100644 index 00000000..2713dae2 --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/faster_cat/test.py @@ -0,0 +1,37 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +import torch +import sp_opt + +if __name__ == "__main__": + m1 = 320 + m2 = 321 + hidden_size = 5000 + + a = torch.randn([m1,hidden_size]).cuda().half() + b = torch.randn([m2,hidden_size]).cuda().half() + + + res_pt = torch.cat([a,b],dim=0) + + res_cu, = sp_opt.test_opt_2(a,b) + + + diff = torch.abs(res_pt-res_cu) + print(diff) + print(diff.max()) + + for i in range(20): + res_cu, = sp_opt.test_opt_2(a,b) \ No newline at end of file diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/faster_layer_norm/__init__.py b/models/speech/speech_recognition/transformer_asr/ixrt/faster_layer_norm/__init__.py new file mode 100644 index 00000000..20603650 --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/faster_layer_norm/__init__.py @@ -0,0 +1,16 @@ +import torch +from faster_layer_norm import sp_opt + +class FasterLayerNorm(torch.nn.Module): + def __init__(self, weight, bias): + super(FasterLayerNorm, self).__init__() + self.weight = weight + self.bias = bias + + def forward(self, inputs, *args, **kwargs): + hidden_size = self.weight.size(0) + in_shape = inputs.shape + inputs = inputs.view(-1,hidden_size) + output, = sp_opt.test_opt(inputs,self.weight,self.bias) + output = output.view(*in_shape) + return output diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/faster_layer_norm/build.sh b/models/speech/speech_recognition/transformer_asr/ixrt/faster_layer_norm/build.sh new file mode 100644 index 00000000..f679258d --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/faster_layer_norm/build.sh @@ -0,0 +1,22 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +set -euox pipefail + +rm -rf build +rm -rf *.so + +python3 setup.py build + +cp build/lib*/*.so . \ No newline at end of file diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/faster_layer_norm/kernel.cu b/models/speech/speech_recognition/transformer_asr/ixrt/faster_layer_norm/kernel.cu new file mode 100644 index 00000000..852db917 --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/faster_layer_norm/kernel.cu @@ -0,0 +1,168 @@ +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "transformer_helper.cuh" + +namespace iluvatar::inferrt::transformer { + +template +__global__ void LnOpt2Kernel(half* input, half* ln_weight, half* ln_bias, + half* output, int hidden_size, + float layernorm_eps) { + input += blockIdx.x * hidden_size; + output += blockIdx.x * hidden_size; + + half2* p_in = reinterpret_cast(input); + half2* p_out = reinterpret_cast(output); + half2* p_wei = reinterpret_cast(ln_weight); + half2* p_bias = reinterpret_cast(ln_bias); + int half_hidden_size = hidden_size / 2; + + extern __shared__ half2 shmem[]; + + float s_mean; + float s_variance; + float x_sum = 0.0f; + float x2_sum = 0.0f; +#pragma unroll UNROLL_FACTOR + for (int i = 0; i < UNROLL_FACTOR; ++i) { + int index = i * blockDim.x + threadIdx.x; + if (index < half_hidden_size) { + half2 value = p_in[index]; + shmem[index] = value; + float val_1 = __half2float(value.x); + float val_2 = __half2float(value.y); + x_sum += val_1 + val_2; + x2_sum += val_1 * val_1 + val_2 * val_2; + } + } + float sums[2]; // 和,平方和 + sums[0] = x_sum; + sums[1] = x2_sum; + blockReduceSumV2(sums); + + s_mean = sums[0] / hidden_size; + s_variance = rsqrtf(sums[1] / hidden_size - s_mean * s_mean + layernorm_eps); + +#pragma unroll UNROLL_FACTOR + for (int i = 0; i < UNROLL_FACTOR; ++i) { + int index = i * blockDim.x + threadIdx.x; + if (index < half_hidden_size) { + half2 wei_value = p_wei[index]; + half2 bias_value = p_bias[index]; + half2 vals_value = shmem[index]; + + float2 norm_value; + norm_value.x = (__half2float(vals_value.x) - s_mean) * s_variance * + __half2float(wei_value.x) + + __half2float(bias_value.x); + norm_value.y = (__half2float(vals_value.y) - s_mean) * s_variance * + __half2float(wei_value.y) + + __half2float(bias_value.y); + + __half2 res; + res.x = __float2half(norm_value.x); + res.y = __float2half(norm_value.y); + + p_out[index] = res; + } + } +} + +// FasterTransformer/src/fastertransformer/kernels/layernorm_kernels.cu +void IxinferLnLauncherOpt2(__half* input, __half* ln_weight, __half* ln_bias, + __half* output, int batch_tokens, int hidden_size, + cudaStream_t stream) { + const float layernorm_eps = 1e-5; + if (hidden_size % 2 != 0) { + throw std::runtime_error("layer norm error: hidden_size % 2 != 0"); + } + dim3 grid(batch_tokens); + int half_n = hidden_size / 2; + int half_n_warp = (half_n + warpSize - 1) / warpSize * warpSize; + dim3 block(std::min(half_n_warp, 1024)); + int rolls_per_thread = (half_n + block.x - 1) / block.x; + switch (rolls_per_thread) { + case 1: + LnOpt2Kernel<1><<>>( + input, ln_weight, ln_bias, output, hidden_size, layernorm_eps); + break; + case 2: + LnOpt2Kernel<2><<>>( + input, ln_weight, ln_bias, output, hidden_size, layernorm_eps); + break; + case 3: + LnOpt2Kernel<3><<>>( + input, ln_weight, ln_bias, output, hidden_size, layernorm_eps); + break; + case 4: + LnOpt2Kernel<4><<>>( + input, ln_weight, ln_bias, output, hidden_size, layernorm_eps); + break; + case 5: + LnOpt2Kernel<5><<>>( + input, ln_weight, ln_bias, output, hidden_size, layernorm_eps); + break; + case 6: + LnOpt2Kernel<6><<>>( + input, ln_weight, ln_bias, output, hidden_size, layernorm_eps); + break; + case 7: + LnOpt2Kernel<7><<>>( + input, ln_weight, ln_bias, output, hidden_size, layernorm_eps); + break; + case 8: + LnOpt2Kernel<8><<>>( + input, ln_weight, ln_bias, output, hidden_size, layernorm_eps); + break; + default: + std::cout << "hidden_size: " << hidden_size << std::endl; + throw std::runtime_error("layer norm error, unsupport hidden size! "); + break; + } +} +} // namespace iluvatar::inferrt::transformer + +std::vector one_test_opt(at::Tensor input, at::Tensor ln_weight, + at::Tensor ln_bias) { + TORCH_CHECK(input.scalar_type() == at::ScalarType::Half); + TORCH_CHECK(input.is_cuda()); + TORCH_CHECK(input.is_contiguous()); + + TORCH_CHECK(ln_weight.scalar_type() == at::ScalarType::Half); + TORCH_CHECK(ln_weight.is_cuda()); + TORCH_CHECK(ln_weight.is_contiguous()); + + TORCH_CHECK(ln_bias.scalar_type() == at::ScalarType::Half); + TORCH_CHECK(ln_bias.is_cuda()); + TORCH_CHECK(ln_bias.is_contiguous()); + + TORCH_CHECK(input.dim() == 2); + TORCH_CHECK(ln_weight.dim() == 1); + TORCH_CHECK(ln_bias.dim() == 1); + + int batch_tokens = input.size(0); + int hidden_size = input.size(1); + + TORCH_CHECK(ln_weight.size(0) == hidden_size); + TORCH_CHECK(ln_bias.size(0) == hidden_size); + + at::Tensor output = at::empty_like(input); + + half* p_in = (half*)input.data_ptr(); + half* p_wei = (half*)ln_weight.data_ptr(); + half* p_bias = (half*)ln_bias.data_ptr(); + half* p_out = (half*)output.data_ptr(); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + iluvatar::inferrt::transformer::IxinferLnLauncherOpt2( + p_in, p_wei, p_bias, p_out, batch_tokens, hidden_size, stream); + return {output}; +} diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/faster_layer_norm/setup.py b/models/speech/speech_recognition/transformer_asr/ixrt/faster_layer_norm/setup.py new file mode 100644 index 00000000..a031577c --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/faster_layer_norm/setup.py @@ -0,0 +1,48 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +import os +import glob +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +CUR_DIR = os.path.dirname(os.path.abspath(__file__)) + +# cpp_files = glob.glob(os.path.join(CUR_DIR,"*.cpp")) +# cu_files = glob.glob(os.path.join(CUR_DIR,'*.cu')) +# source_files = cpp_files + cu_files +# print("source files:") +# for i in source_files: +# print(i) +source_files = [ + os.path.join(CUR_DIR,'test.cpp'), + os.path.join(CUR_DIR,'kernel.cu'), +] + +for i in source_files: + assert os.path.isfile(i) + print(i) + +setup( + name="test", + ext_modules=[ + CUDAExtension( + name="sp_opt", + libraries=["cuinfer"], + sources=source_files) + ], + cmdclass={ + "build_ext": BuildExtension + } +) diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/faster_layer_norm/test.cpp b/models/speech/speech_recognition/transformer_asr/ixrt/faster_layer_norm/test.cpp new file mode 100644 index 00000000..f925c1b4 --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/faster_layer_norm/test.cpp @@ -0,0 +1,22 @@ +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +std::vector one_test_opt(at::Tensor input, at::Tensor ln_weight, + at::Tensor ln_bias); + +std::vector test_opt(at::Tensor input, at::Tensor ln_weight, + at::Tensor ln_bias) { + return one_test_opt(input, ln_weight, ln_bias); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("test_opt", &test_opt, "fast depthwise conv1d forward"); +} diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/faster_layer_norm/transformer_helper.cuh b/models/speech/speech_recognition/transformer_asr/ixrt/faster_layer_norm/transformer_helper.cuh new file mode 100644 index 00000000..f8a57622 --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/faster_layer_norm/transformer_helper.cuh @@ -0,0 +1,295 @@ +#pragma once +#include +#include + +namespace iluvatar { +namespace inferrt { +namespace transformer { + +__forceinline__ int nearest_4(int x) { + if (x % 4 == 0) { + return x; + } else { + int padding = 4 - x % 4; + return x + padding; + } +} + +__forceinline__ int nearest_2(int x) { + if (x % 2 == 0) { + return x; + } else { + int padding = 2 - x % 2; + return x + padding; + } +} + +__forceinline__ int nearest_num(int x, int value) { + if (x % value == 0) { + return x; + } else { + int padding = value - x % value; + return x + padding; + } +} + +__device__ int8_t float2int8(float x, float quant_scale) { + float i8_f = x * quant_scale; + int32_t i8 = floorf(i8_f + 0.5); + i8 = i8 < -127 ? -127 : (i8 > 127 ? 127 : i8); + return int8_t(i8); +} + +__device__ void WelfordCombine(float val, float *mean, float *m2, + float *count) { + // Use Welford Online algorithem to compute mean and variance + // For more details you can refer to: + // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm + *count += 1; + float delta1 = val - *mean; + *mean += delta1 / *count; + float delta2 = val - *mean; + *m2 += delta1 * delta2; +} + +__device__ void WelfordCombine(float b_mean, float b_m2, float b_count, + float *mean, float *m2, float *count) { + if (b_count == 0) { + return; + } + float new_count = *count + b_count; + float nb_over_n = b_count / new_count; + float delta = b_mean - *mean; + *mean += delta * nb_over_n; + *m2 += b_m2 + delta * delta * (*count) * nb_over_n; + *count = new_count; +} + +__device__ void WelfordWarpReduce(float thread_mean, float thread_m2, + float thread_count, float *mean, float *m2, + float *count) { + *mean = thread_mean; + *m2 = thread_m2; + *count = thread_count; + for (int mask = warpSize / 2; mask > 0; mask /= 2) { + float b_mean = __shfl_down_sync(0xffffffff, *mean, mask); + float b_m2 = __shfl_down_sync(0xffffffff, *m2, mask); + float b_count = __shfl_down_sync(0xffffffff, *count, mask); + WelfordCombine(b_mean, b_m2, b_count, mean, m2, count); + } +} + +// load 两个 half2, 保存到 float4 +__device__ void load_float4_from_half(float4 &vals, __half2 *input, int index) { + __half2 i1 = input[index * 2]; + __half2 i2 = input[index * 2 + 1]; + + vals.x = __half2float(i1.x); + vals.y = __half2float(i1.y); + vals.z = __half2float(i2.x); + vals.w = __half2float(i2.y); +} + +__device__ char4 float42char4(float4 vals, float quant_scale) { + char4 res; + res.x = float2int8(vals.x, quant_scale); + res.y = float2int8(vals.y, quant_scale); + res.z = float2int8(vals.z, quant_scale); + res.w = float2int8(vals.w, quant_scale); + return res; +} + +__device__ float4 char4addhalf2_dequant(char4 input_4, half2 residual_1, + half2 residual_2, float dequant_scale) { + float4 res; + res.x = + __int2float_rn(input_4.x) * dequant_scale + __half2float(residual_1.x); + res.y = + __int2float_rn(input_4.y) * dequant_scale + __half2float(residual_1.y); + res.z = + __int2float_rn(input_4.z) * dequant_scale + __half2float(residual_2.x); + res.w = + __int2float_rn(input_4.w) * dequant_scale + __half2float(residual_2.y); + return res; +} + +__device__ float4 compute_float4_norm_value(float4 vals, float mean, float m2, + int hidden_size, float epsilon, + half2 scale_1, half2 scale_2, + half2 bias_1, half2 bias_2) { + float4 norm_value; + norm_value.x = (vals.x - mean) * rsqrtf(m2 / hidden_size + epsilon) * + __half2float(scale_1.x) + + __half2float(bias_1.x); + norm_value.y = (vals.y - mean) * rsqrtf(m2 / hidden_size + epsilon) * + __half2float(scale_1.y) + + __half2float(bias_1.y); + norm_value.z = (vals.z - mean) * rsqrtf(m2 / hidden_size + epsilon) * + __half2float(scale_2.x) + + __half2float(bias_2.x); + norm_value.w = (vals.w - mean) * rsqrtf(m2 / hidden_size + epsilon) * + __half2float(scale_2.y) + + __half2float(bias_2.y); + return norm_value; +} + +// softmax +__forceinline__ __host__ __device__ int log2_ceil(int value) { + int log2_value = 0; + while ((1 << log2_value) < value) ++log2_value; + return log2_value; +} +template +__device__ T WARP_SHFL_XOR(T value, int laneMask, int width) { + unsigned int mask = 0xffffffff; +#if !(defined(__HIP_PLATFORM_HCC__) || defined(__ILUVATAR__)) + return __shfl_xor_sync(mask, value, laneMask, width); +#else + return __shfl_xor(value, laneMask, width); +#endif +} + +template +struct Add { + __device__ T operator()(T a, T b) const { return a + b; } +}; + +template +struct Max { + __device__ T operator()(T a, T b) const { return a < b ? b : a; } +}; +template class ReduceOp> +__device__ void warp_reduce(acc_t *sum) { + ReduceOp r; +#pragma unroll + for (int offset = REDUCE_WARP_SIZE / 2; offset > 0; offset /= 2) { + acc_t b = WARP_SHFL_XOR(*sum, offset, REDUCE_WARP_SIZE); + *sum = r(*sum, b); + } +} + +__device__ void warp_argmax(float &value, int32_t &idx) { + for (int offset = warpSize / 2; offset > 0; offset /= 2) { + float next_value = WARP_SHFL_XOR(value, offset, warpSize); + float next_idx = WARP_SHFL_XOR(idx, offset, warpSize); + if (next_value > value) { + value = next_value; + idx = next_idx; + } + } +} + +// gelu +// IxinferBiasGeluI8II8OKernel +template +__device__ T tanhf_exp(T x) { + // float e1 = __expf(x); + // float e2 = 1.0f / e1; + // return (e1 - e2) / (e1 + e2); + + return (2.f / (1.f + __expf(-2.f * x)) - 1.f); +} + +template +__device__ T gelu(T x) { + float cdf = + 0.5f * + (1.0f + tanhf_exp((0.7978845608028654f * (x + 0.044715f * x * x * x)))); + return x * cdf; +} + +/* fp16 gelu */ +template <> +__forceinline__ __device__ __half2 gelu<__half2>(__half2 val) { + __half2 val_pow3 = __hmul2(val, __hmul2(val, val)); + float2 tmp_pow = __half22float2(val_pow3); + float2 tmp = __half22float2(val); + + tmp.x = + 0.5f * + (1.0f + tanhf((0.7978845608028654f * (tmp.x + 0.044715f * tmp_pow.x)))); + tmp.y = + 0.5f * + (1.0f + tanhf((0.7978845608028654f * (tmp.y + 0.044715f * tmp_pow.y)))); + return __hmul2(val, __float22half2_rn(tmp)); +} + +/* Convert vector index to 3-dim tensor index */ +__forceinline__ __host__ __device__ void decompose_3dim(int src, int dim1, + int dim2, int *id0, + int *id1, int *id2) { + *id2 = src % dim2; + src /= dim2; + + *id1 = src % dim1; + *id0 = src / dim1; +} + +template +__inline__ __device__ T warpReduceSumV2(T *val) { +#pragma unroll + for (int i = 0; i < NUM; i++) { +#pragma unroll + for (int mask = warpSize / 2; mask > 0; mask >>= 1) + val[i] += __shfl_xor_sync(0xffffffff, val[i], mask, warpSize); + } + return (T)(0.0f); +} + +template +__inline__ __device__ T blockReduceSumV2(T *val) { + static __shared__ T shared[NUM][warpSize + 1]; + int lane = threadIdx.x % warpSize; + int wid = threadIdx.x / warpSize; + + warpReduceSumV2(val); + + if (lane == 0) { +#pragma unroll + for (int i = 0; i < NUM; i++) { + shared[i][wid] = val[i]; + } + } + + __syncthreads(); + + bool is_mask = lane < (blockDim.x / warpSize); +#pragma unroll + for (int i = 0; i < NUM; i++) { + val[i] = is_mask ? shared[i][lane] : (T)(0.0f); + } + warpReduceSumV2(val); + return (T)0.0f; +} + +__inline__ __device__ void warpReduceSum2Number(float *x, float *y) { +#pragma unroll + for (int mask = warpSize / 2; mask > 0; mask >>= 1) { + *x += __shfl_xor_sync(0xffffffff, *x, mask, warpSize); + *y += __shfl_xor_sync(0xffffffff, *y, mask, warpSize); + } +} + +__inline__ __device__ void blockReduceSum2Number(float *x, float *y) { + static __shared__ float shared[2][warpSize + 1]; + int lane = threadIdx.x % warpSize; + int wid = threadIdx.x / warpSize; + + warpReduceSum2Number(x, y); + if (lane == 0) { + shared[0][wid] = *x; + shared[1][wid] = *y; + } + __syncthreads(); + bool is_mask = lane < (blockDim.x / warpSize); + *x = is_mask ? shared[0][lane] : 0.0f; + *y = is_mask ? shared[0][lane] : 0.0f; + + warpReduceSum2Number(x, y); +} + +} // namespace transformer + +} // namespace inferrt +} // namespace iluvatar diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/faster_logsumexp/__init__.py b/models/speech/speech_recognition/transformer_asr/ixrt/faster_logsumexp/__init__.py new file mode 100644 index 00000000..d50b3758 --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/faster_logsumexp/__init__.py @@ -0,0 +1,38 @@ +import torch +from faster_logsumexp import sp_opt + +# class FasterLogSumExp(torch.nn.Module): +# def __init__(self, weight, bias): +# super(FasterLogSumExp, self).__init__() +# self.weight = weight +# self.bias = bias + +# def forward(self, inputs, *args, **kwargs): +# hidden_size = self.weight.size(0) +# in_shape = inputs.shape +# inputs = inputs.view(-1,hidden_size) +# output, = sp_opt.test_opt(inputs,self.weight,self.bias) +# output = output.view(*in_shape) +# return output + +def FasterLogSumExp(inputs,dim): + # print(inputs.shape, dim) + if dim == 1 and len(inputs.shape)>2 and inputs.size(1)==2: + in_shape = inputs.shape + inputs = inputs.view(in_shape[0],in_shape[1],-1) + res, = sp_opt.test_opt(inputs) + new_shape = (in_shape[0],) + in_shape[2:] + res = res.view(*new_shape) + return res + # dim==0 现在的实现会有bug? + # if dim == 0 and len(inputs.shape)>=2: + # in_shape = inputs.shape + # inputs = inputs.view(in_shape[0],-1) + # res, = sp_opt.test_opt_dim0(inputs) + # new_shape = in_shape[1:] + # res = res.view(*new_shape) + # return res + # print(f"not support shape: {inputs.shape} dim: {dim}") + res = torch.logsumexp(inputs, dim=dim) + return res + diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/faster_logsumexp/build.sh b/models/speech/speech_recognition/transformer_asr/ixrt/faster_logsumexp/build.sh new file mode 100644 index 00000000..f679258d --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/faster_logsumexp/build.sh @@ -0,0 +1,22 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +set -euox pipefail + +rm -rf build +rm -rf *.so + +python3 setup.py build + +cp build/lib*/*.so . \ No newline at end of file diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/faster_logsumexp/kernel.cu b/models/speech/speech_recognition/transformer_asr/ixrt/faster_logsumexp/kernel.cu new file mode 100644 index 00000000..56eb0810 --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/faster_logsumexp/kernel.cu @@ -0,0 +1,155 @@ +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace iluvatar::inferrt::transformer { + +__global__ void LogSumExpWith2(half* input, half* output, int H) { + half2* h2_in1 = reinterpret_cast(input + blockIdx.x * 2 * H); + half2* h2_in2 = reinterpret_cast(input + blockIdx.x * 2 * H + H); + half2* h2_out = reinterpret_cast(output + blockIdx.x * H); + + int i = blockIdx.y * blockDim.x + threadIdx.x; + if (i < H / 2) { + float2 res; + half2 value1 = h2_in1[i]; + half2 value2 = h2_in2[i]; + + res.x = std::log(__expf(__half2float(value1.x)) + + __expf(__half2float(value2.x))); + res.y = std::log(__expf(__half2float(value1.y)) + + __expf(__half2float(value2.y))); + + half2 res_h2; + res_h2.x = __float2half(res.x); + res_h2.y = __float2half(res.y); + h2_out[i] = res_h2; + } +} + +void IxinferLogSumExpLauncher(half* input, half* output, int N, int C, int H, + cudaStream_t stream) { + const float layernorm_eps = 1e-5; + if (H % 2 != 0) { + throw std::runtime_error("IxinferLogSumExpLauncher: size error!"); + } + int num_threads = 1024; + int half_h = H / 2; + int num_roll = (half_h - 1 + num_threads) / num_threads; + dim3 grid(N, num_roll); + dim3 block(num_threads); + switch (C) { + case 2: + LogSumExpWith2<<>>(input, output, H); + break; + default: + throw std::runtime_error( + "IxinferLogSumExpLauncher error, unsupport size! "); + break; + } +} + +// https://zhuanlan.zhihu.com/p/153535799 +__global__ void LogSumExpDim0(half* input, half* output, int N, int H) { + half2* h2_out = reinterpret_cast(output); + + int i = blockIdx.x * blockDim.x + threadIdx.x; + + float2 res; + res.x = 0.f; + res.y = 0.f; + + float2 max_values; + max_values.x = -1000.f; + max_values.y = -1000.f; + + for (int batch_idx = 0; batch_idx < N; batch_idx++) { + half2* h2_in = reinterpret_cast(input + batch_idx * H); + half2 value = h2_in[i]; + + if (max_values.x < __half2float(value.x)) { + max_values.x = __half2float(value.x); + } + if (max_values.y < __half2float(value.y)) { + max_values.y = __half2float(value.y); + } + } + + for (int batch_idx = 0; batch_idx < N; batch_idx++) { + half2* h2_in = reinterpret_cast(input + batch_idx * H); + half2 value = h2_in[i]; + + res.x += __expf(__half2float(value.x) - max_values.x); + res.y += __expf(__half2float(value.y) - max_values.y); + } + + half2 res_h2; + res_h2.x = __float2half(std::log(res.x) + max_values.x); + res_h2.y = __float2half(std::log(res.y) + max_values.y); + + h2_out[i] = res_h2; +} + +void IxinferLogSumExpLauncher(half* input, half* output, int N, int H, + cudaStream_t stream) { + if (H % 2 != 0) { + throw std::runtime_error("IxinferLogSumExpLauncher: size error!"); + } + int num_threads = 1024; + int half_h = H / 2; + int num_roll = (half_h - 1 + num_threads) / num_threads; + dim3 grid(num_roll); + dim3 block(num_threads); + LogSumExpDim0<<>>(input, output, N, H); +} + +} // namespace iluvatar::inferrt::transformer + +std::vector one_test_opt(at::Tensor input) { + TORCH_CHECK(input.scalar_type() == at::ScalarType::Half); + TORCH_CHECK(input.is_cuda()); + TORCH_CHECK(input.is_contiguous()); + + TORCH_CHECK(input.dim() == 3); + + int N = input.size(0); + int C = input.size(1); + int H = input.size(2); + + at::Tensor output = input.new_empty({N, H}); + + half* p_in = (half*)input.data_ptr(); + half* p_out = (half*)output.data_ptr(); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + iluvatar::inferrt::transformer::IxinferLogSumExpLauncher(p_in, p_out, N, C, H, + stream); + return {output}; +} + +std::vector one_test_dim0(at::Tensor input) { + TORCH_CHECK(input.scalar_type() == at::ScalarType::Half); + TORCH_CHECK(input.is_cuda()); + TORCH_CHECK(input.is_contiguous()); + + TORCH_CHECK(input.dim() == 2); + + int N = input.size(0); + int H = input.size(1); + + at::Tensor output = input.new_empty({H}); + + half* p_in = (half*)input.data_ptr(); + half* p_out = (half*)output.data_ptr(); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + iluvatar::inferrt::transformer::IxinferLogSumExpLauncher(p_in, p_out, N, H, + stream); + return {output}; +} \ No newline at end of file diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/faster_logsumexp/setup.py b/models/speech/speech_recognition/transformer_asr/ixrt/faster_logsumexp/setup.py new file mode 100644 index 00000000..a031577c --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/faster_logsumexp/setup.py @@ -0,0 +1,48 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +import os +import glob +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +CUR_DIR = os.path.dirname(os.path.abspath(__file__)) + +# cpp_files = glob.glob(os.path.join(CUR_DIR,"*.cpp")) +# cu_files = glob.glob(os.path.join(CUR_DIR,'*.cu')) +# source_files = cpp_files + cu_files +# print("source files:") +# for i in source_files: +# print(i) +source_files = [ + os.path.join(CUR_DIR,'test.cpp'), + os.path.join(CUR_DIR,'kernel.cu'), +] + +for i in source_files: + assert os.path.isfile(i) + print(i) + +setup( + name="test", + ext_modules=[ + CUDAExtension( + name="sp_opt", + libraries=["cuinfer"], + sources=source_files) + ], + cmdclass={ + "build_ext": BuildExtension + } +) diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/faster_logsumexp/test.cpp b/models/speech/speech_recognition/transformer_asr/ixrt/faster_logsumexp/test.cpp new file mode 100644 index 00000000..5eaf6fe1 --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/faster_logsumexp/test.cpp @@ -0,0 +1,27 @@ +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +std::vector one_test_opt(at::Tensor input); + +std::vector test_opt(at::Tensor input) { + return one_test_opt(input); +} + +std::vector one_test_dim0(at::Tensor input); + +std::vector test_opt_dim0(at::Tensor input) { + return one_test_dim0(input); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("test_opt", &test_opt, ""); + m.def("test_opt_dim0", &test_opt_dim0, ""); +} diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/faster_logsumexp/test.py b/models/speech/speech_recognition/transformer_asr/ixrt/faster_logsumexp/test.py new file mode 100644 index 00000000..7b22dbdd --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/faster_logsumexp/test.py @@ -0,0 +1,50 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +import torch +import sp_opt + +if __name__ == "__main__": + batch_tokens = 2 + c = 2 + hidden_size = 320*5000 + + inputs = torch.randn([batch_tokens,c, hidden_size]).cuda().half() + + # res1 = torch.log(torch.sum(torch.exp(inputs),dim=-1)) + # res2 = torch.logsumexp(inputs,dim=-1) + # diff = torch.abs(res1-res2) + # print(diff.max()) + + res_pt = torch.logsumexp(inputs,dim=1) + + res_cu, = sp_opt.test_opt(inputs) + + diff = torch.abs(res_pt - res_cu) + print(diff.max()) + + for i in range(20): + res_cu, = sp_opt.test_opt(inputs) + + batch_tokens = 55 + hidden_size = 320*5000 + inputs = torch.randn([batch_tokens,hidden_size]).cuda().half() + res_pt = torch.logsumexp(inputs,dim=0) + res_cu, = sp_opt.test_opt_dim0(inputs) + + diff = torch.abs(res_pt - res_cu) + print(diff.max()) + for i in range(20): + res_cu, = sp_opt.test_opt_dim0(inputs) + diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/faster_stack/__init__.py b/models/speech/speech_recognition/transformer_asr/ixrt/faster_stack/__init__.py new file mode 100644 index 00000000..48d0cf5b --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/faster_stack/__init__.py @@ -0,0 +1,33 @@ +import torch +from faster_stack import sp_opt + +# class FasterLogSumExp(torch.nn.Module): +# def __init__(self, weight, bias): +# super(FasterLogSumExp, self).__init__() +# self.weight = weight +# self.bias = bias + +# def forward(self, inputs, *args, **kwargs): +# hidden_size = self.weight.size(0) +# in_shape = inputs.shape +# inputs = inputs.view(-1,hidden_size) +# output, = sp_opt.test_opt(inputs,self.weight,self.bias) +# output = output.view(*in_shape) +# return output + +def FasterStack(inputs): + if len(inputs) == 4: + a,b,c,d = inputs + in_shape = a.shape + res, = sp_opt.test_opt(a.view(-1),b.view(-1),c.view(-1),d.view(-1)) + new_shape = (4,) + in_shape + res = res.view(*new_shape) + return res + if len(inputs) == 2: + a,b = inputs + in_shape = a.shape + res, = sp_opt.test_opt_2(a.view(-1),b.view(-1)) + new_shape = (2,) + in_shape + res = res.view(*new_shape) + return res + return torch.stack(inputs) \ No newline at end of file diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/faster_stack/build.sh b/models/speech/speech_recognition/transformer_asr/ixrt/faster_stack/build.sh new file mode 100644 index 00000000..f679258d --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/faster_stack/build.sh @@ -0,0 +1,22 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +set -euox pipefail + +rm -rf build +rm -rf *.so + +python3 setup.py build + +cp build/lib*/*.so . \ No newline at end of file diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/faster_stack/kernel.cu b/models/speech/speech_recognition/transformer_asr/ixrt/faster_stack/kernel.cu new file mode 100644 index 00000000..0fdff649 --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/faster_stack/kernel.cu @@ -0,0 +1,146 @@ +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace iluvatar::inferrt::transformer { + +__global__ void Stack(half* a, half* b, half* c, half* d, half* output, int H) { + half2* h2_a = reinterpret_cast(a); + half2* h2_b = reinterpret_cast(b); + half2* h2_c = reinterpret_cast(c); + half2* h2_d = reinterpret_cast(d); + + half2* h2_out_a = reinterpret_cast(output); + half2* h2_out_b = reinterpret_cast(output + H); + half2* h2_out_c = reinterpret_cast(output + H * 2); + half2* h2_out_d = reinterpret_cast(output + H * 3); + + int i = blockIdx.x * blockDim.x + threadIdx.x; + + if (i < H / 2) { + h2_out_a[i] = h2_a[i]; + h2_out_b[i] = h2_b[i]; + h2_out_c[i] = h2_c[i]; + h2_out_d[i] = h2_d[i]; + } +} + +void IxinferStackLauncher(half* a, half* b, half* c, half* d, half* output, + int H, cudaStream_t stream) { + if (H % 2 != 0) { + throw std::runtime_error("IxinferStackLauncher: size error!"); + } + int num_threads = 1024; + int half_h = H / 2; + int num_roll = (half_h - 1 + num_threads) / num_threads; + dim3 grid(num_roll); + dim3 block(num_threads); + Stack<<>>(a, b, c, d, output, H); +} + +__global__ void Stack(half* a, half* b, half* output, int H) { + half2* h2_a = reinterpret_cast(a); + half2* h2_b = reinterpret_cast(b); + + half2* h2_out_a = reinterpret_cast(output); + half2* h2_out_b = reinterpret_cast(output + H); + + int i = blockIdx.x * blockDim.x + threadIdx.x; + + if (i < H / 2) { + h2_out_a[i] = h2_a[i]; + h2_out_b[i] = h2_b[i]; + } +} + +void IxinferStackLauncher(half* a, half* b, half* output, int H, + cudaStream_t stream) { + if (H % 2 != 0) { + throw std::runtime_error("IxinferStackLauncher: size error!"); + } + int num_threads = 1024; + int half_h = H / 2; + int num_roll = (half_h - 1 + num_threads) / num_threads; + dim3 grid(num_roll); + dim3 block(num_threads); + Stack<<>>(a, b, output, H); +} + +} // namespace iluvatar::inferrt::transformer + +std::vector one_test_opt(at::Tensor a, at::Tensor b, at::Tensor c, + at::Tensor d) { + TORCH_CHECK(a.scalar_type() == at::ScalarType::Half); + TORCH_CHECK(a.is_cuda()); + TORCH_CHECK(a.is_contiguous()); + + TORCH_CHECK(b.scalar_type() == at::ScalarType::Half); + TORCH_CHECK(b.is_cuda()); + TORCH_CHECK(b.is_contiguous()); + + TORCH_CHECK(c.scalar_type() == at::ScalarType::Half); + TORCH_CHECK(c.is_cuda()); + TORCH_CHECK(c.is_contiguous()); + + TORCH_CHECK(d.scalar_type() == at::ScalarType::Half); + TORCH_CHECK(d.is_cuda()); + TORCH_CHECK(d.is_contiguous()); + + TORCH_CHECK(a.dim() == 1); + TORCH_CHECK(b.dim() == 1); + TORCH_CHECK(c.dim() == 1); + TORCH_CHECK(d.dim() == 1); + + int N = a.size(0); + + TORCH_CHECK(b.size(0) == N); + TORCH_CHECK(c.size(0) == N); + TORCH_CHECK(d.size(0) == N); + + at::Tensor output = a.new_empty({N * 4}); + + half* p_a = (half*)a.data_ptr(); + half* p_b = (half*)b.data_ptr(); + half* p_c = (half*)c.data_ptr(); + half* p_d = (half*)d.data_ptr(); + half* p_out = (half*)output.data_ptr(); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + iluvatar::inferrt::transformer::IxinferStackLauncher(p_a, p_b, p_c, p_d, + p_out, N, stream); + return {output}; +} + +std::vector one_test_opt_2(at::Tensor a, at::Tensor b) { + TORCH_CHECK(a.scalar_type() == at::ScalarType::Half); + TORCH_CHECK(a.is_cuda()); + TORCH_CHECK(a.is_contiguous()); + + TORCH_CHECK(b.scalar_type() == at::ScalarType::Half); + TORCH_CHECK(b.is_cuda()); + TORCH_CHECK(b.is_contiguous()); + + TORCH_CHECK(a.dim() == 1); + TORCH_CHECK(b.dim() == 1); + + int N = a.size(0); + + TORCH_CHECK(b.size(0) == N); + + at::Tensor output = a.new_empty({N * 2}); + + half* p_a = (half*)a.data_ptr(); + half* p_b = (half*)b.data_ptr(); + half* p_out = (half*)output.data_ptr(); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + iluvatar::inferrt::transformer::IxinferStackLauncher(p_a, p_b, p_out, N, + stream); + return {output}; +} diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/faster_stack/setup.py b/models/speech/speech_recognition/transformer_asr/ixrt/faster_stack/setup.py new file mode 100644 index 00000000..a031577c --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/faster_stack/setup.py @@ -0,0 +1,48 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +import os +import glob +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +CUR_DIR = os.path.dirname(os.path.abspath(__file__)) + +# cpp_files = glob.glob(os.path.join(CUR_DIR,"*.cpp")) +# cu_files = glob.glob(os.path.join(CUR_DIR,'*.cu')) +# source_files = cpp_files + cu_files +# print("source files:") +# for i in source_files: +# print(i) +source_files = [ + os.path.join(CUR_DIR,'test.cpp'), + os.path.join(CUR_DIR,'kernel.cu'), +] + +for i in source_files: + assert os.path.isfile(i) + print(i) + +setup( + name="test", + ext_modules=[ + CUDAExtension( + name="sp_opt", + libraries=["cuinfer"], + sources=source_files) + ], + cmdclass={ + "build_ext": BuildExtension + } +) diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/faster_stack/test.cpp b/models/speech/speech_recognition/transformer_asr/ixrt/faster_stack/test.cpp new file mode 100644 index 00000000..08703064 --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/faster_stack/test.cpp @@ -0,0 +1,29 @@ +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +std::vector one_test_opt(at::Tensor a, at::Tensor b, at::Tensor c, + at::Tensor d); + +std::vector test_opt(at::Tensor a, at::Tensor b, at::Tensor c, + at::Tensor d) { + return one_test_opt(a, b, c, d); +} + +std::vector one_test_opt_2(at::Tensor a, at::Tensor b); + +std::vector test_opt_2(at::Tensor a, at::Tensor b) { + return one_test_opt_2(a, b); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("test_opt", &test_opt, ""); + m.def("test_opt_2", &test_opt_2, ""); +} diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/faster_stack/test.py b/models/speech/speech_recognition/transformer_asr/ixrt/faster_stack/test.py new file mode 100644 index 00000000..185b829b --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/faster_stack/test.py @@ -0,0 +1,74 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +import torch +import sp_opt + +if __name__ == "__main__": + batch_tokens = 320 + hidden_size = 5000 + + a = torch.randn([batch_tokens,hidden_size]).cuda().half() + b = torch.randn([batch_tokens,hidden_size]).cuda().half() + c = torch.randn([batch_tokens,hidden_size]).cuda().half() + d = torch.randn([batch_tokens,hidden_size]).cuda().half() + + res_pt = torch.stack([a,b,c,d]) + + res_cu, = sp_opt.test_opt(a.view(-1),b.view(-1),c.view(-1),d.view(-1)) + res_cu = res_cu.view(4,batch_tokens,hidden_size) + + diff = torch.abs(res_pt-res_cu) + print(diff) + print(diff.max()) + + for i in range(20): + res_cu, = sp_opt.test_opt(a.view(-1),b.view(-1),c.view(-1),d.view(-1)) + + res_pt = torch.stack([a,b]) + + res_cu, = sp_opt.test_opt_2(a.view(-1),b.view(-1)) + res_cu = res_cu.view(2,batch_tokens,hidden_size) + + diff = torch.abs(res_pt-res_cu) + print(diff) + print(diff.max()) + for i in range(20): + res_cu, = sp_opt.test_opt_2(a.view(-1),b.view(-1)) + # # res1 = torch.log(torch.sum(torch.exp(inputs),dim=-1)) + # # res2 = torch.logsumexp(inputs,dim=-1) + # # diff = torch.abs(res1-res2) + # # print(diff.max()) + + # res_pt = torch.logsumexp(inputs,dim=1) + + # res_cu, = sp_opt.test_opt(inputs) + + # diff = torch.abs(res_pt - res_cu) + # print(diff.max()) + + # for i in range(20): + # res_cu, = sp_opt.test_opt(inputs) + + # batch_tokens = 55 + # hidden_size = 320*5000 + # inputs = torch.randn([batch_tokens,hidden_size]).cuda().half() + # res_pt = torch.logsumexp(inputs,dim=0) + # res_cu, = sp_opt.test_opt_dim0(inputs) + + # diff = torch.abs(res_pt - res_cu) + # print(diff.max()) + # for i in range(20): + # res_cu, = sp_opt.test_opt_dim0(inputs) + diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/hparams/train_ASR_transformer.yaml b/models/speech/speech_recognition/transformer_asr/ixrt/hparams/train_ASR_transformer.yaml new file mode 100644 index 00000000..859d09f3 --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/hparams/train_ASR_transformer.yaml @@ -0,0 +1,253 @@ +# ############################################################################ +# Model: E2E ASR with Transformer +# Encoder: Transformer Encoder +# Decoder: Transformer Decoder + (CTC/ATT joint) beamsearch +# Tokens: BPE with unigram +# losses: CTC + KLdiv (Label Smoothing loss) +# Training: AISHELL-1 +# Authors: Jianyuan Zhong, Titouan Parcollet +# ############################################################################ +# Seed needs to be set at top of yaml, before objects with parameters are made +seed: 8886 +__set_seed: !apply:torch.manual_seed [!ref ] +output_folder: !ref results/transformer/ +cer_file: !ref /cer.txt +save_folder: !ref /save +train_log: !ref /train_log.txt + +# Data files +data_folder: !PLACEHOLDER # e,g./path/to/aishell +# noise/ris dataset will automatically be downloaded +data_folder_rirs: !ref # Change this is needed +skip_prep: False +ckpt_interval_minutes: 15 # save checkpoint every N min +train_data: !ref /csv_data/train.csv +valid_data: !ref /csv_data/dev.csv +test_data: !ref /csv_data/test.csv +tokenizer_file: speechbrain/asr-transformer-aishell/tokenizer.ckpt + +# Training parameters +number_of_epochs: 50 +batch_size: 64 +ctc_weight: 0.3 +gradient_accumulation: 4 +loss_reduction: 'batchmean' +sorting: ascending + +dynamic_batching: False +dynamic_batch_sampler: + feats_hop_size: 0.01 + max_batch_len: 15 # in terms of "duration" in annotations by default, second here + left_bucket_len: 200 # old implementation attributs + multiplier: 1.1 # old implementation attributs + shuffle_ex: False # if true re-creates batches at each epoch shuffling examples. + num_buckets: 10 # floor(log(max_batch_len/left_bucket_len, multiplier)) + 1 + batch_ordering: ascending + +num_workers: 6 + +# stages related parameters +stage_one_epochs: 40 +lr_adam: 1.0 +lr_sgd: 0.000025 + +# Feature parameters +sample_rate: 16000 +n_fft: 400 +n_mels: 80 + +# Dataloader options +train_dataloader_opts: + batch_size: !ref + shuffle: True + +valid_dataloader_opts: + batch_size: !ref + +test_dataloader_opts: + batch_size: !ref + +####################### Model parameters ########################### +# Transformer +d_model: 256 +nhead: 4 +num_encoder_layers: 12 +num_decoder_layers: 6 +d_ffn: 2048 +transformer_dropout: 0.1 +activation: !name:torch.nn.GELU +output_neurons: 5000 + +# Outputs +blank_index: 0 +label_smoothing: 0.1 +pad_index: 0 +bos_index: 1 +eos_index: 2 + +# Decoding parameters +min_decode_ratio: 0.0 +max_decode_ratio: 1.0 # 1.0 +valid_search_interval: 10 +valid_beam_size: 10 +test_beam_size: 1 +ctc_weight_decode: 0.40 + +############################## models ################################ + +CNN: !new:speechbrain.lobes.models.convolution.ConvolutionFrontEnd + input_shape: (8, 10, 80) + num_blocks: 2 + num_layers_per_block: 1 + out_channels: (256, 256) + kernel_sizes: (3, 3) + strides: (2, 2) + residuals: (False, False) + +Transformer: !new:speechbrain.lobes.models.transformer.TransformerASR.TransformerASR # yamllint disable-line rule:line-length + input_size: 5120 + tgt_vocab: !ref + d_model: !ref + nhead: !ref + num_encoder_layers: !ref + num_decoder_layers: !ref + d_ffn: !ref + dropout: !ref + activation: !ref + normalize_before: True + +tokenizer: !new:sentencepiece.SentencePieceProcessor + +ctc_lin: !new:speechbrain.nnet.linear.Linear + input_size: !ref + n_neurons: !ref + +seq_lin: !new:speechbrain.nnet.linear.Linear + input_size: !ref + n_neurons: !ref + +env_corrupt: !new:speechbrain.lobes.augment.EnvCorrupt + openrir_folder: !ref + babble_prob: 0.0 + reverb_prob: 0.0 + noise_prob: 1.0 + noise_snr_low: 0 + noise_snr_high: 15 + +modules: + CNN: !ref + Transformer: !ref + seq_lin: !ref + ctc_lin: !ref + env_corrupt: !ref + +model: !new:torch.nn.ModuleList + - [!ref , !ref , !ref , !ref ] + +# define two optimizers here for two-stage training +Adam: !name:torch.optim.Adam + lr: 0 + betas: (0.9, 0.98) + eps: 0.000000001 + +SGD: !name:torch.optim.SGD + lr: !ref + momentum: 0.99 + nesterov: True + + +valid_search: !new:speechbrain.decoders.S2STransformerBeamSearch + modules: [!ref , !ref , !ref ] + bos_index: !ref + eos_index: !ref + blank_index: !ref + min_decode_ratio: !ref + max_decode_ratio: !ref + beam_size: !ref + ctc_weight: !ref + using_eos_threshold: False + length_normalization: True + +test_search: !new:speechbrain.decoders.S2STransformerBeamSearch + modules: [!ref , !ref , !ref ] + bos_index: !ref + eos_index: !ref + blank_index: !ref + min_decode_ratio: !ref + max_decode_ratio: !ref + beam_size: !ref + ctc_weight: !ref + using_eos_threshold: False + length_normalization: True + +log_softmax: !new:torch.nn.LogSoftmax + dim: -1 + +ctc_cost: !name:speechbrain.nnet.losses.ctc_loss + blank_index: !ref + reduction: !ref + +seq_cost: !name:speechbrain.nnet.losses.kldiv_loss + label_smoothing: !ref + reduction: !ref + +noam_annealing: !new:speechbrain.nnet.schedulers.NoamScheduler + lr_initial: !ref + n_warmup_steps: 25000 + model_size: !ref + +checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer + checkpoints_dir: !ref + recoverables: + model: !ref + noam_scheduler: !ref + normalizer: !ref + counter: !ref + +epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter + limit: !ref + +normalize: !new:speechbrain.processing.features.InputNormalization + norm_type: global + update_until_epoch: 4 + +augmentation: !new:speechbrain.lobes.augment.SpecAugment + time_warp: True + time_warp_window: 5 + time_warp_mode: bicubic + freq_mask: True + n_freq_mask: 2 + time_mask: True + n_time_mask: 2 + replace_with_zero: False + freq_mask_width: 30 + time_mask_width: 40 + +compute_features: !new:speechbrain.lobes.features.Fbank + sample_rate: !ref + n_fft: !ref + n_mels: !ref + +train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger + save_file: !ref + +# AISHELL-1 has spaces between words in the transcripts, +# which Chinese writing normally does not do. +# If remove_spaces, spaces are removed +# from the transcript before computing CER. +# (e.g., 祝 可爱 的 你 —> 祝可爱的你) +remove_spaces: True +split_tokens: !apply:operator.not_ [!ref ] + +cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats + split_tokens: !ref +acc_computer: !name:speechbrain.utils.Accuracy.AccuracyStats + +pretrainer: !new:speechbrain.utils.parameter_transfer.Pretrainer + collect_in: !ref + loadables: + tokenizer: !ref + paths: + tokenizer: !ref +engine_path: transformer.engine +ckpt_path: /home/data/speechbrain/results \ No newline at end of file diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/inference.py b/models/speech/speech_recognition/transformer_asr/ixrt/inference.py new file mode 100644 index 00000000..68ef0e40 --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/inference.py @@ -0,0 +1,606 @@ +#!/usr/bin/env/python3 +""" + +AISHELL-1 transformer model recipe. (Adapted from the LibriSpeech recipe.) + +""" +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +import sys +import time +import torch +import logging +import speechbrain as sb +from speechbrain import Stage +from speechbrain.dataio.dataloader import LoopedLoader +from speechbrain.utils.distributed import run_on_main +from hyperpyyaml import load_hyperpyyaml +from speechbrain.utils.checkpoints import Checkpointer +import numpy as np +from speechbrain.utils import data_utils +import tensorrt +from torch.utils.data import DataLoader +from tqdm import tqdm +import convert +import beam_search +from load_ixrt_plugin import load_ixrt_plugin +from tensorrt import Dims +from speechbrain.lobes.models.transformer import Transformer +Transformer.get_lookahead_mask = convert.patch_get_lookahead_mask +load_ixrt_plugin() +logger = logging.getLogger(__name__) + + +def volume(shape): + result = 1 + for i in shape: + result *= i + return result + + +class ASR(sb.core.Brain): + def __init__(self, engine_path, *args, **kwargs): + super().__init__(*args, **kwargs) + # + self.forward_time = 0 + # ixrt + self.logger = tensorrt.Logger(tensorrt.Logger.ERROR) + with open(engine_path, "rb") as f, tensorrt.Runtime(self.logger) as self.runtime: + self.engine = self.runtime.deserialize_cuda_engine(f.read()) + assert self.engine + self.context = self.engine.create_execution_context() + assert self.context + self.encoder_ln_out = torch.zeros((64,2048,256), dtype=torch.float16).cuda() + self.infer_time = 0 + self.hparams.valid_search.return_log_probs = True + self.modules.CNN = self.modules.CNN.half() + self.hparams.valid_search = self.hparams.valid_search.half() + self.hparams.valid_search.model = self.hparams.valid_search.model.half() + self.hparams.valid_search.fc = self.hparams.valid_search.fc.half() + self.hparams.valid_search.ctc_fc = self.hparams.valid_search.ctc_fc.half() + self.hparams.valid_search.minus_inf = -10000 + self.hparams.valid_search.softmax = self.hparams.valid_search.softmax.half() + self.hparams.valid_search.model.decoder = convert.convert_decoder_model(self.hparams.valid_search.model.decoder) + # Given all input/output bindings, run in a dynamic shape way + def ixrt_infer(self, engine, context, bindings): + assert engine.num_bindings == len(bindings) + io_buffers = [0] * engine.num_bindings + for name, arr in bindings.items(): + idx = engine.get_binding_index(name) + io_buffers[idx] = arr.data_ptr() + # dynamic input + if engine.binding_is_input(idx): + context.set_binding_shape(idx, Dims(arr.shape)) + + forward_start_time = time.time() + assert context.execute_v2(io_buffers) + + torch.cuda.synchronize() + self.forward_time += time.time() - forward_start_time + outputs = {} + for name, arr in bindings.items(): + idx = engine.get_binding_index(name) + if not engine.binding_is_input(idx): + # dynamic output + shape = context.get_binding_shape(idx) + outputs[name] = arr.view(-1)[:volume(shape)].view(*shape) + return outputs + + def compute_forward(self, batch, stage): + """Forward computations from the waveform batches to the output probabilities.""" + + batch = batch.to(self.device) + wavs, wav_lens = batch.sig + tokens_bos, _ = batch.tokens_bos + + # Add augmentation if specified + if stage == sb.Stage.TRAIN: + if hasattr(self.modules, "env_corrupt"): + wavs_noise = self.modules.env_corrupt(wavs, wav_lens) + wavs = torch.cat([wavs, wavs_noise], dim=0) + wav_lens = torch.cat([wav_lens, wav_lens]) + tokens_bos = torch.cat([tokens_bos, tokens_bos], dim=0) + + torch.cuda.synchronize() + start_time = time.time() + + # compute features + feats = self.hparams.compute_features(wavs) + current_epoch = self.hparams.epoch_counter.current + feats = self.hparams.normalize(feats, wav_lens, epoch=current_epoch) + + if stage == sb.Stage.TRAIN: + if hasattr(self.hparams, "augmentation"): + feats = self.hparams.augmentation(feats) + + # forward modules + src = self.modules.CNN(feats.half()) + + # Orignal PyTorch implementation, comment this to compare + # enc_out, _ = self.modules.Transformer( + # src, tokens_bos, wav_lens, pad_idx=self.hparams.pad_index + # ) + # logits = self.modules.ctc_lin(enc_out) + # p_ctc = self.hparams.log_softmax(logits) + # hyps, _ = self.hparams.test_search( + # enc_out.detach(), wav_lens + # ) + # return p_ctc, wav_lens, hyps + + # transformer + if src.ndim == 4: + bz, t, ch1, ch2 = src.shape + src = src.reshape(bz, t, ch1 * ch2) + + # ixrt inference + t1 = time.time() + bindings = {"input": src.half(), "length_radio": wav_lens.half(), + "encoder_ln_out": self.encoder_ln_out} + + infer_result = self.ixrt_infer(self.engine, self.context, bindings) + encoder_ln_out = infer_result["encoder_ln_out"] + t2 = time.time() + + hyps, _, p_ctc = beam_search.forward(self.hparams.valid_search, encoder_ln_out.half(), wav_lens.half()) + torch.cuda.synchronize() + infer_time = time.time() - start_time + + self.infer_time += infer_time + + return p_ctc, wav_lens, hyps + + def compute_objectives(self, predictions, batch, stage): + """Computes the loss (CTC+NLL) given predictions and targets.""" + + # ( + # p_ctc, + # p_seq, + # wav_lens, + # hyps, + # ) = predictions + + # 去除 seq2seq log-probabilities + ( + p_ctc, + wav_lens, + hyps, + ) = predictions + + ids = batch.id + tokens_eos, tokens_eos_lens = batch.tokens_eos + tokens, tokens_lens = batch.tokens + + if hasattr(self.modules, "env_corrupt") and stage == sb.Stage.TRAIN: + tokens_eos = torch.cat([tokens_eos, tokens_eos], dim=0) + tokens_eos_lens = torch.cat( + [tokens_eos_lens, tokens_eos_lens], dim=0) + tokens = torch.cat([tokens, tokens], dim=0) + tokens_lens = torch.cat([tokens_lens, tokens_lens], dim=0) + + # 去除 seq2seq 部分 loss + # loss = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens) + + if stage != sb.Stage.TRAIN: + current_epoch = self.hparams.epoch_counter.current + valid_search_interval = self.hparams.valid_search_interval + + if current_epoch % valid_search_interval == 0 or (stage == sb.Stage.TEST): + # Decode token terms to words + predicted_words = [ + tokenizer.decode_ids(utt_seq).split(" ") for utt_seq in hyps + ] + target_words = [wrd.split(" ") for wrd in batch.wrd] + if self.hparams.remove_spaces: + predicted_words = ["".join(p) for p in predicted_words] + target_words = ["".join(t) for t in target_words] + self.cer_metric.append(ids, predicted_words, target_words) + + # 不计算 acc 部分 + # # compute the accuracy of the one-step-forward prediction + # self.acc_metric.append(p_seq, tokens_eos, tokens_eos_lens) + return -torch.ones([1]) + + def fit_batch(self, batch): + """Train the parameters given a single batch in input""" + # check if we need to switch optimizer + # if so change the optimizer from Adam to SGD + self.check_and_reset_optimizer() + + predictions = self.compute_forward(batch, sb.Stage.TRAIN) + loss = self.compute_objectives(predictions, batch, sb.Stage.TRAIN) + + # normalize the loss by gradient_accumulation step + (loss / self.hparams.gradient_accumulation).backward() + + if self.step % self.hparams.gradient_accumulation == 0: + # gradient clipping & early stop if loss is not fini + self.check_gradients(loss) + + self.optimizer.step() + self.optimizer.zero_grad() + + # anneal lr every update + self.hparams.noam_annealing(self.optimizer) + + return loss.detach() + + def evaluate_batch(self, batch, stage): + """Computations needed for validation/test batches""" + with torch.no_grad(): + predictions = self.compute_forward(batch, stage=stage) + loss = self.compute_objectives(predictions, batch, stage=stage) + return loss + + def on_stage_start(self, stage, epoch): + """Gets called at the beginning of each epoch""" + if stage != sb.Stage.TRAIN: + # self.acc_metric = self.hparams.acc_computer() + self.cer_metric = self.hparams.cer_computer() + + def on_stage_end(self, stage, stage_loss, epoch): + """Gets called at the end of a epoch.""" + # Compute/store important stats + stage_stats = {"forward time": self.forward_time} + if stage == sb.Stage.TRAIN: + self.train_stats = stage_stats + else: + # stage_stats["ACC"] = self.acc_metric.summarize() + current_epoch = self.hparams.epoch_counter.current + valid_search_interval = self.hparams.valid_search_interval + if current_epoch % valid_search_interval == 0 or stage == sb.Stage.TEST: + stage_stats["CER"] = self.cer_metric.summarize("error_rate") + + # log stats and save checkpoint at end-of-epoch + if stage == sb.Stage.VALID and sb.utils.distributed.if_main_process(): + + # report different epoch stages according current stage + current_epoch = self.hparams.epoch_counter.current + if current_epoch <= self.hparams.stage_one_epochs: + lr = self.hparams.noam_annealing.current_lr + steps = self.hparams.noam_annealing.n_steps + optimizer = self.optimizer.__class__.__name__ + else: + lr = self.hparams.lr_sgd + steps = -1 + optimizer = self.optimizer.__class__.__name__ + + epoch_stats = { + "epoch": epoch, + "lr": lr, + "steps": steps, + "optimizer": optimizer, + } + self.hparams.train_logger.log_stats( + stats_meta=epoch_stats, + train_stats=self.train_stats, + valid_stats=stage_stats, + ) + self.checkpointer.save_and_keep_only( + meta={"ACC": stage_stats["ACC"], "epoch": epoch}, + max_keys=["ACC"], + num_to_keep=10, + ) + + elif stage == sb.Stage.TEST: + self.hparams.train_logger.log_stats( + stats_meta={ + "Epoch loaded": self.hparams.epoch_counter.current}, + test_stats=stage_stats, + ) + with open(self.hparams.cer_file, "w") as w: + self.cer_metric.write_stats(w) + + def check_and_reset_optimizer(self): + """reset the optimizer if training enters stage 2""" + current_epoch = self.hparams.epoch_counter.current + if not hasattr(self, "switched"): + self.switched = False + if isinstance(self.optimizer, torch.optim.SGD): + self.switched = True + + if self.switched is True: + return + + if current_epoch > self.hparams.stage_one_epochs: + self.optimizer = self.hparams.SGD(self.modules.parameters()) + + if self.checkpointer is not None: + self.checkpointer.add_recoverable("optimizer", self.optimizer) + + self.switched = True + + def on_fit_start(self): + """Initialize the right optimizer on the training start""" + super().on_fit_start() + + # if the model is resumed from stage two, reinitialize the optimizer + current_epoch = self.hparams.epoch_counter.current + current_optimizer = self.optimizer + if current_epoch > self.hparams.stage_one_epochs: + del self.optimizer + self.optimizer = self.hparams.SGD(self.modules.parameters()) + + # Load latest checkpoint to resume training if interrupted + if self.checkpointer is not None: + + # do not reload the weights if training is interrupted right before stage 2 + group = current_optimizer.param_groups[0] + if "momentum" not in group: + return + + self.checkpointer.recover_if_possible( + device=torch.device(self.device)) + + def on_evaluate_start(self, max_key=None, min_key=None): + """perform checkpoint averge if needed""" + super().on_evaluate_start() + + ckpts = self.checkpointer.find_checkpoints( + max_key=max_key, min_key=min_key) + ckpt = sb.utils.checkpoints.average_checkpoints( + ckpts, recoverable_name="model", device=self.device + ) + + self.hparams.model.load_state_dict(ckpt, strict=True) + self.hparams.model.eval() + + def evaluate( + self, + test_set, + max_key=None, + min_key=None, + progressbar=None, + test_loader_kwargs={}, + ): + self.debug = False + self.debug_batches = 1 + if progressbar is None: + progressbar = not self.noprogressbar + + if not ( + isinstance(test_set, DataLoader) + or isinstance(test_set, LoopedLoader) + ): + test_loader_kwargs["ckpt_prefix"] = None + test_set = self.make_dataloader( + test_set, Stage.TEST, **test_loader_kwargs + ) + self.on_evaluate_start(max_key=max_key, min_key=min_key) + self.on_stage_start(Stage.TEST, epoch=None) + self.modules.eval() + avg_test_loss = 0.0 + self.step = 0 + with torch.no_grad(): + for batch in tqdm( + test_set, dynamic_ncols=True, disable=not progressbar + ): + self.step += 1 + loss = self.evaluate_batch(batch, stage=Stage.TEST) + avg_test_loss = self.update_average(loss, avg_test_loss) + + # Profile only if desired (steps allow the profiler to know when all is warmed up) + if self.profiler is not None: + if self.profiler.record_steps: + self.profiler.step() + + # Debug mode only runs a few batches + if self.debug and self.step == self.debug_batches: + break + + # Only run evaluation "on_stage_end" on main process + run_on_main( + self.on_stage_end, args=[Stage.TEST, avg_test_loss, None] + ) + self.step = 0 + return avg_test_loss + + +def dataio_prepare(hparams): + """This function prepares the datasets to be used in the brain class. + It also defines the data processing pipeline through user-defined functions.""" + data_folder = hparams["data_folder"] + + train_data = sb.dataio.dataset.DynamicItemDataset.from_csv( + csv_path=hparams["train_data"], + replacements={"data_root": data_folder}, + ) + + if hparams["sorting"] == "ascending": + # we sort training data to speed up training and get better results. + train_data = train_data.filtered_sorted(sort_key="duration") + # when sorting do not shuffle in dataloader ! otherwise is pointless + hparams["train_dataloader_opts"]["shuffle"] = False + + elif hparams["sorting"] == "descending": + train_data = train_data.filtered_sorted( + sort_key="duration", reverse=True) + # when sorting do not shuffle in dataloader ! otherwise is pointless + hparams["train_dataloader_opts"]["shuffle"] = False + + elif hparams["sorting"] == "random": + pass + + else: + raise NotImplementedError( + "sorting must be random, ascending or descending") + + valid_data = sb.dataio.dataset.DynamicItemDataset.from_csv( + csv_path=hparams["valid_data"], + replacements={"data_root": data_folder}, + ) + valid_data = valid_data.filtered_sorted(sort_key="duration") + + test_data = sb.dataio.dataset.DynamicItemDataset.from_csv( + csv_path=hparams["test_data"], + replacements={"data_root": data_folder}, + ) + test_data = test_data.filtered_sorted(sort_key="duration") + + datasets = [train_data, valid_data, test_data] + + # Defining tokenizer and loading it + tokenizer = hparams["tokenizer"] + + # 2. Define audio pipeline: + @sb.utils.data_pipeline.takes("wav") + @sb.utils.data_pipeline.provides("sig") + def audio_pipeline(wav): + sig = sb.dataio.dataio.read_audio(wav) + return sig + + sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline) + + # 3. Define text pipeline: + @sb.utils.data_pipeline.takes("transcript") + @sb.utils.data_pipeline.provides( + "wrd", "tokens_list", "tokens_bos", "tokens_eos", "tokens" + ) + def text_pipeline(wrd): + yield wrd + tokens_list = tokenizer.encode_as_ids(wrd) + yield tokens_list + tokens_bos = torch.LongTensor([hparams["bos_index"]] + (tokens_list)) + yield tokens_bos + tokens_eos = torch.LongTensor(tokens_list + [hparams["eos_index"]]) + yield tokens_eos + tokens = torch.LongTensor(tokens_list) + yield tokens + + sb.dataio.dataset.add_dynamic_item(datasets, text_pipeline) + + # 4. Set output: + sb.dataio.dataset.set_output_keys( + datasets, + ["id", "sig", "wrd", "tokens_bos", "tokens_eos", "tokens"], + ) + + # 5. If Dynamic Batching is used, we instantiate the needed samplers. + train_batch_sampler = None + valid_batch_sampler = None + if hparams["dynamic_batching"]: + from speechbrain.dataio.sampler import DynamicBatchSampler # noqa + + dynamic_hparams = hparams["dynamic_batch_sampler"] + num_buckets = dynamic_hparams["num_buckets"] + + train_batch_sampler = DynamicBatchSampler( + train_data, + dynamic_hparams["max_batch_len"], + num_buckets=num_buckets, + length_func=lambda x: x["duration"], + shuffle=dynamic_hparams["shuffle_ex"], + batch_ordering=dynamic_hparams["batch_ordering"], + ) + + valid_batch_sampler = DynamicBatchSampler( + valid_data, + dynamic_hparams["max_batch_len"], + num_buckets=num_buckets, + length_func=lambda x: x["duration"], + shuffle=dynamic_hparams["shuffle_ex"], + batch_ordering=dynamic_hparams["batch_ordering"], + ) + + return ( + train_data, + valid_data, + test_data, + tokenizer, + train_batch_sampler, + valid_batch_sampler, + ) + + +if __name__ == "__main__": + # CLI: + hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:]) + with open(hparams_file) as fin: + hparams = load_hyperpyyaml(fin, overrides) + + # If --distributed_launch then + # create ddp_group with the right communication protocol + sb.utils.distributed.ddp_init_group(run_opts) + + # Create experiment directory + sb.create_experiment_directory( + experiment_directory=hparams["output_folder"], + hyperparams_to_save=hparams_file, + overrides=overrides, + ) + + # 1. # Dataset prep (parsing Librispeech) + from aishell_prepare import prepare_aishell # noqa + + # multi-gpu (ddp) save data preparation + run_on_main( + prepare_aishell, + kwargs={ + "data_folder": hparams["data_folder"], + "save_folder": hparams["output_folder"], + "skip_prep": hparams["skip_prep"], + }, + ) + + # here we create the datasets objects as well as tokenization and encoding + ( + train_data, + valid_data, + test_data, + tokenizer, + train_bsampler, + valid_bsampler, + ) = dataio_prepare(hparams) + + hparams["pretrainer"].collect_files(default_source=hparams['ckpt_path']) + hparams["pretrainer"].load_collected(device=run_opts["device"]) + + # Trainer initialization + asr_brain = ASR( + modules=hparams["modules"], + opt_class=hparams["Adam"], + hparams=hparams, + run_opts=run_opts, + checkpointer=hparams["checkpointer"], + engine_path=hparams['engine_path'] + ) + + asr_brain.tokenizer = tokenizer + + # Changing the samplers if dynamic batching is activated + train_dataloader_opts = hparams["train_dataloader_opts"] + valid_dataloader_opts = hparams["valid_dataloader_opts"] + + if train_bsampler is not None: + train_dataloader_opts = { + "batch_sampler": train_bsampler, + "num_workers": hparams["num_workers"], + } + if valid_bsampler is not None: + valid_dataloader_opts = {"batch_sampler": valid_bsampler} + + # evaluation + print("*** start evaluation ***") + start_time = time.time() + asr_brain.evaluate( + test_data, test_loader_kwargs=hparams["test_dataloader_opts"]) + eval_time = asr_brain.infer_time + + ## 统计数据总音频时长 + duration = 0.0 + for value in test_data.data.values(): + duration = duration + value['duration'] + num_samples = len(test_data) + print(f"samples: {num_samples}, QPS: {num_samples / eval_time} ") + print(f"infer time :{eval_time},RTF: {eval_time / duration} ") diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/load_ixrt_plugin.py b/models/speech/speech_recognition/transformer_asr/ixrt/load_ixrt_plugin.py new file mode 100644 index 00000000..2bb0abc2 --- /dev/null +++ b/models/speech/speech_recognition/transformer_asr/ixrt/load_ixrt_plugin.py @@ -0,0 +1,26 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +import ctypes +import tensorrt +from os.path import join, dirname, exists +def load_ixrt_plugin(logger=tensorrt.Logger(tensorrt.Logger.INFO), namespace="", dynamic_path=""): + if not dynamic_path: + dynamic_path = join(dirname(tensorrt.__file__), "lib", "libixrt_plugin.so") + if not exists(dynamic_path): + raise FileNotFoundError( + f"The ixrt_plugin lib {dynamic_path} is not existed, please provided effective plugin path!") + ctypes.CDLL(dynamic_path) + tensorrt.init_libnvinfer_plugins(logger, namespace) + print(f"Loaded plugin from {dynamic_path}") -- Gitee From 2199561f26570783d1b479bf276584f4243abe67 Mon Sep 17 00:00:00 2001 From: majorli Date: Wed, 7 Aug 2024 13:36:29 +0800 Subject: [PATCH 2/2] update transformer results and format Signed-off-by: majorli --- .../transformer_asr/ixrt/README.md | 47 +++++++++++++++---- 1 file changed, 37 insertions(+), 10 deletions(-) diff --git a/models/speech/speech_recognition/transformer_asr/ixrt/README.md b/models/speech/speech_recognition/transformer_asr/ixrt/README.md index 7560b5eb..0c2e1b45 100644 --- a/models/speech/speech_recognition/transformer_asr/ixrt/README.md +++ b/models/speech/speech_recognition/transformer_asr/ixrt/README.md @@ -1,31 +1,52 @@ -# Asr transformer fp16 inference (BeamSearch) +# Transformer ASR(BeamSearch) ## Description Beam search allows us to exert control over the output of text generation. This is useful because we sometimes know exactly what we want inside the output. For example, in a Neural Machine Translation task, we might know which words must be included in the final translation with a dictionary lookup. - ## Setup ### Install -``` +```bash pip3 install speechbrain==0.5.13 ``` -* ixrt 4.0.1_MR release - ### Download Pretrained model: Dataset: to download the Aishell dataset. -``` +```bash # Make sure the checkpoint path is results/transformer/8886/save mkdir -p results/transformer/8886/save +# The data path like below: +results/transformer/8886 +├── cer.txt +├── dev.csv +├── env.log +├── hyperparams.yaml +├── inference_encoder_ctc.py +├── inference.py +├── log.txt +├── save +│ ├── CKPT+2023-03-29+06-31-40+00 +│ │ ├── brain.ckpt +│ │ ├── CKPT.yaml +│ │ ├── counter.ckpt +│ │ ├── model.ckpt +│ │ ├── noam_scheduler.ckpt +│ │ └── normalizer.ckpt +│ └── tokenizer.ckpt +├── test.csv +├── train.csv +└── train_log.txt + # Make sure the dataset path is results/transformer/8886/save -mkdir -p /home/data/speechbrain +mkdir -p /home/data/speechbrain/aishell/csv_data +ln -s /PATH/to/data_aishell /home/data/speechbrain/aishell/ +cp results/transformer/8886/*.csv /home/data/speechbrain/aishell/csv_data ``` ## Inference @@ -40,7 +61,7 @@ bash build.sh max_batch_size and max_seq_len depend on the situation. -``` +```bash python3 builder.py \ --ckpt_path results/transformer/8886/save \ --head_num 4 \ @@ -51,6 +72,12 @@ python3 builder.py \ ### Run engine -``` +```bash python3 inference.py hparams/train_ASR_transformer.yaml --data_folder=/home/data/speechbrain/aishell --engine_path transformer.engine -``` \ No newline at end of file +``` + +## Results + +| Model | BatchSize | Precision | QPS | CER | +| --------------- | --------- | --------- | ----- | ---- | +| Transformer ASR | 32 | FP16 | 15.64 | 5.95 | -- Gitee