diff --git a/AscendIE/TorchAIE/built-in/audio/Wenet/README.md b/AscendIE/TorchAIE/built-in/audio/Wenet/README.md
new file mode 100755
index 0000000000000000000000000000000000000000..c697a3dbf28a475c957e2be049a5ebc642e24dfc
--- /dev/null
+++ b/AscendIE/TorchAIE/built-in/audio/Wenet/README.md
@@ -0,0 +1,298 @@
+# Wenet模型-推理指导
+
+
+- [概述](#ZH-CN_TOPIC_0000001172161501)
+
+ - [输入输出数据](#section540883920406)
+
+- [推理环境准备](#ZH-CN_TOPIC_0000001126281702)
+
+- [快速上手](#ZH-CN_TOPIC_0000001126281700)
+
+ - [获取源码](#section4622531142816)
+ - [准备数据集](#section183221994411)
+ - [模型推理](#section741711594517)
+
+- [模型推理性能&精度](#ZH-CN_TOPIC_0000001172201573)
+
+
+# 概述
+
+Wenet模型是一个使用Conformer结构的ASR(语音识别)模型,具有较好的端到端推理精度和推理性能。
+
+- 参考实现:
+
+ ```
+ url=https://github.com/wenet-e2e/wenet.git
+ branch=v2.0.1
+ model_name=Wenet
+ ```
+
+
+## 输入输出数据
+
+- encoder online输入数据
+
+ | 输入数据 | 大小 | 数据类型 | 数据排布格式 |
+ | -------- | -------- | ------------------------- | ------------ |
+ | chunk_xs | B x 67 x 80 | FLOAT32 | ND |
+ | chunk_lens | B | INT32 | ND |
+ | offset | B x 1 | INT64 | ND |
+ | att_cache | B x 12 x 4 x 64 x 128 | FLOAT32 | ND |
+ | cnn_cache | B x 12 x 256 x 7 | FLOAT32 | ND |
+ | cache_mask | B x 1 x 64 | FLOAT32 | ND |
+
+
+- encoder online输出数据
+
+ | 输出数据 | 大小 | 数据类型 | 数据排布格式 |
+ | -------- | -------- | -------- | ------------ |
+ | log_probs | batchsize x Castlog_probs_dim_1 x 10 | FLOAT32 | ND |
+ | log_probs_idx | batchsize x Castlog_probs_dim_1 x 10 | INT64 | ND |
+ | chunk_out | batchsize x Castlog_probs_dim_1 x 256 | FLOAT32 | ND |
+ | chunk_out_lens | batchsize | INT32 | ND |
+ | r_offset | batchsize x 1 | INT64 | ND |
+ | r_att_cache | batchsize x 12 x dim_2 x dim_3 x dim_4 | FLOAT32 | ND |
+ | r_cnn_cache | batchsize x 12 x 256 x dim_5 | FLOAT32 | ND |
+ | r_cache_mask | batchsize x 1 x 64 | FLOAT32 | ND |
+
+- encoder offline输入数据
+
+ | 输入数据 | 大小 | 数据类型 | 数据排布格式 |
+ | -------- | -------- | ------------------------- | ------------ |
+ | speech | batchsize x T x 80 | FLOAT32 | ND |
+ | speech_lengths | batchsize | INT32 | ND |
+
+
+- encoder offline输出数据
+
+ | 输出数据 | 大小 | 数据类型 | 数据排布格式 |
+ | -------- | -------- | -------- | ------------ |
+ | encoder_out | batchsize x T_out x 256 | FLOAT32 | ND |
+ | encoder_out_lens | batchsize | INT32 | ND |
+ | ctc_log_probs | batchsize x T_OUT x 4233 | FLOAT32 | ND |
+ | beam_log_probs | batchsize x T_OUT x 10 | FLOAT32 | ND |
+ | beam_log_probs_idx | batchsize x T_OUT x 10 | INT64 | ND |
+
+
+- decoder输入数据
+
+ | 输入数据 | 大小 | 数据类型 | 数据排布格式 |
+ | -------- | -------- | ------------------------- | ------------ |
+ | encoder_out | batchsize x T x 256 | FLOAT32 | ND |
+ | encoder_out_lens | batchsize | INT32 | ND |
+ | hyps_pad_sos_eos | batchsize x 10 x T2 | INT64 | ND |
+ | hyps_lens_sos | batchsize x 10 | INT32 | ND |
+ | r_hyps_pad_sos_eos | batchsize x 10 x T2 | INT64 | ND |
+ | ctc_score | batchsize x 10 | FLOAT32 | ND |
+
+
+- decoder输出数据
+
+ | 输出数据 | 大小 | 数据类型 | 数据排布格式 |
+ | -------- | -------- | -------- | ------------ |
+ | best_index | batchsize | INT64 | ND |
+
+
+# 推理环境准备
+
+- 该模型需要以下插件与驱动
+
+ **表 1** 版本配套表
+
+ | 配套 | 版本 | 环境准备指导 |
+ |---------| ------- | ------------------------------------------------------------ |
+ | 固件与驱动 | 22.0.4 | [Pytorch框架推理环境准备](https://www.hiascend.com/document/detail/zh/ModelZoo/pytorchframework/pies) |
+ | CANN | 6.3.RC1 | - |
+ | Python | 3.7.5 | - |
+ | PyTorch | 1.8.0 | - |
+ | Torch_AIE | 6.3.rc2 | \ |
+ | 说明:Atlas 300I Duo 推理卡请以CANN版本选择实际固件与驱动版本。 | \ | \ |
+
+
+# 快速上手
+
+## 获取源码
+
+1. 获取源码。
+
+ 在工作目录下执行下述命令获取源码并切换到相应路径。
+
+ ```
+ git clone https://github.com/wenet-e2e/wenet.git
+ cd wenet
+ git checkout v2.0.1
+ export wenet_path=$(pwd)
+ cd ..
+ ```
+
+2. 安装依赖。
+
+ ```
+ pip3 install -r requirements.txt
+ ```
+
+3. 安装ctc_dcoder
+
+ ```
+ git clone https://github.com/Slyne/ctc_decoder.git
+ apt-get update
+ apt-get install swig
+ apt-get install python3-dev
+ cd ctc_decoder/swig && bash setup.sh
+ ```
+
+
+
+## 准备数据集
+
+1. 获取原始数据集。
+
+ ```
+ cd wenet/examples/aishell/s0/
+ mkdir -p /export/data/asr-data/OpenSLR/33
+ bash run.sh --stage -1 --stop_stage -1 # 下载数据集
+ ```
+
+2. 处理数据集。
+
+ ```
+ bash run.sh --stage 0 --stop_stage 0
+ bash run.sh --stage 1 --stop_stage 1
+ bash run.sh --stage 2 --stop_stage 2
+ bash run.sh --stage 3 --stop_stage 3
+ ```
+
+3. Copy generated data.list file to ModelZoo Wenet working directory
+ ```
+ cp data/test/data.list
+ ```
+
+
+## 模型推理
+
+1. 模型转换。
+
+ Use PyTorch to script the model, and then use Torch AIE to convert the model, serialize it and perform inference.
+
+ 1. 获取权重文件。
+
+ WeNet预训练模型[下载链接](https://github.com/wenet-e2e/wenet/blob/main/docs/pretrained_models.md),选择aishell数据集对应的Checkpoint Model下载,将压缩文件放到自己克隆wenet仓库的路径下解压。
+ ```
+ # Unpack and copy pretrained checkpoint model to working directory
+ cd
+ tar -zvxf aishell_u2++_conformer_exp.tar.gz
+ mkdir -p /exp/20210601_u2++_conformer_exp
+ cp -r aishell_u2++_conformer_exp/* /exp/20210601_u2++_conformer_exp
+ cd
+ ```
+
+ 2. Tracing the model
+
+ Use trace.py to trace and save the .pth file.
+
+ 配置python path
+ ```
+ export PYTHONPATH=$wenet_path
+ ```
+
+ ```
+ python trace.py --config=exp/20210601_u2++_conformer_exp/train.yaml --batch_size=${batch_size} --checkpoint=exp/20210601_u2++_conformer_exp/final.pt --encoder_out=encoder.pth --decoder_out=decoder.pth --reverse_weight 0.3
+ ```
+
+ Add `--streaming` flag to obtain online encoder model.
+
+ 3. Use the Torch AIE to compile the models.
+
+ ```
+ python compile_model.py --encoder=encoder.pth --encoder_out=encoder_aie.pth --decoder=decoder.pth --decoder_out=decoder_aie.pth --batch_size=${batch_size} --encoder_gears="262, 326, 390, 454, 518, 582, 646, 710, 774, 838, 902, 966, 1028, 1284, 1478" --decoder_gears="5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44"
+ ```
+
+ - 参数说明:
+
+ - --encoder:path to input torch encoder model
+ - --encoder_out:path to store the compiled encoder model
+ - --encoder_gears:encoder dynamic dimensions
+ - --decoder:path to input torch decoder model
+ - --decoder_out:path to store the compiled decoder model
+ - --encoder_gears:decoder dynamic dimensions
+ - --soc\_version:处理器型号。
+
+ Add `--streaming` command line argument to compile the online encoder model.
+
+2. 开始推理验证。
+
+ 1. 非流式分档场景精度和性能验证。
+
+ 端到端encoder + decoder,多进程
+
+ ```
+ python recognize_aie.py --config=exp/20210601_u2++_conformer_exp/train.yaml --test_data=data.list --dict=exp/20210601_u2++_conformer_exp/units.txt --mode=attention_rescoring --result_file=aie_result${batch_size}.log --encoder_aie=encoder_aie.pth --decoder_aie=decoder_aie.pth --batch_size=${batch_size} --device_id=0 --test_file=static_test_result_bs32_1.txt --encoder_gears="646, 710, 774, 838, 902, 966, 1028, 1284, 1478"
+ ```
+
+ - 参数说明:
+ - --config:aishell预训练模型配置文件路径。
+ - --test_data:测试数据路径。
+ - --dict:aishell预训练模型词典路径。
+ - --mode:解码模式,可选ctc_greedy_search、ctc_prefix_beam_search和attention_rescoring。
+ - --result_file:解码结果文件。
+ - --encoder_aie:path to compiled encoder model.
+ - --decoder_aie:path to compiled decoder model.
+ - --batch_size:batch大小。
+ - --device_id:卡序号。。
+ - --test_file:性能结果文件。
+ - --encoder_gears: encoder dynamic dims.
+
+ ### Pay attention
+ `max_length` setting should be adjusted whenever using a batch of size 4 or 1. Please, adjust this parameter by changing the config value in `recognize_aie.py:142` script to:
+ ```
+ test_conf_['filter_conf']['max_length'] = 700
+ ```
+
+ ```
+ # 精度验证
+ python tools/compute-wer.py --char=1 --v=1 exp/20210601_u2++_conformer_exp/text aie_result${batch_size}.log
+ ```
+
+ - 参数说明:
+ - --char:是否逐词比对,0为整句比对,1为逐词比对。
+ - --v:是否打印对比结果,0为不打印,1为打印。
+ - aishell_u2pp_conformer_exp/text:标签文件路径。
+ - static_result.txt:比对结果输出文件路径。
+
+ 2. 流式纯推理场景精度和性能验证。
+ Please, use the following script to check performance results of streaming encoder model and check it's accuracy.
+ ```
+ python cosine_similarity.py --encoder_torch=torch_online_encoder_b${batch_size}.pth --encoder_aie=aie_online_encoder_b${batch_size}.pth --batch_size=${batch_size} --device_id=1
+ ```
+ - 参数说明:
+ - --encoder_torch:Path to torch encoder streaming model.
+ - --encoder_aie:Path to AIE encoder streaming model.
+ - --batch_size:batch大小。
+ - --device_id:卡序号。
+
+# 模型推理性能&精度
+
+性能参考下列数据。
+
+非流式分档(encoder + decoder)场景,进程数1。
+
+| 芯片型号 | Batch Size | 数据集 | 精度(WER) | 端到端性能(fps) |
+|-------------|------------|-----------|---------- |------------------|
+| Ascend310P3 | 1 | aishell | 4.60% | 39.55 |
+| Ascend310P3 | 4 | aishell | 4.60% | 53.38 |
+| Ascend310P3 | 8 | aishell | 4.60% | 51.87 |
+| Ascend310P3 | 16 | aishell | 4.60% | 49.66 |
+| Ascend310P3 | 32 | aishell | 4.60% | 44.11 |
+| Ascend310P3 | 64 | aishell | 4.60% | 41.23 |
+
+流式纯推理场景
+
+| 芯片型号 | Batch Size | om与torch余弦相似度 | 性能 |
+|---------------|------------|-------------------|----------------|
+| Ascend310P3 | 1 | 0.9999992 | 593.74 fps |
+| Ascend310P3 | 4 | 0.9999999 | 1219.60 fps |
+| Ascend310P3 | 8 | 0.9999999 | 1635.93 fps |
+| Ascend310P3 | 16 | 0.9999998 | 1966.44 fps |
+| Ascend310P3 | 32 | 0.99999946 | 2098.52 fps |
+| Ascend310P3 | 64 | 0.99999976 | 2079.44 fps |
diff --git a/AscendIE/TorchAIE/built-in/audio/Wenet/compile_model.py b/AscendIE/TorchAIE/built-in/audio/Wenet/compile_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..8dda1405370d8a939e3e3bf97c2834e8b523dbde
--- /dev/null
+++ b/AscendIE/TorchAIE/built-in/audio/Wenet/compile_model.py
@@ -0,0 +1,146 @@
+# BSD 3-Clause License
+#
+# All rights reserved.
+# Copyright 2023 Huawei Technologies Co., Ltd
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# * Neither the name of the copyright holder nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+# ============================================================================
+
+"""
+This script is for testing exported ascend encoder and decoder from
+export_onnx_npu.py. The exported ascend models only support batch offline ASR inference.
+It requires a python wrapped c++ ctc decoder.
+Please install it from ctc decoder in github
+"""
+from __future__ import print_function
+
+import argparse
+import logging
+import sys
+
+import torch_aie
+import torch
+
+def get_args():
+ parser = argparse.ArgumentParser(description='recognize with your model')
+ parser.add_argument('--batch_size',
+ type=int,
+ default=32,
+ help='used batch size')
+ parser.add_argument('--streaming',
+ action='store_true',
+ help="whether to export streaming encoder, default false")
+ parser.add_argument('--encoder',
+ type=str,
+ required=True,
+ help='path to input traced encoder .pth model')
+ parser.add_argument('--encoder_out',
+ type=str,
+ required=True,
+ help='path to store compiled encoder .pth model')
+ parser.add_argument('--decoder',
+ type=str,
+ required=True,
+ help='path to input traced decoder .pth model')
+ parser.add_argument('--decoder_out',
+ type=str,
+ required=True,
+ help='path to store compiled decoder .pth model')
+ parser.add_argument('--encoder_gears',
+ type=str,
+ required=True,
+ help='dynamic dims gears info for encoder, please input split by ","')
+ parser.add_argument('--decoder_gears',
+ type=str,
+ required=True,
+ help='dynamic dims gears info for encoder, please input split by ","')
+ args_ = parser.parse_args()
+ print(args_)
+ return args_
+
+def get_dynamic_dims_encoder(batch_size, args):
+ inputs = []
+ for dim in list(map(int, args.encoder_gears.split(','))):
+ inputs.append([torch_aie.Input((batch_size, dim, 80)),
+ torch_aie.Input([batch_size], dtype=torch_aie.dtype.INT32)])
+ return inputs
+
+def compile_offline_encoder(args):
+ bz = args.batch_size
+ traced = torch.jit.load(args.encoder)
+ print('[INFO] Started to compile the offline encoder using TorchAIE')
+ encoder_compiled = torch_aie.compile(traced, inputs=get_dynamic_dims_encoder(bz, args), soc_version="Ascend310P3")
+ print('[INFO] Success! The offline encoder was compiled!')
+ encoder_compiled.save(args.encoder_out)
+
+def compile_online_encoder(args):
+ batch_size = args.batch_size
+ traced = torch.jit.load(args.encoder)
+
+ inputs = [
+ torch_aie.Input((batch_size, 67, 80)), # chunk_xs
+ torch_aie.Input([batch_size], dtype=torch_aie.dtype.INT32), # chunk_lens
+ torch_aie.Input((batch_size, 1), dtype=torch_aie.dtype.INT32), # offset
+ torch_aie.Input((batch_size, 12, 4, 64, 128)), # to_cache
+ torch_aie.Input((batch_size, 12, 256, 7)), # cnn_cache
+ torch_aie.Input((batch_size, 1, 64)), # cache_mask
+ ]
+
+ print('[INFO] Started to compile the online encoder using TorchAIE')
+ encoder_compiled = torch_aie.compile(traced, inputs=inputs, soc_version="Ascend310P3")
+ print('[INFO] Success! The online encoder was compiled!')
+ encoder_compiled.save(args.encoder_out)
+
+def get_dynamic_dims_decoder(batch_size, args):
+ inputs = []
+
+ for dim in list(map(int, args.decoder_gears.split(','))):
+ inputs.append(
+ [torch_aie.Input(shape=[batch_size, 384, 256], dtype=torch_aie.dtype.FLOAT),
+ torch_aie.Input(shape=[batch_size], dtype=torch_aie.dtype.INT32),
+ torch_aie.Input(shape=[batch_size, 10, dim], dtype=torch_aie.dtype.INT64),
+ torch_aie.Input(shape=[batch_size, 10], dtype=torch_aie.dtype.INT32),
+ torch_aie.Input(shape=[batch_size, 10, dim], dtype=torch_aie.dtype.INT64),
+ torch_aie.Input(shape=[batch_size, 10], dtype=torch_aie.dtype.FLOAT)])
+ return inputs
+
+def compile_decoder(args):
+ bz = args.batch_size
+ decoder = torch.jit.load(args.decoder)
+ print('[INFO] Started to compile the decoder using TorchAIE')
+ decoder_compiled = torch_aie.compile(decoder, inputs=get_dynamic_dims_decoder(bz, args), soc_version="Ascend310P3")
+ print('[INFO] Success! The decoder was compiled!')
+ decoder_compiled.save(args.decoder_out)
+
+
+if __name__ == '__main__':
+ args = get_args()
+ if args.streaming:
+ compile_online_encoder(args)
+ else:
+ compile_offline_encoder(args)
+ compile_decoder(args)
+
diff --git a/AscendIE/TorchAIE/built-in/audio/Wenet/cosine_similarity.py b/AscendIE/TorchAIE/built-in/audio/Wenet/cosine_similarity.py
new file mode 100644
index 0000000000000000000000000000000000000000..fad4e657255a7557706dca356fe9d0178e2144e5
--- /dev/null
+++ b/AscendIE/TorchAIE/built-in/audio/Wenet/cosine_similarity.py
@@ -0,0 +1,97 @@
+# Copyright 2023 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+import argparse
+import time
+
+import torch
+import torch_aie
+import numpy as np
+
+def cos_sim(a, b):
+ b_norm = np.linalg.norm(b.numpy())
+ a_norm = np.linalg.norm(a.numpy())
+ cos = np.dot(a, b)/(a_norm * b_norm)
+ return cos
+
+def measure(model, dummpy_input):
+ chunk_xs , chunk_lens, offset, att_cache, cnn_cache, cache_mask = dummpy_input
+ npu_str = 'npu:' + str(args.device_id)
+ chunk_xs = chunk_xs.to(npu_str)
+ chunk_lens = chunk_lens.to(npu_str)
+ offset = offset.to(npu_str)
+ att_cache = att_cache.to(npu_str)
+ cnn_cache = cnn_cache.to(npu_str)
+ cache_mask = cache_mask.to(npu_str)
+
+
+ # Do a warmup
+ print('Doing a warmup')
+ _ = model(chunk_xs , chunk_lens, offset, att_cache, cnn_cache, cache_mask)
+ _ = model(chunk_xs , chunk_lens, offset, att_cache, cnn_cache, cache_mask)
+ _ = model(chunk_xs , chunk_lens, offset, att_cache, cnn_cache, cache_mask)
+
+ print('Run inference...')
+ s = time.time_ns()
+ n = 1000
+ for i in range(n):
+ _ = model(chunk_xs , chunk_lens, offset, att_cache, cnn_cache, cache_mask)
+ ms = (time.time_ns()-s) / n / 10**6
+ print('Mean compute time : {} (ms)'.format(ms))
+ return ms, _
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(description='get cosine similaryty with your model')
+ parser.add_argument('--encoder_torch', required=True, help='encoder onnx file')
+ parser.add_argument('--encoder_aie', required=True, help='encoder om file')
+ parser.add_argument('--batch_size', required=True, type=int, help='batch size')
+ parser.add_argument('--device_id', default=0, type=int, help='device id')
+ parser.add_argument('--decoding_chunk_size', default=16, type=int,
+ help='decoding chunk size, <=0 is not supported')
+ parser.add_argument('--num_decoding_left_chunks',
+ default=4,
+ type=int,
+ required=False,
+ help="number of left chunks, <= 0 is not supported")
+ args = parser.parse_args()
+ print(args)
+
+ torch_aie.set_device(args.device_id)
+ required_cache_size = args.decoding_chunk_size * args.num_decoding_left_chunks
+
+
+ chunk_xs = torch.from_numpy(np.random.random((args.batch_size, 67, 80)).astype("float32"))
+ chunk_lens = torch.from_numpy(np.array([600]*args.batch_size).astype("int32"))
+ offset = torch.from_numpy(np.array([0]*args.batch_size).reshape((args.batch_size, 1)).astype("int32"))
+ att_cache = torch.from_numpy(np.random.random((args.batch_size, 12, 4, required_cache_size, 128)).astype("float32"))
+ cnn_cache = torch.from_numpy(np.random.random((args.batch_size, 12, 256, 7)).astype("float32"))
+ cache_mask = torch.from_numpy(np.random.random((args.batch_size, 1, required_cache_size)).astype("float32"))
+
+ aie_net = torch.jit.load(args.encoder_aie)
+ torch_net = torch.jit.load(args.encoder_torch)
+
+ aie_net.eval()
+ torch_net.eval()
+ time_ms = 0
+ with torch.no_grad():
+ time_ms, output_data = measure(aie_net, (chunk_xs, chunk_lens, offset, att_cache, cnn_cache, cache_mask))
+ output_data = output_data[0].to('cpu')
+ y = output_data.flatten()
+ torch_output = torch_net(chunk_xs, chunk_lens, offset, att_cache, cnn_cache, cache_mask)
+ torch_y = torch_output[0].flatten()
+ cos_1 = cos_sim(y, torch_y)
+ print("acc: ", cos_1)
+
+ throughput = 1000 * args.batch_size / time_ms
+ print("throughput 1000*batchsize.mean({})/NPU_compute_time.mean({}): {}".format(args.batch_size, time_ms, throughput))
diff --git a/AscendIE/TorchAIE/built-in/audio/Wenet/recognize_aie.py b/AscendIE/TorchAIE/built-in/audio/Wenet/recognize_aie.py
new file mode 100644
index 0000000000000000000000000000000000000000..1907218187b898e20bd8fafe4d0461f8c01ba208
--- /dev/null
+++ b/AscendIE/TorchAIE/built-in/audio/Wenet/recognize_aie.py
@@ -0,0 +1,400 @@
+# BSD 3-Clause License
+#
+# All rights reserved.
+# Copyright 2023 Huawei Technologies Co., Ltd
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# * Neither the name of the copyright holder nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+# ============================================================================
+
+"""
+This script is for testing exported ascend encoder and decoder from
+export_onnx_npu.py. The exported ascend models only support batch offline ASR inference.
+It requires a python wrapped c++ ctc decoder.
+Please install it from ctc decoder in github
+"""
+from __future__ import print_function
+
+import argparse
+import copy
+import logging
+import os
+import sys
+import time
+import stat
+
+import multiprocessing
+import numpy as np
+import torch
+import yaml
+from torch.utils.data import DataLoader
+import torch_aie
+
+from wenet.dataset.dataset import Dataset
+from wenet.utils.common import IGNORE_ID
+from wenet.utils.file_utils import read_symbol_table
+from wenet.utils.config import override_config
+from swig_decoders import map_batch, \
+ ctc_beam_search_decoder_batch, \
+ TrieVector, PathTrie
+
+
+def get_args():
+ parser = argparse.ArgumentParser(description='recognize with your model')
+ parser.add_argument('--config', required=True, help='config file')
+ parser.add_argument('--test_data', required=True, help='test data file')
+ parser.add_argument('--data_type',
+ default='raw',
+ choices=['raw', 'shard'],
+ help='train and cv data type')
+
+ parser.add_argument('--dict', required=True, help='dict file')
+ parser.add_argument('--encoder_aie', required=True,
+ help='encoder om file')
+ parser.add_argument('--decoder_aie', required=True,
+ help='decoder om file')
+ parser.add_argument('--result_file', required=True, help='asr result file')
+ parser.add_argument('--test_file', required=True, help='asr result file')
+ parser.add_argument('--batch_size',
+ type=int,
+ default=32,
+ help='asr result file')
+ parser.add_argument('--device_id',
+ type=int,
+ default=0,
+ help='npu device id')
+ parser.add_argument('--mode',
+ choices=[
+ 'ctc_greedy_search', 'ctc_prefix_beam_search',
+ 'attention_rescoring'],
+ default='attention_rescoring',
+ help='decoding mode')
+ parser.add_argument('--bpe_model',
+ default=None,
+ type=str,
+ help='bpe model for english part')
+ parser.add_argument('--override_config',
+ action='append',
+ default=[],
+ help="override yaml config")
+ parser.add_argument('--fp16',
+ action='store_true',
+ help='whether to export fp16 model, default false')
+ parser.add_argument('--static',
+ action='store_true',
+ help='whether to run static model')
+ parser.add_argument('--num_process',
+ type=int,
+ default=2,
+ help='number of mutiprocesses')
+ parser.add_argument('--encoder_gears',
+ type=str,
+ help='dynamic dims gears info for encoder, please input split by ","')
+ parser.add_argument('--decoder_gears',
+ type=str,
+ help='dynamic dims gears info for encoder, please input split by ","')
+ parser.add_argument('--output_size',
+ type=str,
+ help='only effect in dynamic shapes mode,\
+ outputs size info for encoder, please input split by ","')
+ args_ = parser.parse_args()
+ print(args_)
+ return args_
+
+def get_dict(dict_path):
+ vocabulary_ = []
+ char_dict_ = {}
+ with open(dict_path, 'r') as fin_:
+ for line in fin_:
+ arr = line.strip().split()
+ if len(arr) != 2:
+ print('dict format is incorrect')
+ exit(0)
+ char_dict_[int(arr[1])] = arr[0]
+ vocabulary_.append(arr[0])
+ return vocabulary_, char_dict_
+
+def adjust_test_conf(test_conf_, batch_size_):
+ # adjust dataset parameters for om
+ # reserved suitable memory
+ test_conf_['filter_conf']['max_length'] = 1028
+ test_conf_['filter_conf']['min_length'] = 0
+ test_conf_['filter_conf']['token_max_length'] = 1028
+ test_conf_['filter_conf']['token_min_length'] = 0
+ test_conf_['filter_conf']['max_output_input_ratio'] = 1028
+ test_conf_['filter_conf']['min_output_input_ratio'] = 0
+ test_conf_['speed_perturb'] = False
+ test_conf_['spec_aug'] = False
+ test_conf_['shuffle'] = False
+ test_conf_['sort'] = False
+ test_conf_['fbank_conf']['dither'] = 0.0
+ test_conf_['batch_conf']['batch_type'] = "static"
+ test_conf_['batch_conf']['batch_size'] = batch_size_
+ return test_conf_
+
+class AsrModel:
+ def __init__(self, encoder, decoder, args_, reverse_weight_) -> None:
+ self.encoder, self.decoder = encoder, decoder
+ self.vocabulary, self.char_dict = get_dict(args_.dict)
+ self.reverse_weight = reverse_weight_
+ self.mul_shape = list(map(int, args.encoder_gears.split(',')))
+ self.mul_shape_decoder = [384]
+ self.output_size = None
+ if args.output_size:
+ self.output_size = list(map(int, args.output_size.split(',')))
+ self.args = args_
+
+ def forward(self, data):
+ nums_ = 0
+ eos = sos = len(self.char_dict) - 1
+ mode = "dymdims"
+
+ keys, feats, _, feats_lengths, _ = data
+ feats, feats_lengths = feats.numpy(), feats_lengths.numpy()
+ ort_outs = None
+ pad_size = 0
+ pad_batch = 0
+ for n in self.mul_shape:
+ if n > feats.shape[1]:
+ pad_size = n - feats.shape[1]
+ break
+ if feats.shape[0] < self.args.batch_size:
+ pad_batch = self.args.batch_size - feats.shape[0]
+ feats_lengths = np.pad(feats_lengths, [(0, pad_batch)], 'constant')
+ feats_pad = np.pad(feats, [(0, pad_batch), (0, pad_size), (0, 0)], 'constant')
+
+ feats_pad = torch.from_numpy(feats_pad).to("npu:1")
+ feats_lengths = torch.from_numpy(feats_lengths).to("npu:1")
+ ort_outs = self.encoder.forward(
+ feats_pad, feats_lengths)
+ encoder_out, encoder_out_lens, _, \
+ beam_log_probs, beam_log_probs_idx = ort_outs
+ beam_size = beam_log_probs.shape[-1]
+ batch_size = beam_log_probs.shape[0]
+ num_processes = min(multiprocessing.cpu_count(), batch_size)
+ if self.args.mode == 'ctc_greedy_search':
+ if beam_size != 1:
+ log_probs_idx = beam_log_probs_idx[:, :, 0]
+ batch_sents = []
+ for idx_, seq in enumerate(log_probs_idx):
+ batch_sents.append(seq[0:encoder_out_lens[idx_]].tolist())
+ hyps = map_batch(batch_sents, self.vocabulary, num_processes,
+ True, 0)
+ elif self.args.mode in ('ctc_prefix_beam_search', "attention_rescoring"):
+ batch_log_probs_seq_list = beam_log_probs.tolist()
+ batch_log_probs_idx_list = beam_log_probs_idx.tolist()
+ batch_len_list = encoder_out_lens.tolist()
+ batch_log_probs_seq = []
+ batch_log_probs_ids = []
+ batch_start = [] # only effective in streaming deployment
+ batch_root = TrieVector()
+ root_dict = {}
+ for i in range(len(batch_len_list)):
+ num_sent = batch_len_list[i]
+ batch_log_probs_seq.append(
+ batch_log_probs_seq_list[i][0:num_sent])
+ batch_log_probs_ids.append(
+ batch_log_probs_idx_list[i][0:num_sent])
+ root_dict[i] = PathTrie()
+ batch_root.append(root_dict[i])
+ batch_start.append(True)
+ score_hyps = ctc_beam_search_decoder_batch(batch_log_probs_seq,
+ batch_log_probs_ids,
+ batch_root,
+ batch_start,
+ beam_size,
+ num_processes,
+ 0, -2, 0.99999)
+ if self.args.mode == 'ctc_prefix_beam_search':
+ hyps = []
+ for cand_hyps in score_hyps:
+ hyps.append(cand_hyps[0][1])
+ hyps = map_batch(hyps, self.vocabulary, num_processes, False, 0)
+ if self.args.mode == 'attention_rescoring':
+ ctc_score, all_hyps = [], []
+ max_len = 0
+ for hyps in score_hyps:
+ cur_len = len(hyps)
+ if len(hyps) < beam_size:
+ hyps += (beam_size - cur_len) * ((-float("INF"), (0,)),)
+ cur_ctc_score = []
+ for hyp in hyps:
+ cur_ctc_score.append(hyp[0])
+ all_hyps.append(list(hyp[1]))
+ if len(hyp[1]) > max_len:
+ max_len = len(hyp[1])
+ ctc_score.append(cur_ctc_score)
+ if self.args.fp16:
+ ctc_score = np.array(ctc_score, dtype=np.float16)
+ else:
+ ctc_score = np.array(ctc_score, dtype=np.float32)
+ hyps_pad_sos_eos = np.ones(
+ (batch_size, beam_size, max_len + 2), dtype=np.int64) * IGNORE_ID
+ r_hyps_pad_sos_eos = np.ones(
+ (batch_size, beam_size, max_len + 2), dtype=np.int64) * IGNORE_ID
+ hyps_lens_sos = np.ones(
+ (batch_size, beam_size), dtype=np.int32)
+ k = 0
+ for i in range(batch_size):
+ for j in range(beam_size):
+ cand = all_hyps[k]
+ l = len(cand) + 2
+ hyps_pad_sos_eos[i][j][0:l] = [sos] + cand + [eos]
+ r_hyps_pad_sos_eos[i][j][0:l] = [
+ sos] + cand[::-1] + [eos]
+ hyps_lens_sos[i][j] = len(cand) + 1
+ k += 1
+ best_index = None
+ encoder_out_lens = encoder_out_lens.to("npu:1")
+ hyps_pad_sos_eos = torch.from_numpy(hyps_pad_sos_eos).to("npu:1")
+ hyps_lens_sos = torch.from_numpy(hyps_lens_sos).to("npu:1")
+ r_hyps_pad_sos_eos = torch.from_numpy(r_hyps_pad_sos_eos).to("npu:1")
+ ctc_score = torch.from_numpy(ctc_score).to("npu:1")
+ if not self.args.static:
+ output_size = 100000
+
+ if self.reverse_weight > 0:
+ best_index = self.decoder.forward(
+ encoder_out, encoder_out_lens, hyps_pad_sos_eos, hyps_lens_sos,
+ r_hyps_pad_sos_eos, ctc_score)
+ else:
+ best_index = self.decoder.forward(
+ encoder_out, encoder_out_lens, hyps_pad_sos_eos, hyps_lens_sos,
+ ctc_score)
+ else:
+ pad_size = 0
+ for n in self.mul_shape_decoder:
+ if n > encoder_out.shape[1]:
+ pad_size = n - encoder_out.shape[1]
+ break
+ encoder_out = np.pad(encoder_out.cpu(), ((0, 0), (0, pad_size), (0, 0)), 'constant')
+ encoder_out = torch.from_numpy(encoder_out).to("npu:1")
+ if self.reverse_weight > 0:
+ best_index = self.decoder.forward(
+ encoder_out, encoder_out_lens, hyps_pad_sos_eos, hyps_lens_sos, r_hyps_pad_sos_eos, ctc_score)
+ else:
+ best_index = self.decoder.infer(
+ [encoder_out, encoder_out_lens, hyps_pad_sos_eos, hyps_lens_sos, ctc_score])
+ best_index = best_index.to("cpu")
+ best_sents = []
+ k = 0
+ for idx_ in best_index:
+ cur_best_sent = all_hyps[k: k + beam_size][idx_]
+ best_sents.append(cur_best_sent)
+ k += beam_size
+ hyps = map_batch(best_sents, self.vocabulary, num_processes)
+
+ for i, key in enumerate(keys):
+ nums_ += 1
+ content = hyps[i]
+ logger.info('{} {}'.format(key, content))
+ return nums_
+
+def infer_process(list_args):
+ idx_, encoder, decoder = list_args
+ batches = packed_data[idx_]
+ model = AsrModel(encoder, decoder, args, reverse_weight)
+
+ nums_ = 0
+ infer_s_t = time.time()
+ for data in batches:
+ num = model.forward(data)
+ nums_ += num
+
+ infer_e_t = time.time()
+ # sync_num.pop()
+ return nums_, infer_e_t - infer_s_t
+
+if __name__ == '__main__':
+ args = get_args()
+ logger = logging.getLogger(__name__)
+ logger.setLevel(level=logging.DEBUG)
+ handler = logging.FileHandler(args.result_file, mode='w')
+ formatter = logging.Formatter('%(message)s')
+ handler.setFormatter(formatter)
+
+ console = logging.StreamHandler()
+ console.setLevel(level=logging.DEBUG)
+ console_format = logging.Formatter('Recognize: %(message)s')
+ console.setFormatter(console_format)
+ logger.addHandler(handler)
+ logger.addHandler(console)
+ with open(args.config, 'r') as fin:
+ configs = yaml.safe_load(fin)
+ if len(args.override_config) > 0:
+ print('override!!!')
+ exit(0)
+ configs = override_config(configs, args.override_config)
+
+ reverse_weight = configs["model_conf"].get("reverse_weight", 0.0)
+ symbol_table = read_symbol_table(args.dict)
+ test_conf = copy.deepcopy(configs['dataset_conf'])
+ test_conf = adjust_test_conf(test_conf, args.batch_size)
+
+ test_dataset = Dataset(args.data_type,
+ args.test_data,
+ symbol_table,
+ test_conf,
+ args.bpe_model,
+ partition=False)
+
+ test_data_loader = DataLoader(test_dataset, batch_size=None,
+ num_workers=multiprocessing.cpu_count() // 2)
+ manager = multiprocessing.Manager()
+ num_process = args.num_process
+
+ # lots of file will be open, need to change sharing strategy
+ torch.multiprocessing.set_sharing_strategy('file_system')
+ # packed data for mitiple processes
+ packed_data = [[] for _ in range(num_process)]
+ idx = 0
+ pre_s_t = time.time()
+ for batch in test_data_loader:
+ packed_data[idx].append(batch)
+ idx = (idx + 1) % num_process
+ pre_e_t = time.time()
+
+ sync_num = manager.list()
+ data_cnt = 0
+ infer_times = 0
+
+ encoder = torch.jit.load(args.encoder_aie)
+ decoder = torch.jit.load(args.decoder_aie)
+
+ torch_aie.set_device(args.device_id)
+
+ data_cnt, infer_times = infer_process([0, encoder, decoder])
+
+ fps = float((data_cnt) / (infer_times + pre_e_t - pre_s_t))
+ fps_str = "fps: {}\n".format(fps)
+ resstr = "total time: {}\n".format(infer_times + pre_e_t - pre_s_t)
+ print(fps_str)
+ print(resstr)
+
+ flags = os.O_WRONLY | os.O_CREAT | os.O_EXCL
+ modes = stat.S_IWUSR | stat.S_IRUSR
+ with os.fdopen(os.open(args.test_file, flags, modes), 'w') as f:
+ f.write(fps_str)
+ f.write(resstr)
diff --git a/AscendIE/TorchAIE/built-in/audio/Wenet/trace.py b/AscendIE/TorchAIE/built-in/audio/Wenet/trace.py
new file mode 100644
index 0000000000000000000000000000000000000000..203f391f734ec28b2aefc19fb8ded9733ebd56a8
--- /dev/null
+++ b/AscendIE/TorchAIE/built-in/audio/Wenet/trace.py
@@ -0,0 +1,427 @@
+# BSD 3-Clause License
+#
+# All rights reserved.
+# Copyright 2023 Huawei Technologies Co., Ltd
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# * Neither the name of the copyright holder nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+# ============================================================================
+
+"""
+This script is for testing exported ascend encoder and decoder from
+export_onnx_npu.py. The exported ascend models only support batch offline ASR inference.
+It requires a python wrapped c++ ctc decoder.
+Please install it from ctc decoder in github
+"""
+from __future__ import print_function
+
+import argparse
+import logging
+import sys
+
+import torch
+import yaml
+
+from wenet.transformer.asr_model import init_asr_model
+from wenet.utils.checkpoint import load_checkpoint
+from wenet.transformer.ctc import CTC
+from wenet.transformer.decoder import TransformerDecoder
+from wenet.transformer.encoder import BaseEncoder
+from wenet.utils.mask import make_pad_mask
+
+def get_args():
+ parser = argparse.ArgumentParser(description='recognize with your model')
+ parser.add_argument('--config', required=True, help='config file')
+ parser.add_argument('--checkpoint', required=True, help='checkpoint model')
+ parser.add_argument('--encoder_out', required=True, help='Encoder traced .pth path')
+ parser.add_argument('--decoder_out', required=True, help='Decoder traced .pth path')
+ parser.add_argument('--batch_size',
+ type=int,
+ default=32,
+ help='used batch size')
+ parser.add_argument('--num_decoding_left_chunks',
+ default=5,
+ type=int,
+ required=False,
+ help="number of left chunks, <= 0 is not supported")
+ parser.add_argument('--streaming',
+ action='store_true',
+ help="whether to export streaming encoder, default false")
+ parser.add_argument('--decoding_chunk_size',
+ default=16,
+ type=int,
+ required=False,
+ help='the decoding chunk size, <=0 is not supported')
+ parser.add_argument('--reverse_weight', default=-1.0, type=float,
+ required=False,
+ help='reverse weight for bitransformer,' +
+ 'default value is in config file')
+ parser.add_argument('--beam_size', default=10, type=int, required=False,
+ help="beam size would be ctc output size")
+ parser.add_argument('--output_size',
+ type=str,
+ help='only effect in dynamic shapes mode,\
+ outputs size info for encoder, please input split by ","')
+ args_ = parser.parse_args()
+ print(args_)
+ return args_
+
+class StreamingEncoder(torch.nn.Module):
+ def __init__(self, model, required_cache_size, beam_size, transformer=False):
+ super().__init__()
+ self.ctc = model.ctc
+ self.subsampling_rate = model.encoder.embed.subsampling_rate
+ self.embed = model.encoder.embed
+ self.global_cmvn = model.encoder.global_cmvn
+ self.required_cache_size = required_cache_size
+ self.beam_size = beam_size
+ self.encoder = model.encoder
+ self.transformer = transformer
+
+ def forward(self, chunk_xs, chunk_lens, offset,
+ att_cache, cnn_cache, cache_mask):
+ """Streaming Encoder
+ Args:
+ xs (torch.Tensor): chunk input, with shape (b, time, mel-dim),
+ where `time == (chunk_size - 1) * subsample_rate + \
+ subsample.right_context + 1`
+ offset (torch.Tensor): offset with shape (b, 1)
+ 1 is retained for triton deployment
+ required_cache_size (int): cache size required for next chunk
+ compuation
+ > 0: actual cache size
+ <= 0: not allowed in streaming gpu encoder `
+ att_cache (torch.Tensor): cache tensor for KEY & VALUE in
+ transformer/conformer attention, with shape
+ (b, elayers, head, cache_t1, d_k * 2), where
+ `head * d_k == hidden-dim` and
+ `cache_t1 == chunk_size * num_decoding_left_chunks`.
+ cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer,
+ (b, elayers, b, hidden-dim, cache_t2), where
+ `cache_t2 == cnn.lorder - 1`
+ cache_mask: (torch.Tensor): cache mask with shape (b, required_cache_size)
+ in a batch of request, each request may have different
+ history cache. Cache mask is used to indidate the effective
+ cache for each request
+ Returns:
+ torch.Tensor: log probabilities of ctc output and cutoff by beam size
+ with shape (b, chunk_size, beam)
+ torch.Tensor: index of top beam size probabilities for each timestep
+ with shape (b, chunk_size, beam)
+ torch.Tensor: output of current input xs,
+ with shape (b, chunk_size, hidden-dim).
+ torch.Tensor: new attention cache required for next chunk, with
+ same shape (b, elayers, head, cache_t1, d_k * 2)
+ as the original att_cache
+ torch.Tensor: new conformer cnn cache required for next chunk, with
+ same shape as the original cnn_cache.
+ torch.Tensor: new cache mask, with same shape as the original
+ cache mask
+ """
+ offset = offset.squeeze(1)
+ T = chunk_xs.size(1)
+ chunk_mask = ~make_pad_mask(chunk_lens, T).unsqueeze(1)
+ # B X 1 X T
+ chunk_mask = chunk_mask.to(chunk_xs.dtype)
+ # transpose batch & num_layers dim
+ att_cache = torch.transpose(att_cache, 0, 1)
+ cnn_cache = torch.transpose(cnn_cache, 0, 1)
+
+ # rewrite encoder.forward_chunk
+ # <---------forward_chunk START--------->
+ xs = self.global_cmvn(chunk_xs)
+ # chunk mask is important for batch inferencing since
+ # different sequence in a batch has different length
+ xs, pos_emb, chunk_mask = self.embed(xs, chunk_mask, offset)
+ cache_size = att_cache.size(3) # required cache size
+ masks = torch.cat((cache_mask, chunk_mask), dim=2)
+ index = offset - cache_size
+
+ pos_emb = self.embed.position_encoding(index, cache_size + xs.size(1))
+ pos_emb = pos_emb.to(dtype=xs.dtype)
+
+ next_cache_start = -self.required_cache_size
+ r_cache_mask = masks[:, :, next_cache_start:]
+
+ r_att_cache = []
+ r_cnn_cache = []
+ for i, layer in enumerate(self.encoder.encoders):
+ xs, _, new_att_cache, new_cnn_cache = layer(
+ xs, masks, pos_emb,
+ att_cache=att_cache[i],
+ cnn_cache=cnn_cache[i])
+ # shape(new_att_cache) is (B, head, attention_key_size, d_k * 2),
+ # shape(new_cnn_cache) is (B, hidden-dim, cache_t2)
+ r_att_cache.append(new_att_cache[:, :, next_cache_start:, :].unsqueeze(1))
+ if not self.transformer:
+ r_cnn_cache.append(new_cnn_cache.unsqueeze(1))
+ if self.encoder.normalize_before:
+ chunk_out = self.encoder.after_norm(xs)
+
+ r_att_cache = torch.cat(r_att_cache, dim=1) # concat on layers idx
+ if not self.transformer:
+ r_cnn_cache = torch.cat(r_cnn_cache, dim=1) # concat on layers
+
+ # <---------forward_chunk END--------->
+
+ log_ctc_probs = self.ctc.log_softmax(chunk_out)
+ log_probs, log_probs_idx = torch.topk(log_ctc_probs,
+ self.beam_size,
+ dim=2)
+ log_probs = log_probs.to(chunk_xs.dtype)
+
+ r_offset = offset + chunk_out.shape[1]
+ # the below ops not supported in Tensorrt
+ # chunk_out_lens = torch.div(chunk_lens, subsampling_rate,
+ # rounding_mode='floor')
+ chunk_out_lens = chunk_lens // self.subsampling_rate
+ r_offset = r_offset.unsqueeze(1)
+
+ return log_probs, log_probs_idx, chunk_out, chunk_out_lens, \
+ r_offset, r_att_cache, r_cnn_cache, r_cache_mask
+
+class Encoder(torch.nn.Module):
+ def __init__(self,
+ encoder: BaseEncoder,
+ ctc: CTC,
+ beam_size: int = 10):
+ super().__init__()
+ self.encoder = encoder
+ self.ctc = ctc
+ self.beam_size = beam_size
+
+ def forward(self, speech: torch.Tensor,
+ speech_lengths: torch.Tensor,):
+ """Encoder
+ Args:
+ speech: (Batch, Length, ...)
+ speech_lengths: (Batch, )
+ Returns:
+ encoder_out: B x T x F
+ encoder_out_lens: B
+ ctc_log_probs: B x T x V
+ beam_log_probs: B x T x beam_size
+ beam_log_probs_idx: B x T x beam_size
+ """
+ encoder_out, encoder_mask = self.encoder(speech,
+ speech_lengths,
+ -1, -1)
+ encoder_out_lens = encoder_mask.squeeze(1).sum(1)
+ ctc_log_probs = self.ctc.log_softmax(encoder_out)
+ encoder_out_lens = encoder_out_lens.int()
+ beam_log_probs, beam_log_probs_idx = torch.topk(
+ ctc_log_probs, self.beam_size, dim=2)
+ return encoder_out, encoder_out_lens, ctc_log_probs, \
+ beam_log_probs, beam_log_probs_idx
+
+class Decoder(torch.nn.Module):
+ def __init__(self,
+ decoder: TransformerDecoder,
+ ctc_weight: float = 0.5,
+ reverse_weight: float = 0.3,
+ beam_size: int = 10):
+ super().__init__()
+ self.decoder = decoder
+ self.ctc_weight = ctc_weight
+ self.reverse_weight = reverse_weight
+ self.beam_size = beam_size
+
+ def forward(self,
+ encoder_out: torch.Tensor,
+ encoder_lens: torch.Tensor,
+ hyps_pad_sos_eos: torch.Tensor,
+ hyps_lens_sos: torch.Tensor,
+ r_hyps_pad_sos_eos: torch.Tensor,
+ ctc_score: torch.Tensor):
+ """Encoder
+ Args:
+ encoder_out: B x T x F
+ encoder_lens: B
+ hyps_pad_sos_eos: B x beam x (T2+1),
+ hyps with sos & eos and padded by ignore id
+ hyps_lens_sos: B x beam, length for each hyp with sos
+ r_hyps_pad_sos_eos: B x beam x (T2+1),
+ reversed hyps with sos & eos and padded by ignore id
+ ctc_score: B x beam, ctc score for each hyp
+ Returns:
+ decoder_out: B x beam x T2 x V
+ r_decoder_out: B x beam x T2 x V
+ best_index: B
+ """
+ B, T, F = encoder_out.shape
+ bz = self.beam_size
+ B2 = B * bz
+ encoder_out = encoder_out.repeat(1, bz, 1).view(B2, T, F)
+ encoder_mask = ~make_pad_mask(encoder_lens, T).unsqueeze(1)
+ encoder_mask = encoder_mask.repeat(1, bz, 1).view(B2, 1, T)
+ T2 = hyps_pad_sos_eos.shape[2] - 1
+ hyps_pad = hyps_pad_sos_eos.view(B2, T2 + 1)
+ hyps_lens = hyps_lens_sos.view(B2,)
+ hyps_pad_sos = hyps_pad[:, :-1].contiguous()
+ hyps_pad_eos = hyps_pad[:, 1:].contiguous()
+
+ r_hyps_pad = r_hyps_pad_sos_eos.view(B2, T2 + 1)
+ r_hyps_pad_sos = r_hyps_pad[:, :-1].contiguous()
+ r_hyps_pad_eos = r_hyps_pad[:, 1:].contiguous()
+
+ decoder_out, r_decoder_out, _ = self.decoder(
+ encoder_out, encoder_mask, hyps_pad_sos, hyps_lens, r_hyps_pad_sos,
+ self.reverse_weight)
+ decoder_out = torch.nn.functional.log_softmax(decoder_out, dim=-1)
+ V = decoder_out.shape[-1]
+ decoder_out = decoder_out.view(B2, T2, V)
+ mask = ~make_pad_mask(hyps_lens, T2) # B2 x T2
+ # mask index, remove ignore id
+ index = torch.unsqueeze(hyps_pad_eos * mask, 2)
+ score = decoder_out.gather(2, index).squeeze(2) # B2 X T2
+ # mask padded part
+ score = score * mask
+ decoder_out = decoder_out.view(B, bz, T2, V)
+ if self.reverse_weight > 0:
+ r_decoder_out = torch.nn.functional.log_softmax(r_decoder_out, dim=-1)
+ r_decoder_out = r_decoder_out.view(B2, T2, V)
+ index = torch.unsqueeze(r_hyps_pad_eos * mask, 2)
+ r_score = r_decoder_out.gather(2, index).squeeze(2)
+ r_score = r_score * mask
+ score = score * (1 - self.reverse_weight) + self.reverse_weight * r_score
+ r_decoder_out = r_decoder_out.view(B, bz, T2, V)
+ score = torch.sum(score, axis=1) # B2
+ score = torch.reshape(score, (B, bz)) + self.ctc_weight * ctc_score
+ best_index = torch.argmax(score, dim=1)
+ return best_index
+
+def export_online_encoder(args, model):
+ batch_size = args.batch_size
+ decoding_chunk_size = args.decoding_chunk_size
+ subsampling = model.encoder.embed.subsampling_rate
+ context = model.encoder.embed.right_context + 1
+ decoding_window = (decoding_chunk_size - 1) * subsampling + context
+ audio_len = decoding_window
+ feature_size = configs["input_dim"]
+ output_size = configs["encoder_conf"]["output_size"]
+ num_layers = configs["encoder_conf"]["num_blocks"]
+ # in transformer the cnn module will not be available
+ transformer = False
+ cnn_module_kernel = configs["encoder_conf"].get("cnn_module_kernel", 1) - 1
+ if not cnn_module_kernel:
+ transformer = True
+ num_decoding_left_chunks = args.num_decoding_left_chunks
+ required_cache_size = decoding_chunk_size * num_decoding_left_chunks
+ encoder = StreamingEncoder(model, required_cache_size, args.beam_size, transformer)
+ encoder.eval()
+
+ chunk_xs = torch.randn(batch_size, audio_len, feature_size, dtype=torch.float32)
+ chunk_lens = torch.ones(batch_size, dtype=torch.int32) * audio_len
+
+ offset = torch.arange(0, batch_size).unsqueeze(1)
+ # (elayers, b, head, cache_t1, d_k * 2)
+ head = configs["encoder_conf"]["attention_heads"]
+ d_k = configs["encoder_conf"]["output_size"] // head
+ att_cache = torch.randn(batch_size, num_layers, head,
+ required_cache_size, d_k * 2,
+ dtype=torch.float32)
+ cnn_cache = torch.randn(batch_size, num_layers, output_size,
+ cnn_module_kernel, dtype=torch.float32)
+
+ cache_mask = torch.ones(batch_size, 1, required_cache_size, dtype=torch.float32)
+
+ inputs = [chunk_xs, chunk_lens, offset, att_cache, cnn_cache, cache_mask]
+ print('[INFO] Started to trace the online encoder model')
+ traced = torch.jit.trace(encoder, inputs)
+
+ print('[INFO] Success! Online encoder was traced')
+ traced.save(args.encoder_out)
+
+def export_offline_encoder(args, model):
+ bz = args.batch_size
+ seq_len = 1478
+ beam_size = 10
+ feature_size = configs["input_dim"]
+
+ speech = torch.randn(bz, seq_len, feature_size, dtype=torch.float32)
+ speech_lens = torch.randint(low=10, high=seq_len, size=(bz,), dtype=torch.int32)
+ encoder = Encoder(model.encoder, model.ctc, beam_size)
+ encoder.eval()
+
+ inputs = [speech, speech_lens]
+ print('[INFO] Started to trace the offline encoder model')
+ traced = torch.jit.trace(encoder, inputs)
+
+ print('[INFO] Success! Offline encoder was traced')
+ traced.save(args.encoder_out)
+
+
+def export_decoder(args, model):
+ bz, seq_len = args.batch_size, 100
+ beam_size = 10
+ decoder = Decoder(model.decoder,
+ model.ctc_weight,
+ model.reverse_weight,
+ beam_size)
+ decoder.eval()
+
+ hyps_pad_sos_eos = torch.randint(low=3, high=1000, size=(bz, beam_size, seq_len))
+ hyps_lens_sos = torch.randint(low=3, high=seq_len, size=(bz, beam_size),
+ dtype=torch.int32)
+ r_hyps_pad_sos_eos = torch.randint(low=3, high=1000, size=(bz, beam_size, seq_len))
+
+ output_size = configs["encoder_conf"]["output_size"]
+ encoder_out = torch.randn(bz, seq_len, output_size, dtype=torch.float32)
+ encoder_out_lens = torch.randint(low=3, high=seq_len, size=(bz,), dtype=torch.int32)
+ ctc_score = torch.randn(bz, beam_size, dtype=torch.float32)
+
+ inputs = [encoder_out,
+ encoder_out_lens,
+ hyps_pad_sos_eos,
+ hyps_lens_sos,
+ r_hyps_pad_sos_eos,
+ ctc_score]
+ print('[INFO] Started to trace the decoder model')
+ traced = torch.jit.trace(decoder, inputs)
+ print('[INFO] Success! The decoder was traced')
+ traced.save(args.decoder_out)
+
+
+if __name__ == '__main__':
+ args = get_args()
+ torch.manual_seed(0)
+ torch.set_printoptions(precision=10)
+
+ with open(args.config, 'r') as fin:
+ configs = yaml.load(fin, Loader=yaml.FullLoader)
+ if args.reverse_weight != -1.0 and 'reverse_weight' in configs['model_conf']:
+ configs['model_conf']['reverse_weight'] = args.reverse_weight
+ print("Update reverse weight to", args.reverse_weight)
+ configs["encoder_conf"]["use_dynamic_chunk"] = False
+
+ model = init_asr_model(configs)
+ load_checkpoint(model, args.checkpoint)
+ model.eval()
+ if args.streaming:
+ export_online_encoder(args, model)
+ else:
+ export_offline_encoder(args, model)
+ export_decoder(args, model)
+