From b25749576b9142b17d4b5c21faed71918e214d32 Mon Sep 17 00:00:00 2001 From: tiansongxue Date: Mon, 11 Nov 2024 11:53:23 +0800 Subject: [PATCH 01/14] =?UTF-8?q?=E5=A2=9E=E5=8A=A0whisperx?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../audio/whisperX/compile_whisper.py | 177 ++++ .../built-in/audio/whisperX/config.py | 18 + .../audio/whisperX/modeling_whisper.py | 988 ++++++++++++++++++ .../built-in/audio/whisperX/readme.md | 81 ++ 4 files changed, 1264 insertions(+) create mode 100644 MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py create mode 100644 MindIE/MindIE-Torch/built-in/audio/whisperX/config.py create mode 100644 MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py create mode 100644 MindIE/MindIE-Torch/built-in/audio/whisperX/readme.md diff --git a/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py b/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py new file mode 100644 index 0000000000..7e4690bb51 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py @@ -0,0 +1,177 @@ +import os +import argparse +import time +import torch +import mindietorch +from mindietorch._enums import dtype +from modeling_whisper import MindieWhisperForConditionalGeneration +from config import CompileInfo + +def compile_encoder(model : MindieWhisperForConditionalGeneration, + args : argparse, + compile_info : CompileInfo): + encoder = model.get_encoder() + + class Encoder(torch.nn.Module): + + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, input_features): + return self.model(input_features=input_features, return_dict=False) + + input_features = torch.randn([args.bs, compile_info.mel_feature_size, compile_info.max_frames]) + encoder_traced = torch.jit.trace(Encoder(encoder), (input_features)) + input_info = [mindietorch.Input(shape=(args.bs, compile_info.mel_feature_size, compile_info.max_frames))] + try: + print("Start AOE optimization_level 1.") + compiled = mindietorch.compile(encoder_traced, + inputs=input_info, + precision_policy=mindietorch.PrecisionPolicy.FP16, + soc_version=args.soc_version, + optimization_level=1 + ) + except Exception: + print("Using Aoe to optimize encoder model failed, but still can compile encoder model.") + else: + print("AOE optimize finished and start compile encoder.") + compiled = mindietorch.compile(encoder_traced, + inputs=input_info, + precision_policy=mindietorch.PrecisionPolicy.FP16, + soc_version=args.soc_version, + ) + save_file = os.path.join(args.save_path, f"{compile_info.prefix_name[0]}{args.bs}.ts") + torch.jit.save(compiled, save_file) + print(f"Compile encoder success, saved in {save_file}") + +def compile_prefill_decoder(model : MindieWhisperForConditionalGeneration, + args : argparse, compile_info : CompileInfo): + print("Start compiling prefill_decoder.") + encoder_outputs = torch.randn([args.bs, compile_info.encoder_seq_len, compile_info.hidden_size]) + decoder_input_ids = torch.randint(1, 4, (args.bs, 1)) + prefill_decoder_traced = torch.jit.trace(model.eval(), (decoder_input_ids, encoder_outputs)) + input_info = [mindietorch.Input(shape=(args.bs, 1), dtype=dtype.INT64), + mindietorch.Input(shape=(args.bs, encoder_seq_len, hidden_size))] + prefill_decoder_compiled = mindietorch.compile(prefill_decoder_traced, + inputs=input_info, + precision_policy=mindietorch.PrecisionPolicy.FP16, + soc_version=args.soc_version) + + save_file = os.path.join(args.save_path, f"{compile_info.prefix_name[1]}{args.bs}.ts") + torch.jit.save(prefill_decoder_compiled, save_file) + print(f"Compile prefill_decoder success, saved in {save_file}.") + +def compile_incre_decoder(args : argparse, compile_info : CompileInfo): + class Decoder(torch.nn.Module): + + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, *args): + return self.model.forward(*args)[0] + + mindie_whisper = MindieWhisperForConditionalGeneration.from_pretrained(args.model_path, + machine_type = args.machine_type, + is_incre_decode=True, + soc_version=args.soc_version) + decoder = Decoder(mindie_whisper) + print("Start compiling decoder.") + + encoder_outputs = torch.randn([args.bs, compile_info.encoder_seq_len, compile_info.hidden_size]) + decoder_input_ids = torch.randint(1, 4, (args.bs, 1)) + actual_seq_len = torch.ones((args.bs)) + all_past_key_value = [torch.randn([args.bs, compile_info.max_decode_step, compile_info.head_num, compile_info.head_size]), + torch.randn([args.bs, compile_info.seq_len, compile_info.head_num, compile_info.head_size]), + torch.randn([args.bs, compile_info.encoder_seq_len, compile_info.head_num, compile_info.head_size]), + torch.randn([args.bs, compile_info.encoder_seq_len, compile_info.head_num, compile_info.head_size]) + ] * compile_info.layer_nums + traced_args = [decoder_input_ids, encoder_outputs, actual_seq_len] + traced_args.extend(all_past_key_value) + traced_decoder = torch.jit.trace(decoder, traced_args) + # BSND + key_value_infos = [ + mindietorch.Input(shape=(args.bs, compile_info.max_len, compile_info.head_num, compile_info.head_size), + dtype=dtype.FLOAT16), + mindietorch.Input(shape=(args.bs, compile_info.max_len, compile_info.head_num, compile_info.head_size), + dtype=dtype.FLOAT16), + mindietorch.Input(shape=(args.bs, compile_info.encoder_seq_len, compile_info.head_num, compile_info.head_size), + dtype=dtype.FLOAT16), + mindietorch.Input(shape=(args.bs, compile_info.le_info.encoder_seq_len, compile_info.head_num, compile_info.head_size), + dtype=dtype.FLOAT16 + )] * compile_info.layer_nums + input_info = [mindietorch.Input(shape=(args.bs, 1), dtype=dtype.INT64), # input ids + mindietorch.Input(shape=(args.bs, compile_info.le_info.encoder_seq_len, compile_info.hidden_size)), + mindietorch.Input(shape=(args.bs,), dtype=dtype.INT64)] # actual sq len + + input_info.extend(key_value_infos) + float_size = 4 + voc_size = 51866 + buffer_size = math.ceil((args.bs * 1 * voc_size * float_size) / 1024 / 1024) + print(f"Set {buffer_size}/MB for output.") + compiled_decoder = mindietorch.compile(traced_decoder, + inputs=input_info, + precision_policy=mindietorch.PrecisionPolicy.FP16, + soc_version=args.soc_version, + default_buffer_size_vec=[buffer_size]) + save_file = os.path.join(args.save_path, f"{compile_info.prefix_name[2]}{args.bs}.ts") + torch.jit.save(compiled_decoder, save_file) + + print(f"Compile whisper_decoder success, saved in {save_file}.") + +def compile_scatter_update(self, args, compile_info): + class MindieScatter(torch.nn.Module): + def forward(self, past_key_value, indices, update_states): + out = torch.ops.aie.scatter_update(past_key_value, indices, update_states, 1) + return out + + bs = args.bs + self_past_key_value = torch.randn([bs, compile_info.max_decode_step, compile_info.head_num, compile_info.head_size]) + encoder_past_key_value = torch.randn([bs, compile_info.max_decode_step, compile_info.head_num, compile_info.head_size]) + indices = torch.tensor([0] * bs) + update_states = torch.randn([bs, 1, 20, 64]) + traced = torch.jit.trace(MindieScatter(), (self_past_key_value, indices, update_states)) + + self_attn_info = mindietorch.Input(shape=self_past_key_value.shape, dtype=mindietorch.dtype.FLOAT16) + encoder_attn_info = mindietorch.Input(shape=encoder_past_key_value.shape, dtype=mindietorch.dtype.FLOAT16) + indices_info = mindietorch.Input(shape=indices.shape, dtype=mindietorch.dtype.INT64) + update_states_info = mindietorch.Input(shape=encoder_past_key_value.shape, dtype=mindietorch.dtype.FLOAT16) + + compile_self = mindietorch.compile(traced, inputs=[self_attn_info, indices_info, update_states_info], + soc_version=args.soc_version) + torch.jit.save(compile_self, f"{args.save_path}/{compile_info.prefix_name[3]}{args.bs}.ts") + + compile_self = mindietorch.compile(traced, inputs=[encoder_attn_info, indices_info, update_states_info], + soc_version=args.soc_version) + torch.jit.save(compile_self, f"{args.save_path}/{compile_info.prefix_name[4]}{args.bs}.ts") + print("compile scatter success.") + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('-model_path', type=str, required=True, help="please provide model path.") + parser.add_argument('-bs', type=int, default=8, help="please provide batch_size, default:8.") + parser.add_argument('-soc_version', type=str, required=True, + help="please provide soc_version.") + parser.add_argument('-save_path', type=str, default="compiled_models", help="compiled models save dir.") + parser.add_argument('-machine_type', type=str, choices=["300IPro", "800IA2"], default="800A2") + parser.add_argument('-device_id', type=int, default=0) + + args = parser.parse_args() + mindie_whisper = MindieWhisperForConditionalGeneration.from_pretrained(args.model_path, + soc_version=args.soc_version, + machine_type=args.machine_type) + device = f"npu:{args.device_id}" + print("Start compiling Mindie-Whisper, it will take some time, please wait.") + if not args.save_path: + raise ValueError("Please provide the directory where the compiled model saved.") + if not os.path.exists(args.save_path): + os.makedirs(args.save_path) + print(f"Directory {args.save_path} created.") + else: + print(f"Directory {args.save_path} already exists.") + + compile_scatter_update(mindie_whisper, args, CompileInfo) + compile_encoder(mindie_whisper, args, CompileInfo) + compile_prefill_decoder(mindie_whisper, args, CompileInfo) + compile_incre_decoder(args, CompileInfo) \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/audio/whisperX/config.py b/MindIE/MindIE-Torch/built-in/audio/whisperX/config.py new file mode 100644 index 0000000000..6d4464195e --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/audio/whisperX/config.py @@ -0,0 +1,18 @@ +from dataclasses import dataclass, field +from typing import List + +@dataclass +class CompileInfo: + prefix_name = ["mindie_whisper_encoder_bs", + "mindie_decoder_prefill_bs", + "mindie_whisper_decoder_bs", + "mindie_self_scatter_bs", + "mindie_encoder_scatter_bs"] + mel_feature_size = 128 + max_frames = 3000 + max_decode_step = 448 + head_num = 20 + head_size = 64 + encoder_seq_len = 1500 + hidden_size = 1280 + layer_nums = 32 \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py b/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py new file mode 100644 index 0000000000..639c847349 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py @@ -0,0 +1,988 @@ +import argparse +import os +import copy +import math +import warnings +import librosa +from typing import Optional, Tuple, Union, List, Dict +from collections import OrderedDict + +import torch +from torch.utils.data import Dataset +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from transformers import WhisperForConditionalGeneration +from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask +from transformers.models.whisper.modeling_whisper import WhisperDecoder, WhisperEncoder, WhisperEncoderLayer, \ + WhisperModel, WhisperDecoderLayer, WhisperAttention +from transformers.generation.logits_process import LogitsProcessorList +from transformers.generation.stopping_criteria import StoppingCriteriaList, validate_stopping_criteria +from transformers.generation.utils import GenerationMixin +import mindietorch +from mindietorch._enums import dtype + + +class MindiePFA(WhisperAttention): + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + is_causal: bool = False, + config: bool = None ): + super().__init__( + embed_dim, + num_heads, + dropout, + bias, + is_causal, + config ) + + # BSND + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + **kwargs ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # cross_attn + is_cross_attn = key_value_states is not None + bsz, tgt_len, _ = hidden_states.size() + query_states = self.q_proj(hidden_states) + if ( + is_cross_attn + and past_key_value is not None + # past_key_value layout is BSND + and past_key_value[0].shape[1] == key_value_states.shape[1] + ): + # reuse k, v + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attn: + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + # self attn + elif past_key_value is not None: + # reuse k v + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=1) + value_states = torch.cat([past_key_value[1], value_states], dim=1) + else: + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + past_key_value = (key_states.to(torch.float16), value_states.to(torch.float16)) + + # B S N D + query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).contiguous() + + # B S N D + attn_output = torch.ops.aie.flash_attention( + query=query_states, + key=key_states, + value=value_states, + num_head=self.num_heads, + scale=self.scaling, + layout="BSND", + type="PFA" ) + attn_output = attn_output.view(bsz, tgt_len, self.num_heads, self.head_dim) + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + attn_output = self.out_proj(attn_output) + return attn_output, _, past_key_value + +class MindieIFA(WhisperAttention): + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + is_causal: bool = True, + config=None ): + super().__init__( + embed_dim, + num_heads, + dropout, + is_decoder, + bias, + is_causal, + config ) + + # BSND + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + past_key_value: torch.Tensor, + actual_seq_len: torch.Tensor, + **kwargs ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # self attn + assert past_key_value is not None, \ + "Current operation is incre_flash_attention, past_key_value is required." + bsz, tgt_len, _ = hidden_states.size() + assert tgt_len == 1, \ + "Current operation is incre_flash_attention, query's seq length should be equal to 1." + query_states = self.q_proj(hidden_states) + key_states = self._shape(self.k_proj(hidden_states), -1, bsz).to(torch.float16) + values_states = self._shape(self.v_proj(hidden_states), -1, bsz).to(torch.float16) + past_key_cache, past_value_cache = past_key_value[0], past_key_value[1] + indices = actual_seq_len - 1 + past_key_cache = torch.ops.aie.scatter_update(past_key_cache, indices, key_states, axis=1) + past_value_cache = torch.ops.aie.scatter_update(past_value_cache, indices, values_states, axis=1) + past_key_value = (past_key_cache, past_value_cache) + # B S N D + query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).contiguous() + # B S N D + attn_output = torch.ops.aie.incre_flash_attention( + query=query_states, + key=past_key_cache, + value=past_value_cache, + actual_seq_lengths=actual_seq_len, + num_head=self.num_heads, + scale=self.scaling, + layout="BSND") + attn_output = attn_output.view(bsz, tgt_len, self.num_heads, self.head_dim) + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + attn_output = self.out_proj(attn_output) + return attn_output, _, past_key_value + +class MindieFA(WhisperAttention): + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + is_causal: bool = False, + config: bool = None ): + super().__init__( + embed_dim, + num_heads, + dropout, + bias, + is_causal, + config) + + # BSND + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + **kwargs ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # cross_attn + is_cross_attn = key_value_states is not None + bsz, tgt_len, _ = hidden_states.size() + query_states = self.q_proj(hidden_states) + if ( + is_cross_attn + and past_key_value is not None + # past_key_value layout is BSND + and past_key_value[0].shape[1] == key_value_states.shape[1] + ): + # reuse k, v + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attn: + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + # self attn + elif past_key_value is not None: + # reuse k v + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=1) + value_states = torch.cat([past_key_value[1], value_states], dim=1) + else: + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + past_key_value = (key_states.to(torch.float16), value_states.to(torch.float16)) + + # MindFA only support BNSD layout + query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).contiguous() + attn_output = torch.ops.aie.flash_attention( + query=query_states.transpose(1, 2), + key=key_states.transpose(1, 2), + value=value_states.transpose(1, 2), + num_head=self.num_heads, + scale=self.scaling, + layout="BNSD", + type="FA_HIGH_PERF") + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.view(bsz, tgt_len, self.num_heads, self.head_dim) + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + attn_output = self.out_proj(attn_output) + return attn_output, _, past_key_value + +class MindieAttention(WhisperAttention): + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + is_causal: bool = False, + config=None ): + super().__init__( + embed_dim, + num_heads, + dropout, + is_decoder, + bias, + is_causal, + config ) + + # BSND + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + **kwargs ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, tgt_len, _ = hidden_states.size() + query_states = self.q_proj(hidden_states) + + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + # B S N D + query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).contiguous() + + # B S N D + attn_output = torch.ops.aie.flash_attention( + query=query_states, + key=key_states, + value=value_states, + num_head=self.num_heads, + scale=self.scaling, + layout="BSND", + type="PFA") + + attn_output = attn_output.view(bsz, tgt_len, self.num_heads, self.head_dim) + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + attn_output = self.out_proj(attn_output) + return attn_output, _, _ + + +class MindieWhisperDecoderLayer(WhisperDecoderLayer): + + def __init__(self, config): + super().__init__(config) + self.embed_dim = config.d_model + if config.is_incre_decode: + self.self_attn = MindieIFA( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + config=config, + is_decoder=True + ) + self.encoder_attn = MindieIFA( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + is_decoder=True, + config=config + ) + elif config.machine_type == "800IA2": + self.self_attn = MindiePFA( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + config=config, + is_decoder=True + ) + self.encoder_attn = MindiePFA( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + is_decoder=True, + config=config + ) + elif config.machine_type == "300IPro": + self.self_attn = MindieFA( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + config=config, + is_decoder=True + ) + self.encoder_attn = MindieFA( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + is_decoder=True, + config=config + ) + else: + raise ValueError(f"Unsupporting current parameters. eg. " + f"is_incre_decode: {is_incre_decode} machine_type {machine_type}") + + def forward( + self, + hidden_states, + attention_mask, + encoder_hidden_states, + layer_head_mask, + cross_attn_layer_head_mask, + past_key_value, + cross_attn_past_key_value, + output_attentions, + use_cache, + actual_seq_len=None, + encoder_attention_mask=None + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Self Attention + self_attn_past_key_value = past_key_value if past_key_value is not None else None + hidden_states, _, present_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_value=self_attn_past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + actual_seq_len=actual_seq_len + ) + hidden_states = residual + hidden_states + + # Cross-Attention Block + cross_attn_present_key_value = None + if encoder_hidden_states is not None: + residual = hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + + hidden_states, _, cross_attn_present_key_value = self.encoder_attn( + hidden_states=hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + layer_head_mask=cross_attn_layer_head_mask, + past_key_value=cross_attn_past_key_value, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + present_key_value = present_key_value + cross_attn_present_key_value + + # Fully Connected + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = hidden_states.reshape(-1, self.embed_dim) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = self.fc2(hidden_states) + hidden_states = hidden_states.reshape(-1, 1, self.embed_dim) + hidden_states = residual + hidden_states + outputs = (hidden_states,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class MindieWhisperDecoder(WhisperDecoder): + + def __init__(self, config): + super().__init__(config) + self.layers = nn.ModuleList([MindieWhisperDecoderLayer(config) + for _ in range(config.decoder_layers)]) + self.config = config + + def forward( + self, + input_ids, + encoder_hidden_states, + past_key_values, + actual_seq_len, + use_cache=True, + attention_mask=None): + if input_ids is None: + raise ValueError("You have to specify either decoder_input_ids") + inputs_embeds = self.embed_tokens(input_ids) + + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + # B S N D + past_key_values_length = past_key_values[0].shape[1] if past_key_values is not None else 0 + inputs_embeds = self.embed_tokens(input_ids) + # embed positions + positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length) + hidden_states = inputs_embeds + positions + + past_key_value_cache = [] + for idx, decoder_layer in enumerate(self.layers): + past_key_value = (past_key_values[4 * idx], past_key_values[4 * idx + 1]) \ + if past_key_values is not None else None + cross_past_key_value = (past_key_values[4 * idx + 2], past_key_values[4 * idx + 3]) \ + if past_key_values is not None else None + + layer_outputs = decoder_layer( + hidden_states, + actual_seq_len=actual_seq_len, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + layer_head_mask=None, + cross_attn_layer_head_mask=None, + past_key_value=past_key_value, + cross_attn_past_key_value=cross_past_key_value, + output_attentions=None, + use_cache=use_cache) + + hidden_states = layer_outputs[0] + past_key_value_cache.extend(layer_outputs[1]) + + hidden_states = self.layer_norm(hidden_states) + return hidden_states, past_key_value_cache + + +class MindieWhisperEncoderLayer(WhisperEncoderLayer): + def __init__(self, config): + super().__init__(config) + self.self_attn = MindiePFA( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + dropout=config.attention_dropout, + config=config, + ) + + +class MindieWhisperEncoder(WhisperEncoder): + + def __init__(self, config): + super().__init__(config) + self.layers = nn.ModuleList([MindieWhisperEncoderLayer(config) for _ in range(config.encoder_layers)]) + + +class MindieWhisperModel(WhisperModel): + + def __init__(self, config): + super().__init__(config) + self.decoder = MindieWhisperDecoder(config) + self.encoder = MindieWhisperEncoder(config) + + def forward( + self, + encoder_outputs, + decoder_input_ids, + past_key_values: Optional[torch.Tensor] = None, + actual_seq_len: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = False, + return_dict: Optional[bool] = True, + input_features: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None + ) -> List[torch.Tensor]: + if input_features is None and encoder_outputs is None: + raise ValueError("Make sure to provide either `input_features` or `encoder_outputs` to `forward`.") + + if encoder_outputs is None: + input_features = self._mask_input_features(input_features, attention_mask=attention_mask) + + encoder_outputs = self.encoder( + input_features, + head_mask=None, + output_attentions=None, + output_hidden_states=None, + return_dict=return_dict, + ) + + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + encoder_hidden_states=encoder_outputs, + past_key_values=past_key_values, + use_cache=use_cache, + actual_seq_len=actual_seq_len + ) + return decoder_outputs + + +class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, GenerationMixin): + + def __init__(self, config, is_incre_decode=False, machine_type="800A2"): + super().__init__(config) + config.is_incre_decode = is_incre_decode + config.machine_type = machine_type + self.model = MindieWhisperModel(config) + self.has_load = False + self.has_compile = False + self.mindie_encoder = None + self.mindie_decoder_prefill = None + self.mindie_decoder = None + self.self_attn_scatter = None + self.encoder_attn_scatter = None + self.save_path = None + self.encoder_seq_len = 1500 + self.file_prefix_names = ["mindie_whisper_encoder_bs", + "mindie_decoder_prefill_bs", + "mindie_whisper_decoder_bs", + "mindie_self_scatter_bs", + "mindie_encoder_scatter_bs" + ] + self.hidden_size = 1280 + self.head_num = 20 + self.head_size = 64 + self.seq_len = 1 + self.layer_nums = 32 + self.max_len = 448 + self.past_key_value = [] + + def forward(self, *args): + if len(args) not in (2, 131): + raise ValueError(f"The args length of forward can only be 2 or 131, but got {len(args)}") + decoder_input_ids = args[0] + encoder_outputs = args[1] + if len(args) == 131: + actual_seq_len = args[2] + past_key_values = args[3:] + else: + actual_seq_len = None + past_key_values = None + outputs = self.model( + decoder_input_ids=decoder_input_ids, + encoder_outputs=encoder_outputs, + past_key_values=past_key_values, + actual_seq_len=actual_seq_len, + use_cache=True, + return_dict=False, + input_features=None + ) + lm_logits = self.proj_out(outputs[0]) + return [lm_logits] + outputs[1] + + def compile_encoder(self, args): + encoder = self.get_encoder() + class Encoder(torch.nn.Module): + + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, input_features): + return self.model(input_features=input_features, return_dict=False) + + mel_feature_size = 128 + max_frames = 3000 + input_features = torch.randn([args.bs, mel_feature_size, max_frames]) + encoder_traced = torch.jit.trace(Encoder(encoder), (input_features)) + input_info = [mindietorch.Input(shape=(args.bs, mel_feature_size, max_frames))] + print("Start AOE optimization_level 1.") + compiled = mindietorch.compile(encoder_traced, + inputs=input_info, + precision_policy=mindietorch.PrecisionPolicy.FP16, + soc_version=args.soc_version, + optimization_level=1 + ) + print("AOE optimize finished and start compile encoder.") + compiled = mindietorch.compile(encoder_traced, + inputs=input_info, + precision_policy=mindietorch.PrecisionPolicy.FP16, + soc_version=args.soc_version, + ) + save_file = os.path.join(args.save_path, f"{self.file_prefix_names[0]}{args.bs}.ts") + torch.jit.save(compiled, save_file) + print(f"Compile encoder success, saved in {save_file}") + + def compile_prefill_decoder(self, args): + print("Start compiling prefill_decoder.") + encoder_seq_len = 1500 + hidden_size = 1280 + encoder_outputs = torch.randn([args.bs, encoder_seq_len, hidden_size]) + decoder_input_ids = torch.randint(1, 4, (args.bs, 1)) + prefill_decoder_traced = torch.jit.trace(self.eval(), (decoder_input_ids, encoder_outputs)) + input_info = [mindietorch.Input(shape=(args.bs, 1), dtype=dtype.INT64), + mindietorch.Input(shape=(args.bs, encoder_seq_len, hidden_size))] + prefill_decoder_compiled = mindietorch.compile(prefill_decoder_traced, + inputs=input_info, + precision_policy=mindietorch.PrecisionPolicy.FP16, + soc_version=args.soc_version) + + save_file = os.path.join(args.save_path, f"{self.file_prefix_names[1]}{args.bs}.ts") + torch.jit.save(prefill_decoder_compiled, save_file) + print(f"Compile prefill_decoder success, saved in {save_file}.") + + def compile_decoder(self, args): + + class Decoder(torch.nn.Module): + + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, *args): + return self.model.forward(*args)[0] + + mindie_whisper = MindieWhisperForConditionalGeneration.from_pretrained(args.model_path, + is_incre_decode=True) + decoder = Decoder(mindie_whisper) + print("Start compiling decoder.") + + encoder_outputs = torch.randn([args.bs, self.encoder_seq_len, self.hidden_size]) + decoder_input_ids = torch.randint(1, 4, (args.bs, 1)) + actual_seq_len = torch.ones((args.bs)) + all_past_key_value = [torch.randn([args.bs, self.seq_len, self.head_num, self.head_size]), + torch.randn([args.bs, self.seq_len, self.head_num, self.head_size]), + torch.randn([args.bs, self.encoder_seq_len, self.head_num, self.head_size]), + torch.randn([args.bs, self.encoder_seq_len, self.head_num, self.head_size]) + ] * self.layer_nums + traced_args = [decoder_input_ids, encoder_outputs, actual_seq_len] + traced_args.extend(all_past_key_value) + traced_decoder = torch.jit.trace(decoder, traced_args) + # BSND + key_value_infos = [ + mindietorch.Input(shape=(args.bs, self.max_len, self.head_num, self.head_size), + dtype=dtype.FLOAT16), + mindietorch.Input(shape=(args.bs, self.max_len, self.head_num, self.head_size), + dtype=dtype.FLOAT16), + mindietorch.Input(shape=(args.bs, self.encoder_seq_len, self.head_num, self.head_size), + dtype=dtype.FLOAT16), + mindietorch.Input(shape=(args.bs, self.encoder_seq_len, self.head_num, self.head_size), + dtype=dtype.FLOAT16 + )] * self.layer_nums + input_info = [mindietorch.Input(shape=(args.bs, 1), dtype=dtype.INT64), + mindietorch.Input(shape=(args.bs, self.encoder_seq_len, self.hidden_size)), + mindietorch.Input(shape=(args.bs, ), dtype=dtype.INT64)] + + input_info.extend(key_value_infos) + float_size = 4 + voc_size = 51866 + buffer_size = math.ceil((args.bs * 1 * voc_size * float_size) / 1024 / 1024) + print(f"Set {buffer_size}/MB for output.") + compiled_decoder = mindietorch.compile(traced_decoder, + inputs=input_info, + precision_policy=mindietorch.PrecisionPolicy.FP16, + soc_version=args.soc_version, + default_buffer_size_vec=[buffer_size]) + save_file = os.path.join(args.save_path, f"{self.file_prefix_names[2]}{args.bs}.ts") + torch.jit.save(compiled_decoder, save_file) + + print(f"Compile whisper_decoder success, saved in {save_file}.") + + def compile_scatter_update(self, args): + + class MindieScatter(torch.nn.Module): + def forward(self, past_key_value, indices, update_states): + out = torch.ops.aie.scatter_update(past_key_value, indices, update_states, 1) + return out + bs = args.bs + self_past_key_value = torch.randn([bs, self.max_len, self.head_num, self.head_size]) + encoder_past_key_value = torch.randn([bs, self.max_len, self.head_num, self.head_size]) + indices = torch.tensor([0]*bs) + update_states = torch.randn([bs, 1, 20, 64]) + traced = torch.jit.trace(MindieScatter(), (self_past_key_value, indices, update_states)) + + self_attn_info = mindietorch.Input(shape=self_past_key_value.shape, dtype=mindietorch.dtype.FLOAT16) + encoder_attn_info = mindietorch.Input(shape=encoder_past_key_value.shape, dtype=mindietorch.dtype.FLOAT16) + indices_info = mindietorch.Input(shape=indices.shape, dtype=mindietorch.dtype.INT64) + update_states_info = mindietorch.Input(shape=encoder_past_key_value.shape, dtype=mindietorch.dtype.FLOAT16) + + compile_self = mindietorch.compile(traced, inputs=[self_attn_info, indices_info, update_states_info], + soc_version=args.soc_version) + torch.jit.save(compile_self, f"{args.save_path}/{self.file_prefix_names[3]}{args.bs}.ts") + + compile_self = mindietorch.compile(traced, inputs=[encoder_attn_info, indices_info, update_states_info], + soc_version=args.soc_version) + torch.jit.save(compile_self, f"{args.save_path}/{self.file_prefix_names[4]}{args.bs}.ts") + print("compile scatter success.") + + def compile(self, args): + print("Start compiling Mindie-Whisper, it will take some time, please wait.") + save_path = args.save_path + + if not save_path: + raise ValueError("Please provide the directory where the compiled model saved.") + if not os.path.exists(save_path): + os.makedirs(save_path) + print(f"Directory {save_path} created.") + else: + print(f"Directory {save_path} already exists.") + self.compile_encoder(args) + self.compile_prefill_decoder(args) + self.compile_decoder(args) + self.compile_scatter_update(args) + self.has_compile = True + + def _init_past_key_value_cache(self, bsz): + for _ in range(32): + self.past_key_value.append(torch.ones([bsz, self.max_len, self.head_num, self.head_size])) + self.past_key_value.append(torch.ones([bsz, self.max_len, self.head_num, self.head_size])) + self.past_key_value.append(torch.ones([bsz, self.encoder_seq_len, self.head_num, self.head_size])) + self.past_key_value.append(torch.ones([bsz, self.encoder_seq_len, self.head_num, self.head_size])) + print("init past key value cache success.") + + def load_mindie_models(self, save_path, batch_size): + if not (save_path and batch_size): + raise ValueError(f"Please provide batch_size and the directory where the compiled models saved,\ + but found save_path is {save_path}, batch_size is{batch_size}.") + self._check_save_path(save_path, batch_size) + self.batch_size = batch_size + + self._init_past_key_value_cache(batch_size) + + + if not self.has_load: + self.mindie_encoder = torch.jit.load(f"{save_path}/{self.file_prefix_names[0]}{batch_size}.ts") + print(f"load {self.file_prefix_names[0]}{batch_size}.ts success.") + + self.mindie_decoder_prefill = torch.jit.load(f"{save_path}/{self.file_prefix_names[1]}{batch_size}.ts") + print(f"load {self.file_prefix_names[1]}{batch_size}.ts success.") + + self.mindie_decoder = torch.jit.load(f"{save_path}/{self.file_prefix_names[2]}{batch_size}.ts") + print(f"load {self.file_prefix_names[2]}{batch_size}.ts success.") + + self.self_attn_scatter = torch.jit.load(f"{save_path}/{self.file_prefix_names[3]}{batch_size}.ts") + print(f"load {self.file_prefix_names[3]}{batch_size}.ts success.") + + self.self_attn_scatter = torch.jit.load(f"{save_path}/{self.file_prefix_names[4]}{batch_size}.ts") + print(f"load {self.file_prefix_names[4]}{batch_size}.ts success.") + self.has_load = True + else: + print("Mindie whisper has already load.") + + def greedy_search( + self, + input_ids: torch.LongTensor, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + max_length: Optional[int] = None, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[Union[int, List[int]]] = None, + output_scores: Optional[bool] = None, + return_dict_in_generate: Optional[bool] = None, + synced_gpus: bool = False, + **model_kwargs): + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + if max_length is not None: + print( + "`max_length` is deprecated in this function, use" + " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead." + ) + stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) + pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + eos_token_id_tensor = torch.tensor(eos_token_id).to("cpu") if eos_token_id is not None else None + output_scores = output_scores if output_scores is not None else self.generation_config.output_scores + # init attention / hidden states / scores tuples + scores = () if (return_dict_in_generate and output_scores) else None + + # keep track of which sequences are already finished + unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device="cpu") + + this_peer_finished = False # used by synced_gpus only + kv_actual_step = 1 + indices = torch.tensor([0] * input_ids.shape[0]) + is_first_step = True + while True: + + model_inputs = self.prepare_inputs_for_generation(input_ids, is_first_step, **model_kwargs) + args = [model_inputs["decoder_input_ids"].contiguous().to("npu"), model_inputs["encoder_outputs"]] + if is_first_step: + outputs = self.mindie_decoder_prefill(*args) + for idx in range(32): + self.self_attn_scatter(self.past_key_value[0], indices, outputs[1 + 4*idx]) + self.self_attn_scatter(self.past_key_value[1], indices, outputs[1 + 4*idx + 1]) + self.encoder_attn_scatter(self.past_key_value[2], indices, outputs[1 + 4*idx + 2]) + self.encoder_attn_scatter(self.past_key_value[3], indices, outputs[1 + 4*idx + 3]) + is_first_step = False + else: + kv_actual_step += 1 + args.append(torch.tensor([kv_actual_step] * input_ids.shape[0])).to("npu") + args.extend(self.past_key_value) + outputs = self.mindie_decoder(*args) + if isinstance(outputs, list): + next_token_logits = outputs[0].to("cpu")[:, -1, :] + else: + next_token_logits = outputs.to("cpu")[:, -1, :] + + # pre-process distribution + next_tokens_scores = logits_processor(input_ids.to("cpu"), next_token_logits) + + # argmax + next_tokens = torch.argmax(next_tokens_scores, dim=-1) + + # finished sentences should have their next token be a padding token + if eos_token_id is not None: + if pad_token_id is None: + raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") + next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) + + # update generated ids, model inputs, and length for next step + input_ids = torch.cat([input_ids.to("cpu"), next_tokens[:, None]], dim=-1) + + model_kwargs = self._update_model_kwargs_for_generation( + outputs, model_kwargs + ) + + # if eos_token was found in one sentence, set sentence to finished + if eos_token_id_tensor is not None: + unfinished_sequences = unfinished_sequences.mul( + next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) + ) + + # stop when each sentence is finished + if unfinished_sequences.max() == 0: + this_peer_finished = True + + # stop if we exceed the maximum length + if stopping_criteria(input_ids, scores): + this_peer_finished = True + + if this_peer_finished and not synced_gpus: + break + + return input_ids + + def generate( + self, + input_features: Optional[torch.Tensor] = None, + generation_config=None, + logits_processor=None, + stopping_criteria=None, + prefix_allowed_tokens_fn=None, + synced_gpus=False, + return_dict_in_generate: Optional[bool] = None, + **kwargs, + ): + if generation_config is None: + generation_config = copy.deepcopy(self.generation_config) + num_segment_frames = 3000 + assert input_features.shape[-1] == num_segment_frames, "Only support 30s speech." + encoder_outputs = self.mindie_encoder(input_features.to("npu"))[0] + kwargs["encoder_outputs"] = encoder_outputs + outputs = GenerationMixin.generate( + self, + input_features, + generation_config, + logits_processor, + stopping_criteria, + prefix_allowed_tokens_fn, + synced_gpus, + return_dict_in_generate=return_dict_in_generate, + **kwargs + ) + return outputs + + def _update_model_kwargs_for_generation( + self, + outputs, + model_inputs, + ): + return model_inputs + + def _maybe_initialize_input_ids_for_generation( + self, + inputs: Optional[torch.Tensor] = None, + bos_token_id: Optional[int] = None, + model_kwargs: Optional[Dict[str, torch.Tensor]] = None, + ) -> torch.LongTensor: + """Initializes input ids for generation, if necessary.""" + if inputs is not None: + return inputs + + encoder_outputs = model_kwargs.get("encoder_outputs") + if self.config.is_encoder_decoder and encoder_outputs is not None: + # make dummy input_ids with value -100, as a sanity check ensuring that they won't be used for encoding + shape = encoder_outputs.size()[:-1] + input_ids = torch.ones(shape, dtype=torch.long, device="cpu") * -100 + return input_ids.to("npu") + + if bos_token_id is None: + raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.") + + # If there is some tensor in `model_kwargs`, we can infer the batch size from it. This is helpful with + # soft-prompting or in multimodal implementations built on top of decoder-only language models. + batch_size = 1 + for value in model_kwargs.values(): + if isinstance(value, torch.Tensor): + batch_size = value.shape[0] + break + return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * bos_token_id + + def _prepare_decoder_input_ids_for_generation( + self, + batch_size: int, + model_input_name: str, + model_kwargs: Dict[str, torch.Tensor], + decoder_start_token_id: int = None, + bos_token_id: int = None, + device: torch.device = None, + ) -> Tuple[torch.LongTensor, Dict[str, torch.Tensor]]: + """Prepares `decoder_input_ids` for generation with encoder-decoder models""" + # 1. Check whether the user has defined `decoder_input_ids` manually. To facilitate in terms of input naming, + # we also allow the user to pass it under `input_ids`, if the encoder does not use it as the main input. + if model_kwargs is not None and "decoder_input_ids" in model_kwargs: + decoder_input_ids = model_kwargs.pop("decoder_input_ids") + elif "input_ids" in model_kwargs and model_input_name != "input_ids": + decoder_input_ids = model_kwargs.pop("input_ids") + else: + decoder_input_ids = None + + # 2. Encoder-decoder models expect the `decoder_input_ids` to start with a special token. Let's ensure that. + decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id) + if device is None: + device = self.device + decoder_input_ids_start = torch.ones((batch_size, 1), dtype=torch.long, device="cpu") * decoder_start_token_id + + # no user input -> use decoder_start_token_id as decoder_input_ids + if decoder_input_ids is None: + decoder_input_ids = decoder_input_ids_start + # exception: Donut checkpoints have task-specific decoder starts and don't expect a BOS token + elif self.config.model_type == "vision-encoder-decoder" and "donut" in self.name_or_path.lower(): + pass + # user input but doesn't start with decoder_start_token_id -> prepend decoder_start_token_id (and adjust + # decoder_attention_mask if provided) + elif (decoder_input_ids[:, 0] != decoder_start_token_id).all().item(): + decoder_input_ids = torch.cat([decoder_input_ids_start, decoder_input_ids], dim=-1) + if "decoder_attention_mask" in model_kwargs: + decoder_attention_mask = model_kwargs["decoder_attention_mask"] + decoder_attention_mask = torch.cat( + (torch.ones_like(decoder_attention_mask)[:, :1], decoder_attention_mask), + dim=-1, + ) + model_kwargs["decoder_attention_mask"] = decoder_attention_mask.to(self.device) + + return decoder_input_ids.to("npu"), model_kwargs + + def _check_save_path(self, save_path, batch_size): + file_list = os.listdir(save_path) + expected_files = [file + f"{batch_size}.ts" for file in self.file_prefix_names] + for file in expected_files: + if file not in file_list: + raise ValueError(f"Expected file name is {file}, but can't be found in path: {save_path}") + + def prepare_inputs_for_generation( + self, + decoder_input_ids, + if_first_step, + encoder_outputs=None, + **kwargs + ): + if not if_first_step: + decoder_input_ids_shape = decoder_input_ids.shape + remove_prefix_length = decoder_input_ids_shape[1] - 1 + decoder_input_ids = decoder_input_ids[:, remove_prefix_length] + return { + "encoder_outputs": encoder_outputs, + "decoder_input_ids": decoder_input_ids + } \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/audio/whisperX/readme.md b/MindIE/MindIE-Torch/built-in/audio/whisperX/readme.md new file mode 100644 index 0000000000..60a3c83ab5 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/audio/whisperX/readme.md @@ -0,0 +1,81 @@ +# Whisper-large-v3模型-推理指导 + +- [概述](#概述) +- [推理环境准备](#推理环境准备) +- [快速上手](#快速上手) + - [获取源码](#获取源码) + - [模型推理](#模型推理) + +# 概述 + +该工程使用mindietorch部署whisper-large-v3模型 + + +# 推理环境准备 + +- 该模型需要以下插件与驱动 + + **表 1** 版本配套表 + + | 配套 | 版本 | 环境准备指导 | + |-----------| ------- | ------------ | + | Python | 3.10.13 | - | + | torch | 2.1.0+cpu | - | + | torch_audio | 2.1.0+cpu | - | + | CANN | 8.0.RC3 | - | + | MindIE | 1.0.RC3 | - | + +# 快速上手 +## 获取源码 + +1. 安装mindie包 + + ```bash + # 安装mindie + chmod +x ./Ascend-mindie_xxx.run + ./Ascend-mindie_xxx.run --install + source /usr/local/Ascend/mindie/set_env.sh + ``` + + +2. 模型权重下载路径: + ```bash + https://huggingface.co/openai/whisper-large-v3/tree/main + ``` + 将权重文件存放至当前目录下的model_path文件夹,请先创建改文件夹。 + + +5. 安装依赖 + ``` + pip3 install torch==2.1.0 --index-url https://download.pytorch.org/whl/cpu + pip3 install nltk + pip3 install librosa + pip3 install transformers==4.36.0 + pip3 install numpy==1.26.0 + ``` + +## 模型推理 +1. 设置mindie内存池上限为32,执行如下命令设置环境变量。内存池设置过小,内存重复申请和释放会影响性能。 + ``` + export TORCH_AIE_NPU_CACHE_MAX_SIZE=32 + ``` + +2. 模型编译和推理 + ``` + python3 compile_whisper.py \ + -model_path ./model_path \ + -bs 16 \ + -save_path ./compiled_models \ + -soc_version Ascend310P3 + ``` + + 参数说明: + - -model_path:预训练模型路径,必选。 + - -bs:batch_size, 默认值为8, 可选。 + - -save_path: 编译好的模型的保存文件,必选。 + - -device_id: 选在模型运行的卡编号,默认值0,可选。 + - -soc_version: 芯片类型,必选。 + - -machine_tyoe: 机器型号,必选。 + + 约束说明: + 1. 当前暂不支持动态batch,batch_size改变后,需要重新编图。 -- Gitee From f7d0e7555113587939e63fd50db4899bed1867e6 Mon Sep 17 00:00:00 2001 From: tiansongxue Date: Mon, 11 Nov 2024 14:08:02 +0800 Subject: [PATCH 02/14] =?UTF-8?q?=E5=A2=9E=E5=8A=A0whisperx?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../built-in/audio/whisperX/compile_whisper.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py b/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py index 7e4690bb51..923cfe0905 100644 --- a/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py +++ b/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py @@ -52,7 +52,7 @@ def compile_prefill_decoder(model : MindieWhisperForConditionalGeneration, decoder_input_ids = torch.randint(1, 4, (args.bs, 1)) prefill_decoder_traced = torch.jit.trace(model.eval(), (decoder_input_ids, encoder_outputs)) input_info = [mindietorch.Input(shape=(args.bs, 1), dtype=dtype.INT64), - mindietorch.Input(shape=(args.bs, encoder_seq_len, hidden_size))] + mindietorch.Input(shape=(args.bs, compile_info.encoder_seq_len, compile_info.hidden_size))] prefill_decoder_compiled = mindietorch.compile(prefill_decoder_traced, inputs=input_info, precision_policy=mindietorch.PrecisionPolicy.FP16, @@ -83,7 +83,7 @@ def compile_incre_decoder(args : argparse, compile_info : CompileInfo): decoder_input_ids = torch.randint(1, 4, (args.bs, 1)) actual_seq_len = torch.ones((args.bs)) all_past_key_value = [torch.randn([args.bs, compile_info.max_decode_step, compile_info.head_num, compile_info.head_size]), - torch.randn([args.bs, compile_info.seq_len, compile_info.head_num, compile_info.head_size]), + torch.randn([args.bs, compile_info.max_decode_step, compile_info.head_num, compile_info.head_size]), torch.randn([args.bs, compile_info.encoder_seq_len, compile_info.head_num, compile_info.head_size]), torch.randn([args.bs, compile_info.encoder_seq_len, compile_info.head_num, compile_info.head_size]) ] * compile_info.layer_nums @@ -92,17 +92,17 @@ def compile_incre_decoder(args : argparse, compile_info : CompileInfo): traced_decoder = torch.jit.trace(decoder, traced_args) # BSND key_value_infos = [ - mindietorch.Input(shape=(args.bs, compile_info.max_len, compile_info.head_num, compile_info.head_size), + mindietorch.Input(shape=(args.bs, compile_info.max_decode_step, compile_info.head_num, compile_info.head_size), dtype=dtype.FLOAT16), - mindietorch.Input(shape=(args.bs, compile_info.max_len, compile_info.head_num, compile_info.head_size), + mindietorch.Input(shape=(args.bs, compile_info.max_decode_step, compile_info.head_num, compile_info.head_size), dtype=dtype.FLOAT16), mindietorch.Input(shape=(args.bs, compile_info.encoder_seq_len, compile_info.head_num, compile_info.head_size), dtype=dtype.FLOAT16), - mindietorch.Input(shape=(args.bs, compile_info.le_info.encoder_seq_len, compile_info.head_num, compile_info.head_size), + mindietorch.Input(shape=(args.bs, compile_info.encoder_seq_len, compile_info.head_num, compile_info.head_size), dtype=dtype.FLOAT16 )] * compile_info.layer_nums input_info = [mindietorch.Input(shape=(args.bs, 1), dtype=dtype.INT64), # input ids - mindietorch.Input(shape=(args.bs, compile_info.le_info.encoder_seq_len, compile_info.hidden_size)), + mindietorch.Input(shape=(args.bs, compile_info.encoder_seq_len, compile_info.hidden_size)), mindietorch.Input(shape=(args.bs,), dtype=dtype.INT64)] # actual sq len input_info.extend(key_value_infos) -- Gitee From 696b6f11a3a7567a1f56e6ef850e0f07d2313a4c Mon Sep 17 00:00:00 2001 From: tiansongxue Date: Tue, 12 Nov 2024 00:10:06 +0800 Subject: [PATCH 03/14] =?UTF-8?q?=E5=A2=9E=E5=8A=A0whisperx?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../audio/whisperX/compile_whisper.py | 1 + .../audio/whisperX/modeling_whisper.py | 83 ++++++++++++------- 2 files changed, 56 insertions(+), 28 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py b/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py index 923cfe0905..5733edcb30 100644 --- a/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py +++ b/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py @@ -1,6 +1,7 @@ import os import argparse import time +import math import torch import mindietorch from mindietorch._enums import dtype diff --git a/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py b/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py index 639c847349..4800ba33d0 100644 --- a/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py +++ b/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py @@ -126,7 +126,7 @@ class MindieIFA(WhisperAttention): self, hidden_states: torch.Tensor, past_key_value: torch.Tensor, - actual_seq_len: torch.Tensor, + actual_seq_len: torch.Tensor = None, **kwargs ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: # self attn assert past_key_value is not None, \ @@ -138,25 +138,43 @@ class MindieIFA(WhisperAttention): key_states = self._shape(self.k_proj(hidden_states), -1, bsz).to(torch.float16) values_states = self._shape(self.v_proj(hidden_states), -1, bsz).to(torch.float16) past_key_cache, past_value_cache = past_key_value[0], past_key_value[1] - indices = actual_seq_len - 1 - past_key_cache = torch.ops.aie.scatter_update(past_key_cache, indices, key_states, axis=1) - past_value_cache = torch.ops.aie.scatter_update(past_value_cache, indices, values_states, axis=1) - past_key_value = (past_key_cache, past_value_cache) - # B S N D - query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).contiguous() - # B S N D - attn_output = torch.ops.aie.incre_flash_attention( - query=query_states, - key=past_key_cache, - value=past_value_cache, - actual_seq_lengths=actual_seq_len, - num_head=self.num_heads, - scale=self.scaling, - layout="BSND") - attn_output = attn_output.view(bsz, tgt_len, self.num_heads, self.head_dim) - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) - attn_output = self.out_proj(attn_output) - return attn_output, _, past_key_value + if actual_seq_len is not None: + # self atten + indices = actual_seq_len - 1 + past_key_cache = torch.ops.aie.scatter_update(past_key_cache, indices, key_states, axis=1) + past_value_cache = torch.ops.aie.scatter_update(past_value_cache, indices, values_states, axis=1) + past_key_value = (past_key_cache, past_value_cache) + + # B S N D + query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).contiguous() + # B S N D + attn_output = torch.ops.aie.incre_flash_attention( + query=query_states, + key=past_key_cache, + value=past_value_cache, + actual_seq_lengths=actual_seq_len, + num_head=self.num_heads, + scale=self.scaling, + layout="BSND") + attn_output = attn_output.view(bsz, tgt_len, self.num_heads, self.head_dim) + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + attn_output = self.out_proj(attn_output) + return attn_output, _, past_key_value + else: + query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).contiguous() + # B S N D + attn_output = torch.ops.aie.incre_flash_attention( + query=query_states, + key=past_key_cache, + value=past_value_cache, + num_head=self.num_heads, + scale=self.scaling, + layout="BSND") + attn_output = attn_output.view(bsz, tgt_len, self.num_heads, self.head_dim) + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + attn_output = self.out_proj(attn_output) + return attn_output, _, past_key_value + class MindieFA(WhisperAttention): @@ -334,7 +352,7 @@ class MindieWhisperDecoderLayer(WhisperDecoderLayer): ) else: raise ValueError(f"Unsupporting current parameters. eg. " - f"is_incre_decode: {is_incre_decode} machine_type {machine_type}") + f"is_incre_decode: {config.is_incre_decode} machine_type {config.machine_type}") def forward( self, @@ -377,7 +395,7 @@ class MindieWhisperDecoderLayer(WhisperDecoderLayer): attention_mask=encoder_attention_mask, layer_head_mask=cross_attn_layer_head_mask, past_key_value=cross_attn_past_key_value, - output_attentions=output_attentions, + output_attentions=output_attentions ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states @@ -422,8 +440,8 @@ class MindieWhisperDecoder(WhisperDecoder): input_shape = input_ids.size() input_ids = input_ids.view(-1, input_shape[-1]) - # B S N D - past_key_values_length = past_key_values[0].shape[1] if past_key_values is not None else 0 + # B S N D todo + past_key_values_length = past_key_values[0].shape[2] if past_key_values is not None else 0 inputs_embeds = self.embed_tokens(input_ids) # embed positions positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length) @@ -720,12 +738,21 @@ class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, Gen self.compile_scatter_update(args) self.has_compile = True + def _init_past_key_value_cache(self, bsz): for _ in range(32): - self.past_key_value.append(torch.ones([bsz, self.max_len, self.head_num, self.head_size])) - self.past_key_value.append(torch.ones([bsz, self.max_len, self.head_num, self.head_size])) - self.past_key_value.append(torch.ones([bsz, self.encoder_seq_len, self.head_num, self.head_size])) - self.past_key_value.append(torch.ones([bsz, self.encoder_seq_len, self.head_num, self.head_size])) + self.past_key_value.append( + torch.ones([bsz, self.max_len, self.head_num, self.head_size], dtype=torch.float16).to("npu") + ) + self.past_key_value.append( + torch.ones([bsz, self.max_len, self.head_num, self.head_size], dtype=torch.float16).to("npu") + ) + self.past_key_value.append( + torch.ones([bsz, self.encoder_seq_len, self.head_num, self.head_size], dtype=torch.float16).to("npu") + ) + self.past_key_value.append( + torch.ones([bsz, self.encoder_seq_len, self.head_num, self.head_size], dtype=torch.float16).to("npu") + ) print("init past key value cache success.") def load_mindie_models(self, save_path, batch_size): -- Gitee From e9d5da5e662875a487c95363caf24cd5e93956a8 Mon Sep 17 00:00:00 2001 From: tiansongxue Date: Tue, 12 Nov 2024 00:22:54 +0800 Subject: [PATCH 04/14] =?UTF-8?q?=E5=A2=9E=E5=8A=A0whisperx?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../audio/whisperX/compile_whisper.py | 24 ++- .../audio/whisperX/modeling_whisper.py | 161 +----------------- 2 files changed, 13 insertions(+), 172 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py b/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py index 5733edcb30..a9c44d0e5d 100644 --- a/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py +++ b/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py @@ -75,8 +75,7 @@ def compile_incre_decoder(args : argparse, compile_info : CompileInfo): mindie_whisper = MindieWhisperForConditionalGeneration.from_pretrained(args.model_path, machine_type = args.machine_type, - is_incre_decode=True, - soc_version=args.soc_version) + is_incre_decode=True) decoder = Decoder(mindie_whisper) print("Start compiling decoder.") @@ -93,15 +92,15 @@ def compile_incre_decoder(args : argparse, compile_info : CompileInfo): traced_decoder = torch.jit.trace(decoder, traced_args) # BSND key_value_infos = [ - mindietorch.Input(shape=(args.bs, compile_info.max_decode_step, compile_info.head_num, compile_info.head_size), - dtype=dtype.FLOAT16), - mindietorch.Input(shape=(args.bs, compile_info.max_decode_step, compile_info.head_num, compile_info.head_size), - dtype=dtype.FLOAT16), - mindietorch.Input(shape=(args.bs, compile_info.encoder_seq_len, compile_info.head_num, compile_info.head_size), - dtype=dtype.FLOAT16), - mindietorch.Input(shape=(args.bs, compile_info.encoder_seq_len, compile_info.head_num, compile_info.head_size), - dtype=dtype.FLOAT16 - )] * compile_info.layer_nums + mindietorch.Input(shape=(args.bs, compile_info.max_decode_step, compile_info.head_num, compile_info.head_size), + dtype=dtype.FLOAT16), + mindietorch.Input(shape=(args.bs, compile_info.max_decode_step, compile_info.head_num, compile_info.head_size), + dtype=dtype.FLOAT16), + mindietorch.Input(shape=(args.bs, compile_info.encoder_seq_len, compile_info.head_num, compile_info.head_size), + dtype=dtype.FLOAT16), + mindietorch.Input(shape=(args.bs, compile_info.encoder_seq_len, compile_info.head_num, compile_info.head_size), + dtype=dtype.FLOAT16 + )] * compile_info.layer_nums input_info = [mindietorch.Input(shape=(args.bs, 1), dtype=dtype.INT64), # input ids mindietorch.Input(shape=(args.bs, compile_info.encoder_seq_len, compile_info.hidden_size)), mindietorch.Input(shape=(args.bs,), dtype=dtype.INT64)] # actual sq len @@ -121,7 +120,7 @@ def compile_incre_decoder(args : argparse, compile_info : CompileInfo): print(f"Compile whisper_decoder success, saved in {save_file}.") -def compile_scatter_update(self, args, compile_info): +def compile_scatter_update(args, compile_info): class MindieScatter(torch.nn.Module): def forward(self, past_key_value, indices, update_states): out = torch.ops.aie.scatter_update(past_key_value, indices, update_states, 1) @@ -160,7 +159,6 @@ if __name__ == "__main__": args = parser.parse_args() mindie_whisper = MindieWhisperForConditionalGeneration.from_pretrained(args.model_path, - soc_version=args.soc_version, machine_type=args.machine_type) device = f"npu:{args.device_id}" print("Start compiling Mindie-Whisper, it will take some time, please wait.") diff --git a/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py b/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py index 4800ba33d0..f1cc3bba31 100644 --- a/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py +++ b/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py @@ -535,7 +535,7 @@ class MindieWhisperModel(WhisperModel): class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, GenerationMixin): - def __init__(self, config, is_incre_decode=False, machine_type="800A2"): + def __init__(self, config, is_incre_decode=False, machine_type="800IA2"): super().__init__(config) config.is_incre_decode = is_incre_decode config.machine_type = machine_type @@ -555,12 +555,7 @@ class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, Gen "mindie_self_scatter_bs", "mindie_encoder_scatter_bs" ] - self.hidden_size = 1280 - self.head_num = 20 - self.head_size = 64 - self.seq_len = 1 - self.layer_nums = 32 - self.max_len = 448 + self.past_key_value = [] def forward(self, *args): @@ -586,158 +581,6 @@ class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, Gen lm_logits = self.proj_out(outputs[0]) return [lm_logits] + outputs[1] - def compile_encoder(self, args): - encoder = self.get_encoder() - class Encoder(torch.nn.Module): - - def __init__(self, model): - super().__init__() - self.model = model - - def forward(self, input_features): - return self.model(input_features=input_features, return_dict=False) - - mel_feature_size = 128 - max_frames = 3000 - input_features = torch.randn([args.bs, mel_feature_size, max_frames]) - encoder_traced = torch.jit.trace(Encoder(encoder), (input_features)) - input_info = [mindietorch.Input(shape=(args.bs, mel_feature_size, max_frames))] - print("Start AOE optimization_level 1.") - compiled = mindietorch.compile(encoder_traced, - inputs=input_info, - precision_policy=mindietorch.PrecisionPolicy.FP16, - soc_version=args.soc_version, - optimization_level=1 - ) - print("AOE optimize finished and start compile encoder.") - compiled = mindietorch.compile(encoder_traced, - inputs=input_info, - precision_policy=mindietorch.PrecisionPolicy.FP16, - soc_version=args.soc_version, - ) - save_file = os.path.join(args.save_path, f"{self.file_prefix_names[0]}{args.bs}.ts") - torch.jit.save(compiled, save_file) - print(f"Compile encoder success, saved in {save_file}") - - def compile_prefill_decoder(self, args): - print("Start compiling prefill_decoder.") - encoder_seq_len = 1500 - hidden_size = 1280 - encoder_outputs = torch.randn([args.bs, encoder_seq_len, hidden_size]) - decoder_input_ids = torch.randint(1, 4, (args.bs, 1)) - prefill_decoder_traced = torch.jit.trace(self.eval(), (decoder_input_ids, encoder_outputs)) - input_info = [mindietorch.Input(shape=(args.bs, 1), dtype=dtype.INT64), - mindietorch.Input(shape=(args.bs, encoder_seq_len, hidden_size))] - prefill_decoder_compiled = mindietorch.compile(prefill_decoder_traced, - inputs=input_info, - precision_policy=mindietorch.PrecisionPolicy.FP16, - soc_version=args.soc_version) - - save_file = os.path.join(args.save_path, f"{self.file_prefix_names[1]}{args.bs}.ts") - torch.jit.save(prefill_decoder_compiled, save_file) - print(f"Compile prefill_decoder success, saved in {save_file}.") - - def compile_decoder(self, args): - - class Decoder(torch.nn.Module): - - def __init__(self, model): - super().__init__() - self.model = model - - def forward(self, *args): - return self.model.forward(*args)[0] - - mindie_whisper = MindieWhisperForConditionalGeneration.from_pretrained(args.model_path, - is_incre_decode=True) - decoder = Decoder(mindie_whisper) - print("Start compiling decoder.") - - encoder_outputs = torch.randn([args.bs, self.encoder_seq_len, self.hidden_size]) - decoder_input_ids = torch.randint(1, 4, (args.bs, 1)) - actual_seq_len = torch.ones((args.bs)) - all_past_key_value = [torch.randn([args.bs, self.seq_len, self.head_num, self.head_size]), - torch.randn([args.bs, self.seq_len, self.head_num, self.head_size]), - torch.randn([args.bs, self.encoder_seq_len, self.head_num, self.head_size]), - torch.randn([args.bs, self.encoder_seq_len, self.head_num, self.head_size]) - ] * self.layer_nums - traced_args = [decoder_input_ids, encoder_outputs, actual_seq_len] - traced_args.extend(all_past_key_value) - traced_decoder = torch.jit.trace(decoder, traced_args) - # BSND - key_value_infos = [ - mindietorch.Input(shape=(args.bs, self.max_len, self.head_num, self.head_size), - dtype=dtype.FLOAT16), - mindietorch.Input(shape=(args.bs, self.max_len, self.head_num, self.head_size), - dtype=dtype.FLOAT16), - mindietorch.Input(shape=(args.bs, self.encoder_seq_len, self.head_num, self.head_size), - dtype=dtype.FLOAT16), - mindietorch.Input(shape=(args.bs, self.encoder_seq_len, self.head_num, self.head_size), - dtype=dtype.FLOAT16 - )] * self.layer_nums - input_info = [mindietorch.Input(shape=(args.bs, 1), dtype=dtype.INT64), - mindietorch.Input(shape=(args.bs, self.encoder_seq_len, self.hidden_size)), - mindietorch.Input(shape=(args.bs, ), dtype=dtype.INT64)] - - input_info.extend(key_value_infos) - float_size = 4 - voc_size = 51866 - buffer_size = math.ceil((args.bs * 1 * voc_size * float_size) / 1024 / 1024) - print(f"Set {buffer_size}/MB for output.") - compiled_decoder = mindietorch.compile(traced_decoder, - inputs=input_info, - precision_policy=mindietorch.PrecisionPolicy.FP16, - soc_version=args.soc_version, - default_buffer_size_vec=[buffer_size]) - save_file = os.path.join(args.save_path, f"{self.file_prefix_names[2]}{args.bs}.ts") - torch.jit.save(compiled_decoder, save_file) - - print(f"Compile whisper_decoder success, saved in {save_file}.") - - def compile_scatter_update(self, args): - - class MindieScatter(torch.nn.Module): - def forward(self, past_key_value, indices, update_states): - out = torch.ops.aie.scatter_update(past_key_value, indices, update_states, 1) - return out - bs = args.bs - self_past_key_value = torch.randn([bs, self.max_len, self.head_num, self.head_size]) - encoder_past_key_value = torch.randn([bs, self.max_len, self.head_num, self.head_size]) - indices = torch.tensor([0]*bs) - update_states = torch.randn([bs, 1, 20, 64]) - traced = torch.jit.trace(MindieScatter(), (self_past_key_value, indices, update_states)) - - self_attn_info = mindietorch.Input(shape=self_past_key_value.shape, dtype=mindietorch.dtype.FLOAT16) - encoder_attn_info = mindietorch.Input(shape=encoder_past_key_value.shape, dtype=mindietorch.dtype.FLOAT16) - indices_info = mindietorch.Input(shape=indices.shape, dtype=mindietorch.dtype.INT64) - update_states_info = mindietorch.Input(shape=encoder_past_key_value.shape, dtype=mindietorch.dtype.FLOAT16) - - compile_self = mindietorch.compile(traced, inputs=[self_attn_info, indices_info, update_states_info], - soc_version=args.soc_version) - torch.jit.save(compile_self, f"{args.save_path}/{self.file_prefix_names[3]}{args.bs}.ts") - - compile_self = mindietorch.compile(traced, inputs=[encoder_attn_info, indices_info, update_states_info], - soc_version=args.soc_version) - torch.jit.save(compile_self, f"{args.save_path}/{self.file_prefix_names[4]}{args.bs}.ts") - print("compile scatter success.") - - def compile(self, args): - print("Start compiling Mindie-Whisper, it will take some time, please wait.") - save_path = args.save_path - - if not save_path: - raise ValueError("Please provide the directory where the compiled model saved.") - if not os.path.exists(save_path): - os.makedirs(save_path) - print(f"Directory {save_path} created.") - else: - print(f"Directory {save_path} already exists.") - self.compile_encoder(args) - self.compile_prefill_decoder(args) - self.compile_decoder(args) - self.compile_scatter_update(args) - self.has_compile = True - def _init_past_key_value_cache(self, bsz): for _ in range(32): -- Gitee From d712b57560ae2f2e98179d1795604beaea1a9d2d Mon Sep 17 00:00:00 2001 From: tiansongxue Date: Tue, 12 Nov 2024 09:43:42 +0800 Subject: [PATCH 05/14] =?UTF-8?q?=E5=A2=9E=E5=8A=A0whisperx?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../audio/whisperX/compile_whisper.py | 4 ++-- .../audio/whisperX/modeling_whisper.py | 24 +++++++++++++------ 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py b/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py index a9c44d0e5d..8c524a23b8 100644 --- a/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py +++ b/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py @@ -169,8 +169,8 @@ if __name__ == "__main__": print(f"Directory {args.save_path} created.") else: print(f"Directory {args.save_path} already exists.") - - compile_scatter_update(mindie_whisper, args, CompileInfo) + mindietorch.set_device(args.device_id) + compile_scatter_update(args, CompileInfo) compile_encoder(mindie_whisper, args, CompileInfo) compile_prefill_decoder(mindie_whisper, args, CompileInfo) compile_incre_decoder(args, CompileInfo) \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py b/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py index f1cc3bba31..3287a10472 100644 --- a/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py +++ b/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py @@ -351,7 +351,7 @@ class MindieWhisperDecoderLayer(WhisperDecoderLayer): config=config ) else: - raise ValueError(f"Unsupporting current parameters. eg. " + raise ValueError(f"Unsupport current parameters. eg. " f"is_incre_decode: {config.is_incre_decode} machine_type {config.machine_type}") def forward( @@ -476,12 +476,22 @@ class MindieWhisperDecoder(WhisperDecoder): class MindieWhisperEncoderLayer(WhisperEncoderLayer): def __init__(self, config): super().__init__(config) - self.self_attn = MindiePFA( - embed_dim=self.embed_dim, - num_heads=config.encoder_attention_heads, - dropout=config.attention_dropout, - config=config, - ) + if config.machine_type == "300IPro": + self.self_attn = MindieFA( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + dropout=config.attention_dropout, + config=config, + ) + elif config.machine_type == "800IA2": + self.self_attn = MindiePFA( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + dropout=config.attention_dropout, + config=config, + ) + else: + raise ValueError(f"Unsupport current machine_type {config.machine_type} when init encoder.") class MindieWhisperEncoder(WhisperEncoder): -- Gitee From 89aaf0c0b1222340e9f3a37740774940abf02464 Mon Sep 17 00:00:00 2001 From: tiansongxue Date: Tue, 12 Nov 2024 11:04:45 +0800 Subject: [PATCH 06/14] =?UTF-8?q?=E5=A2=9E=E5=8A=A0whisperx?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../built-in/audio/whisperX/modeling_whisper.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py b/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py index 3287a10472..6dbbff7905 100644 --- a/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py +++ b/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py @@ -21,6 +21,7 @@ from transformers.generation.stopping_criteria import StoppingCriteriaList, vali from transformers.generation.utils import GenerationMixin import mindietorch from mindietorch._enums import dtype +from config import CompileInfo class MindiePFA(WhisperAttention): @@ -567,6 +568,7 @@ class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, Gen ] self.past_key_value = [] + self.max_len = 448 def forward(self, *args): if len(args) not in (2, 131): @@ -595,16 +597,20 @@ class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, Gen def _init_past_key_value_cache(self, bsz): for _ in range(32): self.past_key_value.append( - torch.ones([bsz, self.max_len, self.head_num, self.head_size], dtype=torch.float16).to("npu") + torch.ones([bsz, CompileInfo.max_decode_step, CompileInfo.head_num, CompileInfo.head_size], + dtype=torch.float16).to("npu") ) self.past_key_value.append( - torch.ones([bsz, self.max_len, self.head_num, self.head_size], dtype=torch.float16).to("npu") + torch.ones([bsz, CompileInfo.max_len, CompileInfo.head_num, CompileInfo.head_size], + dtype=torch.float16).to("npu") ) self.past_key_value.append( - torch.ones([bsz, self.encoder_seq_len, self.head_num, self.head_size], dtype=torch.float16).to("npu") + torch.ones([bsz, CompileInfo.encoder_seq_len, CompileInfo.head_num, CompileInfo.head_size], + dtype=torch.float16).to("npu") ) self.past_key_value.append( - torch.ones([bsz, self.encoder_seq_len, self.head_num, self.head_size], dtype=torch.float16).to("npu") + torch.ones([bsz, CompileInfo.encoder_seq_len, CompileInfo.head_num, CompileInfo.head_size], + dtype=torch.float16).to("npu") ) print("init past key value cache success.") -- Gitee From cec842442e55060a8ee4ed9e8880a84d31fb5435 Mon Sep 17 00:00:00 2001 From: tiansongxue Date: Tue, 12 Nov 2024 15:04:57 +0800 Subject: [PATCH 07/14] =?UTF-8?q?=E5=A2=9E=E5=8A=A0whisperx?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../built-in/audio/whisperX/compile_whisper.py | 4 ++-- .../built-in/audio/whisperX/modeling_whisper.py | 7 +++---- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py b/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py index 8c524a23b8..716277ee25 100644 --- a/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py +++ b/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py @@ -128,7 +128,7 @@ def compile_scatter_update(args, compile_info): bs = args.bs self_past_key_value = torch.randn([bs, compile_info.max_decode_step, compile_info.head_num, compile_info.head_size]) - encoder_past_key_value = torch.randn([bs, compile_info.max_decode_step, compile_info.head_num, compile_info.head_size]) + encoder_past_key_value = torch.randn([bs, compile_info.encoder_seq_len, compile_info.head_num, compile_info.head_size]) indices = torch.tensor([0] * bs) update_states = torch.randn([bs, 1, 20, 64]) traced = torch.jit.trace(MindieScatter(), (self_past_key_value, indices, update_states)) @@ -136,7 +136,7 @@ def compile_scatter_update(args, compile_info): self_attn_info = mindietorch.Input(shape=self_past_key_value.shape, dtype=mindietorch.dtype.FLOAT16) encoder_attn_info = mindietorch.Input(shape=encoder_past_key_value.shape, dtype=mindietorch.dtype.FLOAT16) indices_info = mindietorch.Input(shape=indices.shape, dtype=mindietorch.dtype.INT64) - update_states_info = mindietorch.Input(shape=encoder_past_key_value.shape, dtype=mindietorch.dtype.FLOAT16) + update_states_info = mindietorch.Input(shape=update_states.shape, dtype=mindietorch.dtype.FLOAT16) compile_self = mindietorch.compile(traced, inputs=[self_attn_info, indices_info, update_states_info], soc_version=args.soc_version) diff --git a/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py b/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py index 6dbbff7905..4ea7fba27c 100644 --- a/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py +++ b/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py @@ -568,7 +568,6 @@ class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, Gen ] self.past_key_value = [] - self.max_len = 448 def forward(self, *args): if len(args) not in (2, 131): @@ -601,7 +600,7 @@ class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, Gen dtype=torch.float16).to("npu") ) self.past_key_value.append( - torch.ones([bsz, CompileInfo.max_len, CompileInfo.head_num, CompileInfo.head_size], + torch.ones([bsz, CompileInfo.max_decode_step, CompileInfo.head_num, CompileInfo.head_size], dtype=torch.float16).to("npu") ) self.past_key_value.append( @@ -637,7 +636,7 @@ class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, Gen self.self_attn_scatter = torch.jit.load(f"{save_path}/{self.file_prefix_names[3]}{batch_size}.ts") print(f"load {self.file_prefix_names[3]}{batch_size}.ts success.") - self.self_attn_scatter = torch.jit.load(f"{save_path}/{self.file_prefix_names[4]}{batch_size}.ts") + self.encoder_attn_scatter = torch.jit.load(f"{save_path}/{self.file_prefix_names[4]}{batch_size}.ts") print(f"load {self.file_prefix_names[4]}{batch_size}.ts success.") self.has_load = True else: @@ -677,7 +676,7 @@ class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, Gen this_peer_finished = False # used by synced_gpus only kv_actual_step = 1 - indices = torch.tensor([0] * input_ids.shape[0]) + indices = torch.tensor([0] * input_ids.shape[0]).to("npu") is_first_step = True while True: -- Gitee From a27a6459623af61c7c3afea58fbfd0edcb0b0b0f Mon Sep 17 00:00:00 2001 From: tiansongxue Date: Tue, 12 Nov 2024 17:07:26 +0800 Subject: [PATCH 08/14] =?UTF-8?q?=E8=B0=83=E9=80=9A310P3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../audio/whisperX/compile_whisper.py | 3 +- .../audio/whisperX/modeling_whisper.py | 39 ++++++++----------- 2 files changed, 19 insertions(+), 23 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py b/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py index 716277ee25..dcd3c07e86 100644 --- a/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py +++ b/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py @@ -142,11 +142,12 @@ def compile_scatter_update(args, compile_info): soc_version=args.soc_version) torch.jit.save(compile_self, f"{args.save_path}/{compile_info.prefix_name[3]}{args.bs}.ts") - compile_self = mindietorch.compile(traced, inputs=[encoder_attn_info, indices_info, update_states_info], + compile_self = mindietorch.compile(traced, inputs=[encoder_attn_info, indices_info, encoder_attn_info], soc_version=args.soc_version) torch.jit.save(compile_self, f"{args.save_path}/{compile_info.prefix_name[4]}{args.bs}.ts") print("compile scatter success.") + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('-model_path', type=str, required=True, help="please provide model path.") diff --git a/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py b/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py index 4ea7fba27c..aa523924e7 100644 --- a/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py +++ b/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py @@ -592,8 +592,13 @@ class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, Gen lm_logits = self.proj_out(outputs[0]) return [lm_logits] + outputs[1] + def load_mindie_models(self, save_path, batch_size): + if not (save_path and batch_size): + raise ValueError(f"Please provide batch_size and the directory where the compiled models saved,\ + but found save_path is {save_path}, batch_size is{batch_size}.") + self._check_save_path(save_path, batch_size) + self.batch_size = batch_size - def _init_past_key_value_cache(self, bsz): for _ in range(32): self.past_key_value.append( torch.ones([bsz, CompileInfo.max_decode_step, CompileInfo.head_num, CompileInfo.head_size], @@ -613,31 +618,21 @@ class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, Gen ) print("init past key value cache success.") - def load_mindie_models(self, save_path, batch_size): - if not (save_path and batch_size): - raise ValueError(f"Please provide batch_size and the directory where the compiled models saved,\ - but found save_path is {save_path}, batch_size is{batch_size}.") - self._check_save_path(save_path, batch_size) - self.batch_size = batch_size - - self._init_past_key_value_cache(batch_size) - - if not self.has_load: self.mindie_encoder = torch.jit.load(f"{save_path}/{self.file_prefix_names[0]}{batch_size}.ts") - print(f"load {self.file_prefix_names[0]}{batch_size}.ts success.") + print(f"load {save_path}/{self.file_prefix_names[0]}{batch_size}.ts success.") self.mindie_decoder_prefill = torch.jit.load(f"{save_path}/{self.file_prefix_names[1]}{batch_size}.ts") - print(f"load {self.file_prefix_names[1]}{batch_size}.ts success.") + print(f"load {save_path}/{self.file_prefix_names[1]}{batch_size}.ts success.") self.mindie_decoder = torch.jit.load(f"{save_path}/{self.file_prefix_names[2]}{batch_size}.ts") - print(f"load {self.file_prefix_names[2]}{batch_size}.ts success.") + print(f"load {save_path}/{self.file_prefix_names[2]}{batch_size}.ts success.") self.self_attn_scatter = torch.jit.load(f"{save_path}/{self.file_prefix_names[3]}{batch_size}.ts") - print(f"load {self.file_prefix_names[3]}{batch_size}.ts success.") + print(f"load {save_path}/{self.file_prefix_names[3]}{batch_size}.ts success.") self.encoder_attn_scatter = torch.jit.load(f"{save_path}/{self.file_prefix_names[4]}{batch_size}.ts") - print(f"load {self.file_prefix_names[4]}{batch_size}.ts success.") + print(f"load {save_path}/{self.file_prefix_names[4]}{batch_size}.ts success.") self.has_load = True else: print("Mindie whisper has already load.") @@ -685,14 +680,14 @@ class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, Gen if is_first_step: outputs = self.mindie_decoder_prefill(*args) for idx in range(32): - self.self_attn_scatter(self.past_key_value[0], indices, outputs[1 + 4*idx]) - self.self_attn_scatter(self.past_key_value[1], indices, outputs[1 + 4*idx + 1]) - self.encoder_attn_scatter(self.past_key_value[2], indices, outputs[1 + 4*idx + 2]) - self.encoder_attn_scatter(self.past_key_value[3], indices, outputs[1 + 4*idx + 3]) + self.self_attn_scatter(self.past_key_value[4*idx], indices, outputs[1 + 4*idx]) + self.self_attn_scatter(self.past_key_value[4*idx + 1], indices, outputs[1 + 4*idx + 1]) + self.encoder_attn_scatter(self.past_key_value[4*idx + 2], indices, outputs[1 + 4*idx + 2]) + self.encoder_attn_scatter(self.past_key_value[4*idx + 3], indices, outputs[1 + 4*idx + 3]) is_first_step = False else: kv_actual_step += 1 - args.append(torch.tensor([kv_actual_step] * input_ids.shape[0])).to("npu") + args.append(torch.tensor([kv_actual_step] * input_ids.shape[0]).to("npu")) args.extend(self.past_key_value) outputs = self.mindie_decoder(*args) if isinstance(outputs, list): @@ -866,7 +861,7 @@ class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, Gen if not if_first_step: decoder_input_ids_shape = decoder_input_ids.shape remove_prefix_length = decoder_input_ids_shape[1] - 1 - decoder_input_ids = decoder_input_ids[:, remove_prefix_length] + decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] return { "encoder_outputs": encoder_outputs, "decoder_input_ids": decoder_input_ids -- Gitee From af074673f0c56bfd10b3de2b7e30a1a6d9c4d92d Mon Sep 17 00:00:00 2001 From: tiansongxue Date: Tue, 12 Nov 2024 19:49:47 +0800 Subject: [PATCH 09/14] =?UTF-8?q?=E6=B5=8B=E8=AF=95emb=5Fpos=E7=9A=84?= =?UTF-8?q?=E6=80=A7=E8=83=BD=E5=92=8C=E7=B2=BE=E5=BA=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../audio/whisperX/compile_whisper.py | 30 ++++++++++++--- .../built-in/audio/whisperX/config.py | 3 +- .../audio/whisperX/modeling_whisper.py | 38 ++++++++++++------- 3 files changed, 52 insertions(+), 19 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py b/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py index dcd3c07e86..6ddedefab7 100644 --- a/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py +++ b/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py @@ -82,6 +82,7 @@ def compile_incre_decoder(args : argparse, compile_info : CompileInfo): encoder_outputs = torch.randn([args.bs, compile_info.encoder_seq_len, compile_info.hidden_size]) decoder_input_ids = torch.randint(1, 4, (args.bs, 1)) actual_seq_len = torch.ones((args.bs)) + embed_ops = torch.tensor([1]) all_past_key_value = [torch.randn([args.bs, compile_info.max_decode_step, compile_info.head_num, compile_info.head_size]), torch.randn([args.bs, compile_info.max_decode_step, compile_info.head_num, compile_info.head_size]), torch.randn([args.bs, compile_info.encoder_seq_len, compile_info.head_num, compile_info.head_size]), @@ -103,7 +104,8 @@ def compile_incre_decoder(args : argparse, compile_info : CompileInfo): )] * compile_info.layer_nums input_info = [mindietorch.Input(shape=(args.bs, 1), dtype=dtype.INT64), # input ids mindietorch.Input(shape=(args.bs, compile_info.encoder_seq_len, compile_info.hidden_size)), - mindietorch.Input(shape=(args.bs,), dtype=dtype.INT64)] # actual sq len + mindietorch.Input(shape=(args.bs,), dtype=dtype.INT64), # actual sq len + mindietorch.Input(shape=(embed_ops.shape,), dtype=dtype.INT32)] # past_key_values_len input_info.extend(key_value_infos) float_size = 4 @@ -148,6 +150,23 @@ def compile_scatter_update(args, compile_info): print("compile scatter success.") +def compile_embed_pos(model : MindieWhisperForConditionalGeneration, args : argparse, compile_info : CompileInfo): + + class EmbeddingPosition(torch.nn.Module): + def __init__(self, whisper_model : MindieWhisperForConditionalGeneration): + super().__init__() + self.embed_pos = whisper_model.model.decoder.embed_positions + self.dumpy_input_ids = torch.ones([args.bs, 1]) + + def forward(self, past_length): + return self.embed_pos(input_ids = self.dumpy_input_ids, past_key_values_length=past_length) + dump_inputs = torch.tensor([1]) + traced = torch.jit.trace(EmbeddingPosition(), (dump_inputs,)) + compiled = mindietorch.compile(traced, inputs=[mindietorch.Input(shape=dump_inputs.shape, dtype=dtype.INT32)]) + torch.jit.save(compiled, f"{args.save_path}/{compile_info.prefix_name[-1]}") + print(f"save {args.save_path}/{compile_info.prefix_name[-1]} success.") + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('-model_path', type=str, required=True, help="please provide model path.") @@ -171,7 +190,8 @@ if __name__ == "__main__": else: print(f"Directory {args.save_path} already exists.") mindietorch.set_device(args.device_id) - compile_scatter_update(args, CompileInfo) - compile_encoder(mindie_whisper, args, CompileInfo) - compile_prefill_decoder(mindie_whisper, args, CompileInfo) - compile_incre_decoder(args, CompileInfo) \ No newline at end of file + # compile_scatter_update(args, CompileInfo) + # compile_encoder(mindie_whisper, args, CompileInfo) + # compile_prefill_decoder(mindie_whisper, args, CompileInfo) + compile_incre_decoder(args, CompileInfo) + compile_embed_pos(args, CompileInfo) \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/audio/whisperX/config.py b/MindIE/MindIE-Torch/built-in/audio/whisperX/config.py index 6d4464195e..591584d94d 100644 --- a/MindIE/MindIE-Torch/built-in/audio/whisperX/config.py +++ b/MindIE/MindIE-Torch/built-in/audio/whisperX/config.py @@ -7,7 +7,8 @@ class CompileInfo: "mindie_decoder_prefill_bs", "mindie_whisper_decoder_bs", "mindie_self_scatter_bs", - "mindie_encoder_scatter_bs"] + "mindie_encoder_scatter_bs", + "mindie_embed_bs"] mel_feature_size = 128 max_frames = 3000 max_decode_step = 448 diff --git a/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py b/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py index aa523924e7..540047a6d1 100644 --- a/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py +++ b/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py @@ -433,6 +433,7 @@ class MindieWhisperDecoder(WhisperDecoder): encoder_hidden_states, past_key_values, actual_seq_len, + embed_position = None, use_cache=True, attention_mask=None): if input_ids is None: @@ -442,11 +443,15 @@ class MindieWhisperDecoder(WhisperDecoder): input_shape = input_ids.size() input_ids = input_ids.view(-1, input_shape[-1]) # B S N D todo - past_key_values_length = past_key_values[0].shape[2] if past_key_values is not None else 0 + # past_key_values_length = past_key_values[0].shape[2] if past_key_values is not None else 0 inputs_embeds = self.embed_tokens(input_ids) # embed positions - positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length) - hidden_states = inputs_embeds + positions + if past_key_values is None: + positions = self.embed_positions(input_ids, past_key_values_length=0) + hidden_states = inputs_embeds + positions + else: + assert embed_position is not None, "Embed_position is required, as past_key_values is not None." + hidden_states = inputs_embeds + embed_position past_key_value_cache = [] for idx, decoder_layer in enumerate(self.layers): @@ -513,6 +518,7 @@ class MindieWhisperModel(WhisperModel): self, encoder_outputs, decoder_input_ids, + embed_position: Optional[torch.Tensor] = None, past_key_values: Optional[torch.Tensor] = None, actual_seq_len: Optional[torch.Tensor] = None, use_cache: Optional[bool] = False, @@ -539,7 +545,8 @@ class MindieWhisperModel(WhisperModel): encoder_hidden_states=encoder_outputs, past_key_values=past_key_values, use_cache=use_cache, - actual_seq_len=actual_seq_len + actual_seq_len=actual_seq_len, + embed_position=embed_position ) return decoder_outputs @@ -570,15 +577,17 @@ class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, Gen self.past_key_value = [] def forward(self, *args): - if len(args) not in (2, 131): + if len(args) not in (2, 132): raise ValueError(f"The args length of forward can only be 2 or 131, but got {len(args)}") decoder_input_ids = args[0] encoder_outputs = args[1] - if len(args) == 131: + if len(args) == 132: actual_seq_len = args[2] - past_key_values = args[3:] + embed_position = args[3] + past_key_values = args[4:] else: actual_seq_len = None + embed_position = None past_key_values = None outputs = self.model( decoder_input_ids=decoder_input_ids, @@ -587,7 +596,8 @@ class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, Gen actual_seq_len=actual_seq_len, use_cache=True, return_dict=False, - input_features=None + input_features=None, + embed_position = None ) lm_logits = self.proj_out(outputs[0]) return [lm_logits] + outputs[1] @@ -597,23 +607,22 @@ class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, Gen raise ValueError(f"Please provide batch_size and the directory where the compiled models saved,\ but found save_path is {save_path}, batch_size is{batch_size}.") self._check_save_path(save_path, batch_size) - self.batch_size = batch_size for _ in range(32): self.past_key_value.append( - torch.ones([bsz, CompileInfo.max_decode_step, CompileInfo.head_num, CompileInfo.head_size], + torch.ones([batch_size, CompileInfo.max_decode_step, CompileInfo.head_num, CompileInfo.head_size], dtype=torch.float16).to("npu") ) self.past_key_value.append( - torch.ones([bsz, CompileInfo.max_decode_step, CompileInfo.head_num, CompileInfo.head_size], + torch.ones([batch_size, CompileInfo.max_decode_step, CompileInfo.head_num, CompileInfo.head_size], dtype=torch.float16).to("npu") ) self.past_key_value.append( - torch.ones([bsz, CompileInfo.encoder_seq_len, CompileInfo.head_num, CompileInfo.head_size], + torch.ones([batch_size, CompileInfo.encoder_seq_len, CompileInfo.head_num, CompileInfo.head_size], dtype=torch.float16).to("npu") ) self.past_key_value.append( - torch.ones([bsz, CompileInfo.encoder_seq_len, CompileInfo.head_num, CompileInfo.head_size], + torch.ones([batch_size, CompileInfo.encoder_seq_len, CompileInfo.head_num, CompileInfo.head_size], dtype=torch.float16).to("npu") ) print("init past key value cache success.") @@ -633,6 +642,9 @@ class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, Gen self.encoder_attn_scatter = torch.jit.load(f"{save_path}/{self.file_prefix_names[4]}{batch_size}.ts") print(f"load {save_path}/{self.file_prefix_names[4]}{batch_size}.ts success.") + + self.embed_pos = torch.jit.load(f"{save_path}/{self.file_prefix_names[-1]}{batch_size}.ts") + print(f"load {save_path}/{self.file_prefix_names[-1]}{batch_size}.ts success.") self.has_load = True else: print("Mindie whisper has already load.") -- Gitee From eed9343c7f338d711471bd61da8687c0f7d92357 Mon Sep 17 00:00:00 2001 From: tiansongxue Date: Tue, 12 Nov 2024 20:24:20 +0800 Subject: [PATCH 10/14] =?UTF-8?q?=E6=B5=8B=E8=AF=95emb=5Fpos=E7=9A=84?= =?UTF-8?q?=E6=80=A7=E8=83=BD=E5=92=8C=E7=B2=BE=E5=BA=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../built-in/audio/whisperX/compile_whisper.py | 12 ++++++------ .../built-in/audio/whisperX/modeling_whisper.py | 6 +++++- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py b/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py index 6ddedefab7..0c764b6b6a 100644 --- a/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py +++ b/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py @@ -88,7 +88,7 @@ def compile_incre_decoder(args : argparse, compile_info : CompileInfo): torch.randn([args.bs, compile_info.encoder_seq_len, compile_info.head_num, compile_info.head_size]), torch.randn([args.bs, compile_info.encoder_seq_len, compile_info.head_num, compile_info.head_size]) ] * compile_info.layer_nums - traced_args = [decoder_input_ids, encoder_outputs, actual_seq_len] + traced_args = [decoder_input_ids, encoder_outputs, actual_seq_len, embed_ops] traced_args.extend(all_past_key_value) traced_decoder = torch.jit.trace(decoder, traced_args) # BSND @@ -105,7 +105,7 @@ def compile_incre_decoder(args : argparse, compile_info : CompileInfo): input_info = [mindietorch.Input(shape=(args.bs, 1), dtype=dtype.INT64), # input ids mindietorch.Input(shape=(args.bs, compile_info.encoder_seq_len, compile_info.hidden_size)), mindietorch.Input(shape=(args.bs,), dtype=dtype.INT64), # actual sq len - mindietorch.Input(shape=(embed_ops.shape,), dtype=dtype.INT32)] # past_key_values_len + mindietorch.Input(shape=embed_ops.shape, dtype=dtype.INT32)] # past_key_values_len input_info.extend(key_value_infos) float_size = 4 @@ -161,10 +161,10 @@ def compile_embed_pos(model : MindieWhisperForConditionalGeneration, args : argp def forward(self, past_length): return self.embed_pos(input_ids = self.dumpy_input_ids, past_key_values_length=past_length) dump_inputs = torch.tensor([1]) - traced = torch.jit.trace(EmbeddingPosition(), (dump_inputs,)) + traced = torch.jit.trace(EmbeddingPosition(model), (dump_inputs,)) compiled = mindietorch.compile(traced, inputs=[mindietorch.Input(shape=dump_inputs.shape, dtype=dtype.INT32)]) - torch.jit.save(compiled, f"{args.save_path}/{compile_info.prefix_name[-1]}") - print(f"save {args.save_path}/{compile_info.prefix_name[-1]} success.") + torch.jit.save(compiled, f"{args.save_path}/{compile_info.prefix_name[-1]}{args.bs}.ts") + print(f"save {args.save_path}/{compile_info.prefix_name[-1]}{args.bs}.ts success.") if __name__ == "__main__": @@ -194,4 +194,4 @@ if __name__ == "__main__": # compile_encoder(mindie_whisper, args, CompileInfo) # compile_prefill_decoder(mindie_whisper, args, CompileInfo) compile_incre_decoder(args, CompileInfo) - compile_embed_pos(args, CompileInfo) \ No newline at end of file + compile_embed_pos(mindie_whisper, args, CompileInfo) \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py b/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py index 540047a6d1..d172b2db47 100644 --- a/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py +++ b/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py @@ -597,7 +597,7 @@ class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, Gen use_cache=True, return_dict=False, input_features=None, - embed_position = None + embed_position = embed_position ) lm_logits = self.proj_out(outputs[0]) return [lm_logits] + outputs[1] @@ -683,6 +683,7 @@ class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, Gen this_peer_finished = False # used by synced_gpus only kv_actual_step = 1 + past_kv_len = 0 indices = torch.tensor([0] * input_ids.shape[0]).to("npu") is_first_step = True while True: @@ -697,9 +698,12 @@ class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, Gen self.encoder_attn_scatter(self.past_key_value[4*idx + 2], indices, outputs[1 + 4*idx + 2]) self.encoder_attn_scatter(self.past_key_value[4*idx + 3], indices, outputs[1 + 4*idx + 3]) is_first_step = False + past_kv_len + 1 else: kv_actual_step += 1 args.append(torch.tensor([kv_actual_step] * input_ids.shape[0]).to("npu")) + args.append(torch.tensor([past_kv_len]).to("npu")) + past_kv_len += 1 args.extend(self.past_key_value) outputs = self.mindie_decoder(*args) if isinstance(outputs, list): -- Gitee From 3d303849ac8c3f5d1c1b9146880ad33d97df3e1d Mon Sep 17 00:00:00 2001 From: tiansongxue Date: Tue, 12 Nov 2024 21:04:41 +0800 Subject: [PATCH 11/14] =?UTF-8?q?=E6=B5=8B=E8=AF=95embed=5Fpositions?= =?UTF-8?q?=E7=B2=BE=E5=BA=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../audio/whisperX/modeling_whisper.py | 9 +---- .../built-in/audio/whisperX/test.py | 37 +++++++++++++++++++ 2 files changed, 39 insertions(+), 7 deletions(-) create mode 100644 MindIE/MindIE-Torch/built-in/audio/whisperX/test.py diff --git a/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py b/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py index d172b2db47..2b3849cb06 100644 --- a/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py +++ b/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py @@ -567,12 +567,7 @@ class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, Gen self.encoder_attn_scatter = None self.save_path = None self.encoder_seq_len = 1500 - self.file_prefix_names = ["mindie_whisper_encoder_bs", - "mindie_decoder_prefill_bs", - "mindie_whisper_decoder_bs", - "mindie_self_scatter_bs", - "mindie_encoder_scatter_bs" - ] + self.file_prefix_names = CompileInfo.prefix_name self.past_key_value = [] @@ -702,7 +697,7 @@ class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, Gen else: kv_actual_step += 1 args.append(torch.tensor([kv_actual_step] * input_ids.shape[0]).to("npu")) - args.append(torch.tensor([past_kv_len]).to("npu")) + args.append(torch.tensor([past_kv_len], dtype=torch.int32).to("npu")) past_kv_len += 1 args.extend(self.past_key_value) outputs = self.mindie_decoder(*args) diff --git a/MindIE/MindIE-Torch/built-in/audio/whisperX/test.py b/MindIE/MindIE-Torch/built-in/audio/whisperX/test.py new file mode 100644 index 0000000000..854eefd64c --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/audio/whisperX/test.py @@ -0,0 +1,37 @@ +import os +import argparse +import time +import math +import torch +import torch.nn.functional as F +import mindietorch +from mindietorch._enums import dtype +from modeling_whisper import MindieWhisperForConditionalGeneration +from config import CompileInfo + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument('-model_path', type=str, required=True, help="please provide model path.") + parser.add_argument('-bs', type=int, default=8, help="please provide batch_size, default:8.") + parser.add_argument('-soc_version', type=str, required=True, + help="please provide soc_version.") + parser.add_argument('-save_path', type=str, default="compiled_models", help="compiled models save dir.") + parser.add_argument('-machine_type', type=str, choices=["300IPro", "800IA2"], default="800A2") + parser.add_argument('-device_id', type=int, default=0) + args = parser.parse_args() + mindie_whisper = MindieWhisperForConditionalGeneration.from_pretrained(args.model_path, + machine_type=args.machine_type) + cpu_model = mindie_whisper.model.decoder.embed_positions + mindietorch.set_device(args.device_id) + npu_model = torch.jit.load(f"{args.save_path}/{CompileInfo.prefix_name[-1]}{args.bs}.ts") + print(f"load {CompileInfo.prefix_name[-1]}{args.bs}.ts success.") + input_ids = torch.ones([args.bs, 1]) + for step in range(1, 448): + past_key_values_length = torch.tensor([step]) + cpu_ret = cpu_model(input_ids=input_ids, past_key_values_length=past_key_values_length) + npu_ret = npu_model(past_key_values_length.to("npu")).to("cpu") + + print(f"npu shape {npu_ret.shape}, cpu shape {cpu_ret.shape}") + print(f"cosine {F.cosine_similarity(cpu_ret.reshape(1, -1), npu_ret.reshape(1, -1))}") -- Gitee From c491f044736d58bb1ecec5d5bcc62150a0649d0d Mon Sep 17 00:00:00 2001 From: tiansongxue Date: Tue, 12 Nov 2024 21:35:14 +0800 Subject: [PATCH 12/14] =?UTF-8?q?=E6=B5=8B=E8=AF=95embed=5Fpositions?= =?UTF-8?q?=E7=B2=BE=E5=BA=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../MindIE-Torch/built-in/audio/whisperX/compile_whisper.py | 6 +++--- .../built-in/audio/whisperX/modeling_whisper.py | 3 ++- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py b/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py index 0c764b6b6a..eec476ac58 100644 --- a/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py +++ b/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py @@ -82,7 +82,7 @@ def compile_incre_decoder(args : argparse, compile_info : CompileInfo): encoder_outputs = torch.randn([args.bs, compile_info.encoder_seq_len, compile_info.hidden_size]) decoder_input_ids = torch.randint(1, 4, (args.bs, 1)) actual_seq_len = torch.ones((args.bs)) - embed_ops = torch.tensor([1]) + embed_ops = torch.randn((1, 1280)) all_past_key_value = [torch.randn([args.bs, compile_info.max_decode_step, compile_info.head_num, compile_info.head_size]), torch.randn([args.bs, compile_info.max_decode_step, compile_info.head_num, compile_info.head_size]), torch.randn([args.bs, compile_info.encoder_seq_len, compile_info.head_num, compile_info.head_size]), @@ -105,7 +105,7 @@ def compile_incre_decoder(args : argparse, compile_info : CompileInfo): input_info = [mindietorch.Input(shape=(args.bs, 1), dtype=dtype.INT64), # input ids mindietorch.Input(shape=(args.bs, compile_info.encoder_seq_len, compile_info.hidden_size)), mindietorch.Input(shape=(args.bs,), dtype=dtype.INT64), # actual sq len - mindietorch.Input(shape=embed_ops.shape, dtype=dtype.INT32)] # past_key_values_len + mindietorch.Input(shape=embed_ops.shape, dtype=dtype.FLOAT)] # embed_ops input_info.extend(key_value_infos) float_size = 4 @@ -194,4 +194,4 @@ if __name__ == "__main__": # compile_encoder(mindie_whisper, args, CompileInfo) # compile_prefill_decoder(mindie_whisper, args, CompileInfo) compile_incre_decoder(args, CompileInfo) - compile_embed_pos(mindie_whisper, args, CompileInfo) \ No newline at end of file + # compile_embed_pos(mindie_whisper, args, CompileInfo) \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py b/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py index 2b3849cb06..5d817fcdbf 100644 --- a/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py +++ b/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py @@ -697,7 +697,8 @@ class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, Gen else: kv_actual_step += 1 args.append(torch.tensor([kv_actual_step] * input_ids.shape[0]).to("npu")) - args.append(torch.tensor([past_kv_len], dtype=torch.int32).to("npu")) + pos_info = self.embed_pos(torch.tensor([past_kv_len], dtype=torch.int32).to("npu")) + args.append(pos_info) past_kv_len += 1 args.extend(self.past_key_value) outputs = self.mindie_decoder(*args) -- Gitee From 75c86dab2466f287c013b66b35357109e804cce5 Mon Sep 17 00:00:00 2001 From: tiansongxue Date: Wed, 13 Nov 2024 09:09:13 +0800 Subject: [PATCH 13/14] =?UTF-8?q?=E8=B0=83=E9=80=9A910=E5=92=8C310?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../MindIE-Torch/built-in/audio/whisperX/compile_whisper.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py b/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py index eec476ac58..eb60111744 100644 --- a/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py +++ b/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py @@ -162,7 +162,9 @@ def compile_embed_pos(model : MindieWhisperForConditionalGeneration, args : argp return self.embed_pos(input_ids = self.dumpy_input_ids, past_key_values_length=past_length) dump_inputs = torch.tensor([1]) traced = torch.jit.trace(EmbeddingPosition(model), (dump_inputs,)) - compiled = mindietorch.compile(traced, inputs=[mindietorch.Input(shape=dump_inputs.shape, dtype=dtype.INT32)]) + compiled = mindietorch.compile(traced, + inputs=[mindietorch.Input(shape=dump_inputs.shape, dtype=dtype.INT32)], + soc_version=args.soc_version) torch.jit.save(compiled, f"{args.save_path}/{compile_info.prefix_name[-1]}{args.bs}.ts") print(f"save {args.save_path}/{compile_info.prefix_name[-1]}{args.bs}.ts success.") -- Gitee From 4bcc13948410962f273a9e1edcf471306278af66 Mon Sep 17 00:00:00 2001 From: tiansongxue Date: Wed, 13 Nov 2024 09:10:30 +0800 Subject: [PATCH 14/14] =?UTF-8?q?=E8=B0=83=E9=80=9A910=E5=92=8C310?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../built-in/audio/whisperX/compile_whisper.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py b/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py index eb60111744..b220beadae 100644 --- a/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py +++ b/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py @@ -192,8 +192,8 @@ if __name__ == "__main__": else: print(f"Directory {args.save_path} already exists.") mindietorch.set_device(args.device_id) - # compile_scatter_update(args, CompileInfo) - # compile_encoder(mindie_whisper, args, CompileInfo) - # compile_prefill_decoder(mindie_whisper, args, CompileInfo) + compile_scatter_update(args, CompileInfo) + compile_encoder(mindie_whisper, args, CompileInfo) + compile_prefill_decoder(mindie_whisper, args, CompileInfo) compile_incre_decoder(args, CompileInfo) - # compile_embed_pos(mindie_whisper, args, CompileInfo) \ No newline at end of file + compile_embed_pos(mindie_whisper, args, CompileInfo) \ No newline at end of file -- Gitee