diff --git a/AscendIE/TorchAIE/built-in/audio/Conformer/conformer_py.patch b/AscendIE/TorchAIE/built-in/audio/Conformer/conformer_py.patch new file mode 100644 index 0000000000000000000000000000000000000000..c7b8c2b335a8179639e5e06cd8d47b91fd790839 --- /dev/null +++ b/AscendIE/TorchAIE/built-in/audio/Conformer/conformer_py.patch @@ -0,0 +1,11 @@ +--- ./pruned_transducer_stateless5/conformer.py 2024-03-18 17:36:34.852000000 +0800 ++++ ./pruned_transducer_stateless5/new_conformer.py 2024-03-18 17:35:14.136000000 +0800 +@@ -193,7 +193,7 @@ + ) # (T, N, C) + + x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) +- return x, lengths ++ return x#, lengths + + @torch.jit.export + def get_init_state( diff --git a/AscendIE/TorchAIE/built-in/audio/Conformer/decoder_compile.py b/AscendIE/TorchAIE/built-in/audio/Conformer/decoder_compile.py new file mode 100644 index 0000000000000000000000000000000000000000..40cfab638900ca466b06426234eef4c113a54938 --- /dev/null +++ b/AscendIE/TorchAIE/built-in/audio/Conformer/decoder_compile.py @@ -0,0 +1,40 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +import mindietorch +from mindietorch import _enums + +DECODER_Y_SHAPE = (1, 2) + +inputs = [mindietorch.Input(DECODER_Y_SHAPE, dtype=torch.int64)] + +decoder_ts_model = torch.jit.load('./exp/exported_decoder-epoch-99-avg-1.ts') +decoder_ts_model.eval() +mindietorch.set_device(0) +try: + compiled_decoder = mindietorch.compile( + decoder_ts_model, + inputs=inputs, + precision_policy=_enums.PrecisionPolicy.FP16, + truncate_long_and_double=True, + soc_version="Ascend310P3", + ) + compiled_decoder.save("compiled_decoder.ts") +except Exception as e: + print(f"During the compilation of decoder model, an error has occured: {e}") + + import sys + sys.exit(1) diff --git a/AscendIE/TorchAIE/built-in/audio/Conformer/encoder_compile.py b/AscendIE/TorchAIE/built-in/audio/Conformer/encoder_compile.py new file mode 100644 index 0000000000000000000000000000000000000000..31ca3d65ae1cb7cbe1495f40685deb294c50daec --- /dev/null +++ b/AscendIE/TorchAIE/built-in/audio/Conformer/encoder_compile.py @@ -0,0 +1,43 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import sys + +import torch +import mindietorch + +ENCODER_X_SHAPE = (1, 100, 80) +ENCODER_X_LENS_SHAPE = (1, ) + +inputs = [mindietorch.Input(ENCODER_X_SHAPE, dtype=torch.float32), mindietorch.Input(ENCODER_X_LENS_SHAPE, dtype=torch.int64)] + +encoder_ts_model = torch.jit.load('./exp/exported_encoder-epoch-99-avg-1.ts') +encoder_ts_model.eval() + +mindietorch.set_device(0) + +try: + compiled_encoder_model = mindietorch.compile( + encoder_ts_model, + inputs=inputs, + precision_policy=mindietorch.PrecisionPolicy.FP16, + truncate_long_and_double=True, + soc_version="Ascend310P3" + ) + compiled_encoder_model.save("./compiled_encoder.ts") +except Exception as e: + print("an error has occured.") + print(e) + sys.exit(1) diff --git a/AscendIE/TorchAIE/built-in/audio/Conformer/export_torchscript.patch b/AscendIE/TorchAIE/built-in/audio/Conformer/export_torchscript.patch new file mode 100644 index 0000000000000000000000000000000000000000..97f59af37e6919ebd9f27118b9b236f326f1dd67 --- /dev/null +++ b/AscendIE/TorchAIE/built-in/audio/Conformer/export_torchscript.patch @@ -0,0 +1,333 @@ +--- ./pruned_transducer_stateless5/export-onnx.py 2024-03-26 15:07:15.360000000 +0800 ++++ ./pruned_transducer_stateless5/export_torchscript.py 2024-03-18 16:36:38.728000000 +0800 +@@ -28,7 +28,7 @@ + 2. Export the model to ONNX + + ./pruned_transducer_stateless5/export-onnx.py \ +- --lang-dir $repo/data/lang_char \ ++ --tokens $repo/data/lang_char/tokens.txt \ + --epoch 99 \ + --avg 1 \ + --use-averaged-model 0 \ +@@ -55,6 +55,7 @@ + from pathlib import Path + from typing import Dict, Tuple + ++import k2 + import onnx + import torch + import torch.nn as nn +@@ -70,8 +71,7 @@ + find_checkpoints, + load_checkpoint, + ) +-from icefall.lexicon import Lexicon +-from icefall.utils import setup_logger, str2bool ++from icefall.utils import num_tokens, setup_logger, str2bool + + + def get_parser(): +@@ -128,10 +128,10 @@ + ) + + parser.add_argument( +- "--lang-dir", ++ "--tokens", + type=str, +- default="data/lang_char", +- help="The lang dir", ++ default="data/lang_char/tokens.txt", ++ help="Path to the tokens.txt", + ) + + parser.add_argument( +@@ -146,22 +146,22 @@ + return parser + + +-def add_meta_data(filename: str, meta_data: Dict[str, str]): +- """Add meta data to an ONNX model. It is changed in-place. ++# def add_meta_data(filename: str, meta_data: Dict[str, str]): ++# """Add meta data to an ONNX model. It is changed in-place. + +- Args: +- filename: +- Filename of the ONNX model to be changed. +- meta_data: +- Key-value pairs. +- """ +- model = onnx.load(filename) +- for key, value in meta_data.items(): +- meta = model.metadata_props.add() +- meta.key = key +- meta.value = value ++# Args: ++# filename: ++# Filename of the ONNX model to be changed. ++# meta_data: ++# Key-value pairs. ++# """ ++# model = onnx.load(filename) ++# for key, value in meta_data.items(): ++# meta = model.metadata_props.add() ++# meta.key = key ++# meta.value = value + +- onnx.save(model, filename) ++# onnx.save(model, filename) + + + class OnnxEncoder(nn.Module): +@@ -196,12 +196,13 @@ + - encoder_out, A 3-D tensor of shape (N, T', joiner_dim) + - encoder_out_lens, A 1-D tensor of shape (N,) + """ +- encoder_out, encoder_out_lens = self.encoder(x, x_lens) ++ # encoder_out, encoder_out_lens = self.encoder(x, x_lens) ++ encoder_out = self.encoder(x, x_lens) + + encoder_out = self.encoder_proj(encoder_out) + # Now encoder_out is of shape (N, T, joiner_dim) + +- return encoder_out, encoder_out_lens ++ return encoder_out#, encoder_out_lens + + + class OnnxDecoder(nn.Module): +@@ -254,10 +255,10 @@ + return logit + + +-def export_encoder_model_onnx( ++def export_encoder_torchscript( + encoder_model: OnnxEncoder, + encoder_filename: str, +- opset_version: int = 11, ++ # opset_version: int = 11, + ) -> None: + """Export the given encoder model to ONNX format. + The exported model has two inputs: +@@ -278,40 +279,16 @@ + opset_version: + The opset version to use. + """ +- x = torch.zeros(1, 100, 80, dtype=torch.float32) ++ x = torch.rand(1, 100, 80, dtype=torch.float32) + x_lens = torch.tensor([100], dtype=torch.int64) ++ encoder_ts = torch.jit.trace(encoder_model, (x, x_lens)) ++ # encoder_ts = torch.jit.trace(encoder_model, x) ++ torch.jit.save(encoder_ts, encoder_filename) + +- torch.onnx.export( +- encoder_model, +- (x, x_lens), +- encoder_filename, +- verbose=False, +- opset_version=opset_version, +- input_names=["x", "x_lens"], +- output_names=["encoder_out", "encoder_out_lens"], +- dynamic_axes={ +- "x": {0: "N", 1: "T"}, +- "x_lens": {0: "N"}, +- "encoder_out": {0: "N", 1: "T"}, +- "encoder_out_lens": {0: "N"}, +- }, +- ) +- +- meta_data = { +- "model_type": "conformer", +- "version": "1", +- "model_author": "k2-fsa", +- "comment": "stateless5", +- } +- logging.info(f"meta_data: {meta_data}") +- +- add_meta_data(filename=encoder_filename, meta_data=meta_data) +- +- +-def export_decoder_model_onnx( ++def export_decoder_torchscript( + decoder_model: OnnxDecoder, + decoder_filename: str, +- opset_version: int = 11, ++ # opset_version: int = 11, + ) -> None: + """Export the decoder model to ONNX format. + +@@ -334,33 +311,14 @@ + context_size = decoder_model.decoder.context_size + vocab_size = decoder_model.decoder.vocab_size + +- y = torch.zeros(10, context_size, dtype=torch.int64) +- decoder_model = torch.jit.script(decoder_model) +- torch.onnx.export( +- decoder_model, +- y, +- decoder_filename, +- verbose=False, +- opset_version=opset_version, +- input_names=["y"], +- output_names=["decoder_out"], +- dynamic_axes={ +- "y": {0: "N"}, +- "decoder_out": {0: "N"}, +- }, +- ) +- +- meta_data = { +- "context_size": str(context_size), +- "vocab_size": str(vocab_size), +- } +- add_meta_data(filename=decoder_filename, meta_data=meta_data) +- ++ y = torch.rand(10, context_size).to(dtype=torch.int64) ++ decoder_model = torch.jit.trace(decoder_model, y) ++ torch.jit.save(decoder_model, decoder_filename) + +-def export_joiner_model_onnx( ++def export_joiner_torchscript( + joiner_model: nn.Module, + joiner_filename: str, +- opset_version: int = 11, ++ # opset_version: int = 11, + ) -> None: + """Export the joiner model to ONNX format. + The exported joiner model has two inputs: +@@ -377,28 +335,8 @@ + + projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) +- +- torch.onnx.export( +- joiner_model, +- (projected_encoder_out, projected_decoder_out), +- joiner_filename, +- verbose=False, +- opset_version=opset_version, +- input_names=[ +- "encoder_out", +- "decoder_out", +- ], +- output_names=["logit"], +- dynamic_axes={ +- "encoder_out": {0: "N"}, +- "decoder_out": {0: "N"}, +- "logit": {0: "N"}, +- }, +- ) +- meta_data = { +- "joiner_dim": str(joiner_dim), +- } +- add_meta_data(filename=joiner_filename, meta_data=meta_data) ++ joiner_ts = torch.jit.trace(joiner_model, (projected_encoder_out, projected_decoder_out)) ++ torch.jit.save(joiner_ts, joiner_filename) + + + @torch.no_grad() +@@ -417,9 +355,9 @@ + + logging.info(f"device: {device}") + +- lexicon = Lexicon(params.lang_dir) +- params.blank_id = 0 +- params.vocab_size = max(lexicon.tokens) + 1 ++ token_table = k2.SymbolTable.from_file(params.tokens) ++ params.blank_id = token_table[""] ++ params.vocab_size = num_tokens(token_table) + 1 + + logging.info(params) + +@@ -541,60 +479,60 @@ + opset_version = 13 + + logging.info("Exporting encoder") +- encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx" +- export_encoder_model_onnx( ++ encoder_filename = params.exp_dir / f"exported_encoder-{suffix}.ts" ++ export_encoder_torchscript( + encoder, + encoder_filename, +- opset_version=opset_version, ++ # opset_version=opset_version, + ) + logging.info(f"Exported encoder to {encoder_filename}") + + logging.info("Exporting decoder") +- decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx" +- export_decoder_model_onnx( ++ decoder_filename = params.exp_dir / f"exported_decoder-{suffix}.ts" ++ export_decoder_torchscript( + decoder, + decoder_filename, +- opset_version=opset_version, ++ # opset_version=opset_version, + ) + logging.info(f"Exported decoder to {decoder_filename}") + + logging.info("Exporting joiner") +- joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx" +- export_joiner_model_onnx( ++ joiner_filename = params.exp_dir / f"exported_joiner-{suffix}.ts" ++ export_joiner_torchscript( + joiner, + joiner_filename, +- opset_version=opset_version, ++ # opset_version=opset_version, + ) + logging.info(f"Exported joiner to {joiner_filename}") + + # Generate int8 quantization models + # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection + +- logging.info("Generate int8 quantization models") ++ # logging.info("Generate int8 quantization models") + +- encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx" +- quantize_dynamic( +- model_input=encoder_filename, +- model_output=encoder_filename_int8, +- op_types_to_quantize=["MatMul"], +- weight_type=QuantType.QInt8, +- ) +- +- decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx" +- quantize_dynamic( +- model_input=decoder_filename, +- model_output=decoder_filename_int8, +- op_types_to_quantize=["MatMul"], +- weight_type=QuantType.QInt8, +- ) +- +- joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx" +- quantize_dynamic( +- model_input=joiner_filename, +- model_output=joiner_filename_int8, +- op_types_to_quantize=["MatMul"], +- weight_type=QuantType.QInt8, +- ) ++ # encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx" ++ # quantize_dynamic( ++ # model_input=encoder_filename, ++ # model_output=encoder_filename_int8, ++ # op_types_to_quantize=["MatMul"], ++ # weight_type=QuantType.QInt8, ++ # ) ++ ++ # decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx" ++ # quantize_dynamic( ++ # model_input=decoder_filename, ++ # model_output=decoder_filename_int8, ++ # op_types_to_quantize=["MatMul"], ++ # weight_type=QuantType.QInt8, ++ # ) ++ ++ # joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx" ++ # quantize_dynamic( ++ # model_input=joiner_filename, ++ # model_output=joiner_filename_int8, ++ # op_types_to_quantize=["MatMul"], ++ # weight_type=QuantType.QInt8, ++ # ) + + + if __name__ == "__main__": diff --git a/AscendIE/TorchAIE/built-in/audio/Conformer/joiner_compile.py b/AscendIE/TorchAIE/built-in/audio/Conformer/joiner_compile.py new file mode 100644 index 0000000000000000000000000000000000000000..3a75b47a76d2a63218754fe681a5a631d5a75866 --- /dev/null +++ b/AscendIE/TorchAIE/built-in/audio/Conformer/joiner_compile.py @@ -0,0 +1,39 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +import mindietorch + +JOINER_X_SHAPE = (1, 512) + +inputs = [mindietorch.Input(JOINER_X_SHAPE, dtype=torch.float32), mindietorch.Input(JOINER_X_SHAPE, dtype=torch.float32)] + +joiner_ts_model = torch.jit.load('./exp/exported_joiner-epoch-99-avg-1.ts') +joiner_ts_model.eval() +mindietorch.set_device(0) + +try: + compiled_joiner = mindietorch.compile( + joiner_ts_model, + inputs=inputs, + precision_policy=mindietorch.PrecisionPolicy.FP16, + truncate_long_and_double=True, + soc_version="Ascend310P3", + ) + compiled_joiner.save("./compiled_joiner.ts") +except Exception as e: + print("During the compilation of joiner model, an error has occured.") + import sys + sys.exit(1) \ No newline at end of file diff --git a/AscendIE/TorchAIE/built-in/audio/Conformer/perf_test_aie.py b/AscendIE/TorchAIE/built-in/audio/Conformer/perf_test_aie.py new file mode 100644 index 0000000000000000000000000000000000000000..a5728bea44ba292664b0a8b0798950c9706d451f --- /dev/null +++ b/AscendIE/TorchAIE/built-in/audio/Conformer/perf_test_aie.py @@ -0,0 +1,143 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import time +import argparse +import json + +import numpy as np +import torch +import mindietorch +from tqdm import tqdm + +def test_encoder(aie_path, device_id = 0): + batch_size = 1 + device = f'npu:{device_id}' + stream = mindietorch.npu.Stream(device) + print("Start loading ts module...") + ts = torch.jit.load(aie_path) + print("Ts module loaded.") + ts.eval() + x, x_lens = np.ones((1, 100, 80), dtype=np.float32), np.array([100]) + + inputs = (torch.from_numpy(x).to("npu:0"), torch.from_numpy(x_lens).to("npu:0")) + print("Start infering...") + # warmup + for _ in range(10): + with mindietorch.npu.stream(stream): + ts(*inputs) + stream.synchronize() + + # performance test + num_infer = 100 + + start = time.time() + for _ in tqdm(range(num_infer)): + with mindietorch.npu.stream(stream): + + ts(*inputs) + stream.synchronize() + end = time.time() + + print(f"Encoder latency: {(end - start) / num_infer * 1000:.2f} ms") + print(f"Encoder throughput: {num_infer * batch_size / (end - start):.2f} fps") + + +def test_decoder(aie_path, device_id): + batch_size = 1 + dummpy_input = np.ones((batch_size, 2), dtype=np.int64) + + device = f'npu:{device_id}' + stream = mindietorch.npu.Stream(device) + print("Start loading ts module...") + model = torch.jit.load(aie_path) + print("Ts module loaded.") + model.eval() + dummpy_input = torch.from_numpy(dummpy_input).to(device) + + # warmup + for _ in range(10): + with mindietorch.npu.stream(stream): + model(dummpy_input) + stream.synchronize() + + # performance test + num_infer = 100 + start = time.time() + for _ in tqdm(range(num_infer)): + with mindietorch.npu.stream(stream): + model(dummpy_input) + stream.synchronize() + end = time.time() + + print(f"Decoder latency: {(end - start) / num_infer * 1000:.2f} ms") + print(f"Decoder throughput: {num_infer * batch_size / (end - start):.2f} fps") + + +def test_joiner(aie_path, device_id): + batch_size = 1 + encoder_out = np.ones((batch_size, 512), dtype=np.float32) + decoder_out = np.ones((batch_size, 512), dtype=np.float32) + + device = f'npu:{device_id}' + stream = mindietorch.npu.Stream(device) + model = torch.jit.load(aie_path) + model.eval() + encoder_out = torch.from_numpy(encoder_out).to(device) + decoder_out = torch.from_numpy(decoder_out).to(device) + + # warmup + for _ in range(10): + with mindietorch.npu.stream(stream): + out = model(encoder_out, decoder_out) + stream.synchronize() + + # performance test + num_infer = 100 + start = time.time() + for _ in range(num_infer): + with mindietorch.npu.stream(stream): + model(encoder_out, decoder_out) + stream.synchronize() + end = time.time() + + print(f"Joiner latency: {(end - start) / num_infer * 1000:.2f} ms") + print(f"Joiner throughput: {num_infer * batch_size / (end - start):.2f} fps") + + +def parse_args(): + parser = argparse.ArgumentParser() + + parser.add_argument("--encoder_aie_path", type=str, required=True) + parser.add_argument("--decoder_aie_path", type=str, required=True) + parser.add_argument("--joiner_aie_path", type=str, required=True) + parser.add_argument("--device_id", type=int, help="NPU device id", default=0) + + args = parser.parse_args() + return args + + +def main(): + mindietorch.set_device(0) + args = parse_args() + + + test_encoder(args.encoder_aie_path, args.device_id) + test_decoder(args.decoder_aie_path, args.device_id) + test_joiner(args.joiner_aie_path, args.device_id) + + +if __name__ == "__main__": + main() diff --git a/AscendIE/TorchAIE/built-in/audio/Conformer/readme.md b/AscendIE/TorchAIE/built-in/audio/Conformer/readme.md new file mode 100644 index 0000000000000000000000000000000000000000..4a0d7af67b12db9a606aa70d4c82fbef90e6a1d3 --- /dev/null +++ b/AscendIE/TorchAIE/built-in/audio/Conformer/readme.md @@ -0,0 +1,169 @@ +# Conformer模型-推理指导 + + +- [概述](#ZH-CN_TOPIC_0000001172161501) + +- [推理环境准备](#ZH-CN_TOPIC_0000001126281702) + +- [快速上手](#ZH-CN_TOPIC_0000001126281700) + +- [模型推理性能精度](#ZH-CN_TOPIC_0000001172201573) + + +# 概述 + +Conformer模型是一种混合神经网络架构,专门设计用于处理序列到序列的任务,如自动语音识别(ASR)。它融合了卷积神经网络(CNN)和自注意力机制(来自Transformer模型)的优点,旨在捕捉序列数据的局部特征和全局依赖。Conformer通过在其架构中巧妙地结合这两种方法,有效地处理了时间序列数据的复杂性,比如语音波形,从而在许多任务上实现了卓越的性能。简而言之,Conformer通过集成CNN的强大特征提取能力和Transformer的高效序列建模能力,为序列分析任务提供了一种强大的解决方案。 + + +# 推理环境准备\[所有版本\] + +- 该模型需要以下依赖 + + **表 1** 版本配套表 + + | 配套 | 版本 | + |---------|---------| + | CANN | 7.0RC1 | - | + | Python | 3.10.13 | + | torch | 2.1.0 | + | 芯片类型 | Ascend310P3 | - + +# 快速上手 + +## 环境安装 + +1. 安装k2 + 1. (NPU)x86环境 + ```shell + wget https://huggingface.co/csukuangfj/k2/resolve/main/cpu/k2-1.24.4.dev20231220+cpu.torch2.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl + pip install k2-1.24.4.dev20231220+cpu.torch2.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl + ``` + 2. (NPU/GPU)arm环境,需要从源码编译。 + ```shell + git clone https://github.com/k2-fsa/k2.git + cd k2 + export K2_MAKE_ARGS="-j" + python3 setup.py install + ``` + 若执行以上命令遇到错误,请参考[此链接](https://k2-fsa.github.io/k2/installation/from_source.html)。 + 3. (GPU) x86环境。从[此链接](https://k2-fsa.github.io/k2/cuda.html)下载对应CUDA版本的whl文件,然后使用pip进行安装。 + 4. 验证k2是否安装成功 + ```shell + python3 -m k2.version + ``` +2. 安装其他依赖 + ```shell + pip install lhotse + pip install kaldifeat + ``` +3. 安装icefall + ```shell + git clone https://github.com/k2-fsa/icefall.git + git reset --hard e2fcb42f5f176d9e39eb38506ab99d0a3adaf202 + + cd icefall + pip install -r requirements.txt + ``` +4. 将icefall加入环境变量, "/path/to/icefall"替换为icefall文件夹所在的路径。 + **这一步很重要,否则会报icefall找不到的错误。** + ```shell + export PYTHONPATH=/path/to/icefall:$PYTHONPATH + ``` + +## 模型转换 +1. 下载模型 + +从以下HuggingFace链接下载所需文件,所需文件为/exp/pretrained_epoch_9_avg_1.pt, 和/data整个文件夹 +https://huggingface.co/luomingshuang/icefall_asr_wenetspeech_pruned_transducer_stateless5_offline/tree/main + +进入上一步中的icefall路径,并cd到egs/wenetspeech/ASR/中的,将pretrained_epoch_9_avg_1.pt放在ASR/exp目录下,并**重命名为epoch-99.pt** +即 +``` +wenetspeech/asr/ + --data/ + --lang_char/ + --Linv.pt等文件 + --exp/ + --epoch-99.pt + -- conformer_py.patch + -- export_torchscript.patch + --pruned_transducer_stateless5/ + -- encoder_compile.py + -- decoder_compile.py + -- joiner_compile.py + -- perf_test_aie.py + -- test_precision.py + +``` + +2. 导出torchscipt模型 + +2.1 修改原始模型 + +ASR目录下执行以下命令 +```shell +#先应用patch文件 +patch ./pruned_transducer_stateless5/conformer.py conformer_py.patch + +patch ./pruned_transducer_stateless5/export-onnx.py ./export_torchscript.patch -o ./pruned_transducer_stateless5/export_torchscript.py + +#再导出torchscript模型 +python3 ./pruned_transducer_stateless5/export_torchscript.py \ + --tokens ./data/lang_char/tokens.txt \ + --epoch 99 \ + --avg 1 \ + --use-averaged-model 0 \ + --exp-dir ./exp \ + --num-encoder-layers 24 \ + --dim-feedforward 1536 \ + --nhead 8 \ + --encoder-dim 384 \ + --decoder-dim 512 \ + --joiner-dim 512 + +``` +2.2 转为MindIETorch模型 + +将encoder_compile.py decoder_compile.py joiner_compile.py放在ASR/pruned_transducer_stateless5/目录下,分别执行 +```python +#在ASR目录下执行 +python ./pruned_transducer_stateless5/encoder_compile.py +python ./pruned_transducer_stateless5/decoder_compile.py +python ./pruned_transducer_stateless5/joiner_compile.py +``` +会在ASR目录下生成compiled_encoder.ts compiled_decoder.ts compiled_joiner.ts 三个文件。 + +### 精度验证 + +encoder 模型精度验证,屏幕显示Precision test passed 为精度正常。 +```shell + python test_precision.py encoder compiled_encoder.ts +``` +decoder 模型精度验证,屏幕显示Precision test passed 为精度正常。 +```shell +python test_precision.py decoder compiled_decoder.ts +``` + +joiner模型精度验证,屏幕显示Precision test passed 为精度正常。 +```shell +python test_precision.py joiner compiled_joiner.ts +``` + +### 性能验证 +```shell +#将perf_test_aie.py放在ASR路径下,然后执行 +python perf_test_aie.py \ +--encoder_aie_path ./compiled_encoder.ts \ +--decoder_aie_path ./compiled_decoder.ts \ +--joiner_aie_path ./compiled_joiner.ts \ +--device_id 0 +``` +屏幕上会打印性能数据,以FPS记 + + +### 性能数据 (时延/吞吐率) +|Model| MindIE Torch | T4| A10| +|------| ----------------- |------| --------| +|encoder| 37.31ms/26.81FPS | 20.53ms/48.70FPS | 16.4ms / 60.9FPS| +|decoder| 0.22ms/ 4470FPS | 0.13ms/7443FPS | 0.12ms/8333FPS | +|joiner | 0.20ms/ 4913FPS | 0.13ms/7612FPS | 0.11ms/9212FPS | \ No newline at end of file diff --git a/AscendIE/TorchAIE/built-in/audio/Conformer/test_precision.py b/AscendIE/TorchAIE/built-in/audio/Conformer/test_precision.py new file mode 100644 index 0000000000000000000000000000000000000000..bd3aee96804eb92e0619f50a40ba159b4612254a --- /dev/null +++ b/AscendIE/TorchAIE/built-in/audio/Conformer/test_precision.py @@ -0,0 +1,95 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import sys +from pruned_transducer_stateless5.onnx_pretrained import OnnxModel + +import numpy as np +import torch +import mindietorch + +from torch.nn.functional import cosine_similarity + +# Initialize the ONNX model globally +onnxmodel = OnnxModel("./exp/encoder-epoch-99-avg-1.onnx", "./exp/decoder-epoch-99-avg-1.onnx", "./exp/joiner-epoch-99-avg-1.onnx") + +def is_close_to_ones(x1, atol): + x2 = torch.ones_like(x1) + return torch.allclose(x1, x2, atol) + +def precision_test(ts_output, onnx_output, atol=1e-02): + result = is_close_to_ones(cosine_similarity(ts_output, onnx_output), atol) + print("Precision test" + "passed" if result else "failed") + + +def run_ts_inference(ts_path, dummpy_input, device_id): + batch_size = 1 + device = f'npu:{device_id}' + stream = mindietorch.npu.Stream(device) + model = torch.jit.load(ts_path) + model.eval() + + with mindietorch.npu.stream(stream): + ts_out = model(*dummpy_input) + stream.synchronize() + return ts_out + + +def evaluate_model(mode, ts_path, device_id): + print(f"Evaluating precision of {mode} model") + if mode == 'encoder': + #dummy inputs + x, x_lens = np.random.rand(1, 100, 80).astype(np.float32), np.array([100]) + x_tensor, x_lens_tensor = torch.from_numpy(x), torch.from_numpy(x_lens) + x_npu_tensor, x_lens_npu_tensor = x_tensor.to(f"npu:{device_id}"), x_lens_tensor.to(f"npu:{device_id}") + + #gpu/npu inference + ts_out = run_ts_inference(ts_path, (x_npu_tensor, x_lens_npu_tensor), device_id) + onnx_output, _ = onnxmodel.run_encoder(x_tensor, x_lens_tensor) + + elif mode == 'decoder': + y = np.random.randint(0, 10, size=(1, 2)).astype(np.int64) + y_tensor = torch.from_numpy(y) + y_npu_tensor = y_tensor.to(f'npu:{device_id}') + + ts_out = run_ts_inference(ts_path, (y_npu_tensor, ), device_id) + onnx_output = onnxmodel.run_decoder(y_tensor) + + elif mode == 'joiner': + enc, dec = np.random.rand(1, 512).astype(np.float32), np.random.rand(1, 512).astype(np.float32) + enc_tensor, dec_tensor = torch.from_numpy(enc), torch.from_numpy(dec) + enc_npu_tensor, dec_npu_tensor = enc_tensor.to(f'npu:{device_id}'), dec_tensor.to(f'npu:{device_id}') + + ts_out = run_ts_inference(ts_path, (enc_npu_tensor, dec_npu_tensor), device_id) + onnx_output = onnxmodel.run_joiner(enc_tensor, dec_tensor) + + else: + raise ValueError("Invalid mode") + + ts_out = ts_out.to("cpu") + + precision_test(ts_out, onnx_output, atol=1e-02) + + +if __name__ == "__main__": + if len(sys.argv) != 3: + print("Usage: