diff --git a/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py b/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py new file mode 100644 index 0000000000000000000000000000000000000000..b220beadae290533c6db5933a039bb04f2281bf7 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/audio/whisperX/compile_whisper.py @@ -0,0 +1,199 @@ +import os +import argparse +import time +import math +import torch +import mindietorch +from mindietorch._enums import dtype +from modeling_whisper import MindieWhisperForConditionalGeneration +from config import CompileInfo + +def compile_encoder(model : MindieWhisperForConditionalGeneration, + args : argparse, + compile_info : CompileInfo): + encoder = model.get_encoder() + + class Encoder(torch.nn.Module): + + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, input_features): + return self.model(input_features=input_features, return_dict=False) + + input_features = torch.randn([args.bs, compile_info.mel_feature_size, compile_info.max_frames]) + encoder_traced = torch.jit.trace(Encoder(encoder), (input_features)) + input_info = [mindietorch.Input(shape=(args.bs, compile_info.mel_feature_size, compile_info.max_frames))] + try: + print("Start AOE optimization_level 1.") + compiled = mindietorch.compile(encoder_traced, + inputs=input_info, + precision_policy=mindietorch.PrecisionPolicy.FP16, + soc_version=args.soc_version, + optimization_level=1 + ) + except Exception: + print("Using Aoe to optimize encoder model failed, but still can compile encoder model.") + else: + print("AOE optimize finished and start compile encoder.") + compiled = mindietorch.compile(encoder_traced, + inputs=input_info, + precision_policy=mindietorch.PrecisionPolicy.FP16, + soc_version=args.soc_version, + ) + save_file = os.path.join(args.save_path, f"{compile_info.prefix_name[0]}{args.bs}.ts") + torch.jit.save(compiled, save_file) + print(f"Compile encoder success, saved in {save_file}") + +def compile_prefill_decoder(model : MindieWhisperForConditionalGeneration, + args : argparse, compile_info : CompileInfo): + print("Start compiling prefill_decoder.") + encoder_outputs = torch.randn([args.bs, compile_info.encoder_seq_len, compile_info.hidden_size]) + decoder_input_ids = torch.randint(1, 4, (args.bs, 1)) + prefill_decoder_traced = torch.jit.trace(model.eval(), (decoder_input_ids, encoder_outputs)) + input_info = [mindietorch.Input(shape=(args.bs, 1), dtype=dtype.INT64), + mindietorch.Input(shape=(args.bs, compile_info.encoder_seq_len, compile_info.hidden_size))] + prefill_decoder_compiled = mindietorch.compile(prefill_decoder_traced, + inputs=input_info, + precision_policy=mindietorch.PrecisionPolicy.FP16, + soc_version=args.soc_version) + + save_file = os.path.join(args.save_path, f"{compile_info.prefix_name[1]}{args.bs}.ts") + torch.jit.save(prefill_decoder_compiled, save_file) + print(f"Compile prefill_decoder success, saved in {save_file}.") + +def compile_incre_decoder(args : argparse, compile_info : CompileInfo): + class Decoder(torch.nn.Module): + + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, *args): + return self.model.forward(*args)[0] + + mindie_whisper = MindieWhisperForConditionalGeneration.from_pretrained(args.model_path, + machine_type = args.machine_type, + is_incre_decode=True) + decoder = Decoder(mindie_whisper) + print("Start compiling decoder.") + + encoder_outputs = torch.randn([args.bs, compile_info.encoder_seq_len, compile_info.hidden_size]) + decoder_input_ids = torch.randint(1, 4, (args.bs, 1)) + actual_seq_len = torch.ones((args.bs)) + embed_ops = torch.randn((1, 1280)) + all_past_key_value = [torch.randn([args.bs, compile_info.max_decode_step, compile_info.head_num, compile_info.head_size]), + torch.randn([args.bs, compile_info.max_decode_step, compile_info.head_num, compile_info.head_size]), + torch.randn([args.bs, compile_info.encoder_seq_len, compile_info.head_num, compile_info.head_size]), + torch.randn([args.bs, compile_info.encoder_seq_len, compile_info.head_num, compile_info.head_size]) + ] * compile_info.layer_nums + traced_args = [decoder_input_ids, encoder_outputs, actual_seq_len, embed_ops] + traced_args.extend(all_past_key_value) + traced_decoder = torch.jit.trace(decoder, traced_args) + # BSND + key_value_infos = [ + mindietorch.Input(shape=(args.bs, compile_info.max_decode_step, compile_info.head_num, compile_info.head_size), + dtype=dtype.FLOAT16), + mindietorch.Input(shape=(args.bs, compile_info.max_decode_step, compile_info.head_num, compile_info.head_size), + dtype=dtype.FLOAT16), + mindietorch.Input(shape=(args.bs, compile_info.encoder_seq_len, compile_info.head_num, compile_info.head_size), + dtype=dtype.FLOAT16), + mindietorch.Input(shape=(args.bs, compile_info.encoder_seq_len, compile_info.head_num, compile_info.head_size), + dtype=dtype.FLOAT16 + )] * compile_info.layer_nums + input_info = [mindietorch.Input(shape=(args.bs, 1), dtype=dtype.INT64), # input ids + mindietorch.Input(shape=(args.bs, compile_info.encoder_seq_len, compile_info.hidden_size)), + mindietorch.Input(shape=(args.bs,), dtype=dtype.INT64), # actual sq len + mindietorch.Input(shape=embed_ops.shape, dtype=dtype.FLOAT)] # embed_ops + + input_info.extend(key_value_infos) + float_size = 4 + voc_size = 51866 + buffer_size = math.ceil((args.bs * 1 * voc_size * float_size) / 1024 / 1024) + print(f"Set {buffer_size}/MB for output.") + compiled_decoder = mindietorch.compile(traced_decoder, + inputs=input_info, + precision_policy=mindietorch.PrecisionPolicy.FP16, + soc_version=args.soc_version, + default_buffer_size_vec=[buffer_size]) + save_file = os.path.join(args.save_path, f"{compile_info.prefix_name[2]}{args.bs}.ts") + torch.jit.save(compiled_decoder, save_file) + + print(f"Compile whisper_decoder success, saved in {save_file}.") + +def compile_scatter_update(args, compile_info): + class MindieScatter(torch.nn.Module): + def forward(self, past_key_value, indices, update_states): + out = torch.ops.aie.scatter_update(past_key_value, indices, update_states, 1) + return out + + bs = args.bs + self_past_key_value = torch.randn([bs, compile_info.max_decode_step, compile_info.head_num, compile_info.head_size]) + encoder_past_key_value = torch.randn([bs, compile_info.encoder_seq_len, compile_info.head_num, compile_info.head_size]) + indices = torch.tensor([0] * bs) + update_states = torch.randn([bs, 1, 20, 64]) + traced = torch.jit.trace(MindieScatter(), (self_past_key_value, indices, update_states)) + + self_attn_info = mindietorch.Input(shape=self_past_key_value.shape, dtype=mindietorch.dtype.FLOAT16) + encoder_attn_info = mindietorch.Input(shape=encoder_past_key_value.shape, dtype=mindietorch.dtype.FLOAT16) + indices_info = mindietorch.Input(shape=indices.shape, dtype=mindietorch.dtype.INT64) + update_states_info = mindietorch.Input(shape=update_states.shape, dtype=mindietorch.dtype.FLOAT16) + + compile_self = mindietorch.compile(traced, inputs=[self_attn_info, indices_info, update_states_info], + soc_version=args.soc_version) + torch.jit.save(compile_self, f"{args.save_path}/{compile_info.prefix_name[3]}{args.bs}.ts") + + compile_self = mindietorch.compile(traced, inputs=[encoder_attn_info, indices_info, encoder_attn_info], + soc_version=args.soc_version) + torch.jit.save(compile_self, f"{args.save_path}/{compile_info.prefix_name[4]}{args.bs}.ts") + print("compile scatter success.") + + +def compile_embed_pos(model : MindieWhisperForConditionalGeneration, args : argparse, compile_info : CompileInfo): + + class EmbeddingPosition(torch.nn.Module): + def __init__(self, whisper_model : MindieWhisperForConditionalGeneration): + super().__init__() + self.embed_pos = whisper_model.model.decoder.embed_positions + self.dumpy_input_ids = torch.ones([args.bs, 1]) + + def forward(self, past_length): + return self.embed_pos(input_ids = self.dumpy_input_ids, past_key_values_length=past_length) + dump_inputs = torch.tensor([1]) + traced = torch.jit.trace(EmbeddingPosition(model), (dump_inputs,)) + compiled = mindietorch.compile(traced, + inputs=[mindietorch.Input(shape=dump_inputs.shape, dtype=dtype.INT32)], + soc_version=args.soc_version) + torch.jit.save(compiled, f"{args.save_path}/{compile_info.prefix_name[-1]}{args.bs}.ts") + print(f"save {args.save_path}/{compile_info.prefix_name[-1]}{args.bs}.ts success.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('-model_path', type=str, required=True, help="please provide model path.") + parser.add_argument('-bs', type=int, default=8, help="please provide batch_size, default:8.") + parser.add_argument('-soc_version', type=str, required=True, + help="please provide soc_version.") + parser.add_argument('-save_path', type=str, default="compiled_models", help="compiled models save dir.") + parser.add_argument('-machine_type', type=str, choices=["300IPro", "800IA2"], default="800A2") + parser.add_argument('-device_id', type=int, default=0) + + args = parser.parse_args() + mindie_whisper = MindieWhisperForConditionalGeneration.from_pretrained(args.model_path, + machine_type=args.machine_type) + device = f"npu:{args.device_id}" + print("Start compiling Mindie-Whisper, it will take some time, please wait.") + if not args.save_path: + raise ValueError("Please provide the directory where the compiled model saved.") + if not os.path.exists(args.save_path): + os.makedirs(args.save_path) + print(f"Directory {args.save_path} created.") + else: + print(f"Directory {args.save_path} already exists.") + mindietorch.set_device(args.device_id) + compile_scatter_update(args, CompileInfo) + compile_encoder(mindie_whisper, args, CompileInfo) + compile_prefill_decoder(mindie_whisper, args, CompileInfo) + compile_incre_decoder(args, CompileInfo) + compile_embed_pos(mindie_whisper, args, CompileInfo) \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/audio/whisperX/config.py b/MindIE/MindIE-Torch/built-in/audio/whisperX/config.py new file mode 100644 index 0000000000000000000000000000000000000000..591584d94dcf17651dbc039a3aa33f9a48379ff2 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/audio/whisperX/config.py @@ -0,0 +1,19 @@ +from dataclasses import dataclass, field +from typing import List + +@dataclass +class CompileInfo: + prefix_name = ["mindie_whisper_encoder_bs", + "mindie_decoder_prefill_bs", + "mindie_whisper_decoder_bs", + "mindie_self_scatter_bs", + "mindie_encoder_scatter_bs", + "mindie_embed_bs"] + mel_feature_size = 128 + max_frames = 3000 + max_decode_step = 448 + head_num = 20 + head_size = 64 + encoder_seq_len = 1500 + hidden_size = 1280 + layer_nums = 32 \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py b/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py new file mode 100644 index 0000000000000000000000000000000000000000..5d817fcdbfe3905ddf29ab6d60c5528816982324 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/audio/whisperX/modeling_whisper.py @@ -0,0 +1,880 @@ +import argparse +import os +import copy +import math +import warnings +import librosa +from typing import Optional, Tuple, Union, List, Dict +from collections import OrderedDict + +import torch +from torch.utils.data import Dataset +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from transformers import WhisperForConditionalGeneration +from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask +from transformers.models.whisper.modeling_whisper import WhisperDecoder, WhisperEncoder, WhisperEncoderLayer, \ + WhisperModel, WhisperDecoderLayer, WhisperAttention +from transformers.generation.logits_process import LogitsProcessorList +from transformers.generation.stopping_criteria import StoppingCriteriaList, validate_stopping_criteria +from transformers.generation.utils import GenerationMixin +import mindietorch +from mindietorch._enums import dtype +from config import CompileInfo + + +class MindiePFA(WhisperAttention): + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + is_causal: bool = False, + config: bool = None ): + super().__init__( + embed_dim, + num_heads, + dropout, + bias, + is_causal, + config ) + + # BSND + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + **kwargs ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # cross_attn + is_cross_attn = key_value_states is not None + bsz, tgt_len, _ = hidden_states.size() + query_states = self.q_proj(hidden_states) + if ( + is_cross_attn + and past_key_value is not None + # past_key_value layout is BSND + and past_key_value[0].shape[1] == key_value_states.shape[1] + ): + # reuse k, v + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attn: + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + # self attn + elif past_key_value is not None: + # reuse k v + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=1) + value_states = torch.cat([past_key_value[1], value_states], dim=1) + else: + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + past_key_value = (key_states.to(torch.float16), value_states.to(torch.float16)) + + # B S N D + query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).contiguous() + + # B S N D + attn_output = torch.ops.aie.flash_attention( + query=query_states, + key=key_states, + value=value_states, + num_head=self.num_heads, + scale=self.scaling, + layout="BSND", + type="PFA" ) + attn_output = attn_output.view(bsz, tgt_len, self.num_heads, self.head_dim) + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + attn_output = self.out_proj(attn_output) + return attn_output, _, past_key_value + +class MindieIFA(WhisperAttention): + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + is_causal: bool = True, + config=None ): + super().__init__( + embed_dim, + num_heads, + dropout, + is_decoder, + bias, + is_causal, + config ) + + # BSND + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + past_key_value: torch.Tensor, + actual_seq_len: torch.Tensor = None, + **kwargs ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # self attn + assert past_key_value is not None, \ + "Current operation is incre_flash_attention, past_key_value is required." + bsz, tgt_len, _ = hidden_states.size() + assert tgt_len == 1, \ + "Current operation is incre_flash_attention, query's seq length should be equal to 1." + query_states = self.q_proj(hidden_states) + key_states = self._shape(self.k_proj(hidden_states), -1, bsz).to(torch.float16) + values_states = self._shape(self.v_proj(hidden_states), -1, bsz).to(torch.float16) + past_key_cache, past_value_cache = past_key_value[0], past_key_value[1] + if actual_seq_len is not None: + # self atten + indices = actual_seq_len - 1 + past_key_cache = torch.ops.aie.scatter_update(past_key_cache, indices, key_states, axis=1) + past_value_cache = torch.ops.aie.scatter_update(past_value_cache, indices, values_states, axis=1) + past_key_value = (past_key_cache, past_value_cache) + + # B S N D + query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).contiguous() + # B S N D + attn_output = torch.ops.aie.incre_flash_attention( + query=query_states, + key=past_key_cache, + value=past_value_cache, + actual_seq_lengths=actual_seq_len, + num_head=self.num_heads, + scale=self.scaling, + layout="BSND") + attn_output = attn_output.view(bsz, tgt_len, self.num_heads, self.head_dim) + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + attn_output = self.out_proj(attn_output) + return attn_output, _, past_key_value + else: + query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).contiguous() + # B S N D + attn_output = torch.ops.aie.incre_flash_attention( + query=query_states, + key=past_key_cache, + value=past_value_cache, + num_head=self.num_heads, + scale=self.scaling, + layout="BSND") + attn_output = attn_output.view(bsz, tgt_len, self.num_heads, self.head_dim) + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + attn_output = self.out_proj(attn_output) + return attn_output, _, past_key_value + + +class MindieFA(WhisperAttention): + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + is_causal: bool = False, + config: bool = None ): + super().__init__( + embed_dim, + num_heads, + dropout, + bias, + is_causal, + config) + + # BSND + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + **kwargs ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # cross_attn + is_cross_attn = key_value_states is not None + bsz, tgt_len, _ = hidden_states.size() + query_states = self.q_proj(hidden_states) + if ( + is_cross_attn + and past_key_value is not None + # past_key_value layout is BSND + and past_key_value[0].shape[1] == key_value_states.shape[1] + ): + # reuse k, v + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attn: + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + # self attn + elif past_key_value is not None: + # reuse k v + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=1) + value_states = torch.cat([past_key_value[1], value_states], dim=1) + else: + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + past_key_value = (key_states.to(torch.float16), value_states.to(torch.float16)) + + # MindFA only support BNSD layout + query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).contiguous() + attn_output = torch.ops.aie.flash_attention( + query=query_states.transpose(1, 2), + key=key_states.transpose(1, 2), + value=value_states.transpose(1, 2), + num_head=self.num_heads, + scale=self.scaling, + layout="BNSD", + type="FA_HIGH_PERF") + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.view(bsz, tgt_len, self.num_heads, self.head_dim) + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + attn_output = self.out_proj(attn_output) + return attn_output, _, past_key_value + +class MindieAttention(WhisperAttention): + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + is_causal: bool = False, + config=None ): + super().__init__( + embed_dim, + num_heads, + dropout, + is_decoder, + bias, + is_causal, + config ) + + # BSND + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + **kwargs ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, tgt_len, _ = hidden_states.size() + query_states = self.q_proj(hidden_states) + + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + # B S N D + query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).contiguous() + + # B S N D + attn_output = torch.ops.aie.flash_attention( + query=query_states, + key=key_states, + value=value_states, + num_head=self.num_heads, + scale=self.scaling, + layout="BSND", + type="PFA") + + attn_output = attn_output.view(bsz, tgt_len, self.num_heads, self.head_dim) + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + attn_output = self.out_proj(attn_output) + return attn_output, _, _ + + +class MindieWhisperDecoderLayer(WhisperDecoderLayer): + + def __init__(self, config): + super().__init__(config) + self.embed_dim = config.d_model + if config.is_incre_decode: + self.self_attn = MindieIFA( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + config=config, + is_decoder=True + ) + self.encoder_attn = MindieIFA( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + is_decoder=True, + config=config + ) + elif config.machine_type == "800IA2": + self.self_attn = MindiePFA( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + config=config, + is_decoder=True + ) + self.encoder_attn = MindiePFA( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + is_decoder=True, + config=config + ) + elif config.machine_type == "300IPro": + self.self_attn = MindieFA( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + config=config, + is_decoder=True + ) + self.encoder_attn = MindieFA( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + is_decoder=True, + config=config + ) + else: + raise ValueError(f"Unsupport current parameters. eg. " + f"is_incre_decode: {config.is_incre_decode} machine_type {config.machine_type}") + + def forward( + self, + hidden_states, + attention_mask, + encoder_hidden_states, + layer_head_mask, + cross_attn_layer_head_mask, + past_key_value, + cross_attn_past_key_value, + output_attentions, + use_cache, + actual_seq_len=None, + encoder_attention_mask=None + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Self Attention + self_attn_past_key_value = past_key_value if past_key_value is not None else None + hidden_states, _, present_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_value=self_attn_past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + actual_seq_len=actual_seq_len + ) + hidden_states = residual + hidden_states + + # Cross-Attention Block + cross_attn_present_key_value = None + if encoder_hidden_states is not None: + residual = hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + + hidden_states, _, cross_attn_present_key_value = self.encoder_attn( + hidden_states=hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + layer_head_mask=cross_attn_layer_head_mask, + past_key_value=cross_attn_past_key_value, + output_attentions=output_attentions + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + present_key_value = present_key_value + cross_attn_present_key_value + + # Fully Connected + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = hidden_states.reshape(-1, self.embed_dim) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = self.fc2(hidden_states) + hidden_states = hidden_states.reshape(-1, 1, self.embed_dim) + hidden_states = residual + hidden_states + outputs = (hidden_states,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class MindieWhisperDecoder(WhisperDecoder): + + def __init__(self, config): + super().__init__(config) + self.layers = nn.ModuleList([MindieWhisperDecoderLayer(config) + for _ in range(config.decoder_layers)]) + self.config = config + + def forward( + self, + input_ids, + encoder_hidden_states, + past_key_values, + actual_seq_len, + embed_position = None, + use_cache=True, + attention_mask=None): + if input_ids is None: + raise ValueError("You have to specify either decoder_input_ids") + inputs_embeds = self.embed_tokens(input_ids) + + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + # B S N D todo + # past_key_values_length = past_key_values[0].shape[2] if past_key_values is not None else 0 + inputs_embeds = self.embed_tokens(input_ids) + # embed positions + if past_key_values is None: + positions = self.embed_positions(input_ids, past_key_values_length=0) + hidden_states = inputs_embeds + positions + else: + assert embed_position is not None, "Embed_position is required, as past_key_values is not None." + hidden_states = inputs_embeds + embed_position + + past_key_value_cache = [] + for idx, decoder_layer in enumerate(self.layers): + past_key_value = (past_key_values[4 * idx], past_key_values[4 * idx + 1]) \ + if past_key_values is not None else None + cross_past_key_value = (past_key_values[4 * idx + 2], past_key_values[4 * idx + 3]) \ + if past_key_values is not None else None + + layer_outputs = decoder_layer( + hidden_states, + actual_seq_len=actual_seq_len, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + layer_head_mask=None, + cross_attn_layer_head_mask=None, + past_key_value=past_key_value, + cross_attn_past_key_value=cross_past_key_value, + output_attentions=None, + use_cache=use_cache) + + hidden_states = layer_outputs[0] + past_key_value_cache.extend(layer_outputs[1]) + + hidden_states = self.layer_norm(hidden_states) + return hidden_states, past_key_value_cache + + +class MindieWhisperEncoderLayer(WhisperEncoderLayer): + def __init__(self, config): + super().__init__(config) + if config.machine_type == "300IPro": + self.self_attn = MindieFA( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + dropout=config.attention_dropout, + config=config, + ) + elif config.machine_type == "800IA2": + self.self_attn = MindiePFA( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + dropout=config.attention_dropout, + config=config, + ) + else: + raise ValueError(f"Unsupport current machine_type {config.machine_type} when init encoder.") + + +class MindieWhisperEncoder(WhisperEncoder): + + def __init__(self, config): + super().__init__(config) + self.layers = nn.ModuleList([MindieWhisperEncoderLayer(config) for _ in range(config.encoder_layers)]) + + +class MindieWhisperModel(WhisperModel): + + def __init__(self, config): + super().__init__(config) + self.decoder = MindieWhisperDecoder(config) + self.encoder = MindieWhisperEncoder(config) + + def forward( + self, + encoder_outputs, + decoder_input_ids, + embed_position: Optional[torch.Tensor] = None, + past_key_values: Optional[torch.Tensor] = None, + actual_seq_len: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = False, + return_dict: Optional[bool] = True, + input_features: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None + ) -> List[torch.Tensor]: + if input_features is None and encoder_outputs is None: + raise ValueError("Make sure to provide either `input_features` or `encoder_outputs` to `forward`.") + + if encoder_outputs is None: + input_features = self._mask_input_features(input_features, attention_mask=attention_mask) + + encoder_outputs = self.encoder( + input_features, + head_mask=None, + output_attentions=None, + output_hidden_states=None, + return_dict=return_dict, + ) + + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + encoder_hidden_states=encoder_outputs, + past_key_values=past_key_values, + use_cache=use_cache, + actual_seq_len=actual_seq_len, + embed_position=embed_position + ) + return decoder_outputs + + +class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, GenerationMixin): + + def __init__(self, config, is_incre_decode=False, machine_type="800IA2"): + super().__init__(config) + config.is_incre_decode = is_incre_decode + config.machine_type = machine_type + self.model = MindieWhisperModel(config) + self.has_load = False + self.has_compile = False + self.mindie_encoder = None + self.mindie_decoder_prefill = None + self.mindie_decoder = None + self.self_attn_scatter = None + self.encoder_attn_scatter = None + self.save_path = None + self.encoder_seq_len = 1500 + self.file_prefix_names = CompileInfo.prefix_name + + self.past_key_value = [] + + def forward(self, *args): + if len(args) not in (2, 132): + raise ValueError(f"The args length of forward can only be 2 or 131, but got {len(args)}") + decoder_input_ids = args[0] + encoder_outputs = args[1] + if len(args) == 132: + actual_seq_len = args[2] + embed_position = args[3] + past_key_values = args[4:] + else: + actual_seq_len = None + embed_position = None + past_key_values = None + outputs = self.model( + decoder_input_ids=decoder_input_ids, + encoder_outputs=encoder_outputs, + past_key_values=past_key_values, + actual_seq_len=actual_seq_len, + use_cache=True, + return_dict=False, + input_features=None, + embed_position = embed_position + ) + lm_logits = self.proj_out(outputs[0]) + return [lm_logits] + outputs[1] + + def load_mindie_models(self, save_path, batch_size): + if not (save_path and batch_size): + raise ValueError(f"Please provide batch_size and the directory where the compiled models saved,\ + but found save_path is {save_path}, batch_size is{batch_size}.") + self._check_save_path(save_path, batch_size) + + for _ in range(32): + self.past_key_value.append( + torch.ones([batch_size, CompileInfo.max_decode_step, CompileInfo.head_num, CompileInfo.head_size], + dtype=torch.float16).to("npu") + ) + self.past_key_value.append( + torch.ones([batch_size, CompileInfo.max_decode_step, CompileInfo.head_num, CompileInfo.head_size], + dtype=torch.float16).to("npu") + ) + self.past_key_value.append( + torch.ones([batch_size, CompileInfo.encoder_seq_len, CompileInfo.head_num, CompileInfo.head_size], + dtype=torch.float16).to("npu") + ) + self.past_key_value.append( + torch.ones([batch_size, CompileInfo.encoder_seq_len, CompileInfo.head_num, CompileInfo.head_size], + dtype=torch.float16).to("npu") + ) + print("init past key value cache success.") + + if not self.has_load: + self.mindie_encoder = torch.jit.load(f"{save_path}/{self.file_prefix_names[0]}{batch_size}.ts") + print(f"load {save_path}/{self.file_prefix_names[0]}{batch_size}.ts success.") + + self.mindie_decoder_prefill = torch.jit.load(f"{save_path}/{self.file_prefix_names[1]}{batch_size}.ts") + print(f"load {save_path}/{self.file_prefix_names[1]}{batch_size}.ts success.") + + self.mindie_decoder = torch.jit.load(f"{save_path}/{self.file_prefix_names[2]}{batch_size}.ts") + print(f"load {save_path}/{self.file_prefix_names[2]}{batch_size}.ts success.") + + self.self_attn_scatter = torch.jit.load(f"{save_path}/{self.file_prefix_names[3]}{batch_size}.ts") + print(f"load {save_path}/{self.file_prefix_names[3]}{batch_size}.ts success.") + + self.encoder_attn_scatter = torch.jit.load(f"{save_path}/{self.file_prefix_names[4]}{batch_size}.ts") + print(f"load {save_path}/{self.file_prefix_names[4]}{batch_size}.ts success.") + + self.embed_pos = torch.jit.load(f"{save_path}/{self.file_prefix_names[-1]}{batch_size}.ts") + print(f"load {save_path}/{self.file_prefix_names[-1]}{batch_size}.ts success.") + self.has_load = True + else: + print("Mindie whisper has already load.") + + def greedy_search( + self, + input_ids: torch.LongTensor, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + max_length: Optional[int] = None, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[Union[int, List[int]]] = None, + output_scores: Optional[bool] = None, + return_dict_in_generate: Optional[bool] = None, + synced_gpus: bool = False, + **model_kwargs): + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + if max_length is not None: + print( + "`max_length` is deprecated in this function, use" + " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead." + ) + stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) + pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + eos_token_id_tensor = torch.tensor(eos_token_id).to("cpu") if eos_token_id is not None else None + output_scores = output_scores if output_scores is not None else self.generation_config.output_scores + # init attention / hidden states / scores tuples + scores = () if (return_dict_in_generate and output_scores) else None + + # keep track of which sequences are already finished + unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device="cpu") + + this_peer_finished = False # used by synced_gpus only + kv_actual_step = 1 + past_kv_len = 0 + indices = torch.tensor([0] * input_ids.shape[0]).to("npu") + is_first_step = True + while True: + + model_inputs = self.prepare_inputs_for_generation(input_ids, is_first_step, **model_kwargs) + args = [model_inputs["decoder_input_ids"].contiguous().to("npu"), model_inputs["encoder_outputs"]] + if is_first_step: + outputs = self.mindie_decoder_prefill(*args) + for idx in range(32): + self.self_attn_scatter(self.past_key_value[4*idx], indices, outputs[1 + 4*idx]) + self.self_attn_scatter(self.past_key_value[4*idx + 1], indices, outputs[1 + 4*idx + 1]) + self.encoder_attn_scatter(self.past_key_value[4*idx + 2], indices, outputs[1 + 4*idx + 2]) + self.encoder_attn_scatter(self.past_key_value[4*idx + 3], indices, outputs[1 + 4*idx + 3]) + is_first_step = False + past_kv_len + 1 + else: + kv_actual_step += 1 + args.append(torch.tensor([kv_actual_step] * input_ids.shape[0]).to("npu")) + pos_info = self.embed_pos(torch.tensor([past_kv_len], dtype=torch.int32).to("npu")) + args.append(pos_info) + past_kv_len += 1 + args.extend(self.past_key_value) + outputs = self.mindie_decoder(*args) + if isinstance(outputs, list): + next_token_logits = outputs[0].to("cpu")[:, -1, :] + else: + next_token_logits = outputs.to("cpu")[:, -1, :] + + # pre-process distribution + next_tokens_scores = logits_processor(input_ids.to("cpu"), next_token_logits) + + # argmax + next_tokens = torch.argmax(next_tokens_scores, dim=-1) + + # finished sentences should have their next token be a padding token + if eos_token_id is not None: + if pad_token_id is None: + raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") + next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) + + # update generated ids, model inputs, and length for next step + input_ids = torch.cat([input_ids.to("cpu"), next_tokens[:, None]], dim=-1) + + model_kwargs = self._update_model_kwargs_for_generation( + outputs, model_kwargs + ) + + # if eos_token was found in one sentence, set sentence to finished + if eos_token_id_tensor is not None: + unfinished_sequences = unfinished_sequences.mul( + next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) + ) + + # stop when each sentence is finished + if unfinished_sequences.max() == 0: + this_peer_finished = True + + # stop if we exceed the maximum length + if stopping_criteria(input_ids, scores): + this_peer_finished = True + + if this_peer_finished and not synced_gpus: + break + + return input_ids + + def generate( + self, + input_features: Optional[torch.Tensor] = None, + generation_config=None, + logits_processor=None, + stopping_criteria=None, + prefix_allowed_tokens_fn=None, + synced_gpus=False, + return_dict_in_generate: Optional[bool] = None, + **kwargs, + ): + if generation_config is None: + generation_config = copy.deepcopy(self.generation_config) + num_segment_frames = 3000 + assert input_features.shape[-1] == num_segment_frames, "Only support 30s speech." + encoder_outputs = self.mindie_encoder(input_features.to("npu"))[0] + kwargs["encoder_outputs"] = encoder_outputs + outputs = GenerationMixin.generate( + self, + input_features, + generation_config, + logits_processor, + stopping_criteria, + prefix_allowed_tokens_fn, + synced_gpus, + return_dict_in_generate=return_dict_in_generate, + **kwargs + ) + return outputs + + def _update_model_kwargs_for_generation( + self, + outputs, + model_inputs, + ): + return model_inputs + + def _maybe_initialize_input_ids_for_generation( + self, + inputs: Optional[torch.Tensor] = None, + bos_token_id: Optional[int] = None, + model_kwargs: Optional[Dict[str, torch.Tensor]] = None, + ) -> torch.LongTensor: + """Initializes input ids for generation, if necessary.""" + if inputs is not None: + return inputs + + encoder_outputs = model_kwargs.get("encoder_outputs") + if self.config.is_encoder_decoder and encoder_outputs is not None: + # make dummy input_ids with value -100, as a sanity check ensuring that they won't be used for encoding + shape = encoder_outputs.size()[:-1] + input_ids = torch.ones(shape, dtype=torch.long, device="cpu") * -100 + return input_ids.to("npu") + + if bos_token_id is None: + raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.") + + # If there is some tensor in `model_kwargs`, we can infer the batch size from it. This is helpful with + # soft-prompting or in multimodal implementations built on top of decoder-only language models. + batch_size = 1 + for value in model_kwargs.values(): + if isinstance(value, torch.Tensor): + batch_size = value.shape[0] + break + return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * bos_token_id + + def _prepare_decoder_input_ids_for_generation( + self, + batch_size: int, + model_input_name: str, + model_kwargs: Dict[str, torch.Tensor], + decoder_start_token_id: int = None, + bos_token_id: int = None, + device: torch.device = None, + ) -> Tuple[torch.LongTensor, Dict[str, torch.Tensor]]: + """Prepares `decoder_input_ids` for generation with encoder-decoder models""" + # 1. Check whether the user has defined `decoder_input_ids` manually. To facilitate in terms of input naming, + # we also allow the user to pass it under `input_ids`, if the encoder does not use it as the main input. + if model_kwargs is not None and "decoder_input_ids" in model_kwargs: + decoder_input_ids = model_kwargs.pop("decoder_input_ids") + elif "input_ids" in model_kwargs and model_input_name != "input_ids": + decoder_input_ids = model_kwargs.pop("input_ids") + else: + decoder_input_ids = None + + # 2. Encoder-decoder models expect the `decoder_input_ids` to start with a special token. Let's ensure that. + decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id) + if device is None: + device = self.device + decoder_input_ids_start = torch.ones((batch_size, 1), dtype=torch.long, device="cpu") * decoder_start_token_id + + # no user input -> use decoder_start_token_id as decoder_input_ids + if decoder_input_ids is None: + decoder_input_ids = decoder_input_ids_start + # exception: Donut checkpoints have task-specific decoder starts and don't expect a BOS token + elif self.config.model_type == "vision-encoder-decoder" and "donut" in self.name_or_path.lower(): + pass + # user input but doesn't start with decoder_start_token_id -> prepend decoder_start_token_id (and adjust + # decoder_attention_mask if provided) + elif (decoder_input_ids[:, 0] != decoder_start_token_id).all().item(): + decoder_input_ids = torch.cat([decoder_input_ids_start, decoder_input_ids], dim=-1) + if "decoder_attention_mask" in model_kwargs: + decoder_attention_mask = model_kwargs["decoder_attention_mask"] + decoder_attention_mask = torch.cat( + (torch.ones_like(decoder_attention_mask)[:, :1], decoder_attention_mask), + dim=-1, + ) + model_kwargs["decoder_attention_mask"] = decoder_attention_mask.to(self.device) + + return decoder_input_ids.to("npu"), model_kwargs + + def _check_save_path(self, save_path, batch_size): + file_list = os.listdir(save_path) + expected_files = [file + f"{batch_size}.ts" for file in self.file_prefix_names] + for file in expected_files: + if file not in file_list: + raise ValueError(f"Expected file name is {file}, but can't be found in path: {save_path}") + + def prepare_inputs_for_generation( + self, + decoder_input_ids, + if_first_step, + encoder_outputs=None, + **kwargs + ): + if not if_first_step: + decoder_input_ids_shape = decoder_input_ids.shape + remove_prefix_length = decoder_input_ids_shape[1] - 1 + decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] + return { + "encoder_outputs": encoder_outputs, + "decoder_input_ids": decoder_input_ids + } \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/audio/whisperX/readme.md b/MindIE/MindIE-Torch/built-in/audio/whisperX/readme.md new file mode 100644 index 0000000000000000000000000000000000000000..60a3c83ab5f1d9e1c23a0e714e6725652a9d9d56 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/audio/whisperX/readme.md @@ -0,0 +1,81 @@ +# Whisper-large-v3模型-推理指导 + +- [概述](#概述) +- [推理环境准备](#推理环境准备) +- [快速上手](#快速上手) + - [获取源码](#获取源码) + - [模型推理](#模型推理) + +# 概述 + +该工程使用mindietorch部署whisper-large-v3模型 + + +# 推理环境准备 + +- 该模型需要以下插件与驱动 + + **表 1** 版本配套表 + + | 配套 | 版本 | 环境准备指导 | + |-----------| ------- | ------------ | + | Python | 3.10.13 | - | + | torch | 2.1.0+cpu | - | + | torch_audio | 2.1.0+cpu | - | + | CANN | 8.0.RC3 | - | + | MindIE | 1.0.RC3 | - | + +# 快速上手 +## 获取源码 + +1. 安装mindie包 + + ```bash + # 安装mindie + chmod +x ./Ascend-mindie_xxx.run + ./Ascend-mindie_xxx.run --install + source /usr/local/Ascend/mindie/set_env.sh + ``` + + +2. 模型权重下载路径: + ```bash + https://huggingface.co/openai/whisper-large-v3/tree/main + ``` + 将权重文件存放至当前目录下的model_path文件夹,请先创建改文件夹。 + + +5. 安装依赖 + ``` + pip3 install torch==2.1.0 --index-url https://download.pytorch.org/whl/cpu + pip3 install nltk + pip3 install librosa + pip3 install transformers==4.36.0 + pip3 install numpy==1.26.0 + ``` + +## 模型推理 +1. 设置mindie内存池上限为32,执行如下命令设置环境变量。内存池设置过小,内存重复申请和释放会影响性能。 + ``` + export TORCH_AIE_NPU_CACHE_MAX_SIZE=32 + ``` + +2. 模型编译和推理 + ``` + python3 compile_whisper.py \ + -model_path ./model_path \ + -bs 16 \ + -save_path ./compiled_models \ + -soc_version Ascend310P3 + ``` + + 参数说明: + - -model_path:预训练模型路径,必选。 + - -bs:batch_size, 默认值为8, 可选。 + - -save_path: 编译好的模型的保存文件,必选。 + - -device_id: 选在模型运行的卡编号,默认值0,可选。 + - -soc_version: 芯片类型,必选。 + - -machine_tyoe: 机器型号,必选。 + + 约束说明: + 1. 当前暂不支持动态batch,batch_size改变后,需要重新编图。 diff --git a/MindIE/MindIE-Torch/built-in/audio/whisperX/test.py b/MindIE/MindIE-Torch/built-in/audio/whisperX/test.py new file mode 100644 index 0000000000000000000000000000000000000000..854eefd64c0200c6f146e825de65d17ca5bae81d --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/audio/whisperX/test.py @@ -0,0 +1,37 @@ +import os +import argparse +import time +import math +import torch +import torch.nn.functional as F +import mindietorch +from mindietorch._enums import dtype +from modeling_whisper import MindieWhisperForConditionalGeneration +from config import CompileInfo + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument('-model_path', type=str, required=True, help="please provide model path.") + parser.add_argument('-bs', type=int, default=8, help="please provide batch_size, default:8.") + parser.add_argument('-soc_version', type=str, required=True, + help="please provide soc_version.") + parser.add_argument('-save_path', type=str, default="compiled_models", help="compiled models save dir.") + parser.add_argument('-machine_type', type=str, choices=["300IPro", "800IA2"], default="800A2") + parser.add_argument('-device_id', type=int, default=0) + args = parser.parse_args() + mindie_whisper = MindieWhisperForConditionalGeneration.from_pretrained(args.model_path, + machine_type=args.machine_type) + cpu_model = mindie_whisper.model.decoder.embed_positions + mindietorch.set_device(args.device_id) + npu_model = torch.jit.load(f"{args.save_path}/{CompileInfo.prefix_name[-1]}{args.bs}.ts") + print(f"load {CompileInfo.prefix_name[-1]}{args.bs}.ts success.") + input_ids = torch.ones([args.bs, 1]) + for step in range(1, 448): + past_key_values_length = torch.tensor([step]) + cpu_ret = cpu_model(input_ids=input_ids, past_key_values_length=past_key_values_length) + npu_ret = npu_model(past_key_values_length.to("npu")).to("cpu") + + print(f"npu shape {npu_ret.shape}, cpu shape {cpu_ret.shape}") + print(f"cosine {F.cosine_similarity(cpu_ret.reshape(1, -1), npu_ret.reshape(1, -1))}")