diff --git a/README.md b/README.md index 90c697a015b9f333988e3a205fb5176f0b4749b8..5ad55a16581bc03d39b13b847277f0d29dda59a5 100644 --- a/README.md +++ b/README.md @@ -746,7 +746,7 @@ DeepSparkInference将按季度进行版本更新,后续会逐步丰富模型 Conformer FP16 Supported - - + Supported INT8 diff --git a/models/speech/speech_recognition/conformer/ixrt/README.md b/models/speech/speech_recognition/conformer/ixrt/README.md new file mode 100644 index 0000000000000000000000000000000000000000..2d4d98f35a90699026b707f42619ecb368deb3c3 --- /dev/null +++ b/models/speech/speech_recognition/conformer/ixrt/README.md @@ -0,0 +1,49 @@ +# Conformer + +## Description + +Conformer is a speech recognition model proposed by Google in 2020. It combines the advantages of CNN and Transformer. CNN efficiently extracts local features, while Transformer is more effective in capturing long sequence dependencies. Conformer applies convolution to the Encoder layer of Transformer, enhancing the performance of Transformer in the ASR (Automatic Speech Recognition) domain. + +## Setup + +### Install + +```bash +pip3 install tqdm +pip3 install onnx +pip3 install typeguard==2.13.3 +pip3 install onnxsim +``` + +### Download + +Pretrained model: + +Dataset: to download the Aishell dataset. + +download and put model in conformer_checkpoints, put data in aishell_test_data. + +### Prepare Data +```bash +# Accuracy +DATA_DIR=./aishell_test_data +Tool_DIR=./tools +bash scripts/aishell_data_prepare.sh ${DATA_DIR} ${Tool_DIR} +``` + +### Model Conversion And Inference + +### FP16 + +```bash +# Accuracy +bash scripts/infer_conformer_fp16_accuracy_ixrt.sh +# Performance +bash scripts/infer_conformer_fp16_performance_ixrt.sh +``` + +## Results + +Model |BatchSize |Precision |QPS |CER | +-----------|-----------|----------|----------|----------| +Conformer | 24 | FP16 | 380.00 | 0.051 | diff --git a/models/speech/speech_recognition/conformer/ixrt/build_engine.py b/models/speech/speech_recognition/conformer/ixrt/build_engine.py new file mode 100644 index 0000000000000000000000000000000000000000..aa20ee59f6ecd23d8a8cb9272ece0087ed65ab89 --- /dev/null +++ b/models/speech/speech_recognition/conformer/ixrt/build_engine.py @@ -0,0 +1,145 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +Build Engine From FusionPlugin Onnx. +""" + +import os +import ctypes +import json +import onnx +import logging +import argparse + +import tensorrt +import tensorrt as trt +from tensorrt import Dims + + +TRT_LOGGER = trt.Logger(trt.Logger.WARNING) +def load_ixrt_plugin(logger=trt.Logger(trt.Logger.WARNING), namespace="", dynamic_path=""): + if not dynamic_path: + dynamic_path = os.path.join(os.path.dirname(trt.__file__), "lib", "libixrt_plugin.so") + if not os.path.exists(dynamic_path): + raise FileNotFoundError( + f"The ixrt_plugin lib {dynamic_path} is not existed, please provided effective plugin path!" + ) + ctypes.CDLL(dynamic_path, mode=ctypes.RTLD_GLOBAL) + trt.init_libnvinfer_plugins(logger, namespace) + print(f"Loaded plugin from {dynamic_path}") + +load_ixrt_plugin() + + + +def parse_args(): + parser = argparse.ArgumentParser(description="build tensorrt engine of conformer.", usage="") + parser.add_argument( + "--model_name", + type=str, + required=True, + help="conformer", + ) + parser.add_argument( + "--onnx_path", + type=str, + required=True, + help="onnx_path path to save", + ) + parser.add_argument( + "--engine_path", + type=str, + required=True, + help="engine path to save", + ) + parser.add_argument( + "--max_batch_size", + type=int, + required=True, + ) + parser.add_argument( + "--max_seq_len", + type=int, + required=True, + ) + args = parser.parse_args() + return args + +args = parse_args() +MaxBSZ = args.max_batch_size +MaxSeqLen = args.max_seq_len + + +def build_engine_trtapi_dynamicshape(args): + onnx_model = args.onnx_path + assert os.path.isfile(onnx_model), f"The onnx model{onnx_model} must be existed!" + IXRT_LOGGER = tensorrt.Logger(tensorrt.Logger.WARNING) + builder = tensorrt.Builder(IXRT_LOGGER) + EXPLICIT_BATCH = 1 << (int)(tensorrt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) + network = builder.create_network(EXPLICIT_BATCH) + build_config = builder.create_builder_config() + + profile = builder.create_optimization_profile() + profile.set_shape("input", Dims([MaxBSZ, 100, 80]), Dims([MaxBSZ, 1000, 80]), Dims([MaxBSZ, 1500, 80])) + profile.set_shape("mask", Dims([MaxBSZ, 1, 25]), Dims([MaxBSZ, 1, 250]), Dims([MaxBSZ, 1, 374])) + profile.set_shape("pos_emb", Dims([1, 25, 256]), Dims([1, 250, 256]), Dims([1, 374, 256])) + build_config.add_optimization_profile(profile) + + parser = tensorrt.OnnxParser(network, IXRT_LOGGER) + parser.parse_from_file(onnx_model) + build_config.set_flag(tensorrt.BuilderFlag.FP16) + + # set dynamic + # input + input_tensor = network.get_input(0) + input_tensor.shape = Dims([MaxBSZ, -1, 80]) + # mask + mask_tensor = network.get_input(1) + mask_tensor.shape = Dims([MaxBSZ, 1, -1]) + # pos_emb + pos_emb_tensor = network.get_input(2) + pos_emb_tensor.shape = Dims([1, -1, 256]) + + plan = builder.build_serialized_network(network, build_config) + with open(args.engine_path, "wb") as f: + f.write(plan) + + print("Build dynamic shape engine done!") + + +def build_engine_trtapi_staticshape(args): + onnx_model = args.onnx_path + assert os.path.isfile(onnx_model), f"The onnx model{onnx_model} must be existed!" + IXRT_LOGGER = tensorrt.Logger(tensorrt.Logger.WARNING) + builder = tensorrt.Builder(IXRT_LOGGER) + EXPLICIT_BATCH = 1 << (int)(tensorrt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) + network = builder.create_network(EXPLICIT_BATCH) + build_config = builder.create_builder_config() + parser = tensorrt.OnnxParser(network, IXRT_LOGGER) + + parser.parse_from_file(onnx_model) + build_config.set_flag(tensorrt.BuilderFlag.FP16) + + plan = builder.build_serialized_network(network, build_config) + with open(args.engine_path, "wb") as f: + f.write(plan) + + print("Build static shape engine done!") + + +if __name__ == "__main__": + build_engine_trtapi_dynamicshape(args) + # build_engine_trtapi_staticshape(args) diff --git a/models/speech/speech_recognition/conformer/ixrt/common.py b/models/speech/speech_recognition/conformer/ixrt/common.py new file mode 100644 index 0000000000000000000000000000000000000000..89023300ddc7ca3e4f0f992f4b124d8a8c131ae5 --- /dev/null +++ b/models/speech/speech_recognition/conformer/ixrt/common.py @@ -0,0 +1,136 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import os +import ctypes +import cv2 +import glob +import torch +import tensorrt +import tensorrt as trt +import numpy as np +import pycuda.driver as cuda + +from tensorrt.hook.utils import copy_ixrt_io_tensors_as_np + + +TRT_LOGGER = trt.Logger(trt.Logger.WARNING) +def load_ixrt_plugin(logger=trt.Logger(trt.Logger.WARNING), namespace="", dynamic_path=""): + if not dynamic_path: + dynamic_path = os.path.join(os.path.dirname(trt.__file__), "lib", "libixrt_plugin.so") + if not os.path.exists(dynamic_path): + raise FileNotFoundError( + f"The ixrt_plugin lib {dynamic_path} is not existed, please provided effective plugin path!" + ) + ctypes.CDLL(dynamic_path, mode=ctypes.RTLD_GLOBAL) + trt.init_libnvinfer_plugins(logger, namespace) + print(f"Loaded plugin from {dynamic_path}") +load_ixrt_plugin() + + +def trtapi(engine_file): + datatype = tensorrt.DataType.FLOAT + host_mem = tensorrt.IHostMemory + logger = tensorrt.Logger(tensorrt.Logger.ERROR) + with open(engine_file, "rb") as f, tensorrt.Runtime(logger) as runtime: + runtime = tensorrt.Runtime(logger) + assert runtime + engine = runtime.deserialize_cuda_engine(f.read()) + assert engine + context = engine.create_execution_context() + assert context + + return engine, context + + +def create_engine_context(engine_path, logger): + with open(engine_path, "rb") as f: + runtime = tensorrt.Runtime(logger) + assert runtime + engine = runtime.deserialize_cuda_engine(f.read()) + assert engine + context = engine.create_execution_context() + assert context + + return engine, context + + +def get_io_bindings(engine): + # Setup I/O bindings + inputs = [] + outputs = [] + allocations = [] + + for i in range(engine.num_bindings): + is_input = False + if engine.binding_is_input(i): + is_input = True + name = engine.get_binding_name(i) + dtype = engine.get_binding_dtype(i) + shape = engine.get_binding_shape(i) + if is_input: + batch_size = shape[0] + size = np.dtype(tensorrt.nptype(dtype)).itemsize + for s in shape: + size *= s + allocation = cuda.mem_alloc(size) + binding = { + "index": i, + "name": name, + "dtype": np.dtype(tensorrt.nptype(dtype)), + "shape": list(shape), + "allocation": allocation, + } + print(f"binding {i}, name : {name} dtype : {np.dtype(tensorrt.nptype(dtype))} shape : {list(shape)}") + allocations.append(allocation) + if engine.binding_is_input(i): + inputs.append(binding) + else: + outputs.append(binding) + return inputs, outputs, allocations + + +def setup_io_bindings(engine, context): + # Setup I/O bindings + inputs = [] + outputs = [] + allocations = [] + + for i in range(engine.num_bindings): + is_input = False + if engine.binding_is_input(i): + is_input = True + name = engine.get_binding_name(i) + dtype = engine.get_binding_dtype(i) + shape = context.get_binding_shape(i) + if is_input: + batch_size = shape[0] + size = np.dtype(tensorrt.nptype(dtype)).itemsize + for s in shape: + size *= s + allocation = cuda.mem_alloc(size) + binding = { + "index": i, + "name": name, + "dtype": np.dtype(tensorrt.nptype(dtype)), + "shape": list(shape), + "allocation": allocation, + } + allocations.append(allocation) + if engine.binding_is_input(i): + inputs.append(binding) + else: + outputs.append(binding) + return inputs, outputs, allocations diff --git a/models/speech/speech_recognition/conformer/ixrt/convert2onnx.py b/models/speech/speech_recognition/conformer/ixrt/convert2onnx.py new file mode 100644 index 0000000000000000000000000000000000000000..823ae3215f58d18a636e868668199ed3f388ee20 --- /dev/null +++ b/models/speech/speech_recognition/conformer/ixrt/convert2onnx.py @@ -0,0 +1,529 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +Build Compute Graph(Fusion Plugin Onnx) From Checkpoints. +""" + +import os +import json +import torch +import argparse +import numpy as np +from collections import OrderedDict + +from tensorrt.deploy.api import GraphTransform, create_source, create_target +from tensorrt.deploy.ir.data_type import DataType +from tensorrt.deploy.ir.variable import Variable, VariableOptions +from tensorrt.deploy.ir.graph import Graph + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Build Compute Graph From Checkpoints.", usage="" + ) + parser.add_argument( + "--model_name", + type=str, + required=True, + help="conformer", + ) + parser.add_argument( + "--model_path", + type=str, + required=True, + help="checkpont of conformer", + ) + parser.add_argument( + "--onnx_path", + type=str, + required=True, + help="raw onnx path to save", + ) + parser.add_argument( + "--batch_size", + type=int, + required=True, + help="the batch size for test.", + ) + args = parser.parse_args() + return args + + +def add_global_cmvn_op(graph, state_dict, args): + t = graph + + sub_inputs = [t.make_variable("input", dtype=DataType.FLOAT, shape=(128, 1500, 80))] + key = "encoder.global_cmvn.mean" + sub_inputs.append(t.make_variable(name=key, value=state_dict[key])) + sub_outputs = [t.make_variable("Sub_output_0", dtype=DataType.FLOAT, shape=(128, 1500, 80))] + t.make_operator( + "Sub", + inputs=sub_inputs, + outputs=sub_outputs, + ) + + mul_inputs = sub_outputs + key = "encoder.global_cmvn.istd" + mul_inputs.append(t.make_variable(name=key, value=state_dict[key])) + mul_outputs = [t.make_variable("Mul_output_0", dtype=DataType.FLOAT, shape=(128, 1500, 80))] + t.make_operator( + "Mul", + inputs=mul_inputs, + outputs=mul_outputs, + ) + + unsqueeze_inputs = mul_outputs + unsqueeze_inputs.append(t.make_variable("axes", value=np.array([1], dtype=np.int64))) + unsqueeze_outputs = [t.make_variable("Unsqueeze_output_0", dtype=DataType.FLOAT, shape=(128, 1, 1500, 80))] + t.make_operator( + "Unsqueeze", + inputs=unsqueeze_inputs, + outputs=unsqueeze_outputs, + ) + + +def add_first_submodule_op(graph, state_dict, args): + """ + The firt submodule part contains follows: + 1.Conv2d+ReLU; + 2.Conv2d+ReLU; + 3.Transpose+Reshape; + 4.MatMul+Add+Mul; + """ + + t = graph + conv2d0_weight_keys = [ + "encoder.embed.conv.0.weight", + "encoder.embed.conv.0.bias", + ] + conv2d0_attributes = { + "dilations": [1, 1], + "group": 1, + "kernel_shape": [3, 3], + "pads": [0, 0, 0, 0], + "strides": [2, 2], + } + conv2d0_inputs = [t.get_variable("Unsqueeze_output_0")] + conv2d0_outputs = [t.make_variable("Conv_output_0", dtype=DataType.FLOAT)] + + for key in conv2d0_weight_keys: + conv2d0_inputs.append(t.make_variable(name=key, value=state_dict[key])) + t.make_operator( + "Conv", + inputs=conv2d0_inputs, + outputs=conv2d0_outputs, + **conv2d0_attributes + ) + + relu0_inputs = conv2d0_outputs + relu0_outputs = [t.make_variable("Relu_output_0", dtype=DataType.FLOAT)] + t.make_operator( + "Relu", + inputs=relu0_inputs, + outputs=relu0_outputs + ) + + conv2d1_weight_keys = [ + "encoder.embed.conv.2.weight", + "encoder.embed.conv.2.bias", + ] + conv2d1_attributes = { + "dilations": [1, 1], + "group": 1, + "kernel_shape": [3, 3], + "pads": [0, 0, 0, 0], + "strides": [2, 2], + } + conv2d1_inputs = relu0_outputs + conv2d1_outputs = [t.make_variable("Conv_output_1", dtype=DataType.FLOAT)] + + for key in conv2d1_weight_keys: + conv2d1_inputs.append(t.make_variable(name=key, value=state_dict[key])) + t.make_operator( + "Conv", + inputs=conv2d1_inputs, + outputs=conv2d1_outputs, + **conv2d1_attributes + ) + + relu1_inputs = conv2d1_outputs + relu1_outputs = [t.make_variable("Relu_output_1", dtype=DataType.FLOAT)] + t.make_operator( + "Relu", + inputs=relu1_inputs, + outputs=relu1_outputs + ) + + tran_inputs = relu1_outputs + tran_outputs = [t.make_variable("Transpose_output_0", dtype=DataType.FLOAT)] + tran_attributes = {"perm": [0, 2, 1, 3]} + t.make_operator( + "Transpose", + inputs=tran_inputs, + outputs=tran_outputs, + **tran_attributes + ) + + reshape_inputs = tran_outputs + reshape_inputs.append(t.make_variable(name="constant_0", value=np.array([args.batch_size, -1, 4864]), dtype=DataType.INT64)) + reshape_outputs = [t.make_variable("Reshape_output_0", dtype=DataType.FLOAT)] + t.make_operator( + "Reshape", + inputs=reshape_inputs, + outputs=reshape_outputs, + ) + + matmul_inputs = reshape_outputs + matmul_inputs.append(t.make_variable(name="embed.out.0.weight", value=state_dict["encoder.embed.out.0.weight"].transpose(1, 0))) # (256,4864)--->(4864,256) + matmul_outputs = [t.make_variable("MatMul_output_0", dtype=DataType.FLOAT)] + t.make_operator( + "MatMul", + inputs=matmul_inputs, + outputs=matmul_outputs, + ) + + add_inputs = matmul_outputs + add_inputs.append(t.make_variable(name="embed.out.0.bias", value=state_dict["encoder.embed.out.0.bias"])) + add_outputs = [t.make_variable("Add_output_0", dtype=DataType.FLOAT)] + t.make_operator( + "Add", + inputs=add_inputs, + outputs=add_outputs, + ) + + mul_inputs = add_outputs + mul_inputs.append(t.make_variable(name="constant_1", value=np.array([16.], dtype=np.float32), dtype=DataType.FLOAT)) + mul_outputs = [t.make_variable("Mul_output_1", dtype=DataType.FLOAT)] + t.make_operator( + "Mul", + inputs=mul_inputs, + outputs=mul_outputs, + ) + + +def add_encoder_ff_macaron_op(graph, state_dict, args, index): + + t = graph + ff_macaron_keys = [ + "encoder.encoders.{}.norm_ff_macaron.weight", + "encoder.encoders.{}.norm_ff_macaron.bias", + "encoder.encoders.{}.feed_forward_macaron.w_1.weight", + "encoder.encoders.{}.feed_forward_macaron.w_1.bias", + "encoder.encoders.{}.feed_forward_macaron.w_2.weight", + "encoder.encoders.{}.feed_forward_macaron.w_2.bias", + ] + + attributes = { + "in_feature": 256, + "hidden_size": 2048, + "act_type": 12, + "ff_scale": 0.5, + } + + if index == 0: + inputs = [graph.get_variable("Mul_output_1")] + else: + inputs = [graph.get_variable("norm_final_{}_output".format(index-1))] + + outputs = [t.make_variable("ff_macaron_{}_output".format(index), dtype=DataType.FLOAT)] + + for key in ff_macaron_keys: + key = key.format(index) + inputs.append(t.make_variable(name=key, value=state_dict[key].half(), dtype=DataType.FLOAT16)) + + t.make_operator( + "PositionWiseFFNPluginDynamic_IxRT", + inputs=inputs, + outputs=outputs, + **attributes + ) + + +def add_encoder_mhsa_op(graph, state_dict, args, index): + + t = graph + mhsa_keys = [ + "encoder.encoders.{}.norm_mha.weight", + "encoder.encoders.{}.norm_mha.bias", + "encoder.encoders.{}.self_attn.linear_q.weight", + "encoder.encoders.{}.self_attn.linear_q.bias", + "encoder.encoders.{}.self_attn.linear_k.weight", + "encoder.encoders.{}.self_attn.linear_k.bias", + "encoder.encoders.{}.self_attn.linear_v.weight", + "encoder.encoders.{}.self_attn.linear_v.bias", + "encoder.encoders.{}.self_attn.linear_pos.weight", + "encoder.encoders.{}.self_attn.pos_bias_u", + "encoder.encoders.{}.self_attn.pos_bias_v", + "encoder.encoders.{}.self_attn.linear_out.weight", + "encoder.encoders.{}.self_attn.linear_out.bias", + ] + + attributes = { + "bs": 128, + "seq_len": 374, + "n_head": 4, + "n_feat": 256, + } + + if index == 0: + inputs = [ + graph.get_variable("ff_macaron_{}_output".format(index)), + t.make_variable("mask", dtype=DataType.INT32, shape=(128, 1, 374)), + t.make_variable("pos_emb", dtype=DataType.FLOAT, shape=(1, 374, 256)), + ] + else: + inputs = [ + graph.get_variable("ff_macaron_{}_output".format(index)), + graph.get_variable("mask"), + graph.get_variable("pos_emb"), + ] + + outputs = [t.make_variable("mhsa_{}_output".format(index), dtype=DataType.FLOAT)] + + for key in mhsa_keys: + key = key.format(index) + inputs.append(t.make_variable(name=key, value=state_dict[key].half(), dtype=DataType.FLOAT16)) + + t.make_operator( + "ConformerMultiHeadSelfAttentionPlugin_IxRT", + inputs=inputs, + outputs=outputs, + **attributes + ) + + +def add_encoder_conv_module_op(graph, state_dict, args, index): + + t = graph + conv_module_keys = [ + "encoder.encoders.{}.norm_conv.weight", + "encoder.encoders.{}.norm_conv.bias", + "encoder.encoders.{}.conv_module.pointwise_conv1.weight", + "encoder.encoders.{}.conv_module.pointwise_conv1.bias", + "encoder.encoders.{}.conv_module.depthwise_conv.weight", + "encoder.encoders.{}.conv_module.depthwise_conv.bias", + "encoder.encoders.{}.conv_module.norm.weight", + "encoder.encoders.{}.conv_module.norm.bias", + "encoder.encoders.{}.conv_module.pointwise_conv2.weight", + "encoder.encoders.{}.conv_module.pointwise_conv2.bias", + ] + + attributes = { + "kernel_size_1": 1, + "stride_1": 1, + "odim_1": 512, + "kernel_size_2": 8, + "stride_2": 1, + "odim_2": 256, + "kernel_size_3": 1, + "stride_3": 1, + "odim_3": 256, + } + + inputs = [ + graph.get_variable("mhsa_{}_output".format(index)), + graph.get_variable("mask"), + ] + outputs = [t.make_variable("conv_module_{}_output".format(index), dtype=DataType.FLOAT)] + + for key in conv_module_keys: + key = key.format(index) + + if "conv_module.depthwise_conv.weight" in key: + inputs.append(t.make_variable(name=key, value=state_dict[key].permute(1, 2, 0).half(), dtype=DataType.FLOAT16)) + elif "bias" in key and "norm" not in key: + inputs.append(t.make_variable(name=key, value=state_dict[key], dtype=DataType.FLOAT)) + else: + inputs.append(t.make_variable(name=key, value=state_dict[key].half(), dtype=DataType.FLOAT16)) + + t.make_operator( + "ConformerConvModulePlugin_IxRT", + inputs=inputs, + outputs=outputs, + **attributes + ) + + +def add_encoder_positionwise_ff_op(graph, state_dict, args, index): + + t = graph + positionwise_ff_keys = [ + "encoder.encoders.{}.norm_ff.weight", + "encoder.encoders.{}.norm_ff.bias", + "encoder.encoders.{}.feed_forward.w_1.weight", + "encoder.encoders.{}.feed_forward.w_1.bias", + "encoder.encoders.{}.feed_forward.w_2.weight", + "encoder.encoders.{}.feed_forward.w_2.bias", + ] + + attributes = { + "in_feature": 256, + "hidden_size": 2048, + "act_type": 12, + "ff_scale": 0.5, + } + + inputs = [graph.get_variable('conv_module_{}_output'.format(index))] + outputs = [t.make_variable("positionwise_ff_{}_output".format(index), dtype=DataType.FLOAT)] + + for key in positionwise_ff_keys: + key = key.format(index) + inputs.append(t.make_variable(name=key, value=state_dict[key].half(), dtype=DataType.FLOAT16)) + + t.make_operator( + "PositionWiseFFNPluginDynamic_IxRT", + inputs=inputs, + outputs=outputs, + **attributes + ) + + +def add_encoder_ln_op(graph, state_dict, args, index): + + t = graph + ln_keys = [ + "encoder.encoders.{}.norm_final.weight", + "encoder.encoders.{}.norm_final.bias", + ] + + attributes = { + "axis": -1, + "epsilon": 0.000009999999747378752, + "stash_type": 1, + } + + inputs = [graph.get_variable("positionwise_ff_{}_output".format(index))] + outputs = [t.make_variable("norm_final_{}_output".format(index), dtype=DataType.FLOAT)] + + for key in ln_keys: + key = key.format(index) + inputs.append(t.make_variable(name=key, value=state_dict[key].half(), dtype=DataType.FLOAT16)) + + t.make_operator( + "LayerNormalization", + inputs=inputs, + outputs=outputs, + **attributes + ) + + +def add_final_ln_op(graph, state_dict, args): + + t = graph + ln_keys = [ + "encoder.after_norm.weight", + "encoder.after_norm.bias", + ] + + attributes = { + "axis": -1, + "epsilon": 0.000009999999747378752, + "stash_type": 1, + } + + inputs = [graph.get_variable("norm_final_11_output")] + outputs = [t.make_variable("norm_final_output", dtype=DataType.FLOAT)] + + for key in ln_keys: + inputs.append(t.make_variable(name=key, value=state_dict[key].half(), dtype=DataType.FLOAT16)) + + t.make_operator( + "LayerNormalization", + inputs=inputs, + outputs=outputs, + **attributes + ) + + +def add_ctc_op(graph, state_dict, args): + t = graph + # matmul + matmul_inputs = [graph.get_variable("norm_final_output")] + matmul_inputs.append(t.make_variable(name="ctc.ctc_lo.weight", value=state_dict["ctc.ctc_lo.weight"].transpose(1, 0))) # (4233,256)--->(256,4233) + matmul_outputs = [t.make_variable("MatMul_output_1", dtype=DataType.FLOAT)] + t.make_operator( + "MatMul", + inputs=matmul_inputs, + outputs=matmul_outputs, + ) + + add_inputs = matmul_outputs + add_inputs.append(t.make_variable(name="ctc.ctc_lo.bias", value=state_dict["ctc.ctc_lo.bias"])) + add_outputs = [t.make_variable("Add_output_1", dtype=DataType.FLOAT)] + t.make_operator( + "Add", + inputs=add_inputs, + outputs=add_outputs, + ) + + logsoftmax_inputs = add_outputs + logsoftmax_outputs = [t.make_variable("output", dtype=DataType.FLOAT)] + attributes = { + "axis": 2 + } + t.make_operator( + "LogSoftmax", + inputs=logsoftmax_inputs, + outputs=logsoftmax_outputs, + **attributes + ) + + +def main(args): + graph = Graph() + transform = GraphTransform(graph) + state_dict = torch.load(args.model_path) + + # 0. Global CMVN: sub+mul+unsqueeze + add_global_cmvn_op(transform, state_dict, args) + + # 1. First Submodule: Conv2d+Relu+Transpose+MatMul + add_first_submodule_op(transform, state_dict, args) + + # 2. Second Submodule: ConformerEncoderLayer: 12 layers + for i in range(args.num_layers): + add_encoder_ff_macaron_op(transform, state_dict, args, i) + add_encoder_mhsa_op(transform, state_dict, args, i) + add_encoder_conv_module_op(transform, state_dict, args, i) + add_encoder_positionwise_ff_op(transform, state_dict, args, i) + add_encoder_ln_op(transform, state_dict, args, i) + + # 3. Third Submodule: FinalNorm + add_final_ln_op(transform, state_dict, args) + + # 4.Forth Submodule: CTC+LogSoftmax + add_ctc_op(transform, state_dict, args) + + # 5. set input and output + graph.add_input(graph.get_variable("input")) + graph.add_input(graph.get_variable("mask")) + graph.add_input(graph.get_variable("pos_emb")) + graph.add_output(graph.get_variable("output")) + # 5. export onnx file + create_target(saved_path=args.onnx_path).export(graph) + print("save onnx: ", args.onnx_path) + + +if __name__ == "__main__": + args = parse_args() + model_name = args.model_name.lower() + args.num_layers = 12 + args.hidden_size = 2048 + args.head_num = 4 + args.head_dim = 64 + args.pad_id = 0 + args.inner_size = 3072 + main(args) diff --git a/models/speech/speech_recognition/conformer/ixrt/ixrt_inference_accuracy.py b/models/speech/speech_recognition/conformer/ixrt/ixrt_inference_accuracy.py new file mode 100644 index 0000000000000000000000000000000000000000..35aad9bbf24533bed27e98ddbe4e326fa897df88 --- /dev/null +++ b/models/speech/speech_recognition/conformer/ixrt/ixrt_inference_accuracy.py @@ -0,0 +1,285 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import os +import sys + +sys.path.append(os.path.dirname(os.path.dirname(__file__))) + +import argparse +import yaml +import copy +import torch +import numpy as np + +from tqdm.contrib import tqdm +from torch.utils.data import DataLoader +from wenet.file_utils import read_symbol_table +from wenet.dataset import Dataset +from tools.compute_cer import Calculator, characterize, normalize, default_cluster +import tensorrt +from tensorrt import Dims +from common import create_engine_context, get_io_bindings,trtapi,setup_io_bindings +import pickle + +import pycuda.autoinit +import pycuda.driver as cuda + +from utils import make_pad_mask, RelPositionalEncoding +from postprocess import ctc_greedy_search + + +rel_positional_encoding = RelPositionalEncoding(256, 0.1) + + +def get_args(): + parser = argparse.ArgumentParser(description="recognize with your model") + parser.add_argument( + "--infer_type", + default="fp16", + choices=["fp16", "int8"], + help="inference type: fp16 or int8", + ) + parser.add_argument("--warm_up", type=int, default=3, help="warm_up count") + parser.add_argument("--batch_size", type=int, default=24) + parser.add_argument("--data_dir", required=True, help="test data directory") + parser.add_argument( + "--model_dir", type=str, required=True, help="model for inference" + ) + args = parser.parse_args() + return args + + +def tensorrt_infer(engine, context, all_inputs): + + input_names = ["input", "mask", "pos_emb"] + output_names = ["output"] + + for input_name, input_data in zip(input_names, all_inputs): + input_idx = engine.get_binding_index(input_name) + input_shape = input_data.shape + context.set_binding_shape(input_idx, Dims(input_shape)) + + inputs, outputs, allocations = setup_io_bindings(engine, context) + pred_output = np.zeros(outputs[0]["shape"], outputs[0]["dtype"]) + + for i, input_data in enumerate(all_inputs): + cuda.memcpy_htod(inputs[i]["allocation"], input_data) + + context.execute_v2(allocations) + cuda.memcpy_dtoh(pred_output, outputs[0]["allocation"]) + return pred_output + + +def engine_init(engine): + host_mem = tensorrt.IHostMemory + logger = tensorrt.Logger(tensorrt.Logger.ERROR) + engine, context = create_engine_context(engine, logger) + + return engine,context + + +def calculate_cer(data, reference_data): + calculator = Calculator() + tochar = True + split = None + case_sensitive = False + ignore_words = set() + rec_set = {} + for line in data: + if tochar: + array = characterize(line) + else: + array = line.strip().split() + if len(array) == 0: + continue + fid = array[0] + rec_set[fid] = normalize(array[1:], ignore_words, case_sensitive, split) + + default_clusters = {} + default_words = {} + for line in reference_data: + if tochar: + array = characterize(line) + else: + array = line.strip().split() + if len(array) == 0: + continue + fid = array[0] + if fid not in rec_set: + continue + lab = normalize(array[1:], ignore_words, case_sensitive, split) + rec = rec_set[fid] + + for word in rec + lab: + if word not in default_words: + default_cluster_name = default_cluster(word) + if default_cluster_name not in default_clusters: + default_clusters[default_cluster_name] = {} + if word not in default_clusters[default_cluster_name]: + default_clusters[default_cluster_name][word] = 1 + default_words[word] = default_cluster_name + result = calculator.calculate(lab, rec) + + result = calculator.overall() + cer = float(result["ins"] + result["sub"] + result["del"]) / result["all"] + corr = result["cor"] / result["all"] + + return cer, corr + + +def main(): + args = get_args() + + # 读取配置文件 + config_fn = os.path.join(args.model_dir, "config.yaml") + with open(config_fn, "r") as fin: + configs = yaml.load(fin, Loader=yaml.FullLoader) + + dataset_conf = copy.deepcopy(configs["dataset_conf"]) + dataset_conf["filter_conf"]["max_length"] = 102400 + dataset_conf["filter_conf"]["min_length"] = 0 + dataset_conf["filter_conf"]["token_max_length"] = 102400 + dataset_conf["filter_conf"]["token_min_length"] = 0 + dataset_conf["filter_conf"]["max_output_input_ratio"] = 102400 + dataset_conf["filter_conf"]["min_output_input_ratio"] = 0 + dataset_conf["speed_perturb"] = False + dataset_conf["spec_aug"] = False + dataset_conf["shuffle"] = False + dataset_conf["sort"] = True + dataset_conf["fbank_conf"]["dither"] = 0.0 + dataset_conf["batch_conf"]["batch_type"] = "static" + dataset_conf["batch_conf"]["batch_size"] = args.batch_size + + # Load dict + dict_fn = os.path.join(args.model_dir, "words.txt") + char_dict = {} + with open(dict_fn, "r", encoding="utf8") as fin: + for line in fin: + arr = line.strip().split() + assert len(arr) == 2 + char_dict[int(arr[1])] = arr[0] + eos = len(char_dict) - 1 + + data_type = "raw" + test_data_fn = os.path.join(args.data_dir, "data.list") + symbol_table = read_symbol_table(dict_fn) + test_dataset = Dataset( + data_type, test_data_fn, symbol_table, dataset_conf, partition=False + ) + test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0) + + data_path_pkl = os.path.join(args.data_dir, f"aishell_test_data_bs{args.batch_size}.pkl") + + print("*** 1. Prepare data ***") + if not os.path.isfile(data_path_pkl): + eval_samples = [] + max_batch_size = -1 + max_feature_length = -1 + for batch in test_data_loader: + keys, feats, target, feats_lengths, target_lengths = batch + max_feature_length = max(max_feature_length, feats.size(1)) + max_batch_size = max(max_batch_size, feats.size(0)) + eval_samples.append( + [ + keys, + feats.cpu().numpy().astype(np.float16), + feats_lengths.cpu().numpy().astype(np.int32), + ] + ) + with open(data_path_pkl, "wb") as f: + pickle.dump( + [ + eval_samples, + max_batch_size, + max_feature_length + ], + f, + ) + else: + print(f"load data from tmp: {data_path_pkl}") + with open(data_path_pkl, "rb") as f: + ( + eval_samples, + max_batch_size, + max_feature_length + ) = pickle.load(f) + print( + f"dataset max shape: batch_size: {max_batch_size}, feat_length: {max_feature_length}" + ) + + print("*** 2. Load engine ***") + engine_path = os.path.join(args.model_dir, f"conformer_encoder_fusion.engine") + engine, context = engine_init(engine_path) + + print("*** 3. Warm up ***") + if args.warm_up > 0: + for i in range(args.warm_up): + feats_tmp = np.ones((args.batch_size, 1500, 80)).astype(np.float32) + feats_lengths_tmp = np.ones((args.batch_size)).astype(np.int32) * 1500 + mask_tmp = make_pad_mask(feats_lengths_tmp, 1500) + mask_len_tmp = mask_tmp.shape[-1] + pos_emb_tmp = rel_positional_encoding(mask_len_tmp).numpy() + all_inputs = [feats_tmp, mask_tmp, pos_emb_tmp] + tensorrt_infer(engine, context, all_inputs) + + results = [] + for keys, feats, feats_lengths in tqdm(eval_samples): + b, seq_len, feat = feats.shape + + inputs = feats.astype(np.float32) + mask = make_pad_mask(feats_lengths, seq_len) + mask_len = mask.shape[-1] + pos_emb = rel_positional_encoding(mask_len).numpy() + + all_inputs = [inputs, mask, pos_emb] + hyps = tensorrt_infer( + engine, + context, + all_inputs + ) + + ctc_probs = torch.from_numpy(hyps) + ctc_lens = torch.from_numpy(feats_lengths) + hyps = ctc_greedy_search(ctc_probs, ctc_lens) + + for i, key in enumerate(keys): + line = f"{key} " + for w in hyps[i]: + w = w - 1 + if w == eos: + break + line += char_dict[w] + results.append(line) + + # 3. 计算 CER + reference_file = os.path.join(args.data_dir, "text") + reference_data = [] + for line in open(reference_file, "r", encoding="utf-8"): + reference_data.append(line) + + cer, corr = calculate_cer(results, reference_data) + target_cer = float(os.environ["Accuracy"]) + print("CER: ", cer, "target CER: ", target_cer) + if cer <= target_cer: + print("pass!") + exit() + else: + print("failed!") + exit(1) + + +if __name__ == "__main__": + main() diff --git a/models/speech/speech_recognition/conformer/ixrt/ixrt_inference_performance.py b/models/speech/speech_recognition/conformer/ixrt/ixrt_inference_performance.py new file mode 100644 index 0000000000000000000000000000000000000000..c19233fa6813722083e1e86fbfc310dcd1370670 --- /dev/null +++ b/models/speech/speech_recognition/conformer/ixrt/ixrt_inference_performance.py @@ -0,0 +1,273 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import os +import sys +import time + +sys.path.append(os.path.dirname(os.path.dirname(__file__))) + +import argparse +import yaml +import copy +import torch +import numpy as np + +from tqdm.contrib import tqdm +from torch.utils.data import DataLoader +from wenet.file_utils import read_symbol_table +from wenet.dataset import Dataset +from tools.compute_cer import Calculator, characterize, normalize, default_cluster +import tensorrt +from tensorrt import Dims +from common import create_engine_context, get_io_bindings,trtapi,setup_io_bindings +import pickle + +import pycuda.autoinit +import pycuda.driver as cuda + +from utils import make_pad_mask, RelPositionalEncoding +from postprocess import ctc_greedy_search + + +rel_positional_encoding = RelPositionalEncoding(256, 0.1) + + +def get_args(): + parser = argparse.ArgumentParser(description="recognize with your model") + parser.add_argument( + "--infer_type", + default="fp16", + choices=["fp16", "int8"], + help="inference type: fp16 or int8", + ) + parser.add_argument("--warm_up", type=int, default=3, help="warm_up count") + parser.add_argument("--batch_size", type=int, default=24) + parser.add_argument("--data_dir", required=True, help="test data directory") + parser.add_argument( + "--model_dir", type=str, required=True, help="model for inference" + ) + args = parser.parse_args() + return args + + +def tensorrt_infer(engine, context, all_inputs): + + input_names = ["input", "mask", "pos_emb"] + output_names = ["output"] + + for input_name, input_data in zip(input_names, all_inputs): + input_idx = engine.get_binding_index(input_name) + input_shape = input_data.shape + context.set_binding_shape(input_idx, Dims(input_shape)) + + inputs, outputs, allocations = setup_io_bindings(engine, context) + pred_output = np.zeros(outputs[0]["shape"], outputs[0]["dtype"]) + + for i, input_data in enumerate(all_inputs): + cuda.memcpy_htod(inputs[i]["allocation"], input_data) + + context.execute_v2(allocations) + cuda.memcpy_dtoh(pred_output, outputs[0]["allocation"]) + return pred_output + + +def engine_init(engine): + host_mem = tensorrt.IHostMemory + logger = tensorrt.Logger(tensorrt.Logger.ERROR) + engine, context = create_engine_context(engine, logger) + + return engine,context + + +def calculate_cer(data, reference_data): + calculator = Calculator() + tochar = True + split = None + case_sensitive = False + ignore_words = set() + rec_set = {} + for line in data: + if tochar: + array = characterize(line) + else: + array = line.strip().split() + if len(array) == 0: + continue + fid = array[0] + rec_set[fid] = normalize(array[1:], ignore_words, case_sensitive, split) + + default_clusters = {} + default_words = {} + for line in reference_data: + if tochar: + array = characterize(line) + else: + array = line.strip().split() + if len(array) == 0: + continue + fid = array[0] + if fid not in rec_set: + continue + lab = normalize(array[1:], ignore_words, case_sensitive, split) + rec = rec_set[fid] + + for word in rec + lab: + if word not in default_words: + default_cluster_name = default_cluster(word) + if default_cluster_name not in default_clusters: + default_clusters[default_cluster_name] = {} + if word not in default_clusters[default_cluster_name]: + default_clusters[default_cluster_name][word] = 1 + default_words[word] = default_cluster_name + result = calculator.calculate(lab, rec) + + result = calculator.overall() + cer = float(result["ins"] + result["sub"] + result["del"]) / result["all"] + corr = result["cor"] / result["all"] + + return cer, corr + + +def main(): + args = get_args() + + # 读取配置文件 + config_fn = os.path.join(args.model_dir, "config.yaml") + with open(config_fn, "r") as fin: + configs = yaml.load(fin, Loader=yaml.FullLoader) + + dataset_conf = copy.deepcopy(configs["dataset_conf"]) + dataset_conf["filter_conf"]["max_length"] = 102400 + dataset_conf["filter_conf"]["min_length"] = 0 + dataset_conf["filter_conf"]["token_max_length"] = 102400 + dataset_conf["filter_conf"]["token_min_length"] = 0 + dataset_conf["filter_conf"]["max_output_input_ratio"] = 102400 + dataset_conf["filter_conf"]["min_output_input_ratio"] = 0 + dataset_conf["speed_perturb"] = False + dataset_conf["spec_aug"] = False + dataset_conf["shuffle"] = False + dataset_conf["sort"] = True + dataset_conf["fbank_conf"]["dither"] = 0.0 + dataset_conf["batch_conf"]["batch_type"] = "static" + dataset_conf["batch_conf"]["batch_size"] = args.batch_size + + # Load dict + dict_fn = os.path.join(args.model_dir, "words.txt") + char_dict = {} + with open(dict_fn, "r", encoding="utf8") as fin: + for line in fin: + arr = line.strip().split() + assert len(arr) == 2 + char_dict[int(arr[1])] = arr[0] + eos = len(char_dict) - 1 + + data_type = "raw" + test_data_fn = os.path.join(args.data_dir, "data.list") + symbol_table = read_symbol_table(dict_fn) + test_dataset = Dataset( + data_type, test_data_fn, symbol_table, dataset_conf, partition=False + ) + test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0) + + data_path_pkl = os.path.join(args.data_dir, f"aishell_test_data_bs{args.batch_size}.pkl") + + print("*** 1. Prepare data ***") + if not os.path.isfile(data_path_pkl): + eval_samples = [] + max_batch_size = -1 + max_feature_length = -1 + for batch in test_data_loader: + keys, feats, target, feats_lengths, target_lengths = batch + max_feature_length = max(max_feature_length, feats.size(1)) + max_batch_size = max(max_batch_size, feats.size(0)) + eval_samples.append( + [ + keys, + feats.cpu().numpy().astype(np.float16), + feats_lengths.cpu().numpy().astype(np.int32), + ] + ) + with open(data_path_pkl, "wb") as f: + pickle.dump( + [ + eval_samples, + max_batch_size, + max_feature_length + ], + f, + ) + else: + print(f"load data from tmp: {data_path_pkl}") + with open(data_path_pkl, "rb") as f: + ( + eval_samples, + max_batch_size, + max_feature_length + ) = pickle.load(f) + print( + f"dataset max shape: batch_size: {max_batch_size}, feat_length: {max_feature_length}" + ) + + print("*** 2. Load engine ***") + engine_path = os.path.join(args.model_dir, f"conformer_encoder_fusion.engine") + engine, context = engine_init(engine_path) + + print("*** 3. Warm up ***") + if args.warm_up > 0: + for i in range(args.warm_up): + feats_tmp = np.ones((args.batch_size, 1500, 80)).astype(np.float32) + feats_lengths_tmp = np.ones((args.batch_size)).astype(np.int32) * 1500 + mask_tmp = make_pad_mask(feats_lengths_tmp, 1500) + mask_len_tmp = mask_tmp.shape[-1] + pos_emb_tmp = rel_positional_encoding(mask_len_tmp).numpy() + all_inputs = [feats_tmp, mask_tmp, pos_emb_tmp] + tensorrt_infer(engine, context, all_inputs) + + print("*** 4. Inference ***") + start_time = time.time() + num_samples = 0 + results = [] + for keys, feats, feats_lengths in tqdm(eval_samples): + b, seq_len, feat = feats.shape + num_samples += b + inputs = feats.astype(np.float32) + mask = make_pad_mask(feats_lengths, seq_len) + mask_len = mask.shape[-1] + pos_emb = rel_positional_encoding(mask_len).numpy() + + all_inputs = [inputs, mask, pos_emb] + hyps = tensorrt_infer( + engine, + context, + all_inputs + ) + + eval_time = time.time() - start_time + + QPS = num_samples / eval_time + print(f"Recognize {num_samples} sentences, {QPS} sentences/s") + target_qps = float(os.environ["Accuracy"]) + print("QPS: = ", QPS, "target QPS: ", target_qps) + if QPS >= target_qps: + print("pass!") + exit() + else: + print("failed!") + exit(1) + + +if __name__ == "__main__": + main() diff --git a/models/speech/speech_recognition/conformer/ixrt/postprocess/__init__.py b/models/speech/speech_recognition/conformer/ixrt/postprocess/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..33f8b0465aee011298fa9933086fbdc1c8dbd4d4 --- /dev/null +++ b/models/speech/speech_recognition/conformer/ixrt/postprocess/__init__.py @@ -0,0 +1 @@ +from .search import ctc_greedy_search diff --git a/models/speech/speech_recognition/conformer/ixrt/postprocess/search.py b/models/speech/speech_recognition/conformer/ixrt/postprocess/search.py new file mode 100644 index 0000000000000000000000000000000000000000..d2ae55650539b9d0be352e78a64999606ac12fbb --- /dev/null +++ b/models/speech/speech_recognition/conformer/ixrt/postprocess/search.py @@ -0,0 +1,103 @@ +import math +from collections import defaultdict +from typing import List, Dict + +import torch +from torch.nn.utils.rnn import pad_sequence + + +def remove_duplicates_and_blank(hyp: List[int], + blank_id: int = 0) -> List[int]: + new_hyp: List[int] = [] + cur = 0 + while cur < len(hyp): + if hyp[cur] != blank_id: + new_hyp.append(hyp[cur]) + prev = cur + while cur < len(hyp) and hyp[cur] == hyp[prev]: + cur += 1 + return new_hyp + + +def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor: + """Make mask tensor containing indices of padded part. + + See description of make_non_pad_mask. + + Args: + lengths (torch.Tensor): Batch of lengths (B,). + Returns: + torch.Tensor: Mask tensor containing indices of padded part. + + Examples: + >>> lengths = [5, 3, 2] + >>> make_pad_mask(lengths) + masks = [[0, 0, 0, 0 ,0], + [0, 0, 0, 1, 1], + [0, 0, 1, 1, 1]] + """ + batch_size = lengths.size(0) + max_len = max_len if max_len > 0 else lengths.max().item() + seq_range = torch.arange(0, + max_len, + dtype=torch.int64, + device=lengths.device) + seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) + seq_length_expand = lengths.unsqueeze(-1) + mask = seq_range_expand >= seq_length_expand + + mask = mask[:, 2::2][:, 2::2] + return mask + + +class DecodeResult: + + def __init__(self, + tokens: List[int], + score: float = 0.0, + confidence: float = 0.0, + tokens_confidence: List[float] = None, + times: List[int] = None, + nbest: List[List[int]] = None, + nbest_scores: List[float] = None, + nbest_times: List[List[int]] = None): + """ + Args: + tokens: decode token list + score: the total decode score of this result + confidence: the total confidence of this result, it's in 0~1 + tokens_confidence: confidence of each token + times: timestamp of each token, list of (start, end) + nbest: nbest result + nbest_scores: score of each nbest + nbest_times: + """ + self.tokens = tokens + self.score = score + self.confidence = confidence + self.tokens_confidence = tokens_confidence + self.times = times + self.nbest = nbest + self.nbest_scores = nbest_scores + self.nbest_times = nbest_times + + +def ctc_greedy_search(ctc_probs: torch.Tensor, + ctc_lens: torch.Tensor, + blank_id: int = 0) -> List[DecodeResult]: + + batch_size = ctc_probs.shape[0] + maxlen = ctc_probs.size(1) + topk_prob, topk_index = ctc_probs.topk(1, dim=2) # (B, maxlen, 1) + topk_index = topk_index.view(batch_size, maxlen) # (B, maxlen) + + mask_ctc_lens = ctc_lens[0].item() + mask = make_pad_mask(ctc_lens, mask_ctc_lens) # (B, maxlen) + topk_index = topk_index.masked_fill_(mask, blank_id) # (B, maxlen) + hyps = [hyp.tolist() for hyp in topk_index] + scores = topk_prob.max(1) + results = [] + for hyp in hyps: + results.append(remove_duplicates_and_blank(hyp, blank_id)) + return results + diff --git a/models/speech/speech_recognition/conformer/ixrt/scripts/aishell_data_prepare.sh b/models/speech/speech_recognition/conformer/ixrt/scripts/aishell_data_prepare.sh new file mode 100755 index 0000000000000000000000000000000000000000..985564c2294b2a413531d6ced018029ec911fb23 --- /dev/null +++ b/models/speech/speech_recognition/conformer/ixrt/scripts/aishell_data_prepare.sh @@ -0,0 +1,61 @@ +#!/bin/bash +# Copyright 2019 Mobvoi Inc. All Rights Reserved. +# set -euox pipefail + +data_dir=$1 +tool_dir=$2 + +wav_dir=${data_dir}/wav +aishell_text=${data_dir}/transcript/aishell_transcript_v0.8.txt + +# data directory check +if [ ! -d $wav_dir ] || [ ! -f $aishell_text ]; then + echo "Error: wav directory and aishell text not found!" + exit 1; +fi + +# find test wav file +local_dir=${data_dir}/local +mkdir -p $local_dir +find $wav_dir -iname "*.wav" > $local_dir/wav.flist || exit 1; + +# Transcriptions preparation +sed -e 's/\.wav//' $local_dir/wav.flist | awk -F '/' '{print $NF}' > $local_dir/utt.list +paste -d' ' $local_dir/utt.list $local_dir/wav.flist > $local_dir/wav.scp_all +${tool_dir}/filter_scp.pl -f 1 $local_dir/utt.list $aishell_text > $local_dir/transcripts.txt +awk '{print $1}' $local_dir/transcripts.txt > $local_dir/utt.list +${tool_dir}/filter_scp.pl -f 1 $local_dir/utt.list $local_dir/wav.scp_all | sort -u > $local_dir/wav.scp +sort -u $local_dir/transcripts.txt > $local_dir/text +echo "Preparing transcriptions succeeded!" + +test_dir=${data_dir}/test +mkdir -p ${test_dir} +for f in wav.scp text; do + cp $local_dir/$f ${test_dir}/$f || exit 1; +done +rm -r ${data_dir}/local + +# data_type can be `raw` or `shard`. Typically, raw is used for small dataset, +# `shard` is used for large dataset which is over 1k hours, and `shard` is +# faster on reading data and training. +data_type=raw +num_utts_per_shard=1000 + +# remove the space between the text labels for Mandarin dataset +cp $test_dir/text $test_dir/text.org +paste -d " " <(cut -f 1 -d" " ${test_dir}/text.org) \ + <(cut -f 2- -d" " ${test_dir}/text.org | tr -d " ") \ + > ${test_dir}/text +rm ${test_dir}/text.org + +# Prepare required format +if [ $data_type == "shard" ]; then + ${tool_dir}/make_shard_list.py --num_utts_per_shard $num_utts_per_shard \ + --num_threads 16 $test_dir/wav.scp $test_dir/text \ + $(realpath $test_dir/shards) $test_dir/data.list +else + ${tool_dir}/make_raw_list.py $test_dir/wav.scp $test_dir/text \ + $test_dir/data.list +fi + +echo "AISHELL data preparation succeeded!" \ No newline at end of file diff --git a/models/speech/speech_recognition/conformer/ixrt/scripts/infer_conformer_fp16_accuracy_ixrt.sh b/models/speech/speech_recognition/conformer/ixrt/scripts/infer_conformer_fp16_accuracy_ixrt.sh new file mode 100644 index 0000000000000000000000000000000000000000..f1af4bb4e03a0c9c6084ae7a122f66f765c27c86 --- /dev/null +++ b/models/speech/speech_recognition/conformer/ixrt/scripts/infer_conformer_fp16_accuracy_ixrt.sh @@ -0,0 +1,49 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +set -euo pipefail + +current_path=$(cd $(dirname "${BASH_SOURCE[0]}") && pwd) + +PROJECT_DIR=${current_path}/.. +DATA_DIR=${current_path}/../aishell_test_data/test +MODEL_DIR=${current_path}/../conformer_checkpoints + +export Accuracy=${Accuracy:=0.052} + +cd ${PROJECT_DIR} + +echo "Step1.Export Onnx From Checkpoints!" +python3 convert2onnx.py \ + --model_name "Conformer" \ + --model_path=${MODEL_DIR}/final.pt \ + --onnx_path=${MODEL_DIR}/conformer_encoder_fusion.onnx \ + --batch_size=8 + +echo "Step2.Build Engine!" +python3 build_engine.py \ + --model_name "Conformer" \ + --onnx_path=${MODEL_DIR}/conformer_encoder_fusion.onnx \ + --engine_path=${MODEL_DIR}/conformer_encoder_fusion.engine \ + --max_batch_size=8 \ + --max_seq_len=1500 + +echo "Step3.Inference(Test ACC)!" +python3 ixrt_inference_accuracy.py \ + --infer_type fp16 \ + --warm_up 3 \ + --batch_size ${BATCH_SIZE:=8} \ + --data_dir ${DATA_DIR} \ + --model_dir ${MODEL_DIR} diff --git a/models/speech/speech_recognition/conformer/ixrt/scripts/infer_conformer_fp16_performance_ixrt.sh b/models/speech/speech_recognition/conformer/ixrt/scripts/infer_conformer_fp16_performance_ixrt.sh new file mode 100644 index 0000000000000000000000000000000000000000..dc02673c03fb21a4301b757a18885af81cbad31d --- /dev/null +++ b/models/speech/speech_recognition/conformer/ixrt/scripts/infer_conformer_fp16_performance_ixrt.sh @@ -0,0 +1,59 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +set -euo pipefail + + +EXIT_STATUS=0 +check_status() +{ + if ((${PIPESTATUS[0]} != 0));then + echo "fails" + EXIT_STATUS=1 + fi +} + +current_path=$(cd $(dirname "${BASH_SOURCE[0]}") && pwd) + +PROJECT_DIR=${current_path}/.. +DATA_DIR=${current_path}/../aishell_test_data/test +MODEL_DIR=${current_path}/../conformer_checkpoints + +export Accuracy=${Accuracy:=350} + +cd ${PROJECT_DIR} + + +echo "Step1.Export Onnx From Checkpoints!" +python3 convert2onnx.py \ + --model_name "Conformer" \ + --model_path=${MODEL_DIR}/final.pt \ + --onnx_path=${MODEL_DIR}/conformer_encoder_fusion.onnx \ + --batch_size=24 + +echo "Step2.Build Engine!" +python3 build_engine.py \ + --model_name "Conformer" \ + --onnx_path=${MODEL_DIR}/conformer_encoder_fusion.onnx \ + --engine_path=${MODEL_DIR}/conformer_encoder_fusion.engine \ + --max_batch_size=24 \ + --max_seq_len=1500 + +echo "Step3.Inference(Test QPS)!" +python3 ixrt_inference_performance.py \ + --infer_type fp16 \ + --batch_size ${BATCH_SIZE:=24} \ + --data_dir ${DATA_DIR} \ + --model_dir ${MODEL_DIR} diff --git a/models/speech/speech_recognition/conformer/ixrt/tools/__init__.py b/models/speech/speech_recognition/conformer/ixrt/tools/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/speech/speech_recognition/conformer/ixrt/tools/compute_cer.py b/models/speech/speech_recognition/conformer/ixrt/tools/compute_cer.py new file mode 100755 index 0000000000000000000000000000000000000000..a5db08979f4d31a4a2ac9e4ceb0d122537690aac --- /dev/null +++ b/models/speech/speech_recognition/conformer/ixrt/tools/compute_cer.py @@ -0,0 +1,532 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + + +import sys +import unicodedata +import codecs + +remove_tag = True +spacelist = [' ', '\t', '\r', '\n'] +puncts = ['!', ',', '?', + '、', '。', '!', ',', ';', '?', + ':', '「', '」', '︰', '『', '』', '《', '》'] + +def characterize(string) : + res = [] + i = 0 + while i < len(string): + char = string[i] + if char in puncts: + i += 1 + continue + cat1 = unicodedata.category(char) + # https://unicodebook.readthedocs.io/unicode.html#unicode-categories + if cat1 == 'Zs' or cat1 == 'Cn' or char in spacelist: # space or not assigned + i += 1 + continue + if cat1 == 'Lo': # letter-other + res.append(char) + i += 1 + else: + # some input looks like: , we want to separate it to two words. + sep = ' ' + if char == '<': + sep = '>' + j = i + 1 + while j < len(string): + c = string[j] + if ord(c) >= 128 or (c in spacelist) or (c == sep): + break + j += 1 + if j < len(string) and string[j] == '>': + j += 1 + res.append(string[i:j]) + i = j + return res + +def stripoff_tags(x): + if not x: + return '' + chars = [] + i = 0 + T = len(x) + while i < T: + if x[i] == '<': + while i < T and x[i] != '>': + i += 1 + i += 1 + else: + chars.append(x[i]) + i += 1 + return ''.join(chars) + + +def normalize(sentence, ignore_words, cs, split=None): + """ sentence, ignore_words are both in unicode + """ + new_sentence = [] + for token in sentence: + x = token + if not cs: + x = x.upper() + if x in ignore_words: + continue + if remove_tag: + x = stripoff_tags(x) + if not x: + continue + if split and x in split: + new_sentence += split[x] + if x.isalnum(): + for k in x: + new_sentence.append(k) + else: + new_sentence.append(x) + return new_sentence + +class Calculator : + def __init__(self) : + self.data = {} + self.space = [] + self.cost = {} + self.cost['cor'] = 0 + self.cost['sub'] = 1 + self.cost['del'] = 1 + self.cost['ins'] = 1 + + def calculate(self, lab, rec) : + # Initialization + lab.insert(0, '') + rec.insert(0, '') + while len(self.space) < len(lab) : + self.space.append([]) + for row in self.space : + for element in row : + element['dist'] = 0 + element['error'] = 'non' + while len(row) < len(rec) : + row.append({'dist' : 0, 'error' : 'non'}) + for i in range(len(lab)) : + self.space[i][0]['dist'] = i + self.space[i][0]['error'] = 'del' + for j in range(len(rec)) : + self.space[0][j]['dist'] = j + self.space[0][j]['error'] = 'ins' + self.space[0][0]['error'] = 'non' + for token in lab : + if token not in self.data and len(token) > 0 : + self.data[token] = {'all' : 0, 'cor' : 0, 'sub' : 0, + 'ins' : 0, 'del' : 0} + for token in rec : + if token not in self.data and len(token) > 0 : + self.data[token] = {'all' : 0, 'cor' : 0, 'sub' : 0, + 'ins' : 0, 'del' : 0} + # Computing edit distance + for i, lab_token in enumerate(lab) : + for j, rec_token in enumerate(rec) : + if i == 0 or j == 0 : + continue + min_dist = sys.maxsize + min_error = 'none' + dist = self.space[i - 1][j]['dist'] + self.cost['del'] + error = 'del' + if dist < min_dist : + min_dist = dist + min_error = error + dist = self.space[i][j - 1]['dist'] + self.cost['ins'] + error = 'ins' + if dist < min_dist : + min_dist = dist + min_error = error + if lab_token == rec_token : + dist = self.space[i - 1][j - 1]['dist'] + self.cost['cor'] + error = 'cor' + else : + dist = self.space[i - 1][j - 1]['dist'] + self.cost['sub'] + error = 'sub' + if dist < min_dist : + min_dist = dist + min_error = error + self.space[i][j]['dist'] = min_dist + self.space[i][j]['error'] = min_error + # Tracing back + result = {'lab': [], 'rec': [], 'all': 0, 'cor': 0, 'sub': 0, + 'ins': 0, 'del': 0} + i = len(lab) - 1 + j = len(rec) - 1 + while True : + if self.space[i][j]['error'] == 'cor' : # correct + if len(lab[i]) > 0 : + self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1 + self.data[lab[i]]['cor'] = self.data[lab[i]]['cor'] + 1 + result['all'] = result['all'] + 1 + result['cor'] = result['cor'] + 1 + result['lab'].insert(0, lab[i]) + result['rec'].insert(0, rec[j]) + i = i - 1 + j = j - 1 + elif self.space[i][j]['error'] == 'sub' : # substitution + if len(lab[i]) > 0 : + self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1 + self.data[lab[i]]['sub'] = self.data[lab[i]]['sub'] + 1 + result['all'] = result['all'] + 1 + result['sub'] = result['sub'] + 1 + result['lab'].insert(0, lab[i]) + result['rec'].insert(0, rec[j]) + i = i - 1 + j = j - 1 + elif self.space[i][j]['error'] == 'del' : # deletion + if len(lab[i]) > 0 : + self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1 + self.data[lab[i]]['del'] = self.data[lab[i]]['del'] + 1 + result['all'] = result['all'] + 1 + result['del'] = result['del'] + 1 + result['lab'].insert(0, lab[i]) + result['rec'].insert(0, "") + i = i - 1 + elif self.space[i][j]['error'] == 'ins' : # insertion + if len(rec[j]) > 0 : + self.data[rec[j]]['ins'] = self.data[rec[j]]['ins'] + 1 + result['ins'] = result['ins'] + 1 + result['lab'].insert(0, "") + result['rec'].insert(0, rec[j]) + j = j - 1 + elif self.space[i][j]['error'] == 'non' : # starting point + break + else : # shouldn't reach here + print('this should not happen , i={i} , j={j} , \ + error={error}'. + format(i=i, j=j, error=self.space[i][j]['error'])) + return result + + def overall(self) : + result = {'all': 0, 'cor': 0, 'sub': 0, 'ins': 0, 'del': 0} + for token in self.data : + result['all'] = result['all'] + self.data[token]['all'] + result['cor'] = result['cor'] + self.data[token]['cor'] + result['sub'] = result['sub'] + self.data[token]['sub'] + result['ins'] = result['ins'] + self.data[token]['ins'] + result['del'] = result['del'] + self.data[token]['del'] + return result + + def cluster(self, data) : + result = {'all': 0, 'cor': 0, 'sub': 0, 'ins': 0, 'del': 0} + for token in data : + if token in self.data : + result['all'] = result['all'] + self.data[token]['all'] + result['cor'] = result['cor'] + self.data[token]['cor'] + result['sub'] = result['sub'] + self.data[token]['sub'] + result['ins'] = result['ins'] + self.data[token]['ins'] + result['del'] = result['del'] + self.data[token]['del'] + return result + + def keys(self) : + return list(self.data.keys()) + +def width(string): + return sum(1 + (unicodedata.east_asian_width(c) in "AFW") for c in string) + +def default_cluster(word) : + unicode_names = [unicodedata.name(char) for char in word] + for i in reversed(range(len(unicode_names))) : + if unicode_names[i].startswith('DIGIT') : # 1 + unicode_names[i] = 'Number' # 'DIGIT' + elif (unicode_names[i].startswith('CJK UNIFIED IDEOGRAPH') or + unicode_names[i].startswith('CJK COMPATIBILITY IDEOGRAPH')) : + # 明 / 郎 + unicode_names[i] = 'Mandarin' # 'CJK IDEOGRAPH' + elif (unicode_names[i].startswith('LATIN CAPITAL LETTER') or + unicode_names[i].startswith('LATIN SMALL LETTER')) : + # A / a + unicode_names[i] = 'English' # 'LATIN LETTER' + elif unicode_names[i].startswith('HIRAGANA LETTER') : # は こ め + unicode_names[i] = 'Japanese' # 'GANA LETTER' + elif (unicode_names[i].startswith('AMPERSAND') or + unicode_names[i].startswith('APOSTROPHE') or + unicode_names[i].startswith('COMMERCIAL AT') or + unicode_names[i].startswith('DEGREE CELSIUS') or + unicode_names[i].startswith('EQUALS SIGN') or + unicode_names[i].startswith('FULL STOP') or + unicode_names[i].startswith('HYPHEN-MINUS') or + unicode_names[i].startswith('LOW LINE') or + unicode_names[i].startswith('NUMBER SIGN') or + unicode_names[i].startswith('PLUS SIGN') or + unicode_names[i].startswith('SEMICOLON')) : + # & / ' / @ / ℃ / = / . / - / _ / # / + / ; + del unicode_names[i] + else : + return 'Other' + if len(unicode_names) == 0 : + return 'Other' + if len(unicode_names) == 1 : + return unicode_names[0] + for i in range(len(unicode_names) - 1) : + if unicode_names[i] != unicode_names[i + 1] : + return 'Other' + return unicode_names[0] + +def usage() : + print("compute-wer.py : compute word error rate (WER) \ + and align recognition results and references.") + print(" usage : python compute-wer.py [--cs={0,1}] \ + [--cluster=foo] [--ig=ignore_file] [--char={0,1}] [--v={0,1}] \ + [--padding-symbol={space,underline}] test.ref test.hyp > test.wer") + +if __name__ == '__main__': + if len(sys.argv) == 1 : + usage() + sys.exit(0) + calculator = Calculator() + cluster_file = '' + ignore_words = set() + tochar = False + verbose = 1 + padding_symbol = ' ' + case_sensitive = False + max_words_per_line = sys.maxsize + split = None + while len(sys.argv) > 3: + a = '--maxw=' + if sys.argv[1].startswith(a): + b = sys.argv[1][len(a):] + del sys.argv[1] + max_words_per_line = int(b) + continue + a = '--rt=' + if sys.argv[1].startswith(a): + b = sys.argv[1][len(a):].lower() + del sys.argv[1] + remove_tag = (b == 'true') or (b != '0') + continue + a = '--cs=' + if sys.argv[1].startswith(a): + b = sys.argv[1][len(a):].lower() + del sys.argv[1] + case_sensitive = (b == 'true') or (b != '0') + continue + a = '--cluster=' + if sys.argv[1].startswith(a): + cluster_file = sys.argv[1][len(a):] + del sys.argv[1] + continue + a = '--splitfile=' + if sys.argv[1].startswith(a): + split_file = sys.argv[1][len(a):] + del sys.argv[1] + split = dict() + with codecs.open(split_file, 'r', 'utf-8') as fh: + for line in fh: # line in unicode + words = line.strip().split() + if len(words) >= 2: + split[words[0]] = words[1:] + continue + a = '--ig=' + if sys.argv[1].startswith(a): + ignore_file = sys.argv[1][len(a):] + del sys.argv[1] + with codecs.open(ignore_file, 'r', 'utf-8') as fh: + for line in fh: # line in unicode + line = line.strip() + if len(line) > 0: + ignore_words.add(line) + continue + a = '--char=' + if sys.argv[1].startswith(a): + b = sys.argv[1][len(a):].lower() + del sys.argv[1] + tochar = (b == 'true') or (b != '0') + continue + a = '--v=' + if sys.argv[1].startswith(a): + b = sys.argv[1][len(a):].lower() + del sys.argv[1] + verbose = 0 + try: + verbose = int(b) + except Exception: + if b == 'true' or b != '0': + verbose = 1 + continue + a = '--padding-symbol=' + if sys.argv[1].startswith(a): + b = sys.argv[1][len(a):].lower() + del sys.argv[1] + if b == 'space': + padding_symbol = ' ' + elif b == 'underline': + padding_symbol = '_' + continue + if True or sys.argv[1].startswith('-'): + # ignore invalid switch + del sys.argv[1] + continue + + if not case_sensitive: + ig = set([w.upper() for w in ignore_words]) + ignore_words = ig + + default_clusters = {} + default_words = {} + + ref_file = sys.argv[1] + hyp_file = sys.argv[2] + rec_set = {} + if split and not case_sensitive: + newsplit = dict() + for w in split: + words = split[w] + for i in range(len(words)): + words[i] = words[i].upper() + newsplit[w.upper()] = words + split = newsplit + + with codecs.open(hyp_file, 'r', 'utf-8') as fh: + for line in fh: + if tochar: + array = characterize(line) + else: + array = line.strip().split() + if len(array) == 0: + continue + fid = array[0] + rec_set[fid] = normalize(array[1:], ignore_words, + case_sensitive, split) + + # compute error rate on the interaction of reference file and hyp file + for line in open(ref_file, 'r', encoding='utf-8') : + if tochar: + array = characterize(line) + else: + array = line.rstrip('\n').split() + if len(array) == 0: + continue + fid = array[0] + if fid not in rec_set: + continue + lab = normalize(array[1:], ignore_words, case_sensitive, split) + rec = rec_set[fid] + if verbose: + print('\nutt: %s' % fid) + + for word in rec + lab : + if word not in default_words : + default_cluster_name = default_cluster(word) + if default_cluster_name not in default_clusters : + default_clusters[default_cluster_name] = {} + if word not in default_clusters[default_cluster_name] : + default_clusters[default_cluster_name][word] = 1 + default_words[word] = default_cluster_name + + result = calculator.calculate(lab, rec) + if verbose: + if result['all'] != 0 : + wer = float(result['ins'] + result['sub'] + + result['del']) * 100.0 / result['all'] + else : + wer = 0.0 + print('WER: %4.2f %%' % wer, end=' ') + print('N=%d C=%d S=%d D=%d I=%d' % + (result['all'], result['cor'], result['sub'], + result['del'], result['ins'])) + space = {} + space['lab'] = [] + space['rec'] = [] + for idx in range(len(result['lab'])) : + len_lab = width(result['lab'][idx]) + len_rec = width(result['rec'][idx]) + length = max(len_lab, len_rec) + space['lab'].append(length - len_lab) + space['rec'].append(length - len_rec) + upper_lab = len(result['lab']) + upper_rec = len(result['rec']) + lab1, rec1 = 0, 0 + while lab1 < upper_lab or rec1 < upper_rec: + if verbose > 1: + print('lab(%s):' % fid.encode('utf-8'), end=' ') + else: + print('lab:', end=' ') + lab2 = min(upper_lab, lab1 + max_words_per_line) + for idx in range(lab1, lab2): + token = result['lab'][idx] + print('{token}'.format(token=token), end='') + for n in range(space['lab'][idx]) : + print(padding_symbol, end='') + print(' ', end='') + print() + if verbose > 1: + print('rec(%s):' % fid.encode('utf-8'), end=' ') + else: + print('rec:', end=' ') + rec2 = min(upper_rec, rec1 + max_words_per_line) + for idx in range(rec1, rec2): + token = result['rec'][idx] + print('{token}'.format(token=token), end='') + for n in range(space['rec'][idx]) : + print(padding_symbol, end='') + print(' ', end='') + print('\n', end='\n') + lab1 = lab2 + rec1 = rec2 + + if verbose: + print('===================================================' + '========================') + print() + + result = calculator.overall() + if result['all'] != 0 : + wer = float(result['ins'] + result['sub'] + + result['del']) * 100.0 / result['all'] + else : + wer = 0.0 + print('Overall -> wer %4.2f %% Corr %4.2f %%' % (wer, result['cor']*100/result['all']), end=' ') + print('N=%d C=%d S=%d D=%d I=%d' % + (result['all'], result['cor'], result['sub'], + result['del'], result['ins'])) + if not verbose: + print() + + if verbose: + for cluster_id in default_clusters : + result = calculator.cluster(k for k in default_clusters[cluster_id]) + if result['all'] != 0 : + wer = float(result['ins'] + result['sub'] + + result['del']) * 100.0 / result['all'] + else : + wer = 0.0 + print('%s -> %4.2f %%' % (cluster_id, wer), end=' ') + print('N=%d C=%d S=%d D=%d I=%d' % + (result['all'], result['cor'], result['sub'], + result['del'], result['ins'])) + if len(cluster_file) > 0 : # compute separated WERs for word clusters + cluster_id = '' + cluster = [] + for line in open(cluster_file, 'r', encoding='utf-8') : + for token in line.decode('utf-8').rstrip('\n').split() : + # end of cluster reached, like + if token[0:2] == '' and \ + token.lstrip('') == cluster_id : + result = calculator.cluster(cluster) + if result['all'] != 0 : + wer = float(result['ins'] + result['sub'] + + result['del']) * 100.0 / result['all'] + else : + wer = 0.0 + print('%s -> %4.2f %%' % (cluster_id, wer), end=' ') + print('N=%d C=%d S=%d D=%d I=%d' % + (result['all'], result['cor'], result['sub'], + result['del'], result['ins'])) + cluster_id = '' + cluster = [] + # begin of cluster reached, like + elif (token[0] == '<' and token[len(token) - 1] == '>' and + cluster_id == ''): + cluster_id = token.lstrip('<').rstrip('>') + cluster = [] + # general terms, like WEATHER / CAR / ... + else : + cluster.append(token) + print() + print('=======================================' + '====================================') diff --git a/models/speech/speech_recognition/conformer/ixrt/tools/filter_scp.pl b/models/speech/speech_recognition/conformer/ixrt/tools/filter_scp.pl new file mode 100755 index 0000000000000000000000000000000000000000..b76d37f41be0886470281978bfacf97f6b8ae976 --- /dev/null +++ b/models/speech/speech_recognition/conformer/ixrt/tools/filter_scp.pl @@ -0,0 +1,87 @@ +#!/usr/bin/env perl +# Copyright 2010-2012 Microsoft Corporation +# Johns Hopkins University (author: Daniel Povey) + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + +# This script takes a list of utterance-ids or any file whose first field +# of each line is an utterance-id, and filters an scp +# file (or any file whose "n-th" field is an utterance id), printing +# out only those lines whose "n-th" field is in id_list. The index of +# the "n-th" field is 1, by default, but can be changed by using +# the -f switch + +$exclude = 0; +$field = 1; +$shifted = 0; + +do { + $shifted=0; + if ($ARGV[0] eq "--exclude") { + $exclude = 1; + shift @ARGV; + $shifted=1; + } + if ($ARGV[0] eq "-f") { + $field = $ARGV[1]; + shift @ARGV; shift @ARGV; + $shifted=1 + } +} while ($shifted); + +if(@ARGV < 1 || @ARGV > 2) { + die "Usage: filter_scp.pl [--exclude] [-f ] id_list [in.scp] > out.scp \n" . + "Prints only the input lines whose f'th field (default: first) is in 'id_list'.\n" . + "Note: only the first field of each line in id_list matters. With --exclude, prints\n" . + "only the lines that were *not* in id_list.\n" . + "Caution: previously, the -f option was interpreted as a zero-based field index.\n" . + "If your older scripts (written before Oct 2014) stopped working and you used the\n" . + "-f option, add 1 to the argument.\n" . + "See also: utils/filter_scp.pl .\n"; +} + + +$idlist = shift @ARGV; +open(F, "<$idlist") || die "Could not open id-list file $idlist"; +while() { + @A = split; + @A>=1 || die "Invalid id-list file line $_"; + $seen{$A[0]} = 1; +} + +if ($field == 1) { # Treat this as special case, since it is common. + while(<>) { + $_ =~ m/\s*(\S+)\s*/ || die "Bad line $_, could not get first field."; + # $1 is what we filter on. + if ((!$exclude && $seen{$1}) || ($exclude && !defined $seen{$1})) { + print $_; + } + } +} else { + while(<>) { + @A = split; + @A > 0 || die "Invalid scp file line $_"; + @A >= $field || die "Invalid scp file line $_"; + if ((!$exclude && $seen{$A[$field-1]}) || ($exclude && !defined $seen{$A[$field-1]})) { + print $_; + } + } +} + +# tests: +# the following should print "foo 1" +# ( echo foo 1; echo bar 2 ) | utils/filter_scp.pl <(echo foo) +# the following should print "bar 2". +# ( echo foo 1; echo bar 2 ) | utils/filter_scp.pl -f 2 <(echo 2) diff --git a/models/speech/speech_recognition/conformer/ixrt/tools/make_raw_list.py b/models/speech/speech_recognition/conformer/ixrt/tools/make_raw_list.py new file mode 100755 index 0000000000000000000000000000000000000000..2f84f015542bb38da027b8ea61e8638f873cec33 --- /dev/null +++ b/models/speech/speech_recognition/conformer/ixrt/tools/make_raw_list.py @@ -0,0 +1,59 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='') + parser.add_argument('--segments', default=None, help='segments file') + parser.add_argument('wav_file', help='wav file') + parser.add_argument('text_file', help='text file') + parser.add_argument('output_file', help='output list file') + args = parser.parse_args() + + wav_table = {} + with open(args.wav_file, 'r', encoding='utf8') as fin: + for line in fin: + arr = line.strip().split() + assert len(arr) == 2 + wav_table[arr[0]] = arr[1] + + if args.segments is not None: + segments_table = {} + with open(args.segments, 'r', encoding='utf8') as fin: + for line in fin: + arr = line.strip().split() + assert len(arr) == 4 + segments_table[arr[0]] = (arr[1], float(arr[2]), float(arr[3])) + + with open(args.text_file, 'r', encoding='utf8') as fin, \ + open(args.output_file, 'w', encoding='utf8') as fout: + for line in fin: + arr = line.strip().split(maxsplit=1) + key = arr[0] + txt = arr[1] if len(arr) > 1 else '' + if args.segments is None: + assert key in wav_table + wav = wav_table[key] + line = dict(key=key, wav=wav, txt=txt) + else: + assert key in segments_table + wav_key, start, end = segments_table[key] + wav = wav_table[wav_key] + line = dict(key=key, wav=wav, txt=txt, start=start, end=end) + json_line = json.dumps(line, ensure_ascii=False) + fout.write(json_line + '\n') diff --git a/models/speech/speech_recognition/conformer/ixrt/tools/make_shard_list.py b/models/speech/speech_recognition/conformer/ixrt/tools/make_shard_list.py new file mode 100755 index 0000000000000000000000000000000000000000..fcd4bcd7d62ba933cf27c34fc02e18371a6b10a6 --- /dev/null +++ b/models/speech/speech_recognition/conformer/ixrt/tools/make_shard_list.py @@ -0,0 +1,181 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import io +import logging +import os +import tarfile +import time +import multiprocessing + +import torch +import torchaudio +import torchaudio.backend.sox_io_backend as sox + +AUDIO_FORMAT_SETS = set(['flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma']) + + +def write_tar_file(data_list, + no_segments, + tar_file, + resample=16000, + index=0, + total=1): + logging.info('Processing {} {}/{}'.format(tar_file, index, total)) + read_time = 0.0 + save_time = 0.0 + write_time = 0.0 + with tarfile.open(tar_file, "w") as tar: + prev_wav = None + for item in data_list: + if no_segments: + key, txt, wav = item + else: + key, txt, wav, start, end = item + + suffix = wav.split('.')[-1] + assert suffix in AUDIO_FORMAT_SETS + if no_segments: + ts = time.time() + with open(wav, 'rb') as fin: + data = fin.read() + read_time += (time.time() - ts) + else: + if wav != prev_wav: + ts = time.time() + waveforms, sample_rate = sox.load(wav, normalize=False) + read_time += (time.time() - ts) + prev_wav = wav + start = int(start * sample_rate) + end = int(end * sample_rate) + audio = waveforms[:1, start:end] + + # resample + if sample_rate != resample: + audio = torchaudio.transforms.Resample( + sample_rate, resample)(audio) + + ts = time.time() + f = io.BytesIO() + sox.save(f, audio, resample, format="wav", bits_per_sample=16) + # Save to wav for segments file + suffix = "wav" + f.seek(0) + data = f.read() + save_time += (time.time() - ts) + + assert isinstance(txt, str) + ts = time.time() + txt_file = key + '.txt' + txt = txt.encode('utf8') + txt_data = io.BytesIO(txt) + txt_info = tarfile.TarInfo(txt_file) + txt_info.size = len(txt) + tar.addfile(txt_info, txt_data) + + wav_file = key + '.' + suffix + wav_data = io.BytesIO(data) + wav_info = tarfile.TarInfo(wav_file) + wav_info.size = len(data) + tar.addfile(wav_info, wav_data) + write_time += (time.time() - ts) + logging.info('read {} save {} write {}'.format(read_time, save_time, + write_time)) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='') + parser.add_argument('--num_utts_per_shard', + type=int, + default=1000, + help='num utts per shard') + parser.add_argument('--num_threads', + type=int, + default=1, + help='num threads for make shards') + parser.add_argument('--prefix', + default='shards', + help='prefix of shards tar file') + parser.add_argument('--segments', default=None, help='segments file') + parser.add_argument('--resample', + type=int, + default=16000, + help='segments file') + parser.add_argument('wav_file', help='wav file') + parser.add_argument('text_file', help='text file') + parser.add_argument('shards_dir', help='output shards dir') + parser.add_argument('shards_list', help='output shards list file') + args = parser.parse_args() + logging.basicConfig(level=logging.INFO, + format='%(asctime)s %(levelname)s %(message)s') + + torch.set_num_threads(1) + wav_table = {} + with open(args.wav_file, 'r', encoding='utf8') as fin: + for line in fin: + arr = line.strip().split() + assert len(arr) == 2 + wav_table[arr[0]] = arr[1] + + no_segments = True + segments_table = {} + if args.segments is not None: + no_segments = False + with open(args.segments, 'r', encoding='utf8') as fin: + for line in fin: + arr = line.strip().split() + assert len(arr) == 4 + segments_table[arr[0]] = (arr[1], float(arr[2]), float(arr[3])) + + data = [] + with open(args.text_file, 'r', encoding='utf8') as fin: + for line in fin: + arr = line.strip().split(maxsplit=1) + key = arr[0] + txt = arr[1] if len(arr) > 1 else '' + if no_segments: + assert key in wav_table + wav = wav_table[key] + data.append((key, txt, wav)) + else: + wav_key, start, end = segments_table[key] + wav = wav_table[wav_key] + data.append((key, txt, wav, start, end)) + + num = args.num_utts_per_shard + chunks = [data[i:i + num] for i in range(0, len(data), num)] + os.makedirs(args.shards_dir, exist_ok=True) + + # Using thread pool to speedup + pool = multiprocessing.Pool(processes=args.num_threads) + shards_list = [] + tasks_list = [] + num_chunks = len(chunks) + for i, chunk in enumerate(chunks): + tar_file = os.path.join(args.shards_dir, + '{}_{:09d}.tar'.format(args.prefix, i)) + shards_list.append(tar_file) + pool.apply_async( + write_tar_file, + (chunk, no_segments, tar_file, args.resample, i, num_chunks)) + + pool.close() + pool.join() + + with open(args.shards_list, 'w', encoding='utf8') as fout: + for name in shards_list: + fout.write(name + '\n') diff --git a/models/speech/speech_recognition/conformer/ixrt/tools/text2token.py b/models/speech/speech_recognition/conformer/ixrt/tools/text2token.py new file mode 100755 index 0000000000000000000000000000000000000000..4f4dcc901d436650695f0b80e0cf99e1e99269ee --- /dev/null +++ b/models/speech/speech_recognition/conformer/ixrt/tools/text2token.py @@ -0,0 +1,171 @@ +#!/usr/bin/env python3 + +# Copyright 2017 Johns Hopkins University (Shinji Watanabe) +# Copyright 2021 JD AI Lab. All Rights Reserved. (authors: Lu Fan) +# Copyright 2021 Mobvoi Inc. All Rights Reserved. (Di Wu) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +from __future__ import print_function +from __future__ import unicode_literals + +import argparse +import codecs +import re +import sys + +is_python2 = sys.version_info[0] == 2 + + +def exist_or_not(i, match_pos): + start_pos = None + end_pos = None + for pos in match_pos: + if pos[0] <= i < pos[1]: + start_pos = pos[0] + end_pos = pos[1] + break + + return start_pos, end_pos + +def seg_char(sent): + pattern = re.compile(r'([\u4e00-\u9fa5])') + chars = pattern.split(sent) + chars = [w for w in chars if len(w.strip()) > 0] + return chars + +def get_parser(): + parser = argparse.ArgumentParser( + description='convert raw text to tokenized text', + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument('--nchar', + '-n', + default=1, + type=int, + help='number of characters to split, i.e., \ + aabb -> a a b b with -n 1 and aa bb with -n 2') + parser.add_argument('--skip-ncols', + '-s', + default=0, + type=int, + help='skip first n columns') + parser.add_argument('--space', + default='', + type=str, + help='space symbol') + parser.add_argument('--bpe-model', + '-m', + default=None, + type=str, + help='bpe model for english part') + parser.add_argument('--non-lang-syms', + '-l', + default=None, + type=str, + help='list of non-linguistic symobles,' + ' e.g., etc.') + parser.add_argument('text', + type=str, + default=False, + nargs='?', + help='input text') + parser.add_argument('--trans_type', + '-t', + type=str, + default="char", + choices=["char", "phn", "cn_char_en_bpe"], + help="""Transcript type. char/phn. e.g., for TIMIT + FADG0_SI1279 - + If trans_type is char, read from + SI1279.WRD file -> "bricks are an alternative" + Else if trans_type is phn, + read from SI1279.PHN file -> + "sil b r ih sil k s aa r er n aa l + sil t er n ih sil t ih v sil" """) + return parser + + +def main(): + parser = get_parser() + args = parser.parse_args() + + rs = [] + if args.non_lang_syms is not None: + with codecs.open(args.non_lang_syms, 'r', encoding="utf-8") as f: + nls = [x.rstrip() for x in f.readlines()] + rs = [re.compile(re.escape(x)) for x in nls] + + if args.bpe_model is not None: + import sentencepiece as spm + sp = spm.SentencePieceProcessor() + sp.load(args.bpe_model) + + if args.text: + f = codecs.open(args.text, encoding="utf-8") + else: + f = codecs.getreader("utf-8")( + sys.stdin if is_python2 else sys.stdin.buffer) + + sys.stdout = codecs.getwriter("utf-8")( + sys.stdout if is_python2 else sys.stdout.buffer) + line = f.readline() + n = args.nchar + while line: + x = line.split() + print(' '.join(x[:args.skip_ncols]), end=" ") + a = ' '.join(x[args.skip_ncols:]) + + # get all matched positions + match_pos = [] + for r in rs: + i = 0 + while i >= 0: + m = r.search(a, i) + if m: + match_pos.append([m.start(), m.end()]) + i = m.end() + else: + break + + if len(match_pos) > 0: + chars = [] + i = 0 + while i < len(a): + start_pos, end_pos = exist_or_not(i, match_pos) + if start_pos is not None: + chars.append(a[start_pos:end_pos]) + i = end_pos + else: + chars.append(a[i]) + i += 1 + a = chars + + if (args.trans_type == "phn"): + a = a.split(" ") + elif args.trans_type == "cn_char_en_bpe": + b = seg_char(a) + a = [] + for j in b: + # we use "▁" to instead of blanks among english words + # warning: here is "▁", not "_" + for l in j.strip().split("▁"): + if not l.encode('UTF-8').isalpha(): + a.append(l) + else: + for k in sp.encode_as_pieces(l): + a.append(k) + else: + a = [a[j:j + n] for j in range(0, len(a), n)] + + a_flat = [] + for z in a: + a_flat.append("".join(z)) + + a_chars = [z.replace(' ', args.space) for z in a_flat] + if (args.trans_type == "phn"): + a_chars = [z.replace("sil", args.space) for z in a_chars] + print(' '.join(a_chars)) + line = f.readline() + + +if __name__ == '__main__': + main() diff --git a/models/speech/speech_recognition/conformer/ixrt/utils/__init__.py b/models/speech/speech_recognition/conformer/ixrt/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c57435c110fc12f39d79c1b02f4b2e83dfe1a3e3 --- /dev/null +++ b/models/speech/speech_recognition/conformer/ixrt/utils/__init__.py @@ -0,0 +1,39 @@ +import os +import torch +import numpy as np + +from .embedding import RelPositionalEncoding + + +rel_positional_encoding = RelPositionalEncoding(256, 0.1) + + +def make_pad_mask(lengths: np.ndarray, max_len: int = 0) -> np.ndarray : + """Make mask tensor containing indices of padded part. + + See description of make_non_pad_mask. + + Args: + lengths (numpy.ndarray): Batch of lengths (B,). + Returns: + numpy.ndarray: Mask tensor containing indices of padded part. + + Examples: + >>> lengths = [5, 3, 2] + >>> make_pad_mask(lengths) + masks = [[0, 0, 0, 0 ,0], + [0, 0, 0, 1, 1], + [0, 0, 1, 1, 1]] + """ + + batch_size = lengths.shape[0] + max_len = max_len if max_len > 0 else lengths.max().item() + seq_range = np.arange(0, max_len, dtype=np.int64) + seq_range_expand = np.tile(seq_range, batch_size).reshape(batch_size, max_len) + seq_length_expand = lengths[..., None] + mask = seq_range_expand >= seq_length_expand + mask = np.expand_dims(mask, axis=1) + mask = ~mask + mask = mask[:, :, 2::2][:, :, 2::2] + mask = mask.astype(np.int32) + return mask diff --git a/models/speech/speech_recognition/conformer/ixrt/utils/embedding.py b/models/speech/speech_recognition/conformer/ixrt/utils/embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..0fd65c4cdfc3fec244c88d2c47cf94b33b9088f3 --- /dev/null +++ b/models/speech/speech_recognition/conformer/ixrt/utils/embedding.py @@ -0,0 +1,133 @@ +"""Positonal Encoding Module.""" + +import math +from typing import Tuple, Union + +import torch +import torch.nn.functional as F +import numpy as np + + +class PositionalEncoding(torch.nn.Module): + """Positional encoding. + + :param int d_model: embedding dim + :param float dropout_rate: dropout rate + :param int max_len: maximum input length + + PE(pos, 2i) = sin(pos/(10000^(2i/dmodel))) + PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel))) + """ + + def __init__(self, + d_model: int, + dropout_rate: float, + max_len: int = 5000, + reverse: bool = False): + """Construct an PositionalEncoding object.""" + super().__init__() + self.d_model = d_model + self.xscale = math.sqrt(self.d_model) + self.dropout = torch.nn.Dropout(p=dropout_rate) + self.max_len = max_len + + pe = torch.zeros(self.max_len, self.d_model) + position = torch.arange(0, self.max_len, + dtype=torch.float32).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, self.d_model, 2, dtype=torch.float32) * + -(math.log(10000.0) / self.d_model)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) + self.register_buffer("pe", pe) + + def forward(self, + x: torch.Tensor, + offset: Union[int, torch.Tensor] = 0) \ + -> Tuple[torch.Tensor, torch.Tensor]: + """Add positional encoding. + + Args: + x (torch.Tensor): Input. Its shape is (batch, time, ...) + offset (int, torch.tensor): position offset + + Returns: + torch.Tensor: Encoded tensor. Its shape is (batch, time, ...) + torch.Tensor: for compatibility to RelPositionalEncoding + """ + + pos_emb = self.position_encoding(offset, x.size(1), False) + x = x * self.xscale + pos_emb + return self.dropout(x), self.dropout(pos_emb) + + def position_encoding(self, + offset: Union[int, torch.Tensor], + size: int, + apply_dropout: bool = True) -> torch.Tensor: + """ For getting encoding in a streaming fashion + + Attention!!!!! + we apply dropout only once at the whole utterance level in a none + streaming way, but will call this function several times with + increasing input size in a streaming scenario, so the dropout will + be applied several times. + + Args: + offset (int or torch.tensor): start offset + size (int): required size of position encoding + + Returns: + torch.Tensor: Corresponding encoding + """ + # How to subscript a Union type: + # https://github.com/pytorch/pytorch/issues/69434 + # import ipdb;ipdb.set_trace() + if isinstance(offset, int): + assert offset + size <= self.max_len + pos_emb = self.pe[:, offset:offset + size] + elif isinstance(offset, torch.Tensor) and offset.dim() == 0: # scalar + assert offset + size <= self.max_len + pos_emb = self.pe[:, offset:offset + size] + else: # for batched streaming decoding on GPU + assert torch.max(offset) + size <= self.max_len + index = offset.unsqueeze(1) + \ + torch.arange(0, size).to(offset.device) # B X T + flag = index > 0 + # remove negative offset + index = index * flag + pos_emb = F.embedding(index, self.pe[0]) # B X T X d_model + + if apply_dropout: + pos_emb = self.dropout(pos_emb) + return pos_emb + + +class RelPositionalEncoding(PositionalEncoding): + """Relative positional encoding module. + See : Appendix B in https://arxiv.org/abs/1901.02860 + Args: + d_model (int): Embedding dimension. + dropout_rate (float): Dropout rate. + max_len (int): Maximum input length. + """ + + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000): + """Initialize class.""" + super().__init__(d_model, dropout_rate, max_len, reverse=True) + + def forward(self, + seq_len: int, + offset: Union[int, torch.Tensor] = 0) \ + -> Tuple[torch.Tensor, torch.Tensor]: + """Compute positional encoding. + Args: + x (torch.Tensor): Input tensor (batch, time, `*`). + Returns: + torch.Tensor: Encoded tensor (batch, time, `*`). + torch.Tensor: Positional embedding tensor (1, time, `*`). + """ + pos_emb = self.position_encoding(offset, seq_len, False) + # return self.dropout(pos_emb) + return pos_emb + diff --git a/models/speech/speech_recognition/conformer/ixrt/wenet/__init__.py b/models/speech/speech_recognition/conformer/ixrt/wenet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/speech/speech_recognition/conformer/ixrt/wenet/dataset.py b/models/speech/speech_recognition/conformer/ixrt/wenet/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..88a8cd15aec2277a36358883b25e929b179165e8 --- /dev/null +++ b/models/speech/speech_recognition/conformer/ixrt/wenet/dataset.py @@ -0,0 +1,179 @@ +# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random + +import torch +import torch.distributed as dist +from torch.utils.data import IterableDataset + +import wenet.processor as processor +from wenet.file_utils import read_lists + + +class Processor(IterableDataset): + def __init__(self, source, f, *args, **kw): + assert callable(f) + self.source = source + self.f = f + self.args = args + self.kw = kw + + def set_epoch(self, epoch): + self.source.set_epoch(epoch) + + def __iter__(self): + """ Return an iterator over the source dataset processed by the + given processor. + """ + assert self.source is not None + assert callable(self.f) + return self.f(iter(self.source), *self.args, **self.kw) + + def apply(self, f): + assert callable(f) + return Processor(self, f, *self.args, **self.kw) + + +class DistributedSampler: + def __init__(self, shuffle=True, partition=True): + self.epoch = -1 + self.update() + self.shuffle = shuffle + self.partition = partition + + def update(self): + assert dist.is_available() + if dist.is_initialized(): + self.rank = dist.get_rank() + self.world_size = dist.get_world_size() + else: + self.rank = 0 + self.world_size = 1 + worker_info = torch.utils.data.get_worker_info() + if worker_info is None: + self.worker_id = 0 + self.num_workers = 1 + else: + self.worker_id = worker_info.id + self.num_workers = worker_info.num_workers + return dict(rank=self.rank, + world_size=self.world_size, + worker_id=self.worker_id, + num_workers=self.num_workers) + + def set_epoch(self, epoch): + self.epoch = epoch + + def sample(self, data): + """ Sample data according to rank/world_size/num_workers + + Args: + data(List): input data list + + Returns: + List: data list after sample + """ + data = list(range(len(data))) + # TODO(Binbin Zhang): fix this + # We can not handle uneven data for CV on DDP, so we don't + # sample data by rank, that means every GPU gets the same + # and all the CV data + if self.partition: + if self.shuffle: + random.Random(self.epoch).shuffle(data) + data = data[self.rank::self.world_size] + data = data[self.worker_id::self.num_workers] + return data + + +class DataList(IterableDataset): + def __init__(self, lists, shuffle=True, partition=True): + self.lists = lists + self.sampler = DistributedSampler(shuffle, partition) + + def set_epoch(self, epoch): + self.sampler.set_epoch(epoch) + + def __iter__(self): + sampler_info = self.sampler.update() + indexes = self.sampler.sample(self.lists) + for index in indexes: + # yield dict(src=src) + data = dict(src=self.lists[index]) + data.update(sampler_info) + yield data + + +def Dataset(data_type, + data_list_file, + symbol_table, + conf, + bpe_model=None, + non_lang_syms=None, + partition=True): + """ Construct dataset from arguments + + We have two shuffle stage in the Dataset. The first is global + shuffle at shards tar/raw file level. The second is global shuffle + at training samples level. + + Args: + data_type(str): raw/shard + bpe_model(str): model for english bpe part + partition(bool): whether to do data partition in terms of rank + """ + assert data_type in ['raw', 'shard'] + lists = read_lists(data_list_file) + shuffle = conf.get('shuffle', True) + dataset = DataList(lists, shuffle=shuffle, partition=partition) + if data_type == 'shard': + dataset = Processor(dataset, processor.url_opener) + dataset = Processor(dataset, processor.tar_file_and_group) + else: + dataset = Processor(dataset, processor.parse_raw) + + dataset = Processor(dataset, processor.tokenize, symbol_table, bpe_model, + non_lang_syms, conf.get('split_with_space', False)) + filter_conf = conf.get('filter_conf', {}) + dataset = Processor(dataset, processor.filter, **filter_conf) + + resample_conf = conf.get('resample_conf', {}) + dataset = Processor(dataset, processor.resample, **resample_conf) + + speed_perturb = conf.get('speed_perturb', False) + if speed_perturb: + dataset = Processor(dataset, processor.speed_perturb) + + fbank_conf = conf.get('fbank_conf', {}) + dataset = Processor(dataset, processor.compute_fbank, **fbank_conf) + + spec_aug = conf.get('spec_aug', True) + if spec_aug: + spec_aug_conf = conf.get('spec_aug_conf', {}) + dataset = Processor(dataset, processor.spec_aug, **spec_aug_conf) + + if shuffle: + shuffle_conf = conf.get('shuffle_conf', {}) + dataset = Processor(dataset, processor.shuffle, **shuffle_conf) + + sort = conf.get('sort', True) + if sort: + sort_conf = conf.get('sort_conf', {}) + dataset = Processor(dataset, processor.sort, **sort_conf) + + batch_conf = conf.get('batch_conf', {}) + dataset = Processor(dataset, processor.batch, **batch_conf) + dataset = Processor(dataset, processor.padding) + return dataset diff --git a/models/speech/speech_recognition/conformer/ixrt/wenet/file_utils.py b/models/speech/speech_recognition/conformer/ixrt/wenet/file_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7b7e516cc61f759267f4ef09309ff0b45110a0c1 --- /dev/null +++ b/models/speech/speech_recognition/conformer/ixrt/wenet/file_utils.py @@ -0,0 +1,66 @@ +# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re + + +def read_lists(list_file): + lists = [] + with open(list_file, 'r', encoding='utf8') as fin: + for line in fin: + lists.append(line.strip()) + return lists + + +def read_non_lang_symbols(non_lang_sym_path): + """read non-linguistic symbol from file. + + The file format is like below: + + {NOISE}\n + {BRK}\n + ... + + + Args: + non_lang_sym_path: non-linguistic symbol file path, None means no any + syms. + + """ + if non_lang_sym_path is None: + return None + else: + syms = read_lists(non_lang_sym_path) + non_lang_syms_pattern = re.compile(r"(\[[^\[\]]+\]|<[^<>]+>|{[^{}]+})") + for sym in syms: + if non_lang_syms_pattern.fullmatch(sym) is None: + class BadSymbolFormat(Exception): + pass + raise BadSymbolFormat( + "Non-linguistic symbols should be " + "formatted in {xxx}//[xxx], consider" + " modify '%s' to meet the requirment. " + "More details can be found in discussions here : " + "https://github.com/wenet-e2e/wenet/pull/819" % (sym)) + return syms + + +def read_symbol_table(symbol_table_file): + symbol_table = {} + with open(symbol_table_file, 'r', encoding='utf8') as fin: + for line in fin: + arr = line.strip().split() + assert len(arr) == 2 + symbol_table[arr[0]] = int(arr[1]) + return symbol_table diff --git a/models/speech/speech_recognition/conformer/ixrt/wenet/processor.py b/models/speech/speech_recognition/conformer/ixrt/wenet/processor.py new file mode 100644 index 0000000000000000000000000000000000000000..9a542a3d204cdb3def8cf61ce0b0fd8bb31af32e --- /dev/null +++ b/models/speech/speech_recognition/conformer/ixrt/wenet/processor.py @@ -0,0 +1,550 @@ +# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import json +import random +import re +import tarfile +from subprocess import PIPE, Popen +from urllib.parse import urlparse + +import torch +import torchaudio +import torchaudio.compliance.kaldi as kaldi +from torch.nn.utils.rnn import pad_sequence + +AUDIO_FORMAT_SETS = set(['flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma']) + + +def url_opener(data): + """ Give url or local file, return file descriptor + Inplace operation. + + Args: + data(Iterable[str]): url or local file list + + Returns: + Iterable[{src, stream}] + """ + for sample in data: + assert 'src' in sample + # TODO(Binbin Zhang): support HTTP + url = sample['src'] + try: + pr = urlparse(url) + # local file + if pr.scheme == '' or pr.scheme == 'file': + stream = open(url, 'rb') + # network file, such as HTTP(HDFS/OSS/S3)/HTTPS/SCP + else: + cmd = f'curl -s -L {url}' + process = Popen(cmd, shell=True, stdout=PIPE) + sample.update(process=process) + stream = process.stdout + sample.update(stream=stream) + yield sample + except Exception as ex: + logging.warning('Failed to open {}'.format(url)) + + +def tar_file_and_group(data): + """ Expand a stream of open tar files into a stream of tar file contents. + And groups the file with same prefix + + Args: + data: Iterable[{src, stream}] + + Returns: + Iterable[{key, wav, txt, sample_rate}] + """ + for sample in data: + assert 'stream' in sample + stream = tarfile.open(fileobj=sample['stream'], mode="r|*") + prev_prefix = None + example = {} + valid = True + for tarinfo in stream: + name = tarinfo.name + pos = name.rfind('.') + assert pos > 0 + prefix, postfix = name[:pos], name[pos + 1:] + if prev_prefix is not None and prefix != prev_prefix: + example['key'] = prev_prefix + if valid: + yield example + example = {} + valid = True + with stream.extractfile(tarinfo) as file_obj: + try: + if postfix == 'txt': + example['txt'] = file_obj.read().decode('utf8').strip() + elif postfix in AUDIO_FORMAT_SETS: + waveform, sample_rate = torchaudio.load(file_obj) + example['wav'] = waveform + example['sample_rate'] = sample_rate + else: + example[postfix] = file_obj.read() + except Exception as ex: + valid = False + logging.warning('error to parse {}'.format(name)) + prev_prefix = prefix + if prev_prefix is not None: + example['key'] = prev_prefix + yield example + stream.close() + if 'process' in sample: + sample['process'].communicate() + sample['stream'].close() + + +def parse_raw(data): + """ Parse key/wav/txt from json line + + Args: + data: Iterable[str], str is a json line has key/wav/txt + + Returns: + Iterable[{key, wav, txt, sample_rate}] + """ + for sample in data: + assert 'src' in sample + json_line = sample['src'] + obj = json.loads(json_line) + assert 'key' in obj + assert 'wav' in obj + assert 'txt' in obj + key = obj['key'] + wav_file = obj['wav'] + txt = obj['txt'] + try: + if 'start' in obj: + assert 'end' in obj + sample_rate = torchaudio.backend.sox_io_backend.info( + wav_file).sample_rate + start_frame = int(obj['start'] * sample_rate) + end_frame = int(obj['end'] * sample_rate) + waveform, _ = torchaudio.backend.sox_io_backend.load( + filepath=wav_file, + num_frames=end_frame - start_frame, + frame_offset=start_frame) + else: + waveform, sample_rate = torchaudio.load(wav_file) + example = dict(key=key, + txt=txt, + wav=waveform, + sample_rate=sample_rate) + yield example + except Exception as ex: + logging.warning('Failed to read {}'.format(wav_file)) + + +def filter(data, + max_length=10240, + min_length=10, + token_max_length=200, + token_min_length=1, + min_output_input_ratio=0.0005, + max_output_input_ratio=1): + """ Filter sample according to feature and label length + Inplace operation. + + Args:: + data: Iterable[{key, wav, label, sample_rate}] + max_length: drop utterance which is greater than max_length(10ms) + min_length: drop utterance which is less than min_length(10ms) + token_max_length: drop utterance which is greater than + token_max_length, especially when use char unit for + english modeling + token_min_length: drop utterance which is + less than token_max_length + min_output_input_ratio: minimal ration of + token_length / feats_length(10ms) + max_output_input_ratio: maximum ration of + token_length / feats_length(10ms) + + Returns: + Iterable[{key, wav, label, sample_rate}] + """ + for sample in data: + assert 'sample_rate' in sample + assert 'wav' in sample + assert 'label' in sample + # sample['wav'] is torch.Tensor, we have 100 frames every second + num_frames = sample['wav'].size(1) / sample['sample_rate'] * 100 + if num_frames < min_length: + continue + if num_frames > max_length: + continue + if len(sample['label']) < token_min_length: + continue + if len(sample['label']) > token_max_length: + continue + if num_frames != 0: + if len(sample['label']) / num_frames < min_output_input_ratio: + continue + if len(sample['label']) / num_frames > max_output_input_ratio: + continue + yield sample + + +def resample(data, resample_rate=16000): + """ Resample data. + Inplace operation. + + Args: + data: Iterable[{key, wav, label, sample_rate}] + resample_rate: target resample rate + + Returns: + Iterable[{key, wav, label, sample_rate}] + """ + for sample in data: + assert 'sample_rate' in sample + assert 'wav' in sample + sample_rate = sample['sample_rate'] + waveform = sample['wav'] + if sample_rate != resample_rate: + sample['sample_rate'] = resample_rate + sample['wav'] = torchaudio.transforms.Resample( + orig_freq=sample_rate, new_freq=resample_rate)(waveform) + yield sample + + +def speed_perturb(data, speeds=None): + """ Apply speed perturb to the data. + Inplace operation. + + Args: + data: Iterable[{key, wav, label, sample_rate}] + speeds(List[float]): optional speed + + Returns: + Iterable[{key, wav, label, sample_rate}] + """ + if speeds is None: + speeds = [0.9, 1.0, 1.1] + for sample in data: + assert 'sample_rate' in sample + assert 'wav' in sample + sample_rate = sample['sample_rate'] + waveform = sample['wav'] + speed = random.choice(speeds) + if speed != 1.0: + wav, _ = torchaudio.sox_effects.apply_effects_tensor( + waveform, sample_rate, + [['speed', str(speed)], ['rate', str(sample_rate)]]) + sample['wav'] = wav + + yield sample + + +def compute_fbank(data, + num_mel_bins=23, + frame_length=25, + frame_shift=10, + dither=0.0): + """ Extract fbank + + Args: + data: Iterable[{key, wav, label, sample_rate}] + + Returns: + Iterable[{key, feat, label}] + """ + for sample in data: + assert 'sample_rate' in sample + assert 'wav' in sample + assert 'key' in sample + assert 'label' in sample + sample_rate = sample['sample_rate'] + waveform = sample['wav'] + waveform = waveform * (1 << 15) + # Only keep key, feat, label + mat = kaldi.fbank(waveform, + num_mel_bins=num_mel_bins, + frame_length=frame_length, + frame_shift=frame_shift, + dither=dither, + energy_floor=0.0, + sample_frequency=sample_rate) + yield dict(key=sample['key'], label=sample['label'], feat=mat) + + +def __tokenize_by_bpe_model(sp, txt): + tokens = [] + # CJK(China Japan Korea) unicode range is [U+4E00, U+9FFF], ref: + # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) + pattern = re.compile(r'([\u4e00-\u9fff])') + # Example: + # txt = "你好 ITS'S OKAY 的" + # chars = ["你", "好", " ITS'S OKAY ", "的"] + chars = pattern.split(txt.upper()) + mix_chars = [w for w in chars if len(w.strip()) > 0] + for ch_or_w in mix_chars: + # ch_or_w is a single CJK charater(i.e., "你"), do nothing. + if pattern.fullmatch(ch_or_w) is not None: + tokens.append(ch_or_w) + # ch_or_w contains non-CJK charaters(i.e., " IT'S OKAY "), + # encode ch_or_w using bpe_model. + else: + for p in sp.encode_as_pieces(ch_or_w): + tokens.append(p) + + return tokens + + +def tokenize(data, symbol_table, bpe_model=None, non_lang_syms=None, + split_with_space=False): + """ Decode text to chars or BPE + Inplace operation + + Args: + data: Iterable[{key, wav, txt, sample_rate}] + + Returns: + Iterable[{key, wav, txt, tokens, label, sample_rate}] + """ + if non_lang_syms is not None: + non_lang_syms_pattern = re.compile(r"(\[[^\[\]]+\]|<[^<>]+>|{[^{}]+})") + else: + non_lang_syms = {} + non_lang_syms_pattern = None + + if bpe_model is not None: + import sentencepiece as spm + sp = spm.SentencePieceProcessor() + sp.load(bpe_model) + else: + sp = None + + for sample in data: + assert 'txt' in sample + txt = sample['txt'].strip() + if non_lang_syms_pattern is not None: + parts = non_lang_syms_pattern.split(txt.upper()) + parts = [w for w in parts if len(w.strip()) > 0] + else: + parts = [txt] + + label = [] + tokens = [] + for part in parts: + if part in non_lang_syms: + tokens.append(part) + else: + if bpe_model is not None: + tokens.extend(__tokenize_by_bpe_model(sp, part)) + else: + if split_with_space: + part = part.split(" ") + for ch in part: + if ch == ' ': + ch = "▁" + tokens.append(ch) + + for ch in tokens: + if ch in symbol_table: + label.append(symbol_table[ch]) + elif '' in symbol_table: + label.append(symbol_table['']) + + sample['tokens'] = tokens + sample['label'] = label + yield sample + + +def spec_aug(data, num_t_mask=2, num_f_mask=2, max_t=50, max_f=10, max_w=80): + """ Do spec augmentation + Inplace operation + + Args: + data: Iterable[{key, feat, label}] + num_t_mask: number of time mask to apply + num_f_mask: number of freq mask to apply + max_t: max width of time mask + max_f: max width of freq mask + max_w: max width of time warp + + Returns + Iterable[{key, feat, label}] + """ + for sample in data: + assert 'feat' in sample + x = sample['feat'] + assert isinstance(x, torch.Tensor) + y = x.clone().detach() + max_frames = y.size(0) + max_freq = y.size(1) + # time mask + for i in range(num_t_mask): + start = random.randint(0, max_frames - 1) + length = random.randint(1, max_t) + end = min(max_frames, start + length) + y[start:end, :] = 0 + # freq mask + for i in range(num_f_mask): + start = random.randint(0, max_freq - 1) + length = random.randint(1, max_f) + end = min(max_freq, start + length) + y[:, start:end] = 0 + sample['feat'] = y + yield sample + + +def shuffle(data, shuffle_size=10000): + """ Local shuffle the data + + Args: + data: Iterable[{key, feat, label}] + shuffle_size: buffer size for shuffle + + Returns: + Iterable[{key, feat, label}] + """ + buf = [] + for sample in data: + buf.append(sample) + if len(buf) >= shuffle_size: + random.shuffle(buf) + for x in buf: + yield x + buf = [] + # The sample left over + random.shuffle(buf) + for x in buf: + yield x + + +def sort(data, sort_size=500): + """ Sort the data by feature length. + Sort is used after shuffle and before batch, so we can group + utts with similar lengths into a batch, and `sort_size` should + be less than `shuffle_size` + + Args: + data: Iterable[{key, feat, label}] + sort_size: buffer size for sort + + Returns: + Iterable[{key, feat, label}] + """ + + buf = [] + for sample in data: + buf.append(sample) + if len(buf) >= sort_size: + buf.sort(key=lambda x: x['feat'].size(0)) + for x in buf: + yield x + buf = [] + # The sample left over + buf.sort(key=lambda x: x['feat'].size(0)) + for x in buf: + yield x + + +def static_batch(data, batch_size=16): + """ Static batch the data by `batch_size` + + Args: + data: Iterable[{key, feat, label}] + batch_size: batch size + + Returns: + Iterable[List[{key, feat, label}]] + """ + buf = [] + for sample in data: + buf.append(sample) + if len(buf) >= batch_size: + yield buf + buf = [] + if len(buf) > 0: + yield buf + + +def dynamic_batch(data, max_frames_in_batch=12000): + """ Dynamic batch the data until the total frames in batch + reach `max_frames_in_batch` + + Args: + data: Iterable[{key, feat, label}] + max_frames_in_batch: max_frames in one batch + + Returns: + Iterable[List[{key, feat, label}]] + """ + buf = [] + longest_frames = 0 + for sample in data: + assert 'feat' in sample + assert isinstance(sample['feat'], torch.Tensor) + new_sample_frames = sample['feat'].size(0) + longest_frames = max(longest_frames, new_sample_frames) + frames_after_padding = longest_frames * (len(buf) + 1) + if frames_after_padding > max_frames_in_batch: + yield buf + buf = [sample] + longest_frames = new_sample_frames + else: + buf.append(sample) + if len(buf) > 0: + yield buf + + +def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000): + """ Wrapper for static/dynamic batch + """ + if batch_type == 'static': + return static_batch(data, batch_size) + elif batch_type == 'dynamic': + return dynamic_batch(data, max_frames_in_batch) + else: + logging.fatal('Unsupported batch type {}'.format(batch_type)) + + +def padding(data): + """ Padding the data into training data + + Args: + data: Iterable[List[{key, feat, label}]] + + Returns: + Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)] + """ + for sample in data: + assert isinstance(sample, list) + feats_length = torch.tensor([x['feat'].size(0) for x in sample], + dtype=torch.int32) + order = torch.argsort(feats_length, descending=True) + feats_lengths = torch.tensor( + [sample[i]['feat'].size(0) for i in order], dtype=torch.int32) + sorted_feats = [sample[i]['feat'] for i in order] + sorted_keys = [sample[i]['key'] for i in order] + sorted_labels = [ + torch.tensor(sample[i]['label'], dtype=torch.int64) for i in order + ] + label_lengths = torch.tensor([x.size(0) for x in sorted_labels], + dtype=torch.int32) + + padded_feats = pad_sequence(sorted_feats, + batch_first=True, + padding_value=0) + padding_labels = pad_sequence(sorted_labels, + batch_first=True, + padding_value=-1) + + yield (sorted_keys, padded_feats, padding_labels, feats_lengths, + label_lengths)