diff --git a/AscendIE/TorchAIE/built-in/audio/Wenet/README.md b/AscendIE/TorchAIE/built-in/audio/Wenet/README.md new file mode 100755 index 0000000000000000000000000000000000000000..c697a3dbf28a475c957e2be049a5ebc642e24dfc --- /dev/null +++ b/AscendIE/TorchAIE/built-in/audio/Wenet/README.md @@ -0,0 +1,298 @@ +# Wenet模型-推理指导 + + +- [概述](#ZH-CN_TOPIC_0000001172161501) + + - [输入输出数据](#section540883920406) + +- [推理环境准备](#ZH-CN_TOPIC_0000001126281702) + +- [快速上手](#ZH-CN_TOPIC_0000001126281700) + + - [获取源码](#section4622531142816) + - [准备数据集](#section183221994411) + - [模型推理](#section741711594517) + +- [模型推理性能&精度](#ZH-CN_TOPIC_0000001172201573) + + +# 概述 + +Wenet模型是一个使用Conformer结构的ASR(语音识别)模型,具有较好的端到端推理精度和推理性能。 + +- 参考实现: + + ``` + url=https://github.com/wenet-e2e/wenet.git + branch=v2.0.1 + model_name=Wenet + ``` + + +## 输入输出数据 + +- encoder online输入数据 + + | 输入数据 | 大小 | 数据类型 | 数据排布格式 | + | -------- | -------- | ------------------------- | ------------ | + | chunk_xs | B x 67 x 80 | FLOAT32 | ND | + | chunk_lens | B | INT32 | ND | + | offset | B x 1 | INT64 | ND | + | att_cache | B x 12 x 4 x 64 x 128 | FLOAT32 | ND | + | cnn_cache | B x 12 x 256 x 7 | FLOAT32 | ND | + | cache_mask | B x 1 x 64 | FLOAT32 | ND | + + +- encoder online输出数据 + + | 输出数据 | 大小 | 数据类型 | 数据排布格式 | + | -------- | -------- | -------- | ------------ | + | log_probs | batchsize x Castlog_probs_dim_1 x 10 | FLOAT32 | ND | + | log_probs_idx | batchsize x Castlog_probs_dim_1 x 10 | INT64 | ND | + | chunk_out | batchsize x Castlog_probs_dim_1 x 256 | FLOAT32 | ND | + | chunk_out_lens | batchsize | INT32 | ND | + | r_offset | batchsize x 1 | INT64 | ND | + | r_att_cache | batchsize x 12 x dim_2 x dim_3 x dim_4 | FLOAT32 | ND | + | r_cnn_cache | batchsize x 12 x 256 x dim_5 | FLOAT32 | ND | + | r_cache_mask | batchsize x 1 x 64 | FLOAT32 | ND | + +- encoder offline输入数据 + + | 输入数据 | 大小 | 数据类型 | 数据排布格式 | + | -------- | -------- | ------------------------- | ------------ | + | speech | batchsize x T x 80 | FLOAT32 | ND | + | speech_lengths | batchsize | INT32 | ND | + + +- encoder offline输出数据 + + | 输出数据 | 大小 | 数据类型 | 数据排布格式 | + | -------- | -------- | -------- | ------------ | + | encoder_out | batchsize x T_out x 256 | FLOAT32 | ND | + | encoder_out_lens | batchsize | INT32 | ND | + | ctc_log_probs | batchsize x T_OUT x 4233 | FLOAT32 | ND | + | beam_log_probs | batchsize x T_OUT x 10 | FLOAT32 | ND | + | beam_log_probs_idx | batchsize x T_OUT x 10 | INT64 | ND | + + +- decoder输入数据 + + | 输入数据 | 大小 | 数据类型 | 数据排布格式 | + | -------- | -------- | ------------------------- | ------------ | + | encoder_out | batchsize x T x 256 | FLOAT32 | ND | + | encoder_out_lens | batchsize | INT32 | ND | + | hyps_pad_sos_eos | batchsize x 10 x T2 | INT64 | ND | + | hyps_lens_sos | batchsize x 10 | INT32 | ND | + | r_hyps_pad_sos_eos | batchsize x 10 x T2 | INT64 | ND | + | ctc_score | batchsize x 10 | FLOAT32 | ND | + + +- decoder输出数据 + + | 输出数据 | 大小 | 数据类型 | 数据排布格式 | + | -------- | -------- | -------- | ------------ | + | best_index | batchsize | INT64 | ND | + + +# 推理环境准备 + +- 该模型需要以下插件与驱动 + + **表 1** 版本配套表 + + | 配套 | 版本 | 环境准备指导 | + |---------| ------- | ------------------------------------------------------------ | + | 固件与驱动 | 22.0.4 | [Pytorch框架推理环境准备](https://www.hiascend.com/document/detail/zh/ModelZoo/pytorchframework/pies) | + | CANN | 6.3.RC1 | - | + | Python | 3.7.5 | - | + | PyTorch | 1.8.0 | - | + | Torch_AIE | 6.3.rc2 | \ | + | 说明:Atlas 300I Duo 推理卡请以CANN版本选择实际固件与驱动版本。 | \ | \ | + + +# 快速上手 + +## 获取源码 + +1. 获取源码。 + + 在工作目录下执行下述命令获取源码并切换到相应路径。 + + ``` + git clone https://github.com/wenet-e2e/wenet.git + cd wenet + git checkout v2.0.1 + export wenet_path=$(pwd) + cd .. + ``` + +2. 安装依赖。 + + ``` + pip3 install -r requirements.txt + ``` + +3. 安装ctc_dcoder + + ``` + git clone https://github.com/Slyne/ctc_decoder.git + apt-get update + apt-get install swig + apt-get install python3-dev + cd ctc_decoder/swig && bash setup.sh + ``` + + + +## 准备数据集 + +1. 获取原始数据集。 + + ``` + cd wenet/examples/aishell/s0/ + mkdir -p /export/data/asr-data/OpenSLR/33 + bash run.sh --stage -1 --stop_stage -1 # 下载数据集 + ``` + +2. 处理数据集。 + + ``` + bash run.sh --stage 0 --stop_stage 0 + bash run.sh --stage 1 --stop_stage 1 + bash run.sh --stage 2 --stop_stage 2 + bash run.sh --stage 3 --stop_stage 3 + ``` + +3. Copy generated data.list file to ModelZoo Wenet working directory + ``` + cp data/test/data.list + ``` + + +## 模型推理 + +1. 模型转换。 + + Use PyTorch to script the model, and then use Torch AIE to convert the model, serialize it and perform inference. + + 1. 获取权重文件。 + + WeNet预训练模型[下载链接](https://github.com/wenet-e2e/wenet/blob/main/docs/pretrained_models.md),选择aishell数据集对应的Checkpoint Model下载,将压缩文件放到自己克隆wenet仓库的路径下解压。 + ``` + # Unpack and copy pretrained checkpoint model to working directory + cd + tar -zvxf aishell_u2++_conformer_exp.tar.gz + mkdir -p /exp/20210601_u2++_conformer_exp + cp -r aishell_u2++_conformer_exp/* /exp/20210601_u2++_conformer_exp + cd + ``` + + 2. Tracing the model + + Use trace.py to trace and save the .pth file. + + 配置python path + ``` + export PYTHONPATH=$wenet_path + ``` + + ``` + python trace.py --config=exp/20210601_u2++_conformer_exp/train.yaml --batch_size=${batch_size} --checkpoint=exp/20210601_u2++_conformer_exp/final.pt --encoder_out=encoder.pth --decoder_out=decoder.pth --reverse_weight 0.3 + ``` + + Add `--streaming` flag to obtain online encoder model. + + 3. Use the Torch AIE to compile the models. + + ``` + python compile_model.py --encoder=encoder.pth --encoder_out=encoder_aie.pth --decoder=decoder.pth --decoder_out=decoder_aie.pth --batch_size=${batch_size} --encoder_gears="262, 326, 390, 454, 518, 582, 646, 710, 774, 838, 902, 966, 1028, 1284, 1478" --decoder_gears="5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44" + ``` + + - 参数说明: + + - --encoder:path to input torch encoder model + - --encoder_out:path to store the compiled encoder model + - --encoder_gears:encoder dynamic dimensions + - --decoder:path to input torch decoder model + - --decoder_out:path to store the compiled decoder model + - --encoder_gears:decoder dynamic dimensions + - --soc\_version:处理器型号。 + + Add `--streaming` command line argument to compile the online encoder model. + +2. 开始推理验证。 + + 1. 非流式分档场景精度和性能验证。 + + 端到端encoder + decoder,多进程 + + ``` + python recognize_aie.py --config=exp/20210601_u2++_conformer_exp/train.yaml --test_data=data.list --dict=exp/20210601_u2++_conformer_exp/units.txt --mode=attention_rescoring --result_file=aie_result${batch_size}.log --encoder_aie=encoder_aie.pth --decoder_aie=decoder_aie.pth --batch_size=${batch_size} --device_id=0 --test_file=static_test_result_bs32_1.txt --encoder_gears="646, 710, 774, 838, 902, 966, 1028, 1284, 1478" + ``` + + - 参数说明: + - --config:aishell预训练模型配置文件路径。 + - --test_data:测试数据路径。 + - --dict:aishell预训练模型词典路径。 + - --mode:解码模式,可选ctc_greedy_search、ctc_prefix_beam_search和attention_rescoring。 + - --result_file:解码结果文件。 + - --encoder_aie:path to compiled encoder model. + - --decoder_aie:path to compiled decoder model. + - --batch_size:batch大小。 + - --device_id:卡序号。。 + - --test_file:性能结果文件。 + - --encoder_gears: encoder dynamic dims. + + ### Pay attention + `max_length` setting should be adjusted whenever using a batch of size 4 or 1. Please, adjust this parameter by changing the config value in `recognize_aie.py:142` script to: + ``` + test_conf_['filter_conf']['max_length'] = 700 + ``` + + ``` + # 精度验证 + python tools/compute-wer.py --char=1 --v=1 exp/20210601_u2++_conformer_exp/text aie_result${batch_size}.log + ``` + + - 参数说明: + - --char:是否逐词比对,0为整句比对,1为逐词比对。 + - --v:是否打印对比结果,0为不打印,1为打印。 + - aishell_u2pp_conformer_exp/text:标签文件路径。 + - static_result.txt:比对结果输出文件路径。 + + 2. 流式纯推理场景精度和性能验证。 + Please, use the following script to check performance results of streaming encoder model and check it's accuracy. + ``` + python cosine_similarity.py --encoder_torch=torch_online_encoder_b${batch_size}.pth --encoder_aie=aie_online_encoder_b${batch_size}.pth --batch_size=${batch_size} --device_id=1 + ``` + - 参数说明: + - --encoder_torch:Path to torch encoder streaming model. + - --encoder_aie:Path to AIE encoder streaming model. + - --batch_size:batch大小。 + - --device_id:卡序号。 + +# 模型推理性能&精度 + +性能参考下列数据。 + +非流式分档(encoder + decoder)场景,进程数1。 + +| 芯片型号 | Batch Size | 数据集 | 精度(WER) | 端到端性能(fps) | +|-------------|------------|-----------|---------- |------------------| +| Ascend310P3 | 1 | aishell | 4.60% | 39.55 | +| Ascend310P3 | 4 | aishell | 4.60% | 53.38 | +| Ascend310P3 | 8 | aishell | 4.60% | 51.87 | +| Ascend310P3 | 16 | aishell | 4.60% | 49.66 | +| Ascend310P3 | 32 | aishell | 4.60% | 44.11 | +| Ascend310P3 | 64 | aishell | 4.60% | 41.23 | + +流式纯推理场景 + +| 芯片型号 | Batch Size | om与torch余弦相似度 | 性能 | +|---------------|------------|-------------------|----------------| +| Ascend310P3 | 1 | 0.9999992 | 593.74 fps | +| Ascend310P3 | 4 | 0.9999999 | 1219.60 fps | +| Ascend310P3 | 8 | 0.9999999 | 1635.93 fps | +| Ascend310P3 | 16 | 0.9999998 | 1966.44 fps | +| Ascend310P3 | 32 | 0.99999946 | 2098.52 fps | +| Ascend310P3 | 64 | 0.99999976 | 2079.44 fps | diff --git a/AscendIE/TorchAIE/built-in/audio/Wenet/compile_model.py b/AscendIE/TorchAIE/built-in/audio/Wenet/compile_model.py new file mode 100644 index 0000000000000000000000000000000000000000..8dda1405370d8a939e3e3bf97c2834e8b523dbde --- /dev/null +++ b/AscendIE/TorchAIE/built-in/audio/Wenet/compile_model.py @@ -0,0 +1,146 @@ +# BSD 3-Clause License +# +# All rights reserved. +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# ============================================================================ + +""" +This script is for testing exported ascend encoder and decoder from +export_onnx_npu.py. The exported ascend models only support batch offline ASR inference. +It requires a python wrapped c++ ctc decoder. +Please install it from ctc decoder in github +""" +from __future__ import print_function + +import argparse +import logging +import sys + +import torch_aie +import torch + +def get_args(): + parser = argparse.ArgumentParser(description='recognize with your model') + parser.add_argument('--batch_size', + type=int, + default=32, + help='used batch size') + parser.add_argument('--streaming', + action='store_true', + help="whether to export streaming encoder, default false") + parser.add_argument('--encoder', + type=str, + required=True, + help='path to input traced encoder .pth model') + parser.add_argument('--encoder_out', + type=str, + required=True, + help='path to store compiled encoder .pth model') + parser.add_argument('--decoder', + type=str, + required=True, + help='path to input traced decoder .pth model') + parser.add_argument('--decoder_out', + type=str, + required=True, + help='path to store compiled decoder .pth model') + parser.add_argument('--encoder_gears', + type=str, + required=True, + help='dynamic dims gears info for encoder, please input split by ","') + parser.add_argument('--decoder_gears', + type=str, + required=True, + help='dynamic dims gears info for encoder, please input split by ","') + args_ = parser.parse_args() + print(args_) + return args_ + +def get_dynamic_dims_encoder(batch_size, args): + inputs = [] + for dim in list(map(int, args.encoder_gears.split(','))): + inputs.append([torch_aie.Input((batch_size, dim, 80)), + torch_aie.Input([batch_size], dtype=torch_aie.dtype.INT32)]) + return inputs + +def compile_offline_encoder(args): + bz = args.batch_size + traced = torch.jit.load(args.encoder) + print('[INFO] Started to compile the offline encoder using TorchAIE') + encoder_compiled = torch_aie.compile(traced, inputs=get_dynamic_dims_encoder(bz, args), soc_version="Ascend310P3") + print('[INFO] Success! The offline encoder was compiled!') + encoder_compiled.save(args.encoder_out) + +def compile_online_encoder(args): + batch_size = args.batch_size + traced = torch.jit.load(args.encoder) + + inputs = [ + torch_aie.Input((batch_size, 67, 80)), # chunk_xs + torch_aie.Input([batch_size], dtype=torch_aie.dtype.INT32), # chunk_lens + torch_aie.Input((batch_size, 1), dtype=torch_aie.dtype.INT32), # offset + torch_aie.Input((batch_size, 12, 4, 64, 128)), # to_cache + torch_aie.Input((batch_size, 12, 256, 7)), # cnn_cache + torch_aie.Input((batch_size, 1, 64)), # cache_mask + ] + + print('[INFO] Started to compile the online encoder using TorchAIE') + encoder_compiled = torch_aie.compile(traced, inputs=inputs, soc_version="Ascend310P3") + print('[INFO] Success! The online encoder was compiled!') + encoder_compiled.save(args.encoder_out) + +def get_dynamic_dims_decoder(batch_size, args): + inputs = [] + + for dim in list(map(int, args.decoder_gears.split(','))): + inputs.append( + [torch_aie.Input(shape=[batch_size, 384, 256], dtype=torch_aie.dtype.FLOAT), + torch_aie.Input(shape=[batch_size], dtype=torch_aie.dtype.INT32), + torch_aie.Input(shape=[batch_size, 10, dim], dtype=torch_aie.dtype.INT64), + torch_aie.Input(shape=[batch_size, 10], dtype=torch_aie.dtype.INT32), + torch_aie.Input(shape=[batch_size, 10, dim], dtype=torch_aie.dtype.INT64), + torch_aie.Input(shape=[batch_size, 10], dtype=torch_aie.dtype.FLOAT)]) + return inputs + +def compile_decoder(args): + bz = args.batch_size + decoder = torch.jit.load(args.decoder) + print('[INFO] Started to compile the decoder using TorchAIE') + decoder_compiled = torch_aie.compile(decoder, inputs=get_dynamic_dims_decoder(bz, args), soc_version="Ascend310P3") + print('[INFO] Success! The decoder was compiled!') + decoder_compiled.save(args.decoder_out) + + +if __name__ == '__main__': + args = get_args() + if args.streaming: + compile_online_encoder(args) + else: + compile_offline_encoder(args) + compile_decoder(args) + diff --git a/AscendIE/TorchAIE/built-in/audio/Wenet/cosine_similarity.py b/AscendIE/TorchAIE/built-in/audio/Wenet/cosine_similarity.py new file mode 100644 index 0000000000000000000000000000000000000000..fad4e657255a7557706dca356fe9d0178e2144e5 --- /dev/null +++ b/AscendIE/TorchAIE/built-in/audio/Wenet/cosine_similarity.py @@ -0,0 +1,97 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import argparse +import time + +import torch +import torch_aie +import numpy as np + +def cos_sim(a, b): + b_norm = np.linalg.norm(b.numpy()) + a_norm = np.linalg.norm(a.numpy()) + cos = np.dot(a, b)/(a_norm * b_norm) + return cos + +def measure(model, dummpy_input): + chunk_xs , chunk_lens, offset, att_cache, cnn_cache, cache_mask = dummpy_input + npu_str = 'npu:' + str(args.device_id) + chunk_xs = chunk_xs.to(npu_str) + chunk_lens = chunk_lens.to(npu_str) + offset = offset.to(npu_str) + att_cache = att_cache.to(npu_str) + cnn_cache = cnn_cache.to(npu_str) + cache_mask = cache_mask.to(npu_str) + + + # Do a warmup + print('Doing a warmup') + _ = model(chunk_xs , chunk_lens, offset, att_cache, cnn_cache, cache_mask) + _ = model(chunk_xs , chunk_lens, offset, att_cache, cnn_cache, cache_mask) + _ = model(chunk_xs , chunk_lens, offset, att_cache, cnn_cache, cache_mask) + + print('Run inference...') + s = time.time_ns() + n = 1000 + for i in range(n): + _ = model(chunk_xs , chunk_lens, offset, att_cache, cnn_cache, cache_mask) + ms = (time.time_ns()-s) / n / 10**6 + print('Mean compute time : {} (ms)'.format(ms)) + return ms, _ + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='get cosine similaryty with your model') + parser.add_argument('--encoder_torch', required=True, help='encoder onnx file') + parser.add_argument('--encoder_aie', required=True, help='encoder om file') + parser.add_argument('--batch_size', required=True, type=int, help='batch size') + parser.add_argument('--device_id', default=0, type=int, help='device id') + parser.add_argument('--decoding_chunk_size', default=16, type=int, + help='decoding chunk size, <=0 is not supported') + parser.add_argument('--num_decoding_left_chunks', + default=4, + type=int, + required=False, + help="number of left chunks, <= 0 is not supported") + args = parser.parse_args() + print(args) + + torch_aie.set_device(args.device_id) + required_cache_size = args.decoding_chunk_size * args.num_decoding_left_chunks + + + chunk_xs = torch.from_numpy(np.random.random((args.batch_size, 67, 80)).astype("float32")) + chunk_lens = torch.from_numpy(np.array([600]*args.batch_size).astype("int32")) + offset = torch.from_numpy(np.array([0]*args.batch_size).reshape((args.batch_size, 1)).astype("int32")) + att_cache = torch.from_numpy(np.random.random((args.batch_size, 12, 4, required_cache_size, 128)).astype("float32")) + cnn_cache = torch.from_numpy(np.random.random((args.batch_size, 12, 256, 7)).astype("float32")) + cache_mask = torch.from_numpy(np.random.random((args.batch_size, 1, required_cache_size)).astype("float32")) + + aie_net = torch.jit.load(args.encoder_aie) + torch_net = torch.jit.load(args.encoder_torch) + + aie_net.eval() + torch_net.eval() + time_ms = 0 + with torch.no_grad(): + time_ms, output_data = measure(aie_net, (chunk_xs, chunk_lens, offset, att_cache, cnn_cache, cache_mask)) + output_data = output_data[0].to('cpu') + y = output_data.flatten() + torch_output = torch_net(chunk_xs, chunk_lens, offset, att_cache, cnn_cache, cache_mask) + torch_y = torch_output[0].flatten() + cos_1 = cos_sim(y, torch_y) + print("acc: ", cos_1) + + throughput = 1000 * args.batch_size / time_ms + print("throughput 1000*batchsize.mean({})/NPU_compute_time.mean({}): {}".format(args.batch_size, time_ms, throughput)) diff --git a/AscendIE/TorchAIE/built-in/audio/Wenet/recognize_aie.py b/AscendIE/TorchAIE/built-in/audio/Wenet/recognize_aie.py new file mode 100644 index 0000000000000000000000000000000000000000..1907218187b898e20bd8fafe4d0461f8c01ba208 --- /dev/null +++ b/AscendIE/TorchAIE/built-in/audio/Wenet/recognize_aie.py @@ -0,0 +1,400 @@ +# BSD 3-Clause License +# +# All rights reserved. +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# ============================================================================ + +""" +This script is for testing exported ascend encoder and decoder from +export_onnx_npu.py. The exported ascend models only support batch offline ASR inference. +It requires a python wrapped c++ ctc decoder. +Please install it from ctc decoder in github +""" +from __future__ import print_function + +import argparse +import copy +import logging +import os +import sys +import time +import stat + +import multiprocessing +import numpy as np +import torch +import yaml +from torch.utils.data import DataLoader +import torch_aie + +from wenet.dataset.dataset import Dataset +from wenet.utils.common import IGNORE_ID +from wenet.utils.file_utils import read_symbol_table +from wenet.utils.config import override_config +from swig_decoders import map_batch, \ + ctc_beam_search_decoder_batch, \ + TrieVector, PathTrie + + +def get_args(): + parser = argparse.ArgumentParser(description='recognize with your model') + parser.add_argument('--config', required=True, help='config file') + parser.add_argument('--test_data', required=True, help='test data file') + parser.add_argument('--data_type', + default='raw', + choices=['raw', 'shard'], + help='train and cv data type') + + parser.add_argument('--dict', required=True, help='dict file') + parser.add_argument('--encoder_aie', required=True, + help='encoder om file') + parser.add_argument('--decoder_aie', required=True, + help='decoder om file') + parser.add_argument('--result_file', required=True, help='asr result file') + parser.add_argument('--test_file', required=True, help='asr result file') + parser.add_argument('--batch_size', + type=int, + default=32, + help='asr result file') + parser.add_argument('--device_id', + type=int, + default=0, + help='npu device id') + parser.add_argument('--mode', + choices=[ + 'ctc_greedy_search', 'ctc_prefix_beam_search', + 'attention_rescoring'], + default='attention_rescoring', + help='decoding mode') + parser.add_argument('--bpe_model', + default=None, + type=str, + help='bpe model for english part') + parser.add_argument('--override_config', + action='append', + default=[], + help="override yaml config") + parser.add_argument('--fp16', + action='store_true', + help='whether to export fp16 model, default false') + parser.add_argument('--static', + action='store_true', + help='whether to run static model') + parser.add_argument('--num_process', + type=int, + default=2, + help='number of mutiprocesses') + parser.add_argument('--encoder_gears', + type=str, + help='dynamic dims gears info for encoder, please input split by ","') + parser.add_argument('--decoder_gears', + type=str, + help='dynamic dims gears info for encoder, please input split by ","') + parser.add_argument('--output_size', + type=str, + help='only effect in dynamic shapes mode,\ + outputs size info for encoder, please input split by ","') + args_ = parser.parse_args() + print(args_) + return args_ + +def get_dict(dict_path): + vocabulary_ = [] + char_dict_ = {} + with open(dict_path, 'r') as fin_: + for line in fin_: + arr = line.strip().split() + if len(arr) != 2: + print('dict format is incorrect') + exit(0) + char_dict_[int(arr[1])] = arr[0] + vocabulary_.append(arr[0]) + return vocabulary_, char_dict_ + +def adjust_test_conf(test_conf_, batch_size_): + # adjust dataset parameters for om + # reserved suitable memory + test_conf_['filter_conf']['max_length'] = 1028 + test_conf_['filter_conf']['min_length'] = 0 + test_conf_['filter_conf']['token_max_length'] = 1028 + test_conf_['filter_conf']['token_min_length'] = 0 + test_conf_['filter_conf']['max_output_input_ratio'] = 1028 + test_conf_['filter_conf']['min_output_input_ratio'] = 0 + test_conf_['speed_perturb'] = False + test_conf_['spec_aug'] = False + test_conf_['shuffle'] = False + test_conf_['sort'] = False + test_conf_['fbank_conf']['dither'] = 0.0 + test_conf_['batch_conf']['batch_type'] = "static" + test_conf_['batch_conf']['batch_size'] = batch_size_ + return test_conf_ + +class AsrModel: + def __init__(self, encoder, decoder, args_, reverse_weight_) -> None: + self.encoder, self.decoder = encoder, decoder + self.vocabulary, self.char_dict = get_dict(args_.dict) + self.reverse_weight = reverse_weight_ + self.mul_shape = list(map(int, args.encoder_gears.split(','))) + self.mul_shape_decoder = [384] + self.output_size = None + if args.output_size: + self.output_size = list(map(int, args.output_size.split(','))) + self.args = args_ + + def forward(self, data): + nums_ = 0 + eos = sos = len(self.char_dict) - 1 + mode = "dymdims" + + keys, feats, _, feats_lengths, _ = data + feats, feats_lengths = feats.numpy(), feats_lengths.numpy() + ort_outs = None + pad_size = 0 + pad_batch = 0 + for n in self.mul_shape: + if n > feats.shape[1]: + pad_size = n - feats.shape[1] + break + if feats.shape[0] < self.args.batch_size: + pad_batch = self.args.batch_size - feats.shape[0] + feats_lengths = np.pad(feats_lengths, [(0, pad_batch)], 'constant') + feats_pad = np.pad(feats, [(0, pad_batch), (0, pad_size), (0, 0)], 'constant') + + feats_pad = torch.from_numpy(feats_pad).to("npu:1") + feats_lengths = torch.from_numpy(feats_lengths).to("npu:1") + ort_outs = self.encoder.forward( + feats_pad, feats_lengths) + encoder_out, encoder_out_lens, _, \ + beam_log_probs, beam_log_probs_idx = ort_outs + beam_size = beam_log_probs.shape[-1] + batch_size = beam_log_probs.shape[0] + num_processes = min(multiprocessing.cpu_count(), batch_size) + if self.args.mode == 'ctc_greedy_search': + if beam_size != 1: + log_probs_idx = beam_log_probs_idx[:, :, 0] + batch_sents = [] + for idx_, seq in enumerate(log_probs_idx): + batch_sents.append(seq[0:encoder_out_lens[idx_]].tolist()) + hyps = map_batch(batch_sents, self.vocabulary, num_processes, + True, 0) + elif self.args.mode in ('ctc_prefix_beam_search', "attention_rescoring"): + batch_log_probs_seq_list = beam_log_probs.tolist() + batch_log_probs_idx_list = beam_log_probs_idx.tolist() + batch_len_list = encoder_out_lens.tolist() + batch_log_probs_seq = [] + batch_log_probs_ids = [] + batch_start = [] # only effective in streaming deployment + batch_root = TrieVector() + root_dict = {} + for i in range(len(batch_len_list)): + num_sent = batch_len_list[i] + batch_log_probs_seq.append( + batch_log_probs_seq_list[i][0:num_sent]) + batch_log_probs_ids.append( + batch_log_probs_idx_list[i][0:num_sent]) + root_dict[i] = PathTrie() + batch_root.append(root_dict[i]) + batch_start.append(True) + score_hyps = ctc_beam_search_decoder_batch(batch_log_probs_seq, + batch_log_probs_ids, + batch_root, + batch_start, + beam_size, + num_processes, + 0, -2, 0.99999) + if self.args.mode == 'ctc_prefix_beam_search': + hyps = [] + for cand_hyps in score_hyps: + hyps.append(cand_hyps[0][1]) + hyps = map_batch(hyps, self.vocabulary, num_processes, False, 0) + if self.args.mode == 'attention_rescoring': + ctc_score, all_hyps = [], [] + max_len = 0 + for hyps in score_hyps: + cur_len = len(hyps) + if len(hyps) < beam_size: + hyps += (beam_size - cur_len) * ((-float("INF"), (0,)),) + cur_ctc_score = [] + for hyp in hyps: + cur_ctc_score.append(hyp[0]) + all_hyps.append(list(hyp[1])) + if len(hyp[1]) > max_len: + max_len = len(hyp[1]) + ctc_score.append(cur_ctc_score) + if self.args.fp16: + ctc_score = np.array(ctc_score, dtype=np.float16) + else: + ctc_score = np.array(ctc_score, dtype=np.float32) + hyps_pad_sos_eos = np.ones( + (batch_size, beam_size, max_len + 2), dtype=np.int64) * IGNORE_ID + r_hyps_pad_sos_eos = np.ones( + (batch_size, beam_size, max_len + 2), dtype=np.int64) * IGNORE_ID + hyps_lens_sos = np.ones( + (batch_size, beam_size), dtype=np.int32) + k = 0 + for i in range(batch_size): + for j in range(beam_size): + cand = all_hyps[k] + l = len(cand) + 2 + hyps_pad_sos_eos[i][j][0:l] = [sos] + cand + [eos] + r_hyps_pad_sos_eos[i][j][0:l] = [ + sos] + cand[::-1] + [eos] + hyps_lens_sos[i][j] = len(cand) + 1 + k += 1 + best_index = None + encoder_out_lens = encoder_out_lens.to("npu:1") + hyps_pad_sos_eos = torch.from_numpy(hyps_pad_sos_eos).to("npu:1") + hyps_lens_sos = torch.from_numpy(hyps_lens_sos).to("npu:1") + r_hyps_pad_sos_eos = torch.from_numpy(r_hyps_pad_sos_eos).to("npu:1") + ctc_score = torch.from_numpy(ctc_score).to("npu:1") + if not self.args.static: + output_size = 100000 + + if self.reverse_weight > 0: + best_index = self.decoder.forward( + encoder_out, encoder_out_lens, hyps_pad_sos_eos, hyps_lens_sos, + r_hyps_pad_sos_eos, ctc_score) + else: + best_index = self.decoder.forward( + encoder_out, encoder_out_lens, hyps_pad_sos_eos, hyps_lens_sos, + ctc_score) + else: + pad_size = 0 + for n in self.mul_shape_decoder: + if n > encoder_out.shape[1]: + pad_size = n - encoder_out.shape[1] + break + encoder_out = np.pad(encoder_out.cpu(), ((0, 0), (0, pad_size), (0, 0)), 'constant') + encoder_out = torch.from_numpy(encoder_out).to("npu:1") + if self.reverse_weight > 0: + best_index = self.decoder.forward( + encoder_out, encoder_out_lens, hyps_pad_sos_eos, hyps_lens_sos, r_hyps_pad_sos_eos, ctc_score) + else: + best_index = self.decoder.infer( + [encoder_out, encoder_out_lens, hyps_pad_sos_eos, hyps_lens_sos, ctc_score]) + best_index = best_index.to("cpu") + best_sents = [] + k = 0 + for idx_ in best_index: + cur_best_sent = all_hyps[k: k + beam_size][idx_] + best_sents.append(cur_best_sent) + k += beam_size + hyps = map_batch(best_sents, self.vocabulary, num_processes) + + for i, key in enumerate(keys): + nums_ += 1 + content = hyps[i] + logger.info('{} {}'.format(key, content)) + return nums_ + +def infer_process(list_args): + idx_, encoder, decoder = list_args + batches = packed_data[idx_] + model = AsrModel(encoder, decoder, args, reverse_weight) + + nums_ = 0 + infer_s_t = time.time() + for data in batches: + num = model.forward(data) + nums_ += num + + infer_e_t = time.time() + # sync_num.pop() + return nums_, infer_e_t - infer_s_t + +if __name__ == '__main__': + args = get_args() + logger = logging.getLogger(__name__) + logger.setLevel(level=logging.DEBUG) + handler = logging.FileHandler(args.result_file, mode='w') + formatter = logging.Formatter('%(message)s') + handler.setFormatter(formatter) + + console = logging.StreamHandler() + console.setLevel(level=logging.DEBUG) + console_format = logging.Formatter('Recognize: %(message)s') + console.setFormatter(console_format) + logger.addHandler(handler) + logger.addHandler(console) + with open(args.config, 'r') as fin: + configs = yaml.safe_load(fin) + if len(args.override_config) > 0: + print('override!!!') + exit(0) + configs = override_config(configs, args.override_config) + + reverse_weight = configs["model_conf"].get("reverse_weight", 0.0) + symbol_table = read_symbol_table(args.dict) + test_conf = copy.deepcopy(configs['dataset_conf']) + test_conf = adjust_test_conf(test_conf, args.batch_size) + + test_dataset = Dataset(args.data_type, + args.test_data, + symbol_table, + test_conf, + args.bpe_model, + partition=False) + + test_data_loader = DataLoader(test_dataset, batch_size=None, + num_workers=multiprocessing.cpu_count() // 2) + manager = multiprocessing.Manager() + num_process = args.num_process + + # lots of file will be open, need to change sharing strategy + torch.multiprocessing.set_sharing_strategy('file_system') + # packed data for mitiple processes + packed_data = [[] for _ in range(num_process)] + idx = 0 + pre_s_t = time.time() + for batch in test_data_loader: + packed_data[idx].append(batch) + idx = (idx + 1) % num_process + pre_e_t = time.time() + + sync_num = manager.list() + data_cnt = 0 + infer_times = 0 + + encoder = torch.jit.load(args.encoder_aie) + decoder = torch.jit.load(args.decoder_aie) + + torch_aie.set_device(args.device_id) + + data_cnt, infer_times = infer_process([0, encoder, decoder]) + + fps = float((data_cnt) / (infer_times + pre_e_t - pre_s_t)) + fps_str = "fps: {}\n".format(fps) + resstr = "total time: {}\n".format(infer_times + pre_e_t - pre_s_t) + print(fps_str) + print(resstr) + + flags = os.O_WRONLY | os.O_CREAT | os.O_EXCL + modes = stat.S_IWUSR | stat.S_IRUSR + with os.fdopen(os.open(args.test_file, flags, modes), 'w') as f: + f.write(fps_str) + f.write(resstr) diff --git a/AscendIE/TorchAIE/built-in/audio/Wenet/trace.py b/AscendIE/TorchAIE/built-in/audio/Wenet/trace.py new file mode 100644 index 0000000000000000000000000000000000000000..203f391f734ec28b2aefc19fb8ded9733ebd56a8 --- /dev/null +++ b/AscendIE/TorchAIE/built-in/audio/Wenet/trace.py @@ -0,0 +1,427 @@ +# BSD 3-Clause License +# +# All rights reserved. +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# ============================================================================ + +""" +This script is for testing exported ascend encoder and decoder from +export_onnx_npu.py. The exported ascend models only support batch offline ASR inference. +It requires a python wrapped c++ ctc decoder. +Please install it from ctc decoder in github +""" +from __future__ import print_function + +import argparse +import logging +import sys + +import torch +import yaml + +from wenet.transformer.asr_model import init_asr_model +from wenet.utils.checkpoint import load_checkpoint +from wenet.transformer.ctc import CTC +from wenet.transformer.decoder import TransformerDecoder +from wenet.transformer.encoder import BaseEncoder +from wenet.utils.mask import make_pad_mask + +def get_args(): + parser = argparse.ArgumentParser(description='recognize with your model') + parser.add_argument('--config', required=True, help='config file') + parser.add_argument('--checkpoint', required=True, help='checkpoint model') + parser.add_argument('--encoder_out', required=True, help='Encoder traced .pth path') + parser.add_argument('--decoder_out', required=True, help='Decoder traced .pth path') + parser.add_argument('--batch_size', + type=int, + default=32, + help='used batch size') + parser.add_argument('--num_decoding_left_chunks', + default=5, + type=int, + required=False, + help="number of left chunks, <= 0 is not supported") + parser.add_argument('--streaming', + action='store_true', + help="whether to export streaming encoder, default false") + parser.add_argument('--decoding_chunk_size', + default=16, + type=int, + required=False, + help='the decoding chunk size, <=0 is not supported') + parser.add_argument('--reverse_weight', default=-1.0, type=float, + required=False, + help='reverse weight for bitransformer,' + + 'default value is in config file') + parser.add_argument('--beam_size', default=10, type=int, required=False, + help="beam size would be ctc output size") + parser.add_argument('--output_size', + type=str, + help='only effect in dynamic shapes mode,\ + outputs size info for encoder, please input split by ","') + args_ = parser.parse_args() + print(args_) + return args_ + +class StreamingEncoder(torch.nn.Module): + def __init__(self, model, required_cache_size, beam_size, transformer=False): + super().__init__() + self.ctc = model.ctc + self.subsampling_rate = model.encoder.embed.subsampling_rate + self.embed = model.encoder.embed + self.global_cmvn = model.encoder.global_cmvn + self.required_cache_size = required_cache_size + self.beam_size = beam_size + self.encoder = model.encoder + self.transformer = transformer + + def forward(self, chunk_xs, chunk_lens, offset, + att_cache, cnn_cache, cache_mask): + """Streaming Encoder + Args: + xs (torch.Tensor): chunk input, with shape (b, time, mel-dim), + where `time == (chunk_size - 1) * subsample_rate + \ + subsample.right_context + 1` + offset (torch.Tensor): offset with shape (b, 1) + 1 is retained for triton deployment + required_cache_size (int): cache size required for next chunk + compuation + > 0: actual cache size + <= 0: not allowed in streaming gpu encoder ` + att_cache (torch.Tensor): cache tensor for KEY & VALUE in + transformer/conformer attention, with shape + (b, elayers, head, cache_t1, d_k * 2), where + `head * d_k == hidden-dim` and + `cache_t1 == chunk_size * num_decoding_left_chunks`. + cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer, + (b, elayers, b, hidden-dim, cache_t2), where + `cache_t2 == cnn.lorder - 1` + cache_mask: (torch.Tensor): cache mask with shape (b, required_cache_size) + in a batch of request, each request may have different + history cache. Cache mask is used to indidate the effective + cache for each request + Returns: + torch.Tensor: log probabilities of ctc output and cutoff by beam size + with shape (b, chunk_size, beam) + torch.Tensor: index of top beam size probabilities for each timestep + with shape (b, chunk_size, beam) + torch.Tensor: output of current input xs, + with shape (b, chunk_size, hidden-dim). + torch.Tensor: new attention cache required for next chunk, with + same shape (b, elayers, head, cache_t1, d_k * 2) + as the original att_cache + torch.Tensor: new conformer cnn cache required for next chunk, with + same shape as the original cnn_cache. + torch.Tensor: new cache mask, with same shape as the original + cache mask + """ + offset = offset.squeeze(1) + T = chunk_xs.size(1) + chunk_mask = ~make_pad_mask(chunk_lens, T).unsqueeze(1) + # B X 1 X T + chunk_mask = chunk_mask.to(chunk_xs.dtype) + # transpose batch & num_layers dim + att_cache = torch.transpose(att_cache, 0, 1) + cnn_cache = torch.transpose(cnn_cache, 0, 1) + + # rewrite encoder.forward_chunk + # <---------forward_chunk START---------> + xs = self.global_cmvn(chunk_xs) + # chunk mask is important for batch inferencing since + # different sequence in a batch has different length + xs, pos_emb, chunk_mask = self.embed(xs, chunk_mask, offset) + cache_size = att_cache.size(3) # required cache size + masks = torch.cat((cache_mask, chunk_mask), dim=2) + index = offset - cache_size + + pos_emb = self.embed.position_encoding(index, cache_size + xs.size(1)) + pos_emb = pos_emb.to(dtype=xs.dtype) + + next_cache_start = -self.required_cache_size + r_cache_mask = masks[:, :, next_cache_start:] + + r_att_cache = [] + r_cnn_cache = [] + for i, layer in enumerate(self.encoder.encoders): + xs, _, new_att_cache, new_cnn_cache = layer( + xs, masks, pos_emb, + att_cache=att_cache[i], + cnn_cache=cnn_cache[i]) + # shape(new_att_cache) is (B, head, attention_key_size, d_k * 2), + # shape(new_cnn_cache) is (B, hidden-dim, cache_t2) + r_att_cache.append(new_att_cache[:, :, next_cache_start:, :].unsqueeze(1)) + if not self.transformer: + r_cnn_cache.append(new_cnn_cache.unsqueeze(1)) + if self.encoder.normalize_before: + chunk_out = self.encoder.after_norm(xs) + + r_att_cache = torch.cat(r_att_cache, dim=1) # concat on layers idx + if not self.transformer: + r_cnn_cache = torch.cat(r_cnn_cache, dim=1) # concat on layers + + # <---------forward_chunk END---------> + + log_ctc_probs = self.ctc.log_softmax(chunk_out) + log_probs, log_probs_idx = torch.topk(log_ctc_probs, + self.beam_size, + dim=2) + log_probs = log_probs.to(chunk_xs.dtype) + + r_offset = offset + chunk_out.shape[1] + # the below ops not supported in Tensorrt + # chunk_out_lens = torch.div(chunk_lens, subsampling_rate, + # rounding_mode='floor') + chunk_out_lens = chunk_lens // self.subsampling_rate + r_offset = r_offset.unsqueeze(1) + + return log_probs, log_probs_idx, chunk_out, chunk_out_lens, \ + r_offset, r_att_cache, r_cnn_cache, r_cache_mask + +class Encoder(torch.nn.Module): + def __init__(self, + encoder: BaseEncoder, + ctc: CTC, + beam_size: int = 10): + super().__init__() + self.encoder = encoder + self.ctc = ctc + self.beam_size = beam_size + + def forward(self, speech: torch.Tensor, + speech_lengths: torch.Tensor,): + """Encoder + Args: + speech: (Batch, Length, ...) + speech_lengths: (Batch, ) + Returns: + encoder_out: B x T x F + encoder_out_lens: B + ctc_log_probs: B x T x V + beam_log_probs: B x T x beam_size + beam_log_probs_idx: B x T x beam_size + """ + encoder_out, encoder_mask = self.encoder(speech, + speech_lengths, + -1, -1) + encoder_out_lens = encoder_mask.squeeze(1).sum(1) + ctc_log_probs = self.ctc.log_softmax(encoder_out) + encoder_out_lens = encoder_out_lens.int() + beam_log_probs, beam_log_probs_idx = torch.topk( + ctc_log_probs, self.beam_size, dim=2) + return encoder_out, encoder_out_lens, ctc_log_probs, \ + beam_log_probs, beam_log_probs_idx + +class Decoder(torch.nn.Module): + def __init__(self, + decoder: TransformerDecoder, + ctc_weight: float = 0.5, + reverse_weight: float = 0.3, + beam_size: int = 10): + super().__init__() + self.decoder = decoder + self.ctc_weight = ctc_weight + self.reverse_weight = reverse_weight + self.beam_size = beam_size + + def forward(self, + encoder_out: torch.Tensor, + encoder_lens: torch.Tensor, + hyps_pad_sos_eos: torch.Tensor, + hyps_lens_sos: torch.Tensor, + r_hyps_pad_sos_eos: torch.Tensor, + ctc_score: torch.Tensor): + """Encoder + Args: + encoder_out: B x T x F + encoder_lens: B + hyps_pad_sos_eos: B x beam x (T2+1), + hyps with sos & eos and padded by ignore id + hyps_lens_sos: B x beam, length for each hyp with sos + r_hyps_pad_sos_eos: B x beam x (T2+1), + reversed hyps with sos & eos and padded by ignore id + ctc_score: B x beam, ctc score for each hyp + Returns: + decoder_out: B x beam x T2 x V + r_decoder_out: B x beam x T2 x V + best_index: B + """ + B, T, F = encoder_out.shape + bz = self.beam_size + B2 = B * bz + encoder_out = encoder_out.repeat(1, bz, 1).view(B2, T, F) + encoder_mask = ~make_pad_mask(encoder_lens, T).unsqueeze(1) + encoder_mask = encoder_mask.repeat(1, bz, 1).view(B2, 1, T) + T2 = hyps_pad_sos_eos.shape[2] - 1 + hyps_pad = hyps_pad_sos_eos.view(B2, T2 + 1) + hyps_lens = hyps_lens_sos.view(B2,) + hyps_pad_sos = hyps_pad[:, :-1].contiguous() + hyps_pad_eos = hyps_pad[:, 1:].contiguous() + + r_hyps_pad = r_hyps_pad_sos_eos.view(B2, T2 + 1) + r_hyps_pad_sos = r_hyps_pad[:, :-1].contiguous() + r_hyps_pad_eos = r_hyps_pad[:, 1:].contiguous() + + decoder_out, r_decoder_out, _ = self.decoder( + encoder_out, encoder_mask, hyps_pad_sos, hyps_lens, r_hyps_pad_sos, + self.reverse_weight) + decoder_out = torch.nn.functional.log_softmax(decoder_out, dim=-1) + V = decoder_out.shape[-1] + decoder_out = decoder_out.view(B2, T2, V) + mask = ~make_pad_mask(hyps_lens, T2) # B2 x T2 + # mask index, remove ignore id + index = torch.unsqueeze(hyps_pad_eos * mask, 2) + score = decoder_out.gather(2, index).squeeze(2) # B2 X T2 + # mask padded part + score = score * mask + decoder_out = decoder_out.view(B, bz, T2, V) + if self.reverse_weight > 0: + r_decoder_out = torch.nn.functional.log_softmax(r_decoder_out, dim=-1) + r_decoder_out = r_decoder_out.view(B2, T2, V) + index = torch.unsqueeze(r_hyps_pad_eos * mask, 2) + r_score = r_decoder_out.gather(2, index).squeeze(2) + r_score = r_score * mask + score = score * (1 - self.reverse_weight) + self.reverse_weight * r_score + r_decoder_out = r_decoder_out.view(B, bz, T2, V) + score = torch.sum(score, axis=1) # B2 + score = torch.reshape(score, (B, bz)) + self.ctc_weight * ctc_score + best_index = torch.argmax(score, dim=1) + return best_index + +def export_online_encoder(args, model): + batch_size = args.batch_size + decoding_chunk_size = args.decoding_chunk_size + subsampling = model.encoder.embed.subsampling_rate + context = model.encoder.embed.right_context + 1 + decoding_window = (decoding_chunk_size - 1) * subsampling + context + audio_len = decoding_window + feature_size = configs["input_dim"] + output_size = configs["encoder_conf"]["output_size"] + num_layers = configs["encoder_conf"]["num_blocks"] + # in transformer the cnn module will not be available + transformer = False + cnn_module_kernel = configs["encoder_conf"].get("cnn_module_kernel", 1) - 1 + if not cnn_module_kernel: + transformer = True + num_decoding_left_chunks = args.num_decoding_left_chunks + required_cache_size = decoding_chunk_size * num_decoding_left_chunks + encoder = StreamingEncoder(model, required_cache_size, args.beam_size, transformer) + encoder.eval() + + chunk_xs = torch.randn(batch_size, audio_len, feature_size, dtype=torch.float32) + chunk_lens = torch.ones(batch_size, dtype=torch.int32) * audio_len + + offset = torch.arange(0, batch_size).unsqueeze(1) + # (elayers, b, head, cache_t1, d_k * 2) + head = configs["encoder_conf"]["attention_heads"] + d_k = configs["encoder_conf"]["output_size"] // head + att_cache = torch.randn(batch_size, num_layers, head, + required_cache_size, d_k * 2, + dtype=torch.float32) + cnn_cache = torch.randn(batch_size, num_layers, output_size, + cnn_module_kernel, dtype=torch.float32) + + cache_mask = torch.ones(batch_size, 1, required_cache_size, dtype=torch.float32) + + inputs = [chunk_xs, chunk_lens, offset, att_cache, cnn_cache, cache_mask] + print('[INFO] Started to trace the online encoder model') + traced = torch.jit.trace(encoder, inputs) + + print('[INFO] Success! Online encoder was traced') + traced.save(args.encoder_out) + +def export_offline_encoder(args, model): + bz = args.batch_size + seq_len = 1478 + beam_size = 10 + feature_size = configs["input_dim"] + + speech = torch.randn(bz, seq_len, feature_size, dtype=torch.float32) + speech_lens = torch.randint(low=10, high=seq_len, size=(bz,), dtype=torch.int32) + encoder = Encoder(model.encoder, model.ctc, beam_size) + encoder.eval() + + inputs = [speech, speech_lens] + print('[INFO] Started to trace the offline encoder model') + traced = torch.jit.trace(encoder, inputs) + + print('[INFO] Success! Offline encoder was traced') + traced.save(args.encoder_out) + + +def export_decoder(args, model): + bz, seq_len = args.batch_size, 100 + beam_size = 10 + decoder = Decoder(model.decoder, + model.ctc_weight, + model.reverse_weight, + beam_size) + decoder.eval() + + hyps_pad_sos_eos = torch.randint(low=3, high=1000, size=(bz, beam_size, seq_len)) + hyps_lens_sos = torch.randint(low=3, high=seq_len, size=(bz, beam_size), + dtype=torch.int32) + r_hyps_pad_sos_eos = torch.randint(low=3, high=1000, size=(bz, beam_size, seq_len)) + + output_size = configs["encoder_conf"]["output_size"] + encoder_out = torch.randn(bz, seq_len, output_size, dtype=torch.float32) + encoder_out_lens = torch.randint(low=3, high=seq_len, size=(bz,), dtype=torch.int32) + ctc_score = torch.randn(bz, beam_size, dtype=torch.float32) + + inputs = [encoder_out, + encoder_out_lens, + hyps_pad_sos_eos, + hyps_lens_sos, + r_hyps_pad_sos_eos, + ctc_score] + print('[INFO] Started to trace the decoder model') + traced = torch.jit.trace(decoder, inputs) + print('[INFO] Success! The decoder was traced') + traced.save(args.decoder_out) + + +if __name__ == '__main__': + args = get_args() + torch.manual_seed(0) + torch.set_printoptions(precision=10) + + with open(args.config, 'r') as fin: + configs = yaml.load(fin, Loader=yaml.FullLoader) + if args.reverse_weight != -1.0 and 'reverse_weight' in configs['model_conf']: + configs['model_conf']['reverse_weight'] = args.reverse_weight + print("Update reverse weight to", args.reverse_weight) + configs["encoder_conf"]["use_dynamic_chunk"] = False + + model = init_asr_model(configs) + load_checkpoint(model, args.checkpoint) + model.eval() + if args.streaming: + export_online_encoder(args, model) + else: + export_offline_encoder(args, model) + export_decoder(args, model) +