diff --git a/ACL_PyTorch/README.md b/ACL_PyTorch/README.md index e417ad3e9505a15e3296f06f11bbc8873fa5e82c..51cd1c0085d5332dd35b7535d36e2e260ec72f57 100755 --- a/ACL_PyTorch/README.md +++ b/ACL_PyTorch/README.md @@ -4822,6 +4822,21 @@ python3 get_modelID.py --model your_model_name

154.8(bs1) 多尺度 + + + 100321 + + whisper + + librispeech_asr_dummy + + 8.21% + + + + 67.32(bs1) + bs x 80 x 3000 +

Knowledge

diff --git a/ACL_PyTorch/built-in/audio/whisper/README.md b/ACL_PyTorch/built-in/audio/whisper/README.md new file mode 100644 index 0000000000000000000000000000000000000000..73573fd9334e7044c8baf7c78c675aa7a74d11d4 --- /dev/null +++ b/ACL_PyTorch/built-in/audio/whisper/README.md @@ -0,0 +1,121 @@ +# Whisper模型推理指导 + +- [概述](#概述) +- [插件与驱动准备](#插件与驱动准备) +- [获取本仓源码](#获取本仓源码) +- [环境准备](#环境准备) +- [数据集准备](#数据集准备) +- [文件目录结构](#文件目录结构) +- [开始推理](#开始推理) +- [性能数据](#性能数据) + +## 概述 +Whisper 是 OpenAI 开源的通用语音识别模型,支持多语言转录和翻译,基于 Transformer 架构,适用于会议记录、字幕生成等场景。其特点是开箱即用、鲁棒性强,并提供多种模型尺寸平衡速度与精度。 + +## 插件与驱动准备 + +- 该模型需要以下插件与驱动 + + | 配套 | 版本 | 环境准备指导 | + | ------------------------------------------------------------ |-------------| ------------------------------------------------------------ | + | 固件与驱动 | 25.0.RC1 | [Pytorch框架推理环境准备](https://www.hiascend.com/document/detail/zh/ModelZoo/pytorchframework/pies) | + | CANN | 8.1.RC1 | 包含kernels包和toolkit包 | + | Python | 3.10 | - | + | PyTorch | 2.5.1 | - | + | Ascend Extension PyTorch | 2.5.1 | - | + | 说明:Atlas 800I A2 推理卡和Atlas 300I DUO 推理卡请以CANN版本选择实际固件与驱动版本。 | \ | \ | + + +## 获取本仓源码 +``` +git clone https://gitee.com/ascend/ModelZoo-PyTorch.git +cd ModelZoo-PyTorch/ACL_PyTorch/built-in/audio/whisper/ +``` + +## 环境准备 + +* 通过以下命令下载并安装(或升级至)Whisper 的最新版本: + + `pip3 install -U openai-whisper` + +* 下载模型权重: + * `base.pt`:[下载链接](https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt) + +* 安装命令行工具**ffmpeg**: + * 在 Ubuntu or Debian上: + `sudo apt update && sudo apt install ffmpeg` + * 在 Arch Linux上: + `sudo pacman -S ffmpeg` + +* 安装requirements: + `pip3 install -r requirements.txt` + +## 数据集准备 +* librispeech_asr_dummy数据集[下载地址](https://huggingface.co/datasets/hf-internal-testing/librispeech_asr_dummy/tree/main),该数据集是 Hugging Face Datasets 库中提供的一个小型测试数据集,用于快速验证语音识别。下载下来后,把它放入当前文件夹内。 +* `audio.mp3`是普通的语音文件,在warm up阶段使用,并可以直观测试,可以通过以下链接获取。(你也可以自己找一个中文语音.mp3/wav文件,放入目录中) + ```TEXT + https://pan.baidu.com/s/1fHL0fWbGgKXQ9W1GXA2RBQ?pwd=xe2x 提取码: xe2x 复制这段内容后打开百度网盘手机App,操作更方便哦 + ``` + +## 文件目录结构 +文件目录结构大致如下: + +```text +📁 whisper/ +├── audio.mp3 +├── infer.py +├── rewrited_models.py +├── whisper_decoding.patch +├── base.pt +├── README.md +├── requrements.txt +├── 📁 librispeech_asr_dummy/ +| |── 📁 clean +│ └── 📄 validation-00000-of-00001.parquet +``` + +## 开始推理 +```SHELL +# 1. 激活环境变量 +source /usr/local/Ascend/ascend-toolkit/set_env.sh # 具体路径根据你自己的情况修改 +# 2. 指定使用NPU ID,默认为0 +export ASCEND_RT_VISIBLE_DEVICES=0 +# 3. 开始推理 +python3 infer.py +``` +infer.py推理参数: +* --model_path:模型权重路径,默认为"base.pt" +* --audio_path:音频文件的路径,默认为"audio.mp3" +* --speech_path:librispeech_asr_dummy数据集文件的路径,默认为"./librispeech_asr_dummy/clean/" +* --device: npu设备编号,默认为0 +* --batch_size: batch_size大小,默认为1 +* --warm_up:warm_up次数,默认为5 + +在推理开始后,首先会默认执行warm_up,目的是执行首次编译,首次编译时间较长,在warm_up结束后,会执行推理操作,输出audio.mp3音频的推理得到的文本。 + +warmup结束之后,开始推理librispeech_asr_dummy数据集,推理过程中会打屏输出E2E性能,推理结束后会输出WER精度得分。 + +**如果你想推理过程中打印encode和decode的耗时,你可以执行以下命令:** +```SHELL +# 1. 找到当前的环境路径(简称${location}),Location后面的那一串就是当前环境路径 +pip show openai-whisper | grep Location +# 2. 记录当前whisper库decoding.py的文件路径 +${decoding_path} = ${location}/whisper/decoding.py +# 3. 执行patch文件 +patch -p1 < whisper_decoding.patch +# 可能会提示你 +# cant find file to patch at input line 3 +# ... +# File to patch: +# 这时候需要你手动指定文件路径,输入之前得到的 +${decoding_path} +# 按回车,提示 patching file ${decoding_path} 即成功 +``` + +## 性能数据 + 在librispeech_asr_dummy/clean数据集上的性能如下: + + | 模型 | 芯片 | 平均encode | 平均decode |平均E2E | + |---------|------------|----------|-----------------|---------| + | whisper | 800I A2 | 0.90ms | 3.25ms | 67.32ms | + 注:平均decode 指在decode阶段,生成单个token的平均耗时。 \ No newline at end of file diff --git a/ACL_PyTorch/built-in/audio/whisper/infer.py b/ACL_PyTorch/built-in/audio/whisper/infer.py new file mode 100644 index 0000000000000000000000000000000000000000..ba5da6fa131bd9a0b799ac2c47f64b9c65f3767b --- /dev/null +++ b/ACL_PyTorch/built-in/audio/whisper/infer.py @@ -0,0 +1,304 @@ +# Copyright (c) 2025 Huawei Technologies Co., Ltd +# [Software Name] is licensed under Mulan PSL v2. +# You can use this software according to the terms and conditions of the Mulan PSL v2. +# You may obtain a copy of Mulan PSL v2 at: +# http://license.coscl.org.cn/MulanPSL2 +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, +# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, +# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +# See the Mulan PSL v2 for more details. + +import time +import math +import argparse +from typing import Optional + +import jiwer +import numpy as np +import pandas as pd +from datasets import load_dataset + +import torch +from torch import nn, Tensor +import torch_npu +import torchair as tng +from torchair.configs.compiler_config import CompilerConfig + +import whisper +from whisper.model import Linear +from whisper.decoding import PyTorchInference, DecodingResult, DecodingTask +from whisper.normalizers import EnglishTextNormalizer + +from rewrited_models import PrefillTextDecoder, DecodeTextDecoder + + +class LibriSpeechDataset(torch.utils.data.Dataset): + def __init__(self, speech_path, device, audio_column="audio", text_column='text'): + self.dataset = load_dataset(speech_path, split="validation") + self.audio_column = audio_column + self.text_column = text_column + self.device = device + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, idx): + # 自动解码音频 + 重采样到 16kHz + audio = self.dataset[idx]["audio"]["array"] # 直接获取 NumPy 数组 + audio = torch.from_numpy(audio).float() + + # 统一长度 + 生成梅尔频谱 + audio = whisper.pad_or_trim(audio) + mel = whisper.log_mel_spectrogram(audio) + + return mel.contiguous().to(self.device), self.dataset[idx][self.text_column] + + +def parse_args(): + parser = argparse.ArgumentParser("Whisper infer") + parser.add_argument("--model_path", type=str, default="./base.pt", help="model checkpoint file path") + parser.add_argument("--audio_path", type=str, default="./audio.mp3", + help="warmup audio file path") + parser.add_argument("--speech_path", type=str, default="./librispeech_asr_dummy/clean/", + help="librispeech_asr_dummy english transaction speech data path") + parser.add_argument('--device', type=int, default='0', help="npu device id") + parser.add_argument('--batch_size', type=int, default=1, help="batch size") + parser.add_argument('--warmup', type=int, default=4, help="Warm up times") + args = parser.parse_args() + return args + + +def create_model(args): + model = whisper.load_model(args.model_path) + print( + f"Model is {'multilingual' if model.is_multilingual else 'English-only'} " + f"and has {sum(np.prod(p.shape) for p in model.parameters()):,} parameters." + ) + return model + + +def rewrite_multi_head_attention_forward(model): + wk = model.key.weight + wv = model.value.weight + model.kv = Linear(in_features=wk.shape[0], out_features=wk.shape[1] + wv.shape[1]) + model.kv.weight = nn.Parameter(torch.concat([wk, wv], dim=0), requires_grad=False) + wk_bias = model.key.bias if model.key.bias is not None else torch.zeros(wk.shape[0]) + wv_bias = model.value.bias if model.value.bias is not None else torch.zeros(wv.shape[0]) + model.kv.bias = nn.Parameter(torch.concat([wk_bias, wv_bias], dim=0), requires_grad=False) + + def forward( + x: Tensor, + xa: Optional[Tensor] = None, + mask: Optional[Tensor] = None, + kv_cache: Optional[dict] = None, + actual_seq_len: Optional[list] = None, + ): + q = model.query(x) + + # encoder + if kv_cache is None: + kv = model.kv(x) + k, v = kv.chunk(2, dim=-1) + + # decoder - cross_attention + if kv_cache is not None and xa is not None: + k_key = "key" + v_key = "value" + if k_key in kv_cache: + k = kv_cache[k_key] + v = kv_cache[v_key] + else: + kv = model.kv(xa) + k, v = kv.chunk(2, dim=-1) + kv_cache[k_key] = k.contiguous() + kv_cache[v_key] = v.contiguous() + + # decoder - self_attention + if kv_cache is not None and xa is None: + k_key = "key" + v_key = "value" + if k_key in kv_cache: + k = kv_cache[k_key] + v = kv_cache[v_key] + new_kv = model.kv(x[:, -1:]) + new_k = new_kv[..., :wk.shape[0]] + new_v = new_kv[..., wk.shape[0]:] + kv_cache[k_key] = torch.cat([k.contiguous(), new_k.contiguous()], dim=1).detach() + kv_cache[v_key] = torch.cat([v.contiguous(), new_v.contiguous()], dim=1).detach() + k, v = kv_cache[k_key], kv_cache[v_key] + else: + kv = model.kv(x) + k, v = kv.chunk(2, dim=-1) + kv_cache[k_key] = k.contiguous() + kv_cache[v_key] = v.contiguous() + + n_batch, n_ctx, n_state = q.shape + q = q.view(*q.shape[:2], model.n_head, -1).permute(0, 2, 1, 3) + k = k.view(*k.shape[:2], model.n_head, -1).permute(0, 2, 1, 3) + v = v.view(*v.shape[:2], model.n_head, -1).permute(0, 2, 1, 3) + + mask = mask.to(torch.bool) if mask is not None and n_ctx > 1 else None + sparse_mode = 1 if mask is not None and n_ctx > 1 else 0 + D = n_state // model.n_head + + at = torch_npu.npu_prompt_flash_attention( + q.contiguous(), + k.contiguous(), + v.contiguous(), + num_heads=model.n_head, + input_layout="BNSD", + scale_value=1 / math.sqrt(D), + atten_mask=mask[:n_ctx, :n_ctx] if mask is not None else None, + sparse_mode=sparse_mode + ) + + qk = None + w_v = at.permute(0, 2, 1, 3).flatten(start_dim=2) + return model.out(w_v), qk + + model.forward = forward + + +def modify_model(model, options, args, device): + print("modify model...") + + # 修改encoder的attention forward + for block1, block2 in zip(model.encoder.blocks, model.decoder.blocks): + rewrite_multi_head_attention_forward(block1.attn) + rewrite_multi_head_attention_forward(block2.attn) + rewrite_multi_head_attention_forward(block2.cross_attn) + origin_decoder = model.decoder + + # 将原本的decoder拆分成prefill和decode2个阶段 + prefill_decoder = PrefillTextDecoder( + model.dims.n_vocab, + model.dims.n_text_ctx, + model.dims.n_text_state, + model.dims.n_text_head, + model.dims.n_text_layer + ) + prefill_decoder.load_state_dict(origin_decoder.state_dict()) + + decode_decoder = DecodeTextDecoder( + model.dims.n_vocab, + model.dims.n_text_ctx, + model.dims.n_text_state, + model.dims.n_text_head, + model.dims.n_text_layer + ) + decode_decoder.load_state_dict(origin_decoder.state_dict()) + + model.prefill_decoder = prefill_decoder + model.decode_decoder = decode_decoder + + if options.fp16: + model = model.half() + for module in model.modules(): + # 在Whisper源码中,LayerNorm层需要接收fp32数据,因此需要特殊处理 + if isinstance(module, nn.LayerNorm): + module = module.float() + + return model.eval().to(device) + + +def rewrite_inference_logits(): + def _patched_logits(self, tokens, audio_features) -> Tensor: + if not self.kv_cache: + self.kv_cache, self.hooks = self.model.install_kv_cache_hooks() + self.kv_cache = [ + { + 'attn': {}, + 'cross_attn': {} + } + for _ in range(6) + ] + return self.model.prefill_decoder(tokens, audio_features, kv_cache=self.kv_cache) + + actual_seq_len = tokens.shape[-1] + updated_kv_positions = torch.tensor([actual_seq_len - 1], dtype=torch.long, device=tokens.device) + kv_padding_size = torch.tensor([448 - actual_seq_len], dtype=torch.long, device=tokens.device) + + offset = actual_seq_len - 1 + positional_embedding = self.model.decode_decoder.positional_embedding[offset: offset + 1] + tokens = tokens[:, -1:].contiguous().clone() + + torch._dynamo.mark_static(tokens) + torch._dynamo.mark_static(audio_features) + torch._dynamo.mark_static(positional_embedding) + for i in range(6): + torch._dynamo.mark_static(self.kv_cache[i]['attn']["key"]) + torch._dynamo.mark_static(self.kv_cache[i]['attn']["value"]) + torch._dynamo.mark_static(self.kv_cache[i]['cross_attn']["key"]) + torch._dynamo.mark_static(self.kv_cache[i]['cross_attn']["value"]) + torch._dynamo.mark_static(kv_padding_size) + + return self.model.decode_decoder(tokens, audio_features, positional_embedding, self.kv_cache, + actual_seq_len=[actual_seq_len], kv_padding_size=kv_padding_size, + updated_kv_positions=updated_kv_positions) + + PyTorchInference.logits = _patched_logits + + +def model_compile(): + print("torch.compile...") + wsp_model.encoder.forward = torch.compile(wsp_model.encoder.forward, dynamic=False, fullgraph=True, backend=npu_backend) + wsp_model.prefill_decoder.forward = torch.compile(wsp_model.prefill_decoder.forward, dynamic=False, fullgraph=True, backend=npu_backend) + wsp_model.decode_decoder.forward = torch.compile(wsp_model.decode_decoder.forward, dynamic=True, fullgraph=True, backend=npu_backend) + + +def libri_speech_infer(model, options, loader): + hypotheses = [] + references = [] + + for mels, texts in loader: + start_time = time.time() + results = model.decode(mels, options) + e2e_time = time.time() - start_time + print(f'Parquet infer E2E time = {e2e_time * 1000:.2f} ms') + hypotheses.extend([res.text for res in results]) + references.extend(texts) + + data = pd.DataFrame(dict(hypothesis=hypotheses, reference=references)) + print(data) + normalizer = EnglishTextNormalizer() + data["hypothesis_clean"] = [normalizer(text) for text in data["hypothesis"]] + data["reference_clean"] = [normalizer(text) for text in data["reference"]] + print(data[["hypothesis_clean", "reference_clean"]]) + wer = jiwer.wer(list(data["reference_clean"]), list(data["hypothesis_clean"])) + return wer + + +if __name__ == '__main__': + wsp_args = parse_args() + device = torch.device('npu:{}'.format(wsp_args.device)) + + torch_npu.npu.set_compile_mode(jit_compile=False) + config = CompilerConfig() + config.experimental_config.frozen_parameter = True + config.experimental_config.tiling_schedule_optimize = True # 使能tiling全下沉配置 + npu_backend = tng.get_npu_backend(compiler_config=config) + + dataset = LibriSpeechDataset(wsp_args.speech_path, device=device) + loader = torch.utils.data.DataLoader(dataset, batch_size=wsp_args.batch_size) + options = whisper.DecodingOptions(language='en', without_timestamps=True, fp16=True) + + wsp_model = create_model(wsp_args) + wsp_model = modify_model(wsp_model, options, wsp_args, device) + + rewrite_inference_logits() + model_compile() + + with torch.inference_mode(): + audio = whisper.load_audio(wsp_args.audio_path) + audio = whisper.pad_or_trim(audio) + audio_mel = whisper.log_mel_spectrogram(audio, n_mels=wsp_model.dims.n_mels).to(wsp_model.device) + audio_mel = audio_mel.unsqueeze(0).repeat(wsp_args.batch_size, 1, 1) + w_options = whisper.DecodingOptions(language='zh', without_timestamps=True, fp16=True) + for _step in range(wsp_args.warmup): + result = whisper.decode(wsp_model, audio_mel, w_options) + for bs in range(wsp_args.batch_size): + print("{}/{} - {}".format(_step, wsp_args.warmup, result[bs].text)) + + print("LibriSpeech infer, English to English TRANSCRIBE ...") + p_wer = libri_speech_infer(wsp_model, options, loader) + print(f"LibriSpeech infer WER score = {p_wer * 100:.2f} %") diff --git a/ACL_PyTorch/built-in/audio/whisper/requirements.txt b/ACL_PyTorch/built-in/audio/whisper/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..606e975265450fb74da3461c3c9bdcb58fb858d8 --- /dev/null +++ b/ACL_PyTorch/built-in/audio/whisper/requirements.txt @@ -0,0 +1,77 @@ +aiohappyeyeballs==2.6.1 +aiohttp==3.11.18 +aiosignal==1.3.2 +async-timeout==5.0.1 +attrs==25.3.0 +audioread==3.0.1 +certifi==2025.1.31 +cffi==1.17.1 +charset-normalizer==3.4.1 +click==8.1.8 +datasets==3.5.0 +decorator==5.2.1 +dill==0.3.8 +einops==0.8.1 +filelock==3.18.0 +frozenlist==1.6.0 +fsspec==2024.12.0 +greenlet==3.2.1 +huggingface-hub==0.30.2 +idna==3.10 +ijson==3.3.0 +Jinja2==3.1.6 +jiwer==3.1.0 +joblib==1.4.2 +lazy_loader==0.4 +librosa==0.11.0 +llvmlite==0.44.0 +MarkupSafe==3.0.2 +more-itertools==10.6.0 +mpmath==1.3.0 +msgpack==1.1.0 +msprof-analyze==2.0.2 +multidict==6.4.3 +multiprocess==0.70.16 +networkx==3.4.2 +numba==0.61.2 +numpy==1.24.0 +openai-whisper==20240930 +packaging==25.0 +pandas==2.2.3 +platformdirs==4.3.7 +pooch==1.8.2 +prettytable==3.16.0 +propcache==0.3.1 +protobuf==6.30.2 +psutil==7.0.0 +pyarrow==19.0.1 +pycparser==2.22 +python-dateutil==2.9.0.post0 +pytz==2025.2 +PyYAML==6.0.2 +RapidFuzz==3.13.0 +regex==2024.11.6 +requests==2.32.3 +safetensors==0.5.3 +scikit-learn==1.6.1 +scipy==1.15.2 +six==1.17.0 +soundfile==0.13.1 +soxr==0.5.0.post1 +SQLAlchemy==2.0.40 +sympy==1.13.1 +tabulate==0.9.0 +threadpoolctl==3.6.0 +tiktoken==0.9.0 +tokenizers==0.21.1 +torch==2.5.1 +torch-npu==2.5.1 +tqdm==4.67.1 +transformers==4.51.3 +typing_extensions==4.13.2 +tzdata==2025.2 +urllib3==1.26.20 +wcwidth==0.2.13 +XlsxWriter==3.2.3 +xxhash==3.5.0 +yarl==1.20.0 \ No newline at end of file diff --git a/ACL_PyTorch/built-in/audio/whisper/rewrited_models.py b/ACL_PyTorch/built-in/audio/whisper/rewrited_models.py new file mode 100644 index 0000000000000000000000000000000000000000..4d4b485a853a8498c73adbb5e301b680eafcd6b5 --- /dev/null +++ b/ACL_PyTorch/built-in/audio/whisper/rewrited_models.py @@ -0,0 +1,294 @@ +# Copyright (c) 2025 Huawei Technologies Co., Ltd +# [Software Name] is licensed under Mulan PSL v2. +# You can use this software according to the terms and conditions of the Mulan PSL v2. +# You may obtain a copy of Mulan PSL v2 at: +# http://license.coscl.org.cn/MulanPSL2 +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, +# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, +# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +# See the Mulan PSL v2 for more details. + +import math +from typing import Optional + +import numpy as np +import torch +import torch.nn as nn +from torch import Tensor +import torch_npu + +from whisper.model import Linear, LayerNorm, MultiHeadAttention, ResidualAttentionBlock + + +class MyMultiHeadSelfAttention(nn.Module): + + def __init__(self, n_state: int, n_head: int, n_ctx: int): + super().__init__() + self.n_head = n_head + self.query = Linear(n_state, n_state) + self.key = Linear(n_state, n_state, bias=False) + self.value = Linear(n_state, n_state) + self.out = Linear(n_state, n_state) + + self.kv = Linear(in_features=self.key.weight.shape[0], out_features=self.key.weight.shape[1] + self.value.weight.shape[1]) + self.n_ctx = n_ctx + + def forward( + self, + x: Tensor, + mask: Optional[Tensor] = None, + kv_cache: Optional[dict] = None, + updated_kv_positions: Optional[torch.LongTensor] = None, + actual_seq_len: Optional[list] = None, + kv_padding_size: Optional[torch.LongTensor] = None + ): + q = self.query(x) + + n_batch, n_ctx, n_state = q.shape + max_sample_len = self.n_ctx + # decoder - self_attention + k_key = "key" + v_key = "value" + # Prefill + if k_key not in kv_cache: + kv_cache[k_key] = torch.zeros(n_batch, max_sample_len, n_state, dtype=x.dtype, device=x.device) + kv_cache[v_key] = torch.zeros(n_batch, max_sample_len, n_state, dtype=x.dtype, device=x.device) + kv = self.kv(x) + k, v = kv.chunk(2, dim=-1) + kv_cache[k_key][:, :n_ctx, :] = k.detach().contiguous() + kv_cache[v_key][:, :n_ctx, :] = v.detach().contiguous() + # Decode + else: + new_kv = self.kv(x[:, -1:]) + new_k, new_v = new_kv.chunk(2, dim=-1) + tmp_ids = updated_kv_positions.expand(n_batch) + torch_npu.scatter_update_(kv_cache[k_key], tmp_ids, new_k, 1) + torch_npu.scatter_update_(kv_cache[v_key], tmp_ids, new_v, 1) + + k = kv_cache[k_key] + v = kv_cache[v_key] + + q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) + k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) + v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) + + D = n_state // self.n_head + # Prefill用FPA + if n_ctx > 1: + mask = mask.to(torch.bool) if mask is not None and n_ctx > 1 else None + sparse_mode = 1 if mask is not None and n_ctx > 1 else 0 + attn = torch_npu.npu_prompt_flash_attention( + q.contiguous(), + k.contiguous(), + v.contiguous(), + num_heads=self.n_head, + input_layout="BNSD", + scale_value=1 / math.sqrt(D), + atten_mask=mask[:n_ctx, :n_ctx] if mask is not None else None, + sparse_mode=sparse_mode + ) + # Decode用IFA + else: + attn = torch_npu.npu_incre_flash_attention( + q.contiguous(), + k.contiguous(), + v.contiguous(), + num_heads=self.n_head, + input_layout="BNSD", + scale_value=1 / math.sqrt(D), + atten_mask=None, + actual_seq_lengths=actual_seq_len, + kv_padding_size=kv_padding_size + ) + + w_v = attn.permute(0, 2, 1, 3).flatten(start_dim=2) + return self.out(w_v) + + +class MyMultiHeadCrossAttention(nn.Module): + def __init__(self, n_state: int, n_head: int): + super().__init__() + self.n_head = n_head + self.query = Linear(n_state, n_state) + self.key = Linear(n_state, n_state, bias=False) + self.value = Linear(n_state, n_state) + self.out = Linear(n_state, n_state) + + self.kv = Linear(in_features=self.key.weight.shape[0], + out_features=self.key.weight.shape[1] + self.value.weight.shape[1]) + + def forward( + self, + x: Tensor, + xa: Optional[Tensor] = None, + mask: Optional[Tensor] = None, + kv_cache: Optional[dict] = None, + ): + q = self.query(x) + + # decoder - cross_attention + k_key = "key" + v_key = "value" + if k_key in kv_cache: + k = kv_cache[k_key] + v = kv_cache[v_key] + else: + kv = self.kv(xa) + k, v = kv.chunk(2, dim=-1) + kv_cache[k_key] = k.contiguous() + kv_cache[v_key] = v.contiguous() + + n_batch, n_ctx, n_state = q.shape + q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) + k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) + v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) + + mask = mask.to(torch.bool) if mask is not None and n_ctx > 1 else None + sparse_mode = 1 if mask is not None and n_ctx > 1 else 0 + D = n_state // self.n_head + attn = torch_npu.npu_prompt_flash_attention( + q.contiguous(), + k.contiguous(), + v.contiguous(), + num_heads=self.n_head, + input_layout="BNSD", + scale_value=1 / math.sqrt(D), + atten_mask=mask[:n_ctx, :n_ctx] if mask is not None else None, + sparse_mode=sparse_mode + ) + + w_v = attn.permute(0, 2, 1, 3).flatten(start_dim=2) + return self.out(w_v) + + +class MyResidualAttentionBlock(nn.Module): + def __init__(self, n_state: int, n_head: int, n_ctx: int, cross_attention: bool = False): + super().__init__() + + self.attn = MyMultiHeadSelfAttention(n_state, n_head, n_ctx) + self.attn_ln = LayerNorm(n_state) + + self.cross_attn = ( + MyMultiHeadCrossAttention(n_state, n_head) if cross_attention else None + ) + self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None + + n_mlp = n_state * 4 + self.mlp = nn.Sequential( + Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state) + ) + self.mlp_ln = LayerNorm(n_state) + + def forward( + self, + x: Tensor, + xa: Optional[Tensor] = None, + mask: Optional[Tensor] = None, + kv_cache: Optional[dict] = None, + updated_kv_positions: Optional[torch.LongTensor] = None, + actual_seq_len: Optional[list] = None, + kv_padding_size: Optional[torch.LongTensor] = None + ): + x = x + self.attn(self.attn_ln(x), + mask=mask, + kv_cache=kv_cache['attn'], + actual_seq_len=actual_seq_len, + kv_padding_size=kv_padding_size, + updated_kv_positions=updated_kv_positions)[0] + x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache['cross_attn'])[0] + x = x + self.mlp(self.mlp_ln(x)) + return x + + +class MyTextDecoder(nn.Module): + def __init__(self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int): + super().__init__() + + self.token_embedding = nn.Embedding(n_vocab, n_state) + self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state)) + + self.blocks = nn.ModuleList( + [ + MyResidualAttentionBlock(n_state, n_head, n_ctx, cross_attention=True) + for _ in range(n_layer) + ] + ) + self.ln = LayerNorm(n_state) + + mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1) + self.register_buffer("mask", mask, persistent=False) + + def forward( + self, + x: Tensor, + xa: Tensor, + positional_embedding: Tensor = None, + kv_cache: Optional[dict] = None, + updated_kv_positions: Optional[torch.LongTensor] = None, + actual_seq_len: Optional[list] = None, + kv_padding_size: Optional[torch.LongTensor] = None + ): + pass + + +class PrefillTextDecoder(MyTextDecoder): + def __init__(self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int): + super().__init__(n_vocab, n_ctx, n_state, n_head, n_layer) + + def forward( + self, + x: Tensor, + xa: Tensor, + positional_embedding: Tensor = None, + kv_cache: Optional[dict] = None, + updated_kv_positions: Optional[torch.LongTensor] = None, + actual_seq_len: Optional[list] = None, + kv_padding_size: Optional[torch.LongTensor] = None + ): + offset = 0 + x = ( + self.token_embedding(x) + + self.positional_embedding[offset: offset + x.shape[-1]] + ) + x = x.to(xa.dtype) + + for layer_index, block in enumerate(self.blocks): + x = block(x, xa, mask=self.mask, kv_cache=kv_cache[layer_index], + updated_kv_positions=updated_kv_positions) + + x = self.ln(x) + logits = ( + x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1) + ).float() + + return logits + + +class DecodeTextDecoder(MyTextDecoder): + def __init__(self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int): + super().__init__(n_vocab, n_ctx, n_state, n_head, n_layer) + + def forward( + self, + x: Tensor, + xa: Tensor, + positional_embedding: Tensor, + kv_cache: Optional[dict] = None, + updated_kv_positions: Optional[torch.LongTensor] = None, + actual_seq_len: Optional[list] = None, + kv_padding_size: Optional[torch.LongTensor] = None + ): + x = (self.token_embedding(x) + positional_embedding) + x = x.to(xa.dtype) + + for layer_index, block in enumerate(self.blocks): + x = block(x, xa, mask=self.mask, kv_cache=kv_cache[layer_index], actual_seq_len=actual_seq_len, + kv_padding_size=kv_padding_size, + updated_kv_positions=updated_kv_positions) + + x = self.ln(x) + logits = ( + x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1) + ).float() + + return logits diff --git a/ACL_PyTorch/built-in/audio/whisper/whisper_decoding.patch b/ACL_PyTorch/built-in/audio/whisper/whisper_decoding.patch new file mode 100644 index 0000000000000000000000000000000000000000..871e972c2fb4d8d148103e7d636d7693f676332a --- /dev/null +++ b/ACL_PyTorch/built-in/audio/whisper/whisper_decoding.patch @@ -0,0 +1,34 @@ ++++ decoding.py +@@ -652,7 +652,10 @@ + # encoded audio features are given; skip audio encoding + audio_features = mel + else: ++ import time ++ time1 = time.time() + audio_features = self.model.encoder(mel) ++ print(f"encode time = {(time.time() - time1) * 1000:.2f} ms") + + if audio_features.dtype != ( + torch.float16 if self.options.fp16 else torch.float32 +@@ -683,6 +686,8 @@ + no_speech_probs = [np.nan] * n_batch + + try: ++ import time ++ time1 = time.time() + for i in range(self.sample_len): + logits = self.inference.logits(tokens, audio_features) + +@@ -703,6 +708,8 @@ + tokens, completed = self.decoder.update(tokens, logits, sum_logprobs) + + if completed or tokens.shape[-1] > self.n_ctx: ++ avg_time = (time.time() - time1) / i * 1000 ++ print(f"avg decode time = {avg_time:.2f} ms") + break + finally: + self.inference.cleanup_caching() +@@ -824,3 +831,4 @@ + result = DecodingTask(model, options).run(mel) + + return result[0] if single else result